diff --git a/.gitattributes b/.gitattributes index c633db7b9f3d1313572d0a0e47f5bc8c69a3e5e0..a4062ffa5f9dc16e30bdb8b07bc0d922f7f249f3 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,7 @@ phivenv/Lib/site-packages/numpy/ma/tests/__pycache__/test_core.cpython-39.pyc fi phivenv/Lib/site-packages/numpy/ma/__pycache__/core.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text phivenv/Lib/site-packages/numpy/random/bit_generator.cp39-win_amd64.pyd filter=lfs diff=lfs merge=lfs -text phivenv/Lib/site-packages/numpy/random/mtrand.cp39-win_amd64.pyd filter=lfs diff=lfs merge=lfs -text +phivenv/Lib/site-packages/numpy/random/_common.cp39-win_amd64.pyd filter=lfs diff=lfs merge=lfs -text +phivenv/Lib/site-packages/numpy/random/_bounded_integers.cp39-win_amd64.pyd filter=lfs diff=lfs merge=lfs -text +phivenv/Lib/site-packages/numpy/random/lib/npyrandom.lib filter=lfs diff=lfs merge=lfs -text +phivenv/Lib/site-packages/numpy/random/_generator.cp39-win_amd64.pyd filter=lfs diff=lfs merge=lfs -text diff --git a/phivenv/Lib/site-packages/numpy/random/_bounded_integers.cp39-win_amd64.pyd b/phivenv/Lib/site-packages/numpy/random/_bounded_integers.cp39-win_amd64.pyd new file mode 100644 index 0000000000000000000000000000000000000000..59e45b2e796617faca486d153e96a3c29cca2d35 --- /dev/null +++ b/phivenv/Lib/site-packages/numpy/random/_bounded_integers.cp39-win_amd64.pyd @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5e38cea16c127bbfb7291b60468e9500f45bbdc6b63f83d3240109d02fd00e44 +size 252928 diff --git a/phivenv/Lib/site-packages/numpy/random/_common.cp39-win_amd64.pyd b/phivenv/Lib/site-packages/numpy/random/_common.cp39-win_amd64.pyd new file mode 100644 index 0000000000000000000000000000000000000000..5671ce67853c2ff3e4dd41314c461295505a97de --- /dev/null +++ b/phivenv/Lib/site-packages/numpy/random/_common.cp39-win_amd64.pyd @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a4eceb8cd34b3ecf695c4252df85452c31f3e4e0a2cb43cfc62f70f21e35d46c +size 176640 diff --git a/phivenv/Lib/site-packages/numpy/random/_generator.cp39-win_amd64.pyd b/phivenv/Lib/site-packages/numpy/random/_generator.cp39-win_amd64.pyd new file mode 100644 index 0000000000000000000000000000000000000000..3165cc4148e43bd285eff35ad990d8a636b3d7ff --- /dev/null +++ b/phivenv/Lib/site-packages/numpy/random/_generator.cp39-win_amd64.pyd @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0831522b7e15b6b0888586a5bb296d093c4f771b314ec812caa261f96189d464 +size 754688 diff --git a/phivenv/Lib/site-packages/numpy/random/lib/npyrandom.lib b/phivenv/Lib/site-packages/numpy/random/lib/npyrandom.lib new file mode 100644 index 0000000000000000000000000000000000000000..353e3ff06d39e8dd3fe370e54fa13a20ead29464 --- /dev/null +++ b/phivenv/Lib/site-packages/numpy/random/lib/npyrandom.lib @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6046451e469e3f9257f302b8a7b90e1fe9c20b4635113e0ab4497da3e36dc71f +size 147796 diff --git a/phivenv/Lib/site-packages/sympy/logic/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/logic/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b06247ff94a1b37a68b07c5099077187e2253ebf Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/logic/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/logic/__pycache__/inference.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/logic/__pycache__/inference.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c2b7e84245f33bd6493352f290a681d5dc9d73be Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/logic/__pycache__/inference.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/logic/algorithms/__init__.py b/phivenv/Lib/site-packages/sympy/logic/algorithms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/logic/algorithms/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/logic/algorithms/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..535dba47fdfffc618a970f69d96d6bd5ba26d512 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/logic/algorithms/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/logic/algorithms/__pycache__/dpll.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/logic/algorithms/__pycache__/dpll.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc3b668f65e03a9af2f6fe7ff12944d38dcb1eae Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/logic/algorithms/__pycache__/dpll.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/logic/algorithms/__pycache__/dpll2.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/logic/algorithms/__pycache__/dpll2.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aef697cafba8043325bc332c3ad4b82d97ff804d Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/logic/algorithms/__pycache__/dpll2.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/logic/algorithms/__pycache__/lra_theory.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/logic/algorithms/__pycache__/lra_theory.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a52288cb0a570b76c5a24c8419d146b94a2f1ea Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/logic/algorithms/__pycache__/lra_theory.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/logic/algorithms/__pycache__/minisat22_wrapper.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/logic/algorithms/__pycache__/minisat22_wrapper.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30884d2ee1909d6d9c49d0832f7fa09ece6b4554 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/logic/algorithms/__pycache__/minisat22_wrapper.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/logic/algorithms/__pycache__/pycosat_wrapper.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/logic/algorithms/__pycache__/pycosat_wrapper.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd14459acd68863167ced1ac86051ab48de39212 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/logic/algorithms/__pycache__/pycosat_wrapper.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/logic/algorithms/__pycache__/z3_wrapper.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/logic/algorithms/__pycache__/z3_wrapper.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..396c54f8c559901f2dfc63b5bba003b66cd1ec23 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/logic/algorithms/__pycache__/z3_wrapper.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/logic/algorithms/dpll2.py b/phivenv/Lib/site-packages/sympy/logic/algorithms/dpll2.py new file mode 100644 index 0000000000000000000000000000000000000000..4f18c81189d6be565dc9b7caa3f0bf48e978bb56 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/logic/algorithms/dpll2.py @@ -0,0 +1,688 @@ +"""Implementation of DPLL algorithm + +Features: + - Clause learning + - Watch literal scheme + - VSIDS heuristic + +References: + - https://en.wikipedia.org/wiki/DPLL_algorithm +""" + +from collections import defaultdict +from heapq import heappush, heappop + +from sympy.core.sorting import ordered +from sympy.assumptions.cnf import EncodedCNF + +from sympy.logic.algorithms.lra_theory import LRASolver + + +def dpll_satisfiable(expr, all_models=False, use_lra_theory=False): + """ + Check satisfiability of a propositional sentence. + It returns a model rather than True when it succeeds. + Returns a generator of all models if all_models is True. + + Examples + ======== + + >>> from sympy.abc import A, B + >>> from sympy.logic.algorithms.dpll2 import dpll_satisfiable + >>> dpll_satisfiable(A & ~B) + {A: True, B: False} + >>> dpll_satisfiable(A & ~A) + False + + """ + if not isinstance(expr, EncodedCNF): + exprs = EncodedCNF() + exprs.add_prop(expr) + expr = exprs + + # Return UNSAT when False (encoded as 0) is present in the CNF + if {0} in expr.data: + if all_models: + return (f for f in [False]) + return False + + if use_lra_theory: + lra, immediate_conflicts = LRASolver.from_encoded_cnf(expr) + else: + lra = None + immediate_conflicts = [] + solver = SATSolver(expr.data + immediate_conflicts, expr.variables, set(), expr.symbols, lra_theory=lra) + models = solver._find_model() + + if all_models: + return _all_models(models) + + try: + return next(models) + except StopIteration: + return False + + # Uncomment to confirm the solution is valid (hitting set for the clauses) + #else: + #for cls in clauses_int_repr: + #assert solver.var_settings.intersection(cls) + + +def _all_models(models): + satisfiable = False + try: + while True: + yield next(models) + satisfiable = True + except StopIteration: + if not satisfiable: + yield False + + +class SATSolver: + """ + Class for representing a SAT solver capable of + finding a model to a boolean theory in conjunctive + normal form. + """ + + def __init__(self, clauses, variables, var_settings, symbols=None, + heuristic='vsids', clause_learning='none', INTERVAL=500, + lra_theory = None): + + self.var_settings = var_settings + self.heuristic = heuristic + self.is_unsatisfied = False + self._unit_prop_queue = [] + self.update_functions = [] + self.INTERVAL = INTERVAL + + if symbols is None: + self.symbols = list(ordered(variables)) + else: + self.symbols = symbols + + self._initialize_variables(variables) + self._initialize_clauses(clauses) + + if 'vsids' == heuristic: + self._vsids_init() + self.heur_calculate = self._vsids_calculate + self.heur_lit_assigned = self._vsids_lit_assigned + self.heur_lit_unset = self._vsids_lit_unset + self.heur_clause_added = self._vsids_clause_added + + # Note: Uncomment this if/when clause learning is enabled + #self.update_functions.append(self._vsids_decay) + + else: + raise NotImplementedError + + if 'simple' == clause_learning: + self.add_learned_clause = self._simple_add_learned_clause + self.compute_conflict = self._simple_compute_conflict + self.update_functions.append(self._simple_clean_clauses) + elif 'none' == clause_learning: + self.add_learned_clause = lambda x: None + self.compute_conflict = lambda: None + else: + raise NotImplementedError + + # Create the base level + self.levels = [Level(0)] + self._current_level.varsettings = var_settings + + # Keep stats + self.num_decisions = 0 + self.num_learned_clauses = 0 + self.original_num_clauses = len(self.clauses) + + self.lra = lra_theory + + def _initialize_variables(self, variables): + """Set up the variable data structures needed.""" + self.sentinels = defaultdict(set) + self.occurrence_count = defaultdict(int) + self.variable_set = [False] * (len(variables) + 1) + + def _initialize_clauses(self, clauses): + """Set up the clause data structures needed. + + For each clause, the following changes are made: + - Unit clauses are queued for propagation right away. + - Non-unit clauses have their first and last literals set as sentinels. + - The number of clauses a literal appears in is computed. + """ + self.clauses = [list(clause) for clause in clauses] + + for i, clause in enumerate(self.clauses): + + # Handle the unit clauses + if 1 == len(clause): + self._unit_prop_queue.append(clause[0]) + continue + + self.sentinels[clause[0]].add(i) + self.sentinels[clause[-1]].add(i) + + for lit in clause: + self.occurrence_count[lit] += 1 + + def _find_model(self): + """ + Main DPLL loop. Returns a generator of models. + + Variables are chosen successively, and assigned to be either + True or False. If a solution is not found with this setting, + the opposite is chosen and the search continues. The solver + halts when every variable has a setting. + + Examples + ======== + + >>> from sympy.logic.algorithms.dpll2 import SATSolver + >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2}, + ... {3, -2}], {1, 2, 3}, set()) + >>> list(l._find_model()) + [{1: True, 2: False, 3: False}, {1: True, 2: True, 3: True}] + + >>> from sympy.abc import A, B, C + >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2}, + ... {3, -2}], {1, 2, 3}, set(), [A, B, C]) + >>> list(l._find_model()) + [{A: True, B: False, C: False}, {A: True, B: True, C: True}] + + """ + + # We use this variable to keep track of if we should flip a + # variable setting in successive rounds + flip_var = False + + # Check if unit prop says the theory is unsat right off the bat + self._simplify() + if self.is_unsatisfied: + return + + # While the theory still has clauses remaining + while True: + # Perform cleanup / fixup at regular intervals + if self.num_decisions % self.INTERVAL == 0: + for func in self.update_functions: + func() + + if flip_var: + # We have just backtracked and we are trying to opposite literal + flip_var = False + lit = self._current_level.decision + + else: + # Pick a literal to set + lit = self.heur_calculate() + self.num_decisions += 1 + + # Stopping condition for a satisfying theory + if 0 == lit: + + # check if assignment satisfies lra theory + if self.lra: + for enc_var in self.var_settings: + res = self.lra.assert_lit(enc_var) + if res is not None: + break + res = self.lra.check() + self.lra.reset_bounds() + else: + res = None + if res is None or res[0]: + yield {self.symbols[abs(lit) - 1]: + lit > 0 for lit in self.var_settings} + else: + self._simple_add_learned_clause(res[1]) + + # backtrack until we unassign one of the literals causing the conflict + while not any(-lit in res[1] for lit in self._current_level.var_settings): + self._undo() + + while self._current_level.flipped: + self._undo() + if len(self.levels) == 1: + return + flip_lit = -self._current_level.decision + self._undo() + self.levels.append(Level(flip_lit, flipped=True)) + flip_var = True + continue + + # Start the new decision level + self.levels.append(Level(lit)) + + # Assign the literal, updating the clauses it satisfies + self._assign_literal(lit) + + # _simplify the theory + self._simplify() + + # Check if we've made the theory unsat + if self.is_unsatisfied: + + self.is_unsatisfied = False + + # We unroll all of the decisions until we can flip a literal + while self._current_level.flipped: + self._undo() + + # If we've unrolled all the way, the theory is unsat + if 1 == len(self.levels): + return + + # Detect and add a learned clause + self.add_learned_clause(self.compute_conflict()) + + # Try the opposite setting of the most recent decision + flip_lit = -self._current_level.decision + self._undo() + self.levels.append(Level(flip_lit, flipped=True)) + flip_var = True + + ######################## + # Helper Methods # + ######################## + @property + def _current_level(self): + """The current decision level data structure + + Examples + ======== + + >>> from sympy.logic.algorithms.dpll2 import SATSolver + >>> l = SATSolver([{1}, {2}], {1, 2}, set()) + >>> next(l._find_model()) + {1: True, 2: True} + >>> l._current_level.decision + 0 + >>> l._current_level.flipped + False + >>> l._current_level.var_settings + {1, 2} + + """ + return self.levels[-1] + + def _clause_sat(self, cls): + """Check if a clause is satisfied by the current variable setting. + + Examples + ======== + + >>> from sympy.logic.algorithms.dpll2 import SATSolver + >>> l = SATSolver([{1}, {-1}], {1}, set()) + >>> try: + ... next(l._find_model()) + ... except StopIteration: + ... pass + >>> l._clause_sat(0) + False + >>> l._clause_sat(1) + True + + """ + for lit in self.clauses[cls]: + if lit in self.var_settings: + return True + return False + + def _is_sentinel(self, lit, cls): + """Check if a literal is a sentinel of a given clause. + + Examples + ======== + + >>> from sympy.logic.algorithms.dpll2 import SATSolver + >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2}, + ... {3, -2}], {1, 2, 3}, set()) + >>> next(l._find_model()) + {1: True, 2: False, 3: False} + >>> l._is_sentinel(2, 3) + True + >>> l._is_sentinel(-3, 1) + False + + """ + return cls in self.sentinels[lit] + + def _assign_literal(self, lit): + """Make a literal assignment. + + The literal assignment must be recorded as part of the current + decision level. Additionally, if the literal is marked as a + sentinel of any clause, then a new sentinel must be chosen. If + this is not possible, then unit propagation is triggered and + another literal is added to the queue to be set in the future. + + Examples + ======== + + >>> from sympy.logic.algorithms.dpll2 import SATSolver + >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2}, + ... {3, -2}], {1, 2, 3}, set()) + >>> next(l._find_model()) + {1: True, 2: False, 3: False} + >>> l.var_settings + {-3, -2, 1} + + >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2}, + ... {3, -2}], {1, 2, 3}, set()) + >>> l._assign_literal(-1) + >>> try: + ... next(l._find_model()) + ... except StopIteration: + ... pass + >>> l.var_settings + {-1} + + """ + self.var_settings.add(lit) + self._current_level.var_settings.add(lit) + self.variable_set[abs(lit)] = True + self.heur_lit_assigned(lit) + + sentinel_list = list(self.sentinels[-lit]) + + for cls in sentinel_list: + if not self._clause_sat(cls): + other_sentinel = None + for newlit in self.clauses[cls]: + if newlit != -lit: + if self._is_sentinel(newlit, cls): + other_sentinel = newlit + elif not self.variable_set[abs(newlit)]: + self.sentinels[-lit].remove(cls) + self.sentinels[newlit].add(cls) + other_sentinel = None + break + + # Check if no sentinel update exists + if other_sentinel: + self._unit_prop_queue.append(other_sentinel) + + def _undo(self): + """ + _undo the changes of the most recent decision level. + + Examples + ======== + + >>> from sympy.logic.algorithms.dpll2 import SATSolver + >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2}, + ... {3, -2}], {1, 2, 3}, set()) + >>> next(l._find_model()) + {1: True, 2: False, 3: False} + >>> level = l._current_level + >>> level.decision, level.var_settings, level.flipped + (-3, {-3, -2}, False) + >>> l._undo() + >>> level = l._current_level + >>> level.decision, level.var_settings, level.flipped + (0, {1}, False) + + """ + # Undo the variable settings + for lit in self._current_level.var_settings: + self.var_settings.remove(lit) + self.heur_lit_unset(lit) + self.variable_set[abs(lit)] = False + + # Pop the level off the stack + self.levels.pop() + + ######################### + # Propagation # + ######################### + """ + Propagation methods should attempt to soundly simplify the boolean + theory, and return True if any simplification occurred and False + otherwise. + """ + def _simplify(self): + """Iterate over the various forms of propagation to simplify the theory. + + Examples + ======== + + >>> from sympy.logic.algorithms.dpll2 import SATSolver + >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2}, + ... {3, -2}], {1, 2, 3}, set()) + >>> l.variable_set + [False, False, False, False] + >>> l.sentinels + {-3: {0, 2}, -2: {3, 4}, 2: {0, 3}, 3: {2, 4}} + + >>> l._simplify() + + >>> l.variable_set + [False, True, False, False] + >>> l.sentinels + {-3: {0, 2}, -2: {3, 4}, -1: set(), 2: {0, 3}, + ...3: {2, 4}} + + """ + changed = True + while changed: + changed = False + changed |= self._unit_prop() + changed |= self._pure_literal() + + def _unit_prop(self): + """Perform unit propagation on the current theory.""" + result = len(self._unit_prop_queue) > 0 + while self._unit_prop_queue: + next_lit = self._unit_prop_queue.pop() + if -next_lit in self.var_settings: + self.is_unsatisfied = True + self._unit_prop_queue = [] + return False + else: + self._assign_literal(next_lit) + + return result + + def _pure_literal(self): + """Look for pure literals and assign them when found.""" + return False + + ######################### + # Heuristics # + ######################### + def _vsids_init(self): + """Initialize the data structures needed for the VSIDS heuristic.""" + self.lit_heap = [] + self.lit_scores = {} + + for var in range(1, len(self.variable_set)): + self.lit_scores[var] = float(-self.occurrence_count[var]) + self.lit_scores[-var] = float(-self.occurrence_count[-var]) + heappush(self.lit_heap, (self.lit_scores[var], var)) + heappush(self.lit_heap, (self.lit_scores[-var], -var)) + + def _vsids_decay(self): + """Decay the VSIDS scores for every literal. + + Examples + ======== + + >>> from sympy.logic.algorithms.dpll2 import SATSolver + >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2}, + ... {3, -2}], {1, 2, 3}, set()) + + >>> l.lit_scores + {-3: -2.0, -2: -2.0, -1: 0.0, 1: 0.0, 2: -2.0, 3: -2.0} + + >>> l._vsids_decay() + + >>> l.lit_scores + {-3: -1.0, -2: -1.0, -1: 0.0, 1: 0.0, 2: -1.0, 3: -1.0} + + """ + # We divide every literal score by 2 for a decay factor + # Note: This doesn't change the heap property + for lit in self.lit_scores.keys(): + self.lit_scores[lit] /= 2.0 + + def _vsids_calculate(self): + """ + VSIDS Heuristic Calculation + + Examples + ======== + + >>> from sympy.logic.algorithms.dpll2 import SATSolver + >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2}, + ... {3, -2}], {1, 2, 3}, set()) + + >>> l.lit_heap + [(-2.0, -3), (-2.0, 2), (-2.0, -2), (0.0, 1), (-2.0, 3), (0.0, -1)] + + >>> l._vsids_calculate() + -3 + + >>> l.lit_heap + [(-2.0, -2), (-2.0, 2), (0.0, -1), (0.0, 1), (-2.0, 3)] + + """ + if len(self.lit_heap) == 0: + return 0 + + # Clean out the front of the heap as long the variables are set + while self.variable_set[abs(self.lit_heap[0][1])]: + heappop(self.lit_heap) + if len(self.lit_heap) == 0: + return 0 + + return heappop(self.lit_heap)[1] + + def _vsids_lit_assigned(self, lit): + """Handle the assignment of a literal for the VSIDS heuristic.""" + pass + + def _vsids_lit_unset(self, lit): + """Handle the unsetting of a literal for the VSIDS heuristic. + + Examples + ======== + + >>> from sympy.logic.algorithms.dpll2 import SATSolver + >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2}, + ... {3, -2}], {1, 2, 3}, set()) + >>> l.lit_heap + [(-2.0, -3), (-2.0, 2), (-2.0, -2), (0.0, 1), (-2.0, 3), (0.0, -1)] + + >>> l._vsids_lit_unset(2) + + >>> l.lit_heap + [(-2.0, -3), (-2.0, -2), (-2.0, -2), (-2.0, 2), (-2.0, 3), (0.0, -1), + ...(-2.0, 2), (0.0, 1)] + + """ + var = abs(lit) + heappush(self.lit_heap, (self.lit_scores[var], var)) + heappush(self.lit_heap, (self.lit_scores[-var], -var)) + + def _vsids_clause_added(self, cls): + """Handle the addition of a new clause for the VSIDS heuristic. + + Examples + ======== + + >>> from sympy.logic.algorithms.dpll2 import SATSolver + >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2}, + ... {3, -2}], {1, 2, 3}, set()) + + >>> l.num_learned_clauses + 0 + >>> l.lit_scores + {-3: -2.0, -2: -2.0, -1: 0.0, 1: 0.0, 2: -2.0, 3: -2.0} + + >>> l._vsids_clause_added({2, -3}) + + >>> l.num_learned_clauses + 1 + >>> l.lit_scores + {-3: -1.0, -2: -2.0, -1: 0.0, 1: 0.0, 2: -1.0, 3: -2.0} + + """ + self.num_learned_clauses += 1 + for lit in cls: + self.lit_scores[lit] += 1 + + ######################## + # Clause Learning # + ######################## + def _simple_add_learned_clause(self, cls): + """Add a new clause to the theory. + + Examples + ======== + + >>> from sympy.logic.algorithms.dpll2 import SATSolver + >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2}, + ... {3, -2}], {1, 2, 3}, set()) + + >>> l.num_learned_clauses + 0 + >>> l.clauses + [[2, -3], [1], [3, -3], [2, -2], [3, -2]] + >>> l.sentinels + {-3: {0, 2}, -2: {3, 4}, 2: {0, 3}, 3: {2, 4}} + + >>> l._simple_add_learned_clause([3]) + + >>> l.clauses + [[2, -3], [1], [3, -3], [2, -2], [3, -2], [3]] + >>> l.sentinels + {-3: {0, 2}, -2: {3, 4}, 2: {0, 3}, 3: {2, 4, 5}} + + """ + cls_num = len(self.clauses) + self.clauses.append(cls) + + for lit in cls: + self.occurrence_count[lit] += 1 + + self.sentinels[cls[0]].add(cls_num) + self.sentinels[cls[-1]].add(cls_num) + + self.heur_clause_added(cls) + + def _simple_compute_conflict(self): + """ Build a clause representing the fact that at least one decision made + so far is wrong. + + Examples + ======== + + >>> from sympy.logic.algorithms.dpll2 import SATSolver + >>> l = SATSolver([{2, -3}, {1}, {3, -3}, {2, -2}, + ... {3, -2}], {1, 2, 3}, set()) + >>> next(l._find_model()) + {1: True, 2: False, 3: False} + >>> l._simple_compute_conflict() + [3] + + """ + return [-(level.decision) for level in self.levels[1:]] + + def _simple_clean_clauses(self): + """Clean up learned clauses.""" + pass + + +class Level: + """ + Represents a single level in the DPLL algorithm, and contains + enough information for a sound backtracking procedure. + """ + + def __init__(self, decision, flipped=False): + self.decision = decision + self.var_settings = set() + self.flipped = flipped diff --git a/phivenv/Lib/site-packages/sympy/logic/algorithms/lra_theory.py b/phivenv/Lib/site-packages/sympy/logic/algorithms/lra_theory.py new file mode 100644 index 0000000000000000000000000000000000000000..1690760d36003aed6866f593120c05a5b8f92c83 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/logic/algorithms/lra_theory.py @@ -0,0 +1,912 @@ +"""Implements "A Fast Linear-Arithmetic Solver for DPLL(T)" + +The LRASolver class defined in this file can be used +in conjunction with a SAT solver to check the +satisfiability of formulas involving inequalities. + +Here's an example of how that would work: + + Suppose you want to check the satisfiability of + the following formula: + + >>> from sympy.core.relational import Eq + >>> from sympy.abc import x, y + >>> f = ((x > 0) | (x < 0)) & (Eq(x, 0) | Eq(y, 1)) & (~Eq(y, 1) | Eq(1, 2)) + + First a preprocessing step should be done on f. During preprocessing, + f should be checked for any predicates such as `Q.prime` that can't be + handled. Also unequality like `~Eq(y, 1)` should be split. + + I should mention that the paper says to split both equalities and + unequality, but this implementation only requires that unequality + be split. + + >>> f = ((x > 0) | (x < 0)) & (Eq(x, 0) | Eq(y, 1)) & ((y < 1) | (y > 1) | Eq(1, 2)) + + Then an LRASolver instance needs to be initialized with this formula. + + >>> from sympy.assumptions.cnf import CNF, EncodedCNF + >>> from sympy.assumptions.ask import Q + >>> from sympy.logic.algorithms.lra_theory import LRASolver + >>> cnf = CNF.from_prop(f) + >>> enc = EncodedCNF() + >>> enc.add_from_cnf(cnf) + >>> lra, conflicts = LRASolver.from_encoded_cnf(enc) + + Any immediate one-lital conflicts clauses will be detected here. + In this example, `~Eq(1, 2)` is one such conflict clause. We'll + want to add it to `f` so that the SAT solver is forced to + assign Eq(1, 2) to False. + + >>> f = f & ~Eq(1, 2) + + Now that the one-literal conflict clauses have been added + and an lra object has been initialized, we can pass `f` + to a SAT solver. The SAT solver will give us a satisfying + assignment such as: + + (1 = 2): False + (y = 1): True + (y < 1): True + (y > 1): True + (x = 0): True + (x < 0): True + (x > 0): True + + Next you would pass this assignment to the LRASolver + which will be able to determine that this particular + assignment is satisfiable or not. + + Note that since EncodedCNF is inherently non-deterministic, + the int each predicate is encoded as is not consistent. As a + result, the code below likely does not reflect the assignment + given above. + + >>> lra.assert_lit(-1) #doctest: +SKIP + >>> lra.assert_lit(2) #doctest: +SKIP + >>> lra.assert_lit(3) #doctest: +SKIP + >>> lra.assert_lit(4) #doctest: +SKIP + >>> lra.assert_lit(5) #doctest: +SKIP + >>> lra.assert_lit(6) #doctest: +SKIP + >>> lra.assert_lit(7) #doctest: +SKIP + >>> is_sat, conflict_or_assignment = lra.check() + + As the particular assignment suggested is not satisfiable, + the LRASolver will return unsat and a conflict clause when + given that assignment. The conflict clause will always be + minimal, but there can be multiple minimal conflict clauses. + One possible conflict clause could be `~(x < 0) | ~(x > 0)`. + + We would then add whatever conflict clause is given to + `f` to prevent the SAT solver from coming up with an + assignment with the same conflicting literals. In this case, + the conflict clause `~(x < 0) | ~(x > 0)` would prevent + any assignment where both (x < 0) and (x > 0) were both + true. + + The SAT solver would then find another assignment + and we would check that assignment with the LRASolver + and so on. Eventually either a satisfying assignment + that the SAT solver and LRASolver agreed on would be found + or enough conflict clauses would be added so that the + boolean formula was unsatisfiable. + + +This implementation is based on [1]_, which includes a +detailed explanation of the algorithm and pseudocode +for the most important functions. + +[1]_ also explains how backtracking and theory propagation +could be implemented to speed up the current implementation, +but these are not currently implemented. + +TODO: + - Handle non-rational real numbers + - Handle positive and negative infinity + - Implement backtracking and theory proposition + - Simplify matrix by removing unused variables using Gaussian elimination + +References +========== + +.. [1] Dutertre, B., de Moura, L.: + A Fast Linear-Arithmetic Solver for DPLL(T) + https://link.springer.com/chapter/10.1007/11817963_11 +""" +from sympy.solvers.solveset import linear_eq_to_matrix +from sympy.matrices.dense import eye +from sympy.assumptions import Predicate +from sympy.assumptions.assume import AppliedPredicate +from sympy.assumptions.ask import Q +from sympy.core import Dummy +from sympy.core.mul import Mul +from sympy.core.add import Add +from sympy.core.relational import Eq, Ne +from sympy.core.sympify import sympify +from sympy.core.singleton import S +from sympy.core.numbers import Rational, oo +from sympy.matrices.dense import Matrix + +class UnhandledInput(Exception): + """ + Raised while creating an LRASolver if non-linearity + or non-rational numbers are present. + """ + +# predicates that LRASolver understands and makes use of +ALLOWED_PRED = {Q.eq, Q.gt, Q.lt, Q.le, Q.ge} + +# if true ~Q.gt(x, y) implies Q.le(x, y) +HANDLE_NEGATION = True + +class LRASolver(): + """ + Linear Arithmetic Solver for DPLL(T) implemented with an algorithm based on + the Dual Simplex method. Uses Bland's pivoting rule to avoid cycling. + + References + ========== + + .. [1] Dutertre, B., de Moura, L.: + A Fast Linear-Arithmetic Solver for DPLL(T) + https://link.springer.com/chapter/10.1007/11817963_11 + """ + + def __init__(self, A, slack_variables, nonslack_variables, enc_to_boundary, s_subs, testing_mode): + """ + Use the "from_encoded_cnf" method to create a new LRASolver. + """ + self.run_checks = testing_mode + self.s_subs = s_subs # used only for test_lra_theory.test_random_problems + + if any(not isinstance(a, Rational) for a in A): + raise UnhandledInput("Non-rational numbers are not handled") + if any(not isinstance(b.bound, Rational) for b in enc_to_boundary.values()): + raise UnhandledInput("Non-rational numbers are not handled") + m, n = len(slack_variables), len(slack_variables)+len(nonslack_variables) + if m != 0: + assert A.shape == (m, n) + if self.run_checks: + assert A[:, n-m:] == -eye(m) + + self.enc_to_boundary = enc_to_boundary # mapping of int to Boundary objects + self.boundary_to_enc = {value: key for key, value in enc_to_boundary.items()} + self.A = A + self.slack = slack_variables + self.nonslack = nonslack_variables + self.all_var = nonslack_variables + slack_variables + + self.slack_set = set(slack_variables) + + self.is_sat = True # While True, all constraints asserted so far are satisfiable + self.result = None # always one of: (True, assignment), (False, conflict clause), None + + @staticmethod + def from_encoded_cnf(encoded_cnf, testing_mode=False): + """ + Creates an LRASolver from an EncodedCNF object + and a list of conflict clauses for propositions + that can be simplified to True or False. + + Parameters + ========== + + encoded_cnf : EncodedCNF + + testing_mode : bool + Setting testing_mode to True enables some slow assert statements + and sorting to reduce nonterministic behavior. + + Returns + ======= + + (lra, conflicts) + + lra : LRASolver + + conflicts : list + Contains a one-literal conflict clause for each proposition + that can be simplified to True or False. + + Example + ======= + + >>> from sympy.core.relational import Eq + >>> from sympy.assumptions.cnf import CNF, EncodedCNF + >>> from sympy.assumptions.ask import Q + >>> from sympy.logic.algorithms.lra_theory import LRASolver + >>> from sympy.abc import x, y, z + >>> phi = (x >= 0) & ((x + y <= 2) | (x + 2 * y - z >= 6)) + >>> phi = phi & (Eq(x + y, 2) | (x + 2 * y - z > 4)) + >>> phi = phi & Q.gt(2, 1) + >>> cnf = CNF.from_prop(phi) + >>> enc = EncodedCNF() + >>> enc.from_cnf(cnf) + >>> lra, conflicts = LRASolver.from_encoded_cnf(enc, testing_mode=True) + >>> lra #doctest: +SKIP + + >>> conflicts #doctest: +SKIP + [[4]] + """ + # This function has three main jobs: + # - raise errors if the input formula is not handled + # - preprocesses the formula into a matrix and single variable constraints + # - create one-literal conflict clauses from predicates that are always True + # or always False such as Q.gt(3, 2) + # + # See the preprocessing section of "A Fast Linear-Arithmetic Solver for DPLL(T)" + # for an explanation of how the formula is converted into a matrix + # and a set of single variable constraints. + + encoding = {} # maps int to boundary + A = [] + + basic = [] + s_count = 0 + s_subs = {} + nonbasic = [] + + if testing_mode: + # sort to reduce nondeterminism + encoded_cnf_items = sorted(encoded_cnf.encoding.items(), key=lambda x: str(x)) + else: + encoded_cnf_items = encoded_cnf.encoding.items() + + empty_var = Dummy() + var_to_lra_var = {} + conflicts = [] + + for prop, enc in encoded_cnf_items: + if isinstance(prop, Predicate): + prop = prop(empty_var) + if not isinstance(prop, AppliedPredicate): + if prop == True: + conflicts.append([enc]) + continue + if prop == False: + conflicts.append([-enc]) + continue + + raise ValueError(f"Unhandled Predicate: {prop}") + + assert prop.function in ALLOWED_PRED + if prop.lhs == S.NaN or prop.rhs == S.NaN: + raise ValueError(f"{prop} contains nan") + if prop.lhs.is_imaginary or prop.rhs.is_imaginary: + raise UnhandledInput(f"{prop} contains an imaginary component") + if prop.lhs == oo or prop.rhs == oo: + raise UnhandledInput(f"{prop} contains infinity") + + prop = _eval_binrel(prop) # simplify variable-less quantities to True / False if possible + if prop == True: + conflicts.append([enc]) + continue + elif prop == False: + conflicts.append([-enc]) + continue + elif prop is None: + raise UnhandledInput(f"{prop} could not be simplified") + + expr = prop.lhs - prop.rhs + if prop.function in [Q.ge, Q.gt]: + expr = -expr + + # expr should be less than (or equal to) 0 + # otherwise prop is False + if prop.function in [Q.le, Q.ge]: + bool = (expr <= 0) + elif prop.function in [Q.lt, Q.gt]: + bool = (expr < 0) + else: + assert prop.function == Q.eq + bool = Eq(expr, 0) + + if bool == True: + conflicts.append([enc]) + continue + elif bool == False: + conflicts.append([-enc]) + continue + + + vars, const = _sep_const_terms(expr) # example: (2x + 3y + 2) --> (2x + 3y), (2) + vars, var_coeff = _sep_const_coeff(vars) # examples: (2x) --> (x, 2); (2x + 3y) --> (2x + 3y), (1) + const = const / var_coeff + + terms = _list_terms(vars) # example: (2x + 3y) --> [2x, 3y] + for term in terms: + term, _ = _sep_const_coeff(term) + assert len(term.free_symbols) > 0 + if term not in var_to_lra_var: + var_to_lra_var[term] = LRAVariable(term) + nonbasic.append(term) + + if len(terms) > 1: + if vars not in s_subs: + s_count += 1 + d = Dummy(f"s{s_count}") + var_to_lra_var[d] = LRAVariable(d) + basic.append(d) + s_subs[vars] = d + A.append(vars - d) + var = s_subs[vars] + else: + var = terms[0] + + assert var_coeff != 0 + + equality = prop.function == Q.eq + upper = var_coeff > 0 if not equality else None + strict = prop.function in [Q.gt, Q.lt] + b = Boundary(var_to_lra_var[var], -const, upper, equality, strict) + encoding[enc] = b + + fs = [v.free_symbols for v in nonbasic + basic] + assert all(len(syms) > 0 for syms in fs) + fs_count = sum(len(syms) for syms in fs) + if len(fs) > 0 and len(set.union(*fs)) < fs_count: + raise UnhandledInput("Nonlinearity is not handled") + + A, _ = linear_eq_to_matrix(A, nonbasic + basic) + nonbasic = [var_to_lra_var[nb] for nb in nonbasic] + basic = [var_to_lra_var[b] for b in basic] + for idx, var in enumerate(nonbasic + basic): + var.col_idx = idx + + return LRASolver(A, basic, nonbasic, encoding, s_subs, testing_mode), conflicts + + def reset_bounds(self): + """ + Resets the state of the LRASolver to before + anything was asserted. + """ + self.result = None + for var in self.all_var: + var.lower = LRARational(-float("inf"), 0) + var.lower_from_eq = False + var.lower_from_neg = False + var.upper = LRARational(float("inf"), 0) + var.upper_from_eq= False + var.lower_from_neg = False + var.assign = LRARational(0, 0) + + def assert_lit(self, enc_constraint): + """ + Assert a literal representing a constraint + and update the internal state accordingly. + + Note that due to peculiarities of this implementation + asserting ~(x > 0) will assert (x <= 0) but asserting + ~Eq(x, 0) will not do anything. + + Parameters + ========== + + enc_constraint : int + A mapping of encodings to constraints + can be found in `self.enc_to_boundary`. + + Returns + ======= + + None or (False, explanation) + + explanation : set of ints + A conflict clause that "explains" why + the literals asserted so far are unsatisfiable. + """ + if abs(enc_constraint) not in self.enc_to_boundary: + return None + + if not HANDLE_NEGATION and enc_constraint < 0: + return None + + boundary = self.enc_to_boundary[abs(enc_constraint)] + sym, c, negated = boundary.var, boundary.bound, enc_constraint < 0 + + if boundary.equality and negated: + return None # negated equality is not handled and should only appear in conflict clauses + + upper = boundary.upper != negated + if boundary.strict != negated: + delta = -1 if upper else 1 + c = LRARational(c, delta) + else: + c = LRARational(c, 0) + + if boundary.equality: + res1 = self._assert_lower(sym, c, from_equality=True, from_neg=negated) + if res1 and res1[0] == False: + res = res1 + else: + res2 = self._assert_upper(sym, c, from_equality=True, from_neg=negated) + res = res2 + elif upper: + res = self._assert_upper(sym, c, from_neg=negated) + else: + res = self._assert_lower(sym, c, from_neg=negated) + + if self.is_sat and sym not in self.slack_set: + self.is_sat = res is None + else: + self.is_sat = False + + return res + + def _assert_upper(self, xi, ci, from_equality=False, from_neg=False): + """ + Adjusts the upper bound on variable xi if the new upper bound is + more limiting. The assignment of variable xi is adjusted to be + within the new bound if needed. + + Also calls `self._update` to update the assignment for slack variables + to keep all equalities satisfied. + """ + if self.result: + assert self.result[0] != False + self.result = None + if ci >= xi.upper: + return None + if ci < xi.lower: + assert (xi.lower[1] >= 0) is True + assert (ci[1] <= 0) is True + + lit1, neg1 = Boundary.from_lower(xi) + + lit2 = Boundary(var=xi, const=ci[0], strict=ci[1] != 0, upper=True, equality=from_equality) + if from_neg: + lit2 = lit2.get_negated() + neg2 = -1 if from_neg else 1 + + conflict = [-neg1*self.boundary_to_enc[lit1], -neg2*self.boundary_to_enc[lit2]] + self.result = False, conflict + return self.result + xi.upper = ci + xi.upper_from_eq = from_equality + xi.upper_from_neg = from_neg + if xi in self.nonslack and xi.assign > ci: + self._update(xi, ci) + + if self.run_checks and all(v.assign[0] != float("inf") and v.assign[0] != -float("inf") + for v in self.all_var): + M = self.A + X = Matrix([v.assign[0] for v in self.all_var]) + assert all(abs(val) < 10 ** (-10) for val in M * X) + + return None + + def _assert_lower(self, xi, ci, from_equality=False, from_neg=False): + """ + Adjusts the lower bound on variable xi if the new lower bound is + more limiting. The assignment of variable xi is adjusted to be + within the new bound if needed. + + Also calls `self._update` to update the assignment for slack variables + to keep all equalities satisfied. + """ + if self.result: + assert self.result[0] != False + self.result = None + if ci <= xi.lower: + return None + if ci > xi.upper: + assert (xi.upper[1] <= 0) is True + assert (ci[1] >= 0) is True + + lit1, neg1 = Boundary.from_upper(xi) + + lit2 = Boundary(var=xi, const=ci[0], strict=ci[1] != 0, upper=False, equality=from_equality) + if from_neg: + lit2 = lit2.get_negated() + neg2 = -1 if from_neg else 1 + + conflict = [-neg1*self.boundary_to_enc[lit1],-neg2*self.boundary_to_enc[lit2]] + self.result = False, conflict + return self.result + xi.lower = ci + xi.lower_from_eq = from_equality + xi.lower_from_neg = from_neg + if xi in self.nonslack and xi.assign < ci: + self._update(xi, ci) + + if self.run_checks and all(v.assign[0] != float("inf") and v.assign[0] != -float("inf") + for v in self.all_var): + M = self.A + X = Matrix([v.assign[0] for v in self.all_var]) + assert all(abs(val) < 10 ** (-10) for val in M * X) + + return None + + def _update(self, xi, v): + """ + Updates all slack variables that have equations that contain + variable xi so that they stay satisfied given xi is equal to v. + """ + i = xi.col_idx + for j, b in enumerate(self.slack): + aji = self.A[j, i] + b.assign = b.assign + (v - xi.assign)*aji + xi.assign = v + + def check(self): + """ + Searches for an assignment that satisfies all constraints + or determines that no such assignment exists and gives + a minimal conflict clause that "explains" why the + constraints are unsatisfiable. + + Returns + ======= + + (True, assignment) or (False, explanation) + + assignment : dict of LRAVariables to values + Assigned values are tuples that represent a rational number + plus some infinatesimal delta. + + explanation : set of ints + """ + if self.is_sat: + return True, {var: var.assign for var in self.all_var} + if self.result: + return self.result + + from sympy.matrices.dense import Matrix + M = self.A.copy() + basic = {s: i for i, s in enumerate(self.slack)} # contains the row index associated with each basic variable + nonbasic = set(self.nonslack) + while True: + if self.run_checks: + # nonbasic variables must always be within bounds + assert all(((nb.assign >= nb.lower) == True) and ((nb.assign <= nb.upper) == True) for nb in nonbasic) + + # assignments for x must always satisfy Ax = 0 + # probably have to turn this off when dealing with strict ineq + if all(v.assign[0] != float("inf") and v.assign[0] != -float("inf") + for v in self.all_var): + X = Matrix([v.assign[0] for v in self.all_var]) + assert all(abs(val) < 10**(-10) for val in M*X) + + # check upper and lower match this format: + # x <= rat + delta iff x < rat + # x >= rat - delta iff x > rat + # this wouldn't make sense: + # x <= rat - delta + # x >= rat + delta + assert all(x.upper[1] <= 0 for x in self.all_var) + assert all(x.lower[1] >= 0 for x in self.all_var) + + cand = [b for b in basic if b.assign < b.lower or b.assign > b.upper] + + if len(cand) == 0: + return True, {var: var.assign for var in self.all_var} + + xi = min(cand, key=lambda v: v.col_idx) # Bland's rule + i = basic[xi] + + if xi.assign < xi.lower: + cand = [nb for nb in nonbasic + if (M[i, nb.col_idx] > 0 and nb.assign < nb.upper) + or (M[i, nb.col_idx] < 0 and nb.assign > nb.lower)] + if len(cand) == 0: + N_plus = [nb for nb in nonbasic if M[i, nb.col_idx] > 0] + N_minus = [nb for nb in nonbasic if M[i, nb.col_idx] < 0] + + conflict = [] + conflict += [Boundary.from_upper(nb) for nb in N_plus] + conflict += [Boundary.from_lower(nb) for nb in N_minus] + conflict.append(Boundary.from_lower(xi)) + conflict = [-neg*self.boundary_to_enc[c] for c, neg in conflict] + return False, conflict + xj = min(cand, key=str) + M = self._pivot_and_update(M, basic, nonbasic, xi, xj, xi.lower) + + if xi.assign > xi.upper: + cand = [nb for nb in nonbasic + if (M[i, nb.col_idx] < 0 and nb.assign < nb.upper) + or (M[i, nb.col_idx] > 0 and nb.assign > nb.lower)] + + if len(cand) == 0: + N_plus = [nb for nb in nonbasic if M[i, nb.col_idx] > 0] + N_minus = [nb for nb in nonbasic if M[i, nb.col_idx] < 0] + + conflict = [] + conflict += [Boundary.from_upper(nb) for nb in N_minus] + conflict += [Boundary.from_lower(nb) for nb in N_plus] + conflict.append(Boundary.from_upper(xi)) + + conflict = [-neg*self.boundary_to_enc[c] for c, neg in conflict] + return False, conflict + xj = min(cand, key=lambda v: v.col_idx) + M = self._pivot_and_update(M, basic, nonbasic, xi, xj, xi.upper) + + def _pivot_and_update(self, M, basic, nonbasic, xi, xj, v): + """ + Pivots basic variable xi with nonbasic variable xj, + and sets value of xi to v and adjusts the values of all basic variables + to keep equations satisfied. + """ + i, j = basic[xi], xj.col_idx + assert M[i, j] != 0 + theta = (v - xi.assign)*(1/M[i, j]) + xi.assign = v + xj.assign = xj.assign + theta + for xk in basic: + if xk != xi: + k = basic[xk] + akj = M[k, j] + xk.assign = xk.assign + theta*akj + # pivot + basic[xj] = basic[xi] + del basic[xi] + nonbasic.add(xi) + nonbasic.remove(xj) + return self._pivot(M, i, j) + + @staticmethod + def _pivot(M, i, j): + """ + Performs a pivot operation about entry i, j of M by performing + a series of row operations on a copy of M and returning the result. + The original M is left unmodified. + + Conceptually, M represents a system of equations and pivoting + can be thought of as rearranging equation i to be in terms of + variable j and then substituting in the rest of the equations + to get rid of other occurances of variable j. + + Example + ======= + + >>> from sympy.matrices.dense import Matrix + >>> from sympy.logic.algorithms.lra_theory import LRASolver + >>> from sympy import var + >>> Matrix(3, 3, var('a:i')) + Matrix([ + [a, b, c], + [d, e, f], + [g, h, i]]) + + This matrix is equivalent to: + 0 = a*x + b*y + c*z + 0 = d*x + e*y + f*z + 0 = g*x + h*y + i*z + + >>> LRASolver._pivot(_, 1, 0) + Matrix([ + [ 0, -a*e/d + b, -a*f/d + c], + [-1, -e/d, -f/d], + [ 0, h - e*g/d, i - f*g/d]]) + + We rearrange equation 1 in terms of variable 0 (x) + and substitute to remove x from the other equations. + + 0 = 0 + (-a*e/d + b)*y + (-a*f/d + c)*z + 0 = -x + (-e/d)*y + (-f/d)*z + 0 = 0 + (h - e*g/d)*y + (i - f*g/d)*z + """ + _, _, Mij = M[i, :], M[:, j], M[i, j] + if Mij == 0: + raise ZeroDivisionError("Tried to pivot about zero-valued entry.") + A = M.copy() + A[i, :] = -A[i, :]/Mij + for row in range(M.shape[0]): + if row != i: + A[row, :] = A[row, :] + A[row, j] * A[i, :] + + return A + + +def _sep_const_coeff(expr): + """ + Example + ======= + + >>> from sympy.logic.algorithms.lra_theory import _sep_const_coeff + >>> from sympy.abc import x, y + >>> _sep_const_coeff(2*x) + (x, 2) + >>> _sep_const_coeff(2*x + 3*y) + (2*x + 3*y, 1) + """ + if isinstance(expr, Add): + return expr, sympify(1) + + if isinstance(expr, Mul): + coeffs = expr.args + else: + coeffs = [expr] + + var, const = [], [] + for c in coeffs: + c = sympify(c) + if len(c.free_symbols)==0: + const.append(c) + else: + var.append(c) + return Mul(*var), Mul(*const) + + +def _list_terms(expr): + if not isinstance(expr, Add): + return [expr] + + return expr.args + + +def _sep_const_terms(expr): + """ + Example + ======= + + >>> from sympy.logic.algorithms.lra_theory import _sep_const_terms + >>> from sympy.abc import x, y + >>> _sep_const_terms(2*x + 3*y + 2) + (2*x + 3*y, 2) + """ + if isinstance(expr, Add): + terms = expr.args + else: + terms = [expr] + + var, const = [], [] + for t in terms: + if len(t.free_symbols) == 0: + const.append(t) + else: + var.append(t) + return sum(var), sum(const) + + +def _eval_binrel(binrel): + """ + Simplify binary relation to True / False if possible. + """ + if not (len(binrel.lhs.free_symbols) == 0 and len(binrel.rhs.free_symbols) == 0): + return binrel + if binrel.function == Q.lt: + res = binrel.lhs < binrel.rhs + elif binrel.function == Q.gt: + res = binrel.lhs > binrel.rhs + elif binrel.function == Q.le: + res = binrel.lhs <= binrel.rhs + elif binrel.function == Q.ge: + res = binrel.lhs >= binrel.rhs + elif binrel.function == Q.eq: + res = Eq(binrel.lhs, binrel.rhs) + elif binrel.function == Q.ne: + res = Ne(binrel.lhs, binrel.rhs) + + if res == True or res == False: + return res + else: + return None + + +class Boundary: + """ + Represents an upper or lower bound or an equality between a symbol + and some constant. + """ + def __init__(self, var, const, upper, equality, strict=None): + if not equality in [True, False]: + assert equality in [True, False] + + + self.var = var + if isinstance(const, tuple): + s = const[1] != 0 + if strict: + assert s == strict + self.bound = const[0] + self.strict = s + else: + self.bound = const + self.strict = strict + self.upper = upper if not equality else None + self.equality = equality + self.strict = strict + assert self.strict is not None + + @staticmethod + def from_upper(var): + neg = -1 if var.upper_from_neg else 1 + b = Boundary(var, var.upper[0], True, var.upper_from_eq, var.upper[1] != 0) + if neg < 0: + b = b.get_negated() + return b, neg + + @staticmethod + def from_lower(var): + neg = -1 if var.lower_from_neg else 1 + b = Boundary(var, var.lower[0], False, var.lower_from_eq, var.lower[1] != 0) + if neg < 0: + b = b.get_negated() + return b, neg + + def get_negated(self): + return Boundary(self.var, self.bound, not self.upper, self.equality, not self.strict) + + def get_inequality(self): + if self.equality: + return Eq(self.var.var, self.bound) + elif self.upper and self.strict: + return self.var.var < self.bound + elif not self.upper and self.strict: + return self.var.var > self.bound + elif self.upper: + return self.var.var <= self.bound + else: + return self.var.var >= self.bound + + def __repr__(self): + return repr("Boundary(" + repr(self.get_inequality()) + ")") + + def __eq__(self, other): + other = (other.var, other.bound, other.strict, other.upper, other.equality) + return (self.var, self.bound, self.strict, self.upper, self.equality) == other + + def __hash__(self): + return hash((self.var, self.bound, self.strict, self.upper, self.equality)) + + +class LRARational(): + """ + Represents a rational plus or minus some amount + of arbitrary small deltas. + """ + def __init__(self, rational, delta): + self.value = (rational, delta) + + def __lt__(self, other): + return self.value < other.value + + def __le__(self, other): + return self.value <= other.value + + def __eq__(self, other): + return self.value == other.value + + def __add__(self, other): + return LRARational(self.value[0] + other.value[0], self.value[1] + other.value[1]) + + def __sub__(self, other): + return LRARational(self.value[0] - other.value[0], self.value[1] - other.value[1]) + + def __mul__(self, other): + assert not isinstance(other, LRARational) + return LRARational(self.value[0] * other, self.value[1] * other) + + def __getitem__(self, index): + return self.value[index] + + def __repr__(self): + return repr(self.value) + + +class LRAVariable(): + """ + Object to keep track of upper and lower bounds + on `self.var`. + """ + def __init__(self, var): + self.upper = LRARational(float("inf"), 0) + self.upper_from_eq = False + self.upper_from_neg = False + self.lower = LRARational(-float("inf"), 0) + self.lower_from_eq = False + self.lower_from_neg = False + self.assign = LRARational(0,0) + self.var = var + self.col_idx = None + + def __repr__(self): + return repr(self.var) + + def __eq__(self, other): + if not isinstance(other, LRAVariable): + return False + return other.var == self.var + + def __hash__(self): + return hash(self.var) diff --git a/phivenv/Lib/site-packages/sympy/logic/algorithms/minisat22_wrapper.py b/phivenv/Lib/site-packages/sympy/logic/algorithms/minisat22_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..1d5c1f8f14f04309f7cb8197cc05d01a3c108545 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/logic/algorithms/minisat22_wrapper.py @@ -0,0 +1,46 @@ +from sympy.assumptions.cnf import EncodedCNF + +def minisat22_satisfiable(expr, all_models=False, minimal=False): + + if not isinstance(expr, EncodedCNF): + exprs = EncodedCNF() + exprs.add_prop(expr) + expr = exprs + + from pysat.solvers import Minisat22 + + # Return UNSAT when False (encoded as 0) is present in the CNF + if {0} in expr.data: + if all_models: + return (f for f in [False]) + return False + + r = Minisat22(expr.data) + + if minimal: + r.set_phases([-(i+1) for i in range(r.nof_vars())]) + + if not r.solve(): + return False + + if not all_models: + return {expr.symbols[abs(lit) - 1]: lit > 0 for lit in r.get_model()} + + else: + # Make solutions SymPy compatible by creating a generator + def _gen(results): + satisfiable = False + while results.solve(): + sol = results.get_model() + yield {expr.symbols[abs(lit) - 1]: lit > 0 for lit in sol} + if minimal: + results.add_clause([-i for i in sol if i>0]) + else: + results.add_clause([-i for i in sol]) + satisfiable = True + if not satisfiable: + yield False + raise StopIteration + + + return _gen(r) diff --git a/phivenv/Lib/site-packages/sympy/logic/algorithms/pycosat_wrapper.py b/phivenv/Lib/site-packages/sympy/logic/algorithms/pycosat_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..5ff498b7e3f6b73d95e9b949598ef32df4ecf226 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/logic/algorithms/pycosat_wrapper.py @@ -0,0 +1,41 @@ +from sympy.assumptions.cnf import EncodedCNF + + +def pycosat_satisfiable(expr, all_models=False): + import pycosat + if not isinstance(expr, EncodedCNF): + exprs = EncodedCNF() + exprs.add_prop(expr) + expr = exprs + + # Return UNSAT when False (encoded as 0) is present in the CNF + if {0} in expr.data: + if all_models: + return (f for f in [False]) + return False + + if not all_models: + r = pycosat.solve(expr.data) + result = (r != "UNSAT") + if not result: + return result + return {expr.symbols[abs(lit) - 1]: lit > 0 for lit in r} + else: + r = pycosat.itersolve(expr.data) + result = (r != "UNSAT") + if not result: + return result + + # Make solutions SymPy compatible by creating a generator + def _gen(results): + satisfiable = False + try: + while True: + sol = next(results) + yield {expr.symbols[abs(lit) - 1]: lit > 0 for lit in sol} + satisfiable = True + except StopIteration: + if not satisfiable: + yield False + + return _gen(r) diff --git a/phivenv/Lib/site-packages/sympy/logic/algorithms/z3_wrapper.py b/phivenv/Lib/site-packages/sympy/logic/algorithms/z3_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..fe44f713a2edfd5286c0f81b737212146766b11b --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/logic/algorithms/z3_wrapper.py @@ -0,0 +1,115 @@ +from sympy.printing.smtlib import smtlib_code +from sympy.assumptions.assume import AppliedPredicate +from sympy.assumptions.cnf import EncodedCNF +from sympy.assumptions.ask import Q + +from sympy.core import Add, Mul +from sympy.core.relational import Equality, LessThan, GreaterThan, StrictLessThan, StrictGreaterThan +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.exponential import Pow +from sympy.functions.elementary.miscellaneous import Min, Max +from sympy.logic.boolalg import And, Or, Xor, Implies +from sympy.logic.boolalg import Not, ITE +from sympy.assumptions.relation.equality import StrictGreaterThanPredicate, StrictLessThanPredicate, GreaterThanPredicate, LessThanPredicate, EqualityPredicate +from sympy.external import import_module + +def z3_satisfiable(expr, all_models=False): + if not isinstance(expr, EncodedCNF): + exprs = EncodedCNF() + exprs.add_prop(expr) + expr = exprs + + z3 = import_module("z3") + if z3 is None: + raise ImportError("z3 is not installed") + + s = encoded_cnf_to_z3_solver(expr, z3) + + res = str(s.check()) + if res == "unsat": + return False + elif res == "sat": + return z3_model_to_sympy_model(s.model(), expr) + else: + return None + + +def z3_model_to_sympy_model(z3_model, enc_cnf): + rev_enc = {value : key for key, value in enc_cnf.encoding.items()} + return {rev_enc[int(var.name()[1:])] : bool(z3_model[var]) for var in z3_model} + + +def clause_to_assertion(clause): + clause_strings = [f"d{abs(lit)}" if lit > 0 else f"(not d{abs(lit)})" for lit in clause] + return "(assert (or " + " ".join(clause_strings) + "))" + + +def encoded_cnf_to_z3_solver(enc_cnf, z3): + def dummify_bool(pred): + return False + assert isinstance(pred, AppliedPredicate) + + if pred.function in [Q.positive, Q.negative, Q.zero]: + return pred + else: + return False + + s = z3.Solver() + + declarations = [f"(declare-const d{var} Bool)" for var in enc_cnf.variables] + assertions = [clause_to_assertion(clause) for clause in enc_cnf.data] + + symbols = set() + for pred, enc in enc_cnf.encoding.items(): + if not isinstance(pred, AppliedPredicate): + continue + if pred.function not in (Q.gt, Q.lt, Q.ge, Q.le, Q.ne, Q.eq, Q.positive, Q.negative, Q.extended_negative, Q.extended_positive, Q.zero, Q.nonzero, Q.nonnegative, Q.nonpositive, Q.extended_nonzero, Q.extended_nonnegative, Q.extended_nonpositive): + continue + + pred_str = smtlib_code(pred, auto_declare=False, auto_assert=False, known_functions=known_functions) + + symbols |= pred.free_symbols + pred = pred_str + clause = f"(implies d{enc} {pred})" + assertion = "(assert " + clause + ")" + assertions.append(assertion) + + for sym in symbols: + declarations.append(f"(declare-const {sym} Real)") + + declarations = "\n".join(declarations) + assertions = "\n".join(assertions) + s.from_string(declarations) + s.from_string(assertions) + + return s + + +known_functions = { + Add: '+', + Mul: '*', + + Equality: '=', + LessThan: '<=', + GreaterThan: '>=', + StrictLessThan: '<', + StrictGreaterThan: '>', + + EqualityPredicate(): '=', + LessThanPredicate(): '<=', + GreaterThanPredicate(): '>=', + StrictLessThanPredicate(): '<', + StrictGreaterThanPredicate(): '>', + + Abs: 'abs', + Min: 'min', + Max: 'max', + Pow: '^', + + And: 'and', + Or: 'or', + Xor: 'xor', + Not: 'not', + ITE: 'ite', + Implies: '=>', + } diff --git a/phivenv/Lib/site-packages/sympy/logic/tests/__init__.py b/phivenv/Lib/site-packages/sympy/logic/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/logic/tests/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/logic/tests/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70d155a42c52874fa14edd8ce1e6d19d571c583d Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/logic/tests/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/logic/tests/__pycache__/test_boolalg.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/logic/tests/__pycache__/test_boolalg.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2534fe730c77f301131ba24150d30903833355fb Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/logic/tests/__pycache__/test_boolalg.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/logic/tests/__pycache__/test_dimacs.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/logic/tests/__pycache__/test_dimacs.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25b6a711fdccfe363467846c609dadec38fbccb8 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/logic/tests/__pycache__/test_dimacs.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/logic/tests/__pycache__/test_inference.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/logic/tests/__pycache__/test_inference.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35296f510c272573c13b8a1aa49eac4b26b310ba Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/logic/tests/__pycache__/test_inference.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/logic/tests/__pycache__/test_lra_theory.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/logic/tests/__pycache__/test_lra_theory.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0bf025f3c9132773d40463b017681b6ca4d39f3 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/logic/tests/__pycache__/test_lra_theory.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/logic/tests/test_boolalg.py b/phivenv/Lib/site-packages/sympy/logic/tests/test_boolalg.py new file mode 100644 index 0000000000000000000000000000000000000000..88cdd647fdcc723faee328f71df96030841a3edb --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/logic/tests/test_boolalg.py @@ -0,0 +1,1367 @@ +from sympy.assumptions.ask import Q +from sympy.assumptions.refine import refine +from sympy.core.numbers import oo +from sympy.core.relational import Equality, Eq, Ne +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, symbols) +from sympy.functions import Piecewise +from sympy.functions.elementary.trigonometric import cos, sin +from sympy.sets.sets import Interval, Union +from sympy.sets.contains import Contains +from sympy.simplify.simplify import simplify +from sympy.logic.boolalg import ( + And, Boolean, Equivalent, ITE, Implies, Nand, Nor, Not, Or, + POSform, SOPform, Xor, Xnor, conjuncts, disjuncts, + distribute_or_over_and, distribute_and_over_or, + eliminate_implications, is_nnf, is_cnf, is_dnf, simplify_logic, + to_nnf, to_cnf, to_dnf, to_int_repr, bool_map, true, false, + BooleanAtom, is_literal, term_to_integer, + truth_table, as_Boolean, to_anf, is_anf, distribute_xor_over_and, + anf_coeffs, ANFform, bool_minterm, bool_maxterm, bool_monomial, + _check_pair, _convert_to_varsSOP, _convert_to_varsPOS, Exclusive, + gateinputcount) +from sympy.assumptions.cnf import CNF + +from sympy.testing.pytest import raises, XFAIL, slow + +from itertools import combinations, permutations, product + +A, B, C, D = symbols('A:D') +a, b, c, d, e, w, x, y, z = symbols('a:e w:z') + + +def test_overloading(): + """Test that |, & are overloaded as expected""" + + assert A & B == And(A, B) + assert A | B == Or(A, B) + assert (A & B) | C == Or(And(A, B), C) + assert A >> B == Implies(A, B) + assert A << B == Implies(B, A) + assert ~A == Not(A) + assert A ^ B == Xor(A, B) + + +def test_And(): + assert And() is true + assert And(A) == A + assert And(True) is true + assert And(False) is false + assert And(True, True) is true + assert And(True, False) is false + assert And(False, False) is false + assert And(True, A) == A + assert And(False, A) is false + assert And(True, True, True) is true + assert And(True, True, A) == A + assert And(True, False, A) is false + assert And(1, A) == A + raises(TypeError, lambda: And(2, A)) + assert And(A < 1, A >= 1) is false + e = A > 1 + assert And(e, e.canonical) == e.canonical + g, l, ge, le = A > B, B < A, A >= B, B <= A + assert And(g, l, ge, le) == And(ge, g) + assert {And(*i) for i in permutations((l, g, le, ge))} == {And(ge, g)} + assert And(And(Eq(a, 0), Eq(b, 0)), And(Ne(a, 0), Eq(c, 0))) is false + + +def test_Or(): + assert Or() is false + assert Or(A) == A + assert Or(True) is true + assert Or(False) is false + assert Or(True, True) is true + assert Or(True, False) is true + assert Or(False, False) is false + assert Or(True, A) is true + assert Or(False, A) == A + assert Or(True, False, False) is true + assert Or(True, False, A) is true + assert Or(False, False, A) == A + assert Or(1, A) is true + raises(TypeError, lambda: Or(2, A)) + assert Or(A < 1, A >= 1) is true + e = A > 1 + assert Or(e, e.canonical) == e + g, l, ge, le = A > B, B < A, A >= B, B <= A + assert Or(g, l, ge, le) == Or(g, ge) + + +def test_Xor(): + assert Xor() is false + assert Xor(A) == A + assert Xor(A, A) is false + assert Xor(True, A, A) is true + assert Xor(A, A, A, A, A) == A + assert Xor(True, False, False, A, B) == ~Xor(A, B) + assert Xor(True) is true + assert Xor(False) is false + assert Xor(True, True) is false + assert Xor(True, False) is true + assert Xor(False, False) is false + assert Xor(True, A) == ~A + assert Xor(False, A) == A + assert Xor(True, False, False) is true + assert Xor(True, False, A) == ~A + assert Xor(False, False, A) == A + assert isinstance(Xor(A, B), Xor) + assert Xor(A, B, Xor(C, D)) == Xor(A, B, C, D) + assert Xor(A, B, Xor(B, C)) == Xor(A, C) + assert Xor(A < 1, A >= 1, B) == Xor(0, 1, B) == Xor(1, 0, B) + e = A > 1 + assert Xor(e, e.canonical) == Xor(0, 0) == Xor(1, 1) + + +def test_rewrite_as_And(): + expr = x ^ y + assert expr.rewrite(And) == (x | y) & (~x | ~y) + + +def test_rewrite_as_Or(): + expr = x ^ y + assert expr.rewrite(Or) == (x & ~y) | (y & ~x) + + +def test_rewrite_as_Nand(): + expr = (y & z) | (z & ~w) + assert expr.rewrite(Nand) == ~(~(y & z) & ~(z & ~w)) + + +def test_rewrite_as_Nor(): + expr = z & (y | ~w) + assert expr.rewrite(Nor) == ~(~z | ~(y | ~w)) + + +def test_Not(): + raises(TypeError, lambda: Not(True, False)) + assert Not(True) is false + assert Not(False) is true + assert Not(0) is true + assert Not(1) is false + assert Not(2) is false + + +def test_Nand(): + assert Nand() is false + assert Nand(A) == ~A + assert Nand(True) is false + assert Nand(False) is true + assert Nand(True, True) is false + assert Nand(True, False) is true + assert Nand(False, False) is true + assert Nand(True, A) == ~A + assert Nand(False, A) is true + assert Nand(True, True, True) is false + assert Nand(True, True, A) == ~A + assert Nand(True, False, A) is true + + +def test_Nor(): + assert Nor() is true + assert Nor(A) == ~A + assert Nor(True) is false + assert Nor(False) is true + assert Nor(True, True) is false + assert Nor(True, False) is false + assert Nor(False, False) is true + assert Nor(True, A) is false + assert Nor(False, A) == ~A + assert Nor(True, True, True) is false + assert Nor(True, True, A) is false + assert Nor(True, False, A) is false + + +def test_Xnor(): + assert Xnor() is true + assert Xnor(A) == ~A + assert Xnor(A, A) is true + assert Xnor(True, A, A) is false + assert Xnor(A, A, A, A, A) == ~A + assert Xnor(True) is false + assert Xnor(False) is true + assert Xnor(True, True) is true + assert Xnor(True, False) is false + assert Xnor(False, False) is true + assert Xnor(True, A) == A + assert Xnor(False, A) == ~A + assert Xnor(True, False, False) is false + assert Xnor(True, False, A) == A + assert Xnor(False, False, A) == ~A + + +def test_Implies(): + raises(ValueError, lambda: Implies(A, B, C)) + assert Implies(True, True) is true + assert Implies(True, False) is false + assert Implies(False, True) is true + assert Implies(False, False) is true + assert Implies(0, A) is true + assert Implies(1, 1) is true + assert Implies(1, 0) is false + assert A >> B == B << A + assert (A < 1) >> (A >= 1) == (A >= 1) + assert (A < 1) >> (S.One > A) is true + assert A >> A is true + + +def test_Equivalent(): + assert Equivalent(A, B) == Equivalent(B, A) == Equivalent(A, B, A) + assert Equivalent() is true + assert Equivalent(A, A) == Equivalent(A) is true + assert Equivalent(True, True) == Equivalent(False, False) is true + assert Equivalent(True, False) == Equivalent(False, True) is false + assert Equivalent(A, True) == A + assert Equivalent(A, False) == Not(A) + assert Equivalent(A, B, True) == A & B + assert Equivalent(A, B, False) == ~A & ~B + assert Equivalent(1, A) == A + assert Equivalent(0, A) == Not(A) + assert Equivalent(A, Equivalent(B, C)) != Equivalent(Equivalent(A, B), C) + assert Equivalent(A < 1, A >= 1) is false + assert Equivalent(A < 1, A >= 1, 0) is false + assert Equivalent(A < 1, A >= 1, 1) is false + assert Equivalent(A < 1, S.One > A) == Equivalent(1, 1) == Equivalent(0, 0) + assert Equivalent(Equality(A, B), Equality(B, A)) is true + + +def test_Exclusive(): + assert Exclusive(False, False, False) is true + assert Exclusive(True, False, False) is true + assert Exclusive(True, True, False) is false + assert Exclusive(True, True, True) is false + + +def test_equals(): + assert Not(Or(A, B)).equals(And(Not(A), Not(B))) is True + assert Equivalent(A, B).equals((A >> B) & (B >> A)) is True + assert ((A | ~B) & (~A | B)).equals((~A & ~B) | (A & B)) is True + assert (A >> B).equals(~A >> ~B) is False + assert (A >> (B >> A)).equals(A >> (C >> A)) is False + raises(NotImplementedError, lambda: (A & B).equals(A > B)) + + +def test_simplification_boolalg(): + """ + Test working of simplification methods. + """ + set1 = [[0, 0, 1], [0, 1, 1], [1, 0, 0], [1, 1, 0]] + set2 = [[0, 0, 0], [0, 1, 0], [1, 0, 1], [1, 1, 1]] + assert SOPform([x, y, z], set1) == Or(And(Not(x), z), And(Not(z), x)) + assert Not(SOPform([x, y, z], set2)) == \ + Not(Or(And(Not(x), Not(z)), And(x, z))) + assert POSform([x, y, z], set1 + set2) is true + assert SOPform([x, y, z], set1 + set2) is true + assert SOPform([Dummy(), Dummy(), Dummy()], set1 + set2) is true + + minterms = [[0, 0, 0, 1], [0, 0, 1, 1], [0, 1, 1, 1], [1, 0, 1, 1], + [1, 1, 1, 1]] + dontcares = [[0, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 1]] + assert ( + SOPform([w, x, y, z], minterms, dontcares) == + Or(And(y, z), And(Not(w), Not(x)))) + assert POSform([w, x, y, z], minterms, dontcares) == And(Or(Not(w), y), z) + + minterms = [1, 3, 7, 11, 15] + dontcares = [0, 2, 5] + assert ( + SOPform([w, x, y, z], minterms, dontcares) == + Or(And(y, z), And(Not(w), Not(x)))) + assert POSform([w, x, y, z], minterms, dontcares) == And(Or(Not(w), y), z) + + minterms = [1, [0, 0, 1, 1], 7, [1, 0, 1, 1], + [1, 1, 1, 1]] + dontcares = [0, [0, 0, 1, 0], 5] + assert ( + SOPform([w, x, y, z], minterms, dontcares) == + Or(And(y, z), And(Not(w), Not(x)))) + assert POSform([w, x, y, z], minterms, dontcares) == And(Or(Not(w), y), z) + + minterms = [1, {y: 1, z: 1}] + dontcares = [0, [0, 0, 1, 0], 5] + assert ( + SOPform([w, x, y, z], minterms, dontcares) == + Or(And(y, z), And(Not(w), Not(x)))) + assert POSform([w, x, y, z], minterms, dontcares) == And(Or(Not(w), y), z) + + minterms = [{y: 1, z: 1}, 1] + dontcares = [[0, 0, 0, 0]] + + minterms = [[0, 0, 0]] + raises(ValueError, lambda: SOPform([w, x, y, z], minterms)) + raises(ValueError, lambda: POSform([w, x, y, z], minterms)) + + raises(TypeError, lambda: POSform([w, x, y, z], ["abcdefg"])) + + # test simplification + ans = And(A, Or(B, C)) + assert simplify_logic(A & (B | C)) == ans + assert simplify_logic((A & B) | (A & C)) == ans + assert simplify_logic(Implies(A, B)) == Or(Not(A), B) + assert simplify_logic(Equivalent(A, B)) == \ + Or(And(A, B), And(Not(A), Not(B))) + assert simplify_logic(And(Equality(A, 2), C)) == And(Equality(A, 2), C) + assert simplify_logic(And(Equality(A, 2), A)) == And(Equality(A, 2), A) + assert simplify_logic(And(Equality(A, B), C)) == And(Equality(A, B), C) + assert simplify_logic(Or(And(Equality(A, 3), B), And(Equality(A, 3), C))) \ + == And(Equality(A, 3), Or(B, C)) + b = (~x & ~y & ~z) | (~x & ~y & z) + e = And(A, b) + assert simplify_logic(e) == A & ~x & ~y + raises(ValueError, lambda: simplify_logic(A & (B | C), form='blabla')) + assert simplify(Or(x <= y, And(x < y, z))) == (x <= y) + assert simplify(Or(x <= y, And(y > x, z))) == (x <= y) + assert simplify(Or(x >= y, And(y < x, z))) == (x >= y) + + # Check that expressions with nine variables or more are not simplified + # (without the force-flag) + a, b, c, d, e, f, g, h, j = symbols('a b c d e f g h j') + expr = a & b & c & d & e & f & g & h & j | \ + a & b & c & d & e & f & g & h & ~j + # This expression can be simplified to get rid of the j variables + assert simplify_logic(expr) == expr + + # Test dontcare + assert simplify_logic((a & b) | c | d, dontcare=(a & b)) == c | d + + # check input + ans = SOPform([x, y], [[1, 0]]) + assert SOPform([x, y], [[1, 0]]) == ans + assert POSform([x, y], [[1, 0]]) == ans + + raises(ValueError, lambda: SOPform([x], [[1]], [[1]])) + assert SOPform([x], [[1]], [[0]]) is true + assert SOPform([x], [[0]], [[1]]) is true + assert SOPform([x], [], []) is false + + raises(ValueError, lambda: POSform([x], [[1]], [[1]])) + assert POSform([x], [[1]], [[0]]) is true + assert POSform([x], [[0]], [[1]]) is true + assert POSform([x], [], []) is false + + # check working of simplify + assert simplify((A & B) | (A & C)) == And(A, Or(B, C)) + assert simplify(And(x, Not(x))) == False + assert simplify(Or(x, Not(x))) == True + assert simplify(And(Eq(x, 0), Eq(x, y))) == And(Eq(x, 0), Eq(y, 0)) + assert And(Eq(x - 1, 0), Eq(x, y)).simplify() == And(Eq(x, 1), Eq(y, 1)) + assert And(Ne(x - 1, 0), Ne(x, y)).simplify() == And(Ne(x, 1), Ne(x, y)) + assert And(Eq(x - 1, 0), Ne(x, y)).simplify() == And(Eq(x, 1), Ne(y, 1)) + assert And(Eq(x - 1, 0), Eq(x, z + y), Eq(y + x, 0)).simplify( + ) == And(Eq(x, 1), Eq(y, -1), Eq(z, 2)) + assert And(Eq(x - 1, 0), Eq(x + 2, 3)).simplify() == Eq(x, 1) + assert And(Ne(x - 1, 0), Ne(x + 2, 3)).simplify() == Ne(x, 1) + assert And(Eq(x - 1, 0), Eq(x + 2, 2)).simplify() == False + assert And(Ne(x - 1, 0), Ne(x + 2, 2)).simplify( + ) == And(Ne(x, 1), Ne(x, 0)) + assert simplify(Xor(x, ~x)) == True + + +def test_bool_map(): + """ + Test working of bool_map function. + """ + + minterms = [[0, 0, 0, 1], [0, 0, 1, 1], [0, 1, 1, 1], [1, 0, 1, 1], + [1, 1, 1, 1]] + assert bool_map(Not(Not(a)), a) == (a, {a: a}) + assert bool_map(SOPform([w, x, y, z], minterms), + POSform([w, x, y, z], minterms)) == \ + (And(Or(Not(w), y), Or(Not(x), y), z), {x: x, w: w, z: z, y: y}) + assert bool_map(SOPform([x, z, y], [[1, 0, 1]]), + SOPform([a, b, c], [[1, 0, 1]])) != False + function1 = SOPform([x, z, y], [[1, 0, 1], [0, 0, 1]]) + function2 = SOPform([a, b, c], [[1, 0, 1], [1, 0, 0]]) + assert bool_map(function1, function2) == \ + (function1, {y: a, z: b}) + assert bool_map(Xor(x, y), ~Xor(x, y)) == False + assert bool_map(And(x, y), Or(x, y)) is None + assert bool_map(And(x, y), And(x, y, z)) is None + # issue 16179 + assert bool_map(Xor(x, y, z), ~Xor(x, y, z)) == False + assert bool_map(Xor(a, x, y, z), ~Xor(a, x, y, z)) == False + + +def test_bool_symbol(): + """Test that mixing symbols with boolean values + works as expected""" + + assert And(A, True) == A + assert And(A, True, True) == A + assert And(A, False) is false + assert And(A, True, False) is false + assert Or(A, True) is true + assert Or(A, False) == A + + +def test_is_boolean(): + assert isinstance(True, Boolean) is False + assert isinstance(true, Boolean) is True + assert 1 == True + assert 1 != true + assert (1 == true) is False + assert 0 == False + assert 0 != false + assert (0 == false) is False + assert true.is_Boolean is True + assert (A & B).is_Boolean + assert (A | B).is_Boolean + assert (~A).is_Boolean + assert (A ^ B).is_Boolean + assert A.is_Boolean != isinstance(A, Boolean) + assert isinstance(A, Boolean) + + +def test_subs(): + assert (A & B).subs(A, True) == B + assert (A & B).subs(A, False) is false + assert (A & B).subs(B, True) == A + assert (A & B).subs(B, False) is false + assert (A & B).subs({A: True, B: True}) is true + assert (A | B).subs(A, True) is true + assert (A | B).subs(A, False) == B + assert (A | B).subs(B, True) is true + assert (A | B).subs(B, False) == A + assert (A | B).subs({A: True, B: True}) is true + + +""" +we test for axioms of boolean algebra +see https://en.wikipedia.org/wiki/Boolean_algebra_(structure) +""" + + +def test_commutative(): + """Test for commutativity of And and Or""" + A, B = map(Boolean, symbols('A,B')) + + assert A & B == B & A + assert A | B == B | A + + +def test_and_associativity(): + """Test for associativity of And""" + + assert (A & B) & C == A & (B & C) + + +def test_or_assicativity(): + assert ((A | B) | C) == (A | (B | C)) + + +def test_double_negation(): + a = Boolean() + assert ~(~a) == a + + +# test methods + +def test_eliminate_implications(): + assert eliminate_implications(Implies(A, B, evaluate=False)) == (~A) | B + assert eliminate_implications( + A >> (C >> Not(B))) == Or(Or(Not(B), Not(C)), Not(A)) + assert eliminate_implications(Equivalent(A, B, C, D)) == \ + (~A | B) & (~B | C) & (~C | D) & (~D | A) + + +def test_conjuncts(): + assert conjuncts(A & B & C) == {A, B, C} + assert conjuncts((A | B) & C) == {A | B, C} + assert conjuncts(A) == {A} + assert conjuncts(True) == {True} + assert conjuncts(False) == {False} + + +def test_disjuncts(): + assert disjuncts(A | B | C) == {A, B, C} + assert disjuncts((A | B) & C) == {(A | B) & C} + assert disjuncts(A) == {A} + assert disjuncts(True) == {True} + assert disjuncts(False) == {False} + + +def test_distribute(): + assert distribute_and_over_or(Or(And(A, B), C)) == And(Or(A, C), Or(B, C)) + assert distribute_or_over_and(And(A, Or(B, C))) == Or(And(A, B), And(A, C)) + assert distribute_xor_over_and(And(A, Xor(B, C))) == Xor(And(A, B), And(A, C)) + + +def test_to_anf(): + x, y, z = symbols('x,y,z') + assert to_anf(And(x, y)) == And(x, y) + assert to_anf(Or(x, y)) == Xor(x, y, And(x, y)) + assert to_anf(Or(Implies(x, y), And(x, y), y)) == \ + Xor(x, True, x & y, remove_true=False) + assert to_anf(Or(Nand(x, y), Nor(x, y), Xnor(x, y), Implies(x, y))) == True + assert to_anf(Or(x, Not(y), Nor(x, z), And(x, y), Nand(y, z))) == \ + Xor(True, And(y, z), And(x, y, z), remove_true=False) + assert to_anf(Xor(x, y)) == Xor(x, y) + assert to_anf(Not(x)) == Xor(x, True, remove_true=False) + assert to_anf(Nand(x, y)) == Xor(True, And(x, y), remove_true=False) + assert to_anf(Nor(x, y)) == Xor(x, y, True, And(x, y), remove_true=False) + assert to_anf(Implies(x, y)) == Xor(x, True, And(x, y), remove_true=False) + assert to_anf(Equivalent(x, y)) == Xor(x, y, True, remove_true=False) + assert to_anf(Nand(x | y, x >> y), deep=False) == \ + Xor(True, And(Or(x, y), Implies(x, y)), remove_true=False) + assert to_anf(Nor(x ^ y, x & y), deep=False) == \ + Xor(True, Or(Xor(x, y), And(x, y)), remove_true=False) + # issue 25218 + assert to_anf(x ^ ~(x ^ y ^ ~y)) == False + + +def test_to_nnf(): + assert to_nnf(true) is true + assert to_nnf(false) is false + assert to_nnf(A) == A + assert to_nnf(A | ~A | B) is true + assert to_nnf(A & ~A & B) is false + assert to_nnf(A >> B) == ~A | B + assert to_nnf(Equivalent(A, B, C)) == (~A | B) & (~B | C) & (~C | A) + assert to_nnf(A ^ B ^ C) == \ + (A | B | C) & (~A | ~B | C) & (A | ~B | ~C) & (~A | B | ~C) + assert to_nnf(ITE(A, B, C)) == (~A | B) & (A | C) + assert to_nnf(Not(A | B | C)) == ~A & ~B & ~C + assert to_nnf(Not(A & B & C)) == ~A | ~B | ~C + assert to_nnf(Not(A >> B)) == A & ~B + assert to_nnf(Not(Equivalent(A, B, C))) == And(Or(A, B, C), Or(~A, ~B, ~C)) + assert to_nnf(Not(A ^ B ^ C)) == \ + (~A | B | C) & (A | ~B | C) & (A | B | ~C) & (~A | ~B | ~C) + assert to_nnf(Not(ITE(A, B, C))) == (~A | ~B) & (A | ~C) + assert to_nnf((A >> B) ^ (B >> A)) == (A & ~B) | (~A & B) + assert to_nnf((A >> B) ^ (B >> A), False) == \ + (~A | ~B | A | B) & ((A & ~B) | (~A & B)) + assert ITE(A, 1, 0).to_nnf() == A + assert ITE(A, 0, 1).to_nnf() == ~A + # although ITE can hold non-Boolean, it will complain if + # an attempt is made to convert the ITE to Boolean nnf + raises(TypeError, lambda: ITE(A < 1, [1], B).to_nnf()) + + +def test_to_cnf(): + assert to_cnf(~(B | C)) == And(Not(B), Not(C)) + assert to_cnf((A & B) | C) == And(Or(A, C), Or(B, C)) + assert to_cnf(A >> B) == (~A) | B + assert to_cnf(A >> (B & C)) == (~A | B) & (~A | C) + assert to_cnf(A & (B | C) | ~A & (B | C), True) == B | C + assert to_cnf(A & B) == And(A, B) + + assert to_cnf(Equivalent(A, B)) == And(Or(A, Not(B)), Or(B, Not(A))) + assert to_cnf(Equivalent(A, B & C)) == \ + (~A | B) & (~A | C) & (~B | ~C | A) + assert to_cnf(Equivalent(A, B | C), True) == \ + And(Or(Not(B), A), Or(Not(C), A), Or(B, C, Not(A))) + assert to_cnf(A + 1) == A + 1 + + +def test_issue_18904(): + x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12, x13, x14, x15 = symbols('x1:16') + eq = ((x1 & x2 & x3 & x4 & x5 & x6 & x7 & x8 & x9) | + (x1 & x2 & x3 & x4 & x5 & x6 & x7 & x10 & x9) | + (x1 & x11 & x3 & x12 & x5 & x13 & x14 & x15 & x9)) + assert is_cnf(to_cnf(eq)) + raises(ValueError, lambda: to_cnf(eq, simplify=True)) + for f, t in zip((And, Or), (to_cnf, to_dnf)): + eq = f(x1, x2, x3, x4, x5, x6, x7, x8, x9) + raises(ValueError, lambda: to_cnf(eq, simplify=True)) + assert t(eq, simplify=True, force=True) == eq + + +def test_issue_9949(): + assert is_cnf(to_cnf((b > -5) | (a > 2) & (a < 4))) + + +def test_to_CNF(): + assert CNF.CNF_to_cnf(CNF.to_CNF(~(B | C))) == to_cnf(~(B | C)) + assert CNF.CNF_to_cnf(CNF.to_CNF((A & B) | C)) == to_cnf((A & B) | C) + assert CNF.CNF_to_cnf(CNF.to_CNF(A >> B)) == to_cnf(A >> B) + assert CNF.CNF_to_cnf(CNF.to_CNF(A >> (B & C))) == to_cnf(A >> (B & C)) + assert CNF.CNF_to_cnf(CNF.to_CNF(A & (B | C) | ~A & (B | C))) == to_cnf(A & (B | C) | ~A & (B | C)) + assert CNF.CNF_to_cnf(CNF.to_CNF(A & B)) == to_cnf(A & B) + + +def test_to_dnf(): + assert to_dnf(~(B | C)) == And(Not(B), Not(C)) + assert to_dnf(A & (B | C)) == Or(And(A, B), And(A, C)) + assert to_dnf(A >> B) == (~A) | B + assert to_dnf(A >> (B & C)) == (~A) | (B & C) + assert to_dnf(A | B) == A | B + + assert to_dnf(Equivalent(A, B), True) == \ + Or(And(A, B), And(Not(A), Not(B))) + assert to_dnf(Equivalent(A, B & C), True) == \ + Or(And(A, B, C), And(Not(A), Not(B)), And(Not(A), Not(C))) + assert to_dnf(A + 1) == A + 1 + + +def test_to_int_repr(): + x, y, z = map(Boolean, symbols('x,y,z')) + + def sorted_recursive(arg): + try: + return sorted(sorted_recursive(x) for x in arg) + except TypeError: # arg is not a sequence + return arg + + assert sorted_recursive(to_int_repr([x | y, z | x], [x, y, z])) == \ + sorted_recursive([[1, 2], [1, 3]]) + assert sorted_recursive(to_int_repr([x | y, z | ~x], [x, y, z])) == \ + sorted_recursive([[1, 2], [3, -1]]) + + +def test_is_anf(): + x, y = symbols('x,y') + assert is_anf(true) is True + assert is_anf(false) is True + assert is_anf(x) is True + assert is_anf(And(x, y)) is True + assert is_anf(Xor(x, y, And(x, y))) is True + assert is_anf(Xor(x, y, Or(x, y))) is False + assert is_anf(Xor(Not(x), y)) is False + + +def test_is_nnf(): + assert is_nnf(true) is True + assert is_nnf(A) is True + assert is_nnf(~A) is True + assert is_nnf(A & B) is True + assert is_nnf((A & B) | (~A & A) | (~B & B) | (~A & ~B), False) is True + assert is_nnf((A | B) & (~A | ~B)) is True + assert is_nnf(Not(Or(A, B))) is False + assert is_nnf(A ^ B) is False + assert is_nnf((A & B) | (~A & A) | (~B & B) | (~A & ~B), True) is False + + +def test_is_cnf(): + assert is_cnf(x) is True + assert is_cnf(x | y | z) is True + assert is_cnf(x & y & z) is True + assert is_cnf((x | y) & z) is True + assert is_cnf((x & y) | z) is False + assert is_cnf(~(x & y) | z) is False + + +def test_is_dnf(): + assert is_dnf(x) is True + assert is_dnf(x | y | z) is True + assert is_dnf(x & y & z) is True + assert is_dnf((x & y) | z) is True + assert is_dnf((x | y) & z) is False + assert is_dnf(~(x | y) & z) is False + + +def test_ITE(): + A, B, C = symbols('A:C') + assert ITE(True, False, True) is false + assert ITE(True, True, False) is true + assert ITE(False, True, False) is false + assert ITE(False, False, True) is true + assert isinstance(ITE(A, B, C), ITE) + + A = True + assert ITE(A, B, C) == B + A = False + assert ITE(A, B, C) == C + B = True + assert ITE(And(A, B), B, C) == C + assert ITE(Or(A, False), And(B, True), False) is false + assert ITE(x, A, B) == Not(x) + assert ITE(x, B, A) == x + assert ITE(1, x, y) == x + assert ITE(0, x, y) == y + raises(TypeError, lambda: ITE(2, x, y)) + raises(TypeError, lambda: ITE(1, [], y)) + raises(TypeError, lambda: ITE(1, (), y)) + raises(TypeError, lambda: ITE(1, y, [])) + assert ITE(1, 1, 1) is S.true + assert isinstance(ITE(1, 1, 1, evaluate=False), ITE) + + assert ITE(Eq(x, True), y, x) == ITE(x, y, x) + assert ITE(Eq(x, False), y, x) == ITE(~x, y, x) + assert ITE(Ne(x, True), y, x) == ITE(~x, y, x) + assert ITE(Ne(x, False), y, x) == ITE(x, y, x) + assert ITE(Eq(S.true, x), y, x) == ITE(x, y, x) + assert ITE(Eq(S.false, x), y, x) == ITE(~x, y, x) + assert ITE(Ne(S.true, x), y, x) == ITE(~x, y, x) + assert ITE(Ne(S.false, x), y, x) == ITE(x, y, x) + # 0 and 1 in the context are not treated as True/False + # so the equality must always be False since dissimilar + # objects cannot be equal + assert ITE(Eq(x, 0), y, x) == x + assert ITE(Eq(x, 1), y, x) == x + assert ITE(Ne(x, 0), y, x) == y + assert ITE(Ne(x, 1), y, x) == y + assert ITE(Eq(x, 0), y, z).subs(x, 0) == y + assert ITE(Eq(x, 0), y, z).subs(x, 1) == z + raises(ValueError, lambda: ITE(x > 1, y, x, z)) + + +def test_is_literal(): + assert is_literal(True) is True + assert is_literal(False) is True + assert is_literal(A) is True + assert is_literal(~A) is True + assert is_literal(Or(A, B)) is False + assert is_literal(Q.zero(A)) is True + assert is_literal(Not(Q.zero(A))) is True + assert is_literal(Or(A, B)) is False + assert is_literal(And(Q.zero(A), Q.zero(B))) is False + assert is_literal(x < 3) + assert not is_literal(x + y < 3) + + +def test_operators(): + # Mostly test __and__, __rand__, and so on + assert True & A == A & True == A + assert False & A == A & False == False + assert A & B == And(A, B) + assert True | A == A | True == True + assert False | A == A | False == A + assert A | B == Or(A, B) + assert ~A == Not(A) + assert True >> A == A << True == A + assert False >> A == A << False == True + assert A >> True == True << A == True + assert A >> False == False << A == ~A + assert A >> B == B << A == Implies(A, B) + assert True ^ A == A ^ True == ~A + assert False ^ A == A ^ False == A + assert A ^ B == Xor(A, B) + + +def test_true_false(): + assert true is S.true + assert false is S.false + assert true is not True + assert false is not False + assert true + assert not false + assert true == True + assert false == False + assert not (true == False) + assert not (false == True) + assert not (true == false) + + assert hash(true) == hash(True) + assert hash(false) == hash(False) + assert len({true, True}) == len({false, False}) == 1 + + assert isinstance(true, BooleanAtom) + assert isinstance(false, BooleanAtom) + # We don't want to subclass from bool, because bool subclasses from + # int. But operators like &, |, ^, <<, >>, and ~ act differently on 0 and + # 1 then we want them to on true and false. See the docstrings of the + # various And, Or, etc. functions for examples. + assert not isinstance(true, bool) + assert not isinstance(false, bool) + + # Note: using 'is' comparison is important here. We want these to return + # true and false, not True and False + + assert Not(true) is false + assert Not(True) is false + assert Not(false) is true + assert Not(False) is true + assert ~true is false + assert ~false is true + + for T, F in product((True, true), (False, false)): + assert And(T, F) is false + assert And(F, T) is false + assert And(F, F) is false + assert And(T, T) is true + assert And(T, x) == x + assert And(F, x) is false + if not (T is True and F is False): + assert T & F is false + assert F & T is false + if F is not False: + assert F & F is false + if T is not True: + assert T & T is true + + assert Or(T, F) is true + assert Or(F, T) is true + assert Or(F, F) is false + assert Or(T, T) is true + assert Or(T, x) is true + assert Or(F, x) == x + if not (T is True and F is False): + assert T | F is true + assert F | T is true + if F is not False: + assert F | F is false + if T is not True: + assert T | T is true + + assert Xor(T, F) is true + assert Xor(F, T) is true + assert Xor(F, F) is false + assert Xor(T, T) is false + assert Xor(T, x) == ~x + assert Xor(F, x) == x + if not (T is True and F is False): + assert T ^ F is true + assert F ^ T is true + if F is not False: + assert F ^ F is false + if T is not True: + assert T ^ T is false + + assert Nand(T, F) is true + assert Nand(F, T) is true + assert Nand(F, F) is true + assert Nand(T, T) is false + assert Nand(T, x) == ~x + assert Nand(F, x) is true + + assert Nor(T, F) is false + assert Nor(F, T) is false + assert Nor(F, F) is true + assert Nor(T, T) is false + assert Nor(T, x) is false + assert Nor(F, x) == ~x + + assert Implies(T, F) is false + assert Implies(F, T) is true + assert Implies(F, F) is true + assert Implies(T, T) is true + assert Implies(T, x) == x + assert Implies(F, x) is true + assert Implies(x, T) is true + assert Implies(x, F) == ~x + if not (T is True and F is False): + assert T >> F is false + assert F << T is false + assert F >> T is true + assert T << F is true + if F is not False: + assert F >> F is true + assert F << F is true + if T is not True: + assert T >> T is true + assert T << T is true + + assert Equivalent(T, F) is false + assert Equivalent(F, T) is false + assert Equivalent(F, F) is true + assert Equivalent(T, T) is true + assert Equivalent(T, x) == x + assert Equivalent(F, x) == ~x + assert Equivalent(x, T) == x + assert Equivalent(x, F) == ~x + + assert ITE(T, T, T) is true + assert ITE(T, T, F) is true + assert ITE(T, F, T) is false + assert ITE(T, F, F) is false + assert ITE(F, T, T) is true + assert ITE(F, T, F) is false + assert ITE(F, F, T) is true + assert ITE(F, F, F) is false + + assert all(i.simplify(1, 2) is i for i in (S.true, S.false)) + + +def test_bool_as_set(): + assert ITE(y <= 0, False, y >= 1).as_set() == Interval(1, oo) + assert And(x <= 2, x >= -2).as_set() == Interval(-2, 2) + assert Or(x >= 2, x <= -2).as_set() == Interval(-oo, -2) + Interval(2, oo) + assert Not(x > 2).as_set() == Interval(-oo, 2) + # issue 10240 + assert Not(And(x > 2, x < 3)).as_set() == \ + Union(Interval(-oo, 2), Interval(3, oo)) + assert true.as_set() == S.UniversalSet + assert false.as_set() is S.EmptySet + assert x.as_set() == S.UniversalSet + assert And(Or(x < 1, x > 3), x < 2).as_set() == Interval.open(-oo, 1) + assert And(x < 1, sin(x) < 3).as_set() == (x < 1).as_set() + raises(NotImplementedError, lambda: (sin(x) < 1).as_set()) + # watch for object morph in as_set + assert Eq(-1, cos(2 * x) ** 2 / sin(2 * x) ** 2).as_set() is S.EmptySet + + +@XFAIL +def test_multivariate_bool_as_set(): + x, y = symbols('x,y') + + assert And(x >= 0, y >= 0).as_set() == Interval(0, oo) * Interval(0, oo) + assert Or(x >= 0, y >= 0).as_set() == S.Reals * S.Reals - \ + Interval(-oo, 0, True, True) * Interval(-oo, 0, True, True) + + +def test_all_or_nothing(): + x = symbols('x', extended_real=True) + args = x >= -oo, x <= oo + v = And(*args) + if v.func is And: + assert len(v.args) == len(args) - args.count(S.true) + else: + assert v == True + v = Or(*args) + if v.func is Or: + assert len(v.args) == 2 + else: + assert v == True + + +def test_canonical_atoms(): + assert true.canonical == true + assert false.canonical == false + + +def test_negated_atoms(): + assert true.negated == false + assert false.negated == true + + +def test_issue_8777(): + assert And(x > 2, x < oo).as_set() == Interval(2, oo, left_open=True) + assert And(x >= 1, x < oo).as_set() == Interval(1, oo) + assert (x < oo).as_set() == Interval(-oo, oo) + assert (x > -oo).as_set() == Interval(-oo, oo) + + +def test_issue_8975(): + assert Or(And(-oo < x, x <= -2), And(2 <= x, x < oo)).as_set() == \ + Interval(-oo, -2) + Interval(2, oo) + + +def test_term_to_integer(): + assert term_to_integer([1, 0, 1, 0, 0, 1, 0]) == 82 + assert term_to_integer('0010101000111001') == 10809 + + +def test_issue_21971(): + a, b, c, d = symbols('a b c d') + f = a & b & c | a & c + assert f.subs(a & c, d) == b & d | d + assert f.subs(a & b & c, d) == a & c | d + + f = (a | b | c) & (a | c) + assert f.subs(a | c, d) == (b | d) & d + assert f.subs(a | b | c, d) == (a | c) & d + + f = (a ^ b ^ c) & (a ^ c) + assert f.subs(a ^ c, d) == (b ^ d) & d + assert f.subs(a ^ b ^ c, d) == (a ^ c) & d + + +def test_truth_table(): + assert list(truth_table(And(x, y), [x, y], input=False)) == \ + [False, False, False, True] + assert list(truth_table(x | y, [x, y], input=False)) == \ + [False, True, True, True] + assert list(truth_table(x >> y, [x, y], input=False)) == \ + [True, True, False, True] + assert list(truth_table(And(x, y), [x, y])) == \ + [([0, 0], False), ([0, 1], False), ([1, 0], False), ([1, 1], True)] + + +def test_issue_8571(): + for t in (S.true, S.false): + raises(TypeError, lambda: +t) + raises(TypeError, lambda: -t) + raises(TypeError, lambda: abs(t)) + # use int(bool(t)) to get 0 or 1 + raises(TypeError, lambda: int(t)) + + for o in [S.Zero, S.One, x]: + for _ in range(2): + raises(TypeError, lambda: o + t) + raises(TypeError, lambda: o - t) + raises(TypeError, lambda: o % t) + raises(TypeError, lambda: o * t) + raises(TypeError, lambda: o / t) + raises(TypeError, lambda: o ** t) + o, t = t, o # do again in reversed order + + +def test_expand_relational(): + n = symbols('n', negative=True) + p, q = symbols('p q', positive=True) + r = ((n + q * (-n / q + 1)) / (q * (-n / q + 1)) < 0) + assert r is not S.false + assert r.expand() is S.false + assert (q > 0).expand() is S.true + + +def test_issue_12717(): + assert S.true.is_Atom == True + assert S.false.is_Atom == True + + +def test_as_Boolean(): + nz = symbols('nz', nonzero=True) + assert all(as_Boolean(i) is S.true for i in (True, S.true, 1, nz)) + z = symbols('z', zero=True) + assert all(as_Boolean(i) is S.false for i in (False, S.false, 0, z)) + assert all(as_Boolean(i) == i for i in (x, x < 0)) + for i in (2, S(2), x + 1, []): + raises(TypeError, lambda: as_Boolean(i)) + + +def test_binary_symbols(): + assert ITE(x < 1, y, z).binary_symbols == {y, z} + for f in (Eq, Ne): + assert f(x, 1).binary_symbols == set() + assert f(x, True).binary_symbols == {x} + assert f(x, False).binary_symbols == {x} + assert S.true.binary_symbols == set() + assert S.false.binary_symbols == set() + assert x.binary_symbols == {x} + assert And(x, Eq(y, False), Eq(z, 1)).binary_symbols == {x, y} + assert Q.prime(x).binary_symbols == set() + assert Q.lt(x, 1).binary_symbols == set() + assert Q.is_true(x).binary_symbols == {x} + assert Q.eq(x, True).binary_symbols == {x} + assert Q.prime(x).binary_symbols == set() + + +def test_BooleanFunction_diff(): + assert And(x, y).diff(x) == Piecewise((0, Eq(y, False)), (1, True)) + + +def test_issue_14700(): + A, B, C, D, E, F, G, H = symbols('A B C D E F G H') + q = ((B & D & H & ~F) | (B & H & ~C & ~D) | (B & H & ~C & ~F) | + (B & H & ~D & ~G) | (B & H & ~F & ~G) | (C & G & ~B & ~D) | + (C & G & ~D & ~H) | (C & G & ~F & ~H) | (D & F & H & ~B) | + (D & F & ~G & ~H) | (B & D & F & ~C & ~H) | (D & E & F & ~B & ~C) | + (D & F & ~A & ~B & ~C) | (D & F & ~A & ~C & ~H) | + (A & B & D & F & ~E & ~H)) + soldnf = ((B & D & H & ~F) | (D & F & H & ~B) | (B & H & ~C & ~D) | + (B & H & ~D & ~G) | (C & G & ~B & ~D) | (C & G & ~D & ~H) | + (C & G & ~F & ~H) | (D & F & ~G & ~H) | (D & E & F & ~C & ~H) | + (D & F & ~A & ~C & ~H) | (A & B & D & F & ~E & ~H)) + solcnf = ((B | C | D) & (B | D | G) & (C | D | H) & (C | F | H) & + (D | G | H) & (F | G | H) & (B | F | ~D | ~H) & + (~B | ~D | ~F | ~H) & (D | ~B | ~C | ~G | ~H) & + (A | H | ~C | ~D | ~F | ~G) & (H | ~C | ~D | ~E | ~F | ~G) & + (B | E | H | ~A | ~D | ~F | ~G)) + assert simplify_logic(q, "dnf") == soldnf + assert simplify_logic(q, "cnf") == solcnf + + minterms = [[0, 1, 0, 0], [0, 1, 0, 1], [0, 1, 1, 0], [0, 1, 1, 1], + [0, 0, 1, 1], [1, 0, 1, 1]] + dontcares = [[1, 0, 0, 0], [1, 0, 0, 1], [1, 1, 0, 0], [1, 1, 0, 1]] + assert SOPform([w, x, y, z], minterms) == (x & ~w) | (y & z & ~x) + # Should not be more complicated with don't cares + assert SOPform([w, x, y, z], minterms, dontcares) == \ + (x & ~w) | (y & z & ~x) + + +def test_issue_25115(): + cond = Contains(x, S.Integers) + # Previously this raised an exception: + assert simplify_logic(cond) == cond + + +def test_relational_simplification(): + w, x, y, z = symbols('w x y z', real=True) + d, e = symbols('d e', real=False) + # Test all combinations or sign and order + assert Or(x >= y, x < y).simplify() == S.true + assert Or(x >= y, y > x).simplify() == S.true + assert Or(x >= y, -x > -y).simplify() == S.true + assert Or(x >= y, -y < -x).simplify() == S.true + assert Or(-x <= -y, x < y).simplify() == S.true + assert Or(-x <= -y, -x > -y).simplify() == S.true + assert Or(-x <= -y, y > x).simplify() == S.true + assert Or(-x <= -y, -y < -x).simplify() == S.true + assert Or(y <= x, x < y).simplify() == S.true + assert Or(y <= x, y > x).simplify() == S.true + assert Or(y <= x, -x > -y).simplify() == S.true + assert Or(y <= x, -y < -x).simplify() == S.true + assert Or(-y >= -x, x < y).simplify() == S.true + assert Or(-y >= -x, y > x).simplify() == S.true + assert Or(-y >= -x, -x > -y).simplify() == S.true + assert Or(-y >= -x, -y < -x).simplify() == S.true + + assert Or(x < y, x >= y).simplify() == S.true + assert Or(y > x, x >= y).simplify() == S.true + assert Or(-x > -y, x >= y).simplify() == S.true + assert Or(-y < -x, x >= y).simplify() == S.true + assert Or(x < y, -x <= -y).simplify() == S.true + assert Or(-x > -y, -x <= -y).simplify() == S.true + assert Or(y > x, -x <= -y).simplify() == S.true + assert Or(-y < -x, -x <= -y).simplify() == S.true + assert Or(x < y, y <= x).simplify() == S.true + assert Or(y > x, y <= x).simplify() == S.true + assert Or(-x > -y, y <= x).simplify() == S.true + assert Or(-y < -x, y <= x).simplify() == S.true + assert Or(x < y, -y >= -x).simplify() == S.true + assert Or(y > x, -y >= -x).simplify() == S.true + assert Or(-x > -y, -y >= -x).simplify() == S.true + assert Or(-y < -x, -y >= -x).simplify() == S.true + + # Some other tests + assert Or(x >= y, w < z, x <= y).simplify() == S.true + assert And(x >= y, x < y).simplify() == S.false + assert Or(x >= y, Eq(y, x)).simplify() == (x >= y) + assert And(x >= y, Eq(y, x)).simplify() == Eq(x, y) + assert And(Eq(x, y), x >= 1, 2 < y, y >= 5, z < y).simplify() == \ + (Eq(x, y) & (x >= 1) & (y >= 5) & (y > z)) + assert Or(Eq(x, y), x >= y, w < y, z < y).simplify() == \ + (x >= y) | (y > z) | (w < y) + assert And(Eq(x, y), x >= y, w < y, y >= z, z < y).simplify() == \ + Eq(x, y) & (y > z) & (w < y) + # assert And(Eq(x, y), x >= y, w < y, y >= z, z < y).simplify(relational_minmax=True) == \ + # And(Eq(x, y), y > Max(w, z)) + # assert Or(Eq(x, y), x >= 1, 2 < y, y >= 5, z < y).simplify(relational_minmax=True) == \ + # (Eq(x, y) | (x >= 1) | (y > Min(2, z))) + assert And(Eq(x, y), x >= 1, 2 < y, y >= 5, z < y).simplify() == \ + (Eq(x, y) & (x >= 1) & (y >= 5) & (y > z)) + assert (Eq(x, y) & Eq(d, e) & (x >= y) & (d >= e)).simplify() == \ + (Eq(x, y) & Eq(d, e) & (d >= e)) + assert And(Eq(x, y), Eq(x, -y)).simplify() == And(Eq(x, 0), Eq(y, 0)) + assert Xor(x >= y, x <= y).simplify() == Ne(x, y) + assert And(x > 1, x < -1, Eq(x, y)).simplify() == S.false + # From #16690 + assert And(x >= y, Eq(y, 0)).simplify() == And(x >= 0, Eq(y, 0)) + assert Or(Ne(x, 1), Ne(x, 2)).simplify() == S.true + assert And(Eq(x, 1), Ne(2, x)).simplify() == Eq(x, 1) + assert Or(Eq(x, 1), Ne(2, x)).simplify() == Ne(x, 2) + + +def test_issue_8373(): + x = symbols('x', real=True) + assert Or(x < 1, x > -1).simplify() == S.true + assert Or(x < 1, x >= 1).simplify() == S.true + assert And(x < 1, x >= 1).simplify() == S.false + assert Or(x <= 1, x >= 1).simplify() == S.true + + +def test_issue_7950(): + x = symbols('x', real=True) + assert And(Eq(x, 1), Eq(x, 2)).simplify() == S.false + + +@slow +def test_relational_simplification_numerically(): + def test_simplification_numerically_function(original, simplified): + symb = original.free_symbols + n = len(symb) + valuelist = list(set(combinations(list(range(-(n - 1), n)) * n, n))) + for values in valuelist: + sublist = dict(zip(symb, values)) + originalvalue = original.subs(sublist) + simplifiedvalue = simplified.subs(sublist) + assert originalvalue == simplifiedvalue, "Original: {}\nand" \ + " simplified: {}\ndo not evaluate to the same value for {}" \ + "".format(original, simplified, sublist) + + w, x, y, z = symbols('w x y z', real=True) + d, e = symbols('d e', real=False) + + expressions = (And(Eq(x, y), x >= y, w < y, y >= z, z < y), + And(Eq(x, y), x >= 1, 2 < y, y >= 5, z < y), + Or(Eq(x, y), x >= 1, 2 < y, y >= 5, z < y), + And(x >= y, Eq(y, x)), + Or(And(Eq(x, y), x >= y, w < y, Or(y >= z, z < y)), + And(Eq(x, y), x >= 1, 2 < y, y >= -1, z < y)), + (Eq(x, y) & Eq(d, e) & (x >= y) & (d >= e)), + ) + + for expression in expressions: + test_simplification_numerically_function(expression, + expression.simplify()) + + +def test_relational_simplification_patterns_numerically(): + from sympy.core import Wild + from sympy.logic.boolalg import _simplify_patterns_and, \ + _simplify_patterns_or, _simplify_patterns_xor + a = Wild('a') + b = Wild('b') + c = Wild('c') + symb = [a, b, c] + patternlists = [[And, _simplify_patterns_and()], + [Or, _simplify_patterns_or()], + [Xor, _simplify_patterns_xor()]] + valuelist = list(set(combinations(list(range(-2, 3)) * 3, 3))) + # Skip combinations of +/-2 and 0, except for all 0 + valuelist = [v for v in valuelist if any(w % 2 for w in v) or not any(v)] + for func, patternlist in patternlists: + for pattern in patternlist: + original = func(*pattern[0].args) + simplified = pattern[1] + for values in valuelist: + sublist = dict(zip(symb, values)) + originalvalue = original.xreplace(sublist) + simplifiedvalue = simplified.xreplace(sublist) + assert originalvalue == simplifiedvalue, "Original: {}\nand" \ + " simplified: {}\ndo not evaluate to the same value for" \ + "{}".format(pattern[0], simplified, sublist) + + +def test_issue_16803(): + n = symbols('n') + # No simplification done, but should not raise an exception + assert ((n > 3) | (n < 0) | ((n > 0) & (n < 3))).simplify() == \ + (n > 3) | (n < 0) | ((n > 0) & (n < 3)) + + +def test_issue_17530(): + r = {x: oo, y: oo} + assert Or(x + y > 0, x - y < 0).subs(r) + assert not And(x + y < 0, x - y < 0).subs(r) + raises(TypeError, lambda: Or(x + y < 0, x - y < 0).subs(r)) + raises(TypeError, lambda: And(x + y > 0, x - y < 0).subs(r)) + raises(TypeError, lambda: And(x + y > 0, x - y < 0).subs(r)) + + +def test_anf_coeffs(): + assert anf_coeffs([1, 0]) == [1, 1] + assert anf_coeffs([0, 0, 0, 1]) == [0, 0, 0, 1] + assert anf_coeffs([0, 1, 1, 1]) == [0, 1, 1, 1] + assert anf_coeffs([1, 1, 1, 0]) == [1, 0, 0, 1] + assert anf_coeffs([1, 0, 0, 0]) == [1, 1, 1, 1] + assert anf_coeffs([1, 0, 0, 1]) == [1, 1, 1, 0] + assert anf_coeffs([1, 1, 0, 1]) == [1, 0, 1, 1] + + +def test_ANFform(): + x, y = symbols('x,y') + assert ANFform([x], [1, 1]) == True + assert ANFform([x], [0, 0]) == False + assert ANFform([x], [1, 0]) == Xor(x, True, remove_true=False) + assert ANFform([x, y], [1, 1, 1, 0]) == \ + Xor(True, And(x, y), remove_true=False) + + +def test_bool_minterm(): + x, y = symbols('x,y') + assert bool_minterm(3, [x, y]) == And(x, y) + assert bool_minterm([1, 0], [x, y]) == And(Not(y), x) + + +def test_bool_maxterm(): + x, y = symbols('x,y') + assert bool_maxterm(2, [x, y]) == Or(Not(x), y) + assert bool_maxterm([0, 1], [x, y]) == Or(Not(y), x) + + +def test_bool_monomial(): + x, y = symbols('x,y') + assert bool_monomial(1, [x, y]) == y + assert bool_monomial([1, 1], [x, y]) == And(x, y) + + +def test_check_pair(): + assert _check_pair([0, 1, 0], [0, 1, 1]) == 2 + assert _check_pair([0, 1, 0], [1, 1, 1]) == -1 + + +def test_issue_19114(): + expr = (B & C) | (A & ~C) | (~A & ~B) + # Expression is minimal, but there are multiple minimal forms possible + res1 = (A & B) | (C & ~A) | (~B & ~C) + result = to_dnf(expr, simplify=True) + assert result in (expr, res1) + + +def test_issue_20870(): + result = SOPform([a, b, c, d], [1, 2, 3, 4, 5, 6, 8, 9, 11, 12, 14, 15]) + expected = ((d & ~b) | (a & b & c) | (a & ~c & ~d) | + (b & ~a & ~c) | (c & ~a & ~d)) + assert result == expected + + +def test_convert_to_varsSOP(): + assert _convert_to_varsSOP([0, 1, 0], [x, y, z]) == And(Not(x), y, Not(z)) + assert _convert_to_varsSOP([3, 1, 0], [x, y, z]) == And(y, Not(z)) + + +def test_convert_to_varsPOS(): + assert _convert_to_varsPOS([0, 1, 0], [x, y, z]) == Or(x, Not(y), z) + assert _convert_to_varsPOS([3, 1, 0], [x, y, z]) == Or(Not(y), z) + + +def test_gateinputcount(): + a, b, c, d, e = symbols('a:e') + assert gateinputcount(And(a, b)) == 2 + assert gateinputcount(a | b & c & d ^ (e | a)) == 9 + assert gateinputcount(And(a, True)) == 0 + raises(TypeError, lambda: gateinputcount(a * b)) + + +def test_refine(): + # relational + assert not refine(x < 0, ~(x < 0)) + assert refine(x < 0, (x < 0)) + assert refine(x < 0, (0 > x)) is S.true + assert refine(x < 0, (y < 0)) == (x < 0) + assert not refine(x <= 0, ~(x <= 0)) + assert refine(x <= 0, (x <= 0)) + assert refine(x <= 0, (0 >= x)) is S.true + assert refine(x <= 0, (y <= 0)) == (x <= 0) + assert not refine(x > 0, ~(x > 0)) + assert refine(x > 0, (x > 0)) + assert refine(x > 0, (0 < x)) is S.true + assert refine(x > 0, (y > 0)) == (x > 0) + assert not refine(x >= 0, ~(x >= 0)) + assert refine(x >= 0, (x >= 0)) + assert refine(x >= 0, (0 <= x)) is S.true + assert refine(x >= 0, (y >= 0)) == (x >= 0) + assert not refine(Eq(x, 0), ~(Eq(x, 0))) + assert refine(Eq(x, 0), (Eq(x, 0))) + assert refine(Eq(x, 0), (Eq(0, x))) is S.true + assert refine(Eq(x, 0), (Eq(y, 0))) == Eq(x, 0) + assert not refine(Ne(x, 0), ~(Ne(x, 0))) + assert refine(Ne(x, 0), (Ne(0, x))) is S.true + assert refine(Ne(x, 0), (Ne(x, 0))) + assert refine(Ne(x, 0), (Ne(y, 0))) == (Ne(x, 0)) + + # boolean functions + assert refine(And(x > 0, y > 0), (x > 0)) == (y > 0) + assert refine(And(x > 0, y > 0), (x > 0) & (y > 0)) is S.true + + # predicates + assert refine(Q.positive(x), Q.positive(x)) is S.true + assert refine(Q.positive(x), Q.negative(x)) is S.false + assert refine(Q.positive(x), Q.real(x)) == Q.positive(x) + + +def test_relational_threeterm_simplification_patterns_numerically(): + from sympy.core import Wild + from sympy.logic.boolalg import _simplify_patterns_and3 + a = Wild('a') + b = Wild('b') + c = Wild('c') + symb = [a, b, c] + patternlists = [[And, _simplify_patterns_and3()]] + valuelist = list(set(combinations(list(range(-2, 3)) * 3, 3))) + # Skip combinations of +/-2 and 0, except for all 0 + valuelist = [v for v in valuelist if any(w % 2 for w in v) or not any(v)] + for func, patternlist in patternlists: + for pattern in patternlist: + original = func(*pattern[0].args) + simplified = pattern[1] + for values in valuelist: + sublist = dict(zip(symb, values)) + originalvalue = original.xreplace(sublist) + simplifiedvalue = simplified.xreplace(sublist) + assert originalvalue == simplifiedvalue, "Original: {}\nand" \ + " simplified: {}\ndo not evaluate to the same value for" \ + "{}".format(pattern[0], simplified, sublist) + + +def test_issue_25451(): + x = Or(And(a, c), Eq(a, b)) + assert isinstance(x, Or) + assert set(x.args) == {And(a, c), Eq(a, b)} + + +def test_issue_26985(): + a, b, c, d = symbols('a b c d') + + # Expression before applying to_anf + x = Xor(c, And(a, b), And(a, c)) + y = Xor(a, b, And(a, c)) + + # Applying to_anf + result = Xor(Xor(d, And(x, y)), And(x, y)) + result_anf = to_anf(Xor(to_anf(Xor(d, And(x, y))), And(x, y))) + + assert result_anf == d + assert result == d diff --git a/phivenv/Lib/site-packages/sympy/logic/tests/test_dimacs.py b/phivenv/Lib/site-packages/sympy/logic/tests/test_dimacs.py new file mode 100644 index 0000000000000000000000000000000000000000..3a9a51a39d33fb807688614cb5809b621ce21a2c --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/logic/tests/test_dimacs.py @@ -0,0 +1,234 @@ +"""Various tests on satisfiability using dimacs cnf file syntax +You can find lots of cnf files in +ftp://dimacs.rutgers.edu/pub/challenge/satisfiability/benchmarks/cnf/ +""" + +from sympy.logic.utilities.dimacs import load +from sympy.logic.algorithms.dpll import dpll_satisfiable + + +def test_f1(): + assert bool(dpll_satisfiable(load(f1))) + + +def test_f2(): + assert bool(dpll_satisfiable(load(f2))) + + +def test_f3(): + assert bool(dpll_satisfiable(load(f3))) + + +def test_f4(): + assert not bool(dpll_satisfiable(load(f4))) + + +def test_f5(): + assert bool(dpll_satisfiable(load(f5))) + +f1 = """c simple example +c Resolution: SATISFIABLE +c +p cnf 3 2 +1 -3 0 +2 3 -1 0 +""" + + +f2 = """c an example from Quinn's text, 16 variables and 18 clauses. +c Resolution: SATISFIABLE +c +p cnf 16 18 + 1 2 0 + -2 -4 0 + 3 4 0 + -4 -5 0 + 5 -6 0 + 6 -7 0 + 6 7 0 + 7 -16 0 + 8 -9 0 + -8 -14 0 + 9 10 0 + 9 -10 0 +-10 -11 0 + 10 12 0 + 11 12 0 + 13 14 0 + 14 -15 0 + 15 16 0 +""" + +f3 = """c +p cnf 6 9 +-1 0 +-3 0 +2 -1 0 +2 -4 0 +5 -4 0 +-1 -3 0 +-4 -6 0 +1 3 -2 0 +4 6 -2 -5 0 +""" + +f4 = """c +c file: hole6.cnf [http://people.sc.fsu.edu/~jburkardt/data/cnf/hole6.cnf] +c +c SOURCE: John Hooker (jh38+@andrew.cmu.edu) +c +c DESCRIPTION: Pigeon hole problem of placing n (for file 'holen.cnf') pigeons +c in n+1 holes without placing 2 pigeons in the same hole +c +c NOTE: Part of the collection at the Forschungsinstitut fuer +c anwendungsorientierte Wissensverarbeitung in Ulm Germany. +c +c NOTE: Not satisfiable +c +p cnf 42 133 +-1 -7 0 +-1 -13 0 +-1 -19 0 +-1 -25 0 +-1 -31 0 +-1 -37 0 +-7 -13 0 +-7 -19 0 +-7 -25 0 +-7 -31 0 +-7 -37 0 +-13 -19 0 +-13 -25 0 +-13 -31 0 +-13 -37 0 +-19 -25 0 +-19 -31 0 +-19 -37 0 +-25 -31 0 +-25 -37 0 +-31 -37 0 +-2 -8 0 +-2 -14 0 +-2 -20 0 +-2 -26 0 +-2 -32 0 +-2 -38 0 +-8 -14 0 +-8 -20 0 +-8 -26 0 +-8 -32 0 +-8 -38 0 +-14 -20 0 +-14 -26 0 +-14 -32 0 +-14 -38 0 +-20 -26 0 +-20 -32 0 +-20 -38 0 +-26 -32 0 +-26 -38 0 +-32 -38 0 +-3 -9 0 +-3 -15 0 +-3 -21 0 +-3 -27 0 +-3 -33 0 +-3 -39 0 +-9 -15 0 +-9 -21 0 +-9 -27 0 +-9 -33 0 +-9 -39 0 +-15 -21 0 +-15 -27 0 +-15 -33 0 +-15 -39 0 +-21 -27 0 +-21 -33 0 +-21 -39 0 +-27 -33 0 +-27 -39 0 +-33 -39 0 +-4 -10 0 +-4 -16 0 +-4 -22 0 +-4 -28 0 +-4 -34 0 +-4 -40 0 +-10 -16 0 +-10 -22 0 +-10 -28 0 +-10 -34 0 +-10 -40 0 +-16 -22 0 +-16 -28 0 +-16 -34 0 +-16 -40 0 +-22 -28 0 +-22 -34 0 +-22 -40 0 +-28 -34 0 +-28 -40 0 +-34 -40 0 +-5 -11 0 +-5 -17 0 +-5 -23 0 +-5 -29 0 +-5 -35 0 +-5 -41 0 +-11 -17 0 +-11 -23 0 +-11 -29 0 +-11 -35 0 +-11 -41 0 +-17 -23 0 +-17 -29 0 +-17 -35 0 +-17 -41 0 +-23 -29 0 +-23 -35 0 +-23 -41 0 +-29 -35 0 +-29 -41 0 +-35 -41 0 +-6 -12 0 +-6 -18 0 +-6 -24 0 +-6 -30 0 +-6 -36 0 +-6 -42 0 +-12 -18 0 +-12 -24 0 +-12 -30 0 +-12 -36 0 +-12 -42 0 +-18 -24 0 +-18 -30 0 +-18 -36 0 +-18 -42 0 +-24 -30 0 +-24 -36 0 +-24 -42 0 +-30 -36 0 +-30 -42 0 +-36 -42 0 + 6 5 4 3 2 1 0 + 12 11 10 9 8 7 0 + 18 17 16 15 14 13 0 + 24 23 22 21 20 19 0 + 30 29 28 27 26 25 0 + 36 35 34 33 32 31 0 + 42 41 40 39 38 37 0 +""" + +f5 = """c simple example requiring variable selection +c +c NOTE: Satisfiable +c +p cnf 5 5 +1 2 3 0 +1 -2 3 0 +4 5 -3 0 +1 -4 -3 0 +-1 -5 0 +""" diff --git a/phivenv/Lib/site-packages/sympy/logic/tests/test_inference.py b/phivenv/Lib/site-packages/sympy/logic/tests/test_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..ff37b1b104f6f106ec5df7809fd34959bce35917 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/logic/tests/test_inference.py @@ -0,0 +1,396 @@ +"""For more tests on satisfiability, see test_dimacs""" + +from sympy.assumptions.ask import Q +from sympy.core.symbol import symbols +from sympy.core.relational import Unequality +from sympy.logic.boolalg import And, Or, Implies, Equivalent, true, false +from sympy.logic.inference import literal_symbol, \ + pl_true, satisfiable, valid, entails, PropKB +from sympy.logic.algorithms.dpll import dpll, dpll_satisfiable, \ + find_pure_symbol, find_unit_clause, unit_propagate, \ + find_pure_symbol_int_repr, find_unit_clause_int_repr, \ + unit_propagate_int_repr +from sympy.logic.algorithms.dpll2 import dpll_satisfiable as dpll2_satisfiable + +from sympy.logic.algorithms.z3_wrapper import z3_satisfiable +from sympy.assumptions.cnf import CNF, EncodedCNF +from sympy.logic.tests.test_lra_theory import make_random_problem +from sympy.core.random import randint + +from sympy.testing.pytest import raises, skip +from sympy.external import import_module + + +def test_literal(): + A, B = symbols('A,B') + assert literal_symbol(True) is True + assert literal_symbol(False) is False + assert literal_symbol(A) is A + assert literal_symbol(~A) is A + + +def test_find_pure_symbol(): + A, B, C = symbols('A,B,C') + assert find_pure_symbol([A], [A]) == (A, True) + assert find_pure_symbol([A, B], [~A | B, ~B | A]) == (None, None) + assert find_pure_symbol([A, B, C], [ A | ~B, ~B | ~C, C | A]) == (A, True) + assert find_pure_symbol([A, B, C], [~A | B, B | ~C, C | A]) == (B, True) + assert find_pure_symbol([A, B, C], [~A | ~B, ~B | ~C, C | A]) == (B, False) + assert find_pure_symbol( + [A, B, C], [~A | B, ~B | ~C, C | A]) == (None, None) + + +def test_find_pure_symbol_int_repr(): + assert find_pure_symbol_int_repr([1], [{1}]) == (1, True) + assert find_pure_symbol_int_repr([1, 2], + [{-1, 2}, {-2, 1}]) == (None, None) + assert find_pure_symbol_int_repr([1, 2, 3], + [{1, -2}, {-2, -3}, {3, 1}]) == (1, True) + assert find_pure_symbol_int_repr([1, 2, 3], + [{-1, 2}, {2, -3}, {3, 1}]) == (2, True) + assert find_pure_symbol_int_repr([1, 2, 3], + [{-1, -2}, {-2, -3}, {3, 1}]) == (2, False) + assert find_pure_symbol_int_repr([1, 2, 3], + [{-1, 2}, {-2, -3}, {3, 1}]) == (None, None) + + +def test_unit_clause(): + A, B, C = symbols('A,B,C') + assert find_unit_clause([A], {}) == (A, True) + assert find_unit_clause([A, ~A], {}) == (A, True) # Wrong ?? + assert find_unit_clause([A | B], {A: True}) == (B, True) + assert find_unit_clause([A | B], {B: True}) == (A, True) + assert find_unit_clause( + [A | B | C, B | ~C, A | ~B], {A: True}) == (B, False) + assert find_unit_clause([A | B | C, B | ~C, A | B], {A: True}) == (B, True) + assert find_unit_clause([A | B | C, B | ~C, A ], {}) == (A, True) + + +def test_unit_clause_int_repr(): + assert find_unit_clause_int_repr(map(set, [[1]]), {}) == (1, True) + assert find_unit_clause_int_repr(map(set, [[1], [-1]]), {}) == (1, True) + assert find_unit_clause_int_repr([{1, 2}], {1: True}) == (2, True) + assert find_unit_clause_int_repr([{1, 2}], {2: True}) == (1, True) + assert find_unit_clause_int_repr(map(set, + [[1, 2, 3], [2, -3], [1, -2]]), {1: True}) == (2, False) + assert find_unit_clause_int_repr(map(set, + [[1, 2, 3], [3, -3], [1, 2]]), {1: True}) == (2, True) + + A, B, C = symbols('A,B,C') + assert find_unit_clause([A | B | C, B | ~C, A ], {}) == (A, True) + + +def test_unit_propagate(): + A, B, C = symbols('A,B,C') + assert unit_propagate([A | B], A) == [] + assert unit_propagate([A | B, ~A | C, ~C | B, A], A) == [C, ~C | B, A] + + +def test_unit_propagate_int_repr(): + assert unit_propagate_int_repr([{1, 2}], 1) == [] + assert unit_propagate_int_repr(map(set, + [[1, 2], [-1, 3], [-3, 2], [1]]), 1) == [{3}, {-3, 2}] + + +def test_dpll(): + """This is also tested in test_dimacs""" + A, B, C = symbols('A,B,C') + assert dpll([A | B], [A, B], {A: True, B: True}) == {A: True, B: True} + + +def test_dpll_satisfiable(): + A, B, C = symbols('A,B,C') + assert dpll_satisfiable( A & ~A ) is False + assert dpll_satisfiable( A & ~B ) == {A: True, B: False} + assert dpll_satisfiable( + A | B ) in ({A: True}, {B: True}, {A: True, B: True}) + assert dpll_satisfiable( + (~A | B) & (~B | A) ) in ({A: True, B: True}, {A: False, B: False}) + assert dpll_satisfiable( (A | B) & (~B | C) ) in ({A: True, B: False}, + {A: True, C: True}, {B: True, C: True}) + assert dpll_satisfiable( A & B & C ) == {A: True, B: True, C: True} + assert dpll_satisfiable( (A | B) & (A >> B) ) == {B: True} + assert dpll_satisfiable( Equivalent(A, B) & A ) == {A: True, B: True} + assert dpll_satisfiable( Equivalent(A, B) & ~A ) == {A: False, B: False} + + +def test_dpll2_satisfiable(): + A, B, C = symbols('A,B,C') + assert dpll2_satisfiable( A & ~A ) is False + assert dpll2_satisfiable( A & ~B ) == {A: True, B: False} + assert dpll2_satisfiable( + A | B ) in ({A: True}, {B: True}, {A: True, B: True}) + assert dpll2_satisfiable( + (~A | B) & (~B | A) ) in ({A: True, B: True}, {A: False, B: False}) + assert dpll2_satisfiable( (A | B) & (~B | C) ) in ({A: True, B: False, C: True}, + {A: True, B: True, C: True}) + assert dpll2_satisfiable( A & B & C ) == {A: True, B: True, C: True} + assert dpll2_satisfiable( (A | B) & (A >> B) ) in ({B: True, A: False}, + {B: True, A: True}) + assert dpll2_satisfiable( Equivalent(A, B) & A ) == {A: True, B: True} + assert dpll2_satisfiable( Equivalent(A, B) & ~A ) == {A: False, B: False} + + +def test_minisat22_satisfiable(): + A, B, C = symbols('A,B,C') + minisat22_satisfiable = lambda expr: satisfiable(expr, algorithm="minisat22") + assert minisat22_satisfiable( A & ~A ) is False + assert minisat22_satisfiable( A & ~B ) == {A: True, B: False} + assert minisat22_satisfiable( + A | B ) in ({A: True}, {B: False}, {A: False, B: True}, {A: True, B: True}, {A: True, B: False}) + assert minisat22_satisfiable( + (~A | B) & (~B | A) ) in ({A: True, B: True}, {A: False, B: False}) + assert minisat22_satisfiable( (A | B) & (~B | C) ) in ({A: True, B: False, C: True}, + {A: True, B: True, C: True}, {A: False, B: True, C: True}, {A: True, B: False, C: False}) + assert minisat22_satisfiable( A & B & C ) == {A: True, B: True, C: True} + assert minisat22_satisfiable( (A | B) & (A >> B) ) in ({B: True, A: False}, + {B: True, A: True}) + assert minisat22_satisfiable( Equivalent(A, B) & A ) == {A: True, B: True} + assert minisat22_satisfiable( Equivalent(A, B) & ~A ) == {A: False, B: False} + +def test_minisat22_minimal_satisfiable(): + A, B, C = symbols('A,B,C') + minisat22_satisfiable = lambda expr, minimal=True: satisfiable(expr, algorithm="minisat22", minimal=True) + assert minisat22_satisfiable( A & ~A ) is False + assert minisat22_satisfiable( A & ~B ) == {A: True, B: False} + assert minisat22_satisfiable( + A | B ) in ({A: True}, {B: False}, {A: False, B: True}, {A: True, B: True}, {A: True, B: False}) + assert minisat22_satisfiable( + (~A | B) & (~B | A) ) in ({A: True, B: True}, {A: False, B: False}) + assert minisat22_satisfiable( (A | B) & (~B | C) ) in ({A: True, B: False, C: True}, + {A: True, B: True, C: True}, {A: False, B: True, C: True}, {A: True, B: False, C: False}) + assert minisat22_satisfiable( A & B & C ) == {A: True, B: True, C: True} + assert minisat22_satisfiable( (A | B) & (A >> B) ) in ({B: True, A: False}, + {B: True, A: True}) + assert minisat22_satisfiable( Equivalent(A, B) & A ) == {A: True, B: True} + assert minisat22_satisfiable( Equivalent(A, B) & ~A ) == {A: False, B: False} + g = satisfiable((A | B | C),algorithm="minisat22",minimal=True,all_models=True) + sol = next(g) + first_solution = {key for key, value in sol.items() if value} + sol=next(g) + second_solution = {key for key, value in sol.items() if value} + sol=next(g) + third_solution = {key for key, value in sol.items() if value} + assert not first_solution <= second_solution + assert not second_solution <= third_solution + assert not first_solution <= third_solution + +def test_satisfiable(): + A, B, C = symbols('A,B,C') + assert satisfiable(A & (A >> B) & ~B) is False + + +def test_valid(): + A, B, C = symbols('A,B,C') + assert valid(A >> (B >> A)) is True + assert valid((A >> (B >> C)) >> ((A >> B) >> (A >> C))) is True + assert valid((~B >> ~A) >> (A >> B)) is True + assert valid(A | B | C) is False + assert valid(A >> B) is False + + +def test_pl_true(): + A, B, C = symbols('A,B,C') + assert pl_true(True) is True + assert pl_true( A & B, {A: True, B: True}) is True + assert pl_true( A | B, {A: True}) is True + assert pl_true( A | B, {B: True}) is True + assert pl_true( A | B, {A: None, B: True}) is True + assert pl_true( A >> B, {A: False}) is True + assert pl_true( A | B | ~C, {A: False, B: True, C: True}) is True + assert pl_true(Equivalent(A, B), {A: False, B: False}) is True + + # test for false + assert pl_true(False) is False + assert pl_true( A & B, {A: False, B: False}) is False + assert pl_true( A & B, {A: False}) is False + assert pl_true( A & B, {B: False}) is False + assert pl_true( A | B, {A: False, B: False}) is False + + #test for None + assert pl_true(B, {B: None}) is None + assert pl_true( A & B, {A: True, B: None}) is None + assert pl_true( A >> B, {A: True, B: None}) is None + assert pl_true(Equivalent(A, B), {A: None}) is None + assert pl_true(Equivalent(A, B), {A: True, B: None}) is None + + # Test for deep + assert pl_true(A | B, {A: False}, deep=True) is None + assert pl_true(~A & ~B, {A: False}, deep=True) is None + assert pl_true(A | B, {A: False, B: False}, deep=True) is False + assert pl_true(A & B & (~A | ~B), {A: True}, deep=True) is False + assert pl_true((C >> A) >> (B >> A), {C: True}, deep=True) is True + + +def test_pl_true_wrong_input(): + from sympy.core.numbers import pi + raises(ValueError, lambda: pl_true('John Cleese')) + raises(ValueError, lambda: pl_true(42 + pi + pi ** 2)) + raises(ValueError, lambda: pl_true(42)) + + +def test_entails(): + A, B, C = symbols('A, B, C') + assert entails(A, [A >> B, ~B]) is False + assert entails(B, [Equivalent(A, B), A]) is True + assert entails((A >> B) >> (~A >> ~B)) is False + assert entails((A >> B) >> (~B >> ~A)) is True + + +def test_PropKB(): + A, B, C = symbols('A,B,C') + kb = PropKB() + assert kb.ask(A >> B) is False + assert kb.ask(A >> (B >> A)) is True + kb.tell(A >> B) + kb.tell(B >> C) + assert kb.ask(A) is False + assert kb.ask(B) is False + assert kb.ask(C) is False + assert kb.ask(~A) is False + assert kb.ask(~B) is False + assert kb.ask(~C) is False + assert kb.ask(A >> C) is True + kb.tell(A) + assert kb.ask(A) is True + assert kb.ask(B) is True + assert kb.ask(C) is True + assert kb.ask(~C) is False + kb.retract(A) + assert kb.ask(C) is False + + +def test_propKB_tolerant(): + """"tolerant to bad input""" + kb = PropKB() + A, B, C = symbols('A,B,C') + assert kb.ask(B) is False + +def test_satisfiable_non_symbols(): + x, y = symbols('x y') + assumptions = Q.zero(x*y) + facts = Implies(Q.zero(x*y), Q.zero(x) | Q.zero(y)) + query = ~Q.zero(x) & ~Q.zero(y) + refutations = [ + {Q.zero(x): True, Q.zero(x*y): True}, + {Q.zero(y): True, Q.zero(x*y): True}, + {Q.zero(x): True, Q.zero(y): True, Q.zero(x*y): True}, + {Q.zero(x): True, Q.zero(y): False, Q.zero(x*y): True}, + {Q.zero(x): False, Q.zero(y): True, Q.zero(x*y): True}] + assert not satisfiable(And(assumptions, facts, query), algorithm='dpll') + assert satisfiable(And(assumptions, facts, ~query), algorithm='dpll') in refutations + assert not satisfiable(And(assumptions, facts, query), algorithm='dpll2') + assert satisfiable(And(assumptions, facts, ~query), algorithm='dpll2') in refutations + +def test_satisfiable_bool(): + from sympy.core.singleton import S + assert satisfiable(true) == {true: true} + assert satisfiable(S.true) == {true: true} + assert satisfiable(false) is False + assert satisfiable(S.false) is False + + +def test_satisfiable_all_models(): + from sympy.abc import A, B + assert next(satisfiable(False, all_models=True)) is False + assert list(satisfiable((A >> ~A) & A, all_models=True)) == [False] + assert list(satisfiable(True, all_models=True)) == [{true: true}] + + models = [{A: True, B: False}, {A: False, B: True}] + result = satisfiable(A ^ B, all_models=True) + models.remove(next(result)) + models.remove(next(result)) + raises(StopIteration, lambda: next(result)) + assert not models + + assert list(satisfiable(Equivalent(A, B), all_models=True)) == \ + [{A: False, B: False}, {A: True, B: True}] + + models = [{A: False, B: False}, {A: False, B: True}, {A: True, B: True}] + for model in satisfiable(A >> B, all_models=True): + models.remove(model) + assert not models + + # This is a santiy test to check that only the required number + # of solutions are generated. The expr below has 2**100 - 1 models + # which would time out the test if all are generated at once. + from sympy.utilities.iterables import numbered_symbols + from sympy.logic.boolalg import Or + sym = numbered_symbols() + X = [next(sym) for i in range(100)] + result = satisfiable(Or(*X), all_models=True) + for i in range(10): + assert next(result) + + +def test_z3(): + z3 = import_module("z3") + + if not z3: + skip("z3 not installed.") + A, B, C = symbols('A,B,C') + x, y, z = symbols('x,y,z') + assert z3_satisfiable((x >= 2) & (x < 1)) is False + assert z3_satisfiable( A & ~A ) is False + + model = z3_satisfiable(A & (~A | B | C)) + assert bool(model) is True + assert model[A] is True + + # test nonlinear function + assert z3_satisfiable((x ** 2 >= 2) & (x < 1) & (x > -1)) is False + + +def test_z3_vs_lra_dpll2(): + z3 = import_module("z3") + if z3 is None: + skip("z3 not installed.") + + def boolean_formula_to_encoded_cnf(bf): + cnf = CNF.from_prop(bf) + enc = EncodedCNF() + enc.from_cnf(cnf) + return enc + + def make_random_cnf(num_clauses=5, num_constraints=10, num_var=2): + assert num_clauses <= num_constraints + constraints = make_random_problem(num_variables=num_var, num_constraints=num_constraints, rational=False) + clauses = [[cons] for cons in constraints[:num_clauses]] + for cons in constraints[num_clauses:]: + if isinstance(cons, Unequality): + cons = ~cons + i = randint(0, num_clauses-1) + clauses[i].append(cons) + + clauses = [Or(*clause) for clause in clauses] + cnf = And(*clauses) + return boolean_formula_to_encoded_cnf(cnf) + + lra_dpll2_satisfiable = lambda x: dpll2_satisfiable(x, use_lra_theory=True) + + for _ in range(50): + cnf = make_random_cnf(num_clauses=10, num_constraints=15, num_var=2) + + try: + z3_sat = z3_satisfiable(cnf) + except z3.z3types.Z3Exception: + continue + + lra_dpll2_sat = lra_dpll2_satisfiable(cnf) is not False + + assert z3_sat == lra_dpll2_sat + +def test_issue_27733(): + x, y = symbols('x,y') + clauses = [[1, -3, -2], [5, 7, -8, -6, -4], [-10, -9, 10, 11, -4], [-12, 13, 14], [-10, 9, -6, 11, -4], + [16, -15, 18, -19, -17], [11, -6, 10, -9], [9, 11, -10, -9], [2, -3, -1], [-13, 12], [-15, 3, -17], + [-16, -15, 19, -17], [-6, -9, 10, 11, -4], [20, -1, -2], [-23, -22, -21], [10, 11, -10, -9], + [9, 11, -4, -10], [24, -6, -4], [-14, 12], [-10, -9, 9, -6, 11], [25, -27, -26], [-15, 19, -18, -17], + [5, 8, -7, -6, -4], [-30, -29, 28], [12], [14]] + + encoding = {Q.gt(y, i): i for i in range(1, 31) if i != 11 and i != 12} + encoding[Q.gt(x, 0)] = 11 + encoding[Q.lt(x, 0)] = 12 + + cnf = EncodedCNF(clauses, encoding) + assert satisfiable(cnf, use_lra_theory=True) is False diff --git a/phivenv/Lib/site-packages/sympy/logic/tests/test_lra_theory.py b/phivenv/Lib/site-packages/sympy/logic/tests/test_lra_theory.py new file mode 100644 index 0000000000000000000000000000000000000000..207a3c5ba2c1b16ee5323382deee0863a5dfb595 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/logic/tests/test_lra_theory.py @@ -0,0 +1,440 @@ +from sympy.core.numbers import Rational, I, oo +from sympy.core.relational import Eq +from sympy.core.symbol import symbols +from sympy.core.singleton import S +from sympy.matrices.dense import Matrix +from sympy.matrices.dense import randMatrix +from sympy.assumptions.ask import Q +from sympy.logic.boolalg import And +from sympy.abc import x, y, z +from sympy.assumptions.cnf import CNF, EncodedCNF +from sympy.functions.elementary.trigonometric import cos +from sympy.external import import_module + +from sympy.logic.algorithms.lra_theory import LRASolver, UnhandledInput, LRARational, HANDLE_NEGATION +from sympy.core.random import random, choice, randint +from sympy.core.sympify import sympify +from sympy.ntheory.generate import randprime +from sympy.core.relational import StrictLessThan, StrictGreaterThan +import itertools + +from sympy.testing.pytest import raises, XFAIL, skip + +def make_random_problem(num_variables=2, num_constraints=2, sparsity=.1, rational=True, + disable_strict = False, disable_nonstrict=False, disable_equality=False): + def rand(sparsity=sparsity): + if random() < sparsity: + return sympify(0) + if rational: + int1, int2 = [randprime(0, 50) for _ in range(2)] + return Rational(int1, int2) * choice([-1, 1]) + else: + return randint(1, 10) * choice([-1, 1]) + + variables = symbols('x1:%s' % (num_variables + 1)) + constraints = [] + for _ in range(num_constraints): + lhs, rhs = sum(rand() * x for x in variables), rand(sparsity=0) # sparsity=0 bc of bug with smtlib_code + options = [] + if not disable_equality: + options += [Eq(lhs, rhs)] + if not disable_nonstrict: + options += [lhs <= rhs, lhs >= rhs] + if not disable_strict: + options += [lhs < rhs, lhs > rhs] + + constraints.append(choice(options)) + + return constraints + +def check_if_satisfiable_with_z3(constraints): + from sympy.external.importtools import import_module + from sympy.printing.smtlib import smtlib_code + from sympy.logic.boolalg import And + boolean_formula = And(*constraints) + z3 = import_module("z3") + if z3: + smtlib_string = smtlib_code(boolean_formula) + s = z3.Solver() + s.from_string(smtlib_string) + res = str(s.check()) + if res == 'sat': + return True + elif res == 'unsat': + return False + else: + raise ValueError(f"z3 was not able to check the satisfiability of {boolean_formula}") + +def find_rational_assignment(constr, assignment, iter=20): + eps = sympify(1) + + for _ in range(iter): + assign = {key: val[0] + val[1]*eps for key, val in assignment.items()} + try: + for cons in constr: + assert cons.subs(assign) == True + return assign + except AssertionError: + eps = eps/2 + + return None + +def boolean_formula_to_encoded_cnf(bf): + cnf = CNF.from_prop(bf) + enc = EncodedCNF() + enc.from_cnf(cnf) + return enc + + +def test_from_encoded_cnf(): + s1, s2 = symbols("s1 s2") + + # Test preprocessing + # Example is from section 3 of paper. + phi = (x >= 0) & ((x + y <= 2) | (x + 2 * y - z >= 6)) & (Eq(x + y, 2) | (x + 2 * y - z > 4)) + enc = boolean_formula_to_encoded_cnf(phi) + lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True) + assert lra.A.shape == (2, 5) + assert str(lra.slack) == '[_s1, _s2]' + assert str(lra.nonslack) == '[x, y, z]' + assert lra.A == Matrix([[ 1, 1, 0, -1, 0], + [-1, -2, 1, 0, -1]]) + assert {(str(b.var), b.bound, b.upper, b.equality, b.strict) for b in lra.enc_to_boundary.values()} == {('_s1', 2, None, True, False), + ('_s1', 2, True, False, False), + ('_s2', -4, True, False, True), + ('_s2', -6, True, False, False), + ('x', 0, False, False, False)} + + +def test_problem(): + from sympy.logic.algorithms.lra_theory import LRASolver + from sympy.assumptions.cnf import CNF, EncodedCNF + cons = [-2 * x - 2 * y >= 7, -9 * y >= 7, -6 * y >= 5] + cnf = CNF().from_prop(And(*cons)) + enc = EncodedCNF() + enc.from_cnf(cnf) + lra, _ = LRASolver.from_encoded_cnf(enc) + lra.assert_lit(1) + lra.assert_lit(2) + lra.assert_lit(3) + is_sat, assignment = lra.check() + assert is_sat is True + + +def test_random_problems(): + z3 = import_module("z3") + if z3 is None: + skip("z3 is not installed") + + special_cases = []; x1, x2, x3 = symbols("x1 x2 x3") + special_cases.append([x1 - 3 * x2 <= -5, 6 * x1 + 4 * x2 <= 0, -7 * x1 + 3 * x2 <= 3]) + special_cases.append([-3 * x1 >= 3, Eq(4 * x1, -1)]) + special_cases.append([-4 * x1 < 4, 6 * x1 <= -6]) + special_cases.append([-3 * x2 >= 7, 6 * x1 <= -5, -3 * x2 <= -4]) + special_cases.append([x + y >= 2, x + y <= 1]) + special_cases.append([x >= 0, x + y <= 2, x + 2 * y - z >= 6]) # from paper example + special_cases.append([-2 * x1 - 2 * x2 >= 7, -9 * x1 >= 7, -6 * x1 >= 5]) + special_cases.append([2 * x1 > -3, -9 * x1 < -6, 9 * x1 <= 6]) + special_cases.append([-2*x1 < -4, 9*x1 > -9]) + special_cases.append([-6*x1 >= -1, -8*x1 + x2 >= 5, -8*x1 + 7*x2 < 4, x1 > 7]) + special_cases.append([Eq(x1, 2), Eq(5*x1, -2), Eq(-7*x2, -6), Eq(9*x1 + 10*x2, 9)]) + special_cases.append([Eq(3*x1, 6), Eq(x1 - 8*x2, -9), Eq(-7*x1 + 5*x2, 3), Eq(3*x2, 7)]) + special_cases.append([-4*x1 < 4, 6*x1 <= -6]) + special_cases.append([-3*x1 + 8*x2 >= -8, -10*x2 > 9, 8*x1 - 4*x2 < 8, 10*x1 - 9*x2 >= -9]) + special_cases.append([x1 + 5*x2 >= -6, 9*x1 - 3*x2 >= -9, 6*x1 + 6*x2 < -10, -3*x1 + 3*x2 < -7]) + special_cases.append([-9*x1 < 7, -5*x1 - 7*x2 < -1, 3*x1 + 7*x2 > 1, -6*x1 - 6*x2 > 9]) + special_cases.append([9*x1 - 6*x2 >= -7, 9*x1 + 4*x2 < -8, -7*x2 <= 1, 10*x2 <= -7]) + + feasible_count = 0 + for i in range(50): + if i % 8 == 0: + constraints = make_random_problem(num_variables=1, num_constraints=2, rational=False) + elif i % 8 == 1: + constraints = make_random_problem(num_variables=2, num_constraints=4, rational=False, disable_equality=True, + disable_nonstrict=True) + elif i % 8 == 2: + constraints = make_random_problem(num_variables=2, num_constraints=4, rational=False, disable_strict=True) + elif i % 8 == 3: + constraints = make_random_problem(num_variables=3, num_constraints=12, rational=False) + else: + constraints = make_random_problem(num_variables=3, num_constraints=6, rational=False) + + if i < len(special_cases): + constraints = special_cases[i] + + if False in constraints or True in constraints: + continue + + phi = And(*constraints) + if phi == False: + continue + cnf = CNF.from_prop(phi); enc = EncodedCNF() + enc.from_cnf(cnf) + assert all(0 not in clause for clause in enc.data) + + lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True) + s_subs = lra.s_subs + + lra.run_checks = True + s_subs_rev = {value: key for key, value in s_subs.items()} + lits = {lit for clause in enc.data for lit in clause} + + bounds = [(lra.enc_to_boundary[l], l) for l in lits if l in lra.enc_to_boundary] + bounds = sorted(bounds, key=lambda x: (str(x[0].var), x[0].bound, str(x[0].upper))) # to remove nondeterminism + + for b, l in bounds: + if lra.result and lra.result[0] == False: + break + lra.assert_lit(l) + + feasible = lra.check() + + if feasible[0] == True: + feasible_count += 1 + assert check_if_satisfiable_with_z3(constraints) is True + cons_funcs = [cons.func for cons in constraints] + assignment = feasible[1] + assignment = {key.var : value for key, value in assignment.items()} + if not (StrictLessThan in cons_funcs or StrictGreaterThan in cons_funcs): + assignment = {key: value[0] for key, value in assignment.items()} + for cons in constraints: + assert cons.subs(assignment) == True + + else: + rat_assignment = find_rational_assignment(constraints, assignment) + assert rat_assignment is not None + else: + assert check_if_satisfiable_with_z3(constraints) is False + + conflict = feasible[1] + assert len(conflict) >= 2 + conflict = {lra.enc_to_boundary[-l].get_inequality() for l in conflict} + conflict = {clause.subs(s_subs_rev) for clause in conflict} + assert check_if_satisfiable_with_z3(conflict) is False + + # check that conflict clause is probably minimal + for subset in itertools.combinations(conflict, len(conflict)-1): + assert check_if_satisfiable_with_z3(subset) is True + + +@XFAIL +def test_pos_neg_zero(): + bf = Q.positive(x) & Q.negative(x) & Q.zero(y) + enc = boolean_formula_to_encoded_cnf(bf) + lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True) + for lit in enc.encoding.values(): + if lra.assert_lit(lit) is not None: + break + assert len(lra.enc_to_boundary) == 3 + assert lra.check()[0] == False + + bf = Q.positive(x) & Q.lt(x, -1) + enc = boolean_formula_to_encoded_cnf(bf) + lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True) + for lit in enc.encoding.values(): + if lra.assert_lit(lit) is not None: + break + assert len(lra.enc_to_boundary) == 2 + assert lra.check()[0] == False + + bf = Q.positive(x) & Q.zero(x) + enc = boolean_formula_to_encoded_cnf(bf) + lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True) + for lit in enc.encoding.values(): + if lra.assert_lit(lit) is not None: + break + assert len(lra.enc_to_boundary) == 2 + assert lra.check()[0] == False + + bf = Q.positive(x) & Q.zero(y) + enc = boolean_formula_to_encoded_cnf(bf) + lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True) + for lit in enc.encoding.values(): + if lra.assert_lit(lit) is not None: + break + assert len(lra.enc_to_boundary) == 2 + assert lra.check()[0] == True + + +@XFAIL +def test_pos_neg_infinite(): + bf = Q.positive_infinite(x) & Q.lt(x, 10000000) & Q.positive_infinite(y) + enc = boolean_formula_to_encoded_cnf(bf) + lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True) + for lit in enc.encoding.values(): + if lra.assert_lit(lit) is not None: + break + assert len(lra.enc_to_boundary) == 3 + assert lra.check()[0] == False + + bf = Q.positive_infinite(x) & Q.gt(x, 10000000) & Q.positive_infinite(y) + enc = boolean_formula_to_encoded_cnf(bf) + lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True) + for lit in enc.encoding.values(): + if lra.assert_lit(lit) is not None: + break + assert len(lra.enc_to_boundary) == 3 + assert lra.check()[0] == True + + bf = Q.positive_infinite(x) & Q.negative_infinite(x) + enc = boolean_formula_to_encoded_cnf(bf) + lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True) + for lit in enc.encoding.values(): + if lra.assert_lit(lit) is not None: + break + assert len(lra.enc_to_boundary) == 2 + assert lra.check()[0] == False + + +def test_binrel_evaluation(): + bf = Q.gt(3, 2) + enc = boolean_formula_to_encoded_cnf(bf) + lra, conflicts = LRASolver.from_encoded_cnf(enc, testing_mode=True) + assert len(lra.enc_to_boundary) == 0 + assert conflicts == [[1]] + + bf = Q.lt(3, 2) + enc = boolean_formula_to_encoded_cnf(bf) + lra, conflicts = LRASolver.from_encoded_cnf(enc, testing_mode=True) + assert len(lra.enc_to_boundary) == 0 + assert conflicts == [[-1]] + + +def test_negation(): + assert HANDLE_NEGATION is True + bf = Q.gt(x, 1) & ~Q.gt(x, 0) + enc = boolean_formula_to_encoded_cnf(bf) + lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True) + for clause in enc.data: + for lit in clause: + lra.assert_lit(lit) + assert len(lra.enc_to_boundary) == 2 + assert lra.check()[0] == False + assert sorted(lra.check()[1]) in [[-1, 2], [-2, 1]] + + bf = ~Q.gt(x, 1) & ~Q.lt(x, 0) + enc = boolean_formula_to_encoded_cnf(bf) + lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True) + for clause in enc.data: + for lit in clause: + lra.assert_lit(lit) + assert len(lra.enc_to_boundary) == 2 + assert lra.check()[0] == True + + bf = ~Q.gt(x, 0) & ~Q.lt(x, 1) + enc = boolean_formula_to_encoded_cnf(bf) + lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True) + for clause in enc.data: + for lit in clause: + lra.assert_lit(lit) + assert len(lra.enc_to_boundary) == 2 + assert lra.check()[0] == False + + bf = ~Q.gt(x, 0) & ~Q.le(x, 0) + enc = boolean_formula_to_encoded_cnf(bf) + lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True) + for clause in enc.data: + for lit in clause: + lra.assert_lit(lit) + assert len(lra.enc_to_boundary) == 2 + assert lra.check()[0] == False + + bf = ~Q.le(x+y, 2) & ~Q.ge(x-y, 2) & ~Q.ge(y, 0) + enc = boolean_formula_to_encoded_cnf(bf) + lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True) + for clause in enc.data: + for lit in clause: + lra.assert_lit(lit) + assert len(lra.enc_to_boundary) == 3 + assert lra.check()[0] == False + assert len(lra.check()[1]) == 3 + assert all(i > 0 for i in lra.check()[1]) + + +def test_unhandled_input(): + nan = S.NaN + bf = Q.gt(3, nan) & Q.gt(x, nan) + enc = boolean_formula_to_encoded_cnf(bf) + raises(ValueError, lambda: LRASolver.from_encoded_cnf(enc, testing_mode=True)) + + bf = Q.gt(3, I) & Q.gt(x, I) + enc = boolean_formula_to_encoded_cnf(bf) + raises(UnhandledInput, lambda: LRASolver.from_encoded_cnf(enc, testing_mode=True)) + + bf = Q.gt(3, float("inf")) & Q.gt(x, float("inf")) + enc = boolean_formula_to_encoded_cnf(bf) + raises(UnhandledInput, lambda: LRASolver.from_encoded_cnf(enc, testing_mode=True)) + + bf = Q.gt(3, oo) & Q.gt(x, oo) + enc = boolean_formula_to_encoded_cnf(bf) + raises(UnhandledInput, lambda: LRASolver.from_encoded_cnf(enc, testing_mode=True)) + + # test non-linearity + bf = Q.gt(x**2 + x, 2) + enc = boolean_formula_to_encoded_cnf(bf) + raises(UnhandledInput, lambda: LRASolver.from_encoded_cnf(enc, testing_mode=True)) + + bf = Q.gt(cos(x) + x, 2) + enc = boolean_formula_to_encoded_cnf(bf) + raises(UnhandledInput, lambda: LRASolver.from_encoded_cnf(enc, testing_mode=True)) + +@XFAIL +def test_infinite_strict_inequalities(): + # Extensive testing of the interaction between strict inequalities + # and constraints containing infinity is needed because + # the paper's rule for strict inequalities don't work when + # infinite numbers are allowed. Using the paper's rules you + # can end up with situations where oo + delta > oo is considered + # True when oo + delta should be equal to oo. + # See https://math.stackexchange.com/questions/4757069/can-this-method-of-converting-strict-inequalities-to-equisatisfiable-nonstrict-i + bf = (-x - y >= -float("inf")) & (x > 0) & (y >= float("inf")) + enc = boolean_formula_to_encoded_cnf(bf) + lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True) + for lit in sorted(enc.encoding.values()): + if lra.assert_lit(lit) is not None: + break + assert len(lra.enc_to_boundary) == 3 + assert lra.check()[0] == True + + +def test_pivot(): + for _ in range(10): + m = randMatrix(5) + rref = m.rref() + for _ in range(5): + i, j = randint(0, 4), randint(0, 4) + if m[i, j] != 0: + assert LRASolver._pivot(m, i, j).rref() == rref + + +def test_reset_bounds(): + bf = Q.ge(x, 1) & Q.lt(x, 1) + enc = boolean_formula_to_encoded_cnf(bf) + lra, _ = LRASolver.from_encoded_cnf(enc, testing_mode=True) + for clause in enc.data: + for lit in clause: + lra.assert_lit(lit) + assert len(lra.enc_to_boundary) == 2 + assert lra.check()[0] == False + + lra.reset_bounds() + assert lra.check()[0] == True + for var in lra.all_var: + assert var.upper == LRARational(float("inf"), 0) + assert var.upper_from_eq == False + assert var.upper_from_neg == False + assert var.lower == LRARational(-float("inf"), 0) + assert var.lower_from_eq == False + assert var.lower_from_neg == False + assert var.assign == LRARational(0, 0) + assert var.var is not None + assert var.col_idx is not None + + +def test_empty_cnf(): + cnf = CNF() + enc = EncodedCNF() + enc.from_cnf(cnf) + lra, conflict = LRASolver.from_encoded_cnf(enc) + assert len(conflict) == 0 + assert lra.check() == (True, {}) diff --git a/phivenv/Lib/site-packages/sympy/logic/utilities/__init__.py b/phivenv/Lib/site-packages/sympy/logic/utilities/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3526c3e53d624bc70afe2df05f123c835781364c --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/logic/utilities/__init__.py @@ -0,0 +1,3 @@ +from .dimacs import load_file + +__all__ = ['load_file'] diff --git a/phivenv/Lib/site-packages/sympy/logic/utilities/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/logic/utilities/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6a209568bf1acb5ada50b911422dddb5172a249 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/logic/utilities/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/logic/utilities/__pycache__/dimacs.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/logic/utilities/__pycache__/dimacs.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14838eb9decd6bcdf29206d4a1900193b7a6c8dc Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/logic/utilities/__pycache__/dimacs.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/logic/utilities/dimacs.py b/phivenv/Lib/site-packages/sympy/logic/utilities/dimacs.py new file mode 100644 index 0000000000000000000000000000000000000000..51302d8052c8ed8443239c1e21a2f063cf34e4ab --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/logic/utilities/dimacs.py @@ -0,0 +1,69 @@ +"""For reading in DIMACS file format + +www.cs.ubc.ca/~hoos/SATLIB/Benchmarks/SAT/satformat.ps + +""" + +from sympy.core import Symbol +from sympy.logic.boolalg import And, Or +import re +from pathlib import Path + + +def load(s): + """Loads a boolean expression from a string. + + Examples + ======== + + >>> from sympy.logic.utilities.dimacs import load + >>> load('1') + cnf_1 + >>> load('1 2') + cnf_1 | cnf_2 + >>> load('1 \\n 2') + cnf_1 & cnf_2 + >>> load('1 2 \\n 3') + cnf_3 & (cnf_1 | cnf_2) + """ + clauses = [] + + lines = s.split('\n') + + pComment = re.compile(r'c.*') + pStats = re.compile(r'p\s*cnf\s*(\d*)\s*(\d*)') + + while len(lines) > 0: + line = lines.pop(0) + + # Only deal with lines that aren't comments + if not pComment.match(line): + m = pStats.match(line) + + if not m: + nums = line.rstrip('\n').split(' ') + list = [] + for lit in nums: + if lit != '': + if int(lit) == 0: + continue + num = abs(int(lit)) + sign = True + if int(lit) < 0: + sign = False + + if sign: + list.append(Symbol("cnf_%s" % num)) + else: + list.append(~Symbol("cnf_%s" % num)) + + if len(list) > 0: + clauses.append(Or(*list)) + + return And(*clauses) + + +def load_file(location): + """Loads a boolean expression from a file.""" + s = Path(location).read_text() + return load(s) diff --git a/phivenv/Lib/site-packages/sympy/matrices/__init__.py b/phivenv/Lib/site-packages/sympy/matrices/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..37b558f3f03f149dae6af20254e9b88192f7f1ed --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/__init__.py @@ -0,0 +1,72 @@ +"""A module that handles matrices. + +Includes functions for fast creating matrices like zero, one/eye, random +matrix, etc. +""" +from .exceptions import ShapeError, NonSquareMatrixError +from .kind import MatrixKind +from .dense import ( + GramSchmidt, casoratian, diag, eye, hessian, jordan_cell, + list2numpy, matrix2numpy, matrix_multiply_elementwise, ones, + randMatrix, rot_axis1, rot_axis2, rot_axis3, rot_ccw_axis1, + rot_ccw_axis2, rot_ccw_axis3, rot_givens, + symarray, wronskian, zeros) +from .dense import MutableDenseMatrix +from .matrixbase import DeferredVector, MatrixBase + +MutableMatrix = MutableDenseMatrix +Matrix = MutableMatrix + +from .sparse import MutableSparseMatrix +from .sparsetools import banded +from .immutable import ImmutableDenseMatrix, ImmutableSparseMatrix + +ImmutableMatrix = ImmutableDenseMatrix +SparseMatrix = MutableSparseMatrix + +from .expressions import ( + MatrixSlice, BlockDiagMatrix, BlockMatrix, FunctionMatrix, Identity, + Inverse, MatAdd, MatMul, MatPow, MatrixExpr, MatrixSymbol, Trace, + Transpose, ZeroMatrix, OneMatrix, blockcut, block_collapse, matrix_symbols, Adjoint, + hadamard_product, HadamardProduct, HadamardPower, Determinant, det, + diagonalize_vector, DiagMatrix, DiagonalMatrix, DiagonalOf, trace, + DotProduct, kronecker_product, KroneckerProduct, + PermutationMatrix, MatrixPermute, MatrixSet, Permanent, per) + +from .utilities import dotprodsimp + +__all__ = [ + 'ShapeError', 'NonSquareMatrixError', 'MatrixKind', + + 'GramSchmidt', 'casoratian', 'diag', 'eye', 'hessian', 'jordan_cell', + 'list2numpy', 'matrix2numpy', 'matrix_multiply_elementwise', 'ones', + 'randMatrix', 'rot_axis1', 'rot_axis2', 'rot_axis3', 'symarray', + 'wronskian', 'zeros', 'rot_ccw_axis1', 'rot_ccw_axis2', 'rot_ccw_axis3', + 'rot_givens', + + 'MutableDenseMatrix', + + 'DeferredVector', 'MatrixBase', + + 'Matrix', 'MutableMatrix', + + 'MutableSparseMatrix', + + 'banded', + + 'ImmutableDenseMatrix', 'ImmutableSparseMatrix', + + 'ImmutableMatrix', 'SparseMatrix', + + 'MatrixSlice', 'BlockDiagMatrix', 'BlockMatrix', 'FunctionMatrix', + 'Identity', 'Inverse', 'MatAdd', 'MatMul', 'MatPow', 'MatrixExpr', + 'MatrixSymbol', 'Trace', 'Transpose', 'ZeroMatrix', 'OneMatrix', + 'blockcut', 'block_collapse', 'matrix_symbols', 'Adjoint', + 'hadamard_product', 'HadamardProduct', 'HadamardPower', 'Determinant', + 'det', 'diagonalize_vector', 'DiagMatrix', 'DiagonalMatrix', + 'DiagonalOf', 'trace', 'DotProduct', 'kronecker_product', + 'KroneckerProduct', 'PermutationMatrix', 'MatrixPermute', 'MatrixSet', + 'Permanent', 'per', + + 'dotprodsimp', +] diff --git a/phivenv/Lib/site-packages/sympy/matrices/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba5fc33a1a1e419a60ef7006657441979876f475 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/__pycache__/decompositions.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/__pycache__/decompositions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1adbc9decbce600e3bfebf7e8b2b0ed6bf3f171e Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/__pycache__/decompositions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/__pycache__/dense.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/__pycache__/dense.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ceb17a7d06a503abf8a51a2701f81e13a6523e06 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/__pycache__/dense.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/__pycache__/determinant.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/__pycache__/determinant.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b174aa3c0f9457157774eb0ad015b9cda2915f82 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/__pycache__/determinant.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/__pycache__/eigen.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/__pycache__/eigen.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69353c6b955fc7c43de05f156d4414a1860f44d8 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/__pycache__/eigen.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/__pycache__/exceptions.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/__pycache__/exceptions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ad78915ca2916bc52df9fec22d05091b0c7a77f Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/__pycache__/exceptions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/__pycache__/graph.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/__pycache__/graph.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7df958ca45855fb32c46c5b9c6d51b61a77d653d Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/__pycache__/graph.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/__pycache__/immutable.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/__pycache__/immutable.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6cd9373bdd7cc40471f5f826d46ebb9a7a691fe Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/__pycache__/immutable.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/__pycache__/inverse.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/__pycache__/inverse.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e15904aa25b22049564076a1a6c35327471b193d Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/__pycache__/inverse.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/__pycache__/kind.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/__pycache__/kind.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b74c09a3697b32e0bf49a3bf48174d94c48c8b6 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/__pycache__/kind.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/__pycache__/matrices.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/__pycache__/matrices.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7503d9ec6b1fe91660b227b4bfa72f85e3741653 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/__pycache__/matrices.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/__pycache__/normalforms.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/__pycache__/normalforms.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ae643aeb2bdf4affb73d41431bfee1568512faf1 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/__pycache__/normalforms.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/__pycache__/reductions.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/__pycache__/reductions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..587a37b545952aeab142765ac0c1dfefb5e410ad Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/__pycache__/reductions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/__pycache__/repmatrix.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/__pycache__/repmatrix.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca28317dc71a9d001b026a17b689c748cbf0bfee Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/__pycache__/repmatrix.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/__pycache__/solvers.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/__pycache__/solvers.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..758005791bc0177c36a69a1e134eaf8141a44b6a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/__pycache__/solvers.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/__pycache__/sparse.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/__pycache__/sparse.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bff0791c58582e8029b62c5e2d10beee025bd864 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/__pycache__/sparse.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/__pycache__/sparsetools.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/__pycache__/sparsetools.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..290b493ec7e8acae2f32ad8515a23fae62688f69 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/__pycache__/sparsetools.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/__pycache__/subspaces.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/__pycache__/subspaces.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..75a5a5c6632cab6983675ad793e76056f62b653a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/__pycache__/subspaces.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/__pycache__/utilities.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/__pycache__/utilities.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f7f2cc19004ff400c99a01721cb19e0e5a0110c Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/__pycache__/utilities.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/benchmarks/__init__.py b/phivenv/Lib/site-packages/sympy/matrices/benchmarks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/matrices/benchmarks/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/benchmarks/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2dcdb9b1c4e38c05a8ee6b8e987c68fc67d4a3f2 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/benchmarks/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/benchmarks/__pycache__/bench_matrix.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/benchmarks/__pycache__/bench_matrix.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9d8cb887d16793fc3fe0daf0e0f8d8e62880d76 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/benchmarks/__pycache__/bench_matrix.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/benchmarks/bench_matrix.py b/phivenv/Lib/site-packages/sympy/matrices/benchmarks/bench_matrix.py new file mode 100644 index 0000000000000000000000000000000000000000..4fb845600533c4c6fef196fe5a45b98890f4ad78 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/benchmarks/bench_matrix.py @@ -0,0 +1,21 @@ +from sympy.core.numbers import Integer +from sympy.matrices.dense import (eye, zeros) + +i3 = Integer(3) +M = eye(100) + + +def timeit_Matrix__getitem_ii(): + M[3, 3] + + +def timeit_Matrix__getitem_II(): + M[i3, i3] + + +def timeit_Matrix__getslice(): + M[:, :] + + +def timeit_Matrix_zeronm(): + zeros(100, 100) diff --git a/phivenv/Lib/site-packages/sympy/matrices/common.py b/phivenv/Lib/site-packages/sympy/matrices/common.py new file mode 100644 index 0000000000000000000000000000000000000000..bcb54726fe1a0c36658d8bf63b974db5a3ce8bad --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/common.py @@ -0,0 +1,3258 @@ +""" +A module containing deprecated matrix mixin classes. + +The classes in this module are deprecated and will be removed in a future +release. They are kept here for backwards compatibility in case downstream +code was subclassing them. + +Importing anything else from this module is deprecated so anything here +should either not be used or should be imported from somewhere else. +""" +from __future__ import annotations +from collections import defaultdict +from collections.abc import Iterable +from inspect import isfunction +from functools import reduce + +from sympy.assumptions.refine import refine +from sympy.core import SympifyError, Add +from sympy.core.basic import Atom +from sympy.core.decorators import call_highest_priority +from sympy.core.logic import fuzzy_and, FuzzyBool +from sympy.core.numbers import Integer +from sympy.core.mod import Mod +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.core.sympify import sympify +from sympy.functions.elementary.complexes import Abs, re, im +from sympy.utilities.exceptions import sympy_deprecation_warning +from .utilities import _dotprodsimp, _simplify +from sympy.polys.polytools import Poly +from sympy.utilities.iterables import flatten, is_sequence +from sympy.utilities.misc import as_int, filldedent +from sympy.tensor.array import NDimArray + +from .utilities import _get_intermediate_simp_bool + + +# These exception types were previously defined in this module but were moved +# to exceptions.py. We reimport them here for backwards compatibility in case +# downstream code was importing them from here. +from .exceptions import ( # noqa: F401 + MatrixError, ShapeError, NonSquareMatrixError, NonInvertibleMatrixError, + NonPositiveDefiniteMatrixError +) + + +_DEPRECATED_MIXINS = ( + 'MatrixShaping', + 'MatrixSpecial', + 'MatrixProperties', + 'MatrixOperations', + 'MatrixArithmetic', + 'MatrixCommon', + 'MatrixDeterminant', + 'MatrixReductions', + 'MatrixSubspaces', + 'MatrixEigen', + 'MatrixCalculus', + 'MatrixDeprecated', +) + + +class _MatrixDeprecatedMeta(type): + + # + # Override the default __instancecheck__ implementation to ensure that + # e.g. isinstance(M, MatrixCommon) still works when M is one of the + # matrix classes. Matrix no longer inherits from MatrixCommon so + # isinstance(M, MatrixCommon) would now return False by default. + # + # There were lots of places in the codebase where this was being done + # so it seems likely that downstream code may be doing it too. All use + # of these mixins is deprecated though so we give a deprecation warning + # unconditionally if they are being used with isinstance. + # + # Any code seeing this deprecation warning should be changed to use + # isinstance(M, MatrixBase) instead which also works in previous versions + # of SymPy. + # + + def __instancecheck__(cls, instance): + + sympy_deprecation_warning( + f""" + Checking whether an object is an instance of {cls.__name__} is + deprecated. + + Use `isinstance(obj, Matrix)` instead of `isinstance(obj, {cls.__name__})`. + """, + deprecated_since_version="1.13", + active_deprecations_target="deprecated-matrix-mixins", + stacklevel=3, + ) + + from sympy.matrices.matrixbase import MatrixBase + from sympy.matrices.matrices import ( + MatrixDeterminant, + MatrixReductions, + MatrixSubspaces, + MatrixEigen, + MatrixCalculus, + MatrixDeprecated + ) + + all_mixins = ( + MatrixRequired, + MatrixShaping, + MatrixSpecial, + MatrixProperties, + MatrixOperations, + MatrixArithmetic, + MatrixCommon, + MatrixDeterminant, + MatrixReductions, + MatrixSubspaces, + MatrixEigen, + MatrixCalculus, + MatrixDeprecated + ) + + if cls in all_mixins and isinstance(instance, MatrixBase): + return True + else: + return super().__instancecheck__(instance) + + +class MatrixRequired(metaclass=_MatrixDeprecatedMeta): + """Deprecated mixin class for making matrix classes.""" + + rows: int + cols: int + _simplify = None + + def __init_subclass__(cls, **kwargs): + + # Warn if any downstream code is subclassing this class or any of the + # deprecated mixin classes that are all ultimately subclasses of this + # class. + # + # We don't want to warn about the deprecated mixins themselves being + # created, but only about them being used as mixins by downstream code. + # Otherwise just importing this module would trigger a warning. + # Ultimately the whole module should be deprecated and removed but for + # SymPy 1.13 it is premature to do that given that this module was the + # main way to import matrix exception types in all previous versions. + + if cls.__name__ not in _DEPRECATED_MIXINS: + sympy_deprecation_warning( + f""" + Inheriting from the Matrix mixin classes is deprecated. + + The class {cls.__name__} is subclassing a deprecated mixin. + """, + deprecated_since_version="1.13", + active_deprecations_target="deprecated-matrix-mixins", + stacklevel=3, + ) + + super().__init_subclass__(**kwargs) + + @classmethod + def _new(cls, *args, **kwargs): + """`_new` must, at minimum, be callable as + `_new(rows, cols, mat) where mat is a flat list of the + elements of the matrix.""" + raise NotImplementedError("Subclasses must implement this.") + + def __eq__(self, other): + raise NotImplementedError("Subclasses must implement this.") + + def __getitem__(self, key): + """Implementations of __getitem__ should accept ints, in which + case the matrix is indexed as a flat list, tuples (i,j) in which + case the (i,j) entry is returned, slices, or mixed tuples (a,b) + where a and b are any combination of slices and integers.""" + raise NotImplementedError("Subclasses must implement this.") + + def __len__(self): + """The total number of entries in the matrix.""" + raise NotImplementedError("Subclasses must implement this.") + + @property + def shape(self): + raise NotImplementedError("Subclasses must implement this.") + + +class MatrixShaping(MatrixRequired): + """Provides basic matrix shaping and extracting of submatrices""" + + def _eval_col_del(self, col): + def entry(i, j): + return self[i, j] if j < col else self[i, j + 1] + return self._new(self.rows, self.cols - 1, entry) + + def _eval_col_insert(self, pos, other): + + def entry(i, j): + if j < pos: + return self[i, j] + elif pos <= j < pos + other.cols: + return other[i, j - pos] + return self[i, j - other.cols] + + return self._new(self.rows, self.cols + other.cols, entry) + + def _eval_col_join(self, other): + rows = self.rows + + def entry(i, j): + if i < rows: + return self[i, j] + return other[i - rows, j] + + return classof(self, other)._new(self.rows + other.rows, self.cols, + entry) + + def _eval_extract(self, rowsList, colsList): + mat = list(self) + cols = self.cols + indices = (i * cols + j for i in rowsList for j in colsList) + return self._new(len(rowsList), len(colsList), + [mat[i] for i in indices]) + + def _eval_get_diag_blocks(self): + sub_blocks = [] + + def recurse_sub_blocks(M): + for i in range(1, M.shape[0] + 1): + if i == 1: + to_the_right = M[0, i:] + to_the_bottom = M[i:, 0] + else: + to_the_right = M[:i, i:] + to_the_bottom = M[i:, :i] + if any(to_the_right) or any(to_the_bottom): + continue + sub_blocks.append(M[:i, :i]) + if M.shape != M[:i, :i].shape: + recurse_sub_blocks(M[i:, i:]) + return + + recurse_sub_blocks(self) + return sub_blocks + + def _eval_row_del(self, row): + def entry(i, j): + return self[i, j] if i < row else self[i + 1, j] + return self._new(self.rows - 1, self.cols, entry) + + def _eval_row_insert(self, pos, other): + entries = list(self) + insert_pos = pos * self.cols + entries[insert_pos:insert_pos] = list(other) + return self._new(self.rows + other.rows, self.cols, entries) + + def _eval_row_join(self, other): + cols = self.cols + + def entry(i, j): + if j < cols: + return self[i, j] + return other[i, j - cols] + + return classof(self, other)._new(self.rows, self.cols + other.cols, + entry) + + def _eval_tolist(self): + return [list(self[i,:]) for i in range(self.rows)] + + def _eval_todok(self): + dok = {} + rows, cols = self.shape + for i in range(rows): + for j in range(cols): + val = self[i, j] + if val != self.zero: + dok[i, j] = val + return dok + + def _eval_vec(self): + rows = self.rows + + def entry(n, _): + # we want to read off the columns first + j = n // rows + i = n - j * rows + return self[i, j] + + return self._new(len(self), 1, entry) + + def _eval_vech(self, diagonal): + c = self.cols + v = [] + if diagonal: + for j in range(c): + for i in range(j, c): + v.append(self[i, j]) + else: + for j in range(c): + for i in range(j + 1, c): + v.append(self[i, j]) + return self._new(len(v), 1, v) + + def col_del(self, col): + """Delete the specified column.""" + if col < 0: + col += self.cols + if not 0 <= col < self.cols: + raise IndexError("Column {} is out of range.".format(col)) + return self._eval_col_del(col) + + def col_insert(self, pos, other): + """Insert one or more columns at the given column position. + + Examples + ======== + + >>> from sympy import zeros, ones + >>> M = zeros(3) + >>> V = ones(3, 1) + >>> M.col_insert(1, V) + Matrix([ + [0, 1, 0, 0], + [0, 1, 0, 0], + [0, 1, 0, 0]]) + + See Also + ======== + + col + row_insert + """ + # Allows you to build a matrix even if it is null matrix + if not self: + return type(self)(other) + + pos = as_int(pos) + + if pos < 0: + pos = self.cols + pos + if pos < 0: + pos = 0 + elif pos > self.cols: + pos = self.cols + + if self.rows != other.rows: + raise ShapeError( + "The matrices have incompatible number of rows ({} and {})" + .format(self.rows, other.rows)) + + return self._eval_col_insert(pos, other) + + def col_join(self, other): + """Concatenates two matrices along self's last and other's first row. + + Examples + ======== + + >>> from sympy import zeros, ones + >>> M = zeros(3) + >>> V = ones(1, 3) + >>> M.col_join(V) + Matrix([ + [0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [1, 1, 1]]) + + See Also + ======== + + col + row_join + """ + # A null matrix can always be stacked (see #10770) + if self.rows == 0 and self.cols != other.cols: + return self._new(0, other.cols, []).col_join(other) + + if self.cols != other.cols: + raise ShapeError( + "The matrices have incompatible number of columns ({} and {})" + .format(self.cols, other.cols)) + return self._eval_col_join(other) + + def col(self, j): + """Elementary column selector. + + Examples + ======== + + >>> from sympy import eye + >>> eye(2).col(0) + Matrix([ + [1], + [0]]) + + See Also + ======== + + row + col_del + col_join + col_insert + """ + return self[:, j] + + def extract(self, rowsList, colsList): + r"""Return a submatrix by specifying a list of rows and columns. + Negative indices can be given. All indices must be in the range + $-n \le i < n$ where $n$ is the number of rows or columns. + + Examples + ======== + + >>> from sympy import Matrix + >>> m = Matrix(4, 3, range(12)) + >>> m + Matrix([ + [0, 1, 2], + [3, 4, 5], + [6, 7, 8], + [9, 10, 11]]) + >>> m.extract([0, 1, 3], [0, 1]) + Matrix([ + [0, 1], + [3, 4], + [9, 10]]) + + Rows or columns can be repeated: + + >>> m.extract([0, 0, 1], [-1]) + Matrix([ + [2], + [2], + [5]]) + + Every other row can be taken by using range to provide the indices: + + >>> m.extract(range(0, m.rows, 2), [-1]) + Matrix([ + [2], + [8]]) + + RowsList or colsList can also be a list of booleans, in which case + the rows or columns corresponding to the True values will be selected: + + >>> m.extract([0, 1, 2, 3], [True, False, True]) + Matrix([ + [0, 2], + [3, 5], + [6, 8], + [9, 11]]) + """ + + if not is_sequence(rowsList) or not is_sequence(colsList): + raise TypeError("rowsList and colsList must be iterable") + # ensure rowsList and colsList are lists of integers + if rowsList and all(isinstance(i, bool) for i in rowsList): + rowsList = [index for index, item in enumerate(rowsList) if item] + if colsList and all(isinstance(i, bool) for i in colsList): + colsList = [index for index, item in enumerate(colsList) if item] + + # ensure everything is in range + rowsList = [a2idx(k, self.rows) for k in rowsList] + colsList = [a2idx(k, self.cols) for k in colsList] + + return self._eval_extract(rowsList, colsList) + + def get_diag_blocks(self): + """Obtains the square sub-matrices on the main diagonal of a square matrix. + + Useful for inverting symbolic matrices or solving systems of + linear equations which may be decoupled by having a block diagonal + structure. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.abc import x, y, z + >>> A = Matrix([[1, 3, 0, 0], [y, z*z, 0, 0], [0, 0, x, 0], [0, 0, 0, 0]]) + >>> a1, a2, a3 = A.get_diag_blocks() + >>> a1 + Matrix([ + [1, 3], + [y, z**2]]) + >>> a2 + Matrix([[x]]) + >>> a3 + Matrix([[0]]) + + """ + return self._eval_get_diag_blocks() + + @classmethod + def hstack(cls, *args): + """Return a matrix formed by joining args horizontally (i.e. + by repeated application of row_join). + + Examples + ======== + + >>> from sympy import Matrix, eye + >>> Matrix.hstack(eye(2), 2*eye(2)) + Matrix([ + [1, 0, 2, 0], + [0, 1, 0, 2]]) + """ + if len(args) == 0: + return cls._new() + + kls = type(args[0]) + return reduce(kls.row_join, args) + + def reshape(self, rows, cols): + """Reshape the matrix. Total number of elements must remain the same. + + Examples + ======== + + >>> from sympy import Matrix + >>> m = Matrix(2, 3, lambda i, j: 1) + >>> m + Matrix([ + [1, 1, 1], + [1, 1, 1]]) + >>> m.reshape(1, 6) + Matrix([[1, 1, 1, 1, 1, 1]]) + >>> m.reshape(3, 2) + Matrix([ + [1, 1], + [1, 1], + [1, 1]]) + + """ + if self.rows * self.cols != rows * cols: + raise ValueError("Invalid reshape parameters %d %d" % (rows, cols)) + return self._new(rows, cols, lambda i, j: self[i * cols + j]) + + def row_del(self, row): + """Delete the specified row.""" + if row < 0: + row += self.rows + if not 0 <= row < self.rows: + raise IndexError("Row {} is out of range.".format(row)) + + return self._eval_row_del(row) + + def row_insert(self, pos, other): + """Insert one or more rows at the given row position. + + Examples + ======== + + >>> from sympy import zeros, ones + >>> M = zeros(3) + >>> V = ones(1, 3) + >>> M.row_insert(1, V) + Matrix([ + [0, 0, 0], + [1, 1, 1], + [0, 0, 0], + [0, 0, 0]]) + + See Also + ======== + + row + col_insert + """ + # Allows you to build a matrix even if it is null matrix + if not self: + return self._new(other) + + pos = as_int(pos) + + if pos < 0: + pos = self.rows + pos + if pos < 0: + pos = 0 + elif pos > self.rows: + pos = self.rows + + if self.cols != other.cols: + raise ShapeError( + "The matrices have incompatible number of columns ({} and {})" + .format(self.cols, other.cols)) + + return self._eval_row_insert(pos, other) + + def row_join(self, other): + """Concatenates two matrices along self's last and rhs's first column + + Examples + ======== + + >>> from sympy import zeros, ones + >>> M = zeros(3) + >>> V = ones(3, 1) + >>> M.row_join(V) + Matrix([ + [0, 0, 0, 1], + [0, 0, 0, 1], + [0, 0, 0, 1]]) + + See Also + ======== + + row + col_join + """ + # A null matrix can always be stacked (see #10770) + if self.cols == 0 and self.rows != other.rows: + return self._new(other.rows, 0, []).row_join(other) + + if self.rows != other.rows: + raise ShapeError( + "The matrices have incompatible number of rows ({} and {})" + .format(self.rows, other.rows)) + return self._eval_row_join(other) + + def diagonal(self, k=0): + """Returns the kth diagonal of self. The main diagonal + corresponds to `k=0`; diagonals above and below correspond to + `k > 0` and `k < 0`, respectively. The values of `self[i, j]` + for which `j - i = k`, are returned in order of increasing + `i + j`, starting with `i + j = |k|`. + + Examples + ======== + + >>> from sympy import Matrix + >>> m = Matrix(3, 3, lambda i, j: j - i); m + Matrix([ + [ 0, 1, 2], + [-1, 0, 1], + [-2, -1, 0]]) + >>> _.diagonal() + Matrix([[0, 0, 0]]) + >>> m.diagonal(1) + Matrix([[1, 1]]) + >>> m.diagonal(-2) + Matrix([[-2]]) + + Even though the diagonal is returned as a Matrix, the element + retrieval can be done with a single index: + + >>> Matrix.diag(1, 2, 3).diagonal()[1] # instead of [0, 1] + 2 + + See Also + ======== + + diag + """ + rv = [] + k = as_int(k) + r = 0 if k > 0 else -k + c = 0 if r else k + while True: + if r == self.rows or c == self.cols: + break + rv.append(self[r, c]) + r += 1 + c += 1 + if not rv: + raise ValueError(filldedent(''' + The %s diagonal is out of range [%s, %s]''' % ( + k, 1 - self.rows, self.cols - 1))) + return self._new(1, len(rv), rv) + + def row(self, i): + """Elementary row selector. + + Examples + ======== + + >>> from sympy import eye + >>> eye(2).row(0) + Matrix([[1, 0]]) + + See Also + ======== + + col + row_del + row_join + row_insert + """ + return self[i, :] + + @property + def shape(self): + """The shape (dimensions) of the matrix as the 2-tuple (rows, cols). + + Examples + ======== + + >>> from sympy import zeros + >>> M = zeros(2, 3) + >>> M.shape + (2, 3) + >>> M.rows + 2 + >>> M.cols + 3 + """ + return (self.rows, self.cols) + + def todok(self): + """Return the matrix as dictionary of keys. + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix.eye(3) + >>> M.todok() + {(0, 0): 1, (1, 1): 1, (2, 2): 1} + """ + return self._eval_todok() + + def tolist(self): + """Return the Matrix as a nested Python list. + + Examples + ======== + + >>> from sympy import Matrix, ones + >>> m = Matrix(3, 3, range(9)) + >>> m + Matrix([ + [0, 1, 2], + [3, 4, 5], + [6, 7, 8]]) + >>> m.tolist() + [[0, 1, 2], [3, 4, 5], [6, 7, 8]] + >>> ones(3, 0).tolist() + [[], [], []] + + When there are no rows then it will not be possible to tell how + many columns were in the original matrix: + + >>> ones(0, 3).tolist() + [] + + """ + if not self.rows: + return [] + if not self.cols: + return [[] for i in range(self.rows)] + return self._eval_tolist() + + def todod(M): + """Returns matrix as dict of dicts containing non-zero elements of the Matrix + + Examples + ======== + + >>> from sympy import Matrix + >>> A = Matrix([[0, 1],[0, 3]]) + >>> A + Matrix([ + [0, 1], + [0, 3]]) + >>> A.todod() + {0: {1: 1}, 1: {1: 3}} + + + """ + rowsdict = {} + Mlol = M.tolist() + for i, Mi in enumerate(Mlol): + row = {j: Mij for j, Mij in enumerate(Mi) if Mij} + if row: + rowsdict[i] = row + return rowsdict + + def vec(self): + """Return the Matrix converted into a one column matrix by stacking columns + + Examples + ======== + + >>> from sympy import Matrix + >>> m=Matrix([[1, 3], [2, 4]]) + >>> m + Matrix([ + [1, 3], + [2, 4]]) + >>> m.vec() + Matrix([ + [1], + [2], + [3], + [4]]) + + See Also + ======== + + vech + """ + return self._eval_vec() + + def vech(self, diagonal=True, check_symmetry=True): + """Reshapes the matrix into a column vector by stacking the + elements in the lower triangle. + + Parameters + ========== + + diagonal : bool, optional + If ``True``, it includes the diagonal elements. + + check_symmetry : bool, optional + If ``True``, it checks whether the matrix is symmetric. + + Examples + ======== + + >>> from sympy import Matrix + >>> m=Matrix([[1, 2], [2, 3]]) + >>> m + Matrix([ + [1, 2], + [2, 3]]) + >>> m.vech() + Matrix([ + [1], + [2], + [3]]) + >>> m.vech(diagonal=False) + Matrix([[2]]) + + Notes + ===== + + This should work for symmetric matrices and ``vech`` can + represent symmetric matrices in vector form with less size than + ``vec``. + + See Also + ======== + + vec + """ + if not self.is_square: + raise NonSquareMatrixError + + if check_symmetry and not self.is_symmetric(): + raise ValueError("The matrix is not symmetric.") + + return self._eval_vech(diagonal) + + @classmethod + def vstack(cls, *args): + """Return a matrix formed by joining args vertically (i.e. + by repeated application of col_join). + + Examples + ======== + + >>> from sympy import Matrix, eye + >>> Matrix.vstack(eye(2), 2*eye(2)) + Matrix([ + [1, 0], + [0, 1], + [2, 0], + [0, 2]]) + """ + if len(args) == 0: + return cls._new() + + kls = type(args[0]) + return reduce(kls.col_join, args) + + +class MatrixSpecial(MatrixRequired): + """Construction of special matrices""" + + @classmethod + def _eval_diag(cls, rows, cols, diag_dict): + """diag_dict is a defaultdict containing + all the entries of the diagonal matrix.""" + def entry(i, j): + return diag_dict[(i, j)] + return cls._new(rows, cols, entry) + + @classmethod + def _eval_eye(cls, rows, cols): + vals = [cls.zero]*(rows*cols) + vals[::cols+1] = [cls.one]*min(rows, cols) + return cls._new(rows, cols, vals, copy=False) + + @classmethod + def _eval_jordan_block(cls, size: int, eigenvalue, band='upper'): + if band == 'lower': + def entry(i, j): + if i == j: + return eigenvalue + elif j + 1 == i: + return cls.one + return cls.zero + else: + def entry(i, j): + if i == j: + return eigenvalue + elif i + 1 == j: + return cls.one + return cls.zero + return cls._new(size, size, entry) + + @classmethod + def _eval_ones(cls, rows, cols): + def entry(i, j): + return cls.one + return cls._new(rows, cols, entry) + + @classmethod + def _eval_zeros(cls, rows, cols): + return cls._new(rows, cols, [cls.zero]*(rows*cols), copy=False) + + @classmethod + def _eval_wilkinson(cls, n): + def entry(i, j): + return cls.one if i + 1 == j else cls.zero + + D = cls._new(2*n + 1, 2*n + 1, entry) + + wminus = cls.diag(list(range(-n, n + 1)), unpack=True) + D + D.T + wplus = abs(cls.diag(list(range(-n, n + 1)), unpack=True)) + D + D.T + + return wminus, wplus + + @classmethod + def diag(kls, *args, strict=False, unpack=True, rows=None, cols=None, **kwargs): + """Returns a matrix with the specified diagonal. + If matrices are passed, a block-diagonal matrix + is created (i.e. the "direct sum" of the matrices). + + kwargs + ====== + + rows : rows of the resulting matrix; computed if + not given. + + cols : columns of the resulting matrix; computed if + not given. + + cls : class for the resulting matrix + + unpack : bool which, when True (default), unpacks a single + sequence rather than interpreting it as a Matrix. + + strict : bool which, when False (default), allows Matrices to + have variable-length rows. + + Examples + ======== + + >>> from sympy import Matrix + >>> Matrix.diag(1, 2, 3) + Matrix([ + [1, 0, 0], + [0, 2, 0], + [0, 0, 3]]) + + The current default is to unpack a single sequence. If this is + not desired, set `unpack=False` and it will be interpreted as + a matrix. + + >>> Matrix.diag([1, 2, 3]) == Matrix.diag(1, 2, 3) + True + + When more than one element is passed, each is interpreted as + something to put on the diagonal. Lists are converted to + matrices. Filling of the diagonal always continues from + the bottom right hand corner of the previous item: this + will create a block-diagonal matrix whether the matrices + are square or not. + + >>> col = [1, 2, 3] + >>> row = [[4, 5]] + >>> Matrix.diag(col, row) + Matrix([ + [1, 0, 0], + [2, 0, 0], + [3, 0, 0], + [0, 4, 5]]) + + When `unpack` is False, elements within a list need not all be + of the same length. Setting `strict` to True would raise a + ValueError for the following: + + >>> Matrix.diag([[1, 2, 3], [4, 5], [6]], unpack=False) + Matrix([ + [1, 2, 3], + [4, 5, 0], + [6, 0, 0]]) + + The type of the returned matrix can be set with the ``cls`` + keyword. + + >>> from sympy import ImmutableMatrix + >>> from sympy.utilities.misc import func_name + >>> func_name(Matrix.diag(1, cls=ImmutableMatrix)) + 'ImmutableDenseMatrix' + + A zero dimension matrix can be used to position the start of + the filling at the start of an arbitrary row or column: + + >>> from sympy import ones + >>> r2 = ones(0, 2) + >>> Matrix.diag(r2, 1, 2) + Matrix([ + [0, 0, 1, 0], + [0, 0, 0, 2]]) + + See Also + ======== + eye + diagonal + .dense.diag + .expressions.blockmatrix.BlockMatrix + .sparsetools.banded + """ + from sympy.matrices.matrixbase import MatrixBase + from sympy.matrices.dense import Matrix + from sympy.matrices import SparseMatrix + klass = kwargs.get('cls', kls) + if unpack and len(args) == 1 and is_sequence(args[0]) and \ + not isinstance(args[0], MatrixBase): + args = args[0] + + # fill a default dict with the diagonal entries + diag_entries = defaultdict(int) + rmax = cmax = 0 # keep track of the biggest index seen + for m in args: + if isinstance(m, list): + if strict: + # if malformed, Matrix will raise an error + _ = Matrix(m) + r, c = _.shape + m = _.tolist() + else: + r, c, smat = SparseMatrix._handle_creation_inputs(m) + for (i, j), _ in smat.items(): + diag_entries[(i + rmax, j + cmax)] = _ + m = [] # to skip process below + elif hasattr(m, 'shape'): # a Matrix + # convert to list of lists + r, c = m.shape + m = m.tolist() + else: # in this case, we're a single value + diag_entries[(rmax, cmax)] = m + rmax += 1 + cmax += 1 + continue + # process list of lists + for i, mi in enumerate(m): + for j, _ in enumerate(mi): + diag_entries[(i + rmax, j + cmax)] = _ + rmax += r + cmax += c + if rows is None: + rows, cols = cols, rows + if rows is None: + rows, cols = rmax, cmax + else: + cols = rows if cols is None else cols + if rows < rmax or cols < cmax: + raise ValueError(filldedent(''' + The constructed matrix is {} x {} but a size of {} x {} + was specified.'''.format(rmax, cmax, rows, cols))) + return klass._eval_diag(rows, cols, diag_entries) + + @classmethod + def eye(kls, rows, cols=None, **kwargs): + """Returns an identity matrix. + + Parameters + ========== + + rows : rows of the matrix + cols : cols of the matrix (if None, cols=rows) + + kwargs + ====== + cls : class of the returned matrix + """ + if cols is None: + cols = rows + if rows < 0 or cols < 0: + raise ValueError("Cannot create a {} x {} matrix. " + "Both dimensions must be positive".format(rows, cols)) + klass = kwargs.get('cls', kls) + rows, cols = as_int(rows), as_int(cols) + + return klass._eval_eye(rows, cols) + + @classmethod + def jordan_block(kls, size=None, eigenvalue=None, *, band='upper', **kwargs): + """Returns a Jordan block + + Parameters + ========== + + size : Integer, optional + Specifies the shape of the Jordan block matrix. + + eigenvalue : Number or Symbol + Specifies the value for the main diagonal of the matrix. + + .. note:: + The keyword ``eigenval`` is also specified as an alias + of this keyword, but it is not recommended to use. + + We may deprecate the alias in later release. + + band : 'upper' or 'lower', optional + Specifies the position of the off-diagonal to put `1` s on. + + cls : Matrix, optional + Specifies the matrix class of the output form. + + If it is not specified, the class type where the method is + being executed on will be returned. + + Returns + ======= + + Matrix + A Jordan block matrix. + + Raises + ====== + + ValueError + If insufficient arguments are given for matrix size + specification, or no eigenvalue is given. + + Examples + ======== + + Creating a default Jordan block: + + >>> from sympy import Matrix + >>> from sympy.abc import x + >>> Matrix.jordan_block(4, x) + Matrix([ + [x, 1, 0, 0], + [0, x, 1, 0], + [0, 0, x, 1], + [0, 0, 0, x]]) + + Creating an alternative Jordan block matrix where `1` is on + lower off-diagonal: + + >>> Matrix.jordan_block(4, x, band='lower') + Matrix([ + [x, 0, 0, 0], + [1, x, 0, 0], + [0, 1, x, 0], + [0, 0, 1, x]]) + + Creating a Jordan block with keyword arguments + + >>> Matrix.jordan_block(size=4, eigenvalue=x) + Matrix([ + [x, 1, 0, 0], + [0, x, 1, 0], + [0, 0, x, 1], + [0, 0, 0, x]]) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Jordan_matrix + """ + klass = kwargs.pop('cls', kls) + + eigenval = kwargs.get('eigenval', None) + if eigenvalue is None and eigenval is None: + raise ValueError("Must supply an eigenvalue") + elif eigenvalue != eigenval and None not in (eigenval, eigenvalue): + raise ValueError( + "Inconsistent values are given: 'eigenval'={}, " + "'eigenvalue'={}".format(eigenval, eigenvalue)) + else: + if eigenval is not None: + eigenvalue = eigenval + + if size is None: + raise ValueError("Must supply a matrix size") + + size = as_int(size) + return klass._eval_jordan_block(size, eigenvalue, band) + + @classmethod + def ones(kls, rows, cols=None, **kwargs): + """Returns a matrix of ones. + + Parameters + ========== + + rows : rows of the matrix + cols : cols of the matrix (if None, cols=rows) + + kwargs + ====== + cls : class of the returned matrix + """ + if cols is None: + cols = rows + klass = kwargs.get('cls', kls) + rows, cols = as_int(rows), as_int(cols) + + return klass._eval_ones(rows, cols) + + @classmethod + def zeros(kls, rows, cols=None, **kwargs): + """Returns a matrix of zeros. + + Parameters + ========== + + rows : rows of the matrix + cols : cols of the matrix (if None, cols=rows) + + kwargs + ====== + cls : class of the returned matrix + """ + if cols is None: + cols = rows + if rows < 0 or cols < 0: + raise ValueError("Cannot create a {} x {} matrix. " + "Both dimensions must be positive".format(rows, cols)) + klass = kwargs.get('cls', kls) + rows, cols = as_int(rows), as_int(cols) + + return klass._eval_zeros(rows, cols) + + @classmethod + def companion(kls, poly): + """Returns a companion matrix of a polynomial. + + Examples + ======== + + >>> from sympy import Matrix, Poly, Symbol, symbols + >>> x = Symbol('x') + >>> c0, c1, c2, c3, c4 = symbols('c0:5') + >>> p = Poly(c0 + c1*x + c2*x**2 + c3*x**3 + c4*x**4 + x**5, x) + >>> Matrix.companion(p) + Matrix([ + [0, 0, 0, 0, -c0], + [1, 0, 0, 0, -c1], + [0, 1, 0, 0, -c2], + [0, 0, 1, 0, -c3], + [0, 0, 0, 1, -c4]]) + """ + poly = kls._sympify(poly) + if not isinstance(poly, Poly): + raise ValueError("{} must be a Poly instance.".format(poly)) + if not poly.is_monic: + raise ValueError("{} must be a monic polynomial.".format(poly)) + if not poly.is_univariate: + raise ValueError( + "{} must be a univariate polynomial.".format(poly)) + + size = poly.degree() + if not size >= 1: + raise ValueError( + "{} must have degree not less than 1.".format(poly)) + + coeffs = poly.all_coeffs() + def entry(i, j): + if j == size - 1: + return -coeffs[-1 - i] + elif i == j + 1: + return kls.one + return kls.zero + return kls._new(size, size, entry) + + + @classmethod + def wilkinson(kls, n, **kwargs): + """Returns two square Wilkinson Matrix of size 2*n + 1 + $W_{2n + 1}^-, W_{2n + 1}^+ =$ Wilkinson(n) + + Examples + ======== + + >>> from sympy import Matrix + >>> wminus, wplus = Matrix.wilkinson(3) + >>> wminus + Matrix([ + [-3, 1, 0, 0, 0, 0, 0], + [ 1, -2, 1, 0, 0, 0, 0], + [ 0, 1, -1, 1, 0, 0, 0], + [ 0, 0, 1, 0, 1, 0, 0], + [ 0, 0, 0, 1, 1, 1, 0], + [ 0, 0, 0, 0, 1, 2, 1], + [ 0, 0, 0, 0, 0, 1, 3]]) + >>> wplus + Matrix([ + [3, 1, 0, 0, 0, 0, 0], + [1, 2, 1, 0, 0, 0, 0], + [0, 1, 1, 1, 0, 0, 0], + [0, 0, 1, 0, 1, 0, 0], + [0, 0, 0, 1, 1, 1, 0], + [0, 0, 0, 0, 1, 2, 1], + [0, 0, 0, 0, 0, 1, 3]]) + + References + ========== + + .. [1] https://blogs.mathworks.com/cleve/2013/04/15/wilkinsons-matrices-2/ + .. [2] J. H. Wilkinson, The Algebraic Eigenvalue Problem, Claredon Press, Oxford, 1965, 662 pp. + + """ + klass = kwargs.get('cls', kls) + n = as_int(n) + return klass._eval_wilkinson(n) + +class MatrixProperties(MatrixRequired): + """Provides basic properties of a matrix.""" + + def _eval_atoms(self, *types): + result = set() + for i in self: + result.update(i.atoms(*types)) + return result + + def _eval_free_symbols(self): + return set().union(*(i.free_symbols for i in self if i)) + + def _eval_has(self, *patterns): + return any(a.has(*patterns) for a in self) + + def _eval_is_anti_symmetric(self, simpfunc): + if not all(simpfunc(self[i, j] + self[j, i]).is_zero for i in range(self.rows) for j in range(self.cols)): + return False + return True + + def _eval_is_diagonal(self): + for i in range(self.rows): + for j in range(self.cols): + if i != j and self[i, j]: + return False + return True + + # _eval_is_hermitian is called by some general SymPy + # routines and has a different *args signature. Make + # sure the names don't clash by adding `_matrix_` in name. + def _eval_is_matrix_hermitian(self, simpfunc): + mat = self._new(self.rows, self.cols, lambda i, j: simpfunc(self[i, j] - self[j, i].conjugate())) + return mat.is_zero_matrix + + def _eval_is_Identity(self) -> FuzzyBool: + def dirac(i, j): + if i == j: + return 1 + return 0 + + return all(self[i, j] == dirac(i, j) + for i in range(self.rows) + for j in range(self.cols)) + + def _eval_is_lower_hessenberg(self): + return all(self[i, j].is_zero + for i in range(self.rows) + for j in range(i + 2, self.cols)) + + def _eval_is_lower(self): + return all(self[i, j].is_zero + for i in range(self.rows) + for j in range(i + 1, self.cols)) + + def _eval_is_symbolic(self): + return self.has(Symbol) + + def _eval_is_symmetric(self, simpfunc): + mat = self._new(self.rows, self.cols, lambda i, j: simpfunc(self[i, j] - self[j, i])) + return mat.is_zero_matrix + + def _eval_is_zero_matrix(self): + if any(i.is_zero == False for i in self): + return False + if any(i.is_zero is None for i in self): + return None + return True + + def _eval_is_upper_hessenberg(self): + return all(self[i, j].is_zero + for i in range(2, self.rows) + for j in range(min(self.cols, (i - 1)))) + + def _eval_values(self): + return [i for i in self if not i.is_zero] + + def _has_positive_diagonals(self): + diagonal_entries = (self[i, i] for i in range(self.rows)) + return fuzzy_and(x.is_positive for x in diagonal_entries) + + def _has_nonnegative_diagonals(self): + diagonal_entries = (self[i, i] for i in range(self.rows)) + return fuzzy_and(x.is_nonnegative for x in diagonal_entries) + + def atoms(self, *types): + """Returns the atoms that form the current object. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy import Matrix + >>> Matrix([[x]]) + Matrix([[x]]) + >>> _.atoms() + {x} + >>> Matrix([[x, y], [y, x]]) + Matrix([ + [x, y], + [y, x]]) + >>> _.atoms() + {x, y} + """ + + types = tuple(t if isinstance(t, type) else type(t) for t in types) + if not types: + types = (Atom,) + return self._eval_atoms(*types) + + @property + def free_symbols(self): + """Returns the free symbols within the matrix. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy import Matrix + >>> Matrix([[x], [1]]).free_symbols + {x} + """ + return self._eval_free_symbols() + + def has(self, *patterns): + """Test whether any subexpression matches any of the patterns. + + Examples + ======== + + >>> from sympy import Matrix, SparseMatrix, Float + >>> from sympy.abc import x, y + >>> A = Matrix(((1, x), (0.2, 3))) + >>> B = SparseMatrix(((1, x), (0.2, 3))) + >>> A.has(x) + True + >>> A.has(y) + False + >>> A.has(Float) + True + >>> B.has(x) + True + >>> B.has(y) + False + >>> B.has(Float) + True + """ + return self._eval_has(*patterns) + + def is_anti_symmetric(self, simplify=True): + """Check if matrix M is an antisymmetric matrix, + that is, M is a square matrix with all M[i, j] == -M[j, i]. + + When ``simplify=True`` (default), the sum M[i, j] + M[j, i] is + simplified before testing to see if it is zero. By default, + the SymPy simplify function is used. To use a custom function + set simplify to a function that accepts a single argument which + returns a simplified expression. To skip simplification, set + simplify to False but note that although this will be faster, + it may induce false negatives. + + Examples + ======== + + >>> from sympy import Matrix, symbols + >>> m = Matrix(2, 2, [0, 1, -1, 0]) + >>> m + Matrix([ + [ 0, 1], + [-1, 0]]) + >>> m.is_anti_symmetric() + True + >>> x, y = symbols('x y') + >>> m = Matrix(2, 3, [0, 0, x, -y, 0, 0]) + >>> m + Matrix([ + [ 0, 0, x], + [-y, 0, 0]]) + >>> m.is_anti_symmetric() + False + + >>> from sympy.abc import x, y + >>> m = Matrix(3, 3, [0, x**2 + 2*x + 1, y, + ... -(x + 1)**2, 0, x*y, + ... -y, -x*y, 0]) + + Simplification of matrix elements is done by default so even + though two elements which should be equal and opposite would not + pass an equality test, the matrix is still reported as + anti-symmetric: + + >>> m[0, 1] == -m[1, 0] + False + >>> m.is_anti_symmetric() + True + + If ``simplify=False`` is used for the case when a Matrix is already + simplified, this will speed things up. Here, we see that without + simplification the matrix does not appear anti-symmetric: + + >>> print(m.is_anti_symmetric(simplify=False)) + None + + But if the matrix were already expanded, then it would appear + anti-symmetric and simplification in the is_anti_symmetric routine + is not needed: + + >>> m = m.expand() + >>> m.is_anti_symmetric(simplify=False) + True + """ + # accept custom simplification + simpfunc = simplify + if not isfunction(simplify): + simpfunc = _simplify if simplify else lambda x: x + + if not self.is_square: + return False + return self._eval_is_anti_symmetric(simpfunc) + + def is_diagonal(self): + """Check if matrix is diagonal, + that is matrix in which the entries outside the main diagonal are all zero. + + Examples + ======== + + >>> from sympy import Matrix, diag + >>> m = Matrix(2, 2, [1, 0, 0, 2]) + >>> m + Matrix([ + [1, 0], + [0, 2]]) + >>> m.is_diagonal() + True + + >>> m = Matrix(2, 2, [1, 1, 0, 2]) + >>> m + Matrix([ + [1, 1], + [0, 2]]) + >>> m.is_diagonal() + False + + >>> m = diag(1, 2, 3) + >>> m + Matrix([ + [1, 0, 0], + [0, 2, 0], + [0, 0, 3]]) + >>> m.is_diagonal() + True + + See Also + ======== + + is_lower + is_upper + sympy.matrices.matrixbase.MatrixCommon.is_diagonalizable + diagonalize + """ + return self._eval_is_diagonal() + + @property + def is_weakly_diagonally_dominant(self): + r"""Tests if the matrix is row weakly diagonally dominant. + + Explanation + =========== + + A $n, n$ matrix $A$ is row weakly diagonally dominant if + + .. math:: + \left|A_{i, i}\right| \ge \sum_{j = 0, j \neq i}^{n-1} + \left|A_{i, j}\right| \quad {\text{for all }} + i \in \{ 0, ..., n-1 \} + + Examples + ======== + + >>> from sympy import Matrix + >>> A = Matrix([[3, -2, 1], [1, -3, 2], [-1, 2, 4]]) + >>> A.is_weakly_diagonally_dominant + True + + >>> A = Matrix([[-2, 2, 1], [1, 3, 2], [1, -2, 0]]) + >>> A.is_weakly_diagonally_dominant + False + + >>> A = Matrix([[-4, 2, 1], [1, 6, 2], [1, -2, 5]]) + >>> A.is_weakly_diagonally_dominant + True + + Notes + ===== + + If you want to test whether a matrix is column diagonally + dominant, you can apply the test after transposing the matrix. + """ + if not self.is_square: + return False + + rows, cols = self.shape + + def test_row(i): + summation = self.zero + for j in range(cols): + if i != j: + summation += Abs(self[i, j]) + return (Abs(self[i, i]) - summation).is_nonnegative + + return fuzzy_and(test_row(i) for i in range(rows)) + + @property + def is_strongly_diagonally_dominant(self): + r"""Tests if the matrix is row strongly diagonally dominant. + + Explanation + =========== + + A $n, n$ matrix $A$ is row strongly diagonally dominant if + + .. math:: + \left|A_{i, i}\right| > \sum_{j = 0, j \neq i}^{n-1} + \left|A_{i, j}\right| \quad {\text{for all }} + i \in \{ 0, ..., n-1 \} + + Examples + ======== + + >>> from sympy import Matrix + >>> A = Matrix([[3, -2, 1], [1, -3, 2], [-1, 2, 4]]) + >>> A.is_strongly_diagonally_dominant + False + + >>> A = Matrix([[-2, 2, 1], [1, 3, 2], [1, -2, 0]]) + >>> A.is_strongly_diagonally_dominant + False + + >>> A = Matrix([[-4, 2, 1], [1, 6, 2], [1, -2, 5]]) + >>> A.is_strongly_diagonally_dominant + True + + Notes + ===== + + If you want to test whether a matrix is column diagonally + dominant, you can apply the test after transposing the matrix. + """ + if not self.is_square: + return False + + rows, cols = self.shape + + def test_row(i): + summation = self.zero + for j in range(cols): + if i != j: + summation += Abs(self[i, j]) + return (Abs(self[i, i]) - summation).is_positive + + return fuzzy_and(test_row(i) for i in range(rows)) + + @property + def is_hermitian(self): + """Checks if the matrix is Hermitian. + + In a Hermitian matrix element i,j is the complex conjugate of + element j,i. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy import I + >>> from sympy.abc import x + >>> a = Matrix([[1, I], [-I, 1]]) + >>> a + Matrix([ + [ 1, I], + [-I, 1]]) + >>> a.is_hermitian + True + >>> a[0, 0] = 2*I + >>> a.is_hermitian + False + >>> a[0, 0] = x + >>> a.is_hermitian + >>> a[0, 1] = a[1, 0]*I + >>> a.is_hermitian + False + """ + if not self.is_square: + return False + + return self._eval_is_matrix_hermitian(_simplify) + + @property + def is_Identity(self) -> FuzzyBool: + if not self.is_square: + return False + return self._eval_is_Identity() + + @property + def is_lower_hessenberg(self): + r"""Checks if the matrix is in the lower-Hessenberg form. + + The lower hessenberg matrix has zero entries + above the first superdiagonal. + + Examples + ======== + + >>> from sympy import Matrix + >>> a = Matrix([[1, 2, 0, 0], [5, 2, 3, 0], [3, 4, 3, 7], [5, 6, 1, 1]]) + >>> a + Matrix([ + [1, 2, 0, 0], + [5, 2, 3, 0], + [3, 4, 3, 7], + [5, 6, 1, 1]]) + >>> a.is_lower_hessenberg + True + + See Also + ======== + + is_upper_hessenberg + is_lower + """ + return self._eval_is_lower_hessenberg() + + @property + def is_lower(self): + """Check if matrix is a lower triangular matrix. True can be returned + even if the matrix is not square. + + Examples + ======== + + >>> from sympy import Matrix + >>> m = Matrix(2, 2, [1, 0, 0, 1]) + >>> m + Matrix([ + [1, 0], + [0, 1]]) + >>> m.is_lower + True + + >>> m = Matrix(4, 3, [0, 0, 0, 2, 0, 0, 1, 4, 0, 6, 6, 5]) + >>> m + Matrix([ + [0, 0, 0], + [2, 0, 0], + [1, 4, 0], + [6, 6, 5]]) + >>> m.is_lower + True + + >>> from sympy.abc import x, y + >>> m = Matrix(2, 2, [x**2 + y, y**2 + x, 0, x + y]) + >>> m + Matrix([ + [x**2 + y, x + y**2], + [ 0, x + y]]) + >>> m.is_lower + False + + See Also + ======== + + is_upper + is_diagonal + is_lower_hessenberg + """ + return self._eval_is_lower() + + @property + def is_square(self): + """Checks if a matrix is square. + + A matrix is square if the number of rows equals the number of columns. + The empty matrix is square by definition, since the number of rows and + the number of columns are both zero. + + Examples + ======== + + >>> from sympy import Matrix + >>> a = Matrix([[1, 2, 3], [4, 5, 6]]) + >>> b = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + >>> c = Matrix([]) + >>> a.is_square + False + >>> b.is_square + True + >>> c.is_square + True + """ + return self.rows == self.cols + + def is_symbolic(self): + """Checks if any elements contain Symbols. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.abc import x, y + >>> M = Matrix([[x, y], [1, 0]]) + >>> M.is_symbolic() + True + + """ + return self._eval_is_symbolic() + + def is_symmetric(self, simplify=True): + """Check if matrix is symmetric matrix, + that is square matrix and is equal to its transpose. + + By default, simplifications occur before testing symmetry. + They can be skipped using 'simplify=False'; while speeding things a bit, + this may however induce false negatives. + + Examples + ======== + + >>> from sympy import Matrix + >>> m = Matrix(2, 2, [0, 1, 1, 2]) + >>> m + Matrix([ + [0, 1], + [1, 2]]) + >>> m.is_symmetric() + True + + >>> m = Matrix(2, 2, [0, 1, 2, 0]) + >>> m + Matrix([ + [0, 1], + [2, 0]]) + >>> m.is_symmetric() + False + + >>> m = Matrix(2, 3, [0, 0, 0, 0, 0, 0]) + >>> m + Matrix([ + [0, 0, 0], + [0, 0, 0]]) + >>> m.is_symmetric() + False + + >>> from sympy.abc import x, y + >>> m = Matrix(3, 3, [1, x**2 + 2*x + 1, y, (x + 1)**2, 2, 0, y, 0, 3]) + >>> m + Matrix([ + [ 1, x**2 + 2*x + 1, y], + [(x + 1)**2, 2, 0], + [ y, 0, 3]]) + >>> m.is_symmetric() + True + + If the matrix is already simplified, you may speed-up is_symmetric() + test by using 'simplify=False'. + + >>> bool(m.is_symmetric(simplify=False)) + False + >>> m1 = m.expand() + >>> m1.is_symmetric(simplify=False) + True + """ + simpfunc = simplify + if not isfunction(simplify): + simpfunc = _simplify if simplify else lambda x: x + + if not self.is_square: + return False + + return self._eval_is_symmetric(simpfunc) + + @property + def is_upper_hessenberg(self): + """Checks if the matrix is the upper-Hessenberg form. + + The upper hessenberg matrix has zero entries + below the first subdiagonal. + + Examples + ======== + + >>> from sympy import Matrix + >>> a = Matrix([[1, 4, 2, 3], [3, 4, 1, 7], [0, 2, 3, 4], [0, 0, 1, 3]]) + >>> a + Matrix([ + [1, 4, 2, 3], + [3, 4, 1, 7], + [0, 2, 3, 4], + [0, 0, 1, 3]]) + >>> a.is_upper_hessenberg + True + + See Also + ======== + + is_lower_hessenberg + is_upper + """ + return self._eval_is_upper_hessenberg() + + @property + def is_upper(self): + """Check if matrix is an upper triangular matrix. True can be returned + even if the matrix is not square. + + Examples + ======== + + >>> from sympy import Matrix + >>> m = Matrix(2, 2, [1, 0, 0, 1]) + >>> m + Matrix([ + [1, 0], + [0, 1]]) + >>> m.is_upper + True + + >>> m = Matrix(4, 3, [5, 1, 9, 0, 4, 6, 0, 0, 5, 0, 0, 0]) + >>> m + Matrix([ + [5, 1, 9], + [0, 4, 6], + [0, 0, 5], + [0, 0, 0]]) + >>> m.is_upper + True + + >>> m = Matrix(2, 3, [4, 2, 5, 6, 1, 1]) + >>> m + Matrix([ + [4, 2, 5], + [6, 1, 1]]) + >>> m.is_upper + False + + See Also + ======== + + is_lower + is_diagonal + is_upper_hessenberg + """ + return all(self[i, j].is_zero + for i in range(1, self.rows) + for j in range(min(i, self.cols))) + + @property + def is_zero_matrix(self): + """Checks if a matrix is a zero matrix. + + A matrix is zero if every element is zero. A matrix need not be square + to be considered zero. The empty matrix is zero by the principle of + vacuous truth. For a matrix that may or may not be zero (e.g. + contains a symbol), this will be None + + Examples + ======== + + >>> from sympy import Matrix, zeros + >>> from sympy.abc import x + >>> a = Matrix([[0, 0], [0, 0]]) + >>> b = zeros(3, 4) + >>> c = Matrix([[0, 1], [0, 0]]) + >>> d = Matrix([]) + >>> e = Matrix([[x, 0], [0, 0]]) + >>> a.is_zero_matrix + True + >>> b.is_zero_matrix + True + >>> c.is_zero_matrix + False + >>> d.is_zero_matrix + True + >>> e.is_zero_matrix + """ + return self._eval_is_zero_matrix() + + def values(self): + """Return non-zero values of self.""" + return self._eval_values() + + +class MatrixOperations(MatrixRequired): + """Provides basic matrix shape and elementwise + operations. Should not be instantiated directly.""" + + def _eval_adjoint(self): + return self.transpose().conjugate() + + def _eval_applyfunc(self, f): + out = self._new(self.rows, self.cols, [f(x) for x in self]) + return out + + def _eval_as_real_imag(self): # type: ignore + return (self.applyfunc(re), self.applyfunc(im)) + + def _eval_conjugate(self): + return self.applyfunc(lambda x: x.conjugate()) + + def _eval_permute_cols(self, perm): + # apply the permutation to a list + mapping = list(perm) + + def entry(i, j): + return self[i, mapping[j]] + + return self._new(self.rows, self.cols, entry) + + def _eval_permute_rows(self, perm): + # apply the permutation to a list + mapping = list(perm) + + def entry(i, j): + return self[mapping[i], j] + + return self._new(self.rows, self.cols, entry) + + def _eval_trace(self): + return sum(self[i, i] for i in range(self.rows)) + + def _eval_transpose(self): + return self._new(self.cols, self.rows, lambda i, j: self[j, i]) + + def adjoint(self): + """Conjugate transpose or Hermitian conjugation.""" + return self._eval_adjoint() + + def applyfunc(self, f): + """Apply a function to each element of the matrix. + + Examples + ======== + + >>> from sympy import Matrix + >>> m = Matrix(2, 2, lambda i, j: i*2+j) + >>> m + Matrix([ + [0, 1], + [2, 3]]) + >>> m.applyfunc(lambda i: 2*i) + Matrix([ + [0, 2], + [4, 6]]) + + """ + if not callable(f): + raise TypeError("`f` must be callable.") + + return self._eval_applyfunc(f) + + def as_real_imag(self, deep=True, **hints): + """Returns a tuple containing the (real, imaginary) part of matrix.""" + # XXX: Ignoring deep and hints... + return self._eval_as_real_imag() + + def conjugate(self): + """Return the by-element conjugation. + + Examples + ======== + + >>> from sympy import SparseMatrix, I + >>> a = SparseMatrix(((1, 2 + I), (3, 4), (I, -I))) + >>> a + Matrix([ + [1, 2 + I], + [3, 4], + [I, -I]]) + >>> a.C + Matrix([ + [ 1, 2 - I], + [ 3, 4], + [-I, I]]) + + See Also + ======== + + transpose: Matrix transposition + H: Hermite conjugation + sympy.matrices.matrixbase.MatrixBase.D: Dirac conjugation + """ + return self._eval_conjugate() + + def doit(self, **hints): + return self.applyfunc(lambda x: x.doit(**hints)) + + def evalf(self, n=15, subs=None, maxn=100, chop=False, strict=False, quad=None, verbose=False): + """Apply evalf() to each element of self.""" + options = {'subs':subs, 'maxn':maxn, 'chop':chop, 'strict':strict, + 'quad':quad, 'verbose':verbose} + return self.applyfunc(lambda i: i.evalf(n, **options)) + + def expand(self, deep=True, modulus=None, power_base=True, power_exp=True, + mul=True, log=True, multinomial=True, basic=True, **hints): + """Apply core.function.expand to each entry of the matrix. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy import Matrix + >>> Matrix(1, 1, [x*(x+1)]) + Matrix([[x*(x + 1)]]) + >>> _.expand() + Matrix([[x**2 + x]]) + + """ + return self.applyfunc(lambda x: x.expand( + deep, modulus, power_base, power_exp, mul, log, multinomial, basic, + **hints)) + + @property + def H(self): + """Return Hermite conjugate. + + Examples + ======== + + >>> from sympy import Matrix, I + >>> m = Matrix((0, 1 + I, 2, 3)) + >>> m + Matrix([ + [ 0], + [1 + I], + [ 2], + [ 3]]) + >>> m.H + Matrix([[0, 1 - I, 2, 3]]) + + See Also + ======== + + conjugate: By-element conjugation + sympy.matrices.matrixbase.MatrixBase.D: Dirac conjugation + """ + return self.T.C + + def permute(self, perm, orientation='rows', direction='forward'): + r"""Permute the rows or columns of a matrix by the given list of + swaps. + + Parameters + ========== + + perm : Permutation, list, or list of lists + A representation for the permutation. + + If it is ``Permutation``, it is used directly with some + resizing with respect to the matrix size. + + If it is specified as list of lists, + (e.g., ``[[0, 1], [0, 2]]``), then the permutation is formed + from applying the product of cycles. The direction how the + cyclic product is applied is described in below. + + If it is specified as a list, the list should represent + an array form of a permutation. (e.g., ``[1, 2, 0]``) which + would would form the swapping function + `0 \mapsto 1, 1 \mapsto 2, 2\mapsto 0`. + + orientation : 'rows', 'cols' + A flag to control whether to permute the rows or the columns + + direction : 'forward', 'backward' + A flag to control whether to apply the permutations from + the start of the list first, or from the back of the list + first. + + For example, if the permutation specification is + ``[[0, 1], [0, 2]]``, + + If the flag is set to ``'forward'``, the cycle would be + formed as `0 \mapsto 2, 2 \mapsto 1, 1 \mapsto 0`. + + If the flag is set to ``'backward'``, the cycle would be + formed as `0 \mapsto 1, 1 \mapsto 2, 2 \mapsto 0`. + + If the argument ``perm`` is not in a form of list of lists, + this flag takes no effect. + + Examples + ======== + + >>> from sympy import eye + >>> M = eye(3) + >>> M.permute([[0, 1], [0, 2]], orientation='rows', direction='forward') + Matrix([ + [0, 0, 1], + [1, 0, 0], + [0, 1, 0]]) + + >>> from sympy import eye + >>> M = eye(3) + >>> M.permute([[0, 1], [0, 2]], orientation='rows', direction='backward') + Matrix([ + [0, 1, 0], + [0, 0, 1], + [1, 0, 0]]) + + Notes + ===== + + If a bijective function + `\sigma : \mathbb{N}_0 \rightarrow \mathbb{N}_0` denotes the + permutation. + + If the matrix `A` is the matrix to permute, represented as + a horizontal or a vertical stack of vectors: + + .. math:: + A = + \begin{bmatrix} + a_0 \\ a_1 \\ \vdots \\ a_{n-1} + \end{bmatrix} = + \begin{bmatrix} + \alpha_0 & \alpha_1 & \cdots & \alpha_{n-1} + \end{bmatrix} + + If the matrix `B` is the result, the permutation of matrix rows + is defined as: + + .. math:: + B := \begin{bmatrix} + a_{\sigma(0)} \\ a_{\sigma(1)} \\ \vdots \\ a_{\sigma(n-1)} + \end{bmatrix} + + And the permutation of matrix columns is defined as: + + .. math:: + B := \begin{bmatrix} + \alpha_{\sigma(0)} & \alpha_{\sigma(1)} & + \cdots & \alpha_{\sigma(n-1)} + \end{bmatrix} + """ + from sympy.combinatorics import Permutation + + # allow british variants and `columns` + if direction == 'forwards': + direction = 'forward' + if direction == 'backwards': + direction = 'backward' + if orientation == 'columns': + orientation = 'cols' + + if direction not in ('forward', 'backward'): + raise TypeError("direction='{}' is an invalid kwarg. " + "Try 'forward' or 'backward'".format(direction)) + if orientation not in ('rows', 'cols'): + raise TypeError("orientation='{}' is an invalid kwarg. " + "Try 'rows' or 'cols'".format(orientation)) + + if not isinstance(perm, (Permutation, Iterable)): + raise ValueError( + "{} must be a list, a list of lists, " + "or a SymPy permutation object.".format(perm)) + + # ensure all swaps are in range + max_index = self.rows if orientation == 'rows' else self.cols + if not all(0 <= t <= max_index for t in flatten(list(perm))): + raise IndexError("`swap` indices out of range.") + + if perm and not isinstance(perm, Permutation) and \ + isinstance(perm[0], Iterable): + if direction == 'forward': + perm = list(reversed(perm)) + perm = Permutation(perm, size=max_index+1) + else: + perm = Permutation(perm, size=max_index+1) + + if orientation == 'rows': + return self._eval_permute_rows(perm) + if orientation == 'cols': + return self._eval_permute_cols(perm) + + def permute_cols(self, swaps, direction='forward'): + """Alias for + ``self.permute(swaps, orientation='cols', direction=direction)`` + + See Also + ======== + + permute + """ + return self.permute(swaps, orientation='cols', direction=direction) + + def permute_rows(self, swaps, direction='forward'): + """Alias for + ``self.permute(swaps, orientation='rows', direction=direction)`` + + See Also + ======== + + permute + """ + return self.permute(swaps, orientation='rows', direction=direction) + + def refine(self, assumptions=True): + """Apply refine to each element of the matrix. + + Examples + ======== + + >>> from sympy import Symbol, Matrix, Abs, sqrt, Q + >>> x = Symbol('x') + >>> Matrix([[Abs(x)**2, sqrt(x**2)],[sqrt(x**2), Abs(x)**2]]) + Matrix([ + [ Abs(x)**2, sqrt(x**2)], + [sqrt(x**2), Abs(x)**2]]) + >>> _.refine(Q.real(x)) + Matrix([ + [ x**2, Abs(x)], + [Abs(x), x**2]]) + + """ + return self.applyfunc(lambda x: refine(x, assumptions)) + + def replace(self, F, G, map=False, simultaneous=True, exact=None): + """Replaces Function F in Matrix entries with Function G. + + Examples + ======== + + >>> from sympy import symbols, Function, Matrix + >>> F, G = symbols('F, G', cls=Function) + >>> M = Matrix(2, 2, lambda i, j: F(i+j)) ; M + Matrix([ + [F(0), F(1)], + [F(1), F(2)]]) + >>> N = M.replace(F,G) + >>> N + Matrix([ + [G(0), G(1)], + [G(1), G(2)]]) + """ + return self.applyfunc( + lambda x: x.replace(F, G, map=map, simultaneous=simultaneous, exact=exact)) + + def rot90(self, k=1): + """Rotates Matrix by 90 degrees + + Parameters + ========== + + k : int + Specifies how many times the matrix is rotated by 90 degrees + (clockwise when positive, counter-clockwise when negative). + + Examples + ======== + + >>> from sympy import Matrix, symbols + >>> A = Matrix(2, 2, symbols('a:d')) + >>> A + Matrix([ + [a, b], + [c, d]]) + + Rotating the matrix clockwise one time: + + >>> A.rot90(1) + Matrix([ + [c, a], + [d, b]]) + + Rotating the matrix anticlockwise two times: + + >>> A.rot90(-2) + Matrix([ + [d, c], + [b, a]]) + """ + + mod = k%4 + if mod == 0: + return self + if mod == 1: + return self[::-1, ::].T + if mod == 2: + return self[::-1, ::-1] + if mod == 3: + return self[::, ::-1].T + + def simplify(self, **kwargs): + """Apply simplify to each element of the matrix. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy import SparseMatrix, sin, cos + >>> SparseMatrix(1, 1, [x*sin(y)**2 + x*cos(y)**2]) + Matrix([[x*sin(y)**2 + x*cos(y)**2]]) + >>> _.simplify() + Matrix([[x]]) + """ + return self.applyfunc(lambda x: x.simplify(**kwargs)) + + def subs(self, *args, **kwargs): # should mirror core.basic.subs + """Return a new matrix with subs applied to each entry. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy import SparseMatrix, Matrix + >>> SparseMatrix(1, 1, [x]) + Matrix([[x]]) + >>> _.subs(x, y) + Matrix([[y]]) + >>> Matrix(_).subs(y, x) + Matrix([[x]]) + """ + + if len(args) == 1 and not isinstance(args[0], (dict, set)) and iter(args[0]) and not is_sequence(args[0]): + args = (list(args[0]),) + + return self.applyfunc(lambda x: x.subs(*args, **kwargs)) + + def trace(self): + """ + Returns the trace of a square matrix i.e. the sum of the + diagonal elements. + + Examples + ======== + + >>> from sympy import Matrix + >>> A = Matrix(2, 2, [1, 2, 3, 4]) + >>> A.trace() + 5 + + """ + if self.rows != self.cols: + raise NonSquareMatrixError() + return self._eval_trace() + + def transpose(self): + """ + Returns the transpose of the matrix. + + Examples + ======== + + >>> from sympy import Matrix + >>> A = Matrix(2, 2, [1, 2, 3, 4]) + >>> A.transpose() + Matrix([ + [1, 3], + [2, 4]]) + + >>> from sympy import Matrix, I + >>> m=Matrix(((1, 2+I), (3, 4))) + >>> m + Matrix([ + [1, 2 + I], + [3, 4]]) + >>> m.transpose() + Matrix([ + [ 1, 3], + [2 + I, 4]]) + >>> m.T == m.transpose() + True + + See Also + ======== + + conjugate: By-element conjugation + + """ + return self._eval_transpose() + + @property + def T(self): + '''Matrix transposition''' + return self.transpose() + + @property + def C(self): + '''By-element conjugation''' + return self.conjugate() + + def n(self, *args, **kwargs): + """Apply evalf() to each element of self.""" + return self.evalf(*args, **kwargs) + + def xreplace(self, rule): # should mirror core.basic.xreplace + """Return a new matrix with xreplace applied to each entry. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy import SparseMatrix, Matrix + >>> SparseMatrix(1, 1, [x]) + Matrix([[x]]) + >>> _.xreplace({x: y}) + Matrix([[y]]) + >>> Matrix(_).xreplace({y: x}) + Matrix([[x]]) + """ + return self.applyfunc(lambda x: x.xreplace(rule)) + + def _eval_simplify(self, **kwargs): + # XXX: We can't use self.simplify here as mutable subclasses will + # override simplify and have it return None + return MatrixOperations.simplify(self, **kwargs) + + def _eval_trigsimp(self, **opts): + from sympy.simplify.trigsimp import trigsimp + return self.applyfunc(lambda x: trigsimp(x, **opts)) + + def upper_triangular(self, k=0): + """Return the elements on and above the kth diagonal of a matrix. + If k is not specified then simply returns upper-triangular portion + of a matrix + + Examples + ======== + + >>> from sympy import ones + >>> A = ones(4) + >>> A.upper_triangular() + Matrix([ + [1, 1, 1, 1], + [0, 1, 1, 1], + [0, 0, 1, 1], + [0, 0, 0, 1]]) + + >>> A.upper_triangular(2) + Matrix([ + [0, 0, 1, 1], + [0, 0, 0, 1], + [0, 0, 0, 0], + [0, 0, 0, 0]]) + + >>> A.upper_triangular(-1) + Matrix([ + [1, 1, 1, 1], + [1, 1, 1, 1], + [0, 1, 1, 1], + [0, 0, 1, 1]]) + + """ + + def entry(i, j): + return self[i, j] if i + k <= j else self.zero + + return self._new(self.rows, self.cols, entry) + + + def lower_triangular(self, k=0): + """Return the elements on and below the kth diagonal of a matrix. + If k is not specified then simply returns lower-triangular portion + of a matrix + + Examples + ======== + + >>> from sympy import ones + >>> A = ones(4) + >>> A.lower_triangular() + Matrix([ + [1, 0, 0, 0], + [1, 1, 0, 0], + [1, 1, 1, 0], + [1, 1, 1, 1]]) + + >>> A.lower_triangular(-2) + Matrix([ + [0, 0, 0, 0], + [0, 0, 0, 0], + [1, 0, 0, 0], + [1, 1, 0, 0]]) + + >>> A.lower_triangular(1) + Matrix([ + [1, 1, 0, 0], + [1, 1, 1, 0], + [1, 1, 1, 1], + [1, 1, 1, 1]]) + + """ + + def entry(i, j): + return self[i, j] if i + k >= j else self.zero + + return self._new(self.rows, self.cols, entry) + + + +class MatrixArithmetic(MatrixRequired): + """Provides basic matrix arithmetic operations. + Should not be instantiated directly.""" + + _op_priority = 10.01 + + def _eval_Abs(self): + return self._new(self.rows, self.cols, lambda i, j: Abs(self[i, j])) + + def _eval_add(self, other): + return self._new(self.rows, self.cols, + lambda i, j: self[i, j] + other[i, j]) + + def _eval_matrix_mul(self, other): + def entry(i, j): + vec = [self[i,k]*other[k,j] for k in range(self.cols)] + try: + return Add(*vec) + except (TypeError, SympifyError): + # Some matrices don't work with `sum` or `Add` + # They don't work with `sum` because `sum` tries to add `0` + # Fall back to a safe way to multiply if the `Add` fails. + return reduce(lambda a, b: a + b, vec) + + return self._new(self.rows, other.cols, entry) + + def _eval_matrix_mul_elementwise(self, other): + return self._new(self.rows, self.cols, lambda i, j: self[i,j]*other[i,j]) + + def _eval_matrix_rmul(self, other): + def entry(i, j): + return sum(other[i,k]*self[k,j] for k in range(other.cols)) + return self._new(other.rows, self.cols, entry) + + def _eval_pow_by_recursion(self, num): + if num == 1: + return self + + if num % 2 == 1: + a, b = self, self._eval_pow_by_recursion(num - 1) + else: + a = b = self._eval_pow_by_recursion(num // 2) + + return a.multiply(b) + + def _eval_pow_by_cayley(self, exp): + from sympy.discrete.recurrences import linrec_coeffs + row = self.shape[0] + p = self.charpoly() + + coeffs = (-p).all_coeffs()[1:] + coeffs = linrec_coeffs(coeffs, exp) + new_mat = self.eye(row) + ans = self.zeros(row) + + for i in range(row): + ans += coeffs[i]*new_mat + new_mat *= self + + return ans + + def _eval_pow_by_recursion_dotprodsimp(self, num, prevsimp=None): + if prevsimp is None: + prevsimp = [True]*len(self) + + if num == 1: + return self + + if num % 2 == 1: + a, b = self, self._eval_pow_by_recursion_dotprodsimp(num - 1, + prevsimp=prevsimp) + else: + a = b = self._eval_pow_by_recursion_dotprodsimp(num // 2, + prevsimp=prevsimp) + + m = a.multiply(b, dotprodsimp=False) + lenm = len(m) + elems = [None]*lenm + + for i in range(lenm): + if prevsimp[i]: + elems[i], prevsimp[i] = _dotprodsimp(m[i], withsimp=True) + else: + elems[i] = m[i] + + return m._new(m.rows, m.cols, elems) + + def _eval_scalar_mul(self, other): + return self._new(self.rows, self.cols, lambda i, j: self[i,j]*other) + + def _eval_scalar_rmul(self, other): + return self._new(self.rows, self.cols, lambda i, j: other*self[i,j]) + + def _eval_Mod(self, other): + return self._new(self.rows, self.cols, lambda i, j: Mod(self[i, j], other)) + + # Python arithmetic functions + def __abs__(self): + """Returns a new matrix with entry-wise absolute values.""" + return self._eval_Abs() + + @call_highest_priority('__radd__') + def __add__(self, other): + """Return self + other, raising ShapeError if shapes do not match.""" + if isinstance(other, NDimArray): # Matrix and array addition is currently not implemented + return NotImplemented + other = _matrixify(other) + # matrix-like objects can have shapes. This is + # our first sanity check. + if hasattr(other, 'shape'): + if self.shape != other.shape: + raise ShapeError("Matrix size mismatch: %s + %s" % ( + self.shape, other.shape)) + + # honest SymPy matrices defer to their class's routine + if getattr(other, 'is_Matrix', False): + # call the highest-priority class's _eval_add + a, b = self, other + if a.__class__ != classof(a, b): + b, a = a, b + return a._eval_add(b) + # Matrix-like objects can be passed to CommonMatrix routines directly. + if getattr(other, 'is_MatrixLike', False): + return MatrixArithmetic._eval_add(self, other) + + raise TypeError('cannot add %s and %s' % (type(self), type(other))) + + @call_highest_priority('__rtruediv__') + def __truediv__(self, other): + return self * (self.one / other) + + @call_highest_priority('__rmatmul__') + def __matmul__(self, other): + other = _matrixify(other) + if not getattr(other, 'is_Matrix', False) and not getattr(other, 'is_MatrixLike', False): + return NotImplemented + + return self.__mul__(other) + + def __mod__(self, other): + return self.applyfunc(lambda x: x % other) + + @call_highest_priority('__rmul__') + def __mul__(self, other): + """Return self*other where other is either a scalar or a matrix + of compatible dimensions. + + Examples + ======== + + >>> from sympy import Matrix + >>> A = Matrix([[1, 2, 3], [4, 5, 6]]) + >>> 2*A == A*2 == Matrix([[2, 4, 6], [8, 10, 12]]) + True + >>> B = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + >>> A*B + Matrix([ + [30, 36, 42], + [66, 81, 96]]) + >>> B*A + Traceback (most recent call last): + ... + ShapeError: Matrices size mismatch. + >>> + + See Also + ======== + + matrix_multiply_elementwise + """ + + return self.multiply(other) + + def multiply(self, other, dotprodsimp=None): + """Same as __mul__() but with optional simplification. + + Parameters + ========== + + dotprodsimp : bool, optional + Specifies whether intermediate term algebraic simplification is used + during matrix multiplications to control expression blowup and thus + speed up calculation. Default is off. + """ + + isimpbool = _get_intermediate_simp_bool(False, dotprodsimp) + other = _matrixify(other) + # matrix-like objects can have shapes. This is + # our first sanity check. Double check other is not explicitly not a Matrix. + if (hasattr(other, 'shape') and len(other.shape) == 2 and + (getattr(other, 'is_Matrix', True) or + getattr(other, 'is_MatrixLike', True))): + if self.shape[1] != other.shape[0]: + raise ShapeError("Matrix size mismatch: %s * %s." % ( + self.shape, other.shape)) + + # honest SymPy matrices defer to their class's routine + if getattr(other, 'is_Matrix', False): + m = self._eval_matrix_mul(other) + if isimpbool: + return m._new(m.rows, m.cols, [_dotprodsimp(e) for e in m]) + return m + + # Matrix-like objects can be passed to CommonMatrix routines directly. + if getattr(other, 'is_MatrixLike', False): + return MatrixArithmetic._eval_matrix_mul(self, other) + + # if 'other' is not iterable then scalar multiplication. + if not isinstance(other, Iterable): + try: + return self._eval_scalar_mul(other) + except TypeError: + pass + + return NotImplemented + + def multiply_elementwise(self, other): + """Return the Hadamard product (elementwise product) of A and B + + Examples + ======== + + >>> from sympy import Matrix + >>> A = Matrix([[0, 1, 2], [3, 4, 5]]) + >>> B = Matrix([[1, 10, 100], [100, 10, 1]]) + >>> A.multiply_elementwise(B) + Matrix([ + [ 0, 10, 200], + [300, 40, 5]]) + + See Also + ======== + + sympy.matrices.matrixbase.MatrixBase.cross + sympy.matrices.matrixbase.MatrixBase.dot + multiply + """ + if self.shape != other.shape: + raise ShapeError("Matrix shapes must agree {} != {}".format(self.shape, other.shape)) + + return self._eval_matrix_mul_elementwise(other) + + def __neg__(self): + return self._eval_scalar_mul(-1) + + @call_highest_priority('__rpow__') + def __pow__(self, exp): + """Return self**exp a scalar or symbol.""" + + return self.pow(exp) + + + def pow(self, exp, method=None): + r"""Return self**exp a scalar or symbol. + + Parameters + ========== + + method : multiply, mulsimp, jordan, cayley + If multiply then it returns exponentiation using recursion. + If jordan then Jordan form exponentiation will be used. + If cayley then the exponentiation is done using Cayley-Hamilton + theorem. + If mulsimp then the exponentiation is done using recursion + with dotprodsimp. This specifies whether intermediate term + algebraic simplification is used during naive matrix power to + control expression blowup and thus speed up calculation. + If None, then it heuristically decides which method to use. + + """ + + if method is not None and method not in ['multiply', 'mulsimp', 'jordan', 'cayley']: + raise TypeError('No such method') + if self.rows != self.cols: + raise NonSquareMatrixError() + a = self + jordan_pow = getattr(a, '_matrix_pow_by_jordan_blocks', None) + exp = sympify(exp) + + if exp.is_zero: + return a._new(a.rows, a.cols, lambda i, j: int(i == j)) + if exp == 1: + return a + + diagonal = getattr(a, 'is_diagonal', None) + if diagonal is not None and diagonal(): + return a._new(a.rows, a.cols, lambda i, j: a[i,j]**exp if i == j else 0) + + if exp.is_Number and exp % 1 == 0: + if a.rows == 1: + return a._new([[a[0]**exp]]) + if exp < 0: + exp = -exp + a = a.inv() + # When certain conditions are met, + # Jordan block algorithm is faster than + # computation by recursion. + if method == 'jordan': + try: + return jordan_pow(exp) + except MatrixError: + if method == 'jordan': + raise + + elif method == 'cayley': + if not exp.is_Number or exp % 1 != 0: + raise ValueError("cayley method is only valid for integer powers") + return a._eval_pow_by_cayley(exp) + + elif method == "mulsimp": + if not exp.is_Number or exp % 1 != 0: + raise ValueError("mulsimp method is only valid for integer powers") + return a._eval_pow_by_recursion_dotprodsimp(exp) + + elif method == "multiply": + if not exp.is_Number or exp % 1 != 0: + raise ValueError("multiply method is only valid for integer powers") + return a._eval_pow_by_recursion(exp) + + elif method is None and exp.is_Number and exp % 1 == 0: + if exp.is_Float: + exp = Integer(exp) + # Decide heuristically which method to apply + if a.rows == 2 and exp > 100000: + return jordan_pow(exp) + elif _get_intermediate_simp_bool(True, None): + return a._eval_pow_by_recursion_dotprodsimp(exp) + elif exp > 10000: + return a._eval_pow_by_cayley(exp) + else: + return a._eval_pow_by_recursion(exp) + + if jordan_pow: + try: + return jordan_pow(exp) + except NonInvertibleMatrixError: + # Raised by jordan_pow on zero determinant matrix unless exp is + # definitely known to be a non-negative integer. + # Here we raise if n is definitely not a non-negative integer + # but otherwise we can leave this as an unevaluated MatPow. + if exp.is_integer is False or exp.is_nonnegative is False: + raise + + from sympy.matrices.expressions import MatPow + return MatPow(a, exp) + + @call_highest_priority('__add__') + def __radd__(self, other): + return self + other + + @call_highest_priority('__matmul__') + def __rmatmul__(self, other): + other = _matrixify(other) + if not getattr(other, 'is_Matrix', False) and not getattr(other, 'is_MatrixLike', False): + return NotImplemented + + return self.__rmul__(other) + + @call_highest_priority('__mul__') + def __rmul__(self, other): + return self.rmultiply(other) + + def rmultiply(self, other, dotprodsimp=None): + """Same as __rmul__() but with optional simplification. + + Parameters + ========== + + dotprodsimp : bool, optional + Specifies whether intermediate term algebraic simplification is used + during matrix multiplications to control expression blowup and thus + speed up calculation. Default is off. + """ + isimpbool = _get_intermediate_simp_bool(False, dotprodsimp) + other = _matrixify(other) + # matrix-like objects can have shapes. This is + # our first sanity check. Double check other is not explicitly not a Matrix. + if (hasattr(other, 'shape') and len(other.shape) == 2 and + (getattr(other, 'is_Matrix', True) or + getattr(other, 'is_MatrixLike', True))): + if self.shape[0] != other.shape[1]: + raise ShapeError("Matrix size mismatch.") + + # honest SymPy matrices defer to their class's routine + if getattr(other, 'is_Matrix', False): + m = self._eval_matrix_rmul(other) + if isimpbool: + return m._new(m.rows, m.cols, [_dotprodsimp(e) for e in m]) + return m + # Matrix-like objects can be passed to CommonMatrix routines directly. + if getattr(other, 'is_MatrixLike', False): + return MatrixArithmetic._eval_matrix_rmul(self, other) + + # if 'other' is not iterable then scalar multiplication. + if not isinstance(other, Iterable): + try: + return self._eval_scalar_rmul(other) + except TypeError: + pass + + return NotImplemented + + @call_highest_priority('__sub__') + def __rsub__(self, a): + return (-self) + a + + @call_highest_priority('__rsub__') + def __sub__(self, a): + return self + (-a) + + +class MatrixCommon(MatrixArithmetic, MatrixOperations, MatrixProperties, + MatrixSpecial, MatrixShaping): + """All common matrix operations including basic arithmetic, shaping, + and special matrices like `zeros`, and `eye`.""" + _diff_wrt: bool = True + + +class _MinimalMatrix: + """Class providing the minimum functionality + for a matrix-like object and implementing every method + required for a `MatrixRequired`. This class does not have everything + needed to become a full-fledged SymPy object, but it will satisfy the + requirements of anything inheriting from `MatrixRequired`. If you wish + to make a specialized matrix type, make sure to implement these + methods and properties with the exception of `__init__` and `__repr__` + which are included for convenience.""" + + is_MatrixLike = True + _sympify = staticmethod(sympify) + _class_priority = 3 + zero = S.Zero + one = S.One + + is_Matrix = True + is_MatrixExpr = False + + @classmethod + def _new(cls, *args, **kwargs): + return cls(*args, **kwargs) + + def __init__(self, rows, cols=None, mat=None, copy=False): + if isfunction(mat): + # if we passed in a function, use that to populate the indices + mat = [mat(i, j) for i in range(rows) for j in range(cols)] + if cols is None and mat is None: + mat = rows + rows, cols = getattr(mat, 'shape', (rows, cols)) + try: + # if we passed in a list of lists, flatten it and set the size + if cols is None and mat is None: + mat = rows + cols = len(mat[0]) + rows = len(mat) + mat = [x for l in mat for x in l] + except (IndexError, TypeError): + pass + self.mat = tuple(self._sympify(x) for x in mat) + self.rows, self.cols = rows, cols + if self.rows is None or self.cols is None: + raise NotImplementedError("Cannot initialize matrix with given parameters") + + def __getitem__(self, key): + def _normalize_slices(row_slice, col_slice): + """Ensure that row_slice and col_slice do not have + `None` in their arguments. Any integers are converted + to slices of length 1""" + if not isinstance(row_slice, slice): + row_slice = slice(row_slice, row_slice + 1, None) + row_slice = slice(*row_slice.indices(self.rows)) + + if not isinstance(col_slice, slice): + col_slice = slice(col_slice, col_slice + 1, None) + col_slice = slice(*col_slice.indices(self.cols)) + + return (row_slice, col_slice) + + def _coord_to_index(i, j): + """Return the index in _mat corresponding + to the (i,j) position in the matrix. """ + return i * self.cols + j + + if isinstance(key, tuple): + i, j = key + if isinstance(i, slice) or isinstance(j, slice): + # if the coordinates are not slices, make them so + # and expand the slices so they don't contain `None` + i, j = _normalize_slices(i, j) + + rowsList, colsList = list(range(self.rows))[i], \ + list(range(self.cols))[j] + indices = (i * self.cols + j for i in rowsList for j in + colsList) + return self._new(len(rowsList), len(colsList), + [self.mat[i] for i in indices]) + + # if the key is a tuple of ints, change + # it to an array index + key = _coord_to_index(i, j) + return self.mat[key] + + def __eq__(self, other): + try: + classof(self, other) + except TypeError: + return False + return ( + self.shape == other.shape and list(self) == list(other)) + + def __len__(self): + return self.rows*self.cols + + def __repr__(self): + return "_MinimalMatrix({}, {}, {})".format(self.rows, self.cols, + self.mat) + + @property + def shape(self): + return (self.rows, self.cols) + + +class _CastableMatrix: # this is needed here ONLY FOR TESTS. + def as_mutable(self): + return self + + def as_immutable(self): + return self + + +class _MatrixWrapper: + """Wrapper class providing the minimum functionality for a matrix-like + object: .rows, .cols, .shape, indexability, and iterability. CommonMatrix + math operations should work on matrix-like objects. This one is intended for + matrix-like objects which use the same indexing format as SymPy with respect + to returning matrix elements instead of rows for non-tuple indexes. + """ + + is_Matrix = False # needs to be here because of __getattr__ + is_MatrixLike = True + + def __init__(self, mat, shape): + self.mat = mat + self.shape = shape + self.rows, self.cols = shape + + def __getitem__(self, key): + if isinstance(key, tuple): + return sympify(self.mat.__getitem__(key)) + + return sympify(self.mat.__getitem__((key // self.rows, key % self.cols))) + + def __iter__(self): # supports numpy.matrix and numpy.array + mat = self.mat + cols = self.cols + + return iter(sympify(mat[r, c]) for r in range(self.rows) for c in range(cols)) + + +def _matrixify(mat): + """If `mat` is a Matrix or is matrix-like, + return a Matrix or MatrixWrapper object. Otherwise + `mat` is passed through without modification.""" + + if getattr(mat, 'is_Matrix', False) or getattr(mat, 'is_MatrixLike', False): + return mat + + if not(getattr(mat, 'is_Matrix', True) or getattr(mat, 'is_MatrixLike', True)): + return mat + + shape = None + + if hasattr(mat, 'shape'): # numpy, scipy.sparse + if len(mat.shape) == 2: + shape = mat.shape + elif hasattr(mat, 'rows') and hasattr(mat, 'cols'): # mpmath + shape = (mat.rows, mat.cols) + + if shape: + return _MatrixWrapper(mat, shape) + + return mat + + +def a2idx(j, n=None): + """Return integer after making positive and validating against n.""" + if not isinstance(j, int): + jindex = getattr(j, '__index__', None) + if jindex is not None: + j = jindex() + else: + raise IndexError("Invalid index a[%r]" % (j,)) + if n is not None: + if j < 0: + j += n + if not (j >= 0 and j < n): + raise IndexError("Index out of range: a[%s]" % (j,)) + return int(j) + + +def classof(A, B): + """ + Get the type of the result when combining matrices of different types. + + Currently the strategy is that immutability is contagious. + + Examples + ======== + + >>> from sympy import Matrix, ImmutableMatrix + >>> from sympy.matrices.matrixbase import classof + >>> M = Matrix([[1, 2], [3, 4]]) # a Mutable Matrix + >>> IM = ImmutableMatrix([[1, 2], [3, 4]]) + >>> classof(M, IM) + + """ + priority_A = getattr(A, '_class_priority', None) + priority_B = getattr(B, '_class_priority', None) + if None not in (priority_A, priority_B): + if A._class_priority > B._class_priority: + return A.__class__ + else: + return B.__class__ + + try: + import numpy + except ImportError: + pass + else: + if isinstance(A, numpy.ndarray): + return B.__class__ + if isinstance(B, numpy.ndarray): + return A.__class__ + + raise TypeError("Incompatible classes %s, %s" % (A.__class__, B.__class__)) diff --git a/phivenv/Lib/site-packages/sympy/matrices/decompositions.py b/phivenv/Lib/site-packages/sympy/matrices/decompositions.py new file mode 100644 index 0000000000000000000000000000000000000000..a8dd466d84c957b870396a050fd25ec21e7113a3 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/decompositions.py @@ -0,0 +1,1621 @@ +import copy + +from sympy.core import S +from sympy.core.function import expand_mul +from sympy.functions.elementary.miscellaneous import Min, sqrt +from sympy.functions.elementary.complexes import sign + +from .exceptions import NonSquareMatrixError, NonPositiveDefiniteMatrixError +from .utilities import _get_intermediate_simp, _iszero +from .determinant import _find_reasonable_pivot_naive + + +def _rank_decomposition(M, iszerofunc=_iszero, simplify=False): + r"""Returns a pair of matrices (`C`, `F`) with matching rank + such that `A = C F`. + + Parameters + ========== + + iszerofunc : Function, optional + A function used for detecting whether an element can + act as a pivot. ``lambda x: x.is_zero`` is used by default. + + simplify : Bool or Function, optional + A function used to simplify elements when looking for a + pivot. By default SymPy's ``simplify`` is used. + + Returns + ======= + + (C, F) : Matrices + `C` and `F` are full-rank matrices with rank as same as `A`, + whose product gives `A`. + + See Notes for additional mathematical details. + + Examples + ======== + + >>> from sympy import Matrix + >>> A = Matrix([ + ... [1, 3, 1, 4], + ... [2, 7, 3, 9], + ... [1, 5, 3, 1], + ... [1, 2, 0, 8] + ... ]) + >>> C, F = A.rank_decomposition() + >>> C + Matrix([ + [1, 3, 4], + [2, 7, 9], + [1, 5, 1], + [1, 2, 8]]) + >>> F + Matrix([ + [1, 0, -2, 0], + [0, 1, 1, 0], + [0, 0, 0, 1]]) + >>> C * F == A + True + + Notes + ===== + + Obtaining `F`, an RREF of `A`, is equivalent to creating a + product + + .. math:: + E_n E_{n-1} ... E_1 A = F + + where `E_n, E_{n-1}, \dots, E_1` are the elimination matrices or + permutation matrices equivalent to each row-reduction step. + + The inverse of the same product of elimination matrices gives + `C`: + + .. math:: + C = \left(E_n E_{n-1} \dots E_1\right)^{-1} + + It is not necessary, however, to actually compute the inverse: + the columns of `C` are those from the original matrix with the + same column indices as the indices of the pivot columns of `F`. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Rank_factorization + + .. [2] Piziak, R.; Odell, P. L. (1 June 1999). + "Full Rank Factorization of Matrices". + Mathematics Magazine. 72 (3): 193. doi:10.2307/2690882 + + See Also + ======== + + sympy.matrices.matrixbase.MatrixBase.rref + """ + + F, pivot_cols = M.rref(simplify=simplify, iszerofunc=iszerofunc, + pivots=True) + rank = len(pivot_cols) + + C = M.extract(range(M.rows), pivot_cols) + F = F[:rank, :] + + return C, F + + +def _liupc(M): + """Liu's algorithm, for pre-determination of the Elimination Tree of + the given matrix, used in row-based symbolic Cholesky factorization. + + Examples + ======== + + >>> from sympy import SparseMatrix + >>> S = SparseMatrix([ + ... [1, 0, 3, 2], + ... [0, 0, 1, 0], + ... [4, 0, 0, 5], + ... [0, 6, 7, 0]]) + >>> S.liupc() + ([[0], [], [0], [1, 2]], [4, 3, 4, 4]) + + References + ========== + + .. [1] Symbolic Sparse Cholesky Factorization using Elimination Trees, + Jeroen Van Grondelle (1999) + https://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.39.7582 + """ + # Algorithm 2.4, p 17 of reference + + # get the indices of the elements that are non-zero on or below diag + R = [[] for r in range(M.rows)] + + for r, c, _ in M.row_list(): + if c <= r: + R[r].append(c) + + inf = len(R) # nothing will be this large + parent = [inf]*M.rows + virtual = [inf]*M.rows + + for r in range(M.rows): + for c in R[r][:-1]: + while virtual[c] < r: + t = virtual[c] + virtual[c] = r + c = t + + if virtual[c] == inf: + parent[c] = virtual[c] = r + + return R, parent + +def _row_structure_symbolic_cholesky(M): + """Symbolic cholesky factorization, for pre-determination of the + non-zero structure of the Cholesky factororization. + + Examples + ======== + + >>> from sympy import SparseMatrix + >>> S = SparseMatrix([ + ... [1, 0, 3, 2], + ... [0, 0, 1, 0], + ... [4, 0, 0, 5], + ... [0, 6, 7, 0]]) + >>> S.row_structure_symbolic_cholesky() + [[0], [], [0], [1, 2]] + + References + ========== + + .. [1] Symbolic Sparse Cholesky Factorization using Elimination Trees, + Jeroen Van Grondelle (1999) + https://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.39.7582 + """ + + R, parent = M.liupc() + inf = len(R) # this acts as infinity + Lrow = copy.deepcopy(R) + + for k in range(M.rows): + for j in R[k]: + while j != inf and j != k: + Lrow[k].append(j) + j = parent[j] + + Lrow[k] = sorted(set(Lrow[k])) + + return Lrow + + +def _cholesky(M, hermitian=True): + """Returns the Cholesky-type decomposition L of a matrix A + such that L * L.H == A if hermitian flag is True, + or L * L.T == A if hermitian is False. + + A must be a Hermitian positive-definite matrix if hermitian is True, + or a symmetric matrix if it is False. + + Examples + ======== + + >>> from sympy import Matrix + >>> A = Matrix(((25, 15, -5), (15, 18, 0), (-5, 0, 11))) + >>> A.cholesky() + Matrix([ + [ 5, 0, 0], + [ 3, 3, 0], + [-1, 1, 3]]) + >>> A.cholesky() * A.cholesky().T + Matrix([ + [25, 15, -5], + [15, 18, 0], + [-5, 0, 11]]) + + The matrix can have complex entries: + + >>> from sympy import I + >>> A = Matrix(((9, 3*I), (-3*I, 5))) + >>> A.cholesky() + Matrix([ + [ 3, 0], + [-I, 2]]) + >>> A.cholesky() * A.cholesky().H + Matrix([ + [ 9, 3*I], + [-3*I, 5]]) + + Non-hermitian Cholesky-type decomposition may be useful when the + matrix is not positive-definite. + + >>> A = Matrix([[1, 2], [2, 1]]) + >>> L = A.cholesky(hermitian=False) + >>> L + Matrix([ + [1, 0], + [2, sqrt(3)*I]]) + >>> L*L.T == A + True + + See Also + ======== + + sympy.matrices.dense.DenseMatrix.LDLdecomposition + sympy.matrices.matrixbase.MatrixBase.LUdecomposition + QRdecomposition + """ + + from .dense import MutableDenseMatrix + + if not M.is_square: + raise NonSquareMatrixError("Matrix must be square.") + if hermitian and not M.is_hermitian: + raise ValueError("Matrix must be Hermitian.") + if not hermitian and not M.is_symmetric(): + raise ValueError("Matrix must be symmetric.") + + L = MutableDenseMatrix.zeros(M.rows, M.rows) + + if hermitian: + for i in range(M.rows): + for j in range(i): + L[i, j] = ((1 / L[j, j])*(M[i, j] - + sum(L[i, k]*L[j, k].conjugate() for k in range(j)))) + + Lii2 = (M[i, i] - + sum(L[i, k]*L[i, k].conjugate() for k in range(i))) + + if Lii2.is_positive is False: + raise NonPositiveDefiniteMatrixError( + "Matrix must be positive-definite") + + L[i, i] = sqrt(Lii2) + + else: + for i in range(M.rows): + for j in range(i): + L[i, j] = ((1 / L[j, j])*(M[i, j] - + sum(L[i, k]*L[j, k] for k in range(j)))) + + L[i, i] = sqrt(M[i, i] - + sum(L[i, k]**2 for k in range(i))) + + return M._new(L) + +def _cholesky_sparse(M, hermitian=True): + """ + Returns the Cholesky decomposition L of a matrix A + such that L * L.T = A + + A must be a square, symmetric, positive-definite + and non-singular matrix + + Examples + ======== + + >>> from sympy import SparseMatrix + >>> A = SparseMatrix(((25,15,-5),(15,18,0),(-5,0,11))) + >>> A.cholesky() + Matrix([ + [ 5, 0, 0], + [ 3, 3, 0], + [-1, 1, 3]]) + >>> A.cholesky() * A.cholesky().T == A + True + + The matrix can have complex entries: + + >>> from sympy import I + >>> A = SparseMatrix(((9, 3*I), (-3*I, 5))) + >>> A.cholesky() + Matrix([ + [ 3, 0], + [-I, 2]]) + >>> A.cholesky() * A.cholesky().H + Matrix([ + [ 9, 3*I], + [-3*I, 5]]) + + Non-hermitian Cholesky-type decomposition may be useful when the + matrix is not positive-definite. + + >>> A = SparseMatrix([[1, 2], [2, 1]]) + >>> L = A.cholesky(hermitian=False) + >>> L + Matrix([ + [1, 0], + [2, sqrt(3)*I]]) + >>> L*L.T == A + True + + See Also + ======== + + sympy.matrices.sparse.SparseMatrix.LDLdecomposition + sympy.matrices.matrixbase.MatrixBase.LUdecomposition + QRdecomposition + """ + + from .dense import MutableDenseMatrix + + if not M.is_square: + raise NonSquareMatrixError("Matrix must be square.") + if hermitian and not M.is_hermitian: + raise ValueError("Matrix must be Hermitian.") + if not hermitian and not M.is_symmetric(): + raise ValueError("Matrix must be symmetric.") + + dps = _get_intermediate_simp(expand_mul, expand_mul) + Crowstruc = M.row_structure_symbolic_cholesky() + C = MutableDenseMatrix.zeros(M.rows) + + for i in range(len(Crowstruc)): + for j in Crowstruc[i]: + if i != j: + C[i, j] = M[i, j] + summ = 0 + + for p1 in Crowstruc[i]: + if p1 < j: + for p2 in Crowstruc[j]: + if p2 < j: + if p1 == p2: + if hermitian: + summ += C[i, p1]*C[j, p1].conjugate() + else: + summ += C[i, p1]*C[j, p1] + else: + break + else: + break + + C[i, j] = dps((C[i, j] - summ) / C[j, j]) + + else: # i == j + C[j, j] = M[j, j] + summ = 0 + + for k in Crowstruc[j]: + if k < j: + if hermitian: + summ += C[j, k]*C[j, k].conjugate() + else: + summ += C[j, k]**2 + else: + break + + Cjj2 = dps(C[j, j] - summ) + + if hermitian and Cjj2.is_positive is False: + raise NonPositiveDefiniteMatrixError( + "Matrix must be positive-definite") + + C[j, j] = sqrt(Cjj2) + + return M._new(C) + + +def _LDLdecomposition(M, hermitian=True): + """Returns the LDL Decomposition (L, D) of matrix A, + such that L * D * L.H == A if hermitian flag is True, or + L * D * L.T == A if hermitian is False. + This method eliminates the use of square root. + Further this ensures that all the diagonal entries of L are 1. + A must be a Hermitian positive-definite matrix if hermitian is True, + or a symmetric matrix otherwise. + + Examples + ======== + + >>> from sympy import Matrix, eye + >>> A = Matrix(((25, 15, -5), (15, 18, 0), (-5, 0, 11))) + >>> L, D = A.LDLdecomposition() + >>> L + Matrix([ + [ 1, 0, 0], + [ 3/5, 1, 0], + [-1/5, 1/3, 1]]) + >>> D + Matrix([ + [25, 0, 0], + [ 0, 9, 0], + [ 0, 0, 9]]) + >>> L * D * L.T * A.inv() == eye(A.rows) + True + + The matrix can have complex entries: + + >>> from sympy import I + >>> A = Matrix(((9, 3*I), (-3*I, 5))) + >>> L, D = A.LDLdecomposition() + >>> L + Matrix([ + [ 1, 0], + [-I/3, 1]]) + >>> D + Matrix([ + [9, 0], + [0, 4]]) + >>> L*D*L.H == A + True + + See Also + ======== + + sympy.matrices.dense.DenseMatrix.cholesky + sympy.matrices.matrixbase.MatrixBase.LUdecomposition + QRdecomposition + """ + + from .dense import MutableDenseMatrix + + if not M.is_square: + raise NonSquareMatrixError("Matrix must be square.") + if hermitian and not M.is_hermitian: + raise ValueError("Matrix must be Hermitian.") + if not hermitian and not M.is_symmetric(): + raise ValueError("Matrix must be symmetric.") + + D = MutableDenseMatrix.zeros(M.rows, M.rows) + L = MutableDenseMatrix.eye(M.rows) + + if hermitian: + for i in range(M.rows): + for j in range(i): + L[i, j] = (1 / D[j, j])*(M[i, j] - sum( + L[i, k]*L[j, k].conjugate()*D[k, k] for k in range(j))) + + D[i, i] = (M[i, i] - + sum(L[i, k]*L[i, k].conjugate()*D[k, k] for k in range(i))) + + if D[i, i].is_positive is False: + raise NonPositiveDefiniteMatrixError( + "Matrix must be positive-definite") + + else: + for i in range(M.rows): + for j in range(i): + L[i, j] = (1 / D[j, j])*(M[i, j] - sum( + L[i, k]*L[j, k]*D[k, k] for k in range(j))) + + D[i, i] = M[i, i] - sum(L[i, k]**2*D[k, k] for k in range(i)) + + return M._new(L), M._new(D) + +def _LDLdecomposition_sparse(M, hermitian=True): + """ + Returns the LDL Decomposition (matrices ``L`` and ``D``) of matrix + ``A``, such that ``L * D * L.T == A``. ``A`` must be a square, + symmetric, positive-definite and non-singular. + + This method eliminates the use of square root and ensures that all + the diagonal entries of L are 1. + + Examples + ======== + + >>> from sympy import SparseMatrix + >>> A = SparseMatrix(((25, 15, -5), (15, 18, 0), (-5, 0, 11))) + >>> L, D = A.LDLdecomposition() + >>> L + Matrix([ + [ 1, 0, 0], + [ 3/5, 1, 0], + [-1/5, 1/3, 1]]) + >>> D + Matrix([ + [25, 0, 0], + [ 0, 9, 0], + [ 0, 0, 9]]) + >>> L * D * L.T == A + True + + """ + + from .dense import MutableDenseMatrix + + if not M.is_square: + raise NonSquareMatrixError("Matrix must be square.") + if hermitian and not M.is_hermitian: + raise ValueError("Matrix must be Hermitian.") + if not hermitian and not M.is_symmetric(): + raise ValueError("Matrix must be symmetric.") + + dps = _get_intermediate_simp(expand_mul, expand_mul) + Lrowstruc = M.row_structure_symbolic_cholesky() + L = MutableDenseMatrix.eye(M.rows) + D = MutableDenseMatrix.zeros(M.rows, M.cols) + + for i in range(len(Lrowstruc)): + for j in Lrowstruc[i]: + if i != j: + L[i, j] = M[i, j] + summ = 0 + + for p1 in Lrowstruc[i]: + if p1 < j: + for p2 in Lrowstruc[j]: + if p2 < j: + if p1 == p2: + if hermitian: + summ += L[i, p1]*L[j, p1].conjugate()*D[p1, p1] + else: + summ += L[i, p1]*L[j, p1]*D[p1, p1] + else: + break + else: + break + + L[i, j] = dps((L[i, j] - summ) / D[j, j]) + + else: # i == j + D[i, i] = M[i, i] + summ = 0 + + for k in Lrowstruc[i]: + if k < i: + if hermitian: + summ += L[i, k]*L[i, k].conjugate()*D[k, k] + else: + summ += L[i, k]**2*D[k, k] + else: + break + + D[i, i] = dps(D[i, i] - summ) + + if hermitian and D[i, i].is_positive is False: + raise NonPositiveDefiniteMatrixError( + "Matrix must be positive-definite") + + return M._new(L), M._new(D) + + +def _LUdecomposition(M, iszerofunc=_iszero, simpfunc=None, rankcheck=False): + """Returns (L, U, perm) where L is a lower triangular matrix with unit + diagonal, U is an upper triangular matrix, and perm is a list of row + swap index pairs. If A is the original matrix, then + ``A = (L*U).permuteBkwd(perm)``, and the row permutation matrix P such + that $P A = L U$ can be computed by ``P = eye(A.rows).permuteFwd(perm)``. + + See documentation for LUCombined for details about the keyword argument + rankcheck, iszerofunc, and simpfunc. + + Parameters + ========== + + rankcheck : bool, optional + Determines if this function should detect the rank + deficiency of the matrixis and should raise a + ``ValueError``. + + iszerofunc : function, optional + A function which determines if a given expression is zero. + + The function should be a callable that takes a single + SymPy expression and returns a 3-valued boolean value + ``True``, ``False``, or ``None``. + + It is internally used by the pivot searching algorithm. + See the notes section for a more information about the + pivot searching algorithm. + + simpfunc : function or None, optional + A function that simplifies the input. + + If this is specified as a function, this function should be + a callable that takes a single SymPy expression and returns + an another SymPy expression that is algebraically + equivalent. + + If ``None``, it indicates that the pivot search algorithm + should not attempt to simplify any candidate pivots. + + It is internally used by the pivot searching algorithm. + See the notes section for a more information about the + pivot searching algorithm. + + Examples + ======== + + >>> from sympy import Matrix + >>> a = Matrix([[4, 3], [6, 3]]) + >>> L, U, _ = a.LUdecomposition() + >>> L + Matrix([ + [ 1, 0], + [3/2, 1]]) + >>> U + Matrix([ + [4, 3], + [0, -3/2]]) + + See Also + ======== + + sympy.matrices.dense.DenseMatrix.cholesky + sympy.matrices.dense.DenseMatrix.LDLdecomposition + QRdecomposition + LUdecomposition_Simple + LUdecompositionFF + LUsolve + """ + + combined, p = M.LUdecomposition_Simple(iszerofunc=iszerofunc, + simpfunc=simpfunc, rankcheck=rankcheck) + + # L is lower triangular ``M.rows x M.rows`` + # U is upper triangular ``M.rows x M.cols`` + # L has unit diagonal. For each column in combined, the subcolumn + # below the diagonal of combined is shared by L. + # If L has more columns than combined, then the remaining subcolumns + # below the diagonal of L are zero. + # The upper triangular portion of L and combined are equal. + def entry_L(i, j): + if i < j: + # Super diagonal entry + return M.zero + elif i == j: + return M.one + elif j < combined.cols: + return combined[i, j] + + # Subdiagonal entry of L with no corresponding + # entry in combined + return M.zero + + def entry_U(i, j): + return M.zero if i > j else combined[i, j] + + L = M._new(combined.rows, combined.rows, entry_L) + U = M._new(combined.rows, combined.cols, entry_U) + + return L, U, p + +def _LUdecomposition_Simple(M, iszerofunc=_iszero, simpfunc=None, + rankcheck=False): + r"""Compute the PLU decomposition of the matrix. + + Parameters + ========== + + rankcheck : bool, optional + Determines if this function should detect the rank + deficiency of the matrixis and should raise a + ``ValueError``. + + iszerofunc : function, optional + A function which determines if a given expression is zero. + + The function should be a callable that takes a single + SymPy expression and returns a 3-valued boolean value + ``True``, ``False``, or ``None``. + + It is internally used by the pivot searching algorithm. + See the notes section for a more information about the + pivot searching algorithm. + + simpfunc : function or None, optional + A function that simplifies the input. + + If this is specified as a function, this function should be + a callable that takes a single SymPy expression and returns + an another SymPy expression that is algebraically + equivalent. + + If ``None``, it indicates that the pivot search algorithm + should not attempt to simplify any candidate pivots. + + It is internally used by the pivot searching algorithm. + See the notes section for a more information about the + pivot searching algorithm. + + Returns + ======= + + (lu, row_swaps) : (Matrix, list) + If the original matrix is a $m, n$ matrix: + + *lu* is a $m, n$ matrix, which contains result of the + decomposition in a compressed form. See the notes section + to see how the matrix is compressed. + + *row_swaps* is a $m$-element list where each element is a + pair of row exchange indices. + + ``A = (L*U).permute_backward(perm)``, and the row + permutation matrix $P$ from the formula $P A = L U$ can be + computed by ``P=eye(A.row).permute_forward(perm)``. + + Raises + ====== + + ValueError + Raised if ``rankcheck=True`` and the matrix is found to + be rank deficient during the computation. + + Notes + ===== + + About the PLU decomposition: + + PLU decomposition is a generalization of a LU decomposition + which can be extended for rank-deficient matrices. + + It can further be generalized for non-square matrices, and this + is the notation that SymPy is using. + + PLU decomposition is a decomposition of a $m, n$ matrix $A$ in + the form of $P A = L U$ where + + * $L$ is a $m, m$ lower triangular matrix with unit diagonal + entries. + * $U$ is a $m, n$ upper triangular matrix. + * $P$ is a $m, m$ permutation matrix. + + So, for a square matrix, the decomposition would look like: + + .. math:: + L = \begin{bmatrix} + 1 & 0 & 0 & \cdots & 0 \\ + L_{1, 0} & 1 & 0 & \cdots & 0 \\ + L_{2, 0} & L_{2, 1} & 1 & \cdots & 0 \\ + \vdots & \vdots & \vdots & \ddots & \vdots \\ + L_{n-1, 0} & L_{n-1, 1} & L_{n-1, 2} & \cdots & 1 + \end{bmatrix} + + .. math:: + U = \begin{bmatrix} + U_{0, 0} & U_{0, 1} & U_{0, 2} & \cdots & U_{0, n-1} \\ + 0 & U_{1, 1} & U_{1, 2} & \cdots & U_{1, n-1} \\ + 0 & 0 & U_{2, 2} & \cdots & U_{2, n-1} \\ + \vdots & \vdots & \vdots & \ddots & \vdots \\ + 0 & 0 & 0 & \cdots & U_{n-1, n-1} + \end{bmatrix} + + And for a matrix with more rows than the columns, + the decomposition would look like: + + .. math:: + L = \begin{bmatrix} + 1 & 0 & 0 & \cdots & 0 & 0 & \cdots & 0 \\ + L_{1, 0} & 1 & 0 & \cdots & 0 & 0 & \cdots & 0 \\ + L_{2, 0} & L_{2, 1} & 1 & \cdots & 0 & 0 & \cdots & 0 \\ + \vdots & \vdots & \vdots & \ddots & \vdots & \vdots & \ddots + & \vdots \\ + L_{n-1, 0} & L_{n-1, 1} & L_{n-1, 2} & \cdots & 1 & 0 + & \cdots & 0 \\ + L_{n, 0} & L_{n, 1} & L_{n, 2} & \cdots & L_{n, n-1} & 1 + & \cdots & 0 \\ + \vdots & \vdots & \vdots & \ddots & \vdots & \vdots + & \ddots & \vdots \\ + L_{m-1, 0} & L_{m-1, 1} & L_{m-1, 2} & \cdots & L_{m-1, n-1} + & 0 & \cdots & 1 \\ + \end{bmatrix} + + .. math:: + U = \begin{bmatrix} + U_{0, 0} & U_{0, 1} & U_{0, 2} & \cdots & U_{0, n-1} \\ + 0 & U_{1, 1} & U_{1, 2} & \cdots & U_{1, n-1} \\ + 0 & 0 & U_{2, 2} & \cdots & U_{2, n-1} \\ + \vdots & \vdots & \vdots & \ddots & \vdots \\ + 0 & 0 & 0 & \cdots & U_{n-1, n-1} \\ + 0 & 0 & 0 & \cdots & 0 \\ + \vdots & \vdots & \vdots & \ddots & \vdots \\ + 0 & 0 & 0 & \cdots & 0 + \end{bmatrix} + + Finally, for a matrix with more columns than the rows, the + decomposition would look like: + + .. math:: + L = \begin{bmatrix} + 1 & 0 & 0 & \cdots & 0 \\ + L_{1, 0} & 1 & 0 & \cdots & 0 \\ + L_{2, 0} & L_{2, 1} & 1 & \cdots & 0 \\ + \vdots & \vdots & \vdots & \ddots & \vdots \\ + L_{m-1, 0} & L_{m-1, 1} & L_{m-1, 2} & \cdots & 1 + \end{bmatrix} + + .. math:: + U = \begin{bmatrix} + U_{0, 0} & U_{0, 1} & U_{0, 2} & \cdots & U_{0, m-1} + & \cdots & U_{0, n-1} \\ + 0 & U_{1, 1} & U_{1, 2} & \cdots & U_{1, m-1} + & \cdots & U_{1, n-1} \\ + 0 & 0 & U_{2, 2} & \cdots & U_{2, m-1} + & \cdots & U_{2, n-1} \\ + \vdots & \vdots & \vdots & \ddots & \vdots + & \cdots & \vdots \\ + 0 & 0 & 0 & \cdots & U_{m-1, m-1} + & \cdots & U_{m-1, n-1} \\ + \end{bmatrix} + + About the compressed LU storage: + + The results of the decomposition are often stored in compressed + forms rather than returning $L$ and $U$ matrices individually. + + It may be less intiuitive, but it is commonly used for a lot of + numeric libraries because of the efficiency. + + The storage matrix is defined as following for this specific + method: + + * The subdiagonal elements of $L$ are stored in the subdiagonal + portion of $LU$, that is $LU_{i, j} = L_{i, j}$ whenever + $i > j$. + * The elements on the diagonal of $L$ are all 1, and are not + explicitly stored. + * $U$ is stored in the upper triangular portion of $LU$, that is + $LU_{i, j} = U_{i, j}$ whenever $i <= j$. + * For a case of $m > n$, the right side of the $L$ matrix is + trivial to store. + * For a case of $m < n$, the below side of the $U$ matrix is + trivial to store. + + So, for a square matrix, the compressed output matrix would be: + + .. math:: + LU = \begin{bmatrix} + U_{0, 0} & U_{0, 1} & U_{0, 2} & \cdots & U_{0, n-1} \\ + L_{1, 0} & U_{1, 1} & U_{1, 2} & \cdots & U_{1, n-1} \\ + L_{2, 0} & L_{2, 1} & U_{2, 2} & \cdots & U_{2, n-1} \\ + \vdots & \vdots & \vdots & \ddots & \vdots \\ + L_{n-1, 0} & L_{n-1, 1} & L_{n-1, 2} & \cdots & U_{n-1, n-1} + \end{bmatrix} + + For a matrix with more rows than the columns, the compressed + output matrix would be: + + .. math:: + LU = \begin{bmatrix} + U_{0, 0} & U_{0, 1} & U_{0, 2} & \cdots & U_{0, n-1} \\ + L_{1, 0} & U_{1, 1} & U_{1, 2} & \cdots & U_{1, n-1} \\ + L_{2, 0} & L_{2, 1} & U_{2, 2} & \cdots & U_{2, n-1} \\ + \vdots & \vdots & \vdots & \ddots & \vdots \\ + L_{n-1, 0} & L_{n-1, 1} & L_{n-1, 2} & \cdots + & U_{n-1, n-1} \\ + \vdots & \vdots & \vdots & \ddots & \vdots \\ + L_{m-1, 0} & L_{m-1, 1} & L_{m-1, 2} & \cdots + & L_{m-1, n-1} \\ + \end{bmatrix} + + For a matrix with more columns than the rows, the compressed + output matrix would be: + + .. math:: + LU = \begin{bmatrix} + U_{0, 0} & U_{0, 1} & U_{0, 2} & \cdots & U_{0, m-1} + & \cdots & U_{0, n-1} \\ + L_{1, 0} & U_{1, 1} & U_{1, 2} & \cdots & U_{1, m-1} + & \cdots & U_{1, n-1} \\ + L_{2, 0} & L_{2, 1} & U_{2, 2} & \cdots & U_{2, m-1} + & \cdots & U_{2, n-1} \\ + \vdots & \vdots & \vdots & \ddots & \vdots + & \cdots & \vdots \\ + L_{m-1, 0} & L_{m-1, 1} & L_{m-1, 2} & \cdots & U_{m-1, m-1} + & \cdots & U_{m-1, n-1} \\ + \end{bmatrix} + + About the pivot searching algorithm: + + When a matrix contains symbolic entries, the pivot search algorithm + differs from the case where every entry can be categorized as zero or + nonzero. + The algorithm searches column by column through the submatrix whose + top left entry coincides with the pivot position. + If it exists, the pivot is the first entry in the current search + column that iszerofunc guarantees is nonzero. + If no such candidate exists, then each candidate pivot is simplified + if simpfunc is not None. + The search is repeated, with the difference that a candidate may be + the pivot if ``iszerofunc()`` cannot guarantee that it is nonzero. + In the second search the pivot is the first candidate that + iszerofunc can guarantee is nonzero. + If no such candidate exists, then the pivot is the first candidate + for which iszerofunc returns None. + If no such candidate exists, then the search is repeated in the next + column to the right. + The pivot search algorithm differs from the one in ``rref()``, which + relies on ``_find_reasonable_pivot()``. + Future versions of ``LUdecomposition_simple()`` may use + ``_find_reasonable_pivot()``. + + See Also + ======== + + sympy.matrices.matrixbase.MatrixBase.LUdecomposition + LUdecompositionFF + LUsolve + """ + + if rankcheck: + # https://github.com/sympy/sympy/issues/9796 + pass + + if S.Zero in M.shape: + # Define LU decomposition of a matrix with no entries as a matrix + # of the same dimensions with all zero entries. + return M.zeros(M.rows, M.cols), [] + + dps = _get_intermediate_simp() + lu = M.as_mutable() + row_swaps = [] + + pivot_col = 0 + + for pivot_row in range(0, lu.rows - 1): + # Search for pivot. Prefer entry that iszeropivot determines + # is nonzero, over entry that iszeropivot cannot guarantee + # is zero. + # XXX ``_find_reasonable_pivot`` uses slow zero testing. Blocked by bug #10279 + # Future versions of LUdecomposition_simple can pass iszerofunc and simpfunc + # to _find_reasonable_pivot(). + # In pass 3 of _find_reasonable_pivot(), the predicate in ``if x.equals(S.Zero):`` + # calls sympy.simplify(), and not the simplification function passed in via + # the keyword argument simpfunc. + iszeropivot = True + + while pivot_col != M.cols and iszeropivot: + sub_col = (lu[r, pivot_col] for r in range(pivot_row, M.rows)) + + pivot_row_offset, pivot_value, is_assumed_non_zero, ind_simplified_pairs =\ + _find_reasonable_pivot_naive(sub_col, iszerofunc, simpfunc) + + iszeropivot = pivot_value is None + + if iszeropivot: + # All candidate pivots in this column are zero. + # Proceed to next column. + pivot_col += 1 + + if rankcheck and pivot_col != pivot_row: + # All entries including and below the pivot position are + # zero, which indicates that the rank of the matrix is + # strictly less than min(num rows, num cols) + # Mimic behavior of previous implementation, by throwing a + # ValueError. + raise ValueError("Rank of matrix is strictly less than" + " number of rows or columns." + " Pass keyword argument" + " rankcheck=False to compute" + " the LU decomposition of this matrix.") + + candidate_pivot_row = None if pivot_row_offset is None else pivot_row + pivot_row_offset + + if candidate_pivot_row is None and iszeropivot: + # If candidate_pivot_row is None and iszeropivot is True + # after pivot search has completed, then the submatrix + # below and to the right of (pivot_row, pivot_col) is + # all zeros, indicating that Gaussian elimination is + # complete. + return lu, row_swaps + + # Update entries simplified during pivot search. + for offset, val in ind_simplified_pairs: + lu[pivot_row + offset, pivot_col] = val + + if pivot_row != candidate_pivot_row: + # Row swap book keeping: + # Record which rows were swapped. + # Update stored portion of L factor by multiplying L on the + # left and right with the current permutation. + # Swap rows of U. + row_swaps.append([pivot_row, candidate_pivot_row]) + + # Update L. + lu[pivot_row, 0:pivot_row], lu[candidate_pivot_row, 0:pivot_row] = \ + lu[candidate_pivot_row, 0:pivot_row], lu[pivot_row, 0:pivot_row] + + # Swap pivot row of U with candidate pivot row. + lu[pivot_row, pivot_col:lu.cols], lu[candidate_pivot_row, pivot_col:lu.cols] = \ + lu[candidate_pivot_row, pivot_col:lu.cols], lu[pivot_row, pivot_col:lu.cols] + + # Introduce zeros below the pivot by adding a multiple of the + # pivot row to a row under it, and store the result in the + # row under it. + # Only entries in the target row whose index is greater than + # start_col may be nonzero. + start_col = pivot_col + 1 + + for row in range(pivot_row + 1, lu.rows): + # Store factors of L in the subcolumn below + # (pivot_row, pivot_row). + lu[row, pivot_row] = \ + dps(lu[row, pivot_col]/lu[pivot_row, pivot_col]) + + # Form the linear combination of the pivot row and the current + # row below the pivot row that zeros the entries below the pivot. + # Employing slicing instead of a loop here raises + # NotImplementedError: Cannot add Zero to MutableSparseMatrix + # in sympy/matrices/tests/test_sparse.py. + # c = pivot_row + 1 if pivot_row == pivot_col else pivot_col + for c in range(start_col, lu.cols): + lu[row, c] = dps(lu[row, c] - lu[row, pivot_row]*lu[pivot_row, c]) + + if pivot_row != pivot_col: + # matrix rank < min(num rows, num cols), + # so factors of L are not stored directly below the pivot. + # These entries are zero by construction, so don't bother + # computing them. + for row in range(pivot_row + 1, lu.rows): + lu[row, pivot_col] = M.zero + + pivot_col += 1 + + if pivot_col == lu.cols: + # All candidate pivots are zero implies that Gaussian + # elimination is complete. + return lu, row_swaps + + if rankcheck: + if iszerofunc( + lu[Min(lu.rows, lu.cols) - 1, Min(lu.rows, lu.cols) - 1]): + raise ValueError("Rank of matrix is strictly less than" + " number of rows or columns." + " Pass keyword argument" + " rankcheck=False to compute" + " the LU decomposition of this matrix.") + + return lu, row_swaps + +def _LUdecompositionFF(M): + """Compute a fraction-free LU decomposition. + + Returns 4 matrices P, L, D, U such that PA = L D**-1 U. + If the elements of the matrix belong to some integral domain I, then all + elements of L, D and U are guaranteed to belong to I. + + See Also + ======== + + sympy.matrices.matrixbase.MatrixBase.LUdecomposition + LUdecomposition_Simple + LUsolve + + References + ========== + + .. [1] W. Zhou & D.J. Jeffrey, "Fraction-free matrix factors: new forms + for LU and QR factors". Frontiers in Computer Science in China, + Vol 2, no. 1, pp. 67-80, 2008. + """ + + from sympy.matrices import SparseMatrix + + zeros = SparseMatrix.zeros + eye = SparseMatrix.eye + n, m = M.rows, M.cols + U, L, P = M.as_mutable(), eye(n), eye(n) + DD = zeros(n, n) + oldpivot = 1 + + for k in range(n - 1): + if U[k, k] == 0: + for kpivot in range(k + 1, n): + if U[kpivot, k]: + break + else: + raise ValueError("Matrix is not full rank") + + U[k, k:], U[kpivot, k:] = U[kpivot, k:], U[k, k:] + L[k, :k], L[kpivot, :k] = L[kpivot, :k], L[k, :k] + P[k, :], P[kpivot, :] = P[kpivot, :], P[k, :] + + L [k, k] = Ukk = U[k, k] + DD[k, k] = oldpivot * Ukk + + for i in range(k + 1, n): + L[i, k] = Uik = U[i, k] + + for j in range(k + 1, m): + U[i, j] = (Ukk * U[i, j] - U[k, j] * Uik) / oldpivot + + U[i, k] = 0 + + oldpivot = Ukk + + DD[n - 1, n - 1] = oldpivot + + return P, L, DD, U + +def _singular_value_decomposition(A): + r"""Returns a Condensed Singular Value decomposition. + + Explanation + =========== + + A Singular Value decomposition is a decomposition in the form $A = U \Sigma V^H$ + where + + - $U, V$ are column orthogonal matrix. + - $\Sigma$ is a diagonal matrix, where the main diagonal contains singular + values of matrix A. + + A column orthogonal matrix satisfies + $\mathbb{I} = U^H U$ while a full orthogonal matrix satisfies + relation $\mathbb{I} = U U^H = U^H U$ where $\mathbb{I}$ is an identity + matrix with matching dimensions. + + For matrices which are not square or are rank-deficient, it is + sufficient to return a column orthogonal matrix because augmenting + them may introduce redundant computations. + In condensed Singular Value Decomposition we only return column orthogonal + matrices because of this reason + + If you want to augment the results to return a full orthogonal + decomposition, you should use the following procedures. + + - Augment the $U, V$ matrices with columns that are orthogonal to every + other columns and make it square. + - Augment the $\Sigma$ matrix with zero rows to make it have the same + shape as the original matrix. + + The procedure will be illustrated in the examples section. + + Examples + ======== + + we take a full rank matrix first: + + >>> from sympy import Matrix + >>> A = Matrix([[1, 2],[2,1]]) + >>> U, S, V = A.singular_value_decomposition() + >>> U + Matrix([ + [ sqrt(2)/2, sqrt(2)/2], + [-sqrt(2)/2, sqrt(2)/2]]) + >>> S + Matrix([ + [1, 0], + [0, 3]]) + >>> V + Matrix([ + [-sqrt(2)/2, sqrt(2)/2], + [ sqrt(2)/2, sqrt(2)/2]]) + + If a matrix if square and full rank both U, V + are orthogonal in both directions + + >>> U * U.H + Matrix([ + [1, 0], + [0, 1]]) + >>> U.H * U + Matrix([ + [1, 0], + [0, 1]]) + + >>> V * V.H + Matrix([ + [1, 0], + [0, 1]]) + >>> V.H * V + Matrix([ + [1, 0], + [0, 1]]) + >>> A == U * S * V.H + True + + >>> C = Matrix([ + ... [1, 0, 0, 0, 2], + ... [0, 0, 3, 0, 0], + ... [0, 0, 0, 0, 0], + ... [0, 2, 0, 0, 0], + ... ]) + >>> U, S, V = C.singular_value_decomposition() + + >>> V.H * V + Matrix([ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1]]) + >>> V * V.H + Matrix([ + [1/5, 0, 0, 0, 2/5], + [ 0, 1, 0, 0, 0], + [ 0, 0, 1, 0, 0], + [ 0, 0, 0, 0, 0], + [2/5, 0, 0, 0, 4/5]]) + + If you want to augment the results to be a full orthogonal + decomposition, you should augment $V$ with an another orthogonal + column. + + You are able to append an arbitrary standard basis that are linearly + independent to every other columns and you can run the Gram-Schmidt + process to make them augmented as orthogonal basis. + + >>> V_aug = V.row_join(Matrix([[0,0,0,0,1], + ... [0,0,0,1,0]]).H) + >>> V_aug = V_aug.QRdecomposition()[0] + >>> V_aug + Matrix([ + [0, sqrt(5)/5, 0, -2*sqrt(5)/5, 0], + [1, 0, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 0, 1], + [0, 2*sqrt(5)/5, 0, sqrt(5)/5, 0]]) + >>> V_aug.H * V_aug + Matrix([ + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 1]]) + >>> V_aug * V_aug.H + Matrix([ + [1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 1]]) + + Similarly we augment U + + >>> U_aug = U.row_join(Matrix([0,0,1,0])) + >>> U_aug = U_aug.QRdecomposition()[0] + >>> U_aug + Matrix([ + [0, 1, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1], + [1, 0, 0, 0]]) + + >>> U_aug.H * U_aug + Matrix([ + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1]]) + >>> U_aug * U_aug.H + Matrix([ + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1]]) + + We add 2 zero columns and one row to S + + >>> S_aug = S.col_join(Matrix([[0,0,0]])) + >>> S_aug = S_aug.row_join(Matrix([[0,0,0,0], + ... [0,0,0,0]]).H) + >>> S_aug + Matrix([ + [2, 0, 0, 0, 0], + [0, sqrt(5), 0, 0, 0], + [0, 0, 3, 0, 0], + [0, 0, 0, 0, 0]]) + + + + >>> U_aug * S_aug * V_aug.H == C + True + + """ + + AH = A.H + m, n = A.shape + if m >= n: + V, S = (AH * A).diagonalize() + + ranked = [] + for i, x in enumerate(S.diagonal()): + if not x.is_zero: + ranked.append(i) + + V = V[:, ranked] + + Singular_vals = [sqrt(S[i, i]) for i in range(S.rows) if i in ranked] + + S = S.diag(*Singular_vals) + V, _ = V.QRdecomposition() + U = A * V * S.inv() + else: + U, S = (A * AH).diagonalize() + + ranked = [] + for i, x in enumerate(S.diagonal()): + if not x.is_zero: + ranked.append(i) + + U = U[:, ranked] + Singular_vals = [sqrt(S[i, i]) for i in range(S.rows) if i in ranked] + + S = S.diag(*Singular_vals) + U, _ = U.QRdecomposition() + V = AH * U * S.inv() + + return U, S, V + +def _QRdecomposition_optional(M, normalize=True): + def dot(u, v): + return u.dot(v, hermitian=True) + + dps = _get_intermediate_simp(expand_mul, expand_mul) + + A = M.as_mutable() + ranked = [] + + Q = A + R = A.zeros(A.cols) + + for j in range(A.cols): + for i in range(j): + if Q[:, i].is_zero_matrix: + continue + + R[i, j] = dot(Q[:, i], Q[:, j]) / dot(Q[:, i], Q[:, i]) + R[i, j] = dps(R[i, j]) + Q[:, j] -= Q[:, i] * R[i, j] + + Q[:, j] = dps(Q[:, j]) + if Q[:, j].is_zero_matrix is not True: + ranked.append(j) + R[j, j] = M.one + + Q = Q.extract(range(Q.rows), ranked) + R = R.extract(ranked, range(R.cols)) + + if normalize: + # Normalization + for i in range(Q.cols): + norm = Q[:, i].norm() + Q[:, i] /= norm + R[i, :] *= norm + + return M.__class__(Q), M.__class__(R) + + +def _QRdecomposition(M): + r"""Returns a QR decomposition. + + Explanation + =========== + + A QR decomposition is a decomposition in the form $A = Q R$ + where + + - $Q$ is a column orthogonal matrix. + - $R$ is a upper triangular (trapezoidal) matrix. + + A column orthogonal matrix satisfies + $\mathbb{I} = Q^H Q$ while a full orthogonal matrix satisfies + relation $\mathbb{I} = Q Q^H = Q^H Q$ where $I$ is an identity + matrix with matching dimensions. + + For matrices which are not square or are rank-deficient, it is + sufficient to return a column orthogonal matrix because augmenting + them may introduce redundant computations. + And an another advantage of this is that you can easily inspect the + matrix rank by counting the number of columns of $Q$. + + If you want to augment the results to return a full orthogonal + decomposition, you should use the following procedures. + + - Augment the $Q$ matrix with columns that are orthogonal to every + other columns and make it square. + - Augment the $R$ matrix with zero rows to make it have the same + shape as the original matrix. + + The procedure will be illustrated in the examples section. + + Examples + ======== + + A full rank matrix example: + + >>> from sympy import Matrix + >>> A = Matrix([[12, -51, 4], [6, 167, -68], [-4, 24, -41]]) + >>> Q, R = A.QRdecomposition() + >>> Q + Matrix([ + [ 6/7, -69/175, -58/175], + [ 3/7, 158/175, 6/175], + [-2/7, 6/35, -33/35]]) + >>> R + Matrix([ + [14, 21, -14], + [ 0, 175, -70], + [ 0, 0, 35]]) + + If the matrix is square and full rank, the $Q$ matrix becomes + orthogonal in both directions, and needs no augmentation. + + >>> Q * Q.H + Matrix([ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1]]) + >>> Q.H * Q + Matrix([ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1]]) + + >>> A == Q*R + True + + A rank deficient matrix example: + + >>> A = Matrix([[12, -51, 0], [6, 167, 0], [-4, 24, 0]]) + >>> Q, R = A.QRdecomposition() + >>> Q + Matrix([ + [ 6/7, -69/175], + [ 3/7, 158/175], + [-2/7, 6/35]]) + >>> R + Matrix([ + [14, 21, 0], + [ 0, 175, 0]]) + + QRdecomposition might return a matrix Q that is rectangular. + In this case the orthogonality condition might be satisfied as + $\mathbb{I} = Q.H*Q$ but not in the reversed product + $\mathbb{I} = Q * Q.H$. + + >>> Q.H * Q + Matrix([ + [1, 0], + [0, 1]]) + >>> Q * Q.H + Matrix([ + [27261/30625, 348/30625, -1914/6125], + [ 348/30625, 30589/30625, 198/6125], + [ -1914/6125, 198/6125, 136/1225]]) + + If you want to augment the results to be a full orthogonal + decomposition, you should augment $Q$ with an another orthogonal + column. + + You are able to append an identity matrix, + and you can run the Gram-Schmidt + process to make them augmented as orthogonal basis. + + >>> Q_aug = Q.row_join(Matrix.eye(3)) + >>> Q_aug = Q_aug.QRdecomposition()[0] + >>> Q_aug + Matrix([ + [ 6/7, -69/175, 58/175], + [ 3/7, 158/175, -6/175], + [-2/7, 6/35, 33/35]]) + >>> Q_aug.H * Q_aug + Matrix([ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1]]) + >>> Q_aug * Q_aug.H + Matrix([ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1]]) + + Augmenting the $R$ matrix with zero row is straightforward. + + >>> R_aug = R.col_join(Matrix([[0, 0, 0]])) + >>> R_aug + Matrix([ + [14, 21, 0], + [ 0, 175, 0], + [ 0, 0, 0]]) + >>> Q_aug * R_aug == A + True + + A zero matrix example: + + >>> from sympy import Matrix + >>> A = Matrix.zeros(3, 4) + >>> Q, R = A.QRdecomposition() + + They may return matrices with zero rows and columns. + + >>> Q + Matrix(3, 0, []) + >>> R + Matrix(0, 4, []) + >>> Q*R + Matrix([ + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0]]) + + As the same augmentation rule described above, $Q$ can be augmented + with columns of an identity matrix and $R$ can be augmented with + rows of a zero matrix. + + >>> Q_aug = Q.row_join(Matrix.eye(3)) + >>> R_aug = R.col_join(Matrix.zeros(3, 4)) + >>> Q_aug * Q_aug.T + Matrix([ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1]]) + >>> R_aug + Matrix([ + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0]]) + >>> Q_aug * R_aug == A + True + + See Also + ======== + + sympy.matrices.dense.DenseMatrix.cholesky + sympy.matrices.dense.DenseMatrix.LDLdecomposition + sympy.matrices.matrixbase.MatrixBase.LUdecomposition + QRsolve + """ + return _QRdecomposition_optional(M, normalize=True) + +def _upper_hessenberg_decomposition(A): + """Converts a matrix into Hessenberg matrix H. + + Returns 2 matrices H, P s.t. + $P H P^{T} = A$, where H is an upper hessenberg matrix + and P is an orthogonal matrix + + Examples + ======== + + >>> from sympy import Matrix + >>> A = Matrix([ + ... [1,2,3], + ... [-3,5,6], + ... [4,-8,9], + ... ]) + >>> H, P = A.upper_hessenberg_decomposition() + >>> H + Matrix([ + [1, 6/5, 17/5], + [5, 213/25, -134/25], + [0, 216/25, 137/25]]) + >>> P + Matrix([ + [1, 0, 0], + [0, -3/5, 4/5], + [0, 4/5, 3/5]]) + >>> P * H * P.H == A + True + + + References + ========== + + .. [#] https://mathworld.wolfram.com/HessenbergDecomposition.html + """ + + M = A.as_mutable() + + if not M.is_square: + raise NonSquareMatrixError("Matrix must be square.") + + n = M.cols + P = M.eye(n) + H = M + + for j in range(n - 2): + + u = H[j + 1:, j] + + if u[1:, :].is_zero_matrix: + continue + + if sign(u[0]) != 0: + u[0] = u[0] + sign(u[0]) * u.norm() + else: + u[0] = u[0] + u.norm() + + v = u / u.norm() + + H[j + 1:, :] = H[j + 1:, :] - 2 * v * (v.H * H[j + 1:, :]) + H[:, j + 1:] = H[:, j + 1:] - (H[:, j + 1:] * (2 * v)) * v.H + P[:, j + 1:] = P[:, j + 1:] - (P[:, j + 1:] * (2 * v)) * v.H + + return H, P diff --git a/phivenv/Lib/site-packages/sympy/matrices/dense.py b/phivenv/Lib/site-packages/sympy/matrices/dense.py new file mode 100644 index 0000000000000000000000000000000000000000..98bf9931df54f67abfd9c4dc810b46fdcf70288f --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/dense.py @@ -0,0 +1,1094 @@ +from __future__ import annotations +import random + +from sympy.core.basic import Basic +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.core.sympify import sympify +from sympy.functions.elementary.trigonometric import cos, sin +from sympy.utilities.decorator import doctest_depends_on +from sympy.utilities.exceptions import sympy_deprecation_warning +from sympy.utilities.iterables import is_sequence + +from .exceptions import ShapeError +from .decompositions import _cholesky, _LDLdecomposition +from .matrixbase import MatrixBase +from .repmatrix import MutableRepMatrix, RepMatrix +from .solvers import _lower_triangular_solve, _upper_triangular_solve + + +__doctest_requires__ = {('symarray',): ['numpy']} + + +def _iszero(x): + """Returns True if x is zero.""" + return x.is_zero + + +class DenseMatrix(RepMatrix): + """Matrix implementation based on DomainMatrix as the internal representation""" + + # + # DenseMatrix is a superclass for both MutableDenseMatrix and + # ImmutableDenseMatrix. Methods shared by both classes but not for the + # Sparse classes should be implemented here. + # + + is_MatrixExpr: bool = False + + _op_priority = 10.01 + _class_priority = 4 + + @property + def _mat(self): + sympy_deprecation_warning( + """ + The private _mat attribute of Matrix is deprecated. Use the + .flat() method instead. + """, + deprecated_since_version="1.9", + active_deprecations_target="deprecated-private-matrix-attributes" + ) + + return self.flat() + + def _eval_inverse(self, **kwargs): + return self.inv(method=kwargs.get('method', 'GE'), + iszerofunc=kwargs.get('iszerofunc', _iszero), + try_block_diag=kwargs.get('try_block_diag', False)) + + def as_immutable(self): + """Returns an Immutable version of this Matrix + """ + from .immutable import ImmutableDenseMatrix as cls + return cls._fromrep(self._rep.copy()) + + def as_mutable(self): + """Returns a mutable version of this matrix + + Examples + ======== + + >>> from sympy import ImmutableMatrix + >>> X = ImmutableMatrix([[1, 2], [3, 4]]) + >>> Y = X.as_mutable() + >>> Y[1, 1] = 5 # Can set values in Y + >>> Y + Matrix([ + [1, 2], + [3, 5]]) + """ + return Matrix(self) + + def cholesky(self, hermitian=True): + return _cholesky(self, hermitian=hermitian) + + def LDLdecomposition(self, hermitian=True): + return _LDLdecomposition(self, hermitian=hermitian) + + def lower_triangular_solve(self, rhs): + return _lower_triangular_solve(self, rhs) + + def upper_triangular_solve(self, rhs): + return _upper_triangular_solve(self, rhs) + + cholesky.__doc__ = _cholesky.__doc__ + LDLdecomposition.__doc__ = _LDLdecomposition.__doc__ + lower_triangular_solve.__doc__ = _lower_triangular_solve.__doc__ + upper_triangular_solve.__doc__ = _upper_triangular_solve.__doc__ + + +def _force_mutable(x): + """Return a matrix as a Matrix, otherwise return x.""" + if getattr(x, 'is_Matrix', False): + return x.as_mutable() + elif isinstance(x, Basic): + return x + elif hasattr(x, '__array__'): + a = x.__array__() + if len(a.shape) == 0: + return sympify(a) + return Matrix(x) + return x + + +class MutableDenseMatrix(DenseMatrix, MutableRepMatrix): + + def simplify(self, **kwargs): + """Applies simplify to the elements of a matrix in place. + + This is a shortcut for M.applyfunc(lambda x: simplify(x, ratio, measure)) + + See Also + ======== + + sympy.simplify.simplify.simplify + """ + from sympy.simplify.simplify import simplify as _simplify + for (i, j), element in self.todok().items(): + self[i, j] = _simplify(element, **kwargs) + + +MutableMatrix = Matrix = MutableDenseMatrix + +########### +# Numpy Utility Functions: +# list2numpy, matrix2numpy, symmarray +########### + + +def list2numpy(l, dtype=object): # pragma: no cover + """Converts Python list of SymPy expressions to a NumPy array. + + See Also + ======== + + matrix2numpy + """ + from numpy import empty + a = empty(len(l), dtype) + for i, s in enumerate(l): + a[i] = s + return a + + +def matrix2numpy(m, dtype=object): # pragma: no cover + """Converts SymPy's matrix to a NumPy array. + + See Also + ======== + + list2numpy + """ + from numpy import empty + a = empty(m.shape, dtype) + for i in range(m.rows): + for j in range(m.cols): + a[i, j] = m[i, j] + return a + + +########### +# Rotation matrices: +# rot_givens, rot_axis[123], rot_ccw_axis[123] +########### + + +def rot_givens(i, j, theta, dim=3): + r"""Returns a a Givens rotation matrix, a a rotation in the + plane spanned by two coordinates axes. + + Explanation + =========== + + The Givens rotation corresponds to a generalization of rotation + matrices to any number of dimensions, given by: + + .. math:: + G(i, j, \theta) = + \begin{bmatrix} + 1 & \cdots & 0 & \cdots & 0 & \cdots & 0 \\ + \vdots & \ddots & \vdots & & \vdots & & \vdots \\ + 0 & \cdots & c & \cdots & -s & \cdots & 0 \\ + \vdots & & \vdots & \ddots & \vdots & & \vdots \\ + 0 & \cdots & s & \cdots & c & \cdots & 0 \\ + \vdots & & \vdots & & \vdots & \ddots & \vdots \\ + 0 & \cdots & 0 & \cdots & 0 & \cdots & 1 + \end{bmatrix} + + Where $c = \cos(\theta)$ and $s = \sin(\theta)$ appear at the intersections + ``i``\th and ``j``\th rows and columns. + + For fixed ``i > j``\, the non-zero elements of a Givens matrix are + given by: + + - $g_{kk} = 1$ for $k \ne i,\,j$ + - $g_{kk} = c$ for $k = i,\,j$ + - $g_{ji} = -g_{ij} = -s$ + + Parameters + ========== + + i : int between ``0`` and ``dim - 1`` + Represents first axis + j : int between ``0`` and ``dim - 1`` + Represents second axis + dim : int bigger than 1 + Number of dimensions. Defaults to 3. + + Examples + ======== + + >>> from sympy import pi, rot_givens + + A counterclockwise rotation of pi/3 (60 degrees) around + the third axis (z-axis): + + >>> rot_givens(1, 0, pi/3) + Matrix([ + [ 1/2, -sqrt(3)/2, 0], + [sqrt(3)/2, 1/2, 0], + [ 0, 0, 1]]) + + If we rotate by pi/2 (90 degrees): + + >>> rot_givens(1, 0, pi/2) + Matrix([ + [0, -1, 0], + [1, 0, 0], + [0, 0, 1]]) + + This can be generalized to any number + of dimensions: + + >>> rot_givens(1, 0, pi/2, dim=4) + Matrix([ + [0, -1, 0, 0], + [1, 0, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1]]) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Givens_rotation + + See Also + ======== + + rot_axis1: Returns a rotation matrix for a rotation of theta (in radians) + about the 1-axis (clockwise around the x axis) + rot_axis2: Returns a rotation matrix for a rotation of theta (in radians) + about the 2-axis (clockwise around the y axis) + rot_axis3: Returns a rotation matrix for a rotation of theta (in radians) + about the 3-axis (clockwise around the z axis) + rot_ccw_axis1: Returns a rotation matrix for a rotation of theta (in radians) + about the 1-axis (counterclockwise around the x axis) + rot_ccw_axis2: Returns a rotation matrix for a rotation of theta (in radians) + about the 2-axis (counterclockwise around the y axis) + rot_ccw_axis3: Returns a rotation matrix for a rotation of theta (in radians) + about the 3-axis (counterclockwise around the z axis) + """ + if not isinstance(dim, int) or dim < 2: + raise ValueError('dim must be an integer biggen than one, ' + 'got {}.'.format(dim)) + + if i == j: + raise ValueError('i and j must be different, ' + 'got ({}, {})'.format(i, j)) + + for ij in [i, j]: + if not isinstance(ij, int) or ij < 0 or ij > dim - 1: + raise ValueError('i and j must be integers between 0 and ' + '{}, got i={} and j={}.'.format(dim-1, i, j)) + + theta = sympify(theta) + c = cos(theta) + s = sin(theta) + M = eye(dim) + M[i, i] = c + M[j, j] = c + M[i, j] = s + M[j, i] = -s + return M + + +def rot_axis3(theta): + r"""Returns a rotation matrix for a rotation of theta (in radians) + about the 3-axis. + + Explanation + =========== + + For a right-handed coordinate system, this corresponds to a + clockwise rotation around the `z`-axis, given by: + + .. math:: + + R = \begin{bmatrix} + \cos(\theta) & \sin(\theta) & 0 \\ + -\sin(\theta) & \cos(\theta) & 0 \\ + 0 & 0 & 1 + \end{bmatrix} + + Examples + ======== + + >>> from sympy import pi, rot_axis3 + + A rotation of pi/3 (60 degrees): + + >>> theta = pi/3 + >>> rot_axis3(theta) + Matrix([ + [ 1/2, sqrt(3)/2, 0], + [-sqrt(3)/2, 1/2, 0], + [ 0, 0, 1]]) + + If we rotate by pi/2 (90 degrees): + + >>> rot_axis3(pi/2) + Matrix([ + [ 0, 1, 0], + [-1, 0, 0], + [ 0, 0, 1]]) + + See Also + ======== + + rot_givens: Returns a Givens rotation matrix (generalized rotation for + any number of dimensions) + rot_ccw_axis3: Returns a rotation matrix for a rotation of theta (in radians) + about the 3-axis (counterclockwise around the z axis) + rot_axis1: Returns a rotation matrix for a rotation of theta (in radians) + about the 1-axis (clockwise around the x axis) + rot_axis2: Returns a rotation matrix for a rotation of theta (in radians) + about the 2-axis (clockwise around the y axis) + """ + return rot_givens(0, 1, theta, dim=3) + + +def rot_axis2(theta): + r"""Returns a rotation matrix for a rotation of theta (in radians) + about the 2-axis. + + Explanation + =========== + + For a right-handed coordinate system, this corresponds to a + clockwise rotation around the `y`-axis, given by: + + .. math:: + + R = \begin{bmatrix} + \cos(\theta) & 0 & -\sin(\theta) \\ + 0 & 1 & 0 \\ + \sin(\theta) & 0 & \cos(\theta) + \end{bmatrix} + + Examples + ======== + + >>> from sympy import pi, rot_axis2 + + A rotation of pi/3 (60 degrees): + + >>> theta = pi/3 + >>> rot_axis2(theta) + Matrix([ + [ 1/2, 0, -sqrt(3)/2], + [ 0, 1, 0], + [sqrt(3)/2, 0, 1/2]]) + + If we rotate by pi/2 (90 degrees): + + >>> rot_axis2(pi/2) + Matrix([ + [0, 0, -1], + [0, 1, 0], + [1, 0, 0]]) + + See Also + ======== + + rot_givens: Returns a Givens rotation matrix (generalized rotation for + any number of dimensions) + rot_ccw_axis2: Returns a rotation matrix for a rotation of theta (in radians) + about the 2-axis (clockwise around the y axis) + rot_axis1: Returns a rotation matrix for a rotation of theta (in radians) + about the 1-axis (counterclockwise around the x axis) + rot_axis3: Returns a rotation matrix for a rotation of theta (in radians) + about the 3-axis (counterclockwise around the z axis) + """ + return rot_givens(2, 0, theta, dim=3) + + +def rot_axis1(theta): + r"""Returns a rotation matrix for a rotation of theta (in radians) + about the 1-axis. + + Explanation + =========== + + For a right-handed coordinate system, this corresponds to a + clockwise rotation around the `x`-axis, given by: + + .. math:: + + R = \begin{bmatrix} + 1 & 0 & 0 \\ + 0 & \cos(\theta) & \sin(\theta) \\ + 0 & -\sin(\theta) & \cos(\theta) + \end{bmatrix} + + Examples + ======== + + >>> from sympy import pi, rot_axis1 + + A rotation of pi/3 (60 degrees): + + >>> theta = pi/3 + >>> rot_axis1(theta) + Matrix([ + [1, 0, 0], + [0, 1/2, sqrt(3)/2], + [0, -sqrt(3)/2, 1/2]]) + + If we rotate by pi/2 (90 degrees): + + >>> rot_axis1(pi/2) + Matrix([ + [1, 0, 0], + [0, 0, 1], + [0, -1, 0]]) + + See Also + ======== + + rot_givens: Returns a Givens rotation matrix (generalized rotation for + any number of dimensions) + rot_ccw_axis1: Returns a rotation matrix for a rotation of theta (in radians) + about the 1-axis (counterclockwise around the x axis) + rot_axis2: Returns a rotation matrix for a rotation of theta (in radians) + about the 2-axis (clockwise around the y axis) + rot_axis3: Returns a rotation matrix for a rotation of theta (in radians) + about the 3-axis (clockwise around the z axis) + """ + return rot_givens(1, 2, theta, dim=3) + + +def rot_ccw_axis3(theta): + r"""Returns a rotation matrix for a rotation of theta (in radians) + about the 3-axis. + + Explanation + =========== + + For a right-handed coordinate system, this corresponds to a + counterclockwise rotation around the `z`-axis, given by: + + .. math:: + + R = \begin{bmatrix} + \cos(\theta) & -\sin(\theta) & 0 \\ + \sin(\theta) & \cos(\theta) & 0 \\ + 0 & 0 & 1 + \end{bmatrix} + + Examples + ======== + + >>> from sympy import pi, rot_ccw_axis3 + + A rotation of pi/3 (60 degrees): + + >>> theta = pi/3 + >>> rot_ccw_axis3(theta) + Matrix([ + [ 1/2, -sqrt(3)/2, 0], + [sqrt(3)/2, 1/2, 0], + [ 0, 0, 1]]) + + If we rotate by pi/2 (90 degrees): + + >>> rot_ccw_axis3(pi/2) + Matrix([ + [0, -1, 0], + [1, 0, 0], + [0, 0, 1]]) + + See Also + ======== + + rot_givens: Returns a Givens rotation matrix (generalized rotation for + any number of dimensions) + rot_axis3: Returns a rotation matrix for a rotation of theta (in radians) + about the 3-axis (clockwise around the z axis) + rot_ccw_axis1: Returns a rotation matrix for a rotation of theta (in radians) + about the 1-axis (counterclockwise around the x axis) + rot_ccw_axis2: Returns a rotation matrix for a rotation of theta (in radians) + about the 2-axis (counterclockwise around the y axis) + """ + return rot_givens(1, 0, theta, dim=3) + + +def rot_ccw_axis2(theta): + r"""Returns a rotation matrix for a rotation of theta (in radians) + about the 2-axis. + + Explanation + =========== + + For a right-handed coordinate system, this corresponds to a + counterclockwise rotation around the `y`-axis, given by: + + .. math:: + + R = \begin{bmatrix} + \cos(\theta) & 0 & \sin(\theta) \\ + 0 & 1 & 0 \\ + -\sin(\theta) & 0 & \cos(\theta) + \end{bmatrix} + + Examples + ======== + + >>> from sympy import pi, rot_ccw_axis2 + + A rotation of pi/3 (60 degrees): + + >>> theta = pi/3 + >>> rot_ccw_axis2(theta) + Matrix([ + [ 1/2, 0, sqrt(3)/2], + [ 0, 1, 0], + [-sqrt(3)/2, 0, 1/2]]) + + If we rotate by pi/2 (90 degrees): + + >>> rot_ccw_axis2(pi/2) + Matrix([ + [ 0, 0, 1], + [ 0, 1, 0], + [-1, 0, 0]]) + + See Also + ======== + + rot_givens: Returns a Givens rotation matrix (generalized rotation for + any number of dimensions) + rot_axis2: Returns a rotation matrix for a rotation of theta (in radians) + about the 2-axis (clockwise around the y axis) + rot_ccw_axis1: Returns a rotation matrix for a rotation of theta (in radians) + about the 1-axis (counterclockwise around the x axis) + rot_ccw_axis3: Returns a rotation matrix for a rotation of theta (in radians) + about the 3-axis (counterclockwise around the z axis) + """ + return rot_givens(0, 2, theta, dim=3) + + +def rot_ccw_axis1(theta): + r"""Returns a rotation matrix for a rotation of theta (in radians) + about the 1-axis. + + Explanation + =========== + + For a right-handed coordinate system, this corresponds to a + counterclockwise rotation around the `x`-axis, given by: + + .. math:: + + R = \begin{bmatrix} + 1 & 0 & 0 \\ + 0 & \cos(\theta) & -\sin(\theta) \\ + 0 & \sin(\theta) & \cos(\theta) + \end{bmatrix} + + Examples + ======== + + >>> from sympy import pi, rot_ccw_axis1 + + A rotation of pi/3 (60 degrees): + + >>> theta = pi/3 + >>> rot_ccw_axis1(theta) + Matrix([ + [1, 0, 0], + [0, 1/2, -sqrt(3)/2], + [0, sqrt(3)/2, 1/2]]) + + If we rotate by pi/2 (90 degrees): + + >>> rot_ccw_axis1(pi/2) + Matrix([ + [1, 0, 0], + [0, 0, -1], + [0, 1, 0]]) + + See Also + ======== + + rot_givens: Returns a Givens rotation matrix (generalized rotation for + any number of dimensions) + rot_axis1: Returns a rotation matrix for a rotation of theta (in radians) + about the 1-axis (clockwise around the x axis) + rot_ccw_axis2: Returns a rotation matrix for a rotation of theta (in radians) + about the 2-axis (counterclockwise around the y axis) + rot_ccw_axis3: Returns a rotation matrix for a rotation of theta (in radians) + about the 3-axis (counterclockwise around the z axis) + """ + return rot_givens(2, 1, theta, dim=3) + + +@doctest_depends_on(modules=('numpy',)) +def symarray(prefix, shape, **kwargs): # pragma: no cover + r"""Create a numpy ndarray of symbols (as an object array). + + The created symbols are named ``prefix_i1_i2_``... You should thus provide a + non-empty prefix if you want your symbols to be unique for different output + arrays, as SymPy symbols with identical names are the same object. + + Parameters + ---------- + + prefix : string + A prefix prepended to the name of every symbol. + + shape : int or tuple + Shape of the created array. If an int, the array is one-dimensional; for + more than one dimension the shape must be a tuple. + + \*\*kwargs : dict + keyword arguments passed on to Symbol + + Examples + ======== + These doctests require numpy. + + >>> from sympy import symarray + >>> symarray('', 3) + [_0 _1 _2] + + If you want multiple symarrays to contain distinct symbols, you *must* + provide unique prefixes: + + >>> a = symarray('', 3) + >>> b = symarray('', 3) + >>> a[0] == b[0] + True + >>> a = symarray('a', 3) + >>> b = symarray('b', 3) + >>> a[0] == b[0] + False + + Creating symarrays with a prefix: + + >>> symarray('a', 3) + [a_0 a_1 a_2] + + For more than one dimension, the shape must be given as a tuple: + + >>> symarray('a', (2, 3)) + [[a_0_0 a_0_1 a_0_2] + [a_1_0 a_1_1 a_1_2]] + >>> symarray('a', (2, 3, 2)) + [[[a_0_0_0 a_0_0_1] + [a_0_1_0 a_0_1_1] + [a_0_2_0 a_0_2_1]] + + [[a_1_0_0 a_1_0_1] + [a_1_1_0 a_1_1_1] + [a_1_2_0 a_1_2_1]]] + + For setting assumptions of the underlying Symbols: + + >>> [s.is_real for s in symarray('a', 2, real=True)] + [True, True] + """ + from numpy import empty, ndindex + arr = empty(shape, dtype=object) + for index in ndindex(shape): + arr[index] = Symbol('%s_%s' % (prefix, '_'.join(map(str, index))), + **kwargs) + return arr + + +############### +# Functions +############### + +def casoratian(seqs, n, zero=True): + """Given linear difference operator L of order 'k' and homogeneous + equation Ly = 0 we want to compute kernel of L, which is a set + of 'k' sequences: a(n), b(n), ... z(n). + + Solutions of L are linearly independent iff their Casoratian, + denoted as C(a, b, ..., z), do not vanish for n = 0. + + Casoratian is defined by k x k determinant:: + + + a(n) b(n) . . . z(n) + + | a(n+1) b(n+1) . . . z(n+1) | + | . . . . | + | . . . . | + | . . . . | + + a(n+k-1) b(n+k-1) . . . z(n+k-1) + + + It proves very useful in rsolve_hyper() where it is applied + to a generating set of a recurrence to factor out linearly + dependent solutions and return a basis: + + >>> from sympy import Symbol, casoratian, factorial + >>> n = Symbol('n', integer=True) + + Exponential and factorial are linearly independent: + + >>> casoratian([2**n, factorial(n)], n) != 0 + True + + """ + + seqs = list(map(sympify, seqs)) + + if not zero: + f = lambda i, j: seqs[j].subs(n, n + i) + else: + f = lambda i, j: seqs[j].subs(n, i) + + k = len(seqs) + + return Matrix(k, k, f).det() + + +def eye(*args, **kwargs): + """Create square identity matrix n x n + + See Also + ======== + + diag + zeros + ones + """ + + return Matrix.eye(*args, **kwargs) + + +def diag(*values, strict=True, unpack=False, **kwargs): + """Returns a matrix with the provided values placed on the + diagonal. If non-square matrices are included, they will + produce a block-diagonal matrix. + + Examples + ======== + + This version of diag is a thin wrapper to Matrix.diag that differs + in that it treats all lists like matrices -- even when a single list + is given. If this is not desired, either put a `*` before the list or + set `unpack=True`. + + >>> from sympy import diag + + >>> diag([1, 2, 3], unpack=True) # = diag(1,2,3) or diag(*[1,2,3]) + Matrix([ + [1, 0, 0], + [0, 2, 0], + [0, 0, 3]]) + + >>> diag([1, 2, 3]) # a column vector + Matrix([ + [1], + [2], + [3]]) + + See Also + ======== + .matrixbase.MatrixBase.eye + .matrixbase.MatrixBase.diagonal + .matrixbase.MatrixBase.diag + .expressions.blockmatrix.BlockMatrix + """ + return Matrix.diag(*values, strict=strict, unpack=unpack, **kwargs) + + +def GramSchmidt(vlist, orthonormal=False): + """Apply the Gram-Schmidt process to a set of vectors. + + Parameters + ========== + + vlist : List of Matrix + Vectors to be orthogonalized for. + + orthonormal : Bool, optional + If true, return an orthonormal basis. + + Returns + ======= + + vlist : List of Matrix + Orthogonalized vectors + + Notes + ===== + + This routine is mostly duplicate from ``Matrix.orthogonalize``, + except for some difference that this always raises error when + linearly dependent vectors are found, and the keyword ``normalize`` + has been named as ``orthonormal`` in this function. + + See Also + ======== + + .matrixbase.MatrixBase.orthogonalize + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Gram%E2%80%93Schmidt_process + """ + return MutableDenseMatrix.orthogonalize( + *vlist, normalize=orthonormal, rankcheck=True + ) + + +def hessian(f, varlist, constraints=()): + """Compute Hessian matrix for a function f wrt parameters in varlist + which may be given as a sequence or a row/column vector. A list of + constraints may optionally be given. + + Examples + ======== + + >>> from sympy import Function, hessian, pprint + >>> from sympy.abc import x, y + >>> f = Function('f')(x, y) + >>> g1 = Function('g')(x, y) + >>> g2 = x**2 + 3*y + >>> pprint(hessian(f, (x, y), [g1, g2])) + [ d d ] + [ 0 0 --(g(x, y)) --(g(x, y)) ] + [ dx dy ] + [ ] + [ 0 0 2*x 3 ] + [ ] + [ 2 2 ] + [d d d ] + [--(g(x, y)) 2*x ---(f(x, y)) -----(f(x, y))] + [dx 2 dy dx ] + [ dx ] + [ ] + [ 2 2 ] + [d d d ] + [--(g(x, y)) 3 -----(f(x, y)) ---(f(x, y)) ] + [dy dy dx 2 ] + [ dy ] + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Hessian_matrix + + See Also + ======== + + sympy.matrices.matrixbase.MatrixBase.jacobian + wronskian + """ + # f is the expression representing a function f, return regular matrix + if isinstance(varlist, MatrixBase): + if 1 not in varlist.shape: + raise ShapeError("`varlist` must be a column or row vector.") + if varlist.cols == 1: + varlist = varlist.T + varlist = varlist.tolist()[0] + if is_sequence(varlist): + n = len(varlist) + if not n: + raise ShapeError("`len(varlist)` must not be zero.") + else: + raise ValueError("Improper variable list in hessian function") + if not getattr(f, 'diff'): + # check differentiability + raise ValueError("Function `f` (%s) is not differentiable" % f) + m = len(constraints) + N = m + n + out = zeros(N) + for k, g in enumerate(constraints): + if not getattr(g, 'diff'): + # check differentiability + raise ValueError("Function `f` (%s) is not differentiable" % f) + for i in range(n): + out[k, i + m] = g.diff(varlist[i]) + for i in range(n): + for j in range(i, n): + out[i + m, j + m] = f.diff(varlist[i]).diff(varlist[j]) + for i in range(N): + for j in range(i + 1, N): + out[j, i] = out[i, j] + return out + + +def jordan_cell(eigenval, n): + """ + Create a Jordan block: + + Examples + ======== + + >>> from sympy import jordan_cell + >>> from sympy.abc import x + >>> jordan_cell(x, 4) + Matrix([ + [x, 1, 0, 0], + [0, x, 1, 0], + [0, 0, x, 1], + [0, 0, 0, x]]) + """ + + return Matrix.jordan_block(size=n, eigenvalue=eigenval) + + +def matrix_multiply_elementwise(A, B): + """Return the Hadamard product (elementwise product) of A and B + + >>> from sympy import Matrix, matrix_multiply_elementwise + >>> A = Matrix([[0, 1, 2], [3, 4, 5]]) + >>> B = Matrix([[1, 10, 100], [100, 10, 1]]) + >>> matrix_multiply_elementwise(A, B) + Matrix([ + [ 0, 10, 200], + [300, 40, 5]]) + + See Also + ======== + + sympy.matrices.matrixbase.MatrixBase.__mul__ + """ + return A.multiply_elementwise(B) + + +def ones(*args, **kwargs): + """Returns a matrix of ones with ``rows`` rows and ``cols`` columns; + if ``cols`` is omitted a square matrix will be returned. + + See Also + ======== + + zeros + eye + diag + """ + + if 'c' in kwargs: + kwargs['cols'] = kwargs.pop('c') + + return Matrix.ones(*args, **kwargs) + + +def randMatrix(r, c=None, min=0, max=99, seed=None, symmetric=False, + percent=100, prng=None): + """Create random matrix with dimensions ``r`` x ``c``. If ``c`` is omitted + the matrix will be square. If ``symmetric`` is True the matrix must be + square. If ``percent`` is less than 100 then only approximately the given + percentage of elements will be non-zero. + + The pseudo-random number generator used to generate matrix is chosen in the + following way. + + * If ``prng`` is supplied, it will be used as random number generator. + It should be an instance of ``random.Random``, or at least have + ``randint`` and ``shuffle`` methods with same signatures. + * if ``prng`` is not supplied but ``seed`` is supplied, then new + ``random.Random`` with given ``seed`` will be created; + * otherwise, a new ``random.Random`` with default seed will be used. + + Examples + ======== + + >>> from sympy import randMatrix + >>> randMatrix(3) # doctest:+SKIP + [25, 45, 27] + [44, 54, 9] + [23, 96, 46] + >>> randMatrix(3, 2) # doctest:+SKIP + [87, 29] + [23, 37] + [90, 26] + >>> randMatrix(3, 3, 0, 2) # doctest:+SKIP + [0, 2, 0] + [2, 0, 1] + [0, 0, 1] + >>> randMatrix(3, symmetric=True) # doctest:+SKIP + [85, 26, 29] + [26, 71, 43] + [29, 43, 57] + >>> A = randMatrix(3, seed=1) + >>> B = randMatrix(3, seed=2) + >>> A == B + False + >>> A == randMatrix(3, seed=1) + True + >>> randMatrix(3, symmetric=True, percent=50) # doctest:+SKIP + [77, 70, 0], + [70, 0, 0], + [ 0, 0, 88] + """ + # Note that ``Random()`` is equivalent to ``Random(None)`` + prng = prng or random.Random(seed) + + if c is None: + c = r + + if symmetric and r != c: + raise ValueError('For symmetric matrices, r must equal c, but %i != %i' % (r, c)) + + ij = range(r * c) + if percent != 100: + ij = prng.sample(ij, int(len(ij)*percent // 100)) + + m = zeros(r, c) + + if not symmetric: + for ijk in ij: + i, j = divmod(ijk, c) + m[i, j] = prng.randint(min, max) + else: + for ijk in ij: + i, j = divmod(ijk, c) + if i <= j: + m[i, j] = m[j, i] = prng.randint(min, max) + + return m + + +def wronskian(functions, var, method='bareiss'): + """ + Compute Wronskian for [] of functions + + :: + + | f1 f2 ... fn | + | f1' f2' ... fn' | + | . . . . | + W(f1, ..., fn) = | . . . . | + | . . . . | + | (n) (n) (n) | + | D (f1) D (f2) ... D (fn) | + + see: https://en.wikipedia.org/wiki/Wronskian + + See Also + ======== + + sympy.matrices.matrixbase.MatrixBase.jacobian + hessian + """ + + functions = [sympify(f) for f in functions] + n = len(functions) + if n == 0: + return S.One + W = Matrix(n, n, lambda i, j: functions[i].diff(var, j)) + return W.det(method) + + +def zeros(*args, **kwargs): + """Returns a matrix of zeros with ``rows`` rows and ``cols`` columns; + if ``cols`` is omitted a square matrix will be returned. + + See Also + ======== + + ones + eye + diag + """ + + if 'c' in kwargs: + kwargs['cols'] = kwargs.pop('c') + + return Matrix.zeros(*args, **kwargs) diff --git a/phivenv/Lib/site-packages/sympy/matrices/determinant.py b/phivenv/Lib/site-packages/sympy/matrices/determinant.py new file mode 100644 index 0000000000000000000000000000000000000000..9206c0714999ebe0cde5c4300d9b3293939177df --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/determinant.py @@ -0,0 +1,1021 @@ +from types import FunctionType + +from sympy.core.cache import cacheit +from sympy.core.numbers import Float, Integer +from sympy.core.singleton import S +from sympy.core.symbol import uniquely_named_symbol +from sympy.core.mul import Mul +from sympy.polys import PurePoly, cancel +from sympy.functions.combinatorial.numbers import nC +from sympy.polys.matrices.domainmatrix import DomainMatrix +from sympy.polys.matrices.ddm import DDM + +from .exceptions import NonSquareMatrixError +from .utilities import ( + _get_intermediate_simp, _get_intermediate_simp_bool, + _iszero, _is_zero_after_expand_mul, _dotprodsimp, _simplify) + + +def _find_reasonable_pivot(col, iszerofunc=_iszero, simpfunc=_simplify): + """ Find the lowest index of an item in ``col`` that is + suitable for a pivot. If ``col`` consists only of + Floats, the pivot with the largest norm is returned. + Otherwise, the first element where ``iszerofunc`` returns + False is used. If ``iszerofunc`` does not return false, + items are simplified and retested until a suitable + pivot is found. + + Returns a 4-tuple + (pivot_offset, pivot_val, assumed_nonzero, newly_determined) + where pivot_offset is the index of the pivot, pivot_val is + the (possibly simplified) value of the pivot, assumed_nonzero + is True if an assumption that the pivot was non-zero + was made without being proved, and newly_determined are + elements that were simplified during the process of pivot + finding.""" + + newly_determined = [] + col = list(col) + # a column that contains a mix of floats and integers + # but at least one float is considered a numerical + # column, and so we do partial pivoting + if all(isinstance(x, (Float, Integer)) for x in col) and any( + isinstance(x, Float) for x in col): + col_abs = [abs(x) for x in col] + max_value = max(col_abs) + if iszerofunc(max_value): + # just because iszerofunc returned True, doesn't + # mean the value is numerically zero. Make sure + # to replace all entries with numerical zeros + if max_value != 0: + newly_determined = [(i, 0) for i, x in enumerate(col) if x != 0] + return (None, None, False, newly_determined) + index = col_abs.index(max_value) + return (index, col[index], False, newly_determined) + + # PASS 1 (iszerofunc directly) + possible_zeros = [] + for i, x in enumerate(col): + is_zero = iszerofunc(x) + # is someone wrote a custom iszerofunc, it may return + # BooleanFalse or BooleanTrue instead of True or False, + # so use == for comparison instead of `is` + if is_zero == False: + # we found something that is definitely not zero + return (i, x, False, newly_determined) + possible_zeros.append(is_zero) + + # by this point, we've found no certain non-zeros + if all(possible_zeros): + # if everything is definitely zero, we have + # no pivot + return (None, None, False, newly_determined) + + # PASS 2 (iszerofunc after simplify) + # we haven't found any for-sure non-zeros, so + # go through the elements iszerofunc couldn't + # make a determination about and opportunistically + # simplify to see if we find something + for i, x in enumerate(col): + if possible_zeros[i] is not None: + continue + simped = simpfunc(x) + is_zero = iszerofunc(simped) + if is_zero in (True, False): + newly_determined.append((i, simped)) + if is_zero == False: + return (i, simped, False, newly_determined) + possible_zeros[i] = is_zero + + # after simplifying, some things that were recognized + # as zeros might be zeros + if all(possible_zeros): + # if everything is definitely zero, we have + # no pivot + return (None, None, False, newly_determined) + + # PASS 3 (.equals(0)) + # some expressions fail to simplify to zero, but + # ``.equals(0)`` evaluates to True. As a last-ditch + # attempt, apply ``.equals`` to these expressions + for i, x in enumerate(col): + if possible_zeros[i] is not None: + continue + if x.equals(S.Zero): + # ``.iszero`` may return False with + # an implicit assumption (e.g., ``x.equals(0)`` + # when ``x`` is a symbol), so only treat it + # as proved when ``.equals(0)`` returns True + possible_zeros[i] = True + newly_determined.append((i, S.Zero)) + + if all(possible_zeros): + return (None, None, False, newly_determined) + + # at this point there is nothing that could definitely + # be a pivot. To maintain compatibility with existing + # behavior, we'll assume that an illdetermined thing is + # non-zero. We should probably raise a warning in this case + i = possible_zeros.index(None) + return (i, col[i], True, newly_determined) + + +def _find_reasonable_pivot_naive(col, iszerofunc=_iszero, simpfunc=None): + """ + Helper that computes the pivot value and location from a + sequence of contiguous matrix column elements. As a side effect + of the pivot search, this function may simplify some of the elements + of the input column. A list of these simplified entries and their + indices are also returned. + This function mimics the behavior of _find_reasonable_pivot(), + but does less work trying to determine if an indeterminate candidate + pivot simplifies to zero. This more naive approach can be much faster, + with the trade-off that it may erroneously return a pivot that is zero. + + ``col`` is a sequence of contiguous column entries to be searched for + a suitable pivot. + ``iszerofunc`` is a callable that returns a Boolean that indicates + if its input is zero, or None if no such determination can be made. + ``simpfunc`` is a callable that simplifies its input. It must return + its input if it does not simplify its input. Passing in + ``simpfunc=None`` indicates that the pivot search should not attempt + to simplify any candidate pivots. + + Returns a 4-tuple: + (pivot_offset, pivot_val, assumed_nonzero, newly_determined) + ``pivot_offset`` is the sequence index of the pivot. + ``pivot_val`` is the value of the pivot. + pivot_val and col[pivot_index] are equivalent, but will be different + when col[pivot_index] was simplified during the pivot search. + ``assumed_nonzero`` is a boolean indicating if the pivot cannot be + guaranteed to be zero. If assumed_nonzero is true, then the pivot + may or may not be non-zero. If assumed_nonzero is false, then + the pivot is non-zero. + ``newly_determined`` is a list of index-value pairs of pivot candidates + that were simplified during the pivot search. + """ + + # indeterminates holds the index-value pairs of each pivot candidate + # that is neither zero or non-zero, as determined by iszerofunc(). + # If iszerofunc() indicates that a candidate pivot is guaranteed + # non-zero, or that every candidate pivot is zero then the contents + # of indeterminates are unused. + # Otherwise, the only viable candidate pivots are symbolic. + # In this case, indeterminates will have at least one entry, + # and all but the first entry are ignored when simpfunc is None. + indeterminates = [] + for i, col_val in enumerate(col): + col_val_is_zero = iszerofunc(col_val) + if col_val_is_zero == False: + # This pivot candidate is non-zero. + return i, col_val, False, [] + elif col_val_is_zero is None: + # The candidate pivot's comparison with zero + # is indeterminate. + indeterminates.append((i, col_val)) + + if len(indeterminates) == 0: + # All candidate pivots are guaranteed to be zero, i.e. there is + # no pivot. + return None, None, False, [] + + if simpfunc is None: + # Caller did not pass in a simplification function that might + # determine if an indeterminate pivot candidate is guaranteed + # to be nonzero, so assume the first indeterminate candidate + # is non-zero. + return indeterminates[0][0], indeterminates[0][1], True, [] + + # newly_determined holds index-value pairs of candidate pivots + # that were simplified during the search for a non-zero pivot. + newly_determined = [] + for i, col_val in indeterminates: + tmp_col_val = simpfunc(col_val) + if id(col_val) != id(tmp_col_val): + # simpfunc() simplified this candidate pivot. + newly_determined.append((i, tmp_col_val)) + if iszerofunc(tmp_col_val) == False: + # Candidate pivot simplified to a guaranteed non-zero value. + return i, tmp_col_val, False, newly_determined + + return indeterminates[0][0], indeterminates[0][1], True, newly_determined + + +# This functions is a candidate for caching if it gets implemented for matrices. +def _berkowitz_toeplitz_matrix(M): + """Return (A,T) where T the Toeplitz matrix used in the Berkowitz algorithm + corresponding to ``M`` and A is the first principal submatrix. + """ + + # the 0 x 0 case is trivial + if M.rows == 0 and M.cols == 0: + return M._new(1,1, [M.one]) + + # + # Partition M = [ a_11 R ] + # [ C A ] + # + + a, R = M[0,0], M[0, 1:] + C, A = M[1:, 0], M[1:,1:] + + # + # The Toeplitz matrix looks like + # + # [ 1 ] + # [ -a 1 ] + # [ -RC -a 1 ] + # [ -RAC -RC -a 1 ] + # [ -RA**2C -RAC -RC -a 1 ] + # etc. + + # Compute the diagonal entries. + # Because multiplying matrix times vector is so much + # more efficient than matrix times matrix, recursively + # compute -R * A**n * C. + diags = [C] + for i in range(M.rows - 2): + diags.append(A.multiply(diags[i], dotprodsimp=None)) + diags = [(-R).multiply(d, dotprodsimp=None)[0, 0] for d in diags] + diags = [M.one, -a] + diags + + def entry(i,j): + if j > i: + return M.zero + return diags[i - j] + + toeplitz = M._new(M.cols + 1, M.rows, entry) + return (A, toeplitz) + + +# This functions is a candidate for caching if it gets implemented for matrices. +def _berkowitz_vector(M): + """ Run the Berkowitz algorithm and return a vector whose entries + are the coefficients of the characteristic polynomial of ``M``. + + Given N x N matrix, efficiently compute + coefficients of characteristic polynomials of ``M`` + without division in the ground domain. + + This method is particularly useful for computing determinant, + principal minors and characteristic polynomial when ``M`` + has complicated coefficients e.g. polynomials. Semi-direct + usage of this algorithm is also important in computing + efficiently sub-resultant PRS. + + Assuming that M is a square matrix of dimension N x N and + I is N x N identity matrix, then the Berkowitz vector is + an N x 1 vector whose entries are coefficients of the + polynomial + + charpoly(M) = det(t*I - M) + + As a consequence, all polynomials generated by Berkowitz + algorithm are monic. + + For more information on the implemented algorithm refer to: + + [1] S.J. Berkowitz, On computing the determinant in small + parallel time using a small number of processors, ACM, + Information Processing Letters 18, 1984, pp. 147-150 + + [2] M. Keber, Division-Free computation of sub-resultants + using Bezout matrices, Tech. Report MPI-I-2006-1-006, + Saarbrucken, 2006 + """ + + # handle the trivial cases + if M.rows == 0 and M.cols == 0: + return M._new(1, 1, [M.one]) + elif M.rows == 1 and M.cols == 1: + return M._new(2, 1, [M.one, -M[0,0]]) + + submat, toeplitz = _berkowitz_toeplitz_matrix(M) + + return toeplitz.multiply(_berkowitz_vector(submat), dotprodsimp=None) + + +def _adjugate(M, method="berkowitz"): + """Returns the adjugate, or classical adjoint, of + a matrix. That is, the transpose of the matrix of cofactors. + + https://en.wikipedia.org/wiki/Adjugate + + Parameters + ========== + + method : string, optional + Method to use to find the cofactors, can be "bareiss", "berkowitz", + "bird", "laplace" or "lu". + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix([[1, 2], [3, 4]]) + >>> M.adjugate() + Matrix([ + [ 4, -2], + [-3, 1]]) + + See Also + ======== + + cofactor_matrix + sympy.matrices.matrixbase.MatrixBase.transpose + """ + + return M.cofactor_matrix(method=method).transpose() + + +# This functions is a candidate for caching if it gets implemented for matrices. +def _charpoly(M, x='lambda', simplify=_simplify): + """Computes characteristic polynomial det(x*I - M) where I is + the identity matrix. + + A PurePoly is returned, so using different variables for ``x`` does + not affect the comparison or the polynomials: + + Parameters + ========== + + x : string, optional + Name for the "lambda" variable, defaults to "lambda". + + simplify : function, optional + Simplification function to use on the characteristic polynomial + calculated. Defaults to ``simplify``. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.abc import x, y + >>> M = Matrix([[1, 3], [2, 0]]) + >>> M.charpoly() + PurePoly(lambda**2 - lambda - 6, lambda, domain='ZZ') + >>> M.charpoly(x) == M.charpoly(y) + True + >>> M.charpoly(x) == M.charpoly(y) + True + + Specifying ``x`` is optional; a symbol named ``lambda`` is used by + default (which looks good when pretty-printed in unicode): + + >>> M.charpoly().as_expr() + lambda**2 - lambda - 6 + + And if ``x`` clashes with an existing symbol, underscores will + be prepended to the name to make it unique: + + >>> M = Matrix([[1, 2], [x, 0]]) + >>> M.charpoly(x).as_expr() + _x**2 - _x - 2*x + + Whether you pass a symbol or not, the generator can be obtained + with the gen attribute since it may not be the same as the symbol + that was passed: + + >>> M.charpoly(x).gen + _x + >>> M.charpoly(x).gen == x + False + + Notes + ===== + + The Samuelson-Berkowitz algorithm is used to compute + the characteristic polynomial efficiently and without any + division operations. Thus the characteristic polynomial over any + commutative ring without zero divisors can be computed. + + If the determinant det(x*I - M) can be found out easily as + in the case of an upper or a lower triangular matrix, then + instead of Samuelson-Berkowitz algorithm, eigenvalues are computed + and the characteristic polynomial with their help. + + See Also + ======== + + det + """ + + if not M.is_square: + raise NonSquareMatrixError() + + # Use DomainMatrix. We are already going to convert this to a Poly so there + # is no need to worry about expanding powers etc. Also since this algorithm + # does not require division or zero detection it is fine to use EX. + # + # M.to_DM() will fall back on EXRAW rather than EX. EXRAW is a lot faster + # for elementary arithmetic because it does not call cancel for each + # operation but it generates large unsimplified results that are slow in + # the subsequent call to simplify. Using EX instead is faster overall + # but at least in some cases EXRAW+simplify gives a simpler result so we + # preserve that existing behaviour of charpoly for now... + dM = M.to_DM() + + K = dM.domain + + cp = dM.charpoly() + + x = uniquely_named_symbol(x, [M], modify=lambda s: '_' + s) + + if K.is_EXRAW or simplify is not _simplify: + # XXX: Converting back to Expr is expensive. We only do it if the + # caller supplied a custom simplify function for backwards + # compatibility or otherwise if the domain was EX. For any other domain + # there should be no benefit in simplifying at this stage because Poly + # will put everything into canonical form anyway. + berk_vector = [K.to_sympy(c) for c in cp] + berk_vector = [simplify(a) for a in berk_vector] + p = PurePoly(berk_vector, x) + + else: + # Convert from the list of domain elements directly to Poly. + p = PurePoly(cp, x, domain=K) + + return p + + +def _cofactor(M, i, j, method="berkowitz"): + """Calculate the cofactor of an element. + + Parameters + ========== + + method : string, optional + Method to use to find the cofactors, can be "bareiss", "berkowitz", + "bird", "laplace" or "lu". + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix([[1, 2], [3, 4]]) + >>> M.cofactor(0, 1) + -3 + + See Also + ======== + + cofactor_matrix + minor + minor_submatrix + """ + + if not M.is_square or M.rows < 1: + raise NonSquareMatrixError() + + return S.NegativeOne**((i + j) % 2) * M.minor(i, j, method) + + +def _cofactor_matrix(M, method="berkowitz"): + """Return a matrix containing the cofactor of each element. + + Parameters + ========== + + method : string, optional + Method to use to find the cofactors, can be "bareiss", "berkowitz", + "bird", "laplace" or "lu". + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix([[1, 2], [3, 4]]) + >>> M.cofactor_matrix() + Matrix([ + [ 4, -3], + [-2, 1]]) + + See Also + ======== + + cofactor + minor + minor_submatrix + """ + + if not M.is_square: + raise NonSquareMatrixError() + + return M._new(M.rows, M.cols, + lambda i, j: M.cofactor(i, j, method)) + +def _per(M): + """Returns the permanent of a matrix. Unlike determinant, + permanent is defined for both square and non-square matrices. + + For an m x n matrix, with m less than or equal to n, + it is given as the sum over the permutations s of size + less than or equal to m on [1, 2, . . . n] of the product + from i = 1 to m of M[i, s[i]]. Taking the transpose will + not affect the value of the permanent. + + In the case of a square matrix, this is the same as the permutation + definition of the determinant, but it does not take the sign of the + permutation into account. Computing the permanent with this definition + is quite inefficient, so here the Ryser formula is used. + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + >>> M.per() + 450 + >>> M = Matrix([1, 5, 7]) + >>> M.per() + 13 + + References + ========== + + .. [1] Prof. Frank Ben's notes: https://math.berkeley.edu/~bernd/ban275.pdf + .. [2] Wikipedia article on Permanent: https://en.wikipedia.org/wiki/Permanent_%28mathematics%29 + .. [3] https://reference.wolfram.com/language/ref/Permanent.html + .. [4] Permanent of a rectangular matrix : https://arxiv.org/pdf/0904.3251.pdf + """ + import itertools + + m, n = M.shape + if m > n: + M = M.T + m, n = n, m + s = list(range(n)) + + subsets = [] + for i in range(1, m + 1): + subsets += list(map(list, itertools.combinations(s, i))) + + perm = 0 + for subset in subsets: + prod = 1 + sub_len = len(subset) + for i in range(m): + prod *= sum(M[i, j] for j in subset) + perm += prod * S.NegativeOne**sub_len * nC(n - sub_len, m - sub_len) + perm *= S.NegativeOne**m + return perm.simplify() + +def _det_DOM(M): + DOM = DomainMatrix.from_Matrix(M, field=True, extension=True) + K = DOM.domain + return K.to_sympy(DOM.det()) + +# This functions is a candidate for caching if it gets implemented for matrices. +def _det(M, method="bareiss", iszerofunc=None): + """Computes the determinant of a matrix if ``M`` is a concrete matrix object + otherwise return an expressions ``Determinant(M)`` if ``M`` is a + ``MatrixSymbol`` or other expression. + + Parameters + ========== + + method : string, optional + Specifies the algorithm used for computing the matrix determinant. + + If the matrix is at most 3x3, a hard-coded formula is used and the + specified method is ignored. Otherwise, it defaults to + ``'bareiss'``. + + Also, if the matrix is an upper or a lower triangular matrix, determinant + is computed by simple multiplication of diagonal elements, and the + specified method is ignored. + + If it is set to ``'domain-ge'``, then Gaussian elimination method will + be used via using DomainMatrix. + + If it is set to ``'bareiss'``, Bareiss' fraction-free algorithm will + be used. + + If it is set to ``'berkowitz'``, Berkowitz' algorithm will be used. + + If it is set to ``'bird'``, Bird's algorithm will be used [1]_. + + If it is set to ``'laplace'``, Laplace's algorithm will be used. + + Otherwise, if it is set to ``'lu'``, LU decomposition will be used. + + .. note:: + For backward compatibility, legacy keys like "bareis" and + "det_lu" can still be used to indicate the corresponding + methods. + And the keys are also case-insensitive for now. However, it is + suggested to use the precise keys for specifying the method. + + iszerofunc : FunctionType or None, optional + If it is set to ``None``, it will be defaulted to ``_iszero`` if the + method is set to ``'bareiss'``, and ``_is_zero_after_expand_mul`` if + the method is set to ``'lu'``. + + It can also accept any user-specified zero testing function, if it + is formatted as a function which accepts a single symbolic argument + and returns ``True`` if it is tested as zero and ``False`` if it + tested as non-zero, and also ``None`` if it is undecidable. + + Returns + ======= + + det : Basic + Result of determinant. + + Raises + ====== + + ValueError + If unrecognized keys are given for ``method`` or ``iszerofunc``. + + NonSquareMatrixError + If attempted to calculate determinant from a non-square matrix. + + Examples + ======== + + >>> from sympy import Matrix, eye, det + >>> I3 = eye(3) + >>> det(I3) + 1 + >>> M = Matrix([[1, 2], [3, 4]]) + >>> det(M) + -2 + >>> det(M) == M.det() + True + >>> M.det(method="domain-ge") + -2 + + References + ========== + + .. [1] Bird, R. S. (2011). A simple division-free algorithm for computing + determinants. Inf. Process. Lett., 111(21), 1072-1074. doi: + 10.1016/j.ipl.2011.08.006 + """ + + # sanitize `method` + method = method.lower() + + if method == "bareis": + method = "bareiss" + elif method == "det_lu": + method = "lu" + + if method not in ("bareiss", "berkowitz", "lu", "domain-ge", "bird", + "laplace"): + raise ValueError("Determinant method '%s' unrecognized" % method) + + if iszerofunc is None: + if method == "bareiss": + iszerofunc = _is_zero_after_expand_mul + elif method == "lu": + iszerofunc = _iszero + + elif not isinstance(iszerofunc, FunctionType): + raise ValueError("Zero testing method '%s' unrecognized" % iszerofunc) + + n = M.rows + + if n == M.cols: # square check is done in individual method functions + if n == 0: + return M.one + elif n == 1: + return M[0, 0] + elif n == 2: + m = M[0, 0] * M[1, 1] - M[0, 1] * M[1, 0] + return _get_intermediate_simp(_dotprodsimp)(m) + elif n == 3: + m = (M[0, 0] * M[1, 1] * M[2, 2] + + M[0, 1] * M[1, 2] * M[2, 0] + + M[0, 2] * M[1, 0] * M[2, 1] + - M[0, 2] * M[1, 1] * M[2, 0] + - M[0, 0] * M[1, 2] * M[2, 1] + - M[0, 1] * M[1, 0] * M[2, 2]) + return _get_intermediate_simp(_dotprodsimp)(m) + + dets = [] + for b in M.strongly_connected_components(): + if method == "domain-ge": # uses DomainMatrix to evaluate determinant + det = _det_DOM(M[b, b]) + elif method == "bareiss": + det = M[b, b]._eval_det_bareiss(iszerofunc=iszerofunc) + elif method == "berkowitz": + det = M[b, b]._eval_det_berkowitz() + elif method == "lu": + det = M[b, b]._eval_det_lu(iszerofunc=iszerofunc) + elif method == "bird": + det = M[b, b]._eval_det_bird() + elif method == "laplace": + det = M[b, b]._eval_det_laplace() + dets.append(det) + return Mul(*dets) + + +# This functions is a candidate for caching if it gets implemented for matrices. +def _det_bareiss(M, iszerofunc=_is_zero_after_expand_mul): + """Compute matrix determinant using Bareiss' fraction-free + algorithm which is an extension of the well known Gaussian + elimination method. This approach is best suited for dense + symbolic matrices and will result in a determinant with + minimal number of fractions. It means that less term + rewriting is needed on resulting formulae. + + Parameters + ========== + + iszerofunc : function, optional + The function to use to determine zeros when doing an LU decomposition. + Defaults to ``lambda x: x.is_zero``. + + TODO: Implement algorithm for sparse matrices (SFF), + http://www.eecis.udel.edu/~saunders/papers/sffge/it5.ps. + """ + + # Recursively implemented Bareiss' algorithm as per Deanna Richelle Leggett's + # thesis http://www.math.usm.edu/perry/Research/Thesis_DRL.pdf + def bareiss(mat, cumm=1): + if mat.rows == 0: + return mat.one + elif mat.rows == 1: + return mat[0, 0] + + # find a pivot and extract the remaining matrix + # With the default iszerofunc, _find_reasonable_pivot slows down + # the computation by the factor of 2.5 in one test. + # Relevant issues: #10279 and #13877. + pivot_pos, pivot_val, _, _ = _find_reasonable_pivot(mat[:, 0], iszerofunc=iszerofunc) + if pivot_pos is None: + return mat.zero + + # if we have a valid pivot, we'll do a "row swap", so keep the + # sign of the det + sign = (-1) ** (pivot_pos % 2) + + # we want every row but the pivot row and every column + rows = [i for i in range(mat.rows) if i != pivot_pos] + cols = list(range(mat.cols)) + tmp_mat = mat.extract(rows, cols) + + def entry(i, j): + ret = (pivot_val*tmp_mat[i, j + 1] - mat[pivot_pos, j + 1]*tmp_mat[i, 0]) / cumm + if _get_intermediate_simp_bool(True): + return _dotprodsimp(ret) + elif not ret.is_Atom: + return cancel(ret) + return ret + + return sign*bareiss(M._new(mat.rows - 1, mat.cols - 1, entry), pivot_val) + + if not M.is_square: + raise NonSquareMatrixError() + + if M.rows == 0: + return M.one + # sympy/matrices/tests/test_matrices.py contains a test that + # suggests that the determinant of a 0 x 0 matrix is one, by + # convention. + + return bareiss(M) + + +def _det_berkowitz(M): + """ Use the Berkowitz algorithm to compute the determinant.""" + + if not M.is_square: + raise NonSquareMatrixError() + + if M.rows == 0: + return M.one + # sympy/matrices/tests/test_matrices.py contains a test that + # suggests that the determinant of a 0 x 0 matrix is one, by + # convention. + + berk_vector = _berkowitz_vector(M) + return (-1)**(len(berk_vector) - 1) * berk_vector[-1] + + +# This functions is a candidate for caching if it gets implemented for matrices. +def _det_LU(M, iszerofunc=_iszero, simpfunc=None): + """ Computes the determinant of a matrix from its LU decomposition. + This function uses the LU decomposition computed by + LUDecomposition_Simple(). + + The keyword arguments iszerofunc and simpfunc are passed to + LUDecomposition_Simple(). + iszerofunc is a callable that returns a boolean indicating if its + input is zero, or None if it cannot make the determination. + simpfunc is a callable that simplifies its input. + The default is simpfunc=None, which indicate that the pivot search + algorithm should not attempt to simplify any candidate pivots. + If simpfunc fails to simplify its input, then it must return its input + instead of a copy. + + Parameters + ========== + + iszerofunc : function, optional + The function to use to determine zeros when doing an LU decomposition. + Defaults to ``lambda x: x.is_zero``. + + simpfunc : function, optional + The simplification function to use when looking for zeros for pivots. + """ + + if not M.is_square: + raise NonSquareMatrixError() + + if M.rows == 0: + return M.one + # sympy/matrices/tests/test_matrices.py contains a test that + # suggests that the determinant of a 0 x 0 matrix is one, by + # convention. + + lu, row_swaps = M.LUdecomposition_Simple(iszerofunc=iszerofunc, + simpfunc=simpfunc) + # P*A = L*U => det(A) = det(L)*det(U)/det(P) = det(P)*det(U). + # Lower triangular factor L encoded in lu has unit diagonal => det(L) = 1. + # P is a permutation matrix => det(P) in {-1, 1} => 1/det(P) = det(P). + # LUdecomposition_Simple() returns a list of row exchange index pairs, rather + # than a permutation matrix, but det(P) = (-1)**len(row_swaps). + + # Avoid forming the potentially time consuming product of U's diagonal entries + # if the product is zero. + # Bottom right entry of U is 0 => det(A) = 0. + # It may be impossible to determine if this entry of U is zero when it is symbolic. + if iszerofunc(lu[lu.rows-1, lu.rows-1]): + return M.zero + + # Compute det(P) + det = -M.one if len(row_swaps)%2 else M.one + + # Compute det(U) by calculating the product of U's diagonal entries. + # The upper triangular portion of lu is the upper triangular portion of the + # U factor in the LU decomposition. + for k in range(lu.rows): + det *= lu[k, k] + + # return det(P)*det(U) + return det + + +@cacheit +def __det_laplace(M): + """Compute the determinant of a matrix using Laplace expansion. + + This is a recursive function, and it should not be called directly. + Use _det_laplace() instead. The reason for splitting this function + into two is to allow caching of determinants of submatrices. While + one could also define this function inside _det_laplace(), that + would remove the advantage of using caching in Cramer Solve. + """ + n = M.shape[0] + if n == 1: + return M[0] + elif n == 2: + return M[0, 0] * M[1, 1] - M[0, 1] * M[1, 0] + else: + return sum((-1) ** i * M[0, i] * + __det_laplace(M.minor_submatrix(0, i)) for i in range(n)) + + +def _det_laplace(M): + """Compute the determinant of a matrix using Laplace expansion. + + While Laplace expansion is not the most efficient method of computing + a determinant, it is a simple one, and it has the advantage of + being division free. To improve efficiency, this function uses + caching to avoid recomputing determinants of submatrices. + """ + if not M.is_square: + raise NonSquareMatrixError() + if M.shape[0] == 0: + return M.one + # sympy/matrices/tests/test_matrices.py contains a test that + # suggests that the determinant of a 0 x 0 matrix is one, by + # convention. + return __det_laplace(M.as_immutable()) + + +def _det_bird(M): + r"""Compute the determinant of a matrix using Bird's algorithm. + + Bird's algorithm is a simple division-free algorithm for computing, which + is of lower order than the Laplace's algorithm. It is described in [1]_. + + References + ========== + + .. [1] Bird, R. S. (2011). A simple division-free algorithm for computing + determinants. Inf. Process. Lett., 111(21), 1072-1074. doi: + 10.1016/j.ipl.2011.08.006 + """ + def mu(X): + n = X.shape[0] + zero = X.domain.zero + + total = zero + diag_sums = [zero] + for i in reversed(range(1, n)): + total -= X[i][i] + diag_sums.append(total) + diag_sums = diag_sums[::-1] + + elems = [[zero] * i + [diag_sums[i]] + X_i[i + 1:] for i, X_i in + enumerate(X)] + return DDM(elems, X.shape, X.domain) + + Mddm = M._rep.to_ddm() + n = M.shape[0] + if n == 0: + return M.one + # sympy/matrices/tests/test_matrices.py contains a test that + # suggests that the determinant of a 0 x 0 matrix is one, by + # convention. + Fn1 = Mddm + for _ in range(n - 1): + Fn1 = mu(Fn1).matmul(Mddm) + detA = Fn1[0][0] + if n % 2 == 0: + detA = -detA + + return Mddm.domain.to_sympy(detA) + + +def _minor(M, i, j, method="berkowitz"): + """Return the (i,j) minor of ``M``. That is, + return the determinant of the matrix obtained by deleting + the `i`th row and `j`th column from ``M``. + + Parameters + ========== + + i, j : int + The row and column to exclude to obtain the submatrix. + + method : string, optional + Method to use to find the determinant of the submatrix, can be + "bareiss", "berkowitz", "bird", "laplace" or "lu". + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + >>> M.minor(1, 1) + -12 + + See Also + ======== + + minor_submatrix + cofactor + det + """ + + if not M.is_square: + raise NonSquareMatrixError() + + return M.minor_submatrix(i, j).det(method=method) + + +def _minor_submatrix(M, i, j): + """Return the submatrix obtained by removing the `i`th row + and `j`th column from ``M`` (works with Pythonic negative indices). + + Parameters + ========== + + i, j : int + The row and column to exclude to obtain the submatrix. + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + >>> M.minor_submatrix(1, 1) + Matrix([ + [1, 3], + [7, 9]]) + + See Also + ======== + + minor + cofactor + """ + + if i < 0: + i += M.rows + if j < 0: + j += M.cols + + if not 0 <= i < M.rows or not 0 <= j < M.cols: + raise ValueError("`i` and `j` must satisfy 0 <= i < ``M.rows`` " + "(%d)" % M.rows + "and 0 <= j < ``M.cols`` (%d)." % M.cols) + + rows = [a for a in range(M.rows) if a != i] + cols = [a for a in range(M.cols) if a != j] + + return M.extract(rows, cols) diff --git a/phivenv/Lib/site-packages/sympy/matrices/eigen.py b/phivenv/Lib/site-packages/sympy/matrices/eigen.py new file mode 100644 index 0000000000000000000000000000000000000000..87b2418efcece1c0b158ec56995bb011286feb3c --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/eigen.py @@ -0,0 +1,1346 @@ +from types import FunctionType +from collections import Counter + +from mpmath import mp, workprec +from mpmath.libmp.libmpf import prec_to_dps + +from sympy.core.sorting import default_sort_key +from sympy.core.evalf import DEFAULT_MAXPREC, PrecisionExhausted +from sympy.core.logic import fuzzy_and, fuzzy_or +from sympy.core.numbers import Float +from sympy.core.sympify import _sympify +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.polys import roots, CRootOf, ZZ, QQ, EX +from sympy.polys.matrices import DomainMatrix +from sympy.polys.matrices.eigen import dom_eigenvects, dom_eigenvects_to_sympy +from sympy.polys.polytools import gcd + +from .exceptions import MatrixError, NonSquareMatrixError +from .determinant import _find_reasonable_pivot + +from .utilities import _iszero, _simplify + + +__doctest_requires__ = { + ('_is_indefinite', + '_is_negative_definite', + '_is_negative_semidefinite', + '_is_positive_definite', + '_is_positive_semidefinite'): ['matplotlib'], +} + + +def _eigenvals_eigenvects_mpmath(M): + norm2 = lambda v: mp.sqrt(sum(i**2 for i in v)) + + v1 = None + prec = max(x._prec for x in M.atoms(Float)) + eps = 2**-prec + + while prec < DEFAULT_MAXPREC: + with workprec(prec): + A = mp.matrix(M.evalf(n=prec_to_dps(prec))) + E, ER = mp.eig(A) + v2 = norm2([i for e in E for i in (mp.re(e), mp.im(e))]) + if v1 is not None and mp.fabs(v1 - v2) < eps: + return E, ER + v1 = v2 + prec *= 2 + + # we get here because the next step would have taken us + # past MAXPREC or because we never took a step; in case + # of the latter, we refuse to send back a solution since + # it would not have been verified; we also resist taking + # a small step to arrive exactly at MAXPREC since then + # the two calculations might be artificially close. + raise PrecisionExhausted + + +def _eigenvals_mpmath(M, multiple=False): + """Compute eigenvalues using mpmath""" + E, _ = _eigenvals_eigenvects_mpmath(M) + result = [_sympify(x) for x in E] + if multiple: + return result + return dict(Counter(result)) + + +def _eigenvects_mpmath(M): + E, ER = _eigenvals_eigenvects_mpmath(M) + result = [] + for i in range(M.rows): + eigenval = _sympify(E[i]) + eigenvect = _sympify(ER[:, i]) + result.append((eigenval, 1, [eigenvect])) + + return result + + +# This function is a candidate for caching if it gets implemented for matrices. +def _eigenvals( + M, error_when_incomplete=True, *, simplify=False, multiple=False, + rational=False, **flags): + r"""Compute eigenvalues of the matrix. + + Parameters + ========== + + error_when_incomplete : bool, optional + If it is set to ``True``, it will raise an error if not all + eigenvalues are computed. This is caused by ``roots`` not returning + a full list of eigenvalues. + + simplify : bool or function, optional + If it is set to ``True``, it attempts to return the most + simplified form of expressions returned by applying default + simplification method in every routine. + + If it is set to ``False``, it will skip simplification in this + particular routine to save computation resources. + + If a function is passed to, it will attempt to apply + the particular function as simplification method. + + rational : bool, optional + If it is set to ``True``, every floating point numbers would be + replaced with rationals before computation. It can solve some + issues of ``roots`` routine not working well with floats. + + multiple : bool, optional + If it is set to ``True``, the result will be in the form of a + list. + + If it is set to ``False``, the result will be in the form of a + dictionary. + + Returns + ======= + + eigs : list or dict + Eigenvalues of a matrix. The return format would be specified by + the key ``multiple``. + + Raises + ====== + + MatrixError + If not enough roots had got computed. + + NonSquareMatrixError + If attempted to compute eigenvalues from a non-square matrix. + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix(3, 3, [0, 1, 1, 1, 0, 0, 1, 1, 1]) + >>> M.eigenvals() + {-1: 1, 0: 1, 2: 1} + + See Also + ======== + + MatrixBase.charpoly + eigenvects + + Notes + ===== + + Eigenvalues of a matrix $A$ can be computed by solving a matrix + equation $\det(A - \lambda I) = 0$ + + It's not always possible to return radical solutions for + eigenvalues for matrices larger than $4, 4$ shape due to + Abel-Ruffini theorem. + + If there is no radical solution is found for the eigenvalue, + it may return eigenvalues in the form of + :class:`sympy.polys.rootoftools.ComplexRootOf`. + """ + if not M: + if multiple: + return [] + return {} + + if not M.is_square: + raise NonSquareMatrixError("{} must be a square matrix.".format(M)) + + if M._rep.domain not in (ZZ, QQ): + # Skip this check for ZZ/QQ because it can be slow + if all(x.is_number for x in M) and M.has(Float): + return _eigenvals_mpmath(M, multiple=multiple) + + if rational: + from sympy.simplify import nsimplify + M = M.applyfunc( + lambda x: nsimplify(x, rational=True) if x.has(Float) else x) + + if multiple: + return _eigenvals_list( + M, error_when_incomplete=error_when_incomplete, simplify=simplify, + **flags) + return _eigenvals_dict( + M, error_when_incomplete=error_when_incomplete, simplify=simplify, + **flags) + + +eigenvals_error_message = \ +"It is not always possible to express the eigenvalues of a matrix " + \ +"of size 5x5 or higher in radicals. " + \ +"We have CRootOf, but domains other than the rationals are not " + \ +"currently supported. " + \ +"If there are no symbols in the matrix, " + \ +"it should still be possible to compute numeric approximations " + \ +"of the eigenvalues using " + \ +"M.evalf().eigenvals() or M.charpoly().nroots()." + + +def _eigenvals_list( + M, error_when_incomplete=True, simplify=False, **flags): + iblocks = M.strongly_connected_components() + all_eigs = [] + is_dom = M._rep.domain in (ZZ, QQ) + for b in iblocks: + + # Fast path for a 1x1 block: + if is_dom and len(b) == 1: + index = b[0] + val = M[index, index] + all_eigs.append(val) + continue + + block = M[b, b] + + if isinstance(simplify, FunctionType): + charpoly = block.charpoly(simplify=simplify) + else: + charpoly = block.charpoly() + + eigs = roots(charpoly, multiple=True, **flags) + + if len(eigs) != block.rows: + try: + eigs = charpoly.all_roots(multiple=True) + except NotImplementedError: + if error_when_incomplete: + raise MatrixError(eigenvals_error_message) + else: + eigs = [] + + all_eigs += eigs + + if not simplify: + return all_eigs + if not isinstance(simplify, FunctionType): + simplify = _simplify + return [simplify(value) for value in all_eigs] + + +def _eigenvals_dict( + M, error_when_incomplete=True, simplify=False, **flags): + iblocks = M.strongly_connected_components() + all_eigs = {} + is_dom = M._rep.domain in (ZZ, QQ) + for b in iblocks: + + # Fast path for a 1x1 block: + if is_dom and len(b) == 1: + index = b[0] + val = M[index, index] + all_eigs[val] = all_eigs.get(val, 0) + 1 + continue + + block = M[b, b] + + if isinstance(simplify, FunctionType): + charpoly = block.charpoly(simplify=simplify) + else: + charpoly = block.charpoly() + + eigs = roots(charpoly, multiple=False, **flags) + + if sum(eigs.values()) != block.rows: + try: + eigs = dict(charpoly.all_roots(multiple=False)) + except NotImplementedError: + if error_when_incomplete: + raise MatrixError(eigenvals_error_message) + else: + eigs = {} + + for k, v in eigs.items(): + if k in all_eigs: + all_eigs[k] += v + else: + all_eigs[k] = v + + if not simplify: + return all_eigs + if not isinstance(simplify, FunctionType): + simplify = _simplify + return {simplify(key): value for key, value in all_eigs.items()} + + +def _eigenspace(M, eigenval, iszerofunc=_iszero, simplify=False): + """Get a basis for the eigenspace for a particular eigenvalue""" + m = M - M.eye(M.rows) * eigenval + ret = m.nullspace(iszerofunc=iszerofunc) + + # The nullspace for a real eigenvalue should be non-trivial. + # If we didn't find an eigenvector, try once more a little harder + if len(ret) == 0 and simplify: + ret = m.nullspace(iszerofunc=iszerofunc, simplify=True) + if len(ret) == 0: + raise NotImplementedError( + "Can't evaluate eigenvector for eigenvalue {}".format(eigenval)) + return ret + + +def _eigenvects_DOM(M, **kwargs): + DOM = DomainMatrix.from_Matrix(M, field=True, extension=True) + DOM = DOM.to_dense() + + if DOM.domain != EX: + rational, algebraic = dom_eigenvects(DOM) + eigenvects = dom_eigenvects_to_sympy( + rational, algebraic, M.__class__, **kwargs) + eigenvects = sorted(eigenvects, key=lambda x: default_sort_key(x[0])) + + return eigenvects + return None + + +def _eigenvects_sympy(M, iszerofunc, simplify=True, **flags): + eigenvals = M.eigenvals(rational=False, **flags) + + # Make sure that we have all roots in radical form + for x in eigenvals: + if x.has(CRootOf): + raise MatrixError( + "Eigenvector computation is not implemented if the matrix have " + "eigenvalues in CRootOf form") + + eigenvals = sorted(eigenvals.items(), key=default_sort_key) + ret = [] + for val, mult in eigenvals: + vects = _eigenspace(M, val, iszerofunc=iszerofunc, simplify=simplify) + ret.append((val, mult, vects)) + return ret + + +# This functions is a candidate for caching if it gets implemented for matrices. +def _eigenvects(M, error_when_incomplete=True, iszerofunc=_iszero, *, chop=False, **flags): + """Compute eigenvectors of the matrix. + + Parameters + ========== + + error_when_incomplete : bool, optional + Raise an error when not all eigenvalues are computed. This is + caused by ``roots`` not returning a full list of eigenvalues. + + iszerofunc : function, optional + Specifies a zero testing function to be used in ``rref``. + + Default value is ``_iszero``, which uses SymPy's naive and fast + default assumption handler. + + It can also accept any user-specified zero testing function, if it + is formatted as a function which accepts a single symbolic argument + and returns ``True`` if it is tested as zero and ``False`` if it + is tested as non-zero, and ``None`` if it is undecidable. + + simplify : bool or function, optional + If ``True``, ``as_content_primitive()`` will be used to tidy up + normalization artifacts. + + It will also be used by the ``nullspace`` routine. + + chop : bool or positive number, optional + If the matrix contains any Floats, they will be changed to Rationals + for computation purposes, but the answers will be returned after + being evaluated with evalf. The ``chop`` flag is passed to ``evalf``. + When ``chop=True`` a default precision will be used; a number will + be interpreted as the desired level of precision. + + Returns + ======= + + ret : [(eigenval, multiplicity, eigenspace), ...] + A ragged list containing tuples of data obtained by ``eigenvals`` + and ``nullspace``. + + ``eigenspace`` is a list containing the ``eigenvector`` for each + eigenvalue. + + ``eigenvector`` is a vector in the form of a ``Matrix``. e.g. + a vector of length 3 is returned as ``Matrix([a_1, a_2, a_3])``. + + Raises + ====== + + NotImplementedError + If failed to compute nullspace. + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix(3, 3, [0, 1, 1, 1, 0, 0, 1, 1, 1]) + >>> M.eigenvects() + [(-1, 1, [Matrix([ + [-1], + [ 1], + [ 0]])]), (0, 1, [Matrix([ + [ 0], + [-1], + [ 1]])]), (2, 1, [Matrix([ + [2/3], + [1/3], + [ 1]])])] + + See Also + ======== + + eigenvals + MatrixBase.nullspace + """ + simplify = flags.get('simplify', True) + primitive = flags.get('simplify', False) + flags.pop('simplify', None) # remove this if it's there + flags.pop('multiple', None) # remove this if it's there + + if not isinstance(simplify, FunctionType): + simpfunc = _simplify if simplify else lambda x: x + + has_floats = M.has(Float) + if has_floats: + if all(x.is_number for x in M): + return _eigenvects_mpmath(M) + from sympy.simplify import nsimplify + M = M.applyfunc(lambda x: nsimplify(x, rational=True)) + + ret = _eigenvects_DOM(M) + if ret is None: + ret = _eigenvects_sympy(M, iszerofunc, simplify=simplify, **flags) + + if primitive: + # if the primitive flag is set, get rid of any common + # integer denominators + def denom_clean(l): + return [(v / gcd(list(v))).applyfunc(simpfunc) for v in l] + + ret = [(val, mult, denom_clean(es)) for val, mult, es in ret] + + if has_floats: + # if we had floats to start with, turn the eigenvectors to floats + ret = [(val.evalf(chop=chop), mult, [v.evalf(chop=chop) for v in es]) + for val, mult, es in ret] + + return ret + + +def _is_diagonalizable_with_eigen(M, reals_only=False): + """See _is_diagonalizable. This function returns the bool along with the + eigenvectors to avoid calculating them again in functions like + ``diagonalize``.""" + + if not M.is_square: + return False, [] + + eigenvecs = M.eigenvects(simplify=True) + + for val, mult, basis in eigenvecs: + if reals_only and not val.is_real: # if we have a complex eigenvalue + return False, eigenvecs + + if mult != len(basis): # if the geometric multiplicity doesn't equal the algebraic + return False, eigenvecs + + return True, eigenvecs + +def _is_diagonalizable(M, reals_only=False, **kwargs): + """Returns ``True`` if a matrix is diagonalizable. + + Parameters + ========== + + reals_only : bool, optional + If ``True``, it tests whether the matrix can be diagonalized + to contain only real numbers on the diagonal. + + + If ``False``, it tests whether the matrix can be diagonalized + at all, even with numbers that may not be real. + + Examples + ======== + + Example of a diagonalizable matrix: + + >>> from sympy import Matrix + >>> M = Matrix([[1, 2, 0], [0, 3, 0], [2, -4, 2]]) + >>> M.is_diagonalizable() + True + + Example of a non-diagonalizable matrix: + + >>> M = Matrix([[0, 1], [0, 0]]) + >>> M.is_diagonalizable() + False + + Example of a matrix that is diagonalized in terms of non-real entries: + + >>> M = Matrix([[0, 1], [-1, 0]]) + >>> M.is_diagonalizable(reals_only=False) + True + >>> M.is_diagonalizable(reals_only=True) + False + + See Also + ======== + + sympy.matrices.matrixbase.MatrixBase.is_diagonal + diagonalize + """ + if not M.is_square: + return False + + if all(e.is_real for e in M) and M.is_symmetric(): + return True + + if all(e.is_complex for e in M) and M.is_hermitian: + return True + + return _is_diagonalizable_with_eigen(M, reals_only=reals_only)[0] + + +#G&VL, Matrix Computations, Algo 5.4.2 +def _householder_vector(x): + if not x.cols == 1: + raise ValueError("Input must be a column matrix") + v = x.copy() + v_plus = x.copy() + v_minus = x.copy() + q = x[0, 0] / abs(x[0, 0]) + norm_x = x.norm() + v_plus[0, 0] = x[0, 0] + q * norm_x + v_minus[0, 0] = x[0, 0] - q * norm_x + if x[1:, 0].norm() == 0: + bet = 0 + v[0, 0] = 1 + else: + if v_plus.norm() <= v_minus.norm(): + v = v_plus + else: + v = v_minus + v = v / v[0] + bet = 2 / (v.norm() ** 2) + return v, bet + + +def _bidiagonal_decmp_hholder(M): + m = M.rows + n = M.cols + A = M.as_mutable() + U, V = A.eye(m), A.eye(n) + for i in range(min(m, n)): + v, bet = _householder_vector(A[i:, i]) + hh_mat = A.eye(m - i) - bet * v * v.H + A[i:, i:] = hh_mat * A[i:, i:] + temp = A.eye(m) + temp[i:, i:] = hh_mat + U = U * temp + if i + 1 <= n - 2: + v, bet = _householder_vector(A[i, i+1:].T) + hh_mat = A.eye(n - i - 1) - bet * v * v.H + A[i:, i+1:] = A[i:, i+1:] * hh_mat + temp = A.eye(n) + temp[i+1:, i+1:] = hh_mat + V = temp * V + return U, A, V + + +def _eval_bidiag_hholder(M): + m = M.rows + n = M.cols + A = M.as_mutable() + for i in range(min(m, n)): + v, bet = _householder_vector(A[i:, i]) + hh_mat = A.eye(m-i) - bet * v * v.H + A[i:, i:] = hh_mat * A[i:, i:] + if i + 1 <= n - 2: + v, bet = _householder_vector(A[i, i+1:].T) + hh_mat = A.eye(n - i - 1) - bet * v * v.H + A[i:, i+1:] = A[i:, i+1:] * hh_mat + return A + + +def _bidiagonal_decomposition(M, upper=True): + """ + Returns $(U,B,V.H)$ for + + $$A = UBV^{H}$$ + + where $A$ is the input matrix, and $B$ is its Bidiagonalized form + + Note: Bidiagonal Computation can hang for symbolic matrices. + + Parameters + ========== + + upper : bool. Whether to do upper bidiagnalization or lower. + True for upper and False for lower. + + References + ========== + + .. [1] Algorithm 5.4.2, Matrix computations by Golub and Van Loan, 4th edition + .. [2] Complex Matrix Bidiagonalization, https://github.com/vslobody/Householder-Bidiagonalization + + """ + + if not isinstance(upper, bool): + raise ValueError("upper must be a boolean") + + if upper: + return _bidiagonal_decmp_hholder(M) + + X = _bidiagonal_decmp_hholder(M.H) + return X[2].H, X[1].H, X[0].H + + +def _bidiagonalize(M, upper=True): + """ + Returns $B$, the Bidiagonalized form of the input matrix. + + Note: Bidiagonal Computation can hang for symbolic matrices. + + Parameters + ========== + + upper : bool. Whether to do upper bidiagnalization or lower. + True for upper and False for lower. + + References + ========== + + .. [1] Algorithm 5.4.2, Matrix computations by Golub and Van Loan, 4th edition + .. [2] Complex Matrix Bidiagonalization : https://github.com/vslobody/Householder-Bidiagonalization + + """ + + if not isinstance(upper, bool): + raise ValueError("upper must be a boolean") + + if upper: + return _eval_bidiag_hholder(M) + return _eval_bidiag_hholder(M.H).H + + +def _diagonalize(M, reals_only=False, sort=False, normalize=False): + """ + Return (P, D), where D is diagonal and + + D = P^-1 * M * P + + where M is current matrix. + + Parameters + ========== + + reals_only : bool. Whether to throw an error if complex numbers are need + to diagonalize. (Default: False) + + sort : bool. Sort the eigenvalues along the diagonal. (Default: False) + + normalize : bool. If True, normalize the columns of P. (Default: False) + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix(3, 3, [1, 2, 0, 0, 3, 0, 2, -4, 2]) + >>> M + Matrix([ + [1, 2, 0], + [0, 3, 0], + [2, -4, 2]]) + >>> (P, D) = M.diagonalize() + >>> D + Matrix([ + [1, 0, 0], + [0, 2, 0], + [0, 0, 3]]) + >>> P + Matrix([ + [-1, 0, -1], + [ 0, 0, -1], + [ 2, 1, 2]]) + >>> P.inv() * M * P + Matrix([ + [1, 0, 0], + [0, 2, 0], + [0, 0, 3]]) + + See Also + ======== + + sympy.matrices.matrixbase.MatrixBase.is_diagonal + is_diagonalizable + """ + + if not M.is_square: + raise NonSquareMatrixError() + + is_diagonalizable, eigenvecs = _is_diagonalizable_with_eigen(M, + reals_only=reals_only) + + if not is_diagonalizable: + raise MatrixError("Matrix is not diagonalizable") + + if sort: + eigenvecs = sorted(eigenvecs, key=default_sort_key) + + p_cols, diag = [], [] + + for val, mult, basis in eigenvecs: + diag += [val] * mult + p_cols += basis + + if normalize: + p_cols = [v / v.norm() for v in p_cols] + + return M.hstack(*p_cols), M.diag(*diag) + + +def _fuzzy_positive_definite(M): + positive_diagonals = M._has_positive_diagonals() + if positive_diagonals is False: + return False + + if positive_diagonals and M.is_strongly_diagonally_dominant: + return True + + return None + + +def _fuzzy_positive_semidefinite(M): + nonnegative_diagonals = M._has_nonnegative_diagonals() + if nonnegative_diagonals is False: + return False + + if nonnegative_diagonals and M.is_weakly_diagonally_dominant: + return True + + return None + + +def _is_positive_definite(M): + if not M.is_hermitian: + if not M.is_square: + return False + M = M + M.H + + fuzzy = _fuzzy_positive_definite(M) + if fuzzy is not None: + return fuzzy + + return _is_positive_definite_GE(M) + + +def _is_positive_semidefinite(M): + if not M.is_hermitian: + if not M.is_square: + return False + M = M + M.H + + fuzzy = _fuzzy_positive_semidefinite(M) + if fuzzy is not None: + return fuzzy + + return _is_positive_semidefinite_cholesky(M) + + +def _is_negative_definite(M): + return _is_positive_definite(-M) + + +def _is_negative_semidefinite(M): + return _is_positive_semidefinite(-M) + + +def _is_indefinite(M): + if M.is_hermitian: + eigen = M.eigenvals() + args1 = [x.is_positive for x in eigen.keys()] + any_positive = fuzzy_or(args1) + args2 = [x.is_negative for x in eigen.keys()] + any_negative = fuzzy_or(args2) + + return fuzzy_and([any_positive, any_negative]) + + elif M.is_square: + return (M + M.H).is_indefinite + + return False + + +def _is_positive_definite_GE(M): + """A division-free gaussian elimination method for testing + positive-definiteness.""" + M = M.as_mutable() + size = M.rows + + for i in range(size): + is_positive = M[i, i].is_positive + if is_positive is not True: + return is_positive + for j in range(i+1, size): + M[j, i+1:] = M[i, i] * M[j, i+1:] - M[j, i] * M[i, i+1:] + return True + + +def _is_positive_semidefinite_cholesky(M): + """Uses Cholesky factorization with complete pivoting + + References + ========== + + .. [1] http://eprints.ma.man.ac.uk/1199/1/covered/MIMS_ep2008_116.pdf + + .. [2] https://www.value-at-risk.net/cholesky-factorization/ + """ + M = M.as_mutable() + for k in range(M.rows): + diags = [M[i, i] for i in range(k, M.rows)] + pivot, pivot_val, nonzero, _ = _find_reasonable_pivot(diags) + + if nonzero: + return None + + if pivot is None: + for i in range(k+1, M.rows): + for j in range(k, M.cols): + iszero = M[i, j].is_zero + if iszero is None: + return None + elif iszero is False: + return False + return True + + if M[k, k].is_negative or pivot_val.is_negative: + return False + elif not (M[k, k].is_nonnegative and pivot_val.is_nonnegative): + return None + + if pivot > 0: + M.col_swap(k, k+pivot) + M.row_swap(k, k+pivot) + + M[k, k] = sqrt(M[k, k]) + M[k, k+1:] /= M[k, k] + M[k+1:, k+1:] -= M[k, k+1:].H * M[k, k+1:] + + return M[-1, -1].is_nonnegative + + +_doc_positive_definite = \ + r"""Finds out the definiteness of a matrix. + + Explanation + =========== + + A square real matrix $A$ is: + + - A positive definite matrix if $x^T A x > 0$ + for all non-zero real vectors $x$. + - A positive semidefinite matrix if $x^T A x \geq 0$ + for all non-zero real vectors $x$. + - A negative definite matrix if $x^T A x < 0$ + for all non-zero real vectors $x$. + - A negative semidefinite matrix if $x^T A x \leq 0$ + for all non-zero real vectors $x$. + - An indefinite matrix if there exists non-zero real vectors + $x, y$ with $x^T A x > 0 > y^T A y$. + + A square complex matrix $A$ is: + + - A positive definite matrix if $\text{re}(x^H A x) > 0$ + for all non-zero complex vectors $x$. + - A positive semidefinite matrix if $\text{re}(x^H A x) \geq 0$ + for all non-zero complex vectors $x$. + - A negative definite matrix if $\text{re}(x^H A x) < 0$ + for all non-zero complex vectors $x$. + - A negative semidefinite matrix if $\text{re}(x^H A x) \leq 0$ + for all non-zero complex vectors $x$. + - An indefinite matrix if there exists non-zero complex vectors + $x, y$ with $\text{re}(x^H A x) > 0 > \text{re}(y^H A y)$. + + A matrix need not be symmetric or hermitian to be positive definite. + + - A real non-symmetric matrix is positive definite if and only if + $\frac{A + A^T}{2}$ is positive definite. + - A complex non-hermitian matrix is positive definite if and only if + $\frac{A + A^H}{2}$ is positive definite. + + And this extension can apply for all the definitions above. + + However, for complex cases, you can restrict the definition of + $\text{re}(x^H A x) > 0$ to $x^H A x > 0$ and require the matrix + to be hermitian. + But we do not present this restriction for computation because you + can check ``M.is_hermitian`` independently with this and use + the same procedure. + + Examples + ======== + + An example of symmetric positive definite matrix: + + .. plot:: + :context: reset + :format: doctest + :include-source: True + + >>> from sympy import Matrix, symbols + >>> from sympy.plotting import plot3d + >>> a, b = symbols('a b') + >>> x = Matrix([a, b]) + + >>> A = Matrix([[1, 0], [0, 1]]) + >>> A.is_positive_definite + True + >>> A.is_positive_semidefinite + True + + >>> p = plot3d((x.T*A*x)[0, 0], (a, -1, 1), (b, -1, 1)) + + An example of symmetric positive semidefinite matrix: + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> A = Matrix([[1, -1], [-1, 1]]) + >>> A.is_positive_definite + False + >>> A.is_positive_semidefinite + True + + >>> p = plot3d((x.T*A*x)[0, 0], (a, -1, 1), (b, -1, 1)) + + An example of symmetric negative definite matrix: + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> A = Matrix([[-1, 0], [0, -1]]) + >>> A.is_negative_definite + True + >>> A.is_negative_semidefinite + True + >>> A.is_indefinite + False + + >>> p = plot3d((x.T*A*x)[0, 0], (a, -1, 1), (b, -1, 1)) + + An example of symmetric indefinite matrix: + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> A = Matrix([[1, 2], [2, -1]]) + >>> A.is_indefinite + True + + >>> p = plot3d((x.T*A*x)[0, 0], (a, -1, 1), (b, -1, 1)) + + An example of non-symmetric positive definite matrix. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> A = Matrix([[1, 2], [-2, 1]]) + >>> A.is_positive_definite + True + >>> A.is_positive_semidefinite + True + + >>> p = plot3d((x.T*A*x)[0, 0], (a, -1, 1), (b, -1, 1)) + + Notes + ===== + + Although some people trivialize the definition of positive definite + matrices only for symmetric or hermitian matrices, this restriction + is not correct because it does not classify all instances of + positive definite matrices from the definition $x^T A x > 0$ or + $\text{re}(x^H A x) > 0$. + + For instance, ``Matrix([[1, 2], [-2, 1]])`` presented in + the example above is an example of real positive definite matrix + that is not symmetric. + + However, since the following formula holds true; + + .. math:: + \text{re}(x^H A x) > 0 \iff + \text{re}(x^H \frac{A + A^H}{2} x) > 0 + + We can classify all positive definite matrices that may or may not + be symmetric or hermitian by transforming the matrix to + $\frac{A + A^T}{2}$ or $\frac{A + A^H}{2}$ + (which is guaranteed to be always real symmetric or complex + hermitian) and we can defer most of the studies to symmetric or + hermitian positive definite matrices. + + But it is a different problem for the existence of Cholesky + decomposition. Because even though a non symmetric or a non + hermitian matrix can be positive definite, Cholesky or LDL + decomposition does not exist because the decompositions require the + matrix to be symmetric or hermitian. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Definiteness_of_a_matrix#Eigenvalues + + .. [2] https://mathworld.wolfram.com/PositiveDefiniteMatrix.html + + .. [3] Johnson, C. R. "Positive Definite Matrices." Amer. + Math. Monthly 77, 259-264 1970. + """ + +_is_positive_definite.__doc__ = _doc_positive_definite +_is_positive_semidefinite.__doc__ = _doc_positive_definite +_is_negative_definite.__doc__ = _doc_positive_definite +_is_negative_semidefinite.__doc__ = _doc_positive_definite +_is_indefinite.__doc__ = _doc_positive_definite + + +def _jordan_form(M, calc_transform=True, *, chop=False): + """Return $(P, J)$ where $J$ is a Jordan block + matrix and $P$ is a matrix such that $M = P J P^{-1}$ + + Parameters + ========== + + calc_transform : bool + If ``False``, then only $J$ is returned. + + chop : bool + All matrices are converted to exact types when computing + eigenvalues and eigenvectors. As a result, there may be + approximation errors. If ``chop==True``, these errors + will be truncated. + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix([[ 6, 5, -2, -3], [-3, -1, 3, 3], [ 2, 1, -2, -3], [-1, 1, 5, 5]]) + >>> P, J = M.jordan_form() + >>> J + Matrix([ + [2, 1, 0, 0], + [0, 2, 0, 0], + [0, 0, 2, 1], + [0, 0, 0, 2]]) + + See Also + ======== + + jordan_block + """ + + if not M.is_square: + raise NonSquareMatrixError("Only square matrices have Jordan forms") + + mat = M + has_floats = M.has(Float) + + if has_floats: + try: + max_prec = max(term._prec for term in M.values() if isinstance(term, Float)) + except ValueError: + # if no term in the matrix is explicitly a Float calling max() + # will throw a error so setting max_prec to default value of 53 + max_prec = 53 + + # setting minimum max_dps to 15 to prevent loss of precision in + # matrix containing non evaluated expressions + max_dps = max(prec_to_dps(max_prec), 15) + + def restore_floats(*args): + """If ``has_floats`` is `True`, cast all ``args`` as + matrices of floats.""" + + if has_floats: + args = [m.evalf(n=max_dps, chop=chop) for m in args] + if len(args) == 1: + return args[0] + + return args + + # cache calculations for some speedup + mat_cache = {} + + def eig_mat(val, pow): + """Cache computations of ``(M - val*I)**pow`` for quick + retrieval""" + + if (val, pow) in mat_cache: + return mat_cache[(val, pow)] + + if (val, pow - 1) in mat_cache: + mat_cache[(val, pow)] = mat_cache[(val, pow - 1)].multiply( + mat_cache[(val, 1)], dotprodsimp=None) + else: + mat_cache[(val, pow)] = (mat - val*M.eye(M.rows)).pow(pow) + + return mat_cache[(val, pow)] + + # helper functions + def nullity_chain(val, algebraic_multiplicity): + """Calculate the sequence [0, nullity(E), nullity(E**2), ...] + until it is constant where ``E = M - val*I``""" + + # mat.rank() is faster than computing the null space, + # so use the rank-nullity theorem + cols = M.cols + ret = [0] + nullity = cols - eig_mat(val, 1).rank() + i = 2 + + while nullity != ret[-1]: + ret.append(nullity) + + if nullity == algebraic_multiplicity: + break + + nullity = cols - eig_mat(val, i).rank() + i += 1 + + # Due to issues like #7146 and #15872, SymPy sometimes + # gives the wrong rank. In this case, raise an error + # instead of returning an incorrect matrix + if nullity < ret[-1] or nullity > algebraic_multiplicity: + raise MatrixError( + "SymPy had encountered an inconsistent " + "result while computing Jordan block: " + "{}".format(M)) + + return ret + + def blocks_from_nullity_chain(d): + """Return a list of the size of each Jordan block. + If d_n is the nullity of E**n, then the number + of Jordan blocks of size n is + + 2*d_n - d_(n-1) - d_(n+1)""" + + # d[0] is always the number of columns, so skip past it + mid = [2*d[n] - d[n - 1] - d[n + 1] for n in range(1, len(d) - 1)] + # d is assumed to plateau with "d[ len(d) ] == d[-1]", so + # 2*d_n - d_(n-1) - d_(n+1) == d_n - d_(n-1) + end = [d[-1] - d[-2]] if len(d) > 1 else [d[0]] + + return mid + end + + def pick_vec(small_basis, big_basis): + """Picks a vector from big_basis that isn't in + the subspace spanned by small_basis""" + + if len(small_basis) == 0: + return big_basis[0] + + for v in big_basis: + _, pivots = M.hstack(*(small_basis + [v])).echelon_form( + with_pivots=True) + + if pivots[-1] == len(small_basis): + return v + + # roots doesn't like Floats, so replace them with Rationals + if has_floats: + from sympy.simplify import nsimplify + mat = mat.applyfunc(lambda x: nsimplify(x, rational=True)) + + # first calculate the jordan block structure + eigs = mat.eigenvals() + + # Make sure that we have all roots in radical form + for x in eigs: + if x.has(CRootOf): + raise MatrixError( + "Jordan normal form is not implemented if the matrix have " + "eigenvalues in CRootOf form") + + # most matrices have distinct eigenvalues + # and so are diagonalizable. In this case, don't + # do extra work! + if len(eigs.keys()) == mat.cols: + blocks = sorted(eigs.keys(), key=default_sort_key) + jordan_mat = mat.diag(*blocks) + + if not calc_transform: + return restore_floats(jordan_mat) + + jordan_basis = [eig_mat(eig, 1).nullspace()[0] + for eig in blocks] + basis_mat = mat.hstack(*jordan_basis) + + return restore_floats(basis_mat, jordan_mat) + + block_structure = [] + + for eig in sorted(eigs.keys(), key=default_sort_key): + algebraic_multiplicity = eigs[eig] + chain = nullity_chain(eig, algebraic_multiplicity) + block_sizes = blocks_from_nullity_chain(chain) + + # if block_sizes = = [a, b, c, ...], then the number of + # Jordan blocks of size 1 is a, of size 2 is b, etc. + # create an array that has (eig, block_size) with one + # entry for each block + size_nums = [(i+1, num) for i, num in enumerate(block_sizes)] + + # we expect larger Jordan blocks to come earlier + size_nums.reverse() + + block_structure.extend( + [(eig, size) for size, num in size_nums for _ in range(num)]) + + jordan_form_size = sum(size for eig, size in block_structure) + + if jordan_form_size != M.rows: + raise MatrixError( + "SymPy had encountered an inconsistent result while " + "computing Jordan block. : {}".format(M)) + + blocks = (mat.jordan_block(size=size, eigenvalue=eig) for eig, size in block_structure) + jordan_mat = mat.diag(*blocks) + + if not calc_transform: + return restore_floats(jordan_mat) + + # For each generalized eigenspace, calculate a basis. + # We start by looking for a vector in null( (A - eig*I)**n ) + # which isn't in null( (A - eig*I)**(n-1) ) where n is + # the size of the Jordan block + # + # Ideally we'd just loop through block_structure and + # compute each generalized eigenspace. However, this + # causes a lot of unneeded computation. Instead, we + # go through the eigenvalues separately, since we know + # their generalized eigenspaces must have bases that + # are linearly independent. + jordan_basis = [] + + for eig in sorted(eigs.keys(), key=default_sort_key): + eig_basis = [] + + for block_eig, size in block_structure: + if block_eig != eig: + continue + + null_big = (eig_mat(eig, size)).nullspace() + null_small = (eig_mat(eig, size - 1)).nullspace() + + # we want to pick something that is in the big basis + # and not the small, but also something that is independent + # of any other generalized eigenvectors from a different + # generalized eigenspace sharing the same eigenvalue. + vec = pick_vec(null_small + eig_basis, null_big) + new_vecs = [eig_mat(eig, i).multiply(vec, dotprodsimp=None) + for i in range(size)] + + eig_basis.extend(new_vecs) + jordan_basis.extend(reversed(new_vecs)) + + basis_mat = mat.hstack(*jordan_basis) + + return restore_floats(basis_mat, jordan_mat) + + +def _left_eigenvects(M, **flags): + """Returns left eigenvectors and eigenvalues. + + This function returns the list of triples (eigenval, multiplicity, + basis) for the left eigenvectors. Options are the same as for + eigenvects(), i.e. the ``**flags`` arguments gets passed directly to + eigenvects(). + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix([[0, 1, 1], [1, 0, 0], [1, 1, 1]]) + >>> M.eigenvects() + [(-1, 1, [Matrix([ + [-1], + [ 1], + [ 0]])]), (0, 1, [Matrix([ + [ 0], + [-1], + [ 1]])]), (2, 1, [Matrix([ + [2/3], + [1/3], + [ 1]])])] + >>> M.left_eigenvects() + [(-1, 1, [Matrix([[-2, 1, 1]])]), (0, 1, [Matrix([[-1, -1, 1]])]), (2, + 1, [Matrix([[1, 1, 1]])])] + + """ + + eigs = M.transpose().eigenvects(**flags) + + return [(val, mult, [l.transpose() for l in basis]) for val, mult, basis in eigs] + + +def _singular_values(M): + """Compute the singular values of a Matrix + + Examples + ======== + + >>> from sympy import Matrix, Symbol + >>> x = Symbol('x', real=True) + >>> M = Matrix([[0, 1, 0], [0, x, 0], [-1, 0, 0]]) + >>> M.singular_values() + [sqrt(x**2 + 1), 1, 0] + + See Also + ======== + + condition_number + """ + + if M.rows >= M.cols: + valmultpairs = M.H.multiply(M).eigenvals() + else: + valmultpairs = M.multiply(M.H).eigenvals() + + # Expands result from eigenvals into a simple list + vals = [] + + for k, v in valmultpairs.items(): + vals += [sqrt(k)] * v # dangerous! same k in several spots! + + # Pad with zeros if singular values are computed in reverse way, + # to give consistent format. + if len(vals) < M.cols: + vals += [M.zero] * (M.cols - len(vals)) + + # sort them in descending order + vals.sort(reverse=True, key=default_sort_key) + + return vals diff --git a/phivenv/Lib/site-packages/sympy/matrices/exceptions.py b/phivenv/Lib/site-packages/sympy/matrices/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..bfc7cfa0bdffd59ff2bc5a9cd85cf9b04ed1a63d --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/exceptions.py @@ -0,0 +1,26 @@ +""" +Exceptions raised by the matrix module. +""" + + +class MatrixError(Exception): + pass + + +class ShapeError(ValueError, MatrixError): + """Wrong matrix shape""" + pass + + +class NonSquareMatrixError(ShapeError): + pass + + +class NonInvertibleMatrixError(ValueError, MatrixError): + """The matrix in not invertible (division by multidimensional zero error).""" + pass + + +class NonPositiveDefiniteMatrixError(ValueError, MatrixError): + """The matrix is not a positive-definite matrix.""" + pass diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/__init__.py b/phivenv/Lib/site-packages/sympy/matrices/expressions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5f4ab203ab74165d1003cdedd83945ea3fcf8f47 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/expressions/__init__.py @@ -0,0 +1,62 @@ +""" A module which handles Matrix Expressions """ + +from .slice import MatrixSlice +from .blockmatrix import BlockMatrix, BlockDiagMatrix, block_collapse, blockcut +from .companion import CompanionMatrix +from .funcmatrix import FunctionMatrix +from .inverse import Inverse +from .matadd import MatAdd +from .matexpr import MatrixExpr, MatrixSymbol, matrix_symbols +from .matmul import MatMul +from .matpow import MatPow +from .trace import Trace, trace +from .determinant import Determinant, det, Permanent, per +from .transpose import Transpose +from .adjoint import Adjoint +from .hadamard import hadamard_product, HadamardProduct, hadamard_power, HadamardPower +from .diagonal import DiagonalMatrix, DiagonalOf, DiagMatrix, diagonalize_vector +from .dotproduct import DotProduct +from .kronecker import kronecker_product, KroneckerProduct, combine_kronecker +from .permutation import PermutationMatrix, MatrixPermute +from .sets import MatrixSet +from .special import ZeroMatrix, Identity, OneMatrix + +__all__ = [ + 'MatrixSlice', + + 'BlockMatrix', 'BlockDiagMatrix', 'block_collapse', 'blockcut', + 'FunctionMatrix', + + 'CompanionMatrix', + + 'Inverse', + + 'MatAdd', + + 'Identity', 'MatrixExpr', 'MatrixSymbol', 'ZeroMatrix', 'OneMatrix', + 'matrix_symbols', 'MatrixSet', + + 'MatMul', + + 'MatPow', + + 'Trace', 'trace', + + 'Determinant', 'det', + + 'Transpose', + + 'Adjoint', + + 'hadamard_product', 'HadamardProduct', 'hadamard_power', 'HadamardPower', + + 'DiagonalMatrix', 'DiagonalOf', 'DiagMatrix', 'diagonalize_vector', + + 'DotProduct', + + 'kronecker_product', 'KroneckerProduct', 'combine_kronecker', + + 'PermutationMatrix', 'MatrixPermute', + + 'Permanent', 'per' +] diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f96068b3c80a7911ff8bf718397ef70292ea4fed Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/_shape.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/_shape.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47a9cac9820ee014f786b68ba29f4208e297d6f6 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/_shape.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/adjoint.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/adjoint.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..846ea19bbf44cf9baa343e6c7270850c0ab0ddf1 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/adjoint.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/applyfunc.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/applyfunc.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c64ed6da6ba3c83a3e94c9f095034457f0fd6ebb Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/applyfunc.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/blockmatrix.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/blockmatrix.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..017c359886796c9ac636b0a4351f8224cbbe48e3 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/blockmatrix.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/companion.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/companion.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9719af131950cc668b2585dd5011e2976cc16b8 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/companion.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/determinant.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/determinant.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d73f996a7366f6a84001569301755cd1ce34d906 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/determinant.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/diagonal.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/diagonal.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b759aff98e56d6de88bb1e3099243fe901e5713d Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/diagonal.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/dotproduct.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/dotproduct.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..183c28ba222c268ed5dfb77534d86729c519a464 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/dotproduct.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/factorizations.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/factorizations.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b699f9565b9edfbc1fe591da95baf7f60dc46074 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/factorizations.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/fourier.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/fourier.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e28ad418c8ce94313b604c624195d1a303b6966 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/fourier.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/funcmatrix.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/funcmatrix.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d0b2307179e9ec45acf438dc1a4a33c1211b21b Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/funcmatrix.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/hadamard.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/hadamard.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..58da59dc838c0d80f37f23d9c7a879206509e903 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/hadamard.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/inverse.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/inverse.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67832222eafc3b1b51856cc5f7526d3460fcf0da Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/inverse.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/kronecker.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/kronecker.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53d571827101fe6320bb7d2f826bff5f14e0dc89 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/kronecker.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/matadd.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/matadd.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ba63afb1b7b5f303bff11ad204c97772874c3f6 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/matadd.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/matexpr.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/matexpr.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..126013784af85a975cd30407f2f1e306c93a1414 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/matexpr.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/matmul.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/matmul.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d60bf1ad206bf5cce19fa03e7e0891be6510777 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/matmul.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/matpow.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/matpow.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..596201ede1cc90aa404d1812fbc1226728e2be6d Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/matpow.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/permutation.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/permutation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f4036b232fbe72125fbb0db81610e17f2831f5d Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/permutation.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/sets.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/sets.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e6db7db1079ef9692d93dfd725b130f16d75b4c Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/sets.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/slice.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/slice.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4da098bf6a6b47fb68ceac801b2301b55506199 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/slice.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/special.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/special.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6cf86cf92206e1ab014a996942becf6966fcd43 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/special.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/trace.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/trace.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..441502288f329e5a16feab93eb075318f72341e6 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/trace.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/transpose.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/transpose.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..89cbe482a4443d115b2c52b65244950b74b74312 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/expressions/__pycache__/transpose.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/_shape.py b/phivenv/Lib/site-packages/sympy/matrices/expressions/_shape.py new file mode 100644 index 0000000000000000000000000000000000000000..a95d481bf8e1edf4c62992044cd50563b335caac --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/expressions/_shape.py @@ -0,0 +1,102 @@ +from sympy.core.relational import Eq +from sympy.core.expr import Expr +from sympy.core.numbers import Integer +from sympy.logic.boolalg import Boolean, And +from sympy.matrices.expressions.matexpr import MatrixExpr +from sympy.matrices.exceptions import ShapeError +from typing import Union + + +def is_matadd_valid(*args: MatrixExpr) -> Boolean: + """Return the symbolic condition how ``MatAdd``, ``HadamardProduct`` + makes sense. + + Parameters + ========== + + args + The list of arguments of matrices to be tested for. + + Examples + ======== + + >>> from sympy import MatrixSymbol, symbols + >>> from sympy.matrices.expressions._shape import is_matadd_valid + + >>> m, n, p, q = symbols('m n p q') + >>> A = MatrixSymbol('A', m, n) + >>> B = MatrixSymbol('B', p, q) + >>> is_matadd_valid(A, B) + Eq(m, p) & Eq(n, q) + """ + rows, cols = zip(*(arg.shape for arg in args)) + return And( + *(Eq(i, j) for i, j in zip(rows[:-1], rows[1:])), + *(Eq(i, j) for i, j in zip(cols[:-1], cols[1:])), + ) + + +def is_matmul_valid(*args: Union[MatrixExpr, Expr]) -> Boolean: + """Return the symbolic condition how ``MatMul`` makes sense + + Parameters + ========== + + args + The list of arguments of matrices and scalar expressions to be tested + for. + + Examples + ======== + + >>> from sympy import MatrixSymbol, symbols + >>> from sympy.matrices.expressions._shape import is_matmul_valid + + >>> m, n, p, q = symbols('m n p q') + >>> A = MatrixSymbol('A', m, n) + >>> B = MatrixSymbol('B', p, q) + >>> is_matmul_valid(A, B) + Eq(n, p) + """ + rows, cols = zip(*(arg.shape for arg in args if isinstance(arg, MatrixExpr))) + return And(*(Eq(i, j) for i, j in zip(cols[:-1], rows[1:]))) + + +def is_square(arg: MatrixExpr, /) -> Boolean: + """Return the symbolic condition how the matrix is assumed to be square + + Parameters + ========== + + arg + The matrix to be tested for. + + Examples + ======== + + >>> from sympy import MatrixSymbol, symbols + >>> from sympy.matrices.expressions._shape import is_square + + >>> m, n = symbols('m n') + >>> A = MatrixSymbol('A', m, n) + >>> is_square(A) + Eq(m, n) + """ + return Eq(arg.rows, arg.cols) + + +def validate_matadd_integer(*args: MatrixExpr) -> None: + """Validate matrix shape for addition only for integer values""" + rows, cols = zip(*(x.shape for x in args)) + if len(set(filter(lambda x: isinstance(x, (int, Integer)), rows))) > 1: + raise ShapeError(f"Matrices have mismatching shape: {rows}") + if len(set(filter(lambda x: isinstance(x, (int, Integer)), cols))) > 1: + raise ShapeError(f"Matrices have mismatching shape: {cols}") + + +def validate_matmul_integer(*args: MatrixExpr) -> None: + """Validate matrix shape for multiplication only for integer values""" + for A, B in zip(args[:-1], args[1:]): + i, j = A.cols, B.rows + if isinstance(i, (int, Integer)) and isinstance(j, (int, Integer)) and i != j: + raise ShapeError("Matrices are not aligned", i, j) diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/adjoint.py b/phivenv/Lib/site-packages/sympy/matrices/expressions/adjoint.py new file mode 100644 index 0000000000000000000000000000000000000000..2039a7b2eb8eeacb02435979121c4133a11d8e02 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/expressions/adjoint.py @@ -0,0 +1,60 @@ +from sympy.core import Basic +from sympy.functions import adjoint, conjugate +from sympy.matrices.expressions.matexpr import MatrixExpr + + +class Adjoint(MatrixExpr): + """ + The Hermitian adjoint of a matrix expression. + + This is a symbolic object that simply stores its argument without + evaluating it. To actually compute the adjoint, use the ``adjoint()`` + function. + + Examples + ======== + + >>> from sympy import MatrixSymbol, Adjoint, adjoint + >>> A = MatrixSymbol('A', 3, 5) + >>> B = MatrixSymbol('B', 5, 3) + >>> Adjoint(A*B) + Adjoint(A*B) + >>> adjoint(A*B) + Adjoint(B)*Adjoint(A) + >>> adjoint(A*B) == Adjoint(A*B) + False + >>> adjoint(A*B) == Adjoint(A*B).doit() + True + """ + is_Adjoint = True + + def doit(self, **hints): + arg = self.arg + if hints.get('deep', True) and isinstance(arg, Basic): + return adjoint(arg.doit(**hints)) + else: + return adjoint(self.arg) + + @property + def arg(self): + return self.args[0] + + @property + def shape(self): + return self.arg.shape[::-1] + + def _entry(self, i, j, **kwargs): + return conjugate(self.arg._entry(j, i, **kwargs)) + + def _eval_adjoint(self): + return self.arg + + def _eval_transpose(self): + return self.arg.conjugate() + + def _eval_conjugate(self): + return self.arg.transpose() + + def _eval_trace(self): + from sympy.matrices.expressions.trace import Trace + return conjugate(Trace(self.arg)) diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/applyfunc.py b/phivenv/Lib/site-packages/sympy/matrices/expressions/applyfunc.py new file mode 100644 index 0000000000000000000000000000000000000000..c0363658447a8dc37a152b30e45533bac582b10c --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/expressions/applyfunc.py @@ -0,0 +1,204 @@ +from sympy.core.expr import ExprBuilder +from sympy.core.function import (Function, FunctionClass, Lambda) +from sympy.core.symbol import Dummy +from sympy.core.sympify import sympify, _sympify +from sympy.matrices.expressions import MatrixExpr +from sympy.matrices.matrixbase import MatrixBase + + +class ElementwiseApplyFunction(MatrixExpr): + r""" + Apply function to a matrix elementwise without evaluating. + + Examples + ======== + + It can be created by calling ``.applyfunc()`` on a matrix + expression: + + >>> from sympy import MatrixSymbol + >>> from sympy.matrices.expressions.applyfunc import ElementwiseApplyFunction + >>> from sympy import exp + >>> X = MatrixSymbol("X", 3, 3) + >>> X.applyfunc(exp) + Lambda(_d, exp(_d)).(X) + + Otherwise using the class constructor: + + >>> from sympy import eye + >>> expr = ElementwiseApplyFunction(exp, eye(3)) + >>> expr + Lambda(_d, exp(_d)).(Matrix([ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1]])) + >>> expr.doit() + Matrix([ + [E, 1, 1], + [1, E, 1], + [1, 1, E]]) + + Notice the difference with the real mathematical functions: + + >>> exp(eye(3)) + Matrix([ + [E, 0, 0], + [0, E, 0], + [0, 0, E]]) + """ + + def __new__(cls, function, expr): + expr = _sympify(expr) + if not expr.is_Matrix: + raise ValueError("{} must be a matrix instance.".format(expr)) + + if expr.shape == (1, 1): + # Check if the function returns a matrix, in that case, just apply + # the function instead of creating an ElementwiseApplyFunc object: + ret = function(expr) + if isinstance(ret, MatrixExpr): + return ret + + if not isinstance(function, (FunctionClass, Lambda)): + d = Dummy('d') + function = Lambda(d, function(d)) + + function = sympify(function) + if not isinstance(function, (FunctionClass, Lambda)): + raise ValueError( + "{} should be compatible with SymPy function classes." + .format(function)) + + if 1 not in function.nargs: + raise ValueError( + '{} should be able to accept 1 arguments.'.format(function)) + + if not isinstance(function, Lambda): + d = Dummy('d') + function = Lambda(d, function(d)) + + obj = MatrixExpr.__new__(cls, function, expr) + return obj + + @property + def function(self): + return self.args[0] + + @property + def expr(self): + return self.args[1] + + @property + def shape(self): + return self.expr.shape + + def doit(self, **hints): + deep = hints.get("deep", True) + expr = self.expr + if deep: + expr = expr.doit(**hints) + function = self.function + if isinstance(function, Lambda) and function.is_identity: + # This is a Lambda containing the identity function. + return expr + if isinstance(expr, MatrixBase): + return expr.applyfunc(self.function) + elif isinstance(expr, ElementwiseApplyFunction): + return ElementwiseApplyFunction( + lambda x: self.function(expr.function(x)), + expr.expr + ).doit(**hints) + else: + return self + + def _entry(self, i, j, **kwargs): + return self.function(self.expr._entry(i, j, **kwargs)) + + def _get_function_fdiff(self): + d = Dummy("d") + function = self.function(d) + fdiff = function.diff(d) + if isinstance(fdiff, Function): + fdiff = type(fdiff) + else: + fdiff = Lambda(d, fdiff) + return fdiff + + def _eval_derivative(self, x): + from sympy.matrices.expressions.hadamard import hadamard_product + dexpr = self.expr.diff(x) + fdiff = self._get_function_fdiff() + return hadamard_product( + dexpr, + ElementwiseApplyFunction(fdiff, self.expr) + ) + + def _eval_derivative_matrix_lines(self, x): + from sympy.matrices.expressions.special import Identity + from sympy.tensor.array.expressions.array_expressions import ArrayContraction + from sympy.tensor.array.expressions.array_expressions import ArrayDiagonal + from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct + + fdiff = self._get_function_fdiff() + lr = self.expr._eval_derivative_matrix_lines(x) + ewdiff = ElementwiseApplyFunction(fdiff, self.expr) + if 1 in x.shape: + # Vector: + iscolumn = self.shape[1] == 1 + for i in lr: + if iscolumn: + ptr1 = i.first_pointer + ptr2 = Identity(self.shape[1]) + else: + ptr1 = Identity(self.shape[0]) + ptr2 = i.second_pointer + + subexpr = ExprBuilder( + ArrayDiagonal, + [ + ExprBuilder( + ArrayTensorProduct, + [ + ewdiff, + ptr1, + ptr2, + ] + ), + (0, 2) if iscolumn else (1, 4) + ], + validator=ArrayDiagonal._validate + ) + i._lines = [subexpr] + i._first_pointer_parent = subexpr.args[0].args + i._first_pointer_index = 1 + i._second_pointer_parent = subexpr.args[0].args + i._second_pointer_index = 2 + else: + # Matrix case: + for i in lr: + ptr1 = i.first_pointer + ptr2 = i.second_pointer + newptr1 = Identity(ptr1.shape[1]) + newptr2 = Identity(ptr2.shape[1]) + subexpr = ExprBuilder( + ArrayContraction, + [ + ExprBuilder( + ArrayTensorProduct, + [ptr1, newptr1, ewdiff, ptr2, newptr2] + ), + (1, 2, 4), + (5, 7, 8), + ], + validator=ArrayContraction._validate + ) + i._first_pointer_parent = subexpr.args[0].args + i._first_pointer_index = 1 + i._second_pointer_parent = subexpr.args[0].args + i._second_pointer_index = 4 + i._lines = [subexpr] + return lr + + def _eval_transpose(self): + from sympy.matrices.expressions.transpose import Transpose + return self.func(self.function, Transpose(self.expr).doit()) diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/blockmatrix.py b/phivenv/Lib/site-packages/sympy/matrices/expressions/blockmatrix.py new file mode 100644 index 0000000000000000000000000000000000000000..0125d6233ba7cf8c0b590fbb655d9c7c447e0bd4 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/expressions/blockmatrix.py @@ -0,0 +1,975 @@ +from sympy.assumptions.ask import (Q, ask) +from sympy.core import Basic, Add, Mul, S +from sympy.core.sympify import _sympify +from sympy.functions.elementary.complexes import re, im +from sympy.strategies import typed, exhaust, condition, do_one, unpack +from sympy.strategies.traverse import bottom_up +from sympy.utilities.iterables import is_sequence, sift +from sympy.utilities.misc import filldedent + +from sympy.matrices import Matrix, ShapeError +from sympy.matrices.exceptions import NonInvertibleMatrixError +from sympy.matrices.expressions.determinant import det, Determinant +from sympy.matrices.expressions.inverse import Inverse +from sympy.matrices.expressions.matadd import MatAdd +from sympy.matrices.expressions.matexpr import MatrixExpr, MatrixElement +from sympy.matrices.expressions.matmul import MatMul +from sympy.matrices.expressions.matpow import MatPow +from sympy.matrices.expressions.slice import MatrixSlice +from sympy.matrices.expressions.special import ZeroMatrix, Identity +from sympy.matrices.expressions.trace import trace +from sympy.matrices.expressions.transpose import Transpose, transpose + + +class BlockMatrix(MatrixExpr): + """A BlockMatrix is a Matrix comprised of other matrices. + + The submatrices are stored in a SymPy Matrix object but accessed as part of + a Matrix Expression + + >>> from sympy import (MatrixSymbol, BlockMatrix, symbols, + ... Identity, ZeroMatrix, block_collapse) + >>> n,m,l = symbols('n m l') + >>> X = MatrixSymbol('X', n, n) + >>> Y = MatrixSymbol('Y', m, m) + >>> Z = MatrixSymbol('Z', n, m) + >>> B = BlockMatrix([[X, Z], [ZeroMatrix(m,n), Y]]) + >>> print(B) + Matrix([ + [X, Z], + [0, Y]]) + + >>> C = BlockMatrix([[Identity(n), Z]]) + >>> print(C) + Matrix([[I, Z]]) + + >>> print(block_collapse(C*B)) + Matrix([[X, Z + Z*Y]]) + + Some matrices might be comprised of rows of blocks with + the matrices in each row having the same height and the + rows all having the same total number of columns but + not having the same number of columns for each matrix + in each row. In this case, the matrix is not a block + matrix and should be instantiated by Matrix. + + >>> from sympy import ones, Matrix + >>> dat = [ + ... [ones(3,2), ones(3,3)*2], + ... [ones(2,3)*3, ones(2,2)*4]] + ... + >>> BlockMatrix(dat) + Traceback (most recent call last): + ... + ValueError: + Although this matrix is comprised of blocks, the blocks do not fill + the matrix in a size-symmetric fashion. To create a full matrix from + these arguments, pass them directly to Matrix. + >>> Matrix(dat) + Matrix([ + [1, 1, 2, 2, 2], + [1, 1, 2, 2, 2], + [1, 1, 2, 2, 2], + [3, 3, 3, 4, 4], + [3, 3, 3, 4, 4]]) + + See Also + ======== + sympy.matrices.matrixbase.MatrixBase.irregular + """ + def __new__(cls, *args, **kwargs): + from sympy.matrices.immutable import ImmutableDenseMatrix + isMat = lambda i: getattr(i, 'is_Matrix', False) + if len(args) != 1 or \ + not is_sequence(args[0]) or \ + len({isMat(r) for r in args[0]}) != 1: + raise ValueError(filldedent(''' + expecting a sequence of 1 or more rows + containing Matrices.''')) + rows = args[0] if args else [] + if not isMat(rows): + if rows and isMat(rows[0]): + rows = [rows] # rows is not list of lists or [] + # regularity check + # same number of matrices in each row + blocky = ok = len({len(r) for r in rows}) == 1 + if ok: + # same number of rows for each matrix in a row + for r in rows: + ok = len({i.rows for i in r}) == 1 + if not ok: + break + blocky = ok + if ok: + # same number of cols for each matrix in each col + for c in range(len(rows[0])): + ok = len({rows[i][c].cols + for i in range(len(rows))}) == 1 + if not ok: + break + if not ok: + # same total cols in each row + ok = len({ + sum(i.cols for i in r) for r in rows}) == 1 + if blocky and ok: + raise ValueError(filldedent(''' + Although this matrix is comprised of blocks, + the blocks do not fill the matrix in a + size-symmetric fashion. To create a full matrix + from these arguments, pass them directly to + Matrix.''')) + raise ValueError(filldedent(''' + When there are not the same number of rows in each + row's matrices or there are not the same number of + total columns in each row, the matrix is not a + block matrix. If this matrix is known to consist of + blocks fully filling a 2-D space then see + Matrix.irregular.''')) + mat = ImmutableDenseMatrix(rows, evaluate=False) + obj = Basic.__new__(cls, mat) + return obj + + @property + def shape(self): + numrows = numcols = 0 + M = self.blocks + for i in range(M.shape[0]): + numrows += M[i, 0].shape[0] + for i in range(M.shape[1]): + numcols += M[0, i].shape[1] + return (numrows, numcols) + + @property + def blockshape(self): + return self.blocks.shape + + @property + def blocks(self): + return self.args[0] + + @property + def rowblocksizes(self): + return [self.blocks[i, 0].rows for i in range(self.blockshape[0])] + + @property + def colblocksizes(self): + return [self.blocks[0, i].cols for i in range(self.blockshape[1])] + + def structurally_equal(self, other): + return (isinstance(other, BlockMatrix) + and self.shape == other.shape + and self.blockshape == other.blockshape + and self.rowblocksizes == other.rowblocksizes + and self.colblocksizes == other.colblocksizes) + + def _blockmul(self, other): + if (isinstance(other, BlockMatrix) and + self.colblocksizes == other.rowblocksizes): + return BlockMatrix(self.blocks*other.blocks) + + return self * other + + def _blockadd(self, other): + if (isinstance(other, BlockMatrix) + and self.structurally_equal(other)): + return BlockMatrix(self.blocks + other.blocks) + + return self + other + + def _eval_transpose(self): + # Flip all the individual matrices + matrices = [transpose(matrix) for matrix in self.blocks] + # Make a copy + M = Matrix(self.blockshape[0], self.blockshape[1], matrices) + # Transpose the block structure + M = M.transpose() + return BlockMatrix(M) + + def _eval_adjoint(self): + return BlockMatrix( + Matrix(self.blockshape[0], self.blockshape[1], self.blocks).adjoint() + ) + + def _eval_trace(self): + if self.rowblocksizes == self.colblocksizes: + blocks = [self.blocks[i, i] for i in range(self.blockshape[0])] + return Add(*[trace(block) for block in blocks]) + + def _eval_determinant(self): + if self.blockshape == (1, 1): + return det(self.blocks[0, 0]) + if self.blockshape == (2, 2): + [[A, B], + [C, D]] = self.blocks.tolist() + if ask(Q.invertible(A)): + return det(A)*det(D - C*A.I*B) + elif ask(Q.invertible(D)): + return det(D)*det(A - B*D.I*C) + return Determinant(self) + + def _eval_as_real_imag(self): + real_matrices = [re(matrix) for matrix in self.blocks] + real_matrices = Matrix(self.blockshape[0], self.blockshape[1], real_matrices) + + im_matrices = [im(matrix) for matrix in self.blocks] + im_matrices = Matrix(self.blockshape[0], self.blockshape[1], im_matrices) + + return (BlockMatrix(real_matrices), BlockMatrix(im_matrices)) + + def _eval_derivative(self, x): + return BlockMatrix(self.blocks.diff(x)) + + def transpose(self): + """Return transpose of matrix. + + Examples + ======== + + >>> from sympy import MatrixSymbol, BlockMatrix, ZeroMatrix + >>> from sympy.abc import m, n + >>> X = MatrixSymbol('X', n, n) + >>> Y = MatrixSymbol('Y', m, m) + >>> Z = MatrixSymbol('Z', n, m) + >>> B = BlockMatrix([[X, Z], [ZeroMatrix(m,n), Y]]) + >>> B.transpose() + Matrix([ + [X.T, 0], + [Z.T, Y.T]]) + >>> _.transpose() + Matrix([ + [X, Z], + [0, Y]]) + """ + return self._eval_transpose() + + def schur(self, mat = 'A', generalized = False): + """Return the Schur Complement of the 2x2 BlockMatrix + + Parameters + ========== + + mat : String, optional + The matrix with respect to which the + Schur Complement is calculated. 'A' is + used by default + + generalized : bool, optional + If True, returns the generalized Schur + Component which uses Moore-Penrose Inverse + + Examples + ======== + + >>> from sympy import symbols, MatrixSymbol, BlockMatrix + >>> m, n = symbols('m n') + >>> A = MatrixSymbol('A', n, n) + >>> B = MatrixSymbol('B', n, m) + >>> C = MatrixSymbol('C', m, n) + >>> D = MatrixSymbol('D', m, m) + >>> X = BlockMatrix([[A, B], [C, D]]) + + The default Schur Complement is evaluated with "A" + + >>> X.schur() + -C*A**(-1)*B + D + >>> X.schur('D') + A - B*D**(-1)*C + + Schur complement with non-invertible matrices is not + defined. Instead, the generalized Schur complement can + be calculated which uses the Moore-Penrose Inverse. To + achieve this, `generalized` must be set to `True` + + >>> X.schur('B', generalized=True) + C - D*(B.T*B)**(-1)*B.T*A + >>> X.schur('C', generalized=True) + -A*(C.T*C)**(-1)*C.T*D + B + + Returns + ======= + + M : Matrix + The Schur Complement Matrix + + Raises + ====== + + ShapeError + If the block matrix is not a 2x2 matrix + + NonInvertibleMatrixError + If given matrix is non-invertible + + References + ========== + + .. [1] Wikipedia Article on Schur Component : https://en.wikipedia.org/wiki/Schur_complement + + See Also + ======== + + sympy.matrices.matrixbase.MatrixBase.pinv + """ + + if self.blockshape == (2, 2): + [[A, B], + [C, D]] = self.blocks.tolist() + d={'A' : A, 'B' : B, 'C' : C, 'D' : D} + try: + inv = (d[mat].T*d[mat]).inv()*d[mat].T if generalized else d[mat].inv() + if mat == 'A': + return D - C * inv * B + elif mat == 'B': + return C - D * inv * A + elif mat == 'C': + return B - A * inv * D + elif mat == 'D': + return A - B * inv * C + #For matrices where no sub-matrix is square + return self + except NonInvertibleMatrixError: + raise NonInvertibleMatrixError('The given matrix is not invertible. Please set generalized=True \ + to compute the generalized Schur Complement which uses Moore-Penrose Inverse') + else: + raise ShapeError('Schur Complement can only be calculated for 2x2 block matrices') + + def LDUdecomposition(self): + """Returns the Block LDU decomposition of + a 2x2 Block Matrix + + Returns + ======= + + (L, D, U) : Matrices + L : Lower Diagonal Matrix + D : Diagonal Matrix + U : Upper Diagonal Matrix + + Examples + ======== + + >>> from sympy import symbols, MatrixSymbol, BlockMatrix, block_collapse + >>> m, n = symbols('m n') + >>> A = MatrixSymbol('A', n, n) + >>> B = MatrixSymbol('B', n, m) + >>> C = MatrixSymbol('C', m, n) + >>> D = MatrixSymbol('D', m, m) + >>> X = BlockMatrix([[A, B], [C, D]]) + >>> L, D, U = X.LDUdecomposition() + >>> block_collapse(L*D*U) + Matrix([ + [A, B], + [C, D]]) + + Raises + ====== + + ShapeError + If the block matrix is not a 2x2 matrix + + NonInvertibleMatrixError + If the matrix "A" is non-invertible + + See Also + ======== + sympy.matrices.expressions.blockmatrix.BlockMatrix.UDLdecomposition + sympy.matrices.expressions.blockmatrix.BlockMatrix.LUdecomposition + """ + if self.blockshape == (2,2): + [[A, B], + [C, D]] = self.blocks.tolist() + try: + AI = A.I + except NonInvertibleMatrixError: + raise NonInvertibleMatrixError('Block LDU decomposition cannot be calculated when\ + "A" is singular') + Ip = Identity(B.shape[0]) + Iq = Identity(B.shape[1]) + Z = ZeroMatrix(*B.shape) + L = BlockMatrix([[Ip, Z], [C*AI, Iq]]) + D = BlockDiagMatrix(A, self.schur()) + U = BlockMatrix([[Ip, AI*B],[Z.T, Iq]]) + return L, D, U + else: + raise ShapeError("Block LDU decomposition is supported only for 2x2 block matrices") + + def UDLdecomposition(self): + """Returns the Block UDL decomposition of + a 2x2 Block Matrix + + Returns + ======= + + (U, D, L) : Matrices + U : Upper Diagonal Matrix + D : Diagonal Matrix + L : Lower Diagonal Matrix + + Examples + ======== + + >>> from sympy import symbols, MatrixSymbol, BlockMatrix, block_collapse + >>> m, n = symbols('m n') + >>> A = MatrixSymbol('A', n, n) + >>> B = MatrixSymbol('B', n, m) + >>> C = MatrixSymbol('C', m, n) + >>> D = MatrixSymbol('D', m, m) + >>> X = BlockMatrix([[A, B], [C, D]]) + >>> U, D, L = X.UDLdecomposition() + >>> block_collapse(U*D*L) + Matrix([ + [A, B], + [C, D]]) + + Raises + ====== + + ShapeError + If the block matrix is not a 2x2 matrix + + NonInvertibleMatrixError + If the matrix "D" is non-invertible + + See Also + ======== + sympy.matrices.expressions.blockmatrix.BlockMatrix.LDUdecomposition + sympy.matrices.expressions.blockmatrix.BlockMatrix.LUdecomposition + """ + if self.blockshape == (2,2): + [[A, B], + [C, D]] = self.blocks.tolist() + try: + DI = D.I + except NonInvertibleMatrixError: + raise NonInvertibleMatrixError('Block UDL decomposition cannot be calculated when\ + "D" is singular') + Ip = Identity(A.shape[0]) + Iq = Identity(B.shape[1]) + Z = ZeroMatrix(*B.shape) + U = BlockMatrix([[Ip, B*DI], [Z.T, Iq]]) + D = BlockDiagMatrix(self.schur('D'), D) + L = BlockMatrix([[Ip, Z],[DI*C, Iq]]) + return U, D, L + else: + raise ShapeError("Block UDL decomposition is supported only for 2x2 block matrices") + + def LUdecomposition(self): + """Returns the Block LU decomposition of + a 2x2 Block Matrix + + Returns + ======= + + (L, U) : Matrices + L : Lower Diagonal Matrix + U : Upper Diagonal Matrix + + Examples + ======== + + >>> from sympy import symbols, MatrixSymbol, BlockMatrix, block_collapse + >>> m, n = symbols('m n') + >>> A = MatrixSymbol('A', n, n) + >>> B = MatrixSymbol('B', n, m) + >>> C = MatrixSymbol('C', m, n) + >>> D = MatrixSymbol('D', m, m) + >>> X = BlockMatrix([[A, B], [C, D]]) + >>> L, U = X.LUdecomposition() + >>> block_collapse(L*U) + Matrix([ + [A, B], + [C, D]]) + + Raises + ====== + + ShapeError + If the block matrix is not a 2x2 matrix + + NonInvertibleMatrixError + If the matrix "A" is non-invertible + + See Also + ======== + sympy.matrices.expressions.blockmatrix.BlockMatrix.UDLdecomposition + sympy.matrices.expressions.blockmatrix.BlockMatrix.LDUdecomposition + """ + if self.blockshape == (2,2): + [[A, B], + [C, D]] = self.blocks.tolist() + try: + A = A**S.Half + AI = A.I + except NonInvertibleMatrixError: + raise NonInvertibleMatrixError('Block LU decomposition cannot be calculated when\ + "A" is singular') + Z = ZeroMatrix(*B.shape) + Q = self.schur()**S.Half + L = BlockMatrix([[A, Z], [C*AI, Q]]) + U = BlockMatrix([[A, AI*B],[Z.T, Q]]) + return L, U + else: + raise ShapeError("Block LU decomposition is supported only for 2x2 block matrices") + + def _entry(self, i, j, **kwargs): + # Find row entry + orig_i, orig_j = i, j + for row_block, numrows in enumerate(self.rowblocksizes): + cmp = i < numrows + if cmp == True: + break + elif cmp == False: + i -= numrows + elif row_block < self.blockshape[0] - 1: + # Can't tell which block and it's not the last one, return unevaluated + return MatrixElement(self, orig_i, orig_j) + for col_block, numcols in enumerate(self.colblocksizes): + cmp = j < numcols + if cmp == True: + break + elif cmp == False: + j -= numcols + elif col_block < self.blockshape[1] - 1: + return MatrixElement(self, orig_i, orig_j) + return self.blocks[row_block, col_block][i, j] + + @property + def is_Identity(self): + if self.blockshape[0] != self.blockshape[1]: + return False + for i in range(self.blockshape[0]): + for j in range(self.blockshape[1]): + if i==j and not self.blocks[i, j].is_Identity: + return False + if i!=j and not self.blocks[i, j].is_ZeroMatrix: + return False + return True + + @property + def is_structurally_symmetric(self): + return self.rowblocksizes == self.colblocksizes + + def equals(self, other): + if self == other: + return True + if (isinstance(other, BlockMatrix) and self.blocks == other.blocks): + return True + return super().equals(other) + + +class BlockDiagMatrix(BlockMatrix): + """A sparse matrix with block matrices along its diagonals + + Examples + ======== + + >>> from sympy import MatrixSymbol, BlockDiagMatrix, symbols + >>> n, m, l = symbols('n m l') + >>> X = MatrixSymbol('X', n, n) + >>> Y = MatrixSymbol('Y', m, m) + >>> BlockDiagMatrix(X, Y) + Matrix([ + [X, 0], + [0, Y]]) + + Notes + ===== + + If you want to get the individual diagonal blocks, use + :meth:`get_diag_blocks`. + + See Also + ======== + + sympy.matrices.dense.diag + """ + def __new__(cls, *mats): + return Basic.__new__(BlockDiagMatrix, *[_sympify(m) for m in mats]) + + @property + def diag(self): + return self.args + + @property + def blocks(self): + from sympy.matrices.immutable import ImmutableDenseMatrix + mats = self.args + data = [[mats[i] if i == j else ZeroMatrix(mats[i].rows, mats[j].cols) + for j in range(len(mats))] + for i in range(len(mats))] + return ImmutableDenseMatrix(data, evaluate=False) + + @property + def shape(self): + return (sum(block.rows for block in self.args), + sum(block.cols for block in self.args)) + + @property + def blockshape(self): + n = len(self.args) + return (n, n) + + @property + def rowblocksizes(self): + return [block.rows for block in self.args] + + @property + def colblocksizes(self): + return [block.cols for block in self.args] + + def _all_square_blocks(self): + """Returns true if all blocks are square""" + return all(mat.is_square for mat in self.args) + + def _eval_determinant(self): + if self._all_square_blocks(): + return Mul(*[det(mat) for mat in self.args]) + # At least one block is non-square. Since the entire matrix must be square we know there must + # be at least two blocks in this matrix, in which case the entire matrix is necessarily rank-deficient + return S.Zero + + def _eval_inverse(self, expand='ignored'): + if self._all_square_blocks(): + return BlockDiagMatrix(*[mat.inverse() for mat in self.args]) + # See comment in _eval_determinant() + raise NonInvertibleMatrixError('Matrix det == 0; not invertible.') + + def _eval_transpose(self): + return BlockDiagMatrix(*[mat.transpose() for mat in self.args]) + + def _blockmul(self, other): + if (isinstance(other, BlockDiagMatrix) and + self.colblocksizes == other.rowblocksizes): + return BlockDiagMatrix(*[a*b for a, b in zip(self.args, other.args)]) + else: + return BlockMatrix._blockmul(self, other) + + def _blockadd(self, other): + if (isinstance(other, BlockDiagMatrix) and + self.blockshape == other.blockshape and + self.rowblocksizes == other.rowblocksizes and + self.colblocksizes == other.colblocksizes): + return BlockDiagMatrix(*[a + b for a, b in zip(self.args, other.args)]) + else: + return BlockMatrix._blockadd(self, other) + + def get_diag_blocks(self): + """Return the list of diagonal blocks of the matrix. + + Examples + ======== + + >>> from sympy import BlockDiagMatrix, Matrix + + >>> A = Matrix([[1, 2], [3, 4]]) + >>> B = Matrix([[5, 6], [7, 8]]) + >>> M = BlockDiagMatrix(A, B) + + How to get diagonal blocks from the block diagonal matrix: + + >>> diag_blocks = M.get_diag_blocks() + >>> diag_blocks[0] + Matrix([ + [1, 2], + [3, 4]]) + >>> diag_blocks[1] + Matrix([ + [5, 6], + [7, 8]]) + """ + return self.args + + +def block_collapse(expr): + """Evaluates a block matrix expression + + >>> from sympy import MatrixSymbol, BlockMatrix, symbols, Identity, ZeroMatrix, block_collapse + >>> n,m,l = symbols('n m l') + >>> X = MatrixSymbol('X', n, n) + >>> Y = MatrixSymbol('Y', m, m) + >>> Z = MatrixSymbol('Z', n, m) + >>> B = BlockMatrix([[X, Z], [ZeroMatrix(m, n), Y]]) + >>> print(B) + Matrix([ + [X, Z], + [0, Y]]) + + >>> C = BlockMatrix([[Identity(n), Z]]) + >>> print(C) + Matrix([[I, Z]]) + + >>> print(block_collapse(C*B)) + Matrix([[X, Z + Z*Y]]) + """ + from sympy.strategies.util import expr_fns + + hasbm = lambda expr: isinstance(expr, MatrixExpr) and expr.has(BlockMatrix) + + conditioned_rl = condition( + hasbm, + typed( + {MatAdd: do_one(bc_matadd, bc_block_plus_ident), + MatMul: do_one(bc_matmul, bc_dist), + MatPow: bc_matmul, + Transpose: bc_transpose, + Inverse: bc_inverse, + BlockMatrix: do_one(bc_unpack, deblock)} + ) + ) + + rule = exhaust( + bottom_up( + exhaust(conditioned_rl), + fns=expr_fns + ) + ) + + result = rule(expr) + doit = getattr(result, 'doit', None) + if doit is not None: + return doit() + else: + return result + +def bc_unpack(expr): + if expr.blockshape == (1, 1): + return expr.blocks[0, 0] + return expr + +def bc_matadd(expr): + args = sift(expr.args, lambda M: isinstance(M, BlockMatrix)) + blocks = args[True] + if not blocks: + return expr + + nonblocks = args[False] + block = blocks[0] + for b in blocks[1:]: + block = block._blockadd(b) + if nonblocks: + return MatAdd(*nonblocks) + block + else: + return block + +def bc_block_plus_ident(expr): + idents = [arg for arg in expr.args if arg.is_Identity] + if not idents: + return expr + + blocks = [arg for arg in expr.args if isinstance(arg, BlockMatrix)] + if (blocks and all(b.structurally_equal(blocks[0]) for b in blocks) + and blocks[0].is_structurally_symmetric): + block_id = BlockDiagMatrix(*[Identity(k) + for k in blocks[0].rowblocksizes]) + rest = [arg for arg in expr.args if not arg.is_Identity and not isinstance(arg, BlockMatrix)] + return MatAdd(block_id * len(idents), *blocks, *rest).doit() + + return expr + +def bc_dist(expr): + """ Turn a*[X, Y] into [a*X, a*Y] """ + factor, mat = expr.as_coeff_mmul() + if factor == 1: + return expr + + unpacked = unpack(mat) + + if isinstance(unpacked, BlockDiagMatrix): + B = unpacked.diag + new_B = [factor * mat for mat in B] + return BlockDiagMatrix(*new_B) + elif isinstance(unpacked, BlockMatrix): + B = unpacked.blocks + new_B = [ + [factor * B[i, j] for j in range(B.cols)] for i in range(B.rows)] + return BlockMatrix(new_B) + return expr + + +def bc_matmul(expr): + if isinstance(expr, MatPow): + if expr.args[1].is_Integer and expr.args[1] > 0: + factor, matrices = 1, [expr.args[0]]*expr.args[1] + else: + return expr + else: + factor, matrices = expr.as_coeff_matrices() + + i = 0 + while (i+1 < len(matrices)): + A, B = matrices[i:i+2] + if isinstance(A, BlockMatrix) and isinstance(B, BlockMatrix): + matrices[i] = A._blockmul(B) + matrices.pop(i+1) + elif isinstance(A, BlockMatrix): + matrices[i] = A._blockmul(BlockMatrix([[B]])) + matrices.pop(i+1) + elif isinstance(B, BlockMatrix): + matrices[i] = BlockMatrix([[A]])._blockmul(B) + matrices.pop(i+1) + else: + i+=1 + return MatMul(factor, *matrices).doit() + +def bc_transpose(expr): + collapse = block_collapse(expr.arg) + return collapse._eval_transpose() + + +def bc_inverse(expr): + if isinstance(expr.arg, BlockDiagMatrix): + return expr.inverse() + + expr2 = blockinverse_1x1(expr) + if expr != expr2: + return expr2 + return blockinverse_2x2(Inverse(reblock_2x2(expr.arg))) + +def blockinverse_1x1(expr): + if isinstance(expr.arg, BlockMatrix) and expr.arg.blockshape == (1, 1): + mat = Matrix([[expr.arg.blocks[0].inverse()]]) + return BlockMatrix(mat) + return expr + + +def blockinverse_2x2(expr): + if isinstance(expr.arg, BlockMatrix) and expr.arg.blockshape == (2, 2): + # See: Inverses of 2x2 Block Matrices, Tzon-Tzer Lu and Sheng-Hua Shiou + [[A, B], + [C, D]] = expr.arg.blocks.tolist() + + formula = _choose_2x2_inversion_formula(A, B, C, D) + if formula != None: + MI = expr.arg.schur(formula).I + if formula == 'A': + AI = A.I + return BlockMatrix([[AI + AI * B * MI * C * AI, -AI * B * MI], [-MI * C * AI, MI]]) + if formula == 'B': + BI = B.I + return BlockMatrix([[-MI * D * BI, MI], [BI + BI * A * MI * D * BI, -BI * A * MI]]) + if formula == 'C': + CI = C.I + return BlockMatrix([[-CI * D * MI, CI + CI * D * MI * A * CI], [MI, -MI * A * CI]]) + if formula == 'D': + DI = D.I + return BlockMatrix([[MI, -MI * B * DI], [-DI * C * MI, DI + DI * C * MI * B * DI]]) + + return expr + + +def _choose_2x2_inversion_formula(A, B, C, D): + """ + Assuming [[A, B], [C, D]] would form a valid square block matrix, find + which of the classical 2x2 block matrix inversion formulas would be + best suited. + + Returns 'A', 'B', 'C', 'D' to represent the algorithm involving inversion + of the given argument or None if the matrix cannot be inverted using + any of those formulas. + """ + # Try to find a known invertible matrix. Note that the Schur complement + # is currently not being considered for this + A_inv = ask(Q.invertible(A)) + if A_inv == True: + return 'A' + B_inv = ask(Q.invertible(B)) + if B_inv == True: + return 'B' + C_inv = ask(Q.invertible(C)) + if C_inv == True: + return 'C' + D_inv = ask(Q.invertible(D)) + if D_inv == True: + return 'D' + # Otherwise try to find a matrix that isn't known to be non-invertible + if A_inv != False: + return 'A' + if B_inv != False: + return 'B' + if C_inv != False: + return 'C' + if D_inv != False: + return 'D' + return None + + +def deblock(B): + """ Flatten a BlockMatrix of BlockMatrices """ + if not isinstance(B, BlockMatrix) or not B.blocks.has(BlockMatrix): + return B + wrap = lambda x: x if isinstance(x, BlockMatrix) else BlockMatrix([[x]]) + bb = B.blocks.applyfunc(wrap) # everything is a block + + try: + MM = Matrix(0, sum(bb[0, i].blocks.shape[1] for i in range(bb.shape[1])), []) + for row in range(0, bb.shape[0]): + M = Matrix(bb[row, 0].blocks) + for col in range(1, bb.shape[1]): + M = M.row_join(bb[row, col].blocks) + MM = MM.col_join(M) + + return BlockMatrix(MM) + except ShapeError: + return B + + +def reblock_2x2(expr): + """ + Reblock a BlockMatrix so that it has 2x2 blocks of block matrices. If + possible in such a way that the matrix continues to be invertible using the + classical 2x2 block inversion formulas. + """ + if not isinstance(expr, BlockMatrix) or not all(d > 2 for d in expr.blockshape): + return expr + + BM = BlockMatrix # for brevity's sake + rowblocks, colblocks = expr.blockshape + blocks = expr.blocks + for i in range(1, rowblocks): + for j in range(1, colblocks): + # try to split rows at i and cols at j + A = bc_unpack(BM(blocks[:i, :j])) + B = bc_unpack(BM(blocks[:i, j:])) + C = bc_unpack(BM(blocks[i:, :j])) + D = bc_unpack(BM(blocks[i:, j:])) + + formula = _choose_2x2_inversion_formula(A, B, C, D) + if formula is not None: + return BlockMatrix([[A, B], [C, D]]) + + # else: nothing worked, just split upper left corner + return BM([[blocks[0, 0], BM(blocks[0, 1:])], + [BM(blocks[1:, 0]), BM(blocks[1:, 1:])]]) + + +def bounds(sizes): + """ Convert sequence of numbers into pairs of low-high pairs + + >>> from sympy.matrices.expressions.blockmatrix import bounds + >>> bounds((1, 10, 50)) + [(0, 1), (1, 11), (11, 61)] + """ + low = 0 + rv = [] + for size in sizes: + rv.append((low, low + size)) + low += size + return rv + +def blockcut(expr, rowsizes, colsizes): + """ Cut a matrix expression into Blocks + + >>> from sympy import ImmutableMatrix, blockcut + >>> M = ImmutableMatrix(4, 4, range(16)) + >>> B = blockcut(M, (1, 3), (1, 3)) + >>> type(B).__name__ + 'BlockMatrix' + >>> ImmutableMatrix(B.blocks[0, 1]) + Matrix([[1, 2, 3]]) + """ + + rowbounds = bounds(rowsizes) + colbounds = bounds(colsizes) + return BlockMatrix([[MatrixSlice(expr, rowbound, colbound) + for colbound in colbounds] + for rowbound in rowbounds]) diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/companion.py b/phivenv/Lib/site-packages/sympy/matrices/expressions/companion.py new file mode 100644 index 0000000000000000000000000000000000000000..6969c917f63806cb1f5417804e01ecc1350d1406 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/expressions/companion.py @@ -0,0 +1,56 @@ +from sympy.core.singleton import S +from sympy.core.sympify import _sympify +from sympy.polys.polytools import Poly + +from .matexpr import MatrixExpr + + +class CompanionMatrix(MatrixExpr): + """A symbolic companion matrix of a polynomial. + + Examples + ======== + + >>> from sympy import Poly, Symbol, symbols + >>> from sympy.matrices.expressions import CompanionMatrix + >>> x = Symbol('x') + >>> c0, c1, c2, c3, c4 = symbols('c0:5') + >>> p = Poly(c0 + c1*x + c2*x**2 + c3*x**3 + c4*x**4 + x**5, x) + >>> CompanionMatrix(p) + CompanionMatrix(Poly(x**5 + c4*x**4 + c3*x**3 + c2*x**2 + c1*x + c0, + x, domain='ZZ[c0,c1,c2,c3,c4]')) + """ + def __new__(cls, poly): + poly = _sympify(poly) + if not isinstance(poly, Poly): + raise ValueError("{} must be a Poly instance.".format(poly)) + if not poly.is_monic: + raise ValueError("{} must be a monic polynomial.".format(poly)) + if not poly.is_univariate: + raise ValueError( + "{} must be a univariate polynomial.".format(poly)) + if not poly.degree() >= 1: + raise ValueError( + "{} must have degree not less than 1.".format(poly)) + + return super().__new__(cls, poly) + + + @property + def shape(self): + poly = self.args[0] + size = poly.degree() + return size, size + + + def _entry(self, i, j): + if j == self.cols - 1: + return -self.args[0].all_coeffs()[-1 - i] + elif i == j + 1: + return S.One + return S.Zero + + + def as_explicit(self): + from sympy.matrices.immutable import ImmutableDenseMatrix + return ImmutableDenseMatrix.companion(self.args[0]) diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/determinant.py b/phivenv/Lib/site-packages/sympy/matrices/expressions/determinant.py new file mode 100644 index 0000000000000000000000000000000000000000..b323b3f93a5a0404bf2205f39d25b931d173b6d9 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/expressions/determinant.py @@ -0,0 +1,148 @@ +from sympy.core.basic import Basic +from sympy.core.expr import Expr +from sympy.core.singleton import S +from sympy.core.sympify import sympify +from sympy.matrices.exceptions import NonSquareMatrixError +from sympy.matrices.matrixbase import MatrixBase + + +class Determinant(Expr): + """Matrix Determinant + + Represents the determinant of a matrix expression. + + Examples + ======== + + >>> from sympy import MatrixSymbol, Determinant, eye + >>> A = MatrixSymbol('A', 3, 3) + >>> Determinant(A) + Determinant(A) + >>> Determinant(eye(3)).doit() + 1 + """ + is_commutative = True + + def __new__(cls, mat): + mat = sympify(mat) + if not mat.is_Matrix: + raise TypeError("Input to Determinant, %s, not a matrix" % str(mat)) + + if mat.is_square is False: + raise NonSquareMatrixError("Det of a non-square matrix") + + return Basic.__new__(cls, mat) + + @property + def arg(self): + return self.args[0] + + @property + def kind(self): + return self.arg.kind.element_kind + + def doit(self, **hints): + arg = self.arg + if hints.get('deep', True): + arg = arg.doit(**hints) + + result = arg._eval_determinant() + if result is not None: + return result + + return self + + +def det(matexpr): + """ Matrix Determinant + + Examples + ======== + + >>> from sympy import MatrixSymbol, det, eye + >>> A = MatrixSymbol('A', 3, 3) + >>> det(A) + Determinant(A) + >>> det(eye(3)) + 1 + """ + + return Determinant(matexpr).doit() + +class Permanent(Expr): + """Matrix Permanent + + Represents the permanent of a matrix expression. + + Examples + ======== + + >>> from sympy import MatrixSymbol, Permanent, ones + >>> A = MatrixSymbol('A', 3, 3) + >>> Permanent(A) + Permanent(A) + >>> Permanent(ones(3, 3)).doit() + 6 + """ + + def __new__(cls, mat): + mat = sympify(mat) + if not mat.is_Matrix: + raise TypeError("Input to Permanent, %s, not a matrix" % str(mat)) + + return Basic.__new__(cls, mat) + + @property + def arg(self): + return self.args[0] + + def doit(self, expand=False, **hints): + if isinstance(self.arg, MatrixBase): + return self.arg.per() + else: + return self + +def per(matexpr): + """ Matrix Permanent + + Examples + ======== + + >>> from sympy import MatrixSymbol, Matrix, per, ones + >>> A = MatrixSymbol('A', 3, 3) + >>> per(A) + Permanent(A) + >>> per(ones(5, 5)) + 120 + >>> M = Matrix([1, 2, 5]) + >>> per(M) + 8 + """ + + return Permanent(matexpr).doit() + +from sympy.assumptions.ask import ask, Q +from sympy.assumptions.refine import handlers_dict + + +def refine_Determinant(expr, assumptions): + """ + >>> from sympy import MatrixSymbol, Q, assuming, refine, det + >>> X = MatrixSymbol('X', 2, 2) + >>> det(X) + Determinant(X) + >>> with assuming(Q.orthogonal(X)): + ... print(refine(det(X))) + 1 + """ + if ask(Q.orthogonal(expr.arg), assumptions): + return S.One + elif ask(Q.singular(expr.arg), assumptions): + return S.Zero + elif ask(Q.unit_triangular(expr.arg), assumptions): + return S.One + + return expr + + +handlers_dict['Determinant'] = refine_Determinant diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/diagonal.py b/phivenv/Lib/site-packages/sympy/matrices/expressions/diagonal.py new file mode 100644 index 0000000000000000000000000000000000000000..ba8a0216588143e3e251dab84c25f038fad550a4 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/expressions/diagonal.py @@ -0,0 +1,220 @@ +from sympy.core.sympify import _sympify + +from sympy.matrices.expressions import MatrixExpr +from sympy.core import S, Eq, Ge +from sympy.core.mul import Mul +from sympy.functions.special.tensor_functions import KroneckerDelta + + +class DiagonalMatrix(MatrixExpr): + """DiagonalMatrix(M) will create a matrix expression that + behaves as though all off-diagonal elements, + `M[i, j]` where `i != j`, are zero. + + Examples + ======== + + >>> from sympy import MatrixSymbol, DiagonalMatrix, Symbol + >>> n = Symbol('n', integer=True) + >>> m = Symbol('m', integer=True) + >>> D = DiagonalMatrix(MatrixSymbol('x', 2, 3)) + >>> D[1, 2] + 0 + >>> D[1, 1] + x[1, 1] + + The length of the diagonal -- the lesser of the two dimensions of `M` -- + is accessed through the `diagonal_length` property: + + >>> D.diagonal_length + 2 + >>> DiagonalMatrix(MatrixSymbol('x', n + 1, n)).diagonal_length + n + + When one of the dimensions is symbolic the other will be treated as + though it is smaller: + + >>> tall = DiagonalMatrix(MatrixSymbol('x', n, 3)) + >>> tall.diagonal_length + 3 + >>> tall[10, 1] + 0 + + When the size of the diagonal is not known, a value of None will + be returned: + + >>> DiagonalMatrix(MatrixSymbol('x', n, m)).diagonal_length is None + True + + """ + arg = property(lambda self: self.args[0]) + + shape = property(lambda self: self.arg.shape) # type:ignore + + @property + def diagonal_length(self): + r, c = self.shape + if r.is_Integer and c.is_Integer: + m = min(r, c) + elif r.is_Integer and not c.is_Integer: + m = r + elif c.is_Integer and not r.is_Integer: + m = c + elif r == c: + m = r + else: + try: + m = min(r, c) + except TypeError: + m = None + return m + + def _entry(self, i, j, **kwargs): + if self.diagonal_length is not None: + if Ge(i, self.diagonal_length) is S.true: + return S.Zero + elif Ge(j, self.diagonal_length) is S.true: + return S.Zero + eq = Eq(i, j) + if eq is S.true: + return self.arg[i, i] + elif eq is S.false: + return S.Zero + return self.arg[i, j]*KroneckerDelta(i, j) + + +class DiagonalOf(MatrixExpr): + """DiagonalOf(M) will create a matrix expression that + is equivalent to the diagonal of `M`, represented as + a single column matrix. + + Examples + ======== + + >>> from sympy import MatrixSymbol, DiagonalOf, Symbol + >>> n = Symbol('n', integer=True) + >>> m = Symbol('m', integer=True) + >>> x = MatrixSymbol('x', 2, 3) + >>> diag = DiagonalOf(x) + >>> diag.shape + (2, 1) + + The diagonal can be addressed like a matrix or vector and will + return the corresponding element of the original matrix: + + >>> diag[1, 0] == diag[1] == x[1, 1] + True + + The length of the diagonal -- the lesser of the two dimensions of `M` -- + is accessed through the `diagonal_length` property: + + >>> diag.diagonal_length + 2 + >>> DiagonalOf(MatrixSymbol('x', n + 1, n)).diagonal_length + n + + When only one of the dimensions is symbolic the other will be + treated as though it is smaller: + + >>> dtall = DiagonalOf(MatrixSymbol('x', n, 3)) + >>> dtall.diagonal_length + 3 + + When the size of the diagonal is not known, a value of None will + be returned: + + >>> DiagonalOf(MatrixSymbol('x', n, m)).diagonal_length is None + True + + """ + arg = property(lambda self: self.args[0]) + @property + def shape(self): + r, c = self.arg.shape + if r.is_Integer and c.is_Integer: + m = min(r, c) + elif r.is_Integer and not c.is_Integer: + m = r + elif c.is_Integer and not r.is_Integer: + m = c + elif r == c: + m = r + else: + try: + m = min(r, c) + except TypeError: + m = None + return m, S.One + + @property + def diagonal_length(self): + return self.shape[0] + + def _entry(self, i, j, **kwargs): + return self.arg._entry(i, i, **kwargs) + + +class DiagMatrix(MatrixExpr): + """ + Turn a vector into a diagonal matrix. + """ + def __new__(cls, vector): + vector = _sympify(vector) + obj = MatrixExpr.__new__(cls, vector) + shape = vector.shape + dim = shape[1] if shape[0] == 1 else shape[0] + if vector.shape[0] != 1: + obj._iscolumn = True + else: + obj._iscolumn = False + obj._shape = (dim, dim) + obj._vector = vector + return obj + + @property + def shape(self): + return self._shape + + def _entry(self, i, j, **kwargs): + if self._iscolumn: + result = self._vector._entry(i, 0, **kwargs) + else: + result = self._vector._entry(0, j, **kwargs) + if i != j: + result *= KroneckerDelta(i, j) + return result + + def _eval_transpose(self): + return self + + def as_explicit(self): + from sympy.matrices.dense import diag + return diag(*list(self._vector.as_explicit())) + + def doit(self, **hints): + from sympy.assumptions import ask, Q + from sympy.matrices.expressions.matmul import MatMul + from sympy.matrices.expressions.transpose import Transpose + from sympy.matrices.dense import eye + from sympy.matrices.matrixbase import MatrixBase + vector = self._vector + # This accounts for shape (1, 1) and identity matrices, among others: + if ask(Q.diagonal(vector)): + return vector + if isinstance(vector, MatrixBase): + ret = eye(max(vector.shape)) + for i in range(ret.shape[0]): + ret[i, i] = vector[i] + return type(vector)(ret) + if vector.is_MatMul: + matrices = [arg for arg in vector.args if arg.is_Matrix] + scalars = [arg for arg in vector.args if arg not in matrices] + if scalars: + return Mul.fromiter(scalars)*DiagMatrix(MatMul.fromiter(matrices).doit()).doit() + if isinstance(vector, Transpose): + vector = vector.arg + return DiagMatrix(vector) + + +def diagonalize_vector(vector): + return DiagMatrix(vector).doit() diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/dotproduct.py b/phivenv/Lib/site-packages/sympy/matrices/expressions/dotproduct.py new file mode 100644 index 0000000000000000000000000000000000000000..3a413f8c79a221505f0c082d7f19f78597a2befc --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/expressions/dotproduct.py @@ -0,0 +1,55 @@ +from sympy.core import Basic, Expr +from sympy.core.sympify import _sympify +from sympy.matrices.expressions.transpose import transpose + + +class DotProduct(Expr): + """ + Dot product of vector matrices + + The input should be two 1 x n or n x 1 matrices. The output represents the + scalar dotproduct. + + This is similar to using MatrixElement and MatMul, except DotProduct does + not require that one vector to be a row vector and the other vector to be + a column vector. + + >>> from sympy import MatrixSymbol, DotProduct + >>> A = MatrixSymbol('A', 1, 3) + >>> B = MatrixSymbol('B', 1, 3) + >>> DotProduct(A, B) + DotProduct(A, B) + >>> DotProduct(A, B).doit() + A[0, 0]*B[0, 0] + A[0, 1]*B[0, 1] + A[0, 2]*B[0, 2] + """ + + def __new__(cls, arg1, arg2): + arg1, arg2 = _sympify((arg1, arg2)) + + if not arg1.is_Matrix: + raise TypeError("Argument 1 of DotProduct is not a matrix") + if not arg2.is_Matrix: + raise TypeError("Argument 2 of DotProduct is not a matrix") + if not (1 in arg1.shape): + raise TypeError("Argument 1 of DotProduct is not a vector") + if not (1 in arg2.shape): + raise TypeError("Argument 2 of DotProduct is not a vector") + + if set(arg1.shape) != set(arg2.shape): + raise TypeError("DotProduct arguments are not the same length") + + return Basic.__new__(cls, arg1, arg2) + + def doit(self, expand=False, **hints): + if self.args[0].shape == self.args[1].shape: + if self.args[0].shape[0] == 1: + mul = self.args[0]*transpose(self.args[1]) + else: + mul = transpose(self.args[0])*self.args[1] + else: + if self.args[0].shape[0] == 1: + mul = self.args[0]*self.args[1] + else: + mul = transpose(self.args[0])*transpose(self.args[1]) + + return mul[0] diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/factorizations.py b/phivenv/Lib/site-packages/sympy/matrices/expressions/factorizations.py new file mode 100644 index 0000000000000000000000000000000000000000..aff2bb81ecff99d8e733f282ac2dd187d76ce895 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/expressions/factorizations.py @@ -0,0 +1,62 @@ +from sympy.matrices.expressions import MatrixExpr +from sympy.assumptions.ask import Q + +class Factorization(MatrixExpr): + arg = property(lambda self: self.args[0]) + shape = property(lambda self: self.arg.shape) # type: ignore + +class LofLU(Factorization): + @property + def predicates(self): + return (Q.lower_triangular,) +class UofLU(Factorization): + @property + def predicates(self): + return (Q.upper_triangular,) + +class LofCholesky(LofLU): pass +class UofCholesky(UofLU): pass + +class QofQR(Factorization): + @property + def predicates(self): + return (Q.orthogonal,) +class RofQR(Factorization): + @property + def predicates(self): + return (Q.upper_triangular,) + +class EigenVectors(Factorization): + @property + def predicates(self): + return (Q.orthogonal,) +class EigenValues(Factorization): + @property + def predicates(self): + return (Q.diagonal,) + +class UofSVD(Factorization): + @property + def predicates(self): + return (Q.orthogonal,) +class SofSVD(Factorization): + @property + def predicates(self): + return (Q.diagonal,) +class VofSVD(Factorization): + @property + def predicates(self): + return (Q.orthogonal,) + + +def lu(expr): + return LofLU(expr), UofLU(expr) + +def qr(expr): + return QofQR(expr), RofQR(expr) + +def eig(expr): + return EigenValues(expr), EigenVectors(expr) + +def svd(expr): + return UofSVD(expr), SofSVD(expr), VofSVD(expr) diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/fourier.py b/phivenv/Lib/site-packages/sympy/matrices/expressions/fourier.py new file mode 100644 index 0000000000000000000000000000000000000000..5fa9222c2a9b218f42636267235d5dd44c25f8bb --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/expressions/fourier.py @@ -0,0 +1,91 @@ +from sympy.core.sympify import _sympify +from sympy.matrices.expressions import MatrixExpr +from sympy.core.numbers import I +from sympy.core.singleton import S +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt + + +class DFT(MatrixExpr): + r""" + Returns a discrete Fourier transform matrix. The matrix is scaled + with :math:`\frac{1}{\sqrt{n}}` so that it is unitary. + + Parameters + ========== + + n : integer or Symbol + Size of the transform. + + Examples + ======== + + >>> from sympy.abc import n + >>> from sympy.matrices.expressions.fourier import DFT + >>> DFT(3) + DFT(3) + >>> DFT(3).as_explicit() + Matrix([ + [sqrt(3)/3, sqrt(3)/3, sqrt(3)/3], + [sqrt(3)/3, sqrt(3)*exp(-2*I*pi/3)/3, sqrt(3)*exp(2*I*pi/3)/3], + [sqrt(3)/3, sqrt(3)*exp(2*I*pi/3)/3, sqrt(3)*exp(-2*I*pi/3)/3]]) + >>> DFT(n).shape + (n, n) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/DFT_matrix + + """ + + def __new__(cls, n): + n = _sympify(n) + cls._check_dim(n) + + obj = super().__new__(cls, n) + return obj + + n = property(lambda self: self.args[0]) # type: ignore + shape = property(lambda self: (self.n, self.n)) # type: ignore + + def _entry(self, i, j, **kwargs): + w = exp(-2*S.Pi*I/self.n) + return w**(i*j) / sqrt(self.n) + + def _eval_inverse(self): + return IDFT(self.n) + + +class IDFT(DFT): + r""" + Returns an inverse discrete Fourier transform matrix. The matrix is scaled + with :math:`\frac{1}{\sqrt{n}}` so that it is unitary. + + Parameters + ========== + + n : integer or Symbol + Size of the transform + + Examples + ======== + + >>> from sympy.matrices.expressions.fourier import DFT, IDFT + >>> IDFT(3) + IDFT(3) + >>> IDFT(4)*DFT(4) + I + + See Also + ======== + + DFT + + """ + def _entry(self, i, j, **kwargs): + w = exp(-2*S.Pi*I/self.n) + return w**(-i*j) / sqrt(self.n) + + def _eval_inverse(self): + return DFT(self.n) diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/funcmatrix.py b/phivenv/Lib/site-packages/sympy/matrices/expressions/funcmatrix.py new file mode 100644 index 0000000000000000000000000000000000000000..91106edb489b73ac9dd6cb94adc508c0db75d3a5 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/expressions/funcmatrix.py @@ -0,0 +1,118 @@ +from .matexpr import MatrixExpr +from sympy.core.function import FunctionClass, Lambda +from sympy.core.symbol import Dummy +from sympy.core.sympify import _sympify, sympify +from sympy.matrices import Matrix +from sympy.functions.elementary.complexes import re, im + + +class FunctionMatrix(MatrixExpr): + """Represents a matrix using a function (``Lambda``) which gives + outputs according to the coordinates of each matrix entries. + + Parameters + ========== + + rows : nonnegative integer. Can be symbolic. + + cols : nonnegative integer. Can be symbolic. + + lamda : Function, Lambda or str + If it is a SymPy ``Function`` or ``Lambda`` instance, + it should be able to accept two arguments which represents the + matrix coordinates. + + If it is a pure string containing Python ``lambda`` semantics, + it is interpreted by the SymPy parser and casted into a SymPy + ``Lambda`` instance. + + Examples + ======== + + Creating a ``FunctionMatrix`` from ``Lambda``: + + >>> from sympy import FunctionMatrix, symbols, Lambda, MatPow + >>> i, j, n, m = symbols('i,j,n,m') + >>> FunctionMatrix(n, m, Lambda((i, j), i + j)) + FunctionMatrix(n, m, Lambda((i, j), i + j)) + + Creating a ``FunctionMatrix`` from a SymPy function: + + >>> from sympy import KroneckerDelta + >>> X = FunctionMatrix(3, 3, KroneckerDelta) + >>> X.as_explicit() + Matrix([ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1]]) + + Creating a ``FunctionMatrix`` from a SymPy undefined function: + + >>> from sympy import Function + >>> f = Function('f') + >>> X = FunctionMatrix(3, 3, f) + >>> X.as_explicit() + Matrix([ + [f(0, 0), f(0, 1), f(0, 2)], + [f(1, 0), f(1, 1), f(1, 2)], + [f(2, 0), f(2, 1), f(2, 2)]]) + + Creating a ``FunctionMatrix`` from Python ``lambda``: + + >>> FunctionMatrix(n, m, 'lambda i, j: i + j') + FunctionMatrix(n, m, Lambda((i, j), i + j)) + + Example of lazy evaluation of matrix product: + + >>> Y = FunctionMatrix(1000, 1000, Lambda((i, j), i + j)) + >>> isinstance(Y*Y, MatPow) # this is an expression object + True + >>> (Y**2)[10,10] # So this is evaluated lazily + 342923500 + + Notes + ===== + + This class provides an alternative way to represent an extremely + dense matrix with entries in some form of a sequence, in a most + sparse way. + """ + def __new__(cls, rows, cols, lamda): + rows, cols = _sympify(rows), _sympify(cols) + cls._check_dim(rows) + cls._check_dim(cols) + + lamda = sympify(lamda) + if not isinstance(lamda, (FunctionClass, Lambda)): + raise ValueError( + "{} should be compatible with SymPy function classes." + .format(lamda)) + + if 2 not in lamda.nargs: + raise ValueError( + '{} should be able to accept 2 arguments.'.format(lamda)) + + if not isinstance(lamda, Lambda): + i, j = Dummy('i'), Dummy('j') + lamda = Lambda((i, j), lamda(i, j)) + + return super().__new__(cls, rows, cols, lamda) + + @property + def shape(self): + return self.args[0:2] + + @property + def lamda(self): + return self.args[2] + + def _entry(self, i, j, **kwargs): + return self.lamda(i, j) + + def _eval_trace(self): + from sympy.matrices.expressions.trace import Trace + from sympy.concrete.summations import Sum + return Trace(self).rewrite(Sum).doit() + + def _eval_as_real_imag(self): + return (re(Matrix(self)), im(Matrix(self))) diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/hadamard.py b/phivenv/Lib/site-packages/sympy/matrices/expressions/hadamard.py new file mode 100644 index 0000000000000000000000000000000000000000..38c9033ebea3a7bfc569223978dc6ef3890206cf --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/expressions/hadamard.py @@ -0,0 +1,464 @@ +from collections import Counter + +from sympy.core import Mul, sympify +from sympy.core.add import Add +from sympy.core.expr import ExprBuilder +from sympy.core.sorting import default_sort_key +from sympy.functions.elementary.exponential import log +from sympy.matrices.expressions.matexpr import MatrixExpr +from sympy.matrices.expressions._shape import validate_matadd_integer as validate +from sympy.matrices.expressions.special import ZeroMatrix, OneMatrix +from sympy.strategies import ( + unpack, flatten, condition, exhaust, rm_id, sort +) +from sympy.utilities.exceptions import sympy_deprecation_warning + + +def hadamard_product(*matrices): + """ + Return the elementwise (aka Hadamard) product of matrices. + + Examples + ======== + + >>> from sympy import hadamard_product, MatrixSymbol + >>> A = MatrixSymbol('A', 2, 3) + >>> B = MatrixSymbol('B', 2, 3) + >>> hadamard_product(A) + A + >>> hadamard_product(A, B) + HadamardProduct(A, B) + >>> hadamard_product(A, B)[0, 1] + A[0, 1]*B[0, 1] + """ + if not matrices: + raise TypeError("Empty Hadamard product is undefined") + if len(matrices) == 1: + return matrices[0] + return HadamardProduct(*matrices).doit() + + +class HadamardProduct(MatrixExpr): + """ + Elementwise product of matrix expressions + + Examples + ======== + + Hadamard product for matrix symbols: + + >>> from sympy import hadamard_product, HadamardProduct, MatrixSymbol + >>> A = MatrixSymbol('A', 5, 5) + >>> B = MatrixSymbol('B', 5, 5) + >>> isinstance(hadamard_product(A, B), HadamardProduct) + True + + Notes + ===== + + This is a symbolic object that simply stores its argument without + evaluating it. To actually compute the product, use the function + ``hadamard_product()`` or ``HadamardProduct.doit`` + """ + is_HadamardProduct = True + + def __new__(cls, *args, evaluate=False, check=None): + args = list(map(sympify, args)) + if len(args) == 0: + # We currently don't have a way to support one-matrices of generic dimensions: + raise ValueError("HadamardProduct needs at least one argument") + + if not all(isinstance(arg, MatrixExpr) for arg in args): + raise TypeError("Mix of Matrix and Scalar symbols") + + if check is not None: + sympy_deprecation_warning( + "Passing check to HadamardProduct is deprecated and the check argument will be removed in a future version.", + deprecated_since_version="1.11", + active_deprecations_target='remove-check-argument-from-matrix-operations') + + if check is not False: + validate(*args) + + obj = super().__new__(cls, *args) + if evaluate: + obj = obj.doit(deep=False) + return obj + + @property + def shape(self): + return self.args[0].shape + + def _entry(self, i, j, **kwargs): + return Mul(*[arg._entry(i, j, **kwargs) for arg in self.args]) + + def _eval_transpose(self): + from sympy.matrices.expressions.transpose import transpose + return HadamardProduct(*list(map(transpose, self.args))) + + def doit(self, **hints): + expr = self.func(*(i.doit(**hints) for i in self.args)) + # Check for explicit matrices: + from sympy.matrices.matrixbase import MatrixBase + from sympy.matrices.immutable import ImmutableMatrix + + explicit = [i for i in expr.args if isinstance(i, MatrixBase)] + if explicit: + remainder = [i for i in expr.args if i not in explicit] + expl_mat = ImmutableMatrix([ + Mul.fromiter(i) for i in zip(*explicit) + ]).reshape(*self.shape) + expr = HadamardProduct(*([expl_mat] + remainder)) + + return canonicalize(expr) + + def _eval_derivative(self, x): + terms = [] + args = list(self.args) + for i in range(len(args)): + factors = args[:i] + [args[i].diff(x)] + args[i+1:] + terms.append(hadamard_product(*factors)) + return Add.fromiter(terms) + + def _eval_derivative_matrix_lines(self, x): + from sympy.tensor.array.expressions.array_expressions import ArrayDiagonal + from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct + from sympy.matrices.expressions.matexpr import _make_matrix + + with_x_ind = [i for i, arg in enumerate(self.args) if arg.has(x)] + lines = [] + for ind in with_x_ind: + left_args = self.args[:ind] + right_args = self.args[ind+1:] + + d = self.args[ind]._eval_derivative_matrix_lines(x) + hadam = hadamard_product(*(right_args + left_args)) + diagonal = [(0, 2), (3, 4)] + diagonal = [e for j, e in enumerate(diagonal) if self.shape[j] != 1] + for i in d: + l1 = i._lines[i._first_line_index] + l2 = i._lines[i._second_line_index] + subexpr = ExprBuilder( + ArrayDiagonal, + [ + ExprBuilder( + ArrayTensorProduct, + [ + ExprBuilder(_make_matrix, [l1]), + hadam, + ExprBuilder(_make_matrix, [l2]), + ] + ), + *diagonal], + + ) + i._first_pointer_parent = subexpr.args[0].args[0].args + i._first_pointer_index = 0 + i._second_pointer_parent = subexpr.args[0].args[2].args + i._second_pointer_index = 0 + i._lines = [subexpr] + lines.append(i) + + return lines + + +# TODO Implement algorithm for rewriting Hadamard product as diagonal matrix +# if matmul identy matrix is multiplied. +def canonicalize(x): + """Canonicalize the Hadamard product ``x`` with mathematical properties. + + Examples + ======== + + >>> from sympy import MatrixSymbol, HadamardProduct + >>> from sympy import OneMatrix, ZeroMatrix + >>> from sympy.matrices.expressions.hadamard import canonicalize + >>> from sympy import init_printing + >>> init_printing(use_unicode=False) + + >>> A = MatrixSymbol('A', 2, 2) + >>> B = MatrixSymbol('B', 2, 2) + >>> C = MatrixSymbol('C', 2, 2) + + Hadamard product associativity: + + >>> X = HadamardProduct(A, HadamardProduct(B, C)) + >>> X + A.*(B.*C) + >>> canonicalize(X) + A.*B.*C + + Hadamard product commutativity: + + >>> X = HadamardProduct(A, B) + >>> Y = HadamardProduct(B, A) + >>> X + A.*B + >>> Y + B.*A + >>> canonicalize(X) + A.*B + >>> canonicalize(Y) + A.*B + + Hadamard product identity: + + >>> X = HadamardProduct(A, OneMatrix(2, 2)) + >>> X + A.*1 + >>> canonicalize(X) + A + + Absorbing element of Hadamard product: + + >>> X = HadamardProduct(A, ZeroMatrix(2, 2)) + >>> X + A.*0 + >>> canonicalize(X) + 0 + + Rewriting to Hadamard Power + + >>> X = HadamardProduct(A, A, A) + >>> X + A.*A.*A + >>> canonicalize(X) + .3 + A + + Notes + ===== + + As the Hadamard product is associative, nested products can be flattened. + + The Hadamard product is commutative so that factors can be sorted for + canonical form. + + A matrix of only ones is an identity for Hadamard product, + so every matrices of only ones can be removed. + + Any zero matrix will make the whole product a zero matrix. + + Duplicate elements can be collected and rewritten as HadamardPower + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Hadamard_product_(matrices) + """ + # Associativity + rule = condition( + lambda x: isinstance(x, HadamardProduct), + flatten + ) + fun = exhaust(rule) + x = fun(x) + + # Identity + fun = condition( + lambda x: isinstance(x, HadamardProduct), + rm_id(lambda x: isinstance(x, OneMatrix)) + ) + x = fun(x) + + # Absorbing by Zero Matrix + def absorb(x): + if any(isinstance(c, ZeroMatrix) for c in x.args): + return ZeroMatrix(*x.shape) + else: + return x + fun = condition( + lambda x: isinstance(x, HadamardProduct), + absorb + ) + x = fun(x) + + # Rewriting with HadamardPower + if isinstance(x, HadamardProduct): + tally = Counter(x.args) + + new_arg = [] + for base, exp in tally.items(): + if exp == 1: + new_arg.append(base) + else: + new_arg.append(HadamardPower(base, exp)) + + x = HadamardProduct(*new_arg) + + # Commutativity + fun = condition( + lambda x: isinstance(x, HadamardProduct), + sort(default_sort_key) + ) + x = fun(x) + + # Unpacking + x = unpack(x) + return x + + +def hadamard_power(base, exp): + base = sympify(base) + exp = sympify(exp) + if exp == 1: + return base + if not base.is_Matrix: + return base**exp + if exp.is_Matrix: + raise ValueError("cannot raise expression to a matrix") + return HadamardPower(base, exp) + + +class HadamardPower(MatrixExpr): + r""" + Elementwise power of matrix expressions + + Parameters + ========== + + base : scalar or matrix + + exp : scalar or matrix + + Notes + ===== + + There are four definitions for the hadamard power which can be used. + Let's consider `A, B` as `(m, n)` matrices, and `a, b` as scalars. + + Matrix raised to a scalar exponent: + + .. math:: + A^{\circ b} = \begin{bmatrix} + A_{0, 0}^b & A_{0, 1}^b & \cdots & A_{0, n-1}^b \\ + A_{1, 0}^b & A_{1, 1}^b & \cdots & A_{1, n-1}^b \\ + \vdots & \vdots & \ddots & \vdots \\ + A_{m-1, 0}^b & A_{m-1, 1}^b & \cdots & A_{m-1, n-1}^b + \end{bmatrix} + + Scalar raised to a matrix exponent: + + .. math:: + a^{\circ B} = \begin{bmatrix} + a^{B_{0, 0}} & a^{B_{0, 1}} & \cdots & a^{B_{0, n-1}} \\ + a^{B_{1, 0}} & a^{B_{1, 1}} & \cdots & a^{B_{1, n-1}} \\ + \vdots & \vdots & \ddots & \vdots \\ + a^{B_{m-1, 0}} & a^{B_{m-1, 1}} & \cdots & a^{B_{m-1, n-1}} + \end{bmatrix} + + Matrix raised to a matrix exponent: + + .. math:: + A^{\circ B} = \begin{bmatrix} + A_{0, 0}^{B_{0, 0}} & A_{0, 1}^{B_{0, 1}} & + \cdots & A_{0, n-1}^{B_{0, n-1}} \\ + A_{1, 0}^{B_{1, 0}} & A_{1, 1}^{B_{1, 1}} & + \cdots & A_{1, n-1}^{B_{1, n-1}} \\ + \vdots & \vdots & + \ddots & \vdots \\ + A_{m-1, 0}^{B_{m-1, 0}} & A_{m-1, 1}^{B_{m-1, 1}} & + \cdots & A_{m-1, n-1}^{B_{m-1, n-1}} + \end{bmatrix} + + Scalar raised to a scalar exponent: + + .. math:: + a^{\circ b} = a^b + """ + + def __new__(cls, base, exp): + base = sympify(base) + exp = sympify(exp) + + if base.is_scalar and exp.is_scalar: + return base ** exp + + if isinstance(base, MatrixExpr) and isinstance(exp, MatrixExpr): + validate(base, exp) + + obj = super().__new__(cls, base, exp) + return obj + + @property + def base(self): + return self._args[0] + + @property + def exp(self): + return self._args[1] + + @property + def shape(self): + if self.base.is_Matrix: + return self.base.shape + return self.exp.shape + + def _entry(self, i, j, **kwargs): + base = self.base + exp = self.exp + + if base.is_Matrix: + a = base._entry(i, j, **kwargs) + elif base.is_scalar: + a = base + else: + raise ValueError( + 'The base {} must be a scalar or a matrix.'.format(base)) + + if exp.is_Matrix: + b = exp._entry(i, j, **kwargs) + elif exp.is_scalar: + b = exp + else: + raise ValueError( + 'The exponent {} must be a scalar or a matrix.'.format(exp)) + + return a ** b + + def _eval_transpose(self): + from sympy.matrices.expressions.transpose import transpose + return HadamardPower(transpose(self.base), self.exp) + + def _eval_derivative(self, x): + dexp = self.exp.diff(x) + logbase = self.base.applyfunc(log) + dlbase = logbase.diff(x) + return hadamard_product( + dexp*logbase + self.exp*dlbase, + self + ) + + def _eval_derivative_matrix_lines(self, x): + from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct + from sympy.tensor.array.expressions.array_expressions import ArrayDiagonal + from sympy.matrices.expressions.matexpr import _make_matrix + + lr = self.base._eval_derivative_matrix_lines(x) + for i in lr: + diagonal = [(1, 2), (3, 4)] + diagonal = [e for j, e in enumerate(diagonal) if self.base.shape[j] != 1] + l1 = i._lines[i._first_line_index] + l2 = i._lines[i._second_line_index] + subexpr = ExprBuilder( + ArrayDiagonal, + [ + ExprBuilder( + ArrayTensorProduct, + [ + ExprBuilder(_make_matrix, [l1]), + self.exp*hadamard_power(self.base, self.exp-1), + ExprBuilder(_make_matrix, [l2]), + ] + ), + *diagonal], + validator=ArrayDiagonal._validate + ) + i._first_pointer_parent = subexpr.args[0].args[0].args + i._first_pointer_index = 0 + i._first_line_index = 0 + i._second_pointer_parent = subexpr.args[0].args[2].args + i._second_pointer_index = 0 + i._second_line_index = 0 + i._lines = [subexpr] + return lr diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/inverse.py b/phivenv/Lib/site-packages/sympy/matrices/expressions/inverse.py new file mode 100644 index 0000000000000000000000000000000000000000..cfc3feccd7126a761f18f23599eed9413c86a9e5 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/expressions/inverse.py @@ -0,0 +1,112 @@ +from sympy.core.sympify import _sympify +from sympy.core import S, Basic + +from sympy.matrices.exceptions import NonSquareMatrixError +from sympy.matrices.expressions.matpow import MatPow + + +class Inverse(MatPow): + """ + The multiplicative inverse of a matrix expression + + This is a symbolic object that simply stores its argument without + evaluating it. To actually compute the inverse, use the ``.inverse()`` + method of matrices. + + Examples + ======== + + >>> from sympy import MatrixSymbol, Inverse + >>> A = MatrixSymbol('A', 3, 3) + >>> B = MatrixSymbol('B', 3, 3) + >>> Inverse(A) + A**(-1) + >>> A.inverse() == Inverse(A) + True + >>> (A*B).inverse() + B**(-1)*A**(-1) + >>> Inverse(A*B) + (A*B)**(-1) + + """ + is_Inverse = True + exp = S.NegativeOne + + def __new__(cls, mat, exp=S.NegativeOne): + # exp is there to make it consistent with + # inverse.func(*inverse.args) == inverse + mat = _sympify(mat) + exp = _sympify(exp) + if not mat.is_Matrix: + raise TypeError("mat should be a matrix") + if mat.is_square is False: + raise NonSquareMatrixError("Inverse of non-square matrix %s" % mat) + return Basic.__new__(cls, mat, exp) + + @property + def arg(self): + return self.args[0] + + @property + def shape(self): + return self.arg.shape + + def _eval_inverse(self): + return self.arg + + def _eval_transpose(self): + return Inverse(self.arg.transpose()) + + def _eval_adjoint(self): + return Inverse(self.arg.adjoint()) + + def _eval_conjugate(self): + return Inverse(self.arg.conjugate()) + + def _eval_determinant(self): + from sympy.matrices.expressions.determinant import det + return 1/det(self.arg) + + def doit(self, **hints): + if 'inv_expand' in hints and hints['inv_expand'] == False: + return self + + arg = self.arg + if hints.get('deep', True): + arg = arg.doit(**hints) + + return arg.inverse() + + def _eval_derivative_matrix_lines(self, x): + arg = self.args[0] + lines = arg._eval_derivative_matrix_lines(x) + for line in lines: + line.first_pointer *= -self.T + line.second_pointer *= self + return lines + + +from sympy.assumptions.ask import ask, Q +from sympy.assumptions.refine import handlers_dict + + +def refine_Inverse(expr, assumptions): + """ + >>> from sympy import MatrixSymbol, Q, assuming, refine + >>> X = MatrixSymbol('X', 2, 2) + >>> X.I + X**(-1) + >>> with assuming(Q.orthogonal(X)): + ... print(refine(X.I)) + X.T + """ + if ask(Q.orthogonal(expr), assumptions): + return expr.arg.T + elif ask(Q.unitary(expr), assumptions): + return expr.arg.conjugate() + elif ask(Q.singular(expr), assumptions): + raise ValueError("Inverse of singular matrix %s" % expr.arg) + + return expr + +handlers_dict['Inverse'] = refine_Inverse diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/kronecker.py b/phivenv/Lib/site-packages/sympy/matrices/expressions/kronecker.py new file mode 100644 index 0000000000000000000000000000000000000000..1dd175cb0d500af3e786e2d0dbf6b010947840b4 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/expressions/kronecker.py @@ -0,0 +1,434 @@ +"""Implementation of the Kronecker product""" +from functools import reduce +from math import prod + +from sympy.core import Mul, sympify +from sympy.functions import adjoint +from sympy.matrices.exceptions import ShapeError +from sympy.matrices.expressions.matexpr import MatrixExpr +from sympy.matrices.expressions.transpose import transpose +from sympy.matrices.expressions.special import Identity +from sympy.matrices.matrixbase import MatrixBase +from sympy.strategies import ( + canon, condition, distribute, do_one, exhaust, flatten, typed, unpack) +from sympy.strategies.traverse import bottom_up +from sympy.utilities import sift + +from .matadd import MatAdd +from .matmul import MatMul +from .matpow import MatPow + + +def kronecker_product(*matrices): + """ + The Kronecker product of two or more arguments. + + This computes the explicit Kronecker product for subclasses of + ``MatrixBase`` i.e. explicit matrices. Otherwise, a symbolic + ``KroneckerProduct`` object is returned. + + + Examples + ======== + + For ``MatrixSymbol`` arguments a ``KroneckerProduct`` object is returned. + Elements of this matrix can be obtained by indexing, or for MatrixSymbols + with known dimension the explicit matrix can be obtained with + ``.as_explicit()`` + + >>> from sympy import kronecker_product, MatrixSymbol + >>> A = MatrixSymbol('A', 2, 2) + >>> B = MatrixSymbol('B', 2, 2) + >>> kronecker_product(A) + A + >>> kronecker_product(A, B) + KroneckerProduct(A, B) + >>> kronecker_product(A, B)[0, 1] + A[0, 0]*B[0, 1] + >>> kronecker_product(A, B).as_explicit() + Matrix([ + [A[0, 0]*B[0, 0], A[0, 0]*B[0, 1], A[0, 1]*B[0, 0], A[0, 1]*B[0, 1]], + [A[0, 0]*B[1, 0], A[0, 0]*B[1, 1], A[0, 1]*B[1, 0], A[0, 1]*B[1, 1]], + [A[1, 0]*B[0, 0], A[1, 0]*B[0, 1], A[1, 1]*B[0, 0], A[1, 1]*B[0, 1]], + [A[1, 0]*B[1, 0], A[1, 0]*B[1, 1], A[1, 1]*B[1, 0], A[1, 1]*B[1, 1]]]) + + For explicit matrices the Kronecker product is returned as a Matrix + + >>> from sympy import Matrix, kronecker_product + >>> sigma_x = Matrix([ + ... [0, 1], + ... [1, 0]]) + ... + >>> Isigma_y = Matrix([ + ... [0, 1], + ... [-1, 0]]) + ... + >>> kronecker_product(sigma_x, Isigma_y) + Matrix([ + [ 0, 0, 0, 1], + [ 0, 0, -1, 0], + [ 0, 1, 0, 0], + [-1, 0, 0, 0]]) + + See Also + ======== + KroneckerProduct + + """ + if not matrices: + raise TypeError("Empty Kronecker product is undefined") + if len(matrices) == 1: + return matrices[0] + else: + return KroneckerProduct(*matrices).doit() + + +class KroneckerProduct(MatrixExpr): + """ + The Kronecker product of two or more arguments. + + The Kronecker product is a non-commutative product of matrices. + Given two matrices of dimension (m, n) and (s, t) it produces a matrix + of dimension (m s, n t). + + This is a symbolic object that simply stores its argument without + evaluating it. To actually compute the product, use the function + ``kronecker_product()`` or call the ``.doit()`` or ``.as_explicit()`` + methods. + + >>> from sympy import KroneckerProduct, MatrixSymbol + >>> A = MatrixSymbol('A', 5, 5) + >>> B = MatrixSymbol('B', 5, 5) + >>> isinstance(KroneckerProduct(A, B), KroneckerProduct) + True + """ + is_KroneckerProduct = True + + def __new__(cls, *args, check=True): + args = list(map(sympify, args)) + if all(a.is_Identity for a in args): + ret = Identity(prod(a.rows for a in args)) + if all(isinstance(a, MatrixBase) for a in args): + return ret.as_explicit() + else: + return ret + + if check: + validate(*args) + return super().__new__(cls, *args) + + @property + def shape(self): + rows, cols = self.args[0].shape + for mat in self.args[1:]: + rows *= mat.rows + cols *= mat.cols + return (rows, cols) + + def _entry(self, i, j, **kwargs): + result = 1 + for mat in reversed(self.args): + i, m = divmod(i, mat.rows) + j, n = divmod(j, mat.cols) + result *= mat[m, n] + return result + + def _eval_adjoint(self): + return KroneckerProduct(*list(map(adjoint, self.args))).doit() + + def _eval_conjugate(self): + return KroneckerProduct(*[a.conjugate() for a in self.args]).doit() + + def _eval_transpose(self): + return KroneckerProduct(*list(map(transpose, self.args))).doit() + + def _eval_trace(self): + from .trace import trace + return Mul(*[trace(a) for a in self.args]) + + def _eval_determinant(self): + from .determinant import det, Determinant + if not all(a.is_square for a in self.args): + return Determinant(self) + + m = self.rows + return Mul(*[det(a)**(m/a.rows) for a in self.args]) + + def _eval_inverse(self): + try: + return KroneckerProduct(*[a.inverse() for a in self.args]) + except ShapeError: + from sympy.matrices.expressions.inverse import Inverse + return Inverse(self) + + def structurally_equal(self, other): + '''Determine whether two matrices have the same Kronecker product structure + + Examples + ======== + + >>> from sympy import KroneckerProduct, MatrixSymbol, symbols + >>> m, n = symbols(r'm, n', integer=True) + >>> A = MatrixSymbol('A', m, m) + >>> B = MatrixSymbol('B', n, n) + >>> C = MatrixSymbol('C', m, m) + >>> D = MatrixSymbol('D', n, n) + >>> KroneckerProduct(A, B).structurally_equal(KroneckerProduct(C, D)) + True + >>> KroneckerProduct(A, B).structurally_equal(KroneckerProduct(D, C)) + False + >>> KroneckerProduct(A, B).structurally_equal(C) + False + ''' + # Inspired by BlockMatrix + return (isinstance(other, KroneckerProduct) + and self.shape == other.shape + and len(self.args) == len(other.args) + and all(a.shape == b.shape for (a, b) in zip(self.args, other.args))) + + def has_matching_shape(self, other): + '''Determine whether two matrices have the appropriate structure to bring matrix + multiplication inside the KroneckerProdut + + Examples + ======== + >>> from sympy import KroneckerProduct, MatrixSymbol, symbols + >>> m, n = symbols(r'm, n', integer=True) + >>> A = MatrixSymbol('A', m, n) + >>> B = MatrixSymbol('B', n, m) + >>> KroneckerProduct(A, B).has_matching_shape(KroneckerProduct(B, A)) + True + >>> KroneckerProduct(A, B).has_matching_shape(KroneckerProduct(A, B)) + False + >>> KroneckerProduct(A, B).has_matching_shape(A) + False + ''' + return (isinstance(other, KroneckerProduct) + and self.cols == other.rows + and len(self.args) == len(other.args) + and all(a.cols == b.rows for (a, b) in zip(self.args, other.args))) + + def _eval_expand_kroneckerproduct(self, **hints): + return flatten(canon(typed({KroneckerProduct: distribute(KroneckerProduct, MatAdd)}))(self)) + + def _kronecker_add(self, other): + if self.structurally_equal(other): + return self.__class__(*[a + b for (a, b) in zip(self.args, other.args)]) + else: + return self + other + + def _kronecker_mul(self, other): + if self.has_matching_shape(other): + return self.__class__(*[a*b for (a, b) in zip(self.args, other.args)]) + else: + return self * other + + def doit(self, **hints): + deep = hints.get('deep', True) + if deep: + args = [arg.doit(**hints) for arg in self.args] + else: + args = self.args + return canonicalize(KroneckerProduct(*args)) + + +def validate(*args): + if not all(arg.is_Matrix for arg in args): + raise TypeError("Mix of Matrix and Scalar symbols") + + +# rules + +def extract_commutative(kron): + c_part = [] + nc_part = [] + for arg in kron.args: + c, nc = arg.args_cnc() + c_part.extend(c) + nc_part.append(Mul._from_args(nc)) + + c_part = Mul(*c_part) + if c_part != 1: + return c_part*KroneckerProduct(*nc_part) + return kron + + +def matrix_kronecker_product(*matrices): + """Compute the Kronecker product of a sequence of SymPy Matrices. + + This is the standard Kronecker product of matrices [1]. + + Parameters + ========== + + matrices : tuple of MatrixBase instances + The matrices to take the Kronecker product of. + + Returns + ======= + + matrix : MatrixBase + The Kronecker product matrix. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.matrices.expressions.kronecker import ( + ... matrix_kronecker_product) + + >>> m1 = Matrix([[1,2],[3,4]]) + >>> m2 = Matrix([[1,0],[0,1]]) + >>> matrix_kronecker_product(m1, m2) + Matrix([ + [1, 0, 2, 0], + [0, 1, 0, 2], + [3, 0, 4, 0], + [0, 3, 0, 4]]) + >>> matrix_kronecker_product(m2, m1) + Matrix([ + [1, 2, 0, 0], + [3, 4, 0, 0], + [0, 0, 1, 2], + [0, 0, 3, 4]]) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Kronecker_product + """ + # Make sure we have a sequence of Matrices + if not all(isinstance(m, MatrixBase) for m in matrices): + raise TypeError( + 'Sequence of Matrices expected, got: %s' % repr(matrices) + ) + + # Pull out the first element in the product. + matrix_expansion = matrices[-1] + # Do the kronecker product working from right to left. + for mat in reversed(matrices[:-1]): + rows = mat.rows + cols = mat.cols + # Go through each row appending kronecker product to. + # running matrix_expansion. + for i in range(rows): + start = matrix_expansion*mat[i*cols] + # Go through each column joining each item + for j in range(cols - 1): + start = start.row_join( + matrix_expansion*mat[i*cols + j + 1] + ) + # If this is the first element, make it the start of the + # new row. + if i == 0: + next = start + else: + next = next.col_join(start) + matrix_expansion = next + + MatrixClass = max(matrices, key=lambda M: M._class_priority).__class__ + if isinstance(matrix_expansion, MatrixClass): + return matrix_expansion + else: + return MatrixClass(matrix_expansion) + + +def explicit_kronecker_product(kron): + # Make sure we have a sequence of Matrices + if not all(isinstance(m, MatrixBase) for m in kron.args): + return kron + + return matrix_kronecker_product(*kron.args) + + +rules = (unpack, + explicit_kronecker_product, + flatten, + extract_commutative) + +canonicalize = exhaust(condition(lambda x: isinstance(x, KroneckerProduct), + do_one(*rules))) + + +def _kronecker_dims_key(expr): + if isinstance(expr, KroneckerProduct): + return tuple(a.shape for a in expr.args) + else: + return (0,) + + +def kronecker_mat_add(expr): + args = sift(expr.args, _kronecker_dims_key) + nonkrons = args.pop((0,), None) + if not args: + return expr + + krons = [reduce(lambda x, y: x._kronecker_add(y), group) + for group in args.values()] + + if not nonkrons: + return MatAdd(*krons) + else: + return MatAdd(*krons) + nonkrons + + +def kronecker_mat_mul(expr): + # modified from block matrix code + factor, matrices = expr.as_coeff_matrices() + + i = 0 + while i < len(matrices) - 1: + A, B = matrices[i:i+2] + if isinstance(A, KroneckerProduct) and isinstance(B, KroneckerProduct): + matrices[i] = A._kronecker_mul(B) + matrices.pop(i+1) + else: + i += 1 + + return factor*MatMul(*matrices) + + +def kronecker_mat_pow(expr): + if isinstance(expr.base, KroneckerProduct) and all(a.is_square for a in expr.base.args): + return KroneckerProduct(*[MatPow(a, expr.exp) for a in expr.base.args]) + else: + return expr + + +def combine_kronecker(expr): + """Combine KronekeckerProduct with expression. + + If possible write operations on KroneckerProducts of compatible shapes + as a single KroneckerProduct. + + Examples + ======== + + >>> from sympy.matrices.expressions import combine_kronecker + >>> from sympy import MatrixSymbol, KroneckerProduct, symbols + >>> m, n = symbols(r'm, n', integer=True) + >>> A = MatrixSymbol('A', m, n) + >>> B = MatrixSymbol('B', n, m) + >>> combine_kronecker(KroneckerProduct(A, B)*KroneckerProduct(B, A)) + KroneckerProduct(A*B, B*A) + >>> combine_kronecker(KroneckerProduct(A, B)+KroneckerProduct(B.T, A.T)) + KroneckerProduct(A + B.T, B + A.T) + >>> C = MatrixSymbol('C', n, n) + >>> D = MatrixSymbol('D', m, m) + >>> combine_kronecker(KroneckerProduct(C, D)**m) + KroneckerProduct(C**m, D**m) + """ + def haskron(expr): + return isinstance(expr, MatrixExpr) and expr.has(KroneckerProduct) + + rule = exhaust( + bottom_up(exhaust(condition(haskron, typed( + {MatAdd: kronecker_mat_add, + MatMul: kronecker_mat_mul, + MatPow: kronecker_mat_pow}))))) + result = rule(expr) + doit = getattr(result, 'doit', None) + if doit is not None: + return doit() + else: + return result diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/matadd.py b/phivenv/Lib/site-packages/sympy/matrices/expressions/matadd.py new file mode 100644 index 0000000000000000000000000000000000000000..cfae1e5010e4077c7210c85c4315ed2404f245d7 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/expressions/matadd.py @@ -0,0 +1,155 @@ +from functools import reduce +import operator + +from sympy.core import Basic, sympify +from sympy.core.add import add, Add, _could_extract_minus_sign +from sympy.core.sorting import default_sort_key +from sympy.functions import adjoint +from sympy.matrices.matrixbase import MatrixBase +from sympy.matrices.expressions.transpose import transpose +from sympy.strategies import (rm_id, unpack, flatten, sort, condition, + exhaust, do_one, glom) +from sympy.matrices.expressions.matexpr import MatrixExpr +from sympy.matrices.expressions.special import ZeroMatrix, GenericZeroMatrix +from sympy.matrices.expressions._shape import validate_matadd_integer as validate +from sympy.utilities.iterables import sift +from sympy.utilities.exceptions import sympy_deprecation_warning + +# XXX: MatAdd should perhaps not subclass directly from Add +class MatAdd(MatrixExpr, Add): + """A Sum of Matrix Expressions + + MatAdd inherits from and operates like SymPy Add + + Examples + ======== + + >>> from sympy import MatAdd, MatrixSymbol + >>> A = MatrixSymbol('A', 5, 5) + >>> B = MatrixSymbol('B', 5, 5) + >>> C = MatrixSymbol('C', 5, 5) + >>> MatAdd(A, B, C) + A + B + C + """ + is_MatAdd = True + + identity = GenericZeroMatrix() + + def __new__(cls, *args, evaluate=False, check=None, _sympify=True): + if not args: + return cls.identity + + # This must be removed aggressively in the constructor to avoid + # TypeErrors from GenericZeroMatrix().shape + args = list(filter(lambda i: cls.identity != i, args)) + if _sympify: + args = list(map(sympify, args)) + + if not all(isinstance(arg, MatrixExpr) for arg in args): + raise TypeError("Mix of Matrix and Scalar symbols") + + obj = Basic.__new__(cls, *args) + + if check is not None: + sympy_deprecation_warning( + "Passing check to MatAdd is deprecated and the check argument will be removed in a future version.", + deprecated_since_version="1.11", + active_deprecations_target='remove-check-argument-from-matrix-operations') + + if check is not False: + validate(*args) + + if evaluate: + obj = cls._evaluate(obj) + + return obj + + @classmethod + def _evaluate(cls, expr): + return canonicalize(expr) + + @property + def shape(self): + return self.args[0].shape + + def could_extract_minus_sign(self): + return _could_extract_minus_sign(self) + + def expand(self, **kwargs): + expanded = super(MatAdd, self).expand(**kwargs) + return self._evaluate(expanded) + + def _entry(self, i, j, **kwargs): + return Add(*[arg._entry(i, j, **kwargs) for arg in self.args]) + + def _eval_transpose(self): + return MatAdd(*[transpose(arg) for arg in self.args]).doit() + + def _eval_adjoint(self): + return MatAdd(*[adjoint(arg) for arg in self.args]).doit() + + def _eval_trace(self): + from .trace import trace + return Add(*[trace(arg) for arg in self.args]).doit() + + def doit(self, **hints): + deep = hints.get('deep', True) + if deep: + args = [arg.doit(**hints) for arg in self.args] + else: + args = self.args + return canonicalize(MatAdd(*args)) + + def _eval_derivative_matrix_lines(self, x): + add_lines = [arg._eval_derivative_matrix_lines(x) for arg in self.args] + return [j for i in add_lines for j in i] + +add.register_handlerclass((Add, MatAdd), MatAdd) + + +factor_of = lambda arg: arg.as_coeff_mmul()[0] +matrix_of = lambda arg: unpack(arg.as_coeff_mmul()[1]) +def combine(cnt, mat): + if cnt == 1: + return mat + else: + return cnt * mat + + +def merge_explicit(matadd): + """ Merge explicit MatrixBase arguments + + Examples + ======== + + >>> from sympy import MatrixSymbol, eye, Matrix, MatAdd, pprint + >>> from sympy.matrices.expressions.matadd import merge_explicit + >>> A = MatrixSymbol('A', 2, 2) + >>> B = eye(2) + >>> C = Matrix([[1, 2], [3, 4]]) + >>> X = MatAdd(A, B, C) + >>> pprint(X) + [1 0] [1 2] + A + [ ] + [ ] + [0 1] [3 4] + >>> pprint(merge_explicit(X)) + [2 2] + A + [ ] + [3 5] + """ + groups = sift(matadd.args, lambda arg: isinstance(arg, MatrixBase)) + if len(groups[True]) > 1: + return MatAdd(*(groups[False] + [reduce(operator.add, groups[True])])) + else: + return matadd + + +rules = (rm_id(lambda x: x == 0 or isinstance(x, ZeroMatrix)), + unpack, + flatten, + glom(matrix_of, factor_of, combine), + merge_explicit, + sort(default_sort_key)) + +canonicalize = exhaust(condition(lambda x: isinstance(x, MatAdd), + do_one(*rules))) diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/matexpr.py b/phivenv/Lib/site-packages/sympy/matrices/expressions/matexpr.py new file mode 100644 index 0000000000000000000000000000000000000000..a4e99296ccfcbdac5e09a86ecee020adf9831c73 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/expressions/matexpr.py @@ -0,0 +1,888 @@ +from __future__ import annotations +from functools import wraps + +from sympy.core import S, Integer, Basic, Mul, Add +from sympy.core.assumptions import check_assumptions +from sympy.core.decorators import call_highest_priority +from sympy.core.expr import Expr, ExprBuilder +from sympy.core.logic import FuzzyBool +from sympy.core.symbol import Str, Dummy, symbols, Symbol +from sympy.core.sympify import SympifyError, _sympify +from sympy.external.gmpy import SYMPY_INTS +from sympy.functions import conjugate, adjoint +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.matrices.exceptions import NonSquareMatrixError +from sympy.matrices.kind import MatrixKind +from sympy.matrices.matrixbase import MatrixBase +from sympy.multipledispatch import dispatch +from sympy.utilities.misc import filldedent + + +def _sympifyit(arg, retval=None): + # This version of _sympifyit sympifies MutableMatrix objects + def deco(func): + @wraps(func) + def __sympifyit_wrapper(a, b): + try: + b = _sympify(b) + return func(a, b) + except SympifyError: + return retval + + return __sympifyit_wrapper + + return deco + + +class MatrixExpr(Expr): + """Superclass for Matrix Expressions + + MatrixExprs represent abstract matrices, linear transformations represented + within a particular basis. + + Examples + ======== + + >>> from sympy import MatrixSymbol + >>> A = MatrixSymbol('A', 3, 3) + >>> y = MatrixSymbol('y', 3, 1) + >>> x = (A.T*A).I * A * y + + See Also + ======== + + MatrixSymbol, MatAdd, MatMul, Transpose, Inverse + """ + __slots__: tuple[str, ...] = () + + # Should not be considered iterable by the + # sympy.utilities.iterables.iterable function. Subclass that actually are + # iterable (i.e., explicit matrices) should set this to True. + _iterable = False + + _op_priority = 11.0 + + is_Matrix: bool = True + is_MatrixExpr: bool = True + is_Identity: FuzzyBool = None + is_Inverse = False + is_Transpose = False + is_ZeroMatrix = False + is_MatAdd = False + is_MatMul = False + + is_commutative = False + is_number = False + is_symbol = False + is_scalar = False + + kind: MatrixKind = MatrixKind() + + def __new__(cls, *args, **kwargs): + args = map(_sympify, args) + return Basic.__new__(cls, *args, **kwargs) + + # The following is adapted from the core Expr object + + @property + def shape(self) -> tuple[Expr, Expr]: + raise NotImplementedError + + @property + def _add_handler(self): + return MatAdd + + @property + def _mul_handler(self): + return MatMul + + def __neg__(self): + return MatMul(S.NegativeOne, self).doit() + + def __abs__(self): + raise NotImplementedError + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__radd__') + def __add__(self, other): + return MatAdd(self, other).doit() + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__add__') + def __radd__(self, other): + return MatAdd(other, self).doit() + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__rsub__') + def __sub__(self, other): + return MatAdd(self, -other).doit() + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__sub__') + def __rsub__(self, other): + return MatAdd(other, -self).doit() + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__rmul__') + def __mul__(self, other): + return MatMul(self, other).doit() + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__rmul__') + def __matmul__(self, other): + return MatMul(self, other).doit() + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__mul__') + def __rmul__(self, other): + return MatMul(other, self).doit() + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__mul__') + def __rmatmul__(self, other): + return MatMul(other, self).doit() + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__rpow__') + def __pow__(self, other): + return MatPow(self, other).doit() + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__pow__') + def __rpow__(self, other): + raise NotImplementedError("Matrix Power not defined") + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__rtruediv__') + def __truediv__(self, other): + return self * other**S.NegativeOne + + @_sympifyit('other', NotImplemented) + @call_highest_priority('__truediv__') + def __rtruediv__(self, other): + raise NotImplementedError() + #return MatMul(other, Pow(self, S.NegativeOne)) + + @property + def rows(self): + return self.shape[0] + + @property + def cols(self): + return self.shape[1] + + @property + def is_square(self) -> bool | None: + rows, cols = self.shape + if isinstance(rows, Integer) and isinstance(cols, Integer): + return rows == cols + if rows == cols: + return True + return None + + def _eval_conjugate(self): + from sympy.matrices.expressions.adjoint import Adjoint + return Adjoint(Transpose(self)) + + def as_real_imag(self, deep=True, **hints): + return self._eval_as_real_imag() + + def _eval_as_real_imag(self): + real = S.Half * (self + self._eval_conjugate()) + im = (self - self._eval_conjugate())/(2*S.ImaginaryUnit) + return (real, im) + + def _eval_inverse(self): + return Inverse(self) + + def _eval_determinant(self): + return Determinant(self) + + def _eval_transpose(self): + return Transpose(self) + + def _eval_trace(self): + return None + + def _eval_power(self, exp): + """ + Override this in sub-classes to implement simplification of powers. The cases where the exponent + is -1, 0, 1 are already covered in MatPow.doit(), so implementations can exclude these cases. + """ + return MatPow(self, exp) + + def _eval_simplify(self, **kwargs): + if self.is_Atom: + return self + else: + from sympy.simplify import simplify + return self.func(*[simplify(x, **kwargs) for x in self.args]) + + def _eval_adjoint(self): + from sympy.matrices.expressions.adjoint import Adjoint + return Adjoint(self) + + def _eval_derivative_n_times(self, x, n): + return Basic._eval_derivative_n_times(self, x, n) + + def _eval_derivative(self, x): + # `x` is a scalar: + if self.has(x): + # See if there are other methods using it: + return super()._eval_derivative(x) + else: + return ZeroMatrix(*self.shape) + + @classmethod + def _check_dim(cls, dim): + """Helper function to check invalid matrix dimensions""" + ok = not dim.is_Float and check_assumptions( + dim, integer=True, nonnegative=True) + if ok is False: + raise ValueError( + "The dimension specification {} should be " + "a nonnegative integer.".format(dim)) + + + def _entry(self, i, j, **kwargs): + raise NotImplementedError( + "Indexing not implemented for %s" % self.__class__.__name__) + + def adjoint(self): + return adjoint(self) + + def as_coeff_Mul(self, rational=False): + """Efficiently extract the coefficient of a product.""" + return S.One, self + + def conjugate(self): + return conjugate(self) + + def transpose(self): + from sympy.matrices.expressions.transpose import transpose + return transpose(self) + + @property + def T(self): + '''Matrix transposition''' + return self.transpose() + + def inverse(self): + if self.is_square is False: + raise NonSquareMatrixError('Inverse of non-square matrix') + return self._eval_inverse() + + def inv(self): + return self.inverse() + + def det(self): + from sympy.matrices.expressions.determinant import det + return det(self) + + @property + def I(self): + return self.inverse() + + def valid_index(self, i, j): + def is_valid(idx): + return isinstance(idx, (int, Integer, Symbol, Expr)) + return (is_valid(i) and is_valid(j) and + (self.rows is None or + (i >= -self.rows) != False and (i < self.rows) != False) and + (j >= -self.cols) != False and (j < self.cols) != False) + + def __getitem__(self, key): + if not isinstance(key, tuple) and isinstance(key, slice): + from sympy.matrices.expressions.slice import MatrixSlice + return MatrixSlice(self, key, (0, None, 1)) + if isinstance(key, tuple) and len(key) == 2: + i, j = key + if isinstance(i, slice) or isinstance(j, slice): + from sympy.matrices.expressions.slice import MatrixSlice + return MatrixSlice(self, i, j) + i, j = _sympify(i), _sympify(j) + if self.valid_index(i, j) != False: + return self._entry(i, j) + else: + raise IndexError("Invalid indices (%s, %s)" % (i, j)) + elif isinstance(key, (SYMPY_INTS, Integer)): + # row-wise decomposition of matrix + rows, cols = self.shape + # allow single indexing if number of columns is known + if not isinstance(cols, Integer): + raise IndexError(filldedent(''' + Single indexing is only supported when the number + of columns is known.''')) + key = _sympify(key) + i = key // cols + j = key % cols + if self.valid_index(i, j) != False: + return self._entry(i, j) + else: + raise IndexError("Invalid index %s" % key) + elif isinstance(key, (Symbol, Expr)): + raise IndexError(filldedent(''' + Only integers may be used when addressing the matrix + with a single index.''')) + raise IndexError("Invalid index, wanted %s[i,j]" % self) + + def _is_shape_symbolic(self) -> bool: + return (not isinstance(self.rows, (SYMPY_INTS, Integer)) + or not isinstance(self.cols, (SYMPY_INTS, Integer))) + + def as_explicit(self): + """ + Returns a dense Matrix with elements represented explicitly + + Returns an object of type ImmutableDenseMatrix. + + Examples + ======== + + >>> from sympy import Identity + >>> I = Identity(3) + >>> I + I + >>> I.as_explicit() + Matrix([ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1]]) + + See Also + ======== + as_mutable: returns mutable Matrix type + + """ + if self._is_shape_symbolic(): + raise ValueError( + 'Matrix with symbolic shape ' + 'cannot be represented explicitly.') + from sympy.matrices.immutable import ImmutableDenseMatrix + return ImmutableDenseMatrix([[self[i, j] + for j in range(self.cols)] + for i in range(self.rows)]) + + def as_mutable(self): + """ + Returns a dense, mutable matrix with elements represented explicitly + + Examples + ======== + + >>> from sympy import Identity + >>> I = Identity(3) + >>> I + I + >>> I.shape + (3, 3) + >>> I.as_mutable() + Matrix([ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1]]) + + See Also + ======== + as_explicit: returns ImmutableDenseMatrix + """ + return self.as_explicit().as_mutable() + + def __array__(self, dtype=object, copy=None): + if copy is not None and not copy: + raise TypeError("Cannot implement copy=False when converting Matrix to ndarray") + from numpy import empty + a = empty(self.shape, dtype=object) + for i in range(self.rows): + for j in range(self.cols): + a[i, j] = self[i, j] + return a + + def equals(self, other): + """ + Test elementwise equality between matrices, potentially of different + types + + >>> from sympy import Identity, eye + >>> Identity(3).equals(eye(3)) + True + """ + return self.as_explicit().equals(other) + + def canonicalize(self): + return self + + def as_coeff_mmul(self): + return S.One, MatMul(self) + + @staticmethod + def from_index_summation(expr, first_index=None, last_index=None, dimensions=None): + r""" + Parse expression of matrices with explicitly summed indices into a + matrix expression without indices, if possible. + + This transformation expressed in mathematical notation: + + `\sum_{j=0}^{N-1} A_{i,j} B_{j,k} \Longrightarrow \mathbf{A}\cdot \mathbf{B}` + + Optional parameter ``first_index``: specify which free index to use as + the index starting the expression. + + Examples + ======== + + >>> from sympy import MatrixSymbol, MatrixExpr, Sum + >>> from sympy.abc import i, j, k, l, N + >>> A = MatrixSymbol("A", N, N) + >>> B = MatrixSymbol("B", N, N) + >>> expr = Sum(A[i, j]*B[j, k], (j, 0, N-1)) + >>> MatrixExpr.from_index_summation(expr) + A*B + + Transposition is detected: + + >>> expr = Sum(A[j, i]*B[j, k], (j, 0, N-1)) + >>> MatrixExpr.from_index_summation(expr) + A.T*B + + Detect the trace: + + >>> expr = Sum(A[i, i], (i, 0, N-1)) + >>> MatrixExpr.from_index_summation(expr) + Trace(A) + + More complicated expressions: + + >>> expr = Sum(A[i, j]*B[k, j]*A[l, k], (j, 0, N-1), (k, 0, N-1)) + >>> MatrixExpr.from_index_summation(expr) + A*B.T*A.T + """ + from sympy.tensor.array.expressions.from_indexed_to_array import convert_indexed_to_array + from sympy.tensor.array.expressions.from_array_to_matrix import convert_array_to_matrix + first_indices = [] + if first_index is not None: + first_indices.append(first_index) + if last_index is not None: + first_indices.append(last_index) + arr = convert_indexed_to_array(expr, first_indices=first_indices) + return convert_array_to_matrix(arr) + + def applyfunc(self, func): + from .applyfunc import ElementwiseApplyFunction + return ElementwiseApplyFunction(func, self) + + +@dispatch(MatrixExpr, Expr) +def _eval_is_eq(lhs, rhs): # noqa:F811 + return False + +@dispatch(MatrixExpr, MatrixExpr) # type: ignore +def _eval_is_eq(lhs, rhs): # noqa:F811 + if lhs.shape != rhs.shape: + return False + if (lhs - rhs).is_ZeroMatrix: + return True + +def get_postprocessor(cls): + def _postprocessor(expr): + # To avoid circular imports, we can't have MatMul/MatAdd on the top level + mat_class = {Mul: MatMul, Add: MatAdd}[cls] + nonmatrices = [] + matrices = [] + for term in expr.args: + if isinstance(term, MatrixExpr): + matrices.append(term) + else: + nonmatrices.append(term) + + if not matrices: + return cls._from_args(nonmatrices) + + if nonmatrices: + if cls == Mul: + for i in range(len(matrices)): + if not matrices[i].is_MatrixExpr: + # If one of the matrices explicit, absorb the scalar into it + # (doit will combine all explicit matrices into one, so it + # doesn't matter which) + matrices[i] = matrices[i].__mul__(cls._from_args(nonmatrices)) + nonmatrices = [] + break + + else: + # Maintain the ability to create Add(scalar, matrix) without + # raising an exception. That way different algorithms can + # replace matrix expressions with non-commutative symbols to + # manipulate them like non-commutative scalars. + return cls._from_args(nonmatrices + [mat_class(*matrices).doit(deep=False)]) + + if mat_class == MatAdd: + return mat_class(*matrices).doit(deep=False) + return mat_class(cls._from_args(nonmatrices), *matrices).doit(deep=False) + return _postprocessor + + +Basic._constructor_postprocessor_mapping[MatrixExpr] = { + "Mul": [get_postprocessor(Mul)], + "Add": [get_postprocessor(Add)], +} + + +def _matrix_derivative(expr, x, old_algorithm=False): + + if isinstance(expr, MatrixBase) or isinstance(x, MatrixBase): + # Do not use array expressions for explicit matrices: + old_algorithm = True + + if old_algorithm: + return _matrix_derivative_old_algorithm(expr, x) + + from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array + from sympy.tensor.array.expressions.arrayexpr_derivatives import array_derive + from sympy.tensor.array.expressions.from_array_to_matrix import convert_array_to_matrix + + array_expr = convert_matrix_to_array(expr) + diff_array_expr = array_derive(array_expr, x) + diff_matrix_expr = convert_array_to_matrix(diff_array_expr) + return diff_matrix_expr + + +def _matrix_derivative_old_algorithm(expr, x): + from sympy.tensor.array.array_derivatives import ArrayDerivative + lines = expr._eval_derivative_matrix_lines(x) + + parts = [i.build() for i in lines] + + from sympy.tensor.array.expressions.from_array_to_matrix import convert_array_to_matrix + + parts = [[convert_array_to_matrix(j) for j in i] for i in parts] + + def _get_shape(elem): + if isinstance(elem, MatrixExpr): + return elem.shape + return 1, 1 + + def get_rank(parts): + return sum(j not in (1, None) for i in parts for j in _get_shape(i)) + + ranks = [get_rank(i) for i in parts] + rank = ranks[0] + + def contract_one_dims(parts): + if len(parts) == 1: + return parts[0] + else: + p1, p2 = parts[:2] + if p2.is_Matrix: + p2 = p2.T + if p1 == Identity(1): + pbase = p2 + elif p2 == Identity(1): + pbase = p1 + else: + pbase = p1*p2 + if len(parts) == 2: + return pbase + else: # len(parts) > 2 + if pbase.is_Matrix: + raise ValueError("") + return pbase*Mul.fromiter(parts[2:]) + + if rank <= 2: + return Add.fromiter([contract_one_dims(i) for i in parts]) + + return ArrayDerivative(expr, x) + + +class MatrixElement(Expr): + parent = property(lambda self: self.args[0]) + i = property(lambda self: self.args[1]) + j = property(lambda self: self.args[2]) + _diff_wrt = True + is_symbol = True + is_commutative = True + + def __new__(cls, name, n, m): + n, m = map(_sympify, (n, m)) + if isinstance(name, str): + name = Symbol(name) + else: + if isinstance(name, MatrixBase): + if n.is_Integer and m.is_Integer: + return name[n, m] + name = _sympify(name) # change mutable into immutable + else: + name = _sympify(name) + if not isinstance(name.kind, MatrixKind): + raise TypeError("First argument of MatrixElement should be a matrix") + if not getattr(name, 'valid_index', lambda n, m: True)(n, m): + raise IndexError('indices out of range') + obj = Expr.__new__(cls, name, n, m) + return obj + + @property + def symbol(self): + return self.args[0] + + def doit(self, **hints): + deep = hints.get('deep', True) + if deep: + args = [arg.doit(**hints) for arg in self.args] + else: + args = self.args + return args[0][args[1], args[2]] + + @property + def indices(self): + return self.args[1:] + + def _eval_derivative(self, v): + + if not isinstance(v, MatrixElement): + return self.parent.diff(v)[self.i, self.j] + + M = self.args[0] + + m, n = self.parent.shape + + if M == v.args[0]: + return KroneckerDelta(self.args[1], v.args[1], (0, m-1)) * \ + KroneckerDelta(self.args[2], v.args[2], (0, n-1)) + + if isinstance(M, Inverse): + from sympy.concrete.summations import Sum + i, j = self.args[1:] + i1, i2 = symbols("z1, z2", cls=Dummy) + Y = M.args[0] + r1, r2 = Y.shape + return -Sum(M[i, i1]*Y[i1, i2].diff(v)*M[i2, j], (i1, 0, r1-1), (i2, 0, r2-1)) + + if self.has(v.args[0]): + return None + + return S.Zero + + +class MatrixSymbol(MatrixExpr): + """Symbolic representation of a Matrix object + + Creates a SymPy Symbol to represent a Matrix. This matrix has a shape and + can be included in Matrix Expressions + + Examples + ======== + + >>> from sympy import MatrixSymbol, Identity + >>> A = MatrixSymbol('A', 3, 4) # A 3 by 4 Matrix + >>> B = MatrixSymbol('B', 4, 3) # A 4 by 3 Matrix + >>> A.shape + (3, 4) + >>> 2*A*B + Identity(3) + I + 2*A*B + """ + is_commutative = False + is_symbol = True + _diff_wrt = True + + def __new__(cls, name, n, m): + n, m = _sympify(n), _sympify(m) + + cls._check_dim(m) + cls._check_dim(n) + + if isinstance(name, str): + name = Str(name) + obj = Basic.__new__(cls, name, n, m) + return obj + + @property + def shape(self): + return self.args[1], self.args[2] + + @property + def name(self): + return self.args[0].name + + def _entry(self, i, j, **kwargs): + return MatrixElement(self, i, j) + + @property + def free_symbols(self): + return {self} + + def _eval_simplify(self, **kwargs): + return self + + def _eval_derivative(self, x): + # x is a scalar: + return ZeroMatrix(self.shape[0], self.shape[1]) + + def _eval_derivative_matrix_lines(self, x): + if self != x: + first = ZeroMatrix(x.shape[0], self.shape[0]) if self.shape[0] != 1 else S.Zero + second = ZeroMatrix(x.shape[1], self.shape[1]) if self.shape[1] != 1 else S.Zero + return [_LeftRightArgs( + [first, second], + )] + else: + first = Identity(self.shape[0]) if self.shape[0] != 1 else S.One + second = Identity(self.shape[1]) if self.shape[1] != 1 else S.One + return [_LeftRightArgs( + [first, second], + )] + + +def matrix_symbols(expr): + return [sym for sym in expr.free_symbols if sym.is_Matrix] + + +class _LeftRightArgs: + r""" + Helper class to compute matrix derivatives. + + The logic: when an expression is derived by a matrix `X_{mn}`, two lines of + matrix multiplications are created: the one contracted to `m` (first line), + and the one contracted to `n` (second line). + + Transposition flips the side by which new matrices are connected to the + lines. + + The trace connects the end of the two lines. + """ + + def __init__(self, lines, higher=S.One): + self._lines = list(lines) + self._first_pointer_parent = self._lines + self._first_pointer_index = 0 + self._first_line_index = 0 + self._second_pointer_parent = self._lines + self._second_pointer_index = 1 + self._second_line_index = 1 + self.higher = higher + + @property + def first_pointer(self): + return self._first_pointer_parent[self._first_pointer_index] + + @first_pointer.setter + def first_pointer(self, value): + self._first_pointer_parent[self._first_pointer_index] = value + + @property + def second_pointer(self): + return self._second_pointer_parent[self._second_pointer_index] + + @second_pointer.setter + def second_pointer(self, value): + self._second_pointer_parent[self._second_pointer_index] = value + + def __repr__(self): + built = [self._build(i) for i in self._lines] + return "_LeftRightArgs(lines=%s, higher=%s)" % ( + built, + self.higher, + ) + + def transpose(self): + self._first_pointer_parent, self._second_pointer_parent = self._second_pointer_parent, self._first_pointer_parent + self._first_pointer_index, self._second_pointer_index = self._second_pointer_index, self._first_pointer_index + self._first_line_index, self._second_line_index = self._second_line_index, self._first_line_index + return self + + @staticmethod + def _build(expr): + if isinstance(expr, ExprBuilder): + return expr.build() + if isinstance(expr, list): + if len(expr) == 1: + return expr[0] + else: + return expr[0](*[_LeftRightArgs._build(i) for i in expr[1]]) + else: + return expr + + def build(self): + data = [self._build(i) for i in self._lines] + if self.higher != 1: + data += [self._build(self.higher)] + data = list(data) + return data + + def matrix_form(self): + if self.first != 1 and self.higher != 1: + raise ValueError("higher dimensional array cannot be represented") + + def _get_shape(elem): + if isinstance(elem, MatrixExpr): + return elem.shape + return (None, None) + + if _get_shape(self.first)[1] != _get_shape(self.second)[1]: + # Remove one-dimensional identity matrices: + # (this is needed by `a.diff(a)` where `a` is a vector) + if _get_shape(self.second) == (1, 1): + return self.first*self.second[0, 0] + if _get_shape(self.first) == (1, 1): + return self.first[1, 1]*self.second.T + raise ValueError("incompatible shapes") + if self.first != 1: + return self.first*self.second.T + else: + return self.higher + + def rank(self): + """ + Number of dimensions different from trivial (warning: not related to + matrix rank). + """ + rank = 0 + if self.first != 1: + rank += sum(i != 1 for i in self.first.shape) + if self.second != 1: + rank += sum(i != 1 for i in self.second.shape) + if self.higher != 1: + rank += 2 + return rank + + def _multiply_pointer(self, pointer, other): + from ...tensor.array.expressions.array_expressions import ArrayTensorProduct + from ...tensor.array.expressions.array_expressions import ArrayContraction + + subexpr = ExprBuilder( + ArrayContraction, + [ + ExprBuilder( + ArrayTensorProduct, + [ + pointer, + other + ] + ), + (1, 2) + ], + validator=ArrayContraction._validate + ) + + return subexpr + + def append_first(self, other): + self.first_pointer *= other + + def append_second(self, other): + self.second_pointer *= other + + +def _make_matrix(x): + from sympy.matrices.immutable import ImmutableDenseMatrix + if isinstance(x, MatrixExpr): + return x + return ImmutableDenseMatrix([[x]]) + + +from .matmul import MatMul +from .matadd import MatAdd +from .matpow import MatPow +from .transpose import Transpose +from .inverse import Inverse +from .special import ZeroMatrix, Identity +from .determinant import Determinant diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/matmul.py b/phivenv/Lib/site-packages/sympy/matrices/expressions/matmul.py new file mode 100644 index 0000000000000000000000000000000000000000..1c46f7ff5251d89793423f92ea02d7243601de3f --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/expressions/matmul.py @@ -0,0 +1,496 @@ +from sympy.assumptions.ask import ask, Q +from sympy.assumptions.refine import handlers_dict +from sympy.core import Basic, sympify, S +from sympy.core.mul import mul, Mul +from sympy.core.numbers import Number, Integer +from sympy.core.symbol import Dummy +from sympy.functions import adjoint +from sympy.strategies import (rm_id, unpack, typed, flatten, exhaust, + do_one, new) +from sympy.matrices.exceptions import NonInvertibleMatrixError +from sympy.matrices.matrixbase import MatrixBase +from sympy.utilities.exceptions import sympy_deprecation_warning +from sympy.matrices.expressions._shape import validate_matmul_integer as validate + +from .inverse import Inverse +from .matexpr import MatrixExpr +from .matpow import MatPow +from .transpose import transpose +from .permutation import PermutationMatrix +from .special import ZeroMatrix, Identity, GenericIdentity, OneMatrix + + +# XXX: MatMul should perhaps not subclass directly from Mul +class MatMul(MatrixExpr, Mul): + """ + A product of matrix expressions + + Examples + ======== + + >>> from sympy import MatMul, MatrixSymbol + >>> A = MatrixSymbol('A', 5, 4) + >>> B = MatrixSymbol('B', 4, 3) + >>> C = MatrixSymbol('C', 3, 6) + >>> MatMul(A, B, C) + A*B*C + """ + is_MatMul = True + + identity = GenericIdentity() + + def __new__(cls, *args, evaluate=False, check=None, _sympify=True): + if not args: + return cls.identity + + # This must be removed aggressively in the constructor to avoid + # TypeErrors from GenericIdentity().shape + args = list(filter(lambda i: cls.identity != i, args)) + if _sympify: + args = list(map(sympify, args)) + obj = Basic.__new__(cls, *args) + factor, matrices = obj.as_coeff_matrices() + + if check is not None: + sympy_deprecation_warning( + "Passing check to MatMul is deprecated and the check argument will be removed in a future version.", + deprecated_since_version="1.11", + active_deprecations_target='remove-check-argument-from-matrix-operations') + + if check is not False: + validate(*matrices) + + if not matrices: + # Should it be + # + # return Basic.__neq__(cls, factor, GenericIdentity()) ? + return factor + + if evaluate: + return cls._evaluate(obj) + + return obj + + @classmethod + def _evaluate(cls, expr): + return canonicalize(expr) + + @property + def shape(self): + matrices = [arg for arg in self.args if arg.is_Matrix] + return (matrices[0].rows, matrices[-1].cols) + + def _entry(self, i, j, expand=True, **kwargs): + # Avoid cyclic imports + from sympy.concrete.summations import Sum + from sympy.matrices.immutable import ImmutableMatrix + + coeff, matrices = self.as_coeff_matrices() + + if len(matrices) == 1: # situation like 2*X, matmul is just X + return coeff * matrices[0][i, j] + + indices = [None]*(len(matrices) + 1) + ind_ranges = [None]*(len(matrices) - 1) + indices[0] = i + indices[-1] = j + + def f(): + counter = 1 + while True: + yield Dummy("i_%i" % counter) + counter += 1 + + dummy_generator = kwargs.get("dummy_generator", f()) + + for i in range(1, len(matrices)): + indices[i] = next(dummy_generator) + + for i, arg in enumerate(matrices[:-1]): + ind_ranges[i] = arg.shape[1] - 1 + matrices = [arg._entry(indices[i], indices[i+1], dummy_generator=dummy_generator) for i, arg in enumerate(matrices)] + expr_in_sum = Mul.fromiter(matrices) + if any(v.has(ImmutableMatrix) for v in matrices): + expand = True + result = coeff*Sum( + expr_in_sum, + *zip(indices[1:-1], [0]*len(ind_ranges), ind_ranges) + ) + + # Don't waste time in result.doit() if the sum bounds are symbolic + if not any(isinstance(v, (Integer, int)) for v in ind_ranges): + expand = False + return result.doit() if expand else result + + def as_coeff_matrices(self): + scalars = [x for x in self.args if not x.is_Matrix] + matrices = [x for x in self.args if x.is_Matrix] + coeff = Mul(*scalars) + if coeff.is_commutative is False: + raise NotImplementedError("noncommutative scalars in MatMul are not supported.") + + return coeff, matrices + + def as_coeff_mmul(self): + coeff, matrices = self.as_coeff_matrices() + return coeff, MatMul(*matrices) + + def expand(self, **kwargs): + expanded = super(MatMul, self).expand(**kwargs) + return self._evaluate(expanded) + + def _eval_transpose(self): + """Transposition of matrix multiplication. + + Notes + ===== + + The following rules are applied. + + Transposition for matrix multiplied with another matrix: + `\\left(A B\\right)^{T} = B^{T} A^{T}` + + Transposition for matrix multiplied with scalar: + `\\left(c A\\right)^{T} = c A^{T}` + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Transpose + """ + coeff, matrices = self.as_coeff_matrices() + return MatMul( + coeff, *[transpose(arg) for arg in matrices[::-1]]).doit() + + def _eval_adjoint(self): + return MatMul(*[adjoint(arg) for arg in self.args[::-1]]).doit() + + def _eval_trace(self): + factor, mmul = self.as_coeff_mmul() + if factor != 1: + from .trace import trace + return factor * trace(mmul.doit()) + + def _eval_determinant(self): + from sympy.matrices.expressions.determinant import Determinant + factor, matrices = self.as_coeff_matrices() + square_matrices = only_squares(*matrices) + return factor**self.rows * Mul(*list(map(Determinant, square_matrices))) + + def _eval_inverse(self): + if all(arg.is_square for arg in self.args if isinstance(arg, MatrixExpr)): + return MatMul(*( + arg.inverse() if isinstance(arg, MatrixExpr) else arg**-1 + for arg in self.args[::-1] + ) + ).doit() + return Inverse(self) + + def doit(self, **hints): + deep = hints.get('deep', True) + if deep: + args = tuple(arg.doit(**hints) for arg in self.args) + else: + args = self.args + + # treat scalar*MatrixSymbol or scalar*MatPow separately + expr = canonicalize(MatMul(*args)) + return expr + + # Needed for partial compatibility with Mul + def args_cnc(self, cset=False, warn=True, **kwargs): + coeff_c = [x for x in self.args if x.is_commutative] + coeff_nc = [x for x in self.args if not x.is_commutative] + if cset: + clen = len(coeff_c) + coeff_c = set(coeff_c) + if clen and warn and len(coeff_c) != clen: + raise ValueError('repeated commutative arguments: %s' % + [ci for ci in coeff_c if list(self.args).count(ci) > 1]) + return [coeff_c, coeff_nc] + + def _eval_derivative_matrix_lines(self, x): + from .transpose import Transpose + with_x_ind = [i for i, arg in enumerate(self.args) if arg.has(x)] + lines = [] + for ind in with_x_ind: + left_args = self.args[:ind] + right_args = self.args[ind+1:] + + if right_args: + right_mat = MatMul.fromiter(right_args) + else: + right_mat = Identity(self.shape[1]) + if left_args: + left_rev = MatMul.fromiter([Transpose(i).doit() if i.is_Matrix else i for i in reversed(left_args)]) + else: + left_rev = Identity(self.shape[0]) + + d = self.args[ind]._eval_derivative_matrix_lines(x) + for i in d: + i.append_first(left_rev) + i.append_second(right_mat) + lines.append(i) + + return lines + +mul.register_handlerclass((Mul, MatMul), MatMul) + + +# Rules +def newmul(*args): + if args[0] == 1: + args = args[1:] + return new(MatMul, *args) + +def any_zeros(mul): + if any(arg.is_zero or (arg.is_Matrix and arg.is_ZeroMatrix) + for arg in mul.args): + matrices = [arg for arg in mul.args if arg.is_Matrix] + return ZeroMatrix(matrices[0].rows, matrices[-1].cols) + return mul + +def merge_explicit(matmul): + """ Merge explicit MatrixBase arguments + + >>> from sympy import MatrixSymbol, Matrix, MatMul, pprint + >>> from sympy.matrices.expressions.matmul import merge_explicit + >>> A = MatrixSymbol('A', 2, 2) + >>> B = Matrix([[1, 1], [1, 1]]) + >>> C = Matrix([[1, 2], [3, 4]]) + >>> X = MatMul(A, B, C) + >>> pprint(X) + [1 1] [1 2] + A*[ ]*[ ] + [1 1] [3 4] + >>> pprint(merge_explicit(X)) + [4 6] + A*[ ] + [4 6] + + >>> X = MatMul(B, A, C) + >>> pprint(X) + [1 1] [1 2] + [ ]*A*[ ] + [1 1] [3 4] + >>> pprint(merge_explicit(X)) + [1 1] [1 2] + [ ]*A*[ ] + [1 1] [3 4] + """ + if not any(isinstance(arg, MatrixBase) for arg in matmul.args): + return matmul + newargs = [] + last = matmul.args[0] + for arg in matmul.args[1:]: + if isinstance(arg, (MatrixBase, Number)) and isinstance(last, (MatrixBase, Number)): + last = last * arg + else: + newargs.append(last) + last = arg + newargs.append(last) + + return MatMul(*newargs) + +def remove_ids(mul): + """ Remove Identities from a MatMul + + This is a modified version of sympy.strategies.rm_id. + This is necessary because MatMul may contain both MatrixExprs and Exprs + as args. + + See Also + ======== + + sympy.strategies.rm_id + """ + # Separate Exprs from MatrixExprs in args + factor, mmul = mul.as_coeff_mmul() + # Apply standard rm_id for MatMuls + result = rm_id(lambda x: x.is_Identity is True)(mmul) + if result != mmul: + return newmul(factor, *result.args) # Recombine and return + else: + return mul + +def factor_in_front(mul): + factor, matrices = mul.as_coeff_matrices() + if factor != 1: + return newmul(factor, *matrices) + return mul + +def combine_powers(mul): + r"""Combine consecutive powers with the same base into one, e.g. + $$A \times A^2 \Rightarrow A^3$$ + + This also cancels out the possible matrix inverses using the + knowledgebase of :class:`~.Inverse`, e.g., + $$ Y \times X \times X^{-1} \Rightarrow Y $$ + """ + factor, args = mul.as_coeff_matrices() + new_args = [args[0]] + + for i in range(1, len(args)): + A = new_args[-1] + B = args[i] + + if isinstance(B, Inverse) and isinstance(B.arg, MatMul): + Bargs = B.arg.args + l = len(Bargs) + if list(Bargs) == new_args[-l:]: + new_args = new_args[:-l] + [Identity(B.shape[0])] + continue + + if isinstance(A, Inverse) and isinstance(A.arg, MatMul): + Aargs = A.arg.args + l = len(Aargs) + if list(Aargs) == args[i:i+l]: + identity = Identity(A.shape[0]) + new_args[-1] = identity + for j in range(i, i+l): + args[j] = identity + continue + + if A.is_square == False or B.is_square == False: + new_args.append(B) + continue + + if isinstance(A, MatPow): + A_base, A_exp = A.args + else: + A_base, A_exp = A, S.One + + if isinstance(B, MatPow): + B_base, B_exp = B.args + else: + B_base, B_exp = B, S.One + + if A_base == B_base: + new_exp = A_exp + B_exp + new_args[-1] = MatPow(A_base, new_exp).doit(deep=False) + continue + elif not isinstance(B_base, MatrixBase): + try: + B_base_inv = B_base.inverse() + except NonInvertibleMatrixError: + B_base_inv = None + if B_base_inv is not None and A_base == B_base_inv: + new_exp = A_exp - B_exp + new_args[-1] = MatPow(A_base, new_exp).doit(deep=False) + continue + new_args.append(B) + + return newmul(factor, *new_args) + +def combine_permutations(mul): + """Refine products of permutation matrices as the products of cycles. + """ + args = mul.args + l = len(args) + if l < 2: + return mul + + result = [args[0]] + for i in range(1, l): + A = result[-1] + B = args[i] + if isinstance(A, PermutationMatrix) and \ + isinstance(B, PermutationMatrix): + cycle_1 = A.args[0] + cycle_2 = B.args[0] + result[-1] = PermutationMatrix(cycle_1 * cycle_2) + else: + result.append(B) + + return MatMul(*result) + +def combine_one_matrices(mul): + """ + Combine products of OneMatrix + + e.g. OneMatrix(2, 3) * OneMatrix(3, 4) -> 3 * OneMatrix(2, 4) + """ + factor, args = mul.as_coeff_matrices() + new_args = [args[0]] + + for B in args[1:]: + A = new_args[-1] + if not isinstance(A, OneMatrix) or not isinstance(B, OneMatrix): + new_args.append(B) + continue + new_args.pop() + new_args.append(OneMatrix(A.shape[0], B.shape[1])) + factor *= A.shape[1] + + return newmul(factor, *new_args) + +def distribute_monom(mul): + """ + Simplify MatMul expressions but distributing + rational term to MatMul. + + e.g. 2*(A+B) -> 2*A + 2*B + """ + args = mul.args + if len(args) == 2: + from .matadd import MatAdd + if args[0].is_MatAdd and args[1].is_Rational: + return MatAdd(*[MatMul(mat, args[1]).doit() for mat in args[0].args]) + if args[1].is_MatAdd and args[0].is_Rational: + return MatAdd(*[MatMul(args[0], mat).doit() for mat in args[1].args]) + return mul + +rules = ( + distribute_monom, any_zeros, remove_ids, combine_one_matrices, combine_powers, unpack, rm_id(lambda x: x == 1), + merge_explicit, factor_in_front, flatten, combine_permutations) + +canonicalize = exhaust(typed({MatMul: do_one(*rules)})) + +def only_squares(*matrices): + """factor matrices only if they are square""" + if matrices[0].rows != matrices[-1].cols: + raise RuntimeError("Invalid matrices being multiplied") + out = [] + start = 0 + for i, M in enumerate(matrices): + if M.cols == matrices[start].rows: + out.append(MatMul(*matrices[start:i+1]).doit()) + start = i+1 + return out + + +def refine_MatMul(expr, assumptions): + """ + >>> from sympy import MatrixSymbol, Q, assuming, refine + >>> X = MatrixSymbol('X', 2, 2) + >>> expr = X * X.T + >>> print(expr) + X*X.T + >>> with assuming(Q.orthogonal(X)): + ... print(refine(expr)) + I + """ + newargs = [] + exprargs = [] + + for args in expr.args: + if args.is_Matrix: + exprargs.append(args) + else: + newargs.append(args) + + last = exprargs[0] + for arg in exprargs[1:]: + if arg == last.T and ask(Q.orthogonal(arg), assumptions): + last = Identity(arg.shape[0]) + elif arg == last.conjugate() and ask(Q.unitary(arg), assumptions): + last = Identity(arg.shape[0]) + else: + newargs.append(last) + last = arg + newargs.append(last) + + return MatMul(*newargs) + + +handlers_dict['MatMul'] = refine_MatMul diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/matpow.py b/phivenv/Lib/site-packages/sympy/matrices/expressions/matpow.py new file mode 100644 index 0000000000000000000000000000000000000000..b6472995e134e9e5ebfd28a901480665d1531275 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/expressions/matpow.py @@ -0,0 +1,150 @@ +from .matexpr import MatrixExpr +from .special import Identity +from sympy.core import S +from sympy.core.expr import ExprBuilder +from sympy.core.cache import cacheit +from sympy.core.power import Pow +from sympy.core.sympify import _sympify +from sympy.matrices import MatrixBase +from sympy.matrices.exceptions import NonSquareMatrixError + + +class MatPow(MatrixExpr): + def __new__(cls, base, exp, evaluate=False, **options): + base = _sympify(base) + if not base.is_Matrix: + raise TypeError("MatPow base should be a matrix") + + if base.is_square is False: + raise NonSquareMatrixError("Power of non-square matrix %s" % base) + + exp = _sympify(exp) + obj = super().__new__(cls, base, exp) + + if evaluate: + obj = obj.doit(deep=False) + + return obj + + @property + def base(self): + return self.args[0] + + @property + def exp(self): + return self.args[1] + + @property + def shape(self): + return self.base.shape + + @cacheit + def _get_explicit_matrix(self): + return self.base.as_explicit()**self.exp + + def _entry(self, i, j, **kwargs): + from sympy.matrices.expressions import MatMul + A = self.doit() + if isinstance(A, MatPow): + # We still have a MatPow, make an explicit MatMul out of it. + if A.exp.is_Integer and A.exp.is_positive: + A = MatMul(*[A.base for k in range(A.exp)]) + elif not self._is_shape_symbolic(): + return A._get_explicit_matrix()[i, j] + else: + # Leave the expression unevaluated: + from sympy.matrices.expressions.matexpr import MatrixElement + return MatrixElement(self, i, j) + return A[i, j] + + def doit(self, **hints): + if hints.get('deep', True): + base, exp = (arg.doit(**hints) for arg in self.args) + else: + base, exp = self.args + + # combine all powers, e.g. (A ** 2) ** 3 -> A ** 6 + while isinstance(base, MatPow): + exp *= base.args[1] + base = base.args[0] + + if isinstance(base, MatrixBase): + # Delegate + return base ** exp + + # Handle simple cases so that _eval_power() in MatrixExpr sub-classes can ignore them + if exp == S.One: + return base + if exp == S.Zero: + return Identity(base.rows) + if exp == S.NegativeOne: + from sympy.matrices.expressions import Inverse + return Inverse(base).doit(**hints) + + eval_power = getattr(base, '_eval_power', None) + if eval_power is not None: + return eval_power(exp) + + return MatPow(base, exp) + + def _eval_transpose(self): + base, exp = self.args + return MatPow(base.transpose(), exp) + + def _eval_adjoint(self): + base, exp = self.args + return MatPow(base.adjoint(), exp) + + def _eval_conjugate(self): + base, exp = self.args + return MatPow(base.conjugate(), exp) + + def _eval_derivative(self, x): + return Pow._eval_derivative(self, x) + + def _eval_derivative_matrix_lines(self, x): + from sympy.tensor.array.expressions.array_expressions import ArrayContraction + from ...tensor.array.expressions.array_expressions import ArrayTensorProduct + from .matmul import MatMul + from .inverse import Inverse + exp = self.exp + if self.base.shape == (1, 1) and not exp.has(x): + lr = self.base._eval_derivative_matrix_lines(x) + for i in lr: + subexpr = ExprBuilder( + ArrayContraction, + [ + ExprBuilder( + ArrayTensorProduct, + [ + Identity(1), + i._lines[0], + exp*self.base**(exp-1), + i._lines[1], + Identity(1), + ] + ), + (0, 3, 4), (5, 7, 8) + ], + validator=ArrayContraction._validate + ) + i._first_pointer_parent = subexpr.args[0].args + i._first_pointer_index = 0 + i._second_pointer_parent = subexpr.args[0].args + i._second_pointer_index = 4 + i._lines = [subexpr] + return lr + if (exp > 0) == True: + newexpr = MatMul.fromiter([self.base for i in range(exp)]) + elif (exp == -1) == True: + return Inverse(self.base)._eval_derivative_matrix_lines(x) + elif (exp < 0) == True: + newexpr = MatMul.fromiter([Inverse(self.base) for i in range(-exp)]) + elif (exp == 0) == True: + return self.doit()._eval_derivative_matrix_lines(x) + else: + raise NotImplementedError("cannot evaluate %s derived by %s" % (self, x)) + return newexpr._eval_derivative_matrix_lines(x) + + def _eval_inverse(self): + return MatPow(self.base, -self.exp) diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/permutation.py b/phivenv/Lib/site-packages/sympy/matrices/expressions/permutation.py new file mode 100644 index 0000000000000000000000000000000000000000..5634fa941a53d8890583fe61bb29bc34f4e6000d --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/expressions/permutation.py @@ -0,0 +1,303 @@ +from sympy.core import S +from sympy.core.sympify import _sympify +from sympy.functions import KroneckerDelta + +from .matexpr import MatrixExpr +from .special import ZeroMatrix, Identity, OneMatrix + + +class PermutationMatrix(MatrixExpr): + """A Permutation Matrix + + Parameters + ========== + + perm : Permutation + The permutation the matrix uses. + + The size of the permutation determines the matrix size. + + See the documentation of + :class:`sympy.combinatorics.permutations.Permutation` for + the further information of how to create a permutation object. + + Examples + ======== + + >>> from sympy import Matrix, PermutationMatrix + >>> from sympy.combinatorics import Permutation + + Creating a permutation matrix: + + >>> p = Permutation(1, 2, 0) + >>> P = PermutationMatrix(p) + >>> P = P.as_explicit() + >>> P + Matrix([ + [0, 1, 0], + [0, 0, 1], + [1, 0, 0]]) + + Permuting a matrix row and column: + + >>> M = Matrix([0, 1, 2]) + >>> Matrix(P*M) + Matrix([ + [1], + [2], + [0]]) + + >>> Matrix(M.T*P) + Matrix([[2, 0, 1]]) + + See Also + ======== + + sympy.combinatorics.permutations.Permutation + """ + + def __new__(cls, perm): + from sympy.combinatorics.permutations import Permutation + + perm = _sympify(perm) + if not isinstance(perm, Permutation): + raise ValueError( + "{} must be a SymPy Permutation instance.".format(perm)) + + return super().__new__(cls, perm) + + @property + def shape(self): + size = self.args[0].size + return (size, size) + + @property + def is_Identity(self): + return self.args[0].is_Identity + + def doit(self, **hints): + if self.is_Identity: + return Identity(self.rows) + return self + + def _entry(self, i, j, **kwargs): + perm = self.args[0] + return KroneckerDelta(perm.apply(i), j) + + def _eval_power(self, exp): + return PermutationMatrix(self.args[0] ** exp).doit() + + def _eval_inverse(self): + return PermutationMatrix(self.args[0] ** -1) + + _eval_transpose = _eval_adjoint = _eval_inverse + + def _eval_determinant(self): + sign = self.args[0].signature() + if sign == 1: + return S.One + elif sign == -1: + return S.NegativeOne + raise NotImplementedError + + def _eval_rewrite_as_BlockDiagMatrix(self, *args, **kwargs): + from sympy.combinatorics.permutations import Permutation + from .blockmatrix import BlockDiagMatrix + + perm = self.args[0] + full_cyclic_form = perm.full_cyclic_form + + cycles_picks = [] + + # Stage 1. Decompose the cycles into the blockable form. + a, b, c = 0, 0, 0 + flag = False + for cycle in full_cyclic_form: + l = len(cycle) + m = max(cycle) + + if not flag: + if m + 1 > a + l: + flag = True + temp = [cycle] + b = m + c = l + else: + cycles_picks.append([cycle]) + a += l + + else: + if m > b: + if m + 1 == a + c + l: + temp.append(cycle) + cycles_picks.append(temp) + flag = False + a = m+1 + else: + b = m + temp.append(cycle) + c += l + else: + if b + 1 == a + c + l: + temp.append(cycle) + cycles_picks.append(temp) + flag = False + a = b+1 + else: + temp.append(cycle) + c += l + + # Stage 2. Normalize each decomposed cycles and build matrix. + p = 0 + args = [] + for pick in cycles_picks: + new_cycles = [] + l = 0 + for cycle in pick: + new_cycle = [i - p for i in cycle] + new_cycles.append(new_cycle) + l += len(cycle) + p += l + perm = Permutation(new_cycles) + mat = PermutationMatrix(perm) + args.append(mat) + + return BlockDiagMatrix(*args) + + +class MatrixPermute(MatrixExpr): + r"""Symbolic representation for permuting matrix rows or columns. + + Parameters + ========== + + perm : Permutation, PermutationMatrix + The permutation to use for permuting the matrix. + The permutation can be resized to the suitable one, + + axis : 0 or 1 + The axis to permute alongside. + If `0`, it will permute the matrix rows. + If `1`, it will permute the matrix columns. + + Notes + ===== + + This follows the same notation used in + :meth:`sympy.matrices.matrixbase.MatrixBase.permute`. + + Examples + ======== + + >>> from sympy import Matrix, MatrixPermute + >>> from sympy.combinatorics import Permutation + + Permuting the matrix rows: + + >>> p = Permutation(1, 2, 0) + >>> A = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + >>> B = MatrixPermute(A, p, axis=0) + >>> B.as_explicit() + Matrix([ + [4, 5, 6], + [7, 8, 9], + [1, 2, 3]]) + + Permuting the matrix columns: + + >>> B = MatrixPermute(A, p, axis=1) + >>> B.as_explicit() + Matrix([ + [2, 3, 1], + [5, 6, 4], + [8, 9, 7]]) + + See Also + ======== + + sympy.matrices.matrixbase.MatrixBase.permute + """ + def __new__(cls, mat, perm, axis=S.Zero): + from sympy.combinatorics.permutations import Permutation + + mat = _sympify(mat) + if not mat.is_Matrix: + raise ValueError( + "{} must be a SymPy matrix instance.".format(perm)) + + perm = _sympify(perm) + if isinstance(perm, PermutationMatrix): + perm = perm.args[0] + + if not isinstance(perm, Permutation): + raise ValueError( + "{} must be a SymPy Permutation or a PermutationMatrix " \ + "instance".format(perm)) + + axis = _sympify(axis) + if axis not in (0, 1): + raise ValueError("The axis must be 0 or 1.") + + mat_size = mat.shape[axis] + if mat_size != perm.size: + try: + perm = perm.resize(mat_size) + except ValueError: + raise ValueError( + "Size does not match between the permutation {} " + "and the matrix {} threaded over the axis {} " + "and cannot be converted." + .format(perm, mat, axis)) + + return super().__new__(cls, mat, perm, axis) + + def doit(self, deep=True, **hints): + mat, perm, axis = self.args + + if deep: + mat = mat.doit(deep=deep, **hints) + perm = perm.doit(deep=deep, **hints) + + if perm.is_Identity: + return mat + + if mat.is_Identity: + if axis is S.Zero: + return PermutationMatrix(perm) + elif axis is S.One: + return PermutationMatrix(perm**-1) + + if isinstance(mat, (ZeroMatrix, OneMatrix)): + return mat + + if isinstance(mat, MatrixPermute) and mat.args[2] == axis: + return MatrixPermute(mat.args[0], perm * mat.args[1], axis) + + return self + + @property + def shape(self): + return self.args[0].shape + + def _entry(self, i, j, **kwargs): + mat, perm, axis = self.args + + if axis == 0: + return mat[perm.apply(i), j] + elif axis == 1: + return mat[i, perm.apply(j)] + + def _eval_rewrite_as_MatMul(self, *args, **kwargs): + from .matmul import MatMul + + mat, perm, axis = self.args + + deep = kwargs.get("deep", True) + + if deep: + mat = mat.rewrite(MatMul) + + if axis == 0: + return MatMul(PermutationMatrix(perm), mat) + elif axis == 1: + return MatMul(mat, PermutationMatrix(perm**-1)) diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/sets.py b/phivenv/Lib/site-packages/sympy/matrices/expressions/sets.py new file mode 100644 index 0000000000000000000000000000000000000000..ab4930ea8f1b058977a8dd1abdc62f1f5e2195c1 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/expressions/sets.py @@ -0,0 +1,68 @@ +from sympy.core.assumptions import check_assumptions +from sympy.core.logic import fuzzy_and +from sympy.core.sympify import _sympify +from sympy.matrices.kind import MatrixKind +from sympy.sets.sets import Set, SetKind +from sympy.core.kind import NumberKind +from .matexpr import MatrixExpr + + +class MatrixSet(Set): + """ + MatrixSet represents the set of matrices with ``shape = (n, m)`` over the + given set. + + Examples + ======== + + >>> from sympy.matrices import MatrixSet + >>> from sympy import S, I, Matrix + >>> M = MatrixSet(2, 2, set=S.Reals) + >>> X = Matrix([[1, 2], [3, 4]]) + >>> X in M + True + >>> X = Matrix([[1, 2], [I, 4]]) + >>> X in M + False + + """ + is_empty = False + + def __new__(cls, n, m, set): + n, m, set = _sympify(n), _sympify(m), _sympify(set) + cls._check_dim(n) + cls._check_dim(m) + if not isinstance(set, Set): + raise TypeError("{} should be an instance of Set.".format(set)) + return Set.__new__(cls, n, m, set) + + @property + def shape(self): + return self.args[:2] + + @property + def set(self): + return self.args[2] + + def _contains(self, other): + if not isinstance(other, MatrixExpr): + raise TypeError("{} should be an instance of MatrixExpr.".format(other)) + if other.shape != self.shape: + are_symbolic = any(_sympify(x).is_Symbol for x in other.shape + self.shape) + if are_symbolic: + return None + return False + return fuzzy_and(self.set.contains(x) for x in other) + + @classmethod + def _check_dim(cls, dim): + """Helper function to check invalid matrix dimensions""" + ok = not dim.is_Float and check_assumptions( + dim, integer=True, nonnegative=True) + if ok is False: + raise ValueError( + "The dimension specification {} should be " + "a nonnegative integer.".format(dim)) + + def _kind(self): + return SetKind(MatrixKind(NumberKind)) diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/slice.py b/phivenv/Lib/site-packages/sympy/matrices/expressions/slice.py new file mode 100644 index 0000000000000000000000000000000000000000..1904b49f29c503fb4c0c909532f8342fb0f4b135 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/expressions/slice.py @@ -0,0 +1,114 @@ +from sympy.matrices.expressions.matexpr import MatrixExpr +from sympy.core.basic import Basic +from sympy.core.containers import Tuple +from sympy.functions.elementary.integers import floor + +def normalize(i, parentsize): + if isinstance(i, slice): + i = (i.start, i.stop, i.step) + if not isinstance(i, (tuple, list, Tuple)): + if (i < 0) == True: + i += parentsize + i = (i, i+1, 1) + i = list(i) + if len(i) == 2: + i.append(1) + start, stop, step = i + start = start or 0 + if stop is None: + stop = parentsize + if (start < 0) == True: + start += parentsize + if (stop < 0) == True: + stop += parentsize + step = step or 1 + + if ((stop - start) * step < 1) == True: + raise IndexError() + + return (start, stop, step) + +class MatrixSlice(MatrixExpr): + """ A MatrixSlice of a Matrix Expression + + Examples + ======== + + >>> from sympy import MatrixSlice, ImmutableMatrix + >>> M = ImmutableMatrix(4, 4, range(16)) + >>> M + Matrix([ + [ 0, 1, 2, 3], + [ 4, 5, 6, 7], + [ 8, 9, 10, 11], + [12, 13, 14, 15]]) + + >>> B = MatrixSlice(M, (0, 2), (2, 4)) + >>> ImmutableMatrix(B) + Matrix([ + [2, 3], + [6, 7]]) + """ + parent = property(lambda self: self.args[0]) + rowslice = property(lambda self: self.args[1]) + colslice = property(lambda self: self.args[2]) + + def __new__(cls, parent, rowslice, colslice): + rowslice = normalize(rowslice, parent.shape[0]) + colslice = normalize(colslice, parent.shape[1]) + if not (len(rowslice) == len(colslice) == 3): + raise IndexError() + if ((0 > rowslice[0]) == True or + (parent.shape[0] < rowslice[1]) == True or + (0 > colslice[0]) == True or + (parent.shape[1] < colslice[1]) == True): + raise IndexError() + if isinstance(parent, MatrixSlice): + return mat_slice_of_slice(parent, rowslice, colslice) + return Basic.__new__(cls, parent, Tuple(*rowslice), Tuple(*colslice)) + + @property + def shape(self): + rows = self.rowslice[1] - self.rowslice[0] + rows = rows if self.rowslice[2] == 1 else floor(rows/self.rowslice[2]) + cols = self.colslice[1] - self.colslice[0] + cols = cols if self.colslice[2] == 1 else floor(cols/self.colslice[2]) + return rows, cols + + def _entry(self, i, j, **kwargs): + return self.parent._entry(i*self.rowslice[2] + self.rowslice[0], + j*self.colslice[2] + self.colslice[0], + **kwargs) + + @property + def on_diag(self): + return self.rowslice == self.colslice + + +def slice_of_slice(s, t): + start1, stop1, step1 = s + start2, stop2, step2 = t + + start = start1 + start2*step1 + step = step1 * step2 + stop = start1 + step1*stop2 + + if stop > stop1: + raise IndexError() + + return start, stop, step + + +def mat_slice_of_slice(parent, rowslice, colslice): + """ Collapse nested matrix slices + + >>> from sympy import MatrixSymbol + >>> X = MatrixSymbol('X', 10, 10) + >>> X[:, 1:5][5:8, :] + X[5:8, 1:5] + >>> X[1:9:2, 2:6][1:3, 2] + X[3:7:2, 4:5] + """ + row = slice_of_slice(parent.rowslice, rowslice) + col = slice_of_slice(parent.colslice, colslice) + return MatrixSlice(parent.parent, row, col) diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/special.py b/phivenv/Lib/site-packages/sympy/matrices/expressions/special.py new file mode 100644 index 0000000000000000000000000000000000000000..d1e426f16ada0e4245b644867974b41b6f86b5cc --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/expressions/special.py @@ -0,0 +1,299 @@ +from sympy.assumptions.ask import ask, Q +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.sympify import _sympify +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.matrices.exceptions import NonInvertibleMatrixError +from .matexpr import MatrixExpr + + +class ZeroMatrix(MatrixExpr): + """The Matrix Zero 0 - additive identity + + Examples + ======== + + >>> from sympy import MatrixSymbol, ZeroMatrix + >>> A = MatrixSymbol('A', 3, 5) + >>> Z = ZeroMatrix(3, 5) + >>> A + Z + A + >>> Z*A.T + 0 + """ + is_ZeroMatrix = True + + def __new__(cls, m, n): + m, n = _sympify(m), _sympify(n) + cls._check_dim(m) + cls._check_dim(n) + + return super().__new__(cls, m, n) + + @property + def shape(self): + return (self.args[0], self.args[1]) + + def _eval_power(self, exp): + # exp = -1, 0, 1 are already handled at this stage + if (exp < 0) == True: + raise NonInvertibleMatrixError("Matrix det == 0; not invertible") + return self + + def _eval_transpose(self): + return ZeroMatrix(self.cols, self.rows) + + def _eval_adjoint(self): + return ZeroMatrix(self.cols, self.rows) + + def _eval_trace(self): + return S.Zero + + def _eval_determinant(self): + return S.Zero + + def _eval_inverse(self): + raise NonInvertibleMatrixError("Matrix det == 0; not invertible.") + + def _eval_as_real_imag(self): + return (self, self) + + def _eval_conjugate(self): + return self + + def _entry(self, i, j, **kwargs): + return S.Zero + + +class GenericZeroMatrix(ZeroMatrix): + """ + A zero matrix without a specified shape + + This exists primarily so MatAdd() with no arguments can return something + meaningful. + """ + def __new__(cls): + # super(ZeroMatrix, cls) instead of super(GenericZeroMatrix, cls) + # because ZeroMatrix.__new__ doesn't have the same signature + return super(ZeroMatrix, cls).__new__(cls) + + @property + def rows(self): + raise TypeError("GenericZeroMatrix does not have a specified shape") + + @property + def cols(self): + raise TypeError("GenericZeroMatrix does not have a specified shape") + + @property + def shape(self): + raise TypeError("GenericZeroMatrix does not have a specified shape") + + # Avoid Matrix.__eq__ which might call .shape + def __eq__(self, other): + return isinstance(other, GenericZeroMatrix) + + def __ne__(self, other): + return not (self == other) + + def __hash__(self): + return super().__hash__() + + + +class Identity(MatrixExpr): + """The Matrix Identity I - multiplicative identity + + Examples + ======== + + >>> from sympy import Identity, MatrixSymbol + >>> A = MatrixSymbol('A', 3, 5) + >>> I = Identity(3) + >>> I*A + A + """ + + is_Identity = True + + def __new__(cls, n): + n = _sympify(n) + cls._check_dim(n) + + return super().__new__(cls, n) + + @property + def rows(self): + return self.args[0] + + @property + def cols(self): + return self.args[0] + + @property + def shape(self): + return (self.args[0], self.args[0]) + + @property + def is_square(self): + return True + + def _eval_transpose(self): + return self + + def _eval_trace(self): + return self.rows + + def _eval_inverse(self): + return self + + def _eval_as_real_imag(self): + return (self, ZeroMatrix(*self.shape)) + + def _eval_conjugate(self): + return self + + def _eval_adjoint(self): + return self + + def _entry(self, i, j, **kwargs): + eq = Eq(i, j) + if eq is S.true: + return S.One + elif eq is S.false: + return S.Zero + return KroneckerDelta(i, j, (0, self.cols-1)) + + def _eval_determinant(self): + return S.One + + def _eval_power(self, exp): + return self + + +class GenericIdentity(Identity): + """ + An identity matrix without a specified shape + + This exists primarily so MatMul() with no arguments can return something + meaningful. + """ + def __new__(cls): + # super(Identity, cls) instead of super(GenericIdentity, cls) because + # Identity.__new__ doesn't have the same signature + return super(Identity, cls).__new__(cls) + + @property + def rows(self): + raise TypeError("GenericIdentity does not have a specified shape") + + @property + def cols(self): + raise TypeError("GenericIdentity does not have a specified shape") + + @property + def shape(self): + raise TypeError("GenericIdentity does not have a specified shape") + + @property + def is_square(self): + return True + + # Avoid Matrix.__eq__ which might call .shape + def __eq__(self, other): + return isinstance(other, GenericIdentity) + + def __ne__(self, other): + return not (self == other) + + def __hash__(self): + return super().__hash__() + + +class OneMatrix(MatrixExpr): + """ + Matrix whose all entries are ones. + """ + def __new__(cls, m, n, evaluate=False): + m, n = _sympify(m), _sympify(n) + cls._check_dim(m) + cls._check_dim(n) + + if evaluate: + condition = Eq(m, 1) & Eq(n, 1) + if condition == True: + return Identity(1) + + obj = super().__new__(cls, m, n) + return obj + + @property + def shape(self): + return self._args + + @property + def is_Identity(self): + return self._is_1x1() == True + + def as_explicit(self): + from sympy.matrices.immutable import ImmutableDenseMatrix + return ImmutableDenseMatrix.ones(*self.shape) + + def doit(self, **hints): + args = self.args + if hints.get('deep', True): + args = [a.doit(**hints) for a in args] + return self.func(*args, evaluate=True) + + def _eval_power(self, exp): + # exp = -1, 0, 1 are already handled at this stage + if self._is_1x1() == True: + return Identity(1) + if (exp < 0) == True: + raise NonInvertibleMatrixError("Matrix det == 0; not invertible") + if ask(Q.integer(exp)): + return self.shape[0] ** (exp - 1) * OneMatrix(*self.shape) + return super()._eval_power(exp) + + def _eval_transpose(self): + return OneMatrix(self.cols, self.rows) + + def _eval_adjoint(self): + return OneMatrix(self.cols, self.rows) + + def _eval_trace(self): + return S.One*self.rows + + def _is_1x1(self): + """Returns true if the matrix is known to be 1x1""" + shape = self.shape + return Eq(shape[0], 1) & Eq(shape[1], 1) + + def _eval_determinant(self): + condition = self._is_1x1() + if condition == True: + return S.One + elif condition == False: + return S.Zero + else: + from sympy.matrices.expressions.determinant import Determinant + return Determinant(self) + + def _eval_inverse(self): + condition = self._is_1x1() + if condition == True: + return Identity(1) + elif condition == False: + raise NonInvertibleMatrixError("Matrix det == 0; not invertible.") + else: + from .inverse import Inverse + return Inverse(self) + + def _eval_as_real_imag(self): + return (self, ZeroMatrix(*self.shape)) + + def _eval_conjugate(self): + return self + + def _entry(self, i, j, **kwargs): + return S.One diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__init__.py b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..acbf59164dbefa9bf0592b744cd9f0fffe96d835 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_adjoint.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_adjoint.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4af1b68652193098e7723bc395aaa2bb81c36edd Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_adjoint.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_applyfunc.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_applyfunc.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b0fb5aa99e350f1f42b6c6c13b0fea8edf8ac37 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_applyfunc.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_blockmatrix.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_blockmatrix.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e985f166a4d4060b541cc0f77f9d338c5dad9542 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_blockmatrix.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_companion.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_companion.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..461ff552f9a255d7d50d0205a0a95d4e2ae1fbff Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_companion.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_derivatives.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_derivatives.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..021a3b242e7627e5124cd698b2b244fa7a3ff921 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_derivatives.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_determinant.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_determinant.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32e8fdab56a5179ce110a9eee31962f4ce4a4e68 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_determinant.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_diagonal.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_diagonal.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7de40f2d95cf8904b8e5baa2820494747bdb2f2e Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_diagonal.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_dotproduct.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_dotproduct.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..247cf6cd060dd854473e45e84b96edc8e992812f Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_dotproduct.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_factorizations.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_factorizations.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d26fc62c94c393653783ce4fcf27f713aafac65 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_factorizations.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_fourier.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_fourier.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..516b0c841ae904f2291cb48bd78d1a7d041735d7 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_fourier.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_funcmatrix.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_funcmatrix.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d577ff67db7adbde341afe1f2d6bff817ed24c98 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_funcmatrix.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_hadamard.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_hadamard.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4f92e736376a6061aa27cd4f862c6e180d1dbc1 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_hadamard.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_indexing.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_indexing.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..980cd411ff865f87d362e51599c0a1693c9f5c73 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_indexing.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_inverse.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_inverse.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..33af4a556c820b3d1869002fbe7715aeb3170515 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_inverse.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_kronecker.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_kronecker.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..305d1546edbd60f3db4b26db1a351c6818cf6888 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_kronecker.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_matadd.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_matadd.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4fda0517c1f8f5d8e58ec57dfe059582d5b5158a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_matadd.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_matexpr.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_matexpr.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70739a1499ab192d01d42029cf82b43c7227b881 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_matexpr.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_matmul.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_matmul.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e5d117cce90540b9fd2e09c1938dfe9656386e86 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_matmul.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_matpow.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_matpow.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..855c46e6f1ba47d050dd1c46bd8087061d182f7c Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_matpow.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_permutation.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_permutation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..89224edba9234686d584b52676f10fecad5bfaa0 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_permutation.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_sets.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_sets.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4eb484aa17641fcac9f9528b7cb5e865a4f40936 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_sets.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_slice.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_slice.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1033facf1ae5dd756036437f8e734437f97f1fad Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_slice.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_special.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_special.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7656495f8cf5c9dd3b138eac8985964d684044ef Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_special.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_trace.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_trace.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67d3bd9f0ccc4b6e9cc5602064d59251a44c9344 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_trace.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_transpose.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_transpose.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c47461047dacef2fbe7ad58a25d5ac98e7c29253 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/__pycache__/test_transpose.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_adjoint.py b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_adjoint.py new file mode 100644 index 0000000000000000000000000000000000000000..7106b5740b1dc7c32f2c6f5ecb9d286b5e1dd222 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_adjoint.py @@ -0,0 +1,34 @@ +from sympy.core import symbols, S +from sympy.functions import adjoint, conjugate, transpose +from sympy.matrices.expressions import MatrixSymbol, Adjoint, trace, Transpose +from sympy.matrices import eye, Matrix + +n, m, l, k, p = symbols('n m l k p', integer=True) +A = MatrixSymbol('A', n, m) +B = MatrixSymbol('B', m, l) +C = MatrixSymbol('C', n, n) + + +def test_adjoint(): + Sq = MatrixSymbol('Sq', n, n) + + assert Adjoint(A).shape == (m, n) + assert Adjoint(A*B).shape == (l, n) + assert adjoint(Adjoint(A)) == A + assert isinstance(Adjoint(Adjoint(A)), Adjoint) + + assert conjugate(Adjoint(A)) == Transpose(A) + assert transpose(Adjoint(A)) == Adjoint(Transpose(A)) + + assert Adjoint(eye(3)).doit() == eye(3) + + assert Adjoint(S(5)).doit() == S(5) + + assert Adjoint(Matrix([[1, 2], [3, 4]])).doit() == Matrix([[1, 3], [2, 4]]) + + assert adjoint(trace(Sq)) == conjugate(trace(Sq)) + assert trace(adjoint(Sq)) == conjugate(trace(Sq)) + + assert Adjoint(Sq)[0, 1] == conjugate(Sq[1, 0]) + + assert Adjoint(A*B).doit() == Adjoint(B) * Adjoint(A) diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_applyfunc.py b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_applyfunc.py new file mode 100644 index 0000000000000000000000000000000000000000..d98732e2751e53938d96d7ea56c916e6fee4578e --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_applyfunc.py @@ -0,0 +1,118 @@ +from sympy.core.symbol import symbols, Dummy +from sympy.matrices.expressions.applyfunc import ElementwiseApplyFunction +from sympy.core.function import Lambda +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.trigonometric import sin +from sympy.matrices.dense import Matrix +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.expressions.matmul import MatMul +from sympy.simplify.simplify import simplify + + +X = MatrixSymbol("X", 3, 3) +Y = MatrixSymbol("Y", 3, 3) + +k = symbols("k") +Xk = MatrixSymbol("X", k, k) + +Xd = X.as_explicit() + +x, y, z, t = symbols("x y z t") + + +def test_applyfunc_matrix(): + x = Dummy('x') + double = Lambda(x, x**2) + + expr = ElementwiseApplyFunction(double, Xd) + assert isinstance(expr, ElementwiseApplyFunction) + assert expr.doit() == Xd.applyfunc(lambda x: x**2) + assert expr.shape == (3, 3) + assert expr.func(*expr.args) == expr + assert simplify(expr) == expr + assert expr[0, 0] == double(Xd[0, 0]) + + expr = ElementwiseApplyFunction(double, X) + assert isinstance(expr, ElementwiseApplyFunction) + assert isinstance(expr.doit(), ElementwiseApplyFunction) + assert expr == X.applyfunc(double) + assert expr.func(*expr.args) == expr + + expr = ElementwiseApplyFunction(exp, X*Y) + assert expr.expr == X*Y + assert expr.function.dummy_eq(Lambda(x, exp(x))) + assert expr.dummy_eq((X*Y).applyfunc(exp)) + assert expr.func(*expr.args) == expr + + assert isinstance(X*expr, MatMul) + assert (X*expr).shape == (3, 3) + Z = MatrixSymbol("Z", 2, 3) + assert (Z*expr).shape == (2, 3) + + expr = ElementwiseApplyFunction(exp, Z.T)*ElementwiseApplyFunction(exp, Z) + assert expr.shape == (3, 3) + expr = ElementwiseApplyFunction(exp, Z)*ElementwiseApplyFunction(exp, Z.T) + assert expr.shape == (2, 2) + + M = Matrix([[x, y], [z, t]]) + expr = ElementwiseApplyFunction(sin, M) + assert isinstance(expr, ElementwiseApplyFunction) + assert expr.function.dummy_eq(Lambda(x, sin(x))) + assert expr.expr == M + assert expr.doit() == M.applyfunc(sin) + assert expr.doit() == Matrix([[sin(x), sin(y)], [sin(z), sin(t)]]) + assert expr.func(*expr.args) == expr + + expr = ElementwiseApplyFunction(double, Xk) + assert expr.doit() == expr + assert expr.subs(k, 2).shape == (2, 2) + assert (expr*expr).shape == (k, k) + M = MatrixSymbol("M", k, t) + expr2 = M.T*expr*M + assert isinstance(expr2, MatMul) + assert expr2.args[1] == expr + assert expr2.shape == (t, t) + expr3 = expr*M + assert expr3.shape == (k, t) + + expr1 = ElementwiseApplyFunction(lambda x: x+1, Xk) + expr2 = ElementwiseApplyFunction(lambda x: x, Xk) + assert expr1 != expr2 + + +def test_applyfunc_entry(): + + af = X.applyfunc(sin) + assert af[0, 0] == sin(X[0, 0]) + + af = Xd.applyfunc(sin) + assert af[0, 0] == sin(X[0, 0]) + + +def test_applyfunc_as_explicit(): + + af = X.applyfunc(sin) + assert af.as_explicit() == Matrix([ + [sin(X[0, 0]), sin(X[0, 1]), sin(X[0, 2])], + [sin(X[1, 0]), sin(X[1, 1]), sin(X[1, 2])], + [sin(X[2, 0]), sin(X[2, 1]), sin(X[2, 2])], + ]) + + +def test_applyfunc_transpose(): + + af = Xk.applyfunc(sin) + assert af.T.dummy_eq(Xk.T.applyfunc(sin)) + + +def test_applyfunc_shape_11_matrices(): + M = MatrixSymbol("M", 1, 1) + + double = Lambda(x, x*2) + + expr = M.applyfunc(sin) + assert isinstance(expr, ElementwiseApplyFunction) + + expr = M.applyfunc(double) + assert isinstance(expr, MatMul) + assert expr == 2*M diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_blockmatrix.py b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_blockmatrix.py new file mode 100644 index 0000000000000000000000000000000000000000..1d4893cd9a4b3e47dd8e84db33031f7f6f3201fd --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_blockmatrix.py @@ -0,0 +1,469 @@ +from sympy.matrices.expressions.trace import Trace +from sympy.testing.pytest import raises, slow +from sympy.matrices.expressions.blockmatrix import ( + block_collapse, bc_matmul, bc_block_plus_ident, BlockDiagMatrix, + BlockMatrix, bc_dist, bc_matadd, bc_transpose, bc_inverse, + blockcut, reblock_2x2, deblock) +from sympy.matrices.expressions import ( + MatrixSymbol, Identity, trace, det, ZeroMatrix, OneMatrix) +from sympy.matrices.expressions.inverse import Inverse +from sympy.matrices.expressions.matpow import MatPow +from sympy.matrices.expressions.transpose import Transpose +from sympy.matrices.exceptions import NonInvertibleMatrixError +from sympy.matrices import ( + Matrix, ImmutableMatrix, ImmutableSparseMatrix, zeros) +from sympy.core import Tuple, Expr, S, Function +from sympy.core.symbol import Symbol, symbols +from sympy.functions import transpose, im, re + +i, j, k, l, m, n, p = symbols('i:n, p', integer=True) +A = MatrixSymbol('A', n, n) +B = MatrixSymbol('B', n, n) +C = MatrixSymbol('C', n, n) +D = MatrixSymbol('D', n, n) +G = MatrixSymbol('G', n, n) +H = MatrixSymbol('H', n, n) +b1 = BlockMatrix([[G, H]]) +b2 = BlockMatrix([[G], [H]]) + +def test_bc_matmul(): + assert bc_matmul(H*b1*b2*G) == BlockMatrix([[(H*G*G + H*H*H)*G]]) + +def test_bc_matadd(): + assert bc_matadd(BlockMatrix([[G, H]]) + BlockMatrix([[H, H]])) == \ + BlockMatrix([[G+H, H+H]]) + +def test_bc_transpose(): + assert bc_transpose(Transpose(BlockMatrix([[A, B], [C, D]]))) == \ + BlockMatrix([[A.T, C.T], [B.T, D.T]]) + +def test_bc_dist_diag(): + A = MatrixSymbol('A', n, n) + B = MatrixSymbol('B', m, m) + C = MatrixSymbol('C', l, l) + X = BlockDiagMatrix(A, B, C) + + assert bc_dist(X+X).equals(BlockDiagMatrix(2*A, 2*B, 2*C)) + +def test_block_plus_ident(): + A = MatrixSymbol('A', n, n) + B = MatrixSymbol('B', n, m) + C = MatrixSymbol('C', m, n) + D = MatrixSymbol('D', m, m) + X = BlockMatrix([[A, B], [C, D]]) + Z = MatrixSymbol('Z', n + m, n + m) + assert bc_block_plus_ident(X + Identity(m + n) + Z) == \ + BlockDiagMatrix(Identity(n), Identity(m)) + X + Z + +def test_BlockMatrix(): + A = MatrixSymbol('A', n, m) + B = MatrixSymbol('B', n, k) + C = MatrixSymbol('C', l, m) + D = MatrixSymbol('D', l, k) + M = MatrixSymbol('M', m + k, p) + N = MatrixSymbol('N', l + n, k + m) + X = BlockMatrix(Matrix([[A, B], [C, D]])) + + assert X.__class__(*X.args) == X + + # block_collapse does nothing on normal inputs + E = MatrixSymbol('E', n, m) + assert block_collapse(A + 2*E) == A + 2*E + F = MatrixSymbol('F', m, m) + assert block_collapse(E.T*A*F) == E.T*A*F + + assert X.shape == (l + n, k + m) + assert X.blockshape == (2, 2) + assert transpose(X) == BlockMatrix(Matrix([[A.T, C.T], [B.T, D.T]])) + assert transpose(X).shape == X.shape[::-1] + + # Test that BlockMatrices and MatrixSymbols can still mix + assert (X*M).is_MatMul + assert X._blockmul(M).is_MatMul + assert (X*M).shape == (n + l, p) + assert (X + N).is_MatAdd + assert X._blockadd(N).is_MatAdd + assert (X + N).shape == X.shape + + E = MatrixSymbol('E', m, 1) + F = MatrixSymbol('F', k, 1) + + Y = BlockMatrix(Matrix([[E], [F]])) + + assert (X*Y).shape == (l + n, 1) + assert block_collapse(X*Y).blocks[0, 0] == A*E + B*F + assert block_collapse(X*Y).blocks[1, 0] == C*E + D*F + + # block_collapse passes down into container objects, transposes, and inverse + assert block_collapse(transpose(X*Y)) == transpose(block_collapse(X*Y)) + assert block_collapse(Tuple(X*Y, 2*X)) == ( + block_collapse(X*Y), block_collapse(2*X)) + + # Make sure that MatrixSymbols will enter 1x1 BlockMatrix if it simplifies + Ab = BlockMatrix([[A]]) + Z = MatrixSymbol('Z', *A.shape) + assert block_collapse(Ab + Z) == A + Z + +def test_block_collapse_explicit_matrices(): + A = Matrix([[1, 2], [3, 4]]) + assert block_collapse(BlockMatrix([[A]])) == A + + A = ImmutableSparseMatrix([[1, 2], [3, 4]]) + assert block_collapse(BlockMatrix([[A]])) == A + +def test_issue_17624(): + a = MatrixSymbol("a", 2, 2) + z = ZeroMatrix(2, 2) + b = BlockMatrix([[a, z], [z, z]]) + assert block_collapse(b * b) == BlockMatrix([[a**2, z], [z, z]]) + assert block_collapse(b * b * b) == BlockMatrix([[a**3, z], [z, z]]) + +def test_issue_18618(): + A = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + assert A == Matrix(BlockDiagMatrix(A)) + +def test_BlockMatrix_trace(): + A, B, C, D = [MatrixSymbol(s, 3, 3) for s in 'ABCD'] + X = BlockMatrix([[A, B], [C, D]]) + assert trace(X) == trace(A) + trace(D) + assert trace(BlockMatrix([ZeroMatrix(n, n)])) == 0 + +def test_BlockMatrix_Determinant(): + A, B, C, D = [MatrixSymbol(s, 3, 3) for s in 'ABCD'] + X = BlockMatrix([[A, B], [C, D]]) + from sympy.assumptions.ask import Q + from sympy.assumptions.assume import assuming + with assuming(Q.invertible(A)): + assert det(X) == det(A) * det(X.schur('A')) + + assert isinstance(det(X), Expr) + assert det(BlockMatrix([A])) == det(A) + assert det(BlockMatrix([ZeroMatrix(n, n)])) == 0 + +def test_squareBlockMatrix(): + A = MatrixSymbol('A', n, n) + B = MatrixSymbol('B', n, m) + C = MatrixSymbol('C', m, n) + D = MatrixSymbol('D', m, m) + X = BlockMatrix([[A, B], [C, D]]) + Y = BlockMatrix([[A]]) + + assert X.is_square + + Q = X + Identity(m + n) + assert (block_collapse(Q) == + BlockMatrix([[A + Identity(n), B], [C, D + Identity(m)]])) + + assert (X + MatrixSymbol('Q', n + m, n + m)).is_MatAdd + assert (X * MatrixSymbol('Q', n + m, n + m)).is_MatMul + + assert block_collapse(Y.I) == A.I + + assert isinstance(X.inverse(), Inverse) + + assert not X.is_Identity + + Z = BlockMatrix([[Identity(n), B], [C, D]]) + assert not Z.is_Identity + + +def test_BlockMatrix_2x2_inverse_symbolic(): + A = MatrixSymbol('A', n, m) + B = MatrixSymbol('B', n, k - m) + C = MatrixSymbol('C', k - n, m) + D = MatrixSymbol('D', k - n, k - m) + X = BlockMatrix([[A, B], [C, D]]) + assert X.is_square and X.shape == (k, k) + assert isinstance(block_collapse(X.I), Inverse) # Can't invert when none of the blocks is square + + # test code path where only A is invertible + A = MatrixSymbol('A', n, n) + B = MatrixSymbol('B', n, m) + C = MatrixSymbol('C', m, n) + D = ZeroMatrix(m, m) + X = BlockMatrix([[A, B], [C, D]]) + assert block_collapse(X.inverse()) == BlockMatrix([ + [A.I + A.I * B * X.schur('A').I * C * A.I, -A.I * B * X.schur('A').I], + [-X.schur('A').I * C * A.I, X.schur('A').I], + ]) + + # test code path where only B is invertible + A = MatrixSymbol('A', n, m) + B = MatrixSymbol('B', n, n) + C = ZeroMatrix(m, m) + D = MatrixSymbol('D', m, n) + X = BlockMatrix([[A, B], [C, D]]) + assert block_collapse(X.inverse()) == BlockMatrix([ + [-X.schur('B').I * D * B.I, X.schur('B').I], + [B.I + B.I * A * X.schur('B').I * D * B.I, -B.I * A * X.schur('B').I], + ]) + + # test code path where only C is invertible + A = MatrixSymbol('A', n, m) + B = ZeroMatrix(n, n) + C = MatrixSymbol('C', m, m) + D = MatrixSymbol('D', m, n) + X = BlockMatrix([[A, B], [C, D]]) + assert block_collapse(X.inverse()) == BlockMatrix([ + [-C.I * D * X.schur('C').I, C.I + C.I * D * X.schur('C').I * A * C.I], + [X.schur('C').I, -X.schur('C').I * A * C.I], + ]) + + # test code path where only D is invertible + A = ZeroMatrix(n, n) + B = MatrixSymbol('B', n, m) + C = MatrixSymbol('C', m, n) + D = MatrixSymbol('D', m, m) + X = BlockMatrix([[A, B], [C, D]]) + assert block_collapse(X.inverse()) == BlockMatrix([ + [X.schur('D').I, -X.schur('D').I * B * D.I], + [-D.I * C * X.schur('D').I, D.I + D.I * C * X.schur('D').I * B * D.I], + ]) + + +def test_BlockMatrix_2x2_inverse_numeric(): + """Test 2x2 block matrix inversion numerically for all 4 formulas""" + M = Matrix([[1, 2], [3, 4]]) + # rank deficient matrices that have full rank when two of them combined + D1 = Matrix([[1, 2], [2, 4]]) + D2 = Matrix([[1, 3], [3, 9]]) + D3 = Matrix([[1, 4], [4, 16]]) + assert D1.rank() == D2.rank() == D3.rank() == 1 + assert (D1 + D2).rank() == (D2 + D3).rank() == (D3 + D1).rank() == 2 + + # Only A is invertible + K = BlockMatrix([[M, D1], [D2, D3]]) + assert block_collapse(K.inv()).as_explicit() == K.as_explicit().inv() + # Only B is invertible + K = BlockMatrix([[D1, M], [D2, D3]]) + assert block_collapse(K.inv()).as_explicit() == K.as_explicit().inv() + # Only C is invertible + K = BlockMatrix([[D1, D2], [M, D3]]) + assert block_collapse(K.inv()).as_explicit() == K.as_explicit().inv() + # Only D is invertible + K = BlockMatrix([[D1, D2], [D3, M]]) + assert block_collapse(K.inv()).as_explicit() == K.as_explicit().inv() + + +@slow +def test_BlockMatrix_3x3_symbolic(): + # Only test one of these, instead of all permutations, because it's slow + rowblocksizes = (n, m, k) + colblocksizes = (m, k, n) + K = BlockMatrix([ + [MatrixSymbol('M%s%s' % (rows, cols), rows, cols) for cols in colblocksizes] + for rows in rowblocksizes + ]) + collapse = block_collapse(K.I) + assert isinstance(collapse, BlockMatrix) + + +def test_BlockDiagMatrix(): + A = MatrixSymbol('A', n, n) + B = MatrixSymbol('B', m, m) + C = MatrixSymbol('C', l, l) + M = MatrixSymbol('M', n + m + l, n + m + l) + + X = BlockDiagMatrix(A, B, C) + Y = BlockDiagMatrix(A, 2*B, 3*C) + + assert X.blocks[1, 1] == B + assert X.shape == (n + m + l, n + m + l) + assert all(X.blocks[i, j].is_ZeroMatrix if i != j else X.blocks[i, j] in [A, B, C] + for i in range(3) for j in range(3)) + assert X.__class__(*X.args) == X + assert X.get_diag_blocks() == (A, B, C) + + assert isinstance(block_collapse(X.I * X), Identity) + + assert bc_matmul(X*X) == BlockDiagMatrix(A*A, B*B, C*C) + assert block_collapse(X*X) == BlockDiagMatrix(A*A, B*B, C*C) + #XXX: should be == ?? + assert block_collapse(X + X).equals(BlockDiagMatrix(2*A, 2*B, 2*C)) + assert block_collapse(X*Y) == BlockDiagMatrix(A*A, 2*B*B, 3*C*C) + assert block_collapse(X + Y) == BlockDiagMatrix(2*A, 3*B, 4*C) + + # Ensure that BlockDiagMatrices can still interact with normal MatrixExprs + assert (X*(2*M)).is_MatMul + assert (X + (2*M)).is_MatAdd + + assert (X._blockmul(M)).is_MatMul + assert (X._blockadd(M)).is_MatAdd + +def test_BlockDiagMatrix_nonsquare(): + A = MatrixSymbol('A', n, m) + B = MatrixSymbol('B', k, l) + X = BlockDiagMatrix(A, B) + assert X.shape == (n + k, m + l) + assert X.shape == (n + k, m + l) + assert X.rowblocksizes == [n, k] + assert X.colblocksizes == [m, l] + C = MatrixSymbol('C', n, m) + D = MatrixSymbol('D', k, l) + Y = BlockDiagMatrix(C, D) + assert block_collapse(X + Y) == BlockDiagMatrix(A + C, B + D) + assert block_collapse(X * Y.T) == BlockDiagMatrix(A * C.T, B * D.T) + raises(NonInvertibleMatrixError, lambda: BlockDiagMatrix(A, C.T).inverse()) + +def test_BlockDiagMatrix_determinant(): + A = MatrixSymbol('A', n, n) + B = MatrixSymbol('B', m, m) + assert det(BlockDiagMatrix()) == 1 + assert det(BlockDiagMatrix(A)) == det(A) + assert det(BlockDiagMatrix(A, B)) == det(A) * det(B) + + # non-square blocks + C = MatrixSymbol('C', m, n) + D = MatrixSymbol('D', n, m) + assert det(BlockDiagMatrix(C, D)) == 0 + +def test_BlockDiagMatrix_trace(): + assert trace(BlockDiagMatrix()) == 0 + assert trace(BlockDiagMatrix(ZeroMatrix(n, n))) == 0 + A = MatrixSymbol('A', n, n) + assert trace(BlockDiagMatrix(A)) == trace(A) + B = MatrixSymbol('B', m, m) + assert trace(BlockDiagMatrix(A, B)) == trace(A) + trace(B) + + # non-square blocks + C = MatrixSymbol('C', m, n) + D = MatrixSymbol('D', n, m) + assert isinstance(trace(BlockDiagMatrix(C, D)), Trace) + +def test_BlockDiagMatrix_transpose(): + A = MatrixSymbol('A', n, m) + B = MatrixSymbol('B', k, l) + assert transpose(BlockDiagMatrix()) == BlockDiagMatrix() + assert transpose(BlockDiagMatrix(A)) == BlockDiagMatrix(A.T) + assert transpose(BlockDiagMatrix(A, B)) == BlockDiagMatrix(A.T, B.T) + +def test_issue_2460(): + bdm1 = BlockDiagMatrix(Matrix([i]), Matrix([j])) + bdm2 = BlockDiagMatrix(Matrix([k]), Matrix([l])) + assert block_collapse(bdm1 + bdm2) == BlockDiagMatrix(Matrix([i + k]), Matrix([j + l])) + +def test_blockcut(): + A = MatrixSymbol('A', n, m) + B = blockcut(A, (n/2, n/2), (m/2, m/2)) + assert B == BlockMatrix([[A[:n/2, :m/2], A[:n/2, m/2:]], + [A[n/2:, :m/2], A[n/2:, m/2:]]]) + + M = ImmutableMatrix(4, 4, range(16)) + B = blockcut(M, (2, 2), (2, 2)) + assert M == ImmutableMatrix(B) + + B = blockcut(M, (1, 3), (2, 2)) + assert ImmutableMatrix(B.blocks[0, 1]) == ImmutableMatrix([[2, 3]]) + +def test_reblock_2x2(): + B = BlockMatrix([[MatrixSymbol('A_%d%d'%(i,j), 2, 2) + for j in range(3)] + for i in range(3)]) + assert B.blocks.shape == (3, 3) + + BB = reblock_2x2(B) + assert BB.blocks.shape == (2, 2) + + assert B.shape == BB.shape + assert B.as_explicit() == BB.as_explicit() + +def test_deblock(): + B = BlockMatrix([[MatrixSymbol('A_%d%d'%(i,j), n, n) + for j in range(4)] + for i in range(4)]) + + assert deblock(reblock_2x2(B)) == B + +def test_block_collapse_type(): + bm1 = BlockDiagMatrix(ImmutableMatrix([1]), ImmutableMatrix([2])) + bm2 = BlockDiagMatrix(ImmutableMatrix([3]), ImmutableMatrix([4])) + + assert bm1.T.__class__ == BlockDiagMatrix + assert block_collapse(bm1 - bm2).__class__ == BlockDiagMatrix + assert block_collapse(Inverse(bm1)).__class__ == BlockDiagMatrix + assert block_collapse(Transpose(bm1)).__class__ == BlockDiagMatrix + assert bc_transpose(Transpose(bm1)).__class__ == BlockDiagMatrix + assert bc_inverse(Inverse(bm1)).__class__ == BlockDiagMatrix + +def test_invalid_block_matrix(): + raises(ValueError, lambda: BlockMatrix([ + [Identity(2), Identity(5)], + ])) + raises(ValueError, lambda: BlockMatrix([ + [Identity(n), Identity(m)], + ])) + raises(ValueError, lambda: BlockMatrix([ + [ZeroMatrix(n, n), ZeroMatrix(n, n)], + [ZeroMatrix(n, n - 1), ZeroMatrix(n, n + 1)], + ])) + raises(ValueError, lambda: BlockMatrix([ + [ZeroMatrix(n - 1, n), ZeroMatrix(n, n)], + [ZeroMatrix(n + 1, n), ZeroMatrix(n, n)], + ])) + +def test_block_lu_decomposition(): + A = MatrixSymbol('A', n, n) + B = MatrixSymbol('B', n, m) + C = MatrixSymbol('C', m, n) + D = MatrixSymbol('D', m, m) + X = BlockMatrix([[A, B], [C, D]]) + + #LDU decomposition + L, D, U = X.LDUdecomposition() + assert block_collapse(L*D*U) == X + + #UDL decomposition + U, D, L = X.UDLdecomposition() + assert block_collapse(U*D*L) == X + + #LU decomposition + L, U = X.LUdecomposition() + assert block_collapse(L*U) == X + +def test_issue_21866(): + n = 10 + I = Identity(n) + O = ZeroMatrix(n, n) + A = BlockMatrix([[ I, O, O, O ], + [ O, I, O, O ], + [ O, O, I, O ], + [ I, O, O, I ]]) + Ainv = block_collapse(A.inv()) + AinvT = BlockMatrix([[ I, O, O, O ], + [ O, I, O, O ], + [ O, O, I, O ], + [ -I, O, O, I ]]) + assert Ainv == AinvT + + +def test_adjoint_and_special_matrices(): + A = Identity(3) + B = OneMatrix(3, 2) + C = ZeroMatrix(2, 3) + D = Identity(2) + X = BlockMatrix([[A, B], [C, D]]) + X2 = BlockMatrix([[A, S.ImaginaryUnit*B], [C, D]]) + assert X.adjoint() == BlockMatrix([[A, ZeroMatrix(3, 2)], [OneMatrix(2, 3), D]]) + assert re(X) == X + assert X2.adjoint() == BlockMatrix([[A, ZeroMatrix(3, 2)], [-S.ImaginaryUnit*OneMatrix(2, 3), D]]) + assert im(X2) == BlockMatrix([[ZeroMatrix(3, 3), OneMatrix(3, 2)], [ZeroMatrix(2, 3), ZeroMatrix(2, 2)]]) + + +def test_block_matrix_derivative(): + x = symbols('x') + A = Matrix(3, 3, [Function(f'a{i}')(x) for i in range(9)]) + bc = BlockMatrix([[A[:2, :2], A[:2, 2]], [A[2, :2], A[2:, 2]]]) + assert Matrix(bc.diff(x)) - A.diff(x) == zeros(3, 3) + + +def test_transpose_inverse_commute(): + n = Symbol('n') + I = Identity(n) + Z = ZeroMatrix(n, n) + A = BlockMatrix([[I, Z], [Z, I]]) + + assert block_collapse(A.transpose().inverse()) == A + assert block_collapse(A.inverse().transpose()) == A + + assert block_collapse(MatPow(A.transpose(), -2)) == MatPow(A, -2) + assert block_collapse(MatPow(A, -2).transpose()) == MatPow(A, -2) diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_companion.py b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_companion.py new file mode 100644 index 0000000000000000000000000000000000000000..edc592c29098eddce0c6352806aa73d5d889e999 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_companion.py @@ -0,0 +1,48 @@ +from sympy.core.expr import unchanged +from sympy.core.symbol import Symbol, symbols +from sympy.matrices.immutable import ImmutableDenseMatrix +from sympy.matrices.expressions.companion import CompanionMatrix +from sympy.polys.polytools import Poly +from sympy.testing.pytest import raises + + +def test_creation(): + x = Symbol('x') + y = Symbol('y') + raises(ValueError, lambda: CompanionMatrix(1)) + raises(ValueError, lambda: CompanionMatrix(Poly([1], x))) + raises(ValueError, lambda: CompanionMatrix(Poly([2, 1], x))) + raises(ValueError, lambda: CompanionMatrix(Poly(x*y, [x, y]))) + assert unchanged(CompanionMatrix, Poly([1, 2, 3], x)) + + +def test_shape(): + c0, c1, c2 = symbols('c0:3') + x = Symbol('x') + assert CompanionMatrix(Poly([1, c0], x)).shape == (1, 1) + assert CompanionMatrix(Poly([1, c1, c0], x)).shape == (2, 2) + assert CompanionMatrix(Poly([1, c2, c1, c0], x)).shape == (3, 3) + + +def test_entry(): + c0, c1, c2 = symbols('c0:3') + x = Symbol('x') + A = CompanionMatrix(Poly([1, c2, c1, c0], x)) + assert A[0, 0] == 0 + assert A[1, 0] == 1 + assert A[1, 1] == 0 + assert A[2, 1] == 1 + assert A[0, 2] == -c0 + assert A[1, 2] == -c1 + assert A[2, 2] == -c2 + + +def test_as_explicit(): + c0, c1, c2 = symbols('c0:3') + x = Symbol('x') + assert CompanionMatrix(Poly([1, c0], x)).as_explicit() == \ + ImmutableDenseMatrix([-c0]) + assert CompanionMatrix(Poly([1, c1, c0], x)).as_explicit() == \ + ImmutableDenseMatrix([[0, -c0], [1, -c1]]) + assert CompanionMatrix(Poly([1, c2, c1, c0], x)).as_explicit() == \ + ImmutableDenseMatrix([[0, 0, -c0], [1, 0, -c1], [0, 1, -c2]]) diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_derivatives.py b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_derivatives.py new file mode 100644 index 0000000000000000000000000000000000000000..77484c994dda62eea9771a76afd8b3caeadacb93 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_derivatives.py @@ -0,0 +1,477 @@ +""" +Some examples have been taken from: + +http://www.math.uwaterloo.ca/~hwolkowi//matrixcookbook.pdf +""" +from sympy import KroneckerProduct +from sympy.combinatorics import Permutation +from sympy.concrete.summations import Sum +from sympy.core.numbers import Rational +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin, tan) +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.matrices.expressions.determinant import Determinant +from sympy.matrices.expressions.diagonal import DiagMatrix +from sympy.matrices.expressions.hadamard import (HadamardPower, HadamardProduct, hadamard_product) +from sympy.matrices.expressions.inverse import Inverse +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.expressions.special import OneMatrix +from sympy.matrices.expressions.trace import Trace +from sympy.matrices.expressions.matadd import MatAdd +from sympy.matrices.expressions.matmul import MatMul +from sympy.matrices.expressions.special import (Identity, ZeroMatrix) +from sympy.tensor.array.array_derivatives import ArrayDerivative +from sympy.matrices.expressions import hadamard_power +from sympy.tensor.array.expressions.array_expressions import ArrayAdd, ArrayTensorProduct, PermuteDims + +i, j, k = symbols("i j k") +m, n = symbols("m n") + +X = MatrixSymbol("X", k, k) +x = MatrixSymbol("x", k, 1) +y = MatrixSymbol("y", k, 1) + +A = MatrixSymbol("A", k, k) +B = MatrixSymbol("B", k, k) +C = MatrixSymbol("C", k, k) +D = MatrixSymbol("D", k, k) + +a = MatrixSymbol("a", k, 1) +b = MatrixSymbol("b", k, 1) +c = MatrixSymbol("c", k, 1) +d = MatrixSymbol("d", k, 1) + + +KDelta = lambda i, j: KroneckerDelta(i, j, (0, k-1)) + + +def _check_derivative_with_explicit_matrix(expr, x, diffexpr, dim=2): + # TODO: this is commented because it slows down the tests. + return + + expr = expr.xreplace({k: dim}) + x = x.xreplace({k: dim}) + diffexpr = diffexpr.xreplace({k: dim}) + + expr = expr.as_explicit() + x = x.as_explicit() + diffexpr = diffexpr.as_explicit() + + assert expr.diff(x).reshape(*diffexpr.shape).tomatrix() == diffexpr + + +def test_matrix_derivative_by_scalar(): + assert A.diff(i) == ZeroMatrix(k, k) + assert (A*(X + B)*c).diff(i) == ZeroMatrix(k, 1) + assert x.diff(i) == ZeroMatrix(k, 1) + assert (x.T*y).diff(i) == ZeroMatrix(1, 1) + assert (x*x.T).diff(i) == ZeroMatrix(k, k) + assert (x + y).diff(i) == ZeroMatrix(k, 1) + assert hadamard_power(x, 2).diff(i) == ZeroMatrix(k, 1) + assert hadamard_power(x, i).diff(i).dummy_eq( + HadamardProduct(x.applyfunc(log), HadamardPower(x, i))) + assert hadamard_product(x, y).diff(i) == ZeroMatrix(k, 1) + assert hadamard_product(i*OneMatrix(k, 1), x, y).diff(i) == hadamard_product(x, y) + assert (i*x).diff(i) == x + assert (sin(i)*A*B*x).diff(i) == cos(i)*A*B*x + assert x.applyfunc(sin).diff(i) == ZeroMatrix(k, 1) + assert Trace(i**2*X).diff(i) == 2*i*Trace(X) + + mu = symbols("mu") + expr = (2*mu*x) + assert expr.diff(x) == 2*mu*Identity(k) + + +def test_one_matrix(): + assert MatMul(x.T, OneMatrix(k, 1)).diff(x) == OneMatrix(k, 1) + + +def test_matrix_derivative_non_matrix_result(): + # This is a 4-dimensional array: + I = Identity(k) + AdA = PermuteDims(ArrayTensorProduct(I, I), Permutation(3)(1, 2)) + assert A.diff(A) == AdA + assert A.T.diff(A) == PermuteDims(ArrayTensorProduct(I, I), Permutation(3)(1, 2, 3)) + assert (2*A).diff(A) == PermuteDims(ArrayTensorProduct(2*I, I), Permutation(3)(1, 2)) + assert MatAdd(A, A).diff(A) == ArrayAdd(AdA, AdA) + assert (A + B).diff(A) == AdA + + +def test_matrix_derivative_trivial_cases(): + # Cookbook example 33: + # TODO: find a way to represent a four-dimensional zero-array: + assert X.diff(A) == ArrayDerivative(X, A) + + +def test_matrix_derivative_with_inverse(): + + # Cookbook example 61: + expr = a.T*Inverse(X)*b + assert expr.diff(X) == -Inverse(X).T*a*b.T*Inverse(X).T + + # Cookbook example 62: + expr = Determinant(Inverse(X)) + # Not implemented yet: + # assert expr.diff(X) == -Determinant(X.inv())*(X.inv()).T + + # Cookbook example 63: + expr = Trace(A*Inverse(X)*B) + assert expr.diff(X) == -(X**(-1)*B*A*X**(-1)).T + + # Cookbook example 64: + expr = Trace(Inverse(X + A)) + assert expr.diff(X) == -(Inverse(X + A)).T**2 + + +def test_matrix_derivative_vectors_and_scalars(): + + assert x.diff(x) == Identity(k) + assert x[i, 0].diff(x[m, 0]).doit() == KDelta(m, i) + + assert x.T.diff(x) == Identity(k) + + # Cookbook example 69: + expr = x.T*a + assert expr.diff(x) == a + assert expr[0, 0].diff(x[m, 0]).doit() == a[m, 0] + expr = a.T*x + assert expr.diff(x) == a + + # Cookbook example 70: + expr = a.T*X*b + assert expr.diff(X) == a*b.T + + # Cookbook example 71: + expr = a.T*X.T*b + assert expr.diff(X) == b*a.T + + # Cookbook example 72: + expr = a.T*X*a + assert expr.diff(X) == a*a.T + expr = a.T*X.T*a + assert expr.diff(X) == a*a.T + + # Cookbook example 77: + expr = b.T*X.T*X*c + assert expr.diff(X) == X*b*c.T + X*c*b.T + + # Cookbook example 78: + expr = (B*x + b).T*C*(D*x + d) + assert expr.diff(x) == B.T*C*(D*x + d) + D.T*C.T*(B*x + b) + + # Cookbook example 81: + expr = x.T*B*x + assert expr.diff(x) == B*x + B.T*x + + # Cookbook example 82: + expr = b.T*X.T*D*X*c + assert expr.diff(X) == D.T*X*b*c.T + D*X*c*b.T + + # Cookbook example 83: + expr = (X*b + c).T*D*(X*b + c) + assert expr.diff(X) == D*(X*b + c)*b.T + D.T*(X*b + c)*b.T + assert str(expr[0, 0].diff(X[m, n]).doit()) == \ + 'b[n, 0]*Sum((c[_i_1, 0] + Sum(X[_i_1, _i_3]*b[_i_3, 0], (_i_3, 0, k - 1)))*D[_i_1, m], (_i_1, 0, k - 1)) + Sum((c[_i_2, 0] + Sum(X[_i_2, _i_4]*b[_i_4, 0], (_i_4, 0, k - 1)))*D[m, _i_2]*b[n, 0], (_i_2, 0, k - 1))' + + # See https://github.com/sympy/sympy/issues/16504#issuecomment-1018339957 + expr = x*x.T*x + I = Identity(k) + assert expr.diff(x) == KroneckerProduct(I, x.T*x) + 2*x*x.T + + +def test_matrix_derivatives_of_traces(): + + expr = Trace(A)*A + I = Identity(k) + assert expr.diff(A) == ArrayAdd(ArrayTensorProduct(I, A), PermuteDims(ArrayTensorProduct(Trace(A)*I, I), Permutation(3)(1, 2))) + assert expr[i, j].diff(A[m, n]).doit() == ( + KDelta(i, m)*KDelta(j, n)*Trace(A) + + KDelta(m, n)*A[i, j] + ) + + ## First order: + + # Cookbook example 99: + expr = Trace(X) + assert expr.diff(X) == Identity(k) + assert expr.rewrite(Sum).diff(X[m, n]).doit() == KDelta(m, n) + + # Cookbook example 100: + expr = Trace(X*A) + assert expr.diff(X) == A.T + assert expr.rewrite(Sum).diff(X[m, n]).doit() == A[n, m] + + # Cookbook example 101: + expr = Trace(A*X*B) + assert expr.diff(X) == A.T*B.T + assert expr.rewrite(Sum).diff(X[m, n]).doit().dummy_eq((A.T*B.T)[m, n]) + + # Cookbook example 102: + expr = Trace(A*X.T*B) + assert expr.diff(X) == B*A + + # Cookbook example 103: + expr = Trace(X.T*A) + assert expr.diff(X) == A + + # Cookbook example 104: + expr = Trace(A*X.T) + assert expr.diff(X) == A + + # Cookbook example 105: + # TODO: TensorProduct is not supported + #expr = Trace(TensorProduct(A, X)) + #assert expr.diff(X) == Trace(A)*Identity(k) + + ## Second order: + + # Cookbook example 106: + expr = Trace(X**2) + assert expr.diff(X) == 2*X.T + + # Cookbook example 107: + expr = Trace(X**2*B) + assert expr.diff(X) == (X*B + B*X).T + expr = Trace(MatMul(X, X, B)) + assert expr.diff(X) == (X*B + B*X).T + + # Cookbook example 108: + expr = Trace(X.T*B*X) + assert expr.diff(X) == B*X + B.T*X + + # Cookbook example 109: + expr = Trace(B*X*X.T) + assert expr.diff(X) == B*X + B.T*X + + # Cookbook example 110: + expr = Trace(X*X.T*B) + assert expr.diff(X) == B*X + B.T*X + + # Cookbook example 111: + expr = Trace(X*B*X.T) + assert expr.diff(X) == X*B.T + X*B + + # Cookbook example 112: + expr = Trace(B*X.T*X) + assert expr.diff(X) == X*B.T + X*B + + # Cookbook example 113: + expr = Trace(X.T*X*B) + assert expr.diff(X) == X*B.T + X*B + + # Cookbook example 114: + expr = Trace(A*X*B*X) + assert expr.diff(X) == A.T*X.T*B.T + B.T*X.T*A.T + + # Cookbook example 115: + expr = Trace(X.T*X) + assert expr.diff(X) == 2*X + expr = Trace(X*X.T) + assert expr.diff(X) == 2*X + + # Cookbook example 116: + expr = Trace(B.T*X.T*C*X*B) + assert expr.diff(X) == C.T*X*B*B.T + C*X*B*B.T + + # Cookbook example 117: + expr = Trace(X.T*B*X*C) + assert expr.diff(X) == B*X*C + B.T*X*C.T + + # Cookbook example 118: + expr = Trace(A*X*B*X.T*C) + assert expr.diff(X) == A.T*C.T*X*B.T + C*A*X*B + + # Cookbook example 119: + expr = Trace((A*X*B + C)*(A*X*B + C).T) + assert expr.diff(X) == 2*A.T*(A*X*B + C)*B.T + + # Cookbook example 120: + # TODO: no support for TensorProduct. + # expr = Trace(TensorProduct(X, X)) + # expr = Trace(X)*Trace(X) + # expr.diff(X) == 2*Trace(X)*Identity(k) + + # Higher Order + + # Cookbook example 121: + expr = Trace(X**k) + #assert expr.diff(X) == k*(X**(k-1)).T + + # Cookbook example 122: + expr = Trace(A*X**k) + #assert expr.diff(X) == # Needs indices + + # Cookbook example 123: + expr = Trace(B.T*X.T*C*X*X.T*C*X*B) + assert expr.diff(X) == C*X*X.T*C*X*B*B.T + C.T*X*B*B.T*X.T*C.T*X + C*X*B*B.T*X.T*C*X + C.T*X*X.T*C.T*X*B*B.T + + # Other + + # Cookbook example 124: + expr = Trace(A*X**(-1)*B) + assert expr.diff(X) == -Inverse(X).T*A.T*B.T*Inverse(X).T + + # Cookbook example 125: + expr = Trace(Inverse(X.T*C*X)*A) + # Warning: result in the cookbook is equivalent if B and C are symmetric: + assert expr.diff(X) == - X.inv().T*A.T*X.inv()*C.inv().T*X.inv().T - X.inv().T*A*X.inv()*C.inv()*X.inv().T + + # Cookbook example 126: + expr = Trace((X.T*C*X).inv()*(X.T*B*X)) + assert expr.diff(X) == -2*C*X*(X.T*C*X).inv()*X.T*B*X*(X.T*C*X).inv() + 2*B*X*(X.T*C*X).inv() + + # Cookbook example 127: + expr = Trace((A + X.T*C*X).inv()*(X.T*B*X)) + # Warning: result in the cookbook is equivalent if B and C are symmetric: + assert expr.diff(X) == B*X*Inverse(A + X.T*C*X) - C*X*Inverse(A + X.T*C*X)*X.T*B*X*Inverse(A + X.T*C*X) - C.T*X*Inverse(A.T + (C*X).T*X)*X.T*B.T*X*Inverse(A.T + (C*X).T*X) + B.T*X*Inverse(A.T + (C*X).T*X) + + +def test_derivatives_of_complicated_matrix_expr(): + expr = a.T*(A*X*(X.T*B + X*A) + B.T*X.T*(a*b.T*(X*D*X.T + X*(X.T*B + A*X)*D*B - X.T*C.T*A)*B + B*(X*D.T + B*A*X*A.T - 3*X*D))*B + 42*X*B*X.T*A.T*(X + X.T))*b + result = (B*(B*A*X*A.T - 3*X*D + X*D.T) + a*b.T*(X*(A*X + X.T*B)*D*B + X*D*X.T - X.T*C.T*A)*B)*B*b*a.T*B.T + B**2*b*a.T*B.T*X.T*a*b.T*X*D + 42*A*X*B.T*X.T*a*b.T + B*D*B**3*b*a.T*B.T*X.T*a*b.T*X + B*b*a.T*A*X + a*b.T*(42*X + 42*X.T)*A*X*B.T + b*a.T*X*B*a*b.T*B.T**2*X*D.T + b*a.T*X*B*a*b.T*B.T**3*D.T*(B.T*X + X.T*A.T) + 42*b*a.T*X*B*X.T*A.T + A.T*(42*X + 42*X.T)*b*a.T*X*B + A.T*B.T**2*X*B*a*b.T*B.T*A + A.T*a*b.T*(A.T*X.T + B.T*X) + A.T*X.T*b*a.T*X*B*a*b.T*B.T**3*D.T + B.T*X*B*a*b.T*B.T*D - 3*B.T*X*B*a*b.T*B.T*D.T - C.T*A*B**2*b*a.T*B.T*X.T*a*b.T + X.T*A.T*a*b.T*A.T + assert expr.diff(X) == result + + +def test_mixed_deriv_mixed_expressions(): + + expr = 3*Trace(A) + assert expr.diff(A) == 3*Identity(k) + + expr = k + deriv = expr.diff(A) + assert isinstance(deriv, ZeroMatrix) + assert deriv == ZeroMatrix(k, k) + + expr = Trace(A)**2 + assert expr.diff(A) == (2*Trace(A))*Identity(k) + + expr = Trace(A)*A + I = Identity(k) + assert expr.diff(A) == ArrayAdd(ArrayTensorProduct(I, A), PermuteDims(ArrayTensorProduct(Trace(A)*I, I), Permutation(3)(1, 2))) + + expr = Trace(Trace(A)*A) + assert expr.diff(A) == (2*Trace(A))*Identity(k) + + expr = Trace(Trace(Trace(A)*A)*A) + assert expr.diff(A) == (3*Trace(A)**2)*Identity(k) + + +def test_derivatives_matrix_norms(): + + expr = x.T*y + assert expr.diff(x) == y + assert expr[0, 0].diff(x[m, 0]).doit() == y[m, 0] + + expr = (x.T*y)**S.Half + assert expr.diff(x) == y/(2*sqrt(x.T*y)) + + expr = (x.T*x)**S.Half + assert expr.diff(x) == x*(x.T*x)**Rational(-1, 2) + + expr = (c.T*a*x.T*b)**S.Half + assert expr.diff(x) == b*a.T*c/sqrt(c.T*a*x.T*b)/2 + + expr = (c.T*a*x.T*b)**Rational(1, 3) + assert expr.diff(x) == b*a.T*c*(c.T*a*x.T*b)**Rational(-2, 3)/3 + + expr = (a.T*X*b)**S.Half + assert expr.diff(X) == a/(2*sqrt(a.T*X*b))*b.T + + expr = d.T*x*(a.T*X*b)**S.Half*y.T*c + assert expr.diff(X) == a/(2*sqrt(a.T*X*b))*x.T*d*y.T*c*b.T + + +def test_derivatives_elementwise_applyfunc(): + + expr = x.applyfunc(tan) + assert expr.diff(x).dummy_eq( + DiagMatrix(x.applyfunc(lambda x: tan(x)**2 + 1))) + assert expr[i, 0].diff(x[m, 0]).doit() == (tan(x[i, 0])**2 + 1)*KDelta(i, m) + _check_derivative_with_explicit_matrix(expr, x, expr.diff(x)) + + expr = (i**2*x).applyfunc(sin) + assert expr.diff(i).dummy_eq( + HadamardProduct((2*i)*x, (i**2*x).applyfunc(cos))) + assert expr[i, 0].diff(i).doit() == 2*i*x[i, 0]*cos(i**2*x[i, 0]) + _check_derivative_with_explicit_matrix(expr, i, expr.diff(i)) + + expr = (log(i)*A*B).applyfunc(sin) + assert expr.diff(i).dummy_eq( + HadamardProduct(A*B/i, (log(i)*A*B).applyfunc(cos))) + _check_derivative_with_explicit_matrix(expr, i, expr.diff(i)) + + expr = A*x.applyfunc(exp) + # TODO: restore this result (currently returning the transpose): + # assert expr.diff(x).dummy_eq(DiagMatrix(x.applyfunc(exp))*A.T) + _check_derivative_with_explicit_matrix(expr, x, expr.diff(x)) + + expr = x.T*A*x + k*y.applyfunc(sin).T*x + assert expr.diff(x).dummy_eq(A.T*x + A*x + k*y.applyfunc(sin)) + _check_derivative_with_explicit_matrix(expr, x, expr.diff(x)) + + expr = x.applyfunc(sin).T*y + # TODO: restore (currently returning the transpose): + # assert expr.diff(x).dummy_eq(DiagMatrix(x.applyfunc(cos))*y) + _check_derivative_with_explicit_matrix(expr, x, expr.diff(x)) + + expr = (a.T * X * b).applyfunc(sin) + assert expr.diff(X).dummy_eq(a*(a.T*X*b).applyfunc(cos)*b.T) + _check_derivative_with_explicit_matrix(expr, X, expr.diff(X)) + + expr = a.T * X.applyfunc(sin) * b + assert expr.diff(X).dummy_eq( + DiagMatrix(a)*X.applyfunc(cos)*DiagMatrix(b)) + _check_derivative_with_explicit_matrix(expr, X, expr.diff(X)) + + expr = a.T * (A*X*B).applyfunc(sin) * b + assert expr.diff(X).dummy_eq( + A.T*DiagMatrix(a)*(A*X*B).applyfunc(cos)*DiagMatrix(b)*B.T) + _check_derivative_with_explicit_matrix(expr, X, expr.diff(X)) + + expr = a.T * (A*X*b).applyfunc(sin) * b.T + # TODO: not implemented + #assert expr.diff(X) == ... + #_check_derivative_with_explicit_matrix(expr, X, expr.diff(X)) + + expr = a.T*A*X.applyfunc(sin)*B*b + assert expr.diff(X).dummy_eq( + HadamardProduct(A.T * a * b.T * B.T, X.applyfunc(cos))) + + expr = a.T * (A*X.applyfunc(sin)*B).applyfunc(log) * b + # TODO: wrong + # assert expr.diff(X) == A.T*DiagMatrix(a)*(A*X.applyfunc(sin)*B).applyfunc(Lambda(k, 1/k))*DiagMatrix(b)*B.T + + expr = a.T * (X.applyfunc(sin)).applyfunc(log) * b + # TODO: wrong + # assert expr.diff(X) == DiagMatrix(a)*X.applyfunc(sin).applyfunc(Lambda(k, 1/k))*DiagMatrix(b) + + +def test_derivatives_of_hadamard_expressions(): + + # Hadamard Product + + expr = hadamard_product(a, x, b) + assert expr.diff(x) == DiagMatrix(hadamard_product(b, a)) + + expr = a.T*hadamard_product(A, X, B)*b + assert expr.diff(X) == HadamardProduct(a*b.T, A, B) + + # Hadamard Power + + expr = hadamard_power(x, 2) + assert expr.diff(x).doit() == 2*DiagMatrix(x) + + expr = hadamard_power(x.T, 2) + assert expr.diff(x).doit() == 2*DiagMatrix(x) + + expr = hadamard_power(x, S.Half) + assert expr.diff(x) == S.Half*DiagMatrix(hadamard_power(x, Rational(-1, 2))) + + expr = hadamard_power(a.T*X*b, 2) + assert expr.diff(X) == 2*a*a.T*X*b*b.T + + expr = hadamard_power(a.T*X*b, S.Half) + assert expr.diff(X) == a/(2*sqrt(a.T*X*b))*b.T diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_determinant.py b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_determinant.py new file mode 100644 index 0000000000000000000000000000000000000000..d1a66c728f076f8c769d2519ee47c8a9cc90a90e --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_determinant.py @@ -0,0 +1,65 @@ +from sympy.core import S, symbols +from sympy.matrices import eye, ones, Matrix, ShapeError +from sympy.matrices.expressions import ( + Identity, MatrixExpr, MatrixSymbol, Determinant, + det, per, ZeroMatrix, Transpose, + Permanent, MatMul +) +from sympy.matrices.expressions.special import OneMatrix +from sympy.testing.pytest import raises +from sympy.assumptions.ask import Q +from sympy.assumptions.refine import refine + +n = symbols('n', integer=True) +A = MatrixSymbol('A', n, n) +B = MatrixSymbol('B', n, n) +C = MatrixSymbol('C', 3, 4) + + +def test_det(): + assert isinstance(Determinant(A), Determinant) + assert not isinstance(Determinant(A), MatrixExpr) + raises(ShapeError, lambda: Determinant(C)) + assert det(eye(3)) == 1 + assert det(Matrix(3, 3, [1, 3, 2, 4, 1, 3, 2, 5, 2])) == 17 + _ = A / det(A) # Make sure this is possible + + raises(TypeError, lambda: Determinant(S.One)) + + assert Determinant(A).arg is A + + +def test_eval_determinant(): + assert det(Identity(n)) == 1 + assert det(ZeroMatrix(n, n)) == 0 + assert det(OneMatrix(n, n)) == Determinant(OneMatrix(n, n)) + assert det(OneMatrix(1, 1)) == 1 + assert det(OneMatrix(2, 2)) == 0 + assert det(Transpose(A)) == det(A) + assert Determinant(MatMul(eye(2), eye(2))).doit(deep=True) == 1 + + +def test_refine(): + assert refine(det(A), Q.orthogonal(A)) == 1 + assert refine(det(A), Q.singular(A)) == 0 + assert refine(det(A), Q.unit_triangular(A)) == 1 + assert refine(det(A), Q.normal(A)) == det(A) + + +def test_commutative(): + det_a = Determinant(A) + det_b = Determinant(B) + assert det_a.is_commutative + assert det_b.is_commutative + assert det_a * det_b == det_b * det_a + + +def test_permanent(): + assert isinstance(Permanent(A), Permanent) + assert not isinstance(Permanent(A), MatrixExpr) + assert isinstance(Permanent(C), Permanent) + assert Permanent(ones(3, 3)).doit() == 6 + _ = C / per(C) + assert per(Matrix(3, 3, [1, 3, 2, 4, 1, 3, 2, 5, 2])) == 103 + raises(TypeError, lambda: Permanent(S.One)) + assert Permanent(A).arg is A diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_diagonal.py b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_diagonal.py new file mode 100644 index 0000000000000000000000000000000000000000..3e4f7ea4c178121c33eeb26c09675403d274c1e8 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_diagonal.py @@ -0,0 +1,156 @@ +from sympy.matrices.expressions import MatrixSymbol +from sympy.matrices.expressions.diagonal import DiagonalMatrix, DiagonalOf, DiagMatrix, diagonalize_vector +from sympy.assumptions.ask import (Q, ask) +from sympy.core.symbol import Symbol +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.matrices.dense import Matrix +from sympy.matrices.expressions.matmul import MatMul +from sympy.matrices.expressions.special import Identity +from sympy.testing.pytest import raises + + +n = Symbol('n') +m = Symbol('m') + + +def test_DiagonalMatrix(): + x = MatrixSymbol('x', n, m) + D = DiagonalMatrix(x) + assert D.diagonal_length is None + assert D.shape == (n, m) + + x = MatrixSymbol('x', n, n) + D = DiagonalMatrix(x) + assert D.diagonal_length == n + assert D.shape == (n, n) + assert D[1, 2] == 0 + assert D[1, 1] == x[1, 1] + i = Symbol('i') + j = Symbol('j') + x = MatrixSymbol('x', 3, 3) + ij = DiagonalMatrix(x)[i, j] + assert ij != 0 + assert ij.subs({i:0, j:0}) == x[0, 0] + assert ij.subs({i:0, j:1}) == 0 + assert ij.subs({i:1, j:1}) == x[1, 1] + assert ask(Q.diagonal(D)) # affirm that D is diagonal + + x = MatrixSymbol('x', n, 3) + D = DiagonalMatrix(x) + assert D.diagonal_length == 3 + assert D.shape == (n, 3) + assert D[2, m] == KroneckerDelta(2, m)*x[2, m] + assert D[3, m] == 0 + raises(IndexError, lambda: D[m, 3]) + + x = MatrixSymbol('x', 3, n) + D = DiagonalMatrix(x) + assert D.diagonal_length == 3 + assert D.shape == (3, n) + assert D[m, 2] == KroneckerDelta(m, 2)*x[m, 2] + assert D[m, 3] == 0 + raises(IndexError, lambda: D[3, m]) + + x = MatrixSymbol('x', n, m) + D = DiagonalMatrix(x) + assert D.diagonal_length is None + assert D.shape == (n, m) + assert D[m, 4] != 0 + + x = MatrixSymbol('x', 3, 4) + assert [DiagonalMatrix(x)[i] for i in range(12)] == [ + x[0, 0], 0, 0, 0, 0, x[1, 1], 0, 0, 0, 0, x[2, 2], 0] + + # shape is retained, issue 12427 + assert ( + DiagonalMatrix(MatrixSymbol('x', 3, 4))* + DiagonalMatrix(MatrixSymbol('x', 4, 2))).shape == (3, 2) + + +def test_DiagonalOf(): + x = MatrixSymbol('x', n, n) + d = DiagonalOf(x) + assert d.shape == (n, 1) + assert d.diagonal_length == n + assert d[2, 0] == d[2] == x[2, 2] + + x = MatrixSymbol('x', n, m) + d = DiagonalOf(x) + assert d.shape == (None, 1) + assert d.diagonal_length is None + assert d[2, 0] == d[2] == x[2, 2] + + d = DiagonalOf(MatrixSymbol('x', 4, 3)) + assert d.shape == (3, 1) + d = DiagonalOf(MatrixSymbol('x', n, 3)) + assert d.shape == (3, 1) + d = DiagonalOf(MatrixSymbol('x', 3, n)) + assert d.shape == (3, 1) + x = MatrixSymbol('x', n, m) + assert [DiagonalOf(x)[i] for i in range(4)] ==[ + x[0, 0], x[1, 1], x[2, 2], x[3, 3]] + + +def test_DiagMatrix(): + x = MatrixSymbol('x', n, 1) + d = DiagMatrix(x) + assert d.shape == (n, n) + assert d[0, 1] == 0 + assert d[0, 0] == x[0, 0] + + a = MatrixSymbol('a', 1, 1) + d = diagonalize_vector(a) + assert isinstance(d, MatrixSymbol) + assert a == d + assert diagonalize_vector(Identity(3)) == Identity(3) + assert DiagMatrix(Identity(3)).doit() == Identity(3) + assert isinstance(DiagMatrix(Identity(3)), DiagMatrix) + + # A diagonal matrix is equal to its transpose: + assert DiagMatrix(x).T == DiagMatrix(x) + assert diagonalize_vector(x.T) == DiagMatrix(x) + + dx = DiagMatrix(x) + assert dx[0, 0] == x[0, 0] + assert dx[1, 1] == x[1, 0] + assert dx[0, 1] == 0 + assert dx[0, m] == x[0, 0]*KroneckerDelta(0, m) + + z = MatrixSymbol('z', 1, n) + dz = DiagMatrix(z) + assert dz[0, 0] == z[0, 0] + assert dz[1, 1] == z[0, 1] + assert dz[0, 1] == 0 + assert dz[0, m] == z[0, m]*KroneckerDelta(0, m) + + v = MatrixSymbol('v', 3, 1) + dv = DiagMatrix(v) + assert dv.as_explicit() == Matrix([ + [v[0, 0], 0, 0], + [0, v[1, 0], 0], + [0, 0, v[2, 0]], + ]) + + v = MatrixSymbol('v', 1, 3) + dv = DiagMatrix(v) + assert dv.as_explicit() == Matrix([ + [v[0, 0], 0, 0], + [0, v[0, 1], 0], + [0, 0, v[0, 2]], + ]) + + dv = DiagMatrix(3*v) + assert dv.args == (3*v,) + assert dv.doit() == 3*DiagMatrix(v) + assert isinstance(dv.doit(), MatMul) + + a = MatrixSymbol("a", 3, 1).as_explicit() + expr = DiagMatrix(a) + result = Matrix([ + [a[0, 0], 0, 0], + [0, a[1, 0], 0], + [0, 0, a[2, 0]], + ]) + assert expr.doit() == result + expr = DiagMatrix(a.T) + assert expr.doit() == result diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_dotproduct.py b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_dotproduct.py new file mode 100644 index 0000000000000000000000000000000000000000..abf8ab8e935cbd3039f25f83d3603ac444e5a7bb --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_dotproduct.py @@ -0,0 +1,35 @@ +from sympy.core.expr import unchanged +from sympy.core.mul import Mul +from sympy.matrices import Matrix +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.expressions.dotproduct import DotProduct +from sympy.testing.pytest import raises + + +A = Matrix(3, 1, [1, 2, 3]) +B = Matrix(3, 1, [1, 3, 5]) +C = Matrix(4, 1, [1, 2, 4, 5]) +D = Matrix(2, 2, [1, 2, 3, 4]) + +def test_docproduct(): + assert DotProduct(A, B).doit() == 22 + assert DotProduct(A.T, B).doit() == 22 + assert DotProduct(A, B.T).doit() == 22 + assert DotProduct(A.T, B.T).doit() == 22 + + raises(TypeError, lambda: DotProduct(1, A)) + raises(TypeError, lambda: DotProduct(A, 1)) + raises(TypeError, lambda: DotProduct(A, D)) + raises(TypeError, lambda: DotProduct(D, A)) + + raises(TypeError, lambda: DotProduct(B, C).doit()) + +def test_dotproduct_symbolic(): + A = MatrixSymbol('A', 3, 1) + B = MatrixSymbol('B', 3, 1) + + dot = DotProduct(A, B) + assert dot.is_scalar == True + assert unchanged(Mul, 2, dot) + # XXX Fix forced evaluation for arithmetics with matrix expressions + assert dot * A == (A[0, 0]*B[0, 0] + A[1, 0]*B[1, 0] + A[2, 0]*B[2, 0])*A diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_factorizations.py b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_factorizations.py new file mode 100644 index 0000000000000000000000000000000000000000..a0319acabbb7409dfa5c24ceca39e25ff0240618 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_factorizations.py @@ -0,0 +1,29 @@ +from sympy.matrices.expressions.factorizations import lu, LofCholesky, qr, svd +from sympy.assumptions.ask import (Q, ask) +from sympy.core.symbol import Symbol +from sympy.matrices.expressions.matexpr import MatrixSymbol + +n = Symbol('n') +X = MatrixSymbol('X', n, n) + +def test_LU(): + L, U = lu(X) + assert L.shape == U.shape == X.shape + assert ask(Q.lower_triangular(L)) + assert ask(Q.upper_triangular(U)) + +def test_Cholesky(): + LofCholesky(X) + +def test_QR(): + Q_, R = qr(X) + assert Q_.shape == R.shape == X.shape + assert ask(Q.orthogonal(Q_)) + assert ask(Q.upper_triangular(R)) + +def test_svd(): + U, S, V = svd(X) + assert U.shape == S.shape == V.shape == X.shape + assert ask(Q.orthogonal(U)) + assert ask(Q.orthogonal(V)) + assert ask(Q.diagonal(S)) diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_fourier.py b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_fourier.py new file mode 100644 index 0000000000000000000000000000000000000000..0230c8a0957ed28fb0a5cc1e9ee77ecae797265b --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_fourier.py @@ -0,0 +1,44 @@ +from sympy.assumptions.ask import (Q, ask) +from sympy.core.numbers import (I, Rational) +from sympy.core.singleton import S +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.simplify.simplify import simplify +from sympy.core.symbol import symbols +from sympy.matrices.expressions.fourier import DFT, IDFT +from sympy.matrices import det, Matrix, Identity +from sympy.testing.pytest import raises + + +def test_dft_creation(): + assert DFT(2) + assert DFT(0) + raises(ValueError, lambda: DFT(-1)) + raises(ValueError, lambda: DFT(2.0)) + raises(ValueError, lambda: DFT(2 + 1j)) + + n = symbols('n') + assert DFT(n) + n = symbols('n', integer=False) + raises(ValueError, lambda: DFT(n)) + n = symbols('n', negative=True) + raises(ValueError, lambda: DFT(n)) + + +def test_dft(): + n, i, j = symbols('n i j') + assert DFT(4).shape == (4, 4) + assert ask(Q.unitary(DFT(4))) + assert Abs(simplify(det(Matrix(DFT(4))))) == 1 + assert DFT(n)*IDFT(n) == Identity(n) + assert DFT(n)[i, j] == exp(-2*S.Pi*I/n)**(i*j) / sqrt(n) + + +def test_dft2(): + assert DFT(1).as_explicit() == Matrix([[1]]) + assert DFT(2).as_explicit() == 1/sqrt(2)*Matrix([[1,1],[1,-1]]) + assert DFT(4).as_explicit() == Matrix([[S.Half, S.Half, S.Half, S.Half], + [S.Half, -I/2, Rational(-1,2), I/2], + [S.Half, Rational(-1,2), S.Half, Rational(-1,2)], + [S.Half, I/2, Rational(-1,2), -I/2]]) diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_funcmatrix.py b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_funcmatrix.py new file mode 100644 index 0000000000000000000000000000000000000000..e4850fe5c739b9390fac6afa10757b5babf821c6 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_funcmatrix.py @@ -0,0 +1,54 @@ +from sympy.core import symbols, Lambda +from sympy.core.sympify import SympifyError +from sympy.functions import KroneckerDelta +from sympy.matrices import Matrix +from sympy.matrices.expressions import FunctionMatrix, MatrixExpr, Identity +from sympy.testing.pytest import raises + + +def test_funcmatrix_creation(): + i, j, k = symbols('i j k') + assert FunctionMatrix(2, 2, Lambda((i, j), 0)) + assert FunctionMatrix(0, 0, Lambda((i, j), 0)) + + raises(ValueError, lambda: FunctionMatrix(-1, 0, Lambda((i, j), 0))) + raises(ValueError, lambda: FunctionMatrix(2.0, 0, Lambda((i, j), 0))) + raises(ValueError, lambda: FunctionMatrix(2j, 0, Lambda((i, j), 0))) + raises(ValueError, lambda: FunctionMatrix(0, -1, Lambda((i, j), 0))) + raises(ValueError, lambda: FunctionMatrix(0, 2.0, Lambda((i, j), 0))) + raises(ValueError, lambda: FunctionMatrix(0, 2j, Lambda((i, j), 0))) + + raises(ValueError, lambda: FunctionMatrix(2, 2, Lambda(i, 0))) + raises(SympifyError, lambda: FunctionMatrix(2, 2, lambda i, j: 0)) + raises(ValueError, lambda: FunctionMatrix(2, 2, Lambda((i,), 0))) + raises(ValueError, lambda: FunctionMatrix(2, 2, Lambda((i, j, k), 0))) + raises(ValueError, lambda: FunctionMatrix(2, 2, i+j)) + assert FunctionMatrix(2, 2, "lambda i, j: 0") == \ + FunctionMatrix(2, 2, Lambda((i, j), 0)) + + m = FunctionMatrix(2, 2, KroneckerDelta) + assert m.as_explicit() == Identity(2).as_explicit() + assert m.args[2].dummy_eq(Lambda((i, j), KroneckerDelta(i, j))) + + n = symbols('n') + assert FunctionMatrix(n, n, Lambda((i, j), 0)) + n = symbols('n', integer=False) + raises(ValueError, lambda: FunctionMatrix(n, n, Lambda((i, j), 0))) + n = symbols('n', negative=True) + raises(ValueError, lambda: FunctionMatrix(n, n, Lambda((i, j), 0))) + + +def test_funcmatrix(): + i, j = symbols('i,j') + X = FunctionMatrix(3, 3, Lambda((i, j), i - j)) + assert X[1, 1] == 0 + assert X[1, 2] == -1 + assert X.shape == (3, 3) + assert X.rows == X.cols == 3 + assert Matrix(X) == Matrix(3, 3, lambda i, j: i - j) + assert isinstance(X*X + X, MatrixExpr) + + +def test_replace_issue(): + X = FunctionMatrix(3, 3, KroneckerDelta) + assert X.replace(lambda x: True, lambda x: x) == X diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_hadamard.py b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_hadamard.py new file mode 100644 index 0000000000000000000000000000000000000000..800fa830a9b089103d69b372db93ebcea541d02b --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_hadamard.py @@ -0,0 +1,141 @@ +from sympy.matrices.dense import Matrix, eye +from sympy.matrices.exceptions import ShapeError +from sympy.matrices.expressions.matadd import MatAdd +from sympy.matrices.expressions.special import Identity, OneMatrix, ZeroMatrix +from sympy.core import symbols +from sympy.testing.pytest import raises, warns_deprecated_sympy + +from sympy.matrices import MatrixSymbol +from sympy.matrices.expressions import (HadamardProduct, hadamard_product, HadamardPower, hadamard_power) + +n, m, k = symbols('n,m,k') +Z = MatrixSymbol('Z', n, n) +A = MatrixSymbol('A', n, m) +B = MatrixSymbol('B', n, m) +C = MatrixSymbol('C', m, k) + + +def test_HadamardProduct(): + assert HadamardProduct(A, B, A).shape == A.shape + + raises(TypeError, lambda: HadamardProduct(A, n)) + raises(TypeError, lambda: HadamardProduct(A, 1)) + + assert HadamardProduct(A, 2*B, -A)[1, 1] == \ + -2 * A[1, 1] * B[1, 1] * A[1, 1] + + mix = HadamardProduct(Z*A, B)*C + assert mix.shape == (n, k) + + assert set(HadamardProduct(A, B, A).T.args) == {A.T, A.T, B.T} + + +def test_HadamardProduct_isnt_commutative(): + assert HadamardProduct(A, B) != HadamardProduct(B, A) + + +def test_mixed_indexing(): + X = MatrixSymbol('X', 2, 2) + Y = MatrixSymbol('Y', 2, 2) + Z = MatrixSymbol('Z', 2, 2) + + assert (X*HadamardProduct(Y, Z))[0, 0] == \ + X[0, 0]*Y[0, 0]*Z[0, 0] + X[0, 1]*Y[1, 0]*Z[1, 0] + + +def test_canonicalize(): + X = MatrixSymbol('X', 2, 2) + Y = MatrixSymbol('Y', 2, 2) + with warns_deprecated_sympy(): + expr = HadamardProduct(X, check=False) + assert isinstance(expr, HadamardProduct) + expr2 = expr.doit() # unpack is called + assert isinstance(expr2, MatrixSymbol) + Z = ZeroMatrix(2, 2) + U = OneMatrix(2, 2) + assert HadamardProduct(Z, X).doit() == Z + assert HadamardProduct(U, X, X, U).doit() == HadamardPower(X, 2) + assert HadamardProduct(X, U, Y).doit() == HadamardProduct(X, Y) + assert HadamardProduct(X, Z, U, Y).doit() == Z + + +def test_hadamard(): + m, n, p = symbols('m, n, p', integer=True) + A = MatrixSymbol('A', m, n) + B = MatrixSymbol('B', m, n) + X = MatrixSymbol('X', m, m) + I = Identity(m) + + raises(TypeError, lambda: hadamard_product()) + assert hadamard_product(A) == A + assert isinstance(hadamard_product(A, B), HadamardProduct) + assert hadamard_product(A, B).doit() == hadamard_product(A, B) + assert hadamard_product(X, I) == HadamardProduct(I, X) + assert isinstance(hadamard_product(X, I), HadamardProduct) + + a = MatrixSymbol("a", k, 1) + expr = MatAdd(ZeroMatrix(k, 1), OneMatrix(k, 1)) + expr = HadamardProduct(expr, a) + assert expr.doit() == a + + raises(ValueError, lambda: HadamardProduct()) + + +def test_hadamard_product_with_explicit_mat(): + A = MatrixSymbol("A", 3, 3).as_explicit() + B = MatrixSymbol("B", 3, 3).as_explicit() + X = MatrixSymbol("X", 3, 3) + expr = hadamard_product(A, B) + ret = Matrix([i*j for i, j in zip(A, B)]).reshape(3, 3) + assert expr == ret + expr = hadamard_product(A, X, B) + assert expr == HadamardProduct(ret, X) + expr = hadamard_product(eye(3), A) + assert expr == Matrix([[A[0, 0], 0, 0], [0, A[1, 1], 0], [0, 0, A[2, 2]]]) + expr = hadamard_product(eye(3), eye(3)) + assert expr == eye(3) + + +def test_hadamard_power(): + m, n, p = symbols('m, n, p', integer=True) + A = MatrixSymbol('A', m, n) + + assert hadamard_power(A, 1) == A + assert isinstance(hadamard_power(A, 2), HadamardPower) + assert hadamard_power(A, n).T == hadamard_power(A.T, n) + assert hadamard_power(A, n)[0, 0] == A[0, 0]**n + assert hadamard_power(m, n) == m**n + raises(ValueError, lambda: hadamard_power(A, A)) + + +def test_hadamard_power_explicit(): + A = MatrixSymbol('A', 2, 2) + B = MatrixSymbol('B', 2, 2) + a, b = symbols('a b') + + assert HadamardPower(a, b) == a**b + + assert HadamardPower(a, B).as_explicit() == \ + Matrix([ + [a**B[0, 0], a**B[0, 1]], + [a**B[1, 0], a**B[1, 1]]]) + + assert HadamardPower(A, b).as_explicit() == \ + Matrix([ + [A[0, 0]**b, A[0, 1]**b], + [A[1, 0]**b, A[1, 1]**b]]) + + assert HadamardPower(A, B).as_explicit() == \ + Matrix([ + [A[0, 0]**B[0, 0], A[0, 1]**B[0, 1]], + [A[1, 0]**B[1, 0], A[1, 1]**B[1, 1]]]) + + +def test_shape_error(): + A = MatrixSymbol('A', 2, 3) + B = MatrixSymbol('B', 3, 3) + raises(ShapeError, lambda: HadamardProduct(A, B)) + raises(ShapeError, lambda: HadamardPower(A, B)) + A = MatrixSymbol('A', 3, 2) + raises(ShapeError, lambda: HadamardProduct(A, B)) + raises(ShapeError, lambda: HadamardPower(A, B)) diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_indexing.py b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_indexing.py new file mode 100644 index 0000000000000000000000000000000000000000..500761f248eef5f627c2a7344a6817aca0b8a802 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_indexing.py @@ -0,0 +1,299 @@ +from sympy.concrete.summations import Sum +from sympy.core.symbol import symbols, Symbol, Dummy +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.matrices.dense import eye +from sympy.matrices.expressions.blockmatrix import BlockMatrix +from sympy.matrices.expressions.hadamard import HadamardPower +from sympy.matrices.expressions.matexpr import (MatrixSymbol, + MatrixExpr, MatrixElement) +from sympy.matrices.expressions.matpow import MatPow +from sympy.matrices.expressions.special import (ZeroMatrix, Identity, + OneMatrix) +from sympy.matrices.expressions.trace import Trace, trace +from sympy.matrices.immutable import ImmutableMatrix +from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct +from sympy.testing.pytest import XFAIL, raises + +k, l, m, n = symbols('k l m n', integer=True) +i, j = symbols('i j', integer=True) + +W = MatrixSymbol('W', k, l) +X = MatrixSymbol('X', l, m) +Y = MatrixSymbol('Y', l, m) +Z = MatrixSymbol('Z', m, n) + +X1 = MatrixSymbol('X1', m, m) +X2 = MatrixSymbol('X2', m, m) +X3 = MatrixSymbol('X3', m, m) +X4 = MatrixSymbol('X4', m, m) + +A = MatrixSymbol('A', 2, 2) +B = MatrixSymbol('B', 2, 2) +x = MatrixSymbol('x', 1, 2) +y = MatrixSymbol('x', 2, 1) + + +def test_symbolic_indexing(): + x12 = X[1, 2] + assert all(s in str(x12) for s in ['1', '2', X.name]) + # We don't care about the exact form of this. We do want to make sure + # that all of these features are present + + +def test_add_index(): + assert (X + Y)[i, j] == X[i, j] + Y[i, j] + + +def test_mul_index(): + assert (A*y)[0, 0] == A[0, 0]*y[0, 0] + A[0, 1]*y[1, 0] + assert (A*B).as_mutable() == (A.as_mutable() * B.as_mutable()) + X = MatrixSymbol('X', n, m) + Y = MatrixSymbol('Y', m, k) + + result = (X*Y)[4,2] + expected = Sum(X[4, i]*Y[i, 2], (i, 0, m - 1)) + assert result.args[0].dummy_eq(expected.args[0], i) + assert result.args[1][1:] == expected.args[1][1:] + + +def test_pow_index(): + Q = MatPow(A, 2) + assert Q[0, 0] == A[0, 0]**2 + A[0, 1]*A[1, 0] + n = symbols("n") + Q2 = A**n + assert Q2[0, 0] == 2*( + -sqrt((A[0, 0] + A[1, 1])**2 - 4*A[0, 0]*A[1, 1] + + 4*A[0, 1]*A[1, 0])/2 + A[0, 0]/2 + A[1, 1]/2 + )**n * \ + A[0, 1]*A[1, 0]/( + (sqrt(A[0, 0]**2 - 2*A[0, 0]*A[1, 1] + 4*A[0, 1]*A[1, 0] + + A[1, 1]**2) + A[0, 0] - A[1, 1])* + sqrt(A[0, 0]**2 - 2*A[0, 0]*A[1, 1] + 4*A[0, 1]*A[1, 0] + A[1, 1]**2) + ) - 2*( + sqrt((A[0, 0] + A[1, 1])**2 - 4*A[0, 0]*A[1, 1] + + 4*A[0, 1]*A[1, 0])/2 + A[0, 0]/2 + A[1, 1]/2 + )**n * A[0, 1]*A[1, 0]/( + (-sqrt(A[0, 0]**2 - 2*A[0, 0]*A[1, 1] + 4*A[0, 1]*A[1, 0] + + A[1, 1]**2) + A[0, 0] - A[1, 1])* + sqrt(A[0, 0]**2 - 2*A[0, 0]*A[1, 1] + 4*A[0, 1]*A[1, 0] + A[1, 1]**2) + ) + + +def test_transpose_index(): + assert X.T[i, j] == X[j, i] + + +def test_Identity_index(): + I = Identity(3) + assert I[0, 0] == I[1, 1] == I[2, 2] == 1 + assert I[1, 0] == I[0, 1] == I[2, 1] == 0 + assert I[i, 0].delta_range == (0, 2) + raises(IndexError, lambda: I[3, 3]) + + +def test_block_index(): + I = Identity(3) + Z = ZeroMatrix(3, 3) + B = BlockMatrix([[I, I], [I, I]]) + e3 = ImmutableMatrix(eye(3)) + BB = BlockMatrix([[e3, e3], [e3, e3]]) + assert B[0, 0] == B[3, 0] == B[0, 3] == B[3, 3] == 1 + assert B[4, 3] == B[5, 1] == 0 + + BB = BlockMatrix([[e3, e3], [e3, e3]]) + assert B.as_explicit() == BB.as_explicit() + + BI = BlockMatrix([[I, Z], [Z, I]]) + + assert BI.as_explicit().equals(eye(6)) + + +def test_block_index_symbolic(): + # Note that these matrices may be zero-sized and indices may be negative, which causes + # all naive simplifications given in the comments to be invalid + A1 = MatrixSymbol('A1', n, k) + A2 = MatrixSymbol('A2', n, l) + A3 = MatrixSymbol('A3', m, k) + A4 = MatrixSymbol('A4', m, l) + A = BlockMatrix([[A1, A2], [A3, A4]]) + assert A[0, 0] == MatrixElement(A, 0, 0) # Cannot be A1[0, 0] + assert A[n - 1, k - 1] == A1[n - 1, k - 1] + assert A[n, k] == A4[0, 0] + assert A[n + m - 1, 0] == MatrixElement(A, n + m - 1, 0) # Cannot be A3[m - 1, 0] + assert A[0, k + l - 1] == MatrixElement(A, 0, k + l - 1) # Cannot be A2[0, l - 1] + assert A[n + m - 1, k + l - 1] == MatrixElement(A, n + m - 1, k + l - 1) # Cannot be A4[m - 1, l - 1] + assert A[i, j] == MatrixElement(A, i, j) + assert A[n + i, k + j] == MatrixElement(A, n + i, k + j) # Cannot be A4[i, j] + assert A[n - i - 1, k - j - 1] == MatrixElement(A, n - i - 1, k - j - 1) # Cannot be A1[n - i - 1, k - j - 1] + + +def test_block_index_symbolic_nonzero(): + # All invalid simplifications from test_block_index_symbolic() that become valid if all + # matrices have nonzero size and all indices are nonnegative + k, l, m, n = symbols('k l m n', integer=True, positive=True) + i, j = symbols('i j', integer=True, nonnegative=True) + A1 = MatrixSymbol('A1', n, k) + A2 = MatrixSymbol('A2', n, l) + A3 = MatrixSymbol('A3', m, k) + A4 = MatrixSymbol('A4', m, l) + A = BlockMatrix([[A1, A2], [A3, A4]]) + assert A[0, 0] == A1[0, 0] + assert A[n + m - 1, 0] == A3[m - 1, 0] + assert A[0, k + l - 1] == A2[0, l - 1] + assert A[n + m - 1, k + l - 1] == A4[m - 1, l - 1] + assert A[i, j] == MatrixElement(A, i, j) + assert A[n + i, k + j] == A4[i, j] + assert A[n - i - 1, k - j - 1] == A1[n - i - 1, k - j - 1] + assert A[2 * n, 2 * k] == A4[n, k] + + +def test_block_index_large(): + n, m, k = symbols('n m k', integer=True, positive=True) + i = symbols('i', integer=True, nonnegative=True) + A1 = MatrixSymbol('A1', n, n) + A2 = MatrixSymbol('A2', n, m) + A3 = MatrixSymbol('A3', n, k) + A4 = MatrixSymbol('A4', m, n) + A5 = MatrixSymbol('A5', m, m) + A6 = MatrixSymbol('A6', m, k) + A7 = MatrixSymbol('A7', k, n) + A8 = MatrixSymbol('A8', k, m) + A9 = MatrixSymbol('A9', k, k) + A = BlockMatrix([[A1, A2, A3], [A4, A5, A6], [A7, A8, A9]]) + assert A[n + i, n + i] == MatrixElement(A, n + i, n + i) + + +@XFAIL +def test_block_index_symbolic_fail(): + # To make this work, symbolic matrix dimensions would need to be somehow assumed nonnegative + # even if the symbols aren't specified as such. Then 2 * n < n would correctly evaluate to + # False in BlockMatrix._entry() + A1 = MatrixSymbol('A1', n, 1) + A2 = MatrixSymbol('A2', m, 1) + A = BlockMatrix([[A1], [A2]]) + assert A[2 * n, 0] == A2[n, 0] + + +def test_slicing(): + A.as_explicit()[0, :] # does not raise an error + + +def test_errors(): + raises(IndexError, lambda: Identity(2)[1, 2, 3, 4, 5]) + raises(IndexError, lambda: Identity(2)[[1, 2, 3, 4, 5]]) + + +def test_matrix_expression_to_indices(): + i, j = symbols("i, j") + i1, i2, i3 = symbols("i_1:4") + + def replace_dummies(expr): + repl = {i: Symbol(i.name) for i in expr.atoms(Dummy)} + return expr.xreplace(repl) + + expr = W*X*Z + assert replace_dummies(expr._entry(i, j)) == \ + Sum(W[i, i1]*X[i1, i2]*Z[i2, j], (i1, 0, l-1), (i2, 0, m-1)) + assert MatrixExpr.from_index_summation(expr._entry(i, j)) == expr + + expr = Z.T*X.T*W.T + assert replace_dummies(expr._entry(i, j)) == \ + Sum(W[j, i2]*X[i2, i1]*Z[i1, i], (i1, 0, m-1), (i2, 0, l-1)) + assert MatrixExpr.from_index_summation(expr._entry(i, j), i) == expr + + expr = W*X*Z + W*Y*Z + assert replace_dummies(expr._entry(i, j)) == \ + Sum(W[i, i1]*X[i1, i2]*Z[i2, j], (i1, 0, l-1), (i2, 0, m-1)) +\ + Sum(W[i, i1]*Y[i1, i2]*Z[i2, j], (i1, 0, l-1), (i2, 0, m-1)) + assert MatrixExpr.from_index_summation(expr._entry(i, j)) == expr + + expr = 2*W*X*Z + 3*W*Y*Z + assert replace_dummies(expr._entry(i, j)) == \ + 2*Sum(W[i, i1]*X[i1, i2]*Z[i2, j], (i1, 0, l-1), (i2, 0, m-1)) +\ + 3*Sum(W[i, i1]*Y[i1, i2]*Z[i2, j], (i1, 0, l-1), (i2, 0, m-1)) + assert MatrixExpr.from_index_summation(expr._entry(i, j)) == expr + + expr = W*(X + Y)*Z + assert replace_dummies(expr._entry(i, j)) == \ + Sum(W[i, i1]*(X[i1, i2] + Y[i1, i2])*Z[i2, j], (i1, 0, l-1), (i2, 0, m-1)) + assert MatrixExpr.from_index_summation(expr._entry(i, j)) == expr + + expr = A*B**2*A + #assert replace_dummies(expr._entry(i, j)) == \ + # Sum(A[i, i1]*B[i1, i2]*B[i2, i3]*A[i3, j], (i1, 0, 1), (i2, 0, 1), (i3, 0, 1)) + + # Check that different dummies are used in sub-multiplications: + expr = (X1*X2 + X2*X1)*X3 + assert replace_dummies(expr._entry(i, j)) == \ + Sum((Sum(X1[i, i2] * X2[i2, i1], (i2, 0, m - 1)) + Sum(X1[i3, i1] * X2[i, i3], (i3, 0, m - 1))) * X3[ + i1, j], (i1, 0, m - 1)) + + +def test_matrix_expression_from_index_summation(): + from sympy.abc import a,b,c,d + A = MatrixSymbol("A", k, k) + B = MatrixSymbol("B", k, k) + C = MatrixSymbol("C", k, k) + w1 = MatrixSymbol("w1", k, 1) + + i0, i1, i2, i3, i4 = symbols("i0:5", cls=Dummy) + + expr = Sum(W[a,b]*X[b,c]*Z[c,d], (b, 0, l-1), (c, 0, m-1)) + assert MatrixExpr.from_index_summation(expr, a) == W*X*Z + expr = Sum(W.T[b,a]*X[b,c]*Z[c,d], (b, 0, l-1), (c, 0, m-1)) + assert MatrixExpr.from_index_summation(expr, a) == W*X*Z + expr = Sum(A[b, a]*B[b, c]*C[c, d], (b, 0, k-1), (c, 0, k-1)) + assert MatrixSymbol.from_index_summation(expr, a) == A.T*B*C + expr = Sum(A[b, a]*B[c, b]*C[c, d], (b, 0, k-1), (c, 0, k-1)) + assert MatrixSymbol.from_index_summation(expr, a) == A.T*B.T*C + expr = Sum(C[c, d]*A[b, a]*B[c, b], (b, 0, k-1), (c, 0, k-1)) + assert MatrixSymbol.from_index_summation(expr, a) == A.T*B.T*C + expr = Sum(A[a, b] + B[a, b], (a, 0, k-1), (b, 0, k-1)) + assert MatrixExpr.from_index_summation(expr, a) == OneMatrix(1, k)*A*OneMatrix(k, 1) + OneMatrix(1, k)*B*OneMatrix(k, 1) + expr = Sum(A[a, b]**2, (a, 0, k - 1), (b, 0, k - 1)) + assert MatrixExpr.from_index_summation(expr, a) == Trace(A * A.T) + expr = Sum(A[a, b]**3, (a, 0, k - 1), (b, 0, k - 1)) + assert MatrixExpr.from_index_summation(expr, a) == Trace(HadamardPower(A.T, 2) * A) + expr = Sum((A[a, b] + B[a, b])*C[b, c], (b, 0, k-1)) + assert MatrixExpr.from_index_summation(expr, a) == (A+B)*C + expr = Sum((A[a, b] + B[b, a])*C[b, c], (b, 0, k-1)) + assert MatrixExpr.from_index_summation(expr, a) == (A+B.T)*C + expr = Sum(A[a, b]*A[b, c]*A[c, d], (b, 0, k-1), (c, 0, k-1)) + assert MatrixExpr.from_index_summation(expr, a) == A**3 + expr = Sum(A[a, b]*A[b, c]*B[c, d], (b, 0, k-1), (c, 0, k-1)) + assert MatrixExpr.from_index_summation(expr, a) == A**2*B + + # Parse the trace of a matrix: + + expr = Sum(A[a, a], (a, 0, k-1)) + assert MatrixExpr.from_index_summation(expr, None) == trace(A) + expr = Sum(A[a, a]*B[b, c]*C[c, d], (a, 0, k-1), (c, 0, k-1)) + assert MatrixExpr.from_index_summation(expr, b) == trace(A)*B*C + + # Check wrong sum ranges (should raise an exception): + + ## Case 1: 0 to m instead of 0 to m-1 + expr = Sum(W[a,b]*X[b,c]*Z[c,d], (b, 0, l-1), (c, 0, m)) + raises(ValueError, lambda: MatrixExpr.from_index_summation(expr, a)) + ## Case 2: 1 to m-1 instead of 0 to m-1 + expr = Sum(W[a,b]*X[b,c]*Z[c,d], (b, 0, l-1), (c, 1, m-1)) + raises(ValueError, lambda: MatrixExpr.from_index_summation(expr, a)) + + # Parse nested sums: + expr = Sum(A[a, b]*Sum(B[b, c]*C[c, d], (c, 0, k-1)), (b, 0, k-1)) + assert MatrixExpr.from_index_summation(expr, a) == A*B*C + + # Test Kronecker delta: + expr = Sum(A[a, b]*KroneckerDelta(b, c)*B[c, d], (b, 0, k-1), (c, 0, k-1)) + assert MatrixExpr.from_index_summation(expr, a) == A*B + + expr = Sum(KroneckerDelta(i1, m)*KroneckerDelta(i2, n)*A[i, i1]*A[j, i2], (i1, 0, k-1), (i2, 0, k-1)) + assert MatrixExpr.from_index_summation(expr, m) == ArrayTensorProduct(A.T, A) + + # Test numbered indices: + expr = Sum(A[i1, i2]*w1[i2, 0], (i2, 0, k-1)) + assert MatrixExpr.from_index_summation(expr, i1) == MatrixElement(A*w1, i1, 0) + + expr = Sum(A[i1, i2]*B[i2, 0], (i2, 0, k-1)) + assert MatrixExpr.from_index_summation(expr, i1) == MatrixElement(A*B, i1, 0) diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_inverse.py b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_inverse.py new file mode 100644 index 0000000000000000000000000000000000000000..4bcc7d4de2b2bee4c4922bda8bc48a52aa205961 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_inverse.py @@ -0,0 +1,69 @@ +from sympy.core import symbols, S +from sympy.matrices.expressions import MatrixSymbol, Inverse, MatPow, ZeroMatrix, OneMatrix +from sympy.matrices.exceptions import NonInvertibleMatrixError, NonSquareMatrixError +from sympy.matrices import eye, Identity +from sympy.testing.pytest import raises +from sympy.assumptions.ask import Q +from sympy.assumptions.refine import refine + +n, m, l = symbols('n m l', integer=True) +A = MatrixSymbol('A', n, m) +B = MatrixSymbol('B', m, l) +C = MatrixSymbol('C', n, n) +D = MatrixSymbol('D', n, n) +E = MatrixSymbol('E', m, n) + + +def test_inverse(): + assert Inverse(C).args == (C, S.NegativeOne) + assert Inverse(C).shape == (n, n) + assert Inverse(A*E).shape == (n, n) + assert Inverse(E*A).shape == (m, m) + assert Inverse(C).inverse() == C + assert Inverse(Inverse(C)).doit() == C + assert isinstance(Inverse(Inverse(C)), Inverse) + + assert Inverse(*Inverse(E*A).args) == Inverse(E*A) + + assert C.inverse().inverse() == C + + assert C.inverse()*C == Identity(C.rows) + + assert Identity(n).inverse() == Identity(n) + assert (3*Identity(n)).inverse() == Identity(n)/3 + + # Simplifies Muls if possible (i.e. submatrices are square) + assert (C*D).inverse() == D.I*C.I + # But still works when not possible + assert isinstance((A*E).inverse(), Inverse) + assert Inverse(C*D).doit(inv_expand=False) == Inverse(C*D) + + assert Inverse(eye(3)).doit() == eye(3) + assert Inverse(eye(3)).doit(deep=False) == eye(3) + + assert OneMatrix(1, 1).I == Identity(1) + assert isinstance(OneMatrix(n, n).I, Inverse) + +def test_inverse_non_invertible(): + raises(NonInvertibleMatrixError, lambda: ZeroMatrix(n, n).I) + raises(NonInvertibleMatrixError, lambda: OneMatrix(2, 2).I) + +def test_refine(): + assert refine(C.I, Q.orthogonal(C)) == C.T + + +def test_inverse_matpow_canonicalization(): + A = MatrixSymbol('A', 3, 3) + assert Inverse(MatPow(A, 3)).doit() == MatPow(Inverse(A), 3).doit() + + +def test_nonsquare_error(): + A = MatrixSymbol('A', 3, 4) + raises(NonSquareMatrixError, lambda: Inverse(A)) + + +def test_adjoint_trnaspose_conjugate(): + A = MatrixSymbol('A', n, n) + assert A.transpose().inverse() == A.inverse().transpose() + assert A.conjugate().inverse() == A.inverse().conjugate() + assert A.adjoint().inverse() == A.inverse().adjoint() diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_kronecker.py b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_kronecker.py new file mode 100644 index 0000000000000000000000000000000000000000..b4444716a76a52e3638dd7a36238a9f459179083 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_kronecker.py @@ -0,0 +1,150 @@ +from sympy.core.mod import Mod +from sympy.core.numbers import I +from sympy.core.symbol import symbols +from sympy.functions.elementary.integers import floor +from sympy.matrices.dense import (Matrix, eye) +from sympy.matrices import MatrixSymbol, Identity +from sympy.matrices.expressions import det, trace + +from sympy.matrices.expressions.kronecker import (KroneckerProduct, + kronecker_product, + combine_kronecker) + + +mat1 = Matrix([[1, 2 * I], [1 + I, 3]]) +mat2 = Matrix([[2 * I, 3], [4 * I, 2]]) + +i, j, k, n, m, o, p, x = symbols('i,j,k,n,m,o,p,x') +Z = MatrixSymbol('Z', n, n) +W = MatrixSymbol('W', m, m) +A = MatrixSymbol('A', n, m) +B = MatrixSymbol('B', n, m) +C = MatrixSymbol('C', m, k) + + +def test_KroneckerProduct(): + assert isinstance(KroneckerProduct(A, B), KroneckerProduct) + assert KroneckerProduct(A, B).subs(A, C) == KroneckerProduct(C, B) + assert KroneckerProduct(A, C).shape == (n*m, m*k) + assert (KroneckerProduct(A, C) + KroneckerProduct(-A, C)).is_ZeroMatrix + assert (KroneckerProduct(W, Z) * KroneckerProduct(W.I, Z.I)).is_Identity + + +def test_KroneckerProduct_identity(): + assert KroneckerProduct(Identity(m), Identity(n)) == Identity(m*n) + assert KroneckerProduct(eye(2), eye(3)) == eye(6) + + +def test_KroneckerProduct_explicit(): + X = MatrixSymbol('X', 2, 2) + Y = MatrixSymbol('Y', 2, 2) + kp = KroneckerProduct(X, Y) + assert kp.shape == (4, 4) + assert kp.as_explicit() == Matrix( + [ + [X[0, 0]*Y[0, 0], X[0, 0]*Y[0, 1], X[0, 1]*Y[0, 0], X[0, 1]*Y[0, 1]], + [X[0, 0]*Y[1, 0], X[0, 0]*Y[1, 1], X[0, 1]*Y[1, 0], X[0, 1]*Y[1, 1]], + [X[1, 0]*Y[0, 0], X[1, 0]*Y[0, 1], X[1, 1]*Y[0, 0], X[1, 1]*Y[0, 1]], + [X[1, 0]*Y[1, 0], X[1, 0]*Y[1, 1], X[1, 1]*Y[1, 0], X[1, 1]*Y[1, 1]] + ] + ) + + +def test_tensor_product_adjoint(): + assert KroneckerProduct(I*A, B).adjoint() == \ + -I*KroneckerProduct(A.adjoint(), B.adjoint()) + assert KroneckerProduct(mat1, mat2).adjoint() == \ + kronecker_product(mat1.adjoint(), mat2.adjoint()) + + +def test_tensor_product_conjugate(): + assert KroneckerProduct(I*A, B).conjugate() == \ + -I*KroneckerProduct(A.conjugate(), B.conjugate()) + assert KroneckerProduct(mat1, mat2).conjugate() == \ + kronecker_product(mat1.conjugate(), mat2.conjugate()) + + +def test_tensor_product_transpose(): + assert KroneckerProduct(I*A, B).transpose() == \ + I*KroneckerProduct(A.transpose(), B.transpose()) + assert KroneckerProduct(mat1, mat2).transpose() == \ + kronecker_product(mat1.transpose(), mat2.transpose()) + + +def test_KroneckerProduct_is_associative(): + assert kronecker_product(A, kronecker_product( + B, C)) == kronecker_product(kronecker_product(A, B), C) + assert kronecker_product(A, kronecker_product( + B, C)) == KroneckerProduct(A, B, C) + + +def test_KroneckerProduct_is_bilinear(): + assert kronecker_product(x*A, B) == x*kronecker_product(A, B) + assert kronecker_product(A, x*B) == x*kronecker_product(A, B) + + +def test_KroneckerProduct_determinant(): + kp = kronecker_product(W, Z) + assert det(kp) == det(W)**n * det(Z)**m + + +def test_KroneckerProduct_trace(): + kp = kronecker_product(W, Z) + assert trace(kp) == trace(W)*trace(Z) + + +def test_KroneckerProduct_isnt_commutative(): + assert KroneckerProduct(A, B) != KroneckerProduct(B, A) + assert KroneckerProduct(A, B).is_commutative is False + + +def test_KroneckerProduct_extracts_commutative_part(): + assert kronecker_product(x * A, 2 * B) == x * \ + 2 * KroneckerProduct(A, B) + + +def test_KroneckerProduct_inverse(): + kp = kronecker_product(W, Z) + assert kp.inverse() == kronecker_product(W.inverse(), Z.inverse()) + + +def test_KroneckerProduct_combine_add(): + kp1 = kronecker_product(A, B) + kp2 = kronecker_product(C, W) + assert combine_kronecker(kp1*kp2) == kronecker_product(A*C, B*W) + + +def test_KroneckerProduct_combine_mul(): + X = MatrixSymbol('X', m, n) + Y = MatrixSymbol('Y', m, n) + kp1 = kronecker_product(A, X) + kp2 = kronecker_product(B, Y) + assert combine_kronecker(kp1+kp2) == kronecker_product(A+B, X+Y) + + +def test_KroneckerProduct_combine_pow(): + X = MatrixSymbol('X', n, n) + Y = MatrixSymbol('Y', n, n) + assert combine_kronecker(KroneckerProduct( + X, Y)**x) == KroneckerProduct(X**x, Y**x) + assert combine_kronecker(x * KroneckerProduct(X, Y) + ** 2) == x * KroneckerProduct(X**2, Y**2) + assert combine_kronecker( + x * (KroneckerProduct(X, Y)**2) * KroneckerProduct(A, B)) == x * KroneckerProduct(X**2 * A, Y**2 * B) + # cannot simplify because of non-square arguments to kronecker product: + assert combine_kronecker(KroneckerProduct(A, B.T) ** m) == KroneckerProduct(A, B.T) ** m + + +def test_KroneckerProduct_expand(): + X = MatrixSymbol('X', n, n) + Y = MatrixSymbol('Y', n, n) + + assert KroneckerProduct(X + Y, Y + Z).expand(kroneckerproduct=True) == \ + KroneckerProduct(X, Y) + KroneckerProduct(X, Z) + \ + KroneckerProduct(Y, Y) + KroneckerProduct(Y, Z) + +def test_KroneckerProduct_entry(): + A = MatrixSymbol('A', n, m) + B = MatrixSymbol('B', o, p) + + assert KroneckerProduct(A, B)._entry(i, j) == A[Mod(floor(i/o), n), Mod(floor(j/p), m)]*B[Mod(i, o), Mod(j, p)] diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_matadd.py b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_matadd.py new file mode 100644 index 0000000000000000000000000000000000000000..43229ae8c2e42f0253a5f3eceefa5fffe7a99f29 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_matadd.py @@ -0,0 +1,58 @@ +from sympy.matrices.expressions import MatrixSymbol, MatAdd, MatPow, MatMul +from sympy.matrices.expressions.special import GenericZeroMatrix, ZeroMatrix +from sympy.matrices.exceptions import ShapeError +from sympy.matrices import eye, ImmutableMatrix +from sympy.core import Add, Basic, S +from sympy.core.add import add +from sympy.testing.pytest import XFAIL, raises + +X = MatrixSymbol('X', 2, 2) +Y = MatrixSymbol('Y', 2, 2) + +def test_evaluate(): + assert MatAdd(X, X, evaluate=True) == add(X, X, evaluate=True) == MatAdd(X, X).doit() + +def test_sort_key(): + assert MatAdd(Y, X).doit().args == add(Y, X).doit().args == (X, Y) + + +def test_matadd_sympify(): + assert isinstance(MatAdd(eye(1), eye(1)).args[0], Basic) + assert isinstance(add(eye(1), eye(1)).args[0], Basic) + + +def test_matadd_of_matrices(): + assert MatAdd(eye(2), 4*eye(2), eye(2)).doit() == ImmutableMatrix(6*eye(2)) + assert add(eye(2), 4*eye(2), eye(2)).doit() == ImmutableMatrix(6*eye(2)) + + +def test_doit_args(): + A = ImmutableMatrix([[1, 2], [3, 4]]) + B = ImmutableMatrix([[2, 3], [4, 5]]) + assert MatAdd(A, MatPow(B, 2)).doit() == A + B**2 + assert MatAdd(A, MatMul(A, B)).doit() == A + A*B + assert (MatAdd(A, X, MatMul(A, B), Y, MatAdd(2*A, B)).doit() == + add(A, X, MatMul(A, B), Y, add(2*A, B)).doit() == + MatAdd(3*A + A*B + B, X, Y)) + + +def test_generic_identity(): + assert MatAdd.identity == GenericZeroMatrix() + assert MatAdd.identity != S.Zero + + +def test_zero_matrix_add(): + assert Add(ZeroMatrix(2, 2), ZeroMatrix(2, 2)) == ZeroMatrix(2, 2) + +@XFAIL +def test_matrix_Add_with_scalar(): + raises(TypeError, lambda: Add(0, ZeroMatrix(2, 2))) + + +def test_shape_error(): + A = MatrixSymbol('A', 2, 3) + B = MatrixSymbol('B', 3, 3) + raises(ShapeError, lambda: MatAdd(A, B)) + + A = MatrixSymbol('A', 3, 2) + raises(ShapeError, lambda: MatAdd(A, B)) diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_matexpr.py b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_matexpr.py new file mode 100644 index 0000000000000000000000000000000000000000..f2319e8d8097c2ad3519eab783c4665623c55b80 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_matexpr.py @@ -0,0 +1,592 @@ +from sympy.concrete.summations import Sum +from sympy.core.exprtools import gcd_terms +from sympy.core.function import (diff, expand) +from sympy.core.relational import Eq +from sympy.core.symbol import (Dummy, Symbol, Str) +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.matrices.dense import zeros +from sympy.polys.polytools import factor + +from sympy.core import (S, symbols, Add, Mul, SympifyError, Rational, + Function) +from sympy.functions import sin, cos, tan, sqrt, cbrt, exp +from sympy.simplify import simplify +from sympy.matrices import (ImmutableMatrix, Inverse, MatAdd, MatMul, + MatPow, Matrix, MatrixExpr, MatrixSymbol, + SparseMatrix, Transpose, Adjoint, MatrixSet) +from sympy.matrices.exceptions import NonSquareMatrixError +from sympy.matrices.expressions.determinant import Determinant, det +from sympy.matrices.expressions.matexpr import MatrixElement +from sympy.matrices.expressions.special import ZeroMatrix, Identity +from sympy.testing.pytest import raises, XFAIL, skip +from importlib.metadata import version + +n, m, l, k, p = symbols('n m l k p', integer=True) +x = symbols('x') +A = MatrixSymbol('A', n, m) +B = MatrixSymbol('B', m, l) +C = MatrixSymbol('C', n, n) +D = MatrixSymbol('D', n, n) +E = MatrixSymbol('E', m, n) +w = MatrixSymbol('w', n, 1) + + +def test_matrix_symbol_creation(): + assert MatrixSymbol('A', 2, 2) + assert MatrixSymbol('A', 0, 0) + raises(ValueError, lambda: MatrixSymbol('A', -1, 2)) + raises(ValueError, lambda: MatrixSymbol('A', 2.0, 2)) + raises(ValueError, lambda: MatrixSymbol('A', 2j, 2)) + raises(ValueError, lambda: MatrixSymbol('A', 2, -1)) + raises(ValueError, lambda: MatrixSymbol('A', 2, 2.0)) + raises(ValueError, lambda: MatrixSymbol('A', 2, 2j)) + + n = symbols('n') + assert MatrixSymbol('A', n, n) + n = symbols('n', integer=False) + raises(ValueError, lambda: MatrixSymbol('A', n, n)) + n = symbols('n', negative=True) + raises(ValueError, lambda: MatrixSymbol('A', n, n)) + + +def test_matexpr_properties(): + assert A.shape == (n, m) + assert (A * B).shape == (n, l) + assert A[0, 1].indices == (0, 1) + assert A[0, 0].symbol == A + assert A[0, 0].symbol.name == 'A' + + +def test_matexpr(): + assert (x*A).shape == A.shape + assert (x*A).__class__ == MatMul + assert 2*A - A - A == ZeroMatrix(*A.shape) + assert (A*B).shape == (n, l) + + +def test_matexpr_subs(): + A = MatrixSymbol('A', n, m) + B = MatrixSymbol('B', m, l) + C = MatrixSymbol('C', m, l) + + assert A.subs(n, m).shape == (m, m) + assert (A*B).subs(B, C) == A*C + assert (A*B).subs(l, n).is_square + + W = MatrixSymbol("W", 3, 3) + X = MatrixSymbol("X", 2, 2) + Y = MatrixSymbol("Y", 1, 2) + Z = MatrixSymbol("Z", n, 2) + # no restrictions on Symbol replacement + assert X.subs(X, Y) == Y + # it might be better to just change the name + y = Str('y') + assert X.subs(Str("X"), y).args == (y, 2, 2) + # it's ok to introduce a wider matrix + assert X[1, 1].subs(X, W) == W[1, 1] + # but for a given MatrixExpression, only change + # name if indexing on the new shape is valid. + # Here, X is 2,2; Y is 1,2 and Y[1, 1] is out + # of range so an error is raised + raises(IndexError, lambda: X[1, 1].subs(X, Y)) + # here, [0, 1] is in range so the subs succeeds + assert X[0, 1].subs(X, Y) == Y[0, 1] + # and here the size of n will accept any index + # in the first position + assert W[2, 1].subs(W, Z) == Z[2, 1] + # but not in the second position + raises(IndexError, lambda: W[2, 2].subs(W, Z)) + # any matrix should raise if invalid + raises(IndexError, lambda: W[2, 2].subs(W, zeros(2))) + + A = SparseMatrix([[1, 2], [3, 4]]) + B = Matrix([[1, 2], [3, 4]]) + C, D = MatrixSymbol('C', 2, 2), MatrixSymbol('D', 2, 2) + + assert (C*D).subs({C: A, D: B}) == MatMul(A, B) + + +def test_addition(): + A = MatrixSymbol('A', n, m) + B = MatrixSymbol('B', n, m) + + assert isinstance(A + B, MatAdd) + assert (A + B).shape == A.shape + assert isinstance(A - A + 2*B, MatMul) + + raises(TypeError, lambda: A + 1) + raises(TypeError, lambda: 5 + A) + raises(TypeError, lambda: 5 - A) + + assert A + ZeroMatrix(n, m) - A == ZeroMatrix(n, m) + raises(TypeError, lambda: ZeroMatrix(n, m) + S.Zero) + + +def test_multiplication(): + A = MatrixSymbol('A', n, m) + B = MatrixSymbol('B', m, l) + C = MatrixSymbol('C', n, n) + + assert (2*A*B).shape == (n, l) + assert (A*0*B) == ZeroMatrix(n, l) + assert (2*A).shape == A.shape + + assert A * ZeroMatrix(m, m) * B == ZeroMatrix(n, l) + + assert C * Identity(n) * C.I == Identity(n) + + assert B/2 == S.Half*B + raises(NotImplementedError, lambda: 2/B) + + A = MatrixSymbol('A', n, n) + B = MatrixSymbol('B', n, n) + assert Identity(n) * (A + B) == A + B + + assert A**2*A == A**3 + assert A**2*(A.I)**3 == A.I + assert A**3*(A.I)**2 == A + + +def test_MatPow(): + A = MatrixSymbol('A', n, n) + + AA = MatPow(A, 2) + assert AA.exp == 2 + assert AA.base == A + assert (A**n).exp == n + + assert A**0 == Identity(n) + assert A**1 == A + assert A**2 == AA + assert A**-1 == Inverse(A) + assert (A**-1)**-1 == A + assert (A**2)**3 == A**6 + assert A**S.Half == sqrt(A) + assert A**Rational(1, 3) == cbrt(A) + raises(NonSquareMatrixError, lambda: MatrixSymbol('B', 3, 2)**2) + + +def test_MatrixSymbol(): + n, m, t = symbols('n,m,t') + X = MatrixSymbol('X', n, m) + assert X.shape == (n, m) + raises(TypeError, lambda: MatrixSymbol('X', n, m)(t)) # issue 5855 + assert X.doit() == X + + +def test_dense_conversion(): + X = MatrixSymbol('X', 2, 2) + assert ImmutableMatrix(X) == ImmutableMatrix(2, 2, lambda i, j: X[i, j]) + assert Matrix(X) == Matrix(2, 2, lambda i, j: X[i, j]) + + +def test_free_symbols(): + assert (C*D).free_symbols == {C, D} + + +def test_zero_matmul(): + assert isinstance(S.Zero * MatrixSymbol('X', 2, 2), MatrixExpr) + + +def test_matadd_simplify(): + A = MatrixSymbol('A', 1, 1) + assert simplify(MatAdd(A, ImmutableMatrix([[sin(x)**2 + cos(x)**2]]))) == \ + MatAdd(A, Matrix([[1]])) + + +def test_matmul_simplify(): + A = MatrixSymbol('A', 1, 1) + assert simplify(MatMul(A, ImmutableMatrix([[sin(x)**2 + cos(x)**2]]))) == \ + MatMul(A, Matrix([[1]])) + + +def test_invariants(): + A = MatrixSymbol('A', n, m) + B = MatrixSymbol('B', m, l) + X = MatrixSymbol('X', n, n) + objs = [Identity(n), ZeroMatrix(m, n), A, MatMul(A, B), MatAdd(A, A), + Transpose(A), Adjoint(A), Inverse(X), MatPow(X, 2), MatPow(X, -1), + MatPow(X, 0)] + for obj in objs: + assert obj == obj.__class__(*obj.args) + + +def test_matexpr_indexing(): + A = MatrixSymbol('A', n, m) + A[1, 2] + A[l, k] + A[l + 1, k + 1] + A = MatrixSymbol('A', 2, 1) + for i in range(-2, 2): + for j in range(-1, 1): + A[i, j] + + +def test_single_indexing(): + A = MatrixSymbol('A', 2, 3) + assert A[1] == A[0, 1] + assert A[int(1)] == A[0, 1] + assert A[3] == A[1, 0] + assert list(A[:2, :2]) == [A[0, 0], A[0, 1], A[1, 0], A[1, 1]] + raises(IndexError, lambda: A[6]) + raises(IndexError, lambda: A[n]) + B = MatrixSymbol('B', n, m) + raises(IndexError, lambda: B[1]) + B = MatrixSymbol('B', n, 3) + assert B[3] == B[1, 0] + + +def test_MatrixElement_commutative(): + assert A[0, 1]*A[1, 0] == A[1, 0]*A[0, 1] + + +def test_MatrixSymbol_determinant(): + A = MatrixSymbol('A', 4, 4) + assert A.as_explicit().det() == A[0, 0]*A[1, 1]*A[2, 2]*A[3, 3] - \ + A[0, 0]*A[1, 1]*A[2, 3]*A[3, 2] - A[0, 0]*A[1, 2]*A[2, 1]*A[3, 3] + \ + A[0, 0]*A[1, 2]*A[2, 3]*A[3, 1] + A[0, 0]*A[1, 3]*A[2, 1]*A[3, 2] - \ + A[0, 0]*A[1, 3]*A[2, 2]*A[3, 1] - A[0, 1]*A[1, 0]*A[2, 2]*A[3, 3] + \ + A[0, 1]*A[1, 0]*A[2, 3]*A[3, 2] + A[0, 1]*A[1, 2]*A[2, 0]*A[3, 3] - \ + A[0, 1]*A[1, 2]*A[2, 3]*A[3, 0] - A[0, 1]*A[1, 3]*A[2, 0]*A[3, 2] + \ + A[0, 1]*A[1, 3]*A[2, 2]*A[3, 0] + A[0, 2]*A[1, 0]*A[2, 1]*A[3, 3] - \ + A[0, 2]*A[1, 0]*A[2, 3]*A[3, 1] - A[0, 2]*A[1, 1]*A[2, 0]*A[3, 3] + \ + A[0, 2]*A[1, 1]*A[2, 3]*A[3, 0] + A[0, 2]*A[1, 3]*A[2, 0]*A[3, 1] - \ + A[0, 2]*A[1, 3]*A[2, 1]*A[3, 0] - A[0, 3]*A[1, 0]*A[2, 1]*A[3, 2] + \ + A[0, 3]*A[1, 0]*A[2, 2]*A[3, 1] + A[0, 3]*A[1, 1]*A[2, 0]*A[3, 2] - \ + A[0, 3]*A[1, 1]*A[2, 2]*A[3, 0] - A[0, 3]*A[1, 2]*A[2, 0]*A[3, 1] + \ + A[0, 3]*A[1, 2]*A[2, 1]*A[3, 0] + + B = MatrixSymbol('B', 4, 4) + assert Determinant(A + B).doit() == det(A + B) == (A + B).det() + + +def test_MatrixElement_diff(): + assert (A[3, 0]*A[0, 0]).diff(A[0, 0]) == A[3, 0] + + +def test_MatrixElement_doit(): + u = MatrixSymbol('u', 2, 1) + v = ImmutableMatrix([3, 5]) + assert u[0, 0].subs(u, v).doit() == v[0, 0] + + +def test_identity_powers(): + M = Identity(n) + assert MatPow(M, 3).doit() == M**3 + assert M**n == M + assert MatPow(M, 0).doit() == M**2 + assert M**-2 == M + assert MatPow(M, -2).doit() == M**0 + N = Identity(3) + assert MatPow(N, 2).doit() == N**n + assert MatPow(N, 3).doit() == N + assert MatPow(N, -2).doit() == N**4 + assert MatPow(N, 2).doit() == N**0 + + +def test_Zero_power(): + z1 = ZeroMatrix(n, n) + assert z1**4 == z1 + raises(ValueError, lambda:z1**-2) + assert z1**0 == Identity(n) + assert MatPow(z1, 2).doit() == z1**2 + raises(ValueError, lambda:MatPow(z1, -2).doit()) + z2 = ZeroMatrix(3, 3) + assert MatPow(z2, 4).doit() == z2**4 + raises(ValueError, lambda:z2**-3) + assert z2**3 == MatPow(z2, 3).doit() + assert z2**0 == Identity(3) + raises(ValueError, lambda:MatPow(z2, -1).doit()) + + +def test_matrixelement_diff(): + dexpr = diff((D*w)[k,0], w[p,0]) + + assert w[k, p].diff(w[k, p]) == 1 + assert w[k, p].diff(w[0, 0]) == KroneckerDelta(0, k, (0, n-1))*KroneckerDelta(0, p, (0, 0)) + _i_1 = Dummy("_i_1") + assert dexpr.dummy_eq(Sum(KroneckerDelta(_i_1, p, (0, n-1))*D[k, _i_1], (_i_1, 0, n - 1))) + assert dexpr.doit() == D[k, p] + + +def test_MatrixElement_with_values(): + x, y, z, w = symbols("x y z w") + M = Matrix([[x, y], [z, w]]) + i, j = symbols("i, j") + Mij = M[i, j] + assert isinstance(Mij, MatrixElement) + Ms = SparseMatrix([[2, 3], [4, 5]]) + msij = Ms[i, j] + assert isinstance(msij, MatrixElement) + for oi, oj in [(0, 0), (0, 1), (1, 0), (1, 1)]: + assert Mij.subs({i: oi, j: oj}) == M[oi, oj] + assert msij.subs({i: oi, j: oj}) == Ms[oi, oj] + A = MatrixSymbol("A", 2, 2) + assert A[0, 0].subs(A, M) == x + assert A[i, j].subs(A, M) == M[i, j] + assert M[i, j].subs(M, A) == A[i, j] + + assert isinstance(M[3*i - 2, j], MatrixElement) + assert M[3*i - 2, j].subs({i: 1, j: 0}) == M[1, 0] + assert isinstance(M[i, 0], MatrixElement) + assert M[i, 0].subs(i, 0) == M[0, 0] + assert M[0, i].subs(i, 1) == M[0, 1] + + assert M[i, j].diff(x) == Matrix([[1, 0], [0, 0]])[i, j] + + raises(ValueError, lambda: M[i, 2]) + raises(ValueError, lambda: M[i, -1]) + raises(ValueError, lambda: M[2, i]) + raises(ValueError, lambda: M[-1, i]) + + +def test_inv(): + B = MatrixSymbol('B', 3, 3) + assert B.inv() == B**-1 + + # https://github.com/sympy/sympy/issues/19162 + X = MatrixSymbol('X', 1, 1).as_explicit() + assert X.inv() == Matrix([[1/X[0, 0]]]) + + X = MatrixSymbol('X', 2, 2).as_explicit() + detX = X[0, 0]*X[1, 1] - X[0, 1]*X[1, 0] + invX = Matrix([[ X[1, 1], -X[0, 1]], + [-X[1, 0], X[0, 0]]]) / detX + assert X.inv() == invX + + +@XFAIL +def test_factor_expand(): + A = MatrixSymbol("A", n, n) + B = MatrixSymbol("B", n, n) + expr1 = (A + B)*(C + D) + expr2 = A*C + B*C + A*D + B*D + assert expr1 != expr2 + assert expand(expr1) == expr2 + assert factor(expr2) == expr1 + + expr = B**(-1)*(A**(-1)*B**(-1) - A**(-1)*C*B**(-1))**(-1)*A**(-1) + I = Identity(n) + # Ideally we get the first, but we at least don't want a wrong answer + assert factor(expr) in [I - C, B**-1*(A**-1*(I - C)*B**-1)**-1*A**-1] + +def test_numpy_conversion(): + try: + from numpy import array, array_equal + except ImportError: + skip('NumPy must be available to test creating matrices from ndarrays') + A = MatrixSymbol('A', 2, 2) + np_array = array([[MatrixElement(A, 0, 0), MatrixElement(A, 0, 1)], + [MatrixElement(A, 1, 0), MatrixElement(A, 1, 1)]]) + assert array_equal(array(A), np_array) + assert array_equal(array(A, copy=True), np_array) + if(int(version('numpy').split('.')[0]) >= 2): #run this test only if numpy is new enough that copy variable is passed properly. + raises(TypeError, lambda: array(A, copy=False)) + +def test_issue_2749(): + A = MatrixSymbol("A", 5, 2) + assert (A.T * A).I.as_explicit() == Matrix([[(A.T * A).I[0, 0], (A.T * A).I[0, 1]], \ + [(A.T * A).I[1, 0], (A.T * A).I[1, 1]]]) + + +def test_issue_2750(): + x = MatrixSymbol('x', 1, 1) + assert (x.T*x).as_explicit()**-1 == Matrix([[x[0, 0]**(-2)]]) + + +def test_issue_7842(): + A = MatrixSymbol('A', 3, 1) + B = MatrixSymbol('B', 2, 1) + assert Eq(A, B) == False + assert Eq(A[1,0], B[1, 0]).func is Eq + A = ZeroMatrix(2, 3) + B = ZeroMatrix(2, 3) + assert Eq(A, B) == True + + +def test_issue_21195(): + t = symbols('t') + x = Function('x')(t) + dx = x.diff(t) + exp1 = cos(x) + cos(x)*dx + exp2 = sin(x) + tan(x)*(dx.diff(t)) + exp3 = sin(x)*sin(t)*(dx.diff(t)).diff(t) + A = Matrix([[exp1], [exp2], [exp3]]) + B = Matrix([[exp1.diff(x)], [exp2.diff(x)], [exp3.diff(x)]]) + assert A.diff(x) == B + + +def test_issue_24859(): + A = MatrixSymbol('A', 2, 3) + B = MatrixSymbol('B', 3, 2) + J = A*B + Jinv = Matrix(J).adjugate() + u = MatrixSymbol('u', 2, 3) + Jk = Jinv.subs(A, A + x*u) + + expected = B[0, 1]*u[1, 0] + B[1, 1]*u[1, 1] + B[2, 1]*u[1, 2] + assert Jk[0, 0].diff(x) == expected + assert diff(Jk[0, 0], x).doit() == expected + + +def test_MatMul_postprocessor(): + z = zeros(2) + z1 = ZeroMatrix(2, 2) + assert Mul(0, z) == Mul(z, 0) in [z, z1] + + M = Matrix([[1, 2], [3, 4]]) + Mx = Matrix([[x, 2*x], [3*x, 4*x]]) + assert Mul(x, M) == Mul(M, x) == Mx + + A = MatrixSymbol("A", 2, 2) + assert Mul(A, M) == MatMul(A, M) + assert Mul(M, A) == MatMul(M, A) + # Scalars should be absorbed into constant matrices + a = Mul(x, M, A) + b = Mul(M, x, A) + c = Mul(M, A, x) + assert a == b == c == MatMul(Mx, A) + a = Mul(x, A, M) + b = Mul(A, x, M) + c = Mul(A, M, x) + assert a == b == c == MatMul(A, Mx) + assert Mul(M, M) == M**2 + assert Mul(A, M, M) == MatMul(A, M**2) + assert Mul(M, M, A) == MatMul(M**2, A) + assert Mul(M, A, M) == MatMul(M, A, M) + + assert Mul(A, x, M, M, x) == MatMul(A, Mx**2) + + +@XFAIL +def test_MatAdd_postprocessor_xfail(): + # This is difficult to get working because of the way that Add processes + # its args. + z = zeros(2) + assert Add(z, S.NaN) == Add(S.NaN, z) + + +def test_MatAdd_postprocessor(): + # Some of these are nonsensical, but we do not raise errors for Add + # because that breaks algorithms that want to replace matrices with dummy + # symbols. + + z = zeros(2) + + assert Add(0, z) == Add(z, 0) == z + + a = Add(S.Infinity, z) + assert a == Add(z, S.Infinity) + assert isinstance(a, Add) + assert a.args == (S.Infinity, z) + + a = Add(S.ComplexInfinity, z) + assert a == Add(z, S.ComplexInfinity) + assert isinstance(a, Add) + assert a.args == (S.ComplexInfinity, z) + + a = Add(z, S.NaN) + # assert a == Add(S.NaN, z) # See the XFAIL above + assert isinstance(a, Add) + assert a.args == (S.NaN, z) + + M = Matrix([[1, 2], [3, 4]]) + a = Add(x, M) + assert a == Add(M, x) + assert isinstance(a, Add) + assert a.args == (x, M) + + A = MatrixSymbol("A", 2, 2) + assert Add(A, M) == Add(M, A) == A + M + + # Scalars should be absorbed into constant matrices (producing an error) + a = Add(x, M, A) + assert a == Add(M, x, A) == Add(M, A, x) == Add(x, A, M) == Add(A, x, M) == Add(A, M, x) + assert isinstance(a, Add) + assert a.args == (x, A + M) + + assert Add(M, M) == 2*M + assert Add(M, A, M) == Add(M, M, A) == Add(A, M, M) == A + 2*M + + a = Add(A, x, M, M, x) + assert isinstance(a, Add) + assert a.args == (2*x, A + 2*M) + + +def test_simplify_matrix_expressions(): + # Various simplification functions + assert type(gcd_terms(C*D + D*C)) == MatAdd + a = gcd_terms(2*C*D + 4*D*C) + assert type(a) == MatAdd + assert a.args == (2*C*D, 4*D*C) + + +def test_exp(): + A = MatrixSymbol('A', 2, 2) + B = MatrixSymbol('B', 2, 2) + expr1 = exp(A)*exp(B) + expr2 = exp(B)*exp(A) + assert expr1 != expr2 + assert expr1 - expr2 != 0 + assert not isinstance(expr1, exp) + assert not isinstance(expr2, exp) + + +def test_invalid_args(): + raises(SympifyError, lambda: MatrixSymbol(1, 2, 'A')) + + +def test_matrixsymbol_from_symbol(): + # The label should be preserved during doit and subs + A_label = Symbol('A', complex=True) + A = MatrixSymbol(A_label, 2, 2) + + A_1 = A.doit() + A_2 = A.subs(2, 3) + assert A_1.args == A.args + assert A_2.args[0] == A.args[0] + + +def test_as_explicit(): + Z = MatrixSymbol('Z', 2, 3) + assert Z.as_explicit() == ImmutableMatrix([ + [Z[0, 0], Z[0, 1], Z[0, 2]], + [Z[1, 0], Z[1, 1], Z[1, 2]], + ]) + raises(ValueError, lambda: A.as_explicit()) + + +def test_MatrixSet(): + M = MatrixSet(2, 2, set=S.Reals) + assert M.shape == (2, 2) + assert M.set == S.Reals + X = Matrix([[1, 2], [3, 4]]) + assert X in M + X = ZeroMatrix(2, 2) + assert X in M + raises(TypeError, lambda: A in M) + raises(TypeError, lambda: 1 in M) + M = MatrixSet(n, m, set=S.Reals) + assert A in M + raises(TypeError, lambda: C in M) + raises(TypeError, lambda: X in M) + M = MatrixSet(2, 2, set={1, 2, 3}) + X = Matrix([[1, 2], [3, 4]]) + Y = Matrix([[1, 2]]) + assert (X in M) == S.false + assert (Y in M) == S.false + raises(ValueError, lambda: MatrixSet(2, -2, S.Reals)) + raises(ValueError, lambda: MatrixSet(2.4, -1, S.Reals)) + raises(TypeError, lambda: MatrixSet(2, 2, (1, 2, 3))) + + +def test_matrixsymbol_solving(): + A = MatrixSymbol('A', 2, 2) + B = MatrixSymbol('B', 2, 2) + Z = ZeroMatrix(2, 2) + assert -(-A + B) - A + B == Z + assert (-(-A + B) - A + B).simplify() == Z + assert (-(-A + B) - A + B).expand() == Z + assert (-(-A + B) - A + B - Z).simplify() == Z + assert (-(-A + B) - A + B - Z).expand() == Z + assert (A*(A + B) + B*(A.T + B.T)).expand() == A**2 + A*B + B*A.T + B*B.T diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_matmul.py b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_matmul.py new file mode 100644 index 0000000000000000000000000000000000000000..813926e2c83e27716f4f894ebebd09b2a576f046 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_matmul.py @@ -0,0 +1,193 @@ +from sympy.core import I, symbols, Basic, Mul, S +from sympy.core.mul import mul +from sympy.functions import adjoint, transpose +from sympy.matrices.exceptions import ShapeError +from sympy.matrices import (Identity, Inverse, Matrix, MatrixSymbol, ZeroMatrix, + eye, ImmutableMatrix) +from sympy.matrices.expressions import Adjoint, Transpose, det, MatPow +from sympy.matrices.expressions.special import GenericIdentity +from sympy.matrices.expressions.matmul import (factor_in_front, remove_ids, + MatMul, combine_powers, any_zeros, unpack, only_squares) +from sympy.strategies import null_safe +from sympy.assumptions.ask import Q +from sympy.assumptions.refine import refine +from sympy.core.symbol import Symbol + +from sympy.testing.pytest import XFAIL, raises + +n, m, l, k = symbols('n m l k', integer=True) +x = symbols('x') +A = MatrixSymbol('A', n, m) +B = MatrixSymbol('B', m, l) +C = MatrixSymbol('C', n, n) +D = MatrixSymbol('D', n, n) +E = MatrixSymbol('E', m, n) + +def test_evaluate(): + assert MatMul(C, C, evaluate=True) == MatMul(C, C).doit() + +def test_adjoint(): + assert adjoint(A*B) == Adjoint(B)*Adjoint(A) + assert adjoint(2*A*B) == 2*Adjoint(B)*Adjoint(A) + assert adjoint(2*I*C) == -2*I*Adjoint(C) + + M = Matrix(2, 2, [1, 2 + I, 3, 4]) + MA = Matrix(2, 2, [1, 3, 2 - I, 4]) + assert adjoint(M) == MA + assert adjoint(2*M) == 2*MA + assert adjoint(MatMul(2, M)) == MatMul(2, MA).doit() + + +def test_transpose(): + assert transpose(A*B) == Transpose(B)*Transpose(A) + assert transpose(2*A*B) == 2*Transpose(B)*Transpose(A) + assert transpose(2*I*C) == 2*I*Transpose(C) + + M = Matrix(2, 2, [1, 2 + I, 3, 4]) + MT = Matrix(2, 2, [1, 3, 2 + I, 4]) + assert transpose(M) == MT + assert transpose(2*M) == 2*MT + assert transpose(x*M) == x*MT + assert transpose(MatMul(2, M)) == MatMul(2, MT).doit() + + +def test_factor_in_front(): + assert factor_in_front(MatMul(A, 2, B, evaluate=False)) ==\ + MatMul(2, A, B, evaluate=False) + + +def test_remove_ids(): + assert remove_ids(MatMul(A, Identity(m), B, evaluate=False)) == \ + MatMul(A, B, evaluate=False) + assert null_safe(remove_ids)(MatMul(Identity(n), evaluate=False)) == \ + MatMul(Identity(n), evaluate=False) + + +def test_combine_powers(): + assert combine_powers(MatMul(D, Inverse(D), D, evaluate=False)) == \ + MatMul(Identity(n), D, evaluate=False) + assert combine_powers(MatMul(B.T, Inverse(E*A), E, A, B, evaluate=False)) == \ + MatMul(B.T, Identity(m), B, evaluate=False) + assert combine_powers(MatMul(A, E, Inverse(A*E), D, evaluate=False)) == \ + MatMul(Identity(n), D, evaluate=False) + + +def test_any_zeros(): + assert any_zeros(MatMul(A, ZeroMatrix(m, k), evaluate=False)) == \ + ZeroMatrix(n, k) + + +def test_unpack(): + assert unpack(MatMul(A, evaluate=False)) == A + x = MatMul(A, B) + assert unpack(x) == x + + +def test_only_squares(): + assert only_squares(C) == [C] + assert only_squares(C, D) == [C, D] + assert only_squares(C, A, A.T, D) == [C, A*A.T, D] + + +def test_determinant(): + assert det(2*C) == 2**n*det(C) + assert det(2*C*D) == 2**n*det(C)*det(D) + assert det(3*C*A*A.T*D) == 3**n*det(C)*det(A*A.T)*det(D) + + +def test_doit(): + assert MatMul(C, 2, D).args == (C, 2, D) + assert MatMul(C, 2, D).doit().args == (2, C, D) + assert MatMul(C, Transpose(D*C)).args == (C, Transpose(D*C)) + assert MatMul(C, Transpose(D*C)).doit(deep=True).args == (C, C.T, D.T) + + +def test_doit_drills_down(): + X = ImmutableMatrix([[1, 2], [3, 4]]) + Y = ImmutableMatrix([[2, 3], [4, 5]]) + assert MatMul(X, MatPow(Y, 2)).doit() == X*Y**2 + assert MatMul(C, Transpose(D*C)).doit().args == (C, C.T, D.T) + + +def test_doit_deep_false_still_canonical(): + assert (MatMul(C, Transpose(D*C), 2).doit(deep=False).args == + (2, C, Transpose(D*C))) + + +def test_matmul_scalar_Matrix_doit(): + # Issue 9053 + X = Matrix([[1, 2], [3, 4]]) + assert MatMul(2, X).doit() == 2*X + + +def test_matmul_sympify(): + assert isinstance(MatMul(eye(1), eye(1)).args[0], Basic) + + +def test_collapse_MatrixBase(): + A = Matrix([[1, 1], [1, 1]]) + B = Matrix([[1, 2], [3, 4]]) + assert MatMul(A, B).doit() == ImmutableMatrix([[4, 6], [4, 6]]) + + +def test_refine(): + assert refine(C*C.T*D, Q.orthogonal(C)).doit() == D + + kC = k*C + assert refine(kC*C.T, Q.orthogonal(C)).doit() == k*Identity(n) + assert refine(kC* kC.T, Q.orthogonal(C)).doit() == (k**2)*Identity(n) + +def test_matmul_no_matrices(): + assert MatMul(1) == 1 + assert MatMul(n, m) == n*m + assert not isinstance(MatMul(n, m), MatMul) + +def test_matmul_args_cnc(): + assert MatMul(n, A, A.T).args_cnc() == [[n], [A, A.T]] + assert MatMul(A, A.T).args_cnc() == [[], [A, A.T]] + +@XFAIL +def test_matmul_args_cnc_symbols(): + # Not currently supported + a, b = symbols('a b', commutative=False) + assert MatMul(n, a, b, A, A.T).args_cnc() == [[n], [a, b, A, A.T]] + assert MatMul(n, a, A, b, A.T).args_cnc() == [[n], [a, A, b, A.T]] + +def test_issue_12950(): + M = Matrix([[Symbol("x")]]) * MatrixSymbol("A", 1, 1) + assert MatrixSymbol("A", 1, 1).as_explicit()[0]*Symbol('x') == M.as_explicit()[0] + +def test_construction_with_Mul(): + assert Mul(C, D) == MatMul(C, D) + assert Mul(D, C) == MatMul(D, C) + +def test_construction_with_mul(): + assert mul(C, D) == MatMul(C, D) + assert mul(D, C) == MatMul(D, C) + assert mul(C, D) != MatMul(D, C) + +def test_generic_identity(): + assert MatMul.identity == GenericIdentity() + assert MatMul.identity != S.One + + +def test_issue_23519(): + N = Symbol("N", integer=True) + M1 = MatrixSymbol("M1", N, N) + M2 = MatrixSymbol("M2", N, N) + I = Identity(N) + z = (M2 + 2 * (M2 + I) * M1 + I) + assert z.coeff(M1) == 2*I + 2*M2 + + +def test_shape_error(): + A = MatrixSymbol('A', 2, 2) + B = MatrixSymbol('B', 3, 3) + raises(ShapeError, lambda: MatMul(A, B)) + + +def test_matmul_transpose(): + # https://github.com/sympy/sympy/issues/9503 + M = Matrix(2, 2, [1, 2 + I, 3, 4]) + a = Symbol('a') + assert (MatMul(a, M).T).expand() == (a*Matrix([[1, 3],[2 + I, 4]])).expand() diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_matpow.py b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_matpow.py new file mode 100644 index 0000000000000000000000000000000000000000..2afb5fdc2aa652c321de52aba43db63da60941fd --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_matpow.py @@ -0,0 +1,217 @@ +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.simplify.powsimp import powsimp +from sympy.testing.pytest import raises +from sympy.core.expr import unchanged +from sympy.core import symbols, S +from sympy.matrices import Identity, MatrixSymbol, ImmutableMatrix, ZeroMatrix, OneMatrix, Matrix +from sympy.matrices.exceptions import NonSquareMatrixError +from sympy.matrices.expressions import MatPow, MatAdd, MatMul +from sympy.matrices.expressions.inverse import Inverse +from sympy.matrices.expressions.matexpr import MatrixElement + +n, m, l, k = symbols('n m l k', integer=True) +A = MatrixSymbol('A', n, m) +B = MatrixSymbol('B', m, l) +C = MatrixSymbol('C', n, n) +D = MatrixSymbol('D', n, n) +E = MatrixSymbol('E', m, n) + + +def test_entry_matrix(): + X = ImmutableMatrix([[1, 2], [3, 4]]) + assert MatPow(X, 0)[0, 0] == 1 + assert MatPow(X, 0)[0, 1] == 0 + assert MatPow(X, 1)[0, 0] == 1 + assert MatPow(X, 1)[0, 1] == 2 + assert MatPow(X, 2)[0, 0] == 7 + + +def test_entry_symbol(): + from sympy.concrete import Sum + assert MatPow(C, 0)[0, 0] == 1 + assert MatPow(C, 0)[0, 1] == 0 + assert MatPow(C, 1)[0, 0] == C[0, 0] + assert isinstance(MatPow(C, 2)[0, 0], Sum) + assert isinstance(MatPow(C, n)[0, 0], MatrixElement) + + +def test_as_explicit_symbol(): + X = MatrixSymbol('X', 2, 2) + assert MatPow(X, 0).as_explicit() == ImmutableMatrix(Identity(2)) + assert MatPow(X, 1).as_explicit() == X.as_explicit() + assert MatPow(X, 2).as_explicit() == (X.as_explicit())**2 + assert MatPow(X, n).as_explicit() == ImmutableMatrix([ + [(X ** n)[0, 0], (X ** n)[0, 1]], + [(X ** n)[1, 0], (X ** n)[1, 1]], + ]) + + a = MatrixSymbol("a", 3, 1) + b = MatrixSymbol("b", 3, 1) + c = MatrixSymbol("c", 3, 1) + + expr = (a.T*b)**S.Half + assert expr.as_explicit() == Matrix([[sqrt(a[0, 0]*b[0, 0] + a[1, 0]*b[1, 0] + a[2, 0]*b[2, 0])]]) + + expr = c*(a.T*b)**S.Half + m = sqrt(a[0, 0]*b[0, 0] + a[1, 0]*b[1, 0] + a[2, 0]*b[2, 0]) + assert expr.as_explicit() == Matrix([[c[0, 0]*m], [c[1, 0]*m], [c[2, 0]*m]]) + + expr = (a*b.T)**S.Half + denom = sqrt(a[0, 0]*b[0, 0] + a[1, 0]*b[1, 0] + a[2, 0]*b[2, 0]) + expected = (a*b.T).as_explicit()/denom + assert expr.as_explicit() == expected + + expr = X**-1 + det = X[0, 0]*X[1, 1] - X[1, 0]*X[0, 1] + expected = Matrix([[X[1, 1], -X[0, 1]], [-X[1, 0], X[0, 0]]])/det + assert expr.as_explicit() == expected + + expr = X**m + assert expr.as_explicit() == X.as_explicit()**m + + +def test_as_explicit_matrix(): + A = ImmutableMatrix([[1, 2], [3, 4]]) + assert MatPow(A, 0).as_explicit() == ImmutableMatrix(Identity(2)) + assert MatPow(A, 1).as_explicit() == A + assert MatPow(A, 2).as_explicit() == A**2 + assert MatPow(A, -1).as_explicit() == A.inv() + assert MatPow(A, -2).as_explicit() == (A.inv())**2 + # less expensive than testing on a 2x2 + A = ImmutableMatrix([4]) + assert MatPow(A, S.Half).as_explicit() == A**S.Half + + +def test_doit_symbol(): + assert MatPow(C, 0).doit() == Identity(n) + assert MatPow(C, 1).doit() == C + assert MatPow(C, -1).doit() == C.I + for r in [2, S.Half, S.Pi, n]: + assert MatPow(C, r).doit() == MatPow(C, r) + + +def test_doit_matrix(): + X = ImmutableMatrix([[1, 2], [3, 4]]) + assert MatPow(X, 0).doit() == ImmutableMatrix(Identity(2)) + assert MatPow(X, 1).doit() == X + assert MatPow(X, 2).doit() == X**2 + assert MatPow(X, -1).doit() == X.inv() + assert MatPow(X, -2).doit() == (X.inv())**2 + # less expensive than testing on a 2x2 + assert MatPow(ImmutableMatrix([4]), S.Half).doit() == ImmutableMatrix([2]) + X = ImmutableMatrix([[0, 2], [0, 4]]) # det() == 0 + raises(ValueError, lambda: MatPow(X,-1).doit()) + raises(ValueError, lambda: MatPow(X,-2).doit()) + + +def test_nonsquare(): + A = MatrixSymbol('A', 2, 3) + B = ImmutableMatrix([[1, 2, 3], [4, 5, 6]]) + for r in [-1, 0, 1, 2, S.Half, S.Pi, n]: + raises(NonSquareMatrixError, lambda: MatPow(A, r)) + raises(NonSquareMatrixError, lambda: MatPow(B, r)) + + +def test_doit_equals_pow(): #17179 + X = ImmutableMatrix ([[1,0],[0,1]]) + assert MatPow(X, n).doit() == X**n == X + + +def test_doit_nested_MatrixExpr(): + X = ImmutableMatrix([[1, 2], [3, 4]]) + Y = ImmutableMatrix([[2, 3], [4, 5]]) + assert MatPow(MatMul(X, Y), 2).doit() == (X*Y)**2 + assert MatPow(MatAdd(X, Y), 2).doit() == (X + Y)**2 + + +def test_identity_power(): + k = Identity(n) + assert MatPow(k, 4).doit() == k + assert MatPow(k, n).doit() == k + assert MatPow(k, -3).doit() == k + assert MatPow(k, 0).doit() == k + l = Identity(3) + assert MatPow(l, n).doit() == l + assert MatPow(l, -1).doit() == l + assert MatPow(l, 0).doit() == l + + +def test_zero_power(): + z1 = ZeroMatrix(n, n) + assert MatPow(z1, 3).doit() == z1 + raises(ValueError, lambda:MatPow(z1, -1).doit()) + assert MatPow(z1, 0).doit() == Identity(n) + assert MatPow(z1, n).doit() == z1 + raises(ValueError, lambda:MatPow(z1, -2).doit()) + z2 = ZeroMatrix(4, 4) + assert MatPow(z2, n).doit() == z2 + raises(ValueError, lambda:MatPow(z2, -3).doit()) + assert MatPow(z2, 2).doit() == z2 + assert MatPow(z2, 0).doit() == Identity(4) + raises(ValueError, lambda:MatPow(z2, -1).doit()) + + +def test_OneMatrix_power(): + o = OneMatrix(3, 3) + assert o ** 0 == Identity(3) + assert o ** 1 == o + assert o * o == o ** 2 == 3 * o + assert o * o * o == o ** 3 == 9 * o + + o = OneMatrix(n, n) + assert o * o == o ** 2 == n * o + # powsimp necessary as n ** (n - 2) * n does not produce n ** (n - 1) + assert powsimp(o ** (n - 1) * o) == o ** n == n ** (n - 1) * o + + +def test_transpose_power(): + from sympy.matrices.expressions.transpose import Transpose as TP + + assert (C*D).T**5 == ((C*D)**5).T == (D.T * C.T)**5 + assert ((C*D).T**5).T == (C*D)**5 + + assert (C.T.I.T)**7 == C**-7 + assert (C.T**l).T**k == C**(l*k) + + assert ((E.T * A.T)**5).T == (A*E)**5 + assert ((A*E).T**5).T**7 == (A*E)**35 + assert TP(TP(C**2 * D**3)**5).doit() == (C**2 * D**3)**5 + + assert ((D*C)**-5).T**-5 == ((D*C)**25).T + assert (((D*C)**l).T**k).T == (D*C)**(l*k) + + +def test_Inverse(): + assert Inverse(MatPow(C, 0)).doit() == Identity(n) + assert Inverse(MatPow(C, 1)).doit() == Inverse(C) + assert Inverse(MatPow(C, 2)).doit() == MatPow(C, -2) + assert Inverse(MatPow(C, -1)).doit() == C + + assert MatPow(Inverse(C), 0).doit() == Identity(n) + assert MatPow(Inverse(C), 1).doit() == Inverse(C) + assert MatPow(Inverse(C), 2).doit() == MatPow(C, -2) + assert MatPow(Inverse(C), -1).doit() == C + + +def test_combine_powers(): + assert (C ** 1) ** 1 == C + assert (C ** 2) ** 3 == MatPow(C, 6) + assert (C ** -2) ** -3 == MatPow(C, 6) + assert (C ** -1) ** -1 == C + assert (((C ** 2) ** 3) ** 4) ** 5 == MatPow(C, 120) + assert (C ** n) ** n == C ** (n ** 2) + + +def test_unchanged(): + assert unchanged(MatPow, C, 0) + assert unchanged(MatPow, C, 1) + assert unchanged(MatPow, Inverse(C), -1) + assert unchanged(Inverse, MatPow(C, -1), -1) + assert unchanged(MatPow, MatPow(C, -1), -1) + assert unchanged(MatPow, MatPow(C, 1), 1) + + +def test_no_exponentiation(): + # if this passes, Pow.as_numer_denom should recognize + # MatAdd as exponent + raises(NotImplementedError, lambda: 3**(-2*C)) diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_permutation.py b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_permutation.py new file mode 100644 index 0000000000000000000000000000000000000000..41a924f6636afb2e5b6560987e38a0fa0c861f1e --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_permutation.py @@ -0,0 +1,166 @@ +from sympy.combinatorics import Permutation +from sympy.core.expr import unchanged +from sympy.matrices import Matrix +from sympy.matrices.expressions import \ + MatMul, BlockDiagMatrix, Determinant, Inverse +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.expressions.special import ZeroMatrix, OneMatrix, Identity +from sympy.matrices.expressions.permutation import \ + MatrixPermute, PermutationMatrix +from sympy.testing.pytest import raises +from sympy.core.symbol import Symbol + + +def test_PermutationMatrix_basic(): + p = Permutation([1, 0]) + assert unchanged(PermutationMatrix, p) + raises(ValueError, lambda: PermutationMatrix((0, 1, 2))) + assert PermutationMatrix(p).as_explicit() == Matrix([[0, 1], [1, 0]]) + assert isinstance(PermutationMatrix(p)*MatrixSymbol('A', 2, 2), MatMul) + + +def test_PermutationMatrix_matmul(): + p = Permutation([1, 2, 0]) + P = PermutationMatrix(p) + M = Matrix([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) + assert (P*M).as_explicit() == P.as_explicit()*M + assert (M*P).as_explicit() == M*P.as_explicit() + + P1 = PermutationMatrix(Permutation([1, 2, 0])) + P2 = PermutationMatrix(Permutation([2, 1, 0])) + P3 = PermutationMatrix(Permutation([1, 0, 2])) + assert P1*P2 == P3 + + +def test_PermutationMatrix_matpow(): + p1 = Permutation([1, 2, 0]) + P1 = PermutationMatrix(p1) + p2 = Permutation([2, 0, 1]) + P2 = PermutationMatrix(p2) + assert P1**2 == P2 + assert P1**3 == Identity(3) + + +def test_PermutationMatrix_identity(): + p = Permutation([0, 1]) + assert PermutationMatrix(p).is_Identity + + p = Permutation([1, 0]) + assert not PermutationMatrix(p).is_Identity + + +def test_PermutationMatrix_determinant(): + P = PermutationMatrix(Permutation([0, 1, 2])) + assert Determinant(P).doit() == 1 + P = PermutationMatrix(Permutation([0, 2, 1])) + assert Determinant(P).doit() == -1 + P = PermutationMatrix(Permutation([2, 0, 1])) + assert Determinant(P).doit() == 1 + + +def test_PermutationMatrix_inverse(): + P = PermutationMatrix(Permutation(0, 1, 2)) + assert Inverse(P).doit() == PermutationMatrix(Permutation(0, 2, 1)) + + +def test_PermutationMatrix_rewrite_BlockDiagMatrix(): + P = PermutationMatrix(Permutation([0, 1, 2, 3, 4, 5])) + P0 = PermutationMatrix(Permutation([0])) + assert P.rewrite(BlockDiagMatrix) == \ + BlockDiagMatrix(P0, P0, P0, P0, P0, P0) + + P = PermutationMatrix(Permutation([0, 1, 3, 2, 4, 5])) + P10 = PermutationMatrix(Permutation(0, 1)) + assert P.rewrite(BlockDiagMatrix) == \ + BlockDiagMatrix(P0, P0, P10, P0, P0) + + P = PermutationMatrix(Permutation([1, 0, 3, 2, 5, 4])) + assert P.rewrite(BlockDiagMatrix) == \ + BlockDiagMatrix(P10, P10, P10) + + P = PermutationMatrix(Permutation([0, 4, 3, 2, 1, 5])) + P3210 = PermutationMatrix(Permutation([3, 2, 1, 0])) + assert P.rewrite(BlockDiagMatrix) == \ + BlockDiagMatrix(P0, P3210, P0) + + P = PermutationMatrix(Permutation([0, 4, 2, 3, 1, 5])) + P3120 = PermutationMatrix(Permutation([3, 1, 2, 0])) + assert P.rewrite(BlockDiagMatrix) == \ + BlockDiagMatrix(P0, P3120, P0) + + P = PermutationMatrix(Permutation(0, 3)(1, 4)(2, 5)) + assert P.rewrite(BlockDiagMatrix) == BlockDiagMatrix(P) + + +def test_MartrixPermute_basic(): + p = Permutation(0, 1) + P = PermutationMatrix(p) + A = MatrixSymbol('A', 2, 2) + + raises(ValueError, lambda: MatrixPermute(Symbol('x'), p)) + raises(ValueError, lambda: MatrixPermute(A, Symbol('x'))) + + assert MatrixPermute(A, P) == MatrixPermute(A, p) + raises(ValueError, lambda: MatrixPermute(A, p, 2)) + + pp = Permutation(0, 1, size=3) + assert MatrixPermute(A, pp) == MatrixPermute(A, p) + pp = Permutation(0, 1, 2) + raises(ValueError, lambda: MatrixPermute(A, pp)) + + +def test_MatrixPermute_shape(): + p = Permutation(0, 1) + A = MatrixSymbol('A', 2, 3) + assert MatrixPermute(A, p).shape == (2, 3) + + +def test_MatrixPermute_explicit(): + p = Permutation(0, 1, 2) + A = MatrixSymbol('A', 3, 3) + AA = A.as_explicit() + assert MatrixPermute(A, p, 0).as_explicit() == \ + AA.permute(p, orientation='rows') + assert MatrixPermute(A, p, 1).as_explicit() == \ + AA.permute(p, orientation='cols') + + +def test_MatrixPermute_rewrite_MatMul(): + p = Permutation(0, 1, 2) + A = MatrixSymbol('A', 3, 3) + + assert MatrixPermute(A, p, 0).rewrite(MatMul).as_explicit() == \ + MatrixPermute(A, p, 0).as_explicit() + assert MatrixPermute(A, p, 1).rewrite(MatMul).as_explicit() == \ + MatrixPermute(A, p, 1).as_explicit() + + +def test_MatrixPermute_doit(): + p = Permutation(0, 1, 2) + A = MatrixSymbol('A', 3, 3) + assert MatrixPermute(A, p).doit() == MatrixPermute(A, p) + + p = Permutation(0, size=3) + A = MatrixSymbol('A', 3, 3) + assert MatrixPermute(A, p).doit().as_explicit() == \ + MatrixPermute(A, p).as_explicit() + + p = Permutation(0, 1, 2) + A = Identity(3) + assert MatrixPermute(A, p, 0).doit().as_explicit() == \ + MatrixPermute(A, p, 0).as_explicit() + assert MatrixPermute(A, p, 1).doit().as_explicit() == \ + MatrixPermute(A, p, 1).as_explicit() + + A = ZeroMatrix(3, 3) + assert MatrixPermute(A, p).doit() == A + A = OneMatrix(3, 3) + assert MatrixPermute(A, p).doit() == A + + A = MatrixSymbol('A', 4, 4) + p1 = Permutation(0, 1, 2, 3) + p2 = Permutation(0, 2, 3, 1) + expr = MatrixPermute(MatrixPermute(A, p1, 0), p2, 0) + assert expr.as_explicit() == expr.doit().as_explicit() + expr = MatrixPermute(MatrixPermute(A, p1, 1), p2, 1) + assert expr.as_explicit() == expr.doit().as_explicit() diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_sets.py b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_sets.py new file mode 100644 index 0000000000000000000000000000000000000000..e811c7968c5a22d65f1c99e995aaa7e5e59d15c4 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_sets.py @@ -0,0 +1,42 @@ +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.matrices import Matrix +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.expressions.sets import MatrixSet +from sympy.matrices.expressions.special import ZeroMatrix +from sympy.testing.pytest import raises +from sympy.sets.sets import SetKind +from sympy.matrices.kind import MatrixKind +from sympy.core.kind import NumberKind + + +def test_MatrixSet(): + n, m = symbols('n m', integer=True) + A = MatrixSymbol('A', n, m) + C = MatrixSymbol('C', n, n) + + M = MatrixSet(2, 2, set=S.Reals) + assert M.shape == (2, 2) + assert M.set == S.Reals + X = Matrix([[1, 2], [3, 4]]) + assert X in M + X = ZeroMatrix(2, 2) + assert X in M + raises(TypeError, lambda: A in M) + raises(TypeError, lambda: 1 in M) + M = MatrixSet(n, m, set=S.Reals) + assert A in M + raises(TypeError, lambda: C in M) + raises(TypeError, lambda: X in M) + M = MatrixSet(2, 2, set={1, 2, 3}) + X = Matrix([[1, 2], [3, 4]]) + Y = Matrix([[1, 2]]) + assert (X in M) == S.false + assert (Y in M) == S.false + raises(ValueError, lambda: MatrixSet(2, -2, S.Reals)) + raises(ValueError, lambda: MatrixSet(2.4, -1, S.Reals)) + raises(TypeError, lambda: MatrixSet(2, 2, (1, 2, 3))) + + +def test_SetKind_MatrixSet(): + assert MatrixSet(2, 2, set=S.Reals).kind is SetKind(MatrixKind(NumberKind)) diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_slice.py b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_slice.py new file mode 100644 index 0000000000000000000000000000000000000000..36490719e26908b9e913ed99b7673d602647c492 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_slice.py @@ -0,0 +1,65 @@ +from sympy.matrices.expressions.slice import MatrixSlice +from sympy.matrices.expressions import MatrixSymbol +from sympy.abc import a, b, c, d, k, l, m, n +from sympy.testing.pytest import raises, XFAIL +from sympy.functions.elementary.integers import floor +from sympy.assumptions import assuming, Q + + +X = MatrixSymbol('X', n, m) +Y = MatrixSymbol('Y', m, k) + +def test_shape(): + B = MatrixSlice(X, (a, b), (c, d)) + assert B.shape == (b - a, d - c) + +def test_entry(): + B = MatrixSlice(X, (a, b), (c, d)) + assert B[0,0] == X[a, c] + assert B[k,l] == X[a+k, c+l] + raises(IndexError, lambda : MatrixSlice(X, 1, (2, 5))[1, 0]) + + assert X[1::2, :][1, 3] == X[1+2, 3] + assert X[:, 1::2][3, 1] == X[3, 1+2] + +def test_on_diag(): + assert not MatrixSlice(X, (a, b), (c, d)).on_diag + assert MatrixSlice(X, (a, b), (a, b)).on_diag + +def test_inputs(): + assert MatrixSlice(X, 1, (2, 5)) == MatrixSlice(X, (1, 2), (2, 5)) + assert MatrixSlice(X, 1, (2, 5)).shape == (1, 3) + +def test_slicing(): + assert X[1:5, 2:4] == MatrixSlice(X, (1, 5), (2, 4)) + assert X[1, 2:4] == MatrixSlice(X, 1, (2, 4)) + assert X[1:5, :].shape == (4, X.shape[1]) + assert X[:, 1:5].shape == (X.shape[0], 4) + + assert X[::2, ::2].shape == (floor(n/2), floor(m/2)) + assert X[2, :] == MatrixSlice(X, 2, (0, m)) + assert X[k, :] == MatrixSlice(X, k, (0, m)) + +def test_exceptions(): + X = MatrixSymbol('x', 10, 20) + raises(IndexError, lambda: X[0:12, 2]) + raises(IndexError, lambda: X[0:9, 22]) + raises(IndexError, lambda: X[-1:5, 2]) + +@XFAIL +def test_symmetry(): + X = MatrixSymbol('x', 10, 10) + Y = X[:5, 5:] + with assuming(Q.symmetric(X)): + assert Y.T == X[5:, :5] + +def test_slice_of_slice(): + X = MatrixSymbol('x', 10, 10) + assert X[2, :][:, 3][0, 0] == X[2, 3] + assert X[:5, :5][:4, :4] == X[:4, :4] + assert X[1:5, 2:6][1:3, 2] == X[2:4, 4] + assert X[1:9:2, 2:6][1:3, 2] == X[3:7:2, 4] + +def test_negative_index(): + X = MatrixSymbol('x', 10, 10) + assert X[-1, :] == X[9, :] diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_special.py b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_special.py new file mode 100644 index 0000000000000000000000000000000000000000..beeaf1d76a63673b6622709cda598dfcb295bba4 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_special.py @@ -0,0 +1,228 @@ +from sympy.core.add import Add +from sympy.core.expr import unchanged +from sympy.core.mul import Mul +from sympy.core.symbol import symbols +from sympy.core.relational import Eq +from sympy.concrete.summations import Sum +from sympy.functions.elementary.complexes import im, re +from sympy.functions.elementary.piecewise import Piecewise +from sympy.matrices.immutable import ImmutableDenseMatrix +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.expressions.matadd import MatAdd +from sympy.matrices.expressions.special import ( + ZeroMatrix, GenericZeroMatrix, Identity, GenericIdentity, OneMatrix) +from sympy.matrices.expressions.matmul import MatMul +from sympy.testing.pytest import raises + + +def test_zero_matrix_creation(): + assert unchanged(ZeroMatrix, 2, 2) + assert unchanged(ZeroMatrix, 0, 0) + raises(ValueError, lambda: ZeroMatrix(-1, 2)) + raises(ValueError, lambda: ZeroMatrix(2.0, 2)) + raises(ValueError, lambda: ZeroMatrix(2j, 2)) + raises(ValueError, lambda: ZeroMatrix(2, -1)) + raises(ValueError, lambda: ZeroMatrix(2, 2.0)) + raises(ValueError, lambda: ZeroMatrix(2, 2j)) + + n = symbols('n') + assert unchanged(ZeroMatrix, n, n) + n = symbols('n', integer=False) + raises(ValueError, lambda: ZeroMatrix(n, n)) + n = symbols('n', negative=True) + raises(ValueError, lambda: ZeroMatrix(n, n)) + + +def test_generic_zero_matrix(): + z = GenericZeroMatrix() + n = symbols('n', integer=True) + A = MatrixSymbol("A", n, n) + + assert z == z + assert z != A + assert A != z + + assert z.is_ZeroMatrix + + raises(TypeError, lambda: z.shape) + raises(TypeError, lambda: z.rows) + raises(TypeError, lambda: z.cols) + + assert MatAdd() == z + assert MatAdd(z, A) == MatAdd(A) + # Make sure it is hashable + hash(z) + + +def test_identity_matrix_creation(): + assert Identity(2) + assert Identity(0) + raises(ValueError, lambda: Identity(-1)) + raises(ValueError, lambda: Identity(2.0)) + raises(ValueError, lambda: Identity(2j)) + + n = symbols('n') + assert Identity(n) + n = symbols('n', integer=False) + raises(ValueError, lambda: Identity(n)) + n = symbols('n', negative=True) + raises(ValueError, lambda: Identity(n)) + + +def test_generic_identity(): + I = GenericIdentity() + n = symbols('n', integer=True) + A = MatrixSymbol("A", n, n) + + assert I == I + assert I != A + assert A != I + + assert I.is_Identity + assert I**-1 == I + + raises(TypeError, lambda: I.shape) + raises(TypeError, lambda: I.rows) + raises(TypeError, lambda: I.cols) + + assert MatMul() == I + assert MatMul(I, A) == MatMul(A) + # Make sure it is hashable + hash(I) + + +def test_one_matrix_creation(): + assert OneMatrix(2, 2) + assert OneMatrix(0, 0) + assert Eq(OneMatrix(1, 1), Identity(1)) + raises(ValueError, lambda: OneMatrix(-1, 2)) + raises(ValueError, lambda: OneMatrix(2.0, 2)) + raises(ValueError, lambda: OneMatrix(2j, 2)) + raises(ValueError, lambda: OneMatrix(2, -1)) + raises(ValueError, lambda: OneMatrix(2, 2.0)) + raises(ValueError, lambda: OneMatrix(2, 2j)) + + n = symbols('n') + assert OneMatrix(n, n) + n = symbols('n', integer=False) + raises(ValueError, lambda: OneMatrix(n, n)) + n = symbols('n', negative=True) + raises(ValueError, lambda: OneMatrix(n, n)) + + +def test_ZeroMatrix(): + n, m = symbols('n m', integer=True) + A = MatrixSymbol('A', n, m) + Z = ZeroMatrix(n, m) + + assert A + Z == A + assert A*Z.T == ZeroMatrix(n, n) + assert Z*A.T == ZeroMatrix(n, n) + assert A - A == ZeroMatrix(*A.shape) + + assert Z + + assert Z.transpose() == ZeroMatrix(m, n) + assert Z.conjugate() == Z + assert Z.adjoint() == ZeroMatrix(m, n) + assert re(Z) == Z + assert im(Z) == Z + + assert ZeroMatrix(n, n)**0 == Identity(n) + assert ZeroMatrix(3, 3).as_explicit() == ImmutableDenseMatrix.zeros(3, 3) + + +def test_ZeroMatrix_doit(): + n = symbols('n', integer=True) + Znn = ZeroMatrix(Add(n, n, evaluate=False), n) + assert isinstance(Znn.rows, Add) + assert Znn.doit() == ZeroMatrix(2*n, n) + assert isinstance(Znn.doit().rows, Mul) + + +def test_OneMatrix(): + n, m = symbols('n m', integer=True) + A = MatrixSymbol('A', n, m) + U = OneMatrix(n, m) + + assert U.shape == (n, m) + assert isinstance(A + U, Add) + assert U.transpose() == OneMatrix(m, n) + assert U.conjugate() == U + assert U.adjoint() == OneMatrix(m, n) + assert re(U) == U + assert im(U) == ZeroMatrix(n, m) + + assert OneMatrix(n, n) ** 0 == Identity(n) + + U = OneMatrix(n, n) + assert U[1, 2] == 1 + + U = OneMatrix(2, 3) + assert U.as_explicit() == ImmutableDenseMatrix.ones(2, 3) + + +def test_OneMatrix_doit(): + n = symbols('n', integer=True) + Unn = OneMatrix(Add(n, n, evaluate=False), n) + assert isinstance(Unn.rows, Add) + assert Unn.doit() == OneMatrix(2 * n, n) + assert isinstance(Unn.doit().rows, Mul) + + +def test_OneMatrix_mul(): + n, m, k = symbols('n m k', integer=True) + w = MatrixSymbol('w', n, 1) + assert OneMatrix(n, m) * OneMatrix(m, k) == OneMatrix(n, k) * m + assert w * OneMatrix(1, 1) == w + assert OneMatrix(1, 1) * w.T == w.T + + +def test_Identity(): + n, m = symbols('n m', integer=True) + A = MatrixSymbol('A', n, m) + i, j = symbols('i j') + + In = Identity(n) + Im = Identity(m) + + assert A*Im == A + assert In*A == A + + assert In.transpose() == In + assert In.inverse() == In + assert In.conjugate() == In + assert In.adjoint() == In + assert re(In) == In + assert im(In) == ZeroMatrix(n, n) + + assert In[i, j] != 0 + assert Sum(In[i, j], (i, 0, n-1), (j, 0, n-1)).subs(n,3).doit() == 3 + assert Sum(Sum(In[i, j], (i, 0, n-1)), (j, 0, n-1)).subs(n,3).doit() == 3 + + # If range exceeds the limit `(0, n-1)`, do not remove `Piecewise`: + expr = Sum(In[i, j], (i, 0, n-1)) + assert expr.doit() == 1 + expr = Sum(In[i, j], (i, 0, n-2)) + assert expr.doit().dummy_eq( + Piecewise( + (1, (j >= 0) & (j <= n-2)), + (0, True) + ) + ) + expr = Sum(In[i, j], (i, 1, n-1)) + assert expr.doit().dummy_eq( + Piecewise( + (1, (j >= 1) & (j <= n-1)), + (0, True) + ) + ) + assert Identity(3).as_explicit() == ImmutableDenseMatrix.eye(3) + + +def test_Identity_doit(): + n = symbols('n', integer=True) + Inn = Identity(Add(n, n, evaluate=False)) + assert isinstance(Inn.rows, Add) + assert Inn.doit() == Identity(2*n) + assert isinstance(Inn.doit().rows, Mul) diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_trace.py b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_trace.py new file mode 100644 index 0000000000000000000000000000000000000000..3bd66bec2377dae634ff486f42cc474eda7b23b1 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_trace.py @@ -0,0 +1,116 @@ +from sympy.core import Lambda, S, symbols +from sympy.concrete import Sum +from sympy.functions import adjoint, conjugate, transpose +from sympy.matrices import eye, Matrix, ShapeError, ImmutableMatrix +from sympy.matrices.expressions import ( + Adjoint, Identity, FunctionMatrix, MatrixExpr, MatrixSymbol, Trace, + ZeroMatrix, trace, MatPow, MatAdd, MatMul +) +from sympy.matrices.expressions.special import OneMatrix +from sympy.testing.pytest import raises +from sympy.abc import i + + +n = symbols('n', integer=True) +A = MatrixSymbol('A', n, n) +B = MatrixSymbol('B', n, n) +C = MatrixSymbol('C', 3, 4) + + +def test_Trace(): + assert isinstance(Trace(A), Trace) + assert not isinstance(Trace(A), MatrixExpr) + raises(ShapeError, lambda: Trace(C)) + assert trace(eye(3)) == 3 + assert trace(Matrix(3, 3, [1, 2, 3, 4, 5, 6, 7, 8, 9])) == 15 + + assert adjoint(Trace(A)) == trace(Adjoint(A)) + assert conjugate(Trace(A)) == trace(Adjoint(A)) + assert transpose(Trace(A)) == Trace(A) + + _ = A / Trace(A) # Make sure this is possible + + # Some easy simplifications + assert trace(Identity(5)) == 5 + assert trace(ZeroMatrix(5, 5)) == 0 + assert trace(OneMatrix(1, 1)) == 1 + assert trace(OneMatrix(2, 2)) == 2 + assert trace(OneMatrix(n, n)) == n + assert trace(2*A*B) == 2*Trace(A*B) + assert trace(A.T) == trace(A) + + i, j = symbols('i j') + F = FunctionMatrix(3, 3, Lambda((i, j), i + j)) + assert trace(F) == (0 + 0) + (1 + 1) + (2 + 2) + + raises(TypeError, lambda: Trace(S.One)) + + assert Trace(A).arg is A + + assert str(trace(A)) == str(Trace(A).doit()) + + assert Trace(A).is_commutative is True + +def test_Trace_A_plus_B(): + assert trace(A + B) == Trace(A) + Trace(B) + assert Trace(A + B).arg == MatAdd(A, B) + assert Trace(A + B).doit() == Trace(A) + Trace(B) + + +def test_Trace_MatAdd_doit(): + # See issue #9028 + X = ImmutableMatrix([[1, 2, 3]]*3) + Y = MatrixSymbol('Y', 3, 3) + q = MatAdd(X, 2*X, Y, -3*Y) + assert Trace(q).arg == q + assert Trace(q).doit() == 18 - 2*Trace(Y) + + +def test_Trace_MatPow_doit(): + X = Matrix([[1, 2], [3, 4]]) + assert Trace(X).doit() == 5 + q = MatPow(X, 2) + assert Trace(q).arg == q + assert Trace(q).doit() == 29 + + +def test_Trace_MutableMatrix_plus(): + # See issue #9043 + X = Matrix([[1, 2], [3, 4]]) + assert Trace(X) + Trace(X) == 2*Trace(X) + + +def test_Trace_doit_deep_False(): + X = Matrix([[1, 2], [3, 4]]) + q = MatPow(X, 2) + assert Trace(q).doit(deep=False).arg == q + q = MatAdd(X, 2*X) + assert Trace(q).doit(deep=False).arg == q + q = MatMul(X, 2*X) + assert Trace(q).doit(deep=False).arg == q + + +def test_trace_constant_factor(): + # Issue 9052: gave 2*Trace(MatMul(A)) instead of 2*Trace(A) + assert trace(2*A) == 2*Trace(A) + X = ImmutableMatrix([[1, 2], [3, 4]]) + assert trace(MatMul(2, X)) == 10 + + +def test_trace_rewrite(): + assert trace(A).rewrite(Sum) == Sum(A[i, i], (i, 0, n - 1)) + assert trace(eye(3)).rewrite(Sum) == 3 + + +def test_trace_normalize(): + assert Trace(B*A) != Trace(A*B) + assert Trace(B*A)._normalize() == Trace(A*B) + assert Trace(B*A.T)._normalize() == Trace(A*B.T) + + +def test_trace_as_explicit(): + raises(ValueError, lambda: Trace(A).as_explicit()) + + X = MatrixSymbol("X", 3, 3) + assert Trace(X).as_explicit() == X[0, 0] + X[1, 1] + X[2, 2] + assert Trace(eye(3)).as_explicit() == 3 diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_transpose.py b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_transpose.py new file mode 100644 index 0000000000000000000000000000000000000000..a1a6113873426d99bacf85484d3b66781f300af7 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/expressions/tests/test_transpose.py @@ -0,0 +1,69 @@ +from sympy.functions import adjoint, conjugate, transpose +from sympy.matrices.expressions import MatrixSymbol, Adjoint, trace, Transpose +from sympy.matrices import eye, Matrix +from sympy.assumptions.ask import Q +from sympy.assumptions.refine import refine +from sympy.core.singleton import S +from sympy.core.symbol import symbols + +n, m, l, k, p = symbols('n m l k p', integer=True) +A = MatrixSymbol('A', n, m) +B = MatrixSymbol('B', m, l) +C = MatrixSymbol('C', n, n) + + +def test_transpose(): + Sq = MatrixSymbol('Sq', n, n) + + assert transpose(A) == Transpose(A) + assert Transpose(A).shape == (m, n) + assert Transpose(A*B).shape == (l, n) + assert transpose(Transpose(A)) == A + assert isinstance(Transpose(Transpose(A)), Transpose) + + assert adjoint(Transpose(A)) == Adjoint(Transpose(A)) + assert conjugate(Transpose(A)) == Adjoint(A) + + assert Transpose(eye(3)).doit() == eye(3) + + assert Transpose(S(5)).doit() == S(5) + + assert Transpose(Matrix([[1, 2], [3, 4]])).doit() == Matrix([[1, 3], [2, 4]]) + + assert transpose(trace(Sq)) == trace(Sq) + assert trace(Transpose(Sq)) == trace(Sq) + + assert Transpose(Sq)[0, 1] == Sq[1, 0] + + assert Transpose(A*B).doit() == Transpose(B) * Transpose(A) + + +def test_transpose_MatAdd_MatMul(): + # Issue 16807 + from sympy.functions.elementary.trigonometric import cos + + x = symbols('x') + M = MatrixSymbol('M', 3, 3) + N = MatrixSymbol('N', 3, 3) + + assert (N + (cos(x) * M)).T == cos(x)*M.T + N.T + + +def test_refine(): + assert refine(C.T, Q.symmetric(C)) == C + + +def test_transpose1x1(): + m = MatrixSymbol('m', 1, 1) + assert m == refine(m.T) + assert m == refine(m.T.T) + +def test_issue_9817(): + from sympy.matrices.expressions import Identity + v = MatrixSymbol('v', 3, 1) + A = MatrixSymbol('A', 3, 3) + x = Matrix([i + 1 for i in range(3)]) + X = Identity(3) + quadratic = v.T * A * v + subbed = quadratic.xreplace({v:x, A:X}) + assert subbed.as_explicit() == Matrix([[14]]) diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/trace.py b/phivenv/Lib/site-packages/sympy/matrices/expressions/trace.py new file mode 100644 index 0000000000000000000000000000000000000000..b5f9f94ea7486dc21b47c2e2e783a93280b180e0 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/expressions/trace.py @@ -0,0 +1,167 @@ +from sympy.core.basic import Basic +from sympy.core.expr import Expr, ExprBuilder +from sympy.core.singleton import S +from sympy.core.sorting import default_sort_key +from sympy.core.symbol import uniquely_named_symbol +from sympy.core.sympify import sympify +from sympy.matrices.matrixbase import MatrixBase +from sympy.matrices.exceptions import NonSquareMatrixError + + +class Trace(Expr): + """Matrix Trace + + Represents the trace of a matrix expression. + + Examples + ======== + + >>> from sympy import MatrixSymbol, Trace, eye + >>> A = MatrixSymbol('A', 3, 3) + >>> Trace(A) + Trace(A) + >>> Trace(eye(3)) + Trace(Matrix([ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1]])) + >>> Trace(eye(3)).simplify() + 3 + """ + is_Trace = True + is_commutative = True + + def __new__(cls, mat): + mat = sympify(mat) + + if not mat.is_Matrix: + raise TypeError("input to Trace, %s, is not a matrix" % str(mat)) + + if mat.is_square is False: + raise NonSquareMatrixError("Trace of a non-square matrix") + + return Basic.__new__(cls, mat) + + def _eval_transpose(self): + return self + + def _eval_derivative(self, v): + from sympy.concrete.summations import Sum + from .matexpr import MatrixElement + if isinstance(v, MatrixElement): + return self.rewrite(Sum).diff(v) + expr = self.doit() + if isinstance(expr, Trace): + # Avoid looping infinitely: + raise NotImplementedError + return expr._eval_derivative(v) + + def _eval_derivative_matrix_lines(self, x): + from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct, ArrayContraction + r = self.args[0]._eval_derivative_matrix_lines(x) + for lr in r: + if lr.higher == 1: + lr.higher = ExprBuilder( + ArrayContraction, + [ + ExprBuilder( + ArrayTensorProduct, + [ + lr._lines[0], + lr._lines[1], + ] + ), + (1, 3), + ], + validator=ArrayContraction._validate + ) + else: + # This is not a matrix line: + lr.higher = ExprBuilder( + ArrayContraction, + [ + ExprBuilder( + ArrayTensorProduct, + [ + lr._lines[0], + lr._lines[1], + lr.higher, + ] + ), + (1, 3), (0, 2) + ] + ) + lr._lines = [S.One, S.One] + lr._first_pointer_parent = lr._lines + lr._second_pointer_parent = lr._lines + lr._first_pointer_index = 0 + lr._second_pointer_index = 1 + return r + + @property + def arg(self): + return self.args[0] + + def doit(self, **hints): + if hints.get('deep', True): + arg = self.arg.doit(**hints) + result = arg._eval_trace() + if result is not None: + return result + else: + return Trace(arg) + else: + # _eval_trace would go too deep here + if isinstance(self.arg, MatrixBase): + return trace(self.arg) + else: + return Trace(self.arg) + + def as_explicit(self): + return Trace(self.arg.as_explicit()).doit() + + def _normalize(self): + # Normalization of trace of matrix products. Use transposition and + # cyclic properties of traces to make sure the arguments of the matrix + # product are sorted and the first argument is not a transposition. + from sympy.matrices.expressions.matmul import MatMul + from sympy.matrices.expressions.transpose import Transpose + trace_arg = self.arg + if isinstance(trace_arg, MatMul): + + def get_arg_key(x): + a = trace_arg.args[x] + if isinstance(a, Transpose): + a = a.arg + return default_sort_key(a) + + indmin = min(range(len(trace_arg.args)), key=get_arg_key) + if isinstance(trace_arg.args[indmin], Transpose): + trace_arg = Transpose(trace_arg).doit() + indmin = min(range(len(trace_arg.args)), key=lambda x: default_sort_key(trace_arg.args[x])) + trace_arg = MatMul.fromiter(trace_arg.args[indmin:] + trace_arg.args[:indmin]) + return Trace(trace_arg) + return self + + def _eval_rewrite_as_Sum(self, expr, **kwargs): + from sympy.concrete.summations import Sum + i = uniquely_named_symbol('i', [expr]) + s = Sum(self.arg[i, i], (i, 0, self.arg.rows - 1)) + return s.doit() + + +def trace(expr): + """Trace of a Matrix. Sum of the diagonal elements. + + Examples + ======== + + >>> from sympy import trace, Symbol, MatrixSymbol, eye + >>> n = Symbol('n') + >>> X = MatrixSymbol('X', n, n) # A square matrix + >>> trace(2*X) + 2*Trace(X) + >>> trace(eye(3)) + 3 + """ + return Trace(expr).doit() diff --git a/phivenv/Lib/site-packages/sympy/matrices/expressions/transpose.py b/phivenv/Lib/site-packages/sympy/matrices/expressions/transpose.py new file mode 100644 index 0000000000000000000000000000000000000000..b11f7fc21490aab219420610ca529d81d6995d40 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/expressions/transpose.py @@ -0,0 +1,103 @@ +from sympy.core.basic import Basic +from sympy.matrices.expressions.matexpr import MatrixExpr + + +class Transpose(MatrixExpr): + """ + The transpose of a matrix expression. + + This is a symbolic object that simply stores its argument without + evaluating it. To actually compute the transpose, use the ``transpose()`` + function, or the ``.T`` attribute of matrices. + + Examples + ======== + + >>> from sympy import MatrixSymbol, Transpose, transpose + >>> A = MatrixSymbol('A', 3, 5) + >>> B = MatrixSymbol('B', 5, 3) + >>> Transpose(A) + A.T + >>> A.T == transpose(A) == Transpose(A) + True + >>> Transpose(A*B) + (A*B).T + >>> transpose(A*B) + B.T*A.T + + """ + is_Transpose = True + + def doit(self, **hints): + arg = self.arg + if hints.get('deep', True) and isinstance(arg, Basic): + arg = arg.doit(**hints) + _eval_transpose = getattr(arg, '_eval_transpose', None) + if _eval_transpose is not None: + result = _eval_transpose() + return result if result is not None else Transpose(arg) + else: + return Transpose(arg) + + @property + def arg(self): + return self.args[0] + + @property + def shape(self): + return self.arg.shape[::-1] + + def _entry(self, i, j, expand=False, **kwargs): + return self.arg._entry(j, i, expand=expand, **kwargs) + + def _eval_adjoint(self): + return self.arg.conjugate() + + def _eval_conjugate(self): + return self.arg.adjoint() + + def _eval_transpose(self): + return self.arg + + def _eval_trace(self): + from .trace import Trace + return Trace(self.arg) # Trace(X.T) => Trace(X) + + def _eval_determinant(self): + from sympy.matrices.expressions.determinant import det + return det(self.arg) + + def _eval_derivative(self, x): + # x is a scalar: + return self.arg._eval_derivative(x) + + def _eval_derivative_matrix_lines(self, x): + lines = self.args[0]._eval_derivative_matrix_lines(x) + return [i.transpose() for i in lines] + + +def transpose(expr): + """Matrix transpose""" + return Transpose(expr).doit(deep=False) + + +from sympy.assumptions.ask import ask, Q +from sympy.assumptions.refine import handlers_dict + + +def refine_Transpose(expr, assumptions): + """ + >>> from sympy import MatrixSymbol, Q, assuming, refine + >>> X = MatrixSymbol('X', 2, 2) + >>> X.T + X.T + >>> with assuming(Q.symmetric(X)): + ... print(refine(X.T)) + X + """ + if ask(Q.symmetric(expr), assumptions): + return expr.arg + + return expr + +handlers_dict['Transpose'] = refine_Transpose diff --git a/phivenv/Lib/site-packages/sympy/matrices/graph.py b/phivenv/Lib/site-packages/sympy/matrices/graph.py new file mode 100644 index 0000000000000000000000000000000000000000..4c6356db884cfcd3c759ada07ac559f43dbcbbcb --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/graph.py @@ -0,0 +1,279 @@ +from sympy.utilities.iterables import \ + flatten, connected_components, strongly_connected_components +from .exceptions import NonSquareMatrixError + + +def _connected_components(M): + """Returns the list of connected vertices of the graph when + a square matrix is viewed as a weighted graph. + + Examples + ======== + + >>> from sympy import Matrix + >>> A = Matrix([ + ... [66, 0, 0, 68, 0, 0, 0, 0, 67], + ... [0, 55, 0, 0, 0, 0, 54, 53, 0], + ... [0, 0, 0, 0, 1, 2, 0, 0, 0], + ... [86, 0, 0, 88, 0, 0, 0, 0, 87], + ... [0, 0, 10, 0, 11, 12, 0, 0, 0], + ... [0, 0, 20, 0, 21, 22, 0, 0, 0], + ... [0, 45, 0, 0, 0, 0, 44, 43, 0], + ... [0, 35, 0, 0, 0, 0, 34, 33, 0], + ... [76, 0, 0, 78, 0, 0, 0, 0, 77]]) + >>> A.connected_components() + [[0, 3, 8], [1, 6, 7], [2, 4, 5]] + + Notes + ===== + + Even if any symbolic elements of the matrix can be indeterminate + to be zero mathematically, this only takes the account of the + structural aspect of the matrix, so they will considered to be + nonzero. + """ + if not M.is_square: + raise NonSquareMatrixError + + V = range(M.rows) + E = sorted(M.todok().keys()) + return connected_components((V, E)) + + +def _strongly_connected_components(M): + """Returns the list of strongly connected vertices of the graph when + a square matrix is viewed as a weighted graph. + + Examples + ======== + + >>> from sympy import Matrix + >>> A = Matrix([ + ... [44, 0, 0, 0, 43, 0, 45, 0, 0], + ... [0, 66, 62, 61, 0, 68, 0, 60, 67], + ... [0, 0, 22, 21, 0, 0, 0, 20, 0], + ... [0, 0, 12, 11, 0, 0, 0, 10, 0], + ... [34, 0, 0, 0, 33, 0, 35, 0, 0], + ... [0, 86, 82, 81, 0, 88, 0, 80, 87], + ... [54, 0, 0, 0, 53, 0, 55, 0, 0], + ... [0, 0, 2, 1, 0, 0, 0, 0, 0], + ... [0, 76, 72, 71, 0, 78, 0, 70, 77]]) + >>> A.strongly_connected_components() + [[0, 4, 6], [2, 3, 7], [1, 5, 8]] + """ + if not M.is_square: + raise NonSquareMatrixError + + # RepMatrix uses the more efficient DomainMatrix.scc() method + rep = getattr(M, '_rep', None) + if rep is not None: + return rep.scc() + + V = range(M.rows) + E = sorted(M.todok().keys()) + return strongly_connected_components((V, E)) + + +def _connected_components_decomposition(M): + """Decomposes a square matrix into block diagonal form only + using the permutations. + + Explanation + =========== + + The decomposition is in a form of $A = P^{-1} B P$ where $P$ is a + permutation matrix and $B$ is a block diagonal matrix. + + Returns + ======= + + P, B : PermutationMatrix, BlockDiagMatrix + *P* is a permutation matrix for the similarity transform + as in the explanation. And *B* is the block diagonal matrix of + the result of the permutation. + + If you would like to get the diagonal blocks from the + BlockDiagMatrix, see + :meth:`~sympy.matrices.expressions.blockmatrix.BlockDiagMatrix.get_diag_blocks`. + + Examples + ======== + + >>> from sympy import Matrix, pprint + >>> A = Matrix([ + ... [66, 0, 0, 68, 0, 0, 0, 0, 67], + ... [0, 55, 0, 0, 0, 0, 54, 53, 0], + ... [0, 0, 0, 0, 1, 2, 0, 0, 0], + ... [86, 0, 0, 88, 0, 0, 0, 0, 87], + ... [0, 0, 10, 0, 11, 12, 0, 0, 0], + ... [0, 0, 20, 0, 21, 22, 0, 0, 0], + ... [0, 45, 0, 0, 0, 0, 44, 43, 0], + ... [0, 35, 0, 0, 0, 0, 34, 33, 0], + ... [76, 0, 0, 78, 0, 0, 0, 0, 77]]) + + >>> P, B = A.connected_components_decomposition() + >>> pprint(P) + PermutationMatrix((1 3)(2 8 5 7 4 6)) + >>> pprint(B) + [[66 68 67] ] + [[ ] ] + [[86 88 87] 0 0 ] + [[ ] ] + [[76 78 77] ] + [ ] + [ [55 54 53] ] + [ [ ] ] + [ 0 [45 44 43] 0 ] + [ [ ] ] + [ [35 34 33] ] + [ ] + [ [0 1 2 ]] + [ [ ]] + [ 0 0 [10 11 12]] + [ [ ]] + [ [20 21 22]] + + >>> P = P.as_explicit() + >>> B = B.as_explicit() + >>> P.T*B*P == A + True + + Notes + ===== + + This problem corresponds to the finding of the connected components + of a graph, when a matrix is viewed as a weighted graph. + """ + from sympy.combinatorics.permutations import Permutation + from sympy.matrices.expressions.blockmatrix import BlockDiagMatrix + from sympy.matrices.expressions.permutation import PermutationMatrix + + iblocks = M.connected_components() + + p = Permutation(flatten(iblocks)) + P = PermutationMatrix(p) + + blocks = [] + for b in iblocks: + blocks.append(M[b, b]) + B = BlockDiagMatrix(*blocks) + return P, B + + +def _strongly_connected_components_decomposition(M, lower=True): + """Decomposes a square matrix into block triangular form only + using the permutations. + + Explanation + =========== + + The decomposition is in a form of $A = P^{-1} B P$ where $P$ is a + permutation matrix and $B$ is a block diagonal matrix. + + Parameters + ========== + + lower : bool + Makes $B$ lower block triangular when ``True``. + Otherwise, makes $B$ upper block triangular. + + Returns + ======= + + P, B : PermutationMatrix, BlockMatrix + *P* is a permutation matrix for the similarity transform + as in the explanation. And *B* is the block triangular matrix of + the result of the permutation. + + Examples + ======== + + >>> from sympy import Matrix, pprint + >>> A = Matrix([ + ... [44, 0, 0, 0, 43, 0, 45, 0, 0], + ... [0, 66, 62, 61, 0, 68, 0, 60, 67], + ... [0, 0, 22, 21, 0, 0, 0, 20, 0], + ... [0, 0, 12, 11, 0, 0, 0, 10, 0], + ... [34, 0, 0, 0, 33, 0, 35, 0, 0], + ... [0, 86, 82, 81, 0, 88, 0, 80, 87], + ... [54, 0, 0, 0, 53, 0, 55, 0, 0], + ... [0, 0, 2, 1, 0, 0, 0, 0, 0], + ... [0, 76, 72, 71, 0, 78, 0, 70, 77]]) + + A lower block triangular decomposition: + + >>> P, B = A.strongly_connected_components_decomposition() + >>> pprint(P) + PermutationMatrix((8)(1 4 3 2 6)(5 7)) + >>> pprint(B) + [[44 43 45] [0 0 0] [0 0 0] ] + [[ ] [ ] [ ] ] + [[34 33 35] [0 0 0] [0 0 0] ] + [[ ] [ ] [ ] ] + [[54 53 55] [0 0 0] [0 0 0] ] + [ ] + [ [0 0 0] [22 21 20] [0 0 0] ] + [ [ ] [ ] [ ] ] + [ [0 0 0] [12 11 10] [0 0 0] ] + [ [ ] [ ] [ ] ] + [ [0 0 0] [2 1 0 ] [0 0 0] ] + [ ] + [ [0 0 0] [62 61 60] [66 68 67]] + [ [ ] [ ] [ ]] + [ [0 0 0] [82 81 80] [86 88 87]] + [ [ ] [ ] [ ]] + [ [0 0 0] [72 71 70] [76 78 77]] + + >>> P = P.as_explicit() + >>> B = B.as_explicit() + >>> P.T * B * P == A + True + + An upper block triangular decomposition: + + >>> P, B = A.strongly_connected_components_decomposition(lower=False) + >>> pprint(P) + PermutationMatrix((0 1 5 7 4 3 2 8 6)) + >>> pprint(B) + [[66 68 67] [62 61 60] [0 0 0] ] + [[ ] [ ] [ ] ] + [[86 88 87] [82 81 80] [0 0 0] ] + [[ ] [ ] [ ] ] + [[76 78 77] [72 71 70] [0 0 0] ] + [ ] + [ [0 0 0] [22 21 20] [0 0 0] ] + [ [ ] [ ] [ ] ] + [ [0 0 0] [12 11 10] [0 0 0] ] + [ [ ] [ ] [ ] ] + [ [0 0 0] [2 1 0 ] [0 0 0] ] + [ ] + [ [0 0 0] [0 0 0] [44 43 45]] + [ [ ] [ ] [ ]] + [ [0 0 0] [0 0 0] [34 33 35]] + [ [ ] [ ] [ ]] + [ [0 0 0] [0 0 0] [54 53 55]] + + >>> P = P.as_explicit() + >>> B = B.as_explicit() + >>> P.T * B * P == A + True + """ + from sympy.combinatorics.permutations import Permutation + from sympy.matrices.expressions.blockmatrix import BlockMatrix + from sympy.matrices.expressions.permutation import PermutationMatrix + + iblocks = M.strongly_connected_components() + if not lower: + iblocks = list(reversed(iblocks)) + + p = Permutation(flatten(iblocks)) + P = PermutationMatrix(p) + + rows = [] + for a in iblocks: + cols = [] + for b in iblocks: + cols.append(M[a, b]) + rows.append(cols) + B = BlockMatrix(rows) + return P, B diff --git a/phivenv/Lib/site-packages/sympy/matrices/immutable.py b/phivenv/Lib/site-packages/sympy/matrices/immutable.py new file mode 100644 index 0000000000000000000000000000000000000000..7ec2174bf1c785e1a4698e1b55078d300e62dafe --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/immutable.py @@ -0,0 +1,196 @@ +from mpmath.matrices.matrices import _matrix + +from sympy.core import Basic, Dict, Tuple +from sympy.core.numbers import Integer +from sympy.core.cache import cacheit +from sympy.core.sympify import _sympy_converter as sympify_converter, _sympify +from sympy.matrices.dense import DenseMatrix +from sympy.matrices.expressions import MatrixExpr +from sympy.matrices.matrixbase import MatrixBase +from sympy.matrices.repmatrix import RepMatrix +from sympy.matrices.sparse import SparseRepMatrix +from sympy.multipledispatch import dispatch + + +def sympify_matrix(arg): + return arg.as_immutable() + + +sympify_converter[MatrixBase] = sympify_matrix + + +def sympify_mpmath_matrix(arg): + mat = [_sympify(x) for x in arg] + return ImmutableDenseMatrix(arg.rows, arg.cols, mat) + + +sympify_converter[_matrix] = sympify_mpmath_matrix + + +class ImmutableRepMatrix(RepMatrix, MatrixExpr): # type: ignore + """Immutable matrix based on RepMatrix + + Uses DomainMAtrix as the internal representation. + """ + + # + # This is a subclass of RepMatrix that adds/overrides some methods to make + # the instances Basic and immutable. ImmutableRepMatrix is a superclass for + # both ImmutableDenseMatrix and ImmutableSparseMatrix. + # + + def __new__(cls, *args, **kwargs): + return cls._new(*args, **kwargs) + + __hash__ = MatrixExpr.__hash__ + + def copy(self): + return self + + @property + def cols(self): + return self._cols + + @property + def rows(self): + return self._rows + + @property + def shape(self): + return self._rows, self._cols + + def as_immutable(self): + return self + + def _entry(self, i, j, **kwargs): + return self[i, j] + + def __setitem__(self, *args): + raise TypeError("Cannot set values of {}".format(self.__class__)) + + def is_diagonalizable(self, reals_only=False, **kwargs): + return super().is_diagonalizable( + reals_only=reals_only, **kwargs) + + is_diagonalizable.__doc__ = SparseRepMatrix.is_diagonalizable.__doc__ + is_diagonalizable = cacheit(is_diagonalizable) + + def analytic_func(self, f, x): + return self.as_mutable().analytic_func(f, x).as_immutable() + + +class ImmutableDenseMatrix(DenseMatrix, ImmutableRepMatrix): # type: ignore + """Create an immutable version of a matrix. + + Examples + ======== + + >>> from sympy import eye, ImmutableMatrix + >>> ImmutableMatrix(eye(3)) + Matrix([ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1]]) + >>> _[0, 0] = 42 + Traceback (most recent call last): + ... + TypeError: Cannot set values of ImmutableDenseMatrix + """ + + # MatrixExpr is set as NotIterable, but we want explicit matrices to be + # iterable + _iterable = True + _class_priority = 8 + _op_priority = 10.001 + + @classmethod + def _new(cls, *args, **kwargs): + if len(args) == 1 and isinstance(args[0], ImmutableDenseMatrix): + return args[0] + if kwargs.get('copy', True) is False: + if len(args) != 3: + raise TypeError("'copy=False' requires a matrix be initialized as rows,cols,[list]") + rows, cols, flat_list = args + else: + rows, cols, flat_list = cls._handle_creation_inputs(*args, **kwargs) + flat_list = list(flat_list) # create a shallow copy + + rep = cls._flat_list_to_DomainMatrix(rows, cols, flat_list) + + return cls._fromrep(rep) + + @classmethod + def _fromrep(cls, rep): + rows, cols = rep.shape + flat_list = rep.to_sympy().to_list_flat() + obj = Basic.__new__(cls, + Integer(rows), + Integer(cols), + Tuple(*flat_list, sympify=False)) + obj._rows = rows + obj._cols = cols + obj._rep = rep + return obj + + +# make sure ImmutableDenseMatrix is aliased as ImmutableMatrix +ImmutableMatrix = ImmutableDenseMatrix + + +class ImmutableSparseMatrix(SparseRepMatrix, ImmutableRepMatrix): # type:ignore + """Create an immutable version of a sparse matrix. + + Examples + ======== + + >>> from sympy import eye, ImmutableSparseMatrix + >>> ImmutableSparseMatrix(1, 1, {}) + Matrix([[0]]) + >>> ImmutableSparseMatrix(eye(3)) + Matrix([ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1]]) + >>> _[0, 0] = 42 + Traceback (most recent call last): + ... + TypeError: Cannot set values of ImmutableSparseMatrix + >>> _.shape + (3, 3) + """ + is_Matrix = True + _class_priority = 9 + + @classmethod + def _new(cls, *args, **kwargs): + rows, cols, smat = cls._handle_creation_inputs(*args, **kwargs) + + rep = cls._smat_to_DomainMatrix(rows, cols, smat) + + return cls._fromrep(rep) + + @classmethod + def _fromrep(cls, rep): + rows, cols = rep.shape + smat = rep.to_sympy().to_dok() + obj = Basic.__new__(cls, Integer(rows), Integer(cols), Dict(smat)) + obj._rows = rows + obj._cols = cols + obj._rep = rep + return obj + + +@dispatch(ImmutableDenseMatrix, ImmutableDenseMatrix) +def _eval_is_eq(lhs, rhs): # noqa:F811 + """Helper method for Equality with matrices.sympy. + + Relational automatically converts matrices to ImmutableDenseMatrix + instances, so this method only applies here. Returns True if the + matrices are definitively the same, False if they are definitively + different, and None if undetermined (e.g. if they contain Symbols). + Returning None triggers default handling of Equalities. + + """ + if lhs.shape != rhs.shape: + return False + return (lhs - rhs).is_zero_matrix diff --git a/phivenv/Lib/site-packages/sympy/matrices/inverse.py b/phivenv/Lib/site-packages/sympy/matrices/inverse.py new file mode 100644 index 0000000000000000000000000000000000000000..61d9e12edf013d2f5555d61786343aa3840edfd3 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/inverse.py @@ -0,0 +1,524 @@ +from sympy.polys.matrices.exceptions import DMNonInvertibleMatrixError +from sympy.polys.domains import EX + +from .exceptions import MatrixError, NonSquareMatrixError, NonInvertibleMatrixError +from .utilities import _iszero + + +def _pinv_full_rank(M): + """Subroutine for full row or column rank matrices. + + For full row rank matrices, inverse of ``A * A.H`` Exists. + For full column rank matrices, inverse of ``A.H * A`` Exists. + + This routine can apply for both cases by checking the shape + and have small decision. + """ + + if M.is_zero_matrix: + return M.H + + if M.rows >= M.cols: + return M.H.multiply(M).inv().multiply(M.H) + else: + return M.H.multiply(M.multiply(M.H).inv()) + +def _pinv_rank_decomposition(M): + """Subroutine for rank decomposition + + With rank decompositions, `A` can be decomposed into two full- + rank matrices, and each matrix can take pseudoinverse + individually. + """ + + if M.is_zero_matrix: + return M.H + + B, C = M.rank_decomposition() + + Bp = _pinv_full_rank(B) + Cp = _pinv_full_rank(C) + + return Cp.multiply(Bp) + +def _pinv_diagonalization(M): + """Subroutine using diagonalization + + This routine can sometimes fail if SymPy's eigenvalue + computation is not reliable. + """ + + if M.is_zero_matrix: + return M.H + + A = M + AH = M.H + + try: + if M.rows >= M.cols: + P, D = AH.multiply(A).diagonalize(normalize=True) + D_pinv = D.applyfunc(lambda x: 0 if _iszero(x) else 1 / x) + + return P.multiply(D_pinv).multiply(P.H).multiply(AH) + + else: + P, D = A.multiply(AH).diagonalize( + normalize=True) + D_pinv = D.applyfunc(lambda x: 0 if _iszero(x) else 1 / x) + + return AH.multiply(P).multiply(D_pinv).multiply(P.H) + + except MatrixError: + raise NotImplementedError( + 'pinv for rank-deficient matrices where ' + 'diagonalization of A.H*A fails is not supported yet.') + +def _pinv(M, method='RD'): + """Calculate the Moore-Penrose pseudoinverse of the matrix. + + The Moore-Penrose pseudoinverse exists and is unique for any matrix. + If the matrix is invertible, the pseudoinverse is the same as the + inverse. + + Parameters + ========== + + method : String, optional + Specifies the method for computing the pseudoinverse. + + If ``'RD'``, Rank-Decomposition will be used. + + If ``'ED'``, Diagonalization will be used. + + Examples + ======== + + Computing pseudoinverse by rank decomposition : + + >>> from sympy import Matrix + >>> A = Matrix([[1, 2, 3], [4, 5, 6]]) + >>> A.pinv() + Matrix([ + [-17/18, 4/9], + [ -1/9, 1/9], + [ 13/18, -2/9]]) + + Computing pseudoinverse by diagonalization : + + >>> B = A.pinv(method='ED') + >>> B.simplify() + >>> B + Matrix([ + [-17/18, 4/9], + [ -1/9, 1/9], + [ 13/18, -2/9]]) + + See Also + ======== + + inv + pinv_solve + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Moore-Penrose_pseudoinverse + + """ + + # Trivial case: pseudoinverse of all-zero matrix is its transpose. + if M.is_zero_matrix: + return M.H + + if method == 'RD': + return _pinv_rank_decomposition(M) + elif method == 'ED': + return _pinv_diagonalization(M) + else: + raise ValueError('invalid pinv method %s' % repr(method)) + + +def _verify_invertible(M, iszerofunc=_iszero): + """Initial check to see if a matrix is invertible. Raises or returns + determinant for use in _inv_ADJ.""" + + if not M.is_square: + raise NonSquareMatrixError("A Matrix must be square to invert.") + + d = M.det(method='berkowitz') + zero = d.equals(0) + + if zero is None: # if equals() can't decide, will rref be able to? + ok = M.rref(simplify=True)[0] + zero = any(iszerofunc(ok[j, j]) for j in range(ok.rows)) + + if zero: + raise NonInvertibleMatrixError("Matrix det == 0; not invertible.") + + return d + +def _inv_ADJ(M, iszerofunc=_iszero): + """Calculates the inverse using the adjugate matrix and a determinant. + + See Also + ======== + + inv + inverse_GE + inverse_LU + inverse_CH + inverse_LDL + """ + + d = _verify_invertible(M, iszerofunc=iszerofunc) + + return M.adjugate() / d + +def _inv_GE(M, iszerofunc=_iszero): + """Calculates the inverse using Gaussian elimination. + + See Also + ======== + + inv + inverse_ADJ + inverse_LU + inverse_CH + inverse_LDL + """ + + from .dense import Matrix + + if not M.is_square: + raise NonSquareMatrixError("A Matrix must be square to invert.") + + big = Matrix.hstack(M.as_mutable(), Matrix.eye(M.rows)) + red = big.rref(iszerofunc=iszerofunc, simplify=True)[0] + + if any(iszerofunc(red[j, j]) for j in range(red.rows)): + raise NonInvertibleMatrixError("Matrix det == 0; not invertible.") + + return M._new(red[:, big.rows:]) + +def _inv_LU(M, iszerofunc=_iszero): + """Calculates the inverse using LU decomposition. + + See Also + ======== + + inv + inverse_ADJ + inverse_GE + inverse_CH + inverse_LDL + """ + + if not M.is_square: + raise NonSquareMatrixError("A Matrix must be square to invert.") + if M.free_symbols: + _verify_invertible(M, iszerofunc=iszerofunc) + + return M.LUsolve(M.eye(M.rows), iszerofunc=_iszero) + +def _inv_CH(M, iszerofunc=_iszero): + """Calculates the inverse using cholesky decomposition. + + See Also + ======== + + inv + inverse_ADJ + inverse_GE + inverse_LU + inverse_LDL + """ + + _verify_invertible(M, iszerofunc=iszerofunc) + + return M.cholesky_solve(M.eye(M.rows)) + +def _inv_LDL(M, iszerofunc=_iszero): + """Calculates the inverse using LDL decomposition. + + See Also + ======== + + inv + inverse_ADJ + inverse_GE + inverse_LU + inverse_CH + """ + + _verify_invertible(M, iszerofunc=iszerofunc) + + return M.LDLsolve(M.eye(M.rows)) + +def _inv_QR(M, iszerofunc=_iszero): + """Calculates the inverse using QR decomposition. + + See Also + ======== + + inv + inverse_ADJ + inverse_GE + inverse_CH + inverse_LDL + """ + + _verify_invertible(M, iszerofunc=iszerofunc) + + return M.QRsolve(M.eye(M.rows)) + +def _try_DM(M, use_EX=False): + """Try to convert a matrix to a ``DomainMatrix``.""" + dM = M.to_DM() + K = dM.domain + + # Return DomainMatrix if a domain is found. Only use EX if use_EX=True. + if not use_EX and K.is_EXRAW: + return None + elif K.is_EXRAW: + return dM.convert_to(EX) + else: + return dM + + +def _use_exact_domain(dom): + """Check whether to convert to an exact domain.""" + # DomainMatrix can handle RR and CC with partial pivoting. Other inexact + # domains like RR[a,b,...] can only be handled by converting to an exact + # domain like QQ[a,b,...] + if dom.is_RR or dom.is_CC: + return False + else: + return not dom.is_Exact + + +def _inv_DM(dM, cancel=True): + """Calculates the inverse using ``DomainMatrix``. + + See Also + ======== + + inv + inverse_ADJ + inverse_GE + inverse_CH + inverse_LDL + sympy.polys.matrices.domainmatrix.DomainMatrix.inv + """ + m, n = dM.shape + dom = dM.domain + + if m != n: + raise NonSquareMatrixError("A Matrix must be square to invert.") + + # Convert RR[a,b,...] to QQ[a,b,...] + use_exact = _use_exact_domain(dom) + + if use_exact: + dom_exact = dom.get_exact() + dM = dM.convert_to(dom_exact) + + try: + dMi, den = dM.inv_den() + except DMNonInvertibleMatrixError: + raise NonInvertibleMatrixError("Matrix det == 0; not invertible.") + + if use_exact: + dMi = dMi.convert_to(dom) + den = dom.convert_from(den, dom_exact) + + if cancel: + # Convert to field and cancel with the denominator. + if not dMi.domain.is_Field: + dMi = dMi.to_field() + Mi = (dMi / den).to_Matrix() + else: + # Convert to Matrix and divide without cancelling + Mi = dMi.to_Matrix() / dMi.domain.to_sympy(den) + + return Mi + +def _inv_block(M, iszerofunc=_iszero): + """Calculates the inverse using BLOCKWISE inversion. + + See Also + ======== + + inv + inverse_ADJ + inverse_GE + inverse_CH + inverse_LDL + """ + from sympy.matrices.expressions.blockmatrix import BlockMatrix + i = M.shape[0] + if i <= 20 : + return M.inv(method="LU", iszerofunc=_iszero) + A = M[:i // 2, :i //2] + B = M[:i // 2, i // 2:] + C = M[i // 2:, :i // 2] + D = M[i // 2:, i // 2:] + try: + D_inv = _inv_block(D) + except NonInvertibleMatrixError: + return M.inv(method="LU", iszerofunc=_iszero) + B_D_i = B*D_inv + BDC = B_D_i*C + A_n = A - BDC + try: + A_n = _inv_block(A_n) + except NonInvertibleMatrixError: + return M.inv(method="LU", iszerofunc=_iszero) + B_n = -A_n*B_D_i + dc = D_inv*C + C_n = -dc*A_n + D_n = D_inv + dc*-B_n + nn = BlockMatrix([[A_n, B_n], [C_n, D_n]]).as_explicit() + return nn + +def _inv(M, method=None, iszerofunc=_iszero, try_block_diag=False): + """ + Return the inverse of a matrix using the method indicated. The default + is DM if a suitable domain is found or otherwise GE for dense matrices + LDL for sparse matrices. + + Parameters + ========== + + method : ('DM', 'DMNC', 'GE', 'LU', 'ADJ', 'CH', 'LDL', 'QR') + + iszerofunc : function, optional + Zero-testing function to use. + + try_block_diag : bool, optional + If True then will try to form block diagonal matrices using the + method get_diag_blocks(), invert these individually, and then + reconstruct the full inverse matrix. + + Examples + ======== + + >>> from sympy import SparseMatrix, Matrix + >>> A = SparseMatrix([ + ... [ 2, -1, 0], + ... [-1, 2, -1], + ... [ 0, 0, 2]]) + >>> A.inv('CH') + Matrix([ + [2/3, 1/3, 1/6], + [1/3, 2/3, 1/3], + [ 0, 0, 1/2]]) + >>> A.inv(method='LDL') # use of 'method=' is optional + Matrix([ + [2/3, 1/3, 1/6], + [1/3, 2/3, 1/3], + [ 0, 0, 1/2]]) + >>> A * _ + Matrix([ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1]]) + >>> A = Matrix(A) + >>> A.inv('CH') + Matrix([ + [2/3, 1/3, 1/6], + [1/3, 2/3, 1/3], + [ 0, 0, 1/2]]) + >>> A.inv('ADJ') == A.inv('GE') == A.inv('LU') == A.inv('CH') == A.inv('LDL') == A.inv('QR') + True + + Notes + ===== + + According to the ``method`` keyword, it calls the appropriate method: + + DM .... Use DomainMatrix ``inv_den`` method + DMNC .... Use DomainMatrix ``inv_den`` method without cancellation + GE .... inverse_GE(); default for dense matrices + LU .... inverse_LU() + ADJ ... inverse_ADJ() + CH ... inverse_CH() + LDL ... inverse_LDL(); default for sparse matrices + QR ... inverse_QR() + + Note, the GE and LU methods may require the matrix to be simplified + before it is inverted in order to properly detect zeros during + pivoting. In difficult cases a custom zero detection function can + be provided by setting the ``iszerofunc`` argument to a function that + should return True if its argument is zero. The ADJ routine computes + the determinant and uses that to detect singular matrices in addition + to testing for zeros on the diagonal. + + See Also + ======== + + inverse_ADJ + inverse_GE + inverse_LU + inverse_CH + inverse_LDL + + Raises + ====== + + ValueError + If the determinant of the matrix is zero. + """ + + from sympy.matrices import diag, SparseMatrix + + if not M.is_square: + raise NonSquareMatrixError("A Matrix must be square to invert.") + + if try_block_diag: + blocks = M.get_diag_blocks() + r = [] + + for block in blocks: + r.append(block.inv(method=method, iszerofunc=iszerofunc)) + + return diag(*r) + + # Default: Use DomainMatrix if the domain is not EX. + # If DM is requested explicitly then use it even if the domain is EX. + if method is None and iszerofunc is _iszero: + dM = _try_DM(M, use_EX=False) + if dM is not None: + method = 'DM' + elif method in ("DM", "DMNC"): + dM = _try_DM(M, use_EX=True) + + # A suitable domain was not found, fall back to GE for dense matrices + # and LDL for sparse matrices. + if method is None: + if isinstance(M, SparseMatrix): + method = 'LDL' + else: + method = 'GE' + + if method == "DM": + rv = _inv_DM(dM) + elif method == "DMNC": + rv = _inv_DM(dM, cancel=False) + elif method == "GE": + rv = M.inverse_GE(iszerofunc=iszerofunc) + elif method == "LU": + rv = M.inverse_LU(iszerofunc=iszerofunc) + elif method == "ADJ": + rv = M.inverse_ADJ(iszerofunc=iszerofunc) + elif method == "CH": + rv = M.inverse_CH(iszerofunc=iszerofunc) + elif method == "LDL": + rv = M.inverse_LDL(iszerofunc=iszerofunc) + elif method == "QR": + rv = M.inverse_QR(iszerofunc=iszerofunc) + elif method == "BLOCK": + rv = M.inverse_BLOCK(iszerofunc=iszerofunc) + else: + raise ValueError("Inversion method unrecognized") + + return M._new(rv) diff --git a/phivenv/Lib/site-packages/sympy/matrices/kind.py b/phivenv/Lib/site-packages/sympy/matrices/kind.py new file mode 100644 index 0000000000000000000000000000000000000000..f9f53ffe16f7cbde60213e49071a2a74e80e5c6c --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/kind.py @@ -0,0 +1,97 @@ +# sympy.matrices.kind + +from sympy.core.kind import Kind, _NumberKind, NumberKind +from sympy.core.mul import Mul + + +class MatrixKind(Kind): + """ + Kind for all matrices in SymPy. + + Basic class for this kind is ``MatrixBase`` and ``MatrixExpr``, + but any expression representing the matrix can have this. + + Parameters + ========== + + element_kind : Kind + Kind of the element. Default is + :class:`sympy.core.kind.NumberKind`, + which means that the matrix contains only numbers. + + Examples + ======== + + Any instance of matrix class has kind ``MatrixKind``: + + >>> from sympy import MatrixSymbol + >>> A = MatrixSymbol('A', 2, 2) + >>> A.kind + MatrixKind(NumberKind) + + An expression representing a matrix may not be an instance of + the Matrix class, but it will have kind ``MatrixKind``: + + >>> from sympy import MatrixExpr, Integral + >>> from sympy.abc import x + >>> intM = Integral(A, x) + >>> isinstance(intM, MatrixExpr) + False + >>> intM.kind + MatrixKind(NumberKind) + + Use ``isinstance()`` to check for ``MatrixKind`` without specifying the + element kind. Use ``is`` to check the kind including the element kind: + + >>> from sympy import Matrix + >>> from sympy.core import NumberKind + >>> from sympy.matrices import MatrixKind + >>> M = Matrix([1, 2]) + >>> isinstance(M.kind, MatrixKind) + True + >>> M.kind is MatrixKind(NumberKind) + True + + See Also + ======== + + sympy.core.kind.NumberKind + sympy.core.kind.UndefinedKind + sympy.core.containers.TupleKind + sympy.sets.sets.SetKind + + """ + def __new__(cls, element_kind=NumberKind): + obj = super().__new__(cls, element_kind) + obj.element_kind = element_kind + return obj + + def __repr__(self): + return "MatrixKind(%s)" % self.element_kind + + +@Mul._kind_dispatcher.register(_NumberKind, MatrixKind) +def num_mat_mul(k1, k2): + """ + Return MatrixKind. The element kind is selected by recursive dispatching. + Do not need to dispatch in reversed order because KindDispatcher + searches for this automatically. + """ + # Deal with Mul._kind_dispatcher's commutativity + # XXX: this function is called with either k1 or k2 as MatrixKind because + # the Mul kind dispatcher is commutative. Maybe it shouldn't be. Need to + # swap the args here because NumberKind does not have an element_kind + # attribute. + if not isinstance(k2, MatrixKind): + k1, k2 = k2, k1 + elemk = Mul._kind_dispatcher(k1, k2.element_kind) + return MatrixKind(elemk) + + +@Mul._kind_dispatcher.register(MatrixKind, MatrixKind) +def mat_mat_mul(k1, k2): + """ + Return MatrixKind. The element kind is selected by recursive dispatching. + """ + elemk = Mul._kind_dispatcher(k1.element_kind, k2.element_kind) + return MatrixKind(elemk) diff --git a/phivenv/Lib/site-packages/sympy/matrices/matrices.py b/phivenv/Lib/site-packages/sympy/matrices/matrices.py new file mode 100644 index 0000000000000000000000000000000000000000..fed41a626cb395ac0529071317630d853e0d3a96 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/matrices.py @@ -0,0 +1,687 @@ +# +# A module consisting of deprecated matrix classes. New code should not be +# added here. +# +from sympy.core.basic import Basic +from sympy.core.symbol import Dummy + +from .common import MatrixCommon + +from .exceptions import NonSquareMatrixError + +from .utilities import _iszero, _is_zero_after_expand_mul, _simplify + +from .determinant import ( + _find_reasonable_pivot, _find_reasonable_pivot_naive, + _adjugate, _charpoly, _cofactor, _cofactor_matrix, _per, + _det, _det_bareiss, _det_berkowitz, _det_bird, _det_laplace, _det_LU, + _minor, _minor_submatrix) + +from .reductions import _is_echelon, _echelon_form, _rank, _rref +from .subspaces import _columnspace, _nullspace, _rowspace, _orthogonalize + +from .eigen import ( + _eigenvals, _eigenvects, + _bidiagonalize, _bidiagonal_decomposition, + _is_diagonalizable, _diagonalize, + _is_positive_definite, _is_positive_semidefinite, + _is_negative_definite, _is_negative_semidefinite, _is_indefinite, + _jordan_form, _left_eigenvects, _singular_values) + + +# This class was previously defined in this module, but was moved to +# sympy.matrices.matrixbase. We import it here for backwards compatibility in +# case someone was importing it from here. +from .matrixbase import MatrixBase + + +__doctest_requires__ = { + ('MatrixEigen.is_indefinite', + 'MatrixEigen.is_negative_definite', + 'MatrixEigen.is_negative_semidefinite', + 'MatrixEigen.is_positive_definite', + 'MatrixEigen.is_positive_semidefinite'): ['matplotlib'], +} + + +class MatrixDeterminant(MatrixCommon): + """Provides basic matrix determinant operations. Should not be instantiated + directly. See ``determinant.py`` for their implementations.""" + + def _eval_det_bareiss(self, iszerofunc=_is_zero_after_expand_mul): + return _det_bareiss(self, iszerofunc=iszerofunc) + + def _eval_det_berkowitz(self): + return _det_berkowitz(self) + + def _eval_det_lu(self, iszerofunc=_iszero, simpfunc=None): + return _det_LU(self, iszerofunc=iszerofunc, simpfunc=simpfunc) + + def _eval_det_bird(self): + return _det_bird(self) + + def _eval_det_laplace(self): + return _det_laplace(self) + + def _eval_determinant(self): # for expressions.determinant.Determinant + return _det(self) + + def adjugate(self, method="berkowitz"): + return _adjugate(self, method=method) + + def charpoly(self, x='lambda', simplify=_simplify): + return _charpoly(self, x=x, simplify=simplify) + + def cofactor(self, i, j, method="berkowitz"): + return _cofactor(self, i, j, method=method) + + def cofactor_matrix(self, method="berkowitz"): + return _cofactor_matrix(self, method=method) + + def det(self, method="bareiss", iszerofunc=None): + return _det(self, method=method, iszerofunc=iszerofunc) + + def per(self): + return _per(self) + + def minor(self, i, j, method="berkowitz"): + return _minor(self, i, j, method=method) + + def minor_submatrix(self, i, j): + return _minor_submatrix(self, i, j) + + _find_reasonable_pivot.__doc__ = _find_reasonable_pivot.__doc__ + _find_reasonable_pivot_naive.__doc__ = _find_reasonable_pivot_naive.__doc__ + _eval_det_bareiss.__doc__ = _det_bareiss.__doc__ + _eval_det_berkowitz.__doc__ = _det_berkowitz.__doc__ + _eval_det_bird.__doc__ = _det_bird.__doc__ + _eval_det_laplace.__doc__ = _det_laplace.__doc__ + _eval_det_lu.__doc__ = _det_LU.__doc__ + _eval_determinant.__doc__ = _det.__doc__ + adjugate.__doc__ = _adjugate.__doc__ + charpoly.__doc__ = _charpoly.__doc__ + cofactor.__doc__ = _cofactor.__doc__ + cofactor_matrix.__doc__ = _cofactor_matrix.__doc__ + det.__doc__ = _det.__doc__ + per.__doc__ = _per.__doc__ + minor.__doc__ = _minor.__doc__ + minor_submatrix.__doc__ = _minor_submatrix.__doc__ + + +class MatrixReductions(MatrixDeterminant): + """Provides basic matrix row/column operations. Should not be instantiated + directly. See ``reductions.py`` for some of their implementations.""" + + def echelon_form(self, iszerofunc=_iszero, simplify=False, with_pivots=False): + return _echelon_form(self, iszerofunc=iszerofunc, simplify=simplify, + with_pivots=with_pivots) + + @property + def is_echelon(self): + return _is_echelon(self) + + def rank(self, iszerofunc=_iszero, simplify=False): + return _rank(self, iszerofunc=iszerofunc, simplify=simplify) + + def rref_rhs(self, rhs): + """Return reduced row-echelon form of matrix, matrix showing + rhs after reduction steps. ``rhs`` must have the same number + of rows as ``self``. + + Examples + ======== + + >>> from sympy import Matrix, symbols + >>> r1, r2 = symbols('r1 r2') + >>> Matrix([[1, 1], [2, 1]]).rref_rhs(Matrix([r1, r2])) + (Matrix([ + [1, 0], + [0, 1]]), Matrix([ + [ -r1 + r2], + [2*r1 - r2]])) + """ + r, _ = _rref(self.hstack(self, self.eye(self.rows), rhs)) + return r[:, :self.cols], r[:, -rhs.cols:] + + def rref(self, iszerofunc=_iszero, simplify=False, pivots=True, + normalize_last=True): + return _rref(self, iszerofunc=iszerofunc, simplify=simplify, + pivots=pivots, normalize_last=normalize_last) + + echelon_form.__doc__ = _echelon_form.__doc__ + is_echelon.__doc__ = _is_echelon.__doc__ + rank.__doc__ = _rank.__doc__ + rref.__doc__ = _rref.__doc__ + + def _normalize_op_args(self, op, col, k, col1, col2, error_str="col"): + """Validate the arguments for a row/column operation. ``error_str`` + can be one of "row" or "col" depending on the arguments being parsed.""" + if op not in ["n->kn", "n<->m", "n->n+km"]: + raise ValueError("Unknown {} operation '{}'. Valid col operations " + "are 'n->kn', 'n<->m', 'n->n+km'".format(error_str, op)) + + # define self_col according to error_str + self_cols = self.cols if error_str == 'col' else self.rows + + # normalize and validate the arguments + if op == "n->kn": + col = col if col is not None else col1 + if col is None or k is None: + raise ValueError("For a {0} operation 'n->kn' you must provide the " + "kwargs `{0}` and `k`".format(error_str)) + if not 0 <= col < self_cols: + raise ValueError("This matrix does not have a {} '{}'".format(error_str, col)) + + elif op == "n<->m": + # we need two cols to swap. It does not matter + # how they were specified, so gather them together and + # remove `None` + cols = {col, k, col1, col2}.difference([None]) + if len(cols) > 2: + # maybe the user left `k` by mistake? + cols = {col, col1, col2}.difference([None]) + if len(cols) != 2: + raise ValueError("For a {0} operation 'n<->m' you must provide the " + "kwargs `{0}1` and `{0}2`".format(error_str)) + col1, col2 = cols + if not 0 <= col1 < self_cols: + raise ValueError("This matrix does not have a {} '{}'".format(error_str, col1)) + if not 0 <= col2 < self_cols: + raise ValueError("This matrix does not have a {} '{}'".format(error_str, col2)) + + elif op == "n->n+km": + col = col1 if col is None else col + col2 = col1 if col2 is None else col2 + if col is None or col2 is None or k is None: + raise ValueError("For a {0} operation 'n->n+km' you must provide the " + "kwargs `{0}`, `k`, and `{0}2`".format(error_str)) + if col == col2: + raise ValueError("For a {0} operation 'n->n+km' `{0}` and `{0}2` must " + "be different.".format(error_str)) + if not 0 <= col < self_cols: + raise ValueError("This matrix does not have a {} '{}'".format(error_str, col)) + if not 0 <= col2 < self_cols: + raise ValueError("This matrix does not have a {} '{}'".format(error_str, col2)) + + else: + raise ValueError('invalid operation %s' % repr(op)) + + return op, col, k, col1, col2 + + def _eval_col_op_multiply_col_by_const(self, col, k): + def entry(i, j): + if j == col: + return k * self[i, j] + return self[i, j] + return self._new(self.rows, self.cols, entry) + + def _eval_col_op_swap(self, col1, col2): + def entry(i, j): + if j == col1: + return self[i, col2] + elif j == col2: + return self[i, col1] + return self[i, j] + return self._new(self.rows, self.cols, entry) + + def _eval_col_op_add_multiple_to_other_col(self, col, k, col2): + def entry(i, j): + if j == col: + return self[i, j] + k * self[i, col2] + return self[i, j] + return self._new(self.rows, self.cols, entry) + + def _eval_row_op_swap(self, row1, row2): + def entry(i, j): + if i == row1: + return self[row2, j] + elif i == row2: + return self[row1, j] + return self[i, j] + return self._new(self.rows, self.cols, entry) + + def _eval_row_op_multiply_row_by_const(self, row, k): + def entry(i, j): + if i == row: + return k * self[i, j] + return self[i, j] + return self._new(self.rows, self.cols, entry) + + def _eval_row_op_add_multiple_to_other_row(self, row, k, row2): + def entry(i, j): + if i == row: + return self[i, j] + k * self[row2, j] + return self[i, j] + return self._new(self.rows, self.cols, entry) + + def elementary_col_op(self, op="n->kn", col=None, k=None, col1=None, col2=None): + """Performs the elementary column operation `op`. + + `op` may be one of + + * ``"n->kn"`` (column n goes to k*n) + * ``"n<->m"`` (swap column n and column m) + * ``"n->n+km"`` (column n goes to column n + k*column m) + + Parameters + ========== + + op : string; the elementary row operation + col : the column to apply the column operation + k : the multiple to apply in the column operation + col1 : one column of a column swap + col2 : second column of a column swap or column "m" in the column operation + "n->n+km" + """ + + op, col, k, col1, col2 = self._normalize_op_args(op, col, k, col1, col2, "col") + + # now that we've validated, we're all good to dispatch + if op == "n->kn": + return self._eval_col_op_multiply_col_by_const(col, k) + if op == "n<->m": + return self._eval_col_op_swap(col1, col2) + if op == "n->n+km": + return self._eval_col_op_add_multiple_to_other_col(col, k, col2) + + def elementary_row_op(self, op="n->kn", row=None, k=None, row1=None, row2=None): + """Performs the elementary row operation `op`. + + `op` may be one of + + * ``"n->kn"`` (row n goes to k*n) + * ``"n<->m"`` (swap row n and row m) + * ``"n->n+km"`` (row n goes to row n + k*row m) + + Parameters + ========== + + op : string; the elementary row operation + row : the row to apply the row operation + k : the multiple to apply in the row operation + row1 : one row of a row swap + row2 : second row of a row swap or row "m" in the row operation + "n->n+km" + """ + + op, row, k, row1, row2 = self._normalize_op_args(op, row, k, row1, row2, "row") + + # now that we've validated, we're all good to dispatch + if op == "n->kn": + return self._eval_row_op_multiply_row_by_const(row, k) + if op == "n<->m": + return self._eval_row_op_swap(row1, row2) + if op == "n->n+km": + return self._eval_row_op_add_multiple_to_other_row(row, k, row2) + + +class MatrixSubspaces(MatrixReductions): + """Provides methods relating to the fundamental subspaces of a matrix. + Should not be instantiated directly. See ``subspaces.py`` for their + implementations.""" + + def columnspace(self, simplify=False): + return _columnspace(self, simplify=simplify) + + def nullspace(self, simplify=False, iszerofunc=_iszero): + return _nullspace(self, simplify=simplify, iszerofunc=iszerofunc) + + def rowspace(self, simplify=False): + return _rowspace(self, simplify=simplify) + + # This is a classmethod but is converted to such later in order to allow + # assignment of __doc__ since that does not work for already wrapped + # classmethods in Python 3.6. + def orthogonalize(cls, *vecs, **kwargs): + return _orthogonalize(cls, *vecs, **kwargs) + + columnspace.__doc__ = _columnspace.__doc__ + nullspace.__doc__ = _nullspace.__doc__ + rowspace.__doc__ = _rowspace.__doc__ + orthogonalize.__doc__ = _orthogonalize.__doc__ + + orthogonalize = classmethod(orthogonalize) # type:ignore + + +class MatrixEigen(MatrixSubspaces): + """Provides basic matrix eigenvalue/vector operations. + Should not be instantiated directly. See ``eigen.py`` for their + implementations.""" + + def eigenvals(self, error_when_incomplete=True, **flags): + return _eigenvals(self, error_when_incomplete=error_when_incomplete, **flags) + + def eigenvects(self, error_when_incomplete=True, iszerofunc=_iszero, **flags): + return _eigenvects(self, error_when_incomplete=error_when_incomplete, + iszerofunc=iszerofunc, **flags) + + def is_diagonalizable(self, reals_only=False, **kwargs): + return _is_diagonalizable(self, reals_only=reals_only, **kwargs) + + def diagonalize(self, reals_only=False, sort=False, normalize=False): + return _diagonalize(self, reals_only=reals_only, sort=sort, + normalize=normalize) + + def bidiagonalize(self, upper=True): + return _bidiagonalize(self, upper=upper) + + def bidiagonal_decomposition(self, upper=True): + return _bidiagonal_decomposition(self, upper=upper) + + @property + def is_positive_definite(self): + return _is_positive_definite(self) + + @property + def is_positive_semidefinite(self): + return _is_positive_semidefinite(self) + + @property + def is_negative_definite(self): + return _is_negative_definite(self) + + @property + def is_negative_semidefinite(self): + return _is_negative_semidefinite(self) + + @property + def is_indefinite(self): + return _is_indefinite(self) + + def jordan_form(self, calc_transform=True, **kwargs): + return _jordan_form(self, calc_transform=calc_transform, **kwargs) + + def left_eigenvects(self, **flags): + return _left_eigenvects(self, **flags) + + def singular_values(self): + return _singular_values(self) + + eigenvals.__doc__ = _eigenvals.__doc__ + eigenvects.__doc__ = _eigenvects.__doc__ + is_diagonalizable.__doc__ = _is_diagonalizable.__doc__ + diagonalize.__doc__ = _diagonalize.__doc__ + is_positive_definite.__doc__ = _is_positive_definite.__doc__ + is_positive_semidefinite.__doc__ = _is_positive_semidefinite.__doc__ + is_negative_definite.__doc__ = _is_negative_definite.__doc__ + is_negative_semidefinite.__doc__ = _is_negative_semidefinite.__doc__ + is_indefinite.__doc__ = _is_indefinite.__doc__ + jordan_form.__doc__ = _jordan_form.__doc__ + left_eigenvects.__doc__ = _left_eigenvects.__doc__ + singular_values.__doc__ = _singular_values.__doc__ + bidiagonalize.__doc__ = _bidiagonalize.__doc__ + bidiagonal_decomposition.__doc__ = _bidiagonal_decomposition.__doc__ + + +class MatrixCalculus(MatrixCommon): + """Provides calculus-related matrix operations.""" + + def diff(self, *args, evaluate=True, **kwargs): + """Calculate the derivative of each element in the matrix. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.abc import x, y + >>> M = Matrix([[x, y], [1, 0]]) + >>> M.diff(x) + Matrix([ + [1, 0], + [0, 0]]) + + See Also + ======== + + integrate + limit + """ + # XXX this should be handled here rather than in Derivative + from sympy.tensor.array.array_derivatives import ArrayDerivative + deriv = ArrayDerivative(self, *args, evaluate=evaluate) + # XXX This can rather changed to always return immutable matrix + if not isinstance(self, Basic) and evaluate: + return deriv.as_mutable() + return deriv + + def _eval_derivative(self, arg): + return self.applyfunc(lambda x: x.diff(arg)) + + def integrate(self, *args, **kwargs): + """Integrate each element of the matrix. ``args`` will + be passed to the ``integrate`` function. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.abc import x, y + >>> M = Matrix([[x, y], [1, 0]]) + >>> M.integrate((x, )) + Matrix([ + [x**2/2, x*y], + [ x, 0]]) + >>> M.integrate((x, 0, 2)) + Matrix([ + [2, 2*y], + [2, 0]]) + + See Also + ======== + + limit + diff + """ + return self.applyfunc(lambda x: x.integrate(*args, **kwargs)) + + def jacobian(self, X): + """Calculates the Jacobian matrix (derivative of a vector-valued function). + + Parameters + ========== + + ``self`` : vector of expressions representing functions f_i(x_1, ..., x_n). + X : set of x_i's in order, it can be a list or a Matrix + + Both ``self`` and X can be a row or a column matrix in any order + (i.e., jacobian() should always work). + + Examples + ======== + + >>> from sympy import sin, cos, Matrix + >>> from sympy.abc import rho, phi + >>> X = Matrix([rho*cos(phi), rho*sin(phi), rho**2]) + >>> Y = Matrix([rho, phi]) + >>> X.jacobian(Y) + Matrix([ + [cos(phi), -rho*sin(phi)], + [sin(phi), rho*cos(phi)], + [ 2*rho, 0]]) + >>> X = Matrix([rho*cos(phi), rho*sin(phi)]) + >>> X.jacobian(Y) + Matrix([ + [cos(phi), -rho*sin(phi)], + [sin(phi), rho*cos(phi)]]) + + See Also + ======== + + hessian + wronskian + """ + if not isinstance(X, MatrixBase): + X = self._new(X) + # Both X and ``self`` can be a row or a column matrix, so we need to make + # sure all valid combinations work, but everything else fails: + if self.shape[0] == 1: + m = self.shape[1] + elif self.shape[1] == 1: + m = self.shape[0] + else: + raise TypeError("``self`` must be a row or a column matrix") + if X.shape[0] == 1: + n = X.shape[1] + elif X.shape[1] == 1: + n = X.shape[0] + else: + raise TypeError("X must be a row or a column matrix") + + # m is the number of functions and n is the number of variables + # computing the Jacobian is now easy: + return self._new(m, n, lambda j, i: self[j].diff(X[i])) + + def limit(self, *args): + """Calculate the limit of each element in the matrix. + ``args`` will be passed to the ``limit`` function. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.abc import x, y + >>> M = Matrix([[x, y], [1, 0]]) + >>> M.limit(x, 2) + Matrix([ + [2, y], + [1, 0]]) + + See Also + ======== + + integrate + diff + """ + return self.applyfunc(lambda x: x.limit(*args)) + + +# https://github.com/sympy/sympy/pull/12854 +class MatrixDeprecated(MatrixCommon): + """A class to house deprecated matrix methods.""" + def berkowitz_charpoly(self, x=Dummy('lambda'), simplify=_simplify): + return self.charpoly(x=x) + + def berkowitz_det(self): + """Computes determinant using Berkowitz method. + + See Also + ======== + + det + berkowitz + """ + return self.det(method='berkowitz') + + def berkowitz_eigenvals(self, **flags): + """Computes eigenvalues of a Matrix using Berkowitz method. + + See Also + ======== + + berkowitz + """ + return self.eigenvals(**flags) + + def berkowitz_minors(self): + """Computes principal minors using Berkowitz method. + + See Also + ======== + + berkowitz + """ + sign, minors = self.one, [] + + for poly in self.berkowitz(): + minors.append(sign * poly[-1]) + sign = -sign + + return tuple(minors) + + def berkowitz(self): + from sympy.matrices import zeros + berk = ((1,),) + if not self: + return berk + + if not self.is_square: + raise NonSquareMatrixError() + + A, N = self, self.rows + transforms = [0] * (N - 1) + + for n in range(N, 1, -1): + T, k = zeros(n + 1, n), n - 1 + + R, C = -A[k, :k], A[:k, k] + A, a = A[:k, :k], -A[k, k] + + items = [C] + + for i in range(0, n - 2): + items.append(A * items[i]) + + for i, B in enumerate(items): + items[i] = (R * B)[0, 0] + + items = [self.one, a] + items + + for i in range(n): + T[i:, i] = items[:n - i + 1] + + transforms[k - 1] = T + + polys = [self._new([self.one, -A[0, 0]])] + + for i, T in enumerate(transforms): + polys.append(T * polys[i]) + + return berk + tuple(map(tuple, polys)) + + def cofactorMatrix(self, method="berkowitz"): + return self.cofactor_matrix(method=method) + + def det_bareis(self): + return _det_bareiss(self) + + def det_LU_decomposition(self): + """Compute matrix determinant using LU decomposition. + + + Note that this method fails if the LU decomposition itself + fails. In particular, if the matrix has no inverse this method + will fail. + + TODO: Implement algorithm for sparse matrices (SFF), + https://www.eecis.udel.edu/~saunders/papers/sffge/it5.ps + + See Also + ======== + + + det + det_bareiss + berkowitz_det + """ + return self.det(method='lu') + + def jordan_cell(self, eigenval, n): + return self.jordan_block(size=n, eigenvalue=eigenval) + + def jordan_cells(self, calc_transformation=True): + P, J = self.jordan_form() + return P, J.get_diag_blocks() + + def minorEntry(self, i, j, method="berkowitz"): + return self.minor(i, j, method=method) + + def minorMatrix(self, i, j): + return self.minor_submatrix(i, j) + + def permuteBkwd(self, perm): + """Permute the rows of the matrix with the given permutation in reverse.""" + return self.permute_rows(perm, direction='backward') + + def permuteFwd(self, perm): + """Permute the rows of the matrix with the given permutation.""" + return self.permute_rows(perm, direction='forward') diff --git a/phivenv/Lib/site-packages/sympy/matrices/matrixbase.py b/phivenv/Lib/site-packages/sympy/matrices/matrixbase.py new file mode 100644 index 0000000000000000000000000000000000000000..49acc04043b30e003f7eed256f2e06e6a6556401 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/matrixbase.py @@ -0,0 +1,5428 @@ +from __future__ import annotations +from collections import defaultdict +from collections.abc import Iterable +from inspect import isfunction +from functools import reduce + +from sympy.assumptions.refine import refine +from sympy.core import SympifyError, Add +from sympy.core.basic import Atom, Basic +from sympy.core.kind import UndefinedKind +from sympy.core.numbers import Integer +from sympy.core.mod import Mod +from sympy.core.symbol import Symbol, Dummy +from sympy.core.sympify import sympify, _sympify +from sympy.core.function import diff +from sympy.polys import cancel +from sympy.functions.elementary.complexes import Abs, re, im +from sympy.printing import sstr +from sympy.functions.elementary.miscellaneous import Max, Min, sqrt +from sympy.functions.special.tensor_functions import KroneckerDelta, LeviCivita +from sympy.core.singleton import S +from sympy.printing.defaults import Printable +from sympy.printing.str import StrPrinter +from sympy.functions.elementary.exponential import exp, log +from sympy.functions.combinatorial.factorials import binomial, factorial + +import mpmath as mp +from collections.abc import Callable +from sympy.utilities.iterables import reshape +from sympy.core.expr import Expr +from sympy.core.power import Pow +from sympy.core.symbol import uniquely_named_symbol + +from .utilities import _dotprodsimp, _simplify as _utilities_simplify +from sympy.polys.polytools import Poly +from sympy.utilities.iterables import flatten, is_sequence +from sympy.utilities.misc import as_int, filldedent +from sympy.core.decorators import call_highest_priority +from sympy.core.logic import fuzzy_and, FuzzyBool +from sympy.tensor.array import NDimArray +from sympy.utilities.iterables import NotIterable + +from .utilities import _get_intermediate_simp_bool + +from .kind import MatrixKind + +from .exceptions import ( + MatrixError, ShapeError, NonSquareMatrixError, NonInvertibleMatrixError, +) + +from .utilities import _iszero, _is_zero_after_expand_mul + +from .determinant import ( + _find_reasonable_pivot, _find_reasonable_pivot_naive, + _adjugate, _charpoly, _cofactor, _cofactor_matrix, _per, + _det, _det_bareiss, _det_berkowitz, _det_bird, _det_laplace, _det_LU, + _minor, _minor_submatrix) + +from .reductions import _is_echelon, _echelon_form, _rank, _rref + +from .solvers import ( + _diagonal_solve, _lower_triangular_solve, _upper_triangular_solve, + _cholesky_solve, _LDLsolve, _LUsolve, _QRsolve, _gauss_jordan_solve, + _pinv_solve, _cramer_solve, _solve, _solve_least_squares) + +from .inverse import ( + _pinv, _inv_ADJ, _inv_GE, _inv_LU, _inv_CH, _inv_LDL, _inv_QR, + _inv, _inv_block) + +from .subspaces import _columnspace, _nullspace, _rowspace, _orthogonalize + +from .eigen import ( + _eigenvals, _eigenvects, + _bidiagonalize, _bidiagonal_decomposition, + _is_diagonalizable, _diagonalize, + _is_positive_definite, _is_positive_semidefinite, + _is_negative_definite, _is_negative_semidefinite, _is_indefinite, + _jordan_form, _left_eigenvects, _singular_values) + +from .decompositions import ( + _rank_decomposition, _cholesky, _LDLdecomposition, + _LUdecomposition, _LUdecomposition_Simple, _LUdecompositionFF, + _singular_value_decomposition, _QRdecomposition, _upper_hessenberg_decomposition) + +from .graph import ( + _connected_components, _connected_components_decomposition, + _strongly_connected_components, _strongly_connected_components_decomposition) + + +__doctest_requires__ = { + ('MatrixBase.is_indefinite', + 'MatrixBase.is_positive_definite', + 'MatrixBase.is_positive_semidefinite', + 'MatrixBase.is_negative_definite', + 'MatrixBase.is_negative_semidefinite'): ['matplotlib'], +} + + +class MatrixBase(Printable): + """All common matrix operations including basic arithmetic, shaping, + and special matrices like `zeros`, and `eye`.""" + + _op_priority = 10.01 + + # Added just for numpy compatibility + __array_priority__ = 11 + + is_Matrix = True + _class_priority = 3 + _sympify = staticmethod(sympify) + zero = S.Zero + one = S.One + + _diff_wrt: bool = True + rows: int + cols: int + _simplify = None + + @classmethod + def _new(cls, *args, **kwargs): + """`_new` must, at minimum, be callable as + `_new(rows, cols, mat) where mat is a flat list of the + elements of the matrix.""" + raise NotImplementedError("Subclasses must implement this.") + + def __eq__(self, other): + raise NotImplementedError("Subclasses must implement this.") + + def __getitem__(self, key): + """Implementations of __getitem__ should accept ints, in which + case the matrix is indexed as a flat list, tuples (i,j) in which + case the (i,j) entry is returned, slices, or mixed tuples (a,b) + where a and b are any combination of slices and integers.""" + raise NotImplementedError("Subclasses must implement this.") + + @property + def shape(self): + """The shape (dimensions) of the matrix as the 2-tuple (rows, cols). + + Examples + ======== + + >>> from sympy import zeros + >>> M = zeros(2, 3) + >>> M.shape + (2, 3) + >>> M.rows + 2 + >>> M.cols + 3 + """ + return (self.rows, self.cols) + + def _eval_col_del(self, col): + def entry(i, j): + return self[i, j] if j < col else self[i, j + 1] + return self._new(self.rows, self.cols - 1, entry) + + def _eval_col_insert(self, pos, other): + + def entry(i, j): + if j < pos: + return self[i, j] + elif pos <= j < pos + other.cols: + return other[i, j - pos] + return self[i, j - other.cols] + + return self._new(self.rows, self.cols + other.cols, entry) + + def _eval_col_join(self, other): + rows = self.rows + + def entry(i, j): + if i < rows: + return self[i, j] + return other[i - rows, j] + + return classof(self, other)._new(self.rows + other.rows, self.cols, + entry) + + def _eval_extract(self, rowsList, colsList): + mat = list(self) + cols = self.cols + indices = (i * cols + j for i in rowsList for j in colsList) + return self._new(len(rowsList), len(colsList), + [mat[i] for i in indices]) + + def _eval_get_diag_blocks(self): + sub_blocks = [] + + def recurse_sub_blocks(M): + for i in range(1, M.shape[0] + 1): + if i == 1: + to_the_right = M[0, i:] + to_the_bottom = M[i:, 0] + else: + to_the_right = M[:i, i:] + to_the_bottom = M[i:, :i] + if any(to_the_right) or any(to_the_bottom): + continue + sub_blocks.append(M[:i, :i]) + if M.shape != M[:i, :i].shape: + recurse_sub_blocks(M[i:, i:]) + return + + recurse_sub_blocks(self) + return sub_blocks + + def _eval_row_del(self, row): + def entry(i, j): + return self[i, j] if i < row else self[i + 1, j] + return self._new(self.rows - 1, self.cols, entry) + + def _eval_row_insert(self, pos, other): + entries = list(self) + insert_pos = pos * self.cols + entries[insert_pos:insert_pos] = list(other) + return self._new(self.rows + other.rows, self.cols, entries) + + def _eval_row_join(self, other): + cols = self.cols + + def entry(i, j): + if j < cols: + return self[i, j] + return other[i, j - cols] + + return classof(self, other)._new(self.rows, self.cols + other.cols, + entry) + + def _eval_tolist(self): + return [list(self[i,:]) for i in range(self.rows)] + + def _eval_todok(self): + dok = {} + rows, cols = self.shape + for i in range(rows): + for j in range(cols): + val = self[i, j] + if val != self.zero: + dok[i, j] = val + return dok + + @classmethod + def _eval_from_dok(cls, rows, cols, dok): + out_flat = [cls.zero] * (rows * cols) + for (i, j), val in dok.items(): + out_flat[i * cols + j] = val + return cls._new(rows, cols, out_flat) + + def _eval_vec(self): + rows = self.rows + + def entry(n, _): + # we want to read off the columns first + j = n // rows + i = n - j * rows + return self[i, j] + + return self._new(len(self), 1, entry) + + def _eval_vech(self, diagonal): + c = self.cols + v = [] + if diagonal: + for j in range(c): + for i in range(j, c): + v.append(self[i, j]) + else: + for j in range(c): + for i in range(j + 1, c): + v.append(self[i, j]) + return self._new(len(v), 1, v) + + def col_del(self, col): + """Delete the specified column.""" + if col < 0: + col += self.cols + if not 0 <= col < self.cols: + raise IndexError("Column {} is out of range.".format(col)) + return self._eval_col_del(col) + + def col_insert(self, pos, other): + """Insert one or more columns at the given column position. + + Examples + ======== + + >>> from sympy import zeros, ones + >>> M = zeros(3) + >>> V = ones(3, 1) + >>> M.col_insert(1, V) + Matrix([ + [0, 1, 0, 0], + [0, 1, 0, 0], + [0, 1, 0, 0]]) + + See Also + ======== + + col + row_insert + """ + # Allows you to build a matrix even if it is null matrix + if not self: + return type(self)(other) + + pos = as_int(pos) + + if pos < 0: + pos = self.cols + pos + if pos < 0: + pos = 0 + elif pos > self.cols: + pos = self.cols + + if self.rows != other.rows: + raise ShapeError( + "The matrices have incompatible number of rows ({} and {})" + .format(self.rows, other.rows)) + + return self._eval_col_insert(pos, other) + + def col_join(self, other): + """Concatenates two matrices along self's last and other's first row. + + Examples + ======== + + >>> from sympy import zeros, ones + >>> M = zeros(3) + >>> V = ones(1, 3) + >>> M.col_join(V) + Matrix([ + [0, 0, 0], + [0, 0, 0], + [0, 0, 0], + [1, 1, 1]]) + + See Also + ======== + + col + row_join + """ + # A null matrix can always be stacked (see #10770) + if self.rows == 0 and self.cols != other.cols: + return self._new(0, other.cols, []).col_join(other) + + if self.cols != other.cols: + raise ShapeError( + "The matrices have incompatible number of columns ({} and {})" + .format(self.cols, other.cols)) + return self._eval_col_join(other) + + def col(self, j): + """Elementary column selector. + + Examples + ======== + + >>> from sympy import eye + >>> eye(2).col(0) + Matrix([ + [1], + [0]]) + + See Also + ======== + + row + col_del + col_join + col_insert + """ + return self[:, j] + + def extract(self, rowsList, colsList): + r"""Return a submatrix by specifying a list of rows and columns. + Negative indices can be given. All indices must be in the range + $-n \le i < n$ where $n$ is the number of rows or columns. + + Examples + ======== + + >>> from sympy import Matrix + >>> m = Matrix(4, 3, range(12)) + >>> m + Matrix([ + [0, 1, 2], + [3, 4, 5], + [6, 7, 8], + [9, 10, 11]]) + >>> m.extract([0, 1, 3], [0, 1]) + Matrix([ + [0, 1], + [3, 4], + [9, 10]]) + + Rows or columns can be repeated: + + >>> m.extract([0, 0, 1], [-1]) + Matrix([ + [2], + [2], + [5]]) + + Every other row can be taken by using range to provide the indices: + + >>> m.extract(range(0, m.rows, 2), [-1]) + Matrix([ + [2], + [8]]) + + RowsList or colsList can also be a list of booleans, in which case + the rows or columns corresponding to the True values will be selected: + + >>> m.extract([0, 1, 2, 3], [True, False, True]) + Matrix([ + [0, 2], + [3, 5], + [6, 8], + [9, 11]]) + """ + + if not is_sequence(rowsList) or not is_sequence(colsList): + raise TypeError("rowsList and colsList must be iterable") + # ensure rowsList and colsList are lists of integers + if rowsList and all(isinstance(i, bool) for i in rowsList): + rowsList = [index for index, item in enumerate(rowsList) if item] + if colsList and all(isinstance(i, bool) for i in colsList): + colsList = [index for index, item in enumerate(colsList) if item] + + # ensure everything is in range + rowsList = [a2idx(k, self.rows) for k in rowsList] + colsList = [a2idx(k, self.cols) for k in colsList] + + return self._eval_extract(rowsList, colsList) + + def get_diag_blocks(self): + """Obtains the square sub-matrices on the main diagonal of a square matrix. + + Useful for inverting symbolic matrices or solving systems of + linear equations which may be decoupled by having a block diagonal + structure. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.abc import x, y, z + >>> A = Matrix([[1, 3, 0, 0], [y, z*z, 0, 0], [0, 0, x, 0], [0, 0, 0, 0]]) + >>> a1, a2, a3 = A.get_diag_blocks() + >>> a1 + Matrix([ + [1, 3], + [y, z**2]]) + >>> a2 + Matrix([[x]]) + >>> a3 + Matrix([[0]]) + + """ + return self._eval_get_diag_blocks() + + @classmethod + def hstack(cls, *args): + """Return a matrix formed by joining args horizontally (i.e. + by repeated application of row_join). + + Examples + ======== + + >>> from sympy import Matrix, eye + >>> Matrix.hstack(eye(2), 2*eye(2)) + Matrix([ + [1, 0, 2, 0], + [0, 1, 0, 2]]) + """ + if len(args) == 0: + return cls._new() + + kls = type(args[0]) + return reduce(kls.row_join, args) + + def reshape(self, rows, cols): + """Reshape the matrix. Total number of elements must remain the same. + + Examples + ======== + + >>> from sympy import Matrix + >>> m = Matrix(2, 3, lambda i, j: 1) + >>> m + Matrix([ + [1, 1, 1], + [1, 1, 1]]) + >>> m.reshape(1, 6) + Matrix([[1, 1, 1, 1, 1, 1]]) + >>> m.reshape(3, 2) + Matrix([ + [1, 1], + [1, 1], + [1, 1]]) + + """ + if self.rows * self.cols != rows * cols: + raise ValueError("Invalid reshape parameters %d %d" % (rows, cols)) + dok = {divmod(i*self.cols + j, cols): + v for (i, j), v in self.todok().items()} + return self._eval_from_dok(rows, cols, dok) + + def row_del(self, row): + """Delete the specified row.""" + if row < 0: + row += self.rows + if not 0 <= row < self.rows: + raise IndexError("Row {} is out of range.".format(row)) + + return self._eval_row_del(row) + + def row_insert(self, pos, other): + """Insert one or more rows at the given row position. + + Examples + ======== + + >>> from sympy import zeros, ones + >>> M = zeros(3) + >>> V = ones(1, 3) + >>> M.row_insert(1, V) + Matrix([ + [0, 0, 0], + [1, 1, 1], + [0, 0, 0], + [0, 0, 0]]) + + See Also + ======== + + row + col_insert + """ + # Allows you to build a matrix even if it is null matrix + if not self: + return self._new(other) + + pos = as_int(pos) + + if pos < 0: + pos = self.rows + pos + if pos < 0: + pos = 0 + elif pos > self.rows: + pos = self.rows + + if self.cols != other.cols: + raise ShapeError( + "The matrices have incompatible number of columns ({} and {})" + .format(self.cols, other.cols)) + + return self._eval_row_insert(pos, other) + + def row_join(self, other): + """Concatenates two matrices along self's last and rhs's first column + + Examples + ======== + + >>> from sympy import zeros, ones + >>> M = zeros(3) + >>> V = ones(3, 1) + >>> M.row_join(V) + Matrix([ + [0, 0, 0, 1], + [0, 0, 0, 1], + [0, 0, 0, 1]]) + + See Also + ======== + + row + col_join + """ + # A null matrix can always be stacked (see #10770) + if self.cols == 0 and self.rows != other.rows: + return self._new(other.rows, 0, []).row_join(other) + + if self.rows != other.rows: + raise ShapeError( + "The matrices have incompatible number of rows ({} and {})" + .format(self.rows, other.rows)) + return self._eval_row_join(other) + + def diagonal(self, k=0): + """Returns the kth diagonal of self. The main diagonal + corresponds to `k=0`; diagonals above and below correspond to + `k > 0` and `k < 0`, respectively. The values of `self[i, j]` + for which `j - i = k`, are returned in order of increasing + `i + j`, starting with `i + j = |k|`. + + Examples + ======== + + >>> from sympy import Matrix + >>> m = Matrix(3, 3, lambda i, j: j - i); m + Matrix([ + [ 0, 1, 2], + [-1, 0, 1], + [-2, -1, 0]]) + >>> _.diagonal() + Matrix([[0, 0, 0]]) + >>> m.diagonal(1) + Matrix([[1, 1]]) + >>> m.diagonal(-2) + Matrix([[-2]]) + + Even though the diagonal is returned as a Matrix, the element + retrieval can be done with a single index: + + >>> Matrix.diag(1, 2, 3).diagonal()[1] # instead of [0, 1] + 2 + + See Also + ======== + + diag + """ + rv = [] + k = as_int(k) + r = 0 if k > 0 else -k + c = 0 if r else k + while True: + if r == self.rows or c == self.cols: + break + rv.append(self[r, c]) + r += 1 + c += 1 + if not rv: + raise ValueError(filldedent(''' + The %s diagonal is out of range [%s, %s]''' % ( + k, 1 - self.rows, self.cols - 1))) + return self._new(1, len(rv), rv) + + def row(self, i): + """Elementary row selector. + + Examples + ======== + + >>> from sympy import eye + >>> eye(2).row(0) + Matrix([[1, 0]]) + + See Also + ======== + + col + row_del + row_join + row_insert + """ + return self[i, :] + + def todok(self): + """Return the matrix as dictionary of keys. + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix.eye(3) + >>> M.todok() + {(0, 0): 1, (1, 1): 1, (2, 2): 1} + """ + return self._eval_todok() + + @classmethod + def from_dok(cls, rows, cols, dok): + """Create a matrix from a dictionary of keys. + + Examples + ======== + + >>> from sympy import Matrix + >>> d = {(0, 0): 1, (1, 2): 3, (2, 1): 4} + >>> Matrix.from_dok(3, 3, d) + Matrix([ + [1, 0, 0], + [0, 0, 3], + [0, 4, 0]]) + """ + dok = {ij: cls._sympify(val) for ij, val in dok.items()} + return cls._eval_from_dok(rows, cols, dok) + + def tolist(self): + """Return the Matrix as a nested Python list. + + Examples + ======== + + >>> from sympy import Matrix, ones + >>> m = Matrix(3, 3, range(9)) + >>> m + Matrix([ + [0, 1, 2], + [3, 4, 5], + [6, 7, 8]]) + >>> m.tolist() + [[0, 1, 2], [3, 4, 5], [6, 7, 8]] + >>> ones(3, 0).tolist() + [[], [], []] + + When there are no rows then it will not be possible to tell how + many columns were in the original matrix: + + >>> ones(0, 3).tolist() + [] + + """ + if not self.rows: + return [] + if not self.cols: + return [[] for i in range(self.rows)] + return self._eval_tolist() + + def todod(M): + """Returns matrix as dict of dicts containing non-zero elements of the Matrix + + Examples + ======== + + >>> from sympy import Matrix + >>> A = Matrix([[0, 1],[0, 3]]) + >>> A + Matrix([ + [0, 1], + [0, 3]]) + >>> A.todod() + {0: {1: 1}, 1: {1: 3}} + + + """ + rowsdict = {} + Mlol = M.tolist() + for i, Mi in enumerate(Mlol): + row = {j: Mij for j, Mij in enumerate(Mi) if Mij} + if row: + rowsdict[i] = row + return rowsdict + + def vec(self): + """Return the Matrix converted into a one column matrix by stacking columns + + Examples + ======== + + >>> from sympy import Matrix + >>> m=Matrix([[1, 3], [2, 4]]) + >>> m + Matrix([ + [1, 3], + [2, 4]]) + >>> m.vec() + Matrix([ + [1], + [2], + [3], + [4]]) + + See Also + ======== + + vech + """ + return self._eval_vec() + + def vech(self, diagonal=True, check_symmetry=True): + """Reshapes the matrix into a column vector by stacking the + elements in the lower triangle. + + Parameters + ========== + + diagonal : bool, optional + If ``True``, it includes the diagonal elements. + + check_symmetry : bool, optional + If ``True``, it checks whether the matrix is symmetric. + + Examples + ======== + + >>> from sympy import Matrix + >>> m=Matrix([[1, 2], [2, 3]]) + >>> m + Matrix([ + [1, 2], + [2, 3]]) + >>> m.vech() + Matrix([ + [1], + [2], + [3]]) + >>> m.vech(diagonal=False) + Matrix([[2]]) + + Notes + ===== + + This should work for symmetric matrices and ``vech`` can + represent symmetric matrices in vector form with less size than + ``vec``. + + See Also + ======== + + vec + """ + if not self.is_square: + raise NonSquareMatrixError + + if check_symmetry and not self.is_symmetric(): + raise ValueError("The matrix is not symmetric.") + + return self._eval_vech(diagonal) + + @classmethod + def vstack(cls, *args): + """Return a matrix formed by joining args vertically (i.e. + by repeated application of col_join). + + Examples + ======== + + >>> from sympy import Matrix, eye + >>> Matrix.vstack(eye(2), 2*eye(2)) + Matrix([ + [1, 0], + [0, 1], + [2, 0], + [0, 2]]) + """ + if len(args) == 0: + return cls._new() + + kls = type(args[0]) + return reduce(kls.col_join, args) + + @classmethod + def _eval_diag(cls, rows, cols, diag_dict): + """diag_dict is a defaultdict containing + all the entries of the diagonal matrix.""" + def entry(i, j): + return diag_dict[(i, j)] + return cls._new(rows, cols, entry) + + @classmethod + def _eval_eye(cls, rows, cols): + vals = [cls.zero]*(rows*cols) + vals[::cols+1] = [cls.one]*min(rows, cols) + return cls._new(rows, cols, vals, copy=False) + + @classmethod + def _eval_jordan_block(cls, size: int, eigenvalue, band='upper'): + if band == 'lower': + def entry(i, j): + if i == j: + return eigenvalue + elif j + 1 == i: + return cls.one + return cls.zero + else: + def entry(i, j): + if i == j: + return eigenvalue + elif i + 1 == j: + return cls.one + return cls.zero + return cls._new(size, size, entry) + + @classmethod + def _eval_ones(cls, rows, cols): + def entry(i, j): + return cls.one + return cls._new(rows, cols, entry) + + @classmethod + def _eval_zeros(cls, rows, cols): + return cls._new(rows, cols, [cls.zero]*(rows*cols), copy=False) + + @classmethod + def _eval_wilkinson(cls, n): + def entry(i, j): + return cls.one if i + 1 == j else cls.zero + + D = cls._new(2*n + 1, 2*n + 1, entry) + + wminus = cls.diag(list(range(-n, n + 1)), unpack=True) + D + D.T + wplus = abs(cls.diag(list(range(-n, n + 1)), unpack=True)) + D + D.T + + return wminus, wplus + + @classmethod + def diag(kls, *args, strict=False, unpack=True, rows=None, cols=None, **kwargs): + """Returns a matrix with the specified diagonal. + If matrices are passed, a block-diagonal matrix + is created (i.e. the "direct sum" of the matrices). + + kwargs + ====== + + rows : rows of the resulting matrix; computed if + not given. + + cols : columns of the resulting matrix; computed if + not given. + + cls : class for the resulting matrix + + unpack : bool which, when True (default), unpacks a single + sequence rather than interpreting it as a Matrix. + + strict : bool which, when False (default), allows Matrices to + have variable-length rows. + + Examples + ======== + + >>> from sympy import Matrix + >>> Matrix.diag(1, 2, 3) + Matrix([ + [1, 0, 0], + [0, 2, 0], + [0, 0, 3]]) + + The current default is to unpack a single sequence. If this is + not desired, set `unpack=False` and it will be interpreted as + a matrix. + + >>> Matrix.diag([1, 2, 3]) == Matrix.diag(1, 2, 3) + True + + When more than one element is passed, each is interpreted as + something to put on the diagonal. Lists are converted to + matrices. Filling of the diagonal always continues from + the bottom right hand corner of the previous item: this + will create a block-diagonal matrix whether the matrices + are square or not. + + >>> col = [1, 2, 3] + >>> row = [[4, 5]] + >>> Matrix.diag(col, row) + Matrix([ + [1, 0, 0], + [2, 0, 0], + [3, 0, 0], + [0, 4, 5]]) + + When `unpack` is False, elements within a list need not all be + of the same length. Setting `strict` to True would raise a + ValueError for the following: + + >>> Matrix.diag([[1, 2, 3], [4, 5], [6]], unpack=False) + Matrix([ + [1, 2, 3], + [4, 5, 0], + [6, 0, 0]]) + + The type of the returned matrix can be set with the ``cls`` + keyword. + + >>> from sympy import ImmutableMatrix + >>> from sympy.utilities.misc import func_name + >>> func_name(Matrix.diag(1, cls=ImmutableMatrix)) + 'ImmutableDenseMatrix' + + A zero dimension matrix can be used to position the start of + the filling at the start of an arbitrary row or column: + + >>> from sympy import ones + >>> r2 = ones(0, 2) + >>> Matrix.diag(r2, 1, 2) + Matrix([ + [0, 0, 1, 0], + [0, 0, 0, 2]]) + + See Also + ======== + eye + diagonal + .dense.diag + .expressions.blockmatrix.BlockMatrix + .sparsetools.banded + """ + from sympy.matrices.matrixbase import MatrixBase + from sympy.matrices.dense import Matrix + from sympy.matrices import SparseMatrix + klass = kwargs.get('cls', kls) + if unpack and len(args) == 1 and is_sequence(args[0]) and \ + not isinstance(args[0], MatrixBase): + args = args[0] + + # fill a default dict with the diagonal entries + diag_entries = defaultdict(int) + rmax = cmax = 0 # keep track of the biggest index seen + for m in args: + if isinstance(m, list): + if strict: + # if malformed, Matrix will raise an error + _ = Matrix(m) + r, c = _.shape + m = _.tolist() + else: + r, c, smat = SparseMatrix._handle_creation_inputs(m) + for (i, j), _ in smat.items(): + diag_entries[(i + rmax, j + cmax)] = _ + m = [] # to skip process below + elif hasattr(m, 'shape'): # a Matrix + # convert to list of lists + r, c = m.shape + m = m.tolist() + else: # in this case, we're a single value + diag_entries[(rmax, cmax)] = m + rmax += 1 + cmax += 1 + continue + # process list of lists + for i, mi in enumerate(m): + for j, _ in enumerate(mi): + diag_entries[(i + rmax, j + cmax)] = _ + rmax += r + cmax += c + if rows is None: + rows, cols = cols, rows + if rows is None: + rows, cols = rmax, cmax + else: + cols = rows if cols is None else cols + if rows < rmax or cols < cmax: + raise ValueError(filldedent(''' + The constructed matrix is {} x {} but a size of {} x {} + was specified.'''.format(rmax, cmax, rows, cols))) + return klass._eval_diag(rows, cols, diag_entries) + + @classmethod + def eye(kls, rows, cols=None, **kwargs): + """Returns an identity matrix. + + Parameters + ========== + + rows : rows of the matrix + cols : cols of the matrix (if None, cols=rows) + + kwargs + ====== + cls : class of the returned matrix + """ + if cols is None: + cols = rows + if rows < 0 or cols < 0: + raise ValueError("Cannot create a {} x {} matrix. " + "Both dimensions must be positive".format(rows, cols)) + klass = kwargs.get('cls', kls) + rows, cols = as_int(rows), as_int(cols) + + return klass._eval_eye(rows, cols) + + @classmethod + def jordan_block(kls, size=None, eigenvalue=None, *, band='upper', **kwargs): + """Returns a Jordan block + + Parameters + ========== + + size : Integer, optional + Specifies the shape of the Jordan block matrix. + + eigenvalue : Number or Symbol + Specifies the value for the main diagonal of the matrix. + + .. note:: + The keyword ``eigenval`` is also specified as an alias + of this keyword, but it is not recommended to use. + + We may deprecate the alias in later release. + + band : 'upper' or 'lower', optional + Specifies the position of the off-diagonal to put `1` s on. + + cls : Matrix, optional + Specifies the matrix class of the output form. + + If it is not specified, the class type where the method is + being executed on will be returned. + + Returns + ======= + + Matrix + A Jordan block matrix. + + Raises + ====== + + ValueError + If insufficient arguments are given for matrix size + specification, or no eigenvalue is given. + + Examples + ======== + + Creating a default Jordan block: + + >>> from sympy import Matrix + >>> from sympy.abc import x + >>> Matrix.jordan_block(4, x) + Matrix([ + [x, 1, 0, 0], + [0, x, 1, 0], + [0, 0, x, 1], + [0, 0, 0, x]]) + + Creating an alternative Jordan block matrix where `1` is on + lower off-diagonal: + + >>> Matrix.jordan_block(4, x, band='lower') + Matrix([ + [x, 0, 0, 0], + [1, x, 0, 0], + [0, 1, x, 0], + [0, 0, 1, x]]) + + Creating a Jordan block with keyword arguments + + >>> Matrix.jordan_block(size=4, eigenvalue=x) + Matrix([ + [x, 1, 0, 0], + [0, x, 1, 0], + [0, 0, x, 1], + [0, 0, 0, x]]) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Jordan_matrix + """ + klass = kwargs.pop('cls', kls) + + eigenval = kwargs.get('eigenval', None) + if eigenvalue is None and eigenval is None: + raise ValueError("Must supply an eigenvalue") + elif eigenvalue != eigenval and None not in (eigenval, eigenvalue): + raise ValueError( + "Inconsistent values are given: 'eigenval'={}, " + "'eigenvalue'={}".format(eigenval, eigenvalue)) + else: + if eigenval is not None: + eigenvalue = eigenval + + if size is None: + raise ValueError("Must supply a matrix size") + + size = as_int(size) + return klass._eval_jordan_block(size, eigenvalue, band) + + @classmethod + def ones(kls, rows, cols=None, **kwargs): + """Returns a matrix of ones. + + Parameters + ========== + + rows : rows of the matrix + cols : cols of the matrix (if None, cols=rows) + + kwargs + ====== + cls : class of the returned matrix + """ + if cols is None: + cols = rows + klass = kwargs.get('cls', kls) + rows, cols = as_int(rows), as_int(cols) + + return klass._eval_ones(rows, cols) + + @classmethod + def zeros(kls, rows, cols=None, **kwargs): + """Returns a matrix of zeros. + + Parameters + ========== + + rows : rows of the matrix + cols : cols of the matrix (if None, cols=rows) + + kwargs + ====== + cls : class of the returned matrix + """ + if cols is None: + cols = rows + if rows < 0 or cols < 0: + raise ValueError("Cannot create a {} x {} matrix. " + "Both dimensions must be positive".format(rows, cols)) + klass = kwargs.get('cls', kls) + rows, cols = as_int(rows), as_int(cols) + + return klass._eval_zeros(rows, cols) + + @classmethod + def companion(kls, poly): + """Returns a companion matrix of a polynomial. + + Examples + ======== + + >>> from sympy import Matrix, Poly, Symbol, symbols + >>> x = Symbol('x') + >>> c0, c1, c2, c3, c4 = symbols('c0:5') + >>> p = Poly(c0 + c1*x + c2*x**2 + c3*x**3 + c4*x**4 + x**5, x) + >>> Matrix.companion(p) + Matrix([ + [0, 0, 0, 0, -c0], + [1, 0, 0, 0, -c1], + [0, 1, 0, 0, -c2], + [0, 0, 1, 0, -c3], + [0, 0, 0, 1, -c4]]) + """ + poly = kls._sympify(poly) + if not isinstance(poly, Poly): + raise ValueError("{} must be a Poly instance.".format(poly)) + if not poly.is_monic: + raise ValueError("{} must be a monic polynomial.".format(poly)) + if not poly.is_univariate: + raise ValueError( + "{} must be a univariate polynomial.".format(poly)) + + size = poly.degree() + if not size >= 1: + raise ValueError( + "{} must have degree not less than 1.".format(poly)) + + coeffs = poly.all_coeffs() + def entry(i, j): + if j == size - 1: + return -coeffs[-1 - i] + elif i == j + 1: + return kls.one + return kls.zero + return kls._new(size, size, entry) + + + @classmethod + def wilkinson(kls, n, **kwargs): + """Returns two square Wilkinson Matrix of size 2*n + 1 + $W_{2n + 1}^-, W_{2n + 1}^+ =$ Wilkinson(n) + + Examples + ======== + + >>> from sympy import Matrix + >>> wminus, wplus = Matrix.wilkinson(3) + >>> wminus + Matrix([ + [-3, 1, 0, 0, 0, 0, 0], + [ 1, -2, 1, 0, 0, 0, 0], + [ 0, 1, -1, 1, 0, 0, 0], + [ 0, 0, 1, 0, 1, 0, 0], + [ 0, 0, 0, 1, 1, 1, 0], + [ 0, 0, 0, 0, 1, 2, 1], + [ 0, 0, 0, 0, 0, 1, 3]]) + >>> wplus + Matrix([ + [3, 1, 0, 0, 0, 0, 0], + [1, 2, 1, 0, 0, 0, 0], + [0, 1, 1, 1, 0, 0, 0], + [0, 0, 1, 0, 1, 0, 0], + [0, 0, 0, 1, 1, 1, 0], + [0, 0, 0, 0, 1, 2, 1], + [0, 0, 0, 0, 0, 1, 3]]) + + References + ========== + + .. [1] https://blogs.mathworks.com/cleve/2013/04/15/wilkinsons-matrices-2/ + .. [2] J. H. Wilkinson, The Algebraic Eigenvalue Problem, Claredon Press, Oxford, 1965, 662 pp. + + """ + klass = kwargs.get('cls', kls) + n = as_int(n) + return klass._eval_wilkinson(n) + + # The RepMatrix subclass uses more efficient sparse implementations of + # _eval_iter_values and other things. + + def _eval_iter_values(self): + return (i for i in self if i is not S.Zero) + + def _eval_values(self): + return list(self.iter_values()) + + def _eval_iter_items(self): + for i in range(self.rows): + for j in range(self.cols): + if self[i, j]: + yield (i, j), self[i, j] + + def _eval_atoms(self, *types): + values = self.values() + if len(values) < self.rows * self.cols and isinstance(S.Zero, types): + s = {S.Zero} + else: + s = set() + return s.union(*[v.atoms(*types) for v in values]) + + def _eval_free_symbols(self): + return set().union(*(i.free_symbols for i in set(self.values()))) + + def _eval_has(self, *patterns): + return any(a.has(*patterns) for a in self.iter_values()) + + def _eval_is_symbolic(self): + return self.has(Symbol) + + # _eval_is_hermitian is called by some general SymPy + # routines and has a different *args signature. Make + # sure the names don't clash by adding `_matrix_` in name. + def _eval_is_matrix_hermitian(self, simpfunc): + herm = lambda i, j: simpfunc(self[i, j] - self[j, i].adjoint()).is_zero + return fuzzy_and(herm(i, j) for (i, j), v in self.iter_items()) + + def _eval_is_zero_matrix(self): + return fuzzy_and(v.is_zero for v in self.iter_values()) + + def _eval_is_Identity(self) -> FuzzyBool: + one = self.one + zero = self.zero + ident = lambda i, j, v: v is one if i == j else v is zero + return all(ident(i, j, v) for (i, j), v in self.iter_items()) + + def _eval_is_diagonal(self): + return fuzzy_and(v.is_zero for (i, j), v in self.iter_items() if i != j) + + def _eval_is_lower(self): + return all(v.is_zero for (i, j), v in self.iter_items() if i < j) + + def _eval_is_upper(self): + return all(v.is_zero for (i, j), v in self.iter_items() if i > j) + + def _eval_is_lower_hessenberg(self): + return all(v.is_zero for (i, j), v in self.iter_items() if i + 1 < j) + + def _eval_is_upper_hessenberg(self): + return all(v.is_zero for (i, j), v in self.iter_items() if i > j + 1) + + def _eval_is_symmetric(self, simpfunc): + sym = lambda i, j: simpfunc(self[i, j] - self[j, i]).is_zero + return fuzzy_and(sym(i, j) for (i, j), v in self.iter_items()) + + def _eval_is_anti_symmetric(self, simpfunc): + anti = lambda i, j: simpfunc(self[i, j] + self[j, i]).is_zero + return fuzzy_and(anti(i, j) for (i, j), v in self.iter_items()) + + def _has_positive_diagonals(self): + diagonal_entries = (self[i, i] for i in range(self.rows)) + return fuzzy_and(x.is_positive for x in diagonal_entries) + + def _has_nonnegative_diagonals(self): + diagonal_entries = (self[i, i] for i in range(self.rows)) + return fuzzy_and(x.is_nonnegative for x in diagonal_entries) + + def atoms(self, *types): + """Returns the atoms that form the current object. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy import Matrix + >>> Matrix([[x]]) + Matrix([[x]]) + >>> _.atoms() + {x} + >>> Matrix([[x, y], [y, x]]) + Matrix([ + [x, y], + [y, x]]) + >>> _.atoms() + {x, y} + """ + + types = tuple(t if isinstance(t, type) else type(t) for t in types) + if not types: + types = (Atom,) + return self._eval_atoms(*types) + + @property + def free_symbols(self): + """Returns the free symbols within the matrix. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy import Matrix + >>> Matrix([[x], [1]]).free_symbols + {x} + """ + return self._eval_free_symbols() + + def has(self, *patterns): + """Test whether any subexpression matches any of the patterns. + + Examples + ======== + + >>> from sympy import Matrix, SparseMatrix, Float + >>> from sympy.abc import x, y + >>> A = Matrix(((1, x), (0.2, 3))) + >>> B = SparseMatrix(((1, x), (0.2, 3))) + >>> A.has(x) + True + >>> A.has(y) + False + >>> A.has(Float) + True + >>> B.has(x) + True + >>> B.has(y) + False + >>> B.has(Float) + True + """ + return self._eval_has(*patterns) + + def is_anti_symmetric(self, simplify=True): + """Check if matrix M is an antisymmetric matrix, + that is, M is a square matrix with all M[i, j] == -M[j, i]. + + When ``simplify=True`` (default), the sum M[i, j] + M[j, i] is + simplified before testing to see if it is zero. By default, + the SymPy simplify function is used. To use a custom function + set simplify to a function that accepts a single argument which + returns a simplified expression. To skip simplification, set + simplify to False but note that although this will be faster, + it may induce false negatives. + + Examples + ======== + + >>> from sympy import Matrix, symbols + >>> m = Matrix(2, 2, [0, 1, -1, 0]) + >>> m + Matrix([ + [ 0, 1], + [-1, 0]]) + >>> m.is_anti_symmetric() + True + >>> x, y = symbols('x y') + >>> m = Matrix(2, 3, [0, 0, x, -y, 0, 0]) + >>> m + Matrix([ + [ 0, 0, x], + [-y, 0, 0]]) + >>> m.is_anti_symmetric() + False + + >>> from sympy.abc import x, y + >>> m = Matrix(3, 3, [0, x**2 + 2*x + 1, y, + ... -(x + 1)**2, 0, x*y, + ... -y, -x*y, 0]) + + Simplification of matrix elements is done by default so even + though two elements which should be equal and opposite would not + pass an equality test, the matrix is still reported as + anti-symmetric: + + >>> m[0, 1] == -m[1, 0] + False + >>> m.is_anti_symmetric() + True + + If ``simplify=False`` is used for the case when a Matrix is already + simplified, this will speed things up. Here, we see that without + simplification the matrix does not appear anti-symmetric: + + >>> print(m.is_anti_symmetric(simplify=False)) + None + + But if the matrix were already expanded, then it would appear + anti-symmetric and simplification in the is_anti_symmetric routine + is not needed: + + >>> m = m.expand() + >>> m.is_anti_symmetric(simplify=False) + True + """ + # accept custom simplification + simpfunc = simplify + if not isfunction(simplify): + simpfunc = _utilities_simplify if simplify else lambda x: x + + if not self.is_square: + return False + return self._eval_is_anti_symmetric(simpfunc) + + def is_diagonal(self): + """Check if matrix is diagonal, + that is matrix in which the entries outside the main diagonal are all zero. + + Examples + ======== + + >>> from sympy import Matrix, diag + >>> m = Matrix(2, 2, [1, 0, 0, 2]) + >>> m + Matrix([ + [1, 0], + [0, 2]]) + >>> m.is_diagonal() + True + + >>> m = Matrix(2, 2, [1, 1, 0, 2]) + >>> m + Matrix([ + [1, 1], + [0, 2]]) + >>> m.is_diagonal() + False + + >>> m = diag(1, 2, 3) + >>> m + Matrix([ + [1, 0, 0], + [0, 2, 0], + [0, 0, 3]]) + >>> m.is_diagonal() + True + + See Also + ======== + + is_lower + is_upper + sympy.matrices.matrixbase.MatrixBase.is_diagonalizable + diagonalize + """ + return self._eval_is_diagonal() + + @property + def is_weakly_diagonally_dominant(self): + r"""Tests if the matrix is row weakly diagonally dominant. + + Explanation + =========== + + A $n, n$ matrix $A$ is row weakly diagonally dominant if + + .. math:: + \left|A_{i, i}\right| \ge \sum_{j = 0, j \neq i}^{n-1} + \left|A_{i, j}\right| \quad {\text{for all }} + i \in \{ 0, ..., n-1 \} + + Examples + ======== + + >>> from sympy import Matrix + >>> A = Matrix([[3, -2, 1], [1, -3, 2], [-1, 2, 4]]) + >>> A.is_weakly_diagonally_dominant + True + + >>> A = Matrix([[-2, 2, 1], [1, 3, 2], [1, -2, 0]]) + >>> A.is_weakly_diagonally_dominant + False + + >>> A = Matrix([[-4, 2, 1], [1, 6, 2], [1, -2, 5]]) + >>> A.is_weakly_diagonally_dominant + True + + Notes + ===== + + If you want to test whether a matrix is column diagonally + dominant, you can apply the test after transposing the matrix. + """ + if not self.is_square: + return False + + rows, cols = self.shape + + def test_row(i): + summation = self.zero + for j in range(cols): + if i != j: + summation += Abs(self[i, j]) + return (Abs(self[i, i]) - summation).is_nonnegative + + return fuzzy_and(test_row(i) for i in range(rows)) + + @property + def is_strongly_diagonally_dominant(self): + r"""Tests if the matrix is row strongly diagonally dominant. + + Explanation + =========== + + A $n, n$ matrix $A$ is row strongly diagonally dominant if + + .. math:: + \left|A_{i, i}\right| > \sum_{j = 0, j \neq i}^{n-1} + \left|A_{i, j}\right| \quad {\text{for all }} + i \in \{ 0, ..., n-1 \} + + Examples + ======== + + >>> from sympy import Matrix + >>> A = Matrix([[3, -2, 1], [1, -3, 2], [-1, 2, 4]]) + >>> A.is_strongly_diagonally_dominant + False + + >>> A = Matrix([[-2, 2, 1], [1, 3, 2], [1, -2, 0]]) + >>> A.is_strongly_diagonally_dominant + False + + >>> A = Matrix([[-4, 2, 1], [1, 6, 2], [1, -2, 5]]) + >>> A.is_strongly_diagonally_dominant + True + + Notes + ===== + + If you want to test whether a matrix is column diagonally + dominant, you can apply the test after transposing the matrix. + """ + if not self.is_square: + return False + + rows, cols = self.shape + + def test_row(i): + summation = self.zero + for j in range(cols): + if i != j: + summation += Abs(self[i, j]) + return (Abs(self[i, i]) - summation).is_positive + + return fuzzy_and(test_row(i) for i in range(rows)) + + @property + def is_hermitian(self): + """Checks if the matrix is Hermitian. + + In a Hermitian matrix element i,j is the complex conjugate of + element j,i. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy import I + >>> from sympy.abc import x + >>> a = Matrix([[1, I], [-I, 1]]) + >>> a + Matrix([ + [ 1, I], + [-I, 1]]) + >>> a.is_hermitian + True + >>> a[0, 0] = 2*I + >>> a.is_hermitian + False + >>> a[0, 0] = x + >>> a.is_hermitian + >>> a[0, 1] = a[1, 0]*I + >>> a.is_hermitian + False + """ + if not self.is_square: + return False + + return self._eval_is_matrix_hermitian(_utilities_simplify) + + @property + def is_Identity(self) -> FuzzyBool: + if not self.is_square: + return False + return self._eval_is_Identity() + + @property + def is_lower_hessenberg(self): + r"""Checks if the matrix is in the lower-Hessenberg form. + + The lower hessenberg matrix has zero entries + above the first superdiagonal. + + Examples + ======== + + >>> from sympy import Matrix + >>> a = Matrix([[1, 2, 0, 0], [5, 2, 3, 0], [3, 4, 3, 7], [5, 6, 1, 1]]) + >>> a + Matrix([ + [1, 2, 0, 0], + [5, 2, 3, 0], + [3, 4, 3, 7], + [5, 6, 1, 1]]) + >>> a.is_lower_hessenberg + True + + See Also + ======== + + is_upper_hessenberg + is_lower + """ + return self._eval_is_lower_hessenberg() + + @property + def is_lower(self): + """Check if matrix is a lower triangular matrix. True can be returned + even if the matrix is not square. + + Examples + ======== + + >>> from sympy import Matrix + >>> m = Matrix(2, 2, [1, 0, 0, 1]) + >>> m + Matrix([ + [1, 0], + [0, 1]]) + >>> m.is_lower + True + + >>> m = Matrix(4, 3, [0, 0, 0, 2, 0, 0, 1, 4, 0, 6, 6, 5]) + >>> m + Matrix([ + [0, 0, 0], + [2, 0, 0], + [1, 4, 0], + [6, 6, 5]]) + >>> m.is_lower + True + + >>> from sympy.abc import x, y + >>> m = Matrix(2, 2, [x**2 + y, y**2 + x, 0, x + y]) + >>> m + Matrix([ + [x**2 + y, x + y**2], + [ 0, x + y]]) + >>> m.is_lower + False + + See Also + ======== + + is_upper + is_diagonal + is_lower_hessenberg + """ + return self._eval_is_lower() + + @property + def is_square(self): + """Checks if a matrix is square. + + A matrix is square if the number of rows equals the number of columns. + The empty matrix is square by definition, since the number of rows and + the number of columns are both zero. + + Examples + ======== + + >>> from sympy import Matrix + >>> a = Matrix([[1, 2, 3], [4, 5, 6]]) + >>> b = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + >>> c = Matrix([]) + >>> a.is_square + False + >>> b.is_square + True + >>> c.is_square + True + """ + return self.rows == self.cols + + def is_symbolic(self): + """Checks if any elements contain Symbols. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.abc import x, y + >>> M = Matrix([[x, y], [1, 0]]) + >>> M.is_symbolic() + True + + """ + return self._eval_is_symbolic() + + def is_symmetric(self, simplify=True): + """Check if matrix is symmetric matrix, + that is square matrix and is equal to its transpose. + + By default, simplifications occur before testing symmetry. + They can be skipped using 'simplify=False'; while speeding things a bit, + this may however induce false negatives. + + Examples + ======== + + >>> from sympy import Matrix + >>> m = Matrix(2, 2, [0, 1, 1, 2]) + >>> m + Matrix([ + [0, 1], + [1, 2]]) + >>> m.is_symmetric() + True + + >>> m = Matrix(2, 2, [0, 1, 2, 0]) + >>> m + Matrix([ + [0, 1], + [2, 0]]) + >>> m.is_symmetric() + False + + >>> m = Matrix(2, 3, [0, 0, 0, 0, 0, 0]) + >>> m + Matrix([ + [0, 0, 0], + [0, 0, 0]]) + >>> m.is_symmetric() + False + + >>> from sympy.abc import x, y + >>> m = Matrix(3, 3, [1, x**2 + 2*x + 1, y, (x + 1)**2, 2, 0, y, 0, 3]) + >>> m + Matrix([ + [ 1, x**2 + 2*x + 1, y], + [(x + 1)**2, 2, 0], + [ y, 0, 3]]) + >>> m.is_symmetric() + True + + If the matrix is already simplified, you may speed-up is_symmetric() + test by using 'simplify=False'. + + >>> bool(m.is_symmetric(simplify=False)) + False + >>> m1 = m.expand() + >>> m1.is_symmetric(simplify=False) + True + """ + simpfunc = simplify + if not isfunction(simplify): + simpfunc = _utilities_simplify if simplify else lambda x: x + + if not self.is_square: + return False + + return self._eval_is_symmetric(simpfunc) + + @property + def is_upper_hessenberg(self): + """Checks if the matrix is the upper-Hessenberg form. + + The upper hessenberg matrix has zero entries + below the first subdiagonal. + + Examples + ======== + + >>> from sympy import Matrix + >>> a = Matrix([[1, 4, 2, 3], [3, 4, 1, 7], [0, 2, 3, 4], [0, 0, 1, 3]]) + >>> a + Matrix([ + [1, 4, 2, 3], + [3, 4, 1, 7], + [0, 2, 3, 4], + [0, 0, 1, 3]]) + >>> a.is_upper_hessenberg + True + + See Also + ======== + + is_lower_hessenberg + is_upper + """ + return self._eval_is_upper_hessenberg() + + @property + def is_upper(self): + """Check if matrix is an upper triangular matrix. True can be returned + even if the matrix is not square. + + Examples + ======== + + >>> from sympy import Matrix + >>> m = Matrix(2, 2, [1, 0, 0, 1]) + >>> m + Matrix([ + [1, 0], + [0, 1]]) + >>> m.is_upper + True + + >>> m = Matrix(4, 3, [5, 1, 9, 0, 4, 6, 0, 0, 5, 0, 0, 0]) + >>> m + Matrix([ + [5, 1, 9], + [0, 4, 6], + [0, 0, 5], + [0, 0, 0]]) + >>> m.is_upper + True + + >>> m = Matrix(2, 3, [4, 2, 5, 6, 1, 1]) + >>> m + Matrix([ + [4, 2, 5], + [6, 1, 1]]) + >>> m.is_upper + False + + See Also + ======== + + is_lower + is_diagonal + is_upper_hessenberg + """ + return self._eval_is_upper() + + @property + def is_zero_matrix(self): + """Checks if a matrix is a zero matrix. + + A matrix is zero if every element is zero. A matrix need not be square + to be considered zero. The empty matrix is zero by the principle of + vacuous truth. For a matrix that may or may not be zero (e.g. + contains a symbol), this will be None + + Examples + ======== + + >>> from sympy import Matrix, zeros + >>> from sympy.abc import x + >>> a = Matrix([[0, 0], [0, 0]]) + >>> b = zeros(3, 4) + >>> c = Matrix([[0, 1], [0, 0]]) + >>> d = Matrix([]) + >>> e = Matrix([[x, 0], [0, 0]]) + >>> a.is_zero_matrix + True + >>> b.is_zero_matrix + True + >>> c.is_zero_matrix + False + >>> d.is_zero_matrix + True + >>> e.is_zero_matrix + """ + return self._eval_is_zero_matrix() + + def values(self): + """Return non-zero values of self. + + Examples + ======== + + >>> from sympy import Matrix + >>> m = Matrix([[0, 1], [2, 3]]) + >>> m.values() + [1, 2, 3] + + See Also + ======== + + iter_values + tolist + flat + """ + return self._eval_values() + + def iter_values(self): + """ + Iterate over non-zero values of self. + + Examples + ======== + + >>> from sympy import Matrix + >>> m = Matrix([[0, 1], [2, 3]]) + >>> list(m.iter_values()) + [1, 2, 3] + + See Also + ======== + + values + """ + return self._eval_iter_values() + + def iter_items(self): + """Iterate over indices and values of nonzero items. + + Examples + ======== + + >>> from sympy import Matrix + >>> m = Matrix([[0, 1], [2, 3]]) + >>> list(m.iter_items()) + [((0, 1), 1), ((1, 0), 2), ((1, 1), 3)] + + See Also + ======== + + iter_values + todok + """ + return self._eval_iter_items() + + def _eval_adjoint(self): + return self.transpose().applyfunc(lambda x: x.adjoint()) + + def _eval_applyfunc(self, f): + cols = self.cols + size = self.rows*self.cols + + dok = self.todok() + valmap = {v: f(v) for v in dok.values()} + + if len(dok) < size and ((fzero := f(S.Zero)) is not S.Zero): + out_flat = [fzero]*size + for (i, j), v in dok.items(): + out_flat[i*cols + j] = valmap[v] + out = self._new(self.rows, self.cols, out_flat) + else: + fdok = {ij: valmap[v] for ij, v in dok.items()} + out = self.from_dok(self.rows, self.cols, fdok) + + return out + + def _eval_as_real_imag(self): # type: ignore + return (self.applyfunc(re), self.applyfunc(im)) + + def _eval_conjugate(self): + return self.applyfunc(lambda x: x.conjugate()) + + def _eval_permute_cols(self, perm): + # apply the permutation to a list + mapping = list(perm) + + def entry(i, j): + return self[i, mapping[j]] + + return self._new(self.rows, self.cols, entry) + + def _eval_permute_rows(self, perm): + # apply the permutation to a list + mapping = list(perm) + + def entry(i, j): + return self[mapping[i], j] + + return self._new(self.rows, self.cols, entry) + + def _eval_trace(self): + return sum(self[i, i] for i in range(self.rows)) + + def _eval_transpose(self): + return self._new(self.cols, self.rows, lambda i, j: self[j, i]) + + def adjoint(self): + """Conjugate transpose or Hermitian conjugation.""" + return self._eval_adjoint() + + def applyfunc(self, f): + """Apply a function to each element of the matrix. + + Examples + ======== + + >>> from sympy import Matrix + >>> m = Matrix(2, 2, lambda i, j: i*2+j) + >>> m + Matrix([ + [0, 1], + [2, 3]]) + >>> m.applyfunc(lambda i: 2*i) + Matrix([ + [0, 2], + [4, 6]]) + + """ + if not callable(f): + raise TypeError("`f` must be callable.") + + return self._eval_applyfunc(f) + + def as_real_imag(self, deep=True, **hints): + """Returns a tuple containing the (real, imaginary) part of matrix.""" + # XXX: Ignoring deep and hints... + return self._eval_as_real_imag() + + def conjugate(self): + """Return the by-element conjugation. + + Examples + ======== + + >>> from sympy import SparseMatrix, I + >>> a = SparseMatrix(((1, 2 + I), (3, 4), (I, -I))) + >>> a + Matrix([ + [1, 2 + I], + [3, 4], + [I, -I]]) + >>> a.C + Matrix([ + [ 1, 2 - I], + [ 3, 4], + [-I, I]]) + + See Also + ======== + + transpose: Matrix transposition + H: Hermite conjugation + sympy.matrices.matrixbase.MatrixBase.D: Dirac conjugation + """ + return self._eval_conjugate() + + def doit(self, **hints): + return self.applyfunc(lambda x: x.doit(**hints)) + + def evalf(self, n=15, subs=None, maxn=100, chop=False, strict=False, quad=None, verbose=False): + """Apply evalf() to each element of self.""" + options = {'subs':subs, 'maxn':maxn, 'chop':chop, 'strict':strict, + 'quad':quad, 'verbose':verbose} + return self.applyfunc(lambda i: i.evalf(n, **options)) + + def expand(self, deep=True, modulus=None, power_base=True, power_exp=True, + mul=True, log=True, multinomial=True, basic=True, **hints): + """Apply core.function.expand to each entry of the matrix. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy import Matrix + >>> Matrix(1, 1, [x*(x+1)]) + Matrix([[x*(x + 1)]]) + >>> _.expand() + Matrix([[x**2 + x]]) + + """ + return self.applyfunc(lambda x: x.expand( + deep, modulus, power_base, power_exp, mul, log, multinomial, basic, + **hints)) + + @property + def H(self): + """Return Hermite conjugate. + + Examples + ======== + + >>> from sympy import Matrix, I + >>> m = Matrix((0, 1 + I, 2, 3)) + >>> m + Matrix([ + [ 0], + [1 + I], + [ 2], + [ 3]]) + >>> m.H + Matrix([[0, 1 - I, 2, 3]]) + + See Also + ======== + + conjugate: By-element conjugation + sympy.matrices.matrixbase.MatrixBase.D: Dirac conjugation + """ + return self.adjoint() + + def permute(self, perm, orientation='rows', direction='forward'): + r"""Permute the rows or columns of a matrix by the given list of + swaps. + + Parameters + ========== + + perm : Permutation, list, or list of lists + A representation for the permutation. + + If it is ``Permutation``, it is used directly with some + resizing with respect to the matrix size. + + If it is specified as list of lists, + (e.g., ``[[0, 1], [0, 2]]``), then the permutation is formed + from applying the product of cycles. The direction how the + cyclic product is applied is described in below. + + If it is specified as a list, the list should represent + an array form of a permutation. (e.g., ``[1, 2, 0]``) which + would would form the swapping function + `0 \mapsto 1, 1 \mapsto 2, 2\mapsto 0`. + + orientation : 'rows', 'cols' + A flag to control whether to permute the rows or the columns + + direction : 'forward', 'backward' + A flag to control whether to apply the permutations from + the start of the list first, or from the back of the list + first. + + For example, if the permutation specification is + ``[[0, 1], [0, 2]]``, + + If the flag is set to ``'forward'``, the cycle would be + formed as `0 \mapsto 2, 2 \mapsto 1, 1 \mapsto 0`. + + If the flag is set to ``'backward'``, the cycle would be + formed as `0 \mapsto 1, 1 \mapsto 2, 2 \mapsto 0`. + + If the argument ``perm`` is not in a form of list of lists, + this flag takes no effect. + + Examples + ======== + + >>> from sympy import eye + >>> M = eye(3) + >>> M.permute([[0, 1], [0, 2]], orientation='rows', direction='forward') + Matrix([ + [0, 0, 1], + [1, 0, 0], + [0, 1, 0]]) + + >>> from sympy import eye + >>> M = eye(3) + >>> M.permute([[0, 1], [0, 2]], orientation='rows', direction='backward') + Matrix([ + [0, 1, 0], + [0, 0, 1], + [1, 0, 0]]) + + Notes + ===== + + If a bijective function + `\sigma : \mathbb{N}_0 \rightarrow \mathbb{N}_0` denotes the + permutation. + + If the matrix `A` is the matrix to permute, represented as + a horizontal or a vertical stack of vectors: + + .. math:: + A = + \begin{bmatrix} + a_0 \\ a_1 \\ \vdots \\ a_{n-1} + \end{bmatrix} = + \begin{bmatrix} + \alpha_0 & \alpha_1 & \cdots & \alpha_{n-1} + \end{bmatrix} + + If the matrix `B` is the result, the permutation of matrix rows + is defined as: + + .. math:: + B := \begin{bmatrix} + a_{\sigma(0)} \\ a_{\sigma(1)} \\ \vdots \\ a_{\sigma(n-1)} + \end{bmatrix} + + And the permutation of matrix columns is defined as: + + .. math:: + B := \begin{bmatrix} + \alpha_{\sigma(0)} & \alpha_{\sigma(1)} & + \cdots & \alpha_{\sigma(n-1)} + \end{bmatrix} + """ + from sympy.combinatorics import Permutation + + # allow british variants and `columns` + if direction == 'forwards': + direction = 'forward' + if direction == 'backwards': + direction = 'backward' + if orientation == 'columns': + orientation = 'cols' + + if direction not in ('forward', 'backward'): + raise TypeError("direction='{}' is an invalid kwarg. " + "Try 'forward' or 'backward'".format(direction)) + if orientation not in ('rows', 'cols'): + raise TypeError("orientation='{}' is an invalid kwarg. " + "Try 'rows' or 'cols'".format(orientation)) + + if not isinstance(perm, (Permutation, Iterable)): + raise ValueError( + "{} must be a list, a list of lists, " + "or a SymPy permutation object.".format(perm)) + + # ensure all swaps are in range + max_index = self.rows if orientation == 'rows' else self.cols + if not all(0 <= t <= max_index for t in flatten(list(perm))): + raise IndexError("`swap` indices out of range.") + + if perm and not isinstance(perm, Permutation) and \ + isinstance(perm[0], Iterable): + if direction == 'forward': + perm = list(reversed(perm)) + perm = Permutation(perm, size=max_index+1) + else: + perm = Permutation(perm, size=max_index+1) + + if orientation == 'rows': + return self._eval_permute_rows(perm) + if orientation == 'cols': + return self._eval_permute_cols(perm) + + def permute_cols(self, swaps, direction='forward'): + """Alias for + ``self.permute(swaps, orientation='cols', direction=direction)`` + + See Also + ======== + + permute + """ + return self.permute(swaps, orientation='cols', direction=direction) + + def permute_rows(self, swaps, direction='forward'): + """Alias for + ``self.permute(swaps, orientation='rows', direction=direction)`` + + See Also + ======== + + permute + """ + return self.permute(swaps, orientation='rows', direction=direction) + + def refine(self, assumptions=True): + """Apply refine to each element of the matrix. + + Examples + ======== + + >>> from sympy import Symbol, Matrix, Abs, sqrt, Q + >>> x = Symbol('x') + >>> Matrix([[Abs(x)**2, sqrt(x**2)],[sqrt(x**2), Abs(x)**2]]) + Matrix([ + [ Abs(x)**2, sqrt(x**2)], + [sqrt(x**2), Abs(x)**2]]) + >>> _.refine(Q.real(x)) + Matrix([ + [ x**2, Abs(x)], + [Abs(x), x**2]]) + + """ + return self.applyfunc(lambda x: refine(x, assumptions)) + + def replace(self, F, G, map=False, simultaneous=True, exact=None): + """Replaces Function F in Matrix entries with Function G. + + Examples + ======== + + >>> from sympy import symbols, Function, Matrix + >>> F, G = symbols('F, G', cls=Function) + >>> M = Matrix(2, 2, lambda i, j: F(i+j)) ; M + Matrix([ + [F(0), F(1)], + [F(1), F(2)]]) + >>> N = M.replace(F,G) + >>> N + Matrix([ + [G(0), G(1)], + [G(1), G(2)]]) + """ + kwargs = {'map': map, 'simultaneous': simultaneous, 'exact': exact} + + if map: + + d = {} + def func(eij): + eij, dij = eij.replace(F, G, **kwargs) + d.update(dij) + return eij + + M = self.applyfunc(func) + return M, d + + else: + return self.applyfunc(lambda i: i.replace(F, G, **kwargs)) + + def rot90(self, k=1): + """Rotates Matrix by 90 degrees + + Parameters + ========== + + k : int + Specifies how many times the matrix is rotated by 90 degrees + (clockwise when positive, counter-clockwise when negative). + + Examples + ======== + + >>> from sympy import Matrix, symbols + >>> A = Matrix(2, 2, symbols('a:d')) + >>> A + Matrix([ + [a, b], + [c, d]]) + + Rotating the matrix clockwise one time: + + >>> A.rot90(1) + Matrix([ + [c, a], + [d, b]]) + + Rotating the matrix anticlockwise two times: + + >>> A.rot90(-2) + Matrix([ + [d, c], + [b, a]]) + """ + + mod = k%4 + if mod == 0: + return self + if mod == 1: + return self[::-1, ::].T + if mod == 2: + return self[::-1, ::-1] + if mod == 3: + return self[::, ::-1].T + + def simplify(self, **kwargs): + """Apply simplify to each element of the matrix. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy import SparseMatrix, sin, cos + >>> SparseMatrix(1, 1, [x*sin(y)**2 + x*cos(y)**2]) + Matrix([[x*sin(y)**2 + x*cos(y)**2]]) + >>> _.simplify() + Matrix([[x]]) + """ + return self.applyfunc(lambda x: x.simplify(**kwargs)) + + def subs(self, *args, **kwargs): # should mirror core.basic.subs + """Return a new matrix with subs applied to each entry. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy import SparseMatrix, Matrix + >>> SparseMatrix(1, 1, [x]) + Matrix([[x]]) + >>> _.subs(x, y) + Matrix([[y]]) + >>> Matrix(_).subs(y, x) + Matrix([[x]]) + """ + + if len(args) == 1 and not isinstance(args[0], (dict, set)) and iter(args[0]) and not is_sequence(args[0]): + args = (list(args[0]),) + + return self.applyfunc(lambda x: x.subs(*args, **kwargs)) + + def trace(self): + """ + Returns the trace of a square matrix i.e. the sum of the + diagonal elements. + + Examples + ======== + + >>> from sympy import Matrix + >>> A = Matrix(2, 2, [1, 2, 3, 4]) + >>> A.trace() + 5 + + """ + if self.rows != self.cols: + raise NonSquareMatrixError() + return self._eval_trace() + + def transpose(self): + """ + Returns the transpose of the matrix. + + Examples + ======== + + >>> from sympy import Matrix + >>> A = Matrix(2, 2, [1, 2, 3, 4]) + >>> A.transpose() + Matrix([ + [1, 3], + [2, 4]]) + + >>> from sympy import Matrix, I + >>> m=Matrix(((1, 2+I), (3, 4))) + >>> m + Matrix([ + [1, 2 + I], + [3, 4]]) + >>> m.transpose() + Matrix([ + [ 1, 3], + [2 + I, 4]]) + >>> m.T == m.transpose() + True + + See Also + ======== + + conjugate: By-element conjugation + + """ + return self._eval_transpose() + + @property + def T(self): + '''Matrix transposition''' + return self.transpose() + + @property + def C(self): + '''By-element conjugation''' + return self.conjugate() + + def n(self, *args, **kwargs): + """Apply evalf() to each element of self.""" + return self.evalf(*args, **kwargs) + + def xreplace(self, rule): # should mirror core.basic.xreplace + """Return a new matrix with xreplace applied to each entry. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy import SparseMatrix, Matrix + >>> SparseMatrix(1, 1, [x]) + Matrix([[x]]) + >>> _.xreplace({x: y}) + Matrix([[y]]) + >>> Matrix(_).xreplace({y: x}) + Matrix([[x]]) + """ + return self.applyfunc(lambda x: x.xreplace(rule)) + + def _eval_simplify(self, **kwargs): + # XXX: We can't use self.simplify here as mutable subclasses will + # override simplify and have it return None + return self.applyfunc(lambda x: x.simplify(**kwargs)) + + def _eval_trigsimp(self, **opts): + from sympy.simplify.trigsimp import trigsimp + return self.applyfunc(lambda x: trigsimp(x, **opts)) + + def upper_triangular(self, k=0): + """Return the elements on and above the kth diagonal of a matrix. + If k is not specified then simply returns upper-triangular portion + of a matrix + + Examples + ======== + + >>> from sympy import ones + >>> A = ones(4) + >>> A.upper_triangular() + Matrix([ + [1, 1, 1, 1], + [0, 1, 1, 1], + [0, 0, 1, 1], + [0, 0, 0, 1]]) + + >>> A.upper_triangular(2) + Matrix([ + [0, 0, 1, 1], + [0, 0, 0, 1], + [0, 0, 0, 0], + [0, 0, 0, 0]]) + + >>> A.upper_triangular(-1) + Matrix([ + [1, 1, 1, 1], + [1, 1, 1, 1], + [0, 1, 1, 1], + [0, 0, 1, 1]]) + + """ + + def entry(i, j): + return self[i, j] if i + k <= j else self.zero + + return self._new(self.rows, self.cols, entry) + + def lower_triangular(self, k=0): + """Return the elements on and below the kth diagonal of a matrix. + If k is not specified then simply returns lower-triangular portion + of a matrix + + Examples + ======== + + >>> from sympy import ones + >>> A = ones(4) + >>> A.lower_triangular() + Matrix([ + [1, 0, 0, 0], + [1, 1, 0, 0], + [1, 1, 1, 0], + [1, 1, 1, 1]]) + + >>> A.lower_triangular(-2) + Matrix([ + [0, 0, 0, 0], + [0, 0, 0, 0], + [1, 0, 0, 0], + [1, 1, 0, 0]]) + + >>> A.lower_triangular(1) + Matrix([ + [1, 1, 0, 0], + [1, 1, 1, 0], + [1, 1, 1, 1], + [1, 1, 1, 1]]) + + """ + + def entry(i, j): + return self[i, j] if i + k >= j else self.zero + + return self._new(self.rows, self.cols, entry) + + def _eval_Abs(self): + return self._new(self.rows, self.cols, lambda i, j: Abs(self[i, j])) + + def _eval_add(self, other): + return self._new(self.rows, self.cols, + lambda i, j: self[i, j] + other[i, j]) + + def _eval_matrix_mul(self, other): + def entry(i, j): + vec = [self[i,k]*other[k,j] for k in range(self.cols)] + try: + return Add(*vec) + except (TypeError, SympifyError): + # Some matrices don't work with `sum` or `Add` + # They don't work with `sum` because `sum` tries to add `0` + # Fall back to a safe way to multiply if the `Add` fails. + return reduce(lambda a, b: a + b, vec) + + return self._new(self.rows, other.cols, entry) + + def _eval_matrix_mul_elementwise(self, other): + return self._new(self.rows, self.cols, lambda i, j: self[i,j]*other[i,j]) + + def _eval_matrix_rmul(self, other): + def entry(i, j): + return sum(other[i,k]*self[k,j] for k in range(other.cols)) + return self._new(other.rows, self.cols, entry) + + def _eval_pow_by_recursion(self, num): + if num == 1: + return self + + if num % 2 == 1: + a, b = self, self._eval_pow_by_recursion(num - 1) + else: + a = b = self._eval_pow_by_recursion(num // 2) + + return a.multiply(b) + + def _eval_pow_by_cayley(self, exp): + from sympy.discrete.recurrences import linrec_coeffs + row = self.shape[0] + p = self.charpoly() + + coeffs = (-p).all_coeffs()[1:] + coeffs = linrec_coeffs(coeffs, exp) + new_mat = self.eye(row) + ans = self.zeros(row) + + for i in range(row): + ans += coeffs[i]*new_mat + new_mat *= self + + return ans + + def _eval_pow_by_recursion_dotprodsimp(self, num, prevsimp=None): + if prevsimp is None: + prevsimp = [True]*len(self) + + if num == 1: + return self + + if num % 2 == 1: + a, b = self, self._eval_pow_by_recursion_dotprodsimp(num - 1, + prevsimp=prevsimp) + else: + a = b = self._eval_pow_by_recursion_dotprodsimp(num // 2, + prevsimp=prevsimp) + + m = a.multiply(b, dotprodsimp=False) + lenm = len(m) + elems = [None]*lenm + + for i in range(lenm): + if prevsimp[i]: + elems[i], prevsimp[i] = _dotprodsimp(m[i], withsimp=True) + else: + elems[i] = m[i] + + return m._new(m.rows, m.cols, elems) + + def _eval_scalar_mul(self, other): + return self._new(self.rows, self.cols, lambda i, j: self[i,j]*other) + + def _eval_scalar_rmul(self, other): + return self._new(self.rows, self.cols, lambda i, j: other*self[i,j]) + + def _eval_Mod(self, other): + return self._new(self.rows, self.cols, lambda i, j: Mod(self[i, j], other)) + + # Python arithmetic functions + def __abs__(self): + """Returns a new matrix with entry-wise absolute values.""" + return self._eval_Abs() + + @call_highest_priority('__radd__') + def __add__(self, other): + """Return self + other, raising ShapeError if shapes do not match.""" + + other, T = _coerce_operand(self, other) + + if T != "is_matrix": + return NotImplemented + + if self.shape != other.shape: + raise ShapeError(f"Matrix size mismatch: {self.shape} + {other.shape}.") + + # Unify matrix types + a, b = self, other + if a.__class__ != classof(a, b): + b, a = a, b + + return a._eval_add(b) + + @call_highest_priority('__rtruediv__') + def __truediv__(self, other): + return self * (self.one / other) + + @call_highest_priority('__rmatmul__') + def __matmul__(self, other): + self, other, T = _unify_with_other(self, other) + + if T != "is_matrix": + return NotImplemented + + return self.__mul__(other) + + def __mod__(self, other): + return self.applyfunc(lambda x: x % other) + + @call_highest_priority('__rmul__') + def __mul__(self, other): + """Return self*other where other is either a scalar or a matrix + of compatible dimensions. + + Examples + ======== + + >>> from sympy import Matrix + >>> A = Matrix([[1, 2, 3], [4, 5, 6]]) + >>> 2*A == A*2 == Matrix([[2, 4, 6], [8, 10, 12]]) + True + >>> B = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + >>> A*B + Matrix([ + [30, 36, 42], + [66, 81, 96]]) + >>> B*A + Traceback (most recent call last): + ... + ShapeError: Matrices size mismatch. + >>> + + See Also + ======== + + matrix_multiply_elementwise + """ + + return self.multiply(other) + + def multiply(self, other, dotprodsimp=None): + """Same as __mul__() but with optional simplification. + + Parameters + ========== + + dotprodsimp : bool, optional + Specifies whether intermediate term algebraic simplification is used + during matrix multiplications to control expression blowup and thus + speed up calculation. Default is off. + """ + + isimpbool = _get_intermediate_simp_bool(False, dotprodsimp) + + self, other, T = _unify_with_other(self, other) + + if T == "possible_scalar": + try: + return self._eval_scalar_mul(other) + except TypeError: + return NotImplemented + + elif T == "is_matrix": + + if self.shape[1] != other.shape[0]: + raise ShapeError(f"Matrix size mismatch: {self.shape} * {other.shape}.") + + m = self._eval_matrix_mul(other) + + if isimpbool: + m = m._new(m.rows, m.cols, [_dotprodsimp(e) for e in m]) + + return m + + else: + return NotImplemented + + def multiply_elementwise(self, other): + """Return the Hadamard product (elementwise product) of A and B + + Examples + ======== + + >>> from sympy import Matrix + >>> A = Matrix([[0, 1, 2], [3, 4, 5]]) + >>> B = Matrix([[1, 10, 100], [100, 10, 1]]) + >>> A.multiply_elementwise(B) + Matrix([ + [ 0, 10, 200], + [300, 40, 5]]) + + See Also + ======== + + sympy.matrices.matrixbase.MatrixBase.cross + sympy.matrices.matrixbase.MatrixBase.dot + multiply + """ + if self.shape != other.shape: + raise ShapeError("Matrix shapes must agree {} != {}".format(self.shape, other.shape)) + + return self._eval_matrix_mul_elementwise(other) + + def __neg__(self): + return self._eval_scalar_mul(-1) + + @call_highest_priority('__rpow__') + def __pow__(self, exp): + """Return self**exp a scalar or symbol.""" + + return self.pow(exp) + + + def pow(self, exp, method=None): + r"""Return self**exp a scalar or symbol. + + Parameters + ========== + + method : multiply, mulsimp, jordan, cayley + If multiply then it returns exponentiation using recursion. + If jordan then Jordan form exponentiation will be used. + If cayley then the exponentiation is done using Cayley-Hamilton + theorem. + If mulsimp then the exponentiation is done using recursion + with dotprodsimp. This specifies whether intermediate term + algebraic simplification is used during naive matrix power to + control expression blowup and thus speed up calculation. + If None, then it heuristically decides which method to use. + + """ + + if method is not None and method not in ['multiply', 'mulsimp', 'jordan', 'cayley']: + raise TypeError('No such method') + if self.rows != self.cols: + raise NonSquareMatrixError() + a = self + jordan_pow = getattr(a, '_matrix_pow_by_jordan_blocks', None) + exp = sympify(exp) + + if exp.is_zero: + return a._new(a.rows, a.cols, lambda i, j: int(i == j)) + if exp == 1: + return a + + diagonal = getattr(a, 'is_diagonal', None) + if diagonal is not None and diagonal(): + return a._new(a.rows, a.cols, lambda i, j: a[i,j]**exp if i == j else 0) + + if exp.is_Number and exp % 1 == 0: + if a.rows == 1: + return a._new([[a[0]**exp]]) + if exp < 0: + exp = -exp + a = a.inv() + # When certain conditions are met, + # Jordan block algorithm is faster than + # computation by recursion. + if method == 'jordan': + try: + return jordan_pow(exp) + except MatrixError: + if method == 'jordan': + raise + + elif method == 'cayley': + if not exp.is_Number or exp % 1 != 0: + raise ValueError("cayley method is only valid for integer powers") + return a._eval_pow_by_cayley(exp) + + elif method == "mulsimp": + if not exp.is_Number or exp % 1 != 0: + raise ValueError("mulsimp method is only valid for integer powers") + return a._eval_pow_by_recursion_dotprodsimp(exp) + + elif method == "multiply": + if not exp.is_Number or exp % 1 != 0: + raise ValueError("multiply method is only valid for integer powers") + return a._eval_pow_by_recursion(exp) + + elif method is None and exp.is_Number and exp % 1 == 0: + if exp.is_Float: + exp = Integer(exp) + # Decide heuristically which method to apply + if a.rows == 2 and exp > 100000: + return jordan_pow(exp) + elif _get_intermediate_simp_bool(True, None): + return a._eval_pow_by_recursion_dotprodsimp(exp) + elif exp > 10000: + return a._eval_pow_by_cayley(exp) + else: + return a._eval_pow_by_recursion(exp) + + if jordan_pow: + try: + return jordan_pow(exp) + except NonInvertibleMatrixError: + # Raised by jordan_pow on zero determinant matrix unless exp is + # definitely known to be a non-negative integer. + # Here we raise if n is definitely not a non-negative integer + # but otherwise we can leave this as an unevaluated MatPow. + if exp.is_integer is False or exp.is_nonnegative is False: + raise + + from sympy.matrices.expressions import MatPow + return MatPow(a, exp) + + @call_highest_priority('__add__') + def __radd__(self, other): + return self.__add__(other) + + @call_highest_priority('__matmul__') + def __rmatmul__(self, other): + self, other, T = _unify_with_other(self, other) + + if T != "is_matrix": + return NotImplemented + + return self.__rmul__(other) + + @call_highest_priority('__mul__') + def __rmul__(self, other): + return self.rmultiply(other) + + def rmultiply(self, other, dotprodsimp=None): + """Same as __rmul__() but with optional simplification. + + Parameters + ========== + + dotprodsimp : bool, optional + Specifies whether intermediate term algebraic simplification is used + during matrix multiplications to control expression blowup and thus + speed up calculation. Default is off. + """ + isimpbool = _get_intermediate_simp_bool(False, dotprodsimp) + self, other, T = _unify_with_other(self, other) + + if T == "possible_scalar": + try: + return self._eval_scalar_rmul(other) + except TypeError: + return NotImplemented + + elif T == "is_matrix": + if self.shape[0] != other.shape[1]: + raise ShapeError("Matrix size mismatch.") + + m = self._eval_matrix_rmul(other) + + if isimpbool: + return m._new(m.rows, m.cols, [_dotprodsimp(e) for e in m]) + + return m + + else: + return NotImplemented + + @call_highest_priority('__sub__') + def __rsub__(self, a): + return (-self) + a + + @call_highest_priority('__rsub__') + def __sub__(self, a): + return self + (-a) + + def _eval_det_bareiss(self, iszerofunc=_is_zero_after_expand_mul): + return _det_bareiss(self, iszerofunc=iszerofunc) + + def _eval_det_berkowitz(self): + return _det_berkowitz(self) + + def _eval_det_lu(self, iszerofunc=_iszero, simpfunc=None): + return _det_LU(self, iszerofunc=iszerofunc, simpfunc=simpfunc) + + def _eval_det_bird(self): + return _det_bird(self) + + def _eval_det_laplace(self): + return _det_laplace(self) + + def _eval_determinant(self): # for expressions.determinant.Determinant + return _det(self) + + def adjugate(self, method="berkowitz"): + return _adjugate(self, method=method) + + def charpoly(self, x='lambda', simplify=_utilities_simplify): + return _charpoly(self, x=x, simplify=simplify) + + def cofactor(self, i, j, method="berkowitz"): + return _cofactor(self, i, j, method=method) + + def cofactor_matrix(self, method="berkowitz"): + return _cofactor_matrix(self, method=method) + + def det(self, method="bareiss", iszerofunc=None): + return _det(self, method=method, iszerofunc=iszerofunc) + + def per(self): + return _per(self) + + def minor(self, i, j, method="berkowitz"): + return _minor(self, i, j, method=method) + + def minor_submatrix(self, i, j): + return _minor_submatrix(self, i, j) + + _find_reasonable_pivot.__doc__ = _find_reasonable_pivot.__doc__ + _find_reasonable_pivot_naive.__doc__ = _find_reasonable_pivot_naive.__doc__ + _eval_det_bareiss.__doc__ = _det_bareiss.__doc__ + _eval_det_berkowitz.__doc__ = _det_berkowitz.__doc__ + _eval_det_bird.__doc__ = _det_bird.__doc__ + _eval_det_laplace.__doc__ = _det_laplace.__doc__ + _eval_det_lu.__doc__ = _det_LU.__doc__ + _eval_determinant.__doc__ = _det.__doc__ + adjugate.__doc__ = _adjugate.__doc__ + charpoly.__doc__ = _charpoly.__doc__ + cofactor.__doc__ = _cofactor.__doc__ + cofactor_matrix.__doc__ = _cofactor_matrix.__doc__ + det.__doc__ = _det.__doc__ + per.__doc__ = _per.__doc__ + minor.__doc__ = _minor.__doc__ + minor_submatrix.__doc__ = _minor_submatrix.__doc__ + + def echelon_form(self, iszerofunc=_iszero, simplify=False, with_pivots=False): + return _echelon_form(self, iszerofunc=iszerofunc, simplify=simplify, + with_pivots=with_pivots) + + @property + def is_echelon(self): + return _is_echelon(self) + + def rank(self, iszerofunc=_iszero, simplify=False): + return _rank(self, iszerofunc=iszerofunc, simplify=simplify) + + def rref_rhs(self, rhs): + """Return reduced row-echelon form of matrix, matrix showing + rhs after reduction steps. ``rhs`` must have the same number + of rows as ``self``. + + Examples + ======== + + >>> from sympy import Matrix, symbols + >>> r1, r2 = symbols('r1 r2') + >>> Matrix([[1, 1], [2, 1]]).rref_rhs(Matrix([r1, r2])) + (Matrix([ + [1, 0], + [0, 1]]), Matrix([ + [ -r1 + r2], + [2*r1 - r2]])) + """ + r, _ = _rref(self.hstack(self, self.eye(self.rows), rhs)) + return r[:, :self.cols], r[:, -rhs.cols:] + + def rref(self, iszerofunc=_iszero, simplify=False, pivots=True, + normalize_last=True): + return _rref(self, iszerofunc=iszerofunc, simplify=simplify, + pivots=pivots, normalize_last=normalize_last) + + echelon_form.__doc__ = _echelon_form.__doc__ + is_echelon.__doc__ = _is_echelon.__doc__ + rank.__doc__ = _rank.__doc__ + rref.__doc__ = _rref.__doc__ + + def _normalize_op_args(self, op, col, k, col1, col2, error_str="col"): + """Validate the arguments for a row/column operation. ``error_str`` + can be one of "row" or "col" depending on the arguments being parsed.""" + if op not in ["n->kn", "n<->m", "n->n+km"]: + raise ValueError("Unknown {} operation '{}'. Valid col operations " + "are 'n->kn', 'n<->m', 'n->n+km'".format(error_str, op)) + + # define self_col according to error_str + self_cols = self.cols if error_str == 'col' else self.rows + + # normalize and validate the arguments + if op == "n->kn": + col = col if col is not None else col1 + if col is None or k is None: + raise ValueError("For a {0} operation 'n->kn' you must provide the " + "kwargs `{0}` and `k`".format(error_str)) + if not 0 <= col < self_cols: + raise ValueError("This matrix does not have a {} '{}'".format(error_str, col)) + + elif op == "n<->m": + # we need two cols to swap. It does not matter + # how they were specified, so gather them together and + # remove `None` + cols = {col, k, col1, col2}.difference([None]) + if len(cols) > 2: + # maybe the user left `k` by mistake? + cols = {col, col1, col2}.difference([None]) + if len(cols) != 2: + raise ValueError("For a {0} operation 'n<->m' you must provide the " + "kwargs `{0}1` and `{0}2`".format(error_str)) + col1, col2 = cols + if not 0 <= col1 < self_cols: + raise ValueError("This matrix does not have a {} '{}'".format(error_str, col1)) + if not 0 <= col2 < self_cols: + raise ValueError("This matrix does not have a {} '{}'".format(error_str, col2)) + + elif op == "n->n+km": + col = col1 if col is None else col + col2 = col1 if col2 is None else col2 + if col is None or col2 is None or k is None: + raise ValueError("For a {0} operation 'n->n+km' you must provide the " + "kwargs `{0}`, `k`, and `{0}2`".format(error_str)) + if col == col2: + raise ValueError("For a {0} operation 'n->n+km' `{0}` and `{0}2` must " + "be different.".format(error_str)) + if not 0 <= col < self_cols: + raise ValueError("This matrix does not have a {} '{}'".format(error_str, col)) + if not 0 <= col2 < self_cols: + raise ValueError("This matrix does not have a {} '{}'".format(error_str, col2)) + + else: + raise ValueError('invalid operation %s' % repr(op)) + + return op, col, k, col1, col2 + + def _eval_col_op_multiply_col_by_const(self, col, k): + def entry(i, j): + if j == col: + return k * self[i, j] + return self[i, j] + return self._new(self.rows, self.cols, entry) + + def _eval_col_op_swap(self, col1, col2): + def entry(i, j): + if j == col1: + return self[i, col2] + elif j == col2: + return self[i, col1] + return self[i, j] + return self._new(self.rows, self.cols, entry) + + def _eval_col_op_add_multiple_to_other_col(self, col, k, col2): + def entry(i, j): + if j == col: + return self[i, j] + k * self[i, col2] + return self[i, j] + return self._new(self.rows, self.cols, entry) + + def _eval_row_op_swap(self, row1, row2): + def entry(i, j): + if i == row1: + return self[row2, j] + elif i == row2: + return self[row1, j] + return self[i, j] + return self._new(self.rows, self.cols, entry) + + def _eval_row_op_multiply_row_by_const(self, row, k): + def entry(i, j): + if i == row: + return k * self[i, j] + return self[i, j] + return self._new(self.rows, self.cols, entry) + + def _eval_row_op_add_multiple_to_other_row(self, row, k, row2): + def entry(i, j): + if i == row: + return self[i, j] + k * self[row2, j] + return self[i, j] + return self._new(self.rows, self.cols, entry) + + def elementary_col_op(self, op="n->kn", col=None, k=None, col1=None, col2=None): + """Performs the elementary column operation `op`. + + `op` may be one of + + * ``"n->kn"`` (column n goes to k*n) + * ``"n<->m"`` (swap column n and column m) + * ``"n->n+km"`` (column n goes to column n + k*column m) + + Parameters + ========== + + op : string; the elementary row operation + col : the column to apply the column operation + k : the multiple to apply in the column operation + col1 : one column of a column swap + col2 : second column of a column swap or column "m" in the column operation + "n->n+km" + """ + + op, col, k, col1, col2 = self._normalize_op_args(op, col, k, col1, col2, "col") + + # now that we've validated, we're all good to dispatch + if op == "n->kn": + return self._eval_col_op_multiply_col_by_const(col, k) + if op == "n<->m": + return self._eval_col_op_swap(col1, col2) + if op == "n->n+km": + return self._eval_col_op_add_multiple_to_other_col(col, k, col2) + + def elementary_row_op(self, op="n->kn", row=None, k=None, row1=None, row2=None): + """Performs the elementary row operation `op`. + + `op` may be one of + + * ``"n->kn"`` (row n goes to k*n) + * ``"n<->m"`` (swap row n and row m) + * ``"n->n+km"`` (row n goes to row n + k*row m) + + Parameters + ========== + + op : string; the elementary row operation + row : the row to apply the row operation + k : the multiple to apply in the row operation + row1 : one row of a row swap + row2 : second row of a row swap or row "m" in the row operation + "n->n+km" + """ + + op, row, k, row1, row2 = self._normalize_op_args(op, row, k, row1, row2, "row") + + # now that we've validated, we're all good to dispatch + if op == "n->kn": + return self._eval_row_op_multiply_row_by_const(row, k) + if op == "n<->m": + return self._eval_row_op_swap(row1, row2) + if op == "n->n+km": + return self._eval_row_op_add_multiple_to_other_row(row, k, row2) + + def columnspace(self, simplify=False): + return _columnspace(self, simplify=simplify) + + def nullspace(self, simplify=False, iszerofunc=_iszero): + return _nullspace(self, simplify=simplify, iszerofunc=iszerofunc) + + def rowspace(self, simplify=False): + return _rowspace(self, simplify=simplify) + + # This is a classmethod but is converted to such later in order to allow + # assignment of __doc__ since that does not work for already wrapped + # classmethods in Python 3.6. + def orthogonalize(cls, *vecs, **kwargs): + return _orthogonalize(cls, *vecs, **kwargs) + + columnspace.__doc__ = _columnspace.__doc__ + nullspace.__doc__ = _nullspace.__doc__ + rowspace.__doc__ = _rowspace.__doc__ + orthogonalize.__doc__ = _orthogonalize.__doc__ + + orthogonalize = classmethod(orthogonalize) # type:ignore + + def eigenvals(self, error_when_incomplete=True, **flags): + return _eigenvals(self, error_when_incomplete=error_when_incomplete, **flags) + + def eigenvects(self, error_when_incomplete=True, iszerofunc=_iszero, **flags): + return _eigenvects(self, error_when_incomplete=error_when_incomplete, + iszerofunc=iszerofunc, **flags) + + def is_diagonalizable(self, reals_only=False, **kwargs): + return _is_diagonalizable(self, reals_only=reals_only, **kwargs) + + def diagonalize(self, reals_only=False, sort=False, normalize=False): + return _diagonalize(self, reals_only=reals_only, sort=sort, + normalize=normalize) + + def bidiagonalize(self, upper=True): + return _bidiagonalize(self, upper=upper) + + def bidiagonal_decomposition(self, upper=True): + return _bidiagonal_decomposition(self, upper=upper) + + @property + def is_positive_definite(self): + return _is_positive_definite(self) + + @property + def is_positive_semidefinite(self): + return _is_positive_semidefinite(self) + + @property + def is_negative_definite(self): + return _is_negative_definite(self) + + @property + def is_negative_semidefinite(self): + return _is_negative_semidefinite(self) + + @property + def is_indefinite(self): + return _is_indefinite(self) + + def jordan_form(self, calc_transform=True, **kwargs): + return _jordan_form(self, calc_transform=calc_transform, **kwargs) + + def left_eigenvects(self, **flags): + return _left_eigenvects(self, **flags) + + def singular_values(self): + return _singular_values(self) + + eigenvals.__doc__ = _eigenvals.__doc__ + eigenvects.__doc__ = _eigenvects.__doc__ + is_diagonalizable.__doc__ = _is_diagonalizable.__doc__ + diagonalize.__doc__ = _diagonalize.__doc__ + is_positive_definite.__doc__ = _is_positive_definite.__doc__ + is_positive_semidefinite.__doc__ = _is_positive_semidefinite.__doc__ + is_negative_definite.__doc__ = _is_negative_definite.__doc__ + is_negative_semidefinite.__doc__ = _is_negative_semidefinite.__doc__ + is_indefinite.__doc__ = _is_indefinite.__doc__ + jordan_form.__doc__ = _jordan_form.__doc__ + left_eigenvects.__doc__ = _left_eigenvects.__doc__ + singular_values.__doc__ = _singular_values.__doc__ + bidiagonalize.__doc__ = _bidiagonalize.__doc__ + bidiagonal_decomposition.__doc__ = _bidiagonal_decomposition.__doc__ + + def diff(self, *args, evaluate=True, **kwargs): + """Calculate the derivative of each element in the matrix. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.abc import x, y + >>> M = Matrix([[x, y], [1, 0]]) + >>> M.diff(x) + Matrix([ + [1, 0], + [0, 0]]) + + See Also + ======== + + integrate + limit + """ + # XXX this should be handled here rather than in Derivative + from sympy.tensor.array.array_derivatives import ArrayDerivative + deriv = ArrayDerivative(self, *args, evaluate=evaluate) + # XXX This can rather changed to always return immutable matrix + if not isinstance(self, Basic) and evaluate: + return deriv.as_mutable() + return deriv + + def _eval_derivative(self, arg): + return self.applyfunc(lambda x: x.diff(arg)) + + def integrate(self, *args, **kwargs): + """Integrate each element of the matrix. ``args`` will + be passed to the ``integrate`` function. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.abc import x, y + >>> M = Matrix([[x, y], [1, 0]]) + >>> M.integrate((x, )) + Matrix([ + [x**2/2, x*y], + [ x, 0]]) + >>> M.integrate((x, 0, 2)) + Matrix([ + [2, 2*y], + [2, 0]]) + + See Also + ======== + + limit + diff + """ + return self.applyfunc(lambda x: x.integrate(*args, **kwargs)) + + def jacobian(self, X): + """Calculates the Jacobian matrix (derivative of a vector-valued function). + + Parameters + ========== + + ``self`` : vector of expressions representing functions f_i(x_1, ..., x_n). + X : set of x_i's in order, it can be a list or a Matrix + + Both ``self`` and X can be a row or a column matrix in any order + (i.e., jacobian() should always work). + + Examples + ======== + + >>> from sympy import sin, cos, Matrix + >>> from sympy.abc import rho, phi + >>> X = Matrix([rho*cos(phi), rho*sin(phi), rho**2]) + >>> Y = Matrix([rho, phi]) + >>> X.jacobian(Y) + Matrix([ + [cos(phi), -rho*sin(phi)], + [sin(phi), rho*cos(phi)], + [ 2*rho, 0]]) + >>> X = Matrix([rho*cos(phi), rho*sin(phi)]) + >>> X.jacobian(Y) + Matrix([ + [cos(phi), -rho*sin(phi)], + [sin(phi), rho*cos(phi)]]) + + See Also + ======== + + hessian + wronskian + """ + from sympy.matrices.matrixbase import MatrixBase + if not isinstance(X, MatrixBase): + X = self._new(X) + # Both X and ``self`` can be a row or a column matrix, so we need to make + # sure all valid combinations work, but everything else fails: + if self.shape[0] == 1: + m = self.shape[1] + elif self.shape[1] == 1: + m = self.shape[0] + else: + raise TypeError("``self`` must be a row or a column matrix") + if X.shape[0] == 1: + n = X.shape[1] + elif X.shape[1] == 1: + n = X.shape[0] + else: + raise TypeError("X must be a row or a column matrix") + + # m is the number of functions and n is the number of variables + # computing the Jacobian is now easy: + return self._new(m, n, lambda j, i: self[j].diff(X[i])) + + def limit(self, *args): + """Calculate the limit of each element in the matrix. + ``args`` will be passed to the ``limit`` function. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.abc import x, y + >>> M = Matrix([[x, y], [1, 0]]) + >>> M.limit(x, 2) + Matrix([ + [2, y], + [1, 0]]) + + See Also + ======== + + integrate + diff + """ + return self.applyfunc(lambda x: x.limit(*args)) + + def berkowitz_charpoly(self, x=Dummy('lambda'), simplify=_utilities_simplify): + return self.charpoly(x=x) + + def berkowitz_det(self): + """Computes determinant using Berkowitz method. + + See Also + ======== + + det + """ + return self.det(method='berkowitz') + + def berkowitz_eigenvals(self, **flags): + """Computes eigenvalues of a Matrix using Berkowitz method.""" + return self.eigenvals(**flags) + + def berkowitz_minors(self): + """Computes principal minors using Berkowitz method.""" + sign, minors = self.one, [] + + for poly in self.berkowitz(): + minors.append(sign * poly[-1]) + sign = -sign + + return tuple(minors) + + def berkowitz(self): + from sympy.matrices import zeros + berk = ((1,),) + if not self: + return berk + + if not self.is_square: + raise NonSquareMatrixError() + + A, N = self, self.rows + transforms = [0] * (N - 1) + + for n in range(N, 1, -1): + T, k = zeros(n + 1, n), n - 1 + + R, C = -A[k, :k], A[:k, k] + A, a = A[:k, :k], -A[k, k] + + items = [C] + + for i in range(0, n - 2): + items.append(A * items[i]) + + for i, B in enumerate(items): + items[i] = (R * B)[0, 0] + + items = [self.one, a] + items + + for i in range(n): + T[i:, i] = items[:n - i + 1] + + transforms[k - 1] = T + + polys = [self._new([self.one, -A[0, 0]])] + + for i, T in enumerate(transforms): + polys.append(T * polys[i]) + + return berk + tuple(map(tuple, polys)) + + def cofactorMatrix(self, method="berkowitz"): + return self.cofactor_matrix(method=method) + + def det_bareis(self): + return _det_bareiss(self) + + def det_LU_decomposition(self): + """Compute matrix determinant using LU decomposition. + + + Note that this method fails if the LU decomposition itself + fails. In particular, if the matrix has no inverse this method + will fail. + + TODO: Implement algorithm for sparse matrices (SFF), + http://www.eecis.udel.edu/~saunders/papers/sffge/it5.ps. + + See Also + ======== + + + det + berkowitz_det + """ + return self.det(method='lu') + + def jordan_cell(self, eigenval, n): + return self.jordan_block(size=n, eigenvalue=eigenval) + + def jordan_cells(self, calc_transformation=True): + P, J = self.jordan_form() + return P, J.get_diag_blocks() + + def minorEntry(self, i, j, method="berkowitz"): + return self.minor(i, j, method=method) + + def minorMatrix(self, i, j): + return self.minor_submatrix(i, j) + + def permuteBkwd(self, perm): + """Permute the rows of the matrix with the given permutation in reverse.""" + return self.permute_rows(perm, direction='backward') + + def permuteFwd(self, perm): + """Permute the rows of the matrix with the given permutation.""" + return self.permute_rows(perm, direction='forward') + + @property + def kind(self) -> MatrixKind: + elem_kinds = {e.kind for e in self.flat()} + if len(elem_kinds) == 1: + elemkind, = elem_kinds + else: + elemkind = UndefinedKind + return MatrixKind(elemkind) + + def flat(self): + """ + Returns a flat list of all elements in the matrix. + + Examples + ======== + + >>> from sympy import Matrix + >>> m = Matrix([[0, 2], [3, 4]]) + >>> m.flat() + [0, 2, 3, 4] + + See Also + ======== + + tolist + values + """ + return [self[i, j] for i in range(self.rows) for j in range(self.cols)] + + def __array__(self, dtype=object, copy=None): + if copy is not None and not copy: + raise TypeError("Cannot implement copy=False when converting Matrix to ndarray") + from .dense import matrix2numpy + return matrix2numpy(self, dtype=dtype) + + def __len__(self): + """Return the number of elements of ``self``. + + Implemented mainly so bool(Matrix()) == False. + """ + return self.rows * self.cols + + def _matrix_pow_by_jordan_blocks(self, num): + from sympy.matrices import diag, MutableMatrix + + def jordan_cell_power(jc, n): + N = jc.shape[0] + l = jc[0,0] + if l.is_zero: + if N == 1 and n.is_nonnegative: + jc[0,0] = l**n + elif not (n.is_integer and n.is_nonnegative): + raise NonInvertibleMatrixError("Non-invertible matrix can only be raised to a nonnegative integer") + else: + for i in range(N): + jc[0,i] = KroneckerDelta(i, n) + else: + for i in range(N): + bn = binomial(n, i) + if isinstance(bn, binomial): + bn = bn._eval_expand_func() + jc[0,i] = l**(n-i)*bn + for i in range(N): + for j in range(1, N-i): + jc[j,i+j] = jc [j-1,i+j-1] + + P, J = self.jordan_form() + jordan_cells = J.get_diag_blocks() + # Make sure jordan_cells matrices are mutable: + jordan_cells = [MutableMatrix(j) for j in jordan_cells] + for j in jordan_cells: + jordan_cell_power(j, num) + return self._new(P.multiply(diag(*jordan_cells)) + .multiply(P.inv())) + + def __str__(self): + if S.Zero in self.shape: + return 'Matrix(%s, %s, [])' % (self.rows, self.cols) + return "Matrix(%s)" % str(self.tolist()) + + def _format_str(self, printer=None): + if not printer: + printer = StrPrinter() + # Handle zero dimensions: + if S.Zero in self.shape: + return 'Matrix(%s, %s, [])' % (self.rows, self.cols) + if self.rows == 1: + return "Matrix([%s])" % self.table(printer, rowsep=',\n') + return "Matrix([\n%s])" % self.table(printer, rowsep=',\n') + + @classmethod + def irregular(cls, ntop, *matrices, **kwargs): + """Return a matrix filled by the given matrices which + are listed in order of appearance from left to right, top to + bottom as they first appear in the matrix. They must fill the + matrix completely. + + Examples + ======== + + >>> from sympy import ones, Matrix + >>> Matrix.irregular(3, ones(2,1), ones(3,3)*2, ones(2,2)*3, + ... ones(1,1)*4, ones(2,2)*5, ones(1,2)*6, ones(1,2)*7) + Matrix([ + [1, 2, 2, 2, 3, 3], + [1, 2, 2, 2, 3, 3], + [4, 2, 2, 2, 5, 5], + [6, 6, 7, 7, 5, 5]]) + """ + ntop = as_int(ntop) + # make sure we are working with explicit matrices + b = [i.as_explicit() if hasattr(i, 'as_explicit') else i + for i in matrices] + q = list(range(len(b))) + dat = [i.rows for i in b] + active = [q.pop(0) for _ in range(ntop)] + cols = sum(b[i].cols for i in active) + rows = [] + while any(dat): + r = [] + for a, j in enumerate(active): + r.extend(b[j][-dat[j], :]) + dat[j] -= 1 + if dat[j] == 0 and q: + active[a] = q.pop(0) + if len(r) != cols: + raise ValueError(filldedent(''' + Matrices provided do not appear to fill + the space completely.''')) + rows.append(r) + return cls._new(rows) + + @classmethod + def _handle_ndarray(cls, arg): + # NumPy array or matrix or some other object that implements + # __array__. So let's first use this method to get a + # numpy.array() and then make a Python list out of it. + arr = arg.__array__() + if len(arr.shape) == 2: + rows, cols = arr.shape[0], arr.shape[1] + flat_list = [cls._sympify(i) for i in arr.ravel()] + return rows, cols, flat_list + elif len(arr.shape) == 1: + flat_list = [cls._sympify(i) for i in arr] + return arr.shape[0], 1, flat_list + else: + raise NotImplementedError( + "SymPy supports just 1D and 2D matrices") + + @classmethod + def _handle_creation_inputs(cls, *args, **kwargs): + """Return the number of rows, cols and flat matrix elements. + + Examples + ======== + + >>> from sympy import Matrix, I + + Matrix can be constructed as follows: + + * from a nested list of iterables + + >>> Matrix( ((1, 2+I), (3, 4)) ) + Matrix([ + [1, 2 + I], + [3, 4]]) + + * from un-nested iterable (interpreted as a column) + + >>> Matrix( [1, 2] ) + Matrix([ + [1], + [2]]) + + * from un-nested iterable with dimensions + + >>> Matrix(1, 2, [1, 2] ) + Matrix([[1, 2]]) + + * from no arguments (a 0 x 0 matrix) + + >>> Matrix() + Matrix(0, 0, []) + + * from a rule + + >>> Matrix(2, 2, lambda i, j: i/(j + 1) ) + Matrix([ + [0, 0], + [1, 1/2]]) + + See Also + ======== + irregular - filling a matrix with irregular blocks + """ + from sympy.matrices import SparseMatrix + from sympy.matrices.expressions.matexpr import MatrixSymbol + from sympy.matrices.expressions.blockmatrix import BlockMatrix + + flat_list = None + + if len(args) == 1: + # Matrix(SparseMatrix(...)) + if isinstance(args[0], SparseMatrix): + return args[0].rows, args[0].cols, flatten(args[0].tolist()) + + # Matrix(Matrix(...)) + elif isinstance(args[0], MatrixBase): + return args[0].rows, args[0].cols, args[0].flat() + + # Matrix(MatrixSymbol('X', 2, 2)) + elif isinstance(args[0], Basic) and args[0].is_Matrix: + return args[0].rows, args[0].cols, args[0].as_explicit().flat() + + elif isinstance(args[0], mp.matrix): + M = args[0] + flat_list = [cls._sympify(x) for x in M] + return M.rows, M.cols, flat_list + + # Matrix(numpy.ones((2, 2))) + elif hasattr(args[0], "__array__"): + return cls._handle_ndarray(args[0]) + + # Matrix([1, 2, 3]) or Matrix([[1, 2], [3, 4]]) + elif is_sequence(args[0]) \ + and not isinstance(args[0], DeferredVector): + dat = list(args[0]) + ismat = lambda i: isinstance(i, MatrixBase) and ( + evaluate or isinstance(i, (BlockMatrix, MatrixSymbol))) + raw = lambda i: is_sequence(i) and not ismat(i) + evaluate = kwargs.get('evaluate', True) + + + if evaluate: + + def make_explicit(x): + """make Block and Symbol explicit""" + if isinstance(x, BlockMatrix): + return x.as_explicit() + elif isinstance(x, MatrixSymbol) and all(_.is_Integer for _ in x.shape): + return x.as_explicit() + else: + return x + + def make_explicit_row(row): + # Could be list or could be list of lists + if isinstance(row, (list, tuple)): + return [make_explicit(x) for x in row] + else: + return make_explicit(row) + + if isinstance(dat, (list, tuple)): + dat = [make_explicit_row(row) for row in dat] + + if len(dat) == 0: + rows = cols = 0 + flat_list = [] + elif all(raw(i) for i in dat) and len(dat[0]) == 0: + if not all(len(i) == 0 for i in dat): + raise ValueError('mismatched dimensions') + rows = len(dat) + cols = 0 + flat_list = [] + elif not any(raw(i) or ismat(i) for i in dat): + # a column as a list of values + flat_list = [cls._sympify(i) for i in dat] + rows = len(flat_list) + cols = 1 if rows else 0 + elif evaluate and all(ismat(i) for i in dat): + # a column as a list of matrices + ncol = {i.cols for i in dat if any(i.shape)} + if ncol: + if len(ncol) != 1: + raise ValueError('mismatched dimensions') + flat_list = [_ for i in dat for r in i.tolist() for _ in r] + cols = ncol.pop() + rows = len(flat_list)//cols + else: + rows = cols = 0 + flat_list = [] + elif evaluate and any(ismat(i) for i in dat): + ncol = set() + flat_list = [] + for i in dat: + if ismat(i): + flat_list.extend( + [k for j in i.tolist() for k in j]) + if any(i.shape): + ncol.add(i.cols) + elif raw(i): + if i: + ncol.add(len(i)) + flat_list.extend([cls._sympify(ij) for ij in i]) + else: + ncol.add(1) + flat_list.append(i) + if len(ncol) > 1: + raise ValueError('mismatched dimensions') + cols = ncol.pop() + rows = len(flat_list)//cols + else: + # list of lists; each sublist is a logical row + # which might consist of many rows if the values in + # the row are matrices + flat_list = [] + ncol = set() + rows = cols = 0 + for row in dat: + if not is_sequence(row) and \ + not getattr(row, 'is_Matrix', False): + raise ValueError('expecting list of lists') + + if hasattr(row, '__array__'): + if 0 in row.shape: + continue + + if evaluate and all(ismat(i) for i in row): + r, c, flatT = cls._handle_creation_inputs( + [i.T for i in row]) + T = reshape(flatT, [c]) + flat = \ + [T[i][j] for j in range(c) for i in range(r)] + r, c = c, r + else: + r = 1 + if getattr(row, 'is_Matrix', False): + c = 1 + flat = [row] + else: + c = len(row) + flat = [cls._sympify(i) for i in row] + ncol.add(c) + if len(ncol) > 1: + raise ValueError('mismatched dimensions') + flat_list.extend(flat) + rows += r + cols = ncol.pop() if ncol else 0 + + elif len(args) == 3: + rows = as_int(args[0]) + cols = as_int(args[1]) + + if rows < 0 or cols < 0: + raise ValueError("Cannot create a {} x {} matrix. " + "Both dimensions must be positive".format(rows, cols)) + + # Matrix(2, 2, lambda i, j: i+j) + if len(args) == 3 and isinstance(args[2], Callable): + op = args[2] + flat_list = [] + for i in range(rows): + flat_list.extend( + [cls._sympify(op(cls._sympify(i), cls._sympify(j))) + for j in range(cols)]) + + # Matrix(2, 2, [1, 2, 3, 4]) + elif len(args) == 3 and is_sequence(args[2]): + flat_list = args[2] + if len(flat_list) != rows * cols: + raise ValueError( + 'List length should be equal to rows*columns') + flat_list = [cls._sympify(i) for i in flat_list] + + + # Matrix() + elif len(args) == 0: + # Empty Matrix + rows = cols = 0 + flat_list = [] + + if flat_list is None: + raise TypeError(filldedent(''' + Data type not understood; expecting list of lists + or lists of values.''')) + + return rows, cols, flat_list + + def _setitem(self, key, value): + """Helper to set value at location given by key. + + Examples + ======== + + >>> from sympy import Matrix, I, zeros, ones + >>> m = Matrix(((1, 2+I), (3, 4))) + >>> m + Matrix([ + [1, 2 + I], + [3, 4]]) + >>> m[1, 0] = 9 + >>> m + Matrix([ + [1, 2 + I], + [9, 4]]) + >>> m[1, 0] = [[0, 1]] + + To replace row r you assign to position r*m where m + is the number of columns: + + >>> M = zeros(4) + >>> m = M.cols + >>> M[3*m] = ones(1, m)*2; M + Matrix([ + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0], + [2, 2, 2, 2]]) + + And to replace column c you can assign to position c: + + >>> M[2] = ones(m, 1)*4; M + Matrix([ + [0, 0, 4, 0], + [0, 0, 4, 0], + [0, 0, 4, 0], + [2, 2, 4, 2]]) + """ + from .dense import Matrix + + is_slice = isinstance(key, slice) + i, j = key = self.key2ij(key) + is_mat = isinstance(value, MatrixBase) + if isinstance(i, slice) or isinstance(j, slice): + if is_mat: + self.copyin_matrix(key, value) + return + if not isinstance(value, Expr) and is_sequence(value): + self.copyin_list(key, value) + return + raise ValueError('unexpected value: %s' % value) + else: + if (not is_mat and + not isinstance(value, Basic) and is_sequence(value)): + value = Matrix(value) + is_mat = True + if is_mat: + if is_slice: + key = (slice(*divmod(i, self.cols)), + slice(*divmod(j, self.cols))) + else: + key = (slice(i, i + value.rows), + slice(j, j + value.cols)) + self.copyin_matrix(key, value) + else: + return i, j, self._sympify(value) + return + + def add(self, b): + """Return self + b.""" + return self + b + + def condition_number(self): + """Returns the condition number of a matrix. + + This is the maximum singular value divided by the minimum singular value + + Examples + ======== + + >>> from sympy import Matrix, S + >>> A = Matrix([[1, 0, 0], [0, 10, 0], [0, 0, S.One/10]]) + >>> A.condition_number() + 100 + + See Also + ======== + + singular_values + """ + + if not self: + return self.zero + singularvalues = self.singular_values() + return Max(*singularvalues) / Min(*singularvalues) + + def copy(self): + """ + Returns the copy of a matrix. + + Examples + ======== + + >>> from sympy import Matrix + >>> A = Matrix(2, 2, [1, 2, 3, 4]) + >>> A.copy() + Matrix([ + [1, 2], + [3, 4]]) + + """ + return self._new(self.rows, self.cols, self.flat()) + + def cross(self, b): + r""" + Return the cross product of ``self`` and ``b`` relaxing the condition + of compatible dimensions: if each has 3 elements, a matrix of the + same type and shape as ``self`` will be returned. If ``b`` has the same + shape as ``self`` then common identities for the cross product (like + `a \times b = - b \times a`) will hold. + + Parameters + ========== + b : 3x1 or 1x3 Matrix + + See Also + ======== + + dot + hat + vee + multiply + multiply_elementwise + """ + from sympy.matrices.expressions.matexpr import MatrixExpr + + if not isinstance(b, (MatrixBase, MatrixExpr)): + raise TypeError( + "{} must be a Matrix, not {}.".format(b, type(b))) + + if not (self.rows * self.cols == b.rows * b.cols == 3): + raise ShapeError("Dimensions incorrect for cross product: %s x %s" % + ((self.rows, self.cols), (b.rows, b.cols))) + else: + return self._new(self.rows, self.cols, ( + (self[1] * b[2] - self[2] * b[1]), + (self[2] * b[0] - self[0] * b[2]), + (self[0] * b[1] - self[1] * b[0]))) + + def hat(self): + r""" + Return the skew-symmetric matrix representing the cross product, + so that ``self.hat() * b`` is equivalent to ``self.cross(b)``. + + Examples + ======== + + Calling ``hat`` creates a skew-symmetric 3x3 Matrix from a 3x1 Matrix: + + >>> from sympy import Matrix + >>> a = Matrix([1, 2, 3]) + >>> a.hat() + Matrix([ + [ 0, -3, 2], + [ 3, 0, -1], + [-2, 1, 0]]) + + Multiplying it with another 3x1 Matrix calculates the cross product: + + >>> b = Matrix([3, 2, 1]) + >>> a.hat() * b + Matrix([ + [-4], + [ 8], + [-4]]) + + Which is equivalent to calling the ``cross`` method: + + >>> a.cross(b) + Matrix([ + [-4], + [ 8], + [-4]]) + + See Also + ======== + + dot + cross + vee + multiply + multiply_elementwise + """ + + if self.shape != (3, 1): + raise ShapeError("Dimensions incorrect, expected (3, 1), got " + + str(self.shape)) + else: + x, y, z = self + return self._new(3, 3, ( + 0, -z, y, + z, 0, -x, + -y, x, 0)) + + def vee(self): + r""" + Return a 3x1 vector from a skew-symmetric matrix representing the cross product, + so that ``self * b`` is equivalent to ``self.vee().cross(b)``. + + Examples + ======== + + Calling ``vee`` creates a vector from a skew-symmetric Matrix: + + >>> from sympy import Matrix + >>> A = Matrix([[0, -3, 2], [3, 0, -1], [-2, 1, 0]]) + >>> a = A.vee() + >>> a + Matrix([ + [1], + [2], + [3]]) + + Calculating the matrix product of the original matrix with a vector + is equivalent to a cross product: + + >>> b = Matrix([3, 2, 1]) + >>> A * b + Matrix([ + [-4], + [ 8], + [-4]]) + + >>> a.cross(b) + Matrix([ + [-4], + [ 8], + [-4]]) + + ``vee`` can also be used to retrieve angular velocity expressions. + Defining a rotation matrix: + + >>> from sympy import rot_ccw_axis3, trigsimp + >>> from sympy.physics.mechanics import dynamicsymbols + >>> theta = dynamicsymbols('theta') + >>> R = rot_ccw_axis3(theta) + >>> R + Matrix([ + [cos(theta(t)), -sin(theta(t)), 0], + [sin(theta(t)), cos(theta(t)), 0], + [ 0, 0, 1]]) + + We can retrieve the angular velocity: + + >>> Omega = R.T * R.diff() + >>> Omega = trigsimp(Omega) + >>> Omega.vee() + Matrix([ + [ 0], + [ 0], + [Derivative(theta(t), t)]]) + + See Also + ======== + + dot + cross + hat + multiply + multiply_elementwise + """ + + if self.shape != (3, 3): + raise ShapeError("Dimensions incorrect, expected (3, 3), got " + + str(self.shape)) + elif not self.is_anti_symmetric(): + raise ValueError("Matrix is not skew-symmetric") + else: + return self._new(3, 1, ( + self[2, 1], + self[0, 2], + self[1, 0])) + + @property + def D(self): + """Return Dirac conjugate (if ``self.rows == 4``). + + Examples + ======== + + >>> from sympy import Matrix, I, eye + >>> m = Matrix((0, 1 + I, 2, 3)) + >>> m.D + Matrix([[0, 1 - I, -2, -3]]) + >>> m = (eye(4) + I*eye(4)) + >>> m[0, 3] = 2 + >>> m.D + Matrix([ + [1 - I, 0, 0, 0], + [ 0, 1 - I, 0, 0], + [ 0, 0, -1 + I, 0], + [ 2, 0, 0, -1 + I]]) + + If the matrix does not have 4 rows an AttributeError will be raised + because this property is only defined for matrices with 4 rows. + + >>> Matrix(eye(2)).D + Traceback (most recent call last): + ... + AttributeError: Matrix has no attribute D. + + See Also + ======== + + sympy.matrices.matrixbase.MatrixBase.conjugate: By-element conjugation + sympy.matrices.matrixbase.MatrixBase.H: Hermite conjugation + """ + from sympy.physics.matrices import mgamma + if self.rows != 4: + # In Python 3.2, properties can only return an AttributeError + # so we can't raise a ShapeError -- see commit which added the + # first line of this inline comment. Also, there is no need + # for a message since MatrixBase will raise the AttributeError + raise AttributeError + return self.H * mgamma(0) + + def dot(self, b, hermitian=None, conjugate_convention=None): + """Return the dot or inner product of two vectors of equal length. + Here ``self`` must be a ``Matrix`` of size 1 x n or n x 1, and ``b`` + must be either a matrix of size 1 x n, n x 1, or a list/tuple of length n. + A scalar is returned. + + By default, ``dot`` does not conjugate ``self`` or ``b``, even if there are + complex entries. Set ``hermitian=True`` (and optionally a ``conjugate_convention``) + to compute the hermitian inner product. + + Possible kwargs are ``hermitian`` and ``conjugate_convention``. + + If ``conjugate_convention`` is ``"left"``, ``"math"`` or ``"maths"``, + the conjugate of the first vector (``self``) is used. If ``"right"`` + or ``"physics"`` is specified, the conjugate of the second vector ``b`` is used. + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + >>> v = Matrix([1, 1, 1]) + >>> M.row(0).dot(v) + 6 + >>> M.col(0).dot(v) + 12 + >>> v = [3, 2, 1] + >>> M.row(0).dot(v) + 10 + + >>> from sympy import I + >>> q = Matrix([1*I, 1*I, 1*I]) + >>> q.dot(q, hermitian=False) + -3 + + >>> q.dot(q, hermitian=True) + 3 + + >>> q1 = Matrix([1, 1, 1*I]) + >>> q.dot(q1, hermitian=True, conjugate_convention="maths") + 1 - 2*I + >>> q.dot(q1, hermitian=True, conjugate_convention="physics") + 1 + 2*I + + + See Also + ======== + + cross + multiply + multiply_elementwise + """ + from .dense import Matrix + + if not isinstance(b, MatrixBase): + if is_sequence(b): + if len(b) != self.cols and len(b) != self.rows: + raise ShapeError( + "Dimensions incorrect for dot product: %s, %s" % ( + self.shape, len(b))) + return self.dot(Matrix(b)) + else: + raise TypeError( + "`b` must be an ordered iterable or Matrix, not %s." % + type(b)) + + if (1 not in self.shape) or (1 not in b.shape): + raise ShapeError + if len(self) != len(b): + raise ShapeError( + "Dimensions incorrect for dot product: %s, %s" % (self.shape, b.shape)) + + mat = self + n = len(mat) + if mat.shape != (1, n): + mat = mat.reshape(1, n) + if b.shape != (n, 1): + b = b.reshape(n, 1) + + # Now ``mat`` is a row vector and ``b`` is a column vector. + + # If it so happens that only conjugate_convention is passed + # then automatically set hermitian to True. If only hermitian + # is true but no conjugate_convention is not passed then + # automatically set it to ``"maths"`` + + if conjugate_convention is not None and hermitian is None: + hermitian = True + if hermitian and conjugate_convention is None: + conjugate_convention = "maths" + + if hermitian == True: + if conjugate_convention in ("maths", "left", "math"): + mat = mat.conjugate() + elif conjugate_convention in ("physics", "right"): + b = b.conjugate() + else: + raise ValueError("Unknown conjugate_convention was entered." + " conjugate_convention must be one of the" + " following: math, maths, left, physics or right.") + return (mat * b)[0] + + def dual(self): + """Returns the dual of a matrix. + + A dual of a matrix is: + + ``(1/2)*levicivita(i, j, k, l)*M(k, l)`` summed over indices `k` and `l` + + Since the levicivita method is anti_symmetric for any pairwise + exchange of indices, the dual of a symmetric matrix is the zero + matrix. Strictly speaking the dual defined here assumes that the + 'matrix' `M` is a contravariant anti_symmetric second rank tensor, + so that the dual is a covariant second rank tensor. + + """ + from sympy.matrices import zeros + + M, n = self[:, :], self.rows + work = zeros(n) + if self.is_symmetric(): + return work + + for i in range(1, n): + for j in range(1, n): + acum = 0 + for k in range(1, n): + acum += LeviCivita(i, j, 0, k) * M[0, k] + work[i, j] = acum + work[j, i] = -acum + + for l in range(1, n): + acum = 0 + for a in range(1, n): + for b in range(1, n): + acum += LeviCivita(0, l, a, b) * M[a, b] + acum /= 2 + work[0, l] = -acum + work[l, 0] = acum + + return work + + def _eval_matrix_exp_jblock(self): + """A helper function to compute an exponential of a Jordan block + matrix + + Examples + ======== + + >>> from sympy import Symbol, Matrix + >>> l = Symbol('lamda') + + A trivial example of 1*1 Jordan block: + + >>> m = Matrix.jordan_block(1, l) + >>> m._eval_matrix_exp_jblock() + Matrix([[exp(lamda)]]) + + An example of 3*3 Jordan block: + + >>> m = Matrix.jordan_block(3, l) + >>> m._eval_matrix_exp_jblock() + Matrix([ + [exp(lamda), exp(lamda), exp(lamda)/2], + [ 0, exp(lamda), exp(lamda)], + [ 0, 0, exp(lamda)]]) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Matrix_function#Jordan_decomposition + """ + size = self.rows + l = self[0, 0] + exp_l = exp(l) + + bands = {i: exp_l / factorial(i) for i in range(size)} + + from .sparsetools import banded + return self.__class__(banded(size, bands)) + + + def analytic_func(self, f, x): + """ + Computes f(A) where A is a Square Matrix + and f is an analytic function. + + Examples + ======== + + >>> from sympy import Symbol, Matrix, S, log + + >>> x = Symbol('x') + >>> m = Matrix([[S(5)/4, S(3)/4], [S(3)/4, S(5)/4]]) + >>> f = log(x) + >>> m.analytic_func(f, x) + Matrix([ + [ 0, log(2)], + [log(2), 0]]) + + Parameters + ========== + + f : Expr + Analytic Function + x : Symbol + parameter of f + + """ + + f, x = _sympify(f), _sympify(x) + if not self.is_square: + raise NonSquareMatrixError + if not x.is_symbol: + raise ValueError("{} must be a symbol.".format(x)) + if x not in f.free_symbols: + raise ValueError( + "{} must be a parameter of {}.".format(x, f)) + if x in self.free_symbols: + raise ValueError( + "{} must not be a parameter of {}.".format(x, self)) + + eigen = self.eigenvals() + max_mul = max(eigen.values()) + derivative = {} + dd = f + for i in range(max_mul - 1): + dd = diff(dd, x) + derivative[i + 1] = dd + n = self.shape[0] + r = self.zeros(n) + f_val = self.zeros(n, 1) + row = 0 + + for i in eigen: + mul = eigen[i] + f_val[row] = f.subs(x, i) + if f_val[row].is_number and not f_val[row].is_complex: + raise ValueError( + "Cannot evaluate the function because the " + "function {} is not analytic at the given " + "eigenvalue {}".format(f, f_val[row])) + val = 1 + for a in range(n): + r[row, a] = val + val *= i + if mul > 1: + coe = [1 for ii in range(n)] + deri = 1 + while mul > 1: + row = row + 1 + mul -= 1 + d_i = derivative[deri].subs(x, i) + if d_i.is_number and not d_i.is_complex: + raise ValueError( + "Cannot evaluate the function because the " + "derivative {} is not analytic at the given " + "eigenvalue {}".format(derivative[deri], d_i)) + f_val[row] = d_i + for a in range(n): + if a - deri + 1 <= 0: + r[row, a] = 0 + coe[a] = 0 + continue + coe[a] = coe[a]*(a - deri + 1) + r[row, a] = coe[a]*pow(i, a - deri) + deri += 1 + row += 1 + c = r.solve(f_val) + ans = self.zeros(n) + pre = self.eye(n) + for i in range(n): + ans = ans + c[i]*pre + pre *= self + return ans + + + def exp(self): + """Return the exponential of a square matrix. + + Examples + ======== + + >>> from sympy import Symbol, Matrix + + >>> t = Symbol('t') + >>> m = Matrix([[0, 1], [-1, 0]]) * t + >>> m.exp() + Matrix([ + [ exp(I*t)/2 + exp(-I*t)/2, -I*exp(I*t)/2 + I*exp(-I*t)/2], + [I*exp(I*t)/2 - I*exp(-I*t)/2, exp(I*t)/2 + exp(-I*t)/2]]) + """ + if not self.is_square: + raise NonSquareMatrixError( + "Exponentiation is valid only for square matrices") + try: + P, J = self.jordan_form() + cells = J.get_diag_blocks() + except MatrixError: + raise NotImplementedError( + "Exponentiation is implemented only for matrices for which the Jordan normal form can be computed") + + blocks = [cell._eval_matrix_exp_jblock() for cell in cells] + from sympy.matrices import diag + eJ = diag(*blocks) + # n = self.rows + ret = P.multiply(eJ, dotprodsimp=None).multiply(P.inv(), dotprodsimp=None) + if all(value.is_real for value in self.values()): + return type(self)(re(ret)) + else: + return type(self)(ret) + + def _eval_matrix_log_jblock(self): + """Helper function to compute logarithm of a jordan block. + + Examples + ======== + + >>> from sympy import Symbol, Matrix + >>> l = Symbol('lamda') + + A trivial example of 1*1 Jordan block: + + >>> m = Matrix.jordan_block(1, l) + >>> m._eval_matrix_log_jblock() + Matrix([[log(lamda)]]) + + An example of 3*3 Jordan block: + + >>> m = Matrix.jordan_block(3, l) + >>> m._eval_matrix_log_jblock() + Matrix([ + [log(lamda), 1/lamda, -1/(2*lamda**2)], + [ 0, log(lamda), 1/lamda], + [ 0, 0, log(lamda)]]) + """ + size = self.rows + l = self[0, 0] + + if l.is_zero: + raise MatrixError( + 'Could not take logarithm or reciprocal for the given ' + 'eigenvalue {}'.format(l)) + + bands = {0: log(l)} + for i in range(1, size): + bands[i] = -((-l) ** -i) / i + + from .sparsetools import banded + return self.__class__(banded(size, bands)) + + def log(self, simplify=cancel): + """Return the logarithm of a square matrix. + + Parameters + ========== + + simplify : function, bool + The function to simplify the result with. + + Default is ``cancel``, which is effective to reduce the + expression growing for taking reciprocals and inverses for + symbolic matrices. + + Examples + ======== + + >>> from sympy import S, Matrix + + Examples for positive-definite matrices: + + >>> m = Matrix([[1, 1], [0, 1]]) + >>> m.log() + Matrix([ + [0, 1], + [0, 0]]) + + >>> m = Matrix([[S(5)/4, S(3)/4], [S(3)/4, S(5)/4]]) + >>> m.log() + Matrix([ + [ 0, log(2)], + [log(2), 0]]) + + Examples for non positive-definite matrices: + + >>> m = Matrix([[S(3)/4, S(5)/4], [S(5)/4, S(3)/4]]) + >>> m.log() + Matrix([ + [ I*pi/2, log(2) - I*pi/2], + [log(2) - I*pi/2, I*pi/2]]) + + >>> m = Matrix( + ... [[0, 0, 0, 1], + ... [0, 0, 1, 0], + ... [0, 1, 0, 0], + ... [1, 0, 0, 0]]) + >>> m.log() + Matrix([ + [ I*pi/2, 0, 0, -I*pi/2], + [ 0, I*pi/2, -I*pi/2, 0], + [ 0, -I*pi/2, I*pi/2, 0], + [-I*pi/2, 0, 0, I*pi/2]]) + """ + if not self.is_square: + raise NonSquareMatrixError( + "Logarithm is valid only for square matrices") + + try: + if simplify: + P, J = simplify(self).jordan_form() + else: + P, J = self.jordan_form() + + cells = J.get_diag_blocks() + except MatrixError: + raise NotImplementedError( + "Logarithm is implemented only for matrices for which " + "the Jordan normal form can be computed") + + blocks = [ + cell._eval_matrix_log_jblock() + for cell in cells] + from sympy.matrices import diag + eJ = diag(*blocks) + + if simplify: + ret = simplify(P * eJ * simplify(P.inv())) + ret = self.__class__(ret) + else: + ret = P * eJ * P.inv() + + return ret + + def is_nilpotent(self): + """Checks if a matrix is nilpotent. + + A matrix B is nilpotent if for some integer k, B**k is + a zero matrix. + + Examples + ======== + + >>> from sympy import Matrix + >>> a = Matrix([[0, 0, 0], [1, 0, 0], [1, 1, 0]]) + >>> a.is_nilpotent() + True + + >>> a = Matrix([[1, 0, 1], [1, 0, 0], [1, 1, 0]]) + >>> a.is_nilpotent() + False + """ + if not self: + return True + if not self.is_square: + raise NonSquareMatrixError( + "Nilpotency is valid only for square matrices") + x = uniquely_named_symbol('x', self, modify=lambda s: '_' + s) + p = self.charpoly(x) + if p.args[0] == x ** self.rows: + return True + return False + + def key2bounds(self, keys): + """Converts a key with potentially mixed types of keys (integer and slice) + into a tuple of ranges and raises an error if any index is out of ``self``'s + range. + + See Also + ======== + + key2ij + """ + islice, jslice = [isinstance(k, slice) for k in keys] + if islice: + if not self.rows: + rlo = rhi = 0 + else: + rlo, rhi = keys[0].indices(self.rows)[:2] + else: + rlo = a2idx(keys[0], self.rows) + rhi = rlo + 1 + if jslice: + if not self.cols: + clo = chi = 0 + else: + clo, chi = keys[1].indices(self.cols)[:2] + else: + clo = a2idx(keys[1], self.cols) + chi = clo + 1 + return rlo, rhi, clo, chi + + def key2ij(self, key): + """Converts key into canonical form, converting integers or indexable + items into valid integers for ``self``'s range or returning slices + unchanged. + + See Also + ======== + + key2bounds + """ + if is_sequence(key): + if not len(key) == 2: + raise TypeError('key must be a sequence of length 2') + return [a2idx(i, n) if not isinstance(i, slice) else i + for i, n in zip(key, self.shape)] + elif isinstance(key, slice): + return key.indices(len(self))[:2] + else: + return divmod(a2idx(key, len(self)), self.cols) + + def normalized(self, iszerofunc=_iszero): + """Return the normalized version of ``self``. + + Parameters + ========== + + iszerofunc : Function, optional + A function to determine whether ``self`` is a zero vector. + The default ``_iszero`` tests to see if each element is + exactly zero. + + Returns + ======= + + Matrix + Normalized vector form of ``self``. + It has the same length as a unit vector. However, a zero vector + will be returned for a vector with norm 0. + + Raises + ====== + + ShapeError + If the matrix is not in a vector form. + + See Also + ======== + + norm + """ + if self.rows != 1 and self.cols != 1: + raise ShapeError("A Matrix must be a vector to normalize.") + norm = self.norm() + if iszerofunc(norm): + out = self.zeros(self.rows, self.cols) + else: + out = self.applyfunc(lambda i: i / norm) + return out + + def norm(self, ord=None): + """Return the Norm of a Matrix or Vector. + + In the simplest case this is the geometric size of the vector + Other norms can be specified by the ord parameter + + + ===== ============================ ========================== + ord norm for matrices norm for vectors + ===== ============================ ========================== + None Frobenius norm 2-norm + 'fro' Frobenius norm - does not exist + inf maximum row sum max(abs(x)) + -inf -- min(abs(x)) + 1 maximum column sum as below + -1 -- as below + 2 2-norm (largest sing. value) as below + -2 smallest singular value as below + other - does not exist sum(abs(x)**ord)**(1./ord) + ===== ============================ ========================== + + Examples + ======== + + >>> from sympy import Matrix, Symbol, trigsimp, cos, sin, oo + >>> x = Symbol('x', real=True) + >>> v = Matrix([cos(x), sin(x)]) + >>> trigsimp( v.norm() ) + 1 + >>> v.norm(10) + (sin(x)**10 + cos(x)**10)**(1/10) + >>> A = Matrix([[1, 1], [1, 1]]) + >>> A.norm(1) # maximum sum of absolute values of A is 2 + 2 + >>> A.norm(2) # Spectral norm (max of |Ax|/|x| under 2-vector-norm) + 2 + >>> A.norm(-2) # Inverse spectral norm (smallest singular value) + 0 + >>> A.norm() # Frobenius Norm + 2 + >>> A.norm(oo) # Infinity Norm + 2 + >>> Matrix([1, -2]).norm(oo) + 2 + >>> Matrix([-1, 2]).norm(-oo) + 1 + + See Also + ======== + + normalized + """ + # Row or Column Vector Norms + vals = list(self.values()) or [0] + if S.One in self.shape: + if ord in (2, None): # Common case sqrt() + return sqrt(Add(*(abs(i) ** 2 for i in vals))) + + elif ord == 1: # sum(abs(x)) + return Add(*(abs(i) for i in vals)) + + elif ord is S.Infinity: # max(abs(x)) + return Max(*[abs(i) for i in vals]) + + elif ord is S.NegativeInfinity: # min(abs(x)) + return Min(*[abs(i) for i in vals]) + + # Otherwise generalize the 2-norm, Sum(x_i**ord)**(1/ord) + # Note that while useful this is not mathematically a norm + try: + return Pow(Add(*(abs(i) ** ord for i in vals)), S.One / ord) + except (NotImplementedError, TypeError): + raise ValueError("Expected order to be Number, Symbol, oo") + + # Matrix Norms + else: + if ord == 1: # Maximum column sum + m = self.applyfunc(abs) + return Max(*[sum(m.col(i)) for i in range(m.cols)]) + + elif ord == 2: # Spectral Norm + # Maximum singular value + return Max(*self.singular_values()) + + elif ord == -2: + # Minimum singular value + return Min(*self.singular_values()) + + elif ord is S.Infinity: # Infinity Norm - Maximum row sum + m = self.applyfunc(abs) + return Max(*[sum(m.row(i)) for i in range(m.rows)]) + + elif (ord is None or isinstance(ord, + str) and ord.lower() in + ['f', 'fro', 'frobenius', 'vector']): + # Reshape as vector and send back to norm function + return self.vec().norm(ord=2) + + else: + raise NotImplementedError("Matrix Norms under development") + + def print_nonzero(self, symb="X"): + """Shows location of non-zero entries for fast shape lookup. + + Examples + ======== + + >>> from sympy import Matrix, eye + >>> m = Matrix(2, 3, lambda i, j: i*3+j) + >>> m + Matrix([ + [0, 1, 2], + [3, 4, 5]]) + >>> m.print_nonzero() + [ XX] + [XXX] + >>> m = eye(4) + >>> m.print_nonzero("x") + [x ] + [ x ] + [ x ] + [ x] + + """ + s = [] + for i in range(self.rows): + line = [] + for j in range(self.cols): + if self[i, j] == 0: + line.append(" ") + else: + line.append(str(symb)) + s.append("[%s]" % ''.join(line)) + print('\n'.join(s)) + + def project(self, v): + """Return the projection of ``self`` onto the line containing ``v``. + + Examples + ======== + + >>> from sympy import Matrix, S, sqrt + >>> V = Matrix([sqrt(3)/2, S.Half]) + >>> x = Matrix([[1, 0]]) + >>> V.project(x) + Matrix([[sqrt(3)/2, 0]]) + >>> V.project(-x) + Matrix([[sqrt(3)/2, 0]]) + """ + return v * (self.dot(v) / v.dot(v)) + + def table(self, printer, rowstart='[', rowend=']', rowsep='\n', + colsep=', ', align='right'): + r""" + String form of Matrix as a table. + + ``printer`` is the printer to use for on the elements (generally + something like StrPrinter()) + + ``rowstart`` is the string used to start each row (by default '['). + + ``rowend`` is the string used to end each row (by default ']'). + + ``rowsep`` is the string used to separate rows (by default a newline). + + ``colsep`` is the string used to separate columns (by default ', '). + + ``align`` defines how the elements are aligned. Must be one of 'left', + 'right', or 'center'. You can also use '<', '>', and '^' to mean the + same thing, respectively. + + This is used by the string printer for Matrix. + + Examples + ======== + + >>> from sympy import Matrix, StrPrinter + >>> M = Matrix([[1, 2], [-33, 4]]) + >>> printer = StrPrinter() + >>> M.table(printer) + '[ 1, 2]\n[-33, 4]' + >>> print(M.table(printer)) + [ 1, 2] + [-33, 4] + >>> print(M.table(printer, rowsep=',\n')) + [ 1, 2], + [-33, 4] + >>> print('[%s]' % M.table(printer, rowsep=',\n')) + [[ 1, 2], + [-33, 4]] + >>> print(M.table(printer, colsep=' ')) + [ 1 2] + [-33 4] + >>> print(M.table(printer, align='center')) + [ 1 , 2] + [-33, 4] + >>> print(M.table(printer, rowstart='{', rowend='}')) + { 1, 2} + {-33, 4} + """ + # Handle zero dimensions: + if S.Zero in self.shape: + return '[]' + # Build table of string representations of the elements + res = [] + # Track per-column max lengths for pretty alignment + maxlen = [0] * self.cols + for i in range(self.rows): + res.append([]) + for j in range(self.cols): + s = printer._print(self[i, j]) + res[-1].append(s) + maxlen[j] = max(len(s), maxlen[j]) + # Patch strings together + align = { + 'left': 'ljust', + 'right': 'rjust', + 'center': 'center', + '<': 'ljust', + '>': 'rjust', + '^': 'center', + }[align] + for i, row in enumerate(res): + for j, elem in enumerate(row): + row[j] = getattr(elem, align)(maxlen[j]) + res[i] = rowstart + colsep.join(row) + rowend + return rowsep.join(res) + + def rank_decomposition(self, iszerofunc=_iszero, simplify=False): + return _rank_decomposition(self, iszerofunc=iszerofunc, + simplify=simplify) + + def cholesky(self, hermitian=True): + raise NotImplementedError('This function is implemented in DenseMatrix or SparseMatrix') + + def LDLdecomposition(self, hermitian=True): + raise NotImplementedError('This function is implemented in DenseMatrix or SparseMatrix') + + def LUdecomposition(self, iszerofunc=_iszero, simpfunc=None, + rankcheck=False): + return _LUdecomposition(self, iszerofunc=iszerofunc, simpfunc=simpfunc, + rankcheck=rankcheck) + + def LUdecomposition_Simple(self, iszerofunc=_iszero, simpfunc=None, + rankcheck=False): + return _LUdecomposition_Simple(self, iszerofunc=iszerofunc, + simpfunc=simpfunc, rankcheck=rankcheck) + + def LUdecompositionFF(self): + return _LUdecompositionFF(self) + + def singular_value_decomposition(self): + return _singular_value_decomposition(self) + + def QRdecomposition(self): + return _QRdecomposition(self) + + def upper_hessenberg_decomposition(self): + return _upper_hessenberg_decomposition(self) + + def diagonal_solve(self, rhs): + return _diagonal_solve(self, rhs) + + def lower_triangular_solve(self, rhs): + raise NotImplementedError('This function is implemented in DenseMatrix or SparseMatrix') + + def upper_triangular_solve(self, rhs): + raise NotImplementedError('This function is implemented in DenseMatrix or SparseMatrix') + + def cholesky_solve(self, rhs): + return _cholesky_solve(self, rhs) + + def LDLsolve(self, rhs): + return _LDLsolve(self, rhs) + + def LUsolve(self, rhs, iszerofunc=_iszero): + return _LUsolve(self, rhs, iszerofunc=iszerofunc) + + def QRsolve(self, b): + return _QRsolve(self, b) + + def gauss_jordan_solve(self, B, freevar=False): + return _gauss_jordan_solve(self, B, freevar=freevar) + + def pinv_solve(self, B, arbitrary_matrix=None): + return _pinv_solve(self, B, arbitrary_matrix=arbitrary_matrix) + + def cramer_solve(self, rhs, det_method="laplace"): + return _cramer_solve(self, rhs, det_method=det_method) + + def solve(self, rhs, method='GJ'): + return _solve(self, rhs, method=method) + + def solve_least_squares(self, rhs, method='CH'): + return _solve_least_squares(self, rhs, method=method) + + def pinv(self, method='RD'): + return _pinv(self, method=method) + + def inverse_ADJ(self, iszerofunc=_iszero): + return _inv_ADJ(self, iszerofunc=iszerofunc) + + def inverse_BLOCK(self, iszerofunc=_iszero): + return _inv_block(self, iszerofunc=iszerofunc) + + def inverse_GE(self, iszerofunc=_iszero): + return _inv_GE(self, iszerofunc=iszerofunc) + + def inverse_LU(self, iszerofunc=_iszero): + return _inv_LU(self, iszerofunc=iszerofunc) + + def inverse_CH(self, iszerofunc=_iszero): + return _inv_CH(self, iszerofunc=iszerofunc) + + def inverse_LDL(self, iszerofunc=_iszero): + return _inv_LDL(self, iszerofunc=iszerofunc) + + def inverse_QR(self, iszerofunc=_iszero): + return _inv_QR(self, iszerofunc=iszerofunc) + + def inv(self, method=None, iszerofunc=_iszero, try_block_diag=False): + return _inv(self, method=method, iszerofunc=iszerofunc, + try_block_diag=try_block_diag) + + def connected_components(self): + return _connected_components(self) + + def connected_components_decomposition(self): + return _connected_components_decomposition(self) + + def strongly_connected_components(self): + return _strongly_connected_components(self) + + def strongly_connected_components_decomposition(self, lower=True): + return _strongly_connected_components_decomposition(self, lower=lower) + + _sage_ = Basic._sage_ + + rank_decomposition.__doc__ = _rank_decomposition.__doc__ + cholesky.__doc__ = _cholesky.__doc__ + LDLdecomposition.__doc__ = _LDLdecomposition.__doc__ + LUdecomposition.__doc__ = _LUdecomposition.__doc__ + LUdecomposition_Simple.__doc__ = _LUdecomposition_Simple.__doc__ + LUdecompositionFF.__doc__ = _LUdecompositionFF.__doc__ + singular_value_decomposition.__doc__ = _singular_value_decomposition.__doc__ + QRdecomposition.__doc__ = _QRdecomposition.__doc__ + upper_hessenberg_decomposition.__doc__ = _upper_hessenberg_decomposition.__doc__ + + diagonal_solve.__doc__ = _diagonal_solve.__doc__ + lower_triangular_solve.__doc__ = _lower_triangular_solve.__doc__ + upper_triangular_solve.__doc__ = _upper_triangular_solve.__doc__ + cholesky_solve.__doc__ = _cholesky_solve.__doc__ + LDLsolve.__doc__ = _LDLsolve.__doc__ + LUsolve.__doc__ = _LUsolve.__doc__ + QRsolve.__doc__ = _QRsolve.__doc__ + gauss_jordan_solve.__doc__ = _gauss_jordan_solve.__doc__ + pinv_solve.__doc__ = _pinv_solve.__doc__ + cramer_solve.__doc__ = _cramer_solve.__doc__ + solve.__doc__ = _solve.__doc__ + solve_least_squares.__doc__ = _solve_least_squares.__doc__ + + pinv.__doc__ = _pinv.__doc__ + inverse_ADJ.__doc__ = _inv_ADJ.__doc__ + inverse_GE.__doc__ = _inv_GE.__doc__ + inverse_LU.__doc__ = _inv_LU.__doc__ + inverse_CH.__doc__ = _inv_CH.__doc__ + inverse_LDL.__doc__ = _inv_LDL.__doc__ + inverse_QR.__doc__ = _inv_QR.__doc__ + inverse_BLOCK.__doc__ = _inv_block.__doc__ + inv.__doc__ = _inv.__doc__ + + connected_components.__doc__ = _connected_components.__doc__ + connected_components_decomposition.__doc__ = \ + _connected_components_decomposition.__doc__ + strongly_connected_components.__doc__ = \ + _strongly_connected_components.__doc__ + strongly_connected_components_decomposition.__doc__ = \ + _strongly_connected_components_decomposition.__doc__ + + +def _convert_matrix(typ, mat): + """Convert mat to a Matrix of type typ.""" + from sympy.matrices.matrixbase import MatrixBase + if getattr(mat, "is_Matrix", False) and not isinstance(mat, MatrixBase): + # This is needed for interop between Matrix and the redundant matrix + # mixin types like _MinimalMatrix etc. If anyone should happen to be + # using those then this keeps them working. Really _MinimalMatrix etc + # should be deprecated and removed though. + return typ(*mat.shape, list(mat)) + else: + return typ(mat) + + +def _has_matrix_shape(other): + shape = getattr(other, 'shape', None) + if shape is None: + return False + return isinstance(shape, tuple) and len(shape) == 2 + + +def _has_rows_cols(other): + return hasattr(other, 'rows') and hasattr(other, 'cols') + + +def _coerce_operand(self, other): + """Convert other to a Matrix, or check for possible scalar.""" + + INVALID = None, 'invalid_type' + + # Disallow mixing Matrix and Array + if isinstance(other, NDimArray): + return INVALID + + is_Matrix = getattr(other, 'is_Matrix', None) + + # Return a Matrix as-is + if is_Matrix: + return other, 'is_matrix' + + # Try to convert numpy array, mpmath matrix etc. + if is_Matrix is None: + if _has_matrix_shape(other) or _has_rows_cols(other): + return _convert_matrix(type(self), other), 'is_matrix' + + # Could be a scalar but only if not iterable... + if not isinstance(other, Iterable): + return other, 'possible_scalar' + + return INVALID + + +def classof(A, B): + """ + Get the type of the result when combining matrices of different types. + + Currently the strategy is that immutability is contagious. + + Examples + ======== + + >>> from sympy import Matrix, ImmutableMatrix + >>> from sympy.matrices.matrixbase import classof + >>> M = Matrix([[1, 2], [3, 4]]) # a Mutable Matrix + >>> IM = ImmutableMatrix([[1, 2], [3, 4]]) + >>> classof(M, IM) + + """ + priority_A = getattr(A, '_class_priority', None) + priority_B = getattr(B, '_class_priority', None) + if None not in (priority_A, priority_B): + if A._class_priority > B._class_priority: + return A.__class__ + else: + return B.__class__ + + try: + import numpy + except ImportError: + pass + else: + if isinstance(A, numpy.ndarray): + return B.__class__ + if isinstance(B, numpy.ndarray): + return A.__class__ + + raise TypeError("Incompatible classes %s, %s" % (A.__class__, B.__class__)) + + +def _unify_with_other(self, other): + """Unify self and other into a single matrix type, or check for scalar.""" + other, T = _coerce_operand(self, other) + + if T == "is_matrix": + typ = classof(self, other) + if typ != self.__class__: + self = _convert_matrix(typ, self) + if typ != other.__class__: + other = _convert_matrix(typ, other) + + return self, other, T + + +def a2idx(j, n=None): + """Return integer after making positive and validating against n.""" + if not isinstance(j, int): + jindex = getattr(j, '__index__', None) + if jindex is not None: + j = jindex() + else: + raise IndexError("Invalid index a[%r]" % (j,)) + if n is not None: + if j < 0: + j += n + if not (j >= 0 and j < n): + raise IndexError("Index out of range: a[%s]" % (j,)) + return int(j) + + +class DeferredVector(Symbol, NotIterable): # type: ignore + """A vector whose components are deferred (e.g. for use with lambdify). + + Examples + ======== + + >>> from sympy import DeferredVector, lambdify + >>> X = DeferredVector( 'X' ) + >>> X + X + >>> expr = (X[0] + 2, X[2] + 3) + >>> func = lambdify( X, expr) + >>> func( [1, 2, 3] ) + (3, 6) + """ + + def __getitem__(self, i): + if i == -0: + i = 0 + if i < 0: + raise IndexError('DeferredVector index out of range') + component_name = '%s[%d]' % (self.name, i) + return Symbol(component_name) + + def __str__(self): + return sstr(self) + + def __repr__(self): + return "DeferredVector('%s')" % self.name diff --git a/phivenv/Lib/site-packages/sympy/matrices/normalforms.py b/phivenv/Lib/site-packages/sympy/matrices/normalforms.py new file mode 100644 index 0000000000000000000000000000000000000000..61a7d26bbdb8c8a3e8e3044d39b2403b2e14b7d5 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/normalforms.py @@ -0,0 +1,156 @@ +'''Functions returning normal forms of matrices''' + +from sympy.polys.domains.integerring import ZZ +from sympy.polys.polytools import Poly +from sympy.polys.matrices import DomainMatrix +from sympy.polys.matrices.normalforms import ( + smith_normal_form as _snf, + is_smith_normal_form as _is_snf, + smith_normal_decomp as _snd, + invariant_factors as _invf, + hermite_normal_form as _hnf, + ) + + +def _to_domain(m, domain=None): + """Convert Matrix to DomainMatrix""" + # XXX: deprecated support for RawMatrix: + ring = getattr(m, "ring", None) + m = m.applyfunc(lambda e: e.as_expr() if isinstance(e, Poly) else e) + + dM = DomainMatrix.from_Matrix(m) + + domain = domain or ring + if domain is not None: + dM = dM.convert_to(domain) + return dM + + +def smith_normal_form(m, domain=None): + ''' + Return the Smith Normal Form of a matrix `m` over the ring `domain`. + This will only work if the ring is a principal ideal domain. + + Examples + ======== + + >>> from sympy import Matrix, ZZ + >>> from sympy.matrices.normalforms import smith_normal_form + >>> m = Matrix([[12, 6, 4], [3, 9, 6], [2, 16, 14]]) + >>> print(smith_normal_form(m, domain=ZZ)) + Matrix([[1, 0, 0], [0, 10, 0], [0, 0, 30]]) + + ''' + dM = _to_domain(m, domain) + return _snf(dM).to_Matrix() + + +def is_smith_normal_form(m, domain=None): + ''' + Checks that the matrix is in Smith Normal Form + ''' + dM = _to_domain(m, domain) + return _is_snf(dM) + + +def smith_normal_decomp(m, domain=None): + ''' + Return the Smith Normal Decomposition of a matrix `m` over the ring + `domain`. This will only work if the ring is a principal ideal domain. + + Examples + ======== + + >>> from sympy import Matrix, ZZ + >>> from sympy.matrices.normalforms import smith_normal_decomp + >>> m = Matrix([[12, 6, 4], [3, 9, 6], [2, 16, 14]]) + >>> a, s, t = smith_normal_decomp(m, domain=ZZ) + >>> assert a == s * m * t + ''' + dM = _to_domain(m, domain) + a, s, t = _snd(dM) + return a.to_Matrix(), s.to_Matrix(), t.to_Matrix() + + +def invariant_factors(m, domain=None): + ''' + Return the tuple of abelian invariants for a matrix `m` + (as in the Smith-Normal form) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Smith_normal_form#Algorithm + .. [2] https://web.archive.org/web/20200331143852/https://sierra.nmsu.edu/morandi/notes/SmithNormalForm.pdf + + ''' + dM = _to_domain(m, domain) + factors = _invf(dM) + factors = tuple(dM.domain.to_sympy(f) for f in factors) + # XXX: deprecated. + if hasattr(m, "ring"): + if m.ring.is_PolynomialRing: + K = m.ring + to_poly = lambda f: Poly(f, K.symbols, domain=K.domain) + factors = tuple(to_poly(f) for f in factors) + return factors + + +def hermite_normal_form(A, *, D=None, check_rank=False): + r""" + Compute the Hermite Normal Form of a Matrix *A* of integers. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.matrices.normalforms import hermite_normal_form + >>> m = Matrix([[12, 6, 4], [3, 9, 6], [2, 16, 14]]) + >>> print(hermite_normal_form(m)) + Matrix([[10, 0, 2], [0, 15, 3], [0, 0, 2]]) + + Parameters + ========== + + A : $m \times n$ ``Matrix`` of integers. + + D : int, optional + Let $W$ be the HNF of *A*. If known in advance, a positive integer *D* + being any multiple of $\det(W)$ may be provided. In this case, if *A* + also has rank $m$, then we may use an alternative algorithm that works + mod *D* in order to prevent coefficient explosion. + + check_rank : boolean, optional (default=False) + The basic assumption is that, if you pass a value for *D*, then + you already believe that *A* has rank $m$, so we do not waste time + checking it for you. If you do want this to be checked (and the + ordinary, non-modulo *D* algorithm to be used if the check fails), then + set *check_rank* to ``True``. + + Returns + ======= + + ``Matrix`` + The HNF of matrix *A*. + + Raises + ====== + + DMDomainError + If the domain of the matrix is not :ref:`ZZ`. + + DMShapeError + If the mod *D* algorithm is used but the matrix has more rows than + columns. + + References + ========== + + .. [1] Cohen, H. *A Course in Computational Algebraic Number Theory.* + (See Algorithms 2.4.5 and 2.4.8.) + + """ + # Accept any of Python int, SymPy Integer, and ZZ itself: + if D is not None and not ZZ.of_type(D): + D = ZZ(int(D)) + return _hnf(A._rep, D=D, check_rank=check_rank).to_Matrix() diff --git a/phivenv/Lib/site-packages/sympy/matrices/reductions.py b/phivenv/Lib/site-packages/sympy/matrices/reductions.py new file mode 100644 index 0000000000000000000000000000000000000000..aace8c0336358e1869a34d99f79390cc0c0163fe --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/reductions.py @@ -0,0 +1,387 @@ +from types import FunctionType + +from sympy.polys.polyerrors import CoercionFailed +from sympy.polys.domains import ZZ, QQ + +from .utilities import _get_intermediate_simp, _iszero, _dotprodsimp, _simplify +from .determinant import _find_reasonable_pivot + + +def _row_reduce_list(mat, rows, cols, one, iszerofunc, simpfunc, + normalize_last=True, normalize=True, zero_above=True): + """Row reduce a flat list representation of a matrix and return a tuple + (rref_matrix, pivot_cols, swaps) where ``rref_matrix`` is a flat list, + ``pivot_cols`` are the pivot columns and ``swaps`` are any row swaps that + were used in the process of row reduction. + + Parameters + ========== + + mat : list + list of matrix elements, must be ``rows`` * ``cols`` in length + + rows, cols : integer + number of rows and columns in flat list representation + + one : SymPy object + represents the value one, from ``Matrix.one`` + + iszerofunc : determines if an entry can be used as a pivot + + simpfunc : used to simplify elements and test if they are + zero if ``iszerofunc`` returns `None` + + normalize_last : indicates where all row reduction should + happen in a fraction-free manner and then the rows are + normalized (so that the pivots are 1), or whether + rows should be normalized along the way (like the naive + row reduction algorithm) + + normalize : whether pivot rows should be normalized so that + the pivot value is 1 + + zero_above : whether entries above the pivot should be zeroed. + If ``zero_above=False``, an echelon matrix will be returned. + """ + + def get_col(i): + return mat[i::cols] + + def row_swap(i, j): + mat[i*cols:(i + 1)*cols], mat[j*cols:(j + 1)*cols] = \ + mat[j*cols:(j + 1)*cols], mat[i*cols:(i + 1)*cols] + + def cross_cancel(a, i, b, j): + """Does the row op row[i] = a*row[i] - b*row[j]""" + q = (j - i)*cols + for p in range(i*cols, (i + 1)*cols): + mat[p] = isimp(a*mat[p] - b*mat[p + q]) + + isimp = _get_intermediate_simp(_dotprodsimp) + piv_row, piv_col = 0, 0 + pivot_cols = [] + swaps = [] + + # use a fraction free method to zero above and below each pivot + while piv_col < cols and piv_row < rows: + pivot_offset, pivot_val, \ + assumed_nonzero, newly_determined = _find_reasonable_pivot( + get_col(piv_col)[piv_row:], iszerofunc, simpfunc) + + # _find_reasonable_pivot may have simplified some things + # in the process. Let's not let them go to waste + for (offset, val) in newly_determined: + offset += piv_row + mat[offset*cols + piv_col] = val + + if pivot_offset is None: + piv_col += 1 + continue + + pivot_cols.append(piv_col) + if pivot_offset != 0: + row_swap(piv_row, pivot_offset + piv_row) + swaps.append((piv_row, pivot_offset + piv_row)) + + # if we aren't normalizing last, we normalize + # before we zero the other rows + if normalize_last is False: + i, j = piv_row, piv_col + mat[i*cols + j] = one + for p in range(i*cols + j + 1, (i + 1)*cols): + mat[p] = isimp(mat[p] / pivot_val) + # after normalizing, the pivot value is 1 + pivot_val = one + + # zero above and below the pivot + for row in range(rows): + # don't zero our current row + if row == piv_row: + continue + # don't zero above the pivot unless we're told. + if zero_above is False and row < piv_row: + continue + # if we're already a zero, don't do anything + val = mat[row*cols + piv_col] + if iszerofunc(val): + continue + + cross_cancel(pivot_val, row, val, piv_row) + piv_row += 1 + + # normalize each row + if normalize_last is True and normalize is True: + for piv_i, piv_j in enumerate(pivot_cols): + pivot_val = mat[piv_i*cols + piv_j] + mat[piv_i*cols + piv_j] = one + for p in range(piv_i*cols + piv_j + 1, (piv_i + 1)*cols): + mat[p] = isimp(mat[p] / pivot_val) + + return mat, tuple(pivot_cols), tuple(swaps) + + +# This functions is a candidate for caching if it gets implemented for matrices. +def _row_reduce(M, iszerofunc, simpfunc, normalize_last=True, + normalize=True, zero_above=True): + + mat, pivot_cols, swaps = _row_reduce_list(list(M), M.rows, M.cols, M.one, + iszerofunc, simpfunc, normalize_last=normalize_last, + normalize=normalize, zero_above=zero_above) + + return M._new(M.rows, M.cols, mat), pivot_cols, swaps + + +def _is_echelon(M, iszerofunc=_iszero): + """Returns `True` if the matrix is in echelon form. That is, all rows of + zeros are at the bottom, and below each leading non-zero in a row are + exclusively zeros.""" + + if M.rows <= 0 or M.cols <= 0: + return True + + zeros_below = all(iszerofunc(t) for t in M[1:, 0]) + + if iszerofunc(M[0, 0]): + return zeros_below and _is_echelon(M[:, 1:], iszerofunc) + + return zeros_below and _is_echelon(M[1:, 1:], iszerofunc) + + +def _echelon_form(M, iszerofunc=_iszero, simplify=False, with_pivots=False): + """Returns a matrix row-equivalent to ``M`` that is in echelon form. Note + that echelon form of a matrix is *not* unique, however, properties like the + row space and the null space are preserved. + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix([[1, 2], [3, 4]]) + >>> M.echelon_form() + Matrix([ + [1, 2], + [0, -2]]) + """ + + simpfunc = simplify if isinstance(simplify, FunctionType) else _simplify + + mat, pivots, _ = _row_reduce(M, iszerofunc, simpfunc, + normalize_last=True, normalize=False, zero_above=False) + + if with_pivots: + return mat, pivots + + return mat + + +# This functions is a candidate for caching if it gets implemented for matrices. +def _rank(M, iszerofunc=_iszero, simplify=False): + """Returns the rank of a matrix. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.abc import x + >>> m = Matrix([[1, 2], [x, 1 - 1/x]]) + >>> m.rank() + 2 + >>> n = Matrix(3, 3, range(1, 10)) + >>> n.rank() + 2 + """ + + def _permute_complexity_right(M, iszerofunc): + """Permute columns with complicated elements as + far right as they can go. Since the ``sympy`` row reduction + algorithms start on the left, having complexity right-shifted + speeds things up. + + Returns a tuple (mat, perm) where perm is a permutation + of the columns to perform to shift the complex columns right, and mat + is the permuted matrix.""" + + def complexity(i): + # the complexity of a column will be judged by how many + # element's zero-ness cannot be determined + return sum(1 if iszerofunc(e) is None else 0 for e in M[:, i]) + + complex = [(complexity(i), i) for i in range(M.cols)] + perm = [j for (i, j) in sorted(complex)] + + return (M.permute(perm, orientation='cols'), perm) + + simpfunc = simplify if isinstance(simplify, FunctionType) else _simplify + + # for small matrices, we compute the rank explicitly + # if is_zero on elements doesn't answer the question + # for small matrices, we fall back to the full routine. + if M.rows <= 0 or M.cols <= 0: + return 0 + + if M.rows <= 1 or M.cols <= 1: + zeros = [iszerofunc(x) for x in M] + + if False in zeros: + return 1 + + if M.rows == 2 and M.cols == 2: + zeros = [iszerofunc(x) for x in M] + + if False not in zeros and None not in zeros: + return 0 + + d = M.det() + + if iszerofunc(d) and False in zeros: + return 1 + if iszerofunc(d) is False: + return 2 + + mat, _ = _permute_complexity_right(M, iszerofunc=iszerofunc) + _, pivots, _ = _row_reduce(mat, iszerofunc, simpfunc, normalize_last=True, + normalize=False, zero_above=False) + + return len(pivots) + + +def _to_DM_ZZ_QQ(M): + # We have to test for _rep here because there are tests that otherwise fail + # with e.g. "AttributeError: 'SubspaceOnlyMatrix' object has no attribute + # '_rep'." There is almost certainly no value in such tests. The + # presumption seems to be that someone could create a new class by + # inheriting some of the Matrix classes and not the full set that is used + # by the standard Matrix class but if anyone tried that it would fail in + # many ways. + if not hasattr(M, '_rep'): + return None + + rep = M._rep + K = rep.domain + + if K.is_ZZ: + return rep + elif K.is_QQ: + try: + return rep.convert_to(ZZ) + except CoercionFailed: + return rep + else: + if not all(e.is_Rational for e in M): + return None + try: + return rep.convert_to(ZZ) + except CoercionFailed: + return rep.convert_to(QQ) + + +def _rref_dm(dM): + """Compute the reduced row echelon form of a DomainMatrix.""" + K = dM.domain + + if K.is_ZZ: + dM_rref, den, pivots = dM.rref_den(keep_domain=False) + dM_rref = dM_rref.to_field() / den + elif K.is_QQ: + dM_rref, pivots = dM.rref() + else: + assert False # pragma: no cover + + M_rref = dM_rref.to_Matrix() + + return M_rref, pivots + + +def _rref(M, iszerofunc=_iszero, simplify=False, pivots=True, + normalize_last=True): + """Return reduced row-echelon form of matrix and indices + of pivot vars. + + Parameters + ========== + + iszerofunc : Function + A function used for detecting whether an element can + act as a pivot. ``lambda x: x.is_zero`` is used by default. + + simplify : Function + A function used to simplify elements when looking for a pivot. + By default SymPy's ``simplify`` is used. + + pivots : True or False + If ``True``, a tuple containing the row-reduced matrix and a tuple + of pivot columns is returned. If ``False`` just the row-reduced + matrix is returned. + + normalize_last : True or False + If ``True``, no pivots are normalized to `1` until after all + entries above and below each pivot are zeroed. This means the row + reduction algorithm is fraction free until the very last step. + If ``False``, the naive row reduction procedure is used where + each pivot is normalized to be `1` before row operations are + used to zero above and below the pivot. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.abc import x + >>> m = Matrix([[1, 2], [x, 1 - 1/x]]) + >>> m.rref() + (Matrix([ + [1, 0], + [0, 1]]), (0, 1)) + >>> rref_matrix, rref_pivots = m.rref() + >>> rref_matrix + Matrix([ + [1, 0], + [0, 1]]) + >>> rref_pivots + (0, 1) + + ``iszerofunc`` can correct rounding errors in matrices with float + values. In the following example, calling ``rref()`` leads to + floating point errors, incorrectly row reducing the matrix. + ``iszerofunc= lambda x: abs(x) < 1e-9`` sets sufficiently small numbers + to zero, avoiding this error. + + >>> m = Matrix([[0.9, -0.1, -0.2, 0], [-0.8, 0.9, -0.4, 0], [-0.1, -0.8, 0.6, 0]]) + >>> m.rref() + (Matrix([ + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, 0]]), (0, 1, 2)) + >>> m.rref(iszerofunc=lambda x:abs(x)<1e-9) + (Matrix([ + [1, 0, -0.301369863013699, 0], + [0, 1, -0.712328767123288, 0], + [0, 0, 0, 0]]), (0, 1)) + + Notes + ===== + + The default value of ``normalize_last=True`` can provide significant + speedup to row reduction, especially on matrices with symbols. However, + if you depend on the form row reduction algorithm leaves entries + of the matrix, set ``normalize_last=False`` + """ + # Try to use DomainMatrix for ZZ or QQ + dM = _to_DM_ZZ_QQ(M) + + if dM is not None: + # Use DomainMatrix for ZZ or QQ + mat, pivot_cols = _rref_dm(dM) + else: + # Use the generic Matrix routine. + if isinstance(simplify, FunctionType): + simpfunc = simplify + else: + simpfunc = _simplify + + mat, pivot_cols, _ = _row_reduce(M, iszerofunc, simpfunc, + normalize_last, normalize=True, zero_above=True) + + if pivots: + return mat, pivot_cols + else: + return mat diff --git a/phivenv/Lib/site-packages/sympy/matrices/repmatrix.py b/phivenv/Lib/site-packages/sympy/matrices/repmatrix.py new file mode 100644 index 0000000000000000000000000000000000000000..57f32fae34786f68f579fad7de38c9e3cf43e131 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/repmatrix.py @@ -0,0 +1,1034 @@ +from collections import defaultdict + +from operator import index as index_ + +from sympy.core.expr import Expr +from sympy.core.kind import Kind, NumberKind, UndefinedKind +from sympy.core.numbers import Integer, Rational +from sympy.core.sympify import _sympify, SympifyError +from sympy.core.singleton import S +from sympy.polys.domains import ZZ, QQ, GF, EXRAW +from sympy.polys.matrices import DomainMatrix +from sympy.polys.matrices.exceptions import DMNonInvertibleMatrixError +from sympy.polys.polyerrors import CoercionFailed, NotInvertible +from sympy.utilities.exceptions import sympy_deprecation_warning +from sympy.utilities.iterables import is_sequence +from sympy.utilities.misc import filldedent, as_int + +from .exceptions import ShapeError, NonSquareMatrixError, NonInvertibleMatrixError +from .matrixbase import classof, MatrixBase +from .kind import MatrixKind + + +class RepMatrix(MatrixBase): + """Matrix implementation based on DomainMatrix as an internal representation. + + The RepMatrix class is a superclass for Matrix, ImmutableMatrix, + SparseMatrix and ImmutableSparseMatrix which are the main usable matrix + classes in SymPy. Most methods on this class are simply forwarded to + DomainMatrix. + """ + + # + # MatrixBase is the common superclass for all of the usable explicit matrix + # classes in SymPy. The idea is that MatrixBase is an abstract class though + # and that subclasses will implement the lower-level methods. + # + # RepMatrix is a subclass of MatrixBase that uses DomainMatrix as an + # internal representation and delegates lower-level methods to + # DomainMatrix. All of SymPy's standard explicit matrix classes subclass + # RepMatrix and so use DomainMatrix internally. + # + # A RepMatrix uses an internal DomainMatrix with the domain set to ZZ, QQ + # or EXRAW. The EXRAW domain is equivalent to the previous implementation + # of Matrix that used Expr for the elements. The ZZ and QQ domains are used + # when applicable just because they are compatible with the previous + # implementation but are much more efficient. Other domains such as QQ[x] + # are not used because they differ from Expr in some way (e.g. automatic + # expansion of powers and products). + # + + _rep: DomainMatrix + + def __eq__(self, other): + # Skip sympify for mutable matrices... + if not isinstance(other, RepMatrix): + try: + other = _sympify(other) + except SympifyError: + return NotImplemented + if not isinstance(other, RepMatrix): + return NotImplemented + + return self._rep.unify_eq(other._rep) + + def to_DM(self, domain=None, **kwargs): + """Convert to a :class:`~.DomainMatrix`. + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix([[1, 2], [3, 4]]) + >>> M.to_DM() + DomainMatrix({0: {0: 1, 1: 2}, 1: {0: 3, 1: 4}}, (2, 2), ZZ) + + The :meth:`DomainMatrix.to_Matrix` method can be used to convert back: + + >>> M.to_DM().to_Matrix() == M + True + + The domain can be given explicitly or otherwise it will be chosen by + :func:`construct_domain`. Any keyword arguments (besides ``domain``) + are passed to :func:`construct_domain`: + + >>> from sympy import QQ, symbols + >>> x = symbols('x') + >>> M = Matrix([[x, 1], [1, x]]) + >>> M + Matrix([ + [x, 1], + [1, x]]) + >>> M.to_DM().domain + ZZ[x] + >>> M.to_DM(field=True).domain + ZZ(x) + >>> M.to_DM(domain=QQ[x]).domain + QQ[x] + + See Also + ======== + + DomainMatrix + DomainMatrix.to_Matrix + DomainMatrix.convert_to + DomainMatrix.choose_domain + construct_domain + """ + if domain is not None: + if kwargs: + raise TypeError("Options cannot be used with domain parameter") + return self._rep.convert_to(domain) + + rep = self._rep + dom = rep.domain + + # If the internal DomainMatrix is already ZZ or QQ then we can maybe + # bypass calling construct_domain or performing any conversions. Some + # kwargs might affect this though e.g. field=True (not sure if there + # are others). + if not kwargs: + if dom.is_ZZ: + return rep.copy() + elif dom.is_QQ: + # All elements might be integers + try: + return rep.convert_to(ZZ) + except CoercionFailed: + pass + return rep.copy() + + # Let construct_domain choose a domain + rep_dom = rep.choose_domain(**kwargs) + + # XXX: There should be an option to construct_domain to choose EXRAW + # instead of EX. At least converting to EX does not initially trigger + # EX.simplify which is what we want here but should probably be + # considered a bug in EX. Perhaps also this could be handled in + # DomainMatrix.choose_domain rather than here... + if rep_dom.domain.is_EX: + rep_dom = rep_dom.convert_to(EXRAW) + + return rep_dom + + @classmethod + def _unify_element_sympy(cls, rep, element): + domain = rep.domain + element = _sympify(element) + + if domain != EXRAW: + # The domain can only be ZZ, QQ or EXRAW + if element.is_Integer: + new_domain = domain + elif element.is_Rational: + new_domain = QQ + else: + new_domain = EXRAW + + # XXX: This converts the domain for all elements in the matrix + # which can be slow. This happens e.g. if __setitem__ changes one + # element to something that does not fit in the domain + if new_domain != domain: + rep = rep.convert_to(new_domain) + domain = new_domain + + if domain != EXRAW: + element = new_domain.from_sympy(element) + + if domain == EXRAW and not isinstance(element, Expr): + sympy_deprecation_warning( + """ + non-Expr objects in a Matrix is deprecated. Matrix represents + a mathematical matrix. To represent a container of non-numeric + entities, Use a list of lists, TableForm, NumPy array, or some + other data structure instead. + """, + deprecated_since_version="1.9", + active_deprecations_target="deprecated-non-expr-in-matrix", + stacklevel=4, + ) + + return rep, element + + @classmethod + def _dod_to_DomainMatrix(cls, rows, cols, dod, types): + + if not all(issubclass(typ, Expr) for typ in types): + sympy_deprecation_warning( + """ + non-Expr objects in a Matrix is deprecated. Matrix represents + a mathematical matrix. To represent a container of non-numeric + entities, Use a list of lists, TableForm, NumPy array, or some + other data structure instead. + """, + deprecated_since_version="1.9", + active_deprecations_target="deprecated-non-expr-in-matrix", + stacklevel=6, + ) + + rep = DomainMatrix(dod, (rows, cols), EXRAW) + + if all(issubclass(typ, Rational) for typ in types): + if all(issubclass(typ, Integer) for typ in types): + rep = rep.convert_to(ZZ) + else: + rep = rep.convert_to(QQ) + + return rep + + @classmethod + def _flat_list_to_DomainMatrix(cls, rows, cols, flat_list): + + elements_dod = defaultdict(dict) + for n, element in enumerate(flat_list): + if element != 0: + i, j = divmod(n, cols) + elements_dod[i][j] = element + + types = set(map(type, flat_list)) + + rep = cls._dod_to_DomainMatrix(rows, cols, elements_dod, types) + return rep + + @classmethod + def _smat_to_DomainMatrix(cls, rows, cols, smat): + + elements_dod = defaultdict(dict) + for (i, j), element in smat.items(): + if element != 0: + elements_dod[i][j] = element + + types = set(map(type, smat.values())) + + rep = cls._dod_to_DomainMatrix(rows, cols, elements_dod, types) + return rep + + def flat(self): + return self._rep.to_sympy().to_list_flat() + + def _eval_tolist(self): + return self._rep.to_sympy().to_list() + + def _eval_todok(self): + return self._rep.to_sympy().to_dok() + + @classmethod + def _eval_from_dok(cls, rows, cols, dok): + return cls._fromrep(cls._smat_to_DomainMatrix(rows, cols, dok)) + + def _eval_values(self): + return list(self._eval_iter_values()) + + def _eval_iter_values(self): + rep = self._rep + K = rep.domain + values = rep.iter_values() + if not K.is_EXRAW: + values = map(K.to_sympy, values) + return values + + def _eval_iter_items(self): + rep = self._rep + K = rep.domain + to_sympy = K.to_sympy + items = rep.iter_items() + if not K.is_EXRAW: + items = ((i, to_sympy(v)) for i, v in items) + return items + + def copy(self): + return self._fromrep(self._rep.copy()) + + @property + def kind(self) -> MatrixKind: + domain = self._rep.domain + element_kind: Kind + if domain in (ZZ, QQ): + element_kind = NumberKind + elif domain == EXRAW: + kinds = {e.kind for e in self.values()} + if len(kinds) == 1: + [element_kind] = kinds + else: + element_kind = UndefinedKind + else: # pragma: no cover + raise RuntimeError("Domain should only be ZZ, QQ or EXRAW") + return MatrixKind(element_kind) + + def _eval_has(self, *patterns): + # if the matrix has any zeros, see if S.Zero + # has the pattern. If _smat is full length, + # the matrix has no zeros. + zhas = False + dok = self.todok() + if len(dok) != self.rows*self.cols: + zhas = S.Zero.has(*patterns) + return zhas or any(value.has(*patterns) for value in dok.values()) + + def _eval_is_Identity(self): + if not all(self[i, i] == 1 for i in range(self.rows)): + return False + return len(self.todok()) == self.rows + + def _eval_is_symmetric(self, simpfunc): + diff = (self - self.T).applyfunc(simpfunc) + return len(diff.values()) == 0 + + def _eval_transpose(self): + """Returns the transposed SparseMatrix of this SparseMatrix. + + Examples + ======== + + >>> from sympy import SparseMatrix + >>> a = SparseMatrix(((1, 2), (3, 4))) + >>> a + Matrix([ + [1, 2], + [3, 4]]) + >>> a.T + Matrix([ + [1, 3], + [2, 4]]) + """ + return self._fromrep(self._rep.transpose()) + + def _eval_col_join(self, other): + return self._fromrep(self._rep.vstack(other._rep)) + + def _eval_row_join(self, other): + return self._fromrep(self._rep.hstack(other._rep)) + + def _eval_extract(self, rowsList, colsList): + return self._fromrep(self._rep.extract(rowsList, colsList)) + + def __getitem__(self, key): + return _getitem_RepMatrix(self, key) + + @classmethod + def _eval_zeros(cls, rows, cols): + rep = DomainMatrix.zeros((rows, cols), ZZ) + return cls._fromrep(rep) + + @classmethod + def _eval_eye(cls, rows, cols): + rep = DomainMatrix.eye((rows, cols), ZZ) + return cls._fromrep(rep) + + def _eval_add(self, other): + return classof(self, other)._fromrep(self._rep + other._rep) + + def _eval_matrix_mul(self, other): + return classof(self, other)._fromrep(self._rep * other._rep) + + def _eval_matrix_mul_elementwise(self, other): + selfrep, otherrep = self._rep.unify(other._rep) + newrep = selfrep.mul_elementwise(otherrep) + return classof(self, other)._fromrep(newrep) + + def _eval_scalar_mul(self, other): + rep, other = self._unify_element_sympy(self._rep, other) + return self._fromrep(rep.scalarmul(other)) + + def _eval_scalar_rmul(self, other): + rep, other = self._unify_element_sympy(self._rep, other) + return self._fromrep(rep.rscalarmul(other)) + + def _eval_Abs(self): + return self._fromrep(self._rep.applyfunc(abs)) + + def _eval_conjugate(self): + rep = self._rep + domain = rep.domain + if domain in (ZZ, QQ): + return self.copy() + else: + return self._fromrep(rep.applyfunc(lambda e: e.conjugate())) + + def equals(self, other, failing_expression=False): + """Applies ``equals`` to corresponding elements of the matrices, + trying to prove that the elements are equivalent, returning True + if they are, False if any pair is not, and None (or the first + failing expression if failing_expression is True) if it cannot + be decided if the expressions are equivalent or not. This is, in + general, an expensive operation. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.abc import x + >>> A = Matrix([x*(x - 1), 0]) + >>> B = Matrix([x**2 - x, 0]) + >>> A == B + False + >>> A.simplify() == B.simplify() + True + >>> A.equals(B) + True + >>> A.equals(2) + False + + See Also + ======== + sympy.core.expr.Expr.equals + """ + if self.shape != getattr(other, 'shape', None): + return False + + rv = True + for i in range(self.rows): + for j in range(self.cols): + ans = self[i, j].equals(other[i, j], failing_expression) + if ans is False: + return False + elif ans is not True and rv is True: + rv = ans + return rv + + def inv_mod(M, m): + r""" + Returns the inverse of the integer matrix ``M`` modulo ``m``. + + Examples + ======== + + >>> from sympy import Matrix + >>> A = Matrix(2, 2, [1, 2, 3, 4]) + >>> A.inv_mod(5) + Matrix([ + [3, 1], + [4, 2]]) + >>> A.inv_mod(3) + Matrix([ + [1, 1], + [0, 1]]) + + """ + + if not M.is_square: + raise NonSquareMatrixError() + + try: + m = as_int(m) + except ValueError: + raise TypeError("inv_mod: modulus m must be an integer") + + K = GF(m, symmetric=False) + + try: + dM = M.to_DM(K) + except CoercionFailed: + raise ValueError("inv_mod: matrix entries must be integers") + + if K.is_Field: + try: + dMi = dM.inv() + except DMNonInvertibleMatrixError as exc: + msg = f'Matrix is not invertible (mod {m})' + raise NonInvertibleMatrixError(msg) from exc + else: + dMadj, det = dM.adj_det() + try: + detinv = 1 / det + except NotInvertible: + msg = f'Matrix is not invertible (mod {m})' + raise NonInvertibleMatrixError(msg) + dMi = dMadj * detinv + + return dMi.to_Matrix() + + def lll(self, delta=0.75): + """LLL-reduced basis for the rowspace of a matrix of integers. + + Performs the Lenstra–Lenstra–Lovász (LLL) basis reduction algorithm. + + The implementation is provided by :class:`~DomainMatrix`. See + :meth:`~DomainMatrix.lll` for more details. + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix([[1, 0, 0, 0, -20160], + ... [0, 1, 0, 0, 33768], + ... [0, 0, 1, 0, 39578], + ... [0, 0, 0, 1, 47757]]) + >>> M.lll() + Matrix([ + [ 10, -3, -2, 8, -4], + [ 3, -9, 8, 1, -11], + [ -3, 13, -9, -3, -9], + [-12, -7, -11, 9, -1]]) + + See Also + ======== + + lll_transform + sympy.polys.matrices.domainmatrix.DomainMatrix.lll + """ + delta = QQ.from_sympy(_sympify(delta)) + dM = self._rep.convert_to(ZZ) + basis = dM.lll(delta=delta) + return self._fromrep(basis) + + def lll_transform(self, delta=0.75): + """LLL-reduced basis and transformation matrix. + + Performs the Lenstra–Lenstra–Lovász (LLL) basis reduction algorithm. + + The implementation is provided by :class:`~DomainMatrix`. See + :meth:`~DomainMatrix.lll_transform` for more details. + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix([[1, 0, 0, 0, -20160], + ... [0, 1, 0, 0, 33768], + ... [0, 0, 1, 0, 39578], + ... [0, 0, 0, 1, 47757]]) + >>> B, T = M.lll_transform() + >>> B + Matrix([ + [ 10, -3, -2, 8, -4], + [ 3, -9, 8, 1, -11], + [ -3, 13, -9, -3, -9], + [-12, -7, -11, 9, -1]]) + >>> T + Matrix([ + [ 10, -3, -2, 8], + [ 3, -9, 8, 1], + [ -3, 13, -9, -3], + [-12, -7, -11, 9]]) + + The transformation matrix maps the original basis to the LLL-reduced + basis: + + >>> T * M == B + True + + See Also + ======== + + lll + sympy.polys.matrices.domainmatrix.DomainMatrix.lll_transform + """ + delta = QQ.from_sympy(_sympify(delta)) + dM = self._rep.convert_to(ZZ) + basis, transform = dM.lll_transform(delta=delta) + B = self._fromrep(basis) + T = self._fromrep(transform) + return B, T + + +class MutableRepMatrix(RepMatrix): + """Mutable matrix based on DomainMatrix as the internal representation""" + + # + # MutableRepMatrix is a subclass of RepMatrix that adds/overrides methods + # to make the instances mutable. MutableRepMatrix is a superclass for both + # MutableDenseMatrix and MutableSparseMatrix. + # + + is_zero = False + + def __new__(cls, *args, **kwargs): + return cls._new(*args, **kwargs) + + @classmethod + def _new(cls, *args, copy=True, **kwargs): + if copy is False: + # The input was rows, cols, [list]. + # It should be used directly without creating a copy. + if len(args) != 3: + raise TypeError("'copy=False' requires a matrix be initialized as rows,cols,[list]") + rows, cols, flat_list = args + else: + rows, cols, flat_list = cls._handle_creation_inputs(*args, **kwargs) + flat_list = list(flat_list) # create a shallow copy + + rep = cls._flat_list_to_DomainMatrix(rows, cols, flat_list) + + return cls._fromrep(rep) + + @classmethod + def _fromrep(cls, rep): + obj = super().__new__(cls) + obj.rows, obj.cols = rep.shape + obj._rep = rep + return obj + + def copy(self): + return self._fromrep(self._rep.copy()) + + def as_mutable(self): + return self.copy() + + def __setitem__(self, key, value): + """ + + Examples + ======== + + >>> from sympy import Matrix, I, zeros, ones + >>> m = Matrix(((1, 2+I), (3, 4))) + >>> m + Matrix([ + [1, 2 + I], + [3, 4]]) + >>> m[1, 0] = 9 + >>> m + Matrix([ + [1, 2 + I], + [9, 4]]) + >>> m[1, 0] = [[0, 1]] + + To replace row r you assign to position r*m where m + is the number of columns: + + >>> M = zeros(4) + >>> m = M.cols + >>> M[3*m] = ones(1, m)*2; M + Matrix([ + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0], + [2, 2, 2, 2]]) + + And to replace column c you can assign to position c: + + >>> M[2] = ones(m, 1)*4; M + Matrix([ + [0, 0, 4, 0], + [0, 0, 4, 0], + [0, 0, 4, 0], + [2, 2, 4, 2]]) + """ + rv = self._setitem(key, value) + if rv is not None: + i, j, value = rv + self._rep, value = self._unify_element_sympy(self._rep, value) + self._rep.rep.setitem(i, j, value) + + def _eval_col_del(self, col): + self._rep = DomainMatrix.hstack(self._rep[:,:col], self._rep[:,col+1:]) + self.cols -= 1 + + def _eval_row_del(self, row): + self._rep = DomainMatrix.vstack(self._rep[:row,:], self._rep[row+1:, :]) + self.rows -= 1 + + def _eval_col_insert(self, col, other): + other = self._new(other) + return self.hstack(self[:,:col], other, self[:,col:]) + + def _eval_row_insert(self, row, other): + other = self._new(other) + return self.vstack(self[:row,:], other, self[row:,:]) + + def col_op(self, j, f): + """In-place operation on col j using two-arg functor whose args are + interpreted as (self[i, j], i). + + Examples + ======== + + >>> from sympy import eye + >>> M = eye(3) + >>> M.col_op(1, lambda v, i: v + 2*M[i, 0]); M + Matrix([ + [1, 2, 0], + [0, 1, 0], + [0, 0, 1]]) + + See Also + ======== + col + row_op + """ + for i in range(self.rows): + self[i, j] = f(self[i, j], i) + + def col_swap(self, i, j): + """Swap the two given columns of the matrix in-place. + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix([[1, 0], [1, 0]]) + >>> M + Matrix([ + [1, 0], + [1, 0]]) + >>> M.col_swap(0, 1) + >>> M + Matrix([ + [0, 1], + [0, 1]]) + + See Also + ======== + + col + row_swap + """ + for k in range(0, self.rows): + self[k, i], self[k, j] = self[k, j], self[k, i] + + def row_op(self, i, f): + """In-place operation on row ``i`` using two-arg functor whose args are + interpreted as ``(self[i, j], j)``. + + Examples + ======== + + >>> from sympy import eye + >>> M = eye(3) + >>> M.row_op(1, lambda v, j: v + 2*M[0, j]); M + Matrix([ + [1, 0, 0], + [2, 1, 0], + [0, 0, 1]]) + + See Also + ======== + row + zip_row_op + col_op + + """ + for j in range(self.cols): + self[i, j] = f(self[i, j], j) + + #The next three methods give direct support for the most common row operations inplace. + def row_mult(self,i,factor): + """Multiply the given row by the given factor in-place. + + Examples + ======== + + >>> from sympy import eye + >>> M = eye(3) + >>> M.row_mult(1,7); M + Matrix([ + [1, 0, 0], + [0, 7, 0], + [0, 0, 1]]) + + """ + for j in range(self.cols): + self[i,j] *= factor + + def row_add(self,s,t,k): + """Add k times row s (source) to row t (target) in place. + + Examples + ======== + + >>> from sympy import eye + >>> M = eye(3) + >>> M.row_add(0, 2,3); M + Matrix([ + [1, 0, 0], + [0, 1, 0], + [3, 0, 1]]) + """ + + for j in range(self.cols): + self[t,j] += k*self[s,j] + + def row_swap(self, i, j): + """Swap the two given rows of the matrix in-place. + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix([[0, 1], [1, 0]]) + >>> M + Matrix([ + [0, 1], + [1, 0]]) + >>> M.row_swap(0, 1) + >>> M + Matrix([ + [1, 0], + [0, 1]]) + + See Also + ======== + + row + col_swap + """ + for k in range(0, self.cols): + self[i, k], self[j, k] = self[j, k], self[i, k] + + def zip_row_op(self, i, k, f): + """In-place operation on row ``i`` using two-arg functor whose args are + interpreted as ``(self[i, j], self[k, j])``. + + Examples + ======== + + >>> from sympy import eye + >>> M = eye(3) + >>> M.zip_row_op(1, 0, lambda v, u: v + 2*u); M + Matrix([ + [1, 0, 0], + [2, 1, 0], + [0, 0, 1]]) + + See Also + ======== + row + row_op + col_op + + """ + for j in range(self.cols): + self[i, j] = f(self[i, j], self[k, j]) + + def copyin_list(self, key, value): + """Copy in elements from a list. + + Parameters + ========== + + key : slice + The section of this matrix to replace. + value : iterable + The iterable to copy values from. + + Examples + ======== + + >>> from sympy import eye + >>> I = eye(3) + >>> I[:2, 0] = [1, 2] # col + >>> I + Matrix([ + [1, 0, 0], + [2, 1, 0], + [0, 0, 1]]) + >>> I[1, :2] = [[3, 4]] + >>> I + Matrix([ + [1, 0, 0], + [3, 4, 0], + [0, 0, 1]]) + + See Also + ======== + + copyin_matrix + """ + if not is_sequence(value): + raise TypeError("`value` must be an ordered iterable, not %s." % type(value)) + return self.copyin_matrix(key, type(self)(value)) + + def copyin_matrix(self, key, value): + """Copy in values from a matrix into the given bounds. + + Parameters + ========== + + key : slice + The section of this matrix to replace. + value : Matrix + The matrix to copy values from. + + Examples + ======== + + >>> from sympy import Matrix, eye + >>> M = Matrix([[0, 1], [2, 3], [4, 5]]) + >>> I = eye(3) + >>> I[:3, :2] = M + >>> I + Matrix([ + [0, 1, 0], + [2, 3, 0], + [4, 5, 1]]) + >>> I[0, 1] = M + >>> I + Matrix([ + [0, 0, 1], + [2, 2, 3], + [4, 4, 5]]) + + See Also + ======== + + copyin_list + """ + rlo, rhi, clo, chi = self.key2bounds(key) + shape = value.shape + dr, dc = rhi - rlo, chi - clo + if shape != (dr, dc): + raise ShapeError(filldedent("The Matrix `value` doesn't have the " + "same dimensions " + "as the in sub-Matrix given by `key`.")) + + for i in range(value.rows): + for j in range(value.cols): + self[i + rlo, j + clo] = value[i, j] + + def fill(self, value): + """Fill self with the given value. + + Notes + ===== + + Unless many values are going to be deleted (i.e. set to zero) + this will create a matrix that is slower than a dense matrix in + operations. + + Examples + ======== + + >>> from sympy import SparseMatrix + >>> M = SparseMatrix.zeros(3); M + Matrix([ + [0, 0, 0], + [0, 0, 0], + [0, 0, 0]]) + >>> M.fill(1); M + Matrix([ + [1, 1, 1], + [1, 1, 1], + [1, 1, 1]]) + + See Also + ======== + + zeros + ones + """ + value = _sympify(value) + if not value: + self._rep = DomainMatrix.zeros(self.shape, EXRAW) + else: + elements_dod = {i: dict.fromkeys(range(self.cols), value) for i in range(self.rows)} + self._rep = DomainMatrix(elements_dod, self.shape, EXRAW) + + +def _getitem_RepMatrix(self, key): + """Return portion of self defined by key. If the key involves a slice + then a list will be returned (if key is a single slice) or a matrix + (if key was a tuple involving a slice). + + Examples + ======== + + >>> from sympy import Matrix, I + >>> m = Matrix([ + ... [1, 2 + I], + ... [3, 4 ]]) + + If the key is a tuple that does not involve a slice then that element + is returned: + + >>> m[1, 0] + 3 + + When a tuple key involves a slice, a matrix is returned. Here, the + first column is selected (all rows, column 0): + + >>> m[:, 0] + Matrix([ + [1], + [3]]) + + If the slice is not a tuple then it selects from the underlying + list of elements that are arranged in row order and a list is + returned if a slice is involved: + + >>> m[0] + 1 + >>> m[::2] + [1, 3] + """ + if isinstance(key, tuple): + i, j = key + try: + return self._rep.getitem_sympy(index_(i), index_(j)) + except (TypeError, IndexError): + if (isinstance(i, Expr) and not i.is_number) or (isinstance(j, Expr) and not j.is_number): + if ((j < 0) is True) or ((j >= self.shape[1]) is True) or\ + ((i < 0) is True) or ((i >= self.shape[0]) is True): + raise ValueError("index out of boundary") + from sympy.matrices.expressions.matexpr import MatrixElement + return MatrixElement(self, i, j) + + if isinstance(i, slice): + i = range(self.rows)[i] + elif is_sequence(i): + pass + else: + i = [i] + if isinstance(j, slice): + j = range(self.cols)[j] + elif is_sequence(j): + pass + else: + j = [j] + return self.extract(i, j) + + else: + # Index/slice like a flattened list + rows, cols = self.shape + + # Raise the appropriate exception: + if not rows * cols: + return [][key] + + rep = self._rep.rep + domain = rep.domain + is_slice = isinstance(key, slice) + + if is_slice: + values = [rep.getitem(*divmod(n, cols)) for n in range(rows * cols)[key]] + else: + values = [rep.getitem(*divmod(index_(key), cols))] + + if domain != EXRAW: + to_sympy = domain.to_sympy + values = [to_sympy(val) for val in values] + + if is_slice: + return values + else: + return values[0] diff --git a/phivenv/Lib/site-packages/sympy/matrices/solvers.py b/phivenv/Lib/site-packages/sympy/matrices/solvers.py new file mode 100644 index 0000000000000000000000000000000000000000..1fba990df80dcf46304ecb1412f5382f60948c51 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/solvers.py @@ -0,0 +1,942 @@ +from sympy.core.function import expand_mul +from sympy.core.symbol import Dummy, uniquely_named_symbol, symbols +from sympy.utilities.iterables import numbered_symbols + +from .exceptions import ShapeError, NonSquareMatrixError, NonInvertibleMatrixError +from .eigen import _fuzzy_positive_definite +from .utilities import _get_intermediate_simp, _iszero + + +def _diagonal_solve(M, rhs): + """Solves ``Ax = B`` efficiently, where A is a diagonal Matrix, + with non-zero diagonal entries. + + Examples + ======== + + >>> from sympy import Matrix, eye + >>> A = eye(2)*2 + >>> B = Matrix([[1, 2], [3, 4]]) + >>> A.diagonal_solve(B) == B/2 + True + + See Also + ======== + + sympy.matrices.dense.DenseMatrix.lower_triangular_solve + sympy.matrices.dense.DenseMatrix.upper_triangular_solve + gauss_jordan_solve + cholesky_solve + LDLsolve + LUsolve + QRsolve + pinv_solve + cramer_solve + """ + + if not M.is_diagonal(): + raise TypeError("Matrix should be diagonal") + if rhs.rows != M.rows: + raise TypeError("Size mismatch") + + return M._new( + rhs.rows, rhs.cols, lambda i, j: rhs[i, j] / M[i, i]) + + +def _lower_triangular_solve(M, rhs): + """Solves ``Ax = B``, where A is a lower triangular matrix. + + See Also + ======== + + upper_triangular_solve + gauss_jordan_solve + cholesky_solve + diagonal_solve + LDLsolve + LUsolve + QRsolve + pinv_solve + cramer_solve + """ + + from .dense import MutableDenseMatrix + + if not M.is_square: + raise NonSquareMatrixError("Matrix must be square.") + if rhs.rows != M.rows: + raise ShapeError("Matrices size mismatch.") + if not M.is_lower: + raise ValueError("Matrix must be lower triangular.") + + dps = _get_intermediate_simp() + X = MutableDenseMatrix.zeros(M.rows, rhs.cols) + + for j in range(rhs.cols): + for i in range(M.rows): + if M[i, i] == 0: + raise TypeError("Matrix must be non-singular.") + + X[i, j] = dps((rhs[i, j] - sum(M[i, k]*X[k, j] + for k in range(i))) / M[i, i]) + + return M._new(X) + +def _lower_triangular_solve_sparse(M, rhs): + """Solves ``Ax = B``, where A is a lower triangular matrix. + + See Also + ======== + + upper_triangular_solve + gauss_jordan_solve + cholesky_solve + diagonal_solve + LDLsolve + LUsolve + QRsolve + pinv_solve + cramer_solve + """ + + if not M.is_square: + raise NonSquareMatrixError("Matrix must be square.") + if rhs.rows != M.rows: + raise ShapeError("Matrices size mismatch.") + if not M.is_lower: + raise ValueError("Matrix must be lower triangular.") + + dps = _get_intermediate_simp() + rows = [[] for i in range(M.rows)] + + for i, j, v in M.row_list(): + if i > j: + rows[i].append((j, v)) + + X = rhs.as_mutable() + + for j in range(rhs.cols): + for i in range(rhs.rows): + for u, v in rows[i]: + X[i, j] -= v*X[u, j] + + X[i, j] = dps(X[i, j] / M[i, i]) + + return M._new(X) + + +def _upper_triangular_solve(M, rhs): + """Solves ``Ax = B``, where A is an upper triangular matrix. + + See Also + ======== + + lower_triangular_solve + gauss_jordan_solve + cholesky_solve + diagonal_solve + LDLsolve + LUsolve + QRsolve + pinv_solve + cramer_solve + """ + + from .dense import MutableDenseMatrix + + if not M.is_square: + raise NonSquareMatrixError("Matrix must be square.") + if rhs.rows != M.rows: + raise ShapeError("Matrix size mismatch.") + if not M.is_upper: + raise TypeError("Matrix is not upper triangular.") + + dps = _get_intermediate_simp() + X = MutableDenseMatrix.zeros(M.rows, rhs.cols) + + for j in range(rhs.cols): + for i in reversed(range(M.rows)): + if M[i, i] == 0: + raise ValueError("Matrix must be non-singular.") + + X[i, j] = dps((rhs[i, j] - sum(M[i, k]*X[k, j] + for k in range(i + 1, M.rows))) / M[i, i]) + + return M._new(X) + +def _upper_triangular_solve_sparse(M, rhs): + """Solves ``Ax = B``, where A is an upper triangular matrix. + + See Also + ======== + + lower_triangular_solve + gauss_jordan_solve + cholesky_solve + diagonal_solve + LDLsolve + LUsolve + QRsolve + pinv_solve + cramer_solve + """ + + if not M.is_square: + raise NonSquareMatrixError("Matrix must be square.") + if rhs.rows != M.rows: + raise ShapeError("Matrix size mismatch.") + if not M.is_upper: + raise TypeError("Matrix is not upper triangular.") + + dps = _get_intermediate_simp() + rows = [[] for i in range(M.rows)] + + for i, j, v in M.row_list(): + if i < j: + rows[i].append((j, v)) + + X = rhs.as_mutable() + + for j in range(rhs.cols): + for i in reversed(range(rhs.rows)): + for u, v in reversed(rows[i]): + X[i, j] -= v*X[u, j] + + X[i, j] = dps(X[i, j] / M[i, i]) + + return M._new(X) + + +def _cholesky_solve(M, rhs): + """Solves ``Ax = B`` using Cholesky decomposition, + for a general square non-singular matrix. + For a non-square matrix with rows > cols, + the least squares solution is returned. + + See Also + ======== + + sympy.matrices.dense.DenseMatrix.lower_triangular_solve + sympy.matrices.dense.DenseMatrix.upper_triangular_solve + gauss_jordan_solve + diagonal_solve + LDLsolve + LUsolve + QRsolve + pinv_solve + cramer_solve + """ + + if M.rows < M.cols: + raise NotImplementedError( + 'Under-determined System. Try M.gauss_jordan_solve(rhs)') + + hermitian = True + reform = False + + if M.is_symmetric(): + hermitian = False + elif not M.is_hermitian: + reform = True + + if reform or _fuzzy_positive_definite(M) is False: + H = M.H + M = H.multiply(M) + rhs = H.multiply(rhs) + hermitian = not M.is_symmetric() + + L = M.cholesky(hermitian=hermitian) + Y = L.lower_triangular_solve(rhs) + + if hermitian: + return (L.H).upper_triangular_solve(Y) + else: + return (L.T).upper_triangular_solve(Y) + + +def _LDLsolve(M, rhs): + """Solves ``Ax = B`` using LDL decomposition, + for a general square and non-singular matrix. + + For a non-square matrix with rows > cols, + the least squares solution is returned. + + Examples + ======== + + >>> from sympy import Matrix, eye + >>> A = eye(2)*2 + >>> B = Matrix([[1, 2], [3, 4]]) + >>> A.LDLsolve(B) == B/2 + True + + See Also + ======== + + sympy.matrices.dense.DenseMatrix.LDLdecomposition + sympy.matrices.dense.DenseMatrix.lower_triangular_solve + sympy.matrices.dense.DenseMatrix.upper_triangular_solve + gauss_jordan_solve + cholesky_solve + diagonal_solve + LUsolve + QRsolve + pinv_solve + cramer_solve + """ + + if M.rows < M.cols: + raise NotImplementedError( + 'Under-determined System. Try M.gauss_jordan_solve(rhs)') + + hermitian = True + reform = False + + if M.is_symmetric(): + hermitian = False + elif not M.is_hermitian: + reform = True + + if reform or _fuzzy_positive_definite(M) is False: + H = M.H + M = H.multiply(M) + rhs = H.multiply(rhs) + hermitian = not M.is_symmetric() + + L, D = M.LDLdecomposition(hermitian=hermitian) + Y = L.lower_triangular_solve(rhs) + Z = D.diagonal_solve(Y) + + if hermitian: + return (L.H).upper_triangular_solve(Z) + else: + return (L.T).upper_triangular_solve(Z) + + +def _LUsolve(M, rhs, iszerofunc=_iszero): + """Solve the linear system ``Ax = rhs`` for ``x`` where ``A = M``. + + This is for symbolic matrices, for real or complex ones use + mpmath.lu_solve or mpmath.qr_solve. + + See Also + ======== + + sympy.matrices.dense.DenseMatrix.lower_triangular_solve + sympy.matrices.dense.DenseMatrix.upper_triangular_solve + gauss_jordan_solve + cholesky_solve + diagonal_solve + LDLsolve + QRsolve + pinv_solve + LUdecomposition + cramer_solve + """ + + if rhs.rows != M.rows: + raise ShapeError( + "``M`` and ``rhs`` must have the same number of rows.") + + m = M.rows + n = M.cols + + if m < n: + raise NotImplementedError("Underdetermined systems not supported.") + + try: + A, perm = M.LUdecomposition_Simple( + iszerofunc=iszerofunc, rankcheck=True) + except ValueError: + raise NonInvertibleMatrixError("Matrix det == 0; not invertible.") + + dps = _get_intermediate_simp() + b = rhs.permute_rows(perm).as_mutable() + + # forward substitution, all diag entries are scaled to 1 + for i in range(m): + for j in range(min(i, n)): + scale = A[i, j] + b.zip_row_op(i, j, lambda x, y: dps(x - scale * y)) + + # consistency check for overdetermined systems + if m > n: + for i in range(n, m): + for j in range(b.cols): + if not iszerofunc(b[i, j]): + raise ValueError("The system is inconsistent.") + + b = b[0:n, :] # truncate zero rows if consistent + + # backward substitution + for i in range(n - 1, -1, -1): + for j in range(i + 1, n): + scale = A[i, j] + b.zip_row_op(i, j, lambda x, y: dps(x - scale * y)) + + scale = A[i, i] + b.row_op(i, lambda x, _: dps(scale**-1 * x)) + + return rhs.__class__(b) + + +def _QRsolve(M, b): + """Solve the linear system ``Ax = b``. + + ``M`` is the matrix ``A``, the method argument is the vector + ``b``. The method returns the solution vector ``x``. If ``b`` is a + matrix, the system is solved for each column of ``b`` and the + return value is a matrix of the same shape as ``b``. + + This method is slower (approximately by a factor of 2) but + more stable for floating-point arithmetic than the LUsolve method. + However, LUsolve usually uses an exact arithmetic, so you do not need + to use QRsolve. + + This is mainly for educational purposes and symbolic matrices, for real + (or complex) matrices use mpmath.qr_solve. + + See Also + ======== + + sympy.matrices.dense.DenseMatrix.lower_triangular_solve + sympy.matrices.dense.DenseMatrix.upper_triangular_solve + gauss_jordan_solve + cholesky_solve + diagonal_solve + LDLsolve + LUsolve + pinv_solve + QRdecomposition + cramer_solve + """ + + dps = _get_intermediate_simp(expand_mul, expand_mul) + Q, R = M.QRdecomposition() + y = Q.T * b + + # back substitution to solve R*x = y: + # We build up the result "backwards" in the vector 'x' and reverse it + # only in the end. + x = [] + n = R.rows + + for j in range(n - 1, -1, -1): + tmp = y[j, :] + + for k in range(j + 1, n): + tmp -= R[j, k] * x[n - 1 - k] + + tmp = dps(tmp) + + x.append(tmp / R[j, j]) + + return M.vstack(*x[::-1]) + + +def _gauss_jordan_solve(M, B, freevar=False): + """ + Solves ``Ax = B`` using Gauss Jordan elimination. + + There may be zero, one, or infinite solutions. If one solution + exists, it will be returned. If infinite solutions exist, it will + be returned parametrically. If no solutions exist, It will throw + ValueError. + + Parameters + ========== + + B : Matrix + The right hand side of the equation to be solved for. Must have + the same number of rows as matrix A. + + freevar : boolean, optional + Flag, when set to `True` will return the indices of the free + variables in the solutions (column Matrix), for a system that is + undetermined (e.g. A has more columns than rows), for which + infinite solutions are possible, in terms of arbitrary + values of free variables. Default `False`. + + Returns + ======= + + x : Matrix + The matrix that will satisfy ``Ax = B``. Will have as many rows as + matrix A has columns, and as many columns as matrix B. + + params : Matrix + If the system is underdetermined (e.g. A has more columns than + rows), infinite solutions are possible, in terms of arbitrary + parameters. These arbitrary parameters are returned as params + Matrix. + + free_var_index : List, optional + If the system is underdetermined (e.g. A has more columns than + rows), infinite solutions are possible, in terms of arbitrary + values of free variables. Then the indices of the free variables + in the solutions (column Matrix) are returned by free_var_index, + if the flag `freevar` is set to `True`. + + Examples + ======== + + >>> from sympy import Matrix + >>> A = Matrix([[1, 2, 1, 1], [1, 2, 2, -1], [2, 4, 0, 6]]) + >>> B = Matrix([7, 12, 4]) + >>> sol, params = A.gauss_jordan_solve(B) + >>> sol + Matrix([ + [-2*tau0 - 3*tau1 + 2], + [ tau0], + [ 2*tau1 + 5], + [ tau1]]) + >>> params + Matrix([ + [tau0], + [tau1]]) + >>> taus_zeroes = { tau:0 for tau in params } + >>> sol_unique = sol.xreplace(taus_zeroes) + >>> sol_unique + Matrix([ + [2], + [0], + [5], + [0]]) + + + >>> A = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 10]]) + >>> B = Matrix([3, 6, 9]) + >>> sol, params = A.gauss_jordan_solve(B) + >>> sol + Matrix([ + [-1], + [ 2], + [ 0]]) + >>> params + Matrix(0, 1, []) + + >>> A = Matrix([[2, -7], [-1, 4]]) + >>> B = Matrix([[-21, 3], [12, -2]]) + >>> sol, params = A.gauss_jordan_solve(B) + >>> sol + Matrix([ + [0, -2], + [3, -1]]) + >>> params + Matrix(0, 2, []) + + + >>> from sympy import Matrix + >>> A = Matrix([[1, 2, 1, 1], [1, 2, 2, -1], [2, 4, 0, 6]]) + >>> B = Matrix([7, 12, 4]) + >>> sol, params, freevars = A.gauss_jordan_solve(B, freevar=True) + >>> sol + Matrix([ + [-2*tau0 - 3*tau1 + 2], + [ tau0], + [ 2*tau1 + 5], + [ tau1]]) + >>> params + Matrix([ + [tau0], + [tau1]]) + >>> freevars + [1, 3] + + + See Also + ======== + + sympy.matrices.dense.DenseMatrix.lower_triangular_solve + sympy.matrices.dense.DenseMatrix.upper_triangular_solve + cholesky_solve + diagonal_solve + LDLsolve + LUsolve + QRsolve + pinv + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Gaussian_elimination + + """ + + from sympy.matrices import Matrix, zeros + + cls = M.__class__ + aug = M.hstack(M.copy(), B.copy()) + B_cols = B.cols + row, col = aug[:, :-B_cols].shape + + # solve by reduced row echelon form + A, pivots = aug.rref(simplify=True) + A, v = A[:, :-B_cols], A[:, -B_cols:] + pivots = list(filter(lambda p: p < col, pivots)) + rank = len(pivots) + + # Get index of free symbols (free parameters) + # non-pivots columns are free variables + free_var_index = [c for c in range(A.cols) if c not in pivots] + + # Bring to block form + permutation = Matrix(pivots + free_var_index).T + + # check for existence of solutions + # rank of aug Matrix should be equal to rank of coefficient matrix + if not v[rank:, :].is_zero_matrix: + raise ValueError("Linear system has no solution") + + # Free parameters + # what are current unnumbered free symbol names? + name = uniquely_named_symbol('tau', [aug], + compare=lambda i: str(i).rstrip('1234567890'), + modify=lambda s: '_' + s).name + gen = numbered_symbols(name) + tau = Matrix([next(gen) for k in range((col - rank)*B_cols)]).reshape( + col - rank, B_cols) + + # Full parametric solution + V = A[:rank, free_var_index] + vt = v[:rank, :] + free_sol = tau.vstack(vt - V * tau, tau) + + # Undo permutation + sol = zeros(col, B_cols) + + for k in range(col): + sol[permutation[k], :] = free_sol[k,:] + + sol, tau = cls(sol), cls(tau) + + if freevar: + return sol, tau, free_var_index + else: + return sol, tau + + +def _pinv_solve(M, B, arbitrary_matrix=None): + """Solve ``Ax = B`` using the Moore-Penrose pseudoinverse. + + There may be zero, one, or infinite solutions. If one solution + exists, it will be returned. If infinite solutions exist, one will + be returned based on the value of arbitrary_matrix. If no solutions + exist, the least-squares solution is returned. + + Parameters + ========== + + B : Matrix + The right hand side of the equation to be solved for. Must have + the same number of rows as matrix A. + arbitrary_matrix : Matrix + If the system is underdetermined (e.g. A has more columns than + rows), infinite solutions are possible, in terms of an arbitrary + matrix. This parameter may be set to a specific matrix to use + for that purpose; if so, it must be the same shape as x, with as + many rows as matrix A has columns, and as many columns as matrix + B. If left as None, an appropriate matrix containing dummy + symbols in the form of ``wn_m`` will be used, with n and m being + row and column position of each symbol. + + Returns + ======= + + x : Matrix + The matrix that will satisfy ``Ax = B``. Will have as many rows as + matrix A has columns, and as many columns as matrix B. + + Examples + ======== + + >>> from sympy import Matrix + >>> A = Matrix([[1, 2, 3], [4, 5, 6]]) + >>> B = Matrix([7, 8]) + >>> A.pinv_solve(B) + Matrix([ + [ _w0_0/6 - _w1_0/3 + _w2_0/6 - 55/18], + [-_w0_0/3 + 2*_w1_0/3 - _w2_0/3 + 1/9], + [ _w0_0/6 - _w1_0/3 + _w2_0/6 + 59/18]]) + >>> A.pinv_solve(B, arbitrary_matrix=Matrix([0, 0, 0])) + Matrix([ + [-55/18], + [ 1/9], + [ 59/18]]) + + See Also + ======== + + sympy.matrices.dense.DenseMatrix.lower_triangular_solve + sympy.matrices.dense.DenseMatrix.upper_triangular_solve + gauss_jordan_solve + cholesky_solve + diagonal_solve + LDLsolve + LUsolve + QRsolve + pinv + + Notes + ===== + + This may return either exact solutions or least squares solutions. + To determine which, check ``A * A.pinv() * B == B``. It will be + True if exact solutions exist, and False if only a least-squares + solution exists. Be aware that the left hand side of that equation + may need to be simplified to correctly compare to the right hand + side. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Moore-Penrose_pseudoinverse#Obtaining_all_solutions_of_a_linear_system + + """ + + from sympy.matrices import eye + + A = M + A_pinv = M.pinv() + + if arbitrary_matrix is None: + rows, cols = A.cols, B.cols + w = symbols('w:{}_:{}'.format(rows, cols), cls=Dummy) + arbitrary_matrix = M.__class__(cols, rows, w).T + + return A_pinv.multiply(B) + (eye(A.cols) - + A_pinv.multiply(A)).multiply(arbitrary_matrix) + + +def _cramer_solve(M, rhs, det_method="laplace"): + """Solves system of linear equations using Cramer's rule. + + This method is relatively inefficient compared to other methods. + However it only uses a single division, assuming a division-free determinant + method is provided. This is helpful to minimize the chance of divide-by-zero + cases in symbolic solutions to linear systems. + + Parameters + ========== + M : Matrix + The matrix representing the left hand side of the equation. + rhs : Matrix + The matrix representing the right hand side of the equation. + det_method : str or callable + The method to use to calculate the determinant of the matrix. + The default is ``'laplace'``. If a callable is passed, it should take a + single argument, the matrix, and return the determinant of the matrix. + + Returns + ======= + x : Matrix + The matrix that will satisfy ``Ax = B``. Will have as many rows as + matrix A has columns, and as many columns as matrix B. + + Examples + ======== + + >>> from sympy import Matrix + >>> A = Matrix([[0, -6, 1], [0, -6, -1], [-5, -2, 3]]) + >>> B = Matrix([[-30, -9], [-18, -27], [-26, 46]]) + >>> x = A.cramer_solve(B) + >>> x + Matrix([ + [ 0, -5], + [ 4, 3], + [-6, 9]]) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Cramer%27s_rule#Explicit_formulas_for_small_systems + + """ + from .dense import zeros + + def entry(i, j): + return rhs[i, sol] if j == col else M[i, j] + + if det_method == "bird": + from .determinant import _det_bird + det = _det_bird + elif det_method == "laplace": + from .determinant import _det_laplace + det = _det_laplace + elif isinstance(det_method, str): + det = lambda matrix: matrix.det(method=det_method) + else: + det = det_method + det_M = det(M) + x = zeros(*rhs.shape) + for sol in range(rhs.shape[1]): + for col in range(rhs.shape[0]): + x[col, sol] = det(M.__class__(*M.shape, entry)) / det_M + return M.__class__(x) + + +def _solve(M, rhs, method='GJ'): + """Solves linear equation where the unique solution exists. + + Parameters + ========== + + rhs : Matrix + Vector representing the right hand side of the linear equation. + + method : string, optional + If set to ``'GJ'`` or ``'GE'``, the Gauss-Jordan elimination will be + used, which is implemented in the routine ``gauss_jordan_solve``. + + If set to ``'LU'``, ``LUsolve`` routine will be used. + + If set to ``'QR'``, ``QRsolve`` routine will be used. + + If set to ``'PINV'``, ``pinv_solve`` routine will be used. + + If set to ``'CRAMER'``, ``cramer_solve`` routine will be used. + + It also supports the methods available for special linear systems + + For positive definite systems: + + If set to ``'CH'``, ``cholesky_solve`` routine will be used. + + If set to ``'LDL'``, ``LDLsolve`` routine will be used. + + To use a different method and to compute the solution via the + inverse, use a method defined in the .inv() docstring. + + Returns + ======= + + solutions : Matrix + Vector representing the solution. + + Raises + ====== + + ValueError + If there is not a unique solution then a ``ValueError`` will be + raised. + + If ``M`` is not square, a ``ValueError`` and a different routine + for solving the system will be suggested. + """ + + if method in ('GJ', 'GE'): + try: + soln, param = M.gauss_jordan_solve(rhs) + + if param: + raise NonInvertibleMatrixError("Matrix det == 0; not invertible. " + "Try ``M.gauss_jordan_solve(rhs)`` to obtain a parametric solution.") + + except ValueError: + raise NonInvertibleMatrixError("Matrix det == 0; not invertible.") + + return soln + + elif method == 'LU': + return M.LUsolve(rhs) + elif method == 'CH': + return M.cholesky_solve(rhs) + elif method == 'QR': + return M.QRsolve(rhs) + elif method == 'LDL': + return M.LDLsolve(rhs) + elif method == 'PINV': + return M.pinv_solve(rhs) + elif method == 'CRAMER': + return M.cramer_solve(rhs) + else: + return M.inv(method=method).multiply(rhs) + + +def _solve_least_squares(M, rhs, method='CH'): + """Return the least-square fit to the data. + + Parameters + ========== + + rhs : Matrix + Vector representing the right hand side of the linear equation. + + method : string or boolean, optional + If set to ``'CH'``, ``cholesky_solve`` routine will be used. + + If set to ``'LDL'``, ``LDLsolve`` routine will be used. + + If set to ``'QR'``, ``QRsolve`` routine will be used. + + If set to ``'PINV'``, ``pinv_solve`` routine will be used. + + Otherwise, the conjugate of ``M`` will be used to create a system + of equations that is passed to ``solve`` along with the hint + defined by ``method``. + + Returns + ======= + + solutions : Matrix + Vector representing the solution. + + Examples + ======== + + >>> from sympy import Matrix, ones + >>> A = Matrix([1, 2, 3]) + >>> B = Matrix([2, 3, 4]) + >>> S = Matrix(A.row_join(B)) + >>> S + Matrix([ + [1, 2], + [2, 3], + [3, 4]]) + + If each line of S represent coefficients of Ax + By + and x and y are [2, 3] then S*xy is: + + >>> r = S*Matrix([2, 3]); r + Matrix([ + [ 8], + [13], + [18]]) + + But let's add 1 to the middle value and then solve for the + least-squares value of xy: + + >>> xy = S.solve_least_squares(Matrix([8, 14, 18])); xy + Matrix([ + [ 5/3], + [10/3]]) + + The error is given by S*xy - r: + + >>> S*xy - r + Matrix([ + [1/3], + [1/3], + [1/3]]) + >>> _.norm().n(2) + 0.58 + + If a different xy is used, the norm will be higher: + + >>> xy += ones(2, 1)/10 + >>> (S*xy - r).norm().n(2) + 1.5 + + """ + + if method == 'CH': + return M.cholesky_solve(rhs) + elif method == 'QR': + return M.QRsolve(rhs) + elif method == 'LDL': + return M.LDLsolve(rhs) + elif method == 'PINV': + return M.pinv_solve(rhs) + else: + t = M.H + return (t * M).solve(t * rhs, method=method) diff --git a/phivenv/Lib/site-packages/sympy/matrices/sparse.py b/phivenv/Lib/site-packages/sympy/matrices/sparse.py new file mode 100644 index 0000000000000000000000000000000000000000..95a7b3ca0ac29cf4409ec1eeecd059f9643e9bbc --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/sparse.py @@ -0,0 +1,473 @@ +from collections.abc import Callable + +from sympy.core.containers import Dict +from sympy.utilities.exceptions import sympy_deprecation_warning +from sympy.utilities.iterables import is_sequence +from sympy.utilities.misc import as_int + +from .matrixbase import MatrixBase +from .repmatrix import MutableRepMatrix, RepMatrix + +from .utilities import _iszero + +from .decompositions import ( + _liupc, _row_structure_symbolic_cholesky, _cholesky_sparse, + _LDLdecomposition_sparse) + +from .solvers import ( + _lower_triangular_solve_sparse, _upper_triangular_solve_sparse) + + +class SparseRepMatrix(RepMatrix): + """ + A sparse matrix (a matrix with a large number of zero elements). + + Examples + ======== + + >>> from sympy import SparseMatrix, ones + >>> SparseMatrix(2, 2, range(4)) + Matrix([ + [0, 1], + [2, 3]]) + >>> SparseMatrix(2, 2, {(1, 1): 2}) + Matrix([ + [0, 0], + [0, 2]]) + + A SparseMatrix can be instantiated from a ragged list of lists: + + >>> SparseMatrix([[1, 2, 3], [1, 2], [1]]) + Matrix([ + [1, 2, 3], + [1, 2, 0], + [1, 0, 0]]) + + For safety, one may include the expected size and then an error + will be raised if the indices of any element are out of range or + (for a flat list) if the total number of elements does not match + the expected shape: + + >>> SparseMatrix(2, 2, [1, 2]) + Traceback (most recent call last): + ... + ValueError: List length (2) != rows*columns (4) + + Here, an error is not raised because the list is not flat and no + element is out of range: + + >>> SparseMatrix(2, 2, [[1, 2]]) + Matrix([ + [1, 2], + [0, 0]]) + + But adding another element to the first (and only) row will cause + an error to be raised: + + >>> SparseMatrix(2, 2, [[1, 2, 3]]) + Traceback (most recent call last): + ... + ValueError: The location (0, 2) is out of designated range: (1, 1) + + To autosize the matrix, pass None for rows: + + >>> SparseMatrix(None, [[1, 2, 3]]) + Matrix([[1, 2, 3]]) + >>> SparseMatrix(None, {(1, 1): 1, (3, 3): 3}) + Matrix([ + [0, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 3]]) + + Values that are themselves a Matrix are automatically expanded: + + >>> SparseMatrix(4, 4, {(1, 1): ones(2)}) + Matrix([ + [0, 0, 0, 0], + [0, 1, 1, 0], + [0, 1, 1, 0], + [0, 0, 0, 0]]) + + A ValueError is raised if the expanding matrix tries to overwrite + a different element already present: + + >>> SparseMatrix(3, 3, {(0, 0): ones(2), (1, 1): 2}) + Traceback (most recent call last): + ... + ValueError: collision at (1, 1) + + See Also + ======== + DenseMatrix + MutableSparseMatrix + ImmutableSparseMatrix + """ + + @classmethod + def _handle_creation_inputs(cls, *args, **kwargs): + if len(args) == 1 and isinstance(args[0], MatrixBase): + rows = args[0].rows + cols = args[0].cols + smat = args[0].todok() + return rows, cols, smat + + smat = {} + # autosizing + if len(args) == 2 and args[0] is None: + args = [None, None, args[1]] + + if len(args) == 3: + r, c = args[:2] + if r is c is None: + rows = cols = None + elif None in (r, c): + raise ValueError( + 'Pass rows=None and no cols for autosizing.') + else: + rows, cols = as_int(args[0]), as_int(args[1]) + + if isinstance(args[2], Callable): + op = args[2] + + if None in (rows, cols): + raise ValueError( + "{} and {} must be integers for this " + "specification.".format(rows, cols)) + + row_indices = [cls._sympify(i) for i in range(rows)] + col_indices = [cls._sympify(j) for j in range(cols)] + + for i in row_indices: + for j in col_indices: + value = cls._sympify(op(i, j)) + if value != cls.zero: + smat[i, j] = value + + return rows, cols, smat + + elif isinstance(args[2], (dict, Dict)): + def update(i, j, v): + # update smat and make sure there are no collisions + if v: + if (i, j) in smat and v != smat[i, j]: + raise ValueError( + "There is a collision at {} for {} and {}." + .format((i, j), v, smat[i, j]) + ) + smat[i, j] = v + + # manual copy, copy.deepcopy() doesn't work + for (r, c), v in args[2].items(): + if isinstance(v, MatrixBase): + for (i, j), vv in v.todok().items(): + update(r + i, c + j, vv) + elif isinstance(v, (list, tuple)): + _, _, smat = cls._handle_creation_inputs(v, **kwargs) + for i, j in smat: + update(r + i, c + j, smat[i, j]) + else: + v = cls._sympify(v) + update(r, c, cls._sympify(v)) + + elif is_sequence(args[2]): + flat = not any(is_sequence(i) for i in args[2]) + if not flat: + _, _, smat = \ + cls._handle_creation_inputs(args[2], **kwargs) + else: + flat_list = args[2] + if len(flat_list) != rows * cols: + raise ValueError( + "The length of the flat list ({}) does not " + "match the specified size ({} * {})." + .format(len(flat_list), rows, cols) + ) + + for i in range(rows): + for j in range(cols): + value = flat_list[i*cols + j] + value = cls._sympify(value) + if value != cls.zero: + smat[i, j] = value + + if rows is None: # autosizing + keys = smat.keys() + rows = max(r for r, _ in keys) + 1 if keys else 0 + cols = max(c for _, c in keys) + 1 if keys else 0 + + else: + for i, j in smat.keys(): + if i and i >= rows or j and j >= cols: + raise ValueError( + "The location {} is out of the designated range" + "[{}, {}]x[{}, {}]" + .format((i, j), 0, rows - 1, 0, cols - 1) + ) + + return rows, cols, smat + + elif len(args) == 1 and isinstance(args[0], (list, tuple)): + # list of values or lists + v = args[0] + c = 0 + for i, row in enumerate(v): + if not isinstance(row, (list, tuple)): + row = [row] + for j, vv in enumerate(row): + if vv != cls.zero: + smat[i, j] = cls._sympify(vv) + c = max(c, len(row)) + rows = len(v) if c else 0 + cols = c + return rows, cols, smat + + else: + # handle full matrix forms with _handle_creation_inputs + rows, cols, mat = super()._handle_creation_inputs(*args) + for i in range(rows): + for j in range(cols): + value = mat[cols*i + j] + if value != cls.zero: + smat[i, j] = value + + return rows, cols, smat + + @property + def _smat(self): + + sympy_deprecation_warning( + """ + The private _smat attribute of SparseMatrix is deprecated. Use the + .todok() method instead. + """, + deprecated_since_version="1.9", + active_deprecations_target="deprecated-private-matrix-attributes" + ) + + return self.todok() + + def _eval_inverse(self, **kwargs): + return self.inv(method=kwargs.get('method', 'LDL'), + iszerofunc=kwargs.get('iszerofunc', _iszero), + try_block_diag=kwargs.get('try_block_diag', False)) + + def applyfunc(self, f): + """Apply a function to each element of the matrix. + + Examples + ======== + + >>> from sympy import SparseMatrix + >>> m = SparseMatrix(2, 2, lambda i, j: i*2+j) + >>> m + Matrix([ + [0, 1], + [2, 3]]) + >>> m.applyfunc(lambda i: 2*i) + Matrix([ + [0, 2], + [4, 6]]) + + """ + if not callable(f): + raise TypeError("`f` must be callable.") + + # XXX: This only applies the function to the nonzero elements of the + # matrix so is inconsistent with DenseMatrix.applyfunc e.g. + # zeros(2, 2).applyfunc(lambda x: x + 1) + dok = {} + for k, v in self.todok().items(): + fv = f(v) + if fv != 0: + dok[k] = fv + + return self._new(self.rows, self.cols, dok) + + def as_immutable(self): + """Returns an Immutable version of this Matrix.""" + from .immutable import ImmutableSparseMatrix + return ImmutableSparseMatrix(self) + + def as_mutable(self): + """Returns a mutable version of this matrix. + + Examples + ======== + + >>> from sympy import ImmutableMatrix + >>> X = ImmutableMatrix([[1, 2], [3, 4]]) + >>> Y = X.as_mutable() + >>> Y[1, 1] = 5 # Can set values in Y + >>> Y + Matrix([ + [1, 2], + [3, 5]]) + """ + return MutableSparseMatrix(self) + + def col_list(self): + """Returns a column-sorted list of non-zero elements of the matrix. + + Examples + ======== + + >>> from sympy import SparseMatrix + >>> a=SparseMatrix(((1, 2), (3, 4))) + >>> a + Matrix([ + [1, 2], + [3, 4]]) + >>> a.CL + [(0, 0, 1), (1, 0, 3), (0, 1, 2), (1, 1, 4)] + + See Also + ======== + + sympy.matrices.sparse.SparseMatrix.row_list + """ + return [tuple(k + (self[k],)) for k in sorted(self.todok().keys(), key=lambda k: list(reversed(k)))] + + def nnz(self): + """Returns the number of non-zero elements in Matrix.""" + return len(self.todok()) + + def row_list(self): + """Returns a row-sorted list of non-zero elements of the matrix. + + Examples + ======== + + >>> from sympy import SparseMatrix + >>> a = SparseMatrix(((1, 2), (3, 4))) + >>> a + Matrix([ + [1, 2], + [3, 4]]) + >>> a.RL + [(0, 0, 1), (0, 1, 2), (1, 0, 3), (1, 1, 4)] + + See Also + ======== + + sympy.matrices.sparse.SparseMatrix.col_list + """ + return [tuple(k + (self[k],)) for k in + sorted(self.todok().keys(), key=list)] + + def scalar_multiply(self, scalar): + "Scalar element-wise multiplication" + return scalar * self + + def solve_least_squares(self, rhs, method='LDL'): + """Return the least-square fit to the data. + + By default the cholesky_solve routine is used (method='CH'); other + methods of matrix inversion can be used. To find out which are + available, see the docstring of the .inv() method. + + Examples + ======== + + >>> from sympy import SparseMatrix, Matrix, ones + >>> A = Matrix([1, 2, 3]) + >>> B = Matrix([2, 3, 4]) + >>> S = SparseMatrix(A.row_join(B)) + >>> S + Matrix([ + [1, 2], + [2, 3], + [3, 4]]) + + If each line of S represent coefficients of Ax + By + and x and y are [2, 3] then S*xy is: + + >>> r = S*Matrix([2, 3]); r + Matrix([ + [ 8], + [13], + [18]]) + + But let's add 1 to the middle value and then solve for the + least-squares value of xy: + + >>> xy = S.solve_least_squares(Matrix([8, 14, 18])); xy + Matrix([ + [ 5/3], + [10/3]]) + + The error is given by S*xy - r: + + >>> S*xy - r + Matrix([ + [1/3], + [1/3], + [1/3]]) + >>> _.norm().n(2) + 0.58 + + If a different xy is used, the norm will be higher: + + >>> xy += ones(2, 1)/10 + >>> (S*xy - r).norm().n(2) + 1.5 + + """ + t = self.T + return (t*self).inv(method=method)*t*rhs + + def solve(self, rhs, method='LDL'): + """Return solution to self*soln = rhs using given inversion method. + + For a list of possible inversion methods, see the .inv() docstring. + """ + if not self.is_square: + if self.rows < self.cols: + raise ValueError('Under-determined system.') + elif self.rows > self.cols: + raise ValueError('For over-determined system, M, having ' + 'more rows than columns, try M.solve_least_squares(rhs).') + else: + return self.inv(method=method).multiply(rhs) + + RL = property(row_list, None, None, "Alternate faster representation") + CL = property(col_list, None, None, "Alternate faster representation") + + def liupc(self): + return _liupc(self) + + def row_structure_symbolic_cholesky(self): + return _row_structure_symbolic_cholesky(self) + + def cholesky(self, hermitian=True): + return _cholesky_sparse(self, hermitian=hermitian) + + def LDLdecomposition(self, hermitian=True): + return _LDLdecomposition_sparse(self, hermitian=hermitian) + + def lower_triangular_solve(self, rhs): + return _lower_triangular_solve_sparse(self, rhs) + + def upper_triangular_solve(self, rhs): + return _upper_triangular_solve_sparse(self, rhs) + + liupc.__doc__ = _liupc.__doc__ + row_structure_symbolic_cholesky.__doc__ = _row_structure_symbolic_cholesky.__doc__ + cholesky.__doc__ = _cholesky_sparse.__doc__ + LDLdecomposition.__doc__ = _LDLdecomposition_sparse.__doc__ + lower_triangular_solve.__doc__ = lower_triangular_solve.__doc__ + upper_triangular_solve.__doc__ = upper_triangular_solve.__doc__ + + +class MutableSparseMatrix(SparseRepMatrix, MutableRepMatrix): + + @classmethod + def _new(cls, *args, **kwargs): + rows, cols, smat = cls._handle_creation_inputs(*args, **kwargs) + + rep = cls._smat_to_DomainMatrix(rows, cols, smat) + + return cls._fromrep(rep) + + +SparseMatrix = MutableSparseMatrix diff --git a/phivenv/Lib/site-packages/sympy/matrices/sparsetools.py b/phivenv/Lib/site-packages/sympy/matrices/sparsetools.py new file mode 100644 index 0000000000000000000000000000000000000000..50048f6dc7e5cf160366963d16427987616ddce7 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/sparsetools.py @@ -0,0 +1,300 @@ +from sympy.core.containers import Dict +from sympy.core.symbol import Dummy +from sympy.utilities.iterables import is_sequence +from sympy.utilities.misc import as_int, filldedent + +from .sparse import MutableSparseMatrix as SparseMatrix + + +def _doktocsr(dok): + """Converts a sparse matrix to Compressed Sparse Row (CSR) format. + + Parameters + ========== + + A : contains non-zero elements sorted by key (row, column) + JA : JA[i] is the column corresponding to A[i] + IA : IA[i] contains the index in A for the first non-zero element + of row[i]. Thus IA[i+1] - IA[i] gives number of non-zero + elements row[i]. The length of IA is always 1 more than the + number of rows in the matrix. + + Examples + ======== + + >>> from sympy.matrices.sparsetools import _doktocsr + >>> from sympy import SparseMatrix, diag + >>> m = SparseMatrix(diag(1, 2, 3)) + >>> m[2, 0] = -1 + >>> _doktocsr(m) + [[1, 2, -1, 3], [0, 1, 0, 2], [0, 1, 2, 4], [3, 3]] + + """ + row, JA, A = [list(i) for i in zip(*dok.row_list())] + IA = [0]*((row[0] if row else 0) + 1) + for i, r in enumerate(row): + IA.extend([i]*(r - row[i - 1])) # if i = 0 nothing is extended + IA.extend([len(A)]*(dok.rows - len(IA) + 1)) + shape = [dok.rows, dok.cols] + return [A, JA, IA, shape] + + +def _csrtodok(csr): + """Converts a CSR representation to DOK representation. + + Examples + ======== + + >>> from sympy.matrices.sparsetools import _csrtodok + >>> _csrtodok([[5, 8, 3, 6], [0, 1, 2, 1], [0, 0, 2, 3, 4], [4, 3]]) + Matrix([ + [0, 0, 0], + [5, 8, 0], + [0, 0, 3], + [0, 6, 0]]) + + """ + smat = {} + A, JA, IA, shape = csr + for i in range(len(IA) - 1): + indices = slice(IA[i], IA[i + 1]) + for l, m in zip(A[indices], JA[indices]): + smat[i, m] = l + return SparseMatrix(*shape, smat) + + +def banded(*args, **kwargs): + """Returns a SparseMatrix from the given dictionary describing + the diagonals of the matrix. The keys are positive for upper + diagonals and negative for those below the main diagonal. The + values may be: + + * expressions or single-argument functions, + + * lists or tuples of values, + + * matrices + + Unless dimensions are given, the size of the returned matrix will + be large enough to contain the largest non-zero value provided. + + kwargs + ====== + + rows : rows of the resulting matrix; computed if + not given. + + cols : columns of the resulting matrix; computed if + not given. + + Examples + ======== + + >>> from sympy import banded, ones, Matrix + >>> from sympy.abc import x + + If explicit values are given in tuples, + the matrix will autosize to contain all values, otherwise + a single value is filled onto the entire diagonal: + + >>> banded({1: (1, 2, 3), -1: (4, 5, 6), 0: x}) + Matrix([ + [x, 1, 0, 0], + [4, x, 2, 0], + [0, 5, x, 3], + [0, 0, 6, x]]) + + A function accepting a single argument can be used to fill the + diagonal as a function of diagonal index (which starts at 0). + The size (or shape) of the matrix must be given to obtain more + than a 1x1 matrix: + + >>> s = lambda d: (1 + d)**2 + >>> banded(5, {0: s, 2: s, -2: 2}) + Matrix([ + [1, 0, 1, 0, 0], + [0, 4, 0, 4, 0], + [2, 0, 9, 0, 9], + [0, 2, 0, 16, 0], + [0, 0, 2, 0, 25]]) + + The diagonal of matrices placed on a diagonal will coincide + with the indicated diagonal: + + >>> vert = Matrix([1, 2, 3]) + >>> banded({0: vert}, cols=3) + Matrix([ + [1, 0, 0], + [2, 1, 0], + [3, 2, 1], + [0, 3, 2], + [0, 0, 3]]) + + >>> banded(4, {0: ones(2)}) + Matrix([ + [1, 1, 0, 0], + [1, 1, 0, 0], + [0, 0, 1, 1], + [0, 0, 1, 1]]) + + Errors are raised if the designated size will not hold + all values an integral number of times. Here, the rows + are designated as odd (but an even number is required to + hold the off-diagonal 2x2 ones): + + >>> banded({0: 2, 1: ones(2)}, rows=5) + Traceback (most recent call last): + ... + ValueError: + sequence does not fit an integral number of times in the matrix + + And here, an even number of rows is given...but the square + matrix has an even number of columns, too. As we saw + in the previous example, an odd number is required: + + >>> banded(4, {0: 2, 1: ones(2)}) # trying to make 4x4 and cols must be odd + Traceback (most recent call last): + ... + ValueError: + sequence does not fit an integral number of times in the matrix + + A way around having to count rows is to enclosing matrix elements + in a tuple and indicate the desired number of them to the right: + + >>> banded({0: 2, 2: (ones(2),)*3}) + Matrix([ + [2, 0, 1, 1, 0, 0, 0, 0], + [0, 2, 1, 1, 0, 0, 0, 0], + [0, 0, 2, 0, 1, 1, 0, 0], + [0, 0, 0, 2, 1, 1, 0, 0], + [0, 0, 0, 0, 2, 0, 1, 1], + [0, 0, 0, 0, 0, 2, 1, 1]]) + + An error will be raised if more than one value + is written to a given entry. Here, the ones overlap + with the main diagonal if they are placed on the + first diagonal: + + >>> banded({0: (2,)*5, 1: (ones(2),)*3}) + Traceback (most recent call last): + ... + ValueError: collision at (1, 1) + + By placing a 0 at the bottom left of the 2x2 matrix of + ones, the collision is avoided: + + >>> u2 = Matrix([ + ... [1, 1], + ... [0, 1]]) + >>> banded({0: [2]*5, 1: [u2]*3}) + Matrix([ + [2, 1, 1, 0, 0, 0, 0], + [0, 2, 1, 0, 0, 0, 0], + [0, 0, 2, 1, 1, 0, 0], + [0, 0, 0, 2, 1, 0, 0], + [0, 0, 0, 0, 2, 1, 1], + [0, 0, 0, 0, 0, 0, 1]]) + """ + try: + if len(args) not in (1, 2, 3): + raise TypeError + if not isinstance(args[-1], (dict, Dict)): + raise TypeError + if len(args) == 1: + rows = kwargs.get('rows', None) + cols = kwargs.get('cols', None) + if rows is not None: + rows = as_int(rows) + if cols is not None: + cols = as_int(cols) + elif len(args) == 2: + rows = cols = as_int(args[0]) + else: + rows, cols = map(as_int, args[:2]) + # fails with ValueError if any keys are not ints + _ = all(as_int(k) for k in args[-1]) + except (ValueError, TypeError): + raise TypeError(filldedent( + '''unrecognized input to banded: + expecting [[row,] col,] {int: value}''')) + def rc(d): + # return row,col coord of diagonal start + r = -d if d < 0 else 0 + c = 0 if r else d + return r, c + smat = {} + undone = [] + tba = Dummy() + # first handle objects with size + for d, v in args[-1].items(): + r, c = rc(d) + # note: only list and tuple are recognized since this + # will allow other Basic objects like Tuple + # into the matrix if so desired + if isinstance(v, (list, tuple)): + extra = 0 + for i, vi in enumerate(v): + i += extra + if is_sequence(vi): + vi = SparseMatrix(vi) + smat[r + i, c + i] = vi + extra += min(vi.shape) - 1 + else: + smat[r + i, c + i] = vi + elif is_sequence(v): + v = SparseMatrix(v) + rv, cv = v.shape + if rows and cols: + nr, xr = divmod(rows - r, rv) + nc, xc = divmod(cols - c, cv) + x = xr or xc + do = min(nr, nc) + elif rows: + do, x = divmod(rows - r, rv) + elif cols: + do, x = divmod(cols - c, cv) + else: + do = 1 + x = 0 + if x: + raise ValueError(filldedent(''' + sequence does not fit an integral number of times + in the matrix''')) + j = min(v.shape) + for i in range(do): + smat[r, c] = v + r += j + c += j + elif v: + smat[r, c] = tba + undone.append((d, v)) + s = SparseMatrix(None, smat) # to expand matrices + smat = s.todok() + # check for dim errors here + if rows is not None and rows < s.rows: + raise ValueError('Designated rows %s < needed %s' % (rows, s.rows)) + if cols is not None and cols < s.cols: + raise ValueError('Designated cols %s < needed %s' % (cols, s.cols)) + if rows is cols is None: + rows = s.rows + cols = s.cols + elif rows is not None and cols is None: + cols = max(rows, s.cols) + elif cols is not None and rows is None: + rows = max(cols, s.rows) + def update(i, j, v): + # update smat and make sure there are + # no collisions + if v: + if (i, j) in smat and smat[i, j] not in (tba, v): + raise ValueError('collision at %s' % ((i, j),)) + smat[i, j] = v + if undone: + for d, vi in undone: + r, c = rc(d) + v = vi if callable(vi) else lambda _: vi + i = 0 + while r + i < rows and c + i < cols: + update(r + i, c + i, v(i)) + i += 1 + return SparseMatrix(rows, cols, smat) diff --git a/phivenv/Lib/site-packages/sympy/matrices/subspaces.py b/phivenv/Lib/site-packages/sympy/matrices/subspaces.py new file mode 100644 index 0000000000000000000000000000000000000000..1ab0b71b4289ebaeb6394059c6a7cd49d3a148a1 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/subspaces.py @@ -0,0 +1,174 @@ +from .utilities import _iszero + + +def _columnspace(M, simplify=False): + """Returns a list of vectors (Matrix objects) that span columnspace of ``M`` + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix(3, 3, [1, 3, 0, -2, -6, 0, 3, 9, 6]) + >>> M + Matrix([ + [ 1, 3, 0], + [-2, -6, 0], + [ 3, 9, 6]]) + >>> M.columnspace() + [Matrix([ + [ 1], + [-2], + [ 3]]), Matrix([ + [0], + [0], + [6]])] + + See Also + ======== + + nullspace + rowspace + """ + + reduced, pivots = M.echelon_form(simplify=simplify, with_pivots=True) + + return [M.col(i) for i in pivots] + + +def _nullspace(M, simplify=False, iszerofunc=_iszero): + """Returns list of vectors (Matrix objects) that span nullspace of ``M`` + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix(3, 3, [1, 3, 0, -2, -6, 0, 3, 9, 6]) + >>> M + Matrix([ + [ 1, 3, 0], + [-2, -6, 0], + [ 3, 9, 6]]) + >>> M.nullspace() + [Matrix([ + [-3], + [ 1], + [ 0]])] + + See Also + ======== + + columnspace + rowspace + """ + + reduced, pivots = M.rref(iszerofunc=iszerofunc, simplify=simplify) + + free_vars = [i for i in range(M.cols) if i not in pivots] + basis = [] + + for free_var in free_vars: + # for each free variable, we will set it to 1 and all others + # to 0. Then, we will use back substitution to solve the system + vec = [M.zero] * M.cols + vec[free_var] = M.one + + for piv_row, piv_col in enumerate(pivots): + vec[piv_col] -= reduced[piv_row, free_var] + + basis.append(vec) + + return [M._new(M.cols, 1, b) for b in basis] + + +def _rowspace(M, simplify=False): + """Returns a list of vectors that span the row space of ``M``. + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix(3, 3, [1, 3, 0, -2, -6, 0, 3, 9, 6]) + >>> M + Matrix([ + [ 1, 3, 0], + [-2, -6, 0], + [ 3, 9, 6]]) + >>> M.rowspace() + [Matrix([[1, 3, 0]]), Matrix([[0, 0, 6]])] + """ + + reduced, pivots = M.echelon_form(simplify=simplify, with_pivots=True) + + return [reduced.row(i) for i in range(len(pivots))] + + +def _orthogonalize(cls, *vecs, normalize=False, rankcheck=False): + """Apply the Gram-Schmidt orthogonalization procedure + to vectors supplied in ``vecs``. + + Parameters + ========== + + vecs + vectors to be made orthogonal + + normalize : bool + If ``True``, return an orthonormal basis. + + rankcheck : bool + If ``True``, the computation does not stop when encountering + linearly dependent vectors. + + If ``False``, it will raise ``ValueError`` when any zero + or linearly dependent vectors are found. + + Returns + ======= + + list + List of orthogonal (or orthonormal) basis vectors. + + Examples + ======== + + >>> from sympy import I, Matrix + >>> v = [Matrix([1, I]), Matrix([1, -I])] + >>> Matrix.orthogonalize(*v) + [Matrix([ + [1], + [I]]), Matrix([ + [ 1], + [-I]])] + + See Also + ======== + + MatrixBase.QRdecomposition + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Gram%E2%80%93Schmidt_process + """ + from .decompositions import _QRdecomposition_optional + + if not vecs: + return [] + + all_row_vecs = (vecs[0].rows == 1) + + vecs = [x.vec() for x in vecs] + M = cls.hstack(*vecs) + Q, R = _QRdecomposition_optional(M, normalize=normalize) + + if rankcheck and Q.cols < len(vecs): + raise ValueError("GramSchmidt: vector set not linearly independent") + + ret = [] + for i in range(Q.cols): + if all_row_vecs: + col = cls(Q[:, i].T) + else: + col = cls(Q[:, i]) + ret.append(col) + return ret diff --git a/phivenv/Lib/site-packages/sympy/matrices/tests/__init__.py b/phivenv/Lib/site-packages/sympy/matrices/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/matrices/tests/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/tests/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff2003d80b62610a4773f8df236477b7413747ff Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/tests/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/tests/__pycache__/test_commonmatrix.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/tests/__pycache__/test_commonmatrix.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..615aa6c2493af2e41478d0de181b3660832702bf Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/tests/__pycache__/test_commonmatrix.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/tests/__pycache__/test_decompositions.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/tests/__pycache__/test_decompositions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05b991e115093fa452d9989faac98819a6e9b153 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/tests/__pycache__/test_decompositions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/tests/__pycache__/test_determinant.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/tests/__pycache__/test_determinant.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc603f4c04254fd65543811b4d99156d529272a4 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/tests/__pycache__/test_determinant.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/tests/__pycache__/test_domains.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/tests/__pycache__/test_domains.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0bbafbf7c0eddfe5a4185da8e1c2753c28fb3478 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/tests/__pycache__/test_domains.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/tests/__pycache__/test_eigen.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/tests/__pycache__/test_eigen.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f58181b38e5c097ae3cee345d3705b58c74a5538 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/tests/__pycache__/test_eigen.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/tests/__pycache__/test_graph.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/tests/__pycache__/test_graph.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b641874fe36430abb7ef3260c53c6b9c904be038 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/tests/__pycache__/test_graph.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/tests/__pycache__/test_immutable.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/tests/__pycache__/test_immutable.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91128efdd2a5d3221163d84fa75dde28ea7121b4 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/tests/__pycache__/test_immutable.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/tests/__pycache__/test_interactions.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/tests/__pycache__/test_interactions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a95db1c8f7935133f468e3a22b4c4a58b9daa30 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/tests/__pycache__/test_interactions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/tests/__pycache__/test_normalforms.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/tests/__pycache__/test_normalforms.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..007be9d5e4c640bcb2140b96ba3c5a11cc625311 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/tests/__pycache__/test_normalforms.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/tests/__pycache__/test_reductions.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/tests/__pycache__/test_reductions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3ca177d45009d65a9928430abed799f5e4c4e17 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/tests/__pycache__/test_reductions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/tests/__pycache__/test_repmatrix.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/tests/__pycache__/test_repmatrix.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8da60ef3da56c713ed5648712ff0a52fe933fa4 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/tests/__pycache__/test_repmatrix.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/tests/__pycache__/test_solvers.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/tests/__pycache__/test_solvers.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..734b143a9789d71d0cad0607c3baa0817f40c234 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/tests/__pycache__/test_solvers.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/tests/__pycache__/test_sparse.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/tests/__pycache__/test_sparse.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9597b11ce937bce163bf7aed1220dc75d43dadd Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/tests/__pycache__/test_sparse.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/tests/__pycache__/test_sparsetools.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/tests/__pycache__/test_sparsetools.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3bbbd738cf44e7181d3a02d32afe82a2c2b89934 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/tests/__pycache__/test_sparsetools.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/tests/__pycache__/test_subspaces.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/matrices/tests/__pycache__/test_subspaces.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69992b9b0ec2efc0150de7f813751ca32092596c Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/matrices/tests/__pycache__/test_subspaces.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/matrices/tests/test_commonmatrix.py b/phivenv/Lib/site-packages/sympy/matrices/tests/test_commonmatrix.py new file mode 100644 index 0000000000000000000000000000000000000000..6735adc1a9d4f9934a55c7ee70b087a19d3a48b4 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/tests/test_commonmatrix.py @@ -0,0 +1,1266 @@ +# +# Code for testing deprecated matrix classes. New test code should not be added +# here. Instead, add it to test_matrixbase.py. +# +# This entire test module and the corresponding sympy/matrices/common.py +# module will be removed in a future release. +# +from sympy.testing.pytest import raises, XFAIL, warns_deprecated_sympy + +from sympy.assumptions import Q +from sympy.core.expr import Expr +from sympy.core.add import Add +from sympy.core.function import Function +from sympy.core.kind import NumberKind, UndefinedKind +from sympy.core.numbers import I, Integer, oo, pi, Rational +from sympy.core.singleton import S +from sympy.core.symbol import Symbol, symbols +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import cos, sin +from sympy.matrices.exceptions import ShapeError, NonSquareMatrixError +from sympy.matrices.kind import MatrixKind +from sympy.matrices.common import ( + _MinimalMatrix, _CastableMatrix, MatrixShaping, MatrixProperties, + MatrixOperations, MatrixArithmetic, MatrixSpecial) +from sympy.matrices.matrices import MatrixCalculus +from sympy.matrices import (Matrix, diag, eye, + matrix_multiply_elementwise, ones, zeros, SparseMatrix, banded, + MutableDenseMatrix, MutableSparseMatrix, ImmutableDenseMatrix, + ImmutableSparseMatrix) +from sympy.polys.polytools import Poly +from sympy.utilities.iterables import flatten +from sympy.tensor.array.dense_ndim_array import ImmutableDenseNDimArray as Array + +from sympy.abc import x, y, z + + +def test_matrix_deprecated_isinstance(): + + # Test that e.g. isinstance(M, MatrixCommon) still gives True when M is a + # Matrix for each of the deprecated matrix classes. + + from sympy.matrices.common import ( + MatrixRequired, + MatrixShaping, + MatrixSpecial, + MatrixProperties, + MatrixOperations, + MatrixArithmetic, + MatrixCommon + ) + from sympy.matrices.matrices import ( + MatrixDeterminant, + MatrixReductions, + MatrixSubspaces, + MatrixEigen, + MatrixCalculus, + MatrixDeprecated + ) + from sympy import ( + Matrix, + ImmutableMatrix, + SparseMatrix, + ImmutableSparseMatrix + ) + all_mixins = ( + MatrixRequired, + MatrixShaping, + MatrixSpecial, + MatrixProperties, + MatrixOperations, + MatrixArithmetic, + MatrixCommon, + MatrixDeterminant, + MatrixReductions, + MatrixSubspaces, + MatrixEigen, + MatrixCalculus, + MatrixDeprecated + ) + all_matrices = ( + Matrix, + ImmutableMatrix, + SparseMatrix, + ImmutableSparseMatrix + ) + + Ms = [M([[1, 2], [3, 4]]) for M in all_matrices] + t = () + + for mixin in all_mixins: + for M in Ms: + with warns_deprecated_sympy(): + assert isinstance(M, mixin) is True + with warns_deprecated_sympy(): + assert isinstance(t, mixin) is False + + +# classes to test the deprecated matrix classes. We use warns_deprecated_sympy +# to suppress the deprecation warnings because subclassing the deprecated +# classes causes a warning to be raised. + +with warns_deprecated_sympy(): + class ShapingOnlyMatrix(_MinimalMatrix, _CastableMatrix, MatrixShaping): + pass + + +def eye_Shaping(n): + return ShapingOnlyMatrix(n, n, lambda i, j: int(i == j)) + + +def zeros_Shaping(n): + return ShapingOnlyMatrix(n, n, lambda i, j: 0) + + +with warns_deprecated_sympy(): + class PropertiesOnlyMatrix(_MinimalMatrix, _CastableMatrix, MatrixProperties): + pass + + +def eye_Properties(n): + return PropertiesOnlyMatrix(n, n, lambda i, j: int(i == j)) + + +def zeros_Properties(n): + return PropertiesOnlyMatrix(n, n, lambda i, j: 0) + + +with warns_deprecated_sympy(): + class OperationsOnlyMatrix(_MinimalMatrix, _CastableMatrix, MatrixOperations): + pass + + +def eye_Operations(n): + return OperationsOnlyMatrix(n, n, lambda i, j: int(i == j)) + + +def zeros_Operations(n): + return OperationsOnlyMatrix(n, n, lambda i, j: 0) + + +with warns_deprecated_sympy(): + class ArithmeticOnlyMatrix(_MinimalMatrix, _CastableMatrix, MatrixArithmetic): + pass + + +def eye_Arithmetic(n): + return ArithmeticOnlyMatrix(n, n, lambda i, j: int(i == j)) + + +def zeros_Arithmetic(n): + return ArithmeticOnlyMatrix(n, n, lambda i, j: 0) + + +with warns_deprecated_sympy(): + class SpecialOnlyMatrix(_MinimalMatrix, _CastableMatrix, MatrixSpecial): + pass + + +with warns_deprecated_sympy(): + class CalculusOnlyMatrix(_MinimalMatrix, _CastableMatrix, MatrixCalculus): + pass + + +def test__MinimalMatrix(): + x = _MinimalMatrix(2, 3, [1, 2, 3, 4, 5, 6]) + assert x.rows == 2 + assert x.cols == 3 + assert x[2] == 3 + assert x[1, 1] == 5 + assert list(x) == [1, 2, 3, 4, 5, 6] + assert list(x[1, :]) == [4, 5, 6] + assert list(x[:, 1]) == [2, 5] + assert list(x[:, :]) == list(x) + assert x[:, :] == x + assert _MinimalMatrix(x) == x + assert _MinimalMatrix([[1, 2, 3], [4, 5, 6]]) == x + assert _MinimalMatrix(([1, 2, 3], [4, 5, 6])) == x + assert _MinimalMatrix([(1, 2, 3), (4, 5, 6)]) == x + assert _MinimalMatrix(((1, 2, 3), (4, 5, 6))) == x + assert not (_MinimalMatrix([[1, 2], [3, 4], [5, 6]]) == x) + + +def test_kind(): + assert Matrix([[1, 2], [3, 4]]).kind == MatrixKind(NumberKind) + assert Matrix([[0, 0], [0, 0]]).kind == MatrixKind(NumberKind) + assert Matrix(0, 0, []).kind == MatrixKind(NumberKind) + assert Matrix([[x]]).kind == MatrixKind(NumberKind) + assert Matrix([[1, Matrix([[1]])]]).kind == MatrixKind(UndefinedKind) + assert SparseMatrix([[1]]).kind == MatrixKind(NumberKind) + assert SparseMatrix([[1, Matrix([[1]])]]).kind == MatrixKind(UndefinedKind) + + +# ShapingOnlyMatrix tests +def test_vec(): + m = ShapingOnlyMatrix(2, 2, [1, 3, 2, 4]) + m_vec = m.vec() + assert m_vec.cols == 1 + for i in range(4): + assert m_vec[i] == i + 1 + + +def test_todok(): + a, b, c, d = symbols('a:d') + m1 = MutableDenseMatrix([[a, b], [c, d]]) + m2 = ImmutableDenseMatrix([[a, b], [c, d]]) + m3 = MutableSparseMatrix([[a, b], [c, d]]) + m4 = ImmutableSparseMatrix([[a, b], [c, d]]) + assert m1.todok() == m2.todok() == m3.todok() == m4.todok() == \ + {(0, 0): a, (0, 1): b, (1, 0): c, (1, 1): d} + + +def test_tolist(): + lst = [[S.One, S.Half, x*y, S.Zero], [x, y, z, x**2], [y, -S.One, z*x, 3]] + flat_lst = [S.One, S.Half, x*y, S.Zero, x, y, z, x**2, y, -S.One, z*x, 3] + m = ShapingOnlyMatrix(3, 4, flat_lst) + assert m.tolist() == lst + +def test_todod(): + m = ShapingOnlyMatrix(3, 2, [[S.One, 0], [0, S.Half], [x, 0]]) + dict = {0: {0: S.One}, 1: {1: S.Half}, 2: {0: x}} + assert m.todod() == dict + +def test_row_col_del(): + e = ShapingOnlyMatrix(3, 3, [1, 2, 3, 4, 5, 6, 7, 8, 9]) + raises(IndexError, lambda: e.row_del(5)) + raises(IndexError, lambda: e.row_del(-5)) + raises(IndexError, lambda: e.col_del(5)) + raises(IndexError, lambda: e.col_del(-5)) + + assert e.row_del(2) == e.row_del(-1) == Matrix([[1, 2, 3], [4, 5, 6]]) + assert e.col_del(2) == e.col_del(-1) == Matrix([[1, 2], [4, 5], [7, 8]]) + + assert e.row_del(1) == e.row_del(-2) == Matrix([[1, 2, 3], [7, 8, 9]]) + assert e.col_del(1) == e.col_del(-2) == Matrix([[1, 3], [4, 6], [7, 9]]) + + +def test_get_diag_blocks1(): + a = Matrix([[1, 2], [2, 3]]) + b = Matrix([[3, x], [y, 3]]) + c = Matrix([[3, x, 3], [y, 3, z], [x, y, z]]) + assert a.get_diag_blocks() == [a] + assert b.get_diag_blocks() == [b] + assert c.get_diag_blocks() == [c] + + +def test_get_diag_blocks2(): + a = Matrix([[1, 2], [2, 3]]) + b = Matrix([[3, x], [y, 3]]) + c = Matrix([[3, x, 3], [y, 3, z], [x, y, z]]) + A, B, C, D = diag(a, b, b), diag(a, b, c), diag(a, c, b), diag(c, c, b) + A = ShapingOnlyMatrix(A.rows, A.cols, A) + B = ShapingOnlyMatrix(B.rows, B.cols, B) + C = ShapingOnlyMatrix(C.rows, C.cols, C) + D = ShapingOnlyMatrix(D.rows, D.cols, D) + + assert A.get_diag_blocks() == [a, b, b] + assert B.get_diag_blocks() == [a, b, c] + assert C.get_diag_blocks() == [a, c, b] + assert D.get_diag_blocks() == [c, c, b] + + +def test_shape(): + m = ShapingOnlyMatrix(1, 2, [0, 0]) + assert m.shape == (1, 2) + + +def test_reshape(): + m0 = eye_Shaping(3) + assert m0.reshape(1, 9) == Matrix(1, 9, (1, 0, 0, 0, 1, 0, 0, 0, 1)) + m1 = ShapingOnlyMatrix(3, 4, lambda i, j: i + j) + assert m1.reshape( + 4, 3) == Matrix(((0, 1, 2), (3, 1, 2), (3, 4, 2), (3, 4, 5))) + assert m1.reshape(2, 6) == Matrix(((0, 1, 2, 3, 1, 2), (3, 4, 2, 3, 4, 5))) + + +def test_row_col(): + m = ShapingOnlyMatrix(3, 3, [1, 2, 3, 4, 5, 6, 7, 8, 9]) + assert m.row(0) == Matrix(1, 3, [1, 2, 3]) + assert m.col(0) == Matrix(3, 1, [1, 4, 7]) + + +def test_row_join(): + assert eye_Shaping(3).row_join(Matrix([7, 7, 7])) == \ + Matrix([[1, 0, 0, 7], + [0, 1, 0, 7], + [0, 0, 1, 7]]) + + +def test_col_join(): + assert eye_Shaping(3).col_join(Matrix([[7, 7, 7]])) == \ + Matrix([[1, 0, 0], + [0, 1, 0], + [0, 0, 1], + [7, 7, 7]]) + + +def test_row_insert(): + r4 = Matrix([[4, 4, 4]]) + for i in range(-4, 5): + l = [1, 0, 0] + l.insert(i, 4) + assert flatten(eye_Shaping(3).row_insert(i, r4).col(0).tolist()) == l + + +def test_col_insert(): + c4 = Matrix([4, 4, 4]) + for i in range(-4, 5): + l = [0, 0, 0] + l.insert(i, 4) + assert flatten(zeros_Shaping(3).col_insert(i, c4).row(0).tolist()) == l + # issue 13643 + assert eye_Shaping(6).col_insert(3, Matrix([[2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2]])) == \ + Matrix([[1, 0, 0, 2, 2, 0, 0, 0], + [0, 1, 0, 2, 2, 0, 0, 0], + [0, 0, 1, 2, 2, 0, 0, 0], + [0, 0, 0, 2, 2, 1, 0, 0], + [0, 0, 0, 2, 2, 0, 1, 0], + [0, 0, 0, 2, 2, 0, 0, 1]]) + + +def test_extract(): + m = ShapingOnlyMatrix(4, 3, lambda i, j: i*3 + j) + assert m.extract([0, 1, 3], [0, 1]) == Matrix(3, 2, [0, 1, 3, 4, 9, 10]) + assert m.extract([0, 3], [0, 0, 2]) == Matrix(2, 3, [0, 0, 2, 9, 9, 11]) + assert m.extract(range(4), range(3)) == m + raises(IndexError, lambda: m.extract([4], [0])) + raises(IndexError, lambda: m.extract([0], [3])) + + +def test_hstack(): + m = ShapingOnlyMatrix(4, 3, lambda i, j: i*3 + j) + m2 = ShapingOnlyMatrix(3, 4, lambda i, j: i*3 + j) + assert m == m.hstack(m) + assert m.hstack(m, m, m) == ShapingOnlyMatrix.hstack(m, m, m) == Matrix([ + [0, 1, 2, 0, 1, 2, 0, 1, 2], + [3, 4, 5, 3, 4, 5, 3, 4, 5], + [6, 7, 8, 6, 7, 8, 6, 7, 8], + [9, 10, 11, 9, 10, 11, 9, 10, 11]]) + raises(ShapeError, lambda: m.hstack(m, m2)) + assert Matrix.hstack() == Matrix() + + # test regression #12938 + M1 = Matrix.zeros(0, 0) + M2 = Matrix.zeros(0, 1) + M3 = Matrix.zeros(0, 2) + M4 = Matrix.zeros(0, 3) + m = ShapingOnlyMatrix.hstack(M1, M2, M3, M4) + assert m.rows == 0 and m.cols == 6 + + +def test_vstack(): + m = ShapingOnlyMatrix(4, 3, lambda i, j: i*3 + j) + m2 = ShapingOnlyMatrix(3, 4, lambda i, j: i*3 + j) + assert m == m.vstack(m) + assert m.vstack(m, m, m) == ShapingOnlyMatrix.vstack(m, m, m) == Matrix([ + [0, 1, 2], + [3, 4, 5], + [6, 7, 8], + [9, 10, 11], + [0, 1, 2], + [3, 4, 5], + [6, 7, 8], + [9, 10, 11], + [0, 1, 2], + [3, 4, 5], + [6, 7, 8], + [9, 10, 11]]) + raises(ShapeError, lambda: m.vstack(m, m2)) + assert Matrix.vstack() == Matrix() + + +# PropertiesOnlyMatrix tests +def test_atoms(): + m = PropertiesOnlyMatrix(2, 2, [1, 2, x, 1 - 1/x]) + assert m.atoms() == {S.One, S(2), S.NegativeOne, x} + assert m.atoms(Symbol) == {x} + + +def test_free_symbols(): + assert PropertiesOnlyMatrix([[x], [0]]).free_symbols == {x} + + +def test_has(): + A = PropertiesOnlyMatrix(((x, y), (2, 3))) + assert A.has(x) + assert not A.has(z) + assert A.has(Symbol) + + A = PropertiesOnlyMatrix(((2, y), (2, 3))) + assert not A.has(x) + + +def test_is_anti_symmetric(): + x = symbols('x') + assert PropertiesOnlyMatrix(2, 1, [1, 2]).is_anti_symmetric() is False + m = PropertiesOnlyMatrix(3, 3, [0, x**2 + 2*x + 1, y, -(x + 1)**2, 0, x*y, -y, -x*y, 0]) + assert m.is_anti_symmetric() is True + assert m.is_anti_symmetric(simplify=False) is False + assert m.is_anti_symmetric(simplify=lambda x: x) is False + + m = PropertiesOnlyMatrix(3, 3, [x.expand() for x in m]) + assert m.is_anti_symmetric(simplify=False) is True + m = PropertiesOnlyMatrix(3, 3, [x.expand() for x in [S.One] + list(m)[1:]]) + assert m.is_anti_symmetric() is False + + +def test_diagonal_symmetrical(): + m = PropertiesOnlyMatrix(2, 2, [0, 1, 1, 0]) + assert not m.is_diagonal() + assert m.is_symmetric() + assert m.is_symmetric(simplify=False) + + m = PropertiesOnlyMatrix(2, 2, [1, 0, 0, 1]) + assert m.is_diagonal() + + m = PropertiesOnlyMatrix(3, 3, diag(1, 2, 3)) + assert m.is_diagonal() + assert m.is_symmetric() + + m = PropertiesOnlyMatrix(3, 3, [1, 0, 0, 0, 2, 0, 0, 0, 3]) + assert m == diag(1, 2, 3) + + m = PropertiesOnlyMatrix(2, 3, zeros(2, 3)) + assert not m.is_symmetric() + assert m.is_diagonal() + + m = PropertiesOnlyMatrix(((5, 0), (0, 6), (0, 0))) + assert m.is_diagonal() + + m = PropertiesOnlyMatrix(((5, 0, 0), (0, 6, 0))) + assert m.is_diagonal() + + m = Matrix(3, 3, [1, x**2 + 2*x + 1, y, (x + 1)**2, 2, 0, y, 0, 3]) + assert m.is_symmetric() + assert not m.is_symmetric(simplify=False) + assert m.expand().is_symmetric(simplify=False) + + +def test_is_hermitian(): + a = PropertiesOnlyMatrix([[1, I], [-I, 1]]) + assert a.is_hermitian + a = PropertiesOnlyMatrix([[2*I, I], [-I, 1]]) + assert a.is_hermitian is False + a = PropertiesOnlyMatrix([[x, I], [-I, 1]]) + assert a.is_hermitian is None + a = PropertiesOnlyMatrix([[x, 1], [-I, 1]]) + assert a.is_hermitian is False + + +def test_is_Identity(): + assert eye_Properties(3).is_Identity + assert not PropertiesOnlyMatrix(zeros(3)).is_Identity + assert not PropertiesOnlyMatrix(ones(3)).is_Identity + # issue 6242 + assert not PropertiesOnlyMatrix([[1, 0, 0]]).is_Identity + + +def test_is_symbolic(): + a = PropertiesOnlyMatrix([[x, x], [x, x]]) + assert a.is_symbolic() is True + a = PropertiesOnlyMatrix([[1, 2, 3, 4], [5, 6, 7, 8]]) + assert a.is_symbolic() is False + a = PropertiesOnlyMatrix([[1, 2, 3, 4], [5, 6, x, 8]]) + assert a.is_symbolic() is True + a = PropertiesOnlyMatrix([[1, x, 3]]) + assert a.is_symbolic() is True + a = PropertiesOnlyMatrix([[1, 2, 3]]) + assert a.is_symbolic() is False + a = PropertiesOnlyMatrix([[1], [x], [3]]) + assert a.is_symbolic() is True + a = PropertiesOnlyMatrix([[1], [2], [3]]) + assert a.is_symbolic() is False + + +def test_is_upper(): + a = PropertiesOnlyMatrix([[1, 2, 3]]) + assert a.is_upper is True + a = PropertiesOnlyMatrix([[1], [2], [3]]) + assert a.is_upper is False + + +def test_is_lower(): + a = PropertiesOnlyMatrix([[1, 2, 3]]) + assert a.is_lower is False + a = PropertiesOnlyMatrix([[1], [2], [3]]) + assert a.is_lower is True + + +def test_is_square(): + m = PropertiesOnlyMatrix([[1], [1]]) + m2 = PropertiesOnlyMatrix([[2, 2], [2, 2]]) + assert not m.is_square + assert m2.is_square + + +def test_is_symmetric(): + m = PropertiesOnlyMatrix(2, 2, [0, 1, 1, 0]) + assert m.is_symmetric() + m = PropertiesOnlyMatrix(2, 2, [0, 1, 0, 1]) + assert not m.is_symmetric() + + +def test_is_hessenberg(): + A = PropertiesOnlyMatrix([[3, 4, 1], [2, 4, 5], [0, 1, 2]]) + assert A.is_upper_hessenberg + A = PropertiesOnlyMatrix(3, 3, [3, 2, 0, 4, 4, 1, 1, 5, 2]) + assert A.is_lower_hessenberg + A = PropertiesOnlyMatrix(3, 3, [3, 2, -1, 4, 4, 1, 1, 5, 2]) + assert A.is_lower_hessenberg is False + assert A.is_upper_hessenberg is False + + A = PropertiesOnlyMatrix([[3, 4, 1], [2, 4, 5], [3, 1, 2]]) + assert not A.is_upper_hessenberg + + +def test_is_zero(): + assert PropertiesOnlyMatrix(0, 0, []).is_zero_matrix + assert PropertiesOnlyMatrix([[0, 0], [0, 0]]).is_zero_matrix + assert PropertiesOnlyMatrix(zeros(3, 4)).is_zero_matrix + assert not PropertiesOnlyMatrix(eye(3)).is_zero_matrix + assert PropertiesOnlyMatrix([[x, 0], [0, 0]]).is_zero_matrix == None + assert PropertiesOnlyMatrix([[x, 1], [0, 0]]).is_zero_matrix == False + a = Symbol('a', nonzero=True) + assert PropertiesOnlyMatrix([[a, 0], [0, 0]]).is_zero_matrix == False + + +def test_values(): + assert set(PropertiesOnlyMatrix(2, 2, [0, 1, 2, 3] + ).values()) == {1, 2, 3} + x = Symbol('x', real=True) + assert set(PropertiesOnlyMatrix(2, 2, [x, 0, 0, 1] + ).values()) == {x, 1} + + +# OperationsOnlyMatrix tests +def test_applyfunc(): + m0 = OperationsOnlyMatrix(eye(3)) + assert m0.applyfunc(lambda x: 2*x) == eye(3)*2 + assert m0.applyfunc(lambda x: 0) == zeros(3) + assert m0.applyfunc(lambda x: 1) == ones(3) + + +def test_adjoint(): + dat = [[0, I], [1, 0]] + ans = OperationsOnlyMatrix([[0, 1], [-I, 0]]) + assert ans.adjoint() == Matrix(dat) + + +def test_as_real_imag(): + m1 = OperationsOnlyMatrix(2, 2, [1, 2, 3, 4]) + m3 = OperationsOnlyMatrix(2, 2, + [1 + S.ImaginaryUnit, 2 + 2*S.ImaginaryUnit, + 3 + 3*S.ImaginaryUnit, 4 + 4*S.ImaginaryUnit]) + + a, b = m3.as_real_imag() + assert a == m1 + assert b == m1 + + +def test_conjugate(): + M = OperationsOnlyMatrix([[0, I, 5], + [1, 2, 0]]) + + assert M.T == Matrix([[0, 1], + [I, 2], + [5, 0]]) + + assert M.C == Matrix([[0, -I, 5], + [1, 2, 0]]) + assert M.C == M.conjugate() + + assert M.H == M.T.C + assert M.H == Matrix([[ 0, 1], + [-I, 2], + [ 5, 0]]) + + +def test_doit(): + a = OperationsOnlyMatrix([[Add(x, x, evaluate=False)]]) + assert a[0] != 2*x + assert a.doit() == Matrix([[2*x]]) + + +def test_evalf(): + a = OperationsOnlyMatrix(2, 1, [sqrt(5), 6]) + assert all(a.evalf()[i] == a[i].evalf() for i in range(2)) + assert all(a.evalf(2)[i] == a[i].evalf(2) for i in range(2)) + assert all(a.n(2)[i] == a[i].n(2) for i in range(2)) + + +def test_expand(): + m0 = OperationsOnlyMatrix([[x*(x + y), 2], [((x + y)*y)*x, x*(y + x*(x + y))]]) + # Test if expand() returns a matrix + m1 = m0.expand() + assert m1 == Matrix( + [[x*y + x**2, 2], [x*y**2 + y*x**2, x*y + y*x**2 + x**3]]) + + a = Symbol('a', real=True) + + assert OperationsOnlyMatrix(1, 1, [exp(I*a)]).expand(complex=True) == \ + Matrix([cos(a) + I*sin(a)]) + + +def test_refine(): + m0 = OperationsOnlyMatrix([[Abs(x)**2, sqrt(x**2)], + [sqrt(x**2)*Abs(y)**2, sqrt(y**2)*Abs(x)**2]]) + m1 = m0.refine(Q.real(x) & Q.real(y)) + assert m1 == Matrix([[x**2, Abs(x)], [y**2*Abs(x), x**2*Abs(y)]]) + + m1 = m0.refine(Q.positive(x) & Q.positive(y)) + assert m1 == Matrix([[x**2, x], [x*y**2, x**2*y]]) + + m1 = m0.refine(Q.negative(x) & Q.negative(y)) + assert m1 == Matrix([[x**2, -x], [-x*y**2, -x**2*y]]) + + +def test_replace(): + F, G = symbols('F, G', cls=Function) + K = OperationsOnlyMatrix(2, 2, lambda i, j: G(i+j)) + M = OperationsOnlyMatrix(2, 2, lambda i, j: F(i+j)) + N = M.replace(F, G) + assert N == K + + +def test_replace_map(): + F, G = symbols('F, G', cls=Function) + K = OperationsOnlyMatrix(2, 2, [(G(0), {F(0): G(0)}), (G(1), {F(1): G(1)}), (G(1), {F(1) \ + : G(1)}), (G(2), {F(2): G(2)})]) + M = OperationsOnlyMatrix(2, 2, lambda i, j: F(i+j)) + N = M.replace(F, G, True) + assert N == K + + +def test_rot90(): + A = Matrix([[1, 2], [3, 4]]) + assert A == A.rot90(0) == A.rot90(4) + assert A.rot90(2) == A.rot90(-2) == A.rot90(6) == Matrix(((4, 3), (2, 1))) + assert A.rot90(3) == A.rot90(-1) == A.rot90(7) == Matrix(((2, 4), (1, 3))) + assert A.rot90() == A.rot90(-7) == A.rot90(-3) == Matrix(((3, 1), (4, 2))) + +def test_simplify(): + n = Symbol('n') + f = Function('f') + + M = OperationsOnlyMatrix([[ 1/x + 1/y, (x + x*y) / x ], + [ (f(x) + y*f(x))/f(x), 2 * (1/n - cos(n * pi)/n) / pi ]]) + assert M.simplify() == Matrix([[ (x + y)/(x * y), 1 + y ], + [ 1 + y, 2*((1 - 1*cos(pi*n))/(pi*n)) ]]) + eq = (1 + x)**2 + M = OperationsOnlyMatrix([[eq]]) + assert M.simplify() == Matrix([[eq]]) + assert M.simplify(ratio=oo) == Matrix([[eq.simplify(ratio=oo)]]) + + # https://github.com/sympy/sympy/issues/19353 + m = Matrix([[30, 2], [3, 4]]) + assert (1/(m.trace())).simplify() == Rational(1, 34) + + +def test_subs(): + assert OperationsOnlyMatrix([[1, x], [x, 4]]).subs(x, 5) == Matrix([[1, 5], [5, 4]]) + assert OperationsOnlyMatrix([[x, 2], [x + y, 4]]).subs([[x, -1], [y, -2]]) == \ + Matrix([[-1, 2], [-3, 4]]) + assert OperationsOnlyMatrix([[x, 2], [x + y, 4]]).subs([(x, -1), (y, -2)]) == \ + Matrix([[-1, 2], [-3, 4]]) + assert OperationsOnlyMatrix([[x, 2], [x + y, 4]]).subs({x: -1, y: -2}) == \ + Matrix([[-1, 2], [-3, 4]]) + assert OperationsOnlyMatrix([[x*y]]).subs({x: y - 1, y: x - 1}, simultaneous=True) == \ + Matrix([[(x - 1)*(y - 1)]]) + + +def test_trace(): + M = OperationsOnlyMatrix([[1, 0, 0], + [0, 5, 0], + [0, 0, 8]]) + assert M.trace() == 14 + + +def test_xreplace(): + assert OperationsOnlyMatrix([[1, x], [x, 4]]).xreplace({x: 5}) == \ + Matrix([[1, 5], [5, 4]]) + assert OperationsOnlyMatrix([[x, 2], [x + y, 4]]).xreplace({x: -1, y: -2}) == \ + Matrix([[-1, 2], [-3, 4]]) + + +def test_permute(): + a = OperationsOnlyMatrix(3, 4, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]) + + raises(IndexError, lambda: a.permute([[0, 5]])) + raises(ValueError, lambda: a.permute(Symbol('x'))) + b = a.permute_rows([[0, 2], [0, 1]]) + assert a.permute([[0, 2], [0, 1]]) == b == Matrix([ + [5, 6, 7, 8], + [9, 10, 11, 12], + [1, 2, 3, 4]]) + + b = a.permute_cols([[0, 2], [0, 1]]) + assert a.permute([[0, 2], [0, 1]], orientation='cols') == b ==\ + Matrix([ + [ 2, 3, 1, 4], + [ 6, 7, 5, 8], + [10, 11, 9, 12]]) + + b = a.permute_cols([[0, 2], [0, 1]], direction='backward') + assert a.permute([[0, 2], [0, 1]], orientation='cols', direction='backward') == b ==\ + Matrix([ + [ 3, 1, 2, 4], + [ 7, 5, 6, 8], + [11, 9, 10, 12]]) + + assert a.permute([1, 2, 0, 3]) == Matrix([ + [5, 6, 7, 8], + [9, 10, 11, 12], + [1, 2, 3, 4]]) + + from sympy.combinatorics import Permutation + assert a.permute(Permutation([1, 2, 0, 3])) == Matrix([ + [5, 6, 7, 8], + [9, 10, 11, 12], + [1, 2, 3, 4]]) + +def test_upper_triangular(): + + A = OperationsOnlyMatrix([ + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1] + ]) + + R = A.upper_triangular(2) + assert R == OperationsOnlyMatrix([ + [0, 0, 1, 1], + [0, 0, 0, 1], + [0, 0, 0, 0], + [0, 0, 0, 0] + ]) + + R = A.upper_triangular(-2) + assert R == OperationsOnlyMatrix([ + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [0, 1, 1, 1] + ]) + + R = A.upper_triangular() + assert R == OperationsOnlyMatrix([ + [1, 1, 1, 1], + [0, 1, 1, 1], + [0, 0, 1, 1], + [0, 0, 0, 1] + ]) + +def test_lower_triangular(): + A = OperationsOnlyMatrix([ + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1] + ]) + + L = A.lower_triangular() + assert L == ArithmeticOnlyMatrix([ + [1, 0, 0, 0], + [1, 1, 0, 0], + [1, 1, 1, 0], + [1, 1, 1, 1]]) + + L = A.lower_triangular(2) + assert L == ArithmeticOnlyMatrix([ + [1, 1, 1, 0], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1] + ]) + + L = A.lower_triangular(-2) + assert L == ArithmeticOnlyMatrix([ + [0, 0, 0, 0], + [0, 0, 0, 0], + [1, 0, 0, 0], + [1, 1, 0, 0] + ]) + + +# ArithmeticOnlyMatrix tests +def test_abs(): + m = ArithmeticOnlyMatrix([[1, -2], [x, y]]) + assert abs(m) == ArithmeticOnlyMatrix([[1, 2], [Abs(x), Abs(y)]]) + + +def test_add(): + m = ArithmeticOnlyMatrix([[1, 2, 3], [x, y, x], [2*y, -50, z*x]]) + assert m + m == ArithmeticOnlyMatrix([[2, 4, 6], [2*x, 2*y, 2*x], [4*y, -100, 2*z*x]]) + n = ArithmeticOnlyMatrix(1, 2, [1, 2]) + raises(ShapeError, lambda: m + n) + + +def test_multiplication(): + a = ArithmeticOnlyMatrix(( + (1, 2), + (3, 1), + (0, 6), + )) + + b = ArithmeticOnlyMatrix(( + (1, 2), + (3, 0), + )) + + raises(ShapeError, lambda: b*a) + raises(TypeError, lambda: a*{}) + + c = a*b + assert c[0, 0] == 7 + assert c[0, 1] == 2 + assert c[1, 0] == 6 + assert c[1, 1] == 6 + assert c[2, 0] == 18 + assert c[2, 1] == 0 + + try: + eval('c = a @ b') + except SyntaxError: + pass + else: + assert c[0, 0] == 7 + assert c[0, 1] == 2 + assert c[1, 0] == 6 + assert c[1, 1] == 6 + assert c[2, 0] == 18 + assert c[2, 1] == 0 + + h = a.multiply_elementwise(c) + assert h == matrix_multiply_elementwise(a, c) + assert h[0, 0] == 7 + assert h[0, 1] == 4 + assert h[1, 0] == 18 + assert h[1, 1] == 6 + assert h[2, 0] == 0 + assert h[2, 1] == 0 + raises(ShapeError, lambda: a.multiply_elementwise(b)) + + c = b * Symbol("x") + assert isinstance(c, ArithmeticOnlyMatrix) + assert c[0, 0] == x + assert c[0, 1] == 2*x + assert c[1, 0] == 3*x + assert c[1, 1] == 0 + + c2 = x * b + assert c == c2 + + c = 5 * b + assert isinstance(c, ArithmeticOnlyMatrix) + assert c[0, 0] == 5 + assert c[0, 1] == 2*5 + assert c[1, 0] == 3*5 + assert c[1, 1] == 0 + + try: + eval('c = 5 @ b') + except SyntaxError: + pass + else: + assert isinstance(c, ArithmeticOnlyMatrix) + assert c[0, 0] == 5 + assert c[0, 1] == 2*5 + assert c[1, 0] == 3*5 + assert c[1, 1] == 0 + + # https://github.com/sympy/sympy/issues/22353 + A = Matrix(ones(3, 1)) + _h = -Rational(1, 2) + B = Matrix([_h, _h, _h]) + assert A.multiply_elementwise(B) == Matrix([ + [_h], + [_h], + [_h]]) + + +def test_matmul(): + a = Matrix([[1, 2], [3, 4]]) + + assert a.__matmul__(2) == NotImplemented + + assert a.__rmatmul__(2) == NotImplemented + + #This is done this way because @ is only supported in Python 3.5+ + #To check 2@a case + try: + eval('2 @ a') + except SyntaxError: + pass + except TypeError: #TypeError is raised in case of NotImplemented is returned + pass + + #Check a@2 case + try: + eval('a @ 2') + except SyntaxError: + pass + except TypeError: #TypeError is raised in case of NotImplemented is returned + pass + + +def test_non_matmul(): + """ + Test that if explicitly specified as non-matrix, mul reverts + to scalar multiplication. + """ + class foo(Expr): + is_Matrix=False + is_MatrixLike=False + shape = (1, 1) + + A = Matrix([[1, 2], [3, 4]]) + b = foo() + assert b*A == Matrix([[b, 2*b], [3*b, 4*b]]) + assert A*b == Matrix([[b, 2*b], [3*b, 4*b]]) + + +def test_power(): + raises(NonSquareMatrixError, lambda: Matrix((1, 2))**2) + + A = ArithmeticOnlyMatrix([[2, 3], [4, 5]]) + assert (A**5)[:] == (6140, 8097, 10796, 14237) + A = ArithmeticOnlyMatrix([[2, 1, 3], [4, 2, 4], [6, 12, 1]]) + assert (A**3)[:] == (290, 262, 251, 448, 440, 368, 702, 954, 433) + assert A**0 == eye(3) + assert A**1 == A + assert (ArithmeticOnlyMatrix([[2]]) ** 100)[0, 0] == 2**100 + assert ArithmeticOnlyMatrix([[1, 2], [3, 4]])**Integer(2) == ArithmeticOnlyMatrix([[7, 10], [15, 22]]) + A = Matrix([[1,2],[4,5]]) + assert A.pow(20, method='cayley') == A.pow(20, method='multiply') + +def test_neg(): + n = ArithmeticOnlyMatrix(1, 2, [1, 2]) + assert -n == ArithmeticOnlyMatrix(1, 2, [-1, -2]) + + +def test_sub(): + n = ArithmeticOnlyMatrix(1, 2, [1, 2]) + assert n - n == ArithmeticOnlyMatrix(1, 2, [0, 0]) + + +def test_div(): + n = ArithmeticOnlyMatrix(1, 2, [1, 2]) + assert n/2 == ArithmeticOnlyMatrix(1, 2, [S.Half, S(2)/2]) + +# SpecialOnlyMatrix tests +def test_eye(): + assert list(SpecialOnlyMatrix.eye(2, 2)) == [1, 0, 0, 1] + assert list(SpecialOnlyMatrix.eye(2)) == [1, 0, 0, 1] + assert type(SpecialOnlyMatrix.eye(2)) == SpecialOnlyMatrix + assert type(SpecialOnlyMatrix.eye(2, cls=Matrix)) == Matrix + + +def test_ones(): + assert list(SpecialOnlyMatrix.ones(2, 2)) == [1, 1, 1, 1] + assert list(SpecialOnlyMatrix.ones(2)) == [1, 1, 1, 1] + assert SpecialOnlyMatrix.ones(2, 3) == Matrix([[1, 1, 1], [1, 1, 1]]) + assert type(SpecialOnlyMatrix.ones(2)) == SpecialOnlyMatrix + assert type(SpecialOnlyMatrix.ones(2, cls=Matrix)) == Matrix + + +def test_zeros(): + assert list(SpecialOnlyMatrix.zeros(2, 2)) == [0, 0, 0, 0] + assert list(SpecialOnlyMatrix.zeros(2)) == [0, 0, 0, 0] + assert SpecialOnlyMatrix.zeros(2, 3) == Matrix([[0, 0, 0], [0, 0, 0]]) + assert type(SpecialOnlyMatrix.zeros(2)) == SpecialOnlyMatrix + assert type(SpecialOnlyMatrix.zeros(2, cls=Matrix)) == Matrix + + +def test_diag_make(): + diag = SpecialOnlyMatrix.diag + a = Matrix([[1, 2], [2, 3]]) + b = Matrix([[3, x], [y, 3]]) + c = Matrix([[3, x, 3], [y, 3, z], [x, y, z]]) + assert diag(a, b, b) == Matrix([ + [1, 2, 0, 0, 0, 0], + [2, 3, 0, 0, 0, 0], + [0, 0, 3, x, 0, 0], + [0, 0, y, 3, 0, 0], + [0, 0, 0, 0, 3, x], + [0, 0, 0, 0, y, 3], + ]) + assert diag(a, b, c) == Matrix([ + [1, 2, 0, 0, 0, 0, 0], + [2, 3, 0, 0, 0, 0, 0], + [0, 0, 3, x, 0, 0, 0], + [0, 0, y, 3, 0, 0, 0], + [0, 0, 0, 0, 3, x, 3], + [0, 0, 0, 0, y, 3, z], + [0, 0, 0, 0, x, y, z], + ]) + assert diag(a, c, b) == Matrix([ + [1, 2, 0, 0, 0, 0, 0], + [2, 3, 0, 0, 0, 0, 0], + [0, 0, 3, x, 3, 0, 0], + [0, 0, y, 3, z, 0, 0], + [0, 0, x, y, z, 0, 0], + [0, 0, 0, 0, 0, 3, x], + [0, 0, 0, 0, 0, y, 3], + ]) + a = Matrix([x, y, z]) + b = Matrix([[1, 2], [3, 4]]) + c = Matrix([[5, 6]]) + # this "wandering diagonal" is what makes this + # a block diagonal where each block is independent + # of the others + assert diag(a, 7, b, c) == Matrix([ + [x, 0, 0, 0, 0, 0], + [y, 0, 0, 0, 0, 0], + [z, 0, 0, 0, 0, 0], + [0, 7, 0, 0, 0, 0], + [0, 0, 1, 2, 0, 0], + [0, 0, 3, 4, 0, 0], + [0, 0, 0, 0, 5, 6]]) + raises(ValueError, lambda: diag(a, 7, b, c, rows=5)) + assert diag(1) == Matrix([[1]]) + assert diag(1, rows=2) == Matrix([[1, 0], [0, 0]]) + assert diag(1, cols=2) == Matrix([[1, 0], [0, 0]]) + assert diag(1, rows=3, cols=2) == Matrix([[1, 0], [0, 0], [0, 0]]) + assert diag(*[2, 3]) == Matrix([ + [2, 0], + [0, 3]]) + assert diag(Matrix([2, 3])) == Matrix([ + [2], + [3]]) + assert diag([1, [2, 3], 4], unpack=False) == \ + diag([[1], [2, 3], [4]], unpack=False) == Matrix([ + [1, 0], + [2, 3], + [4, 0]]) + assert type(diag(1)) == SpecialOnlyMatrix + assert type(diag(1, cls=Matrix)) == Matrix + assert Matrix.diag([1, 2, 3]) == Matrix.diag(1, 2, 3) + assert Matrix.diag([1, 2, 3], unpack=False).shape == (3, 1) + assert Matrix.diag([[1, 2, 3]]).shape == (3, 1) + assert Matrix.diag([[1, 2, 3]], unpack=False).shape == (1, 3) + assert Matrix.diag([[[1, 2, 3]]]).shape == (1, 3) + # kerning can be used to move the starting point + assert Matrix.diag(ones(0, 2), 1, 2) == Matrix([ + [0, 0, 1, 0], + [0, 0, 0, 2]]) + assert Matrix.diag(ones(2, 0), 1, 2) == Matrix([ + [0, 0], + [0, 0], + [1, 0], + [0, 2]]) + + +def test_diagonal(): + m = Matrix(3, 3, range(9)) + d = m.diagonal() + assert d == m.diagonal(0) + assert tuple(d) == (0, 4, 8) + assert tuple(m.diagonal(1)) == (1, 5) + assert tuple(m.diagonal(-1)) == (3, 7) + assert tuple(m.diagonal(2)) == (2,) + assert type(m.diagonal()) == type(m) + s = SparseMatrix(3, 3, {(1, 1): 1}) + assert type(s.diagonal()) == type(s) + assert type(m) != type(s) + raises(ValueError, lambda: m.diagonal(3)) + raises(ValueError, lambda: m.diagonal(-3)) + raises(ValueError, lambda: m.diagonal(pi)) + M = ones(2, 3) + assert banded({i: list(M.diagonal(i)) + for i in range(1-M.rows, M.cols)}) == M + + +def test_jordan_block(): + assert SpecialOnlyMatrix.jordan_block(3, 2) == SpecialOnlyMatrix.jordan_block(3, eigenvalue=2) \ + == SpecialOnlyMatrix.jordan_block(size=3, eigenvalue=2) \ + == SpecialOnlyMatrix.jordan_block(3, 2, band='upper') \ + == SpecialOnlyMatrix.jordan_block( + size=3, eigenval=2, eigenvalue=2) \ + == Matrix([ + [2, 1, 0], + [0, 2, 1], + [0, 0, 2]]) + + assert SpecialOnlyMatrix.jordan_block(3, 2, band='lower') == Matrix([ + [2, 0, 0], + [1, 2, 0], + [0, 1, 2]]) + # missing eigenvalue + raises(ValueError, lambda: SpecialOnlyMatrix.jordan_block(2)) + # non-integral size + raises(ValueError, lambda: SpecialOnlyMatrix.jordan_block(3.5, 2)) + # size not specified + raises(ValueError, lambda: SpecialOnlyMatrix.jordan_block(eigenvalue=2)) + # inconsistent eigenvalue + raises(ValueError, + lambda: SpecialOnlyMatrix.jordan_block( + eigenvalue=2, eigenval=4)) + + # Using alias keyword + assert SpecialOnlyMatrix.jordan_block(size=3, eigenvalue=2) == \ + SpecialOnlyMatrix.jordan_block(size=3, eigenval=2) + + +def test_orthogonalize(): + m = Matrix([[1, 2], [3, 4]]) + assert m.orthogonalize(Matrix([[2], [1]])) == [Matrix([[2], [1]])] + assert m.orthogonalize(Matrix([[2], [1]]), normalize=True) == \ + [Matrix([[2*sqrt(5)/5], [sqrt(5)/5]])] + assert m.orthogonalize(Matrix([[1], [2]]), Matrix([[-1], [4]])) == \ + [Matrix([[1], [2]]), Matrix([[Rational(-12, 5)], [Rational(6, 5)]])] + assert m.orthogonalize(Matrix([[0], [0]]), Matrix([[-1], [4]])) == \ + [Matrix([[-1], [4]])] + assert m.orthogonalize(Matrix([[0], [0]])) == [] + + n = Matrix([[9, 1, 9], [3, 6, 10], [8, 5, 2]]) + vecs = [Matrix([[-5], [1]]), Matrix([[-5], [2]]), Matrix([[-5], [-2]])] + assert n.orthogonalize(*vecs) == \ + [Matrix([[-5], [1]]), Matrix([[Rational(5, 26)], [Rational(25, 26)]])] + + vecs = [Matrix([0, 0, 0]), Matrix([1, 2, 3]), Matrix([1, 4, 5])] + raises(ValueError, lambda: Matrix.orthogonalize(*vecs, rankcheck=True)) + + vecs = [Matrix([1, 2, 3]), Matrix([4, 5, 6]), Matrix([7, 8, 9])] + raises(ValueError, lambda: Matrix.orthogonalize(*vecs, rankcheck=True)) + +def test_wilkinson(): + + wminus, wplus = Matrix.wilkinson(1) + assert wminus == Matrix([ + [-1, 1, 0], + [1, 0, 1], + [0, 1, 1]]) + assert wplus == Matrix([ + [1, 1, 0], + [1, 0, 1], + [0, 1, 1]]) + + wminus, wplus = Matrix.wilkinson(3) + assert wminus == Matrix([ + [-3, 1, 0, 0, 0, 0, 0], + [1, -2, 1, 0, 0, 0, 0], + [0, 1, -1, 1, 0, 0, 0], + [0, 0, 1, 0, 1, 0, 0], + [0, 0, 0, 1, 1, 1, 0], + [0, 0, 0, 0, 1, 2, 1], + + [0, 0, 0, 0, 0, 1, 3]]) + + assert wplus == Matrix([ + [3, 1, 0, 0, 0, 0, 0], + [1, 2, 1, 0, 0, 0, 0], + [0, 1, 1, 1, 0, 0, 0], + [0, 0, 1, 0, 1, 0, 0], + [0, 0, 0, 1, 1, 1, 0], + [0, 0, 0, 0, 1, 2, 1], + [0, 0, 0, 0, 0, 1, 3]]) + + +# CalculusOnlyMatrix tests +@XFAIL +def test_diff(): + x, y = symbols('x y') + m = CalculusOnlyMatrix(2, 1, [x, y]) + # TODO: currently not working as ``_MinimalMatrix`` cannot be sympified: + assert m.diff(x) == Matrix(2, 1, [1, 0]) + + +def test_integrate(): + x, y = symbols('x y') + m = CalculusOnlyMatrix(2, 1, [x, y]) + assert m.integrate(x) == Matrix(2, 1, [x**2/2, y*x]) + + +def test_jacobian2(): + rho, phi = symbols("rho,phi") + X = CalculusOnlyMatrix(3, 1, [rho*cos(phi), rho*sin(phi), rho**2]) + Y = CalculusOnlyMatrix(2, 1, [rho, phi]) + J = Matrix([ + [cos(phi), -rho*sin(phi)], + [sin(phi), rho*cos(phi)], + [ 2*rho, 0], + ]) + assert X.jacobian(Y) == J + + m = CalculusOnlyMatrix(2, 2, [1, 2, 3, 4]) + m2 = CalculusOnlyMatrix(4, 1, [1, 2, 3, 4]) + raises(TypeError, lambda: m.jacobian(Matrix([1, 2]))) + raises(TypeError, lambda: m2.jacobian(m)) + + +def test_limit(): + x, y = symbols('x y') + m = CalculusOnlyMatrix(2, 1, [1/x, y]) + assert m.limit(x, 5) == Matrix(2, 1, [Rational(1, 5), y]) + + +def test_issue_13774(): + M = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + v = [1, 1, 1] + raises(TypeError, lambda: M*v) + raises(TypeError, lambda: v*M) + +def test_companion(): + x = Symbol('x') + y = Symbol('y') + raises(ValueError, lambda: Matrix.companion(1)) + raises(ValueError, lambda: Matrix.companion(Poly([1], x))) + raises(ValueError, lambda: Matrix.companion(Poly([2, 1], x))) + raises(ValueError, lambda: Matrix.companion(Poly(x*y, [x, y]))) + + c0, c1, c2 = symbols('c0:3') + assert Matrix.companion(Poly([1, c0], x)) == Matrix([-c0]) + assert Matrix.companion(Poly([1, c1, c0], x)) == \ + Matrix([[0, -c0], [1, -c1]]) + assert Matrix.companion(Poly([1, c2, c1, c0], x)) == \ + Matrix([[0, 0, -c0], [1, 0, -c1], [0, 1, -c2]]) + +def test_issue_10589(): + x, y, z = symbols("x, y z") + M1 = Matrix([x, y, z]) + M1 = M1.subs(zip([x, y, z], [1, 2, 3])) + assert M1 == Matrix([[1], [2], [3]]) + + M2 = Matrix([[x, x, x, x, x], [x, x, x, x, x], [x, x, x, x, x]]) + M2 = M2.subs(zip([x], [1])) + assert M2 == Matrix([[1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]) + +def test_rmul_pr19860(): + class Foo(ImmutableDenseMatrix): + _op_priority = MutableDenseMatrix._op_priority + 0.01 + + a = Matrix(2, 2, [1, 2, 3, 4]) + b = Foo(2, 2, [1, 2, 3, 4]) + + # This would throw a RecursionError: maximum recursion depth + # since b always has higher priority even after a.as_mutable() + c = a*b + + assert isinstance(c, Foo) + assert c == Matrix([[7, 10], [15, 22]]) + + +def test_issue_18956(): + A = Array([[1, 2], [3, 4]]) + B = Matrix([[1,2],[3,4]]) + raises(TypeError, lambda: B + A) + raises(TypeError, lambda: A + B) + + +def test__eq__(): + class My(object): + def __iter__(self): + yield 1 + yield 2 + return + def __getitem__(self, i): + return list(self)[i] + a = Matrix(2, 1, [1, 2]) + assert a != My() + class My_sympy(My): + def _sympy_(self): + return Matrix(self) + assert a == My_sympy() diff --git a/phivenv/Lib/site-packages/sympy/matrices/tests/test_decompositions.py b/phivenv/Lib/site-packages/sympy/matrices/tests/test_decompositions.py new file mode 100644 index 0000000000000000000000000000000000000000..d169ec3a8846fed786981e62d932fd860b6d4951 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/tests/test_decompositions.py @@ -0,0 +1,474 @@ +from sympy.core.function import expand_mul +from sympy.core.numbers import I, Rational +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.complexes import Abs +from sympy.simplify.simplify import simplify +from sympy.matrices.exceptions import NonSquareMatrixError +from sympy.matrices import Matrix, zeros, eye, SparseMatrix +from sympy.abc import x, y, z +from sympy.testing.pytest import raises, slow +from sympy.testing.matrices import allclose + + +def test_LUdecomp(): + testmat = Matrix([[0, 2, 5, 3], + [3, 3, 7, 4], + [8, 4, 0, 2], + [-2, 6, 3, 4]]) + L, U, p = testmat.LUdecomposition() + assert L.is_lower + assert U.is_upper + assert (L*U).permute_rows(p, 'backward') - testmat == zeros(4) + + testmat = Matrix([[6, -2, 7, 4], + [0, 3, 6, 7], + [1, -2, 7, 4], + [-9, 2, 6, 3]]) + L, U, p = testmat.LUdecomposition() + assert L.is_lower + assert U.is_upper + assert (L*U).permute_rows(p, 'backward') - testmat == zeros(4) + + # non-square + testmat = Matrix([[1, 2, 3], + [4, 5, 6], + [7, 8, 9], + [10, 11, 12]]) + L, U, p = testmat.LUdecomposition(rankcheck=False) + assert L.is_lower + assert U.is_upper + assert (L*U).permute_rows(p, 'backward') - testmat == zeros(4, 3) + + # square and singular + testmat = Matrix([[1, 2, 3], + [2, 4, 6], + [4, 5, 6]]) + L, U, p = testmat.LUdecomposition(rankcheck=False) + assert L.is_lower + assert U.is_upper + assert (L*U).permute_rows(p, 'backward') - testmat == zeros(3) + + M = Matrix(((1, x, 1), (2, y, 0), (y, 0, z))) + L, U, p = M.LUdecomposition() + assert L.is_lower + assert U.is_upper + assert (L*U).permute_rows(p, 'backward') - M == zeros(3) + + mL = Matrix(( + (1, 0, 0), + (2, 3, 0), + )) + assert mL.is_lower is True + assert mL.is_upper is False + mU = Matrix(( + (1, 2, 3), + (0, 4, 5), + )) + assert mU.is_lower is False + assert mU.is_upper is True + + # test FF LUdecomp + M = Matrix([[1, 3, 3], + [3, 2, 6], + [3, 2, 2]]) + P, L, Dee, U = M.LUdecompositionFF() + assert P*M == L*Dee.inv()*U + + M = Matrix([[1, 2, 3, 4], + [3, -1, 2, 3], + [3, 1, 3, -2], + [6, -1, 0, 2]]) + P, L, Dee, U = M.LUdecompositionFF() + assert P*M == L*Dee.inv()*U + + M = Matrix([[0, 0, 1], + [2, 3, 0], + [3, 1, 4]]) + P, L, Dee, U = M.LUdecompositionFF() + assert P*M == L*Dee.inv()*U + + # issue 15794 + M = Matrix( + [[1, 2, 3], + [4, 5, 6], + [7, 8, 9]] + ) + raises(ValueError, lambda : M.LUdecomposition_Simple(rankcheck=True)) + +def test_singular_value_decompositionD(): + A = Matrix([[1, 2], [2, 1]]) + U, S, V = A.singular_value_decomposition() + assert U * S * V.T == A + assert U.T * U == eye(U.cols) + assert V.T * V == eye(V.cols) + + B = Matrix([[1, 2]]) + U, S, V = B.singular_value_decomposition() + + assert U * S * V.T == B + assert U.T * U == eye(U.cols) + assert V.T * V == eye(V.cols) + + C = Matrix([ + [1, 0, 0, 0, 2], + [0, 0, 3, 0, 0], + [0, 0, 0, 0, 0], + [0, 2, 0, 0, 0], + ]) + + U, S, V = C.singular_value_decomposition() + + assert U * S * V.T == C + assert U.T * U == eye(U.cols) + assert V.T * V == eye(V.cols) + + D = Matrix([[Rational(1, 3), sqrt(2)], [0, Rational(1, 4)]]) + U, S, V = D.singular_value_decomposition() + assert simplify(U.T * U) == eye(U.cols) + assert simplify(V.T * V) == eye(V.cols) + assert simplify(U * S * V.T) == D + + +def test_QR(): + A = Matrix([[1, 2], [2, 3]]) + Q, S = A.QRdecomposition() + R = Rational + assert Q == Matrix([ + [ 5**R(-1, 2), (R(2)/5)*(R(1)/5)**R(-1, 2)], + [2*5**R(-1, 2), (-R(1)/5)*(R(1)/5)**R(-1, 2)]]) + assert S == Matrix([[5**R(1, 2), 8*5**R(-1, 2)], [0, (R(1)/5)**R(1, 2)]]) + assert Q*S == A + assert Q.T * Q == eye(2) + + A = Matrix([[1, 1, 1], [1, 1, 3], [2, 3, 4]]) + Q, R = A.QRdecomposition() + assert Q.T * Q == eye(Q.cols) + assert R.is_upper + assert A == Q*R + + A = Matrix([[12, 0, -51], [6, 0, 167], [-4, 0, 24]]) + Q, R = A.QRdecomposition() + assert Q.T * Q == eye(Q.cols) + assert R.is_upper + assert A == Q*R + + x = Symbol('x') + A = Matrix([x]) + Q, R = A.QRdecomposition() + assert Q == Matrix([x / Abs(x)]) + assert R == Matrix([Abs(x)]) + + A = Matrix([[x, 0], [0, x]]) + Q, R = A.QRdecomposition() + assert Q == x / Abs(x) * Matrix([[1, 0], [0, 1]]) + assert R == Abs(x) * Matrix([[1, 0], [0, 1]]) + + +def test_QR_non_square(): + # Narrow (cols < rows) matrices + A = Matrix([[9, 0, 26], [12, 0, -7], [0, 4, 4], [0, -3, -3]]) + Q, R = A.QRdecomposition() + assert Q.T * Q == eye(Q.cols) + assert R.is_upper + assert A == Q*R + + A = Matrix([[1, -1, 4], [1, 4, -2], [1, 4, 2], [1, -1, 0]]) + Q, R = A.QRdecomposition() + assert Q.T * Q == eye(Q.cols) + assert R.is_upper + assert A == Q*R + + A = Matrix(2, 1, [1, 2]) + Q, R = A.QRdecomposition() + assert Q.T * Q == eye(Q.cols) + assert R.is_upper + assert A == Q*R + + # Wide (cols > rows) matrices + A = Matrix([[1, 2, 3], [4, 5, 6]]) + Q, R = A.QRdecomposition() + assert Q.T * Q == eye(Q.cols) + assert R.is_upper + assert A == Q*R + + A = Matrix([[1, 2, 3, 4], [1, 4, 9, 16], [1, 8, 27, 64]]) + Q, R = A.QRdecomposition() + assert Q.T * Q == eye(Q.cols) + assert R.is_upper + assert A == Q*R + + A = Matrix(1, 2, [1, 2]) + Q, R = A.QRdecomposition() + assert Q.T * Q == eye(Q.cols) + assert R.is_upper + assert A == Q*R + +def test_QR_trivial(): + # Rank deficient matrices + A = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + Q, R = A.QRdecomposition() + assert Q.T * Q == eye(Q.cols) + assert R.is_upper + assert A == Q*R + + A = Matrix([[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4]]) + Q, R = A.QRdecomposition() + assert Q.T * Q == eye(Q.cols) + assert R.is_upper + assert A == Q*R + + A = Matrix([[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4]]).T + Q, R = A.QRdecomposition() + assert Q.T * Q == eye(Q.cols) + assert R.is_upper + assert A == Q*R + + # Zero rank matrices + A = Matrix([[0, 0, 0]]) + Q, R = A.QRdecomposition() + assert Q.T * Q == eye(Q.cols) + assert R.is_upper + assert A == Q*R + + A = Matrix([[0, 0, 0]]).T + Q, R = A.QRdecomposition() + assert Q.T * Q == eye(Q.cols) + assert R.is_upper + assert A == Q*R + + A = Matrix([[0, 0, 0], [0, 0, 0]]) + Q, R = A.QRdecomposition() + assert Q.T * Q == eye(Q.cols) + assert R.is_upper + assert A == Q*R + + A = Matrix([[0, 0, 0], [0, 0, 0]]).T + Q, R = A.QRdecomposition() + assert Q.T * Q == eye(Q.cols) + assert R.is_upper + assert A == Q*R + + # Rank deficient matrices with zero norm from beginning columns + A = Matrix([[0, 0, 0], [1, 2, 3]]).T + Q, R = A.QRdecomposition() + assert Q.T * Q == eye(Q.cols) + assert R.is_upper + assert A == Q*R + + A = Matrix([[0, 0, 0, 0], [1, 2, 3, 4], [0, 0, 0, 0]]).T + Q, R = A.QRdecomposition() + assert Q.T * Q == eye(Q.cols) + assert R.is_upper + assert A == Q*R + + A = Matrix([[0, 0, 0, 0], [1, 2, 3, 4], [0, 0, 0, 0], [2, 4, 6, 8]]).T + Q, R = A.QRdecomposition() + assert Q.T * Q == eye(Q.cols) + assert R.is_upper + assert A == Q*R + + A = Matrix([[0, 0, 0], [0, 0, 0], [0, 0, 0], [1, 2, 3]]).T + Q, R = A.QRdecomposition() + assert Q.T * Q == eye(Q.cols) + assert R.is_upper + assert A == Q*R + + +def test_QR_float(): + A = Matrix([[1, 1], [1, 1.01]]) + Q, R = A.QRdecomposition() + assert allclose(Q * R, A) + assert allclose(Q * Q.T, Matrix.eye(2)) + assert allclose(Q.T * Q, Matrix.eye(2)) + + A = Matrix([[1, 1], [1, 1.001]]) + Q, R = A.QRdecomposition() + assert allclose(Q * R, A) + assert allclose(Q * Q.T, Matrix.eye(2)) + assert allclose(Q.T * Q, Matrix.eye(2)) + + +def test_LUdecomposition_Simple_iszerofunc(): + # Test if callable passed to matrices.LUdecomposition_Simple() as iszerofunc keyword argument is used inside + # matrices.LUdecomposition_Simple() + magic_string = "I got passed in!" + def goofyiszero(value): + raise ValueError(magic_string) + + try: + lu, p = Matrix([[1, 0], [0, 1]]).LUdecomposition_Simple(iszerofunc=goofyiszero) + except ValueError as err: + assert magic_string == err.args[0] + return + + assert False + +def test_LUdecomposition_iszerofunc(): + # Test if callable passed to matrices.LUdecomposition() as iszerofunc keyword argument is used inside + # matrices.LUdecomposition_Simple() + magic_string = "I got passed in!" + def goofyiszero(value): + raise ValueError(magic_string) + + try: + l, u, p = Matrix([[1, 0], [0, 1]]).LUdecomposition(iszerofunc=goofyiszero) + except ValueError as err: + assert magic_string == err.args[0] + return + + assert False + +def test_LDLdecomposition(): + raises(NonSquareMatrixError, lambda: Matrix((1, 2)).LDLdecomposition()) + raises(ValueError, lambda: Matrix(((1, 2), (3, 4))).LDLdecomposition()) + raises(ValueError, lambda: Matrix(((5 + I, 0), (0, 1))).LDLdecomposition()) + raises(ValueError, lambda: Matrix(((1, 5), (5, 1))).LDLdecomposition()) + raises(ValueError, lambda: Matrix(((1, 2), (3, 4))).LDLdecomposition(hermitian=False)) + A = Matrix(((1, 5), (5, 1))) + L, D = A.LDLdecomposition(hermitian=False) + assert L * D * L.T == A + A = Matrix(((25, 15, -5), (15, 18, 0), (-5, 0, 11))) + L, D = A.LDLdecomposition() + assert L * D * L.T == A + assert L.is_lower + assert L == Matrix([[1, 0, 0], [ Rational(3, 5), 1, 0], [Rational(-1, 5), Rational(1, 3), 1]]) + assert D.is_diagonal() + assert D == Matrix([[25, 0, 0], [0, 9, 0], [0, 0, 9]]) + A = Matrix(((4, -2*I, 2 + 2*I), (2*I, 2, -1 + I), (2 - 2*I, -1 - I, 11))) + L, D = A.LDLdecomposition() + assert expand_mul(L * D * L.H) == A + assert L.expand() == Matrix([[1, 0, 0], [I/2, 1, 0], [S.Half - I/2, 0, 1]]) + assert D.expand() == Matrix(((4, 0, 0), (0, 1, 0), (0, 0, 9))) + + raises(NonSquareMatrixError, lambda: SparseMatrix((1, 2)).LDLdecomposition()) + raises(ValueError, lambda: SparseMatrix(((1, 2), (3, 4))).LDLdecomposition()) + raises(ValueError, lambda: SparseMatrix(((5 + I, 0), (0, 1))).LDLdecomposition()) + raises(ValueError, lambda: SparseMatrix(((1, 5), (5, 1))).LDLdecomposition()) + raises(ValueError, lambda: SparseMatrix(((1, 2), (3, 4))).LDLdecomposition(hermitian=False)) + A = SparseMatrix(((1, 5), (5, 1))) + L, D = A.LDLdecomposition(hermitian=False) + assert L * D * L.T == A + A = SparseMatrix(((25, 15, -5), (15, 18, 0), (-5, 0, 11))) + L, D = A.LDLdecomposition() + assert L * D * L.T == A + assert L.is_lower + assert L == Matrix([[1, 0, 0], [ Rational(3, 5), 1, 0], [Rational(-1, 5), Rational(1, 3), 1]]) + assert D.is_diagonal() + assert D == Matrix([[25, 0, 0], [0, 9, 0], [0, 0, 9]]) + A = SparseMatrix(((4, -2*I, 2 + 2*I), (2*I, 2, -1 + I), (2 - 2*I, -1 - I, 11))) + L, D = A.LDLdecomposition() + assert expand_mul(L * D * L.H) == A + assert L == Matrix(((1, 0, 0), (I/2, 1, 0), (S.Half - I/2, 0, 1))) + assert D == Matrix(((4, 0, 0), (0, 1, 0), (0, 0, 9))) + +def test_pinv_succeeds_with_rank_decomposition_method(): + # Test rank decomposition method of pseudoinverse succeeding + As = [Matrix([ + [61, 89, 55, 20, 71, 0], + [62, 96, 85, 85, 16, 0], + [69, 56, 17, 4, 54, 0], + [10, 54, 91, 41, 71, 0], + [ 7, 30, 10, 48, 90, 0], + [0,0,0,0,0,0]])] + for A in As: + A_pinv = A.pinv(method="RD") + AAp = A * A_pinv + ApA = A_pinv * A + assert simplify(AAp * A) == A + assert simplify(ApA * A_pinv) == A_pinv + assert AAp.H == AAp + assert ApA.H == ApA + +def test_rank_decomposition(): + a = Matrix(0, 0, []) + c, f = a.rank_decomposition() + assert f.is_echelon + assert c.cols == f.rows == a.rank() + assert c * f == a + + a = Matrix(1, 1, [5]) + c, f = a.rank_decomposition() + assert f.is_echelon + assert c.cols == f.rows == a.rank() + assert c * f == a + + a = Matrix(3, 3, [1, 2, 3, 1, 2, 3, 1, 2, 3]) + c, f = a.rank_decomposition() + assert f.is_echelon + assert c.cols == f.rows == a.rank() + assert c * f == a + + a = Matrix([ + [0, 0, 1, 2, 2, -5, 3], + [-1, 5, 2, 2, 1, -7, 5], + [0, 0, -2, -3, -3, 8, -5], + [-1, 5, 0, -1, -2, 1, 0]]) + c, f = a.rank_decomposition() + assert f.is_echelon + assert c.cols == f.rows == a.rank() + assert c * f == a + + +@slow +def test_upper_hessenberg_decomposition(): + A = Matrix([ + [1, 0, sqrt(3)], + [sqrt(2), Rational(1, 2), 2], + [1, Rational(1, 4), 3], + ]) + H, P = A.upper_hessenberg_decomposition() + assert simplify(P * P.H) == eye(P.cols) + assert simplify(P.H * P) == eye(P.cols) + assert H.is_upper_hessenberg + assert (simplify(P * H * P.H)) == A + + + B = Matrix([ + [1, 2, 10], + [8, 2, 5], + [3, 12, 34], + ]) + H, P = B.upper_hessenberg_decomposition() + assert simplify(P * P.H) == eye(P.cols) + assert simplify(P.H * P) == eye(P.cols) + assert H.is_upper_hessenberg + assert simplify(P * H * P.H) == B + + C = Matrix([ + [1, sqrt(2), 2, 3], + [0, 5, 3, 4], + [1, 1, 4, sqrt(5)], + [0, 2, 2, 3] + ]) + + H, P = C.upper_hessenberg_decomposition() + assert simplify(P * P.H) == eye(P.cols) + assert simplify(P.H * P) == eye(P.cols) + assert H.is_upper_hessenberg + assert simplify(P * H * P.H) == C + + D = Matrix([ + [1, 2, 3], + [-3, 5, 6], + [4, -8, 9], + ]) + H, P = D.upper_hessenberg_decomposition() + assert simplify(P * P.H) == eye(P.cols) + assert simplify(P.H * P) == eye(P.cols) + assert H.is_upper_hessenberg + assert simplify(P * H * P.H) == D + + E = Matrix([ + [1, 0, 0, 0], + [0, 1, 0, 0], + [1, 1, 0, 1], + [1, 1, 1, 0] + ]) + + H, P = E.upper_hessenberg_decomposition() + assert simplify(P * P.H) == eye(P.cols) + assert simplify(P.H * P) == eye(P.cols) + assert H.is_upper_hessenberg + assert simplify(P * H * P.H) == E diff --git a/phivenv/Lib/site-packages/sympy/matrices/tests/test_determinant.py b/phivenv/Lib/site-packages/sympy/matrices/tests/test_determinant.py new file mode 100644 index 0000000000000000000000000000000000000000..82b42ccf67efa4757bf270782bdf1d65e0efa306 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/tests/test_determinant.py @@ -0,0 +1,280 @@ +import random +import pytest +from sympy.core.numbers import I +from sympy.core.numbers import Rational +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.polys.polytools import Poly +from sympy.matrices import Matrix, eye, ones +from sympy.abc import x, y, z +from sympy.testing.pytest import raises +from sympy.matrices.exceptions import NonSquareMatrixError +from sympy.functions.combinatorial.factorials import factorial, subfactorial + + +@pytest.mark.parametrize("method", [ + # Evaluating these directly because they are never reached via M.det() + Matrix._eval_det_bareiss, Matrix._eval_det_berkowitz, + Matrix._eval_det_bird, Matrix._eval_det_laplace, Matrix._eval_det_lu +]) +@pytest.mark.parametrize("M, sol", [ + (Matrix(), 1), + (Matrix([[0]]), 0), + (Matrix([[5]]), 5), +]) +def test_eval_determinant(method, M, sol): + assert method(M) == sol + + +@pytest.mark.parametrize("method", [ + "domain-ge", "bareiss", "berkowitz", "bird", "laplace", "lu"]) +@pytest.mark.parametrize("M, sol", [ + (Matrix(( (-3, 2), + ( 8, -5) )), -1), + (Matrix(( (x, 1), + (y, 2*y) )), 2*x*y - y), + (Matrix(( (1, 1, 1), + (1, 2, 3), + (1, 3, 6) )), 1), + (Matrix(( ( 3, -2, 0, 5), + (-2, 1, -2, 2), + ( 0, -2, 5, 0), + ( 5, 0, 3, 4) )), -289), + (Matrix(( ( 1, 2, 3, 4), + ( 5, 6, 7, 8), + ( 9, 10, 11, 12), + (13, 14, 15, 16) )), 0), + (Matrix(( (3, 2, 0, 0, 0), + (0, 3, 2, 0, 0), + (0, 0, 3, 2, 0), + (0, 0, 0, 3, 2), + (2, 0, 0, 0, 3) )), 275), + (Matrix(( ( 3, 0, 0, 0), + (-2, 1, 0, 0), + ( 0, -2, 5, 0), + ( 5, 0, 3, 4) )), 60), + (Matrix(( ( 1, 0, 0, 0), + ( 5, 0, 0, 0), + ( 9, 10, 11, 0), + (13, 14, 15, 16) )), 0), + (Matrix(( (3, 2, 0, 0, 0), + (0, 3, 2, 0, 0), + (0, 0, 3, 2, 0), + (0, 0, 0, 3, 2), + (0, 0, 0, 0, 3) )), 243), + (Matrix(( (1, 0, 1, 2, 12), + (2, 0, 1, 1, 4), + (2, 1, 1, -1, 3), + (3, 2, -1, 1, 8), + (1, 1, 1, 0, 6) )), -55), + (Matrix(( (-5, 2, 3, 4, 5), + ( 1, -4, 3, 4, 5), + ( 1, 2, -3, 4, 5), + ( 1, 2, 3, -2, 5), + ( 1, 2, 3, 4, -1) )), 11664), + (Matrix(( ( 2, 7, -1, 3, 2), + ( 0, 0, 1, 0, 1), + (-2, 0, 7, 0, 2), + (-3, -2, 4, 5, 3), + ( 1, 0, 0, 0, 1) )), 123), + (Matrix(( (x, y, z), + (1, 0, 0), + (y, z, x) )), z**2 - x*y), +]) +def test_determinant(method, M, sol): + assert M.det(method=method) == sol + + +def test_issue_13835(): + a = symbols('a') + M = lambda n: Matrix([[i + a*j for i in range(n)] + for j in range(n)]) + assert M(5).det() == 0 + assert M(6).det() == 0 + assert M(7).det() == 0 + + +def test_issue_14517(): + M = Matrix([ + [ 0, 10*I, 10*I, 0], + [10*I, 0, 0, 10*I], + [10*I, 0, 5 + 2*I, 10*I], + [ 0, 10*I, 10*I, 5 + 2*I]]) + ev = M.eigenvals() + # test one random eigenvalue, the computation is a little slow + test_ev = random.choice(list(ev.keys())) + assert (M - test_ev*eye(4)).det() == 0 + + +@pytest.mark.parametrize("method", [ + "bareis", "det_lu", "det_LU", "Bareis", "BAREISS", "BERKOWITZ", "LU"]) +@pytest.mark.parametrize("M, sol", [ + (Matrix(( ( 3, -2, 0, 5), + (-2, 1, -2, 2), + ( 0, -2, 5, 0), + ( 5, 0, 3, 4) )), -289), + (Matrix(( (-5, 2, 3, 4, 5), + ( 1, -4, 3, 4, 5), + ( 1, 2, -3, 4, 5), + ( 1, 2, 3, -2, 5), + ( 1, 2, 3, 4, -1) )), 11664), +]) +def test_legacy_det(method, M, sol): + # Minimal support for legacy keys for 'method' in det() + # Partially copied from test_determinant() + assert M.det(method=method) == sol + + +def eye_Determinant(n): + return Matrix(n, n, lambda i, j: int(i == j)) + +def zeros_Determinant(n): + return Matrix(n, n, lambda i, j: 0) + +def test_det(): + a = Matrix(2, 3, [1, 2, 3, 4, 5, 6]) + raises(NonSquareMatrixError, lambda: a.det()) + + z = zeros_Determinant(2) + ey = eye_Determinant(2) + assert z.det() == 0 + assert ey.det() == 1 + + x = Symbol('x') + a = Matrix(0, 0, []) + b = Matrix(1, 1, [5]) + c = Matrix(2, 2, [1, 2, 3, 4]) + d = Matrix(3, 3, [1, 2, 3, 4, 5, 6, 7, 8, 8]) + e = Matrix(4, 4, + [x, 1, 2, 3, 4, 5, 6, 7, 2, 9, 10, 11, 12, 13, 14, 14]) + from sympy.abc import i, j, k, l, m, n + f = Matrix(3, 3, [i, l, m, 0, j, n, 0, 0, k]) + g = Matrix(3, 3, [i, 0, 0, l, j, 0, m, n, k]) + h = Matrix(3, 3, [x**3, 0, 0, i, x**-1, 0, j, k, x**-2]) + # the method keyword for `det` doesn't kick in until 4x4 matrices, + # so there is no need to test all methods on smaller ones + + assert a.det() == 1 + assert b.det() == 5 + assert c.det() == -2 + assert d.det() == 3 + assert e.det() == 4*x - 24 + assert e.det(method="domain-ge") == 4*x - 24 + assert e.det(method='bareiss') == 4*x - 24 + assert e.det(method='berkowitz') == 4*x - 24 + assert f.det() == i*j*k + assert g.det() == i*j*k + assert h.det() == 1 + raises(ValueError, lambda: e.det(iszerofunc="test")) + +def test_permanent(): + M = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + assert M.per() == 450 + for i in range(1, 12): + assert ones(i, i).per() == ones(i, i).T.per() == factorial(i) + assert (ones(i, i)-eye(i)).per() == (ones(i, i)-eye(i)).T.per() == subfactorial(i) + + a1, a2, a3, a4, a5 = symbols('a_1 a_2 a_3 a_4 a_5') + M = Matrix([a1, a2, a3, a4, a5]) + assert M.per() == M.T.per() == a1 + a2 + a3 + a4 + a5 + +def test_adjugate(): + x = Symbol('x') + e = Matrix(4, 4, + [x, 1, 2, 3, 4, 5, 6, 7, 2, 9, 10, 11, 12, 13, 14, 14]) + + adj = Matrix([ + [ 4, -8, 4, 0], + [ 76, -14*x - 68, 14*x - 8, -4*x + 24], + [-122, 17*x + 142, -21*x + 4, 8*x - 48], + [ 48, -4*x - 72, 8*x, -4*x + 24]]) + assert e.adjugate() == adj + assert e.adjugate(method='bareiss') == adj + assert e.adjugate(method='berkowitz') == adj + assert e.adjugate(method='bird') == adj + assert e.adjugate(method='laplace') == adj + + a = Matrix(2, 3, [1, 2, 3, 4, 5, 6]) + raises(NonSquareMatrixError, lambda: a.adjugate()) + +def test_util(): + R = Rational + + v1 = Matrix(1, 3, [1, 2, 3]) + v2 = Matrix(1, 3, [3, 4, 5]) + assert v1.norm() == sqrt(14) + assert v1.project(v2) == Matrix(1, 3, [R(39)/25, R(52)/25, R(13)/5]) + assert Matrix.zeros(1, 2) == Matrix(1, 2, [0, 0]) + assert ones(1, 2) == Matrix(1, 2, [1, 1]) + assert v1.copy() == v1 + # cofactor + assert eye(3) == eye(3).cofactor_matrix() + test = Matrix([[1, 3, 2], [2, 6, 3], [2, 3, 6]]) + assert test.cofactor_matrix() == \ + Matrix([[27, -6, -6], [-12, 2, 3], [-3, 1, 0]]) + test = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + assert test.cofactor_matrix() == \ + Matrix([[-3, 6, -3], [6, -12, 6], [-3, 6, -3]]) + +def test_cofactor_and_minors(): + x = Symbol('x') + e = Matrix(4, 4, + [x, 1, 2, 3, 4, 5, 6, 7, 2, 9, 10, 11, 12, 13, 14, 14]) + + m = Matrix([ + [ x, 1, 3], + [ 2, 9, 11], + [12, 13, 14]]) + cm = Matrix([ + [ 4, 76, -122, 48], + [-8, -14*x - 68, 17*x + 142, -4*x - 72], + [ 4, 14*x - 8, -21*x + 4, 8*x], + [ 0, -4*x + 24, 8*x - 48, -4*x + 24]]) + sub = Matrix([ + [x, 1, 2], + [4, 5, 6], + [2, 9, 10]]) + + assert e.minor_submatrix(1, 2) == m + assert e.minor_submatrix(-1, -1) == sub + assert e.minor(1, 2) == -17*x - 142 + assert e.cofactor(1, 2) == 17*x + 142 + assert e.cofactor_matrix() == cm + assert e.cofactor_matrix(method="bareiss") == cm + assert e.cofactor_matrix(method="berkowitz") == cm + assert e.cofactor_matrix(method="bird") == cm + assert e.cofactor_matrix(method="laplace") == cm + + raises(ValueError, lambda: e.cofactor(4, 5)) + raises(ValueError, lambda: e.minor(4, 5)) + raises(ValueError, lambda: e.minor_submatrix(4, 5)) + + a = Matrix(2, 3, [1, 2, 3, 4, 5, 6]) + assert a.minor_submatrix(0, 0) == Matrix([[5, 6]]) + + raises(ValueError, lambda: + Matrix(0, 0, []).minor_submatrix(0, 0)) + raises(NonSquareMatrixError, lambda: a.cofactor(0, 0)) + raises(NonSquareMatrixError, lambda: a.minor(0, 0)) + raises(NonSquareMatrixError, lambda: a.cofactor_matrix()) + +def test_charpoly(): + x, y = Symbol('x'), Symbol('y') + z, t = Symbol('z'), Symbol('t') + + from sympy.abc import a,b,c + + m = Matrix(3, 3, [1, 2, 3, 4, 5, 6, 7, 8, 9]) + + assert eye_Determinant(3).charpoly(x) == Poly((x - 1)**3, x) + assert eye_Determinant(3).charpoly(y) == Poly((y - 1)**3, y) + assert m.charpoly() == Poly(x**3 - 15*x**2 - 18*x, x) + raises(NonSquareMatrixError, lambda: Matrix([[1], [2]]).charpoly()) + n = Matrix(4, 4, [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + assert n.charpoly() == Poly(x**4, x) + + n = Matrix(4, 4, [45, 0, 0, 0, 0, 23, 0, 0, 0, 0, 87, 0, 0, 0, 0, 12]) + assert n.charpoly() == Poly(x**4 - 167*x**3 + 8811*x**2 - 173457*x + 1080540, x) + + n = Matrix(3, 3, [x, 0, 0, a, y, 0, b, c, z]) + assert n.charpoly() == Poly(t**3 - (x+y+z)*t**2 + t*(x*y+y*z+x*z) - x*y*z, t) diff --git a/phivenv/Lib/site-packages/sympy/matrices/tests/test_domains.py b/phivenv/Lib/site-packages/sympy/matrices/tests/test_domains.py new file mode 100644 index 0000000000000000000000000000000000000000..26a54b8879a5c65f3a01b4886d223c08309e733d --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/tests/test_domains.py @@ -0,0 +1,113 @@ +# Test Matrix/DomainMatrix interaction. + + +from sympy import GF, ZZ, QQ, EXRAW +from sympy.polys.matrices import DomainMatrix, DM + +from sympy import ( + Matrix, + MutableMatrix, + ImmutableMatrix, + SparseMatrix, + MutableDenseMatrix, + ImmutableDenseMatrix, + MutableSparseMatrix, + ImmutableSparseMatrix, +) +from sympy import symbols, S, sqrt + +from sympy.testing.pytest import raises + + +x, y = symbols('x y') + + +MATRIX_TYPES = ( + Matrix, + MutableMatrix, + ImmutableMatrix, + SparseMatrix, + MutableDenseMatrix, + ImmutableDenseMatrix, + MutableSparseMatrix, + ImmutableSparseMatrix, +) +IMMUTABLE = ( + ImmutableMatrix, + ImmutableDenseMatrix, + ImmutableSparseMatrix, +) + + +def DMs(items, domain): + return DM(items, domain).to_sparse() + + +def test_Matrix_rep_domain(): + + for Mat in MATRIX_TYPES: + + M = Mat([[1, 2], [3, 4]]) + assert M._rep == DMs([[1, 2], [3, 4]], ZZ) + assert (M / 2)._rep == DMs([[(1,2), 1], [(3,2), 2]], QQ) + if not isinstance(M, IMMUTABLE): + M[0, 0] = x + assert M._rep == DMs([[x, 2], [3, 4]], EXRAW) + + M = Mat([[S(1)/2, 2], [3, 4]]) + assert M._rep == DMs([[(1,2), 2], [3, 4]], QQ) + if not isinstance(M, IMMUTABLE): + M[0, 0] = x + assert M._rep == DMs([[x, 2], [3, 4]], EXRAW) + + dM = DMs([[1, 2], [3, 4]], ZZ) + assert Mat._fromrep(dM)._rep == dM + + # XXX: This is not intended. Perhaps it should be coerced to EXRAW? + # The private _fromrep method is never called like this but perhaps it + # should be guarded. + # + # It is not clear how to integrate domains other than ZZ, QQ and EXRAW with + # the rest of Matrix or if the public type for this needs to be something + # different from Matrix somehow. + K = QQ.algebraic_field(sqrt(2)) + dM = DM([[1, 2], [3, 4]], K) + assert Mat._fromrep(dM)._rep.domain == K + + +def test_Matrix_to_DM(): + + M = Matrix([[1, 2], [3, 4]]) + assert M.to_DM() == DMs([[1, 2], [3, 4]], ZZ) + assert M.to_DM() is not M._rep + assert M.to_DM(field=True) == DMs([[1, 2], [3, 4]], QQ) + assert M.to_DM(domain=QQ) == DMs([[1, 2], [3, 4]], QQ) + assert M.to_DM(domain=QQ[x]) == DMs([[1, 2], [3, 4]], QQ[x]) + assert M.to_DM(domain=GF(3)) == DMs([[1, 2], [0, 1]], GF(3)) + + M = Matrix([[1, 2], [3, 4]]) + M[0, 0] = x + assert M._rep.domain == EXRAW + M[0, 0] = 1 + assert M.to_DM() == DMs([[1, 2], [3, 4]], ZZ) + + M = Matrix([[S(1)/2, 2], [3, 4]]) + assert M.to_DM() == DMs([[QQ(1,2), 2], [3, 4]], QQ) + + M = Matrix([[x, 2], [3, 4]]) + assert M.to_DM() == DMs([[x, 2], [3, 4]], ZZ[x]) + assert M.to_DM(field=True) == DMs([[x, 2], [3, 4]], ZZ.frac_field(x)) + + M = Matrix([[1/x, 2], [3, 4]]) + assert M.to_DM() == DMs([[1/x, 2], [3, 4]], ZZ.frac_field(x)) + + M = Matrix([[1, sqrt(2)], [3, 4]]) + K = QQ.algebraic_field(sqrt(2)) + sqrt2 = K.from_sympy(sqrt(2)) # XXX: Maybe K(sqrt(2)) should work + M_K = DomainMatrix([[K(1), sqrt2], [K(3), K(4)]], (2, 2), K) + assert M.to_DM() == DMs([[1, sqrt(2)], [3, 4]], EXRAW) + assert M.to_DM(extension=True) == M_K.to_sparse() + + # Options cannot be used with the domain parameter + M = Matrix([[1, 2], [3, 4]]) + raises(TypeError, lambda: M.to_DM(domain=QQ, field=True)) diff --git a/phivenv/Lib/site-packages/sympy/matrices/tests/test_eigen.py b/phivenv/Lib/site-packages/sympy/matrices/tests/test_eigen.py new file mode 100644 index 0000000000000000000000000000000000000000..fcf96325519879e0683d29e2ddc32db7bf83baa4 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/tests/test_eigen.py @@ -0,0 +1,712 @@ +from sympy.core.evalf import N +from sympy.core.numbers import (Float, I, Rational) +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.matrices import eye, Matrix +from sympy.core.singleton import S +from sympy.testing.pytest import raises, XFAIL +from sympy.matrices.exceptions import NonSquareMatrixError, MatrixError +from sympy.matrices.expressions.fourier import DFT +from sympy.simplify.simplify import simplify +from sympy.matrices.immutable import ImmutableMatrix +from sympy.testing.pytest import slow +from sympy.testing.matrices import allclose + + +def test_eigen(): + R = Rational + M = Matrix.eye(3) + assert M.eigenvals(multiple=False) == {S.One: 3} + assert M.eigenvals(multiple=True) == [1, 1, 1] + + assert M.eigenvects() == ( + [(1, 3, [Matrix([1, 0, 0]), + Matrix([0, 1, 0]), + Matrix([0, 0, 1])])]) + + assert M.left_eigenvects() == ( + [(1, 3, [Matrix([[1, 0, 0]]), + Matrix([[0, 1, 0]]), + Matrix([[0, 0, 1]])])]) + + M = Matrix([[0, 1, 1], + [1, 0, 0], + [1, 1, 1]]) + + assert M.eigenvals() == {2*S.One: 1, -S.One: 1, S.Zero: 1} + + assert M.eigenvects() == ( + [ + (-1, 1, [Matrix([-1, 1, 0])]), + ( 0, 1, [Matrix([0, -1, 1])]), + ( 2, 1, [Matrix([R(2, 3), R(1, 3), 1])]) + ]) + + assert M.left_eigenvects() == ( + [ + (-1, 1, [Matrix([[-2, 1, 1]])]), + (0, 1, [Matrix([[-1, -1, 1]])]), + (2, 1, [Matrix([[1, 1, 1]])]) + ]) + + a = Symbol('a') + M = Matrix([[a, 0], + [0, 1]]) + + assert M.eigenvals() == {a: 1, S.One: 1} + + M = Matrix([[1, -1], + [1, 3]]) + assert M.eigenvects() == ([(2, 2, [Matrix(2, 1, [-1, 1])])]) + assert M.left_eigenvects() == ([(2, 2, [Matrix([[1, 1]])])]) + + M = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + a = R(15, 2) + b = 3*33**R(1, 2) + c = R(13, 2) + d = (R(33, 8) + 3*b/8) + e = (R(33, 8) - 3*b/8) + + def NS(e, n): + return str(N(e, n)) + r = [ + (a - b/2, 1, [Matrix([(12 + 24/(c - b/2))/((c - b/2)*e) + 3/(c - b/2), + (6 + 12/(c - b/2))/e, 1])]), + ( 0, 1, [Matrix([1, -2, 1])]), + (a + b/2, 1, [Matrix([(12 + 24/(c + b/2))/((c + b/2)*d) + 3/(c + b/2), + (6 + 12/(c + b/2))/d, 1])]), + ] + r1 = [(NS(r[i][0], 2), NS(r[i][1], 2), + [NS(j, 2) for j in r[i][2][0]]) for i in range(len(r))] + r = M.eigenvects() + r2 = [(NS(r[i][0], 2), NS(r[i][1], 2), + [NS(j, 2) for j in r[i][2][0]]) for i in range(len(r))] + assert sorted(r1) == sorted(r2) + + eps = Symbol('eps', real=True) + + M = Matrix([[abs(eps), I*eps ], + [-I*eps, abs(eps) ]]) + + assert M.eigenvects() == ( + [ + ( 0, 1, [Matrix([[-I*eps/abs(eps)], [1]])]), + ( 2*abs(eps), 1, [ Matrix([[I*eps/abs(eps)], [1]]) ] ), + ]) + + assert M.left_eigenvects() == ( + [ + (0, 1, [Matrix([[I*eps/Abs(eps), 1]])]), + (2*Abs(eps), 1, [Matrix([[-I*eps/Abs(eps), 1]])]) + ]) + + M = Matrix(3, 3, [1, 2, 0, 0, 3, 0, 2, -4, 2]) + M._eigenvects = M.eigenvects(simplify=False) + assert max(i.q for i in M._eigenvects[0][2][0]) > 1 + M._eigenvects = M.eigenvects(simplify=True) + assert max(i.q for i in M._eigenvects[0][2][0]) == 1 + + M = Matrix([[Rational(1, 4), 1], [1, 1]]) + assert M.eigenvects() == [ + (Rational(5, 8) - sqrt(73)/8, 1, [Matrix([[-sqrt(73)/8 - Rational(3, 8)], [1]])]), + (Rational(5, 8) + sqrt(73)/8, 1, [Matrix([[Rational(-3, 8) + sqrt(73)/8], [1]])])] + + # issue 10719 + assert Matrix([]).eigenvals() == {} + assert Matrix([]).eigenvals(multiple=True) == [] + assert Matrix([]).eigenvects() == [] + + # issue 15119 + raises(NonSquareMatrixError, + lambda: Matrix([[1, 2], [0, 4], [0, 0]]).eigenvals()) + raises(NonSquareMatrixError, + lambda: Matrix([[1, 0], [3, 4], [5, 6]]).eigenvals()) + raises(NonSquareMatrixError, + lambda: Matrix([[1, 2, 3], [0, 5, 6]]).eigenvals()) + raises(NonSquareMatrixError, + lambda: Matrix([[1, 0, 0], [4, 5, 0]]).eigenvals()) + raises(NonSquareMatrixError, + lambda: Matrix([[1, 2, 3], [0, 5, 6]]).eigenvals( + error_when_incomplete = False)) + raises(NonSquareMatrixError, + lambda: Matrix([[1, 0, 0], [4, 5, 0]]).eigenvals( + error_when_incomplete = False)) + + m = Matrix([[1, 2], [3, 4]]) + assert isinstance(m.eigenvals(simplify=True, multiple=False), dict) + assert isinstance(m.eigenvals(simplify=True, multiple=True), list) + assert isinstance(m.eigenvals(simplify=lambda x: x, multiple=False), dict) + assert isinstance(m.eigenvals(simplify=lambda x: x, multiple=True), list) + + +def test_float_eigenvals(): + m = Matrix([[1, .6, .6], [.6, .9, .9], [.9, .6, .6]]) + evals = [ + Rational(5, 4) - sqrt(385)/20, + sqrt(385)/20 + Rational(5, 4), + S.Zero] + + n_evals = m.eigenvals(rational=True, multiple=True) + n_evals = sorted(n_evals) + s_evals = [x.evalf() for x in evals] + s_evals = sorted(s_evals) + + for x, y in zip(n_evals, s_evals): + assert abs(x-y) < 10**-9 + + +@XFAIL +def test_eigen_vects(): + m = Matrix(2, 2, [1, 0, 0, I]) + raises(NotImplementedError, lambda: m.is_diagonalizable(True)) + # !!! bug because of eigenvects() or roots(x**2 + (-1 - I)*x + I, x) + # see issue 5292 + assert not m.is_diagonalizable(True) + raises(MatrixError, lambda: m.diagonalize(True)) + (P, D) = m.diagonalize(True) + +def test_issue_8240(): + # Eigenvalues of large triangular matrices + x, y = symbols('x y') + n = 200 + + diagonal_variables = [Symbol('x%s' % i) for i in range(n)] + M = [[0 for i in range(n)] for j in range(n)] + for i in range(n): + M[i][i] = diagonal_variables[i] + M = Matrix(M) + + eigenvals = M.eigenvals() + assert len(eigenvals) == n + for i in range(n): + assert eigenvals[diagonal_variables[i]] == 1 + + eigenvals = M.eigenvals(multiple=True) + assert set(eigenvals) == set(diagonal_variables) + + # with multiplicity + M = Matrix([[x, 0, 0], [1, y, 0], [2, 3, x]]) + eigenvals = M.eigenvals() + assert eigenvals == {x: 2, y: 1} + + eigenvals = M.eigenvals(multiple=True) + assert len(eigenvals) == 3 + assert eigenvals.count(x) == 2 + assert eigenvals.count(y) == 1 + + +def test_eigenvals(): + M = Matrix([[0, 1, 1], + [1, 0, 0], + [1, 1, 1]]) + assert M.eigenvals() == {2*S.One: 1, -S.One: 1, S.Zero: 1} + + m = Matrix([ + [3, 0, 0, 0, -3], + [0, -3, -3, 0, 3], + [0, 3, 0, 3, 0], + [0, 0, 3, 0, 3], + [3, 0, 0, 3, 0]]) + + # XXX Used dry-run test because arbitrary symbol that appears in + # CRootOf may not be unique. + assert m.eigenvals() + + +def test_eigenvects(): + M = Matrix([[0, 1, 1], + [1, 0, 0], + [1, 1, 1]]) + vecs = M.eigenvects() + for val, mult, vec_list in vecs: + assert len(vec_list) == 1 + assert M*vec_list[0] == val*vec_list[0] + + +def test_left_eigenvects(): + M = Matrix([[0, 1, 1], + [1, 0, 0], + [1, 1, 1]]) + vecs = M.left_eigenvects() + for val, mult, vec_list in vecs: + assert len(vec_list) == 1 + assert vec_list[0]*M == val*vec_list[0] + + +@slow +def test_bidiagonalize(): + M = Matrix([[1, 0, 0], + [0, 1, 0], + [0, 0, 1]]) + assert M.bidiagonalize() == M + assert M.bidiagonalize(upper=False) == M + assert M.bidiagonalize() == M + assert M.bidiagonal_decomposition() == (M, M, M) + assert M.bidiagonal_decomposition(upper=False) == (M, M, M) + assert M.bidiagonalize() == M + + import random + #Real Tests + for real_test in range(2): + test_values = [] + row = 2 + col = 2 + for _ in range(row * col): + value = random.randint(-1000000000, 1000000000) + test_values = test_values + [value] + # L -> Lower Bidiagonalization + # M -> Mutable Matrix + # N -> Immutable Matrix + # 0 -> Bidiagonalized form + # 1,2,3 -> Bidiagonal_decomposition matrices + # 4 -> Product of 1 2 3 + M = Matrix(row, col, test_values) + N = ImmutableMatrix(M) + + N1, N2, N3 = N.bidiagonal_decomposition() + M1, M2, M3 = M.bidiagonal_decomposition() + M0 = M.bidiagonalize() + N0 = N.bidiagonalize() + + N4 = N1 * N2 * N3 + M4 = M1 * M2 * M3 + + N2.simplify() + N4.simplify() + N0.simplify() + + M0.simplify() + M2.simplify() + M4.simplify() + + LM0 = M.bidiagonalize(upper=False) + LM1, LM2, LM3 = M.bidiagonal_decomposition(upper=False) + LN0 = N.bidiagonalize(upper=False) + LN1, LN2, LN3 = N.bidiagonal_decomposition(upper=False) + + LN4 = LN1 * LN2 * LN3 + LM4 = LM1 * LM2 * LM3 + + LN2.simplify() + LN4.simplify() + LN0.simplify() + + LM0.simplify() + LM2.simplify() + LM4.simplify() + + assert M == M4 + assert M2 == M0 + assert N == N4 + assert N2 == N0 + assert M == LM4 + assert LM2 == LM0 + assert N == LN4 + assert LN2 == LN0 + + #Complex Tests + for complex_test in range(2): + test_values = [] + size = 2 + for _ in range(size * size): + real = random.randint(-1000000000, 1000000000) + comp = random.randint(-1000000000, 1000000000) + value = real + comp * I + test_values = test_values + [value] + M = Matrix(size, size, test_values) + N = ImmutableMatrix(M) + # L -> Lower Bidiagonalization + # M -> Mutable Matrix + # N -> Immutable Matrix + # 0 -> Bidiagonalized form + # 1,2,3 -> Bidiagonal_decomposition matrices + # 4 -> Product of 1 2 3 + N1, N2, N3 = N.bidiagonal_decomposition() + M1, M2, M3 = M.bidiagonal_decomposition() + M0 = M.bidiagonalize() + N0 = N.bidiagonalize() + + N4 = N1 * N2 * N3 + M4 = M1 * M2 * M3 + + N2.simplify() + N4.simplify() + N0.simplify() + + M0.simplify() + M2.simplify() + M4.simplify() + + LM0 = M.bidiagonalize(upper=False) + LM1, LM2, LM3 = M.bidiagonal_decomposition(upper=False) + LN0 = N.bidiagonalize(upper=False) + LN1, LN2, LN3 = N.bidiagonal_decomposition(upper=False) + + LN4 = LN1 * LN2 * LN3 + LM4 = LM1 * LM2 * LM3 + + LN2.simplify() + LN4.simplify() + LN0.simplify() + + LM0.simplify() + LM2.simplify() + LM4.simplify() + + assert M == M4 + assert M2 == M0 + assert N == N4 + assert N2 == N0 + assert M == LM4 + assert LM2 == LM0 + assert N == LN4 + assert LN2 == LN0 + + M = Matrix(18, 8, range(1, 145)) + M = M.applyfunc(lambda i: Float(i)) + assert M.bidiagonal_decomposition()[1] == M.bidiagonalize() + assert M.bidiagonal_decomposition(upper=False)[1] == M.bidiagonalize(upper=False) + a, b, c = M.bidiagonal_decomposition() + diff = a * b * c - M + assert abs(max(diff)) < 10**-12 + + +def test_diagonalize(): + m = Matrix(2, 2, [0, -1, 1, 0]) + raises(MatrixError, lambda: m.diagonalize(reals_only=True)) + P, D = m.diagonalize() + assert D.is_diagonal() + assert D == Matrix([ + [-I, 0], + [ 0, I]]) + + # make sure we use floats out if floats are passed in + m = Matrix(2, 2, [0, .5, .5, 0]) + P, D = m.diagonalize() + assert all(isinstance(e, Float) for e in D.values()) + assert all(isinstance(e, Float) for e in P.values()) + + _, D2 = m.diagonalize(reals_only=True) + assert D == D2 + + m = Matrix( + [[0, 1, 0, 0], [1, 0, 0, 0.002], [0.002, 0, 0, 1], [0, 0, 1, 0]]) + P, D = m.diagonalize() + assert allclose(P*D, m*P) + + +def test_is_diagonalizable(): + a, b, c = symbols('a b c') + m = Matrix(2, 2, [a, c, c, b]) + assert m.is_symmetric() + assert m.is_diagonalizable() + assert not Matrix(2, 2, [1, 1, 0, 1]).is_diagonalizable() + + m = Matrix(2, 2, [0, -1, 1, 0]) + assert m.is_diagonalizable() + assert not m.is_diagonalizable(reals_only=True) + + +def test_jordan_form(): + m = Matrix(3, 2, [-3, 1, -3, 20, 3, 10]) + raises(NonSquareMatrixError, lambda: m.jordan_form()) + + # the next two tests test the cases where the old + # algorithm failed due to the fact that the block structure can + # *NOT* be determined from algebraic and geometric multiplicity alone + # This can be seen most easily when one lets compute the J.c.f. of a matrix that + # is in J.c.f already. + m = Matrix(4, 4, [2, 1, 0, 0, + 0, 2, 1, 0, + 0, 0, 2, 0, + 0, 0, 0, 2 + ]) + P, J = m.jordan_form() + assert m == J + + m = Matrix(4, 4, [2, 1, 0, 0, + 0, 2, 0, 0, + 0, 0, 2, 1, + 0, 0, 0, 2 + ]) + P, J = m.jordan_form() + assert m == J + + A = Matrix([[ 2, 4, 1, 0], + [-4, 2, 0, 1], + [ 0, 0, 2, 4], + [ 0, 0, -4, 2]]) + P, J = A.jordan_form() + assert simplify(P*J*P.inv()) == A + + assert Matrix(1, 1, [1]).jordan_form() == (Matrix([1]), Matrix([1])) + assert Matrix(1, 1, [1]).jordan_form(calc_transform=False) == Matrix([1]) + + # If we have eigenvalues in CRootOf form, raise errors + m = Matrix([[3, 0, 0, 0, -3], [0, -3, -3, 0, 3], [0, 3, 0, 3, 0], [0, 0, 3, 0, 3], [3, 0, 0, 3, 0]]) + raises(MatrixError, lambda: m.jordan_form()) + + # make sure that if the input has floats, the output does too + m = Matrix([ + [ 0.6875, 0.125 + 0.1875*sqrt(3)], + [0.125 + 0.1875*sqrt(3), 0.3125]]) + P, J = m.jordan_form() + assert all(isinstance(x, Float) or x == 0 for x in P) + assert all(isinstance(x, Float) or x == 0 for x in J) + + +def test_singular_values(): + x = Symbol('x', real=True) + + A = Matrix([[0, 1*I], [2, 0]]) + # if singular values can be sorted, they should be in decreasing order + assert A.singular_values() == [2, 1] + + A = eye(3) + A[1, 1] = x + A[2, 2] = 5 + vals = A.singular_values() + # since Abs(x) cannot be sorted, test set equality + assert set(vals) == {5, 1, Abs(x)} + + A = Matrix([[sin(x), cos(x)], [-cos(x), sin(x)]]) + vals = [sv.trigsimp() for sv in A.singular_values()] + assert vals == [S.One, S.One] + + A = Matrix([ + [2, 4], + [1, 3], + [0, 0], + [0, 0] + ]) + assert A.singular_values() == \ + [sqrt(sqrt(221) + 15), sqrt(15 - sqrt(221))] + assert A.T.singular_values() == \ + [sqrt(sqrt(221) + 15), sqrt(15 - sqrt(221)), 0, 0] + +def test___eq__(): + assert (Matrix( + [[0, 1, 1], + [1, 0, 0], + [1, 1, 1]]) == {}) is False + + +def test_definite(): + # Examples from Gilbert Strang, "Introduction to Linear Algebra" + # Positive definite matrices + m = Matrix([[2, -1, 0], [-1, 2, -1], [0, -1, 2]]) + assert m.is_positive_definite == True + assert m.is_positive_semidefinite == True + assert m.is_negative_definite == False + assert m.is_negative_semidefinite == False + assert m.is_indefinite == False + + m = Matrix([[5, 4], [4, 5]]) + assert m.is_positive_definite == True + assert m.is_positive_semidefinite == True + assert m.is_negative_definite == False + assert m.is_negative_semidefinite == False + assert m.is_indefinite == False + + # Positive semidefinite matrices + m = Matrix([[2, -1, -1], [-1, 2, -1], [-1, -1, 2]]) + assert m.is_positive_definite == False + assert m.is_positive_semidefinite == True + assert m.is_negative_definite == False + assert m.is_negative_semidefinite == False + assert m.is_indefinite == False + + m = Matrix([[1, 2], [2, 4]]) + assert m.is_positive_definite == False + assert m.is_positive_semidefinite == True + assert m.is_negative_definite == False + assert m.is_negative_semidefinite == False + assert m.is_indefinite == False + + # Examples from Mathematica documentation + # Non-hermitian positive definite matrices + m = Matrix([[2, 3], [4, 8]]) + assert m.is_positive_definite == True + assert m.is_positive_semidefinite == True + assert m.is_negative_definite == False + assert m.is_negative_semidefinite == False + assert m.is_indefinite == False + + # Hermetian matrices + m = Matrix([[1, 2*I], [-I, 4]]) + assert m.is_positive_definite == True + assert m.is_positive_semidefinite == True + assert m.is_negative_definite == False + assert m.is_negative_semidefinite == False + assert m.is_indefinite == False + + # Symbolic matrices examples + a = Symbol('a', positive=True) + b = Symbol('b', negative=True) + m = Matrix([[a, 0, 0], [0, a, 0], [0, 0, a]]) + assert m.is_positive_definite == True + assert m.is_positive_semidefinite == True + assert m.is_negative_definite == False + assert m.is_negative_semidefinite == False + assert m.is_indefinite == False + + m = Matrix([[b, 0, 0], [0, b, 0], [0, 0, b]]) + assert m.is_positive_definite == False + assert m.is_positive_semidefinite == False + assert m.is_negative_definite == True + assert m.is_negative_semidefinite == True + assert m.is_indefinite == False + + m = Matrix([[a, 0], [0, b]]) + assert m.is_positive_definite == False + assert m.is_positive_semidefinite == False + assert m.is_negative_definite == False + assert m.is_negative_semidefinite == False + assert m.is_indefinite == True + + m = Matrix([ + [0.0228202735623867, 0.00518748979085398, + -0.0743036351048907, -0.00709135324903921], + [0.00518748979085398, 0.0349045359786350, + 0.0830317991056637, 0.00233147902806909], + [-0.0743036351048907, 0.0830317991056637, + 1.15859676366277, 0.340359081555988], + [-0.00709135324903921, 0.00233147902806909, + 0.340359081555988, 0.928147644848199] + ]) + assert m.is_positive_definite == True + assert m.is_positive_semidefinite == True + assert m.is_indefinite == False + + # test for issue 19547: https://github.com/sympy/sympy/issues/19547 + m = Matrix([ + [0, 0, 0], + [0, 1, 2], + [0, 2, 1] + ]) + assert not m.is_positive_definite + assert not m.is_positive_semidefinite + + +def test_positive_semidefinite_cholesky(): + from sympy.matrices.eigen import _is_positive_semidefinite_cholesky + + m = Matrix([[0, 0, 0], [0, 0, 0], [0, 0, 0]]) + assert _is_positive_semidefinite_cholesky(m) == True + m = Matrix([[0, 0, 0], [0, 5, -10*I], [0, 10*I, 5]]) + assert _is_positive_semidefinite_cholesky(m) == False + m = Matrix([[1, 0, 0], [0, 0, 0], [0, 0, -1]]) + assert _is_positive_semidefinite_cholesky(m) == False + m = Matrix([[0, 1], [1, 0]]) + assert _is_positive_semidefinite_cholesky(m) == False + + # https://www.value-at-risk.net/cholesky-factorization/ + m = Matrix([[4, -2, -6], [-2, 10, 9], [-6, 9, 14]]) + assert _is_positive_semidefinite_cholesky(m) == True + m = Matrix([[9, -3, 3], [-3, 2, 1], [3, 1, 6]]) + assert _is_positive_semidefinite_cholesky(m) == True + m = Matrix([[4, -2, 2], [-2, 1, -1], [2, -1, 5]]) + assert _is_positive_semidefinite_cholesky(m) == True + m = Matrix([[1, 2, -1], [2, 5, 1], [-1, 1, 9]]) + assert _is_positive_semidefinite_cholesky(m) == False + + +def test_issue_20582(): + A = Matrix([ + [5, -5, -3, 2, -7], + [-2, -5, 0, 2, 1], + [-2, -7, -5, -2, -6], + [7, 10, 3, 9, -2], + [4, -10, 3, -8, -4] + ]) + # XXX Used dry-run test because arbitrary symbol that appears in + # CRootOf may not be unique. + assert A.eigenvects() + +def test_issue_19210(): + t = Symbol('t') + H = Matrix([[3, 0, 0, 0], [0, 1 , 2, 0], [0, 2, 2, 0], [0, 0, 0, 4]]) + A = (-I * H * t).jordan_form() + assert A == (Matrix([ + [0, 1, 0, 0], + [0, 0, -4/(-1 + sqrt(17)), 4/(1 + sqrt(17))], + [0, 0, 1, 1], + [1, 0, 0, 0]]), Matrix([ + [-4*I*t, 0, 0, 0], + [ 0, -3*I*t, 0, 0], + [ 0, 0, t*(-3*I/2 + sqrt(17)*I/2), 0], + [ 0, 0, 0, t*(-sqrt(17)*I/2 - 3*I/2)]])) + + +def test_issue_20275(): + # XXX We use complex expansions because complex exponentials are not + # recognized by polys.domains + A = DFT(3).as_explicit().expand(complex=True) + eigenvects = A.eigenvects() + assert eigenvects[0] == ( + -1, 1, + [Matrix([[1 - sqrt(3)], [1], [1]])] + ) + assert eigenvects[1] == ( + 1, 1, + [Matrix([[1 + sqrt(3)], [1], [1]])] + ) + assert eigenvects[2] == ( + -I, 1, + [Matrix([[0], [-1], [1]])] + ) + + A = DFT(4).as_explicit().expand(complex=True) + eigenvects = A.eigenvects() + assert eigenvects[0] == ( + -1, 1, + [Matrix([[-1], [1], [1], [1]])] + ) + assert eigenvects[1] == ( + 1, 2, + [Matrix([[1], [0], [1], [0]]), Matrix([[2], [1], [0], [1]])] + ) + assert eigenvects[2] == ( + -I, 1, + [Matrix([[0], [-1], [0], [1]])] + ) + + # XXX We skip test for some parts of eigenvectors which are very + # complicated and fragile under expression tree changes + A = DFT(5).as_explicit().expand(complex=True) + eigenvects = A.eigenvects() + assert eigenvects[0] == ( + -1, 1, + [Matrix([[1 - sqrt(5)], [1], [1], [1], [1]])] + ) + assert eigenvects[1] == ( + 1, 2, + [Matrix([[S(1)/2 + sqrt(5)/2], [0], [1], [1], [0]]), + Matrix([[S(1)/2 + sqrt(5)/2], [1], [0], [0], [1]])] + ) + + +def test_issue_20752(): + b = symbols('b', nonzero=True) + m = Matrix([[0, 0, 0], [0, b, 0], [0, 0, b]]) + assert m.is_positive_semidefinite is None + + +def test_issue_25282(): + dd = sd = [0] * 11 + [1] + ds = [2, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0] + ss = ds.copy() + ss[8] = 2 + + def rotate(x, i): + return x[i:] + x[:i] + + mat = [] + for i in range(12): + mat.append(rotate(ss, i) + rotate(sd, i)) + for i in range(12): + mat.append(rotate(ds, i) + rotate(dd, i)) + + assert sum(Matrix(mat).eigenvals().values()) == 24 diff --git a/phivenv/Lib/site-packages/sympy/matrices/tests/test_graph.py b/phivenv/Lib/site-packages/sympy/matrices/tests/test_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..0bf3c819a9477387f53560a034d7949fd76a654f --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/tests/test_graph.py @@ -0,0 +1,108 @@ +from sympy.combinatorics import Permutation +from sympy.core.symbol import symbols +from sympy.matrices import Matrix +from sympy.matrices.expressions import ( + PermutationMatrix, BlockDiagMatrix, BlockMatrix) + + +def test_connected_components(): + a, b, c, d, e, f, g, h, i, j, k, l, m = symbols('a:m') + + M = Matrix([ + [a, 0, 0, 0, b, 0, 0, 0, 0, 0, c, 0, 0], + [0, d, 0, 0, 0, e, 0, 0, 0, 0, 0, f, 0], + [0, 0, g, 0, 0, 0, h, 0, 0, 0, 0, 0, i], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [m, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, m, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, m, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [j, 0, 0, 0, k, 0, 0, 1, 0, 0, l, 0, 0], + [0, j, 0, 0, 0, k, 0, 0, 1, 0, 0, l, 0], + [0, 0, j, 0, 0, 0, k, 0, 0, 1, 0, 0, l], + [0, 0, 0, 0, d, 0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, d, 0, 0, 0, 0, 0, 1, 0], + [0, 0, 0, 0, 0, 0, d, 0, 0, 0, 0, 0, 1]]) + cc = M.connected_components() + assert cc == [[0, 4, 7, 10], [1, 5, 8, 11], [2, 6, 9, 12], [3]] + + P, B = M.connected_components_decomposition() + p = Permutation([0, 4, 7, 10, 1, 5, 8, 11, 2, 6, 9, 12, 3]) + assert P == PermutationMatrix(p) + + B0 = Matrix([ + [a, b, 0, c], + [m, 1, 0, 0], + [j, k, 1, l], + [0, d, 0, 1]]) + B1 = Matrix([ + [d, e, 0, f], + [m, 1, 0, 0], + [j, k, 1, l], + [0, d, 0, 1]]) + B2 = Matrix([ + [g, h, 0, i], + [m, 1, 0, 0], + [j, k, 1, l], + [0, d, 0, 1]]) + B3 = Matrix([[1]]) + assert B == BlockDiagMatrix(B0, B1, B2, B3) + + +def test_strongly_connected_components(): + M = Matrix([ + [11, 14, 10, 0, 15, 0], + [0, 44, 0, 0, 45, 0], + [1, 4, 0, 0, 5, 0], + [0, 0, 0, 22, 0, 23], + [0, 54, 0, 0, 55, 0], + [0, 0, 0, 32, 0, 33]]) + scc = M.strongly_connected_components() + assert scc == [[1, 4], [0, 2], [3, 5]] + + P, B = M.strongly_connected_components_decomposition() + p = Permutation([1, 4, 0, 2, 3, 5]) + assert P == PermutationMatrix(p) + assert B == BlockMatrix([ + [ + Matrix([[44, 45], [54, 55]]), + Matrix.zeros(2, 2), + Matrix.zeros(2, 2) + ], + [ + Matrix([[14, 15], [4, 5]]), + Matrix([[11, 10], [1, 0]]), + Matrix.zeros(2, 2) + ], + [ + Matrix.zeros(2, 2), + Matrix.zeros(2, 2), + Matrix([[22, 23], [32, 33]]) + ] + ]) + P = P.as_explicit() + B = B.as_explicit() + assert P.T * B * P == M + + P, B = M.strongly_connected_components_decomposition(lower=False) + p = Permutation([3, 5, 0, 2, 1, 4]) + assert P == PermutationMatrix(p) + assert B == BlockMatrix([ + [ + Matrix([[22, 23], [32, 33]]), + Matrix.zeros(2, 2), + Matrix.zeros(2, 2) + ], + [ + Matrix.zeros(2, 2), + Matrix([[11, 10], [1, 0]]), + Matrix([[14, 15], [4, 5]]) + ], + [ + Matrix.zeros(2, 2), + Matrix.zeros(2, 2), + Matrix([[44, 45], [54, 55]]) + ] + ]) + P = P.as_explicit() + B = B.as_explicit() + assert P.T * B * P == M diff --git a/phivenv/Lib/site-packages/sympy/matrices/tests/test_immutable.py b/phivenv/Lib/site-packages/sympy/matrices/tests/test_immutable.py new file mode 100644 index 0000000000000000000000000000000000000000..2b83c1f9fae7f83be9d5f7dd4b484781dc128faf --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/tests/test_immutable.py @@ -0,0 +1,136 @@ +from itertools import product + +from sympy.core.relational import (Equality, Unequality) +from sympy.core.singleton import S +from sympy.core.sympify import sympify +from sympy.integrals.integrals import integrate +from sympy.matrices.dense import (Matrix, eye, zeros) +from sympy.matrices.immutable import ImmutableMatrix +from sympy.matrices import SparseMatrix +from sympy.matrices.immutable import \ + ImmutableDenseMatrix, ImmutableSparseMatrix +from sympy.abc import x, y +from sympy.testing.pytest import raises + +IM = ImmutableDenseMatrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) +ISM = ImmutableSparseMatrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) +ieye = ImmutableDenseMatrix(eye(3)) + + +def test_creation(): + assert IM.shape == ISM.shape == (3, 3) + assert IM[1, 2] == ISM[1, 2] == 6 + assert IM[2, 2] == ISM[2, 2] == 9 + + +def test_immutability(): + with raises(TypeError): + IM[2, 2] = 5 + with raises(TypeError): + ISM[2, 2] = 5 + + +def test_slicing(): + assert IM[1, :] == ImmutableDenseMatrix([[4, 5, 6]]) + assert IM[:2, :2] == ImmutableDenseMatrix([[1, 2], [4, 5]]) + assert ISM[1, :] == ImmutableSparseMatrix([[4, 5, 6]]) + assert ISM[:2, :2] == ImmutableSparseMatrix([[1, 2], [4, 5]]) + + +def test_subs(): + A = ImmutableMatrix([[1, 2], [3, 4]]) + B = ImmutableMatrix([[1, 2], [x, 4]]) + C = ImmutableMatrix([[-x, x*y], [-(x + y), y**2]]) + assert B.subs(x, 3) == A + assert (x*B).subs(x, 3) == 3*A + assert (x*eye(2) + B).subs(x, 3) == 3*eye(2) + A + assert C.subs([[x, -1], [y, -2]]) == A + assert C.subs([(x, -1), (y, -2)]) == A + assert C.subs({x: -1, y: -2}) == A + assert C.subs({x: y - 1, y: x - 1}, simultaneous=True) == \ + ImmutableMatrix([[1 - y, (x - 1)*(y - 1)], [2 - x - y, (x - 1)**2]]) + + +def test_as_immutable(): + data = [[1, 2], [3, 4]] + X = Matrix(data) + assert sympify(X) == X.as_immutable() == ImmutableMatrix(data) + + data = {(0, 0): 1, (0, 1): 2, (1, 0): 3, (1, 1): 4} + X = SparseMatrix(2, 2, data) + assert sympify(X) == X.as_immutable() == ImmutableSparseMatrix(2, 2, data) + + +def test_function_return_types(): + # Lets ensure that decompositions of immutable matrices remain immutable + # I.e. do MatrixBase methods return the correct class? + X = ImmutableMatrix([[1, 2], [3, 4]]) + Y = ImmutableMatrix([[1], [0]]) + q, r = X.QRdecomposition() + assert (type(q), type(r)) == (ImmutableMatrix, ImmutableMatrix) + + assert type(X.LUsolve(Y)) == ImmutableMatrix + assert type(X.QRsolve(Y)) == ImmutableMatrix + + X = ImmutableMatrix([[5, 2], [2, 7]]) + assert X.T == X + assert X.is_symmetric + assert type(X.cholesky()) == ImmutableMatrix + L, D = X.LDLdecomposition() + assert (type(L), type(D)) == (ImmutableMatrix, ImmutableMatrix) + + X = ImmutableMatrix([[1, 2], [2, 1]]) + assert X.is_diagonalizable() + assert X.det() == -3 + assert X.norm(2) == 3 + + assert type(X.eigenvects()[0][2][0]) == ImmutableMatrix + + assert type(zeros(3, 3).as_immutable().nullspace()[0]) == ImmutableMatrix + + X = ImmutableMatrix([[1, 0], [2, 1]]) + assert type(X.lower_triangular_solve(Y)) == ImmutableMatrix + assert type(X.T.upper_triangular_solve(Y)) == ImmutableMatrix + + assert type(X.minor_submatrix(0, 0)) == ImmutableMatrix + +# issue 6279 +# https://github.com/sympy/sympy/issues/6279 +# Test that Immutable _op_ Immutable => Immutable and not MatExpr + + +def test_immutable_evaluation(): + X = ImmutableMatrix(eye(3)) + A = ImmutableMatrix(3, 3, range(9)) + assert isinstance(X + A, ImmutableMatrix) + assert isinstance(X * A, ImmutableMatrix) + assert isinstance(X * 2, ImmutableMatrix) + assert isinstance(2 * X, ImmutableMatrix) + assert isinstance(A**2, ImmutableMatrix) + + +def test_deterimant(): + assert ImmutableMatrix(4, 4, lambda i, j: i + j).det() == 0 + + +def test_Equality(): + assert Equality(IM, IM) is S.true + assert Unequality(IM, IM) is S.false + assert Equality(IM, IM.subs(1, 2)) is S.false + assert Unequality(IM, IM.subs(1, 2)) is S.true + assert Equality(IM, 2) is S.false + assert Unequality(IM, 2) is S.true + M = ImmutableMatrix([x, y]) + assert Equality(M, IM) is S.false + assert Unequality(M, IM) is S.true + assert Equality(M, M.subs(x, 2)).subs(x, 2) is S.true + assert Unequality(M, M.subs(x, 2)).subs(x, 2) is S.false + assert Equality(M, M.subs(x, 2)).subs(x, 3) is S.false + assert Unequality(M, M.subs(x, 2)).subs(x, 3) is S.true + + +def test_integrate(): + intIM = integrate(IM, x) + assert intIM.shape == IM.shape + assert all(intIM[i, j] == (1 + j + 3*i)*x for i, j in + product(range(3), range(3))) diff --git a/phivenv/Lib/site-packages/sympy/matrices/tests/test_interactions.py b/phivenv/Lib/site-packages/sympy/matrices/tests/test_interactions.py new file mode 100644 index 0000000000000000000000000000000000000000..f4fc3268368e8dd632fc0df187d57ea5e845120c --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/tests/test_interactions.py @@ -0,0 +1,77 @@ +""" +We have a few different kind of Matrices +Matrix, ImmutableMatrix, MatrixExpr + +Here we test the extent to which they cooperate +""" + +from sympy.core.symbol import symbols +from sympy.matrices import (Matrix, MatrixSymbol, eye, Identity, + ImmutableMatrix) +from sympy.matrices.expressions import MatrixExpr, MatAdd +from sympy.matrices.matrixbase import classof +from sympy.testing.pytest import raises + +SM = MatrixSymbol('X', 3, 3) +SV = MatrixSymbol('v', 3, 1) +MM = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) +IM = ImmutableMatrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) +meye = eye(3) +imeye = ImmutableMatrix(eye(3)) +ideye = Identity(3) +a, b, c = symbols('a,b,c') + + +def test_IM_MM(): + assert isinstance(MM + IM, ImmutableMatrix) + assert isinstance(IM + MM, ImmutableMatrix) + assert isinstance(2*IM + MM, ImmutableMatrix) + assert MM.equals(IM) + + +def test_ME_MM(): + assert isinstance(Identity(3) + MM, MatrixExpr) + assert isinstance(SM + MM, MatAdd) + assert isinstance(MM + SM, MatAdd) + assert (Identity(3) + MM)[1, 1] == 6 + + +def test_equality(): + a, b, c = Identity(3), eye(3), ImmutableMatrix(eye(3)) + for x in [a, b, c]: + for y in [a, b, c]: + assert x.equals(y) + + +def test_matrix_symbol_MM(): + X = MatrixSymbol('X', 3, 3) + Y = eye(3) + X + assert Y[1, 1] == 1 + X[1, 1] + + +def test_matrix_symbol_vector_matrix_multiplication(): + A = MM * SV + B = IM * SV + assert A == B + C = (SV.T * MM.T).T + assert B == C + D = (SV.T * IM.T).T + assert C == D + + +def test_indexing_interactions(): + assert (a * IM)[1, 1] == 5*a + assert (SM + IM)[1, 1] == SM[1, 1] + IM[1, 1] + assert (SM * IM)[1, 1] == SM[1, 0]*IM[0, 1] + SM[1, 1]*IM[1, 1] + \ + SM[1, 2]*IM[2, 1] + + +def test_classof(): + A = Matrix(3, 3, range(9)) + B = ImmutableMatrix(3, 3, range(9)) + C = MatrixSymbol('C', 3, 3) + assert classof(A, A) == Matrix + assert classof(B, B) == ImmutableMatrix + assert classof(A, B) == ImmutableMatrix + assert classof(B, A) == ImmutableMatrix + raises(TypeError, lambda: classof(A, C)) diff --git a/phivenv/Lib/site-packages/sympy/matrices/tests/test_matrices.py b/phivenv/Lib/site-packages/sympy/matrices/tests/test_matrices.py new file mode 100644 index 0000000000000000000000000000000000000000..d9d97341de570e078d652dddce58fb8f5cb99e43 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/tests/test_matrices.py @@ -0,0 +1,3487 @@ +# +# Code for testing deprecated matrix classes. New test code should not be added +# here. Instead, add it to test_matrixbase.py. +# +# This entire test module and the corresponding sympy/matrices/matrices.py +# module will be removed in a future release. +# +import random +import concurrent.futures +from collections.abc import Hashable + +from sympy.core.add import Add +from sympy.core.function import Function, diff, expand +from sympy.core.numbers import (E, Float, I, Integer, Rational, nan, oo, pi) +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.core.sympify import sympify +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import (Max, Min, sqrt) +from sympy.functions.elementary.trigonometric import (cos, sin, tan) +from sympy.integrals.integrals import integrate +from sympy.matrices.expressions.transpose import transpose +from sympy.physics.quantum.operator import HermitianOperator, Operator, Dagger +from sympy.polys.polytools import (Poly, PurePoly) +from sympy.polys.rootoftools import RootOf +from sympy.printing.str import sstr +from sympy.sets.sets import FiniteSet +from sympy.simplify.simplify import (signsimp, simplify) +from sympy.simplify.trigsimp import trigsimp +from sympy.matrices.exceptions import (ShapeError, MatrixError, + NonSquareMatrixError) +from sympy.matrices.matrixbase import DeferredVector +from sympy.matrices.determinant import _find_reasonable_pivot_naive +from sympy.matrices.utilities import _simplify +from sympy.matrices import ( + GramSchmidt, ImmutableMatrix, ImmutableSparseMatrix, Matrix, + SparseMatrix, casoratian, diag, eye, hessian, + matrix_multiply_elementwise, ones, randMatrix, rot_axis1, rot_axis2, + rot_axis3, wronskian, zeros, MutableDenseMatrix, ImmutableDenseMatrix, + MatrixSymbol, dotprodsimp, rot_ccw_axis1, rot_ccw_axis2, rot_ccw_axis3) +from sympy.matrices.utilities import _dotprodsimp_state +from sympy.core import Tuple, Wild +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.utilities.iterables import flatten, capture, iterable +from sympy.utilities.exceptions import ignore_warnings +from sympy.testing.pytest import (raises, XFAIL, slow, skip, skip_under_pyodide, + warns_deprecated_sympy) +from sympy.assumptions import Q +from sympy.tensor.array import Array +from sympy.tensor.array.array_derivatives import ArrayDerivative +from sympy.matrices.expressions import MatPow +from sympy.algebras import Quaternion + +from sympy import O + +from sympy.abc import a, b, c, d, x, y, z, t + + +# don't re-order this list +classes = (Matrix, SparseMatrix, ImmutableMatrix, ImmutableSparseMatrix) + + +# Test the deprecated matrixmixins +from sympy.matrices.common import _MinimalMatrix, _CastableMatrix +from sympy.matrices.matrices import MatrixSubspaces, MatrixReductions + + +with warns_deprecated_sympy(): + class SubspaceOnlyMatrix(_MinimalMatrix, _CastableMatrix, MatrixSubspaces): + pass + + +with warns_deprecated_sympy(): + class ReductionsOnlyMatrix(_MinimalMatrix, _CastableMatrix, MatrixReductions): + pass + + +def eye_Reductions(n): + return ReductionsOnlyMatrix(n, n, lambda i, j: int(i == j)) + + +def zeros_Reductions(n): + return ReductionsOnlyMatrix(n, n, lambda i, j: 0) + + +def test_args(): + for n, cls in enumerate(classes): + m = cls.zeros(3, 2) + # all should give back the same type of arguments, e.g. ints for shape + assert m.shape == (3, 2) and all(type(i) is int for i in m.shape) + assert m.rows == 3 and type(m.rows) is int + assert m.cols == 2 and type(m.cols) is int + if not n % 2: + assert type(m.flat()) in (list, tuple, Tuple) + else: + assert type(m.todok()) is dict + + +def test_deprecated_mat_smat(): + for cls in Matrix, ImmutableMatrix: + m = cls.zeros(3, 2) + with warns_deprecated_sympy(): + mat = m._mat + assert mat == m.flat() + for cls in SparseMatrix, ImmutableSparseMatrix: + m = cls.zeros(3, 2) + with warns_deprecated_sympy(): + smat = m._smat + assert smat == m.todok() + + +def test_division(): + v = Matrix(1, 2, [x, y]) + assert v/z == Matrix(1, 2, [x/z, y/z]) + + +def test_sum(): + m = Matrix([[1, 2, 3], [x, y, x], [2*y, -50, z*x]]) + assert m + m == Matrix([[2, 4, 6], [2*x, 2*y, 2*x], [4*y, -100, 2*z*x]]) + n = Matrix(1, 2, [1, 2]) + raises(ShapeError, lambda: m + n) + +def test_abs(): + m = Matrix(1, 2, [-3, x]) + n = Matrix(1, 2, [3, Abs(x)]) + assert abs(m) == n + +def test_addition(): + a = Matrix(( + (1, 2), + (3, 1), + )) + + b = Matrix(( + (1, 2), + (3, 0), + )) + + assert a + b == a.add(b) == Matrix([[2, 4], [6, 1]]) + + +def test_fancy_index_matrix(): + for M in (Matrix, SparseMatrix): + a = M(3, 3, range(9)) + assert a == a[:, :] + assert a[1, :] == Matrix(1, 3, [3, 4, 5]) + assert a[:, 1] == Matrix([1, 4, 7]) + assert a[[0, 1], :] == Matrix([[0, 1, 2], [3, 4, 5]]) + assert a[[0, 1], 2] == a[[0, 1], [2]] + assert a[2, [0, 1]] == a[[2], [0, 1]] + assert a[:, [0, 1]] == Matrix([[0, 1], [3, 4], [6, 7]]) + assert a[0, 0] == 0 + assert a[0:2, :] == Matrix([[0, 1, 2], [3, 4, 5]]) + assert a[:, 0:2] == Matrix([[0, 1], [3, 4], [6, 7]]) + assert a[::2, 1] == a[[0, 2], 1] + assert a[1, ::2] == a[1, [0, 2]] + a = M(3, 3, range(9)) + assert a[[0, 2, 1, 2, 1], :] == Matrix([ + [0, 1, 2], + [6, 7, 8], + [3, 4, 5], + [6, 7, 8], + [3, 4, 5]]) + assert a[:, [0,2,1,2,1]] == Matrix([ + [0, 2, 1, 2, 1], + [3, 5, 4, 5, 4], + [6, 8, 7, 8, 7]]) + + a = SparseMatrix.zeros(3) + a[1, 2] = 2 + a[0, 1] = 3 + a[2, 0] = 4 + assert a.extract([1, 1], [2]) == Matrix([ + [2], + [2]]) + assert a.extract([1, 0], [2, 2, 2]) == Matrix([ + [2, 2, 2], + [0, 0, 0]]) + assert a.extract([1, 0, 1, 2], [2, 0, 1, 0]) == Matrix([ + [2, 0, 0, 0], + [0, 0, 3, 0], + [2, 0, 0, 0], + [0, 4, 0, 4]]) + + +def test_multiplication(): + a = Matrix(( + (1, 2), + (3, 1), + (0, 6), + )) + + b = Matrix(( + (1, 2), + (3, 0), + )) + + c = a*b + assert c[0, 0] == 7 + assert c[0, 1] == 2 + assert c[1, 0] == 6 + assert c[1, 1] == 6 + assert c[2, 0] == 18 + assert c[2, 1] == 0 + + try: + eval('c = a @ b') + except SyntaxError: + pass + else: + assert c[0, 0] == 7 + assert c[0, 1] == 2 + assert c[1, 0] == 6 + assert c[1, 1] == 6 + assert c[2, 0] == 18 + assert c[2, 1] == 0 + + h = matrix_multiply_elementwise(a, c) + assert h == a.multiply_elementwise(c) + assert h[0, 0] == 7 + assert h[0, 1] == 4 + assert h[1, 0] == 18 + assert h[1, 1] == 6 + assert h[2, 0] == 0 + assert h[2, 1] == 0 + raises(ShapeError, lambda: matrix_multiply_elementwise(a, b)) + + c = b * Symbol("x") + assert isinstance(c, Matrix) + assert c[0, 0] == x + assert c[0, 1] == 2*x + assert c[1, 0] == 3*x + assert c[1, 1] == 0 + + c2 = x * b + assert c == c2 + + c = 5 * b + assert isinstance(c, Matrix) + assert c[0, 0] == 5 + assert c[0, 1] == 2*5 + assert c[1, 0] == 3*5 + assert c[1, 1] == 0 + + try: + eval('c = 5 @ b') + except SyntaxError: + pass + else: + assert isinstance(c, Matrix) + assert c[0, 0] == 5 + assert c[0, 1] == 2*5 + assert c[1, 0] == 3*5 + assert c[1, 1] == 0 + + +def test_multiplication_inf_zero(): + + M = Matrix([[oo, 0], [0, oo]]) + assert M ** 2 == M + + M = Matrix([[oo, oo], [0, 0]]) + assert M ** 2 == Matrix([[nan, nan], [nan, nan]]) + + A = Matrix([ + [0, 0, 0, -S(1)/2], + [0, 1, 0, 0], + [0, 0, 1, 0], + [-S(1)/2, 0, 0, 0]]) + + B = Matrix([ + [pi*x**2, 0, pi*b*x**4/8 + pi*a*x**4/8 + O(x**5), pi*x**4/2 + pi*b**2*x**6/32 + pi*a*b*x**6/48 + pi*a**2*x**6/32 + O(x**7)], + [0, pi*x**4/4, O(x**6), O(x**8)], + [pi*b*x**4/8 + pi*a*x**4/8 + O(x**5), O(x**6), pi*b**2*x**6/32 + pi*a*b*x**6/48 + pi*a**2*x**6/32 + O(x**7), pi*b*x**6/12 + pi*a*x**6/12 + O(x**7)], + [pi*x**4/2 + pi*b**2*x**6/32 + pi*a*b*x**6/48 + pi*a**2*x**6/32 + O(x**7), O(x**8), pi*b*x**6/12 + pi*a*x**6/12 + O(x**7), pi*x**6/3 + 3*pi*b**2*x**8/64 + pi*a*b*x**8/32 + 3*pi*a**2*x**8/64 + O(x**9)]]) + + C = Matrix([ + [-pi*x**4/4 - pi*b**2*x**6/64 - pi*a*b*x**6/96 - pi*a**2*x**6/64 + O(x**7), O(x**8), -pi*b*x**6/24 - pi*a*x**6/24 + O(x**7), -pi*x**6/6 - 3*pi*b**2*x**8/128 - pi*a*b*x**8/64 - 3*pi*a**2*x**8/128 + O(x**9)], + [ 0, pi*x**4/4, O(x**6), O(x**8)], + [ pi*b*x**4/8 + pi*a*x**4/8 + O(x**5), O(x**6), pi*b**2*x**6/32 + pi*a*b*x**6/48 + pi*a**2*x**6/32 + O(x**7), pi*b*x**6/12 + pi*a*x**6/12 + O(x**7)], + [ -pi*x**2/2, 0, -pi*b*x**4/16 - pi*a*x**4/16 + O(x**5), -pi*x**4/4 - pi*b**2*x**6/64 - pi*a*b*x**6/96 - pi*a**2*x**6/64 + O(x**7)]]) + + C2 = Matrix(4, 4, lambda i, j: Add(*(A[i,k]*B[k,j] for k in range(4)))) + + assert A*B == C == C2 + + +def test_power(): + raises(NonSquareMatrixError, lambda: Matrix((1, 2))**2) + + R = Rational + A = Matrix([[2, 3], [4, 5]]) + assert (A**-3)[:] == [R(-269)/8, R(153)/8, R(51)/2, R(-29)/2] + assert (A**5)[:] == [6140, 8097, 10796, 14237] + A = Matrix([[2, 1, 3], [4, 2, 4], [6, 12, 1]]) + assert (A**3)[:] == [290, 262, 251, 448, 440, 368, 702, 954, 433] + assert A**0 == eye(3) + assert A**1 == A + assert (Matrix([[2]]) ** 100)[0, 0] == 2**100 + assert eye(2)**10000000 == eye(2) + assert Matrix([[1, 2], [3, 4]])**Integer(2) == Matrix([[7, 10], [15, 22]]) + + A = Matrix([[33, 24], [48, 57]]) + assert (A**S.Half)[:] == [5, 2, 4, 7] + A = Matrix([[0, 4], [-1, 5]]) + assert (A**S.Half)**2 == A + + assert Matrix([[1, 0], [1, 1]])**S.Half == Matrix([[1, 0], [S.Half, 1]]) + assert Matrix([[1, 0], [1, 1]])**0.5 == Matrix([[1, 0], [0.5, 1]]) + from sympy.abc import n + assert Matrix([[1, a], [0, 1]])**n == Matrix([[1, a*n], [0, 1]]) + assert Matrix([[b, a], [0, b]])**n == Matrix([[b**n, a*b**(n-1)*n], [0, b**n]]) + assert Matrix([ + [a**n, a**(n - 1)*n, (a**n*n**2 - a**n*n)/(2*a**2)], + [ 0, a**n, a**(n - 1)*n], + [ 0, 0, a**n]]) + assert Matrix([[a, 1, 0], [0, a, 0], [0, 0, b]])**n == Matrix([ + [a**n, a**(n-1)*n, 0], + [0, a**n, 0], + [0, 0, b**n]]) + + A = Matrix([[1, 0], [1, 7]]) + assert A._matrix_pow_by_jordan_blocks(S(3)) == A._eval_pow_by_recursion(3) + A = Matrix([[2]]) + assert A**10 == Matrix([[2**10]]) == A._matrix_pow_by_jordan_blocks(S(10)) == \ + A._eval_pow_by_recursion(10) + + # testing a matrix that cannot be jordan blocked issue 11766 + m = Matrix([[3, 0, 0, 0, -3], [0, -3, -3, 0, 3], [0, 3, 0, 3, 0], [0, 0, 3, 0, 3], [3, 0, 0, 3, 0]]) + raises(MatrixError, lambda: m._matrix_pow_by_jordan_blocks(S(10))) + + # test issue 11964 + raises(MatrixError, lambda: Matrix([[1, 1], [3, 3]])._matrix_pow_by_jordan_blocks(S(-10))) + A = Matrix([[0, 1, 0], [0, 0, 1], [0, 0, 0]]) # Nilpotent jordan block size 3 + assert A**10.0 == Matrix([[0, 0, 0], [0, 0, 0], [0, 0, 0]]) + raises(ValueError, lambda: A**2.1) + raises(ValueError, lambda: A**Rational(3, 2)) + A = Matrix([[8, 1], [3, 2]]) + assert A**10.0 == Matrix([[1760744107, 272388050], [817164150, 126415807]]) + A = Matrix([[0, 0, 1], [0, 0, 1], [0, 0, 1]]) # Nilpotent jordan block size 1 + assert A**10.0 == Matrix([[0, 0, 1], [0, 0, 1], [0, 0, 1]]) + A = Matrix([[0, 1, 0], [0, 0, 1], [0, 0, 1]]) # Nilpotent jordan block size 2 + assert A**10.0 == Matrix([[0, 0, 1], [0, 0, 1], [0, 0, 1]]) + n = Symbol('n', integer=True) + assert isinstance(A**n, MatPow) + n = Symbol('n', integer=True, negative=True) + raises(ValueError, lambda: A**n) + n = Symbol('n', integer=True, nonnegative=True) + assert A**n == Matrix([ + [KroneckerDelta(0, n), KroneckerDelta(1, n), -KroneckerDelta(0, n) - KroneckerDelta(1, n) + 1], + [ 0, KroneckerDelta(0, n), 1 - KroneckerDelta(0, n)], + [ 0, 0, 1]]) + assert A**(n + 2) == Matrix([[0, 0, 1], [0, 0, 1], [0, 0, 1]]) + raises(ValueError, lambda: A**Rational(3, 2)) + A = Matrix([[0, 0, 1], [3, 0, 1], [4, 3, 1]]) + assert A**5.0 == Matrix([[168, 72, 89], [291, 144, 161], [572, 267, 329]]) + assert A**5.0 == A**5 + A = Matrix([[0, 1, 0],[-1, 0, 0],[0, 0, 0]]) + n = Symbol("n") + An = A**n + assert An.subs(n, 2).doit() == A**2 + raises(ValueError, lambda: An.subs(n, -2).doit()) + assert An * An == A**(2*n) + + # concretizing behavior for non-integer and complex powers + A = Matrix([[0,0,0],[0,0,0],[0,0,0]]) + n = Symbol('n', integer=True, positive=True) + assert A**n == A + n = Symbol('n', integer=True, nonnegative=True) + assert A**n == diag(0**n, 0**n, 0**n) + assert (A**n).subs(n, 0) == eye(3) + assert (A**n).subs(n, 1) == zeros(3) + A = Matrix ([[2,0,0],[0,2,0],[0,0,2]]) + assert A**2.1 == diag (2**2.1, 2**2.1, 2**2.1) + assert A**I == diag (2**I, 2**I, 2**I) + A = Matrix([[0, 1, 0], [0, 0, 1], [0, 0, 1]]) + raises(ValueError, lambda: A**2.1) + raises(ValueError, lambda: A**I) + A = Matrix([[S.Half, S.Half], [S.Half, S.Half]]) + assert A**S.Half == A + A = Matrix([[1, 1],[3, 3]]) + assert A**S.Half == Matrix ([[S.Half, S.Half], [3*S.Half, 3*S.Half]]) + + +def test_issue_17247_expression_blowup_1(): + M = Matrix([[1+x, 1-x], [1-x, 1+x]]) + with dotprodsimp(True): + assert M.exp().expand() == Matrix([ + [ (exp(2*x) + exp(2))/2, (-exp(2*x) + exp(2))/2], + [(-exp(2*x) + exp(2))/2, (exp(2*x) + exp(2))/2]]) + +def test_issue_17247_expression_blowup_2(): + M = Matrix([[1+x, 1-x], [1-x, 1+x]]) + with dotprodsimp(True): + P, J = M.jordan_form () + assert P*J*P.inv() + +def test_issue_17247_expression_blowup_3(): + M = Matrix([[1+x, 1-x], [1-x, 1+x]]) + with dotprodsimp(True): + assert M**100 == Matrix([ + [633825300114114700748351602688*x**100 + 633825300114114700748351602688, 633825300114114700748351602688 - 633825300114114700748351602688*x**100], + [633825300114114700748351602688 - 633825300114114700748351602688*x**100, 633825300114114700748351602688*x**100 + 633825300114114700748351602688]]) + +def test_issue_17247_expression_blowup_4(): +# This matrix takes extremely long on current master even with intermediate simplification so an abbreviated version is used. It is left here for test in case of future optimizations. +# M = Matrix(S('''[ +# [ -3/4, 45/32 - 37*I/16, 1/4 + I/2, -129/64 - 9*I/64, 1/4 - 5*I/16, 65/128 + 87*I/64, -9/32 - I/16, 183/256 - 97*I/128, 3/64 + 13*I/64, -23/32 - 59*I/256, 15/128 - 3*I/32, 19/256 + 551*I/1024], +# [-149/64 + 49*I/32, -177/128 - 1369*I/128, 125/64 + 87*I/64, -2063/256 + 541*I/128, 85/256 - 33*I/16, 805/128 + 2415*I/512, -219/128 + 115*I/256, 6301/4096 - 6609*I/1024, 119/128 + 143*I/128, -10879/2048 + 4343*I/4096, 129/256 - 549*I/512, 42533/16384 + 29103*I/8192], +# [ 1/2 - I, 9/4 + 55*I/16, -3/4, 45/32 - 37*I/16, 1/4 + I/2, -129/64 - 9*I/64, 1/4 - 5*I/16, 65/128 + 87*I/64, -9/32 - I/16, 183/256 - 97*I/128, 3/64 + 13*I/64, -23/32 - 59*I/256], +# [ -5/8 - 39*I/16, 2473/256 + 137*I/64, -149/64 + 49*I/32, -177/128 - 1369*I/128, 125/64 + 87*I/64, -2063/256 + 541*I/128, 85/256 - 33*I/16, 805/128 + 2415*I/512, -219/128 + 115*I/256, 6301/4096 - 6609*I/1024, 119/128 + 143*I/128, -10879/2048 + 4343*I/4096], +# [ 1 + I, -19/4 + 5*I/4, 1/2 - I, 9/4 + 55*I/16, -3/4, 45/32 - 37*I/16, 1/4 + I/2, -129/64 - 9*I/64, 1/4 - 5*I/16, 65/128 + 87*I/64, -9/32 - I/16, 183/256 - 97*I/128], +# [ 21/8 + I, -537/64 + 143*I/16, -5/8 - 39*I/16, 2473/256 + 137*I/64, -149/64 + 49*I/32, -177/128 - 1369*I/128, 125/64 + 87*I/64, -2063/256 + 541*I/128, 85/256 - 33*I/16, 805/128 + 2415*I/512, -219/128 + 115*I/256, 6301/4096 - 6609*I/1024], +# [ -2, 17/4 - 13*I/2, 1 + I, -19/4 + 5*I/4, 1/2 - I, 9/4 + 55*I/16, -3/4, 45/32 - 37*I/16, 1/4 + I/2, -129/64 - 9*I/64, 1/4 - 5*I/16, 65/128 + 87*I/64], +# [ 1/4 + 13*I/4, -825/64 - 147*I/32, 21/8 + I, -537/64 + 143*I/16, -5/8 - 39*I/16, 2473/256 + 137*I/64, -149/64 + 49*I/32, -177/128 - 1369*I/128, 125/64 + 87*I/64, -2063/256 + 541*I/128, 85/256 - 33*I/16, 805/128 + 2415*I/512], +# [ -4*I, 27/2 + 6*I, -2, 17/4 - 13*I/2, 1 + I, -19/4 + 5*I/4, 1/2 - I, 9/4 + 55*I/16, -3/4, 45/32 - 37*I/16, 1/4 + I/2, -129/64 - 9*I/64], +# [ 1/4 + 5*I/2, -23/8 - 57*I/16, 1/4 + 13*I/4, -825/64 - 147*I/32, 21/8 + I, -537/64 + 143*I/16, -5/8 - 39*I/16, 2473/256 + 137*I/64, -149/64 + 49*I/32, -177/128 - 1369*I/128, 125/64 + 87*I/64, -2063/256 + 541*I/128], +# [ -4, 9 - 5*I, -4*I, 27/2 + 6*I, -2, 17/4 - 13*I/2, 1 + I, -19/4 + 5*I/4, 1/2 - I, 9/4 + 55*I/16, -3/4, 45/32 - 37*I/16], +# [ -2*I, 119/8 + 29*I/4, 1/4 + 5*I/2, -23/8 - 57*I/16, 1/4 + 13*I/4, -825/64 - 147*I/32, 21/8 + I, -537/64 + 143*I/16, -5/8 - 39*I/16, 2473/256 + 137*I/64, -149/64 + 49*I/32, -177/128 - 1369*I/128]]''')) +# assert M**10 == Matrix([ +# [ 7*(-221393644768594642173548179825793834595 - 1861633166167425978847110897013541127952*I)/9671406556917033397649408, 15*(31670992489131684885307005100073928751695 + 10329090958303458811115024718207404523808*I)/77371252455336267181195264, 7*(-3710978679372178839237291049477017392703 + 1377706064483132637295566581525806894169*I)/19342813113834066795298816, (9727707023582419994616144751727760051598 - 59261571067013123836477348473611225724433*I)/9671406556917033397649408, (31896723509506857062605551443641668183707 + 54643444538699269118869436271152084599580*I)/38685626227668133590597632, (-2024044860947539028275487595741003997397402 + 130959428791783397562960461903698670485863*I)/309485009821345068724781056, 3*(26190251453797590396533756519358368860907 - 27221191754180839338002754608545400941638*I)/77371252455336267181195264, (1154643595139959842768960128434994698330461 + 3385496216250226964322872072260446072295634*I)/618970019642690137449562112, 3*(-31849347263064464698310044805285774295286 - 11877437776464148281991240541742691164309*I)/77371252455336267181195264, (4661330392283532534549306589669150228040221 - 4171259766019818631067810706563064103956871*I)/1237940039285380274899124224, (9598353794289061833850770474812760144506 + 358027153990999990968244906482319780943983*I)/309485009821345068724781056, (-9755135335127734571547571921702373498554177 - 4837981372692695195747379349593041939686540*I)/2475880078570760549798248448], +# [(-379516731607474268954110071392894274962069 - 422272153179747548473724096872271700878296*I)/77371252455336267181195264, (41324748029613152354787280677832014263339501 - 12715121258662668420833935373453570749288074*I)/1237940039285380274899124224, (-339216903907423793947110742819264306542397 + 494174755147303922029979279454787373566517*I)/77371252455336267181195264, (-18121350839962855576667529908850640619878381 - 37413012454129786092962531597292531089199003*I)/1237940039285380274899124224, (2489661087330511608618880408199633556675926 + 1137821536550153872137379935240732287260863*I)/309485009821345068724781056, (-136644109701594123227587016790354220062972119 + 110130123468183660555391413889600443583585272*I)/4951760157141521099596496896, (1488043981274920070468141664150073426459593 - 9691968079933445130866371609614474474327650*I)/1237940039285380274899124224, 27*(4636797403026872518131756991410164760195942 + 3369103221138229204457272860484005850416533*I)/4951760157141521099596496896, (-8534279107365915284081669381642269800472363 + 2241118846262661434336333368511372725482742*I)/1237940039285380274899124224, (60923350128174260992536531692058086830950875 - 263673488093551053385865699805250505661590126*I)/9903520314283042199192993792, (18520943561240714459282253753348921824172569 + 24846649186468656345966986622110971925703604*I)/4951760157141521099596496896, (-232781130692604829085973604213529649638644431 + 35981505277760667933017117949103953338570617*I)/9903520314283042199192993792], +# [ (8742968295129404279528270438201520488950 + 3061473358639249112126847237482570858327*I)/4835703278458516698824704, (-245657313712011778432792959787098074935273 + 253113767861878869678042729088355086740856*I)/38685626227668133590597632, (1947031161734702327107371192008011621193 - 19462330079296259148177542369999791122762*I)/9671406556917033397649408, (552856485625209001527688949522750288619217 + 392928441196156725372494335248099016686580*I)/77371252455336267181195264, (-44542866621905323121630214897126343414629 + 3265340021421335059323962377647649632959*I)/19342813113834066795298816, (136272594005759723105646069956434264218730 - 330975364731707309489523680957584684763587*I)/38685626227668133590597632, (27392593965554149283318732469825168894401 + 75157071243800133880129376047131061115278*I)/38685626227668133590597632, 7*(-357821652913266734749960136017214096276154 - 45509144466378076475315751988405961498243*I)/309485009821345068724781056, (104485001373574280824835174390219397141149 - 99041000529599568255829489765415726168162*I)/77371252455336267181195264, (1198066993119982409323525798509037696321291 + 4249784165667887866939369628840569844519936*I)/618970019642690137449562112, (-114985392587849953209115599084503853611014 - 52510376847189529234864487459476242883449*I)/77371252455336267181195264, (6094620517051332877965959223269600650951573 - 4683469779240530439185019982269137976201163*I)/1237940039285380274899124224], +# [ (611292255597977285752123848828590587708323 - 216821743518546668382662964473055912169502*I)/77371252455336267181195264, (-1144023204575811464652692396337616594307487 + 12295317806312398617498029126807758490062855*I)/309485009821345068724781056, (-374093027769390002505693378578475235158281 - 573533923565898290299607461660384634333639*I)/77371252455336267181195264, (47405570632186659000138546955372796986832987 - 2837476058950808941605000274055970055096534*I)/1237940039285380274899124224, (-571573207393621076306216726219753090535121 + 533381457185823100878764749236639320783831*I)/77371252455336267181195264, (-7096548151856165056213543560958582513797519 - 24035731898756040059329175131592138642195366*I)/618970019642690137449562112, (2396762128833271142000266170154694033849225 + 1448501087375679588770230529017516492953051*I)/309485009821345068724781056, (-150609293845161968447166237242456473262037053 + 92581148080922977153207018003184520294188436*I)/4951760157141521099596496896, 5*(270278244730804315149356082977618054486347 - 1997830155222496880429743815321662710091562*I)/1237940039285380274899124224, (62978424789588828258068912690172109324360330 + 44803641177219298311493356929537007630129097*I)/2475880078570760549798248448, 19*(-451431106327656743945775812536216598712236 + 114924966793632084379437683991151177407937*I)/1237940039285380274899124224, (63417747628891221594106738815256002143915995 - 261508229397507037136324178612212080871150958*I)/9903520314283042199192993792], +# [ (-2144231934021288786200752920446633703357 + 2305614436009705803670842248131563850246*I)/1208925819614629174706176, (-90720949337459896266067589013987007078153 - 221951119475096403601562347412753844534569*I)/19342813113834066795298816, (11590973613116630788176337262688659880376 + 6514520676308992726483494976339330626159*I)/4835703278458516698824704, 3*(-131776217149000326618649542018343107657237 + 79095042939612668486212006406818285287004*I)/38685626227668133590597632, (10100577916793945997239221374025741184951 - 28631383488085522003281589065994018550748*I)/9671406556917033397649408, 67*(10090295594251078955008130473573667572549 + 10449901522697161049513326446427839676762*I)/77371252455336267181195264, (-54270981296988368730689531355811033930513 - 3413683117592637309471893510944045467443*I)/19342813113834066795298816, (440372322928679910536575560069973699181278 - 736603803202303189048085196176918214409081*I)/77371252455336267181195264, (33220374714789391132887731139763250155295 + 92055083048787219934030779066298919603554*I)/38685626227668133590597632, 5*(-594638554579967244348856981610805281527116 - 82309245323128933521987392165716076704057*I)/309485009821345068724781056, (128056368815300084550013708313312073721955 - 114619107488668120303579745393765245911404*I)/77371252455336267181195264, 21*(59839959255173222962789517794121843393573 + 241507883613676387255359616163487405826334*I)/618970019642690137449562112], +# [ (-13454485022325376674626653802541391955147 + 184471402121905621396582628515905949793486*I)/19342813113834066795298816, (-6158730123400322562149780662133074862437105 - 3416173052604643794120262081623703514107476*I)/154742504910672534362390528, (770558003844914708453618983120686116100419 - 127758381209767638635199674005029818518766*I)/77371252455336267181195264, (-4693005771813492267479835161596671660631703 + 12703585094750991389845384539501921531449948*I)/309485009821345068724781056, (-295028157441149027913545676461260860036601 - 841544569970643160358138082317324743450770*I)/77371252455336267181195264, (56716442796929448856312202561538574275502893 + 7216818824772560379753073185990186711454778*I)/1237940039285380274899124224, 15*(-87061038932753366532685677510172566368387 + 61306141156647596310941396434445461895538*I)/154742504910672534362390528, (-3455315109680781412178133042301025723909347 - 24969329563196972466388460746447646686670670*I)/618970019642690137449562112, (2453418854160886481106557323699250865361849 + 1497886802326243014471854112161398141242514*I)/309485009821345068724781056, (-151343224544252091980004429001205664193082173 + 90471883264187337053549090899816228846836628*I)/4951760157141521099596496896, (1652018205533026103358164026239417416432989 - 9959733619236515024261775397109724431400162*I)/1237940039285380274899124224, 3*(40676374242956907656984876692623172736522006 + 31023357083037817469535762230872667581366205*I)/4951760157141521099596496896], +# [ (-1226990509403328460274658603410696548387 - 4131739423109992672186585941938392788458*I)/1208925819614629174706176, (162392818524418973411975140074368079662703 + 23706194236915374831230612374344230400704*I)/9671406556917033397649408, (-3935678233089814180000602553655565621193 + 2283744757287145199688061892165659502483*I)/1208925819614629174706176, (-2400210250844254483454290806930306285131 - 315571356806370996069052930302295432758205*I)/19342813113834066795298816, (13365917938215281056563183751673390817910 + 15911483133819801118348625831132324863881*I)/4835703278458516698824704, 3*(-215950551370668982657516660700301003897855 + 51684341999223632631602864028309400489378*I)/38685626227668133590597632, (20886089946811765149439844691320027184765 - 30806277083146786592790625980769214361844*I)/9671406556917033397649408, (562180634592713285745940856221105667874855 + 1031543963988260765153550559766662245114916*I)/77371252455336267181195264, (-65820625814810177122941758625652476012867 - 12429918324787060890804395323920477537595*I)/19342813113834066795298816, (319147848192012911298771180196635859221089 - 402403304933906769233365689834404519960394*I)/38685626227668133590597632, (23035615120921026080284733394359587955057 + 115351677687031786114651452775242461310624*I)/38685626227668133590597632, (-3426830634881892756966440108592579264936130 - 1022954961164128745603407283836365128598559*I)/309485009821345068724781056], +# [ (-192574788060137531023716449082856117537757 - 69222967328876859586831013062387845780692*I)/19342813113834066795298816, (2736383768828013152914815341491629299773262 - 2773252698016291897599353862072533475408743*I)/77371252455336267181195264, (-23280005281223837717773057436155921656805 + 214784953368021840006305033048142888879224*I)/19342813113834066795298816, (-3035247484028969580570400133318947903462326 - 2195168903335435855621328554626336958674325*I)/77371252455336267181195264, (984552428291526892214541708637840971548653 - 64006622534521425620714598573494988589378*I)/77371252455336267181195264, (-3070650452470333005276715136041262898509903 + 7286424705750810474140953092161794621989080*I)/154742504910672534362390528, (-147848877109756404594659513386972921139270 - 416306113044186424749331418059456047650861*I)/38685626227668133590597632, (55272118474097814260289392337160619494260781 + 7494019668394781211907115583302403519488058*I)/1237940039285380274899124224, (-581537886583682322424771088996959213068864 + 542191617758465339135308203815256798407429*I)/77371252455336267181195264, (-6422548983676355789975736799494791970390991 - 23524183982209004826464749309156698827737702*I)/618970019642690137449562112, 7*(180747195387024536886923192475064903482083 + 84352527693562434817771649853047924991804*I)/154742504910672534362390528, (-135485179036717001055310712747643466592387031 + 102346575226653028836678855697782273460527608*I)/4951760157141521099596496896], +# [ (3384238362616083147067025892852431152105 + 156724444932584900214919898954874618256*I)/604462909807314587353088, (-59558300950677430189587207338385764871866 + 114427143574375271097298201388331237478857*I)/4835703278458516698824704, (-1356835789870635633517710130971800616227 - 7023484098542340388800213478357340875410*I)/1208925819614629174706176, (234884918567993750975181728413524549575881 + 79757294640629983786895695752733890213506*I)/9671406556917033397649408, (-7632732774935120473359202657160313866419 + 2905452608512927560554702228553291839465*I)/1208925819614629174706176, (52291747908702842344842889809762246649489 - 520996778817151392090736149644507525892649*I)/19342813113834066795298816, (17472406829219127839967951180375981717322 + 23464704213841582137898905375041819568669*I)/4835703278458516698824704, (-911026971811893092350229536132730760943307 + 150799318130900944080399439626714846752360*I)/38685626227668133590597632, (26234457233977042811089020440646443590687 - 45650293039576452023692126463683727692890*I)/9671406556917033397649408, 3*(288348388717468992528382586652654351121357 + 454526517721403048270274049572136109264668*I)/77371252455336267181195264, (-91583492367747094223295011999405657956347 - 12704691128268298435362255538069612411331*I)/19342813113834066795298816, (411208730251327843849027957710164064354221 - 569898526380691606955496789378230959965898*I)/38685626227668133590597632], +# [ (27127513117071487872628354831658811211795 - 37765296987901990355760582016892124833857*I)/4835703278458516698824704, (1741779916057680444272938534338833170625435 + 3083041729779495966997526404685535449810378*I)/77371252455336267181195264, 3*(-60642236251815783728374561836962709533401 - 24630301165439580049891518846174101510744*I)/19342813113834066795298816, 3*(445885207364591681637745678755008757483408 - 350948497734812895032502179455610024541643*I)/38685626227668133590597632, (-47373295621391195484367368282471381775684 + 219122969294089357477027867028071400054973*I)/19342813113834066795298816, (-2801565819673198722993348253876353741520438 - 2250142129822658548391697042460298703335701*I)/77371252455336267181195264, (801448252275607253266997552356128790317119 - 50890367688077858227059515894356594900558*I)/77371252455336267181195264, (-5082187758525931944557763799137987573501207 + 11610432359082071866576699236013484487676124*I)/309485009821345068724781056, (-328925127096560623794883760398247685166830 - 643447969697471610060622160899409680422019*I)/77371252455336267181195264, 15*(2954944669454003684028194956846659916299765 + 33434406416888505837444969347824812608566*I)/1237940039285380274899124224, (-415749104352001509942256567958449835766827 + 479330966144175743357171151440020955412219*I)/77371252455336267181195264, 3*(-4639987285852134369449873547637372282914255 - 11994411888966030153196659207284951579243273*I)/1237940039285380274899124224], +# [ (-478846096206269117345024348666145495601 + 1249092488629201351470551186322814883283*I)/302231454903657293676544, (-17749319421930878799354766626365926894989 - 18264580106418628161818752318217357231971*I)/1208925819614629174706176, (2801110795431528876849623279389579072819 + 363258850073786330770713557775566973248*I)/604462909807314587353088, (-59053496693129013745775512127095650616252 + 78143588734197260279248498898321500167517*I)/4835703278458516698824704, (-283186724922498212468162690097101115349 - 6443437753863179883794497936345437398276*I)/1208925819614629174706176, (188799118826748909206887165661384998787543 + 84274736720556630026311383931055307398820*I)/9671406556917033397649408, (-5482217151670072904078758141270295025989 + 1818284338672191024475557065444481298568*I)/1208925819614629174706176, (56564463395350195513805521309731217952281 - 360208541416798112109946262159695452898431*I)/19342813113834066795298816, 11*(1259539805728870739006416869463689438068 + 1409136581547898074455004171305324917387*I)/4835703278458516698824704, 5*(-123701190701414554945251071190688818343325 + 30997157322590424677294553832111902279712*I)/38685626227668133590597632, (16130917381301373033736295883982414239781 - 32752041297570919727145380131926943374516*I)/9671406556917033397649408, (650301385108223834347093740500375498354925 + 899526407681131828596801223402866051809258*I)/77371252455336267181195264], +# [ (9011388245256140876590294262420614839483 + 8167917972423946282513000869327525382672*I)/1208925819614629174706176, (-426393174084720190126376382194036323028924 + 180692224825757525982858693158209545430621*I)/9671406556917033397649408, (24588556702197802674765733448108154175535 - 45091766022876486566421953254051868331066*I)/4835703278458516698824704, (1872113939365285277373877183750416985089691 + 3030392393733212574744122057679633775773130*I)/77371252455336267181195264, (-222173405538046189185754954524429864167549 - 75193157893478637039381059488387511299116*I)/19342813113834066795298816, (2670821320766222522963689317316937579844558 - 2645837121493554383087981511645435472169191*I)/77371252455336267181195264, 5*(-2100110309556476773796963197283876204940 + 41957457246479840487980315496957337371937*I)/19342813113834066795298816, (-5733743755499084165382383818991531258980593 - 3328949988392698205198574824396695027195732*I)/154742504910672534362390528, (707827994365259025461378911159398206329247 - 265730616623227695108042528694302299777294*I)/77371252455336267181195264, (-1442501604682933002895864804409322823788319 + 11504137805563265043376405214378288793343879*I)/309485009821345068724781056, (-56130472299445561499538726459719629522285 - 61117552419727805035810982426639329818864*I)/9671406556917033397649408, (39053692321126079849054272431599539429908717 - 10209127700342570953247177602860848130710666*I)/1237940039285380274899124224]]) + M = Matrix(S('''[ + [ -3/4, 45/32 - 37*I/16, 1/4 + I/2, -129/64 - 9*I/64, 1/4 - 5*I/16, 65/128 + 87*I/64], + [-149/64 + 49*I/32, -177/128 - 1369*I/128, 125/64 + 87*I/64, -2063/256 + 541*I/128, 85/256 - 33*I/16, 805/128 + 2415*I/512], + [ 1/2 - I, 9/4 + 55*I/16, -3/4, 45/32 - 37*I/16, 1/4 + I/2, -129/64 - 9*I/64], + [ -5/8 - 39*I/16, 2473/256 + 137*I/64, -149/64 + 49*I/32, -177/128 - 1369*I/128, 125/64 + 87*I/64, -2063/256 + 541*I/128], + [ 1 + I, -19/4 + 5*I/4, 1/2 - I, 9/4 + 55*I/16, -3/4, 45/32 - 37*I/16], + [ 21/8 + I, -537/64 + 143*I/16, -5/8 - 39*I/16, 2473/256 + 137*I/64, -149/64 + 49*I/32, -177/128 - 1369*I/128]]''')) + with dotprodsimp(True): + assert M**10 == Matrix(S('''[ + [ 7369525394972778926719607798014571861/604462909807314587353088 - 229284202061790301477392339912557559*I/151115727451828646838272, -19704281515163975949388435612632058035/1208925819614629174706176 + 14319858347987648723768698170712102887*I/302231454903657293676544, -3623281909451783042932142262164941211/604462909807314587353088 - 6039240602494288615094338643452320495*I/604462909807314587353088, 109260497799140408739847239685705357695/2417851639229258349412352 - 7427566006564572463236368211555511431*I/2417851639229258349412352, -16095803767674394244695716092817006641/2417851639229258349412352 + 10336681897356760057393429626719177583*I/1208925819614629174706176, -42207883340488041844332828574359769743/2417851639229258349412352 - 182332262671671273188016400290188468499*I/4835703278458516698824704], + [50566491050825573392726324995779608259/1208925819614629174706176 - 90047007594468146222002432884052362145*I/2417851639229258349412352, 74273703462900000967697427843983822011/1208925819614629174706176 + 265947522682943571171988741842776095421*I/1208925819614629174706176, -116900341394390200556829767923360888429/2417851639229258349412352 - 53153263356679268823910621474478756845*I/2417851639229258349412352, 195407378023867871243426523048612490249/1208925819614629174706176 - 1242417915995360200584837585002906728929*I/9671406556917033397649408, -863597594389821970177319682495878193/302231454903657293676544 + 476936100741548328800725360758734300481*I/9671406556917033397649408, -3154451590535653853562472176601754835575/19342813113834066795298816 - 232909875490506237386836489998407329215*I/2417851639229258349412352], + [ -1715444997702484578716037230949868543/302231454903657293676544 + 5009695651321306866158517287924120777*I/302231454903657293676544, -30551582497996879620371947949342101301/604462909807314587353088 - 7632518367986526187139161303331519629*I/151115727451828646838272, 312680739924495153190604170938220575/18889465931478580854784 - 108664334509328818765959789219208459*I/75557863725914323419136, -14693696966703036206178521686918865509/604462909807314587353088 + 72345386220900843930147151999899692401*I/1208925819614629174706176, -8218872496728882299722894680635296519/1208925819614629174706176 - 16776782833358893712645864791807664983*I/1208925819614629174706176, 143237839169380078671242929143670635137/2417851639229258349412352 + 2883817094806115974748882735218469447*I/2417851639229258349412352], + [ 3087979417831061365023111800749855987/151115727451828646838272 + 34441942370802869368851419102423997089*I/604462909807314587353088, -148309181940158040917731426845476175667/604462909807314587353088 - 263987151804109387844966835369350904919*I/9671406556917033397649408, 50259518594816377378747711930008883165/1208925819614629174706176 - 95713974916869240305450001443767979653*I/2417851639229258349412352, 153466447023875527996457943521467271119/2417851639229258349412352 + 517285524891117105834922278517084871349*I/2417851639229258349412352, -29184653615412989036678939366291205575/604462909807314587353088 - 27551322282526322041080173287022121083*I/1208925819614629174706176, 196404220110085511863671393922447671649/1208925819614629174706176 - 1204712019400186021982272049902206202145*I/9671406556917033397649408], + [ -2632581805949645784625606590600098779/151115727451828646838272 - 589957435912868015140272627522612771*I/37778931862957161709568, 26727850893953715274702844733506310247/302231454903657293676544 - 10825791956782128799168209600694020481*I/302231454903657293676544, -1036348763702366164044671908440791295/151115727451828646838272 + 3188624571414467767868303105288107375*I/151115727451828646838272, -36814959939970644875593411585393242449/604462909807314587353088 - 18457555789119782404850043842902832647*I/302231454903657293676544, 12454491297984637815063964572803058647/604462909807314587353088 - 340489532842249733975074349495329171*I/302231454903657293676544, -19547211751145597258386735573258916681/604462909807314587353088 + 87299583775782199663414539883938008933*I/1208925819614629174706176], + [ -40281994229560039213253423262678393183/604462909807314587353088 - 2939986850065527327299273003299736641*I/604462909807314587353088, 331940684638052085845743020267462794181/2417851639229258349412352 - 284574901963624403933361315517248458969*I/1208925819614629174706176, 6453843623051745485064693628073010961/302231454903657293676544 + 36062454107479732681350914931391590957*I/604462909807314587353088, -147665869053634695632880753646441962067/604462909807314587353088 - 305987938660447291246597544085345123927*I/9671406556917033397649408, 107821369195275772166593879711259469423/2417851639229258349412352 - 11645185518211204108659001435013326687*I/302231454903657293676544, 64121228424717666402009446088588091619/1208925819614629174706176 + 265557133337095047883844369272389762133*I/1208925819614629174706176]]''')) + +def test_issue_17247_expression_blowup_5(): + M = Matrix(6, 6, lambda i, j: 1 + (-1)**(i+j)*I) + with dotprodsimp(True): + assert M.charpoly('x') == PurePoly(x**6 + (-6 - 6*I)*x**5 + 36*I*x**4, x, domain='EX') + +def test_issue_17247_expression_blowup_6(): + M = Matrix(8, 8, [x+i for i in range (64)]) + with dotprodsimp(True): + assert M.det('bareiss') == 0 + +def test_issue_17247_expression_blowup_7(): + M = Matrix(6, 6, lambda i, j: 1 + (-1)**(i+j)*I) + with dotprodsimp(True): + assert M.det('berkowitz') == 0 + +def test_issue_17247_expression_blowup_8(): + M = Matrix(8, 8, [x+i for i in range (64)]) + with dotprodsimp(True): + assert M.det('lu') == 0 + +def test_issue_17247_expression_blowup_9(): + M = Matrix(8, 8, [x+i for i in range (64)]) + with dotprodsimp(True): + assert M.rref() == (Matrix([ + [1, 0, -1, -2, -3, -4, -5, -6], + [0, 1, 2, 3, 4, 5, 6, 7], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0]]), (0, 1)) + +def test_issue_17247_expression_blowup_10(): + M = Matrix(6, 6, lambda i, j: 1 + (-1)**(i+j)*I) + with dotprodsimp(True): + assert M.cofactor(0, 0) == 0 + +def test_issue_17247_expression_blowup_11(): + M = Matrix(6, 6, lambda i, j: 1 + (-1)**(i+j)*I) + with dotprodsimp(True): + assert M.cofactor_matrix() == Matrix(6, 6, [0]*36) + +def test_issue_17247_expression_blowup_12(): + M = Matrix(6, 6, lambda i, j: 1 + (-1)**(i+j)*I) + with dotprodsimp(True): + assert M.eigenvals() == {6: 1, 6*I: 1, 0: 4} + +def test_issue_17247_expression_blowup_13(): + M = Matrix([ + [ 0, 1 - x, x + 1, 1 - x], + [1 - x, x + 1, 0, x + 1], + [ 0, 1 - x, x + 1, 1 - x], + [ 0, 0, 1 - x, 0]]) + + ev = M.eigenvects() + assert ev[0] == (0, 2, [Matrix([0, -1, 0, 1])]) + assert ev[1][0] == x - sqrt(2)*(x - 1) + 1 + assert ev[1][1] == 1 + assert ev[1][2][0].expand(deep=False, numer=True) == Matrix([ + [(-x + sqrt(2)*(x - 1) - 1)/(x - 1)], + [-4*x/(x**2 - 2*x + 1) + (x + 1)*(x - sqrt(2)*(x - 1) + 1)/(x**2 - 2*x + 1)], + [(-x + sqrt(2)*(x - 1) - 1)/(x - 1)], + [1] + ]) + + assert ev[2][0] == x + sqrt(2)*(x - 1) + 1 + assert ev[2][1] == 1 + assert ev[2][2][0].expand(deep=False, numer=True) == Matrix([ + [(-x - sqrt(2)*(x - 1) - 1)/(x - 1)], + [-4*x/(x**2 - 2*x + 1) + (x + 1)*(x + sqrt(2)*(x - 1) + 1)/(x**2 - 2*x + 1)], + [(-x - sqrt(2)*(x - 1) - 1)/(x - 1)], + [1] + ]) + + +def test_issue_17247_expression_blowup_14(): + M = Matrix(8, 8, ([1+x, 1-x]*4 + [1-x, 1+x]*4)*4) + with dotprodsimp(True): + assert M.echelon_form() == Matrix([ + [x + 1, 1 - x, x + 1, 1 - x, x + 1, 1 - x, x + 1, 1 - x], + [ 0, 4*x, 0, 4*x, 0, 4*x, 0, 4*x], + [ 0, 0, 0, 0, 0, 0, 0, 0], + [ 0, 0, 0, 0, 0, 0, 0, 0], + [ 0, 0, 0, 0, 0, 0, 0, 0], + [ 0, 0, 0, 0, 0, 0, 0, 0], + [ 0, 0, 0, 0, 0, 0, 0, 0], + [ 0, 0, 0, 0, 0, 0, 0, 0]]) + +def test_issue_17247_expression_blowup_15(): + M = Matrix(8, 8, ([1+x, 1-x]*4 + [1-x, 1+x]*4)*4) + with dotprodsimp(True): + assert M.rowspace() == [Matrix([[x + 1, 1 - x, x + 1, 1 - x, x + 1, 1 - x, x + 1, 1 - x]]), Matrix([[0, 4*x, 0, 4*x, 0, 4*x, 0, 4*x]])] + +def test_issue_17247_expression_blowup_16(): + M = Matrix(8, 8, ([1+x, 1-x]*4 + [1-x, 1+x]*4)*4) + with dotprodsimp(True): + assert M.columnspace() == [Matrix([[x + 1],[1 - x],[x + 1],[1 - x],[x + 1],[1 - x],[x + 1],[1 - x]]), Matrix([[1 - x],[x + 1],[1 - x],[x + 1],[1 - x],[x + 1],[1 - x],[x + 1]])] + +def test_issue_17247_expression_blowup_17(): + M = Matrix(8, 8, [x+i for i in range (64)]) + with dotprodsimp(True): + assert M.nullspace() == [ + Matrix([[1],[-2],[1],[0],[0],[0],[0],[0]]), + Matrix([[2],[-3],[0],[1],[0],[0],[0],[0]]), + Matrix([[3],[-4],[0],[0],[1],[0],[0],[0]]), + Matrix([[4],[-5],[0],[0],[0],[1],[0],[0]]), + Matrix([[5],[-6],[0],[0],[0],[0],[1],[0]]), + Matrix([[6],[-7],[0],[0],[0],[0],[0],[1]])] + +def test_issue_17247_expression_blowup_18(): + M = Matrix(6, 6, ([1+x, 1-x]*3 + [1-x, 1+x]*3)*3) + with dotprodsimp(True): + assert not M.is_nilpotent() + +def test_issue_17247_expression_blowup_19(): + M = Matrix(S('''[ + [ -3/4, 0, 1/4 + I/2, 0], + [ 0, -177/128 - 1369*I/128, 0, -2063/256 + 541*I/128], + [ 1/2 - I, 0, 0, 0], + [ 0, 0, 0, -177/128 - 1369*I/128]]''')) + with dotprodsimp(True): + assert not M.is_diagonalizable() + +def test_issue_17247_expression_blowup_20(): + M = Matrix([ + [x + 1, 1 - x, 0, 0], + [1 - x, x + 1, 0, x + 1], + [ 0, 1 - x, x + 1, 0], + [ 0, 0, 0, x + 1]]) + with dotprodsimp(True): + assert M.diagonalize() == (Matrix([ + [1, 1, 0, (x + 1)/(x - 1)], + [1, -1, 0, 0], + [1, 1, 1, 0], + [0, 0, 0, 1]]), + Matrix([ + [2, 0, 0, 0], + [0, 2*x, 0, 0], + [0, 0, x + 1, 0], + [0, 0, 0, x + 1]])) + +def test_issue_17247_expression_blowup_21(): + M = Matrix(S('''[ + [ -3/4, 45/32 - 37*I/16, 0, 0], + [-149/64 + 49*I/32, -177/128 - 1369*I/128, 0, -2063/256 + 541*I/128], + [ 0, 9/4 + 55*I/16, 2473/256 + 137*I/64, 0], + [ 0, 0, 0, -177/128 - 1369*I/128]]''')) + with dotprodsimp(True): + assert M.inv(method='GE') == Matrix(S('''[ + [-26194832/3470993 - 31733264*I/3470993, 156352/3470993 + 10325632*I/3470993, 0, -7741283181072/3306971225785 + 2999007604624*I/3306971225785], + [4408224/3470993 - 9675328*I/3470993, -2422272/3470993 + 1523712*I/3470993, 0, -1824666489984/3306971225785 - 1401091949952*I/3306971225785], + [-26406945676288/22270005630769 + 10245925485056*I/22270005630769, 7453523312640/22270005630769 + 1601616519168*I/22270005630769, 633088/6416033 - 140288*I/6416033, 872209227109521408/21217636514687010905 + 6066405081802389504*I/21217636514687010905], + [0, 0, 0, -11328/952745 + 87616*I/952745]]''')) + +def test_issue_17247_expression_blowup_22(): + M = Matrix(S('''[ + [ -3/4, 45/32 - 37*I/16, 0, 0], + [-149/64 + 49*I/32, -177/128 - 1369*I/128, 0, -2063/256 + 541*I/128], + [ 0, 9/4 + 55*I/16, 2473/256 + 137*I/64, 0], + [ 0, 0, 0, -177/128 - 1369*I/128]]''')) + with dotprodsimp(True): + assert M.inv(method='LU') == Matrix(S('''[ + [-26194832/3470993 - 31733264*I/3470993, 156352/3470993 + 10325632*I/3470993, 0, -7741283181072/3306971225785 + 2999007604624*I/3306971225785], + [4408224/3470993 - 9675328*I/3470993, -2422272/3470993 + 1523712*I/3470993, 0, -1824666489984/3306971225785 - 1401091949952*I/3306971225785], + [-26406945676288/22270005630769 + 10245925485056*I/22270005630769, 7453523312640/22270005630769 + 1601616519168*I/22270005630769, 633088/6416033 - 140288*I/6416033, 872209227109521408/21217636514687010905 + 6066405081802389504*I/21217636514687010905], + [0, 0, 0, -11328/952745 + 87616*I/952745]]''')) + +def test_issue_17247_expression_blowup_23(): + M = Matrix(S('''[ + [ -3/4, 45/32 - 37*I/16, 0, 0], + [-149/64 + 49*I/32, -177/128 - 1369*I/128, 0, -2063/256 + 541*I/128], + [ 0, 9/4 + 55*I/16, 2473/256 + 137*I/64, 0], + [ 0, 0, 0, -177/128 - 1369*I/128]]''')) + with dotprodsimp(True): + assert M.inv(method='ADJ').expand() == Matrix(S('''[ + [-26194832/3470993 - 31733264*I/3470993, 156352/3470993 + 10325632*I/3470993, 0, -7741283181072/3306971225785 + 2999007604624*I/3306971225785], + [4408224/3470993 - 9675328*I/3470993, -2422272/3470993 + 1523712*I/3470993, 0, -1824666489984/3306971225785 - 1401091949952*I/3306971225785], + [-26406945676288/22270005630769 + 10245925485056*I/22270005630769, 7453523312640/22270005630769 + 1601616519168*I/22270005630769, 633088/6416033 - 140288*I/6416033, 872209227109521408/21217636514687010905 + 6066405081802389504*I/21217636514687010905], + [0, 0, 0, -11328/952745 + 87616*I/952745]]''')) + +def test_issue_17247_expression_blowup_24(): + M = SparseMatrix(S('''[ + [ -3/4, 45/32 - 37*I/16, 0, 0], + [-149/64 + 49*I/32, -177/128 - 1369*I/128, 0, -2063/256 + 541*I/128], + [ 0, 9/4 + 55*I/16, 2473/256 + 137*I/64, 0], + [ 0, 0, 0, -177/128 - 1369*I/128]]''')) + with dotprodsimp(True): + assert M.inv(method='CH') == Matrix(S('''[ + [-26194832/3470993 - 31733264*I/3470993, 156352/3470993 + 10325632*I/3470993, 0, -7741283181072/3306971225785 + 2999007604624*I/3306971225785], + [4408224/3470993 - 9675328*I/3470993, -2422272/3470993 + 1523712*I/3470993, 0, -1824666489984/3306971225785 - 1401091949952*I/3306971225785], + [-26406945676288/22270005630769 + 10245925485056*I/22270005630769, 7453523312640/22270005630769 + 1601616519168*I/22270005630769, 633088/6416033 - 140288*I/6416033, 872209227109521408/21217636514687010905 + 6066405081802389504*I/21217636514687010905], + [0, 0, 0, -11328/952745 + 87616*I/952745]]''')) + +def test_issue_17247_expression_blowup_25(): + M = SparseMatrix(S('''[ + [ -3/4, 45/32 - 37*I/16, 0, 0], + [-149/64 + 49*I/32, -177/128 - 1369*I/128, 0, -2063/256 + 541*I/128], + [ 0, 9/4 + 55*I/16, 2473/256 + 137*I/64, 0], + [ 0, 0, 0, -177/128 - 1369*I/128]]''')) + with dotprodsimp(True): + assert M.inv(method='LDL') == Matrix(S('''[ + [-26194832/3470993 - 31733264*I/3470993, 156352/3470993 + 10325632*I/3470993, 0, -7741283181072/3306971225785 + 2999007604624*I/3306971225785], + [4408224/3470993 - 9675328*I/3470993, -2422272/3470993 + 1523712*I/3470993, 0, -1824666489984/3306971225785 - 1401091949952*I/3306971225785], + [-26406945676288/22270005630769 + 10245925485056*I/22270005630769, 7453523312640/22270005630769 + 1601616519168*I/22270005630769, 633088/6416033 - 140288*I/6416033, 872209227109521408/21217636514687010905 + 6066405081802389504*I/21217636514687010905], + [0, 0, 0, -11328/952745 + 87616*I/952745]]''')) + +def test_issue_17247_expression_blowup_26(): + M = Matrix(S('''[ + [ -3/4, 45/32 - 37*I/16, 1/4 + I/2, -129/64 - 9*I/64, 1/4 - 5*I/16, 65/128 + 87*I/64, -9/32 - I/16, 183/256 - 97*I/128], + [-149/64 + 49*I/32, -177/128 - 1369*I/128, 125/64 + 87*I/64, -2063/256 + 541*I/128, 85/256 - 33*I/16, 805/128 + 2415*I/512, -219/128 + 115*I/256, 6301/4096 - 6609*I/1024], + [ 1/2 - I, 9/4 + 55*I/16, -3/4, 45/32 - 37*I/16, 1/4 + I/2, -129/64 - 9*I/64, 1/4 - 5*I/16, 65/128 + 87*I/64], + [ -5/8 - 39*I/16, 2473/256 + 137*I/64, -149/64 + 49*I/32, -177/128 - 1369*I/128, 125/64 + 87*I/64, -2063/256 + 541*I/128, 85/256 - 33*I/16, 805/128 + 2415*I/512], + [ 1 + I, -19/4 + 5*I/4, 1/2 - I, 9/4 + 55*I/16, -3/4, 45/32 - 37*I/16, 1/4 + I/2, -129/64 - 9*I/64], + [ 21/8 + I, -537/64 + 143*I/16, -5/8 - 39*I/16, 2473/256 + 137*I/64, -149/64 + 49*I/32, -177/128 - 1369*I/128, 125/64 + 87*I/64, -2063/256 + 541*I/128], + [ -2, 17/4 - 13*I/2, 1 + I, -19/4 + 5*I/4, 1/2 - I, 9/4 + 55*I/16, -3/4, 45/32 - 37*I/16], + [ 1/4 + 13*I/4, -825/64 - 147*I/32, 21/8 + I, -537/64 + 143*I/16, -5/8 - 39*I/16, 2473/256 + 137*I/64, -149/64 + 49*I/32, -177/128 - 1369*I/128]]''')) + with dotprodsimp(True): + assert M.rank() == 4 + +def test_issue_17247_expression_blowup_27(): + M = Matrix([ + [ 0, 1 - x, x + 1, 1 - x], + [1 - x, x + 1, 0, x + 1], + [ 0, 1 - x, x + 1, 1 - x], + [ 0, 0, 1 - x, 0]]) + with dotprodsimp(True): + P, J = M.jordan_form() + assert P.expand() == Matrix(S('''[ + [ 0, 4*x/(x**2 - 2*x + 1), -(-17*x**4 + 12*sqrt(2)*x**4 - 4*sqrt(2)*x**3 + 6*x**3 - 6*x - 4*sqrt(2)*x + 12*sqrt(2) + 17)/(-7*x**4 + 5*sqrt(2)*x**4 - 6*sqrt(2)*x**3 + 8*x**3 - 2*x**2 + 8*x + 6*sqrt(2)*x - 5*sqrt(2) - 7), -(12*sqrt(2)*x**4 + 17*x**4 - 6*x**3 - 4*sqrt(2)*x**3 - 4*sqrt(2)*x + 6*x - 17 + 12*sqrt(2))/(7*x**4 + 5*sqrt(2)*x**4 - 6*sqrt(2)*x**3 - 8*x**3 + 2*x**2 - 8*x + 6*sqrt(2)*x - 5*sqrt(2) + 7)], + [x - 1, x/(x - 1) + 1/(x - 1), (-7*x**3 + 5*sqrt(2)*x**3 - x**2 + sqrt(2)*x**2 - sqrt(2)*x - x - 5*sqrt(2) - 7)/(-3*x**3 + 2*sqrt(2)*x**3 - 2*sqrt(2)*x**2 + 3*x**2 + 2*sqrt(2)*x + 3*x - 3 - 2*sqrt(2)), (7*x**3 + 5*sqrt(2)*x**3 + x**2 + sqrt(2)*x**2 - sqrt(2)*x + x - 5*sqrt(2) + 7)/(2*sqrt(2)*x**3 + 3*x**3 - 3*x**2 - 2*sqrt(2)*x**2 - 3*x + 2*sqrt(2)*x - 2*sqrt(2) + 3)], + [ 0, 1, -(-3*x**2 + 2*sqrt(2)*x**2 + 2*x - 3 - 2*sqrt(2))/(-x**2 + sqrt(2)*x**2 - 2*sqrt(2)*x + 1 + sqrt(2)), -(2*sqrt(2)*x**2 + 3*x**2 - 2*x - 2*sqrt(2) + 3)/(x**2 + sqrt(2)*x**2 - 2*sqrt(2)*x - 1 + sqrt(2))], + [1 - x, 0, 1, 1]]''')).expand() + assert J == Matrix(S('''[ + [0, 1, 0, 0], + [0, 0, 0, 0], + [0, 0, x - sqrt(2)*(x - 1) + 1, 0], + [0, 0, 0, x + sqrt(2)*(x - 1) + 1]]''')) + +def test_issue_17247_expression_blowup_28(): + M = Matrix(S('''[ + [ -3/4, 45/32 - 37*I/16, 0, 0], + [-149/64 + 49*I/32, -177/128 - 1369*I/128, 0, -2063/256 + 541*I/128], + [ 0, 9/4 + 55*I/16, 2473/256 + 137*I/64, 0], + [ 0, 0, 0, -177/128 - 1369*I/128]]''')) + with dotprodsimp(True): + assert M.singular_values() == S('''[ + sqrt(14609315/131072 + sqrt(64789115132571/2147483648 - 2*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3) + 76627253330829751075/(35184372088832*sqrt(64789115132571/4294967296 + 3546944054712886603889144627/(110680464442257309696*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3)) + 2*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3))) - 3546944054712886603889144627/(110680464442257309696*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3)))/2 + sqrt(64789115132571/4294967296 + 3546944054712886603889144627/(110680464442257309696*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3)) + 2*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3))/2), + sqrt(14609315/131072 - sqrt(64789115132571/2147483648 - 2*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3) + 76627253330829751075/(35184372088832*sqrt(64789115132571/4294967296 + 3546944054712886603889144627/(110680464442257309696*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3)) + 2*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3))) - 3546944054712886603889144627/(110680464442257309696*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3)))/2 + sqrt(64789115132571/4294967296 + 3546944054712886603889144627/(110680464442257309696*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3)) + 2*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3))/2), + sqrt(14609315/131072 - sqrt(64789115132571/4294967296 + 3546944054712886603889144627/(110680464442257309696*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3)) + 2*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3))/2 + sqrt(64789115132571/2147483648 - 2*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3) - 76627253330829751075/(35184372088832*sqrt(64789115132571/4294967296 + 3546944054712886603889144627/(110680464442257309696*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3)) + 2*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3))) - 3546944054712886603889144627/(110680464442257309696*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3)))/2), + sqrt(14609315/131072 - sqrt(64789115132571/4294967296 + 3546944054712886603889144627/(110680464442257309696*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3)) + 2*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3))/2 - sqrt(64789115132571/2147483648 - 2*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3) - 76627253330829751075/(35184372088832*sqrt(64789115132571/4294967296 + 3546944054712886603889144627/(110680464442257309696*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3)) + 2*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3))) - 3546944054712886603889144627/(110680464442257309696*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3)))/2)]''') + + +def test_issue_16823(): + # This still needs to be fixed if not using dotprodsimp. + M = Matrix(S('''[ + [1+I,-19/4+5/4*I,1/2-I,9/4+55/16*I,-3/4,45/32-37/16*I,1/4+1/2*I,-129/64-9/64*I,1/4-5/16*I,65/128+87/64*I,-9/32-1/16*I,183/256-97/128*I,3/64+13/64*I,-23/32-59/256*I,15/128-3/32*I,19/256+551/1024*I], + [21/8+I,-537/64+143/16*I,-5/8-39/16*I,2473/256+137/64*I,-149/64+49/32*I,-177/128-1369/128*I,125/64+87/64*I,-2063/256+541/128*I,85/256-33/16*I,805/128+2415/512*I,-219/128+115/256*I,6301/4096-6609/1024*I,119/128+143/128*I,-10879/2048+4343/4096*I,129/256-549/512*I,42533/16384+29103/8192*I], + [-2,17/4-13/2*I,1+I,-19/4+5/4*I,1/2-I,9/4+55/16*I,-3/4,45/32-37/16*I,1/4+1/2*I,-129/64-9/64*I,1/4-5/16*I,65/128+87/64*I,-9/32-1/16*I,183/256-97/128*I,3/64+13/64*I,-23/32-59/256*I], + [1/4+13/4*I,-825/64-147/32*I,21/8+I,-537/64+143/16*I,-5/8-39/16*I,2473/256+137/64*I,-149/64+49/32*I,-177/128-1369/128*I,125/64+87/64*I,-2063/256+541/128*I,85/256-33/16*I,805/128+2415/512*I,-219/128+115/256*I,6301/4096-6609/1024*I,119/128+143/128*I,-10879/2048+4343/4096*I], + [-4*I,27/2+6*I,-2,17/4-13/2*I,1+I,-19/4+5/4*I,1/2-I,9/4+55/16*I,-3/4,45/32-37/16*I,1/4+1/2*I,-129/64-9/64*I,1/4-5/16*I,65/128+87/64*I,-9/32-1/16*I,183/256-97/128*I], + [1/4+5/2*I,-23/8-57/16*I,1/4+13/4*I,-825/64-147/32*I,21/8+I,-537/64+143/16*I,-5/8-39/16*I,2473/256+137/64*I,-149/64+49/32*I,-177/128-1369/128*I,125/64+87/64*I,-2063/256+541/128*I,85/256-33/16*I,805/128+2415/512*I,-219/128+115/256*I,6301/4096-6609/1024*I], + [-4,9-5*I,-4*I,27/2+6*I,-2,17/4-13/2*I,1+I,-19/4+5/4*I,1/2-I,9/4+55/16*I,-3/4,45/32-37/16*I,1/4+1/2*I,-129/64-9/64*I,1/4-5/16*I,65/128+87/64*I], + [-2*I,119/8+29/4*I,1/4+5/2*I,-23/8-57/16*I,1/4+13/4*I,-825/64-147/32*I,21/8+I,-537/64+143/16*I,-5/8-39/16*I,2473/256+137/64*I,-149/64+49/32*I,-177/128-1369/128*I,125/64+87/64*I,-2063/256+541/128*I,85/256-33/16*I,805/128+2415/512*I], + [0,-6,-4,9-5*I,-4*I,27/2+6*I,-2,17/4-13/2*I,1+I,-19/4+5/4*I,1/2-I,9/4+55/16*I,-3/4,45/32-37/16*I,1/4+1/2*I,-129/64-9/64*I], + [1,-9/4+3*I,-2*I,119/8+29/4*I,1/4+5/2*I,-23/8-57/16*I,1/4+13/4*I,-825/64-147/32*I,21/8+I,-537/64+143/16*I,-5/8-39/16*I,2473/256+137/64*I,-149/64+49/32*I,-177/128-1369/128*I,125/64+87/64*I,-2063/256+541/128*I], + [0,-4*I,0,-6,-4,9-5*I,-4*I,27/2+6*I,-2,17/4-13/2*I,1+I,-19/4+5/4*I,1/2-I,9/4+55/16*I,-3/4,45/32-37/16*I], + [0,1/4+1/2*I,1,-9/4+3*I,-2*I,119/8+29/4*I,1/4+5/2*I,-23/8-57/16*I,1/4+13/4*I,-825/64-147/32*I,21/8+I,-537/64+143/16*I,-5/8-39/16*I,2473/256+137/64*I,-149/64+49/32*I,-177/128-1369/128*I]]''')) + with dotprodsimp(True): + assert M.rank() == 8 + + +def test_issue_18531(): + # solve_linear_system still needs fixing but the rref works. + M = Matrix([ + [1, 1, 1, 1, 1, 0, 1, 0, 0], + [1 + sqrt(2), -1 + sqrt(2), 1 - sqrt(2), -sqrt(2) - 1, 1, 1, -1, 1, 1], + [-5 + 2*sqrt(2), -5 - 2*sqrt(2), -5 - 2*sqrt(2), -5 + 2*sqrt(2), -7, 2, -7, -2, 0], + [-3*sqrt(2) - 1, 1 - 3*sqrt(2), -1 + 3*sqrt(2), 1 + 3*sqrt(2), -7, -5, 7, -5, 3], + [7 - 4*sqrt(2), 4*sqrt(2) + 7, 4*sqrt(2) + 7, 7 - 4*sqrt(2), 7, -12, 7, 12, 0], + [-1 + 3*sqrt(2), 1 + 3*sqrt(2), -3*sqrt(2) - 1, 1 - 3*sqrt(2), 7, -5, -7, -5, 3], + [-3 + 2*sqrt(2), -3 - 2*sqrt(2), -3 - 2*sqrt(2), -3 + 2*sqrt(2), -1, 2, -1, -2, 0], + [1 - sqrt(2), -sqrt(2) - 1, 1 + sqrt(2), -1 + sqrt(2), -1, 1, 1, 1, 1] + ]) + with dotprodsimp(True): + assert M.rref() == (Matrix([ + [1, 0, 0, 0, 0, 0, 0, 0, S(1)/2], + [0, 1, 0, 0, 0, 0, 0, 0, -S(1)/2], + [0, 0, 1, 0, 0, 0, 0, 0, S(1)/2], + [0, 0, 0, 1, 0, 0, 0, 0, -S(1)/2], + [0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, -S(1)/2], + [0, 0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, -S(1)/2]]), (0, 1, 2, 3, 4, 5, 6, 7)) + + +def test_creation(): + raises(ValueError, lambda: Matrix(5, 5, range(20))) + raises(ValueError, lambda: Matrix(5, -1, [])) + raises(IndexError, lambda: Matrix((1, 2))[2]) + with raises(IndexError): + Matrix((1, 2))[3] = 5 + + assert Matrix() == Matrix([]) == Matrix(0, 0, []) + assert Matrix([[]]) == Matrix(1, 0, []) + assert Matrix([[], []]) == Matrix(2, 0, []) + + # anything used to be allowed in a matrix + with warns_deprecated_sympy(): + assert Matrix([[[1], (2,)]]).tolist() == [[[1], (2,)]] + with warns_deprecated_sympy(): + assert Matrix([[[1], (2,)]]).T.tolist() == [[[1]], [(2,)]] + M = Matrix([[0]]) + with warns_deprecated_sympy(): + M[0, 0] = S.EmptySet + + a = Matrix([[x, 0], [0, 0]]) + m = a + assert m.cols == m.rows + assert m.cols == 2 + assert m[:] == [x, 0, 0, 0] + + b = Matrix(2, 2, [x, 0, 0, 0]) + m = b + assert m.cols == m.rows + assert m.cols == 2 + assert m[:] == [x, 0, 0, 0] + + assert a == b + + assert Matrix(b) == b + + c23 = Matrix(2, 3, range(1, 7)) + c13 = Matrix(1, 3, range(7, 10)) + c = Matrix([c23, c13]) + assert c.cols == 3 + assert c.rows == 3 + assert c[:] == [1, 2, 3, 4, 5, 6, 7, 8, 9] + + assert Matrix(eye(2)) == eye(2) + assert ImmutableMatrix(ImmutableMatrix(eye(2))) == ImmutableMatrix(eye(2)) + assert ImmutableMatrix(c) == c.as_immutable() + assert Matrix(ImmutableMatrix(c)) == ImmutableMatrix(c).as_mutable() + + assert c is not Matrix(c) + + dat = [[ones(3,2), ones(3,3)*2], [ones(2,3)*3, ones(2,2)*4]] + M = Matrix(dat) + assert M == Matrix([ + [1, 1, 2, 2, 2], + [1, 1, 2, 2, 2], + [1, 1, 2, 2, 2], + [3, 3, 3, 4, 4], + [3, 3, 3, 4, 4]]) + assert M.tolist() != dat + # keep block form if evaluate=False + assert Matrix(dat, evaluate=False).tolist() == dat + A = MatrixSymbol("A", 2, 2) + dat = [ones(2), A] + assert Matrix(dat) == Matrix([ + [ 1, 1], + [ 1, 1], + [A[0, 0], A[0, 1]], + [A[1, 0], A[1, 1]]]) + with warns_deprecated_sympy(): + assert Matrix(dat, evaluate=False).tolist() == [[i] for i in dat] + + # 0-dim tolerance + assert Matrix([ones(2), ones(0)]) == Matrix([ones(2)]) + raises(ValueError, lambda: Matrix([ones(2), ones(0, 3)])) + raises(ValueError, lambda: Matrix([ones(2), ones(3, 0)])) + + # mix of Matrix and iterable + M = Matrix([[1, 2], [3, 4]]) + M2 = Matrix([M, (5, 6)]) + assert M2 == Matrix([[1, 2], [3, 4], [5, 6]]) + + +def test_irregular_block(): + assert Matrix.irregular(3, ones(2,1), ones(3,3)*2, ones(2,2)*3, + ones(1,1)*4, ones(2,2)*5, ones(1,2)*6, ones(1,2)*7) == Matrix([ + [1, 2, 2, 2, 3, 3], + [1, 2, 2, 2, 3, 3], + [4, 2, 2, 2, 5, 5], + [6, 6, 7, 7, 5, 5]]) + + +def test_tolist(): + lst = [[S.One, S.Half, x*y, S.Zero], [x, y, z, x**2], [y, -S.One, z*x, 3]] + m = Matrix(lst) + assert m.tolist() == lst + + +def test_as_mutable(): + assert zeros(0, 3).as_mutable() == zeros(0, 3) + assert zeros(0, 3).as_immutable() == ImmutableMatrix(zeros(0, 3)) + assert zeros(3, 0).as_immutable() == ImmutableMatrix(zeros(3, 0)) + + +def test_slicing(): + m0 = eye(4) + assert m0[:3, :3] == eye(3) + assert m0[2:4, 0:2] == zeros(2) + + m1 = Matrix(3, 3, lambda i, j: i + j) + assert m1[0, :] == Matrix(1, 3, (0, 1, 2)) + assert m1[1:3, 1] == Matrix(2, 1, (2, 3)) + + m2 = Matrix([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]) + assert m2[:, -1] == Matrix(4, 1, [3, 7, 11, 15]) + assert m2[-2:, :] == Matrix([[8, 9, 10, 11], [12, 13, 14, 15]]) + + +def test_submatrix_assignment(): + m = zeros(4) + m[2:4, 2:4] = eye(2) + assert m == Matrix(((0, 0, 0, 0), + (0, 0, 0, 0), + (0, 0, 1, 0), + (0, 0, 0, 1))) + m[:2, :2] = eye(2) + assert m == eye(4) + m[:, 0] = Matrix(4, 1, (1, 2, 3, 4)) + assert m == Matrix(((1, 0, 0, 0), + (2, 1, 0, 0), + (3, 0, 1, 0), + (4, 0, 0, 1))) + m[:, :] = zeros(4) + assert m == zeros(4) + m[:, :] = [(1, 2, 3, 4), (5, 6, 7, 8), (9, 10, 11, 12), (13, 14, 15, 16)] + assert m == Matrix(((1, 2, 3, 4), + (5, 6, 7, 8), + (9, 10, 11, 12), + (13, 14, 15, 16))) + m[:2, 0] = [0, 0] + assert m == Matrix(((0, 2, 3, 4), + (0, 6, 7, 8), + (9, 10, 11, 12), + (13, 14, 15, 16))) + + +def test_extract(): + m = Matrix(4, 3, lambda i, j: i*3 + j) + assert m.extract([0, 1, 3], [0, 1]) == Matrix(3, 2, [0, 1, 3, 4, 9, 10]) + assert m.extract([0, 3], [0, 0, 2]) == Matrix(2, 3, [0, 0, 2, 9, 9, 11]) + assert m.extract(range(4), range(3)) == m + raises(IndexError, lambda: m.extract([4], [0])) + raises(IndexError, lambda: m.extract([0], [3])) + + +def test_reshape(): + m0 = eye(3) + assert m0.reshape(1, 9) == Matrix(1, 9, (1, 0, 0, 0, 1, 0, 0, 0, 1)) + m1 = Matrix(3, 4, lambda i, j: i + j) + assert m1.reshape( + 4, 3) == Matrix(((0, 1, 2), (3, 1, 2), (3, 4, 2), (3, 4, 5))) + assert m1.reshape(2, 6) == Matrix(((0, 1, 2, 3, 1, 2), (3, 4, 2, 3, 4, 5))) + + +def test_applyfunc(): + m0 = eye(3) + assert m0.applyfunc(lambda x: 2*x) == eye(3)*2 + assert m0.applyfunc(lambda x: 0) == zeros(3) + + +def test_expand(): + m0 = Matrix([[x*(x + y), 2], [((x + y)*y)*x, x*(y + x*(x + y))]]) + # Test if expand() returns a matrix + m1 = m0.expand() + assert m1 == Matrix( + [[x*y + x**2, 2], [x*y**2 + y*x**2, x*y + y*x**2 + x**3]]) + + a = Symbol('a', real=True) + + assert Matrix([exp(I*a)]).expand(complex=True) == \ + Matrix([cos(a) + I*sin(a)]) + + assert Matrix([[0, 1, 2], [0, 0, -1], [0, 0, 0]]).exp() == Matrix([ + [1, 1, Rational(3, 2)], + [0, 1, -1], + [0, 0, 1]] + ) + +def test_refine(): + m0 = Matrix([[Abs(x)**2, sqrt(x**2)], + [sqrt(x**2)*Abs(y)**2, sqrt(y**2)*Abs(x)**2]]) + m1 = m0.refine(Q.real(x) & Q.real(y)) + assert m1 == Matrix([[x**2, Abs(x)], [y**2*Abs(x), x**2*Abs(y)]]) + + m1 = m0.refine(Q.positive(x) & Q.positive(y)) + assert m1 == Matrix([[x**2, x], [x*y**2, x**2*y]]) + + m1 = m0.refine(Q.negative(x) & Q.negative(y)) + assert m1 == Matrix([[x**2, -x], [-x*y**2, -x**2*y]]) + +def test_random(): + M = randMatrix(3, 3) + M = randMatrix(3, 3, seed=3) + assert M == randMatrix(3, 3, seed=3) + + M = randMatrix(3, 4, 0, 150) + M = randMatrix(3, seed=4, symmetric=True) + assert M == randMatrix(3, seed=4, symmetric=True) + + S = M.copy() + S.simplify() + assert S == M # doesn't fail when elements are Numbers, not int + + rng = random.Random(4) + assert M == randMatrix(3, symmetric=True, prng=rng) + + # Ensure symmetry + for size in (10, 11): # Test odd and even + for percent in (100, 70, 30): + M = randMatrix(size, symmetric=True, percent=percent, prng=rng) + assert M == M.T + + M = randMatrix(10, min=1, percent=70) + zero_count = 0 + for i in range(M.shape[0]): + for j in range(M.shape[1]): + if M[i, j] == 0: + zero_count += 1 + assert zero_count == 30 + +def test_inverse(): + A = eye(4) + assert A.inv() == eye(4) + assert A.inv(method="LU") == eye(4) + assert A.inv(method="ADJ") == eye(4) + assert A.inv(method="CH") == eye(4) + assert A.inv(method="LDL") == eye(4) + assert A.inv(method="QR") == eye(4) + A = Matrix([[2, 3, 5], + [3, 6, 2], + [8, 3, 6]]) + Ainv = A.inv() + assert A*Ainv == eye(3) + assert A.inv(method="LU") == Ainv + assert A.inv(method="ADJ") == Ainv + assert A.inv(method="CH") == Ainv + assert A.inv(method="LDL") == Ainv + assert A.inv(method="QR") == Ainv + + AA = Matrix([[0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0], + [1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0], + [1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1], + [1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0], + [1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1], + [0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0], + [1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1], + [0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1], + [1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0], + [0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0], + [1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0], + [0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1], + [1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0], + [0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0], + [1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1], + [0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1], + [1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1], + [0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1], + [0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1], + [0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0], + [0, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0]]) + assert AA.inv(method="BLOCK") * AA == eye(AA.shape[0]) + # test that immutability is not a problem + cls = ImmutableMatrix + m = cls([[48, 49, 31], + [ 9, 71, 94], + [59, 28, 65]]) + assert all(type(m.inv(s)) is cls for s in 'GE ADJ LU CH LDL QR'.split()) + cls = ImmutableSparseMatrix + m = cls([[48, 49, 31], + [ 9, 71, 94], + [59, 28, 65]]) + assert all(type(m.inv(s)) is cls for s in 'GE ADJ LU CH LDL QR'.split()) + + +def test_jacobian_hessian(): + L = Matrix(1, 2, [x**2*y, 2*y**2 + x*y]) + syms = [x, y] + assert L.jacobian(syms) == Matrix([[2*x*y, x**2], [y, 4*y + x]]) + + L = Matrix(1, 2, [x, x**2*y**3]) + assert L.jacobian(syms) == Matrix([[1, 0], [2*x*y**3, x**2*3*y**2]]) + + f = x**2*y + syms = [x, y] + assert hessian(f, syms) == Matrix([[2*y, 2*x], [2*x, 0]]) + + f = x**2*y**3 + assert hessian(f, syms) == \ + Matrix([[2*y**3, 6*x*y**2], [6*x*y**2, 6*x**2*y]]) + + f = z + x*y**2 + g = x**2 + 2*y**3 + ans = Matrix([[0, 2*y], + [2*y, 2*x]]) + assert ans == hessian(f, Matrix([x, y])) + assert ans == hessian(f, Matrix([x, y]).T) + assert hessian(f, (y, x), [g]) == Matrix([ + [ 0, 6*y**2, 2*x], + [6*y**2, 2*x, 2*y], + [ 2*x, 2*y, 0]]) + + +def test_wronskian(): + assert wronskian([cos(x), sin(x)], x) == cos(x)**2 + sin(x)**2 + assert wronskian([exp(x), exp(2*x)], x) == exp(3*x) + assert wronskian([exp(x), x], x) == exp(x) - x*exp(x) + assert wronskian([1, x, x**2], x) == 2 + w1 = -6*exp(x)*sin(x)*x + 6*cos(x)*exp(x)*x**2 - 6*exp(x)*cos(x)*x - \ + exp(x)*cos(x)*x**3 + exp(x)*sin(x)*x**3 + assert wronskian([exp(x), cos(x), x**3], x).expand() == w1 + assert wronskian([exp(x), cos(x), x**3], x, method='berkowitz').expand() \ + == w1 + w2 = -x**3*cos(x)**2 - x**3*sin(x)**2 - 6*x*cos(x)**2 - 6*x*sin(x)**2 + assert wronskian([sin(x), cos(x), x**3], x).expand() == w2 + assert wronskian([sin(x), cos(x), x**3], x, method='berkowitz').expand() \ + == w2 + assert wronskian([], x) == 1 + + +def test_subs(): + assert Matrix([[1, x], [x, 4]]).subs(x, 5) == Matrix([[1, 5], [5, 4]]) + assert Matrix([[x, 2], [x + y, 4]]).subs([[x, -1], [y, -2]]) == \ + Matrix([[-1, 2], [-3, 4]]) + assert Matrix([[x, 2], [x + y, 4]]).subs([(x, -1), (y, -2)]) == \ + Matrix([[-1, 2], [-3, 4]]) + assert Matrix([[x, 2], [x + y, 4]]).subs({x: -1, y: -2}) == \ + Matrix([[-1, 2], [-3, 4]]) + assert Matrix([x*y]).subs({x: y - 1, y: x - 1}, simultaneous=True) == \ + Matrix([(x - 1)*(y - 1)]) + + for cls in classes: + assert Matrix([[2, 0], [0, 2]]) == cls.eye(2).subs(1, 2) + +def test_xreplace(): + assert Matrix([[1, x], [x, 4]]).xreplace({x: 5}) == \ + Matrix([[1, 5], [5, 4]]) + assert Matrix([[x, 2], [x + y, 4]]).xreplace({x: -1, y: -2}) == \ + Matrix([[-1, 2], [-3, 4]]) + for cls in classes: + assert Matrix([[2, 0], [0, 2]]) == cls.eye(2).xreplace({1: 2}) + +def test_simplify(): + n = Symbol('n') + f = Function('f') + + M = Matrix([[ 1/x + 1/y, (x + x*y) / x ], + [ (f(x) + y*f(x))/f(x), 2 * (1/n - cos(n * pi)/n) / pi ]]) + M.simplify() + assert M == Matrix([[ (x + y)/(x * y), 1 + y ], + [ 1 + y, 2*((1 - 1*cos(pi*n))/(pi*n)) ]]) + eq = (1 + x)**2 + M = Matrix([[eq]]) + M.simplify() + assert M == Matrix([[eq]]) + M.simplify(ratio=oo) + assert M == Matrix([[eq.simplify(ratio=oo)]]) + + +def test_transpose(): + M = Matrix([[1, 2, 3, 4, 5, 6, 7, 8, 9, 0], + [1, 2, 3, 4, 5, 6, 7, 8, 9, 0]]) + assert M.T == Matrix( [ [1, 1], + [2, 2], + [3, 3], + [4, 4], + [5, 5], + [6, 6], + [7, 7], + [8, 8], + [9, 9], + [0, 0] ]) + assert M.T.T == M + assert M.T == M.transpose() + + +def test_conjugate(): + M = Matrix([[0, I, 5], + [1, 2, 0]]) + + assert M.T == Matrix([[0, 1], + [I, 2], + [5, 0]]) + + assert M.C == Matrix([[0, -I, 5], + [1, 2, 0]]) + assert M.C == M.conjugate() + + assert M.H == M.T.C + assert M.H == Matrix([[ 0, 1], + [-I, 2], + [ 5, 0]]) + + +def test_conj_dirac(): + raises(AttributeError, lambda: eye(3).D) + + M = Matrix([[1, I, I, I], + [0, 1, I, I], + [0, 0, 1, I], + [0, 0, 0, 1]]) + + assert M.D == Matrix([[ 1, 0, 0, 0], + [-I, 1, 0, 0], + [-I, -I, -1, 0], + [-I, -I, I, -1]]) + + +def test_trace(): + M = Matrix([[1, 0, 0], + [0, 5, 0], + [0, 0, 8]]) + assert M.trace() == 14 + + +def test_shape(): + M = Matrix([[x, 0, 0], + [0, y, 0]]) + assert M.shape == (2, 3) + + +def test_col_row_op(): + M = Matrix([[x, 0, 0], + [0, y, 0]]) + M.row_op(1, lambda r, j: r + j + 1) + assert M == Matrix([[x, 0, 0], + [1, y + 2, 3]]) + + M.col_op(0, lambda c, j: c + y**j) + assert M == Matrix([[x + 1, 0, 0], + [1 + y, y + 2, 3]]) + + # neither row nor slice give copies that allow the original matrix to + # be changed + assert M.row(0) == Matrix([[x + 1, 0, 0]]) + r1 = M.row(0) + r1[0] = 42 + assert M[0, 0] == x + 1 + r1 = M[0, :-1] # also testing negative slice + r1[0] = 42 + assert M[0, 0] == x + 1 + c1 = M.col(0) + assert c1 == Matrix([x + 1, 1 + y]) + c1[0] = 0 + assert M[0, 0] == x + 1 + c1 = M[:, 0] + c1[0] = 42 + assert M[0, 0] == x + 1 + + +def test_row_mult(): + M = Matrix([[1,2,3], + [4,5,6]]) + M.row_mult(1,3) + assert M[1,0] == 12 + assert M[0,0] == 1 + assert M[1,2] == 18 + + +def test_row_add(): + M = Matrix([[1,2,3], + [4,5,6], + [1,1,1]]) + M.row_add(2,0,5) + assert M[0,0] == 6 + assert M[1,0] == 4 + assert M[0,2] == 8 + + +def test_zip_row_op(): + for cls in classes[:2]: # XXX: immutable matrices don't support row ops + M = cls.eye(3) + M.zip_row_op(1, 0, lambda v, u: v + 2*u) + assert M == cls([[1, 0, 0], + [2, 1, 0], + [0, 0, 1]]) + + M = cls.eye(3)*2 + M[0, 1] = -1 + M.zip_row_op(1, 0, lambda v, u: v + 2*u); M + assert M == cls([[2, -1, 0], + [4, 0, 0], + [0, 0, 2]]) + +def test_issue_3950(): + m = Matrix([1, 2, 3]) + a = Matrix([1, 2, 3]) + b = Matrix([2, 2, 3]) + assert not (m in []) + assert not (m in [1]) + assert m != 1 + assert m == a + assert m != b + + +def test_issue_3981(): + class Index1: + def __index__(self): + return 1 + + class Index2: + def __index__(self): + return 2 + index1 = Index1() + index2 = Index2() + + m = Matrix([1, 2, 3]) + + assert m[index2] == 3 + + m[index2] = 5 + assert m[2] == 5 + + m = Matrix([[1, 2, 3], [4, 5, 6]]) + assert m[index1, index2] == 6 + assert m[1, index2] == 6 + assert m[index1, 2] == 6 + + m[index1, index2] = 4 + assert m[1, 2] == 4 + m[1, index2] = 6 + assert m[1, 2] == 6 + m[index1, 2] = 8 + assert m[1, 2] == 8 + + +def test_evalf(): + a = Matrix([sqrt(5), 6]) + assert all(a.evalf()[i] == a[i].evalf() for i in range(2)) + assert all(a.evalf(2)[i] == a[i].evalf(2) for i in range(2)) + assert all(a.n(2)[i] == a[i].n(2) for i in range(2)) + + +def test_is_symbolic(): + a = Matrix([[x, x], [x, x]]) + assert a.is_symbolic() is True + a = Matrix([[1, 2, 3, 4], [5, 6, 7, 8]]) + assert a.is_symbolic() is False + a = Matrix([[1, 2, 3, 4], [5, 6, x, 8]]) + assert a.is_symbolic() is True + a = Matrix([[1, x, 3]]) + assert a.is_symbolic() is True + a = Matrix([[1, 2, 3]]) + assert a.is_symbolic() is False + a = Matrix([[1], [x], [3]]) + assert a.is_symbolic() is True + a = Matrix([[1], [2], [3]]) + assert a.is_symbolic() is False + + +def test_is_upper(): + a = Matrix([[1, 2, 3]]) + assert a.is_upper is True + a = Matrix([[1], [2], [3]]) + assert a.is_upper is False + a = zeros(4, 2) + assert a.is_upper is True + + +def test_is_lower(): + a = Matrix([[1, 2, 3]]) + assert a.is_lower is False + a = Matrix([[1], [2], [3]]) + assert a.is_lower is True + + +def test_is_nilpotent(): + a = Matrix(4, 4, [0, 2, 1, 6, 0, 0, 1, 2, 0, 0, 0, 3, 0, 0, 0, 0]) + assert a.is_nilpotent() + a = Matrix([[1, 0], [0, 1]]) + assert not a.is_nilpotent() + a = Matrix([]) + assert a.is_nilpotent() + + +def test_zeros_ones_fill(): + n, m = 3, 5 + + a = zeros(n, m) + a.fill( 5 ) + + b = 5 * ones(n, m) + + assert a == b + assert a.rows == b.rows == 3 + assert a.cols == b.cols == 5 + assert a.shape == b.shape == (3, 5) + assert zeros(2) == zeros(2, 2) + assert ones(2) == ones(2, 2) + assert zeros(2, 3) == Matrix(2, 3, [0]*6) + assert ones(2, 3) == Matrix(2, 3, [1]*6) + + a.fill(0) + assert a == zeros(n, m) + + +def test_empty_zeros(): + a = zeros(0) + assert a == Matrix() + a = zeros(0, 2) + assert a.rows == 0 + assert a.cols == 2 + a = zeros(2, 0) + assert a.rows == 2 + assert a.cols == 0 + + +def test_issue_3749(): + a = Matrix([[x**2, x*y], [x*sin(y), x*cos(y)]]) + assert a.diff(x) == Matrix([[2*x, y], [sin(y), cos(y)]]) + assert Matrix([ + [x, -x, x**2], + [exp(x), 1/x - exp(-x), x + 1/x]]).limit(x, oo) == \ + Matrix([[oo, -oo, oo], [oo, 0, oo]]) + assert Matrix([ + [(exp(x) - 1)/x, 2*x + y*x, x**x ], + [1/x, abs(x), abs(sin(x + 1))]]).limit(x, 0) == \ + Matrix([[1, 0, 1], [oo, 0, sin(1)]]) + assert a.integrate(x) == Matrix([ + [Rational(1, 3)*x**3, y*x**2/2], + [x**2*sin(y)/2, x**2*cos(y)/2]]) + + +def test_inv_iszerofunc(): + A = eye(4) + A.col_swap(0, 1) + for method in "GE", "LU": + assert A.inv(method=method, iszerofunc=lambda x: x == 0) == \ + A.inv(method="ADJ") + + +def test_jacobian_metrics(): + rho, phi = symbols("rho,phi") + X = Matrix([rho*cos(phi), rho*sin(phi)]) + Y = Matrix([rho, phi]) + J = X.jacobian(Y) + assert J == X.jacobian(Y.T) + assert J == (X.T).jacobian(Y) + assert J == (X.T).jacobian(Y.T) + g = J.T*eye(J.shape[0])*J + g = g.applyfunc(trigsimp) + assert g == Matrix([[1, 0], [0, rho**2]]) + + +def test_jacobian2(): + rho, phi = symbols("rho,phi") + X = Matrix([rho*cos(phi), rho*sin(phi), rho**2]) + Y = Matrix([rho, phi]) + J = Matrix([ + [cos(phi), -rho*sin(phi)], + [sin(phi), rho*cos(phi)], + [ 2*rho, 0], + ]) + assert X.jacobian(Y) == J + + +def test_issue_4564(): + X = Matrix([exp(x + y + z), exp(x + y + z), exp(x + y + z)]) + Y = Matrix([x, y, z]) + for i in range(1, 3): + for j in range(1, 3): + X_slice = X[:i, :] + Y_slice = Y[:j, :] + J = X_slice.jacobian(Y_slice) + assert J.rows == i + assert J.cols == j + for k in range(j): + assert J[:, k] == X_slice + + +def test_nonvectorJacobian(): + X = Matrix([[exp(x + y + z), exp(x + y + z)], + [exp(x + y + z), exp(x + y + z)]]) + raises(TypeError, lambda: X.jacobian(Matrix([x, y, z]))) + X = X[0, :] + Y = Matrix([[x, y], [x, z]]) + raises(TypeError, lambda: X.jacobian(Y)) + raises(TypeError, lambda: X.jacobian(Matrix([ [x, y], [x, z] ]))) + + +def test_vec(): + m = Matrix([[1, 3], [2, 4]]) + m_vec = m.vec() + assert m_vec.cols == 1 + for i in range(4): + assert m_vec[i] == i + 1 + + +def test_vech(): + m = Matrix([[1, 2], [2, 3]]) + m_vech = m.vech() + assert m_vech.cols == 1 + for i in range(3): + assert m_vech[i] == i + 1 + m_vech = m.vech(diagonal=False) + assert m_vech[0] == 2 + + m = Matrix([[1, x*(x + y)], [y*x + x**2, 1]]) + m_vech = m.vech(diagonal=False) + assert m_vech[0] == y*x + x**2 + + m = Matrix([[1, x*(x + y)], [y*x, 1]]) + m_vech = m.vech(diagonal=False, check_symmetry=False) + assert m_vech[0] == y*x + + raises(ShapeError, lambda: Matrix([[1, 3]]).vech()) + raises(ValueError, lambda: Matrix([[1, 3], [2, 4]]).vech()) + raises(ShapeError, lambda: Matrix([[1, 3]]).vech()) + raises(ValueError, lambda: Matrix([[1, 3], [2, 4]]).vech()) + + +def test_diag(): + # mostly tested in testcommonmatrix.py + assert diag([1, 2, 3]) == Matrix([1, 2, 3]) + m = [1, 2, [3]] + raises(ValueError, lambda: diag(m)) + assert diag(m, strict=False) == Matrix([1, 2, 3]) + + +def test_get_diag_blocks1(): + a = Matrix([[1, 2], [2, 3]]) + b = Matrix([[3, x], [y, 3]]) + c = Matrix([[3, x, 3], [y, 3, z], [x, y, z]]) + assert a.get_diag_blocks() == [a] + assert b.get_diag_blocks() == [b] + assert c.get_diag_blocks() == [c] + + +def test_get_diag_blocks2(): + a = Matrix([[1, 2], [2, 3]]) + b = Matrix([[3, x], [y, 3]]) + c = Matrix([[3, x, 3], [y, 3, z], [x, y, z]]) + assert diag(a, b, b).get_diag_blocks() == [a, b, b] + assert diag(a, b, c).get_diag_blocks() == [a, b, c] + assert diag(a, c, b).get_diag_blocks() == [a, c, b] + assert diag(c, c, b).get_diag_blocks() == [c, c, b] + + +def test_inv_block(): + a = Matrix([[1, 2], [2, 3]]) + b = Matrix([[3, x], [y, 3]]) + c = Matrix([[3, x, 3], [y, 3, z], [x, y, z]]) + A = diag(a, b, b) + assert A.inv(try_block_diag=True) == diag(a.inv(), b.inv(), b.inv()) + A = diag(a, b, c) + assert A.inv(try_block_diag=True) == diag(a.inv(), b.inv(), c.inv()) + A = diag(a, c, b) + assert A.inv(try_block_diag=True) == diag(a.inv(), c.inv(), b.inv()) + A = diag(a, a, b, a, c, a) + assert A.inv(try_block_diag=True) == diag( + a.inv(), a.inv(), b.inv(), a.inv(), c.inv(), a.inv()) + assert A.inv(try_block_diag=True, method="ADJ") == diag( + a.inv(method="ADJ"), a.inv(method="ADJ"), b.inv(method="ADJ"), + a.inv(method="ADJ"), c.inv(method="ADJ"), a.inv(method="ADJ")) + + +def test_creation_args(): + """ + Check that matrix dimensions can be specified using any reasonable type + (see issue 4614). + """ + raises(ValueError, lambda: zeros(3, -1)) + raises(TypeError, lambda: zeros(1, 2, 3, 4)) + assert zeros(int(3)) == zeros(3) + assert zeros(Integer(3)) == zeros(3) + raises(ValueError, lambda: zeros(3.)) + assert eye(int(3)) == eye(3) + assert eye(Integer(3)) == eye(3) + raises(ValueError, lambda: eye(3.)) + assert ones(int(3), Integer(4)) == ones(3, 4) + raises(TypeError, lambda: Matrix(5)) + raises(TypeError, lambda: Matrix(1, 2)) + raises(ValueError, lambda: Matrix([1, [2]])) + + +def test_diagonal_symmetrical(): + m = Matrix(2, 2, [0, 1, 1, 0]) + assert not m.is_diagonal() + assert m.is_symmetric() + assert m.is_symmetric(simplify=False) + + m = Matrix(2, 2, [1, 0, 0, 1]) + assert m.is_diagonal() + + m = diag(1, 2, 3) + assert m.is_diagonal() + assert m.is_symmetric() + + m = Matrix(3, 3, [1, 0, 0, 0, 2, 0, 0, 0, 3]) + assert m == diag(1, 2, 3) + + m = Matrix(2, 3, zeros(2, 3)) + assert not m.is_symmetric() + assert m.is_diagonal() + + m = Matrix(((5, 0), (0, 6), (0, 0))) + assert m.is_diagonal() + + m = Matrix(((5, 0, 0), (0, 6, 0))) + assert m.is_diagonal() + + m = Matrix(3, 3, [1, x**2 + 2*x + 1, y, (x + 1)**2, 2, 0, y, 0, 3]) + assert m.is_symmetric() + assert not m.is_symmetric(simplify=False) + assert m.expand().is_symmetric(simplify=False) + + +def test_diagonalization(): + m = Matrix([[1, 2+I], [2-I, 3]]) + assert m.is_diagonalizable() + + m = Matrix(3, 2, [-3, 1, -3, 20, 3, 10]) + assert not m.is_diagonalizable() + assert not m.is_symmetric() + raises(NonSquareMatrixError, lambda: m.diagonalize()) + + # diagonalizable + m = diag(1, 2, 3) + (P, D) = m.diagonalize() + assert P == eye(3) + assert D == m + + m = Matrix(2, 2, [0, 1, 1, 0]) + assert m.is_symmetric() + assert m.is_diagonalizable() + (P, D) = m.diagonalize() + assert P.inv() * m * P == D + + m = Matrix(2, 2, [1, 0, 0, 3]) + assert m.is_symmetric() + assert m.is_diagonalizable() + (P, D) = m.diagonalize() + assert P.inv() * m * P == D + assert P == eye(2) + assert D == m + + m = Matrix(2, 2, [1, 1, 0, 0]) + assert m.is_diagonalizable() + (P, D) = m.diagonalize() + assert P.inv() * m * P == D + + m = Matrix(3, 3, [1, 2, 0, 0, 3, 0, 2, -4, 2]) + assert m.is_diagonalizable() + (P, D) = m.diagonalize() + assert P.inv() * m * P == D + for i in P: + assert i.as_numer_denom()[1] == 1 + + m = Matrix(2, 2, [1, 0, 0, 0]) + assert m.is_diagonal() + assert m.is_diagonalizable() + (P, D) = m.diagonalize() + assert P.inv() * m * P == D + assert P == Matrix([[0, 1], [1, 0]]) + + # diagonalizable, complex only + m = Matrix(2, 2, [0, 1, -1, 0]) + assert not m.is_diagonalizable(True) + raises(MatrixError, lambda: m.diagonalize(True)) + assert m.is_diagonalizable() + (P, D) = m.diagonalize() + assert P.inv() * m * P == D + + # not diagonalizable + m = Matrix(2, 2, [0, 1, 0, 0]) + assert not m.is_diagonalizable() + raises(MatrixError, lambda: m.diagonalize()) + + m = Matrix(3, 3, [-3, 1, -3, 20, 3, 10, 2, -2, 4]) + assert not m.is_diagonalizable() + raises(MatrixError, lambda: m.diagonalize()) + + # symbolic + a, b, c, d = symbols('a b c d') + m = Matrix(2, 2, [a, c, c, b]) + assert m.is_symmetric() + assert m.is_diagonalizable() + + +def test_issue_15887(): + # Mutable matrix should not use cache + a = MutableDenseMatrix([[0, 1], [1, 0]]) + assert a.is_diagonalizable() is True + a[1, 0] = 0 + assert a.is_diagonalizable() is False + + a = MutableDenseMatrix([[0, 1], [1, 0]]) + a.diagonalize() + a[1, 0] = 0 + raises(MatrixError, lambda: a.diagonalize()) + + +def test_jordan_form(): + + m = Matrix(3, 2, [-3, 1, -3, 20, 3, 10]) + raises(NonSquareMatrixError, lambda: m.jordan_form()) + + # diagonalizable + m = Matrix(3, 3, [7, -12, 6, 10, -19, 10, 12, -24, 13]) + Jmust = Matrix(3, 3, [-1, 0, 0, 0, 1, 0, 0, 0, 1]) + P, J = m.jordan_form() + assert Jmust == J + assert Jmust == m.diagonalize()[1] + + # m = Matrix(3, 3, [0, 6, 3, 1, 3, 1, -2, 2, 1]) + # m.jordan_form() # very long + # m.jordan_form() # + + # diagonalizable, complex only + + # Jordan cells + # complexity: one of eigenvalues is zero + m = Matrix(3, 3, [0, 1, 0, -4, 4, 0, -2, 1, 2]) + # The blocks are ordered according to the value of their eigenvalues, + # in order to make the matrix compatible with .diagonalize() + Jmust = Matrix(3, 3, [2, 1, 0, 0, 2, 0, 0, 0, 2]) + P, J = m.jordan_form() + assert Jmust == J + + # complexity: all of eigenvalues are equal + m = Matrix(3, 3, [2, 6, -15, 1, 1, -5, 1, 2, -6]) + # Jmust = Matrix(3, 3, [-1, 0, 0, 0, -1, 1, 0, 0, -1]) + # same here see 1456ff + Jmust = Matrix(3, 3, [-1, 1, 0, 0, -1, 0, 0, 0, -1]) + P, J = m.jordan_form() + assert Jmust == J + + # complexity: two of eigenvalues are zero + m = Matrix(3, 3, [4, -5, 2, 5, -7, 3, 6, -9, 4]) + Jmust = Matrix(3, 3, [0, 1, 0, 0, 0, 0, 0, 0, 1]) + P, J = m.jordan_form() + assert Jmust == J + + m = Matrix(4, 4, [6, 5, -2, -3, -3, -1, 3, 3, 2, 1, -2, -3, -1, 1, 5, 5]) + Jmust = Matrix(4, 4, [2, 1, 0, 0, + 0, 2, 0, 0, + 0, 0, 2, 1, + 0, 0, 0, 2] + ) + P, J = m.jordan_form() + assert Jmust == J + + m = Matrix(4, 4, [6, 2, -8, -6, -3, 2, 9, 6, 2, -2, -8, -6, -1, 0, 3, 4]) + # Jmust = Matrix(4, 4, [2, 0, 0, 0, 0, 2, 1, 0, 0, 0, 2, 0, 0, 0, 0, -2]) + # same here see 1456ff + Jmust = Matrix(4, 4, [-2, 0, 0, 0, + 0, 2, 1, 0, + 0, 0, 2, 0, + 0, 0, 0, 2]) + P, J = m.jordan_form() + assert Jmust == J + + m = Matrix(4, 4, [5, 4, 2, 1, 0, 1, -1, -1, -1, -1, 3, 0, 1, 1, -1, 2]) + assert not m.is_diagonalizable() + Jmust = Matrix(4, 4, [1, 0, 0, 0, 0, 2, 0, 0, 0, 0, 4, 1, 0, 0, 0, 4]) + P, J = m.jordan_form() + assert Jmust == J + + # checking for maximum precision to remain unchanged + m = Matrix([[Float('1.0', precision=110), Float('2.0', precision=110)], + [Float('3.14159265358979323846264338327', precision=110), Float('4.0', precision=110)]]) + P, J = m.jordan_form() + for term in J.values(): + if isinstance(term, Float): + assert term._prec == 110 + + +def test_jordan_form_complex_issue_9274(): + A = Matrix([[ 2, 4, 1, 0], + [-4, 2, 0, 1], + [ 0, 0, 2, 4], + [ 0, 0, -4, 2]]) + p = 2 - 4*I + q = 2 + 4*I + Jmust1 = Matrix([[p, 1, 0, 0], + [0, p, 0, 0], + [0, 0, q, 1], + [0, 0, 0, q]]) + Jmust2 = Matrix([[q, 1, 0, 0], + [0, q, 0, 0], + [0, 0, p, 1], + [0, 0, 0, p]]) + P, J = A.jordan_form() + assert J == Jmust1 or J == Jmust2 + assert simplify(P*J*P.inv()) == A + +def test_issue_10220(): + # two non-orthogonal Jordan blocks with eigenvalue 1 + M = Matrix([[1, 0, 0, 1], + [0, 1, 1, 0], + [0, 0, 1, 1], + [0, 0, 0, 1]]) + P, J = M.jordan_form() + assert P == Matrix([[0, 1, 0, 1], + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, 0]]) + assert J == Matrix([ + [1, 1, 0, 0], + [0, 1, 1, 0], + [0, 0, 1, 0], + [0, 0, 0, 1]]) + +def test_jordan_form_issue_15858(): + A = Matrix([ + [1, 1, 1, 0], + [-2, -1, 0, -1], + [0, 0, -1, -1], + [0, 0, 2, 1]]) + (P, J) = A.jordan_form() + assert P.expand() == Matrix([ + [ -I, -I/2, I, I/2], + [-1 + I, 0, -1 - I, 0], + [ 0, -S(1)/2 - I/2, 0, -S(1)/2 + I/2], + [ 0, 1, 0, 1]]) + assert J == Matrix([ + [-I, 1, 0, 0], + [0, -I, 0, 0], + [0, 0, I, 1], + [0, 0, 0, I]]) + +def test_Matrix_berkowitz_charpoly(): + UA, K_i, K_w = symbols('UA K_i K_w') + + A = Matrix([[-K_i - UA + K_i**2/(K_i + K_w), K_i*K_w/(K_i + K_w)], + [ K_i*K_w/(K_i + K_w), -K_w + K_w**2/(K_i + K_w)]]) + + charpoly = A.charpoly(x) + + assert charpoly == \ + Poly(x**2 + (K_i*UA + K_w*UA + 2*K_i*K_w)/(K_i + K_w)*x + + K_i*K_w*UA/(K_i + K_w), x, domain='ZZ(K_i,K_w,UA)') + + assert type(charpoly) is PurePoly + + A = Matrix([[1, 3], [2, 0]]) + assert A.charpoly() == A.charpoly(x) == PurePoly(x**2 - x - 6) + + A = Matrix([[1, 2], [x, 0]]) + p = A.charpoly(x) + assert p.gen != x + assert p.as_expr().subs(p.gen, x) == x**2 - 3*x + + +def test_exp_jordan_block(): + l = Symbol('lamda') + + m = Matrix.jordan_block(1, l) + assert m._eval_matrix_exp_jblock() == Matrix([[exp(l)]]) + + m = Matrix.jordan_block(3, l) + assert m._eval_matrix_exp_jblock() == \ + Matrix([ + [exp(l), exp(l), exp(l)/2], + [0, exp(l), exp(l)], + [0, 0, exp(l)]]) + + +def test_exp(): + m = Matrix([[3, 4], [0, -2]]) + m_exp = Matrix([[exp(3), -4*exp(-2)/5 + 4*exp(3)/5], [0, exp(-2)]]) + assert m.exp() == m_exp + assert exp(m) == m_exp + + m = Matrix([[1, 0], [0, 1]]) + assert m.exp() == Matrix([[E, 0], [0, E]]) + assert exp(m) == Matrix([[E, 0], [0, E]]) + + m = Matrix([[1, -1], [1, 1]]) + assert m.exp() == Matrix([[E*cos(1), -E*sin(1)], [E*sin(1), E*cos(1)]]) + + +def test_log(): + l = Symbol('lamda') + + m = Matrix.jordan_block(1, l) + assert m._eval_matrix_log_jblock() == Matrix([[log(l)]]) + + m = Matrix.jordan_block(4, l) + assert m._eval_matrix_log_jblock() == \ + Matrix( + [ + [log(l), 1/l, -1/(2*l**2), 1/(3*l**3)], + [0, log(l), 1/l, -1/(2*l**2)], + [0, 0, log(l), 1/l], + [0, 0, 0, log(l)] + ] + ) + + m = Matrix( + [[0, 0, 1], + [0, 0, 0], + [-1, 0, 0]] + ) + raises(MatrixError, lambda: m.log()) + + +def test_has(): + A = Matrix(((x, y), (2, 3))) + assert A.has(x) + assert not A.has(z) + assert A.has(Symbol) + + A = A.subs(x, 2) + assert not A.has(x) + + +def test_find_reasonable_pivot_naive_finds_guaranteed_nonzero1(): + # Test if matrices._find_reasonable_pivot_naive() + # finds a guaranteed non-zero pivot when the + # some of the candidate pivots are symbolic expressions. + # Keyword argument: simpfunc=None indicates that no simplifications + # should be performed during the search. + x = Symbol('x') + column = Matrix(3, 1, [x, cos(x)**2 + sin(x)**2, S.Half]) + pivot_offset, pivot_val, pivot_assumed_nonzero, simplified =\ + _find_reasonable_pivot_naive(column) + assert pivot_val == S.Half + +def test_find_reasonable_pivot_naive_finds_guaranteed_nonzero2(): + # Test if matrices._find_reasonable_pivot_naive() + # finds a guaranteed non-zero pivot when the + # some of the candidate pivots are symbolic expressions. + # Keyword argument: simpfunc=_simplify indicates that the search + # should attempt to simplify candidate pivots. + x = Symbol('x') + column = Matrix(3, 1, + [x, + cos(x)**2+sin(x)**2+x**2, + cos(x)**2+sin(x)**2]) + pivot_offset, pivot_val, pivot_assumed_nonzero, simplified =\ + _find_reasonable_pivot_naive(column, simpfunc=_simplify) + assert pivot_val == 1 + +def test_find_reasonable_pivot_naive_simplifies(): + # Test if matrices._find_reasonable_pivot_naive() + # simplifies candidate pivots, and reports + # their offsets correctly. + x = Symbol('x') + column = Matrix(3, 1, + [x, + cos(x)**2+sin(x)**2+x, + cos(x)**2+sin(x)**2]) + pivot_offset, pivot_val, pivot_assumed_nonzero, simplified =\ + _find_reasonable_pivot_naive(column, simpfunc=_simplify) + + assert len(simplified) == 2 + assert simplified[0][0] == 1 + assert simplified[0][1] == 1+x + assert simplified[1][0] == 2 + assert simplified[1][1] == 1 + +def test_errors(): + raises(ValueError, lambda: Matrix([[1, 2], [1]])) + raises(IndexError, lambda: Matrix([[1, 2]])[1.2, 5]) + raises(IndexError, lambda: Matrix([[1, 2]])[1, 5.2]) + raises(ValueError, lambda: randMatrix(3, c=4, symmetric=True)) + raises(ValueError, lambda: Matrix([1, 2]).reshape(4, 6)) + raises(ShapeError, + lambda: Matrix([[1, 2], [3, 4]]).copyin_matrix([1, 0], Matrix([1, 2]))) + raises(TypeError, lambda: Matrix([[1, 2], [3, 4]]).copyin_list([0, + 1], set())) + raises(NonSquareMatrixError, lambda: Matrix([[1, 2, 3], [2, 3, 0]]).inv()) + raises(ShapeError, + lambda: Matrix(1, 2, [1, 2]).row_join(Matrix([[1, 2], [3, 4]]))) + raises( + ShapeError, lambda: Matrix([1, 2]).col_join(Matrix([[1, 2], [3, 4]]))) + raises(ShapeError, lambda: Matrix([1]).row_insert(1, Matrix([[1, + 2], [3, 4]]))) + raises(ShapeError, lambda: Matrix([1]).col_insert(1, Matrix([[1, + 2], [3, 4]]))) + raises(NonSquareMatrixError, lambda: Matrix([1, 2]).trace()) + raises(TypeError, lambda: Matrix([1]).applyfunc(1)) + raises(ValueError, lambda: Matrix([[1, 2], [3, 4]]).minor(4, 5)) + raises(ValueError, lambda: Matrix([[1, 2], [3, 4]]).minor_submatrix(4, 5)) + raises(TypeError, lambda: Matrix([1, 2, 3]).cross(1)) + raises(TypeError, lambda: Matrix([1, 2, 3]).dot(1)) + raises(ShapeError, lambda: Matrix([1, 2, 3]).dot(Matrix([1, 2]))) + raises(ShapeError, lambda: Matrix([1, 2]).dot([])) + raises(TypeError, lambda: Matrix([1, 2]).dot('a')) + raises(ShapeError, lambda: Matrix([1, 2]).dot([1, 2, 3])) + raises(NonSquareMatrixError, lambda: Matrix([1, 2, 3]).exp()) + raises(ShapeError, lambda: Matrix([[1, 2], [3, 4]]).normalized()) + raises(ValueError, lambda: Matrix([1, 2]).inv(method='not a method')) + raises(NonSquareMatrixError, lambda: Matrix([1, 2]).inverse_GE()) + raises(ValueError, lambda: Matrix([[1, 2], [1, 2]]).inverse_GE()) + raises(NonSquareMatrixError, lambda: Matrix([1, 2]).inverse_ADJ()) + raises(ValueError, lambda: Matrix([[1, 2], [1, 2]]).inverse_ADJ()) + raises(NonSquareMatrixError, lambda: Matrix([1, 2]).inverse_LU()) + raises(NonSquareMatrixError, lambda: Matrix([1, 2]).is_nilpotent()) + raises(NonSquareMatrixError, lambda: Matrix([1, 2]).det()) + raises(ValueError, + lambda: Matrix([[1, 2], [3, 4]]).det(method='Not a real method')) + raises(ValueError, + lambda: Matrix([[1, 2, 3, 4], [5, 6, 7, 8], + [9, 10, 11, 12], [13, 14, 15, 16]]).det(iszerofunc="Not function")) + raises(ValueError, + lambda: Matrix([[1, 2, 3, 4], [5, 6, 7, 8], + [9, 10, 11, 12], [13, 14, 15, 16]]).det(iszerofunc=False)) + raises(ValueError, + lambda: hessian(Matrix([[1, 2], [3, 4]]), Matrix([[1, 2], [2, 1]]))) + raises(ValueError, lambda: hessian(Matrix([[1, 2], [3, 4]]), [])) + raises(ValueError, lambda: hessian(Symbol('x')**2, 'a')) + raises(IndexError, lambda: eye(3)[5, 2]) + raises(IndexError, lambda: eye(3)[2, 5]) + M = Matrix(((1, 2, 3, 4), (5, 6, 7, 8), (9, 10, 11, 12), (13, 14, 15, 16))) + raises(ValueError, lambda: M.det('method=LU_decomposition()')) + V = Matrix([[10, 10, 10]]) + M = Matrix([[1, 2, 3], [2, 3, 4], [3, 4, 5]]) + raises(ValueError, lambda: M.row_insert(4.7, V)) + M = Matrix([[1, 2, 3], [2, 3, 4], [3, 4, 5]]) + raises(ValueError, lambda: M.col_insert(-4.2, V)) + +def test_len(): + assert len(Matrix()) == 0 + assert len(Matrix([[1, 2]])) == len(Matrix([[1], [2]])) == 2 + assert len(Matrix(0, 2, lambda i, j: 0)) == \ + len(Matrix(2, 0, lambda i, j: 0)) == 0 + assert len(Matrix([[0, 1, 2], [3, 4, 5]])) == 6 + assert Matrix([1]) == Matrix([[1]]) + assert not Matrix() + assert Matrix() == Matrix([]) + + +def test_integrate(): + A = Matrix(((1, 4, x), (y, 2, 4), (10, 5, x**2))) + assert A.integrate(x) == \ + Matrix(((x, 4*x, x**2/2), (x*y, 2*x, 4*x), (10*x, 5*x, x**3/3))) + assert A.integrate(y) == \ + Matrix(((y, 4*y, x*y), (y**2/2, 2*y, 4*y), (10*y, 5*y, y*x**2))) + + +def test_limit(): + A = Matrix(((1, 4, sin(x)/x), (y, 2, 4), (10, 5, x**2 + 1))) + assert A.limit(x, 0) == Matrix(((1, 4, 1), (y, 2, 4), (10, 5, 1))) + + +def test_diff(): + A = MutableDenseMatrix(((1, 4, x), (y, 2, 4), (10, 5, x**2 + 1))) + assert isinstance(A.diff(x), type(A)) + assert A.diff(x) == MutableDenseMatrix(((0, 0, 1), (0, 0, 0), (0, 0, 2*x))) + assert A.diff(y) == MutableDenseMatrix(((0, 0, 0), (1, 0, 0), (0, 0, 0))) + + assert diff(A, x) == MutableDenseMatrix(((0, 0, 1), (0, 0, 0), (0, 0, 2*x))) + assert diff(A, y) == MutableDenseMatrix(((0, 0, 0), (1, 0, 0), (0, 0, 0))) + + A_imm = A.as_immutable() + assert isinstance(A_imm.diff(x), type(A_imm)) + assert A_imm.diff(x) == ImmutableDenseMatrix(((0, 0, 1), (0, 0, 0), (0, 0, 2*x))) + assert A_imm.diff(y) == ImmutableDenseMatrix(((0, 0, 0), (1, 0, 0), (0, 0, 0))) + + assert diff(A_imm, x) == ImmutableDenseMatrix(((0, 0, 1), (0, 0, 0), (0, 0, 2*x))) + assert diff(A_imm, y) == ImmutableDenseMatrix(((0, 0, 0), (1, 0, 0), (0, 0, 0))) + + assert A.diff(x, evaluate=False) == ArrayDerivative(A, x, evaluate=False) + assert diff(A, x, evaluate=False) == ArrayDerivative(A, x, evaluate=False) + + +def test_diff_by_matrix(): + + # Derive matrix by matrix: + + A = MutableDenseMatrix([[x, y], [z, t]]) + assert A.diff(A) == Array([[[[1, 0], [0, 0]], [[0, 1], [0, 0]]], [[[0, 0], [1, 0]], [[0, 0], [0, 1]]]]) + assert diff(A, A) == Array([[[[1, 0], [0, 0]], [[0, 1], [0, 0]]], [[[0, 0], [1, 0]], [[0, 0], [0, 1]]]]) + + A_imm = A.as_immutable() + assert A_imm.diff(A_imm) == Array([[[[1, 0], [0, 0]], [[0, 1], [0, 0]]], [[[0, 0], [1, 0]], [[0, 0], [0, 1]]]]) + assert diff(A_imm, A_imm) == Array([[[[1, 0], [0, 0]], [[0, 1], [0, 0]]], [[[0, 0], [1, 0]], [[0, 0], [0, 1]]]]) + + # Derive a constant matrix: + assert A.diff(a) == MutableDenseMatrix([[0, 0], [0, 0]]) + + B = ImmutableDenseMatrix([a, b]) + assert A.diff(B) == Array.zeros(2, 1, 2, 2) + assert A.diff(A) == Array([[[[1, 0], [0, 0]], [[0, 1], [0, 0]]], [[[0, 0], [1, 0]], [[0, 0], [0, 1]]]]) + + # Test diff with tuples: + + dB = B.diff([[a, b]]) + assert dB.shape == (2, 2, 1) + assert dB == Array([[[1], [0]], [[0], [1]]]) + + f = Function("f") + fxyz = f(x, y, z) + assert fxyz.diff([[x, y, z]]) == Array([fxyz.diff(x), fxyz.diff(y), fxyz.diff(z)]) + assert fxyz.diff(([x, y, z], 2)) == Array([ + [fxyz.diff(x, 2), fxyz.diff(x, y), fxyz.diff(x, z)], + [fxyz.diff(x, y), fxyz.diff(y, 2), fxyz.diff(y, z)], + [fxyz.diff(x, z), fxyz.diff(z, y), fxyz.diff(z, 2)], + ]) + + expr = sin(x)*exp(y) + assert expr.diff([[x, y]]) == Array([cos(x)*exp(y), sin(x)*exp(y)]) + assert expr.diff(y, ((x, y),)) == Array([cos(x)*exp(y), sin(x)*exp(y)]) + assert expr.diff(x, ((x, y),)) == Array([-sin(x)*exp(y), cos(x)*exp(y)]) + assert expr.diff(((y, x),), [[x, y]]) == Array([[cos(x)*exp(y), -sin(x)*exp(y)], [sin(x)*exp(y), cos(x)*exp(y)]]) + + # Test different notations: + + assert fxyz.diff(x).diff(y).diff(x) == fxyz.diff(((x, y, z),), 3)[0, 1, 0] + assert fxyz.diff(z).diff(y).diff(x) == fxyz.diff(((x, y, z),), 3)[2, 1, 0] + assert fxyz.diff([[x, y, z]], ((z, y, x),)) == Array([[fxyz.diff(i).diff(j) for i in (x, y, z)] for j in (z, y, x)]) + + # Test scalar derived by matrix remains matrix: + res = x.diff(Matrix([[x, y]])) + assert isinstance(res, ImmutableDenseMatrix) + assert res == Matrix([[1, 0]]) + res = (x**3).diff(Matrix([[x, y]])) + assert isinstance(res, ImmutableDenseMatrix) + assert res == Matrix([[3*x**2, 0]]) + + +def test_getattr(): + A = Matrix(((1, 4, x), (y, 2, 4), (10, 5, x**2 + 1))) + raises(AttributeError, lambda: A.nonexistantattribute) + assert getattr(A, 'diff')(x) == Matrix(((0, 0, 1), (0, 0, 0), (0, 0, 2*x))) + + +def test_hessenberg(): + A = Matrix([[3, 4, 1], [2, 4, 5], [0, 1, 2]]) + assert A.is_upper_hessenberg + A = A.T + assert A.is_lower_hessenberg + A[0, -1] = 1 + assert A.is_lower_hessenberg is False + + A = Matrix([[3, 4, 1], [2, 4, 5], [3, 1, 2]]) + assert not A.is_upper_hessenberg + + A = zeros(5, 2) + assert A.is_upper_hessenberg + + +def test_cholesky(): + raises(NonSquareMatrixError, lambda: Matrix((1, 2)).cholesky()) + raises(ValueError, lambda: Matrix(((1, 2), (3, 4))).cholesky()) + raises(ValueError, lambda: Matrix(((5 + I, 0), (0, 1))).cholesky()) + raises(ValueError, lambda: Matrix(((1, 5), (5, 1))).cholesky()) + raises(ValueError, lambda: Matrix(((1, 2), (3, 4))).cholesky(hermitian=False)) + assert Matrix(((5 + I, 0), (0, 1))).cholesky(hermitian=False) == Matrix([ + [sqrt(5 + I), 0], [0, 1]]) + A = Matrix(((1, 5), (5, 1))) + L = A.cholesky(hermitian=False) + assert L == Matrix([[1, 0], [5, 2*sqrt(6)*I]]) + assert L*L.T == A + A = Matrix(((25, 15, -5), (15, 18, 0), (-5, 0, 11))) + L = A.cholesky() + assert L * L.T == A + assert L.is_lower + assert L == Matrix([[5, 0, 0], [3, 3, 0], [-1, 1, 3]]) + A = Matrix(((4, -2*I, 2 + 2*I), (2*I, 2, -1 + I), (2 - 2*I, -1 - I, 11))) + assert A.cholesky().expand() == Matrix(((2, 0, 0), (I, 1, 0), (1 - I, 0, 3))) + + raises(NonSquareMatrixError, lambda: SparseMatrix((1, 2)).cholesky()) + raises(ValueError, lambda: SparseMatrix(((1, 2), (3, 4))).cholesky()) + raises(ValueError, lambda: SparseMatrix(((5 + I, 0), (0, 1))).cholesky()) + raises(ValueError, lambda: SparseMatrix(((1, 5), (5, 1))).cholesky()) + raises(ValueError, lambda: SparseMatrix(((1, 2), (3, 4))).cholesky(hermitian=False)) + assert SparseMatrix(((5 + I, 0), (0, 1))).cholesky(hermitian=False) == Matrix([ + [sqrt(5 + I), 0], [0, 1]]) + A = SparseMatrix(((1, 5), (5, 1))) + L = A.cholesky(hermitian=False) + assert L == Matrix([[1, 0], [5, 2*sqrt(6)*I]]) + assert L*L.T == A + A = SparseMatrix(((25, 15, -5), (15, 18, 0), (-5, 0, 11))) + L = A.cholesky() + assert L * L.T == A + assert L.is_lower + assert L == Matrix([[5, 0, 0], [3, 3, 0], [-1, 1, 3]]) + A = SparseMatrix(((4, -2*I, 2 + 2*I), (2*I, 2, -1 + I), (2 - 2*I, -1 - I, 11))) + assert A.cholesky() == Matrix(((2, 0, 0), (I, 1, 0), (1 - I, 0, 3))) + + +def test_matrix_norm(): + # Vector Tests + # Test columns and symbols + x = Symbol('x', real=True) + v = Matrix([cos(x), sin(x)]) + assert trigsimp(v.norm(2)) == 1 + assert v.norm(10) == Pow(cos(x)**10 + sin(x)**10, Rational(1, 10)) + + # Test Rows + A = Matrix([[5, Rational(3, 2)]]) + assert A.norm() == Pow(25 + Rational(9, 4), S.Half) + assert A.norm(oo) == max(A) + assert A.norm(-oo) == min(A) + + # Matrix Tests + # Intuitive test + A = Matrix([[1, 1], [1, 1]]) + assert A.norm(2) == 2 + assert A.norm(-2) == 0 + assert A.norm('frobenius') == 2 + assert eye(10).norm(2) == eye(10).norm(-2) == 1 + assert A.norm(oo) == 2 + + # Test with Symbols and more complex entries + A = Matrix([[3, y, y], [x, S.Half, -pi]]) + assert (A.norm('fro') + == sqrt(Rational(37, 4) + 2*abs(y)**2 + pi**2 + x**2)) + + # Check non-square + A = Matrix([[1, 2, -3], [4, 5, Rational(13, 2)]]) + assert A.norm(2) == sqrt(Rational(389, 8) + sqrt(78665)/8) + assert A.norm(-2) is S.Zero + assert A.norm('frobenius') == sqrt(389)/2 + + # Test properties of matrix norms + # https://en.wikipedia.org/wiki/Matrix_norm#Definition + # Two matrices + A = Matrix([[1, 2], [3, 4]]) + B = Matrix([[5, 5], [-2, 2]]) + C = Matrix([[0, -I], [I, 0]]) + D = Matrix([[1, 0], [0, -1]]) + L = [A, B, C, D] + alpha = Symbol('alpha', real=True) + + for order in ['fro', 2, -2]: + # Zero Check + assert zeros(3).norm(order) is S.Zero + # Check Triangle Inequality for all Pairs of Matrices + for X in L: + for Y in L: + dif = (X.norm(order) + Y.norm(order) - + (X + Y).norm(order)) + assert (dif >= 0) + # Scalar multiplication linearity + for M in [A, B, C, D]: + dif = simplify((alpha*M).norm(order) - + abs(alpha) * M.norm(order)) + assert dif == 0 + + # Test Properties of Vector Norms + # https://en.wikipedia.org/wiki/Vector_norm + # Two column vectors + a = Matrix([1, 1 - 1*I, -3]) + b = Matrix([S.Half, 1*I, 1]) + c = Matrix([-1, -1, -1]) + d = Matrix([3, 2, I]) + e = Matrix([Integer(1e2), Rational(1, 1e2), 1]) + L = [a, b, c, d, e] + alpha = Symbol('alpha', real=True) + + for order in [1, 2, -1, -2, S.Infinity, S.NegativeInfinity, pi]: + # Zero Check + if order > 0: + assert Matrix([0, 0, 0]).norm(order) is S.Zero + # Triangle inequality on all pairs + if order >= 1: # Triangle InEq holds only for these norms + for X in L: + for Y in L: + dif = (X.norm(order) + Y.norm(order) - + (X + Y).norm(order)) + assert simplify(dif >= 0) is S.true + # Linear to scalar multiplication + if order in [1, 2, -1, -2, S.Infinity, S.NegativeInfinity]: + for X in L: + dif = simplify((alpha*X).norm(order) - + (abs(alpha) * X.norm(order))) + assert dif == 0 + + # ord=1 + M = Matrix(3, 3, [1, 3, 0, -2, -1, 0, 3, 9, 6]) + assert M.norm(1) == 13 + + +def test_condition_number(): + x = Symbol('x', real=True) + A = eye(3) + A[0, 0] = 10 + A[2, 2] = Rational(1, 10) + assert A.condition_number() == 100 + + A[1, 1] = x + assert A.condition_number() == Max(10, Abs(x)) / Min(Rational(1, 10), Abs(x)) + + M = Matrix([[cos(x), sin(x)], [-sin(x), cos(x)]]) + Mc = M.condition_number() + assert all(Float(1.).epsilon_eq(Mc.subs(x, val).evalf()) for val in + [Rational(1, 5), S.Half, Rational(1, 10), pi/2, pi, pi*Rational(7, 4) ]) + + #issue 10782 + assert Matrix([]).condition_number() == 0 + + +def test_equality(): + A = Matrix(((1, 2, 3), (4, 5, 6), (7, 8, 9))) + B = Matrix(((9, 8, 7), (6, 5, 4), (3, 2, 1))) + assert A == A[:, :] + assert not A != A[:, :] + assert not A == B + assert A != B + assert A != 10 + assert not A == 10 + + # A SparseMatrix can be equal to a Matrix + C = SparseMatrix(((1, 0, 0), (0, 1, 0), (0, 0, 1))) + D = Matrix(((1, 0, 0), (0, 1, 0), (0, 0, 1))) + assert C == D + assert not C != D + + +def test_col_join(): + assert eye(3).col_join(Matrix([[7, 7, 7]])) == \ + Matrix([[1, 0, 0], + [0, 1, 0], + [0, 0, 1], + [7, 7, 7]]) + + +def test_row_insert(): + r4 = Matrix([[4, 4, 4]]) + for i in range(-4, 5): + l = [1, 0, 0] + l.insert(i, 4) + assert flatten(eye(3).row_insert(i, r4).col(0).tolist()) == l + + +def test_col_insert(): + c4 = Matrix([4, 4, 4]) + for i in range(-4, 5): + l = [0, 0, 0] + l.insert(i, 4) + assert flatten(zeros(3).col_insert(i, c4).row(0).tolist()) == l + + +def test_normalized(): + assert Matrix([3, 4]).normalized() == \ + Matrix([Rational(3, 5), Rational(4, 5)]) + + # Zero vector trivial cases + assert Matrix([0, 0, 0]).normalized() == Matrix([0, 0, 0]) + + # Machine precision error truncation trivial cases + m = Matrix([0,0,1.e-100]) + assert m.normalized( + iszerofunc=lambda x: x.evalf(n=10, chop=True).is_zero + ) == Matrix([0, 0, 0]) + + +def test_print_nonzero(): + assert capture(lambda: eye(3).print_nonzero()) == \ + '[X ]\n[ X ]\n[ X]\n' + assert capture(lambda: eye(3).print_nonzero('.')) == \ + '[. ]\n[ . ]\n[ .]\n' + + +def test_zeros_eye(): + assert Matrix.eye(3) == eye(3) + assert Matrix.zeros(3) == zeros(3) + assert ones(3, 4) == Matrix(3, 4, [1]*12) + + i = Matrix([[1, 0], [0, 1]]) + z = Matrix([[0, 0], [0, 0]]) + for cls in classes: + m = cls.eye(2) + assert i == m # but m == i will fail if m is immutable + assert i == eye(2, cls=cls) + assert type(m) == cls + m = cls.zeros(2) + assert z == m + assert z == zeros(2, cls=cls) + assert type(m) == cls + + +def test_is_zero(): + assert Matrix().is_zero_matrix + assert Matrix([[0, 0], [0, 0]]).is_zero_matrix + assert zeros(3, 4).is_zero_matrix + assert not eye(3).is_zero_matrix + assert Matrix([[x, 0], [0, 0]]).is_zero_matrix == None + assert SparseMatrix([[x, 0], [0, 0]]).is_zero_matrix == None + assert ImmutableMatrix([[x, 0], [0, 0]]).is_zero_matrix == None + assert ImmutableSparseMatrix([[x, 0], [0, 0]]).is_zero_matrix == None + assert Matrix([[x, 1], [0, 0]]).is_zero_matrix == False + a = Symbol('a', nonzero=True) + assert Matrix([[a, 0], [0, 0]]).is_zero_matrix == False + + +def test_rotation_matrices(): + # This tests the rotation matrices by rotating about an axis and back. + theta = pi/3 + r3_plus = rot_axis3(theta) + r3_minus = rot_axis3(-theta) + r2_plus = rot_axis2(theta) + r2_minus = rot_axis2(-theta) + r1_plus = rot_axis1(theta) + r1_minus = rot_axis1(-theta) + assert r3_minus*r3_plus*eye(3) == eye(3) + assert r2_minus*r2_plus*eye(3) == eye(3) + assert r1_minus*r1_plus*eye(3) == eye(3) + + # Check the correctness of the trace of the rotation matrix + assert r1_plus.trace() == 1 + 2*cos(theta) + assert r2_plus.trace() == 1 + 2*cos(theta) + assert r3_plus.trace() == 1 + 2*cos(theta) + + # Check that a rotation with zero angle doesn't change anything. + assert rot_axis1(0) == eye(3) + assert rot_axis2(0) == eye(3) + assert rot_axis3(0) == eye(3) + + # Check left-hand convention + # see Issue #24529 + q1 = Quaternion.from_axis_angle([1, 0, 0], pi / 2) + q2 = Quaternion.from_axis_angle([0, 1, 0], pi / 2) + q3 = Quaternion.from_axis_angle([0, 0, 1], pi / 2) + assert rot_axis1(- pi / 2) == q1.to_rotation_matrix() + assert rot_axis2(- pi / 2) == q2.to_rotation_matrix() + assert rot_axis3(- pi / 2) == q3.to_rotation_matrix() + # Check right-hand convention + assert rot_ccw_axis1(+ pi / 2) == q1.to_rotation_matrix() + assert rot_ccw_axis2(+ pi / 2) == q2.to_rotation_matrix() + assert rot_ccw_axis3(+ pi / 2) == q3.to_rotation_matrix() + + +def test_DeferredVector(): + assert str(DeferredVector("vector")[4]) == "vector[4]" + assert sympify(DeferredVector("d")) == DeferredVector("d") + raises(IndexError, lambda: DeferredVector("d")[-1]) + assert str(DeferredVector("d")) == "d" + assert repr(DeferredVector("test")) == "DeferredVector('test')" + +def test_DeferredVector_not_iterable(): + assert not iterable(DeferredVector('X')) + +def test_DeferredVector_Matrix(): + raises(TypeError, lambda: Matrix(DeferredVector("V"))) + +def test_GramSchmidt(): + R = Rational + m1 = Matrix(1, 2, [1, 2]) + m2 = Matrix(1, 2, [2, 3]) + assert GramSchmidt([m1, m2]) == \ + [Matrix(1, 2, [1, 2]), Matrix(1, 2, [R(2)/5, R(-1)/5])] + assert GramSchmidt([m1.T, m2.T]) == \ + [Matrix(2, 1, [1, 2]), Matrix(2, 1, [R(2)/5, R(-1)/5])] + # from wikipedia + assert GramSchmidt([Matrix([3, 1]), Matrix([2, 2])], True) == [ + Matrix([3*sqrt(10)/10, sqrt(10)/10]), + Matrix([-sqrt(10)/10, 3*sqrt(10)/10])] + # https://github.com/sympy/sympy/issues/9488 + L = FiniteSet(Matrix([1])) + assert GramSchmidt(L) == [Matrix([[1]])] + + +def test_casoratian(): + assert casoratian([1, 2, 3, 4], 1) == 0 + assert casoratian([1, 2, 3, 4], 1, zero=False) == 0 + + +def test_zero_dimension_multiply(): + assert (Matrix()*zeros(0, 3)).shape == (0, 3) + assert zeros(3, 0)*zeros(0, 3) == zeros(3, 3) + assert zeros(0, 3)*zeros(3, 0) == Matrix() + + +def test_slice_issue_2884(): + m = Matrix(2, 2, range(4)) + assert m[1, :] == Matrix([[2, 3]]) + assert m[-1, :] == Matrix([[2, 3]]) + assert m[:, 1] == Matrix([[1, 3]]).T + assert m[:, -1] == Matrix([[1, 3]]).T + raises(IndexError, lambda: m[2, :]) + raises(IndexError, lambda: m[2, 2]) + + +def test_slice_issue_3401(): + assert zeros(0, 3)[:, -1].shape == (0, 1) + assert zeros(3, 0)[0, :] == Matrix(1, 0, []) + + +def test_copyin(): + s = zeros(3, 3) + s[3] = 1 + assert s[:, 0] == Matrix([0, 1, 0]) + assert s[3] == 1 + assert s[3: 4] == [1] + s[1, 1] = 42 + assert s[1, 1] == 42 + assert s[1, 1:] == Matrix([[42, 0]]) + s[1, 1:] = Matrix([[5, 6]]) + assert s[1, :] == Matrix([[1, 5, 6]]) + s[1, 1:] = [[42, 43]] + assert s[1, :] == Matrix([[1, 42, 43]]) + s[0, 0] = 17 + assert s[:, :1] == Matrix([17, 1, 0]) + s[0, 0] = [1, 1, 1] + assert s[:, 0] == Matrix([1, 1, 1]) + s[0, 0] = Matrix([1, 1, 1]) + assert s[:, 0] == Matrix([1, 1, 1]) + s[0, 0] = SparseMatrix([1, 1, 1]) + assert s[:, 0] == Matrix([1, 1, 1]) + + +def test_invertible_check(): + # sometimes a singular matrix will have a pivot vector shorter than + # the number of rows in a matrix... + assert Matrix([[1, 2], [1, 2]]).rref() == (Matrix([[1, 2], [0, 0]]), (0,)) + raises(ValueError, lambda: Matrix([[1, 2], [1, 2]]).inv()) + m = Matrix([ + [-1, -1, 0], + [ x, 1, 1], + [ 1, x, -1], + ]) + assert len(m.rref()[1]) != m.rows + # in addition, unless simplify=True in the call to rref, the identity + # matrix will be returned even though m is not invertible + assert m.rref()[0] != eye(3) + assert m.rref(simplify=signsimp)[0] != eye(3) + raises(ValueError, lambda: m.inv(method="ADJ")) + raises(ValueError, lambda: m.inv(method="GE")) + raises(ValueError, lambda: m.inv(method="LU")) + + +def test_issue_3959(): + x, y = symbols('x, y') + e = x*y + assert e.subs(x, Matrix([3, 5, 3])) == Matrix([3, 5, 3])*y + + +def test_issue_5964(): + assert str(Matrix([[1, 2], [3, 4]])) == 'Matrix([[1, 2], [3, 4]])' + + +def test_issue_7604(): + x, y = symbols("x y") + assert sstr(Matrix([[x, 2*y], [y**2, x + 3]])) == \ + 'Matrix([\n[ x, 2*y],\n[y**2, x + 3]])' + + +def test_is_Identity(): + assert eye(3).is_Identity + assert eye(3).as_immutable().is_Identity + assert not zeros(3).is_Identity + assert not ones(3).is_Identity + # issue 6242 + assert not Matrix([[1, 0, 0]]).is_Identity + # issue 8854 + assert SparseMatrix(3,3, {(0,0):1, (1,1):1, (2,2):1}).is_Identity + assert not SparseMatrix(2,3, range(6)).is_Identity + assert not SparseMatrix(3,3, {(0,0):1, (1,1):1}).is_Identity + assert not SparseMatrix(3,3, {(0,0):1, (1,1):1, (2,2):1, (0,1):2, (0,2):3}).is_Identity + + +def test_dot(): + assert ones(1, 3).dot(ones(3, 1)) == 3 + assert ones(1, 3).dot([1, 1, 1]) == 3 + assert Matrix([1, 2, 3]).dot(Matrix([1, 2, 3])) == 14 + assert Matrix([1, 2, 3*I]).dot(Matrix([I, 2, 3*I])) == -5 + I + assert Matrix([1, 2, 3*I]).dot(Matrix([I, 2, 3*I]), hermitian=False) == -5 + I + assert Matrix([1, 2, 3*I]).dot(Matrix([I, 2, 3*I]), hermitian=True) == 13 + I + assert Matrix([1, 2, 3*I]).dot(Matrix([I, 2, 3*I]), hermitian=True, conjugate_convention="physics") == 13 - I + assert Matrix([1, 2, 3*I]).dot(Matrix([4, 5*I, 6]), hermitian=True, conjugate_convention="right") == 4 + 8*I + assert Matrix([1, 2, 3*I]).dot(Matrix([4, 5*I, 6]), hermitian=True, conjugate_convention="left") == 4 - 8*I + assert Matrix([I, 2*I]).dot(Matrix([I, 2*I]), hermitian=False, conjugate_convention="left") == -5 + assert Matrix([I, 2*I]).dot(Matrix([I, 2*I]), conjugate_convention="left") == 5 + raises(ValueError, lambda: Matrix([1, 2]).dot(Matrix([3, 4]), hermitian=True, conjugate_convention="test")) + + +def test_dual(): + B_x, B_y, B_z, E_x, E_y, E_z = symbols( + 'B_x B_y B_z E_x E_y E_z', real=True) + F = Matrix(( + ( 0, E_x, E_y, E_z), + (-E_x, 0, B_z, -B_y), + (-E_y, -B_z, 0, B_x), + (-E_z, B_y, -B_x, 0) + )) + Fd = Matrix(( + ( 0, -B_x, -B_y, -B_z), + (B_x, 0, E_z, -E_y), + (B_y, -E_z, 0, E_x), + (B_z, E_y, -E_x, 0) + )) + assert F.dual().equals(Fd) + assert eye(3).dual().equals(zeros(3)) + assert F.dual().dual().equals(-F) + + +def test_anti_symmetric(): + assert Matrix([1, 2]).is_anti_symmetric() is False + m = Matrix(3, 3, [0, x**2 + 2*x + 1, y, -(x + 1)**2, 0, x*y, -y, -x*y, 0]) + assert m.is_anti_symmetric() is True + assert m.is_anti_symmetric(simplify=False) is None + assert m.is_anti_symmetric(simplify=lambda x: x) is None + + # tweak to fail + m[2, 1] = -m[2, 1] + assert m.is_anti_symmetric() is None + # untweak + m[2, 1] = -m[2, 1] + + m = m.expand() + assert m.is_anti_symmetric(simplify=False) is True + m[0, 0] = 1 + assert m.is_anti_symmetric() is False + + +def test_normalize_sort_diogonalization(): + A = Matrix(((1, 2), (2, 1))) + P, Q = A.diagonalize(normalize=True) + assert P*P.T == P.T*P == eye(P.cols) + P, Q = A.diagonalize(normalize=True, sort=True) + assert P*P.T == P.T*P == eye(P.cols) + assert P*Q*P.inv() == A + + +def test_issue_5321(): + raises(ValueError, lambda: Matrix([[1, 2, 3], Matrix(0, 1, [])])) + + +def test_issue_5320(): + assert Matrix.hstack(eye(2), 2*eye(2)) == Matrix([ + [1, 0, 2, 0], + [0, 1, 0, 2] + ]) + assert Matrix.vstack(eye(2), 2*eye(2)) == Matrix([ + [1, 0], + [0, 1], + [2, 0], + [0, 2] + ]) + cls = SparseMatrix + assert cls.hstack(cls(eye(2)), cls(2*eye(2))) == Matrix([ + [1, 0, 2, 0], + [0, 1, 0, 2] + ]) + +def test_issue_11944(): + A = Matrix([[1]]) + AIm = sympify(A) + assert Matrix.hstack(AIm, A) == Matrix([[1, 1]]) + assert Matrix.vstack(AIm, A) == Matrix([[1], [1]]) + +def test_cross(): + a = [1, 2, 3] + b = [3, 4, 5] + col = Matrix([-2, 4, -2]) + row = col.T + + def test(M, ans): + assert ans == M + assert type(M) == cls + for cls in classes: + A = cls(a) + B = cls(b) + test(A.cross(B), col) + test(A.cross(B.T), col) + test(A.T.cross(B.T), row) + test(A.T.cross(B), row) + raises(ShapeError, lambda: + Matrix(1, 2, [1, 1]).cross(Matrix(1, 2, [1, 1]))) + +def test_hat_vee(): + v1 = Matrix([x, y, z]) + v2 = Matrix([a, b, c]) + assert v1.hat() * v2 == v1.cross(v2) + assert v1.hat().is_anti_symmetric() + assert v1.hat().vee() == v1 + +def test_hash(): + for cls in classes[-2:]: + s = {cls.eye(1), cls.eye(1)} + assert len(s) == 1 and s.pop() == cls.eye(1) + # issue 3979 + for cls in classes[:2]: + assert not isinstance(cls.eye(1), Hashable) + + +@XFAIL +def test_issue_3979(): + # when this passes, delete this and change the [1:2] + # to [:2] in the test_hash above for issue 3979 + cls = classes[0] + raises(AttributeError, lambda: hash(cls.eye(1))) + + +def test_adjoint(): + dat = [[0, I], [1, 0]] + ans = Matrix([[0, 1], [-I, 0]]) + for cls in classes: + assert ans == cls(dat).adjoint() + + +def test_adjoint_with_operator(): + # Regression test for issue 25130: adjoint() should propagate to operators + import sympy.physics.quantum + a = sympy.physics.quantum.operator.Operator('a') + a_dag = sympy.physics.quantum.Dagger(a) + dat = [[0, I * a], [0, a_dag]] + ans = Matrix([[0, 0], [-I * a_dag, a]]) + for cls in classes: + assert ans == cls(dat).adjoint() + + +def test_simplify_immutable(): + assert simplify(ImmutableMatrix([[sin(x)**2 + cos(x)**2]])) == \ + ImmutableMatrix([[1]]) + +def test_replace(): + F, G = symbols('F, G', cls=Function) + K = Matrix(2, 2, lambda i, j: G(i+j)) + M = Matrix(2, 2, lambda i, j: F(i+j)) + N = M.replace(F, G) + assert N == K + + +def test_atoms(): + m = Matrix([[1, 2], [x, 1 - 1/x]]) + assert m.atoms() == {S.One,S(2),S.NegativeOne, x} + assert m.atoms(Symbol) == {x} + + +def test_pinv(): + # Pseudoinverse of an invertible matrix is the inverse. + A1 = Matrix([[a, b], [c, d]]) + assert simplify(A1.pinv(method="RD")) == simplify(A1.inv()) + + # Test the four properties of the pseudoinverse for various matrices. + As = [Matrix([[13, 104], [2212, 3], [-3, 5]]), + Matrix([[1, 7, 9], [11, 17, 19]]), + Matrix([a, b])] + + for A in As: + A_pinv = A.pinv(method="RD") + AAp = A * A_pinv + ApA = A_pinv * A + assert simplify(AAp * A) == A + assert simplify(ApA * A_pinv) == A_pinv + assert AAp.H == AAp + assert ApA.H == ApA + + # XXX Pinv with diagonalization makes expression too complicated. + for A in As: + A_pinv = simplify(A.pinv(method="ED")) + AAp = A * A_pinv + ApA = A_pinv * A + assert simplify(AAp * A) == A + assert simplify(ApA * A_pinv) == A_pinv + assert AAp.H == AAp + assert ApA.H == ApA + + # XXX Computing pinv using diagonalization makes an expression that + # is too complicated to simplify. + # A1 = Matrix([[a, b], [c, d]]) + # assert simplify(A1.pinv(method="ED")) == simplify(A1.inv()) + # so this is tested numerically at a fixed random point + + from sympy.core.numbers import comp + q = A1.pinv(method="ED") + w = A1.inv() + reps = {a: -73633, b: 11362, c: 55486, d: 62570} + assert all( + comp(i.n(), j.n()) + for i, j in zip(q.subs(reps), w.subs(reps)) + ) + + +@slow +def test_pinv_rank_deficient_when_diagonalization_fails(): + # Test the four properties of the pseudoinverse for matrices when + # diagonalization of A.H*A fails. + As = [ + Matrix([ + [61, 89, 55, 20, 71, 0], + [62, 96, 85, 85, 16, 0], + [69, 56, 17, 4, 54, 0], + [10, 54, 91, 41, 71, 0], + [ 7, 30, 10, 48, 90, 0], + [0, 0, 0, 0, 0, 0]]) + ] + for A in As: + A_pinv = A.pinv(method="ED") + AAp = A * A_pinv + ApA = A_pinv * A + assert AAp.H == AAp + + # Here ApA.H and ApA are equivalent expressions but they are very + # complicated expressions involving RootOfs. Using simplify would be + # too slow and so would evalf so we substitute approximate values for + # the RootOfs and then evalf which is less accurate but good enough to + # confirm that these two matrices are equivalent. + # + # assert ApA.H == ApA # <--- would fail (structural equality) + # assert simplify(ApA.H - ApA).is_zero_matrix # <--- too slow + # (ApA.H - ApA).evalf() # <--- too slow + + def allclose(M1, M2): + rootofs = M1.atoms(RootOf) + rootofs_approx = {r: r.evalf() for r in rootofs} + diff_approx = (M1 - M2).xreplace(rootofs_approx).evalf() + return all(abs(e) < 1e-10 for e in diff_approx) + + assert allclose(ApA.H, ApA) + + +def test_issue_7201(): + assert ones(0, 1) + ones(0, 1) == Matrix(0, 1, []) + assert ones(1, 0) + ones(1, 0) == Matrix(1, 0, []) + +def test_free_symbols(): + for M in ImmutableMatrix, ImmutableSparseMatrix, Matrix, SparseMatrix: + assert M([[x], [0]]).free_symbols == {x} + +def test_from_ndarray(): + """See issue 7465.""" + try: + from numpy import array + except ImportError: + skip('NumPy must be available to test creating matrices from ndarrays') + + assert Matrix(array([1, 2, 3])) == Matrix([1, 2, 3]) + assert Matrix(array([[1, 2, 3]])) == Matrix([[1, 2, 3]]) + assert Matrix(array([[1, 2, 3], [4, 5, 6]])) == \ + Matrix([[1, 2, 3], [4, 5, 6]]) + assert Matrix(array([x, y, z])) == Matrix([x, y, z]) + raises(NotImplementedError, + lambda: Matrix(array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]))) + assert Matrix([array([1, 2]), array([3, 4])]) == Matrix([[1, 2], [3, 4]]) + assert Matrix([array([1, 2]), [3, 4]]) == Matrix([[1, 2], [3, 4]]) + assert Matrix([array([]), array([])]) == Matrix(2, 0, []) != Matrix(0, 0, []) + +def test_17522_numpy(): + from sympy.matrices.common import _matrixify + try: + from numpy import array, matrix + except ImportError: + skip('NumPy must be available to test indexing matrixified NumPy ndarrays and matrices') + + m = _matrixify(array([[1, 2], [3, 4]])) + assert m[3] == 4 + assert list(m) == [1, 2, 3, 4] + + with ignore_warnings(PendingDeprecationWarning): + m = _matrixify(matrix([[1, 2], [3, 4]])) + assert m[3] == 4 + assert list(m) == [1, 2, 3, 4] + +def test_17522_mpmath(): + from sympy.matrices.common import _matrixify + try: + from mpmath import matrix + except ImportError: + skip('mpmath must be available to test indexing matrixified mpmath matrices') + + m = _matrixify(matrix([[1, 2], [3, 4]])) + assert m[3] == 4.0 + assert list(m) == [1.0, 2.0, 3.0, 4.0] + +def test_17522_scipy(): + from sympy.matrices.common import _matrixify + try: + from scipy.sparse import csr_matrix + except ImportError: + skip('SciPy must be available to test indexing matrixified SciPy sparse matrices') + + m = _matrixify(csr_matrix([[1, 2], [3, 4]])) + assert m[3] == 4 + assert list(m) == [1, 2, 3, 4] + +def test_hermitian(): + a = Matrix([[1, I], [-I, 1]]) + assert a.is_hermitian + a[0, 0] = 2*I + assert a.is_hermitian is False + a[0, 0] = x + assert a.is_hermitian is None + a[0, 1] = a[1, 0]*I + assert a.is_hermitian is False + b = HermitianOperator("b") + c = Operator("c") + assert Matrix([[b]]).is_hermitian is True + assert Matrix([[b, c], [Dagger(c), b]]).is_hermitian is True + assert Matrix([[b, c], [c, b]]).is_hermitian is False + assert Matrix([[b, c], [transpose(c), b]]).is_hermitian is False + +def test_doit(): + a = Matrix([[Add(x,x, evaluate=False)]]) + assert a[0] != 2*x + assert a.doit() == Matrix([[2*x]]) + +def test_issue_9457_9467_9876(): + # for row_del(index) + M = Matrix([[1, 2, 3], [2, 3, 4], [3, 4, 5]]) + M.row_del(1) + assert M == Matrix([[1, 2, 3], [3, 4, 5]]) + N = Matrix([[1, 2, 3], [2, 3, 4], [3, 4, 5]]) + N.row_del(-2) + assert N == Matrix([[1, 2, 3], [3, 4, 5]]) + O = Matrix([[1, 2, 3], [5, 6, 7], [9, 10, 11]]) + O.row_del(-1) + assert O == Matrix([[1, 2, 3], [5, 6, 7]]) + P = Matrix([[1, 2, 3], [2, 3, 4], [3, 4, 5]]) + raises(IndexError, lambda: P.row_del(10)) + Q = Matrix([[1, 2, 3], [2, 3, 4], [3, 4, 5]]) + raises(IndexError, lambda: Q.row_del(-10)) + + # for col_del(index) + M = Matrix([[1, 2, 3], [2, 3, 4], [3, 4, 5]]) + M.col_del(1) + assert M == Matrix([[1, 3], [2, 4], [3, 5]]) + N = Matrix([[1, 2, 3], [2, 3, 4], [3, 4, 5]]) + N.col_del(-2) + assert N == Matrix([[1, 3], [2, 4], [3, 5]]) + P = Matrix([[1, 2, 3], [2, 3, 4], [3, 4, 5]]) + raises(IndexError, lambda: P.col_del(10)) + Q = Matrix([[1, 2, 3], [2, 3, 4], [3, 4, 5]]) + raises(IndexError, lambda: Q.col_del(-10)) + +def test_issue_9422(): + x, y = symbols('x y', commutative=False) + a, b = symbols('a b') + M = eye(2) + M1 = Matrix(2, 2, [x, y, y, z]) + assert y*x*M != x*y*M + assert b*a*M == a*b*M + assert x*M1 != M1*x + assert a*M1 == M1*a + assert y*x*M == Matrix([[y*x, 0], [0, y*x]]) + + +def test_issue_10770(): + M = Matrix([]) + a = ['col_insert', 'row_join'], Matrix([9, 6, 3]) + b = ['row_insert', 'col_join'], a[1].T + c = ['row_insert', 'col_insert'], Matrix([[1, 2], [3, 4]]) + for ops, m in (a, b, c): + for op in ops: + f = getattr(M, op) + new = f(m) if 'join' in op else f(42, m) + assert new == m and id(new) != id(m) + + +def test_issue_10658(): + A = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + assert A.extract([0, 1, 2], [True, True, False]) == \ + Matrix([[1, 2], [4, 5], [7, 8]]) + assert A.extract([0, 1, 2], [True, False, False]) == Matrix([[1], [4], [7]]) + assert A.extract([True, False, False], [0, 1, 2]) == Matrix([[1, 2, 3]]) + assert A.extract([True, False, True], [0, 1, 2]) == \ + Matrix([[1, 2, 3], [7, 8, 9]]) + assert A.extract([0, 1, 2], [False, False, False]) == Matrix(3, 0, []) + assert A.extract([False, False, False], [0, 1, 2]) == Matrix(0, 3, []) + assert A.extract([True, False, True], [False, True, False]) == \ + Matrix([[2], [8]]) + +def test_opportunistic_simplification(): + # this test relates to issue #10718, #9480, #11434 + + # issue #9480 + m = Matrix([[-5 + 5*sqrt(2), -5], [-5*sqrt(2)/2 + 5, -5*sqrt(2)/2]]) + assert m.rank() == 1 + + # issue #10781 + m = Matrix([[3+3*sqrt(3)*I, -9],[4,-3+3*sqrt(3)*I]]) + assert simplify(m.rref()[0] - Matrix([[1, -9/(3 + 3*sqrt(3)*I)], [0, 0]])) == zeros(2, 2) + + # issue #11434 + ax,ay,bx,by,cx,cy,dx,dy,ex,ey,t0,t1 = symbols('a_x a_y b_x b_y c_x c_y d_x d_y e_x e_y t_0 t_1') + m = Matrix([[ax,ay,ax*t0,ay*t0,0],[bx,by,bx*t0,by*t0,0],[cx,cy,cx*t0,cy*t0,1],[dx,dy,dx*t0,dy*t0,1],[ex,ey,2*ex*t1-ex*t0,2*ey*t1-ey*t0,0]]) + assert m.rank() == 4 + +def test_partial_pivoting(): + # example from https://en.wikipedia.org/wiki/Pivot_element + # partial pivoting with back substitution gives a perfect result + # naive pivoting give an error ~1e-13, so anything better than + # 1e-15 is good + mm=Matrix([[0.003, 59.14, 59.17], [5.291, -6.13, 46.78]]) + assert (mm.rref()[0] - Matrix([[1.0, 0, 10.0], + [ 0, 1.0, 1.0]])).norm() < 1e-15 + + # issue #11549 + m_mixed = Matrix([[6e-17, 1.0, 4], + [ -1.0, 0, 8], + [ 0, 0, 1]]) + m_float = Matrix([[6e-17, 1.0, 4.], + [ -1.0, 0., 8.], + [ 0., 0., 1.]]) + m_inv = Matrix([[ 0, -1.0, 8.0], + [1.0, 6.0e-17, -4.0], + [ 0, 0, 1]]) + # this example is numerically unstable and involves a matrix with a norm >= 8, + # this comparing the difference of the results with 1e-15 is numerically sound. + assert (m_mixed.inv() - m_inv).norm() < 1e-15 + assert (m_float.inv() - m_inv).norm() < 1e-15 + +def test_iszero_substitution(): + """ When doing numerical computations, all elements that pass + the iszerofunc test should be set to numerically zero if they + aren't already. """ + + # Matrix from issue #9060 + m = Matrix([[0.9, -0.1, -0.2, 0],[-0.8, 0.9, -0.4, 0],[-0.1, -0.8, 0.6, 0]]) + m_rref = m.rref(iszerofunc=lambda x: abs(x)<6e-15)[0] + m_correct = Matrix([[1.0, 0, -0.301369863013699, 0],[ 0, 1.0, -0.712328767123288, 0],[ 0, 0, 0, 0]]) + m_diff = m_rref - m_correct + assert m_diff.norm() < 1e-15 + # if a zero-substitution wasn't made, this entry will be -1.11022302462516e-16 + assert m_rref[2,2] == 0 + +def test_issue_11238(): + from sympy.geometry.point import Point + xx = 8*tan(pi*Rational(13, 45))/(tan(pi*Rational(13, 45)) + sqrt(3)) + yy = (-8*sqrt(3)*tan(pi*Rational(13, 45))**2 + 24*tan(pi*Rational(13, 45)))/(-3 + tan(pi*Rational(13, 45))**2) + p1 = Point(0, 0) + p2 = Point(1, -sqrt(3)) + p0 = Point(xx,yy) + m1 = Matrix([p1 - simplify(p0), p2 - simplify(p0)]) + m2 = Matrix([p1 - p0, p2 - p0]) + m3 = Matrix([simplify(p1 - p0), simplify(p2 - p0)]) + + # This system has expressions which are zero and + # cannot be easily proved to be such, so without + # numerical testing, these assertions will fail. + Z = lambda x: abs(x.n()) < 1e-20 + assert m1.rank(simplify=True, iszerofunc=Z) == 1 + assert m2.rank(simplify=True, iszerofunc=Z) == 1 + assert m3.rank(simplify=True, iszerofunc=Z) == 1 + +def test_as_real_imag(): + m1 = Matrix(2,2,[1,2,3,4]) + m2 = m1*S.ImaginaryUnit + m3 = m1 + m2 + + for kls in classes: + a,b = kls(m3).as_real_imag() + assert list(a) == list(m1) + assert list(b) == list(m1) + +def test_deprecated(): + # Maintain tests for deprecated functions. We must capture + # the deprecation warnings. When the deprecated functionality is + # removed, the corresponding tests should be removed. + + m = Matrix(3, 3, [0, 1, 0, -4, 4, 0, -2, 1, 2]) + P, Jcells = m.jordan_cells() + assert Jcells[1] == Matrix(1, 1, [2]) + assert Jcells[0] == Matrix(2, 2, [2, 1, 0, 2]) + + +def test_issue_14489(): + from sympy.core.mod import Mod + A = Matrix([-1, 1, 2]) + B = Matrix([10, 20, -15]) + + assert Mod(A, 3) == Matrix([2, 1, 2]) + assert Mod(B, 4) == Matrix([2, 0, 1]) + +def test_issue_14943(): + # Test that __array__ accepts the optional dtype argument + try: + from numpy import array + except ImportError: + skip('NumPy must be available to test creating matrices from ndarrays') + + M = Matrix([[1,2], [3,4]]) + assert array(M, dtype=float).dtype.name == 'float64' + +def test_case_6913(): + m = MatrixSymbol('m', 1, 1) + a = Symbol("a") + a = m[0, 0]>0 + assert str(a) == 'm[0, 0] > 0' + +def test_issue_11948(): + A = MatrixSymbol('A', 3, 3) + a = Wild('a') + assert A.match(a) == {a: A} + +def test_gramschmidt_conjugate_dot(): + vecs = [Matrix([1, I]), Matrix([1, -I])] + assert Matrix.orthogonalize(*vecs) == \ + [Matrix([[1], [I]]), Matrix([[1], [-I]])] + + vecs = [Matrix([1, I, 0]), Matrix([I, 0, -I])] + assert Matrix.orthogonalize(*vecs) == \ + [Matrix([[1], [I], [0]]), Matrix([[I/2], [S(1)/2], [-I]])] + + mat = Matrix([[1, I], [1, -I]]) + Q, R = mat.QRdecomposition() + assert Q * Q.H == Matrix.eye(2) + +def test_issue_8207(): + a = Matrix(MatrixSymbol('a', 3, 1)) + b = Matrix(MatrixSymbol('b', 3, 1)) + c = a.dot(b) + d = diff(c, a[0, 0]) + e = diff(d, a[0, 0]) + assert d == b[0, 0] + assert e == 0 + +def test_func(): + from sympy.simplify.simplify import nthroot + + A = Matrix([[1, 2],[0, 3]]) + assert A.analytic_func(sin(x*t), x) == Matrix([[sin(t), sin(3*t) - sin(t)], [0, sin(3*t)]]) + + A = Matrix([[2, 1],[1, 2]]) + assert (pi * A / 6).analytic_func(cos(x), x) == Matrix([[sqrt(3)/4, -sqrt(3)/4], [-sqrt(3)/4, sqrt(3)/4]]) + + + raises(ValueError, lambda : zeros(5).analytic_func(log(x), x)) + raises(ValueError, lambda : (A*x).analytic_func(log(x), x)) + + A = Matrix([[0, -1, -2, 3], [0, -1, -2, 3], [0, 1, 0, -1], [0, 0, -1, 1]]) + assert A.analytic_func(exp(x), x) == A.exp() + raises(ValueError, lambda : A.analytic_func(sqrt(x), x)) + + A = Matrix([[41, 12],[12, 34]]) + assert simplify(A.analytic_func(sqrt(x), x)**2) == A + + A = Matrix([[3, -12, 4], [-1, 0, -2], [-1, 5, -1]]) + assert simplify(A.analytic_func(nthroot(x, 3), x)**3) == A + + A = Matrix([[2, 0, 0, 0], [1, 2, 0, 0], [0, 1, 3, 0], [0, 0, 1, 3]]) + assert A.analytic_func(exp(x), x) == A.exp() + + A = Matrix([[0, 2, 1, 6], [0, 0, 1, 2], [0, 0, 0, 3], [0, 0, 0, 0]]) + assert A.analytic_func(exp(x*t), x) == expand(simplify((A*t).exp())) + + +@skip_under_pyodide("Cannot create threads under pyodide.") +def test_issue_19809(): + + def f(): + assert _dotprodsimp_state.state == None + m = Matrix([[1]]) + m = m * m + return True + + with dotprodsimp(True): + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(f) + assert future.result() + + +def test_issue_23276(): + M = Matrix([x, y]) + assert integrate(M, (x, 0, 1), (y, 0, 1)) == Matrix([ + [S.Half], + [S.Half]]) + + +# SubspaceOnlyMatrix tests +def test_columnspace_one(): + m = SubspaceOnlyMatrix([[ 1, 2, 0, 2, 5], + [-2, -5, 1, -1, -8], + [ 0, -3, 3, 4, 1], + [ 3, 6, 0, -7, 2]]) + + basis = m.columnspace() + assert basis[0] == Matrix([1, -2, 0, 3]) + assert basis[1] == Matrix([2, -5, -3, 6]) + assert basis[2] == Matrix([2, -1, 4, -7]) + + assert len(basis) == 3 + assert Matrix.hstack(m, *basis).columnspace() == basis + + +def test_rowspace(): + m = SubspaceOnlyMatrix([[ 1, 2, 0, 2, 5], + [-2, -5, 1, -1, -8], + [ 0, -3, 3, 4, 1], + [ 3, 6, 0, -7, 2]]) + + basis = m.rowspace() + assert basis[0] == Matrix([[1, 2, 0, 2, 5]]) + assert basis[1] == Matrix([[0, -1, 1, 3, 2]]) + assert basis[2] == Matrix([[0, 0, 0, 5, 5]]) + + assert len(basis) == 3 + + +def test_nullspace_one(): + m = SubspaceOnlyMatrix([[ 1, 2, 0, 2, 5], + [-2, -5, 1, -1, -8], + [ 0, -3, 3, 4, 1], + [ 3, 6, 0, -7, 2]]) + + basis = m.nullspace() + assert basis[0] == Matrix([-2, 1, 1, 0, 0]) + assert basis[1] == Matrix([-1, -1, 0, -1, 1]) + # make sure the null space is really gets zeroed + assert all(e.is_zero for e in m*basis[0]) + assert all(e.is_zero for e in m*basis[1]) + + +# ReductionsOnlyMatrix tests +def test_row_op(): + e = eye_Reductions(3) + + raises(ValueError, lambda: e.elementary_row_op("abc")) + raises(ValueError, lambda: e.elementary_row_op()) + raises(ValueError, lambda: e.elementary_row_op('n->kn', row=5, k=5)) + raises(ValueError, lambda: e.elementary_row_op('n->kn', row=-5, k=5)) + raises(ValueError, lambda: e.elementary_row_op('n<->m', row1=1, row2=5)) + raises(ValueError, lambda: e.elementary_row_op('n<->m', row1=5, row2=1)) + raises(ValueError, lambda: e.elementary_row_op('n<->m', row1=-5, row2=1)) + raises(ValueError, lambda: e.elementary_row_op('n<->m', row1=1, row2=-5)) + raises(ValueError, lambda: e.elementary_row_op('n->n+km', row1=1, row2=5, k=5)) + raises(ValueError, lambda: e.elementary_row_op('n->n+km', row1=5, row2=1, k=5)) + raises(ValueError, lambda: e.elementary_row_op('n->n+km', row1=-5, row2=1, k=5)) + raises(ValueError, lambda: e.elementary_row_op('n->n+km', row1=1, row2=-5, k=5)) + raises(ValueError, lambda: e.elementary_row_op('n->n+km', row1=1, row2=1, k=5)) + + # test various ways to set arguments + assert e.elementary_row_op("n->kn", 0, 5) == Matrix([[5, 0, 0], [0, 1, 0], [0, 0, 1]]) + assert e.elementary_row_op("n->kn", 1, 5) == Matrix([[1, 0, 0], [0, 5, 0], [0, 0, 1]]) + assert e.elementary_row_op("n->kn", row=1, k=5) == Matrix([[1, 0, 0], [0, 5, 0], [0, 0, 1]]) + assert e.elementary_row_op("n->kn", row1=1, k=5) == Matrix([[1, 0, 0], [0, 5, 0], [0, 0, 1]]) + assert e.elementary_row_op("n<->m", 0, 1) == Matrix([[0, 1, 0], [1, 0, 0], [0, 0, 1]]) + assert e.elementary_row_op("n<->m", row1=0, row2=1) == Matrix([[0, 1, 0], [1, 0, 0], [0, 0, 1]]) + assert e.elementary_row_op("n<->m", row=0, row2=1) == Matrix([[0, 1, 0], [1, 0, 0], [0, 0, 1]]) + assert e.elementary_row_op("n->n+km", 0, 5, 1) == Matrix([[1, 5, 0], [0, 1, 0], [0, 0, 1]]) + assert e.elementary_row_op("n->n+km", row=0, k=5, row2=1) == Matrix([[1, 5, 0], [0, 1, 0], [0, 0, 1]]) + assert e.elementary_row_op("n->n+km", row1=0, k=5, row2=1) == Matrix([[1, 5, 0], [0, 1, 0], [0, 0, 1]]) + + # make sure the matrix doesn't change size + a = ReductionsOnlyMatrix(2, 3, [0]*6) + assert a.elementary_row_op("n->kn", 1, 5) == Matrix(2, 3, [0]*6) + assert a.elementary_row_op("n<->m", 0, 1) == Matrix(2, 3, [0]*6) + assert a.elementary_row_op("n->n+km", 0, 5, 1) == Matrix(2, 3, [0]*6) + + +def test_col_op(): + e = eye_Reductions(3) + + raises(ValueError, lambda: e.elementary_col_op("abc")) + raises(ValueError, lambda: e.elementary_col_op()) + raises(ValueError, lambda: e.elementary_col_op('n->kn', col=5, k=5)) + raises(ValueError, lambda: e.elementary_col_op('n->kn', col=-5, k=5)) + raises(ValueError, lambda: e.elementary_col_op('n<->m', col1=1, col2=5)) + raises(ValueError, lambda: e.elementary_col_op('n<->m', col1=5, col2=1)) + raises(ValueError, lambda: e.elementary_col_op('n<->m', col1=-5, col2=1)) + raises(ValueError, lambda: e.elementary_col_op('n<->m', col1=1, col2=-5)) + raises(ValueError, lambda: e.elementary_col_op('n->n+km', col1=1, col2=5, k=5)) + raises(ValueError, lambda: e.elementary_col_op('n->n+km', col1=5, col2=1, k=5)) + raises(ValueError, lambda: e.elementary_col_op('n->n+km', col1=-5, col2=1, k=5)) + raises(ValueError, lambda: e.elementary_col_op('n->n+km', col1=1, col2=-5, k=5)) + raises(ValueError, lambda: e.elementary_col_op('n->n+km', col1=1, col2=1, k=5)) + + # test various ways to set arguments + assert e.elementary_col_op("n->kn", 0, 5) == Matrix([[5, 0, 0], [0, 1, 0], [0, 0, 1]]) + assert e.elementary_col_op("n->kn", 1, 5) == Matrix([[1, 0, 0], [0, 5, 0], [0, 0, 1]]) + assert e.elementary_col_op("n->kn", col=1, k=5) == Matrix([[1, 0, 0], [0, 5, 0], [0, 0, 1]]) + assert e.elementary_col_op("n->kn", col1=1, k=5) == Matrix([[1, 0, 0], [0, 5, 0], [0, 0, 1]]) + assert e.elementary_col_op("n<->m", 0, 1) == Matrix([[0, 1, 0], [1, 0, 0], [0, 0, 1]]) + assert e.elementary_col_op("n<->m", col1=0, col2=1) == Matrix([[0, 1, 0], [1, 0, 0], [0, 0, 1]]) + assert e.elementary_col_op("n<->m", col=0, col2=1) == Matrix([[0, 1, 0], [1, 0, 0], [0, 0, 1]]) + assert e.elementary_col_op("n->n+km", 0, 5, 1) == Matrix([[1, 0, 0], [5, 1, 0], [0, 0, 1]]) + assert e.elementary_col_op("n->n+km", col=0, k=5, col2=1) == Matrix([[1, 0, 0], [5, 1, 0], [0, 0, 1]]) + assert e.elementary_col_op("n->n+km", col1=0, k=5, col2=1) == Matrix([[1, 0, 0], [5, 1, 0], [0, 0, 1]]) + + # make sure the matrix doesn't change size + a = ReductionsOnlyMatrix(2, 3, [0]*6) + assert a.elementary_col_op("n->kn", 1, 5) == Matrix(2, 3, [0]*6) + assert a.elementary_col_op("n<->m", 0, 1) == Matrix(2, 3, [0]*6) + assert a.elementary_col_op("n->n+km", 0, 5, 1) == Matrix(2, 3, [0]*6) + + +def test_is_echelon(): + zro = zeros_Reductions(3) + ident = eye_Reductions(3) + + assert zro.is_echelon + assert ident.is_echelon + + a = ReductionsOnlyMatrix(0, 0, []) + assert a.is_echelon + + a = ReductionsOnlyMatrix(2, 3, [3, 2, 1, 0, 0, 6]) + assert a.is_echelon + + a = ReductionsOnlyMatrix(2, 3, [0, 0, 6, 3, 2, 1]) + assert not a.is_echelon + + x = Symbol('x') + a = ReductionsOnlyMatrix(3, 1, [x, 0, 0]) + assert a.is_echelon + + a = ReductionsOnlyMatrix(3, 1, [x, x, 0]) + assert not a.is_echelon + + a = ReductionsOnlyMatrix(3, 3, [0, 0, 0, 1, 2, 3, 0, 0, 0]) + assert not a.is_echelon + + +def test_echelon_form(): + # echelon form is not unique, but the result + # must be row-equivalent to the original matrix + # and it must be in echelon form. + + a = zeros_Reductions(3) + e = eye_Reductions(3) + + # we can assume the zero matrix and the identity matrix shouldn't change + assert a.echelon_form() == a + assert e.echelon_form() == e + + a = ReductionsOnlyMatrix(0, 0, []) + assert a.echelon_form() == a + + a = ReductionsOnlyMatrix(1, 1, [5]) + assert a.echelon_form() == a + + # now we get to the real tests + + def verify_row_null_space(mat, rows, nulls): + for v in nulls: + assert all(t.is_zero for t in a_echelon*v) + for v in rows: + if not all(t.is_zero for t in v): + assert not all(t.is_zero for t in a_echelon*v.transpose()) + + a = ReductionsOnlyMatrix(3, 3, [1, 2, 3, 4, 5, 6, 7, 8, 9]) + nulls = [Matrix([ + [ 1], + [-2], + [ 1]])] + rows = [a[i, :] for i in range(a.rows)] + a_echelon = a.echelon_form() + assert a_echelon.is_echelon + verify_row_null_space(a, rows, nulls) + + + a = ReductionsOnlyMatrix(3, 3, [1, 2, 3, 4, 5, 6, 7, 8, 8]) + nulls = [] + rows = [a[i, :] for i in range(a.rows)] + a_echelon = a.echelon_form() + assert a_echelon.is_echelon + verify_row_null_space(a, rows, nulls) + + a = ReductionsOnlyMatrix(3, 3, [2, 1, 3, 0, 0, 0, 2, 1, 3]) + nulls = [Matrix([ + [Rational(-1, 2)], + [ 1], + [ 0]]), + Matrix([ + [Rational(-3, 2)], + [ 0], + [ 1]])] + rows = [a[i, :] for i in range(a.rows)] + a_echelon = a.echelon_form() + assert a_echelon.is_echelon + verify_row_null_space(a, rows, nulls) + + # this one requires a row swap + a = ReductionsOnlyMatrix(3, 3, [2, 1, 3, 0, 0, 0, 1, 1, 3]) + nulls = [Matrix([ + [ 0], + [ -3], + [ 1]])] + rows = [a[i, :] for i in range(a.rows)] + a_echelon = a.echelon_form() + assert a_echelon.is_echelon + verify_row_null_space(a, rows, nulls) + + a = ReductionsOnlyMatrix(3, 3, [0, 3, 3, 0, 2, 2, 0, 1, 1]) + nulls = [Matrix([ + [1], + [0], + [0]]), + Matrix([ + [ 0], + [-1], + [ 1]])] + rows = [a[i, :] for i in range(a.rows)] + a_echelon = a.echelon_form() + assert a_echelon.is_echelon + verify_row_null_space(a, rows, nulls) + + a = ReductionsOnlyMatrix(2, 3, [2, 2, 3, 3, 3, 0]) + nulls = [Matrix([ + [-1], + [1], + [0]])] + rows = [a[i, :] for i in range(a.rows)] + a_echelon = a.echelon_form() + assert a_echelon.is_echelon + verify_row_null_space(a, rows, nulls) + + +def test_rref(): + e = ReductionsOnlyMatrix(0, 0, []) + assert e.rref(pivots=False) == e + + e = ReductionsOnlyMatrix(1, 1, [1]) + a = ReductionsOnlyMatrix(1, 1, [5]) + assert e.rref(pivots=False) == a.rref(pivots=False) == e + + a = ReductionsOnlyMatrix(3, 1, [1, 2, 3]) + assert a.rref(pivots=False) == Matrix([[1], [0], [0]]) + + a = ReductionsOnlyMatrix(1, 3, [1, 2, 3]) + assert a.rref(pivots=False) == Matrix([[1, 2, 3]]) + + a = ReductionsOnlyMatrix(3, 3, [1, 2, 3, 4, 5, 6, 7, 8, 9]) + assert a.rref(pivots=False) == Matrix([ + [1, 0, -1], + [0, 1, 2], + [0, 0, 0]]) + + a = ReductionsOnlyMatrix(3, 3, [1, 2, 3, 1, 2, 3, 1, 2, 3]) + b = ReductionsOnlyMatrix(3, 3, [1, 2, 3, 0, 0, 0, 0, 0, 0]) + c = ReductionsOnlyMatrix(3, 3, [0, 0, 0, 1, 2, 3, 0, 0, 0]) + d = ReductionsOnlyMatrix(3, 3, [0, 0, 0, 0, 0, 0, 1, 2, 3]) + assert a.rref(pivots=False) == \ + b.rref(pivots=False) == \ + c.rref(pivots=False) == \ + d.rref(pivots=False) == b + + e = eye_Reductions(3) + z = zeros_Reductions(3) + assert e.rref(pivots=False) == e + assert z.rref(pivots=False) == z + + a = ReductionsOnlyMatrix([ + [ 0, 0, 1, 2, 2, -5, 3], + [-1, 5, 2, 2, 1, -7, 5], + [ 0, 0, -2, -3, -3, 8, -5], + [-1, 5, 0, -1, -2, 1, 0]]) + mat, pivot_offsets = a.rref() + assert mat == Matrix([ + [1, -5, 0, 0, 1, 1, -1], + [0, 0, 1, 0, 0, -1, 1], + [0, 0, 0, 1, 1, -2, 1], + [0, 0, 0, 0, 0, 0, 0]]) + assert pivot_offsets == (0, 2, 3) + + a = ReductionsOnlyMatrix([[Rational(1, 19), Rational(1, 5), 2, 3], + [ 4, 5, 6, 7], + [ 8, 9, 10, 11], + [ 12, 13, 14, 15]]) + assert a.rref(pivots=False) == Matrix([ + [1, 0, 0, Rational(-76, 157)], + [0, 1, 0, Rational(-5, 157)], + [0, 0, 1, Rational(238, 157)], + [0, 0, 0, 0]]) + + x = Symbol('x') + a = ReductionsOnlyMatrix(2, 3, [x, 1, 1, sqrt(x), x, 1]) + for i, j in zip(a.rref(pivots=False), + [1, 0, sqrt(x)*(-x + 1)/(-x**Rational(5, 2) + x), + 0, 1, 1/(sqrt(x) + x + 1)]): + assert simplify(i - j).is_zero + + +def test_rref_rhs(): + a, b, c, d = symbols('a b c d') + A = Matrix([[0, 0], [0, 0], [1, 2], [3, 4]]) + B = Matrix([a, b, c, d]) + assert A.rref_rhs(B) == (Matrix([ + [1, 0], + [0, 1], + [0, 0], + [0, 0]]), Matrix([ + [ -2*c + d], + [3*c/2 - d/2], + [ a], + [ b]])) + + +def test_issue_17827(): + C = Matrix([ + [3, 4, -1, 1], + [9, 12, -3, 3], + [0, 2, 1, 3], + [2, 3, 0, -2], + [0, 3, 3, -5], + [8, 15, 0, 6] + ]) + # Tests for row/col within valid range + D = C.elementary_row_op('n<->m', row1=2, row2=5) + E = C.elementary_row_op('n->n+km', row1=5, row2=3, k=-4) + F = C.elementary_row_op('n->kn', row=5, k=2) + assert(D[5, :] == Matrix([[0, 2, 1, 3]])) + assert(E[5, :] == Matrix([[0, 3, 0, 14]])) + assert(F[5, :] == Matrix([[16, 30, 0, 12]])) + # Tests for row/col out of range + raises(ValueError, lambda: C.elementary_row_op('n<->m', row1=2, row2=6)) + raises(ValueError, lambda: C.elementary_row_op('n->kn', row=7, k=2)) + raises(ValueError, lambda: C.elementary_row_op('n->n+km', row1=-1, row2=5, k=2)) + +def test_rank(): + m = Matrix([[1, 2], [x, 1 - 1/x]]) + assert m.rank() == 2 + n = Matrix(3, 3, range(1, 10)) + assert n.rank() == 2 + p = zeros(3) + assert p.rank() == 0 + +def test_issue_11434(): + ax, ay, bx, by, cx, cy, dx, dy, ex, ey, t0, t1 = \ + symbols('a_x a_y b_x b_y c_x c_y d_x d_y e_x e_y t_0 t_1') + M = Matrix([[ax, ay, ax*t0, ay*t0, 0], + [bx, by, bx*t0, by*t0, 0], + [cx, cy, cx*t0, cy*t0, 1], + [dx, dy, dx*t0, dy*t0, 1], + [ex, ey, 2*ex*t1 - ex*t0, 2*ey*t1 - ey*t0, 0]]) + assert M.rank() == 4 + +def test_rank_regression_from_so(): + # see: + # https://stackoverflow.com/questions/19072700/why-does-sympy-give-me-the-wrong-answer-when-i-row-reduce-a-symbolic-matrix + + nu, lamb = symbols('nu, lambda') + A = Matrix([[-3*nu, 1, 0, 0], + [ 3*nu, -2*nu - 1, 2, 0], + [ 0, 2*nu, (-1*nu) - lamb - 2, 3], + [ 0, 0, nu + lamb, -3]]) + expected_reduced = Matrix([[1, 0, 0, 1/(nu**2*(-lamb - nu))], + [0, 1, 0, 3/(nu*(-lamb - nu))], + [0, 0, 1, 3/(-lamb - nu)], + [0, 0, 0, 0]]) + expected_pivots = (0, 1, 2) + + reduced, pivots = A.rref() + + assert simplify(expected_reduced - reduced) == zeros(*A.shape) + assert pivots == expected_pivots + +def test_issue_15872(): + A = Matrix([[1, 1, 1, 0], [-2, -1, 0, -1], [0, 0, -1, -1], [0, 0, 2, 1]]) + B = A - Matrix.eye(4) * I + assert B.rank() == 3 + assert (B**2).rank() == 2 + assert (B**3).rank() == 2 diff --git a/phivenv/Lib/site-packages/sympy/matrices/tests/test_matrixbase.py b/phivenv/Lib/site-packages/sympy/matrices/tests/test_matrixbase.py new file mode 100644 index 0000000000000000000000000000000000000000..a77f51596c6622dc427feeeb9383214592fab632 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/tests/test_matrixbase.py @@ -0,0 +1,3795 @@ +import concurrent.futures +import random +from collections.abc import Hashable + +from sympy import ( + Abs, Add, Array, DeferredVector, E, Expr, FiniteSet, Float, Function, + GramSchmidt, I, ImmutableDenseMatrix, ImmutableMatrix, + ImmutableSparseMatrix, Integer, KroneckerDelta, MatPow, Matrix, + MatrixSymbol, Max, Min, MutableDenseMatrix, MutableSparseMatrix, Poly, Pow, + PurePoly, Q, Quaternion, Rational, RootOf, S, SparseMatrix, Symbol, Tuple, + Wild, banded, casoratian, cos, diag, diff, exp, expand, eye, floor, hessian, + integrate, log, matrix_multiply_elementwise, nan, ones, oo, pi, randMatrix, + rot_axis1, rot_axis2, rot_axis3, rot_ccw_axis1, rot_ccw_axis2, + rot_ccw_axis3, signsimp, simplify, sin, sqrt, sstr, symbols, sympify, tan, + trigsimp, wronskian, zeros, cancel) +from sympy.abc import a, b, c, d, t, x, y, z +from sympy.core.kind import NumberKind, UndefinedKind +from sympy.matrices.determinant import _find_reasonable_pivot_naive +from sympy.matrices.exceptions import ( + MatrixError, NonSquareMatrixError, ShapeError) +from sympy.matrices.kind import MatrixKind +from sympy.matrices.utilities import _dotprodsimp_state, _simplify, dotprodsimp +from sympy.tensor.array.array_derivatives import ArrayDerivative +from sympy.testing.pytest import ( + ignore_warnings, raises, skip, skip_under_pyodide, slow, + warns_deprecated_sympy) +from sympy.utilities.iterables import capture, iterable +from importlib.metadata import version + +all_classes = (Matrix, SparseMatrix, ImmutableMatrix, ImmutableSparseMatrix) +mutable_classes = (Matrix, SparseMatrix) +immutable_classes = (ImmutableMatrix, ImmutableSparseMatrix) + + +def test__MinimalMatrix(): + x = Matrix(2, 3, [1, 2, 3, 4, 5, 6]) + assert x.rows == 2 + assert x.cols == 3 + assert x[2] == 3 + assert x[1, 1] == 5 + assert list(x) == [1, 2, 3, 4, 5, 6] + assert list(x[1, :]) == [4, 5, 6] + assert list(x[:, 1]) == [2, 5] + assert list(x[:, :]) == list(x) + assert x[:, :] == x + assert Matrix(x) == x + assert Matrix([[1, 2, 3], [4, 5, 6]]) == x + assert Matrix(([1, 2, 3], [4, 5, 6])) == x + assert Matrix([(1, 2, 3), (4, 5, 6)]) == x + assert Matrix(((1, 2, 3), (4, 5, 6))) == x + assert not (Matrix([[1, 2], [3, 4], [5, 6]]) == x) + + +def test_kind(): + assert Matrix([[1, 2], [3, 4]]).kind == MatrixKind(NumberKind) + assert Matrix([[0, 0], [0, 0]]).kind == MatrixKind(NumberKind) + assert Matrix(0, 0, []).kind == MatrixKind(NumberKind) + assert Matrix([[x]]).kind == MatrixKind(NumberKind) + assert Matrix([[1, Matrix([[1]])]]).kind == MatrixKind(UndefinedKind) + assert SparseMatrix([[1]]).kind == MatrixKind(NumberKind) + assert SparseMatrix([[1, Matrix([[1]])]]).kind == MatrixKind(UndefinedKind) + + +def test_todok(): + a, b, c, d = symbols('a:d') + m1 = MutableDenseMatrix([[a, b], [c, d]]) + m2 = ImmutableDenseMatrix([[a, b], [c, d]]) + m3 = MutableSparseMatrix([[a, b], [c, d]]) + m4 = ImmutableSparseMatrix([[a, b], [c, d]]) + assert m1.todok() == m2.todok() == m3.todok() == m4.todok() == \ + {(0, 0): a, (0, 1): b, (1, 0): c, (1, 1): d} + + +def test_tolist(): + lst = [[S.One, S.Half, x*y, S.Zero], [x, y, z, x**2], [y, -S.One, z*x, 3]] + flat_lst = [S.One, S.Half, x*y, S.Zero, x, y, z, x**2, y, -S.One, z*x, 3] + m = Matrix(3, 4, flat_lst) + assert m.tolist() == lst + + +def test_todod(): + m = Matrix([[S.One, 0], [0, S.Half], [x, 0]]) + dict = {0: {0: S.One}, 1: {1: S.Half}, 2: {0: x}} + assert m.todod() == dict + + +def test_row_col_del(): + e = ImmutableMatrix(3, 3, [1, 2, 3, 4, 5, 6, 7, 8, 9]) + raises(IndexError, lambda: e.row_del(5)) + raises(IndexError, lambda: e.row_del(-5)) + raises(IndexError, lambda: e.col_del(5)) + raises(IndexError, lambda: e.col_del(-5)) + + assert e.row_del(2) == e.row_del(-1) == Matrix([[1, 2, 3], [4, 5, 6]]) + assert e.col_del(2) == e.col_del(-1) == Matrix([[1, 2], [4, 5], [7, 8]]) + + assert e.row_del(1) == e.row_del(-2) == Matrix([[1, 2, 3], [7, 8, 9]]) + assert e.col_del(1) == e.col_del(-2) == Matrix([[1, 3], [4, 6], [7, 9]]) + + +def test_get_diag_blocks1(): + a = Matrix([[1, 2], [2, 3]]) + b = Matrix([[3, x], [y, 3]]) + c = Matrix([[3, x, 3], [y, 3, z], [x, y, z]]) + assert a.get_diag_blocks() == [a] + assert b.get_diag_blocks() == [b] + assert c.get_diag_blocks() == [c] + + +def test_get_diag_blocks2(): + a = Matrix([[1, 2], [2, 3]]) + b = Matrix([[3, x], [y, 3]]) + c = Matrix([[3, x, 3], [y, 3, z], [x, y, z]]) + A, B, C, D = diag(a, b, b), diag(a, b, c), diag(a, c, b), diag(c, c, b) + A = Matrix(A.rows, A.cols, A) + B = Matrix(B.rows, B.cols, B) + C = Matrix(C.rows, C.cols, C) + D = Matrix(D.rows, D.cols, D) + + assert A.get_diag_blocks() == [a, b, b] + assert B.get_diag_blocks() == [a, b, c] + assert C.get_diag_blocks() == [a, c, b] + assert D.get_diag_blocks() == [c, c, b] + + +def test_row_col(): + m = Matrix(3, 3, [1, 2, 3, 4, 5, 6, 7, 8, 9]) + assert m.row(0) == Matrix(1, 3, [1, 2, 3]) + assert m.col(0) == Matrix(3, 1, [1, 4, 7]) + + +def test_row_join(): + assert eye(3).row_join(Matrix([7, 7, 7])) == \ + Matrix([[1, 0, 0, 7], + [0, 1, 0, 7], + [0, 0, 1, 7]]) + + +def test_col_join(): + assert eye(3).col_join(Matrix([[7, 7, 7]])) == \ + Matrix([[1, 0, 0], + [0, 1, 0], + [0, 0, 1], + [7, 7, 7]]) + + +def test_row_insert(): + r4 = Matrix([[4, 4, 4]]) + for i in range(-4, 5): + l = [1, 0, 0] + l.insert(i, 4) + assert eye(3).row_insert(i, r4).col(0).flat() == l + + +def test_col_insert(): + c4 = Matrix([4, 4, 4]) + for i in range(-4, 5): + l = [0, 0, 0] + l.insert(i, 4) + assert zeros(3).col_insert(i, c4).row(0).flat() == l + # issue 13643 + assert eye(6).col_insert(3, Matrix([[2, 2], [2, 2], [2, 2], [2, 2], [2, 2], [2, 2]])) == \ + Matrix([[1, 0, 0, 2, 2, 0, 0, 0], + [0, 1, 0, 2, 2, 0, 0, 0], + [0, 0, 1, 2, 2, 0, 0, 0], + [0, 0, 0, 2, 2, 1, 0, 0], + [0, 0, 0, 2, 2, 0, 1, 0], + [0, 0, 0, 2, 2, 0, 0, 1]]) + + +def test_extract(): + m = Matrix(4, 3, lambda i, j: i*3 + j) + assert m.extract([0, 1, 3], [0, 1]) == Matrix(3, 2, [0, 1, 3, 4, 9, 10]) + assert m.extract([0, 3], [0, 0, 2]) == Matrix(2, 3, [0, 0, 2, 9, 9, 11]) + assert m.extract(range(4), range(3)) == m + raises(IndexError, lambda: m.extract([4], [0])) + raises(IndexError, lambda: m.extract([0], [3])) + + +def test_hstack(): + m = Matrix(4, 3, lambda i, j: i*3 + j) + m2 = Matrix(3, 4, lambda i, j: i*3 + j) + assert m == m.hstack(m) + assert m.hstack(m, m, m) == Matrix.hstack(m, m, m) == Matrix([ + [0, 1, 2, 0, 1, 2, 0, 1, 2], + [3, 4, 5, 3, 4, 5, 3, 4, 5], + [6, 7, 8, 6, 7, 8, 6, 7, 8], + [9, 10, 11, 9, 10, 11, 9, 10, 11]]) + raises(ShapeError, lambda: m.hstack(m, m2)) + assert Matrix.hstack() == Matrix() + + # test regression #12938 + M1 = Matrix.zeros(0, 0) + M2 = Matrix.zeros(0, 1) + M3 = Matrix.zeros(0, 2) + M4 = Matrix.zeros(0, 3) + m = Matrix.hstack(M1, M2, M3, M4) + assert m.rows == 0 and m.cols == 6 + + +def test_vstack(): + m = Matrix(4, 3, lambda i, j: i*3 + j) + m2 = Matrix(3, 4, lambda i, j: i*3 + j) + assert m == m.vstack(m) + assert m.vstack(m, m, m) == Matrix.vstack(m, m, m) == Matrix([ + [0, 1, 2], + [3, 4, 5], + [6, 7, 8], + [9, 10, 11], + [0, 1, 2], + [3, 4, 5], + [6, 7, 8], + [9, 10, 11], + [0, 1, 2], + [3, 4, 5], + [6, 7, 8], + [9, 10, 11]]) + raises(ShapeError, lambda: m.vstack(m, m2)) + assert Matrix.vstack() == Matrix() + + +def test_has(): + A = Matrix(((x, y), (2, 3))) + assert A.has(x) + assert not A.has(z) + assert A.has(Symbol) + + A = Matrix(((2, y), (2, 3))) + assert not A.has(x) + + +def test_is_anti_symmetric(): + x = symbols('x') + assert Matrix(2, 1, [1, 2]).is_anti_symmetric() is False + m = Matrix(3, 3, [0, x**2 + 2*x + 1, y, -(x + 1)**2, 0, x*y, -y, -x*y, 0]) + assert m.is_anti_symmetric() is True + assert m.is_anti_symmetric(simplify=False) is None + assert m.is_anti_symmetric(simplify=lambda x: x) is None + + m = Matrix(3, 3, [x.expand() for x in m]) + assert m.is_anti_symmetric(simplify=False) is True + m = Matrix(3, 3, [x.expand() for x in [S.One] + list(m)[1:]]) + assert m.is_anti_symmetric() is False + + +def test_is_hermitian(): + a = Matrix([[1, I], [-I, 1]]) + assert a.is_hermitian + a = Matrix([[2*I, I], [-I, 1]]) + assert a.is_hermitian is False + a = Matrix([[x, I], [-I, 1]]) + assert a.is_hermitian is None + a = Matrix([[x, 1], [-I, 1]]) + assert a.is_hermitian is False + + +def test_is_symbolic(): + a = Matrix([[x, x], [x, x]]) + assert a.is_symbolic() is True + a = Matrix([[1, 2, 3, 4], [5, 6, 7, 8]]) + assert a.is_symbolic() is False + a = Matrix([[1, 2, 3, 4], [5, 6, x, 8]]) + assert a.is_symbolic() is True + a = Matrix([[1, x, 3]]) + assert a.is_symbolic() is True + a = Matrix([[1, 2, 3]]) + assert a.is_symbolic() is False + a = Matrix([[1], [x], [3]]) + assert a.is_symbolic() is True + a = Matrix([[1], [2], [3]]) + assert a.is_symbolic() is False + + +def test_is_square(): + m = Matrix([[1], [1]]) + m2 = Matrix([[2, 2], [2, 2]]) + assert not m.is_square + assert m2.is_square + + +def test_is_symmetric(): + m = Matrix(2, 2, [0, 1, 1, 0]) + assert m.is_symmetric() + m = Matrix(2, 2, [0, 1, 0, 1]) + assert not m.is_symmetric() + + +def test_is_hessenberg(): + A = Matrix([[3, 4, 1], [2, 4, 5], [0, 1, 2]]) + assert A.is_upper_hessenberg + A = Matrix(3, 3, [3, 2, 0, 4, 4, 1, 1, 5, 2]) + assert A.is_lower_hessenberg + A = Matrix(3, 3, [3, 2, -1, 4, 4, 1, 1, 5, 2]) + assert A.is_lower_hessenberg is False + assert A.is_upper_hessenberg is False + + A = Matrix([[3, 4, 1], [2, 4, 5], [3, 1, 2]]) + assert not A.is_upper_hessenberg + + +def test_values(): + assert set(Matrix(2, 2, [0, 1, 2, 3] + ).values()) == {1, 2, 3} + x = Symbol('x', real=True) + assert set(Matrix(2, 2, [x, 0, 0, 1] + ).values()) == {x, 1} + + +def test_conjugate(): + M = Matrix([[0, I, 5], + [1, 2, 0]]) + + assert M.T == Matrix([[0, 1], + [I, 2], + [5, 0]]) + + assert M.C == Matrix([[0, -I, 5], + [1, 2, 0]]) + assert M.C == M.conjugate() + + assert M.H == M.T.C + assert M.H == Matrix([[ 0, 1], + [-I, 2], + [ 5, 0]]) + + +def test_doit(): + a = Matrix([[Add(x, x, evaluate=False)]]) + assert a[0] != 2*x + assert a.doit() == Matrix([[2*x]]) + + +def test_evalf(): + a = Matrix(2, 1, [sqrt(5), 6]) + assert all(a.evalf()[i] == a[i].evalf() for i in range(2)) + assert all(a.evalf(2)[i] == a[i].evalf(2) for i in range(2)) + assert all(a.n(2)[i] == a[i].n(2) for i in range(2)) + + +def test_replace(): + F, G = symbols('F, G', cls=Function) + K = Matrix(2, 2, lambda i, j: G(i+j)) + M = Matrix(2, 2, lambda i, j: F(i+j)) + N = M.replace(F, G) + assert N == K + + +def test_replace_map(): + F, G = symbols('F, G', cls=Function) + M = Matrix(2, 2, lambda i, j: F(i+j)) + N, d = M.replace(F, G, True) + assert N == Matrix(2, 2, lambda i, j: G(i+j)) + assert d == {F(0): G(0), F(1): G(1), F(2): G(2)} + +def test_numpy_conversion(): + try: + from numpy import array, array_equal + except ImportError: + skip('NumPy must be available to test creating matrices from ndarrays') + A = Matrix([[1,2], [3,4]]) + np_array = array([[1,2], [3,4]]) + assert array_equal(array(A), np_array) + assert array_equal(array(A, copy=True), np_array) + if(int(version('numpy').split('.')[0]) >= 2): #run this test only if numpy is new enough that copy variable is passed properly. + raises(TypeError, lambda: array(A, copy=False)) + +def test_rot90(): + A = Matrix([[1, 2], [3, 4]]) + assert A == A.rot90(0) == A.rot90(4) + assert A.rot90(2) == A.rot90(-2) == A.rot90(6) == Matrix(((4, 3), (2, 1))) + assert A.rot90(3) == A.rot90(-1) == A.rot90(7) == Matrix(((2, 4), (1, 3))) + assert A.rot90() == A.rot90(-7) == A.rot90(-3) == Matrix(((3, 1), (4, 2))) + + +def test_subs(): + assert Matrix([[1, x], [x, 4]]).subs(x, 5) == Matrix([[1, 5], [5, 4]]) + assert Matrix([[x, 2], [x + y, 4]]).subs([[x, -1], [y, -2]]) == \ + Matrix([[-1, 2], [-3, 4]]) + assert Matrix([[x, 2], [x + y, 4]]).subs([(x, -1), (y, -2)]) == \ + Matrix([[-1, 2], [-3, 4]]) + assert Matrix([[x, 2], [x + y, 4]]).subs({x: -1, y: -2}) == \ + Matrix([[-1, 2], [-3, 4]]) + assert Matrix([[x*y]]).subs({x: y - 1, y: x - 1}, simultaneous=True) == \ + Matrix([[(x - 1)*(y - 1)]]) + + +def test_permute(): + a = Matrix(3, 4, [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]) + + raises(IndexError, lambda: a.permute([[0, 5]])) + raises(ValueError, lambda: a.permute(Symbol('x'))) + b = a.permute_rows([[0, 2], [0, 1]]) + assert a.permute([[0, 2], [0, 1]]) == b == Matrix([ + [5, 6, 7, 8], + [9, 10, 11, 12], + [1, 2, 3, 4]]) + + b = a.permute_cols([[0, 2], [0, 1]]) + assert a.permute([[0, 2], [0, 1]], orientation='cols') == b ==\ + Matrix([ + [ 2, 3, 1, 4], + [ 6, 7, 5, 8], + [10, 11, 9, 12]]) + + b = a.permute_cols([[0, 2], [0, 1]], direction='backward') + assert a.permute([[0, 2], [0, 1]], orientation='cols', direction='backward') == b ==\ + Matrix([ + [ 3, 1, 2, 4], + [ 7, 5, 6, 8], + [11, 9, 10, 12]]) + + assert a.permute([1, 2, 0, 3]) == Matrix([ + [5, 6, 7, 8], + [9, 10, 11, 12], + [1, 2, 3, 4]]) + + from sympy.combinatorics import Permutation + assert a.permute(Permutation([1, 2, 0, 3])) == Matrix([ + [5, 6, 7, 8], + [9, 10, 11, 12], + [1, 2, 3, 4]]) + +def test_upper_triangular(): + + A = Matrix([ + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1] + ]) + + R = A.upper_triangular(2) + assert R == Matrix([ + [0, 0, 1, 1], + [0, 0, 0, 1], + [0, 0, 0, 0], + [0, 0, 0, 0] + ]) + + R = A.upper_triangular(-2) + assert R == Matrix([ + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [0, 1, 1, 1] + ]) + + R = A.upper_triangular() + assert R == Matrix([ + [1, 1, 1, 1], + [0, 1, 1, 1], + [0, 0, 1, 1], + [0, 0, 0, 1] + ]) + + +def test_lower_triangular(): + A = Matrix([ + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1] + ]) + + L = A.lower_triangular() + assert L == Matrix([ + [1, 0, 0, 0], + [1, 1, 0, 0], + [1, 1, 1, 0], + [1, 1, 1, 1]]) + + L = A.lower_triangular(2) + assert L == Matrix([ + [1, 1, 1, 0], + [1, 1, 1, 1], + [1, 1, 1, 1], + [1, 1, 1, 1] + ]) + + L = A.lower_triangular(-2) + assert L == Matrix([ + [0, 0, 0, 0], + [0, 0, 0, 0], + [1, 0, 0, 0], + [1, 1, 0, 0] + ]) + + +def test_add(): + m = Matrix([[1, 2, 3], [x, y, x], [2*y, -50, z*x]]) + assert m + m == Matrix([[2, 4, 6], [2*x, 2*y, 2*x], [4*y, -100, 2*z*x]]) + n = Matrix(1, 2, [1, 2]) + raises(ShapeError, lambda: m + n) + + +def test_matmul(): + a = Matrix([[1, 2], [3, 4]]) + + assert a.__matmul__(2) == NotImplemented + + assert a.__rmatmul__(2) == NotImplemented + + #This is done this way because @ is only supported in Python 3.5+ + #To check 2@a case + try: + eval('2 @ a') + except SyntaxError: + pass + except TypeError: #TypeError is raised in case of NotImplemented is returned + pass + + #Check a@2 case + try: + eval('a @ 2') + except SyntaxError: + pass + except TypeError: #TypeError is raised in case of NotImplemented is returned + pass + + +def test_non_matmul(): + """ + Test that if explicitly specified as non-matrix, mul reverts + to scalar multiplication. + """ + class foo(Expr): + is_Matrix=False + is_MatrixLike=False + shape = (1, 1) + + A = Matrix([[1, 2], [3, 4]]) + b = foo() + assert b*A == Matrix([[b, 2*b], [3*b, 4*b]]) + assert A*b == Matrix([[b, 2*b], [3*b, 4*b]]) + + +def test_neg(): + n = Matrix(1, 2, [1, 2]) + assert -n == Matrix(1, 2, [-1, -2]) + + +def test_sub(): + n = Matrix(1, 2, [1, 2]) + assert n - n == Matrix(1, 2, [0, 0]) + + +def test_div(): + n = Matrix(1, 2, [1, 2]) + assert n/2 == Matrix(1, 2, [S.Half, S(2)/2]) + + +def test_eye(): + assert list(Matrix.eye(2, 2)) == [1, 0, 0, 1] + assert list(Matrix.eye(2)) == [1, 0, 0, 1] + assert type(Matrix.eye(2)) == Matrix + assert type(Matrix.eye(2, cls=Matrix)) == Matrix + + +def test_ones(): + assert list(Matrix.ones(2, 2)) == [1, 1, 1, 1] + assert list(Matrix.ones(2)) == [1, 1, 1, 1] + assert Matrix.ones(2, 3) == Matrix([[1, 1, 1], [1, 1, 1]]) + assert type(Matrix.ones(2)) == Matrix + assert type(Matrix.ones(2, cls=Matrix)) == Matrix + + +def test_zeros(): + assert list(Matrix.zeros(2, 2)) == [0, 0, 0, 0] + assert list(Matrix.zeros(2)) == [0, 0, 0, 0] + assert Matrix.zeros(2, 3) == Matrix([[0, 0, 0], [0, 0, 0]]) + assert type(Matrix.zeros(2)) == Matrix + assert type(Matrix.zeros(2, cls=Matrix)) == Matrix + + +def test_diag_make(): + diag = Matrix.diag + a = Matrix([[1, 2], [2, 3]]) + b = Matrix([[3, x], [y, 3]]) + c = Matrix([[3, x, 3], [y, 3, z], [x, y, z]]) + assert diag(a, b, b) == Matrix([ + [1, 2, 0, 0, 0, 0], + [2, 3, 0, 0, 0, 0], + [0, 0, 3, x, 0, 0], + [0, 0, y, 3, 0, 0], + [0, 0, 0, 0, 3, x], + [0, 0, 0, 0, y, 3], + ]) + assert diag(a, b, c) == Matrix([ + [1, 2, 0, 0, 0, 0, 0], + [2, 3, 0, 0, 0, 0, 0], + [0, 0, 3, x, 0, 0, 0], + [0, 0, y, 3, 0, 0, 0], + [0, 0, 0, 0, 3, x, 3], + [0, 0, 0, 0, y, 3, z], + [0, 0, 0, 0, x, y, z], + ]) + assert diag(a, c, b) == Matrix([ + [1, 2, 0, 0, 0, 0, 0], + [2, 3, 0, 0, 0, 0, 0], + [0, 0, 3, x, 3, 0, 0], + [0, 0, y, 3, z, 0, 0], + [0, 0, x, y, z, 0, 0], + [0, 0, 0, 0, 0, 3, x], + [0, 0, 0, 0, 0, y, 3], + ]) + a = Matrix([x, y, z]) + b = Matrix([[1, 2], [3, 4]]) + c = Matrix([[5, 6]]) + # this "wandering diagonal" is what makes this + # a block diagonal where each block is independent + # of the others + assert diag(a, 7, b, c) == Matrix([ + [x, 0, 0, 0, 0, 0], + [y, 0, 0, 0, 0, 0], + [z, 0, 0, 0, 0, 0], + [0, 7, 0, 0, 0, 0], + [0, 0, 1, 2, 0, 0], + [0, 0, 3, 4, 0, 0], + [0, 0, 0, 0, 5, 6]]) + raises(ValueError, lambda: diag(a, 7, b, c, rows=5)) + assert diag(1) == Matrix([[1]]) + assert diag(1, rows=2) == Matrix([[1, 0], [0, 0]]) + assert diag(1, cols=2) == Matrix([[1, 0], [0, 0]]) + assert diag(1, rows=3, cols=2) == Matrix([[1, 0], [0, 0], [0, 0]]) + assert diag(*[2, 3]) == Matrix([ + [2, 0], + [0, 3]]) + assert diag(Matrix([2, 3])) == Matrix([ + [2], + [3]]) + assert diag([1, [2, 3], 4], unpack=False) == \ + diag([[1], [2, 3], [4]], unpack=False) == Matrix([ + [1, 0], + [2, 3], + [4, 0]]) + assert type(diag(1)) == Matrix + assert type(diag(1, cls=Matrix)) == Matrix + assert Matrix.diag([1, 2, 3]) == Matrix.diag(1, 2, 3) + assert Matrix.diag([1, 2, 3], unpack=False).shape == (3, 1) + assert Matrix.diag([[1, 2, 3]]).shape == (3, 1) + assert Matrix.diag([[1, 2, 3]], unpack=False).shape == (1, 3) + assert Matrix.diag([[[1, 2, 3]]]).shape == (1, 3) + # kerning can be used to move the starting point + assert Matrix.diag(ones(0, 2), 1, 2) == Matrix([ + [0, 0, 1, 0], + [0, 0, 0, 2]]) + assert Matrix.diag(ones(2, 0), 1, 2) == Matrix([ + [0, 0], + [0, 0], + [1, 0], + [0, 2]]) + + +def test_diagonal(): + m = Matrix(3, 3, range(9)) + d = m.diagonal() + assert d == m.diagonal(0) + assert tuple(d) == (0, 4, 8) + assert tuple(m.diagonal(1)) == (1, 5) + assert tuple(m.diagonal(-1)) == (3, 7) + assert tuple(m.diagonal(2)) == (2,) + assert type(m.diagonal()) == type(m) + s = SparseMatrix(3, 3, {(1, 1): 1}) + assert type(s.diagonal()) == type(s) + assert type(m) != type(s) + raises(ValueError, lambda: m.diagonal(3)) + raises(ValueError, lambda: m.diagonal(-3)) + raises(ValueError, lambda: m.diagonal(pi)) + M = ones(2, 3) + assert banded({i: list(M.diagonal(i)) + for i in range(1-M.rows, M.cols)}) == M + + +def test_jordan_block(): + assert Matrix.jordan_block(3, 2) == Matrix.jordan_block(3, eigenvalue=2) \ + == Matrix.jordan_block(size=3, eigenvalue=2) \ + == Matrix.jordan_block(3, 2, band='upper') \ + == Matrix.jordan_block( + size=3, eigenval=2, eigenvalue=2) \ + == Matrix([ + [2, 1, 0], + [0, 2, 1], + [0, 0, 2]]) + + assert Matrix.jordan_block(3, 2, band='lower') == Matrix([ + [2, 0, 0], + [1, 2, 0], + [0, 1, 2]]) + # missing eigenvalue + raises(ValueError, lambda: Matrix.jordan_block(2)) + # non-integral size + raises(ValueError, lambda: Matrix.jordan_block(3.5, 2)) + # size not specified + raises(ValueError, lambda: Matrix.jordan_block(eigenvalue=2)) + # inconsistent eigenvalue + raises(ValueError, + lambda: Matrix.jordan_block( + eigenvalue=2, eigenval=4)) + + # Using alias keyword + assert Matrix.jordan_block(size=3, eigenvalue=2) == \ + Matrix.jordan_block(size=3, eigenval=2) + + +def test_orthogonalize(): + m = Matrix([[1, 2], [3, 4]]) + assert m.orthogonalize(Matrix([[2], [1]])) == [Matrix([[2], [1]])] + assert m.orthogonalize(Matrix([[2], [1]]), normalize=True) == \ + [Matrix([[2*sqrt(5)/5], [sqrt(5)/5]])] + assert m.orthogonalize(Matrix([[1], [2]]), Matrix([[-1], [4]])) == \ + [Matrix([[1], [2]]), Matrix([[Rational(-12, 5)], [Rational(6, 5)]])] + assert m.orthogonalize(Matrix([[0], [0]]), Matrix([[-1], [4]])) == \ + [Matrix([[-1], [4]])] + assert m.orthogonalize(Matrix([[0], [0]])) == [] + + n = Matrix([[9, 1, 9], [3, 6, 10], [8, 5, 2]]) + vecs = [Matrix([[-5], [1]]), Matrix([[-5], [2]]), Matrix([[-5], [-2]])] + assert n.orthogonalize(*vecs) == \ + [Matrix([[-5], [1]]), Matrix([[Rational(5, 26)], [Rational(25, 26)]])] + + vecs = [Matrix([0, 0, 0]), Matrix([1, 2, 3]), Matrix([1, 4, 5])] + raises(ValueError, lambda: Matrix.orthogonalize(*vecs, rankcheck=True)) + + vecs = [Matrix([1, 2, 3]), Matrix([4, 5, 6]), Matrix([7, 8, 9])] + raises(ValueError, lambda: Matrix.orthogonalize(*vecs, rankcheck=True)) + +def test_wilkinson(): + + wminus, wplus = Matrix.wilkinson(1) + assert wminus == Matrix([ + [-1, 1, 0], + [1, 0, 1], + [0, 1, 1]]) + assert wplus == Matrix([ + [1, 1, 0], + [1, 0, 1], + [0, 1, 1]]) + + wminus, wplus = Matrix.wilkinson(3) + assert wminus == Matrix([ + [-3, 1, 0, 0, 0, 0, 0], + [1, -2, 1, 0, 0, 0, 0], + [0, 1, -1, 1, 0, 0, 0], + [0, 0, 1, 0, 1, 0, 0], + [0, 0, 0, 1, 1, 1, 0], + [0, 0, 0, 0, 1, 2, 1], + + [0, 0, 0, 0, 0, 1, 3]]) + + assert wplus == Matrix([ + [3, 1, 0, 0, 0, 0, 0], + [1, 2, 1, 0, 0, 0, 0], + [0, 1, 1, 1, 0, 0, 0], + [0, 0, 1, 0, 1, 0, 0], + [0, 0, 0, 1, 1, 1, 0], + [0, 0, 0, 0, 1, 2, 1], + [0, 0, 0, 0, 0, 1, 3]]) + + +def test_limit(): + x, y = symbols('x y') + m = Matrix(2, 1, [1/x, y]) + assert m.limit(x, 5) == Matrix(2, 1, [Rational(1, 5), y]) + A = Matrix(((1, 4, sin(x)/x), (y, 2, 4), (10, 5, x**2 + 1))) + assert A.limit(x, 0) == Matrix(((1, 4, 1), (y, 2, 4), (10, 5, 1))) + + +def test_issue_13774(): + M = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + v = [1, 1, 1] + raises(TypeError, lambda: M*v) + raises(TypeError, lambda: v*M) + + +def test_companion(): + x = Symbol('x') + y = Symbol('y') + raises(ValueError, lambda: Matrix.companion(1)) + raises(ValueError, lambda: Matrix.companion(Poly([1], x))) + raises(ValueError, lambda: Matrix.companion(Poly([2, 1], x))) + raises(ValueError, lambda: Matrix.companion(Poly(x*y, [x, y]))) + + c0, c1, c2 = symbols('c0:3') + assert Matrix.companion(Poly([1, c0], x)) == Matrix([-c0]) + assert Matrix.companion(Poly([1, c1, c0], x)) == \ + Matrix([[0, -c0], [1, -c1]]) + assert Matrix.companion(Poly([1, c2, c1, c0], x)) == \ + Matrix([[0, 0, -c0], [1, 0, -c1], [0, 1, -c2]]) + + +def test_issue_10589(): + x, y, z = symbols("x, y z") + M1 = Matrix([x, y, z]) + M1 = M1.subs(zip([x, y, z], [1, 2, 3])) + assert M1 == Matrix([[1], [2], [3]]) + + M2 = Matrix([[x, x, x, x, x], [x, x, x, x, x], [x, x, x, x, x]]) + M2 = M2.subs(zip([x], [1])) + assert M2 == Matrix([[1, 1, 1, 1, 1], [1, 1, 1, 1, 1], [1, 1, 1, 1, 1]]) + + +def test_rmul_pr19860(): + class Foo(ImmutableDenseMatrix): + _op_priority = MutableDenseMatrix._op_priority + 0.01 + + a = Matrix(2, 2, [1, 2, 3, 4]) + b = Foo(2, 2, [1, 2, 3, 4]) + + # This would throw a RecursionError: maximum recursion depth + # since b always has higher priority even after a.as_mutable() + c = a*b + + assert isinstance(c, Foo) + assert c == Matrix([[7, 10], [15, 22]]) + + +def test_issue_18956(): + A = Array([[1, 2], [3, 4]]) + B = Matrix([[1,2],[3,4]]) + raises(TypeError, lambda: B + A) + raises(TypeError, lambda: A + B) + + +def test__eq__(): + class My(object): + def __iter__(self): + yield 1 + yield 2 + return + def __getitem__(self, i): + return list(self)[i] + a = Matrix(2, 1, [1, 2]) + assert a != My() + class My_sympy(My): + def _sympy_(self): + return Matrix(self) + assert a == My_sympy() + + +def test_args(): + for n, cls in enumerate(all_classes): + m = cls.zeros(3, 2) + # all should give back the same type of arguments, e.g. ints for shape + assert m.shape == (3, 2) and all(type(i) is int for i in m.shape) + assert m.rows == 3 and type(m.rows) is int + assert m.cols == 2 and type(m.cols) is int + if not n % 2: + assert type(m.flat()) in (list, tuple, Tuple) + else: + assert type(m.todok()) is dict + + +def test_deprecated_mat_smat(): + for cls in Matrix, ImmutableMatrix: + m = cls.zeros(3, 2) + with warns_deprecated_sympy(): + mat = m._mat + assert mat == m.flat() + for cls in SparseMatrix, ImmutableSparseMatrix: + m = cls.zeros(3, 2) + with warns_deprecated_sympy(): + smat = m._smat + assert smat == m.todok() + + +def test_division(): + v = Matrix(1, 2, [x, y]) + assert v/z == Matrix(1, 2, [x/z, y/z]) + + +def test_sum(): + m = Matrix([[1, 2, 3], [x, y, x], [2*y, -50, z*x]]) + assert m + m == Matrix([[2, 4, 6], [2*x, 2*y, 2*x], [4*y, -100, 2*z*x]]) + n = Matrix(1, 2, [1, 2]) + raises(ShapeError, lambda: m + n) + + +def test_abs(): + m = Matrix([[1, -2], [x, y]]) + assert abs(m) == Matrix([[1, 2], [Abs(x), Abs(y)]]) + m = Matrix(1, 2, [-3, x]) + n = Matrix(1, 2, [3, Abs(x)]) + assert abs(m) == n + + +def test_addition(): + a = Matrix(( + (1, 2), + (3, 1), + )) + + b = Matrix(( + (1, 2), + (3, 0), + )) + + assert a + b == a.add(b) == Matrix([[2, 4], [6, 1]]) + + +def test_fancy_index_matrix(): + for M in (Matrix, SparseMatrix): + a = M(3, 3, range(9)) + assert a == a[:, :] + assert a[1, :] == Matrix(1, 3, [3, 4, 5]) + assert a[:, 1] == Matrix([1, 4, 7]) + assert a[[0, 1], :] == Matrix([[0, 1, 2], [3, 4, 5]]) + assert a[[0, 1], 2] == a[[0, 1], [2]] + assert a[2, [0, 1]] == a[[2], [0, 1]] + assert a[:, [0, 1]] == Matrix([[0, 1], [3, 4], [6, 7]]) + assert a[0, 0] == 0 + assert a[0:2, :] == Matrix([[0, 1, 2], [3, 4, 5]]) + assert a[:, 0:2] == Matrix([[0, 1], [3, 4], [6, 7]]) + assert a[::2, 1] == a[[0, 2], 1] + assert a[1, ::2] == a[1, [0, 2]] + a = M(3, 3, range(9)) + assert a[[0, 2, 1, 2, 1], :] == Matrix([ + [0, 1, 2], + [6, 7, 8], + [3, 4, 5], + [6, 7, 8], + [3, 4, 5]]) + assert a[:, [0,2,1,2,1]] == Matrix([ + [0, 2, 1, 2, 1], + [3, 5, 4, 5, 4], + [6, 8, 7, 8, 7]]) + + a = SparseMatrix.zeros(3) + a[1, 2] = 2 + a[0, 1] = 3 + a[2, 0] = 4 + assert a.extract([1, 1], [2]) == Matrix([ + [2], + [2]]) + assert a.extract([1, 0], [2, 2, 2]) == Matrix([ + [2, 2, 2], + [0, 0, 0]]) + assert a.extract([1, 0, 1, 2], [2, 0, 1, 0]) == Matrix([ + [2, 0, 0, 0], + [0, 0, 3, 0], + [2, 0, 0, 0], + [0, 4, 0, 4]]) + + +def test_multiplication(): + a = Matrix(( + (1, 2), + (3, 1), + (0, 6), + )) + + b = Matrix(( + (1, 2), + (3, 0), + )) + + raises(ShapeError, lambda: b*a) + raises(TypeError, lambda: a*{}) + + c = a*b + assert c[0, 0] == 7 + assert c[0, 1] == 2 + assert c[1, 0] == 6 + assert c[1, 1] == 6 + assert c[2, 0] == 18 + assert c[2, 1] == 0 + + c = a @ b + assert c[0, 0] == 7 + assert c[0, 1] == 2 + assert c[1, 0] == 6 + assert c[1, 1] == 6 + assert c[2, 0] == 18 + assert c[2, 1] == 0 + + h = matrix_multiply_elementwise(a, c) + assert h == a.multiply_elementwise(c) + assert h[0, 0] == 7 + assert h[0, 1] == 4 + assert h[1, 0] == 18 + assert h[1, 1] == 6 + assert h[2, 0] == 0 + assert h[2, 1] == 0 + raises(ShapeError, lambda: matrix_multiply_elementwise(a, b)) + + c = b * Symbol("x") + assert isinstance(c, Matrix) + assert c[0, 0] == x + assert c[0, 1] == 2*x + assert c[1, 0] == 3*x + assert c[1, 1] == 0 + + c2 = x * b + assert c == c2 + + c = 5 * b + assert isinstance(c, Matrix) + assert c[0, 0] == 5 + assert c[0, 1] == 2*5 + assert c[1, 0] == 3*5 + assert c[1, 1] == 0 + + M = Matrix([[oo, 0], [0, oo]]) + assert M ** 2 == M + + M = Matrix([[oo, oo], [0, 0]]) + assert M ** 2 == Matrix([[nan, nan], [nan, nan]]) + + # https://github.com/sympy/sympy/issues/22353 + A = Matrix(ones(3, 1)) + _h = -Rational(1, 2) + B = Matrix([_h, _h, _h]) + assert A.multiply_elementwise(B) == Matrix([ + [_h], + [_h], + [_h]]) + + +def test_power(): + raises(NonSquareMatrixError, lambda: Matrix((1, 2))**2) + + A = Matrix([[2, 3], [4, 5]]) + assert A**5 == Matrix([[6140, 8097], [10796, 14237]]) + A = Matrix([[2, 1, 3], [4, 2, 4], [6, 12, 1]]) + assert A**3 == Matrix([[290, 262, 251], [448, 440, 368], [702, 954, 433]]) + assert A**0 == eye(3) + assert A**1 == A + assert (Matrix([[2]]) ** 100)[0, 0] == 2**100 + assert Matrix([[1, 2], [3, 4]])**Integer(2) == Matrix([[7, 10], [15, 22]]) + A = Matrix([[1,2],[4,5]]) + assert A.pow(20, method='cayley') == A.pow(20, method='multiply') + assert A**Integer(2) == Matrix([[9, 12], [24, 33]]) + assert eye(2)**10000000 == eye(2) + + A = Matrix([[33, 24], [48, 57]]) + assert (A**S.Half)[:] == [5, 2, 4, 7] + A = Matrix([[0, 4], [-1, 5]]) + assert (A**S.Half)**2 == A + + assert Matrix([[1, 0], [1, 1]])**S.Half == Matrix([[1, 0], [S.Half, 1]]) + assert Matrix([[1, 0], [1, 1]])**0.5 == Matrix([[1, 0], [0.5, 1]]) + from sympy.abc import n + assert Matrix([[1, a], [0, 1]])**n == Matrix([[1, a*n], [0, 1]]) + assert Matrix([[b, a], [0, b]])**n == Matrix([[b**n, a*b**(n-1)*n], [0, b**n]]) + assert Matrix([ + [a**n, a**(n - 1)*n, (a**n*n**2 - a**n*n)/(2*a**2)], + [ 0, a**n, a**(n - 1)*n], + [ 0, 0, a**n]]) + assert Matrix([[a, 1, 0], [0, a, 0], [0, 0, b]])**n == Matrix([ + [a**n, a**(n-1)*n, 0], + [0, a**n, 0], + [0, 0, b**n]]) + + A = Matrix([[1, 0], [1, 7]]) + assert A._matrix_pow_by_jordan_blocks(S(3)) == A._eval_pow_by_recursion(3) + A = Matrix([[2]]) + assert A**10 == Matrix([[2**10]]) == A._matrix_pow_by_jordan_blocks(S(10)) == \ + A._eval_pow_by_recursion(10) + + # testing a matrix that cannot be jordan blocked issue 11766 + m = Matrix([[3, 0, 0, 0, -3], [0, -3, -3, 0, 3], [0, 3, 0, 3, 0], [0, 0, 3, 0, 3], [3, 0, 0, 3, 0]]) + raises(MatrixError, lambda: m._matrix_pow_by_jordan_blocks(S(10))) + + # test issue 11964 + raises(MatrixError, lambda: Matrix([[1, 1], [3, 3]])._matrix_pow_by_jordan_blocks(S(-10))) + A = Matrix([[0, 1, 0], [0, 0, 1], [0, 0, 0]]) # Nilpotent jordan block size 3 + assert A**10.0 == Matrix([[0, 0, 0], [0, 0, 0], [0, 0, 0]]) + raises(ValueError, lambda: A**2.1) + raises(ValueError, lambda: A**Rational(3, 2)) + A = Matrix([[8, 1], [3, 2]]) + assert A**10.0 == Matrix([[1760744107, 272388050], [817164150, 126415807]]) + A = Matrix([[0, 0, 1], [0, 0, 1], [0, 0, 1]]) # Nilpotent jordan block size 1 + assert A**10.0 == Matrix([[0, 0, 1], [0, 0, 1], [0, 0, 1]]) + A = Matrix([[0, 1, 0], [0, 0, 1], [0, 0, 1]]) # Nilpotent jordan block size 2 + assert A**10.0 == Matrix([[0, 0, 1], [0, 0, 1], [0, 0, 1]]) + n = Symbol('n', integer=True) + assert isinstance(A**n, MatPow) + n = Symbol('n', integer=True, negative=True) + raises(ValueError, lambda: A**n) + n = Symbol('n', integer=True, nonnegative=True) + assert A**n == Matrix([ + [KroneckerDelta(0, n), KroneckerDelta(1, n), -KroneckerDelta(0, n) - KroneckerDelta(1, n) + 1], + [ 0, KroneckerDelta(0, n), 1 - KroneckerDelta(0, n)], + [ 0, 0, 1]]) + assert A**(n + 2) == Matrix([[0, 0, 1], [0, 0, 1], [0, 0, 1]]) + raises(ValueError, lambda: A**Rational(3, 2)) + A = Matrix([[0, 0, 1], [3, 0, 1], [4, 3, 1]]) + assert A**5.0 == Matrix([[168, 72, 89], [291, 144, 161], [572, 267, 329]]) + assert A**5.0 == A**5 + A = Matrix([[0, 1, 0],[-1, 0, 0],[0, 0, 0]]) + n = Symbol("n") + An = A**n + assert An.subs(n, 2).doit() == A**2 + raises(ValueError, lambda: An.subs(n, -2).doit()) + assert An * An == A**(2*n) + + # concretizing behavior for non-integer and complex powers + A = Matrix([[0,0,0],[0,0,0],[0,0,0]]) + n = Symbol('n', integer=True, positive=True) + assert A**n == A + n = Symbol('n', integer=True, nonnegative=True) + assert A**n == diag(0**n, 0**n, 0**n) + assert (A**n).subs(n, 0) == eye(3) + assert (A**n).subs(n, 1) == zeros(3) + A = Matrix ([[2,0,0],[0,2,0],[0,0,2]]) + assert A**2.1 == diag (2**2.1, 2**2.1, 2**2.1) + assert A**I == diag (2**I, 2**I, 2**I) + A = Matrix([[0, 1, 0], [0, 0, 1], [0, 0, 1]]) + raises(ValueError, lambda: A**2.1) + raises(ValueError, lambda: A**I) + A = Matrix([[S.Half, S.Half], [S.Half, S.Half]]) + assert A**S.Half == A + A = Matrix([[1, 1],[3, 3]]) + assert A**S.Half == Matrix ([[S.Half, S.Half], [3*S.Half, 3*S.Half]]) + + +def test_issue_17247_expression_blowup_1(): + M = Matrix([[1+x, 1-x], [1-x, 1+x]]) + with dotprodsimp(True): + assert M.exp().expand() == Matrix([ + [ (exp(2*x) + exp(2))/2, (-exp(2*x) + exp(2))/2], + [(-exp(2*x) + exp(2))/2, (exp(2*x) + exp(2))/2]]) + + +def test_issue_17247_expression_blowup_2(): + M = Matrix([[1+x, 1-x], [1-x, 1+x]]) + with dotprodsimp(True): + P, J = M.jordan_form () + assert P*J*P.inv() + + +def test_issue_17247_expression_blowup_3(): + M = Matrix([[1+x, 1-x], [1-x, 1+x]]) + with dotprodsimp(True): + assert M**100 == Matrix([ + [633825300114114700748351602688*x**100 + 633825300114114700748351602688, 633825300114114700748351602688 - 633825300114114700748351602688*x**100], + [633825300114114700748351602688 - 633825300114114700748351602688*x**100, 633825300114114700748351602688*x**100 + 633825300114114700748351602688]]) + + +def test_issue_17247_expression_blowup_4(): +# This matrix takes extremely long on current master even with intermediate simplification so an abbreviated version is used. It is left here for test in case of future optimizations. +# M = Matrix(S('''[ +# [ -3/4, 45/32 - 37*I/16, 1/4 + I/2, -129/64 - 9*I/64, 1/4 - 5*I/16, 65/128 + 87*I/64, -9/32 - I/16, 183/256 - 97*I/128, 3/64 + 13*I/64, -23/32 - 59*I/256, 15/128 - 3*I/32, 19/256 + 551*I/1024], +# [-149/64 + 49*I/32, -177/128 - 1369*I/128, 125/64 + 87*I/64, -2063/256 + 541*I/128, 85/256 - 33*I/16, 805/128 + 2415*I/512, -219/128 + 115*I/256, 6301/4096 - 6609*I/1024, 119/128 + 143*I/128, -10879/2048 + 4343*I/4096, 129/256 - 549*I/512, 42533/16384 + 29103*I/8192], +# [ 1/2 - I, 9/4 + 55*I/16, -3/4, 45/32 - 37*I/16, 1/4 + I/2, -129/64 - 9*I/64, 1/4 - 5*I/16, 65/128 + 87*I/64, -9/32 - I/16, 183/256 - 97*I/128, 3/64 + 13*I/64, -23/32 - 59*I/256], +# [ -5/8 - 39*I/16, 2473/256 + 137*I/64, -149/64 + 49*I/32, -177/128 - 1369*I/128, 125/64 + 87*I/64, -2063/256 + 541*I/128, 85/256 - 33*I/16, 805/128 + 2415*I/512, -219/128 + 115*I/256, 6301/4096 - 6609*I/1024, 119/128 + 143*I/128, -10879/2048 + 4343*I/4096], +# [ 1 + I, -19/4 + 5*I/4, 1/2 - I, 9/4 + 55*I/16, -3/4, 45/32 - 37*I/16, 1/4 + I/2, -129/64 - 9*I/64, 1/4 - 5*I/16, 65/128 + 87*I/64, -9/32 - I/16, 183/256 - 97*I/128], +# [ 21/8 + I, -537/64 + 143*I/16, -5/8 - 39*I/16, 2473/256 + 137*I/64, -149/64 + 49*I/32, -177/128 - 1369*I/128, 125/64 + 87*I/64, -2063/256 + 541*I/128, 85/256 - 33*I/16, 805/128 + 2415*I/512, -219/128 + 115*I/256, 6301/4096 - 6609*I/1024], +# [ -2, 17/4 - 13*I/2, 1 + I, -19/4 + 5*I/4, 1/2 - I, 9/4 + 55*I/16, -3/4, 45/32 - 37*I/16, 1/4 + I/2, -129/64 - 9*I/64, 1/4 - 5*I/16, 65/128 + 87*I/64], +# [ 1/4 + 13*I/4, -825/64 - 147*I/32, 21/8 + I, -537/64 + 143*I/16, -5/8 - 39*I/16, 2473/256 + 137*I/64, -149/64 + 49*I/32, -177/128 - 1369*I/128, 125/64 + 87*I/64, -2063/256 + 541*I/128, 85/256 - 33*I/16, 805/128 + 2415*I/512], +# [ -4*I, 27/2 + 6*I, -2, 17/4 - 13*I/2, 1 + I, -19/4 + 5*I/4, 1/2 - I, 9/4 + 55*I/16, -3/4, 45/32 - 37*I/16, 1/4 + I/2, -129/64 - 9*I/64], +# [ 1/4 + 5*I/2, -23/8 - 57*I/16, 1/4 + 13*I/4, -825/64 - 147*I/32, 21/8 + I, -537/64 + 143*I/16, -5/8 - 39*I/16, 2473/256 + 137*I/64, -149/64 + 49*I/32, -177/128 - 1369*I/128, 125/64 + 87*I/64, -2063/256 + 541*I/128], +# [ -4, 9 - 5*I, -4*I, 27/2 + 6*I, -2, 17/4 - 13*I/2, 1 + I, -19/4 + 5*I/4, 1/2 - I, 9/4 + 55*I/16, -3/4, 45/32 - 37*I/16], +# [ -2*I, 119/8 + 29*I/4, 1/4 + 5*I/2, -23/8 - 57*I/16, 1/4 + 13*I/4, -825/64 - 147*I/32, 21/8 + I, -537/64 + 143*I/16, -5/8 - 39*I/16, 2473/256 + 137*I/64, -149/64 + 49*I/32, -177/128 - 1369*I/128]]''')) +# assert M**10 == Matrix([ +# [ 7*(-221393644768594642173548179825793834595 - 1861633166167425978847110897013541127952*I)/9671406556917033397649408, 15*(31670992489131684885307005100073928751695 + 10329090958303458811115024718207404523808*I)/77371252455336267181195264, 7*(-3710978679372178839237291049477017392703 + 1377706064483132637295566581525806894169*I)/19342813113834066795298816, (9727707023582419994616144751727760051598 - 59261571067013123836477348473611225724433*I)/9671406556917033397649408, (31896723509506857062605551443641668183707 + 54643444538699269118869436271152084599580*I)/38685626227668133590597632, (-2024044860947539028275487595741003997397402 + 130959428791783397562960461903698670485863*I)/309485009821345068724781056, 3*(26190251453797590396533756519358368860907 - 27221191754180839338002754608545400941638*I)/77371252455336267181195264, (1154643595139959842768960128434994698330461 + 3385496216250226964322872072260446072295634*I)/618970019642690137449562112, 3*(-31849347263064464698310044805285774295286 - 11877437776464148281991240541742691164309*I)/77371252455336267181195264, (4661330392283532534549306589669150228040221 - 4171259766019818631067810706563064103956871*I)/1237940039285380274899124224, (9598353794289061833850770474812760144506 + 358027153990999990968244906482319780943983*I)/309485009821345068724781056, (-9755135335127734571547571921702373498554177 - 4837981372692695195747379349593041939686540*I)/2475880078570760549798248448], +# [(-379516731607474268954110071392894274962069 - 422272153179747548473724096872271700878296*I)/77371252455336267181195264, (41324748029613152354787280677832014263339501 - 12715121258662668420833935373453570749288074*I)/1237940039285380274899124224, (-339216903907423793947110742819264306542397 + 494174755147303922029979279454787373566517*I)/77371252455336267181195264, (-18121350839962855576667529908850640619878381 - 37413012454129786092962531597292531089199003*I)/1237940039285380274899124224, (2489661087330511608618880408199633556675926 + 1137821536550153872137379935240732287260863*I)/309485009821345068724781056, (-136644109701594123227587016790354220062972119 + 110130123468183660555391413889600443583585272*I)/4951760157141521099596496896, (1488043981274920070468141664150073426459593 - 9691968079933445130866371609614474474327650*I)/1237940039285380274899124224, 27*(4636797403026872518131756991410164760195942 + 3369103221138229204457272860484005850416533*I)/4951760157141521099596496896, (-8534279107365915284081669381642269800472363 + 2241118846262661434336333368511372725482742*I)/1237940039285380274899124224, (60923350128174260992536531692058086830950875 - 263673488093551053385865699805250505661590126*I)/9903520314283042199192993792, (18520943561240714459282253753348921824172569 + 24846649186468656345966986622110971925703604*I)/4951760157141521099596496896, (-232781130692604829085973604213529649638644431 + 35981505277760667933017117949103953338570617*I)/9903520314283042199192993792], +# [ (8742968295129404279528270438201520488950 + 3061473358639249112126847237482570858327*I)/4835703278458516698824704, (-245657313712011778432792959787098074935273 + 253113767861878869678042729088355086740856*I)/38685626227668133590597632, (1947031161734702327107371192008011621193 - 19462330079296259148177542369999791122762*I)/9671406556917033397649408, (552856485625209001527688949522750288619217 + 392928441196156725372494335248099016686580*I)/77371252455336267181195264, (-44542866621905323121630214897126343414629 + 3265340021421335059323962377647649632959*I)/19342813113834066795298816, (136272594005759723105646069956434264218730 - 330975364731707309489523680957584684763587*I)/38685626227668133590597632, (27392593965554149283318732469825168894401 + 75157071243800133880129376047131061115278*I)/38685626227668133590597632, 7*(-357821652913266734749960136017214096276154 - 45509144466378076475315751988405961498243*I)/309485009821345068724781056, (104485001373574280824835174390219397141149 - 99041000529599568255829489765415726168162*I)/77371252455336267181195264, (1198066993119982409323525798509037696321291 + 4249784165667887866939369628840569844519936*I)/618970019642690137449562112, (-114985392587849953209115599084503853611014 - 52510376847189529234864487459476242883449*I)/77371252455336267181195264, (6094620517051332877965959223269600650951573 - 4683469779240530439185019982269137976201163*I)/1237940039285380274899124224], +# [ (611292255597977285752123848828590587708323 - 216821743518546668382662964473055912169502*I)/77371252455336267181195264, (-1144023204575811464652692396337616594307487 + 12295317806312398617498029126807758490062855*I)/309485009821345068724781056, (-374093027769390002505693378578475235158281 - 573533923565898290299607461660384634333639*I)/77371252455336267181195264, (47405570632186659000138546955372796986832987 - 2837476058950808941605000274055970055096534*I)/1237940039285380274899124224, (-571573207393621076306216726219753090535121 + 533381457185823100878764749236639320783831*I)/77371252455336267181195264, (-7096548151856165056213543560958582513797519 - 24035731898756040059329175131592138642195366*I)/618970019642690137449562112, (2396762128833271142000266170154694033849225 + 1448501087375679588770230529017516492953051*I)/309485009821345068724781056, (-150609293845161968447166237242456473262037053 + 92581148080922977153207018003184520294188436*I)/4951760157141521099596496896, 5*(270278244730804315149356082977618054486347 - 1997830155222496880429743815321662710091562*I)/1237940039285380274899124224, (62978424789588828258068912690172109324360330 + 44803641177219298311493356929537007630129097*I)/2475880078570760549798248448, 19*(-451431106327656743945775812536216598712236 + 114924966793632084379437683991151177407937*I)/1237940039285380274899124224, (63417747628891221594106738815256002143915995 - 261508229397507037136324178612212080871150958*I)/9903520314283042199192993792], +# [ (-2144231934021288786200752920446633703357 + 2305614436009705803670842248131563850246*I)/1208925819614629174706176, (-90720949337459896266067589013987007078153 - 221951119475096403601562347412753844534569*I)/19342813113834066795298816, (11590973613116630788176337262688659880376 + 6514520676308992726483494976339330626159*I)/4835703278458516698824704, 3*(-131776217149000326618649542018343107657237 + 79095042939612668486212006406818285287004*I)/38685626227668133590597632, (10100577916793945997239221374025741184951 - 28631383488085522003281589065994018550748*I)/9671406556917033397649408, 67*(10090295594251078955008130473573667572549 + 10449901522697161049513326446427839676762*I)/77371252455336267181195264, (-54270981296988368730689531355811033930513 - 3413683117592637309471893510944045467443*I)/19342813113834066795298816, (440372322928679910536575560069973699181278 - 736603803202303189048085196176918214409081*I)/77371252455336267181195264, (33220374714789391132887731139763250155295 + 92055083048787219934030779066298919603554*I)/38685626227668133590597632, 5*(-594638554579967244348856981610805281527116 - 82309245323128933521987392165716076704057*I)/309485009821345068724781056, (128056368815300084550013708313312073721955 - 114619107488668120303579745393765245911404*I)/77371252455336267181195264, 21*(59839959255173222962789517794121843393573 + 241507883613676387255359616163487405826334*I)/618970019642690137449562112], +# [ (-13454485022325376674626653802541391955147 + 184471402121905621396582628515905949793486*I)/19342813113834066795298816, (-6158730123400322562149780662133074862437105 - 3416173052604643794120262081623703514107476*I)/154742504910672534362390528, (770558003844914708453618983120686116100419 - 127758381209767638635199674005029818518766*I)/77371252455336267181195264, (-4693005771813492267479835161596671660631703 + 12703585094750991389845384539501921531449948*I)/309485009821345068724781056, (-295028157441149027913545676461260860036601 - 841544569970643160358138082317324743450770*I)/77371252455336267181195264, (56716442796929448856312202561538574275502893 + 7216818824772560379753073185990186711454778*I)/1237940039285380274899124224, 15*(-87061038932753366532685677510172566368387 + 61306141156647596310941396434445461895538*I)/154742504910672534362390528, (-3455315109680781412178133042301025723909347 - 24969329563196972466388460746447646686670670*I)/618970019642690137449562112, (2453418854160886481106557323699250865361849 + 1497886802326243014471854112161398141242514*I)/309485009821345068724781056, (-151343224544252091980004429001205664193082173 + 90471883264187337053549090899816228846836628*I)/4951760157141521099596496896, (1652018205533026103358164026239417416432989 - 9959733619236515024261775397109724431400162*I)/1237940039285380274899124224, 3*(40676374242956907656984876692623172736522006 + 31023357083037817469535762230872667581366205*I)/4951760157141521099596496896], +# [ (-1226990509403328460274658603410696548387 - 4131739423109992672186585941938392788458*I)/1208925819614629174706176, (162392818524418973411975140074368079662703 + 23706194236915374831230612374344230400704*I)/9671406556917033397649408, (-3935678233089814180000602553655565621193 + 2283744757287145199688061892165659502483*I)/1208925819614629174706176, (-2400210250844254483454290806930306285131 - 315571356806370996069052930302295432758205*I)/19342813113834066795298816, (13365917938215281056563183751673390817910 + 15911483133819801118348625831132324863881*I)/4835703278458516698824704, 3*(-215950551370668982657516660700301003897855 + 51684341999223632631602864028309400489378*I)/38685626227668133590597632, (20886089946811765149439844691320027184765 - 30806277083146786592790625980769214361844*I)/9671406556917033397649408, (562180634592713285745940856221105667874855 + 1031543963988260765153550559766662245114916*I)/77371252455336267181195264, (-65820625814810177122941758625652476012867 - 12429918324787060890804395323920477537595*I)/19342813113834066795298816, (319147848192012911298771180196635859221089 - 402403304933906769233365689834404519960394*I)/38685626227668133590597632, (23035615120921026080284733394359587955057 + 115351677687031786114651452775242461310624*I)/38685626227668133590597632, (-3426830634881892756966440108592579264936130 - 1022954961164128745603407283836365128598559*I)/309485009821345068724781056], +# [ (-192574788060137531023716449082856117537757 - 69222967328876859586831013062387845780692*I)/19342813113834066795298816, (2736383768828013152914815341491629299773262 - 2773252698016291897599353862072533475408743*I)/77371252455336267181195264, (-23280005281223837717773057436155921656805 + 214784953368021840006305033048142888879224*I)/19342813113834066795298816, (-3035247484028969580570400133318947903462326 - 2195168903335435855621328554626336958674325*I)/77371252455336267181195264, (984552428291526892214541708637840971548653 - 64006622534521425620714598573494988589378*I)/77371252455336267181195264, (-3070650452470333005276715136041262898509903 + 7286424705750810474140953092161794621989080*I)/154742504910672534362390528, (-147848877109756404594659513386972921139270 - 416306113044186424749331418059456047650861*I)/38685626227668133590597632, (55272118474097814260289392337160619494260781 + 7494019668394781211907115583302403519488058*I)/1237940039285380274899124224, (-581537886583682322424771088996959213068864 + 542191617758465339135308203815256798407429*I)/77371252455336267181195264, (-6422548983676355789975736799494791970390991 - 23524183982209004826464749309156698827737702*I)/618970019642690137449562112, 7*(180747195387024536886923192475064903482083 + 84352527693562434817771649853047924991804*I)/154742504910672534362390528, (-135485179036717001055310712747643466592387031 + 102346575226653028836678855697782273460527608*I)/4951760157141521099596496896], +# [ (3384238362616083147067025892852431152105 + 156724444932584900214919898954874618256*I)/604462909807314587353088, (-59558300950677430189587207338385764871866 + 114427143574375271097298201388331237478857*I)/4835703278458516698824704, (-1356835789870635633517710130971800616227 - 7023484098542340388800213478357340875410*I)/1208925819614629174706176, (234884918567993750975181728413524549575881 + 79757294640629983786895695752733890213506*I)/9671406556917033397649408, (-7632732774935120473359202657160313866419 + 2905452608512927560554702228553291839465*I)/1208925819614629174706176, (52291747908702842344842889809762246649489 - 520996778817151392090736149644507525892649*I)/19342813113834066795298816, (17472406829219127839967951180375981717322 + 23464704213841582137898905375041819568669*I)/4835703278458516698824704, (-911026971811893092350229536132730760943307 + 150799318130900944080399439626714846752360*I)/38685626227668133590597632, (26234457233977042811089020440646443590687 - 45650293039576452023692126463683727692890*I)/9671406556917033397649408, 3*(288348388717468992528382586652654351121357 + 454526517721403048270274049572136109264668*I)/77371252455336267181195264, (-91583492367747094223295011999405657956347 - 12704691128268298435362255538069612411331*I)/19342813113834066795298816, (411208730251327843849027957710164064354221 - 569898526380691606955496789378230959965898*I)/38685626227668133590597632], +# [ (27127513117071487872628354831658811211795 - 37765296987901990355760582016892124833857*I)/4835703278458516698824704, (1741779916057680444272938534338833170625435 + 3083041729779495966997526404685535449810378*I)/77371252455336267181195264, 3*(-60642236251815783728374561836962709533401 - 24630301165439580049891518846174101510744*I)/19342813113834066795298816, 3*(445885207364591681637745678755008757483408 - 350948497734812895032502179455610024541643*I)/38685626227668133590597632, (-47373295621391195484367368282471381775684 + 219122969294089357477027867028071400054973*I)/19342813113834066795298816, (-2801565819673198722993348253876353741520438 - 2250142129822658548391697042460298703335701*I)/77371252455336267181195264, (801448252275607253266997552356128790317119 - 50890367688077858227059515894356594900558*I)/77371252455336267181195264, (-5082187758525931944557763799137987573501207 + 11610432359082071866576699236013484487676124*I)/309485009821345068724781056, (-328925127096560623794883760398247685166830 - 643447969697471610060622160899409680422019*I)/77371252455336267181195264, 15*(2954944669454003684028194956846659916299765 + 33434406416888505837444969347824812608566*I)/1237940039285380274899124224, (-415749104352001509942256567958449835766827 + 479330966144175743357171151440020955412219*I)/77371252455336267181195264, 3*(-4639987285852134369449873547637372282914255 - 11994411888966030153196659207284951579243273*I)/1237940039285380274899124224], +# [ (-478846096206269117345024348666145495601 + 1249092488629201351470551186322814883283*I)/302231454903657293676544, (-17749319421930878799354766626365926894989 - 18264580106418628161818752318217357231971*I)/1208925819614629174706176, (2801110795431528876849623279389579072819 + 363258850073786330770713557775566973248*I)/604462909807314587353088, (-59053496693129013745775512127095650616252 + 78143588734197260279248498898321500167517*I)/4835703278458516698824704, (-283186724922498212468162690097101115349 - 6443437753863179883794497936345437398276*I)/1208925819614629174706176, (188799118826748909206887165661384998787543 + 84274736720556630026311383931055307398820*I)/9671406556917033397649408, (-5482217151670072904078758141270295025989 + 1818284338672191024475557065444481298568*I)/1208925819614629174706176, (56564463395350195513805521309731217952281 - 360208541416798112109946262159695452898431*I)/19342813113834066795298816, 11*(1259539805728870739006416869463689438068 + 1409136581547898074455004171305324917387*I)/4835703278458516698824704, 5*(-123701190701414554945251071190688818343325 + 30997157322590424677294553832111902279712*I)/38685626227668133590597632, (16130917381301373033736295883982414239781 - 32752041297570919727145380131926943374516*I)/9671406556917033397649408, (650301385108223834347093740500375498354925 + 899526407681131828596801223402866051809258*I)/77371252455336267181195264], +# [ (9011388245256140876590294262420614839483 + 8167917972423946282513000869327525382672*I)/1208925819614629174706176, (-426393174084720190126376382194036323028924 + 180692224825757525982858693158209545430621*I)/9671406556917033397649408, (24588556702197802674765733448108154175535 - 45091766022876486566421953254051868331066*I)/4835703278458516698824704, (1872113939365285277373877183750416985089691 + 3030392393733212574744122057679633775773130*I)/77371252455336267181195264, (-222173405538046189185754954524429864167549 - 75193157893478637039381059488387511299116*I)/19342813113834066795298816, (2670821320766222522963689317316937579844558 - 2645837121493554383087981511645435472169191*I)/77371252455336267181195264, 5*(-2100110309556476773796963197283876204940 + 41957457246479840487980315496957337371937*I)/19342813113834066795298816, (-5733743755499084165382383818991531258980593 - 3328949988392698205198574824396695027195732*I)/154742504910672534362390528, (707827994365259025461378911159398206329247 - 265730616623227695108042528694302299777294*I)/77371252455336267181195264, (-1442501604682933002895864804409322823788319 + 11504137805563265043376405214378288793343879*I)/309485009821345068724781056, (-56130472299445561499538726459719629522285 - 61117552419727805035810982426639329818864*I)/9671406556917033397649408, (39053692321126079849054272431599539429908717 - 10209127700342570953247177602860848130710666*I)/1237940039285380274899124224]]) + M = Matrix(S('''[ + [ -3/4, 45/32 - 37*I/16, 1/4 + I/2, -129/64 - 9*I/64, 1/4 - 5*I/16, 65/128 + 87*I/64], + [-149/64 + 49*I/32, -177/128 - 1369*I/128, 125/64 + 87*I/64, -2063/256 + 541*I/128, 85/256 - 33*I/16, 805/128 + 2415*I/512], + [ 1/2 - I, 9/4 + 55*I/16, -3/4, 45/32 - 37*I/16, 1/4 + I/2, -129/64 - 9*I/64], + [ -5/8 - 39*I/16, 2473/256 + 137*I/64, -149/64 + 49*I/32, -177/128 - 1369*I/128, 125/64 + 87*I/64, -2063/256 + 541*I/128], + [ 1 + I, -19/4 + 5*I/4, 1/2 - I, 9/4 + 55*I/16, -3/4, 45/32 - 37*I/16], + [ 21/8 + I, -537/64 + 143*I/16, -5/8 - 39*I/16, 2473/256 + 137*I/64, -149/64 + 49*I/32, -177/128 - 1369*I/128]]''')) + with dotprodsimp(True): + assert M**10 == Matrix(S('''[ + [ 7369525394972778926719607798014571861/604462909807314587353088 - 229284202061790301477392339912557559*I/151115727451828646838272, -19704281515163975949388435612632058035/1208925819614629174706176 + 14319858347987648723768698170712102887*I/302231454903657293676544, -3623281909451783042932142262164941211/604462909807314587353088 - 6039240602494288615094338643452320495*I/604462909807314587353088, 109260497799140408739847239685705357695/2417851639229258349412352 - 7427566006564572463236368211555511431*I/2417851639229258349412352, -16095803767674394244695716092817006641/2417851639229258349412352 + 10336681897356760057393429626719177583*I/1208925819614629174706176, -42207883340488041844332828574359769743/2417851639229258349412352 - 182332262671671273188016400290188468499*I/4835703278458516698824704], + [50566491050825573392726324995779608259/1208925819614629174706176 - 90047007594468146222002432884052362145*I/2417851639229258349412352, 74273703462900000967697427843983822011/1208925819614629174706176 + 265947522682943571171988741842776095421*I/1208925819614629174706176, -116900341394390200556829767923360888429/2417851639229258349412352 - 53153263356679268823910621474478756845*I/2417851639229258349412352, 195407378023867871243426523048612490249/1208925819614629174706176 - 1242417915995360200584837585002906728929*I/9671406556917033397649408, -863597594389821970177319682495878193/302231454903657293676544 + 476936100741548328800725360758734300481*I/9671406556917033397649408, -3154451590535653853562472176601754835575/19342813113834066795298816 - 232909875490506237386836489998407329215*I/2417851639229258349412352], + [ -1715444997702484578716037230949868543/302231454903657293676544 + 5009695651321306866158517287924120777*I/302231454903657293676544, -30551582497996879620371947949342101301/604462909807314587353088 - 7632518367986526187139161303331519629*I/151115727451828646838272, 312680739924495153190604170938220575/18889465931478580854784 - 108664334509328818765959789219208459*I/75557863725914323419136, -14693696966703036206178521686918865509/604462909807314587353088 + 72345386220900843930147151999899692401*I/1208925819614629174706176, -8218872496728882299722894680635296519/1208925819614629174706176 - 16776782833358893712645864791807664983*I/1208925819614629174706176, 143237839169380078671242929143670635137/2417851639229258349412352 + 2883817094806115974748882735218469447*I/2417851639229258349412352], + [ 3087979417831061365023111800749855987/151115727451828646838272 + 34441942370802869368851419102423997089*I/604462909807314587353088, -148309181940158040917731426845476175667/604462909807314587353088 - 263987151804109387844966835369350904919*I/9671406556917033397649408, 50259518594816377378747711930008883165/1208925819614629174706176 - 95713974916869240305450001443767979653*I/2417851639229258349412352, 153466447023875527996457943521467271119/2417851639229258349412352 + 517285524891117105834922278517084871349*I/2417851639229258349412352, -29184653615412989036678939366291205575/604462909807314587353088 - 27551322282526322041080173287022121083*I/1208925819614629174706176, 196404220110085511863671393922447671649/1208925819614629174706176 - 1204712019400186021982272049902206202145*I/9671406556917033397649408], + [ -2632581805949645784625606590600098779/151115727451828646838272 - 589957435912868015140272627522612771*I/37778931862957161709568, 26727850893953715274702844733506310247/302231454903657293676544 - 10825791956782128799168209600694020481*I/302231454903657293676544, -1036348763702366164044671908440791295/151115727451828646838272 + 3188624571414467767868303105288107375*I/151115727451828646838272, -36814959939970644875593411585393242449/604462909807314587353088 - 18457555789119782404850043842902832647*I/302231454903657293676544, 12454491297984637815063964572803058647/604462909807314587353088 - 340489532842249733975074349495329171*I/302231454903657293676544, -19547211751145597258386735573258916681/604462909807314587353088 + 87299583775782199663414539883938008933*I/1208925819614629174706176], + [ -40281994229560039213253423262678393183/604462909807314587353088 - 2939986850065527327299273003299736641*I/604462909807314587353088, 331940684638052085845743020267462794181/2417851639229258349412352 - 284574901963624403933361315517248458969*I/1208925819614629174706176, 6453843623051745485064693628073010961/302231454903657293676544 + 36062454107479732681350914931391590957*I/604462909807314587353088, -147665869053634695632880753646441962067/604462909807314587353088 - 305987938660447291246597544085345123927*I/9671406556917033397649408, 107821369195275772166593879711259469423/2417851639229258349412352 - 11645185518211204108659001435013326687*I/302231454903657293676544, 64121228424717666402009446088588091619/1208925819614629174706176 + 265557133337095047883844369272389762133*I/1208925819614629174706176]]''')) + + +def test_issue_17247_expression_blowup_5(): + M = Matrix(6, 6, lambda i, j: 1 + (-1)**(i+j)*I) + with dotprodsimp(True): + assert M.charpoly('x') == PurePoly(x**6 + (-6 - 6*I)*x**5 + 36*I*x**4, x, domain='EX') + + +def test_issue_17247_expression_blowup_6(): + M = Matrix(8, 8, [x+i for i in range (64)]) + with dotprodsimp(True): + assert M.det('bareiss') == 0 + + +def test_issue_17247_expression_blowup_7(): + M = Matrix(6, 6, lambda i, j: 1 + (-1)**(i+j)*I) + with dotprodsimp(True): + assert M.det('berkowitz') == 0 + + +def test_issue_17247_expression_blowup_8(): + M = Matrix(8, 8, [x+i for i in range (64)]) + with dotprodsimp(True): + assert M.det('lu') == 0 + + +def test_issue_17247_expression_blowup_9(): + M = Matrix(8, 8, [x+i for i in range (64)]) + with dotprodsimp(True): + assert M.rref() == (Matrix([ + [1, 0, -1, -2, -3, -4, -5, -6], + [0, 1, 2, 3, 4, 5, 6, 7], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0]]), (0, 1)) + + +def test_issue_17247_expression_blowup_10(): + M = Matrix(6, 6, lambda i, j: 1 + (-1)**(i+j)*I) + with dotprodsimp(True): + assert M.cofactor(0, 0) == 0 + + +def test_issue_17247_expression_blowup_11(): + M = Matrix(6, 6, lambda i, j: 1 + (-1)**(i+j)*I) + with dotprodsimp(True): + assert M.cofactor_matrix() == Matrix(6, 6, [0]*36) + + +def test_issue_17247_expression_blowup_12(): + M = Matrix(6, 6, lambda i, j: 1 + (-1)**(i+j)*I) + with dotprodsimp(True): + assert M.eigenvals() == {6: 1, 6*I: 1, 0: 4} + + +def test_issue_17247_expression_blowup_13(): + M = Matrix([ + [ 0, 1 - x, x + 1, 1 - x], + [1 - x, x + 1, 0, x + 1], + [ 0, 1 - x, x + 1, 1 - x], + [ 0, 0, 1 - x, 0]]) + + ev = M.eigenvects() + assert ev[0] == (0, 2, [Matrix([0, -1, 0, 1])]) + assert ev[1][0] == x - sqrt(2)*(x - 1) + 1 + assert ev[1][1] == 1 + assert ev[1][2][0].expand(deep=False, numer=True) == Matrix([ + [(-x + sqrt(2)*(x - 1) - 1)/(x - 1)], + [-4*x/(x**2 - 2*x + 1) + (x + 1)*(x - sqrt(2)*(x - 1) + 1)/(x**2 - 2*x + 1)], + [(-x + sqrt(2)*(x - 1) - 1)/(x - 1)], + [1] + ]) + + assert ev[2][0] == x + sqrt(2)*(x - 1) + 1 + assert ev[2][1] == 1 + assert ev[2][2][0].expand(deep=False, numer=True) == Matrix([ + [(-x - sqrt(2)*(x - 1) - 1)/(x - 1)], + [-4*x/(x**2 - 2*x + 1) + (x + 1)*(x + sqrt(2)*(x - 1) + 1)/(x**2 - 2*x + 1)], + [(-x - sqrt(2)*(x - 1) - 1)/(x - 1)], + [1] + ]) + + +def test_issue_17247_expression_blowup_14(): + M = Matrix(8, 8, ([1+x, 1-x]*4 + [1-x, 1+x]*4)*4) + with dotprodsimp(True): + assert M.echelon_form() == Matrix([ + [x + 1, 1 - x, x + 1, 1 - x, x + 1, 1 - x, x + 1, 1 - x], + [ 0, 4*x, 0, 4*x, 0, 4*x, 0, 4*x], + [ 0, 0, 0, 0, 0, 0, 0, 0], + [ 0, 0, 0, 0, 0, 0, 0, 0], + [ 0, 0, 0, 0, 0, 0, 0, 0], + [ 0, 0, 0, 0, 0, 0, 0, 0], + [ 0, 0, 0, 0, 0, 0, 0, 0], + [ 0, 0, 0, 0, 0, 0, 0, 0]]) + + +def test_issue_17247_expression_blowup_15(): + M = Matrix(8, 8, ([1+x, 1-x]*4 + [1-x, 1+x]*4)*4) + with dotprodsimp(True): + assert M.rowspace() == [Matrix([[x + 1, 1 - x, x + 1, 1 - x, x + 1, 1 - x, x + 1, 1 - x]]), Matrix([[0, 4*x, 0, 4*x, 0, 4*x, 0, 4*x]])] + + +def test_issue_17247_expression_blowup_16(): + M = Matrix(8, 8, ([1+x, 1-x]*4 + [1-x, 1+x]*4)*4) + with dotprodsimp(True): + assert M.columnspace() == [Matrix([[x + 1],[1 - x],[x + 1],[1 - x],[x + 1],[1 - x],[x + 1],[1 - x]]), Matrix([[1 - x],[x + 1],[1 - x],[x + 1],[1 - x],[x + 1],[1 - x],[x + 1]])] + + +def test_issue_17247_expression_blowup_17(): + M = Matrix(8, 8, [x+i for i in range (64)]) + with dotprodsimp(True): + assert M.nullspace() == [ + Matrix([[1],[-2],[1],[0],[0],[0],[0],[0]]), + Matrix([[2],[-3],[0],[1],[0],[0],[0],[0]]), + Matrix([[3],[-4],[0],[0],[1],[0],[0],[0]]), + Matrix([[4],[-5],[0],[0],[0],[1],[0],[0]]), + Matrix([[5],[-6],[0],[0],[0],[0],[1],[0]]), + Matrix([[6],[-7],[0],[0],[0],[0],[0],[1]])] + + +def test_issue_17247_expression_blowup_18(): + M = Matrix(6, 6, ([1+x, 1-x]*3 + [1-x, 1+x]*3)*3) + with dotprodsimp(True): + assert not M.is_nilpotent() + + +def test_issue_17247_expression_blowup_19(): + M = Matrix(S('''[ + [ -3/4, 0, 1/4 + I/2, 0], + [ 0, -177/128 - 1369*I/128, 0, -2063/256 + 541*I/128], + [ 1/2 - I, 0, 0, 0], + [ 0, 0, 0, -177/128 - 1369*I/128]]''')) + with dotprodsimp(True): + assert not M.is_diagonalizable() + + +def test_issue_17247_expression_blowup_20(): + M = Matrix([ + [x + 1, 1 - x, 0, 0], + [1 - x, x + 1, 0, x + 1], + [ 0, 1 - x, x + 1, 0], + [ 0, 0, 0, x + 1]]) + with dotprodsimp(True): + assert M.diagonalize() == (Matrix([ + [1, 1, 0, (x + 1)/(x - 1)], + [1, -1, 0, 0], + [1, 1, 1, 0], + [0, 0, 0, 1]]), + Matrix([ + [2, 0, 0, 0], + [0, 2*x, 0, 0], + [0, 0, x + 1, 0], + [0, 0, 0, x + 1]])) + + +def test_issue_17247_expression_blowup_21(): + M = Matrix(S('''[ + [ -3/4, 45/32 - 37*I/16, 0, 0], + [-149/64 + 49*I/32, -177/128 - 1369*I/128, 0, -2063/256 + 541*I/128], + [ 0, 9/4 + 55*I/16, 2473/256 + 137*I/64, 0], + [ 0, 0, 0, -177/128 - 1369*I/128]]''')) + with dotprodsimp(True): + assert M.inv(method='GE') == Matrix(S('''[ + [-26194832/3470993 - 31733264*I/3470993, 156352/3470993 + 10325632*I/3470993, 0, -7741283181072/3306971225785 + 2999007604624*I/3306971225785], + [4408224/3470993 - 9675328*I/3470993, -2422272/3470993 + 1523712*I/3470993, 0, -1824666489984/3306971225785 - 1401091949952*I/3306971225785], + [-26406945676288/22270005630769 + 10245925485056*I/22270005630769, 7453523312640/22270005630769 + 1601616519168*I/22270005630769, 633088/6416033 - 140288*I/6416033, 872209227109521408/21217636514687010905 + 6066405081802389504*I/21217636514687010905], + [0, 0, 0, -11328/952745 + 87616*I/952745]]''')) + + +def test_issue_17247_expression_blowup_22(): + M = Matrix(S('''[ + [ -3/4, 45/32 - 37*I/16, 0, 0], + [-149/64 + 49*I/32, -177/128 - 1369*I/128, 0, -2063/256 + 541*I/128], + [ 0, 9/4 + 55*I/16, 2473/256 + 137*I/64, 0], + [ 0, 0, 0, -177/128 - 1369*I/128]]''')) + with dotprodsimp(True): + assert M.inv(method='LU') == Matrix(S('''[ + [-26194832/3470993 - 31733264*I/3470993, 156352/3470993 + 10325632*I/3470993, 0, -7741283181072/3306971225785 + 2999007604624*I/3306971225785], + [4408224/3470993 - 9675328*I/3470993, -2422272/3470993 + 1523712*I/3470993, 0, -1824666489984/3306971225785 - 1401091949952*I/3306971225785], + [-26406945676288/22270005630769 + 10245925485056*I/22270005630769, 7453523312640/22270005630769 + 1601616519168*I/22270005630769, 633088/6416033 - 140288*I/6416033, 872209227109521408/21217636514687010905 + 6066405081802389504*I/21217636514687010905], + [0, 0, 0, -11328/952745 + 87616*I/952745]]''')) + + +def test_issue_17247_expression_blowup_23(): + M = Matrix(S('''[ + [ -3/4, 45/32 - 37*I/16, 0, 0], + [-149/64 + 49*I/32, -177/128 - 1369*I/128, 0, -2063/256 + 541*I/128], + [ 0, 9/4 + 55*I/16, 2473/256 + 137*I/64, 0], + [ 0, 0, 0, -177/128 - 1369*I/128]]''')) + with dotprodsimp(True): + assert M.inv(method='ADJ').expand() == Matrix(S('''[ + [-26194832/3470993 - 31733264*I/3470993, 156352/3470993 + 10325632*I/3470993, 0, -7741283181072/3306971225785 + 2999007604624*I/3306971225785], + [4408224/3470993 - 9675328*I/3470993, -2422272/3470993 + 1523712*I/3470993, 0, -1824666489984/3306971225785 - 1401091949952*I/3306971225785], + [-26406945676288/22270005630769 + 10245925485056*I/22270005630769, 7453523312640/22270005630769 + 1601616519168*I/22270005630769, 633088/6416033 - 140288*I/6416033, 872209227109521408/21217636514687010905 + 6066405081802389504*I/21217636514687010905], + [0, 0, 0, -11328/952745 + 87616*I/952745]]''')) + + +def test_issue_17247_expression_blowup_24(): + M = SparseMatrix(S('''[ + [ -3/4, 45/32 - 37*I/16, 0, 0], + [-149/64 + 49*I/32, -177/128 - 1369*I/128, 0, -2063/256 + 541*I/128], + [ 0, 9/4 + 55*I/16, 2473/256 + 137*I/64, 0], + [ 0, 0, 0, -177/128 - 1369*I/128]]''')) + with dotprodsimp(True): + assert M.inv(method='CH') == Matrix(S('''[ + [-26194832/3470993 - 31733264*I/3470993, 156352/3470993 + 10325632*I/3470993, 0, -7741283181072/3306971225785 + 2999007604624*I/3306971225785], + [4408224/3470993 - 9675328*I/3470993, -2422272/3470993 + 1523712*I/3470993, 0, -1824666489984/3306971225785 - 1401091949952*I/3306971225785], + [-26406945676288/22270005630769 + 10245925485056*I/22270005630769, 7453523312640/22270005630769 + 1601616519168*I/22270005630769, 633088/6416033 - 140288*I/6416033, 872209227109521408/21217636514687010905 + 6066405081802389504*I/21217636514687010905], + [0, 0, 0, -11328/952745 + 87616*I/952745]]''')) + + +def test_issue_17247_expression_blowup_25(): + M = SparseMatrix(S('''[ + [ -3/4, 45/32 - 37*I/16, 0, 0], + [-149/64 + 49*I/32, -177/128 - 1369*I/128, 0, -2063/256 + 541*I/128], + [ 0, 9/4 + 55*I/16, 2473/256 + 137*I/64, 0], + [ 0, 0, 0, -177/128 - 1369*I/128]]''')) + with dotprodsimp(True): + assert M.inv(method='LDL') == Matrix(S('''[ + [-26194832/3470993 - 31733264*I/3470993, 156352/3470993 + 10325632*I/3470993, 0, -7741283181072/3306971225785 + 2999007604624*I/3306971225785], + [4408224/3470993 - 9675328*I/3470993, -2422272/3470993 + 1523712*I/3470993, 0, -1824666489984/3306971225785 - 1401091949952*I/3306971225785], + [-26406945676288/22270005630769 + 10245925485056*I/22270005630769, 7453523312640/22270005630769 + 1601616519168*I/22270005630769, 633088/6416033 - 140288*I/6416033, 872209227109521408/21217636514687010905 + 6066405081802389504*I/21217636514687010905], + [0, 0, 0, -11328/952745 + 87616*I/952745]]''')) + + +def test_issue_17247_expression_blowup_26(): + M = Matrix(S('''[ + [ -3/4, 45/32 - 37*I/16, 1/4 + I/2, -129/64 - 9*I/64, 1/4 - 5*I/16, 65/128 + 87*I/64, -9/32 - I/16, 183/256 - 97*I/128], + [-149/64 + 49*I/32, -177/128 - 1369*I/128, 125/64 + 87*I/64, -2063/256 + 541*I/128, 85/256 - 33*I/16, 805/128 + 2415*I/512, -219/128 + 115*I/256, 6301/4096 - 6609*I/1024], + [ 1/2 - I, 9/4 + 55*I/16, -3/4, 45/32 - 37*I/16, 1/4 + I/2, -129/64 - 9*I/64, 1/4 - 5*I/16, 65/128 + 87*I/64], + [ -5/8 - 39*I/16, 2473/256 + 137*I/64, -149/64 + 49*I/32, -177/128 - 1369*I/128, 125/64 + 87*I/64, -2063/256 + 541*I/128, 85/256 - 33*I/16, 805/128 + 2415*I/512], + [ 1 + I, -19/4 + 5*I/4, 1/2 - I, 9/4 + 55*I/16, -3/4, 45/32 - 37*I/16, 1/4 + I/2, -129/64 - 9*I/64], + [ 21/8 + I, -537/64 + 143*I/16, -5/8 - 39*I/16, 2473/256 + 137*I/64, -149/64 + 49*I/32, -177/128 - 1369*I/128, 125/64 + 87*I/64, -2063/256 + 541*I/128], + [ -2, 17/4 - 13*I/2, 1 + I, -19/4 + 5*I/4, 1/2 - I, 9/4 + 55*I/16, -3/4, 45/32 - 37*I/16], + [ 1/4 + 13*I/4, -825/64 - 147*I/32, 21/8 + I, -537/64 + 143*I/16, -5/8 - 39*I/16, 2473/256 + 137*I/64, -149/64 + 49*I/32, -177/128 - 1369*I/128]]''')) + with dotprodsimp(True): + assert M.rank() == 4 + + +def test_issue_17247_expression_blowup_27(): + M = Matrix([ + [ 0, 1 - x, x + 1, 1 - x], + [1 - x, x + 1, 0, x + 1], + [ 0, 1 - x, x + 1, 1 - x], + [ 0, 0, 1 - x, 0]]) + with dotprodsimp(True): + P, J = M.jordan_form() + assert P.expand() == Matrix(S('''[ + [ 0, 4*x/(x**2 - 2*x + 1), -(-17*x**4 + 12*sqrt(2)*x**4 - 4*sqrt(2)*x**3 + 6*x**3 - 6*x - 4*sqrt(2)*x + 12*sqrt(2) + 17)/(-7*x**4 + 5*sqrt(2)*x**4 - 6*sqrt(2)*x**3 + 8*x**3 - 2*x**2 + 8*x + 6*sqrt(2)*x - 5*sqrt(2) - 7), -(12*sqrt(2)*x**4 + 17*x**4 - 6*x**3 - 4*sqrt(2)*x**3 - 4*sqrt(2)*x + 6*x - 17 + 12*sqrt(2))/(7*x**4 + 5*sqrt(2)*x**4 - 6*sqrt(2)*x**3 - 8*x**3 + 2*x**2 - 8*x + 6*sqrt(2)*x - 5*sqrt(2) + 7)], + [x - 1, x/(x - 1) + 1/(x - 1), (-7*x**3 + 5*sqrt(2)*x**3 - x**2 + sqrt(2)*x**2 - sqrt(2)*x - x - 5*sqrt(2) - 7)/(-3*x**3 + 2*sqrt(2)*x**3 - 2*sqrt(2)*x**2 + 3*x**2 + 2*sqrt(2)*x + 3*x - 3 - 2*sqrt(2)), (7*x**3 + 5*sqrt(2)*x**3 + x**2 + sqrt(2)*x**2 - sqrt(2)*x + x - 5*sqrt(2) + 7)/(2*sqrt(2)*x**3 + 3*x**3 - 3*x**2 - 2*sqrt(2)*x**2 - 3*x + 2*sqrt(2)*x - 2*sqrt(2) + 3)], + [ 0, 1, -(-3*x**2 + 2*sqrt(2)*x**2 + 2*x - 3 - 2*sqrt(2))/(-x**2 + sqrt(2)*x**2 - 2*sqrt(2)*x + 1 + sqrt(2)), -(2*sqrt(2)*x**2 + 3*x**2 - 2*x - 2*sqrt(2) + 3)/(x**2 + sqrt(2)*x**2 - 2*sqrt(2)*x - 1 + sqrt(2))], + [1 - x, 0, 1, 1]]''')).expand() + assert J == Matrix(S('''[ + [0, 1, 0, 0], + [0, 0, 0, 0], + [0, 0, x - sqrt(2)*(x - 1) + 1, 0], + [0, 0, 0, x + sqrt(2)*(x - 1) + 1]]''')) + + +def test_issue_17247_expression_blowup_28(): + M = Matrix(S('''[ + [ -3/4, 45/32 - 37*I/16, 0, 0], + [-149/64 + 49*I/32, -177/128 - 1369*I/128, 0, -2063/256 + 541*I/128], + [ 0, 9/4 + 55*I/16, 2473/256 + 137*I/64, 0], + [ 0, 0, 0, -177/128 - 1369*I/128]]''')) + with dotprodsimp(True): + assert M.singular_values() == S('''[ + sqrt(14609315/131072 + sqrt(64789115132571/2147483648 - 2*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3) + 76627253330829751075/(35184372088832*sqrt(64789115132571/4294967296 + 3546944054712886603889144627/(110680464442257309696*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3)) + 2*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3))) - 3546944054712886603889144627/(110680464442257309696*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3)))/2 + sqrt(64789115132571/4294967296 + 3546944054712886603889144627/(110680464442257309696*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3)) + 2*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3))/2), + sqrt(14609315/131072 - sqrt(64789115132571/2147483648 - 2*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3) + 76627253330829751075/(35184372088832*sqrt(64789115132571/4294967296 + 3546944054712886603889144627/(110680464442257309696*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3)) + 2*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3))) - 3546944054712886603889144627/(110680464442257309696*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3)))/2 + sqrt(64789115132571/4294967296 + 3546944054712886603889144627/(110680464442257309696*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3)) + 2*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3))/2), + sqrt(14609315/131072 - sqrt(64789115132571/4294967296 + 3546944054712886603889144627/(110680464442257309696*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3)) + 2*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3))/2 + sqrt(64789115132571/2147483648 - 2*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3) - 76627253330829751075/(35184372088832*sqrt(64789115132571/4294967296 + 3546944054712886603889144627/(110680464442257309696*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3)) + 2*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3))) - 3546944054712886603889144627/(110680464442257309696*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3)))/2), + sqrt(14609315/131072 - sqrt(64789115132571/4294967296 + 3546944054712886603889144627/(110680464442257309696*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3)) + 2*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3))/2 - sqrt(64789115132571/2147483648 - 2*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3) - 76627253330829751075/(35184372088832*sqrt(64789115132571/4294967296 + 3546944054712886603889144627/(110680464442257309696*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3)) + 2*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3))) - 3546944054712886603889144627/(110680464442257309696*(25895222463957462655758224991455280215303/633825300114114700748351602688 + sqrt(1213909058710955930446995195883114969038524625997915131236390724543989220134670)*I/22282920707136844948184236032)**(1/3)))/2)]''') + + +def test_issue_16823(): + # This still needs to be fixed if not using dotprodsimp. + M = Matrix(S('''[ + [1+I,-19/4+5/4*I,1/2-I,9/4+55/16*I,-3/4,45/32-37/16*I,1/4+1/2*I,-129/64-9/64*I,1/4-5/16*I,65/128+87/64*I,-9/32-1/16*I,183/256-97/128*I,3/64+13/64*I,-23/32-59/256*I,15/128-3/32*I,19/256+551/1024*I], + [21/8+I,-537/64+143/16*I,-5/8-39/16*I,2473/256+137/64*I,-149/64+49/32*I,-177/128-1369/128*I,125/64+87/64*I,-2063/256+541/128*I,85/256-33/16*I,805/128+2415/512*I,-219/128+115/256*I,6301/4096-6609/1024*I,119/128+143/128*I,-10879/2048+4343/4096*I,129/256-549/512*I,42533/16384+29103/8192*I], + [-2,17/4-13/2*I,1+I,-19/4+5/4*I,1/2-I,9/4+55/16*I,-3/4,45/32-37/16*I,1/4+1/2*I,-129/64-9/64*I,1/4-5/16*I,65/128+87/64*I,-9/32-1/16*I,183/256-97/128*I,3/64+13/64*I,-23/32-59/256*I], + [1/4+13/4*I,-825/64-147/32*I,21/8+I,-537/64+143/16*I,-5/8-39/16*I,2473/256+137/64*I,-149/64+49/32*I,-177/128-1369/128*I,125/64+87/64*I,-2063/256+541/128*I,85/256-33/16*I,805/128+2415/512*I,-219/128+115/256*I,6301/4096-6609/1024*I,119/128+143/128*I,-10879/2048+4343/4096*I], + [-4*I,27/2+6*I,-2,17/4-13/2*I,1+I,-19/4+5/4*I,1/2-I,9/4+55/16*I,-3/4,45/32-37/16*I,1/4+1/2*I,-129/64-9/64*I,1/4-5/16*I,65/128+87/64*I,-9/32-1/16*I,183/256-97/128*I], + [1/4+5/2*I,-23/8-57/16*I,1/4+13/4*I,-825/64-147/32*I,21/8+I,-537/64+143/16*I,-5/8-39/16*I,2473/256+137/64*I,-149/64+49/32*I,-177/128-1369/128*I,125/64+87/64*I,-2063/256+541/128*I,85/256-33/16*I,805/128+2415/512*I,-219/128+115/256*I,6301/4096-6609/1024*I], + [-4,9-5*I,-4*I,27/2+6*I,-2,17/4-13/2*I,1+I,-19/4+5/4*I,1/2-I,9/4+55/16*I,-3/4,45/32-37/16*I,1/4+1/2*I,-129/64-9/64*I,1/4-5/16*I,65/128+87/64*I], + [-2*I,119/8+29/4*I,1/4+5/2*I,-23/8-57/16*I,1/4+13/4*I,-825/64-147/32*I,21/8+I,-537/64+143/16*I,-5/8-39/16*I,2473/256+137/64*I,-149/64+49/32*I,-177/128-1369/128*I,125/64+87/64*I,-2063/256+541/128*I,85/256-33/16*I,805/128+2415/512*I], + [0,-6,-4,9-5*I,-4*I,27/2+6*I,-2,17/4-13/2*I,1+I,-19/4+5/4*I,1/2-I,9/4+55/16*I,-3/4,45/32-37/16*I,1/4+1/2*I,-129/64-9/64*I], + [1,-9/4+3*I,-2*I,119/8+29/4*I,1/4+5/2*I,-23/8-57/16*I,1/4+13/4*I,-825/64-147/32*I,21/8+I,-537/64+143/16*I,-5/8-39/16*I,2473/256+137/64*I,-149/64+49/32*I,-177/128-1369/128*I,125/64+87/64*I,-2063/256+541/128*I], + [0,-4*I,0,-6,-4,9-5*I,-4*I,27/2+6*I,-2,17/4-13/2*I,1+I,-19/4+5/4*I,1/2-I,9/4+55/16*I,-3/4,45/32-37/16*I], + [0,1/4+1/2*I,1,-9/4+3*I,-2*I,119/8+29/4*I,1/4+5/2*I,-23/8-57/16*I,1/4+13/4*I,-825/64-147/32*I,21/8+I,-537/64+143/16*I,-5/8-39/16*I,2473/256+137/64*I,-149/64+49/32*I,-177/128-1369/128*I]]''')) + with dotprodsimp(True): + assert M.rank() == 8 + + +def test_issue_18531(): + # solve_linear_system still needs fixing but the rref works. + M = Matrix([ + [1, 1, 1, 1, 1, 0, 1, 0, 0], + [1 + sqrt(2), -1 + sqrt(2), 1 - sqrt(2), -sqrt(2) - 1, 1, 1, -1, 1, 1], + [-5 + 2*sqrt(2), -5 - 2*sqrt(2), -5 - 2*sqrt(2), -5 + 2*sqrt(2), -7, 2, -7, -2, 0], + [-3*sqrt(2) - 1, 1 - 3*sqrt(2), -1 + 3*sqrt(2), 1 + 3*sqrt(2), -7, -5, 7, -5, 3], + [7 - 4*sqrt(2), 4*sqrt(2) + 7, 4*sqrt(2) + 7, 7 - 4*sqrt(2), 7, -12, 7, 12, 0], + [-1 + 3*sqrt(2), 1 + 3*sqrt(2), -3*sqrt(2) - 1, 1 - 3*sqrt(2), 7, -5, -7, -5, 3], + [-3 + 2*sqrt(2), -3 - 2*sqrt(2), -3 - 2*sqrt(2), -3 + 2*sqrt(2), -1, 2, -1, -2, 0], + [1 - sqrt(2), -sqrt(2) - 1, 1 + sqrt(2), -1 + sqrt(2), -1, 1, 1, 1, 1] + ]) + with dotprodsimp(True): + assert M.rref() == (Matrix([ + [1, 0, 0, 0, 0, 0, 0, 0, S(1)/2], + [0, 1, 0, 0, 0, 0, 0, 0, -S(1)/2], + [0, 0, 1, 0, 0, 0, 0, 0, S(1)/2], + [0, 0, 0, 1, 0, 0, 0, 0, -S(1)/2], + [0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, -S(1)/2], + [0, 0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, -S(1)/2]]), (0, 1, 2, 3, 4, 5, 6, 7)) + + +def test_creation(): + raises(ValueError, lambda: Matrix(5, 5, range(20))) + raises(ValueError, lambda: Matrix(5, -1, [])) + raises(IndexError, lambda: Matrix((1, 2))[2]) + with raises(IndexError): + Matrix((1, 2))[3] = 5 + + assert Matrix() == Matrix([]) == Matrix(0, 0, []) + assert Matrix([[]]) == Matrix(1, 0, []) + assert Matrix([[], []]) == Matrix(2, 0, []) + + # anything used to be allowed in a matrix + with warns_deprecated_sympy(): + assert Matrix([[[1], (2,)]]).tolist() == [[[1], (2,)]] + with warns_deprecated_sympy(): + assert Matrix([[[1], (2,)]]).T.tolist() == [[[1]], [(2,)]] + M = Matrix([[0]]) + with warns_deprecated_sympy(): + M[0, 0] = S.EmptySet + + a = Matrix([[x, 0], [0, 0]]) + m = a + assert m.cols == m.rows + assert m.cols == 2 + assert m[:] == [x, 0, 0, 0] + + b = Matrix(2, 2, [x, 0, 0, 0]) + m = b + assert m.cols == m.rows + assert m.cols == 2 + assert m[:] == [x, 0, 0, 0] + + assert a == b + + assert Matrix(b) == b + + c23 = Matrix(2, 3, range(1, 7)) + c13 = Matrix(1, 3, range(7, 10)) + c = Matrix([c23, c13]) + assert c.cols == 3 + assert c.rows == 3 + assert c[:] == [1, 2, 3, 4, 5, 6, 7, 8, 9] + + assert Matrix(eye(2)) == eye(2) + assert ImmutableMatrix(ImmutableMatrix(eye(2))) == ImmutableMatrix(eye(2)) + assert ImmutableMatrix(c) == c.as_immutable() + assert Matrix(ImmutableMatrix(c)) == ImmutableMatrix(c).as_mutable() + + assert c is not Matrix(c) + + dat = [[ones(3,2), ones(3,3)*2], [ones(2,3)*3, ones(2,2)*4]] + M = Matrix(dat) + assert M == Matrix([ + [1, 1, 2, 2, 2], + [1, 1, 2, 2, 2], + [1, 1, 2, 2, 2], + [3, 3, 3, 4, 4], + [3, 3, 3, 4, 4]]) + assert M.tolist() != dat + # keep block form if evaluate=False + assert Matrix(dat, evaluate=False).tolist() == dat + A = MatrixSymbol("A", 2, 2) + dat = [ones(2), A] + assert Matrix(dat) == Matrix([ + [ 1, 1], + [ 1, 1], + [A[0, 0], A[0, 1]], + [A[1, 0], A[1, 1]]]) + with warns_deprecated_sympy(): + assert Matrix(dat, evaluate=False).tolist() == [[i] for i in dat] + + # 0-dim tolerance + assert Matrix([ones(2), ones(0)]) == Matrix([ones(2)]) + raises(ValueError, lambda: Matrix([ones(2), ones(0, 3)])) + raises(ValueError, lambda: Matrix([ones(2), ones(3, 0)])) + + # mix of Matrix and iterable + M = Matrix([[1, 2], [3, 4]]) + M2 = Matrix([M, (5, 6)]) + assert M2 == Matrix([[1, 2], [3, 4], [5, 6]]) + + +def test_irregular_block(): + assert Matrix.irregular(3, ones(2,1), ones(3,3)*2, ones(2,2)*3, + ones(1,1)*4, ones(2,2)*5, ones(1,2)*6, ones(1,2)*7) == Matrix([ + [1, 2, 2, 2, 3, 3], + [1, 2, 2, 2, 3, 3], + [4, 2, 2, 2, 5, 5], + [6, 6, 7, 7, 5, 5]]) + + +def test_slicing(): + m0 = eye(4) + assert m0[:3, :3] == eye(3) + assert m0[2:4, 0:2] == zeros(2) + + m1 = Matrix(3, 3, lambda i, j: i + j) + assert m1[0, :] == Matrix(1, 3, (0, 1, 2)) + assert m1[1:3, 1] == Matrix(2, 1, (2, 3)) + + m2 = Matrix([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]) + assert m2[:, -1] == Matrix(4, 1, [3, 7, 11, 15]) + assert m2[-2:, :] == Matrix([[8, 9, 10, 11], [12, 13, 14, 15]]) + + +def test_submatrix_assignment(): + m = zeros(4) + m[2:4, 2:4] = eye(2) + assert m == Matrix(((0, 0, 0, 0), + (0, 0, 0, 0), + (0, 0, 1, 0), + (0, 0, 0, 1))) + m[:2, :2] = eye(2) + assert m == eye(4) + m[:, 0] = Matrix(4, 1, (1, 2, 3, 4)) + assert m == Matrix(((1, 0, 0, 0), + (2, 1, 0, 0), + (3, 0, 1, 0), + (4, 0, 0, 1))) + m[:, :] = zeros(4) + assert m == zeros(4) + m[:, :] = [(1, 2, 3, 4), (5, 6, 7, 8), (9, 10, 11, 12), (13, 14, 15, 16)] + assert m == Matrix(((1, 2, 3, 4), + (5, 6, 7, 8), + (9, 10, 11, 12), + (13, 14, 15, 16))) + m[:2, 0] = [0, 0] + assert m == Matrix(((0, 2, 3, 4), + (0, 6, 7, 8), + (9, 10, 11, 12), + (13, 14, 15, 16))) + + +def test_reshape(): + m0 = eye(3) + assert m0.reshape(1, 9) == Matrix(1, 9, (1, 0, 0, 0, 1, 0, 0, 0, 1)) + m1 = Matrix(3, 4, lambda i, j: i + j) + assert m1.reshape( + 4, 3) == Matrix(((0, 1, 2), (3, 1, 2), (3, 4, 2), (3, 4, 5))) + assert m1.reshape(2, 6) == Matrix(((0, 1, 2, 3, 1, 2), (3, 4, 2, 3, 4, 5))) + + +def test_applyfunc(): + m0 = eye(3) + assert m0.applyfunc(lambda x: 2*x) == eye(3)*2 + assert m0.applyfunc(lambda x: 0) == zeros(3) + + +def test_expand(): + m0 = Matrix([[x*(x + y), 2], [((x + y)*y)*x, x*(y + x*(x + y))]]) + # Test if expand() returns a matrix + m1 = m0.expand() + assert m1 == Matrix( + [[x*y + x**2, 2], [x*y**2 + y*x**2, x*y + y*x**2 + x**3]]) + + a = Symbol('a', real=True) + + assert Matrix([exp(I*a)]).expand(complex=True) == \ + Matrix([cos(a) + I*sin(a)]) + + assert Matrix([[0, 1, 2], [0, 0, -1], [0, 0, 0]]).exp() == Matrix([ + [1, 1, Rational(3, 2)], + [0, 1, -1], + [0, 0, 1]] + ) + + +def test_refine(): + m0 = Matrix([[Abs(x)**2, sqrt(x**2)], + [sqrt(x**2)*Abs(y)**2, sqrt(y**2)*Abs(x)**2]]) + m1 = m0.refine(Q.real(x) & Q.real(y)) + assert m1 == Matrix([[x**2, Abs(x)], [y**2*Abs(x), x**2*Abs(y)]]) + + m1 = m0.refine(Q.positive(x) & Q.positive(y)) + assert m1 == Matrix([[x**2, x], [x*y**2, x**2*y]]) + + m1 = m0.refine(Q.negative(x) & Q.negative(y)) + assert m1 == Matrix([[x**2, -x], [-x*y**2, -x**2*y]]) + + +def test_random(): + M = randMatrix(3, 3) + M = randMatrix(3, 3, seed=3) + assert M == randMatrix(3, 3, seed=3) + + M = randMatrix(3, 4, 0, 150) + M = randMatrix(3, seed=4, symmetric=True) + assert M == randMatrix(3, seed=4, symmetric=True) + + S = M.copy() + S.simplify() + assert S == M # doesn't fail when elements are Numbers, not int + + rng = random.Random(4) + assert M == randMatrix(3, symmetric=True, prng=rng) + + # Ensure symmetry + for size in (10, 11): # Test odd and even + for percent in (100, 70, 30): + M = randMatrix(size, symmetric=True, percent=percent, prng=rng) + assert M == M.T + + M = randMatrix(10, min=1, percent=70) + zero_count = 0 + for i in range(M.shape[0]): + for j in range(M.shape[1]): + if M[i, j] == 0: + zero_count += 1 + assert zero_count == 30 + + +def test_inverse(): + A = eye(4) + assert A.inv() == eye(4) + assert A.inv(method="LU") == eye(4) + assert A.inv(method="ADJ") == eye(4) + assert A.inv(method="CH") == eye(4) + assert A.inv(method="LDL") == eye(4) + assert A.inv(method="QR") == eye(4) + A = Matrix([[2, 3, 5], + [3, 6, 2], + [8, 3, 6]]) + Ainv = A.inv() + assert A*Ainv == eye(3) + assert A.inv(method="LU") == Ainv + assert A.inv(method="ADJ") == Ainv + assert A.inv(method="CH") == Ainv + assert A.inv(method="LDL") == Ainv + assert A.inv(method="QR") == Ainv + + AA = Matrix([[0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0], + [1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0], + [1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1], + [1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0], + [1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1], + [0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0], + [1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1], + [0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1], + [1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0], + [0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0], + [1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0], + [0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 1], + [1, 1, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0], + [0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 0], + [1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1], + [0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1], + [1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1], + [0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1], + [0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1], + [0, 0, 0, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0], + [0, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0]]) + assert AA.inv(method="BLOCK") * AA == eye(AA.shape[0]) + # test that immutability is not a problem + cls = ImmutableMatrix + m = cls([[48, 49, 31], + [ 9, 71, 94], + [59, 28, 65]]) + assert all(type(m.inv(s)) is cls for s in 'GE ADJ LU CH LDL QR'.split()) + cls = ImmutableSparseMatrix + m = cls([[48, 49, 31], + [ 9, 71, 94], + [59, 28, 65]]) + assert all(type(m.inv(s)) is cls for s in 'GE ADJ LU CH LDL QR'.split()) + + +def test_inverse_symbolic_float_issue_26821(): + Tau, Tau_syn_in, Tau_syn_ex, C_m, Tau_syn_gap = symbols("Tau Tau_syn_in Tau_syn_ex C_m Tau_syn_gap") + __h = symbols("__h") + + M = Matrix([ + [0,0,0,0,0,(1.0*Tau*__h-1.0*Tau_syn_in*__h)/(2.0*Tau-1.0*Tau_syn_in),-1.0*Tau*Tau_syn_in/(2.0*Tau-1.0*Tau_syn_in)], + [0,0,0,0,0,(-1.0*Tau*__h+1.0*Tau_syn_in*__h)/(2.0*Tau*Tau_syn_in-1.0*Tau_syn_in**2),1.0], + [0,(1.0*Tau*__h-1.0*Tau_syn_ex*__h)/(2.0*Tau-1.0*Tau_syn_ex),-1.0*Tau*Tau_syn_ex/(2.0*Tau-1.0*Tau_syn_ex),0,0,0,0], + [0,(-1.0*Tau*__h+1.0*Tau_syn_ex*__h)/(2.0*Tau*Tau_syn_ex-1.0*Tau_syn_ex**2),1.0,0,0,0,0], + [0,0,0,(1.0*Tau*__h-1.0*Tau_syn_gap*__h)/(2.0*Tau-1.0*Tau_syn_gap),-1.0*Tau*Tau_syn_gap/(2.0*Tau-1.0*Tau_syn_gap),0,0], + [0,0,0,(-1.0*Tau*__h+1.0*Tau_syn_gap*__h)/(2.0*Tau*Tau_syn_gap-1.0*Tau_syn_gap**2),1.0,0,0], + [1.0,-1.0*Tau*Tau_syn_ex*__h/(2.0*C_m*Tau-1.0*C_m*Tau_syn_ex),0,-1.0*Tau*Tau_syn_gap*__h/(2.0*C_m*Tau-1.0*C_m*Tau_syn_gap),0,-1.0*Tau*Tau_syn_in*__h/(2.0*C_m*Tau-1.0*C_m*Tau_syn_in),0] + ]) + + Mi = M.inv() + + assert (M*Mi - eye(7)).applyfunc(cancel) == zeros(7) + + # https://github.com/sympy/sympy/issues/26821 + # Previously very large floats were in the result. + assert max(abs(f) for f in Mi.atoms(Float)) < 1e3 + + +@slow +def test_matrix_exponential_issue_26821(): + # The symbol names matter in the original bug... + a, b, c, d, e = symbols("Tau, Tau_syn_in, Tau_syn_ex, C_m, Tau_syn_gap") + t = symbols("__h") + M = Matrix([ + [ 0, 1.0, 0, 0, 0, 0, 0], + [-1/b**2, -2/b, 0, 0, 0, 0, 0], + [ 0, 0, 0, 1.0, 0, 0, 0], + [ 0, 0, -1/c**2, -2/c, 0, 0, 0], + [ 0, 0, 0, 0, 0, 1, 0], + [ 0, 0, 0, 0, -1/e**2, -2/e, 0], + [ 1/d, 0, 1/d, 0, 1/d, 0, -1/a] + ]) + + Me = (t*M).exp() + assert (Me.diff(t) - M*Me).applyfunc(cancel) == zeros(7) + # https://github.com/sympy/sympy/issues/26821 + # Previously very large floats were in the result. + assert max(abs(f) for f in Me.atoms(Float)) < 1e3 + + +def test_jacobian_hessian(): + L = Matrix(1, 2, [x**2*y, 2*y**2 + x*y]) + syms = [x, y] + assert L.jacobian(syms) == Matrix([[2*x*y, x**2], [y, 4*y + x]]) + + L = Matrix(1, 2, [x, x**2*y**3]) + assert L.jacobian(syms) == Matrix([[1, 0], [2*x*y**3, x**2*3*y**2]]) + + f = x**2*y + syms = [x, y] + assert hessian(f, syms) == Matrix([[2*y, 2*x], [2*x, 0]]) + + f = x**2*y**3 + assert hessian(f, syms) == \ + Matrix([[2*y**3, 6*x*y**2], [6*x*y**2, 6*x**2*y]]) + + f = z + x*y**2 + g = x**2 + 2*y**3 + ans = Matrix([[0, 2*y], + [2*y, 2*x]]) + assert ans == hessian(f, Matrix([x, y])) + assert ans == hessian(f, Matrix([x, y]).T) + assert hessian(f, (y, x), [g]) == Matrix([ + [ 0, 6*y**2, 2*x], + [6*y**2, 2*x, 2*y], + [ 2*x, 2*y, 0]]) + + +def test_wronskian(): + assert wronskian([cos(x), sin(x)], x) == cos(x)**2 + sin(x)**2 + assert wronskian([exp(x), exp(2*x)], x) == exp(3*x) + assert wronskian([exp(x), x], x) == exp(x) - x*exp(x) + assert wronskian([1, x, x**2], x) == 2 + w1 = -6*exp(x)*sin(x)*x + 6*cos(x)*exp(x)*x**2 - 6*exp(x)*cos(x)*x - \ + exp(x)*cos(x)*x**3 + exp(x)*sin(x)*x**3 + assert wronskian([exp(x), cos(x), x**3], x).expand() == w1 + assert wronskian([exp(x), cos(x), x**3], x, method='berkowitz').expand() \ + == w1 + w2 = -x**3*cos(x)**2 - x**3*sin(x)**2 - 6*x*cos(x)**2 - 6*x*sin(x)**2 + assert wronskian([sin(x), cos(x), x**3], x).expand() == w2 + assert wronskian([sin(x), cos(x), x**3], x, method='berkowitz').expand() \ + == w2 + assert wronskian([], x) == 1 + + +def test_xreplace(): + assert Matrix([[1, x], [x, 4]]).xreplace({x: 5}) == \ + Matrix([[1, 5], [5, 4]]) + assert Matrix([[x, 2], [x + y, 4]]).xreplace({x: -1, y: -2}) == \ + Matrix([[-1, 2], [-3, 4]]) + for cls in all_classes: + assert Matrix([[2, 0], [0, 2]]) == cls.eye(2).xreplace({1: 2}) + + +def test_simplify(): + n = Symbol('n') + f = Function('f') + + M = Matrix([[ 1/x + 1/y, (x + x*y) / x ], + [ (f(x) + y*f(x))/f(x), 2 * (1/n - cos(n * pi)/n) / pi ]]) + M.simplify() + assert M == Matrix([[ (x + y)/(x * y), 1 + y ], + [ 1 + y, 2*((1 - 1*cos(pi*n))/(pi*n)) ]]) + eq = (1 + x)**2 + M = Matrix([[eq]]) + M.simplify() + assert M == Matrix([[eq]]) + M.simplify(ratio=oo) + assert M == Matrix([[eq.simplify(ratio=oo)]]) + + n = Symbol('n') + f = Function('f') + + M = ImmutableMatrix([ + [ 1/x + 1/y, (x + x*y) / x ], + [ (f(x) + y*f(x))/f(x), 2 * (1/n - cos(n * pi)/n) / pi ] + ]) + assert M.simplify() == Matrix([ + [ (x + y)/(x * y), 1 + y ], + [ 1 + y, 2*((1 - 1*cos(pi*n))/(pi*n)) ] + ]) + + eq = (1 + x)**2 + M = ImmutableMatrix([[eq]]) + assert M.simplify() == Matrix([[eq]]) + assert M.simplify(ratio=oo) == Matrix([[eq.simplify(ratio=oo)]]) + + assert simplify(ImmutableMatrix([[sin(x)**2 + cos(x)**2]])) == \ + ImmutableMatrix([[1]]) + + # https://github.com/sympy/sympy/issues/19353 + m = Matrix([[30, 2], [3, 4]]) + assert (1/(m.trace())).simplify() == Rational(1, 34) + +def test_transpose(): + M = Matrix([[1, 2, 3, 4, 5, 6, 7, 8, 9, 0], + [1, 2, 3, 4, 5, 6, 7, 8, 9, 0]]) + assert M.T == Matrix( [ [1, 1], + [2, 2], + [3, 3], + [4, 4], + [5, 5], + [6, 6], + [7, 7], + [8, 8], + [9, 9], + [0, 0] ]) + assert M.T.T == M + assert M.T == M.transpose() + + +def test_conj_dirac(): + raises(AttributeError, lambda: eye(3).D) + + M = Matrix([[1, I, I, I], + [0, 1, I, I], + [0, 0, 1, I], + [0, 0, 0, 1]]) + + assert M.D == Matrix([[ 1, 0, 0, 0], + [-I, 1, 0, 0], + [-I, -I, -1, 0], + [-I, -I, I, -1]]) + + +def test_trace(): + M = Matrix([[1, 0, 0], + [0, 5, 0], + [0, 0, 8]]) + assert M.trace() == 14 + + +def test_shape(): + m = Matrix(1, 2, [0, 0]) + assert m.shape == (1, 2) + M = Matrix([[x, 0, 0], + [0, y, 0]]) + assert M.shape == (2, 3) + + +def test_col_row_op(): + M = Matrix([[x, 0, 0], + [0, y, 0]]) + M.row_op(1, lambda r, j: r + j + 1) + assert M == Matrix([[x, 0, 0], + [1, y + 2, 3]]) + + M.col_op(0, lambda c, j: c + y**j) + assert M == Matrix([[x + 1, 0, 0], + [1 + y, y + 2, 3]]) + + # neither row nor slice give copies that allow the original matrix to + # be changed + assert M.row(0) == Matrix([[x + 1, 0, 0]]) + r1 = M.row(0) + r1[0] = 42 + assert M[0, 0] == x + 1 + r1 = M[0, :-1] # also testing negative slice + r1[0] = 42 + assert M[0, 0] == x + 1 + c1 = M.col(0) + assert c1 == Matrix([x + 1, 1 + y]) + c1[0] = 0 + assert M[0, 0] == x + 1 + c1 = M[:, 0] + c1[0] = 42 + assert M[0, 0] == x + 1 + + +def test_row_mult(): + M = Matrix([[1,2,3], + [4,5,6]]) + M.row_mult(1,3) + assert M[1,0] == 12 + assert M[0,0] == 1 + assert M[1,2] == 18 + + +def test_row_add(): + M = Matrix([[1,2,3], + [4,5,6], + [1,1,1]]) + M.row_add(2,0,5) + assert M[0,0] == 6 + assert M[1,0] == 4 + assert M[0,2] == 8 + + +def test_zip_row_op(): + for cls in mutable_classes: # XXX: immutable matrices don't support row ops + M = cls.eye(3) + M.zip_row_op(1, 0, lambda v, u: v + 2*u) + assert M == cls([[1, 0, 0], + [2, 1, 0], + [0, 0, 1]]) + + M = cls.eye(3)*2 + M[0, 1] = -1 + M.zip_row_op(1, 0, lambda v, u: v + 2*u); M + assert M == cls([[2, -1, 0], + [4, 0, 0], + [0, 0, 2]]) + + +def test_issue_3950(): + m = Matrix([1, 2, 3]) + a = Matrix([1, 2, 3]) + b = Matrix([2, 2, 3]) + assert not (m in []) + assert not (m in [1]) + assert m != 1 + assert m == a + assert m != b + + +def test_issue_3981(): + class Index1: + def __index__(self): + return 1 + + class Index2: + def __index__(self): + return 2 + index1 = Index1() + index2 = Index2() + + m = Matrix([1, 2, 3]) + + assert m[index2] == 3 + + m[index2] = 5 + assert m[2] == 5 + + m = Matrix([[1, 2, 3], [4, 5, 6]]) + assert m[index1, index2] == 6 + assert m[1, index2] == 6 + assert m[index1, 2] == 6 + + m[index1, index2] = 4 + assert m[1, 2] == 4 + m[1, index2] = 6 + assert m[1, 2] == 6 + m[index1, 2] = 8 + assert m[1, 2] == 8 + + +def test_is_upper(): + a = Matrix([[1, 2, 3]]) + assert a.is_upper is True + a = Matrix([[1], [2], [3]]) + assert a.is_upper is False + a = zeros(4, 2) + assert a.is_upper is True + + +def test_is_lower(): + a = Matrix([[1, 2, 3]]) + assert a.is_lower is False + a = Matrix([[1], [2], [3]]) + assert a.is_lower is True + + +def test_is_nilpotent(): + a = Matrix(4, 4, [0, 2, 1, 6, 0, 0, 1, 2, 0, 0, 0, 3, 0, 0, 0, 0]) + assert a.is_nilpotent() + a = Matrix([[1, 0], [0, 1]]) + assert not a.is_nilpotent() + a = Matrix([]) + assert a.is_nilpotent() + + +def test_zeros_ones_fill(): + n, m = 3, 5 + + a = zeros(n, m) + a.fill( 5 ) + + b = 5 * ones(n, m) + + assert a == b + assert a.rows == b.rows == 3 + assert a.cols == b.cols == 5 + assert a.shape == b.shape == (3, 5) + assert zeros(2) == zeros(2, 2) + assert ones(2) == ones(2, 2) + assert zeros(2, 3) == Matrix(2, 3, [0]*6) + assert ones(2, 3) == Matrix(2, 3, [1]*6) + + a.fill(0) + assert a == zeros(n, m) + + +def test_empty_zeros(): + a = zeros(0) + assert a == Matrix() + a = zeros(0, 2) + assert a.rows == 0 + assert a.cols == 2 + a = zeros(2, 0) + assert a.rows == 2 + assert a.cols == 0 + + +def test_issue_3749(): + a = Matrix([[x**2, x*y], [x*sin(y), x*cos(y)]]) + assert a.diff(x) == Matrix([[2*x, y], [sin(y), cos(y)]]) + assert Matrix([ + [x, -x, x**2], + [exp(x), 1/x - exp(-x), x + 1/x]]).limit(x, oo) == \ + Matrix([[oo, -oo, oo], [oo, 0, oo]]) + assert Matrix([ + [(exp(x) - 1)/x, 2*x + y*x, x**x ], + [1/x, abs(x), abs(sin(x + 1))]]).limit(x, 0) == \ + Matrix([[1, 0, 1], [oo, 0, sin(1)]]) + assert a.integrate(x) == Matrix([ + [Rational(1, 3)*x**3, y*x**2/2], + [x**2*sin(y)/2, x**2*cos(y)/2]]) + + +def test_inv_iszerofunc(): + A = eye(4) + A.col_swap(0, 1) + for method in "GE", "LU": + assert A.inv(method=method, iszerofunc=lambda x: x == 0) == \ + A.inv(method="ADJ") + + +def test_jacobian_metrics(): + rho, phi = symbols("rho,phi") + X = Matrix([rho*cos(phi), rho*sin(phi)]) + Y = Matrix([rho, phi]) + J = X.jacobian(Y) + assert J == X.jacobian(Y.T) + assert J == (X.T).jacobian(Y) + assert J == (X.T).jacobian(Y.T) + g = J.T*eye(J.shape[0])*J + g = g.applyfunc(trigsimp) + assert g == Matrix([[1, 0], [0, rho**2]]) + + +def test_jacobian2(): + rho, phi = symbols("rho,phi") + X = Matrix([rho*cos(phi), rho*sin(phi), rho**2]) + Y = Matrix([rho, phi]) + J = Matrix([ + [cos(phi), -rho*sin(phi)], + [sin(phi), rho*cos(phi)], + [ 2*rho, 0], + ]) + assert X.jacobian(Y) == J + + +def test_issue_4564(): + X = Matrix([exp(x + y + z), exp(x + y + z), exp(x + y + z)]) + Y = Matrix([x, y, z]) + for i in range(1, 3): + for j in range(1, 3): + X_slice = X[:i, :] + Y_slice = Y[:j, :] + J = X_slice.jacobian(Y_slice) + assert J.rows == i + assert J.cols == j + for k in range(j): + assert J[:, k] == X_slice + + +def test_nonvectorJacobian(): + X = Matrix([[exp(x + y + z), exp(x + y + z)], + [exp(x + y + z), exp(x + y + z)]]) + raises(TypeError, lambda: X.jacobian(Matrix([x, y, z]))) + X = X[0, :] + Y = Matrix([[x, y], [x, z]]) + raises(TypeError, lambda: X.jacobian(Y)) + raises(TypeError, lambda: X.jacobian(Matrix([ [x, y], [x, z] ]))) + + +def test_vec(): + m = Matrix([[1, 3], [2, 4]]) + m_vec = m.vec() + assert m_vec.cols == 1 + for i in range(4): + assert m_vec[i] == i + 1 + + +def test_vech(): + m = Matrix([[1, 2], [2, 3]]) + m_vech = m.vech() + assert m_vech.cols == 1 + for i in range(3): + assert m_vech[i] == i + 1 + m_vech = m.vech(diagonal=False) + assert m_vech[0] == 2 + + m = Matrix([[1, x*(x + y)], [y*x + x**2, 1]]) + m_vech = m.vech(diagonal=False) + assert m_vech[0] == y*x + x**2 + + m = Matrix([[1, x*(x + y)], [y*x, 1]]) + m_vech = m.vech(diagonal=False, check_symmetry=False) + assert m_vech[0] == y*x + + raises(ShapeError, lambda: Matrix([[1, 3]]).vech()) + raises(ValueError, lambda: Matrix([[1, 3], [2, 4]]).vech()) + raises(ShapeError, lambda: Matrix([[1, 3]]).vech()) + raises(ValueError, lambda: Matrix([[1, 3], [2, 4]]).vech()) + + +def test_diag(): + # mostly tested in testcommonmatrix.py + assert diag([1, 2, 3]) == Matrix([1, 2, 3]) + m = [1, 2, [3]] + raises(ValueError, lambda: diag(m)) + assert diag(m, strict=False) == Matrix([1, 2, 3]) + + +def test_inv_block(): + a = Matrix([[1, 2], [2, 3]]) + b = Matrix([[3, x], [y, 3]]) + c = Matrix([[3, x, 3], [y, 3, z], [x, y, z]]) + A = diag(a, b, b) + assert A.inv(try_block_diag=True) == diag(a.inv(), b.inv(), b.inv()) + A = diag(a, b, c) + assert A.inv(try_block_diag=True) == diag(a.inv(), b.inv(), c.inv()) + A = diag(a, c, b) + assert A.inv(try_block_diag=True) == diag(a.inv(), c.inv(), b.inv()) + A = diag(a, a, b, a, c, a) + assert A.inv(try_block_diag=True) == diag( + a.inv(), a.inv(), b.inv(), a.inv(), c.inv(), a.inv()) + assert A.inv(try_block_diag=True, method="ADJ") == diag( + a.inv(method="ADJ"), a.inv(method="ADJ"), b.inv(method="ADJ"), + a.inv(method="ADJ"), c.inv(method="ADJ"), a.inv(method="ADJ")) + + +def test_creation_args(): + """ + Check that matrix dimensions can be specified using any reasonable type + (see issue 4614). + """ + raises(ValueError, lambda: zeros(3, -1)) + raises(TypeError, lambda: zeros(1, 2, 3, 4)) + assert zeros(int(3)) == zeros(3) + assert zeros(Integer(3)) == zeros(3) + raises(ValueError, lambda: zeros(3.)) + assert eye(int(3)) == eye(3) + assert eye(Integer(3)) == eye(3) + raises(ValueError, lambda: eye(3.)) + assert ones(int(3), Integer(4)) == ones(3, 4) + raises(TypeError, lambda: Matrix(5)) + raises(TypeError, lambda: Matrix(1, 2)) + raises(ValueError, lambda: Matrix([1, [2]])) + + +def test_diagonal_symmetrical(): + m = Matrix(2, 2, [0, 1, 1, 0]) + assert not m.is_diagonal() + assert m.is_symmetric() + assert m.is_symmetric(simplify=False) + + m = Matrix(2, 2, [1, 0, 0, 1]) + assert m.is_diagonal() + + m = diag(1, 2, 3) + assert m.is_diagonal() + assert m.is_symmetric() + + m = Matrix(3, 3, [1, 0, 0, 0, 2, 0, 0, 0, 3]) + assert m == diag(1, 2, 3) + + m = Matrix(2, 3, zeros(2, 3)) + assert not m.is_symmetric() + assert m.is_diagonal() + + m = Matrix(((5, 0), (0, 6), (0, 0))) + assert m.is_diagonal() + + m = Matrix(((5, 0, 0), (0, 6, 0))) + assert m.is_diagonal() + + m = Matrix(3, 3, [1, x**2 + 2*x + 1, y, (x + 1)**2, 2, 0, y, 0, 3]) + assert m.is_symmetric() + assert not m.is_symmetric(simplify=False) + assert m.expand().is_symmetric(simplify=False) + + +def test_diagonalization(): + m = Matrix([[1, 2+I], [2-I, 3]]) + assert m.is_diagonalizable() + + m = Matrix(3, 2, [-3, 1, -3, 20, 3, 10]) + assert not m.is_diagonalizable() + assert not m.is_symmetric() + raises(NonSquareMatrixError, lambda: m.diagonalize()) + + # diagonalizable + m = diag(1, 2, 3) + (P, D) = m.diagonalize() + assert P == eye(3) + assert D == m + + m = Matrix(2, 2, [0, 1, 1, 0]) + assert m.is_symmetric() + assert m.is_diagonalizable() + (P, D) = m.diagonalize() + assert P.inv() * m * P == D + + m = Matrix(2, 2, [1, 0, 0, 3]) + assert m.is_symmetric() + assert m.is_diagonalizable() + (P, D) = m.diagonalize() + assert P.inv() * m * P == D + assert P == eye(2) + assert D == m + + m = Matrix(2, 2, [1, 1, 0, 0]) + assert m.is_diagonalizable() + (P, D) = m.diagonalize() + assert P.inv() * m * P == D + + m = Matrix(3, 3, [1, 2, 0, 0, 3, 0, 2, -4, 2]) + assert m.is_diagonalizable() + (P, D) = m.diagonalize() + assert P.inv() * m * P == D + for i in P: + assert i.as_numer_denom()[1] == 1 + + m = Matrix(2, 2, [1, 0, 0, 0]) + assert m.is_diagonal() + assert m.is_diagonalizable() + (P, D) = m.diagonalize() + assert P.inv() * m * P == D + assert P == Matrix([[0, 1], [1, 0]]) + + # diagonalizable, complex only + m = Matrix(2, 2, [0, 1, -1, 0]) + assert not m.is_diagonalizable(True) + raises(MatrixError, lambda: m.diagonalize(True)) + assert m.is_diagonalizable() + (P, D) = m.diagonalize() + assert P.inv() * m * P == D + + # not diagonalizable + m = Matrix(2, 2, [0, 1, 0, 0]) + assert not m.is_diagonalizable() + raises(MatrixError, lambda: m.diagonalize()) + + m = Matrix(3, 3, [-3, 1, -3, 20, 3, 10, 2, -2, 4]) + assert not m.is_diagonalizable() + raises(MatrixError, lambda: m.diagonalize()) + + # symbolic + a, b, c, d = symbols('a b c d') + m = Matrix(2, 2, [a, c, c, b]) + assert m.is_symmetric() + assert m.is_diagonalizable() + + +def test_issue_15887(): + # Mutable matrix should not use cache + a = MutableDenseMatrix([[0, 1], [1, 0]]) + assert a.is_diagonalizable() is True + a[1, 0] = 0 + assert a.is_diagonalizable() is False + + a = MutableDenseMatrix([[0, 1], [1, 0]]) + a.diagonalize() + a[1, 0] = 0 + raises(MatrixError, lambda: a.diagonalize()) + + +def test_jordan_form(): + + m = Matrix(3, 2, [-3, 1, -3, 20, 3, 10]) + raises(NonSquareMatrixError, lambda: m.jordan_form()) + + # diagonalizable + m = Matrix(3, 3, [7, -12, 6, 10, -19, 10, 12, -24, 13]) + Jmust = Matrix(3, 3, [-1, 0, 0, 0, 1, 0, 0, 0, 1]) + P, J = m.jordan_form() + assert Jmust == J + assert Jmust == m.diagonalize()[1] + + # m = Matrix(3, 3, [0, 6, 3, 1, 3, 1, -2, 2, 1]) + # m.jordan_form() # very long + # m.jordan_form() # + + # diagonalizable, complex only + + # Jordan cells + # complexity: one of eigenvalues is zero + m = Matrix(3, 3, [0, 1, 0, -4, 4, 0, -2, 1, 2]) + # The blocks are ordered according to the value of their eigenvalues, + # in order to make the matrix compatible with .diagonalize() + Jmust = Matrix(3, 3, [2, 1, 0, 0, 2, 0, 0, 0, 2]) + P, J = m.jordan_form() + assert Jmust == J + + # complexity: all of eigenvalues are equal + m = Matrix(3, 3, [2, 6, -15, 1, 1, -5, 1, 2, -6]) + # Jmust = Matrix(3, 3, [-1, 0, 0, 0, -1, 1, 0, 0, -1]) + # same here see 1456ff + Jmust = Matrix(3, 3, [-1, 1, 0, 0, -1, 0, 0, 0, -1]) + P, J = m.jordan_form() + assert Jmust == J + + # complexity: two of eigenvalues are zero + m = Matrix(3, 3, [4, -5, 2, 5, -7, 3, 6, -9, 4]) + Jmust = Matrix(3, 3, [0, 1, 0, 0, 0, 0, 0, 0, 1]) + P, J = m.jordan_form() + assert Jmust == J + + m = Matrix(4, 4, [6, 5, -2, -3, -3, -1, 3, 3, 2, 1, -2, -3, -1, 1, 5, 5]) + Jmust = Matrix(4, 4, [2, 1, 0, 0, + 0, 2, 0, 0, + 0, 0, 2, 1, + 0, 0, 0, 2] + ) + P, J = m.jordan_form() + assert Jmust == J + + m = Matrix(4, 4, [6, 2, -8, -6, -3, 2, 9, 6, 2, -2, -8, -6, -1, 0, 3, 4]) + # Jmust = Matrix(4, 4, [2, 0, 0, 0, 0, 2, 1, 0, 0, 0, 2, 0, 0, 0, 0, -2]) + # same here see 1456ff + Jmust = Matrix(4, 4, [-2, 0, 0, 0, + 0, 2, 1, 0, + 0, 0, 2, 0, + 0, 0, 0, 2]) + P, J = m.jordan_form() + assert Jmust == J + + m = Matrix(4, 4, [5, 4, 2, 1, 0, 1, -1, -1, -1, -1, 3, 0, 1, 1, -1, 2]) + assert not m.is_diagonalizable() + Jmust = Matrix(4, 4, [1, 0, 0, 0, 0, 2, 0, 0, 0, 0, 4, 1, 0, 0, 0, 4]) + P, J = m.jordan_form() + assert Jmust == J + + # checking for maximum precision to remain unchanged + m = Matrix([[Float('1.0', precision=110), Float('2.0', precision=110)], + [Float('3.14159265358979323846264338327', precision=110), Float('4.0', precision=110)]]) + P, J = m.jordan_form() + for term in J.values(): + if isinstance(term, Float): + assert term._prec == 110 + + +def test_jordan_form_complex_issue_9274(): + A = Matrix([[ 2, 4, 1, 0], + [-4, 2, 0, 1], + [ 0, 0, 2, 4], + [ 0, 0, -4, 2]]) + p = 2 - 4*I + q = 2 + 4*I + Jmust1 = Matrix([[p, 1, 0, 0], + [0, p, 0, 0], + [0, 0, q, 1], + [0, 0, 0, q]]) + Jmust2 = Matrix([[q, 1, 0, 0], + [0, q, 0, 0], + [0, 0, p, 1], + [0, 0, 0, p]]) + P, J = A.jordan_form() + assert J == Jmust1 or J == Jmust2 + assert simplify(P*J*P.inv()) == A + + +def test_issue_10220(): + # two non-orthogonal Jordan blocks with eigenvalue 1 + M = Matrix([[1, 0, 0, 1], + [0, 1, 1, 0], + [0, 0, 1, 1], + [0, 0, 0, 1]]) + P, J = M.jordan_form() + assert P == Matrix([[0, 1, 0, 1], + [1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, 0]]) + assert J == Matrix([ + [1, 1, 0, 0], + [0, 1, 1, 0], + [0, 0, 1, 0], + [0, 0, 0, 1]]) + + +def test_jordan_form_issue_15858(): + A = Matrix([ + [1, 1, 1, 0], + [-2, -1, 0, -1], + [0, 0, -1, -1], + [0, 0, 2, 1]]) + (P, J) = A.jordan_form() + assert P.expand() == Matrix([ + [ -I, -I/2, I, I/2], + [-1 + I, 0, -1 - I, 0], + [ 0, -S(1)/2 - I/2, 0, -S(1)/2 + I/2], + [ 0, 1, 0, 1]]) + assert J == Matrix([ + [-I, 1, 0, 0], + [0, -I, 0, 0], + [0, 0, I, 1], + [0, 0, 0, I]]) + + +def test_Matrix_berkowitz_charpoly(): + UA, K_i, K_w = symbols('UA K_i K_w') + + A = Matrix([[-K_i - UA + K_i**2/(K_i + K_w), K_i*K_w/(K_i + K_w)], + [ K_i*K_w/(K_i + K_w), -K_w + K_w**2/(K_i + K_w)]]) + + charpoly = A.charpoly(x) + + assert charpoly == \ + Poly(x**2 + (K_i*UA + K_w*UA + 2*K_i*K_w)/(K_i + K_w)*x + + K_i*K_w*UA/(K_i + K_w), x, domain='ZZ(K_i,K_w,UA)') + + assert type(charpoly) is PurePoly + + A = Matrix([[1, 3], [2, 0]]) + assert A.charpoly() == A.charpoly(x) == PurePoly(x**2 - x - 6) + + A = Matrix([[1, 2], [x, 0]]) + p = A.charpoly(x) + assert p.gen != x + assert p.as_expr().subs(p.gen, x) == x**2 - 3*x + + +def test_exp_jordan_block(): + l = Symbol('lamda') + + m = Matrix.jordan_block(1, l) + assert m._eval_matrix_exp_jblock() == Matrix([[exp(l)]]) + + m = Matrix.jordan_block(3, l) + assert m._eval_matrix_exp_jblock() == \ + Matrix([ + [exp(l), exp(l), exp(l)/2], + [0, exp(l), exp(l)], + [0, 0, exp(l)]]) + + +def test_exp(): + m = Matrix([[3, 4], [0, -2]]) + m_exp = Matrix([[exp(3), -4*exp(-2)/5 + 4*exp(3)/5], [0, exp(-2)]]) + assert m.exp() == m_exp + assert exp(m) == m_exp + + m = Matrix([[1, 0], [0, 1]]) + assert m.exp() == Matrix([[E, 0], [0, E]]) + assert exp(m) == Matrix([[E, 0], [0, E]]) + + m = Matrix([[1, -1], [1, 1]]) + assert m.exp() == Matrix([[E*cos(1), -E*sin(1)], [E*sin(1), E*cos(1)]]) + + +def test_log(): + l = Symbol('lamda') + + m = Matrix.jordan_block(1, l) + assert m._eval_matrix_log_jblock() == Matrix([[log(l)]]) + + m = Matrix.jordan_block(4, l) + assert m._eval_matrix_log_jblock() == \ + Matrix( + [ + [log(l), 1/l, -1/(2*l**2), 1/(3*l**3)], + [0, log(l), 1/l, -1/(2*l**2)], + [0, 0, log(l), 1/l], + [0, 0, 0, log(l)] + ] + ) + + m = Matrix( + [[0, 0, 1], + [0, 0, 0], + [-1, 0, 0]] + ) + raises(MatrixError, lambda: m.log()) + + +def test_find_reasonable_pivot_naive_finds_guaranteed_nonzero1(): + # Test if matrices._find_reasonable_pivot_naive() + # finds a guaranteed non-zero pivot when the + # some of the candidate pivots are symbolic expressions. + # Keyword argument: simpfunc=None indicates that no simplifications + # should be performed during the search. + x = Symbol('x') + column = Matrix(3, 1, [x, cos(x)**2 + sin(x)**2, S.Half]) + pivot_offset, pivot_val, pivot_assumed_nonzero, simplified =\ + _find_reasonable_pivot_naive(column) + assert pivot_val == S.Half + + +def test_find_reasonable_pivot_naive_finds_guaranteed_nonzero2(): + # Test if matrices._find_reasonable_pivot_naive() + # finds a guaranteed non-zero pivot when the + # some of the candidate pivots are symbolic expressions. + # Keyword argument: simpfunc=_simplify indicates that the search + # should attempt to simplify candidate pivots. + x = Symbol('x') + column = Matrix(3, 1, + [x, + cos(x)**2+sin(x)**2+x**2, + cos(x)**2+sin(x)**2]) + pivot_offset, pivot_val, pivot_assumed_nonzero, simplified =\ + _find_reasonable_pivot_naive(column, simpfunc=_simplify) + assert pivot_val == 1 + + +def test_find_reasonable_pivot_naive_simplifies(): + # Test if matrices._find_reasonable_pivot_naive() + # simplifies candidate pivots, and reports + # their offsets correctly. + x = Symbol('x') + column = Matrix(3, 1, + [x, + cos(x)**2+sin(x)**2+x, + cos(x)**2+sin(x)**2]) + pivot_offset, pivot_val, pivot_assumed_nonzero, simplified =\ + _find_reasonable_pivot_naive(column, simpfunc=_simplify) + + assert len(simplified) == 2 + assert simplified[0][0] == 1 + assert simplified[0][1] == 1+x + assert simplified[1][0] == 2 + assert simplified[1][1] == 1 + + +def test_errors(): + raises(ValueError, lambda: Matrix([[1, 2], [1]])) + raises(IndexError, lambda: Matrix([[1, 2]])[1.2, 5]) + raises(IndexError, lambda: Matrix([[1, 2]])[1, 5.2]) + raises(ValueError, lambda: randMatrix(3, c=4, symmetric=True)) + raises(ValueError, lambda: Matrix([1, 2]).reshape(4, 6)) + raises(ShapeError, + lambda: Matrix([[1, 2], [3, 4]]).copyin_matrix([1, 0], Matrix([1, 2]))) + raises(TypeError, lambda: Matrix([[1, 2], [3, 4]]).copyin_list([0, + 1], set())) + raises(NonSquareMatrixError, lambda: Matrix([[1, 2, 3], [2, 3, 0]]).inv()) + raises(ShapeError, + lambda: Matrix(1, 2, [1, 2]).row_join(Matrix([[1, 2], [3, 4]]))) + raises( + ShapeError, lambda: Matrix([1, 2]).col_join(Matrix([[1, 2], [3, 4]]))) + raises(ShapeError, lambda: Matrix([1]).row_insert(1, Matrix([[1, + 2], [3, 4]]))) + raises(ShapeError, lambda: Matrix([1]).col_insert(1, Matrix([[1, + 2], [3, 4]]))) + raises(NonSquareMatrixError, lambda: Matrix([1, 2]).trace()) + raises(TypeError, lambda: Matrix([1]).applyfunc(1)) + raises(ValueError, lambda: Matrix([[1, 2], [3, 4]]).minor(4, 5)) + raises(ValueError, lambda: Matrix([[1, 2], [3, 4]]).minor_submatrix(4, 5)) + raises(TypeError, lambda: Matrix([1, 2, 3]).cross(1)) + raises(TypeError, lambda: Matrix([1, 2, 3]).dot(1)) + raises(ShapeError, lambda: Matrix([1, 2, 3]).dot(Matrix([1, 2]))) + raises(ShapeError, lambda: Matrix([1, 2]).dot([])) + raises(TypeError, lambda: Matrix([1, 2]).dot('a')) + raises(ShapeError, lambda: Matrix([1, 2]).dot([1, 2, 3])) + raises(NonSquareMatrixError, lambda: Matrix([1, 2, 3]).exp()) + raises(ShapeError, lambda: Matrix([[1, 2], [3, 4]]).normalized()) + raises(ValueError, lambda: Matrix([1, 2]).inv(method='not a method')) + raises(NonSquareMatrixError, lambda: Matrix([1, 2]).inverse_GE()) + raises(ValueError, lambda: Matrix([[1, 2], [1, 2]]).inverse_GE()) + raises(NonSquareMatrixError, lambda: Matrix([1, 2]).inverse_ADJ()) + raises(ValueError, lambda: Matrix([[1, 2], [1, 2]]).inverse_ADJ()) + raises(NonSquareMatrixError, lambda: Matrix([1, 2]).inverse_LU()) + raises(NonSquareMatrixError, lambda: Matrix([1, 2]).is_nilpotent()) + raises(NonSquareMatrixError, lambda: Matrix([1, 2]).det()) + raises(ValueError, + lambda: Matrix([[1, 2], [3, 4]]).det(method='Not a real method')) + raises(ValueError, + lambda: Matrix([[1, 2, 3, 4], [5, 6, 7, 8], + [9, 10, 11, 12], [13, 14, 15, 16]]).det(iszerofunc="Not function")) + raises(ValueError, + lambda: Matrix([[1, 2, 3, 4], [5, 6, 7, 8], + [9, 10, 11, 12], [13, 14, 15, 16]]).det(iszerofunc=False)) + raises(ValueError, + lambda: hessian(Matrix([[1, 2], [3, 4]]), Matrix([[1, 2], [2, 1]]))) + raises(ValueError, lambda: hessian(Matrix([[1, 2], [3, 4]]), [])) + raises(ValueError, lambda: hessian(Symbol('x')**2, 'a')) + raises(IndexError, lambda: eye(3)[5, 2]) + raises(IndexError, lambda: eye(3)[2, 5]) + M = Matrix(((1, 2, 3, 4), (5, 6, 7, 8), (9, 10, 11, 12), (13, 14, 15, 16))) + raises(ValueError, lambda: M.det('method=LU_decomposition()')) + V = Matrix([[10, 10, 10]]) + M = Matrix([[1, 2, 3], [2, 3, 4], [3, 4, 5]]) + raises(ValueError, lambda: M.row_insert(4.7, V)) + M = Matrix([[1, 2, 3], [2, 3, 4], [3, 4, 5]]) + raises(ValueError, lambda: M.col_insert(-4.2, V)) + + +def test_len(): + assert len(Matrix()) == 0 + assert len(Matrix([[1, 2]])) == len(Matrix([[1], [2]])) == 2 + assert len(Matrix(0, 2, lambda i, j: 0)) == \ + len(Matrix(2, 0, lambda i, j: 0)) == 0 + assert len(Matrix([[0, 1, 2], [3, 4, 5]])) == 6 + assert Matrix([1]) == Matrix([[1]]) + assert not Matrix() + assert Matrix() == Matrix([]) + + +def test_integrate(): + A = Matrix(((1, 4, x), (y, 2, 4), (10, 5, x**2))) + assert A.integrate(x) == \ + Matrix(((x, 4*x, x**2/2), (x*y, 2*x, 4*x), (10*x, 5*x, x**3/3))) + assert A.integrate(y) == \ + Matrix(((y, 4*y, x*y), (y**2/2, 2*y, 4*y), (10*y, 5*y, y*x**2))) + m = Matrix(2, 1, [x, y]) + assert m.integrate(x) == Matrix(2, 1, [x**2/2, y*x]) + + +def test_diff(): + A = MutableDenseMatrix(((1, 4, x), (y, 2, 4), (10, 5, x**2 + 1))) + assert isinstance(A.diff(x), type(A)) + assert A.diff(x) == MutableDenseMatrix(((0, 0, 1), (0, 0, 0), (0, 0, 2*x))) + assert A.diff(y) == MutableDenseMatrix(((0, 0, 0), (1, 0, 0), (0, 0, 0))) + + assert diff(A, x) == MutableDenseMatrix(((0, 0, 1), (0, 0, 0), (0, 0, 2*x))) + assert diff(A, y) == MutableDenseMatrix(((0, 0, 0), (1, 0, 0), (0, 0, 0))) + + A_imm = A.as_immutable() + assert isinstance(A_imm.diff(x), type(A_imm)) + assert A_imm.diff(x) == ImmutableDenseMatrix(((0, 0, 1), (0, 0, 0), (0, 0, 2*x))) + assert A_imm.diff(y) == ImmutableDenseMatrix(((0, 0, 0), (1, 0, 0), (0, 0, 0))) + + assert diff(A_imm, x) == ImmutableDenseMatrix(((0, 0, 1), (0, 0, 0), (0, 0, 2*x))) + assert diff(A_imm, y) == ImmutableDenseMatrix(((0, 0, 0), (1, 0, 0), (0, 0, 0))) + + assert A.diff(x, evaluate=False) == ArrayDerivative(A, x, evaluate=False) + assert diff(A, x, evaluate=False) == ArrayDerivative(A, x, evaluate=False) + + +def test_diff_by_matrix(): + + # Derive matrix by matrix: + + A = MutableDenseMatrix([[x, y], [z, t]]) + assert A.diff(A) == Array([[[[1, 0], [0, 0]], [[0, 1], [0, 0]]], [[[0, 0], [1, 0]], [[0, 0], [0, 1]]]]) + assert diff(A, A) == Array([[[[1, 0], [0, 0]], [[0, 1], [0, 0]]], [[[0, 0], [1, 0]], [[0, 0], [0, 1]]]]) + + A_imm = A.as_immutable() + assert A_imm.diff(A_imm) == Array([[[[1, 0], [0, 0]], [[0, 1], [0, 0]]], [[[0, 0], [1, 0]], [[0, 0], [0, 1]]]]) + assert diff(A_imm, A_imm) == Array([[[[1, 0], [0, 0]], [[0, 1], [0, 0]]], [[[0, 0], [1, 0]], [[0, 0], [0, 1]]]]) + + # Derive a constant matrix: + assert A.diff(a) == MutableDenseMatrix([[0, 0], [0, 0]]) + + B = ImmutableDenseMatrix([a, b]) + assert A.diff(B) == Array.zeros(2, 1, 2, 2) + assert A.diff(A) == Array([[[[1, 0], [0, 0]], [[0, 1], [0, 0]]], [[[0, 0], [1, 0]], [[0, 0], [0, 1]]]]) + + # Test diff with tuples: + + dB = B.diff([[a, b]]) + assert dB.shape == (2, 2, 1) + assert dB == Array([[[1], [0]], [[0], [1]]]) + + f = Function("f") + fxyz = f(x, y, z) + assert fxyz.diff([[x, y, z]]) == Array([fxyz.diff(x), fxyz.diff(y), fxyz.diff(z)]) + assert fxyz.diff(([x, y, z], 2)) == Array([ + [fxyz.diff(x, 2), fxyz.diff(x, y), fxyz.diff(x, z)], + [fxyz.diff(x, y), fxyz.diff(y, 2), fxyz.diff(y, z)], + [fxyz.diff(x, z), fxyz.diff(z, y), fxyz.diff(z, 2)], + ]) + + expr = sin(x)*exp(y) + assert expr.diff([[x, y]]) == Array([cos(x)*exp(y), sin(x)*exp(y)]) + assert expr.diff(y, ((x, y),)) == Array([cos(x)*exp(y), sin(x)*exp(y)]) + assert expr.diff(x, ((x, y),)) == Array([-sin(x)*exp(y), cos(x)*exp(y)]) + assert expr.diff(((y, x),), [[x, y]]) == Array([[cos(x)*exp(y), -sin(x)*exp(y)], [sin(x)*exp(y), cos(x)*exp(y)]]) + + # Test different notations: + + assert fxyz.diff(x).diff(y).diff(x) == fxyz.diff(((x, y, z),), 3)[0, 1, 0] + assert fxyz.diff(z).diff(y).diff(x) == fxyz.diff(((x, y, z),), 3)[2, 1, 0] + assert fxyz.diff([[x, y, z]], ((z, y, x),)) == Array([[fxyz.diff(i).diff(j) for i in (x, y, z)] for j in (z, y, x)]) + + # Test scalar derived by matrix remains matrix: + res = x.diff(Matrix([[x, y]])) + assert isinstance(res, ImmutableDenseMatrix) + assert res == Matrix([[1, 0]]) + res = (x**3).diff(Matrix([[x, y]])) + assert isinstance(res, ImmutableDenseMatrix) + assert res == Matrix([[3*x**2, 0]]) + + +def test_getattr(): + A = Matrix(((1, 4, x), (y, 2, 4), (10, 5, x**2 + 1))) + raises(AttributeError, lambda: A.nonexistantattribute) + assert getattr(A, 'diff')(x) == Matrix(((0, 0, 1), (0, 0, 0), (0, 0, 2*x))) + + +def test_hessenberg(): + A = Matrix([[3, 4, 1], [2, 4, 5], [0, 1, 2]]) + assert A.is_upper_hessenberg + A = A.T + assert A.is_lower_hessenberg + A[0, -1] = 1 + assert A.is_lower_hessenberg is False + + A = Matrix([[3, 4, 1], [2, 4, 5], [3, 1, 2]]) + assert not A.is_upper_hessenberg + + A = zeros(5, 2) + assert A.is_upper_hessenberg + + +def test_cholesky(): + raises(NonSquareMatrixError, lambda: Matrix((1, 2)).cholesky()) + raises(ValueError, lambda: Matrix(((1, 2), (3, 4))).cholesky()) + raises(ValueError, lambda: Matrix(((5 + I, 0), (0, 1))).cholesky()) + raises(ValueError, lambda: Matrix(((1, 5), (5, 1))).cholesky()) + raises(ValueError, lambda: Matrix(((1, 2), (3, 4))).cholesky(hermitian=False)) + assert Matrix(((5 + I, 0), (0, 1))).cholesky(hermitian=False) == Matrix([ + [sqrt(5 + I), 0], [0, 1]]) + A = Matrix(((1, 5), (5, 1))) + L = A.cholesky(hermitian=False) + assert L == Matrix([[1, 0], [5, 2*sqrt(6)*I]]) + assert L*L.T == A + A = Matrix(((25, 15, -5), (15, 18, 0), (-5, 0, 11))) + L = A.cholesky() + assert L * L.T == A + assert L.is_lower + assert L == Matrix([[5, 0, 0], [3, 3, 0], [-1, 1, 3]]) + A = Matrix(((4, -2*I, 2 + 2*I), (2*I, 2, -1 + I), (2 - 2*I, -1 - I, 11))) + assert A.cholesky().expand() == Matrix(((2, 0, 0), (I, 1, 0), (1 - I, 0, 3))) + + raises(NonSquareMatrixError, lambda: SparseMatrix((1, 2)).cholesky()) + raises(ValueError, lambda: SparseMatrix(((1, 2), (3, 4))).cholesky()) + raises(ValueError, lambda: SparseMatrix(((5 + I, 0), (0, 1))).cholesky()) + raises(ValueError, lambda: SparseMatrix(((1, 5), (5, 1))).cholesky()) + raises(ValueError, lambda: SparseMatrix(((1, 2), (3, 4))).cholesky(hermitian=False)) + assert SparseMatrix(((5 + I, 0), (0, 1))).cholesky(hermitian=False) == Matrix([ + [sqrt(5 + I), 0], [0, 1]]) + A = SparseMatrix(((1, 5), (5, 1))) + L = A.cholesky(hermitian=False) + assert L == Matrix([[1, 0], [5, 2*sqrt(6)*I]]) + assert L*L.T == A + A = SparseMatrix(((25, 15, -5), (15, 18, 0), (-5, 0, 11))) + L = A.cholesky() + assert L * L.T == A + assert L.is_lower + assert L == Matrix([[5, 0, 0], [3, 3, 0], [-1, 1, 3]]) + A = SparseMatrix(((4, -2*I, 2 + 2*I), (2*I, 2, -1 + I), (2 - 2*I, -1 - I, 11))) + assert A.cholesky() == Matrix(((2, 0, 0), (I, 1, 0), (1 - I, 0, 3))) + + +def test_matrix_norm(): + # Vector Tests + # Test columns and symbols + x = Symbol('x', real=True) + v = Matrix([cos(x), sin(x)]) + assert trigsimp(v.norm(2)) == 1 + assert v.norm(10) == Pow(cos(x)**10 + sin(x)**10, Rational(1, 10)) + + # Test Rows + A = Matrix([[5, Rational(3, 2)]]) + assert A.norm() == Pow(25 + Rational(9, 4), S.Half) + assert A.norm(oo) == max(A) + assert A.norm(-oo) == min(A) + + # Matrix Tests + # Intuitive test + A = Matrix([[1, 1], [1, 1]]) + assert A.norm(2) == 2 + assert A.norm(-2) == 0 + assert A.norm('frobenius') == 2 + assert eye(10).norm(2) == eye(10).norm(-2) == 1 + assert A.norm(oo) == 2 + + # Test with Symbols and more complex entries + A = Matrix([[3, y, y], [x, S.Half, -pi]]) + assert (A.norm('fro') + == sqrt(Rational(37, 4) + 2*abs(y)**2 + pi**2 + x**2)) + + # Check non-square + A = Matrix([[1, 2, -3], [4, 5, Rational(13, 2)]]) + assert A.norm(2) == sqrt(Rational(389, 8) + sqrt(78665)/8) + assert A.norm(-2) is S.Zero + assert A.norm('frobenius') == sqrt(389)/2 + + # Test properties of matrix norms + # https://en.wikipedia.org/wiki/Matrix_norm#Definition + # Two matrices + A = Matrix([[1, 2], [3, 4]]) + B = Matrix([[5, 5], [-2, 2]]) + C = Matrix([[0, -I], [I, 0]]) + D = Matrix([[1, 0], [0, -1]]) + L = [A, B, C, D] + alpha = Symbol('alpha', real=True) + + for order in ['fro', 2, -2]: + # Zero Check + assert zeros(3).norm(order) is S.Zero + # Check Triangle Inequality for all Pairs of Matrices + for X in L: + for Y in L: + dif = (X.norm(order) + Y.norm(order) - + (X + Y).norm(order)) + assert (dif >= 0) + # Scalar multiplication linearity + for M in [A, B, C, D]: + dif = simplify((alpha*M).norm(order) - + abs(alpha) * M.norm(order)) + assert dif == 0 + + # Test Properties of Vector Norms + # https://en.wikipedia.org/wiki/Vector_norm + # Two column vectors + a = Matrix([1, 1 - 1*I, -3]) + b = Matrix([S.Half, 1*I, 1]) + c = Matrix([-1, -1, -1]) + d = Matrix([3, 2, I]) + e = Matrix([Integer(1e2), Rational(1, 1e2), 1]) + L = [a, b, c, d, e] + alpha = Symbol('alpha', real=True) + + for order in [1, 2, -1, -2, S.Infinity, S.NegativeInfinity, pi]: + # Zero Check + if order > 0: + assert Matrix([0, 0, 0]).norm(order) is S.Zero + # Triangle inequality on all pairs + if order >= 1: # Triangle InEq holds only for these norms + for X in L: + for Y in L: + dif = (X.norm(order) + Y.norm(order) - + (X + Y).norm(order)) + assert simplify(dif >= 0) is S.true + # Linear to scalar multiplication + if order in [1, 2, -1, -2, S.Infinity, S.NegativeInfinity]: + for X in L: + dif = simplify((alpha*X).norm(order) - + (abs(alpha) * X.norm(order))) + assert dif == 0 + + # ord=1 + M = Matrix(3, 3, [1, 3, 0, -2, -1, 0, 3, 9, 6]) + assert M.norm(1) == 13 + + +def test_condition_number(): + x = Symbol('x', real=True) + A = eye(3) + A[0, 0] = 10 + A[2, 2] = Rational(1, 10) + assert A.condition_number() == 100 + + A[1, 1] = x + assert A.condition_number() == Max(10, Abs(x)) / Min(Rational(1, 10), Abs(x)) + + M = Matrix([[cos(x), sin(x)], [-sin(x), cos(x)]]) + Mc = M.condition_number() + assert all(Float(1.).epsilon_eq(Mc.subs(x, val).evalf()) for val in + [Rational(1, 5), S.Half, Rational(1, 10), pi/2, pi, pi*Rational(7, 4) ]) + + #issue 10782 + assert Matrix([]).condition_number() == 0 + + +def test_equality(): + A = Matrix(((1, 2, 3), (4, 5, 6), (7, 8, 9))) + B = Matrix(((9, 8, 7), (6, 5, 4), (3, 2, 1))) + assert A == A[:, :] + assert not A != A[:, :] + assert not A == B + assert A != B + assert A != 10 + assert not A == 10 + + # A SparseMatrix can be equal to a Matrix + C = SparseMatrix(((1, 0, 0), (0, 1, 0), (0, 0, 1))) + D = Matrix(((1, 0, 0), (0, 1, 0), (0, 0, 1))) + assert C == D + assert not C != D + + +def test_normalized(): + assert Matrix([3, 4]).normalized() == \ + Matrix([Rational(3, 5), Rational(4, 5)]) + + # Zero vector trivial cases + assert Matrix([0, 0, 0]).normalized() == Matrix([0, 0, 0]) + + # Machine precision error truncation trivial cases + m = Matrix([0,0,1.e-100]) + assert m.normalized( + iszerofunc=lambda x: x.evalf(n=10, chop=True).is_zero + ) == Matrix([0, 0, 0]) + + +def test_print_nonzero(): + assert capture(lambda: eye(3).print_nonzero()) == \ + '[X ]\n[ X ]\n[ X]\n' + assert capture(lambda: eye(3).print_nonzero('.')) == \ + '[. ]\n[ . ]\n[ .]\n' + + +def test_zeros_eye(): + assert Matrix.eye(3) == eye(3) + assert Matrix.zeros(3) == zeros(3) + assert ones(3, 4) == Matrix(3, 4, [1]*12) + + i = Matrix([[1, 0], [0, 1]]) + z = Matrix([[0, 0], [0, 0]]) + for cls in all_classes: + m = cls.eye(2) + assert i == m # but m == i will fail if m is immutable + assert i == eye(2, cls=cls) + assert type(m) == cls + m = cls.zeros(2) + assert z == m + assert z == zeros(2, cls=cls) + assert type(m) == cls + + +def test_is_zero(): + assert Matrix().is_zero_matrix + assert Matrix([[0, 0], [0, 0]]).is_zero_matrix + assert zeros(3, 4).is_zero_matrix + assert not eye(3).is_zero_matrix + assert Matrix([[x, 0], [0, 0]]).is_zero_matrix == None + assert SparseMatrix([[x, 0], [0, 0]]).is_zero_matrix == None + assert ImmutableMatrix([[x, 0], [0, 0]]).is_zero_matrix == None + assert ImmutableSparseMatrix([[x, 0], [0, 0]]).is_zero_matrix == None + assert Matrix([[x, 1], [0, 0]]).is_zero_matrix == False + a = Symbol('a', nonzero=True) + assert Matrix([[a, 0], [0, 0]]).is_zero_matrix == False + + +def test_rotation_matrices(): + # This tests the rotation matrices by rotating about an axis and back. + theta = pi/3 + r3_plus = rot_axis3(theta) + r3_minus = rot_axis3(-theta) + r2_plus = rot_axis2(theta) + r2_minus = rot_axis2(-theta) + r1_plus = rot_axis1(theta) + r1_minus = rot_axis1(-theta) + assert r3_minus*r3_plus*eye(3) == eye(3) + assert r2_minus*r2_plus*eye(3) == eye(3) + assert r1_minus*r1_plus*eye(3) == eye(3) + + # Check the correctness of the trace of the rotation matrix + assert r1_plus.trace() == 1 + 2*cos(theta) + assert r2_plus.trace() == 1 + 2*cos(theta) + assert r3_plus.trace() == 1 + 2*cos(theta) + + # Check that a rotation with zero angle doesn't change anything. + assert rot_axis1(0) == eye(3) + assert rot_axis2(0) == eye(3) + assert rot_axis3(0) == eye(3) + + # Check left-hand convention + # see Issue #24529 + q1 = Quaternion.from_axis_angle([1, 0, 0], pi / 2) + q2 = Quaternion.from_axis_angle([0, 1, 0], pi / 2) + q3 = Quaternion.from_axis_angle([0, 0, 1], pi / 2) + assert rot_axis1(- pi / 2) == q1.to_rotation_matrix() + assert rot_axis2(- pi / 2) == q2.to_rotation_matrix() + assert rot_axis3(- pi / 2) == q3.to_rotation_matrix() + # Check right-hand convention + assert rot_ccw_axis1(+ pi / 2) == q1.to_rotation_matrix() + assert rot_ccw_axis2(+ pi / 2) == q2.to_rotation_matrix() + assert rot_ccw_axis3(+ pi / 2) == q3.to_rotation_matrix() + + +def test_DeferredVector(): + assert str(DeferredVector("vector")[4]) == "vector[4]" + assert sympify(DeferredVector("d")) == DeferredVector("d") + raises(IndexError, lambda: DeferredVector("d")[-1]) + assert str(DeferredVector("d")) == "d" + assert repr(DeferredVector("test")) == "DeferredVector('test')" + + +def test_DeferredVector_not_iterable(): + assert not iterable(DeferredVector('X')) + + +def test_DeferredVector_Matrix(): + raises(TypeError, lambda: Matrix(DeferredVector("V"))) + + +def test_GramSchmidt(): + R = Rational + m1 = Matrix(1, 2, [1, 2]) + m2 = Matrix(1, 2, [2, 3]) + assert GramSchmidt([m1, m2]) == \ + [Matrix(1, 2, [1, 2]), Matrix(1, 2, [R(2)/5, R(-1)/5])] + assert GramSchmidt([m1.T, m2.T]) == \ + [Matrix(2, 1, [1, 2]), Matrix(2, 1, [R(2)/5, R(-1)/5])] + # from wikipedia + assert GramSchmidt([Matrix([3, 1]), Matrix([2, 2])], True) == [ + Matrix([3*sqrt(10)/10, sqrt(10)/10]), + Matrix([-sqrt(10)/10, 3*sqrt(10)/10])] + # https://github.com/sympy/sympy/issues/9488 + L = FiniteSet(Matrix([1])) + assert GramSchmidt(L) == [Matrix([[1]])] + + +def test_casoratian(): + assert casoratian([1, 2, 3, 4], 1) == 0 + assert casoratian([1, 2, 3, 4], 1, zero=False) == 0 + + +def test_zero_dimension_multiply(): + assert (Matrix()*zeros(0, 3)).shape == (0, 3) + assert zeros(3, 0)*zeros(0, 3) == zeros(3, 3) + assert zeros(0, 3)*zeros(3, 0) == Matrix() + + +def test_slice_issue_2884(): + m = Matrix(2, 2, range(4)) + assert m[1, :] == Matrix([[2, 3]]) + assert m[-1, :] == Matrix([[2, 3]]) + assert m[:, 1] == Matrix([[1, 3]]).T + assert m[:, -1] == Matrix([[1, 3]]).T + raises(IndexError, lambda: m[2, :]) + raises(IndexError, lambda: m[2, 2]) + + +def test_slice_issue_3401(): + assert zeros(0, 3)[:, -1].shape == (0, 1) + assert zeros(3, 0)[0, :] == Matrix(1, 0, []) + + +def test_copyin(): + s = zeros(3, 3) + s[3] = 1 + assert s[:, 0] == Matrix([0, 1, 0]) + assert s[3] == 1 + assert s[3: 4] == [1] + s[1, 1] = 42 + assert s[1, 1] == 42 + assert s[1, 1:] == Matrix([[42, 0]]) + s[1, 1:] = Matrix([[5, 6]]) + assert s[1, :] == Matrix([[1, 5, 6]]) + s[1, 1:] = [[42, 43]] + assert s[1, :] == Matrix([[1, 42, 43]]) + s[0, 0] = 17 + assert s[:, :1] == Matrix([17, 1, 0]) + s[0, 0] = [1, 1, 1] + assert s[:, 0] == Matrix([1, 1, 1]) + s[0, 0] = Matrix([1, 1, 1]) + assert s[:, 0] == Matrix([1, 1, 1]) + s[0, 0] = SparseMatrix([1, 1, 1]) + assert s[:, 0] == Matrix([1, 1, 1]) + + +def test_invertible_check(): + # sometimes a singular matrix will have a pivot vector shorter than + # the number of rows in a matrix... + assert Matrix([[1, 2], [1, 2]]).rref() == (Matrix([[1, 2], [0, 0]]), (0,)) + raises(ValueError, lambda: Matrix([[1, 2], [1, 2]]).inv()) + m = Matrix([ + [-1, -1, 0], + [ x, 1, 1], + [ 1, x, -1], + ]) + assert len(m.rref()[1]) != m.rows + # in addition, unless simplify=True in the call to rref, the identity + # matrix will be returned even though m is not invertible + assert m.rref()[0] != eye(3) + assert m.rref(simplify=signsimp)[0] != eye(3) + raises(ValueError, lambda: m.inv(method="ADJ")) + raises(ValueError, lambda: m.inv(method="GE")) + raises(ValueError, lambda: m.inv(method="LU")) + + +def test_issue_3959(): + x, y = symbols('x, y') + e = x*y + assert e.subs(x, Matrix([3, 5, 3])) == Matrix([3, 5, 3])*y + + +def test_issue_5964(): + assert str(Matrix([[1, 2], [3, 4]])) == 'Matrix([[1, 2], [3, 4]])' + + +def test_issue_7604(): + x, y = symbols("x y") + assert sstr(Matrix([[x, 2*y], [y**2, x + 3]])) == \ + 'Matrix([\n[ x, 2*y],\n[y**2, x + 3]])' + + +def test_is_Identity(): + assert eye(3).is_Identity + assert eye(3).as_immutable().is_Identity + assert not zeros(3).is_Identity + assert not ones(3).is_Identity + # issue 6242 + assert not Matrix([[1, 0, 0]]).is_Identity + # issue 8854 + assert SparseMatrix(3,3, {(0,0):1, (1,1):1, (2,2):1}).is_Identity + assert not SparseMatrix(2,3, range(6)).is_Identity + assert not SparseMatrix(3,3, {(0,0):1, (1,1):1}).is_Identity + assert not SparseMatrix(3,3, {(0,0):1, (1,1):1, (2,2):1, (0,1):2, (0,2):3}).is_Identity + + +def test_dot(): + assert ones(1, 3).dot(ones(3, 1)) == 3 + assert ones(1, 3).dot([1, 1, 1]) == 3 + assert Matrix([1, 2, 3]).dot(Matrix([1, 2, 3])) == 14 + assert Matrix([1, 2, 3*I]).dot(Matrix([I, 2, 3*I])) == -5 + I + assert Matrix([1, 2, 3*I]).dot(Matrix([I, 2, 3*I]), hermitian=False) == -5 + I + assert Matrix([1, 2, 3*I]).dot(Matrix([I, 2, 3*I]), hermitian=True) == 13 + I + assert Matrix([1, 2, 3*I]).dot(Matrix([I, 2, 3*I]), hermitian=True, conjugate_convention="physics") == 13 - I + assert Matrix([1, 2, 3*I]).dot(Matrix([4, 5*I, 6]), hermitian=True, conjugate_convention="right") == 4 + 8*I + assert Matrix([1, 2, 3*I]).dot(Matrix([4, 5*I, 6]), hermitian=True, conjugate_convention="left") == 4 - 8*I + assert Matrix([I, 2*I]).dot(Matrix([I, 2*I]), hermitian=False, conjugate_convention="left") == -5 + assert Matrix([I, 2*I]).dot(Matrix([I, 2*I]), conjugate_convention="left") == 5 + raises(ValueError, lambda: Matrix([1, 2]).dot(Matrix([3, 4]), hermitian=True, conjugate_convention="test")) + + +def test_dual(): + B_x, B_y, B_z, E_x, E_y, E_z = symbols( + 'B_x B_y B_z E_x E_y E_z', real=True) + F = Matrix(( + ( 0, E_x, E_y, E_z), + (-E_x, 0, B_z, -B_y), + (-E_y, -B_z, 0, B_x), + (-E_z, B_y, -B_x, 0) + )) + Fd = Matrix(( + ( 0, -B_x, -B_y, -B_z), + (B_x, 0, E_z, -E_y), + (B_y, -E_z, 0, E_x), + (B_z, E_y, -E_x, 0) + )) + assert F.dual().equals(Fd) + assert eye(3).dual().equals(zeros(3)) + assert F.dual().dual().equals(-F) + + +def test_anti_symmetric(): + assert Matrix([1, 2]).is_anti_symmetric() is False + m = Matrix(3, 3, [0, x**2 + 2*x + 1, y, -(x + 1)**2, 0, x*y, -y, -x*y, 0]) + assert m.is_anti_symmetric() is True + assert m.is_anti_symmetric(simplify=False) is None + assert m.is_anti_symmetric(simplify=lambda x: x) is None + + # tweak to fail + m[2, 1] = -m[2, 1] + assert m.is_anti_symmetric() is None + # untweak + m[2, 1] = -m[2, 1] + + m = m.expand() + assert m.is_anti_symmetric(simplify=False) is True + m[0, 0] = 1 + assert m.is_anti_symmetric() is False + + +def test_normalize_sort_diogonalization(): + A = Matrix(((1, 2), (2, 1))) + P, Q = A.diagonalize(normalize=True) + assert P*P.T == P.T*P == eye(P.cols) + P, Q = A.diagonalize(normalize=True, sort=True) + assert P*P.T == P.T*P == eye(P.cols) + assert P*Q*P.inv() == A + + +def test_issue_5321(): + raises(ValueError, lambda: Matrix([[1, 2, 3], Matrix(0, 1, [])])) + + +def test_issue_5320(): + assert Matrix.hstack(eye(2), 2*eye(2)) == Matrix([ + [1, 0, 2, 0], + [0, 1, 0, 2] + ]) + assert Matrix.vstack(eye(2), 2*eye(2)) == Matrix([ + [1, 0], + [0, 1], + [2, 0], + [0, 2] + ]) + cls = SparseMatrix + assert cls.hstack(cls(eye(2)), cls(2*eye(2))) == Matrix([ + [1, 0, 2, 0], + [0, 1, 0, 2] + ]) + + +def test_issue_11944(): + A = Matrix([[1]]) + AIm = sympify(A) + assert Matrix.hstack(AIm, A) == Matrix([[1, 1]]) + assert Matrix.vstack(AIm, A) == Matrix([[1], [1]]) + + +def test_cross(): + a = [1, 2, 3] + b = [3, 4, 5] + col = Matrix([-2, 4, -2]) + row = col.T + + def test(M, ans): + assert ans == M + assert type(M) == cls + for cls in all_classes: + A = cls(a) + B = cls(b) + test(A.cross(B), col) + test(A.cross(B.T), col) + test(A.T.cross(B.T), row) + test(A.T.cross(B), row) + raises(ShapeError, lambda: + Matrix(1, 2, [1, 1]).cross(Matrix(1, 2, [1, 1]))) + + +def test_hat_vee(): + v1 = Matrix([x, y, z]) + v2 = Matrix([a, b, c]) + assert v1.hat() * v2 == v1.cross(v2) + assert v1.hat().is_anti_symmetric() + assert v1.hat().vee() == v1 + + +def test_hash(): + for cls in immutable_classes: + s = {cls.eye(1), cls.eye(1)} + assert len(s) == 1 and s.pop() == cls.eye(1) + # issue 3979 + for cls in mutable_classes: + assert not isinstance(cls.eye(1), Hashable) + + +def test_adjoint(): + dat = [[0, I], [1, 0]] + ans = Matrix([[0, 1], [-I, 0]]) + for cls in all_classes: + assert ans == cls(dat).adjoint() + + +def test_atoms(): + m = Matrix([[1, 2], [x, 1 - 1/x]]) + assert m.atoms() == {S.One,S(2),S.NegativeOne, x} + assert m.atoms(Symbol) == {x} + + +def test_pinv(): + # Pseudoinverse of an invertible matrix is the inverse. + A1 = Matrix([[a, b], [c, d]]) + assert simplify(A1.pinv(method="RD")) == simplify(A1.inv()) + + # Test the four properties of the pseudoinverse for various matrices. + As = [Matrix([[13, 104], [2212, 3], [-3, 5]]), + Matrix([[1, 7, 9], [11, 17, 19]]), + Matrix([a, b])] + + for A in As: + A_pinv = A.pinv(method="RD") + AAp = A * A_pinv + ApA = A_pinv * A + assert simplify(AAp * A) == A + assert simplify(ApA * A_pinv) == A_pinv + assert AAp.H == AAp + assert ApA.H == ApA + + # XXX Pinv with diagonalization makes expression too complicated. + for A in As: + A_pinv = simplify(A.pinv(method="ED")) + AAp = A * A_pinv + ApA = A_pinv * A + assert simplify(AAp * A) == A + assert simplify(ApA * A_pinv) == A_pinv + assert AAp.H == AAp + assert ApA.H == ApA + + # XXX Computing pinv using diagonalization makes an expression that + # is too complicated to simplify. + # A1 = Matrix([[a, b], [c, d]]) + # assert simplify(A1.pinv(method="ED")) == simplify(A1.inv()) + # so this is tested numerically at a fixed random point + + from sympy.core.numbers import comp + q = A1.pinv(method="ED") + w = A1.inv() + reps = {a: -73633, b: 11362, c: 55486, d: 62570} + assert all( + comp(i.n(), j.n()) + for i, j in zip(q.subs(reps), w.subs(reps)) + ) + + +@slow +def test_pinv_rank_deficient_when_diagonalization_fails(): + # Test the four properties of the pseudoinverse for matrices when + # diagonalization of A.H*A fails. + As = [ + Matrix([ + [61, 89, 55, 20, 71, 0], + [62, 96, 85, 85, 16, 0], + [69, 56, 17, 4, 54, 0], + [10, 54, 91, 41, 71, 0], + [ 7, 30, 10, 48, 90, 0], + [0, 0, 0, 0, 0, 0]]) + ] + for A in As: + A_pinv = A.pinv(method="ED") + AAp = A * A_pinv + ApA = A_pinv * A + assert AAp.H == AAp + + # Here ApA.H and ApA are equivalent expressions but they are very + # complicated expressions involving RootOfs. Using simplify would be + # too slow and so would evalf so we substitute approximate values for + # the RootOfs and then evalf which is less accurate but good enough to + # confirm that these two matrices are equivalent. + # + # assert ApA.H == ApA # <--- would fail (structural equality) + # assert simplify(ApA.H - ApA).is_zero_matrix # <--- too slow + # (ApA.H - ApA).evalf() # <--- too slow + + def allclose(M1, M2): + rootofs = M1.atoms(RootOf) + rootofs_approx = {r: r.evalf() for r in rootofs} + diff_approx = (M1 - M2).xreplace(rootofs_approx).evalf() + return all(abs(e) < 1e-10 for e in diff_approx) + + assert allclose(ApA.H, ApA) + + +def test_issue_7201(): + assert ones(0, 1) + ones(0, 1) == Matrix(0, 1, []) + assert ones(1, 0) + ones(1, 0) == Matrix(1, 0, []) + + +def test_free_symbols(): + for M in ImmutableMatrix, ImmutableSparseMatrix, Matrix, SparseMatrix: + assert M([[x], [0]]).free_symbols == {x} + + +def test_from_ndarray(): + """See issue 7465.""" + try: + from numpy import array + except ImportError: + skip('NumPy must be available to test creating matrices from ndarrays') + + assert Matrix(array([1, 2, 3])) == Matrix([1, 2, 3]) + assert Matrix(array([[1, 2, 3]])) == Matrix([[1, 2, 3]]) + assert Matrix(array([[1, 2, 3], [4, 5, 6]])) == \ + Matrix([[1, 2, 3], [4, 5, 6]]) + assert Matrix(array([x, y, z])) == Matrix([x, y, z]) + raises(NotImplementedError, + lambda: Matrix(array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]))) + assert Matrix([array([1, 2]), array([3, 4])]) == Matrix([[1, 2], [3, 4]]) + assert Matrix([array([1, 2]), [3, 4]]) == Matrix([[1, 2], [3, 4]]) + assert Matrix([array([]), array([])]) == Matrix(2, 0, []) != Matrix([]) + + +def test_17522_numpy(): + from sympy.matrices.common import _matrixify + try: + from numpy import array, matrix + except ImportError: + skip('NumPy must be available to test indexing matrixified NumPy ndarrays and matrices') + + m = _matrixify(array([[1, 2], [3, 4]])) + assert m[3] == 4 + assert list(m) == [1, 2, 3, 4] + + with ignore_warnings(PendingDeprecationWarning): + m = _matrixify(matrix([[1, 2], [3, 4]])) + assert m[3] == 4 + assert list(m) == [1, 2, 3, 4] + + +def test_17522_mpmath(): + from sympy.matrices.common import _matrixify + try: + from mpmath import matrix + except ImportError: + skip('mpmath must be available to test indexing matrixified mpmath matrices') + + m = _matrixify(matrix([[1, 2], [3, 4]])) + assert m[3] == 4.0 + assert list(m) == [1.0, 2.0, 3.0, 4.0] + + +def test_17522_scipy(): + from sympy.matrices.common import _matrixify + try: + from scipy.sparse import csr_matrix + except ImportError: + skip('SciPy must be available to test indexing matrixified SciPy sparse matrices') + + m = _matrixify(csr_matrix([[1, 2], [3, 4]])) + assert m[3] == 4 + assert list(m) == [1, 2, 3, 4] + + +def test_hermitian(): + a = Matrix([[1, I], [-I, 1]]) + assert a.is_hermitian + a[0, 0] = 2*I + assert a.is_hermitian is False + a[0, 0] = x + assert a.is_hermitian is None + a[0, 1] = a[1, 0]*I + assert a.is_hermitian is False + + +def test_issue_9457_9467_9876(): + # for row_del(index) + M = Matrix([[1, 2, 3], [2, 3, 4], [3, 4, 5]]) + M.row_del(1) + assert M == Matrix([[1, 2, 3], [3, 4, 5]]) + N = Matrix([[1, 2, 3], [2, 3, 4], [3, 4, 5]]) + N.row_del(-2) + assert N == Matrix([[1, 2, 3], [3, 4, 5]]) + O = Matrix([[1, 2, 3], [5, 6, 7], [9, 10, 11]]) + O.row_del(-1) + assert O == Matrix([[1, 2, 3], [5, 6, 7]]) + P = Matrix([[1, 2, 3], [2, 3, 4], [3, 4, 5]]) + raises(IndexError, lambda: P.row_del(10)) + Q = Matrix([[1, 2, 3], [2, 3, 4], [3, 4, 5]]) + raises(IndexError, lambda: Q.row_del(-10)) + + # for col_del(index) + M = Matrix([[1, 2, 3], [2, 3, 4], [3, 4, 5]]) + M.col_del(1) + assert M == Matrix([[1, 3], [2, 4], [3, 5]]) + N = Matrix([[1, 2, 3], [2, 3, 4], [3, 4, 5]]) + N.col_del(-2) + assert N == Matrix([[1, 3], [2, 4], [3, 5]]) + P = Matrix([[1, 2, 3], [2, 3, 4], [3, 4, 5]]) + raises(IndexError, lambda: P.col_del(10)) + Q = Matrix([[1, 2, 3], [2, 3, 4], [3, 4, 5]]) + raises(IndexError, lambda: Q.col_del(-10)) + + +def test_issue_9422(): + x, y = symbols('x y', commutative=False) + a, b = symbols('a b') + M = eye(2) + M1 = Matrix(2, 2, [x, y, y, z]) + assert y*x*M != x*y*M + assert b*a*M == a*b*M + assert x*M1 != M1*x + assert a*M1 == M1*a + assert y*x*M == Matrix([[y*x, 0], [0, y*x]]) + + +def test_issue_10770(): + M = Matrix([]) + a = ['col_insert', 'row_join'], Matrix([9, 6, 3]) + b = ['row_insert', 'col_join'], a[1].T + c = ['row_insert', 'col_insert'], Matrix([[1, 2], [3, 4]]) + for ops, m in (a, b, c): + for op in ops: + f = getattr(M, op) + new = f(m) if 'join' in op else f(42, m) + assert new == m and id(new) != id(m) + + +def test_issue_10658(): + A = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + assert A.extract([0, 1, 2], [True, True, False]) == \ + Matrix([[1, 2], [4, 5], [7, 8]]) + assert A.extract([0, 1, 2], [True, False, False]) == Matrix([[1], [4], [7]]) + assert A.extract([True, False, False], [0, 1, 2]) == Matrix([[1, 2, 3]]) + assert A.extract([True, False, True], [0, 1, 2]) == \ + Matrix([[1, 2, 3], [7, 8, 9]]) + assert A.extract([0, 1, 2], [False, False, False]) == Matrix(3, 0, []) + assert A.extract([False, False, False], [0, 1, 2]) == Matrix(0, 3, []) + assert A.extract([True, False, True], [False, True, False]) == \ + Matrix([[2], [8]]) + + +def test_opportunistic_simplification(): + # this test relates to issue #10718, #9480, #11434 + + # issue #9480 + m = Matrix([[-5 + 5*sqrt(2), -5], [-5*sqrt(2)/2 + 5, -5*sqrt(2)/2]]) + assert m.rank() == 1 + + # issue #10781 + m = Matrix([[3+3*sqrt(3)*I, -9],[4,-3+3*sqrt(3)*I]]) + assert simplify(m.rref()[0] - Matrix([[1, -9/(3 + 3*sqrt(3)*I)], [0, 0]])) == zeros(2, 2) + + # issue #11434 + ax,ay,bx,by,cx,cy,dx,dy,ex,ey,t0,t1 = symbols('a_x a_y b_x b_y c_x c_y d_x d_y e_x e_y t_0 t_1') + m = Matrix([[ax,ay,ax*t0,ay*t0,0],[bx,by,bx*t0,by*t0,0],[cx,cy,cx*t0,cy*t0,1],[dx,dy,dx*t0,dy*t0,1],[ex,ey,2*ex*t1-ex*t0,2*ey*t1-ey*t0,0]]) + assert m.rank() == 4 + + +def test_partial_pivoting(): + # example from https://en.wikipedia.org/wiki/Pivot_element + # partial pivoting with back substitution gives a perfect result + # naive pivoting give an error ~1e-13, so anything better than + # 1e-15 is good + mm=Matrix([[0.003, 59.14, 59.17], [5.291, -6.13, 46.78]]) + assert (mm.rref()[0] - Matrix([[1.0, 0, 10.0], + [ 0, 1.0, 1.0]])).norm() < 1e-15 + + # issue #11549 + m_mixed = Matrix([[6e-17, 1.0, 4], + [ -1.0, 0, 8], + [ 0, 0, 1]]) + m_float = Matrix([[6e-17, 1.0, 4.], + [ -1.0, 0., 8.], + [ 0., 0., 1.]]) + m_inv = Matrix([[ 0, -1.0, 8.0], + [1.0, 6.0e-17, -4.0], + [ 0, 0, 1]]) + # this example is numerically unstable and involves a matrix with a norm >= 8, + # this comparing the difference of the results with 1e-15 is numerically sound. + assert (m_mixed.inv() - m_inv).norm() < 1e-15 + assert (m_float.inv() - m_inv).norm() < 1e-15 + + +def test_iszero_substitution(): + """ When doing numerical computations, all elements that pass + the iszerofunc test should be set to numerically zero if they + aren't already. """ + + # Matrix from issue #9060 + m = Matrix([[0.9, -0.1, -0.2, 0],[-0.8, 0.9, -0.4, 0],[-0.1, -0.8, 0.6, 0]]) + m_rref = m.rref(iszerofunc=lambda x: abs(x)<6e-15)[0] + m_correct = Matrix([[1.0, 0, -0.301369863013699, 0],[ 0, 1.0, -0.712328767123288, 0],[ 0, 0, 0, 0]]) + m_diff = m_rref - m_correct + assert m_diff.norm() < 1e-15 + # if a zero-substitution wasn't made, this entry will be -1.11022302462516e-16 + assert m_rref[2,2] == 0 + + +def test_issue_11238(): + from sympy.geometry.point import Point + xx = 8*tan(pi*Rational(13, 45))/(tan(pi*Rational(13, 45)) + sqrt(3)) + yy = (-8*sqrt(3)*tan(pi*Rational(13, 45))**2 + 24*tan(pi*Rational(13, 45)))/(-3 + tan(pi*Rational(13, 45))**2) + p1 = Point(0, 0) + p2 = Point(1, -sqrt(3)) + p0 = Point(xx,yy) + m1 = Matrix([p1 - simplify(p0), p2 - simplify(p0)]) + m2 = Matrix([p1 - p0, p2 - p0]) + m3 = Matrix([simplify(p1 - p0), simplify(p2 - p0)]) + + # This system has expressions which are zero and + # cannot be easily proved to be such, so without + # numerical testing, these assertions will fail. + Z = lambda x: abs(x.n()) < 1e-20 + assert m1.rank(simplify=True, iszerofunc=Z) == 1 + assert m2.rank(simplify=True, iszerofunc=Z) == 1 + assert m3.rank(simplify=True, iszerofunc=Z) == 1 + + +def test_as_real_imag(): + m1 = Matrix(2,2,[1,2,3,4]) + m2 = m1*S.ImaginaryUnit + m3 = m1 + m2 + + for kls in all_classes: + a,b = kls(m3).as_real_imag() + assert list(a) == list(m1) + assert list(b) == list(m1) + + +def test_deprecated(): + # Maintain tests for deprecated functions. We must capture + # the deprecation warnings. When the deprecated functionality is + # removed, the corresponding tests should be removed. + + m = Matrix(3, 3, [0, 1, 0, -4, 4, 0, -2, 1, 2]) + P, Jcells = m.jordan_cells() + assert Jcells[1] == Matrix(1, 1, [2]) + assert Jcells[0] == Matrix(2, 2, [2, 1, 0, 2]) + + +def test_issue_14489(): + from sympy.core.mod import Mod + A = Matrix([-1, 1, 2]) + B = Matrix([10, 20, -15]) + + assert Mod(A, 3) == Matrix([2, 1, 2]) + assert Mod(B, 4) == Matrix([2, 0, 1]) + + +def test_issue_14943(): + # Test that __array__ accepts the optional dtype argument + try: + from numpy import array + except ImportError: + skip('NumPy must be available to test creating matrices from ndarrays') + + M = Matrix([[1,2], [3,4]]) + assert array(M, dtype=float).dtype.name == 'float64' + + +def test_case_6913(): + m = MatrixSymbol('m', 1, 1) + a = Symbol("a") + a = m[0, 0]>0 + assert str(a) == 'm[0, 0] > 0' + + +def test_issue_11948(): + A = MatrixSymbol('A', 3, 3) + a = Wild('a') + assert A.match(a) == {a: A} + + +def test_gramschmidt_conjugate_dot(): + vecs = [Matrix([1, I]), Matrix([1, -I])] + assert Matrix.orthogonalize(*vecs) == \ + [Matrix([[1], [I]]), Matrix([[1], [-I]])] + + vecs = [Matrix([1, I, 0]), Matrix([I, 0, -I])] + assert Matrix.orthogonalize(*vecs) == \ + [Matrix([[1], [I], [0]]), Matrix([[I/2], [S(1)/2], [-I]])] + + mat = Matrix([[1, I], [1, -I]]) + Q, R = mat.QRdecomposition() + assert Q * Q.H == Matrix.eye(2) + + +def test_issue_8207(): + a = Matrix(MatrixSymbol('a', 3, 1)) + b = Matrix(MatrixSymbol('b', 3, 1)) + c = a.dot(b) + d = diff(c, a[0, 0]) + e = diff(d, a[0, 0]) + assert d == b[0, 0] + assert e == 0 + + +def test_func(): + from sympy.simplify.simplify import nthroot + + A = Matrix([[1, 2],[0, 3]]) + assert A.analytic_func(sin(x*t), x) == Matrix([[sin(t), sin(3*t) - sin(t)], [0, sin(3*t)]]) + + A = Matrix([[2, 1],[1, 2]]) + assert (pi * A / 6).analytic_func(cos(x), x) == Matrix([[sqrt(3)/4, -sqrt(3)/4], [-sqrt(3)/4, sqrt(3)/4]]) + + + raises(ValueError, lambda : zeros(5).analytic_func(log(x), x)) + raises(ValueError, lambda : (A*x).analytic_func(log(x), x)) + + A = Matrix([[0, -1, -2, 3], [0, -1, -2, 3], [0, 1, 0, -1], [0, 0, -1, 1]]) + assert A.analytic_func(exp(x), x) == A.exp() + raises(ValueError, lambda : A.analytic_func(sqrt(x), x)) + + A = Matrix([[41, 12],[12, 34]]) + assert simplify(A.analytic_func(sqrt(x), x)**2) == A + + A = Matrix([[3, -12, 4], [-1, 0, -2], [-1, 5, -1]]) + assert simplify(A.analytic_func(nthroot(x, 3), x)**3) == A + + A = Matrix([[2, 0, 0, 0], [1, 2, 0, 0], [0, 1, 3, 0], [0, 0, 1, 3]]) + assert A.analytic_func(exp(x), x) == A.exp() + + A = Matrix([[0, 2, 1, 6], [0, 0, 1, 2], [0, 0, 0, 3], [0, 0, 0, 0]]) + assert A.analytic_func(exp(x*t), x) == expand(simplify((A*t).exp())) + + +@skip_under_pyodide("Cannot create threads under pyodide.") +def test_issue_19809(): + + def f(): + assert _dotprodsimp_state.state == None + m = Matrix([[1]]) + m = m * m + return True + + with dotprodsimp(True): + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(f) + assert future.result() + + +def test_issue_23276(): + M = Matrix([x, y]) + assert integrate(M, (x, 0, 1), (y, 0, 1)) == Matrix([ + [S.Half], + [S.Half]]) + + +def test_issue_27225(): + # https://github.com/sympy/sympy/issues/27225 + raises(TypeError, lambda : floor(Matrix([1, 1, 0]))) diff --git a/phivenv/Lib/site-packages/sympy/matrices/tests/test_normalforms.py b/phivenv/Lib/site-packages/sympy/matrices/tests/test_normalforms.py new file mode 100644 index 0000000000000000000000000000000000000000..47ee52d73539f7fb79295443e1cf7e0a49e30a5e --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/tests/test_normalforms.py @@ -0,0 +1,111 @@ +from sympy.testing.pytest import warns_deprecated_sympy + +from sympy.core.symbol import Symbol +from sympy.polys.polytools import Poly +from sympy.matrices import Matrix, randMatrix +from sympy.matrices.normalforms import ( + invariant_factors, + smith_normal_form, + smith_normal_decomp, + hermite_normal_form, + is_smith_normal_form, +) +from sympy.polys.domains import ZZ, QQ +from sympy.core.numbers import Integer + +import random + + +def test_smith_normal(): + m = Matrix([[12,6,4,8],[3,9,6,12],[2,16,14,28],[20,10,10,20]]) + smf = Matrix([[1, 0, 0, 0], [0, 10, 0, 0], [0, 0, 30, 0], [0, 0, 0, 0]]) + assert smith_normal_form(m) == smf + + a, s, t = smith_normal_decomp(m) + assert a == s * m * t + + x = Symbol('x') + with warns_deprecated_sympy(): + m = Matrix([[Poly(x-1), Poly(1, x),Poly(-1,x)], + [0, Poly(x), Poly(-1,x)], + [Poly(0,x),Poly(-1,x),Poly(x)]]) + invs = 1, x - 1, x**2 - 1 + assert invariant_factors(m, domain=QQ[x]) == invs + + m = Matrix([[2, 4]]) + smf = Matrix([[2, 0]]) + assert smith_normal_form(m) == smf + + prng = random.Random(0) + for i in range(6): + for j in range(6): + for _ in range(10 if i*j else 1): + m = randMatrix(i, j, max=5, percent=50, prng=prng) + a, s, t = smith_normal_decomp(m) + assert a == s * m * t + assert is_smith_normal_form(a) + s.inv().to_DM(ZZ) + t.inv().to_DM(ZZ) + + a, s, t = smith_normal_decomp(m, QQ) + assert a == s * m * t + assert is_smith_normal_form(a) + s.inv() + t.inv() + + +def test_smith_normal_deprecated(): + from sympy.polys.solvers import RawMatrix as Matrix + + with warns_deprecated_sympy(): + m = Matrix([[12, 6, 4,8],[3,9,6,12],[2,16,14,28],[20,10,10,20]]) + setattr(m, 'ring', ZZ) + with warns_deprecated_sympy(): + smf = Matrix([[1, 0, 0, 0], [0, 10, 0, 0], [0, 0, 30, 0], [0, 0, 0, 0]]) + assert smith_normal_form(m) == smf + + x = Symbol('x') + with warns_deprecated_sympy(): + m = Matrix([[Poly(x-1), Poly(1, x),Poly(-1,x)], + [0, Poly(x), Poly(-1,x)], + [Poly(0,x),Poly(-1,x),Poly(x)]]) + setattr(m, 'ring', QQ[x]) + invs = (Poly(1, x, domain='QQ'), Poly(x - 1, domain='QQ'), Poly(x**2 - 1, domain='QQ')) + assert invariant_factors(m) == invs + + with warns_deprecated_sympy(): + m = Matrix([[2, 4]]) + setattr(m, 'ring', ZZ) + with warns_deprecated_sympy(): + smf = Matrix([[2, 0]]) + assert smith_normal_form(m) == smf + + +def test_hermite_normal(): + m = Matrix([[2, 7, 17, 29, 41], [3, 11, 19, 31, 43], [5, 13, 23, 37, 47]]) + hnf = Matrix([[1, 0, 0], [0, 2, 1], [0, 0, 1]]) + assert hermite_normal_form(m) == hnf + + tr_hnf = Matrix([[37, 0, 19], [222, -6, 113], [48, 0, 25], [0, 2, 1], [0, 0, 1]]) + assert hermite_normal_form(m.transpose()) == tr_hnf + + m = Matrix([[8, 28, 68, 116, 164], [3, 11, 19, 31, 43], [5, 13, 23, 37, 47]]) + hnf = Matrix([[4, 0, 0], [0, 2, 1], [0, 0, 1]]) + assert hermite_normal_form(m) == hnf + assert hermite_normal_form(m, D=8) == hnf + assert hermite_normal_form(m, D=ZZ(8)) == hnf + assert hermite_normal_form(m, D=Integer(8)) == hnf + + m = Matrix([[10, 8, 6, 30, 2], [45, 36, 27, 18, 9], [5, 4, 3, 2, 1]]) + hnf = Matrix([[26, 2], [0, 9], [0, 1]]) + assert hermite_normal_form(m) == hnf + + m = Matrix([[2, 7], [0, 0], [0, 0]]) + hnf = Matrix([[1], [0], [0]]) + assert hermite_normal_form(m) == hnf + + +def test_issue_23410(): + A = Matrix([[1, 12], [0, 8], [0, 5]]) + H = Matrix([[1, 0], [0, 8], [0, 5]]) + assert hermite_normal_form(A) == H diff --git a/phivenv/Lib/site-packages/sympy/matrices/tests/test_reductions.py b/phivenv/Lib/site-packages/sympy/matrices/tests/test_reductions.py new file mode 100644 index 0000000000000000000000000000000000000000..32c98c6f249b1afafc8193f4248dc9493bb803e0 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/tests/test_reductions.py @@ -0,0 +1,351 @@ +from sympy.core.numbers import I +from sympy.core.symbol import symbols +from sympy.testing.pytest import raises +from sympy.matrices import Matrix, zeros, eye +from sympy.core.symbol import Symbol +from sympy.core.numbers import Rational +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.simplify.simplify import simplify +from sympy.abc import x + + +# Matrix tests +def test_row_op(): + e = eye(3) + + raises(ValueError, lambda: e.elementary_row_op("abc")) + raises(ValueError, lambda: e.elementary_row_op()) + raises(ValueError, lambda: e.elementary_row_op('n->kn', row=5, k=5)) + raises(ValueError, lambda: e.elementary_row_op('n->kn', row=-5, k=5)) + raises(ValueError, lambda: e.elementary_row_op('n<->m', row1=1, row2=5)) + raises(ValueError, lambda: e.elementary_row_op('n<->m', row1=5, row2=1)) + raises(ValueError, lambda: e.elementary_row_op('n<->m', row1=-5, row2=1)) + raises(ValueError, lambda: e.elementary_row_op('n<->m', row1=1, row2=-5)) + raises(ValueError, lambda: e.elementary_row_op('n->n+km', row1=1, row2=5, k=5)) + raises(ValueError, lambda: e.elementary_row_op('n->n+km', row1=5, row2=1, k=5)) + raises(ValueError, lambda: e.elementary_row_op('n->n+km', row1=-5, row2=1, k=5)) + raises(ValueError, lambda: e.elementary_row_op('n->n+km', row1=1, row2=-5, k=5)) + raises(ValueError, lambda: e.elementary_row_op('n->n+km', row1=1, row2=1, k=5)) + + # test various ways to set arguments + assert e.elementary_row_op("n->kn", 0, 5) == Matrix([[5, 0, 0], [0, 1, 0], [0, 0, 1]]) + assert e.elementary_row_op("n->kn", 1, 5) == Matrix([[1, 0, 0], [0, 5, 0], [0, 0, 1]]) + assert e.elementary_row_op("n->kn", row=1, k=5) == Matrix([[1, 0, 0], [0, 5, 0], [0, 0, 1]]) + assert e.elementary_row_op("n->kn", row1=1, k=5) == Matrix([[1, 0, 0], [0, 5, 0], [0, 0, 1]]) + assert e.elementary_row_op("n<->m", 0, 1) == Matrix([[0, 1, 0], [1, 0, 0], [0, 0, 1]]) + assert e.elementary_row_op("n<->m", row1=0, row2=1) == Matrix([[0, 1, 0], [1, 0, 0], [0, 0, 1]]) + assert e.elementary_row_op("n<->m", row=0, row2=1) == Matrix([[0, 1, 0], [1, 0, 0], [0, 0, 1]]) + assert e.elementary_row_op("n->n+km", 0, 5, 1) == Matrix([[1, 5, 0], [0, 1, 0], [0, 0, 1]]) + assert e.elementary_row_op("n->n+km", row=0, k=5, row2=1) == Matrix([[1, 5, 0], [0, 1, 0], [0, 0, 1]]) + assert e.elementary_row_op("n->n+km", row1=0, k=5, row2=1) == Matrix([[1, 5, 0], [0, 1, 0], [0, 0, 1]]) + + # make sure the matrix doesn't change size + a = Matrix(2, 3, [0]*6) + assert a.elementary_row_op("n->kn", 1, 5) == Matrix(2, 3, [0]*6) + assert a.elementary_row_op("n<->m", 0, 1) == Matrix(2, 3, [0]*6) + assert a.elementary_row_op("n->n+km", 0, 5, 1) == Matrix(2, 3, [0]*6) + + +def test_col_op(): + e = eye(3) + + raises(ValueError, lambda: e.elementary_col_op("abc")) + raises(ValueError, lambda: e.elementary_col_op()) + raises(ValueError, lambda: e.elementary_col_op('n->kn', col=5, k=5)) + raises(ValueError, lambda: e.elementary_col_op('n->kn', col=-5, k=5)) + raises(ValueError, lambda: e.elementary_col_op('n<->m', col1=1, col2=5)) + raises(ValueError, lambda: e.elementary_col_op('n<->m', col1=5, col2=1)) + raises(ValueError, lambda: e.elementary_col_op('n<->m', col1=-5, col2=1)) + raises(ValueError, lambda: e.elementary_col_op('n<->m', col1=1, col2=-5)) + raises(ValueError, lambda: e.elementary_col_op('n->n+km', col1=1, col2=5, k=5)) + raises(ValueError, lambda: e.elementary_col_op('n->n+km', col1=5, col2=1, k=5)) + raises(ValueError, lambda: e.elementary_col_op('n->n+km', col1=-5, col2=1, k=5)) + raises(ValueError, lambda: e.elementary_col_op('n->n+km', col1=1, col2=-5, k=5)) + raises(ValueError, lambda: e.elementary_col_op('n->n+km', col1=1, col2=1, k=5)) + + # test various ways to set arguments + assert e.elementary_col_op("n->kn", 0, 5) == Matrix([[5, 0, 0], [0, 1, 0], [0, 0, 1]]) + assert e.elementary_col_op("n->kn", 1, 5) == Matrix([[1, 0, 0], [0, 5, 0], [0, 0, 1]]) + assert e.elementary_col_op("n->kn", col=1, k=5) == Matrix([[1, 0, 0], [0, 5, 0], [0, 0, 1]]) + assert e.elementary_col_op("n->kn", col1=1, k=5) == Matrix([[1, 0, 0], [0, 5, 0], [0, 0, 1]]) + assert e.elementary_col_op("n<->m", 0, 1) == Matrix([[0, 1, 0], [1, 0, 0], [0, 0, 1]]) + assert e.elementary_col_op("n<->m", col1=0, col2=1) == Matrix([[0, 1, 0], [1, 0, 0], [0, 0, 1]]) + assert e.elementary_col_op("n<->m", col=0, col2=1) == Matrix([[0, 1, 0], [1, 0, 0], [0, 0, 1]]) + assert e.elementary_col_op("n->n+km", 0, 5, 1) == Matrix([[1, 0, 0], [5, 1, 0], [0, 0, 1]]) + assert e.elementary_col_op("n->n+km", col=0, k=5, col2=1) == Matrix([[1, 0, 0], [5, 1, 0], [0, 0, 1]]) + assert e.elementary_col_op("n->n+km", col1=0, k=5, col2=1) == Matrix([[1, 0, 0], [5, 1, 0], [0, 0, 1]]) + + # make sure the matrix doesn't change size + a = Matrix(2, 3, [0]*6) + assert a.elementary_col_op("n->kn", 1, 5) == Matrix(2, 3, [0]*6) + assert a.elementary_col_op("n<->m", 0, 1) == Matrix(2, 3, [0]*6) + assert a.elementary_col_op("n->n+km", 0, 5, 1) == Matrix(2, 3, [0]*6) + + +def test_is_echelon(): + zro = zeros(3) + ident = eye(3) + + assert zro.is_echelon + assert ident.is_echelon + + a = Matrix(0, 0, []) + assert a.is_echelon + + a = Matrix(2, 3, [3, 2, 1, 0, 0, 6]) + assert a.is_echelon + + a = Matrix(2, 3, [0, 0, 6, 3, 2, 1]) + assert not a.is_echelon + + x = Symbol('x') + a = Matrix(3, 1, [x, 0, 0]) + assert a.is_echelon + + a = Matrix(3, 1, [x, x, 0]) + assert not a.is_echelon + + a = Matrix(3, 3, [0, 0, 0, 1, 2, 3, 0, 0, 0]) + assert not a.is_echelon + + +def test_echelon_form(): + # echelon form is not unique, but the result + # must be row-equivalent to the original matrix + # and it must be in echelon form. + + a = zeros(3) + e = eye(3) + + # we can assume the zero matrix and the identity matrix shouldn't change + assert a.echelon_form() == a + assert e.echelon_form() == e + + a = Matrix(0, 0, []) + assert a.echelon_form() == a + + a = Matrix(1, 1, [5]) + assert a.echelon_form() == a + + # now we get to the real tests + + def verify_row_null_space(mat, rows, nulls): + for v in nulls: + assert all(t.is_zero for t in a_echelon*v) + for v in rows: + if not all(t.is_zero for t in v): + assert not all(t.is_zero for t in a_echelon*v.transpose()) + + a = Matrix(3, 3, [1, 2, 3, 4, 5, 6, 7, 8, 9]) + nulls = [Matrix([ + [ 1], + [-2], + [ 1]])] + rows = [a[i, :] for i in range(a.rows)] + a_echelon = a.echelon_form() + assert a_echelon.is_echelon + verify_row_null_space(a, rows, nulls) + + + a = Matrix(3, 3, [1, 2, 3, 4, 5, 6, 7, 8, 8]) + nulls = [] + rows = [a[i, :] for i in range(a.rows)] + a_echelon = a.echelon_form() + assert a_echelon.is_echelon + verify_row_null_space(a, rows, nulls) + + a = Matrix(3, 3, [2, 1, 3, 0, 0, 0, 2, 1, 3]) + nulls = [Matrix([ + [Rational(-1, 2)], + [ 1], + [ 0]]), + Matrix([ + [Rational(-3, 2)], + [ 0], + [ 1]])] + rows = [a[i, :] for i in range(a.rows)] + a_echelon = a.echelon_form() + assert a_echelon.is_echelon + verify_row_null_space(a, rows, nulls) + + # this one requires a row swap + a = Matrix(3, 3, [2, 1, 3, 0, 0, 0, 1, 1, 3]) + nulls = [Matrix([ + [ 0], + [ -3], + [ 1]])] + rows = [a[i, :] for i in range(a.rows)] + a_echelon = a.echelon_form() + assert a_echelon.is_echelon + verify_row_null_space(a, rows, nulls) + + a = Matrix(3, 3, [0, 3, 3, 0, 2, 2, 0, 1, 1]) + nulls = [Matrix([ + [1], + [0], + [0]]), + Matrix([ + [ 0], + [-1], + [ 1]])] + rows = [a[i, :] for i in range(a.rows)] + a_echelon = a.echelon_form() + assert a_echelon.is_echelon + verify_row_null_space(a, rows, nulls) + + a = Matrix(2, 3, [2, 2, 3, 3, 3, 0]) + nulls = [Matrix([ + [-1], + [1], + [0]])] + rows = [a[i, :] for i in range(a.rows)] + a_echelon = a.echelon_form() + assert a_echelon.is_echelon + verify_row_null_space(a, rows, nulls) + + +def test_rref(): + e = Matrix(0, 0, []) + assert e.rref(pivots=False) == e + + e = Matrix(1, 1, [1]) + a = Matrix(1, 1, [5]) + assert e.rref(pivots=False) == a.rref(pivots=False) == e + + a = Matrix(3, 1, [1, 2, 3]) + assert a.rref(pivots=False) == Matrix([[1], [0], [0]]) + + a = Matrix(1, 3, [1, 2, 3]) + assert a.rref(pivots=False) == Matrix([[1, 2, 3]]) + + a = Matrix(3, 3, [1, 2, 3, 4, 5, 6, 7, 8, 9]) + assert a.rref(pivots=False) == Matrix([ + [1, 0, -1], + [0, 1, 2], + [0, 0, 0]]) + + a = Matrix(3, 3, [1, 2, 3, 1, 2, 3, 1, 2, 3]) + b = Matrix(3, 3, [1, 2, 3, 0, 0, 0, 0, 0, 0]) + c = Matrix(3, 3, [0, 0, 0, 1, 2, 3, 0, 0, 0]) + d = Matrix(3, 3, [0, 0, 0, 0, 0, 0, 1, 2, 3]) + assert a.rref(pivots=False) == \ + b.rref(pivots=False) == \ + c.rref(pivots=False) == \ + d.rref(pivots=False) == b + + e = eye(3) + z = zeros(3) + assert e.rref(pivots=False) == e + assert z.rref(pivots=False) == z + + a = Matrix([ + [ 0, 0, 1, 2, 2, -5, 3], + [-1, 5, 2, 2, 1, -7, 5], + [ 0, 0, -2, -3, -3, 8, -5], + [-1, 5, 0, -1, -2, 1, 0]]) + mat, pivot_offsets = a.rref() + assert mat == Matrix([ + [1, -5, 0, 0, 1, 1, -1], + [0, 0, 1, 0, 0, -1, 1], + [0, 0, 0, 1, 1, -2, 1], + [0, 0, 0, 0, 0, 0, 0]]) + assert pivot_offsets == (0, 2, 3) + + a = Matrix([[Rational(1, 19), Rational(1, 5), 2, 3], + [ 4, 5, 6, 7], + [ 8, 9, 10, 11], + [ 12, 13, 14, 15]]) + assert a.rref(pivots=False) == Matrix([ + [1, 0, 0, Rational(-76, 157)], + [0, 1, 0, Rational(-5, 157)], + [0, 0, 1, Rational(238, 157)], + [0, 0, 0, 0]]) + + x = Symbol('x') + a = Matrix(2, 3, [x, 1, 1, sqrt(x), x, 1]) + for i, j in zip(a.rref(pivots=False), + [1, 0, sqrt(x)*(-x + 1)/(-x**Rational(5, 2) + x), + 0, 1, 1/(sqrt(x) + x + 1)]): + assert simplify(i - j).is_zero + + +def test_rref_rhs(): + a, b, c, d = symbols('a b c d') + A = Matrix([[0, 0], [0, 0], [1, 2], [3, 4]]) + B = Matrix([a, b, c, d]) + assert A.rref_rhs(B) == (Matrix([ + [1, 0], + [0, 1], + [0, 0], + [0, 0]]), Matrix([ + [ -2*c + d], + [3*c/2 - d/2], + [ a], + [ b]])) + + +def test_issue_17827(): + C = Matrix([ + [3, 4, -1, 1], + [9, 12, -3, 3], + [0, 2, 1, 3], + [2, 3, 0, -2], + [0, 3, 3, -5], + [8, 15, 0, 6] + ]) + # Tests for row/col within valid range + D = C.elementary_row_op('n<->m', row1=2, row2=5) + E = C.elementary_row_op('n->n+km', row1=5, row2=3, k=-4) + F = C.elementary_row_op('n->kn', row=5, k=2) + assert(D[5, :] == Matrix([[0, 2, 1, 3]])) + assert(E[5, :] == Matrix([[0, 3, 0, 14]])) + assert(F[5, :] == Matrix([[16, 30, 0, 12]])) + # Tests for row/col out of range + raises(ValueError, lambda: C.elementary_row_op('n<->m', row1=2, row2=6)) + raises(ValueError, lambda: C.elementary_row_op('n->kn', row=7, k=2)) + raises(ValueError, lambda: C.elementary_row_op('n->n+km', row1=-1, row2=5, k=2)) + +def test_rank(): + m = Matrix([[1, 2], [x, 1 - 1/x]]) + assert m.rank() == 2 + n = Matrix(3, 3, range(1, 10)) + assert n.rank() == 2 + p = zeros(3) + assert p.rank() == 0 + +def test_issue_11434(): + ax, ay, bx, by, cx, cy, dx, dy, ex, ey, t0, t1 = \ + symbols('a_x a_y b_x b_y c_x c_y d_x d_y e_x e_y t_0 t_1') + M = Matrix([[ax, ay, ax*t0, ay*t0, 0], + [bx, by, bx*t0, by*t0, 0], + [cx, cy, cx*t0, cy*t0, 1], + [dx, dy, dx*t0, dy*t0, 1], + [ex, ey, 2*ex*t1 - ex*t0, 2*ey*t1 - ey*t0, 0]]) + assert M.rank() == 4 + +def test_rank_regression_from_so(): + # see: + # https://stackoverflow.com/questions/19072700/why-does-sympy-give-me-the-wrong-answer-when-i-row-reduce-a-symbolic-matrix + + nu, lamb = symbols('nu, lambda') + A = Matrix([[-3*nu, 1, 0, 0], + [ 3*nu, -2*nu - 1, 2, 0], + [ 0, 2*nu, (-1*nu) - lamb - 2, 3], + [ 0, 0, nu + lamb, -3]]) + expected_reduced = Matrix([[1, 0, 0, 1/(nu**2*(-lamb - nu))], + [0, 1, 0, 3/(nu*(-lamb - nu))], + [0, 0, 1, 3/(-lamb - nu)], + [0, 0, 0, 0]]) + expected_pivots = (0, 1, 2) + + reduced, pivots = A.rref() + + assert simplify(expected_reduced - reduced) == zeros(*A.shape) + assert pivots == expected_pivots + +def test_issue_15872(): + A = Matrix([[1, 1, 1, 0], [-2, -1, 0, -1], [0, 0, -1, -1], [0, 0, 2, 1]]) + B = A - Matrix.eye(4) * I + assert B.rank() == 3 + assert (B**2).rank() == 2 + assert (B**3).rank() == 2 diff --git a/phivenv/Lib/site-packages/sympy/matrices/tests/test_repmatrix.py b/phivenv/Lib/site-packages/sympy/matrices/tests/test_repmatrix.py new file mode 100644 index 0000000000000000000000000000000000000000..ee36de004705f29eaa49ea8e06fd65a8a2baa718 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/tests/test_repmatrix.py @@ -0,0 +1,62 @@ +from sympy.testing.pytest import raises +from sympy.matrices.exceptions import NonSquareMatrixError, NonInvertibleMatrixError + +from sympy import Matrix, Rational + + +def test_lll(): + A = Matrix([[1, 0, 0, 0, -20160], + [0, 1, 0, 0, 33768], + [0, 0, 1, 0, 39578], + [0, 0, 0, 1, 47757]]) + L = Matrix([[ 10, -3, -2, 8, -4], + [ 3, -9, 8, 1, -11], + [ -3, 13, -9, -3, -9], + [-12, -7, -11, 9, -1]]) + T = Matrix([[ 10, -3, -2, 8], + [ 3, -9, 8, 1], + [ -3, 13, -9, -3], + [-12, -7, -11, 9]]) + assert A.lll() == L + assert A.lll_transform() == (L, T) + assert T * A == L + + +def test_matrix_inv_mod(): + A = Matrix(2, 1, [1, 0]) + raises(NonSquareMatrixError, lambda: A.inv_mod(2)) + A = Matrix(2, 2, [1, 0, 0, 0]) + raises(NonInvertibleMatrixError, lambda: A.inv_mod(2)) + A = Matrix(2, 2, [1, 2, 3, 4]) + Ai = Matrix(2, 2, [1, 1, 0, 1]) + assert A.inv_mod(3) == Ai + A = Matrix(2, 2, [1, 0, 0, 1]) + assert A.inv_mod(2) == A + A = Matrix(3, 3, [1, 2, 3, 4, 5, 6, 7, 8, 9]) + raises(NonInvertibleMatrixError, lambda: A.inv_mod(5)) + A = Matrix(3, 3, [5, 1, 3, 2, 6, 0, 2, 1, 1]) + Ai = Matrix(3, 3, [6, 8, 0, 1, 5, 6, 5, 6, 4]) + assert A.inv_mod(9) == Ai + A = Matrix(3, 3, [1, 6, -3, 4, 1, -5, 3, -5, 5]) + Ai = Matrix(3, 3, [4, 3, 3, 1, 2, 5, 1, 5, 1]) + assert A.inv_mod(6) == Ai + A = Matrix(3, 3, [1, 6, 1, 4, 1, 5, 3, 2, 5]) + Ai = Matrix(3, 3, [6, 0, 3, 6, 6, 4, 1, 6, 1]) + assert A.inv_mod(7) == Ai + A = Matrix([[1, 2], [3, Rational(3,4)]]) + raises(ValueError, lambda: A.inv_mod(2)) + A = Matrix([[1, 2], [3, 4]]) + raises(TypeError, lambda: A.inv_mod(Rational(1, 2))) + # https://github.com/sympy/sympy/issues/27663 + M = Matrix([ + [2, 3, 1, 4], + [1, 5, 3, 2], + [3, 2, 4, 1], + [4, 1, 2, 5], + ]) + assert M.inv_mod(26) == Matrix([ + [7, 21, 10, 10], + [1, 7, 19, 3], + [14, 1, 15, 1], + [25, 23, 3, 12], + ]) diff --git a/phivenv/Lib/site-packages/sympy/matrices/tests/test_solvers.py b/phivenv/Lib/site-packages/sympy/matrices/tests/test_solvers.py new file mode 100644 index 0000000000000000000000000000000000000000..c1347062c0482336affbeb4bb9a95aedfcc0ae53 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/tests/test_solvers.py @@ -0,0 +1,615 @@ +import pytest +from sympy.core.function import expand_mul +from sympy.core.numbers import (I, Rational) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.core.sympify import sympify +from sympy.simplify.simplify import simplify +from sympy.matrices.exceptions import (ShapeError, NonSquareMatrixError) +from sympy.matrices import ( + ImmutableMatrix, Matrix, eye, ones, ImmutableDenseMatrix, dotprodsimp) +from sympy.matrices.determinant import _det_laplace +from sympy.testing.pytest import raises +from sympy.matrices.exceptions import NonInvertibleMatrixError +from sympy.polys.matrices.exceptions import DMShapeError +from sympy.solvers.solveset import linsolve +from sympy.abc import x, y + +def test_issue_17247_expression_blowup_29(): + M = Matrix(S('''[ + [ -3/4, 45/32 - 37*I/16, 0, 0], + [-149/64 + 49*I/32, -177/128 - 1369*I/128, 0, -2063/256 + 541*I/128], + [ 0, 9/4 + 55*I/16, 2473/256 + 137*I/64, 0], + [ 0, 0, 0, -177/128 - 1369*I/128]]''')) + with dotprodsimp(True): + assert M.gauss_jordan_solve(ones(4, 1)) == (Matrix(S('''[ + [ -32549314808672/3306971225785 - 17397006745216*I/3306971225785], + [ 67439348256/3306971225785 - 9167503335872*I/3306971225785], + [-15091965363354518272/21217636514687010905 + 16890163109293858304*I/21217636514687010905], + [ -11328/952745 + 87616*I/952745]]''')), Matrix(0, 1, [])) + +def test_issue_17247_expression_blowup_30(): + M = Matrix(S('''[ + [ -3/4, 45/32 - 37*I/16, 0, 0], + [-149/64 + 49*I/32, -177/128 - 1369*I/128, 0, -2063/256 + 541*I/128], + [ 0, 9/4 + 55*I/16, 2473/256 + 137*I/64, 0], + [ 0, 0, 0, -177/128 - 1369*I/128]]''')) + with dotprodsimp(True): + assert M.cholesky_solve(ones(4, 1)) == Matrix(S('''[ + [ -32549314808672/3306971225785 - 17397006745216*I/3306971225785], + [ 67439348256/3306971225785 - 9167503335872*I/3306971225785], + [-15091965363354518272/21217636514687010905 + 16890163109293858304*I/21217636514687010905], + [ -11328/952745 + 87616*I/952745]]''')) + +# @XFAIL # This calculation hangs with dotprodsimp. +# def test_issue_17247_expression_blowup_31(): +# M = Matrix([ +# [x + 1, 1 - x, 0, 0], +# [1 - x, x + 1, 0, x + 1], +# [ 0, 1 - x, x + 1, 0], +# [ 0, 0, 0, x + 1]]) +# with dotprodsimp(True): +# assert M.LDLsolve(ones(4, 1)) == Matrix([ +# [(x + 1)/(4*x)], +# [(x - 1)/(4*x)], +# [(x + 1)/(4*x)], +# [ 1/(x + 1)]]) + + +def test_LUsolve_iszerofunc(): + # taken from https://github.com/sympy/sympy/issues/24679 + + M = Matrix([[(x + 1)**2 - (x**2 + 2*x + 1), x], [x, 0]]) + b = Matrix([1, 1]) + is_zero_func = lambda e: False if e._random() else True + + x_exp = Matrix([1/x, (1-(-x**2 - 2*x + (x+1)**2 - 1)/x)/x]) + + assert (x_exp - M.LUsolve(b, iszerofunc=is_zero_func)) == Matrix([0, 0]) + + +def test_issue_17247_expression_blowup_32(): + M = Matrix([ + [x + 1, 1 - x, 0, 0], + [1 - x, x + 1, 0, x + 1], + [ 0, 1 - x, x + 1, 0], + [ 0, 0, 0, x + 1]]) + with dotprodsimp(True): + assert M.LUsolve(ones(4, 1)) == Matrix([ + [(x + 1)/(4*x)], + [(x - 1)/(4*x)], + [(x + 1)/(4*x)], + [ 1/(x + 1)]]) + +def test_LUsolve(): + A = Matrix([[2, 3, 5], + [3, 6, 2], + [8, 3, 6]]) + x = Matrix(3, 1, [3, 7, 5]) + b = A*x + soln = A.LUsolve(b) + assert soln == x + A = Matrix([[0, -1, 2], + [5, 10, 7], + [8, 3, 4]]) + x = Matrix(3, 1, [-1, 2, 5]) + b = A*x + soln = A.LUsolve(b) + assert soln == x + A = Matrix([[2, 1], [1, 0], [1, 0]]) # issue 14548 + b = Matrix([3, 1, 1]) + assert A.LUsolve(b) == Matrix([1, 1]) + b = Matrix([3, 1, 2]) # inconsistent + raises(ValueError, lambda: A.LUsolve(b)) + A = Matrix([[0, -1, 2], + [5, 10, 7], + [8, 3, 4], + [2, 3, 5], + [3, 6, 2], + [8, 3, 6]]) + x = Matrix([2, 1, -4]) + b = A*x + soln = A.LUsolve(b) + assert soln == x + A = Matrix([[0, -1, 2], [5, 10, 7]]) # underdetermined + x = Matrix([-1, 2, 0]) + b = A*x + raises(NotImplementedError, lambda: A.LUsolve(b)) + + A = Matrix(4, 4, lambda i, j: 1/(i+j+1) if i != 3 else 0) + b = Matrix.zeros(4, 1) + raises(NonInvertibleMatrixError, lambda: A.LUsolve(b)) + + +def test_LUsolve_noncommutative(): + a0, a1, a2, a3 = symbols("a:4", commutative=False) + b0, b1 = symbols("b:2", commutative=False) + A = Matrix([[a0, a1], [a2, a3]]) + check = A * A.LUsolve(Matrix([b0, b1])) + assert check[0, 0].expand() == b0 + # Because sympy simplification is very limited with noncommutative expressions, + # perform an explicit check with the second element + assert check[1, 0] == ( + a2*a0**(-1)*(-a1*(-a2*a0**(-1)*a1 + a3)**(-1)*(-a2*a0**(-1)*b0 + b1) + b0) + + a3*(-a2*a0**(-1)*a1 + a3)**(-1)*(-a2*a0**(-1)*b0 + b1) + ) + + +def test_QRsolve(): + A = Matrix([[2, 3, 5], + [3, 6, 2], + [8, 3, 6]]) + x = Matrix(3, 1, [3, 7, 5]) + b = A*x + soln = A.QRsolve(b) + assert soln == x + x = Matrix([[1, 2], [3, 4], [5, 6]]) + b = A*x + soln = A.QRsolve(b) + assert soln == x + + A = Matrix([[0, -1, 2], + [5, 10, 7], + [8, 3, 4]]) + x = Matrix(3, 1, [-1, 2, 5]) + b = A*x + soln = A.QRsolve(b) + assert soln == x + x = Matrix([[7, 8], [9, 10], [11, 12]]) + b = A*x + soln = A.QRsolve(b) + assert soln == x + +def test_errors(): + raises(ShapeError, lambda: Matrix([1]).LUsolve(Matrix([[1, 2], [3, 4]]))) + +def test_cholesky_solve(): + A = Matrix([[2, 3, 5], + [3, 6, 2], + [8, 3, 6]]) + x = Matrix(3, 1, [3, 7, 5]) + b = A*x + soln = A.cholesky_solve(b) + assert soln == x + A = Matrix([[0, -1, 2], + [5, 10, 7], + [8, 3, 4]]) + x = Matrix(3, 1, [-1, 2, 5]) + b = A*x + soln = A.cholesky_solve(b) + assert soln == x + A = Matrix(((1, 5), (5, 1))) + x = Matrix((4, -3)) + b = A*x + soln = A.cholesky_solve(b) + assert soln == x + A = Matrix(((9, 3*I), (-3*I, 5))) + x = Matrix((-2, 1)) + b = A*x + soln = A.cholesky_solve(b) + assert expand_mul(soln) == x + A = Matrix(((9*I, 3), (-3 + I, 5))) + x = Matrix((2 + 3*I, -1)) + b = A*x + soln = A.cholesky_solve(b) + assert expand_mul(soln) == x + a00, a01, a11, b0, b1 = symbols('a00, a01, a11, b0, b1') + A = Matrix(((a00, a01), (a01, a11))) + b = Matrix((b0, b1)) + x = A.cholesky_solve(b) + assert simplify(A*x) == b + + +def test_LDLsolve(): + A = Matrix([[2, 3, 5], + [3, 6, 2], + [8, 3, 6]]) + x = Matrix(3, 1, [3, 7, 5]) + b = A*x + soln = A.LDLsolve(b) + assert soln == x + + A = Matrix([[0, -1, 2], + [5, 10, 7], + [8, 3, 4]]) + x = Matrix(3, 1, [-1, 2, 5]) + b = A*x + soln = A.LDLsolve(b) + assert soln == x + + A = Matrix(((9, 3*I), (-3*I, 5))) + x = Matrix((-2, 1)) + b = A*x + soln = A.LDLsolve(b) + assert expand_mul(soln) == x + + A = Matrix(((9*I, 3), (-3 + I, 5))) + x = Matrix((2 + 3*I, -1)) + b = A*x + soln = A.LDLsolve(b) + assert expand_mul(soln) == x + + A = Matrix(((9, 3), (3, 9))) + x = Matrix((1, 1)) + b = A * x + soln = A.LDLsolve(b) + assert expand_mul(soln) == x + + A = Matrix([[-5, -3, -4], [-3, -7, 7]]) + x = Matrix([[8], [7], [-2]]) + b = A * x + raises(NotImplementedError, lambda: A.LDLsolve(b)) + + +def test_lower_triangular_solve(): + + raises(NonSquareMatrixError, + lambda: Matrix([1, 0]).lower_triangular_solve(Matrix([0, 1]))) + raises(ShapeError, + lambda: Matrix([[1, 0], [0, 1]]).lower_triangular_solve(Matrix([1]))) + raises(ValueError, + lambda: Matrix([[2, 1], [1, 2]]).lower_triangular_solve( + Matrix([[1, 0], [0, 1]]))) + + A = Matrix([[1, 0], [0, 1]]) + B = Matrix([[x, y], [y, x]]) + C = Matrix([[4, 8], [2, 9]]) + + assert A.lower_triangular_solve(B) == B + assert A.lower_triangular_solve(C) == C + + +def test_upper_triangular_solve(): + + raises(NonSquareMatrixError, + lambda: Matrix([1, 0]).upper_triangular_solve(Matrix([0, 1]))) + raises(ShapeError, + lambda: Matrix([[1, 0], [0, 1]]).upper_triangular_solve(Matrix([1]))) + raises(TypeError, + lambda: Matrix([[2, 1], [1, 2]]).upper_triangular_solve( + Matrix([[1, 0], [0, 1]]))) + + A = Matrix([[1, 0], [0, 1]]) + B = Matrix([[x, y], [y, x]]) + C = Matrix([[2, 4], [3, 8]]) + + assert A.upper_triangular_solve(B) == B + assert A.upper_triangular_solve(C) == C + + +def test_diagonal_solve(): + raises(TypeError, lambda: Matrix([1, 1]).diagonal_solve(Matrix([1]))) + A = Matrix([[1, 0], [0, 1]])*2 + B = Matrix([[x, y], [y, x]]) + assert A.diagonal_solve(B) == B/2 + + A = Matrix([[1, 0], [1, 2]]) + raises(TypeError, lambda: A.diagonal_solve(B)) + +def test_pinv_solve(): + # Fully determined system (unique result, identical to other solvers). + A = Matrix([[1, 5], [7, 9]]) + B = Matrix([12, 13]) + assert A.pinv_solve(B) == A.cholesky_solve(B) + assert A.pinv_solve(B) == A.LDLsolve(B) + assert A.pinv_solve(B) == Matrix([sympify('-43/26'), sympify('71/26')]) + assert A * A.pinv() * B == B + # Fully determined, with two-dimensional B matrix. + B = Matrix([[12, 13, 14], [15, 16, 17]]) + assert A.pinv_solve(B) == A.cholesky_solve(B) + assert A.pinv_solve(B) == A.LDLsolve(B) + assert A.pinv_solve(B) == Matrix([[-33, -37, -41], [69, 75, 81]]) / 26 + assert A * A.pinv() * B == B + # Underdetermined system (infinite results). + A = Matrix([[1, 0, 1], [0, 1, 1]]) + B = Matrix([5, 7]) + solution = A.pinv_solve(B) + w = {} + for s in solution.atoms(Symbol): + # Extract dummy symbols used in the solution. + w[s.name] = s + assert solution == Matrix([[w['w0_0']/3 + w['w1_0']/3 - w['w2_0']/3 + 1], + [w['w0_0']/3 + w['w1_0']/3 - w['w2_0']/3 + 3], + [-w['w0_0']/3 - w['w1_0']/3 + w['w2_0']/3 + 4]]) + assert A * A.pinv() * B == B + # Overdetermined system (least squares results). + A = Matrix([[1, 0], [0, 0], [0, 1]]) + B = Matrix([3, 2, 1]) + assert A.pinv_solve(B) == Matrix([3, 1]) + # Proof the solution is not exact. + assert A * A.pinv() * B != B + +def test_pinv_rank_deficient(): + # Test the four properties of the pseudoinverse for various matrices. + As = [Matrix([[1, 1, 1], [2, 2, 2]]), + Matrix([[1, 0], [0, 0]]), + Matrix([[1, 2], [2, 4], [3, 6]])] + + for A in As: + A_pinv = A.pinv(method="RD") + AAp = A * A_pinv + ApA = A_pinv * A + assert simplify(AAp * A) == A + assert simplify(ApA * A_pinv) == A_pinv + assert AAp.H == AAp + assert ApA.H == ApA + + for A in As: + A_pinv = A.pinv(method="ED") + AAp = A * A_pinv + ApA = A_pinv * A + assert simplify(AAp * A) == A + assert simplify(ApA * A_pinv) == A_pinv + assert AAp.H == AAp + assert ApA.H == ApA + + # Test solving with rank-deficient matrices. + A = Matrix([[1, 0], [0, 0]]) + # Exact, non-unique solution. + B = Matrix([3, 0]) + solution = A.pinv_solve(B) + w1 = solution.atoms(Symbol).pop() + assert w1.name == 'w1_0' + assert solution == Matrix([3, w1]) + assert A * A.pinv() * B == B + # Least squares, non-unique solution. + B = Matrix([3, 1]) + solution = A.pinv_solve(B) + w1 = solution.atoms(Symbol).pop() + assert w1.name == 'w1_0' + assert solution == Matrix([3, w1]) + assert A * A.pinv() * B != B + +def test_gauss_jordan_solve(): + + # Square, full rank, unique solution + A = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 10]]) + b = Matrix([3, 6, 9]) + sol, params = A.gauss_jordan_solve(b) + assert sol == Matrix([[-1], [2], [0]]) + assert params == Matrix(0, 1, []) + + # Square, full rank, unique solution, B has more columns than rows + A = eye(3) + B = Matrix([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]) + sol, params = A.gauss_jordan_solve(B) + assert sol == B + assert params == Matrix(0, 4, []) + + # Square, reduced rank, parametrized solution + A = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + b = Matrix([3, 6, 9]) + sol, params, freevar = A.gauss_jordan_solve(b, freevar=True) + w = {} + for s in sol.atoms(Symbol): + # Extract dummy symbols used in the solution. + w[s.name] = s + assert sol == Matrix([[w['tau0'] - 1], [-2*w['tau0'] + 2], [w['tau0']]]) + assert params == Matrix([[w['tau0']]]) + assert freevar == [2] + + # Square, reduced rank, parametrized solution, B has two columns + A = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + B = Matrix([[3, 4], [6, 8], [9, 12]]) + sol, params, freevar = A.gauss_jordan_solve(B, freevar=True) + w = {} + for s in sol.atoms(Symbol): + # Extract dummy symbols used in the solution. + w[s.name] = s + assert sol == Matrix([[w['tau0'] - 1, w['tau1'] - Rational(4, 3)], + [-2*w['tau0'] + 2, -2*w['tau1'] + Rational(8, 3)], + [w['tau0'], w['tau1']],]) + assert params == Matrix([[w['tau0'], w['tau1']]]) + assert freevar == [2] + + # Square, reduced rank, parametrized solution + A = Matrix([[1, 2, 3], [2, 4, 6], [3, 6, 9]]) + b = Matrix([0, 0, 0]) + sol, params = A.gauss_jordan_solve(b) + w = {} + for s in sol.atoms(Symbol): + w[s.name] = s + assert sol == Matrix([[-2*w['tau0'] - 3*w['tau1']], + [w['tau0']], [w['tau1']]]) + assert params == Matrix([[w['tau0']], [w['tau1']]]) + + # Square, reduced rank, parametrized solution + A = Matrix([[0, 0, 0], [0, 0, 0], [0, 0, 0]]) + b = Matrix([0, 0, 0]) + sol, params = A.gauss_jordan_solve(b) + w = {} + for s in sol.atoms(Symbol): + w[s.name] = s + assert sol == Matrix([[w['tau0']], [w['tau1']], [w['tau2']]]) + assert params == Matrix([[w['tau0']], [w['tau1']], [w['tau2']]]) + + # Square, reduced rank, no solution + A = Matrix([[1, 2, 3], [2, 4, 6], [3, 6, 9]]) + b = Matrix([0, 0, 1]) + raises(ValueError, lambda: A.gauss_jordan_solve(b)) + + # Rectangular, tall, full rank, unique solution + A = Matrix([[1, 5, 3], [2, 1, 6], [1, 7, 9], [1, 4, 3]]) + b = Matrix([0, 0, 1, 0]) + sol, params = A.gauss_jordan_solve(b) + assert sol == Matrix([[Rational(-1, 2)], [0], [Rational(1, 6)]]) + assert params == Matrix(0, 1, []) + + # Rectangular, tall, full rank, unique solution, B has less columns than rows + A = Matrix([[1, 5, 3], [2, 1, 6], [1, 7, 9], [1, 4, 3]]) + B = Matrix([[0,0], [0, 0], [1, 2], [0, 0]]) + sol, params = A.gauss_jordan_solve(B) + assert sol == Matrix([[Rational(-1, 2), Rational(-2, 2)], [0, 0], [Rational(1, 6), Rational(2, 6)]]) + assert params == Matrix(0, 2, []) + + # Rectangular, tall, full rank, no solution + A = Matrix([[1, 5, 3], [2, 1, 6], [1, 7, 9], [1, 4, 3]]) + b = Matrix([0, 0, 0, 1]) + raises(ValueError, lambda: A.gauss_jordan_solve(b)) + + # Rectangular, tall, full rank, no solution, B has two columns (2nd has no solution) + A = Matrix([[1, 5, 3], [2, 1, 6], [1, 7, 9], [1, 4, 3]]) + B = Matrix([[0,0], [0, 0], [1, 0], [0, 1]]) + raises(ValueError, lambda: A.gauss_jordan_solve(B)) + + # Rectangular, tall, full rank, no solution, B has two columns (1st has no solution) + A = Matrix([[1, 5, 3], [2, 1, 6], [1, 7, 9], [1, 4, 3]]) + B = Matrix([[0,0], [0, 0], [0, 1], [1, 0]]) + raises(ValueError, lambda: A.gauss_jordan_solve(B)) + + # Rectangular, tall, reduced rank, parametrized solution + A = Matrix([[1, 5, 3], [2, 10, 6], [3, 15, 9], [1, 4, 3]]) + b = Matrix([0, 0, 0, 1]) + sol, params = A.gauss_jordan_solve(b) + w = {} + for s in sol.atoms(Symbol): + w[s.name] = s + assert sol == Matrix([[-3*w['tau0'] + 5], [-1], [w['tau0']]]) + assert params == Matrix([[w['tau0']]]) + + # Rectangular, tall, reduced rank, no solution + A = Matrix([[1, 5, 3], [2, 10, 6], [3, 15, 9], [1, 4, 3]]) + b = Matrix([0, 0, 1, 1]) + raises(ValueError, lambda: A.gauss_jordan_solve(b)) + + # Rectangular, wide, full rank, parametrized solution + A = Matrix([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 1, 12]]) + b = Matrix([1, 1, 1]) + sol, params = A.gauss_jordan_solve(b) + w = {} + for s in sol.atoms(Symbol): + w[s.name] = s + assert sol == Matrix([[2*w['tau0'] - 1], [-3*w['tau0'] + 1], [0], + [w['tau0']]]) + assert params == Matrix([[w['tau0']]]) + + # Rectangular, wide, reduced rank, parametrized solution + A = Matrix([[1, 2, 3, 4], [5, 6, 7, 8], [2, 4, 6, 8]]) + b = Matrix([0, 1, 0]) + sol, params = A.gauss_jordan_solve(b) + w = {} + for s in sol.atoms(Symbol): + w[s.name] = s + assert sol == Matrix([[w['tau0'] + 2*w['tau1'] + S.Half], + [-2*w['tau0'] - 3*w['tau1'] - Rational(1, 4)], + [w['tau0']], [w['tau1']]]) + assert params == Matrix([[w['tau0']], [w['tau1']]]) + # watch out for clashing symbols + x0, x1, x2, _x0 = symbols('_tau0 _tau1 _tau2 tau1') + M = Matrix([[0, 1, 0, 0, 0, 0], [0, 0, 0, 1, 0, _x0]]) + A = M[:, :-1] + b = M[:, -1:] + sol, params = A.gauss_jordan_solve(b) + assert params == Matrix(3, 1, [x0, x1, x2]) + assert sol == Matrix(5, 1, [x0, 0, x1, _x0, x2]) + + # Rectangular, wide, reduced rank, no solution + A = Matrix([[1, 2, 3, 4], [5, 6, 7, 8], [2, 4, 6, 8]]) + b = Matrix([1, 1, 1]) + raises(ValueError, lambda: A.gauss_jordan_solve(b)) + + # Test for immutable matrix + A = ImmutableMatrix([[1, 0], [0, 1]]) + B = ImmutableMatrix([1, 2]) + sol, params = A.gauss_jordan_solve(B) + assert sol == ImmutableMatrix([1, 2]) + assert params == ImmutableMatrix(0, 1, []) + assert sol.__class__ == ImmutableDenseMatrix + assert params.__class__ == ImmutableDenseMatrix + + # Test placement of free variables + A = Matrix([[1, 0, 0, 0], [0, 0, 0, 1]]) + b = Matrix([1, 1]) + sol, params = A.gauss_jordan_solve(b) + w = {} + for s in sol.atoms(Symbol): + w[s.name] = s + assert sol == Matrix([[1], [w['tau0']], [w['tau1']], [1]]) + assert params == Matrix([[w['tau0']], [w['tau1']]]) + + +def test_linsolve_underdetermined_AND_gauss_jordan_solve(): + #Test placement of free variables as per issue 19815 + A = Matrix([[1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1]]) + B = Matrix([1, 2, 1, 1, 1, 1, 1, 2]) + sol, params = A.gauss_jordan_solve(B) + w = {} + for s in sol.atoms(Symbol): + w[s.name] = s + assert params == Matrix([[w['tau0']], [w['tau1']], [w['tau2']], + [w['tau3']], [w['tau4']], [w['tau5']]]) + assert sol == Matrix([[1 - 1*w['tau2']], + [w['tau2']], + [1 - 1*w['tau0'] + w['tau1']], + [w['tau0']], + [w['tau3'] + w['tau4']], + [-1*w['tau3'] - 1*w['tau4'] - 1*w['tau1']], + [1 - 1*w['tau2']], + [w['tau1']], + [w['tau2']], + [w['tau3']], + [w['tau4']], + [1 - 1*w['tau5']], + [w['tau5']], + [1]]) + + from sympy.abc import j,f + # https://github.com/sympy/sympy/issues/20046 + A = Matrix([ + [1, 1, 1, 1, 1, 1, 1, 1, 1], + [0, -1, 0, -1, 0, -1, 0, -1, -j], + [0, 0, 0, 0, 1, 1, 1, 1, f] + ]) + + sol_1=Matrix(list(linsolve(A))[0]) + + tau0, tau1, tau2, tau3, tau4 = symbols('tau:5') + + assert sol_1 == Matrix([[-f - j - tau0 + tau2 + tau4 + 1], + [j - tau1 - tau2 - tau4], + [tau0], + [tau1], + [f - tau2 - tau3 - tau4], + [tau2], + [tau3], + [tau4]]) + + # https://github.com/sympy/sympy/issues/19815 + sol_2 = A[:, : -1 ] * sol_1 - A[:, -1 ] + assert sol_2 == Matrix([[0], [0], [0]]) + + +@pytest.mark.parametrize("det_method", ["bird", "laplace"]) +@pytest.mark.parametrize("M, rhs", [ + (Matrix([[2, 3, 5], [3, 6, 2], [8, 3, 6]]), Matrix(3, 1, [3, 7, 5])), + (Matrix([[2, 3, 5], [3, 6, 2], [8, 3, 6]]), + Matrix([[1, 2], [3, 4], [5, 6]])), + (Matrix(2, 2, symbols("a:4")), Matrix(2, 1, symbols("b:2"))), +]) +def test_cramer_solve(det_method, M, rhs): + assert simplify(M.cramer_solve(rhs, det_method=det_method) - M.LUsolve(rhs) + ) == Matrix.zeros(M.rows, rhs.cols) + + +@pytest.mark.parametrize("det_method, error", [ + ("bird", DMShapeError), (_det_laplace, NonSquareMatrixError)]) +def test_cramer_solve_errors(det_method, error): + # Non-square matrix + A = Matrix([[0, -1, 2], [5, 10, 7]]) + b = Matrix([-2, 15]) + raises(error, lambda: A.cramer_solve(b, det_method=det_method)) + + +def test_solve(): + A = Matrix([[1,2], [2,4]]) + b = Matrix([[3], [4]]) + raises(ValueError, lambda: A.solve(b)) #no solution + b = Matrix([[ 4], [8]]) + raises(ValueError, lambda: A.solve(b)) #infinite solution diff --git a/phivenv/Lib/site-packages/sympy/matrices/tests/test_sparse.py b/phivenv/Lib/site-packages/sympy/matrices/tests/test_sparse.py new file mode 100644 index 0000000000000000000000000000000000000000..4d257c8062f220cc06bc0dabdc7ac40ce9dc4adc --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/tests/test_sparse.py @@ -0,0 +1,745 @@ +from sympy.core.numbers import (Float, I, Rational) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.complexes import Abs +from sympy.polys.polytools import PurePoly +from sympy.matrices import \ + Matrix, MutableSparseMatrix, ImmutableSparseMatrix, SparseMatrix, eye, \ + ones, zeros, ShapeError, NonSquareMatrixError +from sympy.testing.pytest import raises + + +def test_sparse_creation(): + a = SparseMatrix(2, 2, {(0, 0): [[1, 2], [3, 4]]}) + assert a == SparseMatrix([[1, 2], [3, 4]]) + a = SparseMatrix(2, 2, {(0, 0): [[1, 2]]}) + assert a == SparseMatrix([[1, 2], [0, 0]]) + a = SparseMatrix(2, 2, {(0, 0): [1, 2]}) + assert a == SparseMatrix([[1, 0], [2, 0]]) + + +def test_sparse_matrix(): + def sparse_eye(n): + return SparseMatrix.eye(n) + + def sparse_zeros(n): + return SparseMatrix.zeros(n) + + # creation args + raises(TypeError, lambda: SparseMatrix(1, 2)) + + a = SparseMatrix(( + (1, 0), + (0, 1) + )) + assert SparseMatrix(a) == a + + from sympy.matrices import MutableDenseMatrix + a = MutableSparseMatrix([]) + b = MutableDenseMatrix([1, 2]) + assert a.row_join(b) == b + assert a.col_join(b) == b + assert type(a.row_join(b)) == type(a) + assert type(a.col_join(b)) == type(a) + + # make sure 0 x n matrices get stacked correctly + sparse_matrices = [SparseMatrix.zeros(0, n) for n in range(4)] + assert SparseMatrix.hstack(*sparse_matrices) == Matrix(0, 6, []) + sparse_matrices = [SparseMatrix.zeros(n, 0) for n in range(4)] + assert SparseMatrix.vstack(*sparse_matrices) == Matrix(6, 0, []) + + # test element assignment + a = SparseMatrix(( + (1, 0), + (0, 1) + )) + + a[3] = 4 + assert a[1, 1] == 4 + a[3] = 1 + + a[0, 0] = 2 + assert a == SparseMatrix(( + (2, 0), + (0, 1) + )) + a[1, 0] = 5 + assert a == SparseMatrix(( + (2, 0), + (5, 1) + )) + a[1, 1] = 0 + assert a == SparseMatrix(( + (2, 0), + (5, 0) + )) + assert a.todok() == {(0, 0): 2, (1, 0): 5} + + # test_multiplication + a = SparseMatrix(( + (1, 2), + (3, 1), + (0, 6), + )) + + b = SparseMatrix(( + (1, 2), + (3, 0), + )) + + c = a*b + assert c[0, 0] == 7 + assert c[0, 1] == 2 + assert c[1, 0] == 6 + assert c[1, 1] == 6 + assert c[2, 0] == 18 + assert c[2, 1] == 0 + + try: + eval('c = a @ b') + except SyntaxError: + pass + else: + assert c[0, 0] == 7 + assert c[0, 1] == 2 + assert c[1, 0] == 6 + assert c[1, 1] == 6 + assert c[2, 0] == 18 + assert c[2, 1] == 0 + + x = Symbol("x") + + c = b * Symbol("x") + assert isinstance(c, SparseMatrix) + assert c[0, 0] == x + assert c[0, 1] == 2*x + assert c[1, 0] == 3*x + assert c[1, 1] == 0 + + c = 5 * b + assert isinstance(c, SparseMatrix) + assert c[0, 0] == 5 + assert c[0, 1] == 2*5 + assert c[1, 0] == 3*5 + assert c[1, 1] == 0 + + #test_power + A = SparseMatrix([[2, 3], [4, 5]]) + assert (A**5)[:] == [6140, 8097, 10796, 14237] + A = SparseMatrix([[2, 1, 3], [4, 2, 4], [6, 12, 1]]) + assert (A**3)[:] == [290, 262, 251, 448, 440, 368, 702, 954, 433] + + # test_creation + x = Symbol("x") + a = SparseMatrix([[x, 0], [0, 0]]) + m = a + assert m.cols == m.rows + assert m.cols == 2 + assert m[:] == [x, 0, 0, 0] + b = SparseMatrix(2, 2, [x, 0, 0, 0]) + m = b + assert m.cols == m.rows + assert m.cols == 2 + assert m[:] == [x, 0, 0, 0] + + assert a == b + S = sparse_eye(3) + S.row_del(1) + assert S == SparseMatrix([ + [1, 0, 0], + [0, 0, 1]]) + S = sparse_eye(3) + S.col_del(1) + assert S == SparseMatrix([ + [1, 0], + [0, 0], + [0, 1]]) + S = SparseMatrix.eye(3) + S[2, 1] = 2 + S.col_swap(1, 0) + assert S == SparseMatrix([ + [0, 1, 0], + [1, 0, 0], + [2, 0, 1]]) + S.row_swap(0, 1) + assert S == SparseMatrix([ + [1, 0, 0], + [0, 1, 0], + [2, 0, 1]]) + + a = SparseMatrix(1, 2, [1, 2]) + b = a.copy() + c = a.copy() + assert a[0] == 1 + a.row_del(0) + assert a == SparseMatrix(0, 2, []) + b.col_del(1) + assert b == SparseMatrix(1, 1, [1]) + + assert SparseMatrix([[1, 2, 3], [1, 2], [1]]) == Matrix([ + [1, 2, 3], + [1, 2, 0], + [1, 0, 0]]) + assert SparseMatrix(4, 4, {(1, 1): sparse_eye(2)}) == Matrix([ + [0, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 0]]) + raises(ValueError, lambda: SparseMatrix(1, 1, {(1, 1): 1})) + assert SparseMatrix(1, 2, [1, 2]).tolist() == [[1, 2]] + assert SparseMatrix(2, 2, [1, [2, 3]]).tolist() == [[1, 0], [2, 3]] + raises(ValueError, lambda: SparseMatrix(2, 2, [1])) + raises(ValueError, lambda: SparseMatrix(1, 1, [[1, 2]])) + assert SparseMatrix([.1]).has(Float) + # autosizing + assert SparseMatrix(None, {(0, 1): 0}).shape == (0, 0) + assert SparseMatrix(None, {(0, 1): 1}).shape == (1, 2) + assert SparseMatrix(None, None, {(0, 1): 1}).shape == (1, 2) + raises(ValueError, lambda: SparseMatrix(None, 1, [[1, 2]])) + raises(ValueError, lambda: SparseMatrix(1, None, [[1, 2]])) + raises(ValueError, lambda: SparseMatrix(3, 3, {(0, 0): ones(2), (1, 1): 2})) + + # test_determinant + x, y = Symbol('x'), Symbol('y') + + assert SparseMatrix(1, 1, [0]).det() == 0 + + assert SparseMatrix([[1]]).det() == 1 + + assert SparseMatrix(((-3, 2), (8, -5))).det() == -1 + + assert SparseMatrix(((x, 1), (y, 2*y))).det() == 2*x*y - y + + assert SparseMatrix(( (1, 1, 1), + (1, 2, 3), + (1, 3, 6) )).det() == 1 + + assert SparseMatrix(( ( 3, -2, 0, 5), + (-2, 1, -2, 2), + ( 0, -2, 5, 0), + ( 5, 0, 3, 4) )).det() == -289 + + assert SparseMatrix(( ( 1, 2, 3, 4), + ( 5, 6, 7, 8), + ( 9, 10, 11, 12), + (13, 14, 15, 16) )).det() == 0 + + assert SparseMatrix(( (3, 2, 0, 0, 0), + (0, 3, 2, 0, 0), + (0, 0, 3, 2, 0), + (0, 0, 0, 3, 2), + (2, 0, 0, 0, 3) )).det() == 275 + + assert SparseMatrix(( (1, 0, 1, 2, 12), + (2, 0, 1, 1, 4), + (2, 1, 1, -1, 3), + (3, 2, -1, 1, 8), + (1, 1, 1, 0, 6) )).det() == -55 + + assert SparseMatrix(( (-5, 2, 3, 4, 5), + ( 1, -4, 3, 4, 5), + ( 1, 2, -3, 4, 5), + ( 1, 2, 3, -2, 5), + ( 1, 2, 3, 4, -1) )).det() == 11664 + + assert SparseMatrix(( ( 3, 0, 0, 0), + (-2, 1, 0, 0), + ( 0, -2, 5, 0), + ( 5, 0, 3, 4) )).det() == 60 + + assert SparseMatrix(( ( 1, 0, 0, 0), + ( 5, 0, 0, 0), + ( 9, 10, 11, 0), + (13, 14, 15, 16) )).det() == 0 + + assert SparseMatrix(( (3, 2, 0, 0, 0), + (0, 3, 2, 0, 0), + (0, 0, 3, 2, 0), + (0, 0, 0, 3, 2), + (0, 0, 0, 0, 3) )).det() == 243 + + assert SparseMatrix(( ( 2, 7, -1, 3, 2), + ( 0, 0, 1, 0, 1), + (-2, 0, 7, 0, 2), + (-3, -2, 4, 5, 3), + ( 1, 0, 0, 0, 1) )).det() == 123 + + # test_slicing + m0 = sparse_eye(4) + assert m0[:3, :3] == sparse_eye(3) + assert m0[2:4, 0:2] == sparse_zeros(2) + + m1 = SparseMatrix(3, 3, lambda i, j: i + j) + assert m1[0, :] == SparseMatrix(1, 3, (0, 1, 2)) + assert m1[1:3, 1] == SparseMatrix(2, 1, (2, 3)) + + m2 = SparseMatrix( + [[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]) + assert m2[:, -1] == SparseMatrix(4, 1, [3, 7, 11, 15]) + assert m2[-2:, :] == SparseMatrix([[8, 9, 10, 11], [12, 13, 14, 15]]) + + assert SparseMatrix([[1, 2], [3, 4]])[[1], [1]] == Matrix([[4]]) + + # test_submatrix_assignment + m = sparse_zeros(4) + m[2:4, 2:4] = sparse_eye(2) + assert m == SparseMatrix([(0, 0, 0, 0), + (0, 0, 0, 0), + (0, 0, 1, 0), + (0, 0, 0, 1)]) + assert len(m.todok()) == 2 + m[:2, :2] = sparse_eye(2) + assert m == sparse_eye(4) + m[:, 0] = SparseMatrix(4, 1, (1, 2, 3, 4)) + assert m == SparseMatrix([(1, 0, 0, 0), + (2, 1, 0, 0), + (3, 0, 1, 0), + (4, 0, 0, 1)]) + m[:, :] = sparse_zeros(4) + assert m == sparse_zeros(4) + m[:, :] = ((1, 2, 3, 4), (5, 6, 7, 8), (9, 10, 11, 12), (13, 14, 15, 16)) + assert m == SparseMatrix((( 1, 2, 3, 4), + ( 5, 6, 7, 8), + ( 9, 10, 11, 12), + (13, 14, 15, 16))) + m[:2, 0] = [0, 0] + assert m == SparseMatrix((( 0, 2, 3, 4), + ( 0, 6, 7, 8), + ( 9, 10, 11, 12), + (13, 14, 15, 16))) + + # test_reshape + m0 = sparse_eye(3) + assert m0.reshape(1, 9) == SparseMatrix(1, 9, (1, 0, 0, 0, 1, 0, 0, 0, 1)) + m1 = SparseMatrix(3, 4, lambda i, j: i + j) + assert m1.reshape(4, 3) == \ + SparseMatrix([(0, 1, 2), (3, 1, 2), (3, 4, 2), (3, 4, 5)]) + assert m1.reshape(2, 6) == \ + SparseMatrix([(0, 1, 2, 3, 1, 2), (3, 4, 2, 3, 4, 5)]) + + # test_applyfunc + m0 = sparse_eye(3) + assert m0.applyfunc(lambda x: 2*x) == sparse_eye(3)*2 + assert m0.applyfunc(lambda x: 0 ) == sparse_zeros(3) + + # test__eval_Abs + assert abs(SparseMatrix(((x, 1), (y, 2*y)))) == SparseMatrix(((Abs(x), 1), (Abs(y), 2*Abs(y)))) + + # test_LUdecomp + testmat = SparseMatrix([[ 0, 2, 5, 3], + [ 3, 3, 7, 4], + [ 8, 4, 0, 2], + [-2, 6, 3, 4]]) + L, U, p = testmat.LUdecomposition() + assert L.is_lower + assert U.is_upper + assert (L*U).permute_rows(p, 'backward') - testmat == sparse_zeros(4) + + testmat = SparseMatrix([[ 6, -2, 7, 4], + [ 0, 3, 6, 7], + [ 1, -2, 7, 4], + [-9, 2, 6, 3]]) + L, U, p = testmat.LUdecomposition() + assert L.is_lower + assert U.is_upper + assert (L*U).permute_rows(p, 'backward') - testmat == sparse_zeros(4) + + x, y, z = Symbol('x'), Symbol('y'), Symbol('z') + M = Matrix(((1, x, 1), (2, y, 0), (y, 0, z))) + L, U, p = M.LUdecomposition() + assert L.is_lower + assert U.is_upper + assert (L*U).permute_rows(p, 'backward') - M == sparse_zeros(3) + + # test_LUsolve + A = SparseMatrix([[2, 3, 5], + [3, 6, 2], + [8, 3, 6]]) + x = SparseMatrix(3, 1, [3, 7, 5]) + b = A*x + soln = A.LUsolve(b) + assert soln == x + A = SparseMatrix([[0, -1, 2], + [5, 10, 7], + [8, 3, 4]]) + x = SparseMatrix(3, 1, [-1, 2, 5]) + b = A*x + soln = A.LUsolve(b) + assert soln == x + + # test_inverse + A = sparse_eye(4) + assert A.inv() == sparse_eye(4) + assert A.inv(method="CH") == sparse_eye(4) + assert A.inv(method="LDL") == sparse_eye(4) + + A = SparseMatrix([[2, 3, 5], + [3, 6, 2], + [7, 2, 6]]) + Ainv = SparseMatrix(Matrix(A).inv()) + assert A*Ainv == sparse_eye(3) + assert A.inv(method="CH") == Ainv + assert A.inv(method="LDL") == Ainv + + A = SparseMatrix([[2, 3, 5], + [3, 6, 2], + [5, 2, 6]]) + Ainv = SparseMatrix(Matrix(A).inv()) + assert A*Ainv == sparse_eye(3) + assert A.inv(method="CH") == Ainv + assert A.inv(method="LDL") == Ainv + + # test_cross + v1 = Matrix(1, 3, [1, 2, 3]) + v2 = Matrix(1, 3, [3, 4, 5]) + assert v1.cross(v2) == Matrix(1, 3, [-2, 4, -2]) + assert v1.norm(2)**2 == 14 + + # conjugate + a = SparseMatrix(((1, 2 + I), (3, 4))) + assert a.C == SparseMatrix([ + [1, 2 - I], + [3, 4] + ]) + + # mul + assert a*Matrix(2, 2, [1, 0, 0, 1]) == a + assert a + Matrix(2, 2, [1, 1, 1, 1]) == SparseMatrix([ + [2, 3 + I], + [4, 5] + ]) + + # col join + assert a.col_join(sparse_eye(2)) == SparseMatrix([ + [1, 2 + I], + [3, 4], + [1, 0], + [0, 1] + ]) + + # row insert + assert a.row_insert(2, sparse_eye(2)) == SparseMatrix([ + [1, 2 + I], + [3, 4], + [1, 0], + [0, 1] + ]) + + # col insert + assert a.col_insert(2, SparseMatrix.zeros(2, 1)) == SparseMatrix([ + [1, 2 + I, 0], + [3, 4, 0], + ]) + + # symmetric + assert not a.is_symmetric(simplify=False) + + # col op + M = SparseMatrix.eye(3)*2 + M[1, 0] = -1 + M.col_op(1, lambda v, i: v + 2*M[i, 0]) + assert M == SparseMatrix([ + [ 2, 4, 0], + [-1, 0, 0], + [ 0, 0, 2] + ]) + + # fill + M = SparseMatrix.eye(3) + M.fill(2) + assert M == SparseMatrix([ + [2, 2, 2], + [2, 2, 2], + [2, 2, 2], + ]) + + # test_cofactor + assert sparse_eye(3) == sparse_eye(3).cofactor_matrix() + test = SparseMatrix([[1, 3, 2], [2, 6, 3], [2, 3, 6]]) + assert test.cofactor_matrix() == \ + SparseMatrix([[27, -6, -6], [-12, 2, 3], [-3, 1, 0]]) + test = SparseMatrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + assert test.cofactor_matrix() == \ + SparseMatrix([[-3, 6, -3], [6, -12, 6], [-3, 6, -3]]) + + # test_jacobian + x = Symbol('x') + y = Symbol('y') + L = SparseMatrix(1, 2, [x**2*y, 2*y**2 + x*y]) + syms = [x, y] + assert L.jacobian(syms) == Matrix([[2*x*y, x**2], [y, 4*y + x]]) + + L = SparseMatrix(1, 2, [x, x**2*y**3]) + assert L.jacobian(syms) == SparseMatrix([[1, 0], [2*x*y**3, x**2*3*y**2]]) + + # test_QR + A = Matrix([[1, 2], [2, 3]]) + Q, S = A.QRdecomposition() + R = Rational + assert Q == Matrix([ + [ 5**R(-1, 2), (R(2)/5)*(R(1)/5)**R(-1, 2)], + [2*5**R(-1, 2), (-R(1)/5)*(R(1)/5)**R(-1, 2)]]) + assert S == Matrix([ + [5**R(1, 2), 8*5**R(-1, 2)], + [ 0, (R(1)/5)**R(1, 2)]]) + assert Q*S == A + assert Q.T * Q == sparse_eye(2) + + R = Rational + # test nullspace + # first test reduced row-ech form + + M = SparseMatrix([[5, 7, 2, 1], + [1, 6, 2, -1]]) + out, tmp = M.rref() + assert out == Matrix([[1, 0, -R(2)/23, R(13)/23], + [0, 1, R(8)/23, R(-6)/23]]) + + M = SparseMatrix([[ 1, 3, 0, 2, 6, 3, 1], + [-2, -6, 0, -2, -8, 3, 1], + [ 3, 9, 0, 0, 6, 6, 2], + [-1, -3, 0, 1, 0, 9, 3]]) + + out, tmp = M.rref() + assert out == Matrix([[1, 3, 0, 0, 2, 0, 0], + [0, 0, 0, 1, 2, 0, 0], + [0, 0, 0, 0, 0, 1, R(1)/3], + [0, 0, 0, 0, 0, 0, 0]]) + # now check the vectors + basis = M.nullspace() + assert basis[0] == Matrix([-3, 1, 0, 0, 0, 0, 0]) + assert basis[1] == Matrix([0, 0, 1, 0, 0, 0, 0]) + assert basis[2] == Matrix([-2, 0, 0, -2, 1, 0, 0]) + assert basis[3] == Matrix([0, 0, 0, 0, 0, R(-1)/3, 1]) + + # test eigen + x = Symbol('x') + y = Symbol('y') + sparse_eye3 = sparse_eye(3) + assert sparse_eye3.charpoly(x) == PurePoly((x - 1)**3) + assert sparse_eye3.charpoly(y) == PurePoly((y - 1)**3) + + # test values + M = Matrix([( 0, 1, -1), + ( 1, 1, 0), + (-1, 0, 1)]) + vals = M.eigenvals() + assert sorted(vals.keys()) == [-1, 1, 2] + + R = Rational + M = Matrix([[1, 0, 0], + [0, 1, 0], + [0, 0, 1]]) + assert M.eigenvects() == [(1, 3, [ + Matrix([1, 0, 0]), + Matrix([0, 1, 0]), + Matrix([0, 0, 1])])] + M = Matrix([[5, 0, 2], + [3, 2, 0], + [0, 0, 1]]) + assert M.eigenvects() == [(1, 1, [Matrix([R(-1)/2, R(3)/2, 1])]), + (2, 1, [Matrix([0, 1, 0])]), + (5, 1, [Matrix([1, 1, 0])])] + + assert M.zeros(3, 5) == SparseMatrix(3, 5, {}) + A = SparseMatrix(10, 10, {(0, 0): 18, (0, 9): 12, (1, 4): 18, (2, 7): 16, (3, 9): 12, (4, 2): 19, (5, 7): 16, (6, 2): 12, (9, 7): 18}) + assert A.row_list() == [(0, 0, 18), (0, 9, 12), (1, 4, 18), (2, 7, 16), (3, 9, 12), (4, 2, 19), (5, 7, 16), (6, 2, 12), (9, 7, 18)] + assert A.col_list() == [(0, 0, 18), (4, 2, 19), (6, 2, 12), (1, 4, 18), (2, 7, 16), (5, 7, 16), (9, 7, 18), (0, 9, 12), (3, 9, 12)] + assert SparseMatrix.eye(2).nnz() == 2 + + +def test_scalar_multiply(): + assert SparseMatrix([[1, 2]]).scalar_multiply(3) == SparseMatrix([[3, 6]]) + + +def test_transpose(): + assert SparseMatrix(((1, 2), (3, 4))).transpose() == \ + SparseMatrix(((1, 3), (2, 4))) + + +def test_trace(): + assert SparseMatrix(((1, 2), (3, 4))).trace() == 5 + assert SparseMatrix(((0, 0), (0, 4))).trace() == 4 + + +def test_CL_RL(): + assert SparseMatrix(((1, 2), (3, 4))).row_list() == \ + [(0, 0, 1), (0, 1, 2), (1, 0, 3), (1, 1, 4)] + assert SparseMatrix(((1, 2), (3, 4))).col_list() == \ + [(0, 0, 1), (1, 0, 3), (0, 1, 2), (1, 1, 4)] + + +def test_add(): + assert SparseMatrix(((1, 0), (0, 1))) + SparseMatrix(((0, 1), (1, 0))) == \ + SparseMatrix(((1, 1), (1, 1))) + a = SparseMatrix(100, 100, lambda i, j: int(j != 0 and i % j == 0)) + b = SparseMatrix(100, 100, lambda i, j: int(i != 0 and j % i == 0)) + assert (len(a.todok()) + len(b.todok()) - len((a + b).todok()) > 0) + + +def test_errors(): + raises(ValueError, lambda: SparseMatrix(1.4, 2, lambda i, j: 0)) + raises(TypeError, lambda: SparseMatrix([1, 2, 3], [1, 2])) + raises(ValueError, lambda: SparseMatrix([[1, 2], [3, 4]])[(1, 2, 3)]) + raises(IndexError, lambda: SparseMatrix([[1, 2], [3, 4]])[5]) + raises(ValueError, lambda: SparseMatrix([[1, 2], [3, 4]])[1, 2, 3]) + raises(TypeError, + lambda: SparseMatrix([[1, 2], [3, 4]]).copyin_list([0, 1], set())) + raises( + IndexError, lambda: SparseMatrix([[1, 2], [3, 4]])[1, 2]) + raises(TypeError, lambda: SparseMatrix([1, 2, 3]).cross(1)) + raises(IndexError, lambda: SparseMatrix(1, 2, [1, 2])[3]) + raises(ShapeError, + lambda: SparseMatrix(1, 2, [1, 2]) + SparseMatrix(2, 1, [2, 1])) + + +def test_len(): + assert not SparseMatrix() + assert SparseMatrix() == SparseMatrix([]) + assert SparseMatrix() == SparseMatrix([[]]) + + +def test_sparse_zeros_sparse_eye(): + assert SparseMatrix.eye(3) == eye(3, cls=SparseMatrix) + assert len(SparseMatrix.eye(3).todok()) == 3 + assert SparseMatrix.zeros(3) == zeros(3, cls=SparseMatrix) + assert len(SparseMatrix.zeros(3).todok()) == 0 + + +def test_copyin(): + s = SparseMatrix(3, 3, {}) + s[1, 0] = 1 + assert s[:, 0] == SparseMatrix(Matrix([0, 1, 0])) + assert s[3] == 1 + assert s[3: 4] == [1] + s[1, 1] = 42 + assert s[1, 1] == 42 + assert s[1, 1:] == SparseMatrix([[42, 0]]) + s[1, 1:] = Matrix([[5, 6]]) + assert s[1, :] == SparseMatrix([[1, 5, 6]]) + s[1, 1:] = [[42, 43]] + assert s[1, :] == SparseMatrix([[1, 42, 43]]) + s[0, 0] = 17 + assert s[:, :1] == SparseMatrix([17, 1, 0]) + s[0, 0] = [1, 1, 1] + assert s[:, 0] == SparseMatrix([1, 1, 1]) + s[0, 0] = Matrix([1, 1, 1]) + assert s[:, 0] == SparseMatrix([1, 1, 1]) + s[0, 0] = SparseMatrix([1, 1, 1]) + assert s[:, 0] == SparseMatrix([1, 1, 1]) + + +def test_sparse_solve(): + A = SparseMatrix(((25, 15, -5), (15, 18, 0), (-5, 0, 11))) + assert A.cholesky() == Matrix([ + [ 5, 0, 0], + [ 3, 3, 0], + [-1, 1, 3]]) + assert A.cholesky() * A.cholesky().T == Matrix([ + [25, 15, -5], + [15, 18, 0], + [-5, 0, 11]]) + + A = SparseMatrix(((25, 15, -5), (15, 18, 0), (-5, 0, 11))) + L, D = A.LDLdecomposition() + assert 15*L == Matrix([ + [15, 0, 0], + [ 9, 15, 0], + [-3, 5, 15]]) + assert D == Matrix([ + [25, 0, 0], + [ 0, 9, 0], + [ 0, 0, 9]]) + assert L * D * L.T == A + + A = SparseMatrix(((3, 0, 2), (0, 0, 1), (1, 2, 0))) + assert A.inv() * A == SparseMatrix(eye(3)) + + A = SparseMatrix([ + [ 2, -1, 0], + [-1, 2, -1], + [ 0, 0, 2]]) + ans = SparseMatrix([ + [Rational(2, 3), Rational(1, 3), Rational(1, 6)], + [Rational(1, 3), Rational(2, 3), Rational(1, 3)], + [ 0, 0, S.Half]]) + assert A.inv(method='CH') == ans + assert A.inv(method='LDL') == ans + assert A * ans == SparseMatrix(eye(3)) + + s = A.solve(A[:, 0], 'LDL') + assert A*s == A[:, 0] + s = A.solve(A[:, 0], 'CH') + assert A*s == A[:, 0] + A = A.col_join(A) + s = A.solve_least_squares(A[:, 0], 'CH') + assert A*s == A[:, 0] + s = A.solve_least_squares(A[:, 0], 'LDL') + assert A*s == A[:, 0] + + +def test_lower_triangular_solve(): + raises(NonSquareMatrixError, lambda: + SparseMatrix([[1, 2]]).lower_triangular_solve(Matrix([[1, 2]]))) + raises(ShapeError, lambda: + SparseMatrix([[1, 2], [0, 4]]).lower_triangular_solve(Matrix([1]))) + raises(ValueError, lambda: + SparseMatrix([[1, 2], [3, 4]]).lower_triangular_solve(Matrix([[1, 2], [3, 4]]))) + + a, b, c, d = symbols('a:d') + u, v, w, x = symbols('u:x') + + A = SparseMatrix([[a, 0], [c, d]]) + B = MutableSparseMatrix([[u, v], [w, x]]) + C = ImmutableSparseMatrix([[u, v], [w, x]]) + + sol = Matrix([[u/a, v/a], [(w - c*u/a)/d, (x - c*v/a)/d]]) + assert A.lower_triangular_solve(B) == sol + assert A.lower_triangular_solve(C) == sol + + +def test_upper_triangular_solve(): + raises(NonSquareMatrixError, lambda: + SparseMatrix([[1, 2]]).upper_triangular_solve(Matrix([[1, 2]]))) + raises(ShapeError, lambda: + SparseMatrix([[1, 2], [0, 4]]).upper_triangular_solve(Matrix([1]))) + raises(TypeError, lambda: + SparseMatrix([[1, 2], [3, 4]]).upper_triangular_solve(Matrix([[1, 2], [3, 4]]))) + + a, b, c, d = symbols('a:d') + u, v, w, x = symbols('u:x') + + A = SparseMatrix([[a, b], [0, d]]) + B = MutableSparseMatrix([[u, v], [w, x]]) + C = ImmutableSparseMatrix([[u, v], [w, x]]) + + sol = Matrix([[(u - b*w/d)/a, (v - b*x/d)/a], [w/d, x/d]]) + assert A.upper_triangular_solve(B) == sol + assert A.upper_triangular_solve(C) == sol + + +def test_diagonal_solve(): + a, d = symbols('a d') + u, v, w, x = symbols('u:x') + + A = SparseMatrix([[a, 0], [0, d]]) + B = MutableSparseMatrix([[u, v], [w, x]]) + C = ImmutableSparseMatrix([[u, v], [w, x]]) + + sol = Matrix([[u/a, v/a], [w/d, x/d]]) + assert A.diagonal_solve(B) == sol + assert A.diagonal_solve(C) == sol + + +def test_hermitian(): + x = Symbol('x') + a = SparseMatrix([[0, I], [-I, 0]]) + assert a.is_hermitian + a = SparseMatrix([[1, I], [-I, 1]]) + assert a.is_hermitian + a[0, 0] = 2*I + assert a.is_hermitian is False + a[0, 0] = x + assert a.is_hermitian is None + a[0, 1] = a[1, 0]*I + assert a.is_hermitian is False diff --git a/phivenv/Lib/site-packages/sympy/matrices/tests/test_sparsetools.py b/phivenv/Lib/site-packages/sympy/matrices/tests/test_sparsetools.py new file mode 100644 index 0000000000000000000000000000000000000000..244944c31da06460d4bc7beff8bce0f91fea9f14 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/tests/test_sparsetools.py @@ -0,0 +1,132 @@ +from sympy.matrices.sparsetools import _doktocsr, _csrtodok, banded +from sympy.matrices.dense import (Matrix, eye, ones, zeros) +from sympy.matrices import SparseMatrix +from sympy.testing.pytest import raises + + +def test_doktocsr(): + a = SparseMatrix([[1, 2, 0, 0], [0, 3, 9, 0], [0, 1, 4, 0]]) + b = SparseMatrix(4, 6, [10, 20, 0, 0, 0, 0, 0, 30, 0, 40, 0, 0, 0, 0, 50, + 60, 70, 0, 0, 0, 0, 0, 0, 80]) + c = SparseMatrix(4, 4, [0, 0, 0, 0, 0, 12, 0, 2, 15, 0, 12, 0, 0, 0, 0, 4]) + d = SparseMatrix(10, 10, {(1, 1): 12, (3, 5): 7, (7, 8): 12}) + e = SparseMatrix([[0, 0, 0], [1, 0, 2], [3, 0, 0]]) + f = SparseMatrix(7, 8, {(2, 3): 5, (4, 5):12}) + assert _doktocsr(a) == [[1, 2, 3, 9, 1, 4], [0, 1, 1, 2, 1, 2], + [0, 2, 4, 6], [3, 4]] + assert _doktocsr(b) == [[10, 20, 30, 40, 50, 60, 70, 80], + [0, 1, 1, 3, 2, 3, 4, 5], [0, 2, 4, 7, 8], [4, 6]] + assert _doktocsr(c) == [[12, 2, 15, 12, 4], [1, 3, 0, 2, 3], + [0, 0, 2, 4, 5], [4, 4]] + assert _doktocsr(d) == [[12, 7, 12], [1, 5, 8], + [0, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3], [10, 10]] + assert _doktocsr(e) == [[1, 2, 3], [0, 2, 0], [0, 0, 2, 3], [3, 3]] + assert _doktocsr(f) == [[5, 12], [3, 5], [0, 0, 0, 1, 1, 2, 2, 2], [7, 8]] + + +def test_csrtodok(): + h = [[5, 7, 5], [2, 1, 3], [0, 1, 1, 3], [3, 4]] + g = [[12, 5, 4], [2, 4, 2], [0, 1, 2, 3], [3, 7]] + i = [[1, 3, 12], [0, 2, 4], [0, 2, 3], [2, 5]] + j = [[11, 15, 12, 15], [2, 4, 1, 2], [0, 1, 1, 2, 3, 4], [5, 8]] + k = [[1, 3], [2, 1], [0, 1, 1, 2], [3, 3]] + m = _csrtodok(h) + assert isinstance(m, SparseMatrix) + assert m == SparseMatrix(3, 4, + {(0, 2): 5, (2, 1): 7, (2, 3): 5}) + assert _csrtodok(g) == SparseMatrix(3, 7, + {(0, 2): 12, (1, 4): 5, (2, 2): 4}) + assert _csrtodok(i) == SparseMatrix([[1, 0, 3, 0, 0], [0, 0, 0, 0, 12]]) + assert _csrtodok(j) == SparseMatrix(5, 8, + {(0, 2): 11, (2, 4): 15, (3, 1): 12, (4, 2): 15}) + assert _csrtodok(k) == SparseMatrix(3, 3, {(0, 2): 1, (2, 1): 3}) + + +def test_banded(): + raises(TypeError, lambda: banded()) + raises(TypeError, lambda: banded(1)) + raises(TypeError, lambda: banded(1, 2)) + raises(TypeError, lambda: banded(1, 2, 3)) + raises(TypeError, lambda: banded(1, 2, 3, 4)) + raises(ValueError, lambda: banded({0: (1, 2)}, rows=1)) + raises(ValueError, lambda: banded({0: (1, 2)}, cols=1)) + raises(ValueError, lambda: banded(1, {0: (1, 2)})) + raises(ValueError, lambda: banded(2, 1, {0: (1, 2)})) + raises(ValueError, lambda: banded(1, 2, {0: (1, 2)})) + + assert isinstance(banded(2, 4, {}), SparseMatrix) + assert banded(2, 4, {}) == zeros(2, 4) + assert banded({0: 0, 1: 0}) == zeros(0) + assert banded({0: Matrix([1, 2])}) == Matrix([1, 2]) + assert banded({1: [1, 2, 3, 0], -1: [4, 5, 6]}) == \ + banded({1: (1, 2, 3), -1: (4, 5, 6)}) == \ + Matrix([ + [0, 1, 0, 0], + [4, 0, 2, 0], + [0, 5, 0, 3], + [0, 0, 6, 0]]) + assert banded(3, 4, {-1: 1, 0: 2, 1: 3}) == \ + Matrix([ + [2, 3, 0, 0], + [1, 2, 3, 0], + [0, 1, 2, 3]]) + s = lambda d: (1 + d)**2 + assert banded(5, {0: s, 2: s}) == \ + Matrix([ + [1, 0, 1, 0, 0], + [0, 4, 0, 4, 0], + [0, 0, 9, 0, 9], + [0, 0, 0, 16, 0], + [0, 0, 0, 0, 25]]) + assert banded(2, {0: 1}) == \ + Matrix([ + [1, 0], + [0, 1]]) + assert banded(2, 3, {0: 1}) == \ + Matrix([ + [1, 0, 0], + [0, 1, 0]]) + vert = Matrix([1, 2, 3]) + assert banded({0: vert}, cols=3) == \ + Matrix([ + [1, 0, 0], + [2, 1, 0], + [3, 2, 1], + [0, 3, 2], + [0, 0, 3]]) + assert banded(4, {0: ones(2)}) == \ + Matrix([ + [1, 1, 0, 0], + [1, 1, 0, 0], + [0, 0, 1, 1], + [0, 0, 1, 1]]) + raises(ValueError, lambda: banded({0: 2, 1: ones(2)}, rows=5)) + assert banded({0: 2, 2: (ones(2),)*3}) == \ + Matrix([ + [2, 0, 1, 1, 0, 0, 0, 0], + [0, 2, 1, 1, 0, 0, 0, 0], + [0, 0, 2, 0, 1, 1, 0, 0], + [0, 0, 0, 2, 1, 1, 0, 0], + [0, 0, 0, 0, 2, 0, 1, 1], + [0, 0, 0, 0, 0, 2, 1, 1]]) + raises(ValueError, lambda: banded({0: (2,)*5, 1: (ones(2),)*3})) + u2 = Matrix([[1, 1], [0, 1]]) + assert banded({0: (2,)*5, 1: (u2,)*3}) == \ + Matrix([ + [2, 1, 1, 0, 0, 0, 0], + [0, 2, 1, 0, 0, 0, 0], + [0, 0, 2, 1, 1, 0, 0], + [0, 0, 0, 2, 1, 0, 0], + [0, 0, 0, 0, 2, 1, 1], + [0, 0, 0, 0, 0, 0, 1]]) + assert banded({0:(0, ones(2)), 2: 2}) == \ + Matrix([ + [0, 0, 2], + [0, 1, 1], + [0, 1, 1]]) + raises(ValueError, lambda: banded({0: (0, ones(2)), 1: 2})) + assert banded({0: 1}, cols=3) == banded({0: 1}, rows=3) == eye(3) + assert banded({1: 1}, rows=3) == Matrix([ + [0, 1, 0], + [0, 0, 1], + [0, 0, 0]]) diff --git a/phivenv/Lib/site-packages/sympy/matrices/tests/test_subspaces.py b/phivenv/Lib/site-packages/sympy/matrices/tests/test_subspaces.py new file mode 100644 index 0000000000000000000000000000000000000000..0bd853e321eb06f754c17e7bd0c11deb870506f5 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/tests/test_subspaces.py @@ -0,0 +1,109 @@ +from sympy.matrices import Matrix +from sympy.core.numbers import Rational +from sympy.core.symbol import symbols +from sympy.solvers import solve + + +def test_columnspace_one(): + m = Matrix([[ 1, 2, 0, 2, 5], + [-2, -5, 1, -1, -8], + [ 0, -3, 3, 4, 1], + [ 3, 6, 0, -7, 2]]) + + basis = m.columnspace() + assert basis[0] == Matrix([1, -2, 0, 3]) + assert basis[1] == Matrix([2, -5, -3, 6]) + assert basis[2] == Matrix([2, -1, 4, -7]) + + assert len(basis) == 3 + assert Matrix.hstack(m, *basis).columnspace() == basis + + +def test_rowspace(): + m = Matrix([[ 1, 2, 0, 2, 5], + [-2, -5, 1, -1, -8], + [ 0, -3, 3, 4, 1], + [ 3, 6, 0, -7, 2]]) + + basis = m.rowspace() + assert basis[0] == Matrix([[1, 2, 0, 2, 5]]) + assert basis[1] == Matrix([[0, -1, 1, 3, 2]]) + assert basis[2] == Matrix([[0, 0, 0, 5, 5]]) + + assert len(basis) == 3 + + +def test_nullspace_one(): + m = Matrix([[ 1, 2, 0, 2, 5], + [-2, -5, 1, -1, -8], + [ 0, -3, 3, 4, 1], + [ 3, 6, 0, -7, 2]]) + + basis = m.nullspace() + assert basis[0] == Matrix([-2, 1, 1, 0, 0]) + assert basis[1] == Matrix([-1, -1, 0, -1, 1]) + # make sure the null space is really gets zeroed + assert all(e.is_zero for e in m*basis[0]) + assert all(e.is_zero for e in m*basis[1]) + +def test_nullspace_second(): + # first test reduced row-ech form + R = Rational + + M = Matrix([[5, 7, 2, 1], + [1, 6, 2, -1]]) + out, tmp = M.rref() + assert out == Matrix([[1, 0, -R(2)/23, R(13)/23], + [0, 1, R(8)/23, R(-6)/23]]) + + M = Matrix([[-5, -1, 4, -3, -1], + [ 1, -1, -1, 1, 0], + [-1, 0, 0, 0, 0], + [ 4, 1, -4, 3, 1], + [-2, 0, 2, -2, -1]]) + assert M*M.nullspace()[0] == Matrix(5, 1, [0]*5) + + M = Matrix([[ 1, 3, 0, 2, 6, 3, 1], + [-2, -6, 0, -2, -8, 3, 1], + [ 3, 9, 0, 0, 6, 6, 2], + [-1, -3, 0, 1, 0, 9, 3]]) + out, tmp = M.rref() + assert out == Matrix([[1, 3, 0, 0, 2, 0, 0], + [0, 0, 0, 1, 2, 0, 0], + [0, 0, 0, 0, 0, 1, R(1)/3], + [0, 0, 0, 0, 0, 0, 0]]) + + # now check the vectors + basis = M.nullspace() + assert basis[0] == Matrix([-3, 1, 0, 0, 0, 0, 0]) + assert basis[1] == Matrix([0, 0, 1, 0, 0, 0, 0]) + assert basis[2] == Matrix([-2, 0, 0, -2, 1, 0, 0]) + assert basis[3] == Matrix([0, 0, 0, 0, 0, R(-1)/3, 1]) + + # issue 4797; just see that we can do it when rows > cols + M = Matrix([[1, 2], [2, 4], [3, 6]]) + assert M.nullspace() + + +def test_columnspace_second(): + M = Matrix([[ 1, 2, 0, 2, 5], + [-2, -5, 1, -1, -8], + [ 0, -3, 3, 4, 1], + [ 3, 6, 0, -7, 2]]) + + # now check the vectors + basis = M.columnspace() + assert basis[0] == Matrix([1, -2, 0, 3]) + assert basis[1] == Matrix([2, -5, -3, 6]) + assert basis[2] == Matrix([2, -1, 4, -7]) + + #check by columnspace definition + a, b, c, d, e = symbols('a b c d e') + X = Matrix([a, b, c, d, e]) + for i in range(len(basis)): + eq=M*X-basis[i] + assert len(solve(eq, X)) != 0 + + #check if rank-nullity theorem holds + assert M.rank() == len(basis) + assert len(M.nullspace()) + len(M.columnspace()) == M.cols diff --git a/phivenv/Lib/site-packages/sympy/matrices/utilities.py b/phivenv/Lib/site-packages/sympy/matrices/utilities.py new file mode 100644 index 0000000000000000000000000000000000000000..b8a680b47e63615e210e561639a192ba47c642d3 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/matrices/utilities.py @@ -0,0 +1,72 @@ +from contextlib import contextmanager +from threading import local + +from sympy.core.function import expand_mul + + +class DotProdSimpState(local): + def __init__(self): + self.state = None + +_dotprodsimp_state = DotProdSimpState() + +@contextmanager +def dotprodsimp(x): + old = _dotprodsimp_state.state + + try: + _dotprodsimp_state.state = x + yield + finally: + _dotprodsimp_state.state = old + + +def _dotprodsimp(expr, withsimp=False): + """Wrapper for simplify.dotprodsimp to avoid circular imports.""" + from sympy.simplify.simplify import dotprodsimp as dps + return dps(expr, withsimp=withsimp) + + +def _get_intermediate_simp(deffunc=lambda x: x, offfunc=lambda x: x, + onfunc=_dotprodsimp, dotprodsimp=None): + """Support function for controlling intermediate simplification. Returns a + simplification function according to the global setting of dotprodsimp + operation. + + ``deffunc`` - Function to be used by default. + ``offfunc`` - Function to be used if dotprodsimp has been turned off. + ``onfunc`` - Function to be used if dotprodsimp has been turned on. + ``dotprodsimp`` - True, False or None. Will be overridden by global + _dotprodsimp_state.state if that is not None. + """ + + if dotprodsimp is False or _dotprodsimp_state.state is False: + return offfunc + if dotprodsimp is True or _dotprodsimp_state.state is True: + return onfunc + + return deffunc # None, None + + +def _get_intermediate_simp_bool(default=False, dotprodsimp=None): + """Same as ``_get_intermediate_simp`` but returns bools instead of functions + by default.""" + + return _get_intermediate_simp(default, False, True, dotprodsimp) + + +def _iszero(x): + """Returns True if x is zero.""" + return getattr(x, 'is_zero', None) + + +def _is_zero_after_expand_mul(x): + """Tests by expand_mul only, suitable for polynomials and rational + functions.""" + return expand_mul(x) == 0 + + +def _simplify(expr): + """ Wrapper to avoid circular imports. """ + from sympy.simplify.simplify import simplify + return simplify(expr) diff --git a/phivenv/Lib/site-packages/sympy/multipledispatch/__init__.py b/phivenv/Lib/site-packages/sympy/multipledispatch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5447651645e3e2e92df3002822e87a773ade0df8 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/multipledispatch/__init__.py @@ -0,0 +1,11 @@ +from .core import dispatch +from .dispatcher import (Dispatcher, halt_ordering, restart_ordering, + MDNotImplementedError) + +__version__ = '0.4.9' + +__all__ = [ + 'dispatch', + + 'Dispatcher', 'halt_ordering', 'restart_ordering', 'MDNotImplementedError', +] diff --git a/phivenv/Lib/site-packages/sympy/multipledispatch/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/multipledispatch/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e755c9b9ddac2e394603a8eeba601b87f546baf Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/multipledispatch/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/multipledispatch/__pycache__/conflict.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/multipledispatch/__pycache__/conflict.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2989fe6d682a622d0c0779bc36a6273521b0cea Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/multipledispatch/__pycache__/conflict.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/multipledispatch/__pycache__/core.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/multipledispatch/__pycache__/core.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a591ef4606243d348a4a691f76e1de4c72f31ccd Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/multipledispatch/__pycache__/core.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/multipledispatch/__pycache__/dispatcher.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/multipledispatch/__pycache__/dispatcher.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2905238168e3a16c8729afcaaeecc985432213df Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/multipledispatch/__pycache__/dispatcher.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/multipledispatch/__pycache__/utils.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/multipledispatch/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74f2a3672a2ec3b4cddeb9a3766237d2b432f654 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/multipledispatch/__pycache__/utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/multipledispatch/conflict.py b/phivenv/Lib/site-packages/sympy/multipledispatch/conflict.py new file mode 100644 index 0000000000000000000000000000000000000000..98c6742c9c03860233ef0004b241ea3944ac6d4d --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/multipledispatch/conflict.py @@ -0,0 +1,68 @@ +from .utils import _toposort, groupby + +class AmbiguityWarning(Warning): + pass + + +def supercedes(a, b): + """ A is consistent and strictly more specific than B """ + return len(a) == len(b) and all(map(issubclass, a, b)) + + +def consistent(a, b): + """ It is possible for an argument list to satisfy both A and B """ + return (len(a) == len(b) and + all(issubclass(aa, bb) or issubclass(bb, aa) + for aa, bb in zip(a, b))) + + +def ambiguous(a, b): + """ A is consistent with B but neither is strictly more specific """ + return consistent(a, b) and not (supercedes(a, b) or supercedes(b, a)) + + +def ambiguities(signatures): + """ All signature pairs such that A is ambiguous with B """ + signatures = list(map(tuple, signatures)) + return {(a, b) for a in signatures for b in signatures + if hash(a) < hash(b) + and ambiguous(a, b) + and not any(supercedes(c, a) and supercedes(c, b) + for c in signatures)} + + +def super_signature(signatures): + """ A signature that would break ambiguities """ + n = len(signatures[0]) + assert all(len(s) == n for s in signatures) + + return [max([type.mro(sig[i]) for sig in signatures], key=len)[0] + for i in range(n)] + + +def edge(a, b, tie_breaker=hash): + """ A should be checked before B + + Tie broken by tie_breaker, defaults to ``hash`` + """ + if supercedes(a, b): + if supercedes(b, a): + return tie_breaker(a) > tie_breaker(b) + else: + return True + return False + + +def ordering(signatures): + """ A sane ordering of signatures to check, first to last + + Topoological sort of edges as given by ``edge`` and ``supercedes`` + """ + signatures = list(map(tuple, signatures)) + edges = [(a, b) for a in signatures for b in signatures if edge(a, b)] + edges = groupby(lambda x: x[0], edges) + for s in signatures: + if s not in edges: + edges[s] = [] + edges = {k: [b for a, b in v] for k, v in edges.items()} + return _toposort(edges) diff --git a/phivenv/Lib/site-packages/sympy/multipledispatch/core.py b/phivenv/Lib/site-packages/sympy/multipledispatch/core.py new file mode 100644 index 0000000000000000000000000000000000000000..2856ff728c4eb97c5a59fffabddb4bf3c8b4baf2 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/multipledispatch/core.py @@ -0,0 +1,83 @@ +from __future__ import annotations +from typing import Any + +import inspect + +from .dispatcher import Dispatcher, MethodDispatcher, ambiguity_warn + +# XXX: This parameter to dispatch isn't documented and isn't used anywhere in +# sympy. Maybe it should just be removed. +global_namespace: dict[str, Any] = {} + + +def dispatch(*types, namespace=global_namespace, on_ambiguity=ambiguity_warn): + """ Dispatch function on the types of the inputs + + Supports dispatch on all non-keyword arguments. + + Collects implementations based on the function name. Ignores namespaces. + + If ambiguous type signatures occur a warning is raised when the function is + defined suggesting the additional method to break the ambiguity. + + Examples + -------- + + >>> from sympy.multipledispatch import dispatch + >>> @dispatch(int) + ... def f(x): + ... return x + 1 + + >>> @dispatch(float) + ... def f(x): # noqa: F811 + ... return x - 1 + + >>> f(3) + 4 + >>> f(3.0) + 2.0 + + Specify an isolated namespace with the namespace keyword argument + + >>> my_namespace = dict() + >>> @dispatch(int, namespace=my_namespace) + ... def foo(x): + ... return x + 1 + + Dispatch on instance methods within classes + + >>> class MyClass(object): + ... @dispatch(list) + ... def __init__(self, data): + ... self.data = data + ... @dispatch(int) + ... def __init__(self, datum): # noqa: F811 + ... self.data = [datum] + """ + types = tuple(types) + + def _(func): + name = func.__name__ + + if ismethod(func): + dispatcher = inspect.currentframe().f_back.f_locals.get( + name, + MethodDispatcher(name)) + else: + if name not in namespace: + namespace[name] = Dispatcher(name) + dispatcher = namespace[name] + + dispatcher.add(types, func, on_ambiguity=on_ambiguity) + return dispatcher + return _ + + +def ismethod(func): + """ Is func a method? + + Note that this has to work as the method is defined but before the class is + defined. At this stage methods look like functions. + """ + signature = inspect.signature(func) + return signature.parameters.get('self', None) is not None diff --git a/phivenv/Lib/site-packages/sympy/multipledispatch/dispatcher.py b/phivenv/Lib/site-packages/sympy/multipledispatch/dispatcher.py new file mode 100644 index 0000000000000000000000000000000000000000..89471d678e1c330138a91ec6a41a324d29a037d7 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/multipledispatch/dispatcher.py @@ -0,0 +1,413 @@ +from __future__ import annotations + +from warnings import warn +import inspect +from .conflict import ordering, ambiguities, super_signature, AmbiguityWarning +from .utils import expand_tuples +import itertools as itl + + +class MDNotImplementedError(NotImplementedError): + """ A NotImplementedError for multiple dispatch """ + + +### Functions for on_ambiguity + +def ambiguity_warn(dispatcher, ambiguities): + """ Raise warning when ambiguity is detected + + Parameters + ---------- + dispatcher : Dispatcher + The dispatcher on which the ambiguity was detected + ambiguities : set + Set of type signature pairs that are ambiguous within this dispatcher + + See Also: + Dispatcher.add + warning_text + """ + warn(warning_text(dispatcher.name, ambiguities), AmbiguityWarning) + + +class RaiseNotImplementedError: + """Raise ``NotImplementedError`` when called.""" + + def __init__(self, dispatcher): + self.dispatcher = dispatcher + + def __call__(self, *args, **kwargs): + types = tuple(type(a) for a in args) + raise NotImplementedError( + "Ambiguous signature for %s: <%s>" % ( + self.dispatcher.name, str_signature(types) + )) + +def ambiguity_register_error_ignore_dup(dispatcher, ambiguities): + """ + If super signature for ambiguous types is duplicate types, ignore it. + Else, register instance of ``RaiseNotImplementedError`` for ambiguous types. + + Parameters + ---------- + dispatcher : Dispatcher + The dispatcher on which the ambiguity was detected + ambiguities : set + Set of type signature pairs that are ambiguous within this dispatcher + + See Also: + Dispatcher.add + ambiguity_warn + """ + for amb in ambiguities: + signature = tuple(super_signature(amb)) + if len(set(signature)) == 1: + continue + dispatcher.add( + signature, RaiseNotImplementedError(dispatcher), + on_ambiguity=ambiguity_register_error_ignore_dup + ) + +### + + +_unresolved_dispatchers: set[Dispatcher] = set() +_resolve = [True] + + +def halt_ordering(): + _resolve[0] = False + + +def restart_ordering(on_ambiguity=ambiguity_warn): + _resolve[0] = True + while _unresolved_dispatchers: + dispatcher = _unresolved_dispatchers.pop() + dispatcher.reorder(on_ambiguity=on_ambiguity) + + +class Dispatcher: + """ Dispatch methods based on type signature + + Use ``dispatch`` to add implementations + + Examples + -------- + + >>> from sympy.multipledispatch import dispatch + >>> @dispatch(int) + ... def f(x): + ... return x + 1 + + >>> @dispatch(float) + ... def f(x): # noqa: F811 + ... return x - 1 + + >>> f(3) + 4 + >>> f(3.0) + 2.0 + """ + __slots__ = '__name__', 'name', 'funcs', 'ordering', '_cache', 'doc' + + def __init__(self, name, doc=None): + self.name = self.__name__ = name + self.funcs = {} + self._cache = {} + self.ordering = [] + self.doc = doc + + def register(self, *types, **kwargs): + """ Register dispatcher with new implementation + + >>> from sympy.multipledispatch.dispatcher import Dispatcher + >>> f = Dispatcher('f') + >>> @f.register(int) + ... def inc(x): + ... return x + 1 + + >>> @f.register(float) + ... def dec(x): + ... return x - 1 + + >>> @f.register(list) + ... @f.register(tuple) + ... def reverse(x): + ... return x[::-1] + + >>> f(1) + 2 + + >>> f(1.0) + 0.0 + + >>> f([1, 2, 3]) + [3, 2, 1] + """ + def _(func): + self.add(types, func, **kwargs) + return func + return _ + + @classmethod + def get_func_params(cls, func): + if hasattr(inspect, "signature"): + sig = inspect.signature(func) + return sig.parameters.values() + + @classmethod + def get_func_annotations(cls, func): + """ Get annotations of function positional parameters + """ + params = cls.get_func_params(func) + if params: + Parameter = inspect.Parameter + + params = (param for param in params + if param.kind in + (Parameter.POSITIONAL_ONLY, + Parameter.POSITIONAL_OR_KEYWORD)) + + annotations = tuple( + param.annotation + for param in params) + + if not any(ann is Parameter.empty for ann in annotations): + return annotations + + def add(self, signature, func, on_ambiguity=ambiguity_warn): + """ Add new types/method pair to dispatcher + + >>> from sympy.multipledispatch import Dispatcher + >>> D = Dispatcher('add') + >>> D.add((int, int), lambda x, y: x + y) + >>> D.add((float, float), lambda x, y: x + y) + + >>> D(1, 2) + 3 + >>> D(1, 2.0) + Traceback (most recent call last): + ... + NotImplementedError: Could not find signature for add: + + When ``add`` detects a warning it calls the ``on_ambiguity`` callback + with a dispatcher/itself, and a set of ambiguous type signature pairs + as inputs. See ``ambiguity_warn`` for an example. + """ + # Handle annotations + if not signature: + annotations = self.get_func_annotations(func) + if annotations: + signature = annotations + + # Handle union types + if any(isinstance(typ, tuple) for typ in signature): + for typs in expand_tuples(signature): + self.add(typs, func, on_ambiguity) + return + + for typ in signature: + if not isinstance(typ, type): + str_sig = ', '.join(c.__name__ if isinstance(c, type) + else str(c) for c in signature) + raise TypeError("Tried to dispatch on non-type: %s\n" + "In signature: <%s>\n" + "In function: %s" % + (typ, str_sig, self.name)) + + self.funcs[signature] = func + self.reorder(on_ambiguity=on_ambiguity) + self._cache.clear() + + def reorder(self, on_ambiguity=ambiguity_warn): + if _resolve[0]: + self.ordering = ordering(self.funcs) + amb = ambiguities(self.funcs) + if amb: + on_ambiguity(self, amb) + else: + _unresolved_dispatchers.add(self) + + def __call__(self, *args, **kwargs): + types = tuple([type(arg) for arg in args]) + try: + func = self._cache[types] + except KeyError: + func = self.dispatch(*types) + if not func: + raise NotImplementedError( + 'Could not find signature for %s: <%s>' % + (self.name, str_signature(types))) + self._cache[types] = func + try: + return func(*args, **kwargs) + + except MDNotImplementedError: + funcs = self.dispatch_iter(*types) + next(funcs) # burn first + for func in funcs: + try: + return func(*args, **kwargs) + except MDNotImplementedError: + pass + raise NotImplementedError("Matching functions for " + "%s: <%s> found, but none completed successfully" + % (self.name, str_signature(types))) + + def __str__(self): + return "" % self.name + __repr__ = __str__ + + def dispatch(self, *types): + """ Deterimine appropriate implementation for this type signature + + This method is internal. Users should call this object as a function. + Implementation resolution occurs within the ``__call__`` method. + + >>> from sympy.multipledispatch import dispatch + >>> @dispatch(int) + ... def inc(x): + ... return x + 1 + + >>> implementation = inc.dispatch(int) + >>> implementation(3) + 4 + + >>> print(inc.dispatch(float)) + None + + See Also: + ``sympy.multipledispatch.conflict`` - module to determine resolution order + """ + + if types in self.funcs: + return self.funcs[types] + + try: + return next(self.dispatch_iter(*types)) + except StopIteration: + return None + + def dispatch_iter(self, *types): + n = len(types) + for signature in self.ordering: + if len(signature) == n and all(map(issubclass, types, signature)): + result = self.funcs[signature] + yield result + + def resolve(self, types): + """ Deterimine appropriate implementation for this type signature + + .. deprecated:: 0.4.4 + Use ``dispatch(*types)`` instead + """ + warn("resolve() is deprecated, use dispatch(*types)", + DeprecationWarning) + + return self.dispatch(*types) + + def __getstate__(self): + return {'name': self.name, + 'funcs': self.funcs} + + def __setstate__(self, d): + self.name = d['name'] + self.funcs = d['funcs'] + self.ordering = ordering(self.funcs) + self._cache = {} + + @property + def __doc__(self): + docs = ["Multiply dispatched method: %s" % self.name] + + if self.doc: + docs.append(self.doc) + + other = [] + for sig in self.ordering[::-1]: + func = self.funcs[sig] + if func.__doc__: + s = 'Inputs: <%s>\n' % str_signature(sig) + s += '-' * len(s) + '\n' + s += func.__doc__.strip() + docs.append(s) + else: + other.append(str_signature(sig)) + + if other: + docs.append('Other signatures:\n ' + '\n '.join(other)) + + return '\n\n'.join(docs) + + def _help(self, *args): + return self.dispatch(*map(type, args)).__doc__ + + def help(self, *args, **kwargs): + """ Print docstring for the function corresponding to inputs """ + print(self._help(*args)) + + def _source(self, *args): + func = self.dispatch(*map(type, args)) + if not func: + raise TypeError("No function found") + return source(func) + + def source(self, *args, **kwargs): + """ Print source code for the function corresponding to inputs """ + print(self._source(*args)) + + +def source(func): + s = 'File: %s\n\n' % inspect.getsourcefile(func) + s = s + inspect.getsource(func) + return s + + +class MethodDispatcher(Dispatcher): + """ Dispatch methods based on type signature + + See Also: + Dispatcher + """ + + @classmethod + def get_func_params(cls, func): + if hasattr(inspect, "signature"): + sig = inspect.signature(func) + return itl.islice(sig.parameters.values(), 1, None) + + def __get__(self, instance, owner): + self.obj = instance + self.cls = owner + return self + + def __call__(self, *args, **kwargs): + types = tuple([type(arg) for arg in args]) + func = self.dispatch(*types) + if not func: + raise NotImplementedError('Could not find signature for %s: <%s>' % + (self.name, str_signature(types))) + return func(self.obj, *args, **kwargs) + + +def str_signature(sig): + """ String representation of type signature + + >>> from sympy.multipledispatch.dispatcher import str_signature + >>> str_signature((int, float)) + 'int, float' + """ + return ', '.join(cls.__name__ for cls in sig) + + +def warning_text(name, amb): + """ The text for ambiguity warnings """ + text = "\nAmbiguities exist in dispatched function %s\n\n" % (name) + text += "The following signatures may result in ambiguous behavior:\n" + for pair in amb: + text += "\t" + \ + ', '.join('[' + str_signature(s) + ']' for s in pair) + "\n" + text += "\n\nConsider making the following additions:\n\n" + text += '\n\n'.join(['@dispatch(' + str_signature(super_signature(s)) + + ')\ndef %s(...)' % name for s in amb]) + return text diff --git a/phivenv/Lib/site-packages/sympy/multipledispatch/tests/__init__.py b/phivenv/Lib/site-packages/sympy/multipledispatch/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/multipledispatch/tests/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/multipledispatch/tests/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd77078b50b474a0ec4e5ae24257d363948339f1 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/multipledispatch/tests/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/multipledispatch/tests/__pycache__/test_conflict.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/multipledispatch/tests/__pycache__/test_conflict.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06fc1c3182f83e79c24cf20f411fa65e239e9a70 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/multipledispatch/tests/__pycache__/test_conflict.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/multipledispatch/tests/__pycache__/test_core.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/multipledispatch/tests/__pycache__/test_core.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38fe08fc892df90c86eadb4907d3283c36b691ad Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/multipledispatch/tests/__pycache__/test_core.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/multipledispatch/tests/__pycache__/test_dispatcher.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/multipledispatch/tests/__pycache__/test_dispatcher.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..56d2d2b93ce53c7f1fe4722cd32dada2db1ed9be Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/multipledispatch/tests/__pycache__/test_dispatcher.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/multipledispatch/tests/test_conflict.py b/phivenv/Lib/site-packages/sympy/multipledispatch/tests/test_conflict.py new file mode 100644 index 0000000000000000000000000000000000000000..5d2292c460585ae2a65a01795b38499e67706ff0 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/multipledispatch/tests/test_conflict.py @@ -0,0 +1,62 @@ +from sympy.multipledispatch.conflict import (supercedes, ordering, ambiguities, + ambiguous, super_signature, consistent) + + +class A: pass +class B(A): pass +class C: pass + + +def test_supercedes(): + assert supercedes([B], [A]) + assert supercedes([B, A], [A, A]) + assert not supercedes([B, A], [A, B]) + assert not supercedes([A], [B]) + + +def test_consistent(): + assert consistent([A], [A]) + assert consistent([B], [B]) + assert not consistent([A], [C]) + assert consistent([A, B], [A, B]) + assert consistent([B, A], [A, B]) + assert not consistent([B, A], [B]) + assert not consistent([B, A], [B, C]) + + +def test_super_signature(): + assert super_signature([[A]]) == [A] + assert super_signature([[A], [B]]) == [B] + assert super_signature([[A, B], [B, A]]) == [B, B] + assert super_signature([[A, A, B], [A, B, A], [B, A, A]]) == [B, B, B] + + +def test_ambiguous(): + assert not ambiguous([A], [A]) + assert not ambiguous([A], [B]) + assert not ambiguous([B], [B]) + assert not ambiguous([A, B], [B, B]) + assert ambiguous([A, B], [B, A]) + + +def test_ambiguities(): + signatures = [[A], [B], [A, B], [B, A], [A, C]] + expected = {((A, B), (B, A))} + result = ambiguities(signatures) + assert set(map(frozenset, expected)) == set(map(frozenset, result)) + + signatures = [[A], [B], [A, B], [B, A], [A, C], [B, B]] + expected = set() + result = ambiguities(signatures) + assert set(map(frozenset, expected)) == set(map(frozenset, result)) + + +def test_ordering(): + signatures = [[A, A], [A, B], [B, A], [B, B], [A, C]] + ord = ordering(signatures) + assert ord[0] == (B, B) or ord[0] == (A, C) + assert ord[-1] == (A, A) or ord[-1] == (A, C) + + +def test_type_mro(): + assert super_signature([[object], [type]]) == [type] diff --git a/phivenv/Lib/site-packages/sympy/multipledispatch/tests/test_core.py b/phivenv/Lib/site-packages/sympy/multipledispatch/tests/test_core.py new file mode 100644 index 0000000000000000000000000000000000000000..016270fecc8cda644fc71b5c310b1430b50361f6 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/multipledispatch/tests/test_core.py @@ -0,0 +1,213 @@ +from __future__ import annotations +from typing import Any + +from sympy.multipledispatch import dispatch +from sympy.multipledispatch.conflict import AmbiguityWarning +from sympy.testing.pytest import raises, warns +from functools import partial + +test_namespace: dict[str, Any] = {} + +orig_dispatch = dispatch +dispatch = partial(dispatch, namespace=test_namespace) + + +def test_singledispatch(): + @dispatch(int) + def f(x): # noqa:F811 + return x + 1 + + @dispatch(int) + def g(x): # noqa:F811 + return x + 2 + + @dispatch(float) # noqa:F811 + def f(x): # noqa:F811 + return x - 1 + + assert f(1) == 2 + assert g(1) == 3 + assert f(1.0) == 0 + + assert raises(NotImplementedError, lambda: f('hello')) + + +def test_multipledispatch(): + @dispatch(int, int) + def f(x, y): # noqa:F811 + return x + y + + @dispatch(float, float) # noqa:F811 + def f(x, y): # noqa:F811 + return x - y + + assert f(1, 2) == 3 + assert f(1.0, 2.0) == -1.0 + + +class A: pass +class B: pass +class C(A): pass +class D(C): pass +class E(C): pass + + +def test_inheritance(): + @dispatch(A) + def f(x): # noqa:F811 + return 'a' + + @dispatch(B) # noqa:F811 + def f(x): # noqa:F811 + return 'b' + + assert f(A()) == 'a' + assert f(B()) == 'b' + assert f(C()) == 'a' + + +def test_inheritance_and_multiple_dispatch(): + @dispatch(A, A) + def f(x, y): # noqa:F811 + return type(x), type(y) + + @dispatch(A, B) # noqa:F811 + def f(x, y): # noqa:F811 + return 0 + + assert f(A(), A()) == (A, A) + assert f(A(), C()) == (A, C) + assert f(A(), B()) == 0 + assert f(C(), B()) == 0 + assert raises(NotImplementedError, lambda: f(B(), B())) + + +def test_competing_solutions(): + @dispatch(A) + def h(x): # noqa:F811 + return 1 + + @dispatch(C) # noqa:F811 + def h(x): # noqa:F811 + return 2 + + assert h(D()) == 2 + + +def test_competing_multiple(): + @dispatch(A, B) + def h(x, y): # noqa:F811 + return 1 + + @dispatch(C, B) # noqa:F811 + def h(x, y): # noqa:F811 + return 2 + + assert h(D(), B()) == 2 + + +def test_competing_ambiguous(): + test_namespace = {} + dispatch = partial(orig_dispatch, namespace=test_namespace) + + @dispatch(A, C) + def f(x, y): # noqa:F811 + return 2 + + with warns(AmbiguityWarning, test_stacklevel=False): + @dispatch(C, A) # noqa:F811 + def f(x, y): # noqa:F811 + return 2 + + assert f(A(), C()) == f(C(), A()) == 2 + # assert raises(Warning, lambda : f(C(), C())) + + +def test_caching_correct_behavior(): + @dispatch(A) + def f(x): # noqa:F811 + return 1 + + assert f(C()) == 1 + + @dispatch(C) + def f(x): # noqa:F811 + return 2 + + assert f(C()) == 2 + + +def test_union_types(): + @dispatch((A, C)) + def f(x): # noqa:F811 + return 1 + + assert f(A()) == 1 + assert f(C()) == 1 + + +def test_namespaces(): + ns1 = {} + ns2 = {} + + def foo(x): + return 1 + foo1 = orig_dispatch(int, namespace=ns1)(foo) + + def foo(x): + return 2 + foo2 = orig_dispatch(int, namespace=ns2)(foo) + + assert foo1(0) == 1 + assert foo2(0) == 2 + + +""" +Fails +def test_dispatch_on_dispatch(): + @dispatch(A) + @dispatch(C) + def q(x): # noqa:F811 + return 1 + + assert q(A()) == 1 + assert q(C()) == 1 +""" + + +def test_methods(): + class Foo: + @dispatch(float) + def f(self, x): # noqa:F811 + return x - 1 + + @dispatch(int) # noqa:F811 + def f(self, x): # noqa:F811 + return x + 1 + + @dispatch(int) + def g(self, x): # noqa:F811 + return x + 3 + + + foo = Foo() + assert foo.f(1) == 2 + assert foo.f(1.0) == 0.0 + assert foo.g(1) == 4 + + +def test_methods_multiple_dispatch(): + class Foo: + @dispatch(A, A) + def f(x, y): # noqa:F811 + return 1 + + @dispatch(A, C) # noqa:F811 + def f(x, y): # noqa:F811 + return 2 + + + foo = Foo() + assert foo.f(A(), A()) == 1 + assert foo.f(A(), C()) == 2 + assert foo.f(C(), C()) == 2 diff --git a/phivenv/Lib/site-packages/sympy/multipledispatch/tests/test_dispatcher.py b/phivenv/Lib/site-packages/sympy/multipledispatch/tests/test_dispatcher.py new file mode 100644 index 0000000000000000000000000000000000000000..e31ca8a5486b87eb43fc5e6f887caf50d6bfbe20 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/multipledispatch/tests/test_dispatcher.py @@ -0,0 +1,284 @@ +from sympy.multipledispatch.dispatcher import (Dispatcher, MDNotImplementedError, + MethodDispatcher, halt_ordering, + restart_ordering, + ambiguity_register_error_ignore_dup) +from sympy.testing.pytest import raises, warns + + +def identity(x): + return x + + +def inc(x): + return x + 1 + + +def dec(x): + return x - 1 + + +def test_dispatcher(): + f = Dispatcher('f') + f.add((int,), inc) + f.add((float,), dec) + + with warns(DeprecationWarning, test_stacklevel=False): + assert f.resolve((int,)) == inc + assert f.dispatch(int) is inc + + assert f(1) == 2 + assert f(1.0) == 0.0 + + +def test_union_types(): + f = Dispatcher('f') + f.register((int, float))(inc) + + assert f(1) == 2 + assert f(1.0) == 2.0 + + +def test_dispatcher_as_decorator(): + f = Dispatcher('f') + + @f.register(int) + def inc(x): # noqa:F811 + return x + 1 + + @f.register(float) # noqa:F811 + def inc(x): # noqa:F811 + return x - 1 + + assert f(1) == 2 + assert f(1.0) == 0.0 + + +def test_register_instance_method(): + + class Test: + __init__ = MethodDispatcher('f') + + @__init__.register(list) + def _init_list(self, data): + self.data = data + + @__init__.register(object) + def _init_obj(self, datum): + self.data = [datum] + + a = Test(3) + b = Test([3]) + assert a.data == b.data + + +def test_on_ambiguity(): + f = Dispatcher('f') + + def identity(x): return x + + ambiguities = [False] + + def on_ambiguity(dispatcher, amb): + ambiguities[0] = True + + f.add((object, object), identity, on_ambiguity=on_ambiguity) + assert not ambiguities[0] + f.add((object, float), identity, on_ambiguity=on_ambiguity) + assert not ambiguities[0] + f.add((float, object), identity, on_ambiguity=on_ambiguity) + assert ambiguities[0] + + +def test_raise_error_on_non_class(): + f = Dispatcher('f') + assert raises(TypeError, lambda: f.add((1,), inc)) + + +def test_docstring(): + + def one(x, y): + """ Docstring number one """ + return x + y + + def two(x, y): + """ Docstring number two """ + return x + y + + def three(x, y): + return x + y + + master_doc = 'Doc of the multimethod itself' + + f = Dispatcher('f', doc=master_doc) + f.add((object, object), one) + f.add((int, int), two) + f.add((float, float), three) + + assert one.__doc__.strip() in f.__doc__ + assert two.__doc__.strip() in f.__doc__ + assert f.__doc__.find(one.__doc__.strip()) < \ + f.__doc__.find(two.__doc__.strip()) + assert 'object, object' in f.__doc__ + assert master_doc in f.__doc__ + + +def test_help(): + def one(x, y): + """ Docstring number one """ + return x + y + + def two(x, y): + """ Docstring number two """ + return x + y + + def three(x, y): + """ Docstring number three """ + return x + y + + master_doc = 'Doc of the multimethod itself' + + f = Dispatcher('f', doc=master_doc) + f.add((object, object), one) + f.add((int, int), two) + f.add((float, float), three) + + assert f._help(1, 1) == two.__doc__ + assert f._help(1.0, 2.0) == three.__doc__ + + +def test_source(): + def one(x, y): + """ Docstring number one """ + return x + y + + def two(x, y): + """ Docstring number two """ + return x - y + + master_doc = 'Doc of the multimethod itself' + + f = Dispatcher('f', doc=master_doc) + f.add((int, int), one) + f.add((float, float), two) + + assert 'x + y' in f._source(1, 1) + assert 'x - y' in f._source(1.0, 1.0) + + +def test_source_raises_on_missing_function(): + f = Dispatcher('f') + + assert raises(TypeError, lambda: f.source(1)) + + +def test_halt_method_resolution(): + g = [0] + + def on_ambiguity(a, b): + g[0] += 1 + + f = Dispatcher('f') + + halt_ordering() + + def func(*args): + pass + + f.add((int, object), func) + f.add((object, int), func) + + assert g == [0] + + restart_ordering(on_ambiguity=on_ambiguity) + + assert g == [1] + + assert set(f.ordering) == {(int, object), (object, int)} + + +def test_no_implementations(): + f = Dispatcher('f') + assert raises(NotImplementedError, lambda: f('hello')) + + +def test_register_stacking(): + f = Dispatcher('f') + + @f.register(list) + @f.register(tuple) + def rev(x): + return x[::-1] + + assert f((1, 2, 3)) == (3, 2, 1) + assert f([1, 2, 3]) == [3, 2, 1] + + assert raises(NotImplementedError, lambda: f('hello')) + assert rev('hello') == 'olleh' + + +def test_dispatch_method(): + f = Dispatcher('f') + + @f.register(list) + def rev(x): + return x[::-1] + + @f.register(int, int) + def add(x, y): + return x + y + + class MyList(list): + pass + + assert f.dispatch(list) is rev + assert f.dispatch(MyList) is rev + assert f.dispatch(int, int) is add + + +def test_not_implemented(): + f = Dispatcher('f') + + @f.register(object) + def _(x): + return 'default' + + @f.register(int) + def _(x): + if x % 2 == 0: + return 'even' + else: + raise MDNotImplementedError() + + assert f('hello') == 'default' # default behavior + assert f(2) == 'even' # specialized behavior + assert f(3) == 'default' # fall bac to default behavior + assert raises(NotImplementedError, lambda: f(1, 2)) + + +def test_not_implemented_error(): + f = Dispatcher('f') + + @f.register(float) + def _(a): + raise MDNotImplementedError() + + assert raises(NotImplementedError, lambda: f(1.0)) + +def test_ambiguity_register_error_ignore_dup(): + f = Dispatcher('f') + + class A: + pass + class B(A): + pass + class C(A): + pass + + # suppress warning for registering ambiguous signal + f.add((A, B), lambda x,y: None, ambiguity_register_error_ignore_dup) + f.add((B, A), lambda x,y: None, ambiguity_register_error_ignore_dup) + f.add((A, C), lambda x,y: None, ambiguity_register_error_ignore_dup) + f.add((C, A), lambda x,y: None, ambiguity_register_error_ignore_dup) + + # raises error if ambiguous signal is passed + assert raises(NotImplementedError, lambda: f(B(), C())) diff --git a/phivenv/Lib/site-packages/sympy/multipledispatch/utils.py b/phivenv/Lib/site-packages/sympy/multipledispatch/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..11f563772385124c2fc0d285f7aa6e0747b8b412 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/multipledispatch/utils.py @@ -0,0 +1,105 @@ +from collections import OrderedDict + + +def expand_tuples(L): + """ + >>> from sympy.multipledispatch.utils import expand_tuples + >>> expand_tuples([1, (2, 3)]) + [(1, 2), (1, 3)] + + >>> expand_tuples([1, 2]) + [(1, 2)] + """ + if not L: + return [()] + elif not isinstance(L[0], tuple): + rest = expand_tuples(L[1:]) + return [(L[0],) + t for t in rest] + else: + rest = expand_tuples(L[1:]) + return [(item,) + t for t in rest for item in L[0]] + + +# Taken from theano/theano/gof/sched.py +# Avoids licensing issues because this was written by Matthew Rocklin +def _toposort(edges): + """ Topological sort algorithm by Kahn [1] - O(nodes + vertices) + + inputs: + edges - a dict of the form {a: {b, c}} where b and c depend on a + outputs: + L - an ordered list of nodes that satisfy the dependencies of edges + + >>> from sympy.multipledispatch.utils import _toposort + >>> _toposort({1: (2, 3), 2: (3, )}) + [1, 2, 3] + + Closely follows the wikipedia page [2] + + [1] Kahn, Arthur B. (1962), "Topological sorting of large networks", + Communications of the ACM + [2] https://en.wikipedia.org/wiki/Toposort#Algorithms + """ + incoming_edges = reverse_dict(edges) + incoming_edges = {k: set(val) for k, val in incoming_edges.items()} + S = OrderedDict.fromkeys(v for v in edges if v not in incoming_edges) + L = [] + + while S: + n, _ = S.popitem() + L.append(n) + for m in edges.get(n, ()): + assert n in incoming_edges[m] + incoming_edges[m].remove(n) + if not incoming_edges[m]: + S[m] = None + if any(incoming_edges.get(v, None) for v in edges): + raise ValueError("Input has cycles") + return L + + +def reverse_dict(d): + """Reverses direction of dependence dict + + >>> d = {'a': (1, 2), 'b': (2, 3), 'c':()} + >>> reverse_dict(d) # doctest: +SKIP + {1: ('a',), 2: ('a', 'b'), 3: ('b',)} + + :note: dict order are not deterministic. As we iterate on the + input dict, it make the output of this function depend on the + dict order. So this function output order should be considered + as undeterministic. + + """ + result = {} + for key in d: + for val in d[key]: + result[val] = result.get(val, ()) + (key, ) + return result + + +# Taken from toolz +# Avoids licensing issues because this version was authored by Matthew Rocklin +def groupby(func, seq): + """ Group a collection by a key function + + >>> from sympy.multipledispatch.utils import groupby + >>> names = ['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank'] + >>> groupby(len, names) # doctest: +SKIP + {3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']} + + >>> iseven = lambda x: x % 2 == 0 + >>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8]) # doctest: +SKIP + {False: [1, 3, 5, 7], True: [2, 4, 6, 8]} + + See Also: + ``countby`` + """ + + d = {} + for item in seq: + key = func(item) + if key not in d: + d[key] = [] + d[key].append(item) + return d diff --git a/phivenv/Lib/site-packages/sympy/ntheory/__init__.py b/phivenv/Lib/site-packages/sympy/ntheory/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..19576c8935da455743d27f0a263caecca94f59f8 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/ntheory/__init__.py @@ -0,0 +1,67 @@ +""" +Number theory module (primes, etc) +""" + +from .generate import nextprime, prevprime, prime, primepi, primerange, \ + randprime, Sieve, sieve, primorial, cycle_length, composite, compositepi +from .primetest import isprime, is_gaussian_prime, is_mersenne_prime +from .factor_ import divisors, proper_divisors, factorint, multiplicity, \ + multiplicity_in_factorial, perfect_power, factor_cache, pollard_pm1, \ + pollard_rho, primefactors, totient, \ + divisor_count, proper_divisor_count, divisor_sigma, factorrat, \ + reduced_totient, primenu, primeomega, mersenne_prime_exponent, \ + is_perfect, is_abundant, is_deficient, is_amicable, is_carmichael, \ + abundance, dra, drm + +from .partitions_ import npartitions +from .residue_ntheory import is_primitive_root, is_quad_residue, \ + legendre_symbol, jacobi_symbol, n_order, sqrt_mod, quadratic_residues, \ + primitive_root, nthroot_mod, is_nthpow_residue, sqrt_mod_iter, mobius, \ + discrete_log, quadratic_congruence, polynomial_congruence +from .multinomial import binomial_coefficients, binomial_coefficients_list, \ + multinomial_coefficients +from .continued_fraction import continued_fraction_periodic, \ + continued_fraction_iterator, continued_fraction_reduce, \ + continued_fraction_convergents, continued_fraction +from .digits import count_digits, digits, is_palindromic +from .egyptian_fraction import egyptian_fraction +from .ecm import ecm +from .qs import qs, qs_factor +__all__ = [ + 'nextprime', 'prevprime', 'prime', 'primepi', 'primerange', 'randprime', + 'Sieve', 'sieve', 'primorial', 'cycle_length', 'composite', 'compositepi', + + 'isprime', 'is_gaussian_prime', 'is_mersenne_prime', + + + 'divisors', 'proper_divisors', 'factorint', 'multiplicity', 'perfect_power', + 'pollard_pm1', 'factor_cache', 'pollard_rho', 'primefactors', 'totient', + 'divisor_count', 'proper_divisor_count', 'divisor_sigma', 'factorrat', + 'reduced_totient', 'primenu', 'primeomega', 'mersenne_prime_exponent', + 'is_perfect', 'is_abundant', 'is_deficient', 'is_amicable', + 'is_carmichael', 'abundance', 'dra', 'drm', 'multiplicity_in_factorial', + + 'npartitions', + + 'is_primitive_root', 'is_quad_residue', 'legendre_symbol', + 'jacobi_symbol', 'n_order', 'sqrt_mod', 'quadratic_residues', + 'primitive_root', 'nthroot_mod', 'is_nthpow_residue', 'sqrt_mod_iter', + 'mobius', 'discrete_log', 'quadratic_congruence', 'polynomial_congruence', + + 'binomial_coefficients', 'binomial_coefficients_list', + 'multinomial_coefficients', + + 'continued_fraction_periodic', 'continued_fraction_iterator', + 'continued_fraction_reduce', 'continued_fraction_convergents', + 'continued_fraction', + + 'digits', + 'count_digits', + 'is_palindromic', + + 'egyptian_fraction', + + 'ecm', + + 'qs', 'qs_factor', +] diff --git a/phivenv/Lib/site-packages/sympy/ntheory/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/ntheory/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..463a222c9a80657c4452d9661f387d0af3ac0cd3 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/ntheory/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/ntheory/__pycache__/bbp_pi.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/ntheory/__pycache__/bbp_pi.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cbabeac1df94bc6c70447488a9343a66f4d1b950 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/ntheory/__pycache__/bbp_pi.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/ntheory/__pycache__/continued_fraction.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/ntheory/__pycache__/continued_fraction.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d70b0b5ea196ff594d7d4aec517a7c2d7e90123a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/ntheory/__pycache__/continued_fraction.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/ntheory/__pycache__/digits.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/ntheory/__pycache__/digits.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7fd5c37114ab12a1a9dce35b57c4889ebeb279ee Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/ntheory/__pycache__/digits.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/ntheory/__pycache__/ecm.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/ntheory/__pycache__/ecm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4cd2f2b7b71e838b2c772dc54ce9fc97cd864dbf Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/ntheory/__pycache__/ecm.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/ntheory/__pycache__/egyptian_fraction.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/ntheory/__pycache__/egyptian_fraction.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ec1bf5db46f3ee4a1be38958fc9200cb8230aff Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/ntheory/__pycache__/egyptian_fraction.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/ntheory/__pycache__/elliptic_curve.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/ntheory/__pycache__/elliptic_curve.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06c6fc0a7a042e55333762c20b3a8fe4ef89d89d Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/ntheory/__pycache__/elliptic_curve.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/ntheory/__pycache__/factor_.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/ntheory/__pycache__/factor_.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17aa4db3b67ac1a2d537cc335e3e09354d43de3d Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/ntheory/__pycache__/factor_.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/ntheory/__pycache__/generate.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/ntheory/__pycache__/generate.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8dae60cdc9de8c2141a78211d150dd97f1bd05d5 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/ntheory/__pycache__/generate.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/ntheory/__pycache__/modular.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/ntheory/__pycache__/modular.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..894976dabe8a7e80ec4aeba1daa0cb2f2a47d8d2 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/ntheory/__pycache__/modular.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/ntheory/__pycache__/multinomial.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/ntheory/__pycache__/multinomial.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d15c279fed30cefe27795712a5876a0daa0a1e39 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/ntheory/__pycache__/multinomial.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/ntheory/__pycache__/partitions_.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/ntheory/__pycache__/partitions_.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..399eab2844e3da4ed3fb5c1e2aae35cbbd002a38 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/ntheory/__pycache__/partitions_.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/ntheory/__pycache__/primetest.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/ntheory/__pycache__/primetest.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5a7286231490f1ab596548667ea44a6ccf651d1 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/ntheory/__pycache__/primetest.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/ntheory/__pycache__/qs.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/ntheory/__pycache__/qs.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09389bb37cec13633538478be61177365d5fe131 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/ntheory/__pycache__/qs.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/ntheory/__pycache__/residue_ntheory.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/ntheory/__pycache__/residue_ntheory.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..749f0cc03ab9350017ea9e0877d83eea1a26680c Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/ntheory/__pycache__/residue_ntheory.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/ntheory/bbp_pi.py b/phivenv/Lib/site-packages/sympy/ntheory/bbp_pi.py new file mode 100644 index 0000000000000000000000000000000000000000..e2ff4b755d74d4e075ac7195f991c8182d175693 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/ntheory/bbp_pi.py @@ -0,0 +1,190 @@ +''' +This implementation is a heavily modified fixed point implementation of +BBP_formula for calculating the nth position of pi. The original hosted +at: https://web.archive.org/web/20151116045029/http://en.literateprograms.org/Pi_with_the_BBP_formula_(Python) + +# Permission is hereby granted, free of charge, to any person obtaining +# a copy of this software and associated documentation files (the +# "Software"), to deal in the Software without restriction, including +# without limitation the rights to use, copy, modify, merge, publish, +# distribute, sub-license, and/or sell copies of the Software, and to +# permit persons to whom the Software is furnished to do so, subject to +# the following conditions: +# +# The above copyright notice and this permission notice shall be +# included in all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +Modifications: + +1.Once the nth digit and desired number of digits is selected, the +number of digits of working precision is calculated to ensure that +the hexadecimal digits returned are accurate. This is calculated as + + int(math.log(start + prec)/math.log(16) + prec + 3) + --------------------------------------- -------- + / / + number of hex digits additional digits + +This was checked by the following code which completed without +errors (and dig are the digits included in the test_bbp.py file): + + for i in range(0,1000): + for j in range(1,1000): + a, b = pi_hex_digits(i, j), dig[i:i+j] + if a != b: + print('%s\n%s'%(a,b)) + +Deceasing the additional digits by 1 generated errors, so '3' is +the smallest additional precision needed to calculate the above +loop without errors. The following trailing 10 digits were also +checked to be accurate (and the times were slightly faster with +some of the constant modifications that were made): + + >> from time import time + >> t=time();pi_hex_digits(10**2-10 + 1, 10), time()-t + ('e90c6cc0ac', 0.0) + >> t=time();pi_hex_digits(10**4-10 + 1, 10), time()-t + ('26aab49ec6', 0.17100000381469727) + >> t=time();pi_hex_digits(10**5-10 + 1, 10), time()-t + ('a22673c1a5', 4.7109999656677246) + >> t=time();pi_hex_digits(10**6-10 + 1, 10), time()-t + ('9ffd342362', 59.985999822616577) + >> t=time();pi_hex_digits(10**7-10 + 1, 10), time()-t + ('c1a42e06a1', 689.51800012588501) + +2. The while loop to evaluate whether the series has converged quits +when the addition amount `dt` has dropped to zero. + +3. the formatting string to convert the decimal to hexadecimal is +calculated for the given precision. + +4. pi_hex_digits(n) changed to have coefficient to the formula in an +array (perhaps just a matter of preference). + +''' + +from sympy.utilities.misc import as_int + + +def _series(j, n, prec=14): + + # Left sum from the bbp algorithm + s = 0 + D = _dn(n, prec) + D4 = 4 * D + d = j + for k in range(n + 1): + s += (pow(16, n - k, d) << D4) // d + d += 8 + + # Right sum iterates to infinity for full precision, but we + # stop at the point where one iteration is beyond the precision + # specified. + + t = 0 + k = n + 1 + e = D4 - 4 # 4*(D + n - k) + d = 8 * k + j + while True: + dt = (1 << e) // d + if not dt: + break + t += dt + # k += 1 + e -= 4 + d += 8 + total = s + t + + return total + + +def pi_hex_digits(n, prec=14): + """Returns a string containing ``prec`` (default 14) digits + starting at the nth digit of pi in hex. Counting of digits + starts at 0 and the decimal is not counted, so for n = 0 the + returned value starts with 3; n = 1 corresponds to the first + digit past the decimal point (which in hex is 2). + + Parameters + ========== + + n : non-negative integer + prec : non-negative integer. default = 14 + + Returns + ======= + + str : Returns a string containing ``prec`` digits + starting at the nth digit of pi in hex. + If ``prec`` = 0, returns empty string. + + Raises + ====== + + ValueError + If ``n`` < 0 or ``prec`` < 0. + Or ``n`` or ``prec`` is not an integer. + + Examples + ======== + + >>> from sympy.ntheory.bbp_pi import pi_hex_digits + >>> pi_hex_digits(0) + '3243f6a8885a30' + >>> pi_hex_digits(0, 3) + '324' + + These are consistent with the following results + + >>> import math + >>> hex(int(math.pi * 2**((14-1)*4))) + '0x3243f6a8885a30' + >>> hex(int(math.pi * 2**((3-1)*4))) + '0x324' + + References + ========== + + .. [1] http://www.numberworld.org/digits/Pi/ + """ + n, prec = as_int(n), as_int(prec) + if n < 0: + raise ValueError('n cannot be negative') + if prec < 0: + raise ValueError('prec cannot be negative') + if prec == 0: + return '' + + # main of implementation arrays holding formulae coefficients + n -= 1 + a = [4, 2, 1, 1] + j = [1, 4, 5, 6] + + #formulae + D = _dn(n, prec) + x = + (a[0]*_series(j[0], n, prec) + - a[1]*_series(j[1], n, prec) + - a[2]*_series(j[2], n, prec) + - a[3]*_series(j[3], n, prec)) & (16**D - 1) + + s = ("%0" + "%ix" % prec) % (x // 16**(D - prec)) + return s + + +def _dn(n, prec): + # controller for n dependence on precision + # n = starting digit index + # prec = the number of total digits to compute + n += 1 # because we subtract 1 for _series + + # assert int(math.log(n + prec)/math.log(16)) ==\ + # ((n + prec).bit_length() - 1) // 4 + return ((n + prec).bit_length() - 1) // 4 + prec + 3 diff --git a/phivenv/Lib/site-packages/sympy/ntheory/continued_fraction.py b/phivenv/Lib/site-packages/sympy/ntheory/continued_fraction.py new file mode 100644 index 0000000000000000000000000000000000000000..62f8e2d729ada3414a87d6f0583e06bee2a2b220 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/ntheory/continued_fraction.py @@ -0,0 +1,369 @@ +from __future__ import annotations +import itertools +from sympy.core.exprtools import factor_terms +from sympy.core.numbers import Integer, Rational +from sympy.core.singleton import S +from sympy.core.symbol import Dummy +from sympy.core.sympify import _sympify +from sympy.utilities.misc import as_int + + +def continued_fraction(a) -> list: + """Return the continued fraction representation of a Rational or + quadratic irrational. + + Examples + ======== + + >>> from sympy.ntheory.continued_fraction import continued_fraction + >>> from sympy import sqrt + >>> continued_fraction((1 + 2*sqrt(3))/5) + [0, 1, [8, 3, 34, 3]] + + See Also + ======== + continued_fraction_periodic, continued_fraction_reduce, continued_fraction_convergents + """ + e = _sympify(a) + if all(i.is_Rational for i in e.atoms()): + if e.is_Integer: + return continued_fraction_periodic(e, 1, 0) + elif e.is_Rational: + return continued_fraction_periodic(e.p, e.q, 0) + elif e.is_Pow and e.exp is S.Half and e.base.is_Integer: + return continued_fraction_periodic(0, 1, e.base) + elif e.is_Mul and len(e.args) == 2 and ( + e.args[0].is_Rational and + e.args[1].is_Pow and + e.args[1].base.is_Integer and + e.args[1].exp is S.Half): + a, b = e.args + return continued_fraction_periodic(0, a.q, b.base, a.p) + else: + # this should not have to work very hard- no + # simplification, cancel, etc... which should be + # done by the user. e.g. This is a fancy 1 but + # the user should simplify it first: + # sqrt(2)*(1 + sqrt(2))/(sqrt(2) + 2) + p, d = e.expand().as_numer_denom() + if d.is_Integer: + if p.is_Rational: + return continued_fraction_periodic(p, d) + # look for a + b*c + # with c = sqrt(s) + if p.is_Add and len(p.args) == 2: + a, bc = p.args + else: + a = S.Zero + bc = p + if a.is_Integer: + b = S.NaN + if bc.is_Mul and len(bc.args) == 2: + b, c = bc.args + elif bc.is_Pow: + b = Integer(1) + c = bc + if b.is_Integer and ( + c.is_Pow and c.exp is S.Half and + c.base.is_Integer): + # (a + b*sqrt(c))/d + c = c.base + return continued_fraction_periodic(a, d, c, b) + raise ValueError( + 'expecting a rational or quadratic irrational, not %s' % e) + + +def continued_fraction_periodic(p, q, d=0, s=1) -> list: + r""" + Find the periodic continued fraction expansion of a quadratic irrational. + + Compute the continued fraction expansion of a rational or a + quadratic irrational number, i.e. `\frac{p + s\sqrt{d}}{q}`, where + `p`, `q \ne 0` and `d \ge 0` are integers. + + Returns the continued fraction representation (canonical form) as + a list of integers, optionally ending (for quadratic irrationals) + with list of integers representing the repeating digits. + + Parameters + ========== + + p : int + the rational part of the number's numerator + q : int + the denominator of the number + d : int, optional + the irrational part (discriminator) of the number's numerator + s : int, optional + the coefficient of the irrational part + + Examples + ======== + + >>> from sympy.ntheory.continued_fraction import continued_fraction_periodic + >>> continued_fraction_periodic(3, 2, 7) + [2, [1, 4, 1, 1]] + + Golden ratio has the simplest continued fraction expansion: + + >>> continued_fraction_periodic(1, 2, 5) + [[1]] + + If the discriminator is zero or a perfect square then the number will be a + rational number: + + >>> continued_fraction_periodic(4, 3, 0) + [1, 3] + >>> continued_fraction_periodic(4, 3, 49) + [3, 1, 2] + + See Also + ======== + + continued_fraction_iterator, continued_fraction_reduce + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Periodic_continued_fraction + .. [2] K. Rosen. Elementary Number theory and its applications. + Addison-Wesley, 3 Sub edition, pages 379-381, January 1992. + + """ + from sympy.functions import sqrt, floor + + p, q, d, s = list(map(as_int, [p, q, d, s])) + + if d < 0: + raise ValueError("expected non-negative for `d` but got %s" % d) + + if q == 0: + raise ValueError("The denominator cannot be 0.") + + if not s: + d = 0 + + # check for rational case + sd = sqrt(d) + if sd.is_Integer: + return list(continued_fraction_iterator(Rational(p + s*sd, q))) + + # irrational case with sd != Integer + if q < 0: + p, q, s = -p, -q, -s + + n = (p + s*sd)/q + if n < 0: + w = floor(-n) + f = -n - w + one_f = continued_fraction(1 - f) # 1-f < 1 so cf is [0 ... [...]] + one_f[0] -= w + 1 + return one_f + + d *= s**2 + sd *= s + + if (d - p**2)%q: + d *= q**2 + sd *= q + p *= q + q *= q + + terms: list[int] = [] + pq = {} + + while (p, q) not in pq: + pq[(p, q)] = len(terms) + terms.append((p + sd)//q) + p = terms[-1]*q - p + q = (d - p**2)//q + + i = pq[(p, q)] + return terms[:i] + [terms[i:]] # type: ignore + + +def continued_fraction_reduce(cf): + """ + Reduce a continued fraction to a rational or quadratic irrational. + + Compute the rational or quadratic irrational number from its + terminating or periodic continued fraction expansion. The + continued fraction expansion (cf) should be supplied as a + terminating iterator supplying the terms of the expansion. For + terminating continued fractions, this is equivalent to + ``list(continued_fraction_convergents(cf))[-1]``, only a little more + efficient. If the expansion has a repeating part, a list of the + repeating terms should be returned as the last element from the + iterator. This is the format returned by + continued_fraction_periodic. + + For quadratic irrationals, returns the largest solution found, + which is generally the one sought, if the fraction is in canonical + form (all terms positive except possibly the first). + + Examples + ======== + + >>> from sympy.ntheory.continued_fraction import continued_fraction_reduce + >>> continued_fraction_reduce([1, 2, 3, 4, 5]) + 225/157 + >>> continued_fraction_reduce([-2, 1, 9, 7, 1, 2]) + -256/233 + >>> continued_fraction_reduce([2, 1, 2, 1, 1, 4, 1, 1, 6, 1, 1, 8]).n(10) + 2.718281835 + >>> continued_fraction_reduce([1, 4, 2, [3, 1]]) + (sqrt(21) + 287)/238 + >>> continued_fraction_reduce([[1]]) + (1 + sqrt(5))/2 + >>> from sympy.ntheory.continued_fraction import continued_fraction_periodic + >>> continued_fraction_reduce(continued_fraction_periodic(8, 5, 13)) + (sqrt(13) + 8)/5 + + See Also + ======== + + continued_fraction_periodic + + """ + from sympy.solvers import solve + + period = [] + x = Dummy('x') + + def untillist(cf): + for nxt in cf: + if isinstance(nxt, list): + period.extend(nxt) + yield x + break + yield nxt + + a = S.Zero + for a in continued_fraction_convergents(untillist(cf)): + pass + + if period: + y = Dummy('y') + solns = solve(continued_fraction_reduce(period + [y]) - y, y) + solns.sort() + pure = solns[-1] + rv = a.subs(x, pure).radsimp() + else: + rv = a + if rv.is_Add: + rv = factor_terms(rv) + if rv.is_Mul and rv.args[0] == -1: + rv = rv.func(*rv.args) + return rv + + +def continued_fraction_iterator(x): + """ + Return continued fraction expansion of x as iterator. + + Examples + ======== + + >>> from sympy import Rational, pi + >>> from sympy.ntheory.continued_fraction import continued_fraction_iterator + + >>> list(continued_fraction_iterator(Rational(3, 8))) + [0, 2, 1, 2] + >>> list(continued_fraction_iterator(Rational(-3, 8))) + [-1, 1, 1, 1, 2] + + >>> for i, v in enumerate(continued_fraction_iterator(pi)): + ... if i > 7: + ... break + ... print(v) + 3 + 7 + 15 + 1 + 292 + 1 + 1 + 1 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Continued_fraction + + """ + from sympy.functions import floor + while True: + i = floor(x) + yield i + x -= i + if not x: + break + x = 1/x + + +def continued_fraction_convergents(cf): + """ + Return an iterator over the convergents of a continued fraction (cf). + + The parameter should be in either of the following to forms: + - A list of partial quotients, possibly with the last element being a list + of repeating partial quotients, such as might be returned by + continued_fraction and continued_fraction_periodic. + - An iterable returning successive partial quotients of the continued + fraction, such as might be returned by continued_fraction_iterator. + + In computing the convergents, the continued fraction need not be strictly + in canonical form (all integers, all but the first positive). + Rational and negative elements may be present in the expansion. + + Examples + ======== + + >>> from sympy.core import pi + >>> from sympy import S + >>> from sympy.ntheory.continued_fraction import \ + continued_fraction_convergents, continued_fraction_iterator + + >>> list(continued_fraction_convergents([0, 2, 1, 2])) + [0, 1/2, 1/3, 3/8] + + >>> list(continued_fraction_convergents([1, S('1/2'), -7, S('1/4')])) + [1, 3, 19/5, 7] + + >>> it = continued_fraction_convergents(continued_fraction_iterator(pi)) + >>> for n in range(7): + ... print(next(it)) + 3 + 22/7 + 333/106 + 355/113 + 103993/33102 + 104348/33215 + 208341/66317 + + >>> it = continued_fraction_convergents([1, [1, 2]]) # sqrt(3) + >>> for n in range(7): + ... print(next(it)) + 1 + 2 + 5/3 + 7/4 + 19/11 + 26/15 + 71/41 + + See Also + ======== + + continued_fraction_iterator, continued_fraction, continued_fraction_periodic + + """ + if isinstance(cf, list) and isinstance(cf[-1], list): + cf = itertools.chain(cf[:-1], itertools.cycle(cf[-1])) + p_2, q_2 = S.Zero, S.One + p_1, q_1 = S.One, S.Zero + for a in cf: + p, q = a*p_1 + p_2, a*q_1 + q_2 + p_2, q_2 = p_1, q_1 + p_1, q_1 = p, q + yield p/q diff --git a/phivenv/Lib/site-packages/sympy/ntheory/digits.py b/phivenv/Lib/site-packages/sympy/ntheory/digits.py new file mode 100644 index 0000000000000000000000000000000000000000..a0414815871f6f888ccd2823546ab2b0c2c9f515 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/ntheory/digits.py @@ -0,0 +1,150 @@ +from collections import defaultdict + +from sympy.utilities.iterables import multiset, is_palindromic as _palindromic +from sympy.utilities.misc import as_int + + +def digits(n, b=10, digits=None): + """ + Return a list of the digits of ``n`` in base ``b``. The first + element in the list is ``b`` (or ``-b`` if ``n`` is negative). + + Examples + ======== + + >>> from sympy.ntheory.digits import digits + >>> digits(35) + [10, 3, 5] + + If the number is negative, the negative sign will be placed on the + base (which is the first element in the returned list): + + >>> digits(-35) + [-10, 3, 5] + + Bases other than 10 (and greater than 1) can be selected with ``b``: + + >>> digits(27, b=2) + [2, 1, 1, 0, 1, 1] + + Use the ``digits`` keyword if a certain number of digits is desired: + + >>> digits(35, digits=4) + [10, 0, 0, 3, 5] + + Parameters + ========== + + n: integer + The number whose digits are returned. + + b: integer + The base in which digits are computed. + + digits: integer (or None for all digits) + The number of digits to be returned (padded with zeros, if + necessary). + + See Also + ======== + sympy.core.intfunc.num_digits, count_digits + """ + + b = as_int(b) + n = as_int(n) + if b < 2: + raise ValueError("b must be greater than 1") + else: + x, y = abs(n), [] + while x >= b: + x, r = divmod(x, b) + y.append(r) + y.append(x) + y.append(-b if n < 0 else b) + y.reverse() + ndig = len(y) - 1 + if digits is not None: + if ndig > digits: + raise ValueError( + "For %s, at least %s digits are needed." % (n, ndig)) + elif ndig < digits: + y[1:1] = [0]*(digits - ndig) + return y + + +def count_digits(n, b=10): + """ + Return a dictionary whose keys are the digits of ``n`` in the + given base, ``b``, with keys indicating the digits appearing in the + number and values indicating how many times that digit appeared. + + Examples + ======== + + >>> from sympy.ntheory import count_digits + + >>> count_digits(1111339) + {1: 4, 3: 2, 9: 1} + + The digits returned are always represented in base-10 + but the number itself can be entered in any format that is + understood by Python; the base of the number can also be + given if it is different than 10: + + >>> n = 0xFA; n + 250 + >>> count_digits(_) + {0: 1, 2: 1, 5: 1} + >>> count_digits(n, 16) + {10: 1, 15: 1} + + The default dictionary will return a 0 for any digit that did + not appear in the number. For example, which digits appear 7 + times in ``77!``: + + >>> from sympy import factorial + >>> c77 = count_digits(factorial(77)) + >>> [i for i in range(10) if c77[i] == 7] + [1, 3, 7, 9] + + See Also + ======== + sympy.core.intfunc.num_digits, digits + """ + rv = defaultdict(int, multiset(digits(n, b)).items()) + rv.pop(b) if b in rv else rv.pop(-b) # b or -b is there + return rv + + +def is_palindromic(n, b=10): + """return True if ``n`` is the same when read from left to right + or right to left in the given base, ``b``. + + Examples + ======== + + >>> from sympy.ntheory import is_palindromic + + >>> all(is_palindromic(i) for i in (-11, 1, 22, 121)) + True + + The second argument allows you to test numbers in other + bases. For example, 88 is palindromic in base-10 but not + in base-8: + + >>> is_palindromic(88, 8) + False + + On the other hand, a number can be palindromic in base-8 but + not in base-10: + + >>> 0o121, is_palindromic(0o121) + (81, False) + + Or it might be palindromic in both bases: + + >>> oct(121), is_palindromic(121, 8) and is_palindromic(121) + ('0o171', True) + + """ + return _palindromic(digits(n, b), 1) diff --git a/phivenv/Lib/site-packages/sympy/ntheory/ecm.py b/phivenv/Lib/site-packages/sympy/ntheory/ecm.py new file mode 100644 index 0000000000000000000000000000000000000000..498c0c8fdf8478688465c4bae307818e9685b686 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/ntheory/ecm.py @@ -0,0 +1,348 @@ +from math import log + +from sympy.core.random import _randint +from sympy.external.gmpy import gcd, invert, sqrt +from sympy.utilities.misc import as_int +from .generate import sieve, primerange +from .primetest import isprime + + +#----------------------------------------------------------------------------# +# # +# Lenstra's Elliptic Curve Factorization # +# # +#----------------------------------------------------------------------------# + + +class Point: + """Montgomery form of Points in an elliptic curve. + In this form, the addition and doubling of points + does not need any y-coordinate information thus + decreasing the number of operations. + Using Montgomery form we try to perform point addition + and doubling in least amount of multiplications. + + The elliptic curve used here is of the form + (E : b*y**2*z = x**3 + a*x**2*z + x*z**2). + The a_24 parameter is equal to (a + 2)/4. + + References + ========== + + .. [1] Kris Gaj, Soonhak Kwon, Patrick Baier, Paul Kohlbrenner, Hoang Le, Mohammed Khaleeluddin, Ramakrishna Bachimanchi, + Implementing the Elliptic Curve Method of Factoring in Reconfigurable Hardware, + Cryptographic Hardware and Embedded Systems - CHES 2006 (2006), pp. 119-133, + https://doi.org/10.1007/11894063_10 + https://www.hyperelliptic.org/tanja/SHARCS/talks06/Gaj.pdf + + """ + + def __init__(self, x_cord, z_cord, a_24, mod): + """ + Initial parameters for the Point class. + + Parameters + ========== + + x_cord : X coordinate of the Point + z_cord : Z coordinate of the Point + a_24 : Parameter of the elliptic curve in Montgomery form + mod : modulus + """ + self.x_cord = x_cord + self.z_cord = z_cord + self.a_24 = a_24 + self.mod = mod + + def __eq__(self, other): + """Two points are equal if X/Z of both points are equal + """ + if self.a_24 != other.a_24 or self.mod != other.mod: + return False + return self.x_cord * other.z_cord % self.mod ==\ + other.x_cord * self.z_cord % self.mod + + def add(self, Q, diff): + """ + Add two points self and Q where diff = self - Q. Moreover the assumption + is self.x_cord*Q.x_cord*(self.x_cord - Q.x_cord) != 0. This algorithm + requires 6 multiplications. Here the difference between the points + is already known and using this algorithm speeds up the addition + by reducing the number of multiplication required. Also in the + mont_ladder algorithm is constructed in a way so that the difference + between intermediate points is always equal to the initial point. + So, we always know what the difference between the point is. + + + Parameters + ========== + + Q : point on the curve in Montgomery form + diff : self - Q + + Examples + ======== + + >>> from sympy.ntheory.ecm import Point + >>> p1 = Point(11, 16, 7, 29) + >>> p2 = Point(13, 10, 7, 29) + >>> p3 = p2.add(p1, p1) + >>> p3.x_cord + 23 + >>> p3.z_cord + 17 + """ + u = (self.x_cord - self.z_cord)*(Q.x_cord + Q.z_cord) + v = (self.x_cord + self.z_cord)*(Q.x_cord - Q.z_cord) + add, subt = u + v, u - v + x_cord = diff.z_cord * add * add % self.mod + z_cord = diff.x_cord * subt * subt % self.mod + return Point(x_cord, z_cord, self.a_24, self.mod) + + def double(self): + """ + Doubles a point in an elliptic curve in Montgomery form. + This algorithm requires 5 multiplications. + + Examples + ======== + + >>> from sympy.ntheory.ecm import Point + >>> p1 = Point(11, 16, 7, 29) + >>> p2 = p1.double() + >>> p2.x_cord + 13 + >>> p2.z_cord + 10 + """ + u = pow(self.x_cord + self.z_cord, 2, self.mod) + v = pow(self.x_cord - self.z_cord, 2, self.mod) + diff = u - v + x_cord = u*v % self.mod + z_cord = diff*(v + self.a_24*diff) % self.mod + return Point(x_cord, z_cord, self.a_24, self.mod) + + def mont_ladder(self, k): + """ + Scalar multiplication of a point in Montgomery form + using Montgomery Ladder Algorithm. + A total of 11 multiplications are required in each step of this + algorithm. + + Parameters + ========== + + k : The positive integer multiplier + + Examples + ======== + + >>> from sympy.ntheory.ecm import Point + >>> p1 = Point(11, 16, 7, 29) + >>> p3 = p1.mont_ladder(3) + >>> p3.x_cord + 23 + >>> p3.z_cord + 17 + """ + Q = self + R = self.double() + for i in bin(k)[3:]: + if i == '1': + Q = R.add(Q, self) + R = R.double() + else: + R = Q.add(R, self) + Q = Q.double() + return Q + + +def _ecm_one_factor(n, B1=10000, B2=100000, max_curve=200, seed=None): + """Returns one factor of n using + Lenstra's 2 Stage Elliptic curve Factorization + with Suyama's Parameterization. Here Montgomery + arithmetic is used for fast computation of addition + and doubling of points in elliptic curve. + + Explanation + =========== + + This ECM method considers elliptic curves in Montgomery + form (E : b*y**2*z = x**3 + a*x**2*z + x*z**2) and involves + elliptic curve operations (mod N), where the elements in + Z are reduced (mod N). Since N is not a prime, E over FF(N) + is not really an elliptic curve but we can still do point additions + and doubling as if FF(N) was a field. + + Stage 1 : The basic algorithm involves taking a random point (P) on an + elliptic curve in FF(N). The compute k*P using Montgomery ladder algorithm. + Let q be an unknown factor of N. Then the order of the curve E, |E(FF(q))|, + might be a smooth number that divides k. Then we have k = l * |E(FF(q))| + for some l. For any point belonging to the curve E, |E(FF(q))|*P = O, + hence k*P = l*|E(FF(q))|*P. Thus kP.z_cord = 0 (mod q), and the unknownn + factor of N (q) can be recovered by taking gcd(kP.z_cord, N). + + Stage 2 : This is a continuation of Stage 1 if k*P != O. The idea utilize + the fact that even if kP != 0, the value of k might miss just one large + prime divisor of |E(FF(q))|. In this case we only need to compute the + scalar multiplication by p to get p*k*P = O. Here a second bound B2 + restrict the size of possible values of p. + + Parameters + ========== + + n : Number to be Factored. Assume that it is a composite number. + B1 : Stage 1 Bound. Must be an even number. + B2 : Stage 2 Bound. Must be an even number. + max_curve : Maximum number of curves generated + + Returns + ======= + + integer | None : a non-trivial divisor of ``n``. ``None`` if not found + + References + ========== + + .. [1] Carl Pomerance, Richard Crandall, Prime Numbers: A Computational Perspective, + 2nd Edition (2005), page 344, ISBN:978-0387252827 + """ + randint = _randint(seed) + + # When calculating T, if (B1 - 2*D) is negative, it cannot be calculated. + D = min(sqrt(B2), B1 // 2 - 1) + sieve.extend(D) + beta = [0] * D + S = [0] * D + k = 1 + for p in primerange(2, B1 + 1): + k *= pow(p, int(log(B1, p))) + + # Pre-calculate the prime numbers to be used in stage 2. + # Using the fact that the x-coordinates of point P and its + # inverse -P coincide, the number of primes to be checked + # in stage 2 can be reduced. + deltas_list = [] + for r in range(B1 + 2*D, B2 + 2*D, 4*D): + # d in deltas iff r+(2d+1) and/or r-(2d+1) is prime + deltas = {abs(q - r) >> 1 for q in primerange(r - 2*D, r + 2*D)} + deltas_list.append(list(deltas)) + + for _ in range(max_curve): + #Suyama's Parametrization + sigma = randint(6, n - 1) + u = (sigma**2 - 5) % n + v = (4*sigma) % n + u_3 = pow(u, 3, n) + + try: + # We use the elliptic curve y**2 = x**3 + a*x**2 + x + # where a = pow(v - u, 3, n)*(3*u + v)*invert(4*u_3*v, n) - 2 + # However, we do not declare a because it is more convenient + # to use a24 = (a + 2)*invert(4, n) in the calculation. + a24 = pow(v - u, 3, n)*(3*u + v)*invert(16*u_3*v, n) % n + except ZeroDivisionError: + #If the invert(16*u_3*v, n) doesn't exist (i.e., g != 1) + g = gcd(2*u_3*v, n) + #If g = n, try another curve + if g == n: + continue + return g + + Q = Point(u_3, pow(v, 3, n), a24, n) + Q = Q.mont_ladder(k) + g = gcd(Q.z_cord, n) + + #Stage 1 factor + if g != 1 and g != n: + return g + #Stage 1 failure. Q.z = 0, Try another curve + elif g == n: + continue + + #Stage 2 - Improved Standard Continuation + S[0] = Q + Q2 = Q.double() + S[1] = Q2.add(Q, Q) + beta[0] = (S[0].x_cord*S[0].z_cord) % n + beta[1] = (S[1].x_cord*S[1].z_cord) % n + for d in range(2, D): + S[d] = S[d - 1].add(Q2, S[d - 2]) + beta[d] = (S[d].x_cord*S[d].z_cord) % n + # i.e., S[i] = Q.mont_ladder(2*i + 1) + + g = 1 + W = Q.mont_ladder(4*D) + T = Q.mont_ladder(B1 - 2*D) + R = Q.mont_ladder(B1 + 2*D) + for deltas in deltas_list: + # R = Q.mont_ladder(r) where r in range(B1 + 2*D, B2 + 2*D, 4*D) + alpha = (R.x_cord*R.z_cord) % n + for delta in deltas: + # We want to calculate + # f = R.x_cord * S[delta].z_cord - S[delta].x_cord * R.z_cord + f = (R.x_cord - S[delta].x_cord)*\ + (R.z_cord + S[delta].z_cord) - alpha + beta[delta] + g = (g*f) % n + T, R = R, R.add(W, T) + g = gcd(n, g) + + #Stage 2 Factor found + if g != 1 and g != n: + return g + + +def ecm(n, B1=10000, B2=100000, max_curve=200, seed=1234): + """Performs factorization using Lenstra's Elliptic curve method. + + This function repeatedly calls ``_ecm_one_factor`` to compute the factors + of n. First all the small factors are taken out using trial division. + Then ``_ecm_one_factor`` is used to compute one factor at a time. + + Parameters + ========== + + n : Number to be Factored + B1 : Stage 1 Bound. Must be an even number. + B2 : Stage 2 Bound. Must be an even number. + max_curve : Maximum number of curves generated + seed : Initialize pseudorandom generator + + Examples + ======== + + >>> from sympy.ntheory import ecm + >>> ecm(25645121643901801) + {5394769, 4753701529} + >>> ecm(9804659461513846513) + {4641991, 2112166839943} + """ + from .factor_ import _perfect_power + n = as_int(n) + if B1 % 2 != 0 or B2 % 2 != 0: + raise ValueError("both bounds must be even") + TF_LIMIT = 100000 + factors = set() + for prime in sieve.primerange(2, TF_LIMIT): + if n % prime == 0: + factors.add(prime) + while(n % prime == 0): + n //= prime + + queue = [] + def check(m): + if isprime(m): + factors.add(m) + return + if result := _perfect_power(m, TF_LIMIT): + return check(result[0]) + queue.append(m) + check(n) + while queue: + n = queue.pop() + factor = _ecm_one_factor(n, B1, B2, max_curve, seed) + if factor is None: + raise ValueError("Increase the bounds") + check(factor) + check(n // factor) + return factors diff --git a/phivenv/Lib/site-packages/sympy/ntheory/egyptian_fraction.py b/phivenv/Lib/site-packages/sympy/ntheory/egyptian_fraction.py new file mode 100644 index 0000000000000000000000000000000000000000..8a42540b372042f596808684fef8e3fc57935b74 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/ntheory/egyptian_fraction.py @@ -0,0 +1,223 @@ +from sympy.core.containers import Tuple +from sympy.core.numbers import (Integer, Rational) +from sympy.core.singleton import S +import sympy.polys + +from math import gcd + + +def egyptian_fraction(r, algorithm="Greedy"): + """ + Return the list of denominators of an Egyptian fraction + expansion [1]_ of the said rational `r`. + + Parameters + ========== + + r : Rational or (p, q) + a positive rational number, ``p/q``. + algorithm : { "Greedy", "Graham Jewett", "Takenouchi", "Golomb" }, optional + Denotes the algorithm to be used (the default is "Greedy"). + + Examples + ======== + + >>> from sympy import Rational + >>> from sympy.ntheory.egyptian_fraction import egyptian_fraction + >>> egyptian_fraction(Rational(3, 7)) + [3, 11, 231] + >>> egyptian_fraction((3, 7), "Graham Jewett") + [7, 8, 9, 56, 57, 72, 3192] + >>> egyptian_fraction((3, 7), "Takenouchi") + [4, 7, 28] + >>> egyptian_fraction((3, 7), "Golomb") + [3, 15, 35] + >>> egyptian_fraction((11, 5), "Golomb") + [1, 2, 3, 4, 9, 234, 1118, 2580] + + See Also + ======== + + sympy.core.numbers.Rational + + Notes + ===== + + Currently the following algorithms are supported: + + 1) Greedy Algorithm + + Also called the Fibonacci-Sylvester algorithm [2]_. + At each step, extract the largest unit fraction less + than the target and replace the target with the remainder. + + It has some distinct properties: + + a) Given `p/q` in lowest terms, generates an expansion of maximum + length `p`. Even as the numerators get large, the number of + terms is seldom more than a handful. + + b) Uses minimal memory. + + c) The terms can blow up (standard examples of this are 5/121 and + 31/311). The denominator is at most squared at each step + (doubly-exponential growth) and typically exhibits + singly-exponential growth. + + 2) Graham Jewett Algorithm + + The algorithm suggested by the result of Graham and Jewett. + Note that this has a tendency to blow up: the length of the + resulting expansion is always ``2**(x/gcd(x, y)) - 1``. See [3]_. + + 3) Takenouchi Algorithm + + The algorithm suggested by Takenouchi (1921). + Differs from the Graham-Jewett algorithm only in the handling + of duplicates. See [3]_. + + 4) Golomb's Algorithm + + A method given by Golumb (1962), using modular arithmetic and + inverses. It yields the same results as a method using continued + fractions proposed by Bleicher (1972). See [4]_. + + If the given rational is greater than or equal to 1, a greedy algorithm + of summing the harmonic sequence 1/1 + 1/2 + 1/3 + ... is used, taking + all the unit fractions of this sequence until adding one more would be + greater than the given number. This list of denominators is prefixed + to the result from the requested algorithm used on the remainder. For + example, if r is 8/3, using the Greedy algorithm, we get [1, 2, 3, 4, + 5, 6, 7, 14, 420], where the beginning of the sequence, [1, 2, 3, 4, 5, + 6, 7] is part of the harmonic sequence summing to 363/140, leaving a + remainder of 31/420, which yields [14, 420] by the Greedy algorithm. + The result of egyptian_fraction(Rational(8, 3), "Golomb") is [1, 2, 3, + 4, 5, 6, 7, 14, 574, 2788, 6460, 11590, 33062, 113820], and so on. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Egyptian_fraction + .. [2] https://en.wikipedia.org/wiki/Greedy_algorithm_for_Egyptian_fractions + .. [3] https://www.ics.uci.edu/~eppstein/numth/egypt/conflict.html + .. [4] https://web.archive.org/web/20180413004012/https://ami.ektf.hu/uploads/papers/finalpdf/AMI_42_from129to134.pdf + + """ + + if not isinstance(r, Rational): + if isinstance(r, (Tuple, tuple)) and len(r) == 2: + r = Rational(*r) + else: + raise ValueError("Value must be a Rational or tuple of ints") + if r <= 0: + raise ValueError("Value must be positive") + + # common cases that all methods agree on + x, y = r.as_numer_denom() + if y == 1 and x == 2: + return [Integer(i) for i in [1, 2, 3, 6]] + if x == y + 1: + return [S.One, y] + + prefix, rem = egypt_harmonic(r) + if rem == 0: + return prefix + # work in Python ints + x, y = rem.p, rem.q + # assert x < y and gcd(x, y) = 1 + + if algorithm == "Greedy": + postfix = egypt_greedy(x, y) + elif algorithm == "Graham Jewett": + postfix = egypt_graham_jewett(x, y) + elif algorithm == "Takenouchi": + postfix = egypt_takenouchi(x, y) + elif algorithm == "Golomb": + postfix = egypt_golomb(x, y) + else: + raise ValueError("Entered invalid algorithm") + return prefix + [Integer(i) for i in postfix] + + +def egypt_greedy(x, y): + # assumes gcd(x, y) == 1 + if x == 1: + return [y] + else: + a = (-y) % x + b = y*(y//x + 1) + c = gcd(a, b) + if c > 1: + num, denom = a//c, b//c + else: + num, denom = a, b + return [y//x + 1] + egypt_greedy(num, denom) + + +def egypt_graham_jewett(x, y): + # assumes gcd(x, y) == 1 + l = [y] * x + + # l is now a list of integers whose reciprocals sum to x/y. + # we shall now proceed to manipulate the elements of l without + # changing the reciprocated sum until all elements are unique. + + while len(l) != len(set(l)): + l.sort() # so the list has duplicates. find a smallest pair + for i in range(len(l) - 1): + if l[i] == l[i + 1]: + break + # we have now identified a pair of identical + # elements: l[i] and l[i + 1]. + # now comes the application of the result of graham and jewett: + l[i + 1] = l[i] + 1 + # and we just iterate that until the list has no duplicates. + l.append(l[i]*(l[i] + 1)) + return sorted(l) + + +def egypt_takenouchi(x, y): + # assumes gcd(x, y) == 1 + # special cases for 3/y + if x == 3: + if y % 2 == 0: + return [y//2, y] + i = (y - 1)//2 + j = i + 1 + k = j + i + return [j, k, j*k] + l = [y] * x + while len(l) != len(set(l)): + l.sort() + for i in range(len(l) - 1): + if l[i] == l[i + 1]: + break + k = l[i] + if k % 2 == 0: + l[i] = l[i] // 2 + del l[i + 1] + else: + l[i], l[i + 1] = (k + 1)//2, k*(k + 1)//2 + return sorted(l) + + +def egypt_golomb(x, y): + # assumes x < y and gcd(x, y) == 1 + if x == 1: + return [y] + xp = sympy.polys.ZZ.invert(int(x), int(y)) + rv = [xp*y] + rv.extend(egypt_golomb((x*xp - 1)//y, xp)) + return sorted(rv) + + +def egypt_harmonic(r): + # assumes r is Rational + rv = [] + d = S.One + acc = S.Zero + while acc + 1/d <= r: + acc += 1/d + rv.append(d) + d += 1 + return (rv, r - acc) diff --git a/phivenv/Lib/site-packages/sympy/ntheory/elliptic_curve.py b/phivenv/Lib/site-packages/sympy/ntheory/elliptic_curve.py new file mode 100644 index 0000000000000000000000000000000000000000..c969470a6c19a3d17e637529b6615eeba326e84a --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/ntheory/elliptic_curve.py @@ -0,0 +1,397 @@ +from sympy.core.numbers import oo +from sympy.core.symbol import symbols +from sympy.polys.domains import FiniteField, QQ, RationalField, FF +from sympy.polys.polytools import Poly +from sympy.solvers.solvers import solve +from sympy.utilities.iterables import is_sequence +from sympy.utilities.misc import as_int +from .factor_ import divisors +from .residue_ntheory import polynomial_congruence + + +class EllipticCurve: + """ + Create the following Elliptic Curve over domain. + + `y^{2} + a_{1} x y + a_{3} y = x^{3} + a_{2} x^{2} + a_{4} x + a_{6}` + + The default domain is ``QQ``. If no coefficient ``a1``, ``a2``, ``a3``, + is given then it creates a curve with the following form: + + `y^{2} = x^{3} + a_{4} x + a_{6}` + + Examples + ======== + + References + ========== + + .. [1] J. Silverman "A Friendly Introduction to Number Theory" Third Edition + .. [2] https://mathworld.wolfram.com/EllipticDiscriminant.html + .. [3] G. Hardy, E. Wright "An Introduction to the Theory of Numbers" Sixth Edition + + """ + + def __init__(self, a4, a6, a1=0, a2=0, a3=0, modulus=0): + if modulus == 0: + domain = QQ + else: + domain = FF(modulus) + a1, a2, a3, a4, a6 = map(domain.convert, (a1, a2, a3, a4, a6)) + self._domain = domain + self.modulus = modulus + # Calculate discriminant + b2 = a1**2 + 4 * a2 + b4 = 2 * a4 + a1 * a3 + b6 = a3**2 + 4 * a6 + b8 = a1**2 * a6 + 4 * a2 * a6 - a1 * a3 * a4 + a2 * a3**2 - a4**2 + self._b2, self._b4, self._b6, self._b8 = b2, b4, b6, b8 + self._discrim = -b2**2 * b8 - 8 * b4**3 - 27 * b6**2 + 9 * b2 * b4 * b6 + self._a1 = a1 + self._a2 = a2 + self._a3 = a3 + self._a4 = a4 + self._a6 = a6 + x, y, z = symbols('x y z') + self.x, self.y, self.z = x, y, z + self._poly = Poly(y**2*z + a1*x*y*z + a3*y*z**2 - x**3 - a2*x**2*z - a4*x*z**2 - a6*z**3, domain=domain) + if isinstance(self._domain, FiniteField): + self._rank = 0 + elif isinstance(self._domain, RationalField): + self._rank = None + + def __call__(self, x, y, z=1): + return EllipticCurvePoint(x, y, z, self) + + def __contains__(self, point): + if is_sequence(point): + if len(point) == 2: + z1 = 1 + else: + z1 = point[2] + x1, y1 = point[:2] + elif isinstance(point, EllipticCurvePoint): + x1, y1, z1 = point.x, point.y, point.z + else: + raise ValueError('Invalid point.') + if self.characteristic == 0 and z1 == 0: + return True + return self._poly.subs({self.x: x1, self.y: y1, self.z: z1}) == 0 + + def __repr__(self): + return self._poly.__repr__() + + def minimal(self): + """ + Return minimal Weierstrass equation. + + Examples + ======== + + >>> from sympy.ntheory.elliptic_curve import EllipticCurve + + >>> e1 = EllipticCurve(-10, -20, 0, -1, 1) + >>> e1.minimal() + Poly(-x**3 + 13392*x*z**2 + y**2*z + 1080432*z**3, x, y, z, domain='QQ') + + """ + char = self.characteristic + if char == 2: + return self + if char == 3: + return EllipticCurve(self._b4/2, self._b6/4, a2=self._b2/4, modulus=self.modulus) + c4 = self._b2**2 - 24*self._b4 + c6 = -self._b2**3 + 36*self._b2*self._b4 - 216*self._b6 + return EllipticCurve(-27*c4, -54*c6, modulus=self.modulus) + + def points(self): + """ + Return points of curve over Finite Field. + + Examples + ======== + + >>> from sympy.ntheory.elliptic_curve import EllipticCurve + >>> e2 = EllipticCurve(1, 1, 1, 1, 1, modulus=5) + >>> e2.points() + {(0, 2), (1, 4), (2, 0), (2, 2), (3, 0), (3, 1), (4, 0)} + + """ + + char = self.characteristic + all_pt = set() + if char >= 1: + for i in range(char): + congruence_eq = self._poly.subs({self.x: i, self.z: 1}).expr + sol = polynomial_congruence(congruence_eq, char) + all_pt.update((i, num) for num in sol) + return all_pt + else: + raise ValueError("Infinitely many points") + + def points_x(self, x): + """Returns points on the curve for the given x-coordinate.""" + pt = [] + if self._domain == QQ: + for y in solve(self._poly.subs(self.x, x)): + pt.append((x, y)) + else: + congruence_eq = self._poly.subs({self.x: x, self.z: 1}).expr + for y in polynomial_congruence(congruence_eq, self.characteristic): + pt.append((x, y)) + return pt + + def torsion_points(self): + """ + Return torsion points of curve over Rational number. + + Return point objects those are finite order. + According to Nagell-Lutz theorem, torsion point p(x, y) + x and y are integers, either y = 0 or y**2 is divisor + of discriminent. According to Mazur's theorem, there are + at most 15 points in torsion collection. + + Examples + ======== + + >>> from sympy.ntheory.elliptic_curve import EllipticCurve + >>> e2 = EllipticCurve(-43, 166) + >>> sorted(e2.torsion_points()) + [(-5, -16), (-5, 16), O, (3, -8), (3, 8), (11, -32), (11, 32)] + + """ + if self.characteristic > 0: + raise ValueError("No torsion point for Finite Field.") + l = [EllipticCurvePoint.point_at_infinity(self)] + for xx in solve(self._poly.subs({self.y: 0, self.z: 1})): + if xx.is_rational: + l.append(self(xx, 0)) + for i in divisors(self.discriminant, generator=True): + j = int(i**.5) + if j**2 == i: + for xx in solve(self._poly.subs({self.y: j, self.z: 1})): + if not xx.is_rational: + continue + p = self(xx, j) + if p.order() != oo: + l.extend([p, -p]) + return l + + @property + def characteristic(self): + """ + Return domain characteristic. + + Examples + ======== + + >>> from sympy.ntheory.elliptic_curve import EllipticCurve + >>> e2 = EllipticCurve(-43, 166) + >>> e2.characteristic + 0 + + """ + return self._domain.characteristic() + + @property + def discriminant(self): + """ + Return curve discriminant. + + Examples + ======== + + >>> from sympy.ntheory.elliptic_curve import EllipticCurve + >>> e2 = EllipticCurve(0, 17) + >>> e2.discriminant + -124848 + + """ + return int(self._discrim) + + @property + def is_singular(self): + """ + Return True if curve discriminant is equal to zero. + """ + return self.discriminant == 0 + + @property + def j_invariant(self): + """ + Return curve j-invariant. + + Examples + ======== + + >>> from sympy.ntheory.elliptic_curve import EllipticCurve + >>> e1 = EllipticCurve(-2, 0, 0, 1, 1) + >>> e1.j_invariant + 1404928/389 + + """ + c4 = self._b2**2 - 24*self._b4 + return self._domain.to_sympy(c4**3 / self._discrim) + + @property + def order(self): + """ + Number of points in Finite field. + + Examples + ======== + + >>> from sympy.ntheory.elliptic_curve import EllipticCurve + >>> e2 = EllipticCurve(1, 0, modulus=19) + >>> e2.order + 19 + + """ + if self.characteristic == 0: + raise NotImplementedError("Still not implemented") + return len(self.points()) + + @property + def rank(self): + """ + Number of independent points of infinite order. + + For Finite field, it must be 0. + """ + if self._rank is not None: + return self._rank + raise NotImplementedError("Still not implemented") + + +class EllipticCurvePoint: + """ + Point of Elliptic Curve + + Examples + ======== + + >>> from sympy.ntheory.elliptic_curve import EllipticCurve + >>> e1 = EllipticCurve(-17, 16) + >>> p1 = e1(0, -4, 1) + >>> p2 = e1(1, 0) + >>> p1 + p2 + (15, -56) + >>> e3 = EllipticCurve(-1, 9) + >>> e3(1, -3) * 3 + (664/169, 17811/2197) + >>> (e3(1, -3) * 3).order() + oo + >>> e2 = EllipticCurve(-2, 0, 0, 1, 1) + >>> p = e2(-1,1) + >>> q = e2(0, -1) + >>> p+q + (4, 8) + >>> p-q + (1, 0) + >>> 3*p-5*q + (328/361, -2800/6859) + """ + + @staticmethod + def point_at_infinity(curve): + return EllipticCurvePoint(0, 1, 0, curve) + + def __init__(self, x, y, z, curve): + dom = curve._domain.convert + self.x = dom(x) + self.y = dom(y) + self.z = dom(z) + self._curve = curve + self._domain = self._curve._domain + if not self._curve.__contains__(self): + raise ValueError("The curve does not contain this point") + + def __add__(self, p): + if self.z == 0: + return p + if p.z == 0: + return self + x1, y1 = self.x/self.z, self.y/self.z + x2, y2 = p.x/p.z, p.y/p.z + a1 = self._curve._a1 + a2 = self._curve._a2 + a3 = self._curve._a3 + a4 = self._curve._a4 + a6 = self._curve._a6 + if x1 != x2: + slope = (y1 - y2) / (x1 - x2) + yint = (y1 * x2 - y2 * x1) / (x2 - x1) + else: + if (y1 + y2) == 0: + return self.point_at_infinity(self._curve) + slope = (3 * x1**2 + 2*a2*x1 + a4 - a1*y1) / (a1 * x1 + a3 + 2 * y1) + yint = (-x1**3 + a4*x1 + 2*a6 - a3*y1) / (a1*x1 + a3 + 2*y1) + x3 = slope**2 + a1*slope - a2 - x1 - x2 + y3 = -(slope + a1) * x3 - yint - a3 + return self._curve(x3, y3, 1) + + def __lt__(self, other): + return (self.x, self.y, self.z) < (other.x, other.y, other.z) + + def __mul__(self, n): + n = as_int(n) + r = self.point_at_infinity(self._curve) + if n == 0: + return r + if n < 0: + return -self * -n + p = self + while n: + if n & 1: + r = r + p + n >>= 1 + p = p + p + return r + + def __rmul__(self, n): + return self * n + + def __neg__(self): + return EllipticCurvePoint(self.x, -self.y - self._curve._a1*self.x - self._curve._a3, self.z, self._curve) + + def __repr__(self): + if self.z == 0: + return 'O' + dom = self._curve._domain + try: + return '({}, {})'.format(dom.to_sympy(self.x), dom.to_sympy(self.y)) + except TypeError: + pass + return '({}, {})'.format(self.x, self.y) + + def __sub__(self, other): + return self.__add__(-other) + + def order(self): + """ + Return point order n where nP = 0. + + """ + if self.z == 0: + return 1 + if self.y == 0: # P = -P + return 2 + p = self * 2 + if p.y == -self.y: # 2P = -P + return 3 + i = 2 + if self._domain != QQ: + while int(p.x) == p.x and int(p.y) == p.y: + p = self + p + i += 1 + if p.z == 0: + return i + return oo + while p.x.numerator == p.x and p.y.numerator == p.y: + p = self + p + i += 1 + if i > 12: + return oo + if p.z == 0: + return i + return oo diff --git a/phivenv/Lib/site-packages/sympy/ntheory/factor_.py b/phivenv/Lib/site-packages/sympy/ntheory/factor_.py new file mode 100644 index 0000000000000000000000000000000000000000..2dc6ac81c237f000e55014f5e170b27b41335786 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/ntheory/factor_.py @@ -0,0 +1,2841 @@ +""" +Integer factorization +""" +from __future__ import annotations + +from bisect import bisect_left +from collections import defaultdict, OrderedDict +from collections.abc import MutableMapping +import math + +from sympy.core.containers import Dict +from sympy.core.mul import Mul +from sympy.core.numbers import Rational, Integer +from sympy.core.intfunc import num_digits +from sympy.core.power import Pow +from sympy.core.random import _randint +from sympy.core.singleton import S +from sympy.external.gmpy import (SYMPY_INTS, gcd, sqrt as isqrt, + sqrtrem, iroot, bit_scan1, remove) +from .primetest import isprime, MERSENNE_PRIME_EXPONENTS, is_mersenne_prime +from .generate import sieve, primerange, nextprime +from .digits import digits +from sympy.utilities.decorator import deprecated +from sympy.utilities.iterables import flatten +from sympy.utilities.misc import as_int, filldedent +from .ecm import _ecm_one_factor + + +def smoothness(n): + """ + Return the B-smooth and B-power smooth values of n. + + The smoothness of n is the largest prime factor of n; the power- + smoothness is the largest divisor raised to its multiplicity. + + Examples + ======== + + >>> from sympy.ntheory.factor_ import smoothness + >>> smoothness(2**7*3**2) + (3, 128) + >>> smoothness(2**4*13) + (13, 16) + >>> smoothness(2) + (2, 2) + + See Also + ======== + + factorint, smoothness_p + """ + + if n == 1: + return (1, 1) # not prime, but otherwise this causes headaches + facs = factorint(n) + return max(facs), max(m**facs[m] for m in facs) + + +def smoothness_p(n, m=-1, power=0, visual=None): + """ + Return a list of [m, (p, (M, sm(p + m), psm(p + m)))...] + where: + + 1. p**M is the base-p divisor of n + 2. sm(p + m) is the smoothness of p + m (m = -1 by default) + 3. psm(p + m) is the power smoothness of p + m + + The list is sorted according to smoothness (default) or by power smoothness + if power=1. + + The smoothness of the numbers to the left (m = -1) or right (m = 1) of a + factor govern the results that are obtained from the p +/- 1 type factoring + methods. + + >>> from sympy.ntheory.factor_ import smoothness_p, factorint + >>> smoothness_p(10431, m=1) + (1, [(3, (2, 2, 4)), (19, (1, 5, 5)), (61, (1, 31, 31))]) + >>> smoothness_p(10431) + (-1, [(3, (2, 2, 2)), (19, (1, 3, 9)), (61, (1, 5, 5))]) + >>> smoothness_p(10431, power=1) + (-1, [(3, (2, 2, 2)), (61, (1, 5, 5)), (19, (1, 3, 9))]) + + If visual=True then an annotated string will be returned: + + >>> print(smoothness_p(21477639576571, visual=1)) + p**i=4410317**1 has p-1 B=1787, B-pow=1787 + p**i=4869863**1 has p-1 B=2434931, B-pow=2434931 + + This string can also be generated directly from a factorization dictionary + and vice versa: + + >>> factorint(17*9) + {3: 2, 17: 1} + >>> smoothness_p(_) + 'p**i=3**2 has p-1 B=2, B-pow=2\\np**i=17**1 has p-1 B=2, B-pow=16' + >>> smoothness_p(_) + {3: 2, 17: 1} + + The table of the output logic is: + + ====== ====== ======= ======= + | Visual + ------ ---------------------- + Input True False other + ====== ====== ======= ======= + dict str tuple str + str str tuple dict + tuple str tuple str + n str tuple tuple + mul str tuple tuple + ====== ====== ======= ======= + + See Also + ======== + + factorint, smoothness + """ + + # visual must be True, False or other (stored as None) + if visual in (1, 0): + visual = bool(visual) + elif visual not in (True, False): + visual = None + + if isinstance(n, str): + if visual: + return n + d = {} + for li in n.splitlines(): + k, v = [int(i) for i in + li.split('has')[0].split('=')[1].split('**')] + d[k] = v + if visual is not True and visual is not False: + return d + return smoothness_p(d, visual=False) + elif not isinstance(n, tuple): + facs = factorint(n, visual=False) + + if power: + k = -1 + else: + k = 1 + if isinstance(n, tuple): + rv = n + else: + rv = (m, sorted([(f, + tuple([M] + list(smoothness(f + m)))) + for f, M in list(facs.items())], + key=lambda x: (x[1][k], x[0]))) + + if visual is False or (visual is not True) and (type(n) in [int, Mul]): + return rv + lines = [] + for dat in rv[1]: + dat = flatten(dat) + dat.insert(2, m) + lines.append('p**i=%i**%i has p%+i B=%i, B-pow=%i' % tuple(dat)) + return '\n'.join(lines) + + +def multiplicity(p, n): + """ + Find the greatest integer m such that p**m divides n. + + Examples + ======== + + >>> from sympy import multiplicity, Rational + >>> [multiplicity(5, n) for n in [8, 5, 25, 125, 250]] + [0, 1, 2, 3, 3] + >>> multiplicity(3, Rational(1, 9)) + -2 + + Note: when checking for the multiplicity of a number in a + large factorial it is most efficient to send it as an unevaluated + factorial or to call ``multiplicity_in_factorial`` directly: + + >>> from sympy.ntheory import multiplicity_in_factorial + >>> from sympy import factorial + >>> p = factorial(25) + >>> n = 2**100 + >>> nfac = factorial(n, evaluate=False) + >>> multiplicity(p, nfac) + 52818775009509558395695966887 + >>> _ == multiplicity_in_factorial(p, n) + True + + See Also + ======== + + trailing + + """ + try: + p, n = as_int(p), as_int(n) + except ValueError: + from sympy.functions.combinatorial.factorials import factorial + if all(isinstance(i, (SYMPY_INTS, Rational)) for i in (p, n)): + p = Rational(p) + n = Rational(n) + if p.q == 1: + if n.p == 1: + return -multiplicity(p.p, n.q) + return multiplicity(p.p, n.p) - multiplicity(p.p, n.q) + elif p.p == 1: + return multiplicity(p.q, n.q) + else: + like = min( + multiplicity(p.p, n.p), + multiplicity(p.q, n.q)) + cross = min( + multiplicity(p.q, n.p), + multiplicity(p.p, n.q)) + return like - cross + elif (isinstance(p, (SYMPY_INTS, Integer)) and + isinstance(n, factorial) and + isinstance(n.args[0], Integer) and + n.args[0] >= 0): + return multiplicity_in_factorial(p, n.args[0]) + raise ValueError('expecting ints or fractions, got %s and %s' % (p, n)) + + if n == 0: + raise ValueError('no such integer exists: multiplicity of %s is not-defined' %(n)) + return remove(n, p)[1] + + +def multiplicity_in_factorial(p, n): + """return the largest integer ``m`` such that ``p**m`` divides ``n!`` + without calculating the factorial of ``n``. + + Parameters + ========== + + p : Integer + positive integer + n : Integer + non-negative integer + + Examples + ======== + + >>> from sympy.ntheory import multiplicity_in_factorial + >>> from sympy import factorial + + >>> multiplicity_in_factorial(2, 3) + 1 + + An instructive use of this is to tell how many trailing zeros + a given factorial has. For example, there are 6 in 25!: + + >>> factorial(25) + 15511210043330985984000000 + >>> multiplicity_in_factorial(10, 25) + 6 + + For large factorials, it is much faster/feasible to use + this function rather than computing the actual factorial: + + >>> multiplicity_in_factorial(factorial(25), 2**100) + 52818775009509558395695966887 + + See Also + ======== + + multiplicity + + """ + + p, n = as_int(p), as_int(n) + + if p <= 0: + raise ValueError('expecting positive integer got %s' % p ) + + if n < 0: + raise ValueError('expecting non-negative integer got %s' % n ) + + # keep only the largest of a given multiplicity since those + # of a given multiplicity will be goverened by the behavior + # of the largest factor + f = defaultdict(int) + for k, v in factorint(p).items(): + f[v] = max(k, f[v]) + # multiplicity of p in n! depends on multiplicity + # of prime `k` in p, so we floor divide by `v` + # and keep it if smaller than the multiplicity of p + # seen so far + return min((n + k - sum(digits(n, k)))//(k - 1)//v for v, k in f.items()) + + +def _perfect_power(n, next_p=2): + """ Return integers ``(b, e)`` such that ``n == b**e`` if ``n`` is a unique + perfect power with ``e > 1``, else ``False`` (e.g. 1 is not a perfect power). + + Explanation + =========== + + This is a low-level helper for ``perfect_power``, for internal use. + + Parameters + ========== + + n : int + assume that n is a nonnegative integer + next_p : int + Assume that n has no factor less than next_p. + i.e., all(n % p for p in range(2, next_p)) is True + + Examples + ======== + >>> from sympy.ntheory.factor_ import _perfect_power + >>> _perfect_power(16) + (2, 4) + >>> _perfect_power(17) + False + + """ + if n <= 3: + return False + + factors = {} + g = 0 + multi = 1 + + def done(n, factors, g, multi): + g = gcd(g, multi) + if g == 1: + return False + factors[n] = multi + return math.prod(p**(e//g) for p, e in factors.items()), g + + # If n is small, only trial factoring is faster + if n <= 1_000_000: + n = _factorint_small(factors, n, 1_000, 1_000, next_p)[0] + if n > 1: + return False + g = gcd(*factors.values()) + if g == 1: + return False + return math.prod(p**(e//g) for p, e in factors.items()), g + + # divide by 2 + if next_p < 3: + g = bit_scan1(n) + if g: + if g == 1: + return False + n >>= g + factors[2] = g + if n == 1: + return 2, g + else: + # If `m**g`, then we have found perfect power. + # Otherwise, there is no possibility of perfect power, especially if `g` is prime. + m, _exact = iroot(n, g) + if _exact: + return 2*m, g + elif isprime(g): + return False + next_p = 3 + + # square number? + while n & 7 == 1: # n % 8 == 1: + m, _exact = iroot(n, 2) + if _exact: + n = m + multi <<= 1 + else: + break + if n < next_p**3: + return done(n, factors, g, multi) + + # trial factoring + # Since the maximum value an exponent can take is `log_{next_p}(n)`, + # the number of exponents to be checked can be reduced by performing a trial factoring. + # The value of `tf_max` needs more consideration. + tf_max = n.bit_length()//27 + 24 + if next_p < tf_max: + for p in primerange(next_p, tf_max): + m, t = remove(n, p) + if t: + n = m + t *= multi + _g = gcd(g, t) + if _g == 1: + return False + factors[p] = t + if n == 1: + return math.prod(p**(e//_g) + for p, e in factors.items()), _g + elif g == 0 or _g < g: # If g is updated + g = _g + m, _exact = iroot(n**multi, g) + if _exact: + return m * math.prod(p**(e//g) + for p, e in factors.items()), g + elif isprime(g): + return False + next_p = tf_max + if n < next_p**3: + return done(n, factors, g, multi) + + # check iroot + if g: + # If g is non-zero, the exponent is a divisor of g. + # 2 can be omitted since it has already been checked. + prime_iter = sorted(factorint(g >> bit_scan1(g)).keys()) + else: + # The maximum possible value of the exponent is `log_{next_p}(n)`. + # To compensate for the presence of computational error, 2 is added. + prime_iter = primerange(3, int(math.log(n, next_p)) + 2) + logn = math.log2(n) + threshold = logn / 40 # Threshold for direct calculation + for p in prime_iter: + if threshold < p: + # If p is large, find the power root p directly without `iroot`. + while True: + b = pow(2, logn / p) + rb = int(b + 0.5) + if abs(rb - b) < 0.01 and rb**p == n: + n = rb + multi *= p + logn = math.log2(n) + else: + break + else: + while True: + m, _exact = iroot(n, p) + if _exact: + n = m + multi *= p + logn = math.log2(n) + else: + break + if n < next_p**(p + 2): + break + return done(n, factors, g, multi) + + +def perfect_power(n, candidates=None, big=True, factor=True): + """ + Return ``(b, e)`` such that ``n`` == ``b**e`` if ``n`` is a unique + perfect power with ``e > 1``, else ``False`` (e.g. 1 is not a + perfect power). A ValueError is raised if ``n`` is not Rational. + + By default, the base is recursively decomposed and the exponents + collected so the largest possible ``e`` is sought. If ``big=False`` + then the smallest possible ``e`` (thus prime) will be chosen. + + If ``factor=True`` then simultaneous factorization of ``n`` is + attempted since finding a factor indicates the only possible root + for ``n``. This is True by default since only a few small factors will + be tested in the course of searching for the perfect power. + + The use of ``candidates`` is primarily for internal use; if provided, + False will be returned if ``n`` cannot be written as a power with one + of the candidates as an exponent and factoring (beyond testing for + a factor of 2) will not be attempted. + + Examples + ======== + + >>> from sympy import perfect_power, Rational + >>> perfect_power(16) + (2, 4) + >>> perfect_power(16, big=False) + (4, 2) + + Negative numbers can only have odd perfect powers: + + >>> perfect_power(-4) + False + >>> perfect_power(-8) + (-2, 3) + + Rationals are also recognized: + + >>> perfect_power(Rational(1, 2)**3) + (1/2, 3) + >>> perfect_power(Rational(-3, 2)**3) + (-3/2, 3) + + Notes + ===== + + To know whether an integer is a perfect power of 2 use + + >>> is2pow = lambda n: bool(n and not n & (n - 1)) + >>> [(i, is2pow(i)) for i in range(5)] + [(0, False), (1, True), (2, True), (3, False), (4, True)] + + It is not necessary to provide ``candidates``. When provided + it will be assumed that they are ints. The first one that is + larger than the computed maximum possible exponent will signal + failure for the routine. + + >>> perfect_power(3**8, [9]) + False + >>> perfect_power(3**8, [2, 4, 8]) + (3, 8) + >>> perfect_power(3**8, [4, 8], big=False) + (9, 4) + + See Also + ======== + sympy.core.intfunc.integer_nthroot + sympy.ntheory.primetest.is_square + """ + # negative handling + if n < 0: + if candidates is None: + pp = perfect_power(-n, big=True, factor=factor) + if not pp: + return False + + b, e = pp + e2 = e & (-e) + b, e = b ** e2, e // e2 + + if e <= 1: + return False + + if big or isprime(e): + return -b, e + + for p in primerange(3, e + 1): + if e % p == 0: + return - b ** (e // p), p + + odd_candidates = {i for i in candidates if i % 2} + if not odd_candidates: + return False + + pp = perfect_power(-n, odd_candidates, big, factor) + if pp: + return -pp[0], pp[1] + + return False + + # non-integer handling + if isinstance(n, Rational) and not isinstance(n, Integer): + p, q = n.p, n.q + + if p == 1: + qq = perfect_power(q, candidates, big, factor) + return (S.One / qq[0], qq[1]) if qq is not False else False + + if not (pp:=perfect_power(p, factor=factor)): + return False + if not (qq:=perfect_power(q, factor=factor)): + return False + (num_base, num_exp), (den_base, den_exp) = pp, qq + + def compute_tuple(exponent): + """Helper to compute final result given an exponent""" + new_num = num_base ** (num_exp // exponent) + new_den = den_base ** (den_exp // exponent) + return n.func(new_num, new_den), exponent + + if candidates: + valid_candidates = [i for i in candidates + if num_exp % i == 0 and den_exp % i == 0] + if not valid_candidates: + return False + + e = max(valid_candidates) if big else min(valid_candidates) + return compute_tuple(e) + + g = math.gcd(num_exp, den_exp) + if g == 1: + return False + + if big: + return compute_tuple(g) + + e = next(p for p in primerange(2, g + 1) if g % p == 0) + return compute_tuple(e) + + if candidates is not None: + candidates = set(candidates) + + # positive integer handling + n = as_int(n) + + if candidates is None and big: + return _perfect_power(n) + + if n <= 3: + # no unique exponent for 0, 1 + # 2 and 3 have exponents of 1 + return False + logn = math.log2(n) + max_possible = int(logn) + 2 # only check values less than this + not_square = n % 10 in [2, 3, 7, 8] # squares cannot end in 2, 3, 7, 8 + min_possible = 2 + not_square + if not candidates: + candidates = primerange(min_possible, max_possible) + else: + candidates = sorted([i for i in candidates + if min_possible <= i < max_possible]) + if n%2 == 0: + e = bit_scan1(n) + candidates = [i for i in candidates if e%i == 0] + if big: + candidates = reversed(candidates) + for e in candidates: + r, ok = iroot(n, e) + if ok: + return int(r), e + return False + + def _factors(): + rv = 2 + n % 2 + while True: + yield rv + rv = nextprime(rv) + + for fac, e in zip(_factors(), candidates): + # see if there is a factor present + if factor and n % fac == 0: + # find what the potential power is + e = remove(n, fac)[1] + # if it's a trivial power we are done + if e == 1: + return False + + # maybe the e-th root of n is exact + r, exact = iroot(n, e) + if not exact: + # Having a factor, we know that e is the maximal + # possible value for a root of n. + # If n = fac**e*m can be written as a perfect + # power then see if m can be written as r**E where + # gcd(e, E) != 1 so n = (fac**(e//E)*r)**E + m = n//fac**e + rE = perfect_power(m, candidates=divisors(e, generator=True)) + if not rE: + return False + else: + r, E = rE + r, e = fac**(e//E)*r, E + if not big: + e0 = primefactors(e) + if e0[0] != e: + r, e = r**(e//e0[0]), e0[0] + return int(r), e + + # Weed out downright impossible candidates + if logn/e < 40: + b = 2.0**(logn/e) + if abs(int(b + 0.5) - b) > 0.01: + continue + + # now see if the plausible e makes a perfect power + r, exact = iroot(n, e) + if exact: + if big: + m = perfect_power(r, big=big, factor=factor) + if m: + r, e = m[0], e*m[1] + return int(r), e + + return False + + +class FactorCache(MutableMapping): + """ Provides a cache for prime factors. + ``factor_cache`` is pre-prepared as an instance of ``FactorCache``, + and ``factorint`` internally references it to speed up + the factorization of prime factors. + + While cache is automatically added during the execution of ``factorint``, + users can also manually add prime factors independently. + + >>> from sympy import factor_cache + >>> factor_cache[15] = 5 + + Furthermore, by customizing ``get_external``, + it is also possible to use external databases. + The following is an example using http://factordb.com . + + .. code-block:: python + + import requests + from sympy import factor_cache + + def get_external(self, n: int) -> list[int] | None: + res = requests.get("http://factordb.com/api", params={"query": str(n)}) + if res.status_code != requests.codes.ok: + return None + j = res.json() + if j.get("status") in ["FF", "P"]: + return list(int(p) for p, _ in j.get("factors")) + + factor_cache.get_external = get_external + + Be aware that writing this code will trigger internet access + to factordb.com when calling ``factorint``. + + """ + def __init__(self, maxsize: int | None = None): + self._cache: OrderedDict[int, int] = OrderedDict() + self.maxsize = maxsize + + def __len__(self) -> int: + return len(self._cache) + + def __contains__(self, n) -> bool: + return n in self._cache + + def __getitem__(self, n: int) -> int: + factor = self.get(n) + if factor is None: + raise KeyError(f"{n} does not exist.") + return factor + + def __setitem__(self, n: int, factor: int): + if not (1 < factor <= n and n % factor == 0 and isprime(factor)): + raise ValueError(f"{factor} is not a prime factor of {n}") + self._cache[n] = max(self._cache.get(n, 0), factor) + if self.maxsize is not None and len(self._cache) > self.maxsize: + self._cache.popitem(False) + + def __delitem__(self, n: int): + if n not in self._cache: + raise KeyError(f"{n} does not exist.") + del self._cache[n] + + def __iter__(self): + return self._cache.__iter__() + + def cache_clear(self) -> None: + """ Clear the cache """ + self._cache = OrderedDict() + + @property + def maxsize(self) -> int | None: + """ Returns the maximum cache size; if ``None``, it is unlimited. """ + return self._maxsize + + @maxsize.setter + def maxsize(self, value: int | None) -> None: + if value is not None and value <= 0: + raise ValueError("maxsize must be None or a non-negative integer.") + self._maxsize = value + if value is not None: + while len(self._cache) > value: + self._cache.popitem(False) + + def get(self, n: int, default=None): + """ Return the prime factor of ``n``. + If it does not exist in the cache, return the value of ``default``. + """ + if n <= sieve._list[-1]: + if sieve._list[bisect_left(sieve._list, n)] == n: + return n + if n in self._cache: + self._cache.move_to_end(n) + return self._cache[n] + if factors := self.get_external(n): + self.add(n, factors) + return self._cache[n] + return default + + def add(self, n: int, factors: list[int]) -> None: + for p in sorted(factors, reverse=True): + self[n] = p + n, _ = remove(n, p) + + def get_external(self, n: int) -> list[int] | None: + return None + + +factor_cache = FactorCache(maxsize=1000) + + +def pollard_rho(n, s=2, a=1, retries=5, seed=1234, max_steps=None, F=None): + r""" + Use Pollard's rho method to try to extract a nontrivial factor + of ``n``. The returned factor may be a composite number. If no + factor is found, ``None`` is returned. + + The algorithm generates pseudo-random values of x with a generator + function, replacing x with F(x). If F is not supplied then the + function x**2 + ``a`` is used. The first value supplied to F(x) is ``s``. + Upon failure (if ``retries`` is > 0) a new ``a`` and ``s`` will be + supplied; the ``a`` will be ignored if F was supplied. + + The sequence of numbers generated by such functions generally have a + a lead-up to some number and then loop around back to that number and + begin to repeat the sequence, e.g. 1, 2, 3, 4, 5, 3, 4, 5 -- this leader + and loop look a bit like the Greek letter rho, and thus the name, 'rho'. + + For a given function, very different leader-loop values can be obtained + so it is a good idea to allow for retries: + + >>> from sympy.ntheory.generate import cycle_length + >>> n = 16843009 + >>> F = lambda x:(2048*pow(x, 2, n) + 32767) % n + >>> for s in range(5): + ... print('loop length = %4i; leader length = %3i' % next(cycle_length(F, s))) + ... + loop length = 2489; leader length = 43 + loop length = 78; leader length = 121 + loop length = 1482; leader length = 100 + loop length = 1482; leader length = 286 + loop length = 1482; leader length = 101 + + Here is an explicit example where there is a three element leadup to + a sequence of 3 numbers (11, 14, 4) that then repeat: + + >>> x=2 + >>> for i in range(9): + ... print(x) + ... x=(x**2+12)%17 + ... + 2 + 16 + 13 + 11 + 14 + 4 + 11 + 14 + 4 + >>> next(cycle_length(lambda x: (x**2+12)%17, 2)) + (3, 3) + >>> list(cycle_length(lambda x: (x**2+12)%17, 2, values=True)) + [2, 16, 13, 11, 14, 4] + + Instead of checking the differences of all generated values for a gcd + with n, only the kth and 2*kth numbers are checked, e.g. 1st and 2nd, + 2nd and 4th, 3rd and 6th until it has been detected that the loop has been + traversed. Loops may be many thousands of steps long before rho finds a + factor or reports failure. If ``max_steps`` is specified, the iteration + is cancelled with a failure after the specified number of steps. + + Examples + ======== + + >>> from sympy import pollard_rho + >>> n=16843009 + >>> F=lambda x:(2048*pow(x,2,n) + 32767) % n + >>> pollard_rho(n, F=F) + 257 + + Use the default setting with a bad value of ``a`` and no retries: + + >>> pollard_rho(n, a=n-2, retries=0) + + If retries is > 0 then perhaps the problem will correct itself when + new values are generated for a: + + >>> pollard_rho(n, a=n-2, retries=1) + 257 + + References + ========== + + .. [1] Richard Crandall & Carl Pomerance (2005), "Prime Numbers: + A Computational Perspective", Springer, 2nd edition, 229-231 + + """ + n = int(n) + if n < 5: + raise ValueError('pollard_rho should receive n > 4') + randint = _randint(seed + retries) + V = s + for i in range(retries + 1): + U = V + if not F: + F = lambda x: (pow(x, 2, n) + a) % n + j = 0 + while 1: + if max_steps and (j > max_steps): + break + j += 1 + U = F(U) + V = F(F(V)) # V is 2x further along than U + g = gcd(U - V, n) + if g == 1: + continue + if g == n: + break + return int(g) + V = randint(0, n - 1) + a = randint(1, n - 3) # for x**2 + a, a%n should not be 0 or -2 + F = None + return None + + +def pollard_pm1(n, B=10, a=2, retries=0, seed=1234): + """ + Use Pollard's p-1 method to try to extract a nontrivial factor + of ``n``. Either a divisor (perhaps composite) or ``None`` is returned. + + The value of ``a`` is the base that is used in the test gcd(a**M - 1, n). + The default is 2. If ``retries`` > 0 then if no factor is found after the + first attempt, a new ``a`` will be generated randomly (using the ``seed``) + and the process repeated. + + Note: the value of M is lcm(1..B) = reduce(ilcm, range(2, B + 1)). + + A search is made for factors next to even numbers having a power smoothness + less than ``B``. Choosing a larger B increases the likelihood of finding a + larger factor but takes longer. Whether a factor of n is found or not + depends on ``a`` and the power smoothness of the even number just less than + the factor p (hence the name p - 1). + + Although some discussion of what constitutes a good ``a`` some + descriptions are hard to interpret. At the modular.math site referenced + below it is stated that if gcd(a**M - 1, n) = N then a**M % q**r is 1 + for every prime power divisor of N. But consider the following: + + >>> from sympy.ntheory.factor_ import smoothness_p, pollard_pm1 + >>> n=257*1009 + >>> smoothness_p(n) + (-1, [(257, (1, 2, 256)), (1009, (1, 7, 16))]) + + So we should (and can) find a root with B=16: + + >>> pollard_pm1(n, B=16, a=3) + 1009 + + If we attempt to increase B to 256 we find that it does not work: + + >>> pollard_pm1(n, B=256) + >>> + + But if the value of ``a`` is changed we find that only multiples of + 257 work, e.g.: + + >>> pollard_pm1(n, B=256, a=257) + 1009 + + Checking different ``a`` values shows that all the ones that did not + work had a gcd value not equal to ``n`` but equal to one of the + factors: + + >>> from sympy import ilcm, igcd, factorint, Pow + >>> M = 1 + >>> for i in range(2, 256): + ... M = ilcm(M, i) + ... + >>> set([igcd(pow(a, M, n) - 1, n) for a in range(2, 256) if + ... igcd(pow(a, M, n) - 1, n) != n]) + {1009} + + But does aM % d for every divisor of n give 1? + + >>> aM = pow(255, M, n) + >>> [(d, aM%Pow(*d.args)) for d in factorint(n, visual=True).args] + [(257**1, 1), (1009**1, 1)] + + No, only one of them. So perhaps the principle is that a root will + be found for a given value of B provided that: + + 1) the power smoothness of the p - 1 value next to the root + does not exceed B + 2) a**M % p != 1 for any of the divisors of n. + + By trying more than one ``a`` it is possible that one of them + will yield a factor. + + Examples + ======== + + With the default smoothness bound, this number cannot be cracked: + + >>> from sympy.ntheory import pollard_pm1 + >>> pollard_pm1(21477639576571) + + Increasing the smoothness bound helps: + + >>> pollard_pm1(21477639576571, B=2000) + 4410317 + + Looking at the smoothness of the factors of this number we find: + + >>> from sympy.ntheory.factor_ import smoothness_p, factorint + >>> print(smoothness_p(21477639576571, visual=1)) + p**i=4410317**1 has p-1 B=1787, B-pow=1787 + p**i=4869863**1 has p-1 B=2434931, B-pow=2434931 + + The B and B-pow are the same for the p - 1 factorizations of the divisors + because those factorizations had a very large prime factor: + + >>> factorint(4410317 - 1) + {2: 2, 617: 1, 1787: 1} + >>> factorint(4869863-1) + {2: 1, 2434931: 1} + + Note that until B reaches the B-pow value of 1787, the number is not cracked; + + >>> pollard_pm1(21477639576571, B=1786) + >>> pollard_pm1(21477639576571, B=1787) + 4410317 + + The B value has to do with the factors of the number next to the divisor, + not the divisors themselves. A worst case scenario is that the number next + to the factor p has a large prime divisisor or is a perfect power. If these + conditions apply then the power-smoothness will be about p/2 or p. The more + realistic is that there will be a large prime factor next to p requiring + a B value on the order of p/2. Although primes may have been searched for + up to this level, the p/2 is a factor of p - 1, something that we do not + know. The modular.math reference below states that 15% of numbers in the + range of 10**15 to 15**15 + 10**4 are 10**6 power smooth so a B of 10**6 + will fail 85% of the time in that range. From 10**8 to 10**8 + 10**3 the + percentages are nearly reversed...but in that range the simple trial + division is quite fast. + + References + ========== + + .. [1] Richard Crandall & Carl Pomerance (2005), "Prime Numbers: + A Computational Perspective", Springer, 2nd edition, 236-238 + .. [2] https://web.archive.org/web/20150716201437/http://modular.math.washington.edu/edu/2007/spring/ent/ent-html/node81.html + .. [3] https://www.cs.toronto.edu/~yuvalf/Factorization.pdf + """ + + n = int(n) + if n < 4 or B < 3: + raise ValueError('pollard_pm1 should receive n > 3 and B > 2') + randint = _randint(seed + B) + + # computing a**lcm(1,2,3,..B) % n for B > 2 + # it looks weird, but it's right: primes run [2, B] + # and the answer's not right until the loop is done. + for i in range(retries + 1): + aM = a + for p in sieve.primerange(2, B + 1): + e = int(math.log(B, p)) + aM = pow(aM, pow(p, e), n) + g = gcd(aM - 1, n) + if 1 < g < n: + return int(g) + + # get a new a: + # since the exponent, lcm(1..B), is even, if we allow 'a' to be 'n-1' + # then (n - 1)**even % n will be 1 which will give a g of 0 and 1 will + # give a zero, too, so we set the range as [2, n-2]. Some references + # say 'a' should be coprime to n, but either will detect factors. + a = randint(2, n - 2) + + +def _trial(factors, n, candidates, verbose=False): + """ + Helper function for integer factorization. Trial factors ``n` + against all integers given in the sequence ``candidates`` + and updates the dict ``factors`` in-place. Returns the reduced + value of ``n`` and a flag indicating whether any factors were found. + """ + if verbose: + factors0 = list(factors.keys()) + nfactors = len(factors) + for d in candidates: + if n % d == 0: + if n != d: + factor_cache[n] = d + n, m = remove(n // d, d) + factors[d] = m + 1 + if verbose: + for k in sorted(set(factors).difference(set(factors0))): + print(factor_msg % (k, factors[k])) + return int(n), len(factors) != nfactors + + +def _check_termination(factors, n, limit, use_trial, use_rho, use_pm1, + verbose, next_p): + """ + Helper function for integer factorization. Checks if ``n`` + is a prime or a perfect power, and in those cases updates the factorization. + """ + if verbose: + print('Check for termination') + if n == 1: + if verbose: + print(complete_msg) + return True + if n < next_p**2 or isprime(n): + factor_cache[n] = n + factors[int(n)] = 1 + if verbose: + print(complete_msg) + return True + + # since we've already been factoring there is no need to do + # simultaneous factoring with the power check + p = _perfect_power(n, next_p) + if not p: + return False + base, exp = p + if base < next_p**2 or isprime(base): + factor_cache[n] = base + factors[base] = exp + else: + facs = factorint(base, limit, use_trial, use_rho, use_pm1, + verbose=False) + for b, e in facs.items(): + if verbose: + print(factor_msg % (b, e)) + factors[b] = exp*e + if verbose: + print(complete_msg) + return True + + +trial_int_msg = "Trial division with ints [%i ... %i] and fail_max=%i" +trial_msg = "Trial division with primes [%i ... %i]" +rho_msg = "Pollard's rho with retries %i, max_steps %i and seed %i" +pm1_msg = "Pollard's p-1 with smoothness bound %i and seed %i" +ecm_msg = "Elliptic Curve with B1 bound %i, B2 bound %i, num_curves %i" +factor_msg = '\t%i ** %i' +fermat_msg = 'Close factors satisfying Fermat condition found.' +complete_msg = 'Factorization is complete.' + + +def _factorint_small(factors, n, limit, fail_max, next_p=2): + """ + Return the value of n and either a 0 (indicating that factorization up + to the limit was complete) or else the next near-prime that would have + been tested. + + Factoring stops if there are fail_max unsuccessful tests in a row. + + If factors of n were found they will be in the factors dictionary as + {factor: multiplicity} and the returned value of n will have had those + factors removed. The factors dictionary is modified in-place. + + """ + + def done(n, d): + """return n, d if the sqrt(n) was not reached yet, else + n, 0 indicating that factoring is done. + """ + if d*d <= n: + return n, d + return n, 0 + + limit2 = limit**2 + threshold2 = min(n, limit2) + + if next_p < 3: + if not n & 1: + m = bit_scan1(n) + factors[2] = m + n >>= m + threshold2 = min(n, limit2) + next_p = 3 + if threshold2 < 9: # next_p**2 = 9 + return done(n, next_p) + + if next_p < 5: + if not n % 3: + n //= 3 + m = 1 + while not n % 3: + n //= 3 + m += 1 + if m == 20: + n, mm = remove(n, 3) + m += mm + break + factors[3] = m + threshold2 = min(n, limit2) + next_p = 5 + if threshold2 < 25: # next_p**2 = 25 + return done(n, next_p) + + # Because of the order of checks, starting from `min_p = 6k+5`, + # useless checks are caused. + # We want to calculate + # next_p += [-1, -2, 3, 2, 1, 0][next_p % 6] + p6 = next_p % 6 + next_p += (-1 if p6 < 2 else 5) - p6 + + fails = 0 + while fails < fail_max: + # next_p % 6 == 5 + if n % next_p: + fails += 1 + else: + n //= next_p + m = 1 + while not n % next_p: + n //= next_p + m += 1 + if m == 20: + n, mm = remove(n, next_p) + m += mm + break + factors[next_p] = m + fails = 0 + threshold2 = min(n, limit2) + next_p += 2 + if threshold2 < next_p**2: + return done(n, next_p) + + # next_p % 6 == 1 + if n % next_p: + fails += 1 + else: + n //= next_p + m = 1 + while not n % next_p: + n //= next_p + m += 1 + if m == 20: + n, mm = remove(n, next_p) + m += mm + break + factors[next_p] = m + fails = 0 + threshold2 = min(n, limit2) + next_p += 4 + if threshold2 < next_p**2: + return done(n, next_p) + return done(n, next_p) + + +def factorint(n, limit=None, use_trial=True, use_rho=True, use_pm1=True, + use_ecm=True, verbose=False, visual=None, multiple=False): + r""" + Given a positive integer ``n``, ``factorint(n)`` returns a dict containing + the prime factors of ``n`` as keys and their respective multiplicities + as values. For example: + + >>> from sympy.ntheory import factorint + >>> factorint(2000) # 2000 = (2**4) * (5**3) + {2: 4, 5: 3} + >>> factorint(65537) # This number is prime + {65537: 1} + + For input less than 2, factorint behaves as follows: + + - ``factorint(1)`` returns the empty factorization, ``{}`` + - ``factorint(0)`` returns ``{0:1}`` + - ``factorint(-n)`` adds ``-1:1`` to the factors and then factors ``n`` + + Partial Factorization: + + If ``limit`` (> 3) is specified, the search is stopped after performing + trial division up to (and including) the limit (or taking a + corresponding number of rho/p-1 steps). This is useful if one has + a large number and only is interested in finding small factors (if + any). Note that setting a limit does not prevent larger factors + from being found early; it simply means that the largest factor may + be composite. Since checking for perfect power is relatively cheap, it is + done regardless of the limit setting. + + This number, for example, has two small factors and a huge + semi-prime factor that cannot be reduced easily: + + >>> from sympy.ntheory import isprime + >>> a = 1407633717262338957430697921446883 + >>> f = factorint(a, limit=10000) + >>> f == {991: 1, int(202916782076162456022877024859): 1, 7: 1} + True + >>> isprime(max(f)) + False + + This number has a small factor and a residual perfect power whose + base is greater than the limit: + + >>> factorint(3*101**7, limit=5) + {3: 1, 101: 7} + + List of Factors: + + If ``multiple`` is set to ``True`` then a list containing the + prime factors including multiplicities is returned. + + >>> factorint(24, multiple=True) + [2, 2, 2, 3] + + Visual Factorization: + + If ``visual`` is set to ``True``, then it will return a visual + factorization of the integer. For example: + + >>> from sympy import pprint + >>> pprint(factorint(4200, visual=True)) + 3 1 2 1 + 2 *3 *5 *7 + + Note that this is achieved by using the evaluate=False flag in Mul + and Pow. If you do other manipulations with an expression where + evaluate=False, it may evaluate. Therefore, you should use the + visual option only for visualization, and use the normal dictionary + returned by visual=False if you want to perform operations on the + factors. + + You can easily switch between the two forms by sending them back to + factorint: + + >>> from sympy import Mul + >>> regular = factorint(1764); regular + {2: 2, 3: 2, 7: 2} + >>> pprint(factorint(regular)) + 2 2 2 + 2 *3 *7 + + >>> visual = factorint(1764, visual=True); pprint(visual) + 2 2 2 + 2 *3 *7 + >>> print(factorint(visual)) + {2: 2, 3: 2, 7: 2} + + If you want to send a number to be factored in a partially factored form + you can do so with a dictionary or unevaluated expression: + + >>> factorint(factorint({4: 2, 12: 3})) # twice to toggle to dict form + {2: 10, 3: 3} + >>> factorint(Mul(4, 12, evaluate=False)) + {2: 4, 3: 1} + + The table of the output logic is: + + ====== ====== ======= ======= + Visual + ------ ---------------------- + Input True False other + ====== ====== ======= ======= + dict mul dict mul + n mul dict dict + mul mul dict dict + ====== ====== ======= ======= + + Notes + ===== + + Algorithm: + + The function switches between multiple algorithms. Trial division + quickly finds small factors (of the order 1-5 digits), and finds + all large factors if given enough time. The Pollard rho and p-1 + algorithms are used to find large factors ahead of time; they + will often find factors of the order of 10 digits within a few + seconds: + + >>> factors = factorint(12345678910111213141516) + >>> for base, exp in sorted(factors.items()): + ... print('%s %s' % (base, exp)) + ... + 2 2 + 2507191691 1 + 1231026625769 1 + + Any of these methods can optionally be disabled with the following + boolean parameters: + + - ``use_trial``: Toggle use of trial division + - ``use_rho``: Toggle use of Pollard's rho method + - ``use_pm1``: Toggle use of Pollard's p-1 method + + ``factorint`` also periodically checks if the remaining part is + a prime number or a perfect power, and in those cases stops. + + For unevaluated factorial, it uses Legendre's formula(theorem). + + + If ``verbose`` is set to ``True``, detailed progress is printed. + + See Also + ======== + + smoothness, smoothness_p, divisors + + """ + if isinstance(n, Dict): + n = dict(n) + if multiple: + fac = factorint(n, limit=limit, use_trial=use_trial, + use_rho=use_rho, use_pm1=use_pm1, + verbose=verbose, visual=False, multiple=False) + factorlist = sum(([p] * fac[p] if fac[p] > 0 else [S.One/p]*(-fac[p]) + for p in sorted(fac)), []) + return factorlist + + factordict = {} + if visual and not isinstance(n, (Mul, dict)): + factordict = factorint(n, limit=limit, use_trial=use_trial, + use_rho=use_rho, use_pm1=use_pm1, + verbose=verbose, visual=False) + elif isinstance(n, Mul): + factordict = {int(k): int(v) for k, v in + n.as_powers_dict().items()} + elif isinstance(n, dict): + factordict = n + if factordict and isinstance(n, (Mul, dict)): + # check it + for key in list(factordict.keys()): + if isprime(key): + continue + e = factordict.pop(key) + d = factorint(key, limit=limit, use_trial=use_trial, use_rho=use_rho, + use_pm1=use_pm1, verbose=verbose, visual=False) + for k, v in d.items(): + if k in factordict: + factordict[k] += v*e + else: + factordict[k] = v*e + if visual or (type(n) is dict and + visual is not True and + visual is not False): + if factordict == {}: + return S.One + if -1 in factordict: + factordict.pop(-1) + args = [S.NegativeOne] + else: + args = [] + args.extend([Pow(*i, evaluate=False) + for i in sorted(factordict.items())]) + return Mul(*args, evaluate=False) + elif isinstance(n, (dict, Mul)): + return factordict + + assert use_trial or use_rho or use_pm1 or use_ecm + + from sympy.functions.combinatorial.factorials import factorial + if isinstance(n, factorial): + x = as_int(n.args[0]) + if x >= 20: + factors = {} + m = 2 # to initialize the if condition below + for p in sieve.primerange(2, x + 1): + if m > 1: + m, q = 0, x // p + while q != 0: + m += q + q //= p + factors[p] = m + if factors and verbose: + for k in sorted(factors): + print(factor_msg % (k, factors[k])) + if verbose: + print(complete_msg) + return factors + else: + # if n < 20!, direct computation is faster + # since it uses a lookup table + n = n.func(x) + + n = as_int(n) + if limit: + limit = int(limit) + use_ecm = False + + # special cases + if n < 0: + factors = factorint( + -n, limit=limit, use_trial=use_trial, use_rho=use_rho, + use_pm1=use_pm1, verbose=verbose, visual=False) + factors[-1] = 1 + return factors + + if limit and limit < 2: + if n == 1: + return {} + return {n: 1} + elif n < 10: + # doing this we are assured of getting a limit > 2 + # when we have to compute it later + return [{0: 1}, {}, {2: 1}, {3: 1}, {2: 2}, {5: 1}, + {2: 1, 3: 1}, {7: 1}, {2: 3}, {3: 2}][n] + + factors = {} + + # do simplistic factorization + if verbose: + sn = str(n) + if len(sn) > 50: + print('Factoring %s' % sn[:5] + \ + '..(%i other digits)..' % (len(sn) - 10) + sn[-5:]) + else: + print('Factoring', n) + + # this is the preliminary factorization for small factors + # We want to guarantee that there are no small prime factors, + # so we run even if `use_trial` is False. + small = 2**15 + fail_max = 600 + small = min(small, limit or small) + if verbose: + print(trial_int_msg % (2, small, fail_max)) + n, next_p = _factorint_small(factors, n, small, fail_max) + if factors and verbose: + for k in sorted(factors): + print(factor_msg % (k, factors[k])) + if next_p == 0: + if n > 1: + factors[int(n)] = 1 + if verbose: + print(complete_msg) + return factors + # Check if it exists in the cache + while p := factor_cache.get(n): + n, e = remove(n, p) + factors[int(p)] = int(e) + # first check if the simplistic run didn't finish + # because of the limit and check for a perfect + # power before exiting + if limit and next_p > limit: + if verbose: + print('Exceeded limit:', limit) + if _check_termination(factors, n, limit, use_trial, + use_rho, use_pm1, verbose, next_p): + return factors + if n > 1: + factors[int(n)] = 1 + return factors + if _check_termination(factors, n, limit, use_trial, + use_rho, use_pm1, verbose, next_p): + return factors + + # continue with more advanced factorization methods + # ...do a Fermat test since it's so easy and we need the + # square root anyway. Finding 2 factors is easy if they are + # "close enough." This is the big root equivalent of dividing by + # 2, 3, 5. + sqrt_n = isqrt(n) + a = sqrt_n + 1 + # If `n % 4 == 1`, `a` must be odd for `a**2 - n` to be a square number. + if (n % 4 == 1) ^ (a & 1): + a += 1 + a2 = a**2 + b2 = a2 - n + for _ in range(3): + b, fermat = sqrtrem(b2) + if not fermat: + if verbose: + print(fermat_msg) + for r in [a - b, a + b]: + facs = factorint(r, limit=limit, use_trial=use_trial, + use_rho=use_rho, use_pm1=use_pm1, + verbose=verbose) + for k, v in facs.items(): + factors[k] = factors.get(k, 0) + v + factor_cache.add(n, facs) + if verbose: + print(complete_msg) + return factors + b2 += (a + 1) << 2 # equiv to (a + 2)**2 - n + a += 2 + + # these are the limits for trial division which will + # be attempted in parallel with pollard methods + low, high = next_p, 2*next_p + + # add 1 to make sure limit is reached in primerange calls + _limit = (limit or sqrt_n) + 1 + iteration = 0 + while 1: + high_ = min(high, _limit) + + # Trial division + if use_trial: + if verbose: + print(trial_msg % (low, high_)) + ps = sieve.primerange(low, high_) + n, found_trial = _trial(factors, n, ps, verbose) + next_p = high_ + if found_trial and _check_termination(factors, n, limit, use_trial, + use_rho, use_pm1, verbose, next_p): + return factors + else: + found_trial = False + + if high > _limit: + if verbose: + print('Exceeded limit:', _limit) + if n > 1: + factors[int(n)] = 1 + if verbose: + print(complete_msg) + return factors + + # Only used advanced methods when no small factors were found + if not found_trial: + # Pollard p-1 + if use_pm1: + if verbose: + print(pm1_msg % (low, high_)) + c = pollard_pm1(n, B=low, seed=high_) + if c: + if c < next_p**2 or isprime(c): + ps = [c] + else: + ps = factorint(c, limit=limit, + use_trial=use_trial, + use_rho=use_rho, + use_pm1=use_pm1, + use_ecm=use_ecm, + verbose=verbose) + n, _ = _trial(factors, n, ps, verbose=False) + if _check_termination(factors, n, limit, use_trial, + use_rho, use_pm1, verbose, next_p): + return factors + + # Pollard rho + if use_rho: + if verbose: + print(rho_msg % (1, low, high_)) + c = pollard_rho(n, retries=1, max_steps=low, seed=high_) + if c: + if c < next_p**2 or isprime(c): + ps = [c] + else: + ps = factorint(c, limit=limit, + use_trial=use_trial, + use_rho=use_rho, + use_pm1=use_pm1, + use_ecm=use_ecm, + verbose=verbose) + n, _ = _trial(factors, n, ps, verbose=False) + if _check_termination(factors, n, limit, use_trial, + use_rho, use_pm1, verbose, next_p): + return factors + # Use subexponential algorithms if use_ecm + # Use pollard algorithms for finding small factors for 3 iterations + # if after small factors the number of digits of n >= 25 then use ecm + iteration += 1 + if use_ecm and iteration >= 3 and num_digits(n) >= 24: + break + low, high = high, high*2 + + B1 = 10000 + B2 = 100*B1 + num_curves = 50 + while(1): + if verbose: + print(ecm_msg % (B1, B2, num_curves)) + factor = _ecm_one_factor(n, B1, B2, num_curves, seed=B1) + if factor: + if factor < next_p**2 or isprime(factor): + ps = [factor] + else: + ps = factorint(factor, limit=limit, + use_trial=use_trial, + use_rho=use_rho, + use_pm1=use_pm1, + use_ecm=use_ecm, + verbose=verbose) + n, _ = _trial(factors, n, ps, verbose=False) + if _check_termination(factors, n, limit, use_trial, + use_rho, use_pm1, verbose, next_p): + return factors + B1 *= 5 + B2 = 100*B1 + num_curves *= 4 + + +def factorrat(rat, limit=None, use_trial=True, use_rho=True, use_pm1=True, + verbose=False, visual=None, multiple=False): + r""" + Given a Rational ``r``, ``factorrat(r)`` returns a dict containing + the prime factors of ``r`` as keys and their respective multiplicities + as values. For example: + + >>> from sympy import factorrat, S + >>> factorrat(S(8)/9) # 8/9 = (2**3) * (3**-2) + {2: 3, 3: -2} + >>> factorrat(S(-1)/987) # -1/789 = -1 * (3**-1) * (7**-1) * (47**-1) + {-1: 1, 3: -1, 7: -1, 47: -1} + + Please see the docstring for ``factorint`` for detailed explanations + and examples of the following keywords: + + - ``limit``: Integer limit up to which trial division is done + - ``use_trial``: Toggle use of trial division + - ``use_rho``: Toggle use of Pollard's rho method + - ``use_pm1``: Toggle use of Pollard's p-1 method + - ``verbose``: Toggle detailed printing of progress + - ``multiple``: Toggle returning a list of factors or dict + - ``visual``: Toggle product form of output + """ + if multiple: + fac = factorrat(rat, limit=limit, use_trial=use_trial, + use_rho=use_rho, use_pm1=use_pm1, + verbose=verbose, visual=False, multiple=False) + factorlist = sum(([p] * fac[p] if fac[p] > 0 else [S.One/p]*(-fac[p]) + for p, _ in sorted(fac.items(), + key=lambda elem: elem[0] + if elem[1] > 0 + else 1/elem[0])), []) + return factorlist + + f = factorint(rat.p, limit=limit, use_trial=use_trial, + use_rho=use_rho, use_pm1=use_pm1, + verbose=verbose).copy() + f = defaultdict(int, f) + for p, e in factorint(rat.q, limit=limit, + use_trial=use_trial, + use_rho=use_rho, + use_pm1=use_pm1, + verbose=verbose).items(): + f[p] += -e + + if len(f) > 1 and 1 in f: + del f[1] + if not visual: + return dict(f) + else: + if -1 in f: + f.pop(-1) + args = [S.NegativeOne] + else: + args = [] + args.extend([Pow(*i, evaluate=False) + for i in sorted(f.items())]) + return Mul(*args, evaluate=False) + + +def primefactors(n, limit=None, verbose=False, **kwargs): + """Return a sorted list of n's prime factors, ignoring multiplicity + and any composite factor that remains if the limit was set too low + for complete factorization. Unlike factorint(), primefactors() does + not return -1 or 0. + + Parameters + ========== + + n : integer + limit, verbose, **kwargs : + Additional keyword arguments to be passed to ``factorint``. + Since ``kwargs`` is new in version 1.13, + ``limit`` and ``verbose`` are retained for compatibility purposes. + + Returns + ======= + + list(int) : List of prime numbers dividing ``n`` + + Examples + ======== + + >>> from sympy.ntheory import primefactors, factorint, isprime + >>> primefactors(6) + [2, 3] + >>> primefactors(-5) + [5] + + >>> sorted(factorint(123456).items()) + [(2, 6), (3, 1), (643, 1)] + >>> primefactors(123456) + [2, 3, 643] + + >>> sorted(factorint(10000000001, limit=200).items()) + [(101, 1), (99009901, 1)] + >>> isprime(99009901) + False + >>> primefactors(10000000001, limit=300) + [101] + + See Also + ======== + + factorint, divisors + + """ + n = int(n) + kwargs.update({"visual": None, "multiple": False, + "limit": limit, "verbose": verbose}) + factors = sorted(factorint(n=n, **kwargs).keys()) + # We want to calculate + # s = [f for f in factors if isprime(f)] + s = [f for f in factors[:-1:] if f not in [-1, 0, 1]] + if factors and isprime(factors[-1]): + s += [factors[-1]] + return s + + +def _divisors(n, proper=False): + """Helper function for divisors which generates the divisors. + + Parameters + ========== + + n : int + a nonnegative integer + proper: bool + If `True`, returns the generator that outputs only the proper divisor (i.e., excluding n). + + """ + if n <= 1: + if not proper and n: + yield 1 + return + + factordict = factorint(n) + ps = sorted(factordict.keys()) + + def rec_gen(n=0): + if n == len(ps): + yield 1 + else: + pows = [1] + for _ in range(factordict[ps[n]]): + pows.append(pows[-1] * ps[n]) + yield from (p * q for q in rec_gen(n + 1) for p in pows) + + if proper: + yield from (p for p in rec_gen() if p != n) + else: + yield from rec_gen() + + +def divisors(n, generator=False, proper=False): + r""" + Return all divisors of n sorted from 1..n by default. + If generator is ``True`` an unordered generator is returned. + + The number of divisors of n can be quite large if there are many + prime factors (counting repeated factors). If only the number of + factors is desired use divisor_count(n). + + Examples + ======== + + >>> from sympy import divisors, divisor_count + >>> divisors(24) + [1, 2, 3, 4, 6, 8, 12, 24] + >>> divisor_count(24) + 8 + + >>> list(divisors(120, generator=True)) + [1, 2, 4, 8, 3, 6, 12, 24, 5, 10, 20, 40, 15, 30, 60, 120] + + Notes + ===== + + This is a slightly modified version of Tim Peters referenced at: + https://stackoverflow.com/questions/1010381/python-factorization + + See Also + ======== + + primefactors, factorint, divisor_count + """ + rv = _divisors(as_int(abs(n)), proper) + return rv if generator else sorted(rv) + + +def divisor_count(n, modulus=1, proper=False): + """ + Return the number of divisors of ``n``. If ``modulus`` is not 1 then only + those that are divisible by ``modulus`` are counted. If ``proper`` is True + then the divisor of ``n`` will not be counted. + + Examples + ======== + + >>> from sympy import divisor_count + >>> divisor_count(6) + 4 + >>> divisor_count(6, 2) + 2 + >>> divisor_count(6, proper=True) + 3 + + See Also + ======== + + factorint, divisors, totient, proper_divisor_count + + """ + + if not modulus: + return 0 + elif modulus != 1: + n, r = divmod(n, modulus) + if r: + return 0 + if n == 0: + return 0 + n = Mul(*[v + 1 for k, v in factorint(n).items() if k > 1]) + if n and proper: + n -= 1 + return n + + +def proper_divisors(n, generator=False): + """ + Return all divisors of n except n, sorted by default. + If generator is ``True`` an unordered generator is returned. + + Examples + ======== + + >>> from sympy import proper_divisors, proper_divisor_count + >>> proper_divisors(24) + [1, 2, 3, 4, 6, 8, 12] + >>> proper_divisor_count(24) + 7 + >>> list(proper_divisors(120, generator=True)) + [1, 2, 4, 8, 3, 6, 12, 24, 5, 10, 20, 40, 15, 30, 60] + + See Also + ======== + + factorint, divisors, proper_divisor_count + + """ + return divisors(n, generator=generator, proper=True) + + +def proper_divisor_count(n, modulus=1): + """ + Return the number of proper divisors of ``n``. + + Examples + ======== + + >>> from sympy import proper_divisor_count + >>> proper_divisor_count(6) + 3 + >>> proper_divisor_count(6, modulus=2) + 1 + + See Also + ======== + + divisors, proper_divisors, divisor_count + + """ + return divisor_count(n, modulus=modulus, proper=True) + + +def _udivisors(n): + """Helper function for udivisors which generates the unitary divisors. + + Parameters + ========== + + n : int + a nonnegative integer + + """ + if n <= 1: + if n == 1: + yield 1 + return + + factorpows = [p**e for p, e in factorint(n).items()] + # We want to calculate + # yield from (math.prod(s) for s in powersets(factorpows)) + for i in range(2**len(factorpows)): + d = 1 + for k in range(i.bit_length()): + if i & 1: + d *= factorpows[k] + i >>= 1 + yield d + + +def udivisors(n, generator=False): + r""" + Return all unitary divisors of n sorted from 1..n by default. + If generator is ``True`` an unordered generator is returned. + + The number of unitary divisors of n can be quite large if there are many + prime factors. If only the number of unitary divisors is desired use + udivisor_count(n). + + Examples + ======== + + >>> from sympy.ntheory.factor_ import udivisors, udivisor_count + >>> udivisors(15) + [1, 3, 5, 15] + >>> udivisor_count(15) + 4 + + >>> sorted(udivisors(120, generator=True)) + [1, 3, 5, 8, 15, 24, 40, 120] + + See Also + ======== + + primefactors, factorint, divisors, divisor_count, udivisor_count + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Unitary_divisor + .. [2] https://mathworld.wolfram.com/UnitaryDivisor.html + + """ + rv = _udivisors(as_int(abs(n))) + return rv if generator else sorted(rv) + + +def udivisor_count(n): + """ + Return the number of unitary divisors of ``n``. + + Parameters + ========== + + n : integer + + Examples + ======== + + >>> from sympy.ntheory.factor_ import udivisor_count + >>> udivisor_count(120) + 8 + + See Also + ======== + + factorint, divisors, udivisors, divisor_count, totient + + References + ========== + + .. [1] https://mathworld.wolfram.com/UnitaryDivisorFunction.html + + """ + + if n == 0: + return 0 + return 2**len([p for p in factorint(n) if p > 1]) + + +def _antidivisors(n): + """Helper function for antidivisors which generates the antidivisors. + + Parameters + ========== + + n : int + a nonnegative integer + + """ + if n <= 2: + return + for d in _divisors(n): + y = 2*d + if n > y and n % y: + yield y + for d in _divisors(2*n-1): + if n > d >= 2 and n % d: + yield d + for d in _divisors(2*n+1): + if n > d >= 2 and n % d: + yield d + + +def antidivisors(n, generator=False): + r""" + Return all antidivisors of n sorted from 1..n by default. + + Antidivisors [1]_ of n are numbers that do not divide n by the largest + possible margin. If generator is True an unordered generator is returned. + + Examples + ======== + + >>> from sympy.ntheory.factor_ import antidivisors + >>> antidivisors(24) + [7, 16] + + >>> sorted(antidivisors(128, generator=True)) + [3, 5, 15, 17, 51, 85] + + See Also + ======== + + primefactors, factorint, divisors, divisor_count, antidivisor_count + + References + ========== + + .. [1] definition is described in https://oeis.org/A066272/a066272a.html + + """ + rv = _antidivisors(as_int(abs(n))) + return rv if generator else sorted(rv) + + +def antidivisor_count(n): + """ + Return the number of antidivisors [1]_ of ``n``. + + Parameters + ========== + + n : integer + + Examples + ======== + + >>> from sympy.ntheory.factor_ import antidivisor_count + >>> antidivisor_count(13) + 4 + >>> antidivisor_count(27) + 5 + + See Also + ======== + + factorint, divisors, antidivisors, divisor_count, totient + + References + ========== + + .. [1] formula from https://oeis.org/A066272 + + """ + + n = as_int(abs(n)) + if n <= 2: + return 0 + return divisor_count(2*n - 1) + divisor_count(2*n + 1) + \ + divisor_count(n) - divisor_count(n, 2) - 5 + +@deprecated("""\ +The `sympy.ntheory.factor_.totient` has been moved to `sympy.functions.combinatorial.numbers.totient`.""", +deprecated_since_version="1.13", +active_deprecations_target='deprecated-ntheory-symbolic-functions') +def totient(n): + r""" + Calculate the Euler totient function phi(n) + + .. deprecated:: 1.13 + + The ``totient`` function is deprecated. Use :class:`sympy.functions.combinatorial.numbers.totient` + instead. See its documentation for more information. See + :ref:`deprecated-ntheory-symbolic-functions` for details. + + ``totient(n)`` or `\phi(n)` is the number of positive integers `\leq` n + that are relatively prime to n. + + Parameters + ========== + + n : integer + + Examples + ======== + + >>> from sympy.functions.combinatorial.numbers import totient + >>> totient(1) + 1 + >>> totient(25) + 20 + >>> totient(45) == totient(5)*totient(9) + True + + See Also + ======== + + divisor_count + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Euler%27s_totient_function + .. [2] https://mathworld.wolfram.com/TotientFunction.html + + """ + from sympy.functions.combinatorial.numbers import totient as _totient + return _totient(n) + + +@deprecated("""\ +The `sympy.ntheory.factor_.reduced_totient` has been moved to `sympy.functions.combinatorial.numbers.reduced_totient`.""", +deprecated_since_version="1.13", +active_deprecations_target='deprecated-ntheory-symbolic-functions') +def reduced_totient(n): + r""" + Calculate the Carmichael reduced totient function lambda(n) + + .. deprecated:: 1.13 + + The ``reduced_totient`` function is deprecated. Use :class:`sympy.functions.combinatorial.numbers.reduced_totient` + instead. See its documentation for more information. See + :ref:`deprecated-ntheory-symbolic-functions` for details. + + ``reduced_totient(n)`` or `\lambda(n)` is the smallest m > 0 such that + `k^m \equiv 1 \mod n` for all k relatively prime to n. + + Examples + ======== + + >>> from sympy.functions.combinatorial.numbers import reduced_totient + >>> reduced_totient(1) + 1 + >>> reduced_totient(8) + 2 + >>> reduced_totient(30) + 4 + + See Also + ======== + + totient + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Carmichael_function + .. [2] https://mathworld.wolfram.com/CarmichaelFunction.html + + """ + from sympy.functions.combinatorial.numbers import reduced_totient as _reduced_totient + return _reduced_totient(n) + + +@deprecated("""\ +The `sympy.ntheory.factor_.divisor_sigma` has been moved to `sympy.functions.combinatorial.numbers.divisor_sigma`.""", +deprecated_since_version="1.13", +active_deprecations_target='deprecated-ntheory-symbolic-functions') +def divisor_sigma(n, k=1): + r""" + Calculate the divisor function `\sigma_k(n)` for positive integer n + + .. deprecated:: 1.13 + + The ``divisor_sigma`` function is deprecated. Use :class:`sympy.functions.combinatorial.numbers.divisor_sigma` + instead. See its documentation for more information. See + :ref:`deprecated-ntheory-symbolic-functions` for details. + + ``divisor_sigma(n, k)`` is equal to ``sum([x**k for x in divisors(n)])`` + + If n's prime factorization is: + + .. math :: + n = \prod_{i=1}^\omega p_i^{m_i}, + + then + + .. math :: + \sigma_k(n) = \prod_{i=1}^\omega (1+p_i^k+p_i^{2k}+\cdots + + p_i^{m_ik}). + + Parameters + ========== + + n : integer + + k : integer, optional + power of divisors in the sum + + for k = 0, 1: + ``divisor_sigma(n, 0)`` is equal to ``divisor_count(n)`` + ``divisor_sigma(n, 1)`` is equal to ``sum(divisors(n))`` + + Default for k is 1. + + Examples + ======== + + >>> from sympy.functions.combinatorial.numbers import divisor_sigma + >>> divisor_sigma(18, 0) + 6 + >>> divisor_sigma(39, 1) + 56 + >>> divisor_sigma(12, 2) + 210 + >>> divisor_sigma(37) + 38 + + See Also + ======== + + divisor_count, totient, divisors, factorint + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Divisor_function + + """ + from sympy.functions.combinatorial.numbers import divisor_sigma as func_divisor_sigma + return func_divisor_sigma(n, k) + + +def _divisor_sigma(n:int, k:int=1) -> int: + r""" Calculate the divisor function `\sigma_k(n)` for positive integer n + + Parameters + ========== + + n : int + positive integer + k : int + nonnegative integer + + See Also + ======== + + sympy.functions.combinatorial.numbers.divisor_sigma + + """ + if k == 0: + return math.prod(e + 1 for e in factorint(n).values()) + return math.prod((p**(k*(e + 1)) - 1)//(p**k - 1) for p, e in factorint(n).items()) + + +def core(n, t=2): + r""" + Calculate core(n, t) = `core_t(n)` of a positive integer n + + ``core_2(n)`` is equal to the squarefree part of n + + If n's prime factorization is: + + .. math :: + n = \prod_{i=1}^\omega p_i^{m_i}, + + then + + .. math :: + core_t(n) = \prod_{i=1}^\omega p_i^{m_i \mod t}. + + Parameters + ========== + + n : integer + + t : integer + core(n, t) calculates the t-th power free part of n + + ``core(n, 2)`` is the squarefree part of ``n`` + ``core(n, 3)`` is the cubefree part of ``n`` + + Default for t is 2. + + Examples + ======== + + >>> from sympy.ntheory.factor_ import core + >>> core(24, 2) + 6 + >>> core(9424, 3) + 1178 + >>> core(379238) + 379238 + >>> core(15**11, 10) + 15 + + See Also + ======== + + factorint, sympy.solvers.diophantine.diophantine.square_factor + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Square-free_integer#Squarefree_core + + """ + + n = as_int(n) + t = as_int(t) + if n <= 0: + raise ValueError("n must be a positive integer") + elif t <= 1: + raise ValueError("t must be >= 2") + else: + y = 1 + for p, e in factorint(n).items(): + y *= p**(e % t) + return y + + +@deprecated("""\ +The `sympy.ntheory.factor_.udivisor_sigma` has been moved to `sympy.functions.combinatorial.numbers.udivisor_sigma`.""", +deprecated_since_version="1.13", +active_deprecations_target='deprecated-ntheory-symbolic-functions') +def udivisor_sigma(n, k=1): + r""" + Calculate the unitary divisor function `\sigma_k^*(n)` for positive integer n + + .. deprecated:: 1.13 + + The ``udivisor_sigma`` function is deprecated. Use :class:`sympy.functions.combinatorial.numbers.udivisor_sigma` + instead. See its documentation for more information. See + :ref:`deprecated-ntheory-symbolic-functions` for details. + + ``udivisor_sigma(n, k)`` is equal to ``sum([x**k for x in udivisors(n)])`` + + If n's prime factorization is: + + .. math :: + n = \prod_{i=1}^\omega p_i^{m_i}, + + then + + .. math :: + \sigma_k^*(n) = \prod_{i=1}^\omega (1+ p_i^{m_ik}). + + Parameters + ========== + + k : power of divisors in the sum + + for k = 0, 1: + ``udivisor_sigma(n, 0)`` is equal to ``udivisor_count(n)`` + ``udivisor_sigma(n, 1)`` is equal to ``sum(udivisors(n))`` + + Default for k is 1. + + Examples + ======== + + >>> from sympy.functions.combinatorial.numbers import udivisor_sigma + >>> udivisor_sigma(18, 0) + 4 + >>> udivisor_sigma(74, 1) + 114 + >>> udivisor_sigma(36, 3) + 47450 + >>> udivisor_sigma(111) + 152 + + See Also + ======== + + divisor_count, totient, divisors, udivisors, udivisor_count, divisor_sigma, + factorint + + References + ========== + + .. [1] https://mathworld.wolfram.com/UnitaryDivisorFunction.html + + """ + from sympy.functions.combinatorial.numbers import udivisor_sigma as _udivisor_sigma + return _udivisor_sigma(n, k) + + +@deprecated("""\ +The `sympy.ntheory.factor_.primenu` has been moved to `sympy.functions.combinatorial.numbers.primenu`.""", +deprecated_since_version="1.13", +active_deprecations_target='deprecated-ntheory-symbolic-functions') +def primenu(n): + r""" + Calculate the number of distinct prime factors for a positive integer n. + + .. deprecated:: 1.13 + + The ``primenu`` function is deprecated. Use :class:`sympy.functions.combinatorial.numbers.primenu` + instead. See its documentation for more information. See + :ref:`deprecated-ntheory-symbolic-functions` for details. + + If n's prime factorization is: + + .. math :: + n = \prod_{i=1}^k p_i^{m_i}, + + then ``primenu(n)`` or `\nu(n)` is: + + .. math :: + \nu(n) = k. + + Examples + ======== + + >>> from sympy.functions.combinatorial.numbers import primenu + >>> primenu(1) + 0 + >>> primenu(30) + 3 + + See Also + ======== + + factorint + + References + ========== + + .. [1] https://mathworld.wolfram.com/PrimeFactor.html + + """ + from sympy.functions.combinatorial.numbers import primenu as _primenu + return _primenu(n) + + +@deprecated("""\ +The `sympy.ntheory.factor_.primeomega` has been moved to `sympy.functions.combinatorial.numbers.primeomega`.""", +deprecated_since_version="1.13", +active_deprecations_target='deprecated-ntheory-symbolic-functions') +def primeomega(n): + r""" + Calculate the number of prime factors counting multiplicities for a + positive integer n. + + .. deprecated:: 1.13 + + The ``primeomega`` function is deprecated. Use :class:`sympy.functions.combinatorial.numbers.primeomega` + instead. See its documentation for more information. See + :ref:`deprecated-ntheory-symbolic-functions` for details. + + If n's prime factorization is: + + .. math :: + n = \prod_{i=1}^k p_i^{m_i}, + + then ``primeomega(n)`` or `\Omega(n)` is: + + .. math :: + \Omega(n) = \sum_{i=1}^k m_i. + + Examples + ======== + + >>> from sympy.functions.combinatorial.numbers import primeomega + >>> primeomega(1) + 0 + >>> primeomega(20) + 3 + + See Also + ======== + + factorint + + References + ========== + + .. [1] https://mathworld.wolfram.com/PrimeFactor.html + + """ + from sympy.functions.combinatorial.numbers import primeomega as _primeomega + return _primeomega(n) + + +def mersenne_prime_exponent(nth): + """Returns the exponent ``i`` for the nth Mersenne prime (which + has the form `2^i - 1`). + + Examples + ======== + + >>> from sympy.ntheory.factor_ import mersenne_prime_exponent + >>> mersenne_prime_exponent(1) + 2 + >>> mersenne_prime_exponent(20) + 4423 + """ + n = as_int(nth) + if n < 1: + raise ValueError("nth must be a positive integer; mersenne_prime_exponent(1) == 2") + if n > 51: + raise ValueError("There are only 51 perfect numbers; nth must be less than or equal to 51") + return MERSENNE_PRIME_EXPONENTS[n - 1] + + +def is_perfect(n): + """Returns True if ``n`` is a perfect number, else False. + + A perfect number is equal to the sum of its positive, proper divisors. + + Examples + ======== + + >>> from sympy.functions.combinatorial.numbers import divisor_sigma + >>> from sympy.ntheory.factor_ import is_perfect, divisors + >>> is_perfect(20) + False + >>> is_perfect(6) + True + >>> 6 == divisor_sigma(6) - 6 == sum(divisors(6)[:-1]) + True + + References + ========== + + .. [1] https://mathworld.wolfram.com/PerfectNumber.html + .. [2] https://en.wikipedia.org/wiki/Perfect_number + + """ + n = as_int(n) + if n < 1: + return False + if n % 2 == 0: + m = (n.bit_length() + 1) >> 1 + if (1 << (m - 1)) * ((1 << m) - 1) != n: + # Even perfect numbers must be of the form `2^{m-1}(2^m-1)` + return False + return m in MERSENNE_PRIME_EXPONENTS or is_mersenne_prime(2**m - 1) + + # n is an odd integer + if n < 10**2000: # https://www.lirmm.fr/~ochem/opn/ + return False + if n % 105 == 0: # not divis by 105 + return False + if all(n % m != r for m, r in [(12, 1), (468, 117), (324, 81)]): + return False + # there are many criteria that the factor structure of n + # must meet; since we will have to factor it to test the + # structure we will have the factors and can then check + # to see whether it is a perfect number or not. So we + # skip the structure checks and go straight to the final + # test below. + result = abundance(n) == 0 + if result: + raise ValueError(filldedent('''In 1888, Sylvester stated: " + ...a prolonged meditation on the subject has satisfied + me that the existence of any one such [odd perfect number] + -- its escape, so to say, from the complex web of conditions + which hem it in on all sides -- would be little short of a + miracle." I guess SymPy just found that miracle and it + factors like this: %s''' % factorint(n))) + return result + + +def abundance(n): + """Returns the difference between the sum of the positive + proper divisors of a number and the number. + + Examples + ======== + + >>> from sympy.ntheory import abundance, is_perfect, is_abundant + >>> abundance(6) + 0 + >>> is_perfect(6) + True + >>> abundance(10) + -2 + >>> is_abundant(10) + False + """ + return _divisor_sigma(n) - 2 * n + + +def is_abundant(n): + """Returns True if ``n`` is an abundant number, else False. + + A abundant number is smaller than the sum of its positive proper divisors. + + Examples + ======== + + >>> from sympy.ntheory.factor_ import is_abundant + >>> is_abundant(20) + True + >>> is_abundant(15) + False + + References + ========== + + .. [1] https://mathworld.wolfram.com/AbundantNumber.html + + """ + n = as_int(n) + if is_perfect(n): + return False + return n % 6 == 0 or bool(abundance(n) > 0) + + +def is_deficient(n): + """Returns True if ``n`` is a deficient number, else False. + + A deficient number is greater than the sum of its positive proper divisors. + + Examples + ======== + + >>> from sympy.ntheory.factor_ import is_deficient + >>> is_deficient(20) + False + >>> is_deficient(15) + True + + References + ========== + + .. [1] https://mathworld.wolfram.com/DeficientNumber.html + + """ + n = as_int(n) + if is_perfect(n): + return False + return bool(abundance(n) < 0) + + +def is_amicable(m, n): + """Returns True if the numbers `m` and `n` are "amicable", else False. + + Amicable numbers are two different numbers so related that the sum + of the proper divisors of each is equal to that of the other. + + Examples + ======== + + >>> from sympy.functions.combinatorial.numbers import divisor_sigma + >>> from sympy.ntheory.factor_ import is_amicable + >>> is_amicable(220, 284) + True + >>> divisor_sigma(220) == divisor_sigma(284) + True + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Amicable_numbers + + """ + return m != n and m + n == _divisor_sigma(m) == _divisor_sigma(n) + + +def is_carmichael(n): + """ Returns True if the numbers `n` is Carmichael number, else False. + + Parameters + ========== + + n : Integer + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Carmichael_number + .. [2] https://oeis.org/A002997 + + """ + if n < 561: + return False + return n % 2 and not isprime(n) and \ + all(e == 1 and (n - 1) % (p - 1) == 0 for p, e in factorint(n).items()) + + +def find_carmichael_numbers_in_range(x, y): + """ Returns a list of the number of Carmichael in the range + + See Also + ======== + + is_carmichael + + """ + if 0 <= x <= y: + if x % 2 == 0: + return [i for i in range(x + 1, y, 2) if is_carmichael(i)] + else: + return [i for i in range(x, y, 2) if is_carmichael(i)] + else: + raise ValueError('The provided range is not valid. x and y must be non-negative integers and x <= y') + + +def find_first_n_carmichaels(n): + """ Returns the first n Carmichael numbers. + + Parameters + ========== + + n : Integer + + See Also + ======== + + is_carmichael + + """ + i = 561 + carmichaels = [] + + while len(carmichaels) < n: + if is_carmichael(i): + carmichaels.append(i) + i += 2 + + return carmichaels + + +def dra(n, b): + """ + Returns the additive digital root of a natural number ``n`` in base ``b`` + which is a single digit value obtained by an iterative process of summing + digits, on each iteration using the result from the previous iteration to + compute a digit sum. + + Examples + ======== + + >>> from sympy.ntheory.factor_ import dra + >>> dra(3110, 12) + 8 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Digital_root + + """ + + num = abs(as_int(n)) + b = as_int(b) + if b <= 1: + raise ValueError("Base should be an integer greater than 1") + + if num == 0: + return 0 + + return (1 + (num - 1) % (b - 1)) + + +def drm(n, b): + """ + Returns the multiplicative digital root of a natural number ``n`` in a given + base ``b`` which is a single digit value obtained by an iterative process of + multiplying digits, on each iteration using the result from the previous + iteration to compute the digit multiplication. + + Examples + ======== + + >>> from sympy.ntheory.factor_ import drm + >>> drm(9876, 10) + 0 + + >>> drm(49, 10) + 8 + + References + ========== + + .. [1] https://mathworld.wolfram.com/MultiplicativeDigitalRoot.html + + """ + + n = abs(as_int(n)) + b = as_int(b) + if b <= 1: + raise ValueError("Base should be an integer greater than 1") + while n > b: + mul = 1 + while n > 1: + n, r = divmod(n, b) + if r == 0: + return 0 + mul *= r + n = mul + return n diff --git a/phivenv/Lib/site-packages/sympy/ntheory/generate.py b/phivenv/Lib/site-packages/sympy/ntheory/generate.py new file mode 100644 index 0000000000000000000000000000000000000000..855bb44acfcb6241e6b0bcb81e7a2cfc8ced861f --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/ntheory/generate.py @@ -0,0 +1,1157 @@ +""" +Generating and counting primes. + +""" + +from bisect import bisect, bisect_left +from itertools import count +# Using arrays for sieving instead of lists greatly reduces +# memory consumption +from array import array as _array + +from sympy.core.random import randint +from sympy.external.gmpy import sqrt +from .primetest import isprime +from sympy.utilities.decorator import deprecated +from sympy.utilities.misc import as_int + + +def _as_int_ceiling(a): + """ Wrapping ceiling in as_int will raise an error if there was a problem + determining whether the expression was exactly an integer or not.""" + from sympy.functions.elementary.integers import ceiling + return as_int(ceiling(a)) + + +class Sieve: + """A list of prime numbers, implemented as a dynamically + growing sieve of Eratosthenes. When a lookup is requested involving + an odd number that has not been sieved, the sieve is automatically + extended up to that number. Implementation details limit the number of + primes to ``2^32-1``. + + Examples + ======== + + >>> from sympy import sieve + >>> sieve._reset() # this line for doctest only + >>> 25 in sieve + False + >>> sieve._list + array('L', [2, 3, 5, 7, 11, 13, 17, 19, 23]) + """ + + # data shared (and updated) by all Sieve instances + def __init__(self, sieve_interval=1_000_000): + """ Initial parameters for the Sieve class. + + Parameters + ========== + + sieve_interval (int): Amount of memory to be used + + Raises + ====== + + ValueError + If ``sieve_interval`` is not positive. + + """ + self._n = 6 + self._list = _array('L', [2, 3, 5, 7, 11, 13]) # primes + self._tlist = _array('L', [0, 1, 1, 2, 2, 4]) # totient + self._mlist = _array('i', [0, 1, -1, -1, 0, -1]) # mobius + if sieve_interval <= 0: + raise ValueError("sieve_interval should be a positive integer") + self.sieve_interval = sieve_interval + assert all(len(i) == self._n for i in (self._list, self._tlist, self._mlist)) + + def __repr__(self): + return ("<%s sieve (%i): %i, %i, %i, ... %i, %i\n" + "%s sieve (%i): %i, %i, %i, ... %i, %i\n" + "%s sieve (%i): %i, %i, %i, ... %i, %i>") % ( + 'prime', len(self._list), + self._list[0], self._list[1], self._list[2], + self._list[-2], self._list[-1], + 'totient', len(self._tlist), + self._tlist[0], self._tlist[1], + self._tlist[2], self._tlist[-2], self._tlist[-1], + 'mobius', len(self._mlist), + self._mlist[0], self._mlist[1], + self._mlist[2], self._mlist[-2], self._mlist[-1]) + + def _reset(self, prime=None, totient=None, mobius=None): + """Reset all caches (default). To reset one or more set the + desired keyword to True.""" + if all(i is None for i in (prime, totient, mobius)): + prime = totient = mobius = True + if prime: + self._list = self._list[:self._n] + if totient: + self._tlist = self._tlist[:self._n] + if mobius: + self._mlist = self._mlist[:self._n] + + def extend(self, n): + """Grow the sieve to cover all primes <= n. + + Examples + ======== + + >>> from sympy import sieve + >>> sieve._reset() # this line for doctest only + >>> sieve.extend(30) + >>> sieve[10] == 29 + True + """ + n = int(n) + # `num` is even at any point in the function. + # This satisfies the condition required by `self._primerange`. + num = self._list[-1] + 1 + if n < num: + return + num2 = num**2 + while num2 <= n: + self._list += _array('L', self._primerange(num, num2)) + num, num2 = num2, num2**2 + # Merge the sieves + self._list += _array('L', self._primerange(num, n + 1)) + + def _primerange(self, a, b): + """ Generate all prime numbers in the range (a, b). + + Parameters + ========== + + a, b : positive integers assuming the following conditions + * a is an even number + * 2 < self._list[-1] < a < b < nextprime(self._list[-1])**2 + + Yields + ====== + + p (int): prime numbers such that ``a < p < b`` + + Examples + ======== + + >>> from sympy.ntheory.generate import Sieve + >>> s = Sieve() + >>> s._list[-1] + 13 + >>> list(s._primerange(18, 31)) + [19, 23, 29] + + """ + if b % 2: + b -= 1 + while a < b: + block_size = min(self.sieve_interval, (b - a) // 2) + # Create the list such that block[x] iff (a + 2x + 1) is prime. + # Note that even numbers are not considered here. + block = [True] * block_size + for p in self._list[1:bisect(self._list, sqrt(a + 2 * block_size + 1))]: + for t in range((-(a + 1 + p) // 2) % p, block_size, p): + block[t] = False + for idx, p in enumerate(block): + if p: + yield a + 2 * idx + 1 + a += 2 * block_size + + def extend_to_no(self, i): + """Extend to include the ith prime number. + + Parameters + ========== + + i : integer + + Examples + ======== + + >>> from sympy import sieve + >>> sieve._reset() # this line for doctest only + >>> sieve.extend_to_no(9) + >>> sieve._list + array('L', [2, 3, 5, 7, 11, 13, 17, 19, 23]) + + Notes + ===== + + The list is extended by 50% if it is too short, so it is + likely that it will be longer than requested. + """ + i = as_int(i) + while len(self._list) < i: + self.extend(int(self._list[-1] * 1.5)) + + def primerange(self, a, b=None): + """Generate all prime numbers in the range [2, a) or [a, b). + + Examples + ======== + + >>> from sympy import sieve, prime + + All primes less than 19: + + >>> print([i for i in sieve.primerange(19)]) + [2, 3, 5, 7, 11, 13, 17] + + All primes greater than or equal to 7 and less than 19: + + >>> print([i for i in sieve.primerange(7, 19)]) + [7, 11, 13, 17] + + All primes through the 10th prime + + >>> list(sieve.primerange(prime(10) + 1)) + [2, 3, 5, 7, 11, 13, 17, 19, 23, 29] + + """ + if b is None: + b = _as_int_ceiling(a) + a = 2 + else: + a = max(2, _as_int_ceiling(a)) + b = _as_int_ceiling(b) + if a >= b: + return + self.extend(b) + yield from self._list[bisect_left(self._list, a): + bisect_left(self._list, b)] + + def totientrange(self, a, b): + """Generate all totient numbers for the range [a, b). + + Examples + ======== + + >>> from sympy import sieve + >>> print([i for i in sieve.totientrange(7, 18)]) + [6, 4, 6, 4, 10, 4, 12, 6, 8, 8, 16] + """ + a = max(1, _as_int_ceiling(a)) + b = _as_int_ceiling(b) + n = len(self._tlist) + if a >= b: + return + elif b <= n: + for i in range(a, b): + yield self._tlist[i] + else: + self._tlist += _array('L', range(n, b)) + for i in range(1, n): + ti = self._tlist[i] + if ti == i - 1: + startindex = (n + i - 1) // i * i + for j in range(startindex, b, i): + self._tlist[j] -= self._tlist[j] // i + if i >= a: + yield ti + + for i in range(n, b): + ti = self._tlist[i] + if ti == i: + for j in range(i, b, i): + self._tlist[j] -= self._tlist[j] // i + if i >= a: + yield self._tlist[i] + + def mobiusrange(self, a, b): + """Generate all mobius numbers for the range [a, b). + + Parameters + ========== + + a : integer + First number in range + + b : integer + First number outside of range + + Examples + ======== + + >>> from sympy import sieve + >>> print([i for i in sieve.mobiusrange(7, 18)]) + [-1, 0, 0, 1, -1, 0, -1, 1, 1, 0, -1] + """ + a = max(1, _as_int_ceiling(a)) + b = _as_int_ceiling(b) + n = len(self._mlist) + if a >= b: + return + elif b <= n: + for i in range(a, b): + yield self._mlist[i] + else: + self._mlist += _array('i', [0]*(b - n)) + for i in range(1, n): + mi = self._mlist[i] + startindex = (n + i - 1) // i * i + for j in range(startindex, b, i): + self._mlist[j] -= mi + if i >= a: + yield mi + + for i in range(n, b): + mi = self._mlist[i] + for j in range(2 * i, b, i): + self._mlist[j] -= mi + if i >= a: + yield mi + + def search(self, n): + """Return the indices i, j of the primes that bound n. + + If n is prime then i == j. + + Although n can be an expression, if ceiling cannot convert + it to an integer then an n error will be raised. + + Examples + ======== + + >>> from sympy import sieve + >>> sieve.search(25) + (9, 10) + >>> sieve.search(23) + (9, 9) + """ + test = _as_int_ceiling(n) + n = as_int(n) + if n < 2: + raise ValueError("n should be >= 2 but got: %s" % n) + if n > self._list[-1]: + self.extend(n) + b = bisect(self._list, n) + if self._list[b - 1] == test: + return b, b + else: + return b, b + 1 + + def __contains__(self, n): + try: + n = as_int(n) + assert n >= 2 + except (ValueError, AssertionError): + return False + if n % 2 == 0: + return n == 2 + a, b = self.search(n) + return a == b + + def __iter__(self): + for n in count(1): + yield self[n] + + def __getitem__(self, n): + """Return the nth prime number""" + if isinstance(n, slice): + self.extend_to_no(n.stop) + start = n.start if n.start is not None else 0 + if start < 1: + # sieve[:5] would be empty (starting at -1), let's + # just be explicit and raise. + raise IndexError("Sieve indices start at 1.") + return self._list[start - 1:n.stop - 1:n.step] + else: + if n < 1: + # offset is one, so forbid explicit access to sieve[0] + # (would surprisingly return the last one). + raise IndexError("Sieve indices start at 1.") + n = as_int(n) + self.extend_to_no(n) + return self._list[n - 1] + +# Generate a global object for repeated use in trial division etc +sieve = Sieve() + +def prime(nth): + r""" + Return the nth prime number, where primes are indexed starting from 1: + prime(1) = 2, prime(2) = 3, etc. + + Parameters + ========== + + nth : int + The position of the prime number to return (must be a positive integer). + + Returns + ======= + + int + The nth prime number. + + Examples + ======== + + >>> from sympy import prime + >>> prime(10) + 29 + >>> prime(1) + 2 + >>> prime(100000) + 1299709 + + See Also + ======== + + sympy.ntheory.primetest.isprime : Test if a number is prime. + primerange : Generate all primes in a given range. + primepi : Return the number of primes less than or equal to a given number. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Prime_number_theorem + .. [2] https://en.wikipedia.org/wiki/Logarithmic_integral_function + .. [3] https://en.wikipedia.org/wiki/Skewes%27_number + """ + n = as_int(nth) + if n < 1: + raise ValueError("nth must be a positive integer; prime(1) == 2") + + # Check if n is within the sieve range + if n <= len(sieve._list): + return sieve[n] + + from sympy.functions.elementary.exponential import log + from sympy.functions.special.error_functions import li + + if n < 1000: + # Extend sieve up to 8*n as this is empirically sufficient + sieve.extend(8 * n) + return sieve[n] + + a = 2 + # Estimate an upper bound for the nth prime using the prime number theorem + b = int(n * (log(n).evalf() + log(log(n)).evalf())) + + # Binary search for the least m such that li(m) > n + while a < b: + mid = (a + b) >> 1 + if li(mid).evalf() > n: + b = mid + else: + a = mid + 1 + + return nextprime(a - 1, n - _primepi(a - 1)) + + +@deprecated("""\ +The `sympy.ntheory.generate.primepi` has been moved to `sympy.functions.combinatorial.numbers.primepi`.""", +deprecated_since_version="1.13", +active_deprecations_target='deprecated-ntheory-symbolic-functions') +def primepi(n): + r""" Represents the prime counting function pi(n) = the number + of prime numbers less than or equal to n. + + .. deprecated:: 1.13 + + The ``primepi`` function is deprecated. Use :class:`sympy.functions.combinatorial.numbers.primepi` + instead. See its documentation for more information. See + :ref:`deprecated-ntheory-symbolic-functions` for details. + + Algorithm Description: + + In sieve method, we remove all multiples of prime p + except p itself. + + Let phi(i,j) be the number of integers 2 <= k <= i + which remain after sieving from primes less than + or equal to j. + Clearly, pi(n) = phi(n, sqrt(n)) + + If j is not a prime, + phi(i,j) = phi(i, j - 1) + + if j is a prime, + We remove all numbers(except j) whose + smallest prime factor is j. + + Let $x= j \times a$ be such a number, where $2 \le a \le i / j$ + Now, after sieving from primes $\le j - 1$, + a must remain + (because x, and hence a has no prime factor $\le j - 1$) + Clearly, there are phi(i / j, j - 1) such a + which remain on sieving from primes $\le j - 1$ + + Now, if a is a prime less than equal to j - 1, + $x= j \times a$ has smallest prime factor = a, and + has already been removed(by sieving from a). + So, we do not need to remove it again. + (Note: there will be pi(j - 1) such x) + + Thus, number of x, that will be removed are: + phi(i / j, j - 1) - phi(j - 1, j - 1) + (Note that pi(j - 1) = phi(j - 1, j - 1)) + + $\Rightarrow$ phi(i,j) = phi(i, j - 1) - phi(i / j, j - 1) + phi(j - 1, j - 1) + + So,following recursion is used and implemented as dp: + + phi(a, b) = phi(a, b - 1), if b is not a prime + phi(a, b) = phi(a, b-1)-phi(a / b, b-1) + phi(b-1, b-1), if b is prime + + Clearly a is always of the form floor(n / k), + which can take at most $2\sqrt{n}$ values. + Two arrays arr1,arr2 are maintained + arr1[i] = phi(i, j), + arr2[i] = phi(n // i, j) + + Finally the answer is arr2[1] + + Examples + ======== + + >>> from sympy import primepi, prime, prevprime, isprime + >>> primepi(25) + 9 + + So there are 9 primes less than or equal to 25. Is 25 prime? + + >>> isprime(25) + False + + It is not. So the first prime less than 25 must be the + 9th prime: + + >>> prevprime(25) == prime(9) + True + + See Also + ======== + + sympy.ntheory.primetest.isprime : Test if n is prime + primerange : Generate all primes in a given range + prime : Return the nth prime + """ + from sympy.functions.combinatorial.numbers import primepi as func_primepi + return func_primepi(n) + + +def _primepi(n:int) -> int: + r""" Represents the prime counting function pi(n) = the number + of prime numbers less than or equal to n. + + Explanation + =========== + + In sieve method, we remove all multiples of prime p + except p itself. + + Let phi(i,j) be the number of integers 2 <= k <= i + which remain after sieving from primes less than + or equal to j. + Clearly, pi(n) = phi(n, sqrt(n)) + + If j is not a prime, + phi(i,j) = phi(i, j - 1) + + if j is a prime, + We remove all numbers(except j) whose + smallest prime factor is j. + + Let $x= j \times a$ be such a number, where $2 \le a \le i / j$ + Now, after sieving from primes $\le j - 1$, + a must remain + (because x, and hence a has no prime factor $\le j - 1$) + Clearly, there are phi(i / j, j - 1) such a + which remain on sieving from primes $\le j - 1$ + + Now, if a is a prime less than equal to j - 1, + $x= j \times a$ has smallest prime factor = a, and + has already been removed(by sieving from a). + So, we do not need to remove it again. + (Note: there will be pi(j - 1) such x) + + Thus, number of x, that will be removed are: + phi(i / j, j - 1) - phi(j - 1, j - 1) + (Note that pi(j - 1) = phi(j - 1, j - 1)) + + $\Rightarrow$ phi(i,j) = phi(i, j - 1) - phi(i / j, j - 1) + phi(j - 1, j - 1) + + So,following recursion is used and implemented as dp: + + phi(a, b) = phi(a, b - 1), if b is not a prime + phi(a, b) = phi(a, b-1)-phi(a / b, b-1) + phi(b-1, b-1), if b is prime + + Clearly a is always of the form floor(n / k), + which can take at most $2\sqrt{n}$ values. + Two arrays arr1,arr2 are maintained + arr1[i] = phi(i, j), + arr2[i] = phi(n // i, j) + + Finally the answer is arr2[1] + + Parameters + ========== + + n : int + + """ + if n < 2: + return 0 + if n <= sieve._list[-1]: + return sieve.search(n)[0] + lim = sqrt(n) + arr1 = [(i + 1) >> 1 for i in range(lim + 1)] + arr2 = [0] + [(n//i + 1) >> 1 for i in range(1, lim + 1)] + skip = [False] * (lim + 1) + for i in range(3, lim + 1, 2): + # Presently, arr1[k]=phi(k,i - 1), + # arr2[k] = phi(n // k,i - 1) # not all k's do this + if skip[i]: + # skip if i is a composite number + continue + p = arr1[i - 1] + for j in range(i, lim + 1, i): + skip[j] = True + # update arr2 + # phi(n/j, i) = phi(n/j, i-1) - phi(n/(i*j), i-1) + phi(i-1, i-1) + for j in range(1, min(n // (i * i), lim) + 1, 2): + # No need for arr2[j] in j such that skip[j] is True to + # compute the final required arr2[1]. + if skip[j]: + continue + st = i * j + if st <= lim: + arr2[j] -= arr2[st] - p + else: + arr2[j] -= arr1[n // st] - p + # update arr1 + # phi(j, i) = phi(j, i-1) - phi(j/i, i-1) + phi(i-1, i-1) + # where the range below i**2 is fixed and + # does not need to be calculated. + for j in range(lim, min(lim, i*i - 1), -1): + arr1[j] -= arr1[j // i] - p + return arr2[1] + + +def nextprime(n, ith=1): + """ Return the ith prime greater than n. + + Parameters + ========== + + n : integer + ith : positive integer + + Returns + ======= + + int : Return the ith prime greater than n + + Raises + ====== + + ValueError + If ``ith <= 0``. + If ``n`` or ``ith`` is not an integer. + + Notes + ===== + + Potential primes are located at 6*j +/- 1. This + property is used during searching. + + >>> from sympy import nextprime + >>> [(i, nextprime(i)) for i in range(10, 15)] + [(10, 11), (11, 13), (12, 13), (13, 17), (14, 17)] + >>> nextprime(2, ith=2) # the 2nd prime after 2 + 5 + + See Also + ======== + + prevprime : Return the largest prime smaller than n + primerange : Generate all primes in a given range + + """ + n = int(n) + i = as_int(ith) + if i <= 0: + raise ValueError("ith should be positive") + if n < 2: + n = 2 + i -= 1 + if n <= sieve._list[-2]: + l, _ = sieve.search(n) + if l + i - 1 < len(sieve._list): + return sieve._list[l + i - 1] + n = sieve._list[-1] + i += l - len(sieve._list) + nn = 6*(n//6) + if nn == n: + n += 1 + if isprime(n): + i -= 1 + if not i: + return n + n += 4 + elif n - nn == 5: + n += 2 + if isprime(n): + i -= 1 + if not i: + return n + n += 4 + else: + n = nn + 5 + while 1: + if isprime(n): + i -= 1 + if not i: + return n + n += 2 + if isprime(n): + i -= 1 + if not i: + return n + n += 4 + + +def prevprime(n): + """ Return the largest prime smaller than n. + + Notes + ===== + + Potential primes are located at 6*j +/- 1. This + property is used during searching. + + >>> from sympy import prevprime + >>> [(i, prevprime(i)) for i in range(10, 15)] + [(10, 7), (11, 7), (12, 11), (13, 11), (14, 13)] + + See Also + ======== + + nextprime : Return the ith prime greater than n + primerange : Generates all primes in a given range + """ + n = _as_int_ceiling(n) + if n < 3: + raise ValueError("no preceding primes") + if n < 8: + return {3: 2, 4: 3, 5: 3, 6: 5, 7: 5}[n] + if n <= sieve._list[-1]: + l, u = sieve.search(n) + if l == u: + return sieve[l-1] + else: + return sieve[l] + nn = 6*(n//6) + if n - nn <= 1: + n = nn - 1 + if isprime(n): + return n + n -= 4 + else: + n = nn + 1 + while 1: + if isprime(n): + return n + n -= 2 + if isprime(n): + return n + n -= 4 + + +def primerange(a, b=None): + """ Generate a list of all prime numbers in the range [2, a), + or [a, b). + + If the range exists in the default sieve, the values will + be returned from there; otherwise values will be returned + but will not modify the sieve. + + Examples + ======== + + >>> from sympy import primerange, prime + + All primes less than 19: + + >>> list(primerange(19)) + [2, 3, 5, 7, 11, 13, 17] + + All primes greater than or equal to 7 and less than 19: + + >>> list(primerange(7, 19)) + [7, 11, 13, 17] + + All primes through the 10th prime + + >>> list(primerange(prime(10) + 1)) + [2, 3, 5, 7, 11, 13, 17, 19, 23, 29] + + The Sieve method, primerange, is generally faster but it will + occupy more memory as the sieve stores values. The default + instance of Sieve, named sieve, can be used: + + >>> from sympy import sieve + >>> list(sieve.primerange(1, 30)) + [2, 3, 5, 7, 11, 13, 17, 19, 23, 29] + + Notes + ===== + + Some famous conjectures about the occurrence of primes in a given + range are [1]: + + - Twin primes: though often not, the following will give 2 primes + an infinite number of times: + primerange(6*n - 1, 6*n + 2) + - Legendre's: the following always yields at least one prime + primerange(n**2, (n+1)**2+1) + - Bertrand's (proven): there is always a prime in the range + primerange(n, 2*n) + - Brocard's: there are at least four primes in the range + primerange(prime(n)**2, prime(n+1)**2) + + The average gap between primes is log(n) [2]; the gap between + primes can be arbitrarily large since sequences of composite + numbers are arbitrarily large, e.g. the numbers in the sequence + n! + 2, n! + 3 ... n! + n are all composite. + + See Also + ======== + + prime : Return the nth prime + nextprime : Return the ith prime greater than n + prevprime : Return the largest prime smaller than n + randprime : Returns a random prime in a given range + primorial : Returns the product of primes based on condition + Sieve.primerange : return range from already computed primes + or extend the sieve to contain the requested + range. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Prime_number + .. [2] https://primes.utm.edu/notes/gaps.html + """ + if b is None: + a, b = 2, a + if a >= b: + return + # If we already have the range, return it. + largest_known_prime = sieve._list[-1] + if b <= largest_known_prime: + yield from sieve.primerange(a, b) + return + # If we know some of it, return it. + if a <= largest_known_prime: + yield from sieve._list[bisect_left(sieve._list, a):] + a = largest_known_prime + 1 + elif a % 2: + a -= 1 + tail = min(b, (largest_known_prime)**2) + if a < tail: + yield from sieve._primerange(a, tail) + a = tail + if b <= a: + return + # otherwise compute, without storing, the desired range. + while 1: + a = nextprime(a) + if a < b: + yield a + else: + return + + +def randprime(a, b): + """ Return a random prime number in the range [a, b). + + Bertrand's postulate assures that + randprime(a, 2*a) will always succeed for a > 1. + + Note that due to implementation difficulties, + the prime numbers chosen are not uniformly random. + For example, there are two primes in the range [112, 128), + ``113`` and ``127``, but ``randprime(112, 128)`` returns ``127`` + with a probability of 15/17. + + Examples + ======== + + >>> from sympy import randprime, isprime + >>> randprime(1, 30) #doctest: +SKIP + 13 + >>> isprime(randprime(1, 30)) + True + + See Also + ======== + + primerange : Generate all primes in a given range + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Bertrand's_postulate + + """ + if a >= b: + return + a, b = map(int, (a, b)) + n = randint(a - 1, b) + p = nextprime(n) + if p >= b: + p = prevprime(b) + if p < a: + raise ValueError("no primes exist in the specified range") + return p + + +def primorial(n, nth=True): + """ + Returns the product of the first n primes (default) or + the primes less than or equal to n (when ``nth=False``). + + Examples + ======== + + >>> from sympy.ntheory.generate import primorial, primerange + >>> from sympy import factorint, Mul, primefactors, sqrt + >>> primorial(4) # the first 4 primes are 2, 3, 5, 7 + 210 + >>> primorial(4, nth=False) # primes <= 4 are 2 and 3 + 6 + >>> primorial(1) + 2 + >>> primorial(1, nth=False) + 1 + >>> primorial(sqrt(101), nth=False) + 210 + + One can argue that the primes are infinite since if you take + a set of primes and multiply them together (e.g. the primorial) and + then add or subtract 1, the result cannot be divided by any of the + original factors, hence either 1 or more new primes must divide this + product of primes. + + In this case, the number itself is a new prime: + + >>> factorint(primorial(4) + 1) + {211: 1} + + In this case two new primes are the factors: + + >>> factorint(primorial(4) - 1) + {11: 1, 19: 1} + + Here, some primes smaller and larger than the primes multiplied together + are obtained: + + >>> p = list(primerange(10, 20)) + >>> sorted(set(primefactors(Mul(*p) + 1)).difference(set(p))) + [2, 5, 31, 149] + + See Also + ======== + + primerange : Generate all primes in a given range + + """ + if nth: + n = as_int(n) + else: + n = int(n) + if n < 1: + raise ValueError("primorial argument must be >= 1") + p = 1 + if nth: + for i in range(1, n + 1): + p *= prime(i) + else: + for i in primerange(2, n + 1): + p *= i + return p + + +def cycle_length(f, x0, nmax=None, values=False): + """For a given iterated sequence, return a generator that gives + the length of the iterated cycle (lambda) and the length of terms + before the cycle begins (mu); if ``values`` is True then the + terms of the sequence will be returned instead. The sequence is + started with value ``x0``. + + Note: more than the first lambda + mu terms may be returned and this + is the cost of cycle detection with Brent's method; there are, however, + generally less terms calculated than would have been calculated if the + proper ending point were determined, e.g. by using Floyd's method. + + >>> from sympy.ntheory.generate import cycle_length + + This will yield successive values of i <-- func(i): + + >>> def gen(func, i): + ... while 1: + ... yield i + ... i = func(i) + ... + + A function is defined: + + >>> func = lambda i: (i**2 + 1) % 51 + + and given a seed of 4 and the mu and lambda terms calculated: + + >>> next(cycle_length(func, 4)) + (6, 3) + + We can see what is meant by looking at the output: + + >>> iter = cycle_length(func, 4, values=True) + >>> list(iter) + [4, 17, 35, 2, 5, 26, 14, 44, 50, 2, 5, 26, 14] + + There are 6 repeating values after the first 3. + + If a sequence is suspected of being longer than you might wish, ``nmax`` + can be used to exit early (and mu will be returned as None): + + >>> next(cycle_length(func, 4, nmax = 4)) + (4, None) + >>> list(cycle_length(func, 4, nmax = 4, values=True)) + [4, 17, 35, 2] + + Code modified from: + https://en.wikipedia.org/wiki/Cycle_detection. + """ + + nmax = int(nmax or 0) + + # main phase: search successive powers of two + power = lam = 1 + tortoise, hare = x0, f(x0) # f(x0) is the element/node next to x0. + i = 1 + if values: + yield tortoise + while tortoise != hare and (not nmax or i < nmax): + i += 1 + if power == lam: # time to start a new power of two? + tortoise = hare + power *= 2 + lam = 0 + if values: + yield hare + hare = f(hare) + lam += 1 + if nmax and i == nmax: + if values: + return + else: + yield nmax, None + return + if not values: + # Find the position of the first repetition of length lambda + mu = 0 + tortoise = hare = x0 + for i in range(lam): + hare = f(hare) + while tortoise != hare: + tortoise = f(tortoise) + hare = f(hare) + mu += 1 + yield lam, mu + + +def composite(nth): + """ Return the nth composite number, with the composite numbers indexed as + composite(1) = 4, composite(2) = 6, etc.... + + Examples + ======== + + >>> from sympy import composite + >>> composite(36) + 52 + >>> composite(1) + 4 + >>> composite(17737) + 20000 + + See Also + ======== + + sympy.ntheory.primetest.isprime : Test if n is prime + primerange : Generate all primes in a given range + primepi : Return the number of primes less than or equal to n + prime : Return the nth prime + compositepi : Return the number of positive composite numbers less than or equal to n + """ + n = as_int(nth) + if n < 1: + raise ValueError("nth must be a positive integer; composite(1) == 4") + composite_arr = [4, 6, 8, 9, 10, 12, 14, 15, 16, 18] + if n <= 10: + return composite_arr[n - 1] + + a, b = 4, sieve._list[-1] + if n <= b - _primepi(b) - 1: + while a < b - 1: + mid = (a + b) >> 1 + if mid - _primepi(mid) - 1 > n: + b = mid + else: + a = mid + if isprime(a): + a -= 1 + return a + + from sympy.functions.elementary.exponential import log + from sympy.functions.special.error_functions import li + a = 4 # Lower bound for binary search + b = int(n*(log(n) + log(log(n)))) # Upper bound for the search. + + while a < b: + mid = (a + b) >> 1 + if mid - li(mid) - 1 > n: + b = mid + else: + a = mid + 1 + + n_composites = a - _primepi(a) - 1 + while n_composites > n: + if not isprime(a): + n_composites -= 1 + a -= 1 + if isprime(a): + a -= 1 + return a + + +def compositepi(n): + """ Return the number of positive composite numbers less than or equal to n. + The first positive composite is 4, i.e. compositepi(4) = 1. + + Examples + ======== + + >>> from sympy import compositepi + >>> compositepi(25) + 15 + >>> compositepi(1000) + 831 + + See Also + ======== + + sympy.ntheory.primetest.isprime : Test if n is prime + primerange : Generate all primes in a given range + prime : Return the nth prime + primepi : Return the number of primes less than or equal to n + composite : Return the nth composite number + """ + n = int(n) + if n < 4: + return 0 + return n - _primepi(n) - 1 diff --git a/phivenv/Lib/site-packages/sympy/ntheory/modular.py b/phivenv/Lib/site-packages/sympy/ntheory/modular.py new file mode 100644 index 0000000000000000000000000000000000000000..628a3d8c5a7fb4b6c51ad337df66d74f90282496 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/ntheory/modular.py @@ -0,0 +1,291 @@ +from math import prod + +from sympy.external.gmpy import gcd, gcdext +from sympy.ntheory.primetest import isprime +from sympy.polys.domains import ZZ +from sympy.polys.galoistools import gf_crt, gf_crt1, gf_crt2 +from sympy.utilities.misc import as_int + + +def symmetric_residue(a, m): + """Return the residual mod m such that it is within half of the modulus. + + >>> from sympy.ntheory.modular import symmetric_residue + >>> symmetric_residue(1, 6) + 1 + >>> symmetric_residue(4, 6) + -2 + """ + if a <= m // 2: + return a + return a - m + + +def crt(m, v, symmetric=False, check=True): + r"""Chinese Remainder Theorem. + + The moduli in m are assumed to be pairwise coprime. The output + is then an integer f, such that f = v_i mod m_i for each pair out + of v and m. If ``symmetric`` is False a positive integer will be + returned, else \|f\| will be less than or equal to the LCM of the + moduli, and thus f may be negative. + + If the moduli are not co-prime the correct result will be returned + if/when the test of the result is found to be incorrect. This result + will be None if there is no solution. + + The keyword ``check`` can be set to False if it is known that the moduli + are coprime. + + Examples + ======== + + As an example consider a set of residues ``U = [49, 76, 65]`` + and a set of moduli ``M = [99, 97, 95]``. Then we have:: + + >>> from sympy.ntheory.modular import crt + + >>> crt([99, 97, 95], [49, 76, 65]) + (639985, 912285) + + This is the correct result because:: + + >>> [639985 % m for m in [99, 97, 95]] + [49, 76, 65] + + If the moduli are not co-prime, you may receive an incorrect result + if you use ``check=False``: + + >>> crt([12, 6, 17], [3, 4, 2], check=False) + (954, 1224) + >>> [954 % m for m in [12, 6, 17]] + [6, 0, 2] + >>> crt([12, 6, 17], [3, 4, 2]) is None + True + >>> crt([3, 6], [2, 5]) + (5, 6) + + Note: the order of gf_crt's arguments is reversed relative to crt, + and that solve_congruence takes residue, modulus pairs. + + Programmer's note: rather than checking that all pairs of moduli share + no GCD (an O(n**2) test) and rather than factoring all moduli and seeing + that there is no factor in common, a check that the result gives the + indicated residuals is performed -- an O(n) operation. + + See Also + ======== + + solve_congruence + sympy.polys.galoistools.gf_crt : low level crt routine used by this routine + """ + if check: + m = list(map(as_int, m)) + v = list(map(as_int, v)) + + result = gf_crt(v, m, ZZ) + mm = prod(m) + + if check: + if not all(v % m == result % m for v, m in zip(v, m)): + result = solve_congruence(*list(zip(v, m)), + check=False, symmetric=symmetric) + if result is None: + return result + result, mm = result + + if symmetric: + return int(symmetric_residue(result, mm)), int(mm) + return int(result), int(mm) + + +def crt1(m): + """First part of Chinese Remainder Theorem, for multiple application. + + Examples + ======== + + >>> from sympy.ntheory.modular import crt, crt1, crt2 + >>> m = [99, 97, 95] + >>> v = [49, 76, 65] + + The following two codes have the same result. + + >>> crt(m, v) + (639985, 912285) + + >>> mm, e, s = crt1(m) + >>> crt2(m, v, mm, e, s) + (639985, 912285) + + However, it is faster when we want to fix ``m`` and + compute for multiple ``v``, i.e. the following cases: + + >>> mm, e, s = crt1(m) + >>> vs = [[52, 21, 37], [19, 46, 76]] + >>> for v in vs: + ... print(crt2(m, v, mm, e, s)) + (397042, 912285) + (803206, 912285) + + See Also + ======== + + sympy.polys.galoistools.gf_crt1 : low level crt routine used by this routine + sympy.ntheory.modular.crt + sympy.ntheory.modular.crt2 + + """ + + return gf_crt1(m, ZZ) + + +def crt2(m, v, mm, e, s, symmetric=False): + """Second part of Chinese Remainder Theorem, for multiple application. + + See ``crt1`` for usage. + + Examples + ======== + + >>> from sympy.ntheory.modular import crt1, crt2 + >>> mm, e, s = crt1([18, 42, 6]) + >>> crt2([18, 42, 6], [0, 0, 0], mm, e, s) + (0, 4536) + + See Also + ======== + + sympy.polys.galoistools.gf_crt2 : low level crt routine used by this routine + sympy.ntheory.modular.crt + sympy.ntheory.modular.crt1 + + """ + + result = gf_crt2(v, m, mm, e, s, ZZ) + + if symmetric: + return int(symmetric_residue(result, mm)), int(mm) + return int(result), int(mm) + + +def solve_congruence(*remainder_modulus_pairs, **hint): + """Compute the integer ``n`` that has the residual ``ai`` when it is + divided by ``mi`` where the ``ai`` and ``mi`` are given as pairs to + this function: ((a1, m1), (a2, m2), ...). If there is no solution, + return None. Otherwise return ``n`` and its modulus. + + The ``mi`` values need not be co-prime. If it is known that the moduli are + not co-prime then the hint ``check`` can be set to False (default=True) and + the check for a quicker solution via crt() (valid when the moduli are + co-prime) will be skipped. + + If the hint ``symmetric`` is True (default is False), the value of ``n`` + will be within 1/2 of the modulus, possibly negative. + + Examples + ======== + + >>> from sympy.ntheory.modular import solve_congruence + + What number is 2 mod 3, 3 mod 5 and 2 mod 7? + + >>> solve_congruence((2, 3), (3, 5), (2, 7)) + (23, 105) + >>> [23 % m for m in [3, 5, 7]] + [2, 3, 2] + + If you prefer to work with all remainder in one list and + all moduli in another, send the arguments like this: + + >>> solve_congruence(*zip((2, 3, 2), (3, 5, 7))) + (23, 105) + + The moduli need not be co-prime; in this case there may or + may not be a solution: + + >>> solve_congruence((2, 3), (4, 6)) is None + True + + >>> solve_congruence((2, 3), (5, 6)) + (5, 6) + + The symmetric flag will make the result be within 1/2 of the modulus: + + >>> solve_congruence((2, 3), (5, 6), symmetric=True) + (-1, 6) + + See Also + ======== + + crt : high level routine implementing the Chinese Remainder Theorem + + """ + def combine(c1, c2): + """Return the tuple (a, m) which satisfies the requirement + that n = a + i*m satisfy n = a1 + j*m1 and n = a2 = k*m2. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Method_of_successive_substitution + """ + a1, m1 = c1 + a2, m2 = c2 + a, b, c = m1, a2 - a1, m2 + g = gcd(a, b, c) + a, b, c = [i//g for i in [a, b, c]] + if a != 1: + g, inv_a, _ = gcdext(a, c) + if g != 1: + return None + b *= inv_a + a, m = a1 + m1*b, m1*c + return a, m + + rm = remainder_modulus_pairs + symmetric = hint.get('symmetric', False) + + if hint.get('check', True): + rm = [(as_int(r), as_int(m)) for r, m in rm] + + # ignore redundant pairs but raise an error otherwise; also + # make sure that a unique set of bases is sent to gf_crt if + # they are all prime. + # + # The routine will work out less-trivial violations and + # return None, e.g. for the pairs (1,3) and (14,42) there + # is no answer because 14 mod 42 (having a gcd of 14) implies + # (14/2) mod (42/2), (14/7) mod (42/7) and (14/14) mod (42/14) + # which, being 0 mod 3, is inconsistent with 1 mod 3. But to + # preprocess the input beyond checking of another pair with 42 + # or 3 as the modulus (for this example) is not necessary. + uniq = {} + for r, m in rm: + r %= m + if m in uniq: + if r != uniq[m]: + return None + continue + uniq[m] = r + rm = [(r, m) for m, r in uniq.items()] + del uniq + + # if the moduli are co-prime, the crt will be significantly faster; + # checking all pairs for being co-prime gets to be slow but a prime + # test is a good trade-off + if all(isprime(m) for r, m in rm): + r, m = list(zip(*rm)) + return crt(m, r, symmetric=symmetric, check=False) + + rv = (0, 1) + for rmi in rm: + rv = combine(rv, rmi) + if rv is None: + break + n, m = rv + n = n % m + else: + if symmetric: + return symmetric_residue(n, m), m + return n, m diff --git a/phivenv/Lib/site-packages/sympy/ntheory/multinomial.py b/phivenv/Lib/site-packages/sympy/ntheory/multinomial.py new file mode 100644 index 0000000000000000000000000000000000000000..8ec50fdb533be547b9a8e60dc47568965bf89436 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/ntheory/multinomial.py @@ -0,0 +1,188 @@ +from sympy.utilities.misc import as_int + + +def binomial_coefficients(n): + """Return a dictionary containing pairs :math:`{(k1,k2) : C_kn}` where + :math:`C_kn` are binomial coefficients and :math:`n=k1+k2`. + + Examples + ======== + + >>> from sympy.ntheory import binomial_coefficients + >>> binomial_coefficients(9) + {(0, 9): 1, (1, 8): 9, (2, 7): 36, (3, 6): 84, + (4, 5): 126, (5, 4): 126, (6, 3): 84, (7, 2): 36, (8, 1): 9, (9, 0): 1} + + See Also + ======== + + binomial_coefficients_list, multinomial_coefficients + """ + n = as_int(n) + d = {(0, n): 1, (n, 0): 1} + a = 1 + for k in range(1, n//2 + 1): + a = (a * (n - k + 1))//k + d[k, n - k] = d[n - k, k] = a + return d + + +def binomial_coefficients_list(n): + """ Return a list of binomial coefficients as rows of the Pascal's + triangle. + + Examples + ======== + + >>> from sympy.ntheory import binomial_coefficients_list + >>> binomial_coefficients_list(9) + [1, 9, 36, 84, 126, 126, 84, 36, 9, 1] + + See Also + ======== + + binomial_coefficients, multinomial_coefficients + """ + n = as_int(n) + d = [1] * (n + 1) + a = 1 + for k in range(1, n//2 + 1): + a = (a * (n - k + 1))//k + d[k] = d[n - k] = a + return d + + +def multinomial_coefficients(m, n): + r"""Return a dictionary containing pairs ``{(k1,k2,..,km) : C_kn}`` + where ``C_kn`` are multinomial coefficients such that + ``n=k1+k2+..+km``. + + Examples + ======== + + >>> from sympy.ntheory import multinomial_coefficients + >>> multinomial_coefficients(2, 5) # indirect doctest + {(0, 5): 1, (1, 4): 5, (2, 3): 10, (3, 2): 10, (4, 1): 5, (5, 0): 1} + + Notes + ===== + + The algorithm is based on the following result: + + .. math:: + \binom{n}{k_1, \ldots, k_m} = + \frac{k_1 + 1}{n - k_1} \sum_{i=2}^m \binom{n}{k_1 + 1, \ldots, k_i - 1, \ldots} + + Code contributed to Sage by Yann Laigle-Chapuy, copied with permission + of the author. + + See Also + ======== + + binomial_coefficients_list, binomial_coefficients + """ + m = as_int(m) + n = as_int(n) + if not m: + if n: + return {} + return {(): 1} + if m == 2: + return binomial_coefficients(n) + if m >= 2*n and n > 1: + return dict(multinomial_coefficients_iterator(m, n)) + t = [n] + [0] * (m - 1) + r = {tuple(t): 1} + if n: + j = 0 # j will be the leftmost nonzero position + else: + j = m + # enumerate tuples in co-lex order + while j < m - 1: + # compute next tuple + tj = t[j] + if j: + t[j] = 0 + t[0] = tj + if tj > 1: + t[j + 1] += 1 + j = 0 + start = 1 + v = 0 + else: + j += 1 + start = j + 1 + v = r[tuple(t)] + t[j] += 1 + # compute the value + # NB: the initialization of v was done above + for k in range(start, m): + if t[k]: + t[k] -= 1 + v += r[tuple(t)] + t[k] += 1 + t[0] -= 1 + r[tuple(t)] = (v * tj) // (n - t[0]) + return r + + +def multinomial_coefficients_iterator(m, n, _tuple=tuple): + """multinomial coefficient iterator + + This routine has been optimized for `m` large with respect to `n` by taking + advantage of the fact that when the monomial tuples `t` are stripped of + zeros, their coefficient is the same as that of the monomial tuples from + ``multinomial_coefficients(n, n)``. Therefore, the latter coefficients are + precomputed to save memory and time. + + >>> from sympy.ntheory.multinomial import multinomial_coefficients + >>> m53, m33 = multinomial_coefficients(5,3), multinomial_coefficients(3,3) + >>> m53[(0,0,0,1,2)] == m53[(0,0,1,0,2)] == m53[(1,0,2,0,0)] == m33[(0,1,2)] + True + + Examples + ======== + + >>> from sympy.ntheory.multinomial import multinomial_coefficients_iterator + >>> it = multinomial_coefficients_iterator(20,3) + >>> next(it) + ((3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0), 1) + """ + m = as_int(m) + n = as_int(n) + if m < 2*n or n == 1: + mc = multinomial_coefficients(m, n) + yield from mc.items() + else: + mc = multinomial_coefficients(n, n) + mc1 = {} + for k, v in mc.items(): + mc1[_tuple(filter(None, k))] = v + mc = mc1 + + t = [n] + [0] * (m - 1) + t1 = _tuple(t) + b = _tuple(filter(None, t1)) + yield (t1, mc[b]) + if n: + j = 0 # j will be the leftmost nonzero position + else: + j = m + # enumerate tuples in co-lex order + while j < m - 1: + # compute next tuple + tj = t[j] + if j: + t[j] = 0 + t[0] = tj + if tj > 1: + t[j + 1] += 1 + j = 0 + else: + j += 1 + t[j] += 1 + + t[0] -= 1 + t1 = _tuple(t) + b = _tuple(filter(None, t1)) + yield (t1, mc[b]) diff --git a/phivenv/Lib/site-packages/sympy/ntheory/partitions_.py b/phivenv/Lib/site-packages/sympy/ntheory/partitions_.py new file mode 100644 index 0000000000000000000000000000000000000000..953fa9e2fef146b0d3a9baad0ec5e1353ad6f237 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/ntheory/partitions_.py @@ -0,0 +1,277 @@ +from mpmath.libmp import (fzero, from_int, from_rational, + fone, fhalf, bitcount, to_int, mpf_mul, mpf_div, mpf_sub, + mpf_add, mpf_sqrt, mpf_pi, mpf_cosh_sinh, mpf_cos, mpf_sin) +from .residue_ntheory import _sqrt_mod_prime_power, is_quad_residue +from sympy.utilities.decorator import deprecated +from sympy.utilities.memoization import recurrence_memo + +import math +from itertools import count + +def _pre(): + maxn = 10**5 + global _factor, _totient + _factor = [0]*maxn + _totient = [1]*maxn + lim = int(maxn**0.5) + 5 + for i in range(2, lim): + if _factor[i] == 0: + for j in range(i*i, maxn, i): + if _factor[j] == 0: + _factor[j] = i + for i in range(2, maxn): + if _factor[i] == 0: + _factor[i] = i + _totient[i] = i-1 + continue + x = _factor[i] + y = i//x + if y % x == 0: + _totient[i] = _totient[y]*x + else: + _totient[i] = _totient[y]*(x - 1) + +def _a(n, k, prec): + """ Compute the inner sum in HRR formula [1]_ + + References + ========== + + .. [1] https://msp.org/pjm/1956/6-1/pjm-v6-n1-p18-p.pdf + + """ + if k == 1: + return fone + + k1 = k + e = 0 + p = _factor[k] + while k1 % p == 0: + k1 //= p + e += 1 + k2 = k//k1 # k2 = p^e + v = 1 - 24*n + pi = mpf_pi(prec) + + if k1 == 1: + # k = p^e + if p == 2: + mod = 8*k + v = mod + v % mod + v = (v*pow(9, k - 1, mod)) % mod + m = _sqrt_mod_prime_power(v, 2, e + 3)[0] + arg = mpf_div(mpf_mul( + from_int(4*m), pi, prec), from_int(mod), prec) + return mpf_mul(mpf_mul( + from_int((-1)**e*(2 - (m % 4))), + mpf_sqrt(from_int(k), prec), prec), + mpf_sin(arg, prec), prec) + if p == 3: + mod = 3*k + v = mod + v % mod + if e > 1: + v = (v*pow(64, k//3 - 1, mod)) % mod + m = _sqrt_mod_prime_power(v, 3, e + 1)[0] + arg = mpf_div(mpf_mul(from_int(4*m), pi, prec), + from_int(mod), prec) + return mpf_mul(mpf_mul( + from_int(2*(-1)**(e + 1)*(3 - 2*(m % 3))), + mpf_sqrt(from_int(k//3), prec), prec), + mpf_sin(arg, prec), prec) + v = k + v % k + jacobi3 = -1 if k % 12 in [5, 7] else 1 + if v % p == 0: + if e == 1: + return mpf_mul( + from_int(jacobi3), + mpf_sqrt(from_int(k), prec), prec) + return fzero + if not is_quad_residue(v, p): + return fzero + _phi = p**(e - 1)*(p - 1) + v = (v*pow(576, _phi - 1, k)) + m = _sqrt_mod_prime_power(v, p, e)[0] + arg = mpf_div( + mpf_mul(from_int(4*m), pi, prec), + from_int(k), prec) + return mpf_mul(mpf_mul( + from_int(2*jacobi3), + mpf_sqrt(from_int(k), prec), prec), + mpf_cos(arg, prec), prec) + + if p != 2 or e >= 3: + d1, d2 = math.gcd(k1, 24), math.gcd(k2, 24) + e = 24//(d1*d2) + n1 = ((d2*e*n + (k2**2 - 1)//d1)* + pow(e*k2*k2*d2, _totient[k1] - 1, k1)) % k1 + n2 = ((d1*e*n + (k1**2 - 1)//d2)* + pow(e*k1*k1*d1, _totient[k2] - 1, k2)) % k2 + return mpf_mul(_a(n1, k1, prec), _a(n2, k2, prec), prec) + if e == 2: + n1 = ((8*n + 5)*pow(128, _totient[k1] - 1, k1)) % k1 + n2 = (4 + ((n - 2 - (k1**2 - 1)//8)*(k1**2)) % 4) % 4 + return mpf_mul(mpf_mul( + from_int(-1), + _a(n1, k1, prec), prec), + _a(n2, k2, prec)) + n1 = ((8*n + 1)*pow(32, _totient[k1] - 1, k1)) % k1 + n2 = (2 + (n - (k1**2 - 1)//8) % 2) % 2 + return mpf_mul(_a(n1, k1, prec), _a(n2, k2, prec), prec) + +def _d(n, j, prec, sq23pi, sqrt8): + """ + Compute the sinh term in the outer sum of the HRR formula. + The constants sqrt(2/3*pi) and sqrt(8) must be precomputed. + """ + j = from_int(j) + pi = mpf_pi(prec) + a = mpf_div(sq23pi, j, prec) + b = mpf_sub(from_int(n), from_rational(1, 24, prec), prec) + c = mpf_sqrt(b, prec) + ch, sh = mpf_cosh_sinh(mpf_mul(a, c), prec) + D = mpf_div( + mpf_sqrt(j, prec), + mpf_mul(mpf_mul(sqrt8, b), pi), prec) + E = mpf_sub(mpf_mul(a, ch), mpf_div(sh, c, prec), prec) + return mpf_mul(D, E) + + +@recurrence_memo([1, 1]) +def _partition_rec(n: int, prev) -> int: + """ Calculate the partition function P(n) + + Parameters + ========== + + n : int + nonnegative integer + + """ + v = 0 + penta = 0 # pentagonal number: 1, 5, 12, ... + for i in count(): + penta += 3*i + 1 + np = n - penta + if np < 0: + break + s = prev[np] + np -= i + 1 + # np = n - gp where gp = generalized pentagonal: 2, 7, 15, ... + if 0 <= np: + s += prev[np] + v += -s if i % 2 else s + return v + + +def _partition(n: int) -> int: + """ Calculate the partition function P(n) + + Parameters + ========== + + n : int + + """ + if n < 0: + return 0 + if (n <= 200_000 and n - _partition_rec.cache_length() < 70 or + _partition_rec.cache_length() == 2 and n < 14_400): + # There will be 2*10**5 elements created here + # and n elements created by partition, so in case we + # are going to be working with small n, we just + # use partition to calculate (and cache) the values + # since lookup is used there while summation, using + # _factor and _totient, will be used below. But we + # only do so if n is relatively close to the length + # of the cache since doing 1 calculation here is about + # the same as adding 70 elements to the cache. In addition, + # the startup here costs about the same as calculating the first + # 14,400 values via partition, so we delay startup here unless n + # is smaller than that. + return _partition_rec(n) + if '_factor' not in globals(): + _pre() + # Estimate number of bits in p(n). This formula could be tidied + pbits = int(( + math.pi*(2*n/3.)**0.5 - + math.log(4*n))/math.log(10) + 1) * \ + math.log2(10) + prec = p = int(pbits*1.1 + 100) + + # find the number of terms needed so rounded sum will be accurate + # using Rademacher's bound M(n, N) for the remainder after a partial + # sum of N terms (https://arxiv.org/pdf/1205.5991.pdf, (1.8)) + c1 = 44*math.pi**2/(225*math.sqrt(3)) + c2 = math.pi*math.sqrt(2)/75 + c3 = math.pi*math.sqrt(2/3) + def _M(n, N): + sqrt = math.sqrt + return c1/sqrt(N) + c2*sqrt(N/(n - 1))*math.sinh(c3*sqrt(n)/N) + big = max(9, math.ceil(n**0.5)) # should be too large (for n > 65, ceil should work) + assert _M(n, big) < 0.5 # else double big until too large + while big > 40 and _M(n, big) < 0.5: + big //= 2 + small = big + big = small*2 + while big - small > 1: + N = (big + small)//2 + if (er := _M(n, N)) < 0.5: + big = N + elif er >= 0.5: + small = N + M = big # done with function M; now have value + + # sanity check for expected size of answer + if M > 10**5: # i.e. M > maxn + raise ValueError("Input too big") # i.e. n > 149832547102 + + # calculate it + s = fzero + sq23pi = mpf_mul(mpf_sqrt(from_rational(2, 3, p), p), mpf_pi(p), p) + sqrt8 = mpf_sqrt(from_int(8), p) + for q in range(1, M): + a = _a(n, q, p) + d = _d(n, q, p, sq23pi, sqrt8) + s = mpf_add(s, mpf_mul(a, d), prec) + # On average, the terms decrease rapidly in magnitude. + # Dynamically reducing the precision greatly improves + # performance. + p = bitcount(abs(to_int(d))) + 50 + return int(to_int(mpf_add(s, fhalf, prec))) + + +@deprecated("""\ +The `sympy.ntheory.partitions_.npartitions` has been moved to `sympy.functions.combinatorial.numbers.partition`.""", +deprecated_since_version="1.13", +active_deprecations_target='deprecated-ntheory-symbolic-functions') +def npartitions(n, verbose=False): + """ + Calculate the partition function P(n), i.e. the number of ways that + n can be written as a sum of positive integers. + + .. deprecated:: 1.13 + + The ``npartitions`` function is deprecated. Use :class:`sympy.functions.combinatorial.numbers.partition` + instead. See its documentation for more information. See + :ref:`deprecated-ntheory-symbolic-functions` for details. + + P(n) is computed using the Hardy-Ramanujan-Rademacher formula [1]_. + + + The correctness of this implementation has been tested through $10^{10}$. + + Examples + ======== + + >>> from sympy.functions.combinatorial.numbers import partition + >>> partition(25) + 1958 + + References + ========== + + .. [1] https://mathworld.wolfram.com/PartitionFunctionP.html + + """ + from sympy.functions.combinatorial.numbers import partition as func_partition + return func_partition(n) diff --git a/phivenv/Lib/site-packages/sympy/ntheory/primetest.py b/phivenv/Lib/site-packages/sympy/ntheory/primetest.py new file mode 100644 index 0000000000000000000000000000000000000000..ff3cb82cc51bf57ca345a7d72ee715c861f62e2a --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/ntheory/primetest.py @@ -0,0 +1,830 @@ +""" +Primality testing + +""" + +from itertools import count + +from sympy.core.sympify import sympify +from sympy.external.gmpy import (gmpy as _gmpy, gcd, jacobi, + is_square as gmpy_is_square, + bit_scan1, is_fermat_prp, is_euler_prp, + is_selfridge_prp, is_strong_selfridge_prp, + is_strong_bpsw_prp) +from sympy.external.ntheory import _lucas_sequence +from sympy.utilities.misc import as_int, filldedent + +# Note: This list should be updated whenever new Mersenne primes are found. +# Refer: https://www.mersenne.org/ +MERSENNE_PRIME_EXPONENTS = (2, 3, 5, 7, 13, 17, 19, 31, 61, 89, 107, 127, 521, 607, 1279, 2203, + 2281, 3217, 4253, 4423, 9689, 9941, 11213, 19937, 21701, 23209, 44497, 86243, 110503, 132049, + 216091, 756839, 859433, 1257787, 1398269, 2976221, 3021377, 6972593, 13466917, 20996011, 24036583, + 25964951, 30402457, 32582657, 37156667, 42643801, 43112609, 57885161, 74207281, 77232917, 82589933, + 136279841) + + +def is_fermat_pseudoprime(n, a): + r"""Returns True if ``n`` is prime or is an odd composite integer that + is coprime to ``a`` and satisfy the modular arithmetic congruence relation: + + .. math :: + a^{n-1} \equiv 1 \pmod{n} + + (where mod refers to the modulo operation). + + Parameters + ========== + + n : Integer + ``n`` is a positive integer. + a : Integer + ``a`` is a positive integer. + ``a`` and ``n`` should be relatively prime. + + Returns + ======= + + bool : If ``n`` is prime, it always returns ``True``. + The composite number that returns ``True`` is called an Fermat pseudoprime. + + Examples + ======== + + >>> from sympy.ntheory.primetest import is_fermat_pseudoprime + >>> from sympy.ntheory.factor_ import isprime + >>> for n in range(1, 1000): + ... if is_fermat_pseudoprime(n, 2) and not isprime(n): + ... print(n) + 341 + 561 + 645 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Fermat_pseudoprime + """ + n, a = as_int(n), as_int(a) + if a == 1: + return n == 2 or bool(n % 2) + return is_fermat_prp(n, a) + + +def is_euler_pseudoprime(n, a): + r"""Returns True if ``n`` is prime or is an odd composite integer that + is coprime to ``a`` and satisfy the modular arithmetic congruence relation: + + .. math :: + a^{(n-1)/2} \equiv \pm 1 \pmod{n} + + (where mod refers to the modulo operation). + + Parameters + ========== + + n : Integer + ``n`` is a positive integer. + a : Integer + ``a`` is a positive integer. + ``a`` and ``n`` should be relatively prime. + + Returns + ======= + + bool : If ``n`` is prime, it always returns ``True``. + The composite number that returns ``True`` is called an Euler pseudoprime. + + Examples + ======== + + >>> from sympy.ntheory.primetest import is_euler_pseudoprime + >>> from sympy.ntheory.factor_ import isprime + >>> for n in range(1, 1000): + ... if is_euler_pseudoprime(n, 2) and not isprime(n): + ... print(n) + 341 + 561 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Euler_pseudoprime + """ + n, a = as_int(n), as_int(a) + if a < 1: + raise ValueError("a should be an integer greater than 0") + if n < 1: + raise ValueError("n should be an integer greater than 0") + if n == 1: + return False + if a == 1: + return n == 2 or bool(n % 2) # (prime or odd composite) + if n % 2 == 0: + return n == 2 + if gcd(n, a) != 1: + raise ValueError("The two numbers should be relatively prime") + return pow(a, (n - 1) // 2, n) in [1, n - 1] + + +def is_euler_jacobi_pseudoprime(n, a): + r"""Returns True if ``n`` is prime or is an odd composite integer that + is coprime to ``a`` and satisfy the modular arithmetic congruence relation: + + .. math :: + a^{(n-1)/2} \equiv \left(\frac{a}{n}\right) \pmod{n} + + (where mod refers to the modulo operation). + + Parameters + ========== + + n : Integer + ``n`` is a positive integer. + a : Integer + ``a`` is a positive integer. + ``a`` and ``n`` should be relatively prime. + + Returns + ======= + + bool : If ``n`` is prime, it always returns ``True``. + The composite number that returns ``True`` is called an Euler-Jacobi pseudoprime. + + Examples + ======== + + >>> from sympy.ntheory.primetest import is_euler_jacobi_pseudoprime + >>> from sympy.ntheory.factor_ import isprime + >>> for n in range(1, 1000): + ... if is_euler_jacobi_pseudoprime(n, 2) and not isprime(n): + ... print(n) + 561 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Euler%E2%80%93Jacobi_pseudoprime + """ + n, a = as_int(n), as_int(a) + if a == 1: + return n == 2 or bool(n % 2) + return is_euler_prp(n, a) + + +def is_square(n, prep=True): + """Return True if n == a * a for some integer a, else False. + If n is suspected of *not* being a square then this is a + quick method of confirming that it is not. + + Examples + ======== + + >>> from sympy.ntheory.primetest import is_square + >>> is_square(25) + True + >>> is_square(2) + False + + References + ========== + + .. [1] https://mersenneforum.org/showpost.php?p=110896 + + See Also + ======== + sympy.core.intfunc.isqrt + """ + if prep: + n = as_int(n) + if n < 0: + return False + if n in (0, 1): + return True + return gmpy_is_square(n) + + +def _test(n, base, s, t): + """Miller-Rabin strong pseudoprime test for one base. + Return False if n is definitely composite, True if n is + probably prime, with a probability greater than 3/4. + + """ + # do the Fermat test + b = pow(base, t, n) + if b == 1 or b == n - 1: + return True + for _ in range(s - 1): + b = pow(b, 2, n) + if b == n - 1: + return True + # see I. Niven et al. "An Introduction to Theory of Numbers", page 78 + if b == 1: + return False + return False + + +def mr(n, bases): + """Perform a Miller-Rabin strong pseudoprime test on n using a + given list of bases/witnesses. + + References + ========== + + .. [1] Richard Crandall & Carl Pomerance (2005), "Prime Numbers: + A Computational Perspective", Springer, 2nd edition, 135-138 + + A list of thresholds and the bases they require are here: + https://en.wikipedia.org/wiki/Miller%E2%80%93Rabin_primality_test#Deterministic_variants + + Examples + ======== + + >>> from sympy.ntheory.primetest import mr + >>> mr(1373651, [2, 3]) + False + >>> mr(479001599, [31, 73]) + True + + """ + from sympy.polys.domains import ZZ + + n = as_int(n) + if n < 2 or (n > 2 and n % 2 == 0): + return False + # remove powers of 2 from n-1 (= t * 2**s) + s = bit_scan1(n - 1) + t = n >> s + for base in bases: + # Bases >= n are wrapped, bases < 2 are invalid + if base >= n: + base %= n + if base >= 2: + base = ZZ(base) + if not _test(n, base, s, t): + return False + return True + + +def _lucas_extrastrong_params(n): + """Calculates the "extra strong" parameters (D, P, Q) for n. + + Parameters + ========== + + n : int + positive odd integer + + Returns + ======= + + D, P, Q: "extra strong" parameters. + ``(0, 0, 0)`` if we find a nontrivial divisor of ``n``. + + Examples + ======== + + >>> from sympy.ntheory.primetest import _lucas_extrastrong_params + >>> _lucas_extrastrong_params(101) + (12, 4, 1) + >>> _lucas_extrastrong_params(15) + (0, 0, 0) + + References + ========== + .. [1] OEIS A217719: Extra Strong Lucas Pseudoprimes + https://oeis.org/A217719 + .. [2] https://en.wikipedia.org/wiki/Lucas_pseudoprime + + """ + for P in count(3): + D = P**2 - 4 + j = jacobi(D, n) + if j == -1: + return (D, P, 1) + elif j == 0 and D % n: + return (0, 0, 0) + + +def is_lucas_prp(n): + """Standard Lucas compositeness test with Selfridge parameters. Returns + False if n is definitely composite, and True if n is a Lucas probable + prime. + + This is typically used in combination with the Miller-Rabin test. + + References + ========== + .. [1] Robert Baillie, Samuel S. Wagstaff, Lucas Pseudoprimes, + Math. Comp. Vol 35, Number 152 (1980), pp. 1391-1417, + https://doi.org/10.1090%2FS0025-5718-1980-0583518-6 + http://mpqs.free.fr/LucasPseudoprimes.pdf + .. [2] OEIS A217120: Lucas Pseudoprimes + https://oeis.org/A217120 + .. [3] https://en.wikipedia.org/wiki/Lucas_pseudoprime + + Examples + ======== + + >>> from sympy.ntheory.primetest import isprime, is_lucas_prp + >>> for i in range(10000): + ... if is_lucas_prp(i) and not isprime(i): + ... print(i) + 323 + 377 + 1159 + 1829 + 3827 + 5459 + 5777 + 9071 + 9179 + """ + n = as_int(n) + if n < 2: + return False + return is_selfridge_prp(n) + + +def is_strong_lucas_prp(n): + """Strong Lucas compositeness test with Selfridge parameters. Returns + False if n is definitely composite, and True if n is a strong Lucas + probable prime. + + This is often used in combination with the Miller-Rabin test, and + in particular, when combined with M-R base 2 creates the strong BPSW test. + + References + ========== + .. [1] Robert Baillie, Samuel S. Wagstaff, Lucas Pseudoprimes, + Math. Comp. Vol 35, Number 152 (1980), pp. 1391-1417, + https://doi.org/10.1090%2FS0025-5718-1980-0583518-6 + http://mpqs.free.fr/LucasPseudoprimes.pdf + .. [2] OEIS A217255: Strong Lucas Pseudoprimes + https://oeis.org/A217255 + .. [3] https://en.wikipedia.org/wiki/Lucas_pseudoprime + .. [4] https://en.wikipedia.org/wiki/Baillie-PSW_primality_test + + Examples + ======== + + >>> from sympy.ntheory.primetest import isprime, is_strong_lucas_prp + >>> for i in range(20000): + ... if is_strong_lucas_prp(i) and not isprime(i): + ... print(i) + 5459 + 5777 + 10877 + 16109 + 18971 + """ + n = as_int(n) + if n < 2: + return False + return is_strong_selfridge_prp(n) + + +def is_extra_strong_lucas_prp(n): + """Extra Strong Lucas compositeness test. Returns False if n is + definitely composite, and True if n is an "extra strong" Lucas probable + prime. + + The parameters are selected using P = 3, Q = 1, then incrementing P until + (D|n) == -1. The test itself is as defined in [1]_, from the + Mo and Jones preprint. The parameter selection and test are the same as + used in OEIS A217719, Perl's Math::Prime::Util, and the Lucas pseudoprime + page on Wikipedia. + + It is 20-50% faster than the strong test. + + Because of the different parameters selected, there is no relationship + between the strong Lucas pseudoprimes and extra strong Lucas pseudoprimes. + In particular, one is not a subset of the other. + + References + ========== + .. [1] Jon Grantham, Frobenius Pseudoprimes, + Math. Comp. Vol 70, Number 234 (2001), pp. 873-891, + https://doi.org/10.1090%2FS0025-5718-00-01197-2 + .. [2] OEIS A217719: Extra Strong Lucas Pseudoprimes + https://oeis.org/A217719 + .. [3] https://en.wikipedia.org/wiki/Lucas_pseudoprime + + Examples + ======== + + >>> from sympy.ntheory.primetest import isprime, is_extra_strong_lucas_prp + >>> for i in range(20000): + ... if is_extra_strong_lucas_prp(i) and not isprime(i): + ... print(i) + 989 + 3239 + 5777 + 10877 + """ + # Implementation notes: + # 1) the parameters differ from Thomas R. Nicely's. His parameter + # selection leads to pseudoprimes that overlap M-R tests, and + # contradict Baillie and Wagstaff's suggestion of (D|n) = -1. + # 2) The MathWorld page as of June 2013 specifies Q=-1. The Lucas + # sequence must have Q=1. See Grantham theorem 2.3, any of the + # references on the MathWorld page, or run it and see Q=-1 is wrong. + n = as_int(n) + if n == 2: + return True + if n < 2 or (n % 2) == 0: + return False + if gmpy_is_square(n): + return False + + D, P, Q = _lucas_extrastrong_params(n) + if D == 0: + return False + + # remove powers of 2 from n+1 (= k * 2**s) + s = bit_scan1(n + 1) + k = (n + 1) >> s + + U, V, _ = _lucas_sequence(n, P, Q, k) + + if U == 0 and (V == 2 or V == n - 2): + return True + for _ in range(1, s): + if V == 0: + return True + V = (V*V - 2) % n + return False + + +def proth_test(n): + r""" Test if the Proth number `n = k2^m + 1` is prime. where k is a positive odd number and `2^m > k`. + + Parameters + ========== + + n : Integer + ``n`` is Proth number + + Returns + ======= + + bool : If ``True``, then ``n`` is the Proth prime + + Raises + ====== + + ValueError + If ``n`` is not Proth number. + + Examples + ======== + + >>> from sympy.ntheory.primetest import proth_test + >>> proth_test(41) + True + >>> proth_test(57) + False + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Proth_prime + + """ + n = as_int(n) + if n < 3: + raise ValueError("n is not Proth number") + m = bit_scan1(n - 1) + k = n >> m + if m < k.bit_length(): + raise ValueError("n is not Proth number") + if n % 3 == 0: + return n == 3 + if k % 3: # n % 12 == 5 + return pow(3, n >> 1, n) == n - 1 + # If `n` is a square number, then `jacobi(a, n) = 1` for any `a` + if gmpy_is_square(n): + return False + # `a` may be chosen at random. + # In any case, we want to find `a` such that `jacobi(a, n) = -1`. + for a in range(5, n): + j = jacobi(a, n) + if j == -1: + return pow(a, n >> 1, n) == n - 1 + if j == 0: + return False + + +def _lucas_lehmer_primality_test(p): + r""" Test if the Mersenne number `M_p = 2^p-1` is prime. + + Parameters + ========== + + p : int + ``p`` is an odd prime number + + Returns + ======= + + bool : If ``True``, then `M_p` is the Mersenne prime + + Examples + ======== + + >>> from sympy.ntheory.primetest import _lucas_lehmer_primality_test + >>> _lucas_lehmer_primality_test(5) # 2**5 - 1 = 31 is prime + True + >>> _lucas_lehmer_primality_test(11) # 2**11 - 1 = 2047 is not prime + False + + See Also + ======== + + is_mersenne_prime + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Lucas%E2%80%93Lehmer_primality_test + + """ + v = 4 + m = 2**p - 1 + for _ in range(p - 2): + v = pow(v, 2, m) - 2 + return v == 0 + + +def is_mersenne_prime(n): + """Returns True if ``n`` is a Mersenne prime, else False. + + A Mersenne prime is a prime number having the form `2^i - 1`. + + Examples + ======== + + >>> from sympy.ntheory.factor_ import is_mersenne_prime + >>> is_mersenne_prime(6) + False + >>> is_mersenne_prime(127) + True + + References + ========== + + .. [1] https://mathworld.wolfram.com/MersennePrime.html + + """ + n = as_int(n) + if n < 1: + return False + if n & (n + 1): + # n is not Mersenne number + return False + p = n.bit_length() + if p in MERSENNE_PRIME_EXPONENTS: + return True + if p < 65_000_000 or not isprime(p): + # According to GIMPS, verification was completed on September 19, 2023 for p less than 65 million. + # https://www.mersenne.org/report_milestones/ + # If p is composite number, then n=2**p-1 is composite number. + return False + result = _lucas_lehmer_primality_test(p) + if result: + raise ValueError(filldedent(''' + This Mersenne Prime, 2^%s - 1, should + be added to SymPy's known values.''' % p)) + return result + + +_MR_BASES_32 = [15591, 2018, 166, 7429, 8064, 16045, 10503, 4399, 1949, 1295, + 2776, 3620, 560, 3128, 5212, 2657, 2300, 2021, 4652, 1471, + 9336, 4018, 2398, 20462, 10277, 8028, 2213, 6219, 620, 3763, + 4852, 5012, 3185, 1333, 6227,5298, 1074, 2391, 5113, 7061, + 803, 1269, 3875, 422, 751, 580, 4729, 10239, 746, 2951, 556, + 2206, 3778, 481, 1522, 3476, 481, 2487, 3266, 5633, 488, 3373, + 6441, 3344, 17, 15105, 1490, 4154, 2036, 1882, 1813, 467, + 3307, 14042, 6371, 658, 1005, 903, 737, 1887, 7447, 1888, + 2848, 1784, 7559, 3400, 951, 13969, 4304, 177, 41, 19875, + 3110, 13221, 8726, 571, 7043, 6943, 1199, 352, 6435, 165, + 1169, 3315, 978, 233, 3003, 2562, 2994, 10587, 10030, 2377, + 1902, 5354, 4447, 1555, 263, 27027, 2283, 305, 669, 1912, 601, + 6186, 429, 1930, 14873, 1784, 1661, 524, 3577, 236, 2360, + 6146, 2850, 55637, 1753, 4178, 8466, 222, 2579, 2743, 2031, + 2226, 2276, 374, 2132, 813, 23788, 1610, 4422, 5159, 1725, + 3597, 3366, 14336, 579, 165, 1375, 10018, 12616, 9816, 1371, + 536, 1867, 10864, 857, 2206, 5788, 434, 8085, 17618, 727, + 3639, 1595, 4944, 2129, 2029, 8195, 8344, 6232, 9183, 8126, + 1870, 3296, 7455, 8947, 25017, 541, 19115, 368, 566, 5674, + 411, 522, 1027, 8215, 2050, 6544, 10049, 614, 774, 2333, 3007, + 35201, 4706, 1152, 1785, 1028, 1540, 3743, 493, 4474, 2521, + 26845, 8354, 864, 18915, 5465, 2447, 42, 4511, 1660, 166, + 1249, 6259, 2553, 304, 272, 7286, 73, 6554, 899, 2816, 5197, + 13330, 7054, 2818, 3199, 811, 922, 350, 7514, 4452, 3449, + 2663, 4708, 418, 1621, 1171, 3471, 88, 11345, 412, 1559, 194] + + +def isprime(n): + """ + Test if n is a prime number (True) or not (False). For n < 2^64 the + answer is definitive; larger n values have a small probability of actually + being pseudoprimes. + + Negative numbers (e.g. -2) are not considered prime. + + The first step is looking for trivial factors, which if found enables + a quick return. Next, if the sieve is large enough, use bisection search + on the sieve. For small numbers, a set of deterministic Miller-Rabin + tests are performed with bases that are known to have no counterexamples + in their range. Finally if the number is larger than 2^64, a strong + BPSW test is performed. While this is a probable prime test and we + believe counterexamples exist, there are no known counterexamples. + + Examples + ======== + + >>> from sympy.ntheory import isprime + >>> isprime(13) + True + >>> isprime(15) + False + + Notes + ===== + + This routine is intended only for integer input, not numerical + expressions which may represent numbers. Floats are also + rejected as input because they represent numbers of limited + precision. While it is tempting to permit 7.0 to represent an + integer there are errors that may "pass silently" if this is + allowed: + + >>> from sympy import Float, S + >>> int(1e3) == 1e3 == 10**3 + True + >>> int(1e23) == 1e23 + True + >>> int(1e23) == 10**23 + False + + >>> near_int = 1 + S(1)/10**19 + >>> near_int == int(near_int) + False + >>> n = Float(near_int, 10) # truncated by precision + >>> n % 1 == 0 + True + >>> n = Float(near_int, 20) + >>> n % 1 == 0 + False + + See Also + ======== + + sympy.ntheory.generate.primerange : Generates all primes in a given range + sympy.functions.combinatorial.numbers.primepi : Return the number of primes less than or equal to n + sympy.ntheory.generate.prime : Return the nth prime + + References + ========== + .. [1] https://en.wikipedia.org/wiki/Strong_pseudoprime + .. [2] Robert Baillie, Samuel S. Wagstaff, Lucas Pseudoprimes, + Math. Comp. Vol 35, Number 152 (1980), pp. 1391-1417, + https://doi.org/10.1090%2FS0025-5718-1980-0583518-6 + http://mpqs.free.fr/LucasPseudoprimes.pdf + .. [3] https://en.wikipedia.org/wiki/Baillie-PSW_primality_test + """ + n = as_int(n) + + # Step 1, do quick composite testing via trial division. The individual + # modulo tests benchmark faster than one or two primorial igcds for me. + # The point here is just to speedily handle small numbers and many + # composites. Step 2 only requires that n <= 2 get handled here. + if n in [2, 3, 5]: + return True + if n < 2 or (n % 2) == 0 or (n % 3) == 0 or (n % 5) == 0: + return False + if n < 49: + return True + if (n % 7) == 0 or (n % 11) == 0 or (n % 13) == 0 or (n % 17) == 0 or \ + (n % 19) == 0 or (n % 23) == 0 or (n % 29) == 0 or (n % 31) == 0 or \ + (n % 37) == 0 or (n % 41) == 0 or (n % 43) == 0 or (n % 47) == 0: + return False + if n < 2809: + return True + if n < 65077: + # There are only five Euler pseudoprimes with a least prime factor greater than 47 + return pow(2, n >> 1, n) in [1, n - 1] and n not in [8321, 31621, 42799, 49141, 49981] + + # bisection search on the sieve if the sieve is large enough + from sympy.ntheory.generate import sieve as s + if n <= s._list[-1]: + l, u = s.search(n) + return l == u + from sympy.ntheory.factor_ import factor_cache + if (ret := factor_cache.get(n)) is not None: + return ret == n + + # If we have GMPY2, skip straight to step 3 and do a strong BPSW test. + # This should be a bit faster than our step 2, and for large values will + # be a lot faster than our step 3 (C+GMP vs. Python). + if _gmpy is not None: + return is_strong_bpsw_prp(n) + + + # Step 2: deterministic Miller-Rabin testing for numbers < 2^64. See: + # https://miller-rabin.appspot.com/ + # for lists. We have made sure the M-R routine will successfully handle + # bases larger than n, so we can use the minimal set. + # In September 2015 deterministic numbers were extended to over 2^81. + # https://arxiv.org/pdf/1509.00864.pdf + # https://oeis.org/A014233 + if n < 341531: + return mr(n, [9345883071009581737]) + if n < 4296595241: + # Michal Forisek and Jakub Jancina, + # Fast Primality Testing for Integers That Fit into a Machine Word + # https://ceur-ws.org/Vol-1326/020-Forisek.pdf + h = ((n >> 16) ^ n) * 0x45d9f3b + h = ((h >> 16) ^ h) * 0x45d9f3b + h = ((h >> 16) ^ h) & 255 + return mr(n, [_MR_BASES_32[h]]) + if n < 350269456337: + return mr(n, [4230279247111683200, 14694767155120705706, 16641139526367750375]) + if n < 55245642489451: + return mr(n, [2, 141889084524735, 1199124725622454117, 11096072698276303650]) + if n < 7999252175582851: + return mr(n, [2, 4130806001517, 149795463772692060, 186635894390467037, 3967304179347715805]) + if n < 585226005592931977: + return mr(n, [2, 123635709730000, 9233062284813009, 43835965440333360, 761179012939631437, 1263739024124850375]) + if n < 18446744073709551616: + return mr(n, [2, 325, 9375, 28178, 450775, 9780504, 1795265022]) + if n < 318665857834031151167461: + return mr(n, [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37]) + if n < 3317044064679887385961981: + return mr(n, [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41]) + + # We could do this instead at any point: + #if n < 18446744073709551616: + # return mr(n, [2]) and is_extra_strong_lucas_prp(n) + + # Here are tests that are safe for MR routines that don't understand + # large bases. + #if n < 9080191: + # return mr(n, [31, 73]) + #if n < 19471033: + # return mr(n, [2, 299417]) + #if n < 38010307: + # return mr(n, [2, 9332593]) + #if n < 316349281: + # return mr(n, [11000544, 31481107]) + #if n < 4759123141: + # return mr(n, [2, 7, 61]) + #if n < 105936894253: + # return mr(n, [2, 1005905886, 1340600841]) + #if n < 31858317218647: + # return mr(n, [2, 642735, 553174392, 3046413974]) + #if n < 3071837692357849: + # return mr(n, [2, 75088, 642735, 203659041, 3613982119]) + #if n < 18446744073709551616: + # return mr(n, [2, 325, 9375, 28178, 450775, 9780504, 1795265022]) + + # Step 3: BPSW. + # + # Time for isprime(10**2000 + 4561), no gmpy or gmpy2 installed + # 44.0s old isprime using 46 bases + # 5.3s strong BPSW + one random base + # 4.3s extra strong BPSW + one random base + # 4.1s strong BPSW + # 3.2s extra strong BPSW + + # Classic BPSW from page 1401 of the paper. See alternate ideas below. + return is_strong_bpsw_prp(n) + + # Using extra strong test, which is somewhat faster + #return mr(n, [2]) and is_extra_strong_lucas_prp(n) + + # Add a random M-R base + #import random + #return mr(n, [2, random.randint(3, n-1)]) and is_strong_lucas_prp(n) + + +def is_gaussian_prime(num): + r"""Test if num is a Gaussian prime number. + + References + ========== + + .. [1] https://oeis.org/wiki/Gaussian_primes + """ + + num = sympify(num) + a, b = num.as_real_imag() + a = as_int(a, strict=False) + b = as_int(b, strict=False) + if a == 0: + b = abs(b) + return isprime(b) and b % 4 == 3 + elif b == 0: + a = abs(a) + return isprime(a) and a % 4 == 3 + return isprime(a**2 + b**2) diff --git a/phivenv/Lib/site-packages/sympy/ntheory/qs.py b/phivenv/Lib/site-packages/sympy/ntheory/qs.py new file mode 100644 index 0000000000000000000000000000000000000000..acc9a7b6e0151695538a99a738ef397166497ba5 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/ntheory/qs.py @@ -0,0 +1,451 @@ +from math import exp, log +from sympy.core.random import _randint +from sympy.external.gmpy import bit_scan1, gcd, invert, sqrt as isqrt +from sympy.ntheory.factor_ import _perfect_power +from sympy.ntheory.primetest import isprime +from sympy.ntheory.residue_ntheory import _sqrt_mod_prime_power + + +class SievePolynomial: + def __init__(self, a, b, N): + """This class denotes the sieve polynomial. + Provide methods to compute `(a*x + b)**2 - N` and + `a*x + b` when given `x`. + + Parameters + ========== + + a : parameter of the sieve polynomial + b : parameter of the sieve polynomial + N : number to be factored + + """ + self.a = a + self.b = b + self.a2 = a**2 + self.ab = 2*a*b + self.b2 = b**2 - N + + def eval_u(self, x): + return self.a*x + self.b + + def eval_v(self, x): + return (self.a2*x + self.ab)*x + self.b2 + + +class FactorBaseElem: + """This class stores an element of the `factor_base`. + """ + def __init__(self, prime, tmem_p, log_p): + """ + Initialization of factor_base_elem. + + Parameters + ========== + + prime : prime number of the factor_base + tmem_p : Integer square root of x**2 = n mod prime + log_p : Compute Natural Logarithm of the prime + """ + self.prime = prime + self.tmem_p = tmem_p + self.log_p = log_p + # `soln1` and `soln2` are solutions to + # the equation `(a*x + b)**2 - N = 0 (mod p)`. + self.soln1 = None + self.soln2 = None + self.b_ainv = None + + +def _generate_factor_base(prime_bound, n): + """Generate `factor_base` for Quadratic Sieve. The `factor_base` + consists of all the points whose ``legendre_symbol(n, p) == 1`` + and ``p < num_primes``. Along with the prime `factor_base` also stores + natural logarithm of prime and the residue n modulo p. + It also returns the of primes numbers in the `factor_base` which are + close to 1000 and 5000. + + Parameters + ========== + + prime_bound : upper prime bound of the factor_base + n : integer to be factored + """ + from sympy.ntheory.generate import sieve + factor_base = [] + idx_1000, idx_5000 = None, None + for prime in sieve.primerange(1, prime_bound): + if pow(n, (prime - 1) // 2, prime) == 1: + if prime > 1000 and idx_1000 is None: + idx_1000 = len(factor_base) - 1 + if prime > 5000 and idx_5000 is None: + idx_5000 = len(factor_base) - 1 + residue = _sqrt_mod_prime_power(n, prime, 1)[0] + log_p = round(log(prime)*2**10) + factor_base.append(FactorBaseElem(prime, residue, log_p)) + return idx_1000, idx_5000, factor_base + + +def _generate_polynomial(N, M, factor_base, idx_1000, idx_5000, randint): + """ Generate sieve polynomials indefinitely. + Information such as `soln1` in the `factor_base` associated with + the polynomial is modified in place. + + Parameters + ========== + + N : Number to be factored + M : sieve interval + factor_base : factor_base primes + idx_1000 : index of prime number in the factor_base near 1000 + idx_5000 : index of prime number in the factor_base near to 5000 + randint : A callable that takes two integers (a, b) and returns a random integer + n such that a <= n <= b, similar to `random.randint`. + """ + approx_val = log(2*N)/2 - log(M) + start = idx_1000 or 0 + end = idx_5000 or (len(factor_base) - 1) + while True: + # Choose `a` that is close to `sqrt(2*N) / M` + best_a, best_q, best_ratio = None, None, None + for _ in range(50): + a = 1 + q = [] + while log(a) < approx_val: + rand_p = 0 + while(rand_p == 0 or rand_p in q): + rand_p = randint(start, end) + p = factor_base[rand_p].prime + a *= p + q.append(rand_p) + ratio = exp(log(a) - approx_val) + if best_ratio is None or abs(ratio - 1) < abs(best_ratio - 1): + best_q = q + best_a = a + best_ratio = ratio + + # Set `b` using the Chinese remainder theorem + a = best_a + q = best_q + B = [] + for val in q: + q_l = factor_base[val].prime + gamma = factor_base[val].tmem_p * invert(a // q_l, q_l) % q_l + if 2*gamma > q_l: + gamma = q_l - gamma + B.append(a//q_l*gamma) + b = sum(B) + g = SievePolynomial(a, b, N) + for fb in factor_base: + if a % fb.prime == 0: + fb.soln1 = None + continue + a_inv = invert(a, fb.prime) + fb.b_ainv = [2*b_elem*a_inv % fb.prime for b_elem in B] + fb.soln1 = (a_inv*(fb.tmem_p - b)) % fb.prime + fb.soln2 = (a_inv*(-fb.tmem_p - b)) % fb.prime + yield g + + # Update `b` with Gray code + for i in range(1, 2**(len(B)-1)): + v = bit_scan1(i) + neg_pow = 2*((i >> (v + 1)) % 2) - 1 + b = g.b + 2*neg_pow*B[v] + a = g.a + g = SievePolynomial(a, b, N) + for fb in factor_base: + if fb.soln1 is None: + continue + fb.soln1 = (fb.soln1 - neg_pow*fb.b_ainv[v]) % fb.prime + fb.soln2 = (fb.soln2 - neg_pow*fb.b_ainv[v]) % fb.prime + yield g + + +def _gen_sieve_array(M, factor_base): + """Sieve Stage of the Quadratic Sieve. For every prime in the factor_base + that does not divide the coefficient `a` we add log_p over the sieve_array + such that ``-M <= soln1 + i*p <= M`` and ``-M <= soln2 + i*p <= M`` where `i` + is an integer. When p = 2 then log_p is only added using + ``-M <= soln1 + i*p <= M``. + + Parameters + ========== + + M : sieve interval + factor_base : factor_base primes + """ + sieve_array = [0]*(2*M + 1) + for factor in factor_base: + if factor.soln1 is None: #The prime does not divides a + continue + for idx in range((M + factor.soln1) % factor.prime, 2*M, factor.prime): + sieve_array[idx] += factor.log_p + if factor.prime == 2: + continue + #if prime is 2 then sieve only with soln_1_p + for idx in range((M + factor.soln2) % factor.prime, 2*M, factor.prime): + sieve_array[idx] += factor.log_p + return sieve_array + + +def _check_smoothness(num, factor_base): + r""" Check if `num` is smooth with respect to the given `factor_base` + and compute its factorization vector. + + Parameters + ========== + + num : integer whose smootheness is to be checked + factor_base : factor_base primes + """ + if num < 0: + num *= -1 + vec = 1 + else: + vec = 0 + for i, fb in enumerate(factor_base, 1): + if num % fb.prime: + continue + e = 1 + num //= fb.prime + while num % fb.prime == 0: + e += 1 + num //= fb.prime + if e % 2: + vec += 1 << i + return vec, num + + +def _trial_division_stage(N, M, factor_base, sieve_array, sieve_poly, partial_relations, ERROR_TERM): + """Trial division stage. Here we trial divide the values generetated + by sieve_poly in the sieve interval and if it is a smooth number then + it is stored in `smooth_relations`. Moreover, if we find two partial relations + with same large prime then they are combined to form a smooth relation. + First we iterate over sieve array and look for values which are greater + than accumulated_val, as these values have a high chance of being smooth + number. Then using these values we find smooth relations. + In general, let ``t**2 = u*p modN`` and ``r**2 = v*p modN`` be two partial relations + with the same large prime p. Then they can be combined ``(t*r/p)**2 = u*v modN`` + to form a smooth relation. + + Parameters + ========== + + N : Number to be factored + M : sieve interval + factor_base : factor_base primes + sieve_array : stores log_p values + sieve_poly : polynomial from which we find smooth relations + partial_relations : stores partial relations with one large prime + ERROR_TERM : error term for accumulated_val + """ + accumulated_val = (log(M) + log(N)/2 - ERROR_TERM) * 2**10 + smooth_relations = [] + proper_factor = set() + partial_relation_upper_bound = 128*factor_base[-1].prime + for x, val in enumerate(sieve_array, -M): + if val < accumulated_val: + continue + v = sieve_poly.eval_v(x) + vec, num = _check_smoothness(v, factor_base) + if num == 1: + smooth_relations.append((sieve_poly.eval_u(x), v, vec)) + elif num < partial_relation_upper_bound and isprime(num): + if N % num == 0: + proper_factor.add(num) + continue + u = sieve_poly.eval_u(x) + if num in partial_relations: + u_prev, v_prev, vec_prev = partial_relations.pop(num) + u = u*u_prev*invert(num, N) % N + v = v*v_prev // num**2 + vec ^= vec_prev + smooth_relations.append((u, v, vec)) + else: + partial_relations[num] = (u, v, vec) + return smooth_relations, proper_factor + + +def _find_factor(N, smooth_relations, col): + """ Finds proper factor of N using fast gaussian reduction for modulo 2 matrix. + + Parameters + ========== + + N : Number to be factored + smooth_relations : Smooth relations vectors matrix + col : Number of columns in the matrix + + Reference + ========== + + .. [1] A fast algorithm for gaussian elimination over GF(2) and + its implementation on the GAPP. Cetin K.Koc, Sarath N.Arachchige + """ + matrix = [s_relation[2] for s_relation in smooth_relations] + row = len(matrix) + mark = [False] * row + for pos in range(col): + m = 1 << pos + for i in range(row): + if p := matrix[i] & m: + add_col = p ^ matrix[i] + matrix[i] = m + mark[i] = True + for j in range(i + 1, row): + if matrix[j] & m: + matrix[j] ^= add_col + break + + for m, mat, rel in zip(mark, matrix, smooth_relations): + if m: + continue + u, v = rel[0], rel[1] + for m1, mat1, rel1 in zip(mark, matrix, smooth_relations): + if m1 and mat & mat1: + u *= rel1[0] + v *= rel1[1] + # assert is_square(v) + v = isqrt(v) + if 1 < (g := gcd(u - v, N)) < N: + yield g + + +def qs(N, prime_bound, M, ERROR_TERM=25, seed=1234): + """Performs factorization using Self-Initializing Quadratic Sieve. + In SIQS, let N be a number to be factored, and this N should not be a + perfect power. If we find two integers such that ``X**2 = Y**2 modN`` and + ``X != +-Y modN``, then `gcd(X + Y, N)` will reveal a proper factor of N. + In order to find these integers X and Y we try to find relations of form + t**2 = u modN where u is a product of small primes. If we have enough of + these relations then we can form ``(t1*t2...ti)**2 = u1*u2...ui modN`` such that + the right hand side is a square, thus we found a relation of ``X**2 = Y**2 modN``. + + Here, several optimizations are done like using multiple polynomials for + sieving, fast changing between polynomials and using partial relations. + The use of partial relations can speeds up the factoring by 2 times. + + Parameters + ========== + + N : Number to be Factored + prime_bound : upper bound for primes in the factor base + M : Sieve Interval + ERROR_TERM : Error term for checking smoothness + seed : seed of random number generator + + Returns + ======= + + set(int) : A set of factors of N without considering multiplicity. + Returns ``{N}`` if factorization fails. + + Examples + ======== + + >>> from sympy.ntheory import qs + >>> qs(25645121643901801, 2000, 10000) + {5394769, 4753701529} + >>> qs(9804659461513846513, 2000, 10000) + {4641991, 2112166839943} + + See Also + ======== + + qs_factor + + References + ========== + + .. [1] https://pdfs.semanticscholar.org/5c52/8a975c1405bd35c65993abf5a4edb667c1db.pdf + .. [2] https://www.rieselprime.de/ziki/Self-initializing_quadratic_sieve + """ + return set(qs_factor(N, prime_bound, M, ERROR_TERM, seed)) + + +def qs_factor(N, prime_bound, M, ERROR_TERM=25, seed=1234): + """ Performs factorization using Self-Initializing Quadratic Sieve. + + Parameters + ========== + + N : Number to be Factored + prime_bound : upper bound for primes in the factor base + M : Sieve Interval + ERROR_TERM : Error term for checking smoothness + seed : seed of random number generator + + Returns + ======= + + dict[int, int] : Factors of N. + Returns ``{N: 1}`` if factorization fails. + Note that the key is not always a prime number. + + Examples + ======== + + >>> from sympy.ntheory import qs_factor + >>> qs_factor(1009 * 100003, 2000, 10000) + {1009: 1, 100003: 1} + + See Also + ======== + + qs + + """ + if N < 2: + raise ValueError("N should be greater than 1") + factors = {} + smooth_relations = [] + partial_relations = {} + # Eliminate the possibility of even numbers, + # prime numbers, and perfect powers. + if N % 2 == 0: + e = 1 + N //= 2 + while N % 2 == 0: + N //= 2 + e += 1 + factors[2] = e + if isprime(N): + factors[N] = 1 + return factors + if result := _perfect_power(N, 3): + n, e = result + factors[n] = e + return factors + N_copy = N + randint = _randint(seed) + idx_1000, idx_5000, factor_base = _generate_factor_base(prime_bound, N) + threshold = len(factor_base) * 105//100 + for g in _generate_polynomial(N, M, factor_base, idx_1000, idx_5000, randint): + sieve_array = _gen_sieve_array(M, factor_base) + s_rel, p_f = _trial_division_stage(N, M, factor_base, sieve_array, g, partial_relations, ERROR_TERM) + smooth_relations += s_rel + for p in p_f: + if N_copy % p: + continue + e = 1 + N_copy //= p + while N_copy % p == 0: + N_copy //= p + e += 1 + factors[p] = e + if threshold <= len(smooth_relations): + break + + for factor in _find_factor(N, smooth_relations, len(factor_base) + 1): + if N_copy % factor == 0: + e = 1 + N_copy //= factor + while N_copy % factor == 0: + N_copy //= factor + e += 1 + factors[factor] = e + if N_copy == 1 or isprime(N_copy): + break + if N_copy != 1: + factors[N_copy] = 1 + return factors diff --git a/phivenv/Lib/site-packages/sympy/ntheory/residue_ntheory.py b/phivenv/Lib/site-packages/sympy/ntheory/residue_ntheory.py new file mode 100644 index 0000000000000000000000000000000000000000..eba024161194605aabebd10ee30bf09acb90270b --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/ntheory/residue_ntheory.py @@ -0,0 +1,1963 @@ +from __future__ import annotations + +from sympy.external.gmpy import (gcd, lcm, invert, sqrt, jacobi, + bit_scan1, remove) +from sympy.polys import Poly +from sympy.polys.domains import ZZ +from sympy.polys.galoistools import gf_crt1, gf_crt2, linear_congruence, gf_csolve +from .primetest import isprime +from .generate import primerange +from .factor_ import factorint, _perfect_power +from .modular import crt +from sympy.utilities.decorator import deprecated +from sympy.utilities.memoization import recurrence_memo +from sympy.utilities.misc import as_int +from sympy.utilities.iterables import iproduct +from sympy.core.random import _randint, randint + +from itertools import product + + +def n_order(a, n): + r""" Returns the order of ``a`` modulo ``n``. + + Explanation + =========== + + The order of ``a`` modulo ``n`` is the smallest integer + ``k`` such that `a^k` leaves a remainder of 1 with ``n``. + + Parameters + ========== + + a : integer + n : integer, n > 1. a and n should be relatively prime + + Returns + ======= + + int : the order of ``a`` modulo ``n`` + + Raises + ====== + + ValueError + If `n \le 1` or `\gcd(a, n) \neq 1`. + If ``a`` or ``n`` is not an integer. + + Examples + ======== + + >>> from sympy.ntheory import n_order + >>> n_order(3, 7) + 6 + >>> n_order(4, 7) + 3 + + See Also + ======== + + is_primitive_root + We say that ``a`` is a primitive root of ``n`` + when the order of ``a`` modulo ``n`` equals ``totient(n)`` + + """ + a, n = as_int(a), as_int(n) + if n <= 1: + raise ValueError("n should be an integer greater than 1") + a = a % n + # Trivial + if a == 1: + return 1 + if gcd(a, n) != 1: + raise ValueError("The two numbers should be relatively prime") + a_order = 1 + for p, e in factorint(n).items(): + pe = p**e + pe_order = (p - 1) * p**(e - 1) + factors = factorint(p - 1) + if e > 1: + factors[p] = e - 1 + order = 1 + for px, ex in factors.items(): + x = pow(a, pe_order // px**ex, pe) + while x != 1: + x = pow(x, px, pe) + order *= px + a_order = lcm(a_order, order) + return int(a_order) + + +def _primitive_root_prime_iter(p): + r""" Generates the primitive roots for a prime ``p``. + + Explanation + =========== + + The primitive roots generated are not necessarily sorted. + However, the first one is the smallest primitive root. + + Find the element whose order is ``p-1`` from the smaller one. + If we can find the first primitive root ``g``, we can use the following theorem. + + .. math :: + \operatorname{ord}(g^k) = \frac{\operatorname{ord}(g)}{\gcd(\operatorname{ord}(g), k)} + + From the assumption that `\operatorname{ord}(g)=p-1`, + it is a necessary and sufficient condition for + `\operatorname{ord}(g^k)=p-1` that `\gcd(p-1, k)=1`. + + Parameters + ========== + + p : odd prime + + Yields + ====== + + int + the primitive roots of ``p`` + + Examples + ======== + + >>> from sympy.ntheory.residue_ntheory import _primitive_root_prime_iter + >>> sorted(_primitive_root_prime_iter(19)) + [2, 3, 10, 13, 14, 15] + + References + ========== + + .. [1] W. Stein "Elementary Number Theory" (2011), page 44 + + """ + if p == 3: + yield 2 + return + # Let p = +-1 (mod 4a). Legendre symbol (a/p) = 1, so `a` is not the primitive root. + # Corollary : If p = +-1 (mod 8), then 2 is not the primitive root of p. + g_min = 3 if p % 8 in [1, 7] else 2 + if p < 41: + # small case + g = 5 if p == 23 else g_min + else: + v = [(p - 1) // i for i in factorint(p - 1).keys()] + for g in range(g_min, p): + if all(pow(g, pw, p) != 1 for pw in v): + break + yield g + # g**k is the primitive root of p iff gcd(p - 1, k) = 1 + for k in range(3, p, 2): + if gcd(p - 1, k) == 1: + yield pow(g, k, p) + + +def _primitive_root_prime_power_iter(p, e): + r""" Generates the primitive roots of `p^e`. + + Explanation + =========== + + Let ``g`` be the primitive root of ``p``. + If `g^{p-1} \not\equiv 1 \pmod{p^2}`, then ``g`` is primitive root of `p^e`. + Thus, if we find a primitive root ``g`` of ``p``, + then `g, g+p, g+2p, \ldots, g+(p-1)p` are primitive roots of `p^2` except one. + That one satisfies `\hat{g}^{p-1} \equiv 1 \pmod{p^2}`. + If ``h`` is the primitive root of `p^2`, + then `h, h+p^2, h+2p^2, \ldots, h+(p^{e-2}-1)p^e` are primitive roots of `p^e`. + + Parameters + ========== + + p : odd prime + e : positive integer + + Yields + ====== + + int + the primitive roots of `p^e` + + Examples + ======== + + >>> from sympy.ntheory.residue_ntheory import _primitive_root_prime_power_iter + >>> sorted(_primitive_root_prime_power_iter(5, 2)) + [2, 3, 8, 12, 13, 17, 22, 23] + + """ + if e == 1: + yield from _primitive_root_prime_iter(p) + else: + p2 = p**2 + for g in _primitive_root_prime_iter(p): + t = (g - pow(g, 2 - p, p2)) % p2 + for k in range(0, p2, p): + if k != t: + yield from (g + k + m for m in range(0, p**e, p2)) + + +def _primitive_root_prime_power2_iter(p, e): + r""" Generates the primitive roots of `2p^e`. + + Explanation + =========== + + If ``g`` is the primitive root of ``p**e``, + then the odd one of ``g`` and ``g+p**e`` is the primitive root of ``2*p**e``. + + Parameters + ========== + + p : odd prime + e : positive integer + + Yields + ====== + + int + the primitive roots of `2p^e` + + Examples + ======== + + >>> from sympy.ntheory.residue_ntheory import _primitive_root_prime_power2_iter + >>> sorted(_primitive_root_prime_power2_iter(5, 2)) + [3, 13, 17, 23, 27, 33, 37, 47] + + """ + for g in _primitive_root_prime_power_iter(p, e): + if g % 2 == 1: + yield g + else: + yield g + p**e + + +def primitive_root(p, smallest=True): + r""" Returns a primitive root of ``p`` or None. + + Explanation + =========== + + For the definition of primitive root, + see the explanation of ``is_primitive_root``. + + The primitive root of ``p`` exist only for + `p = 2, 4, q^e, 2q^e` (``q`` is an odd prime). + Now, if we know the primitive root of ``q``, + we can calculate the primitive root of `q^e`, + and if we know the primitive root of `q^e`, + we can calculate the primitive root of `2q^e`. + When there is no need to find the smallest primitive root, + this property can be used to obtain a fast primitive root. + On the other hand, when we want the smallest primitive root, + we naively determine whether it is a primitive root or not. + + Parameters + ========== + + p : integer, p > 1 + smallest : if True the smallest primitive root is returned or None + + Returns + ======= + + int | None : + If the primitive root exists, return the primitive root of ``p``. + If not, return None. + + Raises + ====== + + ValueError + If `p \le 1` or ``p`` is not an integer. + + Examples + ======== + + >>> from sympy.ntheory.residue_ntheory import primitive_root + >>> primitive_root(19) + 2 + >>> primitive_root(21) is None + True + >>> primitive_root(50, smallest=False) + 27 + + See Also + ======== + + is_primitive_root + + References + ========== + + .. [1] W. Stein "Elementary Number Theory" (2011), page 44 + .. [2] P. Hackman "Elementary Number Theory" (2009), Chapter C + + """ + p = as_int(p) + if p <= 1: + raise ValueError("p should be an integer greater than 1") + if p <= 4: + return p - 1 + p_even = p % 2 == 0 + if not p_even: + q = p # p is odd + elif p % 4: + q = p//2 # p had 1 factor of 2 + else: + return None # p had more than one factor of 2 + if isprime(q): + e = 1 + else: + m = _perfect_power(q, 3) + if not m: + return None + q, e = m + if not isprime(q): + return None + if not smallest: + if p_even: + return next(_primitive_root_prime_power2_iter(q, e)) + return next(_primitive_root_prime_power_iter(q, e)) + if p_even: + for i in range(3, p, 2): + if i % q and is_primitive_root(i, p): + return i + g = next(_primitive_root_prime_iter(q)) + if e == 1 or pow(g, q - 1, q**2) != 1: + return g + for i in range(g + 1, p): + if i % q and is_primitive_root(i, p): + return i + + +def is_primitive_root(a, p): + r""" Returns True if ``a`` is a primitive root of ``p``. + + Explanation + =========== + + ``a`` is said to be the primitive root of ``p`` if `\gcd(a, p) = 1` and + `\phi(p)` is the smallest positive number s.t. + + `a^{\phi(p)} \equiv 1 \pmod{p}`. + + where `\phi(p)` is Euler's totient function. + + The primitive root of ``p`` exist only for + `p = 2, 4, q^e, 2q^e` (``q`` is an odd prime). + Hence, if it is not such a ``p``, it returns False. + To determine the primitive root, we need to know + the prime factorization of ``q-1``. + The hardness of the determination depends on this complexity. + + Parameters + ========== + + a : integer + p : integer, ``p`` > 1. ``a`` and ``p`` should be relatively prime + + Returns + ======= + + bool : If True, ``a`` is the primitive root of ``p``. + + Raises + ====== + + ValueError + If `p \le 1` or `\gcd(a, p) \neq 1`. + If ``a`` or ``p`` is not an integer. + + Examples + ======== + + >>> from sympy.functions.combinatorial.numbers import totient + >>> from sympy.ntheory import is_primitive_root, n_order + >>> is_primitive_root(3, 10) + True + >>> is_primitive_root(9, 10) + False + >>> n_order(3, 10) == totient(10) + True + >>> n_order(9, 10) == totient(10) + False + + See Also + ======== + + primitive_root + + """ + a, p = as_int(a), as_int(p) + if p <= 1: + raise ValueError("p should be an integer greater than 1") + a = a % p + if gcd(a, p) != 1: + raise ValueError("The two numbers should be relatively prime") + # Primitive root of p exist only for + # p = 2, 4, q**e, 2*q**e (q is odd prime) + if p <= 4: + # The primitive root is only p-1. + return a == p - 1 + if p % 2: + q = p # p is odd + elif p % 4: + q = p//2 # p had 1 factor of 2 + else: + return False # p had more than one factor of 2 + if isprime(q): + group_order = q - 1 + factors = factorint(q - 1).keys() + else: + m = _perfect_power(q, 3) + if not m: + return False + q, e = m + if not isprime(q): + return False + group_order = q**(e - 1)*(q - 1) + factors = set(factorint(q - 1).keys()) + factors.add(q) + return all(pow(a, group_order // prime, p) != 1 for prime in factors) + + +def _sqrt_mod_tonelli_shanks(a, p): + """ + Returns the square root in the case of ``p`` prime with ``p == 1 (mod 8)`` + + Assume that the root exists. + + Parameters + ========== + + a : int + p : int + prime number. should be ``p % 8 == 1`` + + Returns + ======= + + int : Generally, there are two roots, but only one is returned. + Which one is returned is random. + + Examples + ======== + + >>> from sympy.ntheory.residue_ntheory import _sqrt_mod_tonelli_shanks + >>> _sqrt_mod_tonelli_shanks(2, 17) in [6, 11] + True + + References + ========== + + .. [1] Carl Pomerance, Richard Crandall, Prime Numbers: A Computational Perspective, + 2nd Edition (2005), page 101, ISBN:978-0387252827 + + """ + s = bit_scan1(p - 1) + t = p >> s + # find a non-quadratic residue + if p % 12 == 5: + # Legendre symbol (3/p) == -1 if p % 12 in [5, 7] + d = 3 + elif p % 5 in [2, 3]: + # Legendre symbol (5/p) == -1 if p % 5 in [2, 3] + d = 5 + else: + while 1: + d = randint(6, p - 1) + if jacobi(d, p) == -1: + break + #assert legendre_symbol(d, p) == -1 + A = pow(a, t, p) + D = pow(d, t, p) + m = 0 + for i in range(s): + adm = A*pow(D, m, p) % p + adm = pow(adm, 2**(s - 1 - i), p) + if adm % p == p - 1: + m += 2**i + #assert A*pow(D, m, p) % p == 1 + x = pow(a, (t + 1)//2, p)*pow(D, m//2, p) % p + return x + + +def sqrt_mod(a, p, all_roots=False): + """ + Find a root of ``x**2 = a mod p``. + + Parameters + ========== + + a : integer + p : positive integer + all_roots : if True the list of roots is returned or None + + Notes + ===== + + If there is no root it is returned None; else the returned root + is less or equal to ``p // 2``; in general is not the smallest one. + It is returned ``p // 2`` only if it is the only root. + + Use ``all_roots`` only when it is expected that all the roots fit + in memory; otherwise use ``sqrt_mod_iter``. + + Examples + ======== + + >>> from sympy.ntheory import sqrt_mod + >>> sqrt_mod(11, 43) + 21 + >>> sqrt_mod(17, 32, True) + [7, 9, 23, 25] + """ + if all_roots: + return sorted(sqrt_mod_iter(a, p)) + p = abs(as_int(p)) + halfp = p // 2 + x = None + for r in sqrt_mod_iter(a, p): + if r < halfp: + return r + elif r > halfp: + return p - r + else: + x = r + return x + + +def sqrt_mod_iter(a, p, domain=int): + """ + Iterate over solutions to ``x**2 = a mod p``. + + Parameters + ========== + + a : integer + p : positive integer + domain : integer domain, ``int``, ``ZZ`` or ``Integer`` + + Examples + ======== + + >>> from sympy.ntheory.residue_ntheory import sqrt_mod_iter + >>> list(sqrt_mod_iter(11, 43)) + [21, 22] + + See Also + ======== + + sqrt_mod : Same functionality, but you want a sorted list or only one solution. + + """ + a, p = as_int(a), abs(as_int(p)) + v = [] + pv = [] + _product = product + for px, ex in factorint(p).items(): + if a % px: + # `len(rx)` is at most 4 + rx = _sqrt_mod_prime_power(a, px, ex) + else: + # `len(list(rx))` can be assumed to be large. + # The `itertools.product` is disadvantageous in terms of memory usage. + # It is also inferior to iproduct in speed if not all Cartesian products are needed. + rx = _sqrt_mod1(a, px, ex) + _product = iproduct + if not rx: + return + v.append(rx) + pv.append(px**ex) + if len(v) == 1: + yield from map(domain, v[0]) + else: + mm, e, s = gf_crt1(pv, ZZ) + for vx in _product(*v): + yield domain(gf_crt2(vx, pv, mm, e, s, ZZ)) + + +def _sqrt_mod_prime_power(a, p, k): + """ + Find the solutions to ``x**2 = a mod p**k`` when ``a % p != 0``. + If no solution exists, return ``None``. + Solutions are returned in an ascending list. + + Parameters + ========== + + a : integer + p : prime number + k : positive integer + + Examples + ======== + + >>> from sympy.ntheory.residue_ntheory import _sqrt_mod_prime_power + >>> _sqrt_mod_prime_power(11, 43, 1) + [21, 22] + + References + ========== + + .. [1] P. Hackman "Elementary Number Theory" (2009), page 160 + .. [2] http://www.numbertheory.org/php/squareroot.html + .. [3] [Gathen99]_ + """ + pk = p**k + a = a % pk + + if p == 2: + # see Ref.[2] + if a % 8 != 1: + return None + # Trivial + if k <= 3: + return list(range(1, pk, 2)) + r = 1 + # r is one of the solutions to x**2 - a = 0 (mod 2**3). + # Hensel lift them to solutions of x**2 - a = 0 (mod 2**k) + # if r**2 - a = 0 mod 2**nx but not mod 2**(nx+1) + # then r + 2**(nx - 1) is a root mod 2**(nx+1) + for nx in range(3, k): + if ((r**2 - a) >> nx) % 2: + r += 1 << (nx - 1) + # r is a solution of x**2 - a = 0 (mod 2**k), and + # there exist other solutions -r, r+h, -(r+h), and these are all solutions. + h = 1 << (k - 1) + return sorted([r, pk - r, (r + h) % pk, -(r + h) % pk]) + + # If the Legendre symbol (a/p) is not 1, no solution exists. + if jacobi(a, p) != 1: + return None + if p % 4 == 3: + res = pow(a, (p + 1) // 4, p) + elif p % 8 == 5: + res = pow(a, (p + 3) // 8, p) + if pow(res, 2, p) != a % p: + res = res * pow(2, (p - 1) // 4, p) % p + else: + res = _sqrt_mod_tonelli_shanks(a, p) + if k > 1: + # Hensel lifting with Newton iteration, see Ref.[3] chapter 9 + # with f(x) = x**2 - a; one has f'(a) != 0 (mod p) for p != 2 + px = p + for _ in range(k.bit_length() - 1): + px = px**2 + frinv = invert(2*res, px) + res = (res - (res**2 - a)*frinv) % px + if k & (k - 1): # If k is not a power of 2 + frinv = invert(2*res, pk) + res = (res - (res**2 - a)*frinv) % pk + return sorted([res, pk - res]) + + +def _sqrt_mod1(a, p, n): + """ + Find solution to ``x**2 == a mod p**n`` when ``a % p == 0``. + If no solution exists, return ``None``. + + Parameters + ========== + + a : integer + p : prime number, p must divide a + n : positive integer + + References + ========== + + .. [1] http://www.numbertheory.org/php/squareroot.html + """ + pn = p**n + a = a % pn + if a == 0: + # case gcd(a, p**k) = p**n + return range(0, pn, p**((n + 1) // 2)) + # case gcd(a, p**k) = p**r, r < n + a, r = remove(a, p) + if r % 2 == 1: + return None + res = _sqrt_mod_prime_power(a, p, n - r) + if res is None: + return None + m = r // 2 + return (x for rx in res for x in range(rx*p**m, pn, p**(n - m))) + + +def is_quad_residue(a, p): + """ + Returns True if ``a`` (mod ``p``) is in the set of squares mod ``p``, + i.e a % p in set([i**2 % p for i in range(p)]). + + Parameters + ========== + + a : integer + p : positive integer + + Returns + ======= + + bool : If True, ``x**2 == a (mod p)`` has solution. + + Raises + ====== + + ValueError + If ``a``, ``p`` is not integer. + If ``p`` is not positive. + + Examples + ======== + + >>> from sympy.ntheory import is_quad_residue + >>> is_quad_residue(21, 100) + True + + Indeed, ``pow(39, 2, 100)`` would be 21. + + >>> is_quad_residue(21, 120) + False + + That is, for any integer ``x``, ``pow(x, 2, 120)`` is not 21. + + If ``p`` is an odd + prime, an iterative method is used to make the determination: + + >>> from sympy.ntheory import is_quad_residue + >>> sorted(set([i**2 % 7 for i in range(7)])) + [0, 1, 2, 4] + >>> [j for j in range(7) if is_quad_residue(j, 7)] + [0, 1, 2, 4] + + See Also + ======== + + legendre_symbol, jacobi_symbol, sqrt_mod + """ + a, p = as_int(a), as_int(p) + if p < 1: + raise ValueError('p must be > 0') + a %= p + if a < 2 or p < 3: + return True + # Since we want to compute the Jacobi symbol, + # we separate p into the odd part and the rest. + t = bit_scan1(p) + if t: + # The existence of a solution to a power of 2 is determined + # using the logic of `p==2` in `_sqrt_mod_prime_power` and `_sqrt_mod1`. + a_ = a % (1 << t) + if a_: + r = bit_scan1(a_) + if r % 2 or (a_ >> r) & 6: + return False + p >>= t + a %= p + if a < 2 or p < 3: + return True + # If Jacobi symbol is -1 or p is prime, can be determined by Jacobi symbol only + j = jacobi(a, p) + if j == -1 or isprime(p): + return j == 1 + # Checks if `x**2 = a (mod p)` has a solution + for px, ex in factorint(p).items(): + if a % px: + if jacobi(a, px) != 1: + return False + else: + a_ = a % px**ex + if a_ == 0: + continue + a_, r = remove(a_, px) + if r % 2 or jacobi(a_, px) != 1: + return False + return True + + +def is_nthpow_residue(a, n, m): + """ + Returns True if ``x**n == a (mod m)`` has solutions. + + References + ========== + + .. [1] P. Hackman "Elementary Number Theory" (2009), page 76 + + """ + a = a % m + a, n, m = as_int(a), as_int(n), as_int(m) + if m <= 0: + raise ValueError('m must be > 0') + if n < 0: + raise ValueError('n must be >= 0') + if n == 0: + if m == 1: + return False + return a == 1 + if a == 0: + return True + if n == 1: + return True + if n == 2: + return is_quad_residue(a, m) + return all(_is_nthpow_residue_bign_prime_power(a, n, p, e) + for p, e in factorint(m).items()) + + +def _is_nthpow_residue_bign_prime_power(a, n, p, k): + r""" + Returns True if `x^n = a \pmod{p^k}` has solutions for `n > 2`. + + Parameters + ========== + + a : positive integer + n : integer, n > 2 + p : prime number + k : positive integer + + """ + while a % p == 0: + a %= pow(p, k) + if not a: + return True + a, mu = remove(a, p) + if mu % n: + return False + k -= mu + if p != 2: + f = p**(k - 1)*(p - 1) # f = totient(p**k) + return pow(a, f // gcd(f, n), pow(p, k)) == 1 + if n & 1: + return True + c = min(bit_scan1(n) + 2, k) + return a % pow(2, c) == 1 + + +def _nthroot_mod1(s, q, p, all_roots): + """ + Root of ``x**q = s mod p``, ``p`` prime and ``q`` divides ``p - 1``. + Assume that the root exists. + + Parameters + ========== + + s : integer + q : integer, n > 2. ``q`` divides ``p - 1``. + p : prime number + all_roots : if False returns the smallest root, else the list of roots + + Returns + ======= + + list[int] | int : + Root of ``x**q = s mod p``. If ``all_roots == True``, + returned ascending list. otherwise, returned an int. + + Examples + ======== + + >>> from sympy.ntheory.residue_ntheory import _nthroot_mod1 + >>> _nthroot_mod1(5, 3, 13, False) + 7 + >>> _nthroot_mod1(13, 4, 17, True) + [3, 5, 12, 14] + + References + ========== + + .. [1] A. M. Johnston, A Generalized qth Root Algorithm, + ACM-SIAM Symposium on Discrete Algorithms (1999), pp. 929-930 + + """ + g = next(_primitive_root_prime_iter(p)) + r = s + for qx, ex in factorint(q).items(): + f = (p - 1) // qx**ex + while f % qx == 0: + f //= qx + z = f*invert(-f, qx) + x = (1 + z) // qx + t = discrete_log(p, pow(r, f, p), pow(g, f*qx, p)) + for _ in range(ex): + # assert t == discrete_log(p, pow(r, f, p), pow(g, f*qx, p)) + r = pow(r, x, p)*pow(g, -z*t % (p - 1), p) % p + t //= qx + res = [r] + h = pow(g, (p - 1) // q, p) + #assert pow(h, q, p) == 1 + hx = r + for _ in range(q - 1): + hx = (hx*h) % p + res.append(hx) + if all_roots: + res.sort() + return res + return min(res) + + +def _nthroot_mod_prime_power(a, n, p, k): + """ Root of ``x**n = a mod p**k``. + + Parameters + ========== + + a : integer + n : integer, n > 2 + p : prime number + k : positive integer + + Returns + ======= + + list[int] : + Ascending list of roots of ``x**n = a mod p**k``. + If no solution exists, return ``[]``. + + """ + if not _is_nthpow_residue_bign_prime_power(a, n, p, k): + return [] + a_mod_p = a % p + if a_mod_p == 0: + base_roots = [0] + elif (p - 1) % n == 0: + base_roots = _nthroot_mod1(a_mod_p, n, p, all_roots=True) + else: + # The roots of ``x**n - a = 0 (mod p)`` are roots of + # ``gcd(x**n - a, x**(p - 1) - 1) = 0 (mod p)`` + pa = n + pb = p - 1 + b = 1 + if pa < pb: + a_mod_p, pa, b, pb = b, pb, a_mod_p, pa + # gcd(x**pa - a, x**pb - b) = gcd(x**pb - b, x**pc - c) + # where pc = pa % pb; c = b**-q * a mod p + while pb: + q, pc = divmod(pa, pb) + c = pow(b, -q, p) * a_mod_p % p + pa, pb = pb, pc + a_mod_p, b = b, c + if pa == 1: + base_roots = [a_mod_p] + elif pa == 2: + base_roots = sqrt_mod(a_mod_p, p, all_roots=True) + else: + base_roots = _nthroot_mod1(a_mod_p, pa, p, all_roots=True) + if k == 1: + return base_roots + a %= p**k + tot_roots = set() + for root in base_roots: + diff = pow(root, n - 1, p)*n % p + new_base = p + if diff != 0: + m_inv = invert(diff, p) + for _ in range(k - 1): + new_base *= p + tmp = pow(root, n, new_base) - a + tmp *= m_inv + root = (root - tmp) % new_base + tot_roots.add(root) + else: + roots_in_base = {root} + for _ in range(k - 1): + new_base *= p + new_roots = set() + for k_ in roots_in_base: + if pow(k_, n, new_base) != a % new_base: + continue + while k_ not in new_roots: + new_roots.add(k_) + k_ = (k_ + (new_base // p)) % new_base + roots_in_base = new_roots + tot_roots = tot_roots | roots_in_base + return sorted(tot_roots) + + +def nthroot_mod(a, n, p, all_roots=False): + """ + Find the solutions to ``x**n = a mod p``. + + Parameters + ========== + + a : integer + n : positive integer + p : positive integer + all_roots : if False returns the smallest root, else the list of roots + + Returns + ======= + + list[int] | int | None : + solutions to ``x**n = a mod p``. + The table of the output type is: + + ========== ========== ========== + all_roots has roots Returns + ========== ========== ========== + True Yes list[int] + True No [] + False Yes int + False No None + ========== ========== ========== + + Raises + ====== + + ValueError + If ``a``, ``n`` or ``p`` is not integer. + If ``n`` or ``p`` is not positive. + + Examples + ======== + + >>> from sympy.ntheory.residue_ntheory import nthroot_mod + >>> nthroot_mod(11, 4, 19) + 8 + >>> nthroot_mod(11, 4, 19, True) + [8, 11] + >>> nthroot_mod(68, 3, 109) + 23 + + References + ========== + + .. [1] P. Hackman "Elementary Number Theory" (2009), page 76 + + """ + a = a % p + a, n, p = as_int(a), as_int(n), as_int(p) + + if n < 1: + raise ValueError("n should be positive") + if p < 1: + raise ValueError("p should be positive") + if n == 1: + return [a] if all_roots else a + if n == 2: + return sqrt_mod(a, p, all_roots) + base = [] + prime_power = [] + for q, e in factorint(p).items(): + tot_roots = _nthroot_mod_prime_power(a, n, q, e) + if not tot_roots: + return [] if all_roots else None + prime_power.append(q**e) + base.append(sorted(tot_roots)) + P, E, S = gf_crt1(prime_power, ZZ) + ret = sorted(map(int, {gf_crt2(c, prime_power, P, E, S, ZZ) + for c in product(*base)})) + if all_roots: + return ret + if ret: + return ret[0] + + +def quadratic_residues(p) -> list[int]: + """ + Returns the list of quadratic residues. + + Examples + ======== + + >>> from sympy.ntheory.residue_ntheory import quadratic_residues + >>> quadratic_residues(7) + [0, 1, 2, 4] + """ + p = as_int(p) + r = {pow(i, 2, p) for i in range(p // 2 + 1)} + return sorted(r) + + +@deprecated("""\ +The `sympy.ntheory.residue_ntheory.legendre_symbol` has been moved to `sympy.functions.combinatorial.numbers.legendre_symbol`.""", +deprecated_since_version="1.13", +active_deprecations_target='deprecated-ntheory-symbolic-functions') +def legendre_symbol(a, p): + r""" + Returns the Legendre symbol `(a / p)`. + + .. deprecated:: 1.13 + + The ``legendre_symbol`` function is deprecated. Use :class:`sympy.functions.combinatorial.numbers.legendre_symbol` + instead. See its documentation for more information. See + :ref:`deprecated-ntheory-symbolic-functions` for details. + + For an integer ``a`` and an odd prime ``p``, the Legendre symbol is + defined as + + .. math :: + \genfrac(){}{}{a}{p} = \begin{cases} + 0 & \text{if } p \text{ divides } a\\ + 1 & \text{if } a \text{ is a quadratic residue modulo } p\\ + -1 & \text{if } a \text{ is a quadratic nonresidue modulo } p + \end{cases} + + Parameters + ========== + + a : integer + p : odd prime + + Examples + ======== + + >>> from sympy.functions.combinatorial.numbers import legendre_symbol + >>> [legendre_symbol(i, 7) for i in range(7)] + [0, 1, 1, -1, 1, -1, -1] + >>> sorted(set([i**2 % 7 for i in range(7)])) + [0, 1, 2, 4] + + See Also + ======== + + is_quad_residue, jacobi_symbol + + """ + from sympy.functions.combinatorial.numbers import legendre_symbol as _legendre_symbol + return _legendre_symbol(a, p) + + +@deprecated("""\ +The `sympy.ntheory.residue_ntheory.jacobi_symbol` has been moved to `sympy.functions.combinatorial.numbers.jacobi_symbol`.""", +deprecated_since_version="1.13", +active_deprecations_target='deprecated-ntheory-symbolic-functions') +def jacobi_symbol(m, n): + r""" + Returns the Jacobi symbol `(m / n)`. + + .. deprecated:: 1.13 + + The ``jacobi_symbol`` function is deprecated. Use :class:`sympy.functions.combinatorial.numbers.jacobi_symbol` + instead. See its documentation for more information. See + :ref:`deprecated-ntheory-symbolic-functions` for details. + + For any integer ``m`` and any positive odd integer ``n`` the Jacobi symbol + is defined as the product of the Legendre symbols corresponding to the + prime factors of ``n``: + + .. math :: + \genfrac(){}{}{m}{n} = + \genfrac(){}{}{m}{p^{1}}^{\alpha_1} + \genfrac(){}{}{m}{p^{2}}^{\alpha_2} + ... + \genfrac(){}{}{m}{p^{k}}^{\alpha_k} + \text{ where } n = + p_1^{\alpha_1} + p_2^{\alpha_2} + ... + p_k^{\alpha_k} + + Like the Legendre symbol, if the Jacobi symbol `\genfrac(){}{}{m}{n} = -1` + then ``m`` is a quadratic nonresidue modulo ``n``. + + But, unlike the Legendre symbol, if the Jacobi symbol + `\genfrac(){}{}{m}{n} = 1` then ``m`` may or may not be a quadratic residue + modulo ``n``. + + Parameters + ========== + + m : integer + n : odd positive integer + + Examples + ======== + + >>> from sympy.functions.combinatorial.numbers import jacobi_symbol, legendre_symbol + >>> from sympy import S + >>> jacobi_symbol(45, 77) + -1 + >>> jacobi_symbol(60, 121) + 1 + + The relationship between the ``jacobi_symbol`` and ``legendre_symbol`` can + be demonstrated as follows: + + >>> L = legendre_symbol + >>> S(45).factors() + {3: 2, 5: 1} + >>> jacobi_symbol(7, 45) == L(7, 3)**2 * L(7, 5)**1 + True + + See Also + ======== + + is_quad_residue, legendre_symbol + """ + from sympy.functions.combinatorial.numbers import jacobi_symbol as _jacobi_symbol + return _jacobi_symbol(m, n) + + +@deprecated("""\ +The `sympy.ntheory.residue_ntheory.mobius` has been moved to `sympy.functions.combinatorial.numbers.mobius`.""", +deprecated_since_version="1.13", +active_deprecations_target='deprecated-ntheory-symbolic-functions') +def mobius(n): + """ + Mobius function maps natural number to {-1, 0, 1} + + .. deprecated:: 1.13 + + The ``mobius`` function is deprecated. Use :class:`sympy.functions.combinatorial.numbers.mobius` + instead. See its documentation for more information. See + :ref:`deprecated-ntheory-symbolic-functions` for details. + + It is defined as follows: + 1) `1` if `n = 1`. + 2) `0` if `n` has a squared prime factor. + 3) `(-1)^k` if `n` is a square-free positive integer with `k` + number of prime factors. + + It is an important multiplicative function in number theory + and combinatorics. It has applications in mathematical series, + algebraic number theory and also physics (Fermion operator has very + concrete realization with Mobius Function model). + + Parameters + ========== + + n : positive integer + + Examples + ======== + + >>> from sympy.functions.combinatorial.numbers import mobius + >>> mobius(13*7) + 1 + >>> mobius(1) + 1 + >>> mobius(13*7*5) + -1 + >>> mobius(13**2) + 0 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/M%C3%B6bius_function + .. [2] Thomas Koshy "Elementary Number Theory with Applications" + + """ + from sympy.functions.combinatorial.numbers import mobius as _mobius + return _mobius(n) + + +def _discrete_log_trial_mul(n, a, b, order=None): + """ + Trial multiplication algorithm for computing the discrete logarithm of + ``a`` to the base ``b`` modulo ``n``. + + The algorithm finds the discrete logarithm using exhaustive search. This + naive method is used as fallback algorithm of ``discrete_log`` when the + group order is very small. The value ``n`` must be greater than 1. + + Examples + ======== + + >>> from sympy.ntheory.residue_ntheory import _discrete_log_trial_mul + >>> _discrete_log_trial_mul(41, 15, 7) + 3 + + See Also + ======== + + discrete_log + + References + ========== + + .. [1] "Handbook of applied cryptography", Menezes, A. J., Van, O. P. C., & + Vanstone, S. A. (1997). + """ + a %= n + b %= n + if order is None: + order = n + x = 1 + for i in range(order): + if x == a: + return i + x = x * b % n + raise ValueError("Log does not exist") + + +def _discrete_log_shanks_steps(n, a, b, order=None): + """ + Baby-step giant-step algorithm for computing the discrete logarithm of + ``a`` to the base ``b`` modulo ``n``. + + The algorithm is a time-memory trade-off of the method of exhaustive + search. It uses `O(sqrt(m))` memory, where `m` is the group order. + + Examples + ======== + + >>> from sympy.ntheory.residue_ntheory import _discrete_log_shanks_steps + >>> _discrete_log_shanks_steps(41, 15, 7) + 3 + + See Also + ======== + + discrete_log + + References + ========== + + .. [1] "Handbook of applied cryptography", Menezes, A. J., Van, O. P. C., & + Vanstone, S. A. (1997). + """ + a %= n + b %= n + if order is None: + order = n_order(b, n) + m = sqrt(order) + 1 + T = {} + x = 1 + for i in range(m): + T[x] = i + x = x * b % n + z = pow(b, -m, n) + x = a + for i in range(m): + if x in T: + return i * m + T[x] + x = x * z % n + raise ValueError("Log does not exist") + + +def _discrete_log_pollard_rho(n, a, b, order=None, retries=10, rseed=None): + """ + Pollard's Rho algorithm for computing the discrete logarithm of ``a`` to + the base ``b`` modulo ``n``. + + It is a randomized algorithm with the same expected running time as + ``_discrete_log_shanks_steps``, but requires a negligible amount of memory. + + Examples + ======== + + >>> from sympy.ntheory.residue_ntheory import _discrete_log_pollard_rho + >>> _discrete_log_pollard_rho(227, 3**7, 3) + 7 + + See Also + ======== + + discrete_log + + References + ========== + + .. [1] "Handbook of applied cryptography", Menezes, A. J., Van, O. P. C., & + Vanstone, S. A. (1997). + """ + a %= n + b %= n + + if order is None: + order = n_order(b, n) + randint = _randint(rseed) + + for i in range(retries): + aa = randint(1, order - 1) + ba = randint(1, order - 1) + xa = pow(b, aa, n) * pow(a, ba, n) % n + + c = xa % 3 + if c == 0: + xb = a * xa % n + ab = aa + bb = (ba + 1) % order + elif c == 1: + xb = xa * xa % n + ab = (aa + aa) % order + bb = (ba + ba) % order + else: + xb = b * xa % n + ab = (aa + 1) % order + bb = ba + + for j in range(order): + c = xa % 3 + if c == 0: + xa = a * xa % n + ba = (ba + 1) % order + elif c == 1: + xa = xa * xa % n + aa = (aa + aa) % order + ba = (ba + ba) % order + else: + xa = b * xa % n + aa = (aa + 1) % order + + c = xb % 3 + if c == 0: + xb = a * xb % n + bb = (bb + 1) % order + elif c == 1: + xb = xb * xb % n + ab = (ab + ab) % order + bb = (bb + bb) % order + else: + xb = b * xb % n + ab = (ab + 1) % order + + c = xb % 3 + if c == 0: + xb = a * xb % n + bb = (bb + 1) % order + elif c == 1: + xb = xb * xb % n + ab = (ab + ab) % order + bb = (bb + bb) % order + else: + xb = b * xb % n + ab = (ab + 1) % order + + if xa == xb: + r = (ba - bb) % order + try: + e = invert(r, order) * (ab - aa) % order + if (pow(b, e, n) - a) % n == 0: + return e + except ZeroDivisionError: + pass + break + raise ValueError("Pollard's Rho failed to find logarithm") + + +def _discrete_log_is_smooth(n: int, factorbase: list): + """Try to factor n with respect to a given factorbase. + Upon success a list of exponents with respect to the factorbase is returned. + Otherwise None.""" + factors = [0]*len(factorbase) + for i, p in enumerate(factorbase): + while n % p == 0: # divide by p as many times as possible + factors[i] += 1 + n = n // p + if n != 1: + return None # the number factors if at the end nothing is left + return factors + + +def _discrete_log_index_calculus(n, a, b, order, rseed=None): + """ + Index Calculus algorithm for computing the discrete logarithm of ``a`` to + the base ``b`` modulo ``n``. + + The group order must be given and prime. It is not suitable for small orders + and the algorithm might fail to find a solution in such situations. + + Examples + ======== + + >>> from sympy.ntheory.residue_ntheory import _discrete_log_index_calculus + >>> _discrete_log_index_calculus(24570203447, 23859756228, 2, 12285101723) + 4519867240 + + See Also + ======== + + discrete_log + + References + ========== + + .. [1] "Handbook of applied cryptography", Menezes, A. J., Van, O. P. C., & + Vanstone, S. A. (1997). + """ + randint = _randint(rseed) + from math import sqrt, exp, log + a %= n + b %= n + # assert isprime(order), "The order of the base must be prime." + # First choose a heuristic the bound B for the factorbase. + # We have added an extra term to the asymptotic value which + # is closer to the theoretical optimum for n up to 2^70. + B = int(exp(0.5 * sqrt( log(n) * log(log(n)) )*( 1 + 1/log(log(n)) ))) + max = 5 * B * B # expected number of tries to find a relation + factorbase = list(primerange(B)) # compute the factorbase + lf = len(factorbase) # length of the factorbase + ordermo = order-1 + abx = a + for x in range(order): + if abx == 1: + return (order - x) % order + relationa = _discrete_log_is_smooth(abx, factorbase) + if relationa: + relationa = [r % order for r in relationa] + [x] + break + abx = abx * b % n # abx = a*pow(b, x, n) % n + + else: + raise ValueError("Index Calculus failed") + + relations = [None] * lf + k = 1 # number of relations found + kk = 0 + while k < 3 * lf and kk < max: # find relations for all primes in our factor base + x = randint(1,ordermo) + relation = _discrete_log_is_smooth(pow(b,x,n), factorbase) + if relation is None: + kk += 1 + continue + k += 1 + kk = 0 + relation += [ x ] + index = lf # determine the index of the first nonzero entry + for i in range(lf): + ri = relation[i] % order + if ri> 0 and relations[i] is not None: # make this entry zero if we can + for j in range(lf+1): + relation[j] = (relation[j] - ri*relations[i][j]) % order + else: + relation[i] = ri + if relation[i] > 0 and index == lf: # is this the index of the first nonzero entry? + index = i + if index == lf or relations[index] is not None: # the relation contains no new information + continue + # the relation contains new information + rinv = pow(relation[index],-1,order) # normalize the first nonzero entry + for j in range(index,lf+1): + relation[j] = rinv * relation[j] % order + relations[index] = relation + for i in range(lf): # subtract the new relation from the one for a + if relationa[i] > 0 and relations[i] is not None: + rbi = relationa[i] + for j in range(lf+1): + relationa[j] = (relationa[j] - rbi*relations[i][j]) % order + if relationa[i] > 0: # the index of the first nonzero entry + break # we do not need to reduce further at this point + else: # all unknowns are gone + #print(f"Success after {k} relations out of {lf}") + x = (order -relationa[lf]) % order + if pow(b,x,n) == a: + return x + raise ValueError("Index Calculus failed") + raise ValueError("Index Calculus failed") + + +def _discrete_log_pohlig_hellman(n, a, b, order=None, order_factors=None): + """ + Pohlig-Hellman algorithm for computing the discrete logarithm of ``a`` to + the base ``b`` modulo ``n``. + + In order to compute the discrete logarithm, the algorithm takes advantage + of the factorization of the group order. It is more efficient when the + group order factors into many small primes. + + Examples + ======== + + >>> from sympy.ntheory.residue_ntheory import _discrete_log_pohlig_hellman + >>> _discrete_log_pohlig_hellman(251, 210, 71) + 197 + + See Also + ======== + + discrete_log + + References + ========== + + .. [1] "Handbook of applied cryptography", Menezes, A. J., Van, O. P. C., & + Vanstone, S. A. (1997). + """ + from .modular import crt + a %= n + b %= n + + if order is None: + order = n_order(b, n) + if order_factors is None: + order_factors = factorint(order) + l = [0] * len(order_factors) + + for i, (pi, ri) in enumerate(order_factors.items()): + for j in range(ri): + aj = pow(a * pow(b, -l[i], n), order // pi**(j + 1), n) + bj = pow(b, order // pi, n) + cj = discrete_log(n, aj, bj, pi, True) + l[i] += cj * pi**j + + d, _ = crt([pi**ri for pi, ri in order_factors.items()], l) + return d + + +def discrete_log(n, a, b, order=None, prime_order=None): + """ + Compute the discrete logarithm of ``a`` to the base ``b`` modulo ``n``. + + This is a recursive function to reduce the discrete logarithm problem in + cyclic groups of composite order to the problem in cyclic groups of prime + order. + + It employs different algorithms depending on the problem (subgroup order + size, prime order or not): + + * Trial multiplication + * Baby-step giant-step + * Pollard's Rho + * Index Calculus + * Pohlig-Hellman + + Examples + ======== + + >>> from sympy.ntheory import discrete_log + >>> discrete_log(41, 15, 7) + 3 + + References + ========== + + .. [1] https://mathworld.wolfram.com/DiscreteLogarithm.html + .. [2] "Handbook of applied cryptography", Menezes, A. J., Van, O. P. C., & + Vanstone, S. A. (1997). + + """ + from math import sqrt, log + n, a, b = as_int(n), as_int(a), as_int(b) + + if n < 1: + raise ValueError("n should be positive") + if n == 1: + return 0 + + if order is None: + # Compute the order and its factoring in one pass + # order = totient(n), factors = factorint(order) + factors = {} + for px, kx in factorint(n).items(): + if kx > 1: + if px in factors: + factors[px] += kx - 1 + else: + factors[px] = kx - 1 + for py, ky in factorint(px - 1).items(): + if py in factors: + factors[py] += ky + else: + factors[py] = ky + order = 1 + for px, kx in factors.items(): + order *= px**kx + # Now the `order` is the order of the group and factors = factorint(order) + # The order of `b` divides the order of the group. + order_factors = {} + for p, e in factors.items(): + i = 0 + for _ in range(e): + if pow(b, order // p, n) == 1: + order //= p + i += 1 + else: + break + if i < e: + order_factors[p] = e - i + + if prime_order is None: + prime_order = isprime(order) + + if order < 1000: + return _discrete_log_trial_mul(n, a, b, order) + elif prime_order: + # Shanks and Pollard rho are O(sqrt(order)) while index calculus is O(exp(2*sqrt(log(n)log(log(n))))) + # we compare the expected running times to determine the algorithm which is expected to be faster + if 4*sqrt(log(n)*log(log(n))) < log(order) - 10: # the number 10 was determined experimental + return _discrete_log_index_calculus(n, a, b, order) + elif order < 1000000000000: + # Shanks seems typically faster, but uses O(sqrt(order)) memory + return _discrete_log_shanks_steps(n, a, b, order) + return _discrete_log_pollard_rho(n, a, b, order) + + return _discrete_log_pohlig_hellman(n, a, b, order, order_factors) + + + +def quadratic_congruence(a, b, c, n): + r""" + Find the solutions to `a x^2 + b x + c \equiv 0 \pmod{n}`. + + Parameters + ========== + + a : int + b : int + c : int + n : int + A positive integer. + + Returns + ======= + + list[int] : + A sorted list of solutions. If no solution exists, ``[]``. + + Examples + ======== + + >>> from sympy.ntheory.residue_ntheory import quadratic_congruence + >>> quadratic_congruence(2, 5, 3, 7) # 2x^2 + 5x + 3 = 0 (mod 7) + [2, 6] + >>> quadratic_congruence(8, 6, 4, 15) # No solution + [] + + See Also + ======== + + polynomial_congruence : Solve the polynomial congruence + + """ + a = as_int(a) + b = as_int(b) + c = as_int(c) + n = as_int(n) + if n <= 1: + raise ValueError("n should be an integer greater than 1") + a %= n + b %= n + c %= n + + if a == 0: + return linear_congruence(b, -c, n) + if n == 2: + # assert a == 1 + roots = [] + if c == 0: + roots.append(0) + if (b + c) % 2: + roots.append(1) + return roots + if gcd(2*a, n) == 1: + inv_a = invert(a, n) + b *= inv_a + c *= inv_a + if b % 2: + b += n + b >>= 1 + return sorted((i - b) % n for i in sqrt_mod_iter(b**2 - c, n)) + res = set() + for i in sqrt_mod_iter(b**2 - 4*a*c, 4*a*n): + q, rem = divmod(i - b, 2*a) + if rem == 0: + res.add(q % n) + + return sorted(res) + + +def _valid_expr(expr): + """ + return coefficients of expr if it is a univariate polynomial + with integer coefficients else raise a ValueError. + """ + + if not expr.is_polynomial(): + raise ValueError("The expression should be a polynomial") + polynomial = Poly(expr) + if not polynomial.is_univariate: + raise ValueError("The expression should be univariate") + if not polynomial.domain == ZZ: + raise ValueError("The expression should should have integer coefficients") + return polynomial.all_coeffs() + + +def polynomial_congruence(expr, m): + """ + Find the solutions to a polynomial congruence equation modulo m. + + Parameters + ========== + + expr : integer coefficient polynomial + m : positive integer + + Examples + ======== + + >>> from sympy.ntheory import polynomial_congruence + >>> from sympy.abc import x + >>> expr = x**6 - 2*x**5 -35 + >>> polynomial_congruence(expr, 6125) + [3257] + + See Also + ======== + + sympy.polys.galoistools.gf_csolve : low level solving routine used by this routine + + """ + coefficients = _valid_expr(expr) + coefficients = [num % m for num in coefficients] + rank = len(coefficients) + if rank == 3: + return quadratic_congruence(*coefficients, m) + if rank == 2: + return quadratic_congruence(0, *coefficients, m) + if coefficients[0] == 1 and 1 + coefficients[-1] == sum(coefficients): + return nthroot_mod(-coefficients[-1], rank - 1, m, True) + return gf_csolve(coefficients, m) + + +def binomial_mod(n, m, k): + """Compute ``binomial(n, m) % k``. + + Explanation + =========== + + Returns ``binomial(n, m) % k`` using a generalization of Lucas' + Theorem for prime powers given by Granville [1]_, in conjunction with + the Chinese Remainder Theorem. The residue for each prime power + is calculated in time O(log^2(n) + q^4*log(n)log(p) + q^4*p*log^3(p)). + + Parameters + ========== + + n : an integer + m : an integer + k : a positive integer + + Examples + ======== + + >>> from sympy.ntheory.residue_ntheory import binomial_mod + >>> binomial_mod(10, 2, 6) # binomial(10, 2) = 45 + 3 + >>> binomial_mod(17, 9, 10) # binomial(17, 9) = 24310 + 0 + + References + ========== + + .. [1] Binomial coefficients modulo prime powers, Andrew Granville, + Available: https://web.archive.org/web/20170202003812/http://www.dms.umontreal.ca/~andrew/PDF/BinCoeff.pdf + """ + if k < 1: raise ValueError('k is required to be positive') + # We decompose q into a product of prime powers and apply + # the generalization of Lucas' Theorem given by Granville + # to obtain binomial(n, k) mod p^e, and then use the Chinese + # Remainder Theorem to obtain the result mod q + if n < 0 or m < 0 or m > n: return 0 + factorisation = factorint(k) + residues = [_binomial_mod_prime_power(n, m, p, e) for p, e in factorisation.items()] + return crt([p**pw for p, pw in factorisation.items()], residues, check=False)[0] + + +def _binomial_mod_prime_power(n, m, p, q): + """Compute ``binomial(n, m) % p**q`` for a prime ``p``. + + Parameters + ========== + + n : positive integer + m : a nonnegative integer + p : a prime + q : a positive integer (the prime exponent) + + Examples + ======== + + >>> from sympy.ntheory.residue_ntheory import _binomial_mod_prime_power + >>> _binomial_mod_prime_power(10, 2, 3, 2) # binomial(10, 2) = 45 + 0 + >>> _binomial_mod_prime_power(17, 9, 2, 4) # binomial(17, 9) = 24310 + 6 + + References + ========== + + .. [1] Binomial coefficients modulo prime powers, Andrew Granville, + Available: https://web.archive.org/web/20170202003812/http://www.dms.umontreal.ca/~andrew/PDF/BinCoeff.pdf + """ + # Function/variable naming within this function follows Ref.[1] + # n!_p will be used to denote the product of integers <= n not divisible by + # p, with binomial(n, m)_p the same as binomial(n, m), but defined using + # n!_p in place of n! + modulo = pow(p, q) + + def up_factorial(u): + """Compute (u*p)!_p modulo p^q.""" + r = q // 2 + fac = prod = 1 + if r == 1 and p == 2 or 2*r + 1 in (p, p*p): + if q % 2 == 1: r += 1 + modulo, div = pow(p, 2*r), pow(p, 2*r - q) + else: + modulo, div = pow(p, 2*r + 1), pow(p, (2*r + 1) - q) + for j in range(1, r + 1): + for mul in range((j - 1)*p + 1, j*p): # ignore jp itself + fac *= mul + fac %= modulo + bj_ = bj(u, j, r) + prod *= pow(fac, bj_, modulo) + prod %= modulo + if p == 2: + sm = u // 2 + for j in range(1, r + 1): sm += j//2 * bj(u, j, r) + if sm % 2 == 1: prod *= -1 + prod %= modulo//div + return prod % modulo + + def bj(u, j, r): + """Compute the exponent of (j*p)!_p in the calculation of (u*p)!_p.""" + prod = u + for i in range(1, r + 1): + if i != j: prod *= u*u - i*i + for i in range(1, r + 1): + if i != j: prod //= j*j - i*i + return prod // j + + def up_plus_v_binom(u, v): + """Compute binomial(u*p + v, v)_p modulo p^q.""" + prod = 1 + div = invert(factorial(v), modulo) + for j in range(1, q): + b = div + for v_ in range(j*p + 1, j*p + v + 1): + b *= v_ + b %= modulo + aj = u + for i in range(1, q): + if i != j: aj *= u - i + for i in range(1, q): + if i != j: aj //= j - i + aj //= j + prod *= pow(b, aj, modulo) + prod %= modulo + return prod + + @recurrence_memo([1]) + def factorial(v, prev): + """Compute v! modulo p^q.""" + return v*prev[-1] % modulo + + def factorial_p(n): + """Compute n!_p modulo p^q.""" + u, v = divmod(n, p) + return (factorial(v) * up_factorial(u) * up_plus_v_binom(u, v)) % modulo + + prod = 1 + Nj, Mj, Rj = n, m, n - m + # e0 will be the p-adic valuation of binomial(n, m) at p + e0 = carry = eq_1 = j = 0 + while Nj: + numerator = factorial_p(Nj % modulo) + denominator = factorial_p(Mj % modulo) * factorial_p(Rj % modulo) % modulo + Nj, (Mj, mj), (Rj, rj) = Nj//p, divmod(Mj, p), divmod(Rj, p) + carry = (mj + rj + carry) // p + e0 += carry + if j >= q - 1: eq_1 += carry + prod *= numerator * invert(denominator, modulo) + prod %= modulo + j += 1 + + mul = pow(1 if p == 2 and q >= 3 else -1, eq_1, modulo) + return (pow(p, e0, modulo) * mul * prod) % modulo diff --git a/phivenv/Lib/site-packages/sympy/ntheory/tests/__init__.py b/phivenv/Lib/site-packages/sympy/ntheory/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/ntheory/tests/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/ntheory/tests/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2be6b622e7196566b1ef360d6269abb4e91f2b2e Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/ntheory/tests/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/ntheory/tests/__pycache__/test_bbp_pi.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/ntheory/tests/__pycache__/test_bbp_pi.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b8e323cd78e3b48c81b127c53688fded124e3b4 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/ntheory/tests/__pycache__/test_bbp_pi.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/ntheory/tests/__pycache__/test_continued_fraction.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/ntheory/tests/__pycache__/test_continued_fraction.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97beecec391bfe7439568aeac92a920dcd4c8f7c Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/ntheory/tests/__pycache__/test_continued_fraction.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/ntheory/tests/__pycache__/test_digits.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/ntheory/tests/__pycache__/test_digits.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a6ac15006d4b834622b4cf1e70fb0b61da80b1cd Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/ntheory/tests/__pycache__/test_digits.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/ntheory/tests/__pycache__/test_ecm.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/ntheory/tests/__pycache__/test_ecm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7904ec6864fef17b1ef17a7d1f1a28015d534f4 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/ntheory/tests/__pycache__/test_ecm.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/ntheory/tests/__pycache__/test_egyptian_fraction.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/ntheory/tests/__pycache__/test_egyptian_fraction.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6150628985a0387f7f9554e3ac468794de25e1f Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/ntheory/tests/__pycache__/test_egyptian_fraction.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/ntheory/tests/__pycache__/test_elliptic_curve.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/ntheory/tests/__pycache__/test_elliptic_curve.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8937bf5721acb197372b2fd71f53c7fef24e2b8c Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/ntheory/tests/__pycache__/test_elliptic_curve.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/ntheory/tests/__pycache__/test_factor_.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/ntheory/tests/__pycache__/test_factor_.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f698dc883fded3c2f7327bb78188951fdf372f37 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/ntheory/tests/__pycache__/test_factor_.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/ntheory/tests/__pycache__/test_generate.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/ntheory/tests/__pycache__/test_generate.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db7d70ea4d51295b0ed087936d47e2f3e16a6ad0 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/ntheory/tests/__pycache__/test_generate.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/ntheory/tests/__pycache__/test_hypothesis.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/ntheory/tests/__pycache__/test_hypothesis.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b801d8b63469ac1972cc6d74a89888177b83d1e4 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/ntheory/tests/__pycache__/test_hypothesis.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/ntheory/tests/__pycache__/test_modular.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/ntheory/tests/__pycache__/test_modular.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..778f2e5a32dfcd416bdd33a2471dbcef02159614 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/ntheory/tests/__pycache__/test_modular.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/ntheory/tests/__pycache__/test_multinomial.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/ntheory/tests/__pycache__/test_multinomial.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27965684ec38032e47698eefb019fc5fb7043ba5 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/ntheory/tests/__pycache__/test_multinomial.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/ntheory/tests/__pycache__/test_partitions.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/ntheory/tests/__pycache__/test_partitions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e0430ddae4c42e645473375a571bd145532d1976 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/ntheory/tests/__pycache__/test_partitions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/ntheory/tests/__pycache__/test_primetest.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/ntheory/tests/__pycache__/test_primetest.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..466eb11356faa42ce92becc0c24edfaeada18bca Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/ntheory/tests/__pycache__/test_primetest.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/ntheory/tests/__pycache__/test_qs.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/ntheory/tests/__pycache__/test_qs.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93737ab05579ed30aeefc16e394935e2b35af53c Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/ntheory/tests/__pycache__/test_qs.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/ntheory/tests/__pycache__/test_residue.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/ntheory/tests/__pycache__/test_residue.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5474ae3a350d05c0f7fbc8788ab8de6c9459fcb Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/ntheory/tests/__pycache__/test_residue.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/ntheory/tests/test_bbp_pi.py b/phivenv/Lib/site-packages/sympy/ntheory/tests/test_bbp_pi.py new file mode 100644 index 0000000000000000000000000000000000000000..69c24970239cc45eef4140bf19dfd7d4f6a7e150 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/ntheory/tests/test_bbp_pi.py @@ -0,0 +1,134 @@ +from sympy.core.random import randint + +from sympy.ntheory.bbp_pi import pi_hex_digits +from sympy.testing.pytest import raises + + +# http://www.herongyang.com/Cryptography/Blowfish-First-8366-Hex-Digits-of-PI.html +# There are actually 8336 listed there; with the prepended 3 there are 8337 +# below +dig=''.join(''' +3243f6a8885a308d313198a2e03707344a4093822299f31d0082efa98ec4e6c89452821e638d013 +77be5466cf34e90c6cc0ac29b7c97c50dd3f84d5b5b54709179216d5d98979fb1bd1310ba698dfb5 +ac2ffd72dbd01adfb7b8e1afed6a267e96ba7c9045f12c7f9924a19947b3916cf70801f2e2858efc +16636920d871574e69a458fea3f4933d7e0d95748f728eb658718bcd5882154aee7b54a41dc25a59 +b59c30d5392af26013c5d1b023286085f0ca417918b8db38ef8e79dcb0603a180e6c9e0e8bb01e8a +3ed71577c1bd314b2778af2fda55605c60e65525f3aa55ab945748986263e8144055ca396a2aab10 +b6b4cc5c341141e8cea15486af7c72e993b3ee1411636fbc2a2ba9c55d741831f6ce5c3e169b8793 +1eafd6ba336c24cf5c7a325381289586773b8f48986b4bb9afc4bfe81b6628219361d809ccfb21a9 +91487cac605dec8032ef845d5de98575b1dc262302eb651b8823893e81d396acc50f6d6ff383f442 +392e0b4482a484200469c8f04a9e1f9b5e21c66842f6e96c9a670c9c61abd388f06a51a0d2d8542f +68960fa728ab5133a36eef0b6c137a3be4ba3bf0507efb2a98a1f1651d39af017666ca593e82430e +888cee8619456f9fb47d84a5c33b8b5ebee06f75d885c12073401a449f56c16aa64ed3aa62363f77 +061bfedf72429b023d37d0d724d00a1248db0fead349f1c09b075372c980991b7b25d479d8f6e8de +f7e3fe501ab6794c3b976ce0bd04c006bac1a94fb6409f60c45e5c9ec2196a246368fb6faf3e6c53 +b51339b2eb3b52ec6f6dfc511f9b30952ccc814544af5ebd09bee3d004de334afd660f2807192e4b +b3c0cba85745c8740fd20b5f39b9d3fbdb5579c0bd1a60320ad6a100c6402c7279679f25fefb1fa3 +cc8ea5e9f8db3222f83c7516dffd616b152f501ec8ad0552ab323db5fafd23876053317b483e00df +829e5c57bbca6f8ca01a87562edf1769dbd542a8f6287effc3ac6732c68c4f5573695b27b0bbca58 +c8e1ffa35db8f011a010fa3d98fd2183b84afcb56c2dd1d35b9a53e479b6f84565d28e49bc4bfb97 +90e1ddf2daa4cb7e3362fb1341cee4c6e8ef20cada36774c01d07e9efe2bf11fb495dbda4dae9091 +98eaad8e716b93d5a0d08ed1d0afc725e08e3c5b2f8e7594b78ff6e2fbf2122b648888b812900df0 +1c4fad5ea0688fc31cd1cff191b3a8c1ad2f2f2218be0e1777ea752dfe8b021fa1e5a0cc0fb56f74 +e818acf3d6ce89e299b4a84fe0fd13e0b77cc43b81d2ada8d9165fa2668095770593cc7314211a14 +77e6ad206577b5fa86c75442f5fb9d35cfebcdaf0c7b3e89a0d6411bd3ae1e7e4900250e2d2071b3 +5e226800bb57b8e0af2464369bf009b91e5563911d59dfa6aa78c14389d95a537f207d5ba202e5b9 +c5832603766295cfa911c819684e734a41b3472dca7b14a94a1b5100529a532915d60f573fbc9bc6 +e42b60a47681e6740008ba6fb5571be91ff296ec6b2a0dd915b6636521e7b9f9b6ff34052ec58556 +6453b02d5da99f8fa108ba47996e85076a4b7a70e9b5b32944db75092ec4192623ad6ea6b049a7df +7d9cee60b88fedb266ecaa8c71699a17ff5664526cc2b19ee1193602a575094c29a0591340e4183a +3e3f54989a5b429d656b8fe4d699f73fd6a1d29c07efe830f54d2d38e6f0255dc14cdd20868470eb +266382e9c6021ecc5e09686b3f3ebaefc93c9718146b6a70a1687f358452a0e286b79c5305aa5007 +373e07841c7fdeae5c8e7d44ec5716f2b8b03ada37f0500c0df01c1f040200b3ffae0cf51a3cb574 +b225837a58dc0921bdd19113f97ca92ff69432477322f547013ae5e58137c2dadcc8b576349af3dd +a7a94461460fd0030eecc8c73ea4751e41e238cd993bea0e2f3280bba1183eb3314e548b384f6db9 +086f420d03f60a04bf2cb8129024977c795679b072bcaf89afde9a771fd9930810b38bae12dccf3f +2e5512721f2e6b7124501adde69f84cd877a5847187408da17bc9f9abce94b7d8cec7aec3adb851d +fa63094366c464c3d2ef1c18473215d908dd433b3724c2ba1612a14d432a65c45150940002133ae4 +dd71dff89e10314e5581ac77d65f11199b043556f1d7a3c76b3c11183b5924a509f28fe6ed97f1fb +fa9ebabf2c1e153c6e86e34570eae96fb1860e5e0a5a3e2ab3771fe71c4e3d06fa2965dcb999e71d +0f803e89d65266c8252e4cc9789c10b36ac6150eba94e2ea78a5fc3c531e0a2df4f2f74ea7361d2b +3d1939260f19c279605223a708f71312b6ebadfe6eeac31f66e3bc4595a67bc883b17f37d1018cff +28c332ddefbe6c5aa56558218568ab9802eecea50fdb2f953b2aef7dad5b6e2f841521b628290761 +70ecdd4775619f151013cca830eb61bd960334fe1eaa0363cfb5735c904c70a239d59e9e0bcbaade +14eecc86bc60622ca79cab5cabb2f3846e648b1eaf19bdf0caa02369b9655abb5040685a323c2ab4 +b3319ee9d5c021b8f79b540b19875fa09995f7997e623d7da8f837889a97e32d7711ed935f166812 +810e358829c7e61fd696dedfa17858ba9957f584a51b2272639b83c3ff1ac24696cdb30aeb532e30 +548fd948e46dbc312858ebf2ef34c6ffeafe28ed61ee7c3c735d4a14d9e864b7e342105d14203e13 +e045eee2b6a3aaabeadb6c4f15facb4fd0c742f442ef6abbb5654f3b1d41cd2105d81e799e86854d +c7e44b476a3d816250cf62a1f25b8d2646fc8883a0c1c7b6a37f1524c369cb749247848a0b5692b2 +85095bbf00ad19489d1462b17423820e0058428d2a0c55f5ea1dadf43e233f70613372f0928d937e +41d65fecf16c223bdb7cde3759cbee74604085f2a7ce77326ea607808419f8509ee8efd85561d997 +35a969a7aac50c06c25a04abfc800bcadc9e447a2ec3453484fdd567050e1e9ec9db73dbd3105588 +cd675fda79e3674340c5c43465713e38d83d28f89ef16dff20153e21e78fb03d4ae6e39f2bdb83ad +f7e93d5a68948140f7f64c261c94692934411520f77602d4f7bcf46b2ed4a20068d40824713320f4 +6a43b7d4b7500061af1e39f62e9724454614214f74bf8b88404d95fc1d96b591af70f4ddd366a02f +45bfbc09ec03bd97857fac6dd031cb850496eb27b355fd3941da2547e6abca0a9a28507825530429 +f40a2c86dae9b66dfb68dc1462d7486900680ec0a427a18dee4f3ffea2e887ad8cb58ce0067af4d6 +b6aace1e7cd3375fecce78a399406b2a4220fe9e35d9f385b9ee39d7ab3b124e8b1dc9faf74b6d18 +5626a36631eae397b23a6efa74dd5b43326841e7f7ca7820fbfb0af54ed8feb397454056acba4895 +2755533a3a20838d87fe6ba9b7d096954b55a867bca1159a58cca9296399e1db33a62a4a563f3125 +f95ef47e1c9029317cfdf8e80204272f7080bb155c05282ce395c11548e4c66d2248c1133fc70f86 +dc07f9c9ee41041f0f404779a45d886e17325f51ebd59bc0d1f2bcc18f41113564257b7834602a9c +60dff8e8a31f636c1b0e12b4c202e1329eaf664fd1cad181156b2395e0333e92e13b240b62eebeb9 +2285b2a20ee6ba0d99de720c8c2da2f728d012784595b794fd647d0862e7ccf5f05449a36f877d48 +fac39dfd27f33e8d1e0a476341992eff743a6f6eabf4f8fd37a812dc60a1ebddf8991be14cdb6e6b +0dc67b55106d672c372765d43bdcd0e804f1290dc7cc00ffa3b5390f92690fed0b667b9ffbcedb7d +9ca091cf0bd9155ea3bb132f88515bad247b9479bf763bd6eb37392eb3cc1159798026e297f42e31 +2d6842ada7c66a2b3b12754ccc782ef11c6a124237b79251e706a1bbe64bfb63501a6b101811caed +fa3d25bdd8e2e1c3c9444216590a121386d90cec6ed5abea2a64af674eda86a85fbebfe98864e4c3 +fe9dbc8057f0f7c08660787bf86003604dd1fd8346f6381fb07745ae04d736fccc83426b33f01eab +71b08041873c005e5f77a057bebde8ae2455464299bf582e614e58f48ff2ddfda2f474ef388789bd +c25366f9c3c8b38e74b475f25546fcd9b97aeb26618b1ddf84846a0e79915f95e2466e598e20b457 +708cd55591c902de4cb90bace1bb8205d011a862487574a99eb77f19b6e0a9dc09662d09a1c43246 +33e85a1f0209f0be8c4a99a0251d6efe101ab93d1d0ba5a4dfa186f20f2868f169dcb7da83573906 +fea1e2ce9b4fcd7f5250115e01a70683faa002b5c40de6d0279af88c27773f8641c3604c0661a806 +b5f0177a28c0f586e0006058aa30dc7d6211e69ed72338ea6353c2dd94c2c21634bbcbee5690bcb6 +deebfc7da1ce591d766f05e4094b7c018839720a3d7c927c2486e3725f724d9db91ac15bb4d39eb8 +fced54557808fca5b5d83d7cd34dad0fc41e50ef5eb161e6f8a28514d96c51133c6fd5c7e756e14e +c4362abfceddc6c837d79a323492638212670efa8e406000e03a39ce37d3faf5cfabc277375ac52d +1b5cb0679e4fa33742d382274099bc9bbed5118e9dbf0f7315d62d1c7ec700c47bb78c1b6b21a190 +45b26eb1be6a366eb45748ab2fbc946e79c6a376d26549c2c8530ff8ee468dde7dd5730a1d4cd04d +c62939bbdba9ba4650ac9526e8be5ee304a1fad5f06a2d519a63ef8ce29a86ee22c089c2b843242e +f6a51e03aa9cf2d0a483c061ba9be96a4d8fe51550ba645bd62826a2f9a73a3ae14ba99586ef5562 +e9c72fefd3f752f7da3f046f6977fa0a5980e4a91587b086019b09e6ad3b3ee593e990fd5a9e34d7 +972cf0b7d9022b8b5196d5ac3a017da67dd1cf3ed67c7d2d281f9f25cfadf2b89b5ad6b4725a88f5 +4ce029ac71e019a5e647b0acfded93fa9be8d3c48d283b57ccf8d5662979132e28785f0191ed7560 +55f7960e44e3d35e8c15056dd488f46dba03a161250564f0bdc3eb9e153c9057a297271aeca93a07 +2a1b3f6d9b1e6321f5f59c66fb26dcf3197533d928b155fdf5035634828aba3cbb28517711c20ad9 +f8abcc5167ccad925f4de817513830dc8e379d58629320f991ea7a90c2fb3e7bce5121ce64774fbe +32a8b6e37ec3293d4648de53696413e680a2ae0810dd6db22469852dfd09072166b39a460a6445c0 +dd586cdecf1c20c8ae5bbef7dd1b588d40ccd2017f6bb4e3bbdda26a7e3a59ff453e350a44bcb4cd +d572eacea8fa6484bb8d6612aebf3c6f47d29be463542f5d9eaec2771bf64e6370740e0d8de75b13 +57f8721671af537d5d4040cb084eb4e2cc34d2466a0115af84e1b0042895983a1d06b89fb4ce6ea0 +486f3f3b823520ab82011a1d4b277227f8611560b1e7933fdcbb3a792b344525bda08839e151ce79 +4b2f32c9b7a01fbac9e01cc87ebcc7d1f6cf0111c3a1e8aac71a908749d44fbd9ad0dadecbd50ada +380339c32ac69136678df9317ce0b12b4ff79e59b743f5bb3af2d519ff27d9459cbf97222c15e6fc +2a0f91fc719b941525fae59361ceb69cebc2a8645912baa8d1b6c1075ee3056a0c10d25065cb03a4 +42e0ec6e0e1698db3b4c98a0be3278e9649f1f9532e0d392dfd3a0342b8971f21e1b0a74414ba334 +8cc5be7120c37632d8df359f8d9b992f2ee60b6f470fe3f11de54cda541edad891ce6279cfcd3e7e +6f1618b166fd2c1d05848fd2c5f6fb2299f523f357a632762393a8353156cccd02acf081625a75eb +b56e16369788d273ccde96629281b949d04c50901b71c65614e6c6c7bd327a140a45e1d006c3f27b +9ac9aa53fd62a80f00bb25bfe235bdd2f671126905b2040222b6cbcf7ccd769c2b53113ec01640e3 +d338abbd602547adf0ba38209cf746ce7677afa1c52075606085cbfe4e8ae88dd87aaaf9b04cf9aa +7e1948c25c02fb8a8c01c36ae4d6ebe1f990d4f869a65cdea03f09252dc208e69fb74e6132ce77e2 +5b578fdfe33ac372e6'''.split()) + + +def test_hex_pi_nth_digits(): + assert pi_hex_digits(0) == '3243f6a8885a30' + assert pi_hex_digits(1) == '243f6a8885a308' + assert pi_hex_digits(10000) == '68ac8fcfb8016c' + assert pi_hex_digits(13) == '08d313198a2e03' + assert pi_hex_digits(0, 3) == '324' + assert pi_hex_digits(0, 0) == '' + raises(ValueError, lambda: pi_hex_digits(-1)) + raises(ValueError, lambda: pi_hex_digits(0, -1)) + raises(ValueError, lambda: pi_hex_digits(3.14)) + + # this will pick a random segment to compute every time + # it is run. If it ever fails, there is an error in the + # computation. + n = randint(0, len(dig)) + prec = randint(0, len(dig) - n) + assert pi_hex_digits(n, prec) == dig[n: n + prec] diff --git a/phivenv/Lib/site-packages/sympy/ntheory/tests/test_continued_fraction.py b/phivenv/Lib/site-packages/sympy/ntheory/tests/test_continued_fraction.py new file mode 100644 index 0000000000000000000000000000000000000000..8ca6088507f1d112e9146cd5249b1143f375c2cf --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/ntheory/tests/test_continued_fraction.py @@ -0,0 +1,77 @@ +import itertools +from sympy.core import GoldenRatio as phi +from sympy.core.numbers import (Rational, pi) +from sympy.core.singleton import S +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.ntheory.continued_fraction import \ + (continued_fraction_periodic as cf_p, + continued_fraction_iterator as cf_i, + continued_fraction_convergents as cf_c, + continued_fraction_reduce as cf_r, + continued_fraction as cf) +from sympy.testing.pytest import raises + + +def test_continued_fraction(): + assert cf_p(1, 1, 10, 0) == cf_p(1, 1, 0, 1) + assert cf_p(1, -1, 10, 1) == cf_p(-1, 1, 10, -1) + t = sqrt(2) + assert cf((1 + t)*(1 - t)) == cf(-1) + for n in [0, 2, Rational(2, 3), sqrt(2), 3*sqrt(2), 1 + 2*sqrt(3)/5, + (2 - 3*sqrt(5))/7, 1 + sqrt(2), (-5 + sqrt(17))/4]: + assert (cf_r(cf(n)) - n).expand() == 0 + assert (cf_r(cf(-n)) + n).expand() == 0 + raises(ValueError, lambda: cf(sqrt(2 + sqrt(3)))) + raises(ValueError, lambda: cf(sqrt(2) + sqrt(3))) + raises(ValueError, lambda: cf(pi)) + raises(ValueError, lambda: cf(.1)) + + raises(ValueError, lambda: cf_p(1, 0, 0)) + raises(ValueError, lambda: cf_p(1, 1, -1)) + assert cf_p(4, 3, 0) == [1, 3] + assert cf_p(0, 3, 5) == [0, 1, [2, 1, 12, 1, 2, 2]] + assert cf_p(1, 1, 0) == [1] + assert cf_p(3, 4, 0) == [0, 1, 3] + assert cf_p(4, 5, 0) == [0, 1, 4] + assert cf_p(5, 6, 0) == [0, 1, 5] + assert cf_p(11, 13, 0) == [0, 1, 5, 2] + assert cf_p(16, 19, 0) == [0, 1, 5, 3] + assert cf_p(27, 32, 0) == [0, 1, 5, 2, 2] + assert cf_p(1, 2, 5) == [[1]] + assert cf_p(0, 1, 2) == [1, [2]] + assert cf_p(6, 7, 49) == [1, 1, 6] + assert cf_p(3796, 1387, 0) == [2, 1, 2, 1, 4] + assert cf_p(3245, 10000) == [0, 3, 12, 4, 13] + assert cf_p(1932, 2568) == [0, 1, 3, 26, 2] + assert cf_p(6589, 2569) == [2, 1, 1, 3, 2, 1, 3, 1, 23] + + def take(iterator, n=7): + return list(itertools.islice(iterator, n)) + + assert take(cf_i(phi)) == [1, 1, 1, 1, 1, 1, 1] + assert take(cf_i(pi)) == [3, 7, 15, 1, 292, 1, 1] + + assert list(cf_i(Rational(17, 12))) == [1, 2, 2, 2] + assert list(cf_i(Rational(-17, 12))) == [-2, 1, 1, 2, 2] + + assert list(cf_c([1, 6, 1, 8])) == [S.One, Rational(7, 6), Rational(8, 7), Rational(71, 62)] + assert list(cf_c([2])) == [S(2)] + assert list(cf_c([1, 1, 1, 1, 1, 1, 1])) == [S.One, S(2), Rational(3, 2), Rational(5, 3), + Rational(8, 5), Rational(13, 8), Rational(21, 13)] + assert list(cf_c([1, 6, Rational(-1, 2), 4])) == [S.One, Rational(7, 6), Rational(5, 4), Rational(3, 2)] + assert take(cf_c([[1]])) == [S.One, S(2), Rational(3, 2), Rational(5, 3), Rational(8, 5), + Rational(13, 8), Rational(21, 13)] + assert take(cf_c([1, [1, 2]])) == [S.One, S(2), Rational(5, 3), Rational(7, 4), Rational(19, 11), + Rational(26, 15), Rational(71, 41)] + + cf_iter_e = (2 if i == 1 else i // 3 * 2 if i % 3 == 0 else 1 for i in itertools.count(1)) + assert take(cf_c(cf_iter_e)) == [S(2), S(3), Rational(8, 3), Rational(11, 4), Rational(19, 7), + Rational(87, 32), Rational(106, 39)] + + assert cf_r([1, 6, 1, 8]) == Rational(71, 62) + assert cf_r([3]) == S(3) + assert cf_r([-1, 5, 1, 4]) == Rational(-24, 29) + assert (cf_r([0, 1, 1, 7, [24, 8]]) - (sqrt(3) + 2)/7).expand() == 0 + assert cf_r([1, 5, 9]) == Rational(55, 46) + assert (cf_r([[1]]) - (sqrt(5) + 1)/2).expand() == 0 + assert cf_r([-3, 1, 1, [2]]) == -1 - sqrt(2) diff --git a/phivenv/Lib/site-packages/sympy/ntheory/tests/test_digits.py b/phivenv/Lib/site-packages/sympy/ntheory/tests/test_digits.py new file mode 100644 index 0000000000000000000000000000000000000000..4284805f4ffe5b9095eacb2e83f2cd8076db3ee4 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/ntheory/tests/test_digits.py @@ -0,0 +1,55 @@ +from sympy.ntheory import count_digits, digits, is_palindromic +from sympy.core.intfunc import num_digits + +from sympy.testing.pytest import raises + + +def test_num_digits(): + # depending on whether one rounds up or down or uses log or log10, + # one or more of these will fail if you don't check for the off-by + # one condition + assert num_digits(2, 2) == 2 + assert num_digits(2**48 - 1, 2) == 48 + assert num_digits(1000, 10) == 4 + assert num_digits(125, 5) == 4 + assert num_digits(100, 16) == 2 + assert num_digits(-1000, 10) == 4 + # if changes are made to the function, this structured test over + # this range will expose problems + for base in range(2, 100): + for e in range(1, 100): + n = base**e + assert num_digits(n, base) == e + 1 + assert num_digits(n + 1, base) == e + 1 + assert num_digits(n - 1, base) == e + + +def test_digits(): + assert all(digits(n, 2)[1:] == [int(d) for d in format(n, 'b')] + for n in range(20)) + assert all(digits(n, 8)[1:] == [int(d) for d in format(n, 'o')] + for n in range(20)) + assert all(digits(n, 16)[1:] == [int(d, 16) for d in format(n, 'x')] + for n in range(20)) + assert digits(2345, 34) == [34, 2, 0, 33] + assert digits(384753, 71) == [71, 1, 5, 23, 4] + assert digits(93409, 10) == [10, 9, 3, 4, 0, 9] + assert digits(-92838, 11) == [-11, 6, 3, 8, 2, 9] + assert digits(35, 10) == [10, 3, 5] + assert digits(35, 10, 3) == [10, 0, 3, 5] + assert digits(-35, 10, 4) == [-10, 0, 0, 3, 5] + raises(ValueError, lambda: digits(2, 2, 1)) + + +def test_count_digits(): + assert count_digits(55, 2) == {1: 5, 0: 1} + assert count_digits(55, 10) == {5: 2} + n = count_digits(123) + assert n[4] == 0 and type(n[4]) is int + + +def test_is_palindromic(): + assert is_palindromic(-11) + assert is_palindromic(11) + assert is_palindromic(0o121, 8) + assert not is_palindromic(123) diff --git a/phivenv/Lib/site-packages/sympy/ntheory/tests/test_ecm.py b/phivenv/Lib/site-packages/sympy/ntheory/tests/test_ecm.py new file mode 100644 index 0000000000000000000000000000000000000000..7f134e4e1cf68231e9f89242d2b8476b9edeabb8 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/ntheory/tests/test_ecm.py @@ -0,0 +1,63 @@ +from sympy.external.gmpy import invert +from sympy.ntheory.ecm import ecm, Point +from sympy.testing.pytest import slow + +@slow +def test_ecm(): + assert ecm(3146531246531241245132451321) == {3, 100327907731, 10454157497791297} + assert ecm(46167045131415113) == {43, 2634823, 407485517} + assert ecm(631211032315670776841) == {9312934919, 67777885039} + assert ecm(398883434337287) == {99476569, 4009823} + assert ecm(64211816600515193) == {281719, 359641, 633767} + assert ecm(4269021180054189416198169786894227) == {184039, 241603, 333331, 477973, 618619, 974123} + assert ecm(4516511326451341281684513) == {3, 39869, 131743543, 95542348571} + assert ecm(4132846513818654136451) == {47, 160343, 2802377, 195692803} + assert ecm(168541512131094651323) == {79, 113, 11011069, 1714635721} + #This takes ~10secs while factorint is not able to factorize this even in ~10mins + assert ecm(7060005655815754299976961394452809, B1=100000, B2=1000000) == {6988699669998001, 1010203040506070809} + + +def test_Point(): + #The curve is of the form y**2 = x**3 + a*x**2 + x + mod = 101 + a = 10 + a_24 = (a + 2)*invert(4, mod) + p1 = Point(10, 17, a_24, mod) + p2 = p1.double() + assert p2 == Point(68, 56, a_24, mod) + p4 = p2.double() + assert p4 == Point(22, 64, a_24, mod) + p8 = p4.double() + assert p8 == Point(71, 95, a_24, mod) + p16 = p8.double() + assert p16 == Point(5, 16, a_24, mod) + p32 = p16.double() + assert p32 == Point(33, 96, a_24, mod) + + # p3 = p2 + p1 + p3 = p2.add(p1, p1) + assert p3 == Point(1, 61, a_24, mod) + # p5 = p3 + p2 or p4 + p1 + p5 = p3.add(p2, p1) + assert p5 == Point(49, 90, a_24, mod) + assert p5 == p4.add(p1, p3) + # p6 = 2*p3 + p6 = p3.double() + assert p6 == Point(87, 43, a_24, mod) + assert p6 == p4.add(p2, p2) + # p7 = p5 + p2 + p7 = p5.add(p2, p3) + assert p7 == Point(69, 23, a_24, mod) + assert p7 == p4.add(p3, p1) + assert p7 == p6.add(p1, p5) + # p9 = p5 + p4 + p9 = p5.add(p4, p1) + assert p9 == Point(56, 99, a_24, mod) + assert p9 == p6.add(p3, p3) + assert p9 == p7.add(p2, p5) + assert p9 == p8.add(p1, p7) + + assert p5 == p1.mont_ladder(5) + assert p9 == p1.mont_ladder(9) + assert p16 == p1.mont_ladder(16) + assert p9 == p3.mont_ladder(3) diff --git a/phivenv/Lib/site-packages/sympy/ntheory/tests/test_egyptian_fraction.py b/phivenv/Lib/site-packages/sympy/ntheory/tests/test_egyptian_fraction.py new file mode 100644 index 0000000000000000000000000000000000000000..a9a9fac578d93a88a648bdcf8dc34550cf4a7573 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/ntheory/tests/test_egyptian_fraction.py @@ -0,0 +1,49 @@ +from sympy.core.numbers import Rational +from sympy.ntheory.egyptian_fraction import egyptian_fraction +from sympy.core.add import Add +from sympy.testing.pytest import raises +from sympy.core.random import random_complex_number + + +def test_egyptian_fraction(): + def test_equality(r, alg="Greedy"): + return r == Add(*[Rational(1, i) for i in egyptian_fraction(r, alg)]) + + r = random_complex_number(a=0, c=1, b=0, d=0, rational=True) + assert test_equality(r) + + assert egyptian_fraction(Rational(4, 17)) == [5, 29, 1233, 3039345] + assert egyptian_fraction(Rational(7, 13), "Greedy") == [2, 26] + assert egyptian_fraction(Rational(23, 101), "Greedy") == \ + [5, 37, 1438, 2985448, 40108045937720] + assert egyptian_fraction(Rational(18, 23), "Takenouchi") == \ + [2, 6, 12, 35, 276, 2415] + assert egyptian_fraction(Rational(5, 6), "Graham Jewett") == \ + [6, 7, 8, 9, 10, 42, 43, 44, 45, 56, 57, 58, 72, 73, 90, 1806, 1807, + 1808, 1892, 1893, 1980, 3192, 3193, 3306, 5256, 3263442, 3263443, + 3267056, 3581556, 10192056, 10650056950806] + assert egyptian_fraction(Rational(5, 6), "Golomb") == [2, 6, 12, 20, 30] + assert egyptian_fraction(Rational(5, 121), "Golomb") == [25, 1225, 3577, 7081, 11737] + raises(ValueError, lambda: egyptian_fraction(Rational(-4, 9))) + assert egyptian_fraction(Rational(8, 3), "Golomb") == [1, 2, 3, 4, 5, 6, 7, + 14, 574, 2788, 6460, + 11590, 33062, 113820] + assert egyptian_fraction(Rational(355, 113)) == [1, 2, 3, 4, 5, 6, 7, 8, 9, + 10, 11, 12, 27, 744, 893588, + 1251493536607, + 20361068938197002344405230] + + +def test_input(): + r = (2,3), Rational(2, 3), (Rational(2), Rational(3)) + for m in ["Greedy", "Graham Jewett", "Takenouchi", "Golomb"]: + for i in r: + d = egyptian_fraction(i, m) + assert all(i.is_Integer for i in d) + if m == "Graham Jewett": + assert d == [3, 4, 12] + else: + assert d == [2, 6] + # check prefix + d = egyptian_fraction(Rational(5, 3)) + assert d == [1, 2, 6] and all(i.is_Integer for i in d) diff --git a/phivenv/Lib/site-packages/sympy/ntheory/tests/test_elliptic_curve.py b/phivenv/Lib/site-packages/sympy/ntheory/tests/test_elliptic_curve.py new file mode 100644 index 0000000000000000000000000000000000000000..7d49d8eac72cc622fb92dfca8c54e5cc6c8dfb8f --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/ntheory/tests/test_elliptic_curve.py @@ -0,0 +1,20 @@ +from sympy.ntheory.elliptic_curve import EllipticCurve + + +def test_elliptic_curve(): + # Point addition and multiplication + e3 = EllipticCurve(-1, 9) + p = e3(0, 3) + q = e3(-1, 3) + r = p + q + assert r.x == 1 and r.y == -3 + r = 2*p + q + assert r.x == 35 and r.y == 207 + r = -p + q + assert r.x == 37 and r.y == 225 + # Verify result in http://www.lmfdb.org/EllipticCurve/Q + # Discriminant + assert EllipticCurve(-1, 9).discriminant == -34928 + assert EllipticCurve(-2731, -55146, 1, 0, 1).discriminant == 25088 + # Torsion points + assert len(EllipticCurve(0, 1).torsion_points()) == 6 diff --git a/phivenv/Lib/site-packages/sympy/ntheory/tests/test_factor_.py b/phivenv/Lib/site-packages/sympy/ntheory/tests/test_factor_.py new file mode 100644 index 0000000000000000000000000000000000000000..5174b842c49ef0e14c1ad38d2d9ad550c2a2a388 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/ntheory/tests/test_factor_.py @@ -0,0 +1,702 @@ +from sympy.core.containers import Dict +from sympy.core.mul import Mul +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.functions.combinatorial.factorials import factorial as fac +from sympy.core.numbers import Integer, Rational +from sympy.external.gmpy import gcd + +from sympy.ntheory import (totient, + factorint, primefactors, divisors, nextprime, + pollard_rho, perfect_power, multiplicity, multiplicity_in_factorial, + divisor_count, primorial, pollard_pm1, divisor_sigma, + factorrat, reduced_totient) +from sympy.ntheory.factor_ import (smoothness, smoothness_p, proper_divisors, + antidivisors, antidivisor_count, _divisor_sigma, core, udivisors, udivisor_sigma, + udivisor_count, proper_divisor_count, primenu, primeomega, + mersenne_prime_exponent, is_perfect, is_abundant, + is_deficient, is_amicable, is_carmichael, find_carmichael_numbers_in_range, + find_first_n_carmichaels, dra, drm, _perfect_power, factor_cache) + +from sympy.testing.pytest import raises, slow + +from sympy.utilities.iterables import capture + + +def fac_multiplicity(n, p): + """Return the power of the prime number p in the + factorization of n!""" + if p > n: + return 0 + if p > n//2: + return 1 + q, m = n, 0 + while q >= p: + q //= p + m += q + return m + + +def multiproduct(seq=(), start=1): + """ + Return the product of a sequence of factors with multiplicities, + times the value of the parameter ``start``. The input may be a + sequence of (factor, exponent) pairs or a dict of such pairs. + + >>> multiproduct({3:7, 2:5}, 4) # = 3**7 * 2**5 * 4 + 279936 + + """ + if not seq: + return start + if isinstance(seq, dict): + seq = iter(seq.items()) + units = start + multi = [] + for base, exp in seq: + if not exp: + continue + elif exp == 1: + units *= base + else: + if exp % 2: + units *= base + multi.append((base, exp//2)) + return units * multiproduct(multi)**2 + + +def test_multiplicity(): + for b in range(2, 20): + for i in range(100): + assert multiplicity(b, b**i) == i + assert multiplicity(b, (b**i) * 23) == i + assert multiplicity(b, (b**i) * 1000249) == i + # Should be fast + assert multiplicity(10, 10**10023) == 10023 + # Should exit quickly + assert multiplicity(10**10, 10**10) == 1 + # Should raise errors for bad input + raises(ValueError, lambda: multiplicity(1, 1)) + raises(ValueError, lambda: multiplicity(1, 2)) + raises(ValueError, lambda: multiplicity(1.3, 2)) + raises(ValueError, lambda: multiplicity(2, 0)) + raises(ValueError, lambda: multiplicity(1.3, 0)) + + # handles Rationals + assert multiplicity(10, Rational(30, 7)) == 1 + assert multiplicity(Rational(2, 7), Rational(4, 7)) == 1 + assert multiplicity(Rational(1, 7), Rational(3, 49)) == 2 + assert multiplicity(Rational(2, 7), Rational(7, 2)) == -1 + assert multiplicity(3, Rational(1, 9)) == -2 + + +def test_multiplicity_in_factorial(): + n = fac(1000) + for i in (2, 4, 6, 12, 30, 36, 48, 60, 72, 96): + assert multiplicity(i, n) == multiplicity_in_factorial(i, 1000) + + +def test_private_perfect_power(): + assert _perfect_power(0) is False + assert _perfect_power(1) is False + assert _perfect_power(2) is False + assert _perfect_power(3) is False + for x in [2, 3, 5, 6, 7, 12, 15, 105, 100003]: + for y in range(2, 100): + assert _perfect_power(x**y) == (x, y) + if x & 1: + assert _perfect_power(x**y, next_p=3) == (x, y) + if x == 100003: + assert _perfect_power(x**y, next_p=100003) == (x, y) + assert _perfect_power(101*x**y) == False + # Catalan's conjecture + if x**y not in [8, 9]: + assert _perfect_power(x**y + 1) == False + assert _perfect_power(x**y - 1) == False + for x in range(1, 10): + for y in range(1, 10): + g = gcd(x, y) + if g == 1: + assert _perfect_power(5**x * 101**y) == False + else: + assert _perfect_power(5**x * 101**y) == (5**(x//g) * 101**(y//g), g) + + +def test_perfect_power(): + raises(ValueError, lambda: perfect_power(0.1)) + assert perfect_power(0) is False + assert perfect_power(1) is False + assert perfect_power(2) is False + assert perfect_power(3) is False + assert perfect_power(4) == (2, 2) + assert perfect_power(14) is False + assert perfect_power(25) == (5, 2) + assert perfect_power(22) is False + assert perfect_power(22, [2]) is False + assert perfect_power(137**(3*5*13)) == (137, 3*5*13) + assert perfect_power(137**(3*5*13) + 1) is False + assert perfect_power(137**(3*5*13) - 1) is False + assert perfect_power(103005006004**7) == (103005006004, 7) + assert perfect_power(103005006004**7 + 1) is False + assert perfect_power(103005006004**7 - 1) is False + assert perfect_power(103005006004**12) == (103005006004, 12) + assert perfect_power(103005006004**12 + 1) is False + assert perfect_power(103005006004**12 - 1) is False + assert perfect_power(2**10007) == (2, 10007) + assert perfect_power(2**10007 + 1) is False + assert perfect_power(2**10007 - 1) is False + assert perfect_power((9**99 + 1)**60) == (9**99 + 1, 60) + assert perfect_power((9**99 + 1)**60 + 1) is False + assert perfect_power((9**99 + 1)**60 - 1) is False + assert perfect_power((10**40000)**2, big=False) == (10**40000, 2) + assert perfect_power(10**100000) == (10, 100000) + assert perfect_power(10**100001) == (10, 100001) + assert perfect_power(13**4, [3, 5]) is False + assert perfect_power(3**4, [3, 10], factor=0) is False + assert perfect_power(3**3*5**3) == (15, 3) + assert perfect_power(2**3*5**5) is False + assert perfect_power(2*13**4) is False + assert perfect_power(2**5*3**3) is False + t = 2**24 + for d in divisors(24): + m = perfect_power(t*3**d) + assert m and m[1] == d or d == 1 + m = perfect_power(t*3**d, big=False) + assert m and m[1] == 2 or d == 1 or d == 3, (d, m) + + # negatives and non-integer rationals + assert perfect_power(-4) is False + assert perfect_power(-8) == (-2, 3) + assert perfect_power(-S(1)/8) == (-S(1)/2, 3) + assert perfect_power(S(1)/3) == False + assert perfect_power(-5**15) == (-5, 15) + assert perfect_power(-5**15, big=False) == (-3125, 3) + assert perfect_power(-5**15, [15]) == (-5, 15) + + n = -3 ** 60 + assert perfect_power(n) == (-81, 15) + assert perfect_power(n, big=False) == (-3486784401, 3) + assert perfect_power(n, [3, 5], big=True) == (-531441, 5) + assert perfect_power(n, [3, 5], big=False) == (-3486784401, 3) + assert perfect_power(n, [2]) == False + assert perfect_power(n, [2, 15]) == (-81, 15) + assert perfect_power(n, [2, 13]) == False + assert perfect_power(n, [17]) == False + assert perfect_power(n, [3]) == (-3486784401, 3) + assert perfect_power(n + 1) == False + + r = S(2) ** (2 * 5 * 7) / S(3) ** (2 * 7) + assert perfect_power(r) == (S(32) / 3, 14) + assert perfect_power(-r) == (-S(1024) / 9, 7) + assert perfect_power(r, big=False) == (S(34359738368) / 2187, 2) + assert perfect_power(r, [2, 5]) == (S(34359738368) / 2187, 2) + assert perfect_power(r, [5, 7]) == (S(1024) / 9, 7) + assert perfect_power(r, [5, 7], big=False) == (S(1024) / 9, 7) + assert perfect_power(r, [2, 5, 7], big=False) == (S(34359738368) / 2187, 2) + assert perfect_power(-r, [5, 7], big=False) == (-S(1024) / 9, 7) + + assert perfect_power(-S(1) / 8) == (-S(1) / 2, 3) + + assert perfect_power((-3)**60) == (3, 60) + assert perfect_power((-3)**61) == (-3, 61) + + assert perfect_power(S(2 ** 9) / 3 ** 12) == (S(8)/81, 3) + assert perfect_power(Rational(1, 2)**3) == (S.Half, 3) + assert perfect_power(Rational(-3, 2)**3) == (-3*S.Half, 3) + + +def test_factor_cache(): + factor_cache.cache_clear() + raises(ValueError, lambda: factor_cache.__setitem__(1, 5)) + raises(ValueError, lambda: factor_cache.__setitem__(10, 1)) + raises(ValueError, lambda: factor_cache.__setitem__(10, 10)) + raises(ValueError, lambda: factor_cache.__setitem__(10, 3)) + raises(ValueError, lambda: factor_cache.__setitem__(20, 4)) + factor_cache.maxsize = 3 + for i in range(2, 10): + factor_cache[5*i] = 5 + assert len(factor_cache) == 3 + factor_cache.maxsize = 5 + for i in range(2, 10): + factor_cache[5*i] = 5 + assert len(factor_cache) == 5 + factor_cache.maxsize = 2 + assert len(factor_cache) == 2 + factor_cache.maxsize =1000 + + factor_cache.cache_clear() + factor_cache[40] = 5 + assert factor_cache.get(40) == 5 + assert factor_cache.get(20) is None + assert factor_cache[40] == 5 + raises(KeyError, lambda: factor_cache[10]) + del factor_cache[40] + assert len(factor_cache) == 0 + raises(KeyError, lambda: factor_cache.__delitem__(40)) + factor_cache.add(100, [5, 2]) + assert len(factor_cache) == 2 + assert factor_cache[100] == 5 + + for n in [1000000007, 10000019*20000003]: + factorint(n) + assert n in factor_cache + + # Restore the initial state + factor_cache.cache_clear() + factor_cache.maxsize = 1000 + + +@slow +def test_factorint(): + assert primefactors(123456) == [2, 3, 643] + assert factorint(0) == {0: 1} + assert factorint(1) == {} + assert factorint(-1) == {-1: 1} + assert factorint(-2) == {-1: 1, 2: 1} + assert factorint(-16) == {-1: 1, 2: 4} + assert factorint(2) == {2: 1} + assert factorint(126) == {2: 1, 3: 2, 7: 1} + assert factorint(123456) == {2: 6, 3: 1, 643: 1} + assert factorint(5951757) == {3: 1, 7: 1, 29: 2, 337: 1} + assert factorint(64015937) == {7993: 1, 8009: 1} + assert factorint(2**(2**6) + 1) == {274177: 1, 67280421310721: 1} + #issue 19683 + assert factorint(10**38 - 1) == {3: 2, 11: 1, 909090909090909091: 1, 1111111111111111111: 1} + #issue 17676 + assert factorint(28300421052393658575) == {3: 1, 5: 2, 11: 2, 43: 1, 2063: 2, 4127: 1, 4129: 1} + assert factorint(2063**2 * 4127**1 * 4129**1) == {2063: 2, 4127: 1, 4129: 1} + assert factorint(2347**2 * 7039**1 * 7043**1) == {2347: 2, 7039: 1, 7043: 1} + + assert factorint(0, multiple=True) == [0] + assert factorint(1, multiple=True) == [] + assert factorint(-1, multiple=True) == [-1] + assert factorint(-2, multiple=True) == [-1, 2] + assert factorint(-16, multiple=True) == [-1, 2, 2, 2, 2] + assert factorint(2, multiple=True) == [2] + assert factorint(24, multiple=True) == [2, 2, 2, 3] + assert factorint(126, multiple=True) == [2, 3, 3, 7] + assert factorint(123456, multiple=True) == [2, 2, 2, 2, 2, 2, 3, 643] + assert factorint(5951757, multiple=True) == [3, 7, 29, 29, 337] + assert factorint(64015937, multiple=True) == [7993, 8009] + assert factorint(2**(2**6) + 1, multiple=True) == [274177, 67280421310721] + + assert factorint(fac(1, evaluate=False)) == {} + assert factorint(fac(7, evaluate=False)) == {2: 4, 3: 2, 5: 1, 7: 1} + assert factorint(fac(15, evaluate=False)) == \ + {2: 11, 3: 6, 5: 3, 7: 2, 11: 1, 13: 1} + assert factorint(fac(20, evaluate=False)) == \ + {2: 18, 3: 8, 5: 4, 7: 2, 11: 1, 13: 1, 17: 1, 19: 1} + assert factorint(fac(23, evaluate=False)) == \ + {2: 19, 3: 9, 5: 4, 7: 3, 11: 2, 13: 1, 17: 1, 19: 1, 23: 1} + + assert multiproduct(factorint(fac(200))) == fac(200) + assert multiproduct(factorint(fac(200, evaluate=False))) == fac(200) + for b, e in factorint(fac(150)).items(): + assert e == fac_multiplicity(150, b) + for b, e in factorint(fac(150, evaluate=False)).items(): + assert e == fac_multiplicity(150, b) + assert factorint(103005006059**7) == {103005006059: 7} + assert factorint(31337**191) == {31337: 191} + assert factorint(2**1000 * 3**500 * 257**127 * 383**60) == \ + {2: 1000, 3: 500, 257: 127, 383: 60} + assert len(factorint(fac(10000))) == 1229 + assert len(factorint(fac(10000, evaluate=False))) == 1229 + assert factorint(12932983746293756928584532764589230) == \ + {2: 1, 5: 1, 73: 1, 727719592270351: 1, 63564265087747: 1, 383: 1} + assert factorint(727719592270351) == {727719592270351: 1} + assert factorint(2**64 + 1, use_trial=False) == factorint(2**64 + 1) + for n in range(60000): + assert multiproduct(factorint(n)) == n + assert pollard_rho(2**64 + 1, seed=1) == 274177 + assert pollard_rho(19, seed=1) is None + assert factorint(3, limit=2) == {3: 1} + assert factorint(12345) == {3: 1, 5: 1, 823: 1} + assert factorint( + 12345, limit=3) == {4115: 1, 3: 1} # the 5 is greater than the limit + assert factorint(1, limit=1) == {} + assert factorint(0, 3) == {0: 1} + assert factorint(12, limit=1) == {12: 1} + assert factorint(30, limit=2) == {2: 1, 15: 1} + assert factorint(16, limit=2) == {2: 4} + assert factorint(124, limit=3) == {2: 2, 31: 1} + assert factorint(4*31**2, limit=3) == {2: 2, 31: 2} + p1 = nextprime(2**32) + p2 = nextprime(2**16) + p3 = nextprime(p2) + assert factorint(p1*p2*p3) == {p1: 1, p2: 1, p3: 1} + assert factorint(13*17*19, limit=15) == {13: 1, 17*19: 1} + assert factorint(1951*15013*15053, limit=2000) == {225990689: 1, 1951: 1} + assert factorint(primorial(17) + 1, use_pm1=0) == \ + {int(19026377261): 1, 3467: 1, 277: 1, 105229: 1} + # when prime b is closer than approx sqrt(8*p) to prime p then they are + # "close" and have a trivial factorization + a = nextprime(2**2**8) # 78 digits + b = nextprime(a + 2**2**4) + assert 'Fermat' in capture(lambda: factorint(a*b, verbose=1)) + + raises(ValueError, lambda: pollard_rho(4)) + raises(ValueError, lambda: pollard_pm1(3)) + raises(ValueError, lambda: pollard_pm1(10, B=2)) + # verbose coverage + n = nextprime(2**16)*nextprime(2**17)*nextprime(1901) + assert 'with primes' in capture(lambda: factorint(n, verbose=1)) + capture(lambda: factorint(nextprime(2**16)*1012, verbose=1)) + + n = nextprime(2**17) + capture(lambda: factorint(n**3, verbose=1)) # perfect power termination + capture(lambda: factorint(2*n, verbose=1)) # factoring complete msg + + # exceed 1st + n = nextprime(2**17) + n *= nextprime(n) + assert '1000' in capture(lambda: factorint(n, limit=1000, verbose=1)) + n *= nextprime(n) + assert len(factorint(n)) == 3 + assert len(factorint(n, limit=p1)) == 3 + n *= nextprime(2*n) + # exceed 2nd + assert '2001' in capture(lambda: factorint(n, limit=2000, verbose=1)) + assert capture( + lambda: factorint(n, limit=4000, verbose=1)).count('Pollard') == 2 + # non-prime pm1 result + n = nextprime(8069) + n *= nextprime(2*n)*nextprime(2*n, 2) + capture(lambda: factorint(n, verbose=1)) # non-prime pm1 result + # factor fermat composite + p1 = nextprime(2**17) + p2 = nextprime(2*p1) + assert factorint((p1*p2**2)**3) == {p1: 3, p2: 6} + # Test for non integer input + raises(ValueError, lambda: factorint(4.5)) + # test dict/Dict input + sans = '2**10*3**3' + n = {4: 2, 12: 3} + assert str(factorint(n)) == sans + assert str(factorint(Dict(n))) == sans + + +def test_divisors_and_divisor_count(): + assert divisors(-1) == [1] + assert divisors(0) == [] + assert divisors(1) == [1] + assert divisors(2) == [1, 2] + assert divisors(3) == [1, 3] + assert divisors(17) == [1, 17] + assert divisors(10) == [1, 2, 5, 10] + assert divisors(100) == [1, 2, 4, 5, 10, 20, 25, 50, 100] + assert divisors(101) == [1, 101] + assert type(divisors(2, generator=True)) is not list + + assert divisor_count(0) == 0 + assert divisor_count(-1) == 1 + assert divisor_count(1) == 1 + assert divisor_count(6) == 4 + assert divisor_count(12) == 6 + + assert divisor_count(180, 3) == divisor_count(180//3) + assert divisor_count(2*3*5, 7) == 0 + + +def test_proper_divisors_and_proper_divisor_count(): + assert proper_divisors(-1) == [] + assert proper_divisors(0) == [] + assert proper_divisors(1) == [] + assert proper_divisors(2) == [1] + assert proper_divisors(3) == [1] + assert proper_divisors(17) == [1] + assert proper_divisors(10) == [1, 2, 5] + assert proper_divisors(100) == [1, 2, 4, 5, 10, 20, 25, 50] + assert proper_divisors(1000000007) == [1] + assert type(proper_divisors(2, generator=True)) is not list + + assert proper_divisor_count(0) == 0 + assert proper_divisor_count(-1) == 0 + assert proper_divisor_count(1) == 0 + assert proper_divisor_count(36) == 8 + assert proper_divisor_count(2*3*5) == 7 + + +def test_udivisors_and_udivisor_count(): + assert udivisors(-1) == [1] + assert udivisors(0) == [] + assert udivisors(1) == [1] + assert udivisors(2) == [1, 2] + assert udivisors(3) == [1, 3] + assert udivisors(17) == [1, 17] + assert udivisors(10) == [1, 2, 5, 10] + assert udivisors(100) == [1, 4, 25, 100] + assert udivisors(101) == [1, 101] + assert udivisors(1000) == [1, 8, 125, 1000] + assert type(udivisors(2, generator=True)) is not list + + assert udivisor_count(0) == 0 + assert udivisor_count(-1) == 1 + assert udivisor_count(1) == 1 + assert udivisor_count(6) == 4 + assert udivisor_count(12) == 4 + + assert udivisor_count(180) == 8 + assert udivisor_count(2*3*5*7) == 16 + + +def test_issue_6981(): + S = set(divisors(4)).union(set(divisors(Integer(2)))) + assert S == {1,2,4} + + +def test_issue_4356(): + assert factorint(1030903) == {53: 2, 367: 1} + + +def test_divisors(): + assert divisors(28) == [1, 2, 4, 7, 14, 28] + assert list(divisors(3*5*7, 1)) == [1, 3, 5, 15, 7, 21, 35, 105] + assert divisors(0) == [] + + +def test_divisor_count(): + assert divisor_count(0) == 0 + assert divisor_count(6) == 4 + + +def test_proper_divisors(): + assert proper_divisors(-1) == [] + assert proper_divisors(28) == [1, 2, 4, 7, 14] + assert list(proper_divisors(3*5*7, True)) == [1, 3, 5, 15, 7, 21, 35] + + +def test_proper_divisor_count(): + assert proper_divisor_count(6) == 3 + assert proper_divisor_count(108) == 11 + + +def test_antidivisors(): + assert antidivisors(-1) == [] + assert antidivisors(-3) == [2] + assert antidivisors(14) == [3, 4, 9] + assert antidivisors(237) == [2, 5, 6, 11, 19, 25, 43, 95, 158] + assert antidivisors(12345) == [2, 6, 7, 10, 30, 1646, 3527, 4938, 8230] + assert antidivisors(393216) == [262144] + assert sorted(x for x in antidivisors(3*5*7, 1)) == \ + [2, 6, 10, 11, 14, 19, 30, 42, 70] + assert antidivisors(1) == [] + assert type(antidivisors(2, generator=True)) is not list + +def test_antidivisor_count(): + assert antidivisor_count(0) == 0 + assert antidivisor_count(-1) == 0 + assert antidivisor_count(-4) == 1 + assert antidivisor_count(20) == 3 + assert antidivisor_count(25) == 5 + assert antidivisor_count(38) == 7 + assert antidivisor_count(180) == 6 + assert antidivisor_count(2*3*5) == 3 + + +def test_smoothness_and_smoothness_p(): + assert smoothness(1) == (1, 1) + assert smoothness(2**4*3**2) == (3, 16) + + assert smoothness_p(10431, m=1) == \ + (1, [(3, (2, 2, 4)), (19, (1, 5, 5)), (61, (1, 31, 31))]) + assert smoothness_p(10431) == \ + (-1, [(3, (2, 2, 2)), (19, (1, 3, 9)), (61, (1, 5, 5))]) + assert smoothness_p(10431, power=1) == \ + (-1, [(3, (2, 2, 2)), (61, (1, 5, 5)), (19, (1, 3, 9))]) + assert smoothness_p(21477639576571, visual=1) == \ + 'p**i=4410317**1 has p-1 B=1787, B-pow=1787\n' + \ + 'p**i=4869863**1 has p-1 B=2434931, B-pow=2434931' + + +def test_visual_factorint(): + assert factorint(1, visual=1) == 1 + forty2 = factorint(42, visual=True) + assert type(forty2) == Mul + assert str(forty2) == '2**1*3**1*7**1' + assert factorint(1, visual=True) is S.One + no = {"evaluate": False} + assert factorint(42**2, visual=True) == Mul(Pow(2, 2, **no), + Pow(3, 2, **no), + Pow(7, 2, **no), **no) + assert -1 in factorint(-42, visual=True).args + + +def test_factorrat(): + assert str(factorrat(S(12)/1, visual=True)) == '2**2*3**1' + assert str(factorrat(Rational(1, 1), visual=True)) == '1' + assert str(factorrat(S(25)/14, visual=True)) == '5**2/(2*7)' + assert str(factorrat(Rational(25, 14), visual=True)) == '5**2/(2*7)' + assert str(factorrat(S(-25)/14/9, visual=True)) == '-1*5**2/(2*3**2*7)' + + assert factorrat(S(12)/1, multiple=True) == [2, 2, 3] + assert factorrat(Rational(1, 1), multiple=True) == [] + assert factorrat(S(25)/14, multiple=True) == [Rational(1, 7), S.Half, 5, 5] + assert factorrat(Rational(25, 14), multiple=True) == [Rational(1, 7), S.Half, 5, 5] + assert factorrat(Rational(12, 1), multiple=True) == [2, 2, 3] + assert factorrat(S(-25)/14/9, multiple=True) == \ + [-1, Rational(1, 7), Rational(1, 3), Rational(1, 3), S.Half, 5, 5] + + +def test_visual_io(): + sm = smoothness_p + fi = factorint + # with smoothness_p + n = 124 + d = fi(n) + m = fi(d, visual=True) + t = sm(n) + s = sm(t) + for th in [d, s, t, n, m]: + assert sm(th, visual=True) == s + assert sm(th, visual=1) == s + for th in [d, s, t, n, m]: + assert sm(th, visual=False) == t + assert [sm(th, visual=None) for th in [d, s, t, n, m]] == [s, d, s, t, t] + assert [sm(th, visual=2) for th in [d, s, t, n, m]] == [s, d, s, t, t] + + # with factorint + for th in [d, m, n]: + assert fi(th, visual=True) == m + assert fi(th, visual=1) == m + for th in [d, m, n]: + assert fi(th, visual=False) == d + assert [fi(th, visual=None) for th in [d, m, n]] == [m, d, d] + assert [fi(th, visual=0) for th in [d, m, n]] == [m, d, d] + + # test reevaluation + no = {"evaluate": False} + assert sm({4: 2}, visual=False) == sm(16) + assert sm(Mul(*[Pow(k, v, **no) for k, v in {4: 2, 2: 6}.items()], **no), + visual=False) == sm(2**10) + + assert fi({4: 2}, visual=False) == fi(16) + assert fi(Mul(*[Pow(k, v, **no) for k, v in {4: 2, 2: 6}.items()], **no), + visual=False) == fi(2**10) + + +def test_core(): + assert core(35**13, 10) == 42875 + assert core(210**2) == 1 + assert core(7776, 3) == 36 + assert core(10**27, 22) == 10**5 + assert core(537824) == 14 + assert core(1, 6) == 1 + + +def test__divisor_sigma(): + assert _divisor_sigma(23450) == 50592 + assert _divisor_sigma(23450, 0) == 24 + assert _divisor_sigma(23450, 1) == 50592 + assert _divisor_sigma(23450, 2) == 730747500 + assert _divisor_sigma(23450, 3) == 14666785333344 + A000005 = [1, 2, 2, 3, 2, 4, 2, 4, 3, 4, 2, 6, 2, 4, 4, 5, 2, 6, 2, 6, 4, + 4, 2, 8, 3, 4, 4, 6, 2, 8, 2, 6, 4, 4, 4, 9, 2, 4, 4, 8, 2, 8] + for n, val in enumerate(A000005, 1): + assert _divisor_sigma(n, 0) == val + A000203 = [1, 3, 4, 7, 6, 12, 8, 15, 13, 18, 12, 28, 14, 24, 24, 31, 18, + 39, 20, 42, 32, 36, 24, 60, 31, 42, 40, 56, 30, 72, 32, 63, 48] + for n, val in enumerate(A000203, 1): + assert _divisor_sigma(n, 1) == val + A001157 = [1, 5, 10, 21, 26, 50, 50, 85, 91, 130, 122, 210, 170, 250, 260, + 341, 290, 455, 362, 546, 500, 610, 530, 850, 651, 850, 820, 1050] + for n, val in enumerate(A001157, 1): + assert _divisor_sigma(n, 2) == val + + +def test_mersenne_prime_exponent(): + assert mersenne_prime_exponent(1) == 2 + assert mersenne_prime_exponent(4) == 7 + assert mersenne_prime_exponent(10) == 89 + assert mersenne_prime_exponent(25) == 21701 + raises(ValueError, lambda: mersenne_prime_exponent(52)) + raises(ValueError, lambda: mersenne_prime_exponent(0)) + + +def test_is_perfect(): + assert is_perfect(-6) is False + assert is_perfect(6) is True + assert is_perfect(15) is False + assert is_perfect(28) is True + assert is_perfect(400) is False + assert is_perfect(496) is True + assert is_perfect(8128) is True + assert is_perfect(10000) is False + + +def test_is_abundant(): + assert is_abundant(10) is False + assert is_abundant(12) is True + assert is_abundant(18) is True + assert is_abundant(21) is False + assert is_abundant(945) is True + + +def test_is_deficient(): + assert is_deficient(10) is True + assert is_deficient(22) is True + assert is_deficient(56) is False + assert is_deficient(20) is False + assert is_deficient(36) is False + + +def test_is_amicable(): + assert is_amicable(173, 129) is False + assert is_amicable(220, 284) is True + assert is_amicable(8756, 8756) is False + + +def test_is_carmichael(): + A002997 = [561, 1105, 1729, 2465, 2821, 6601, 8911, 10585, 15841, + 29341, 41041, 46657, 52633, 62745, 63973, 75361, 101101] + for n in range(1, 5000): + assert is_carmichael(n) == (n in A002997) + for n in A002997: + assert is_carmichael(n) + + +def test_find_carmichael_numbers_in_range(): + assert find_carmichael_numbers_in_range(0, 561) == [] + assert find_carmichael_numbers_in_range(561, 562) == [561] + assert find_carmichael_numbers_in_range(561, 1105) == find_carmichael_numbers_in_range(561, 562) + raises(ValueError, lambda: find_carmichael_numbers_in_range(-2, 2)) + raises(ValueError, lambda: find_carmichael_numbers_in_range(22, 2)) + + +def test_find_first_n_carmichaels(): + assert find_first_n_carmichaels(0) == [] + assert find_first_n_carmichaels(1) == [561] + assert find_first_n_carmichaels(2) == [561, 1105] + + +def test_dra(): + assert dra(19, 12) == 8 + assert dra(2718, 10) == 9 + assert dra(0, 22) == 0 + assert dra(23456789, 10) == 8 + raises(ValueError, lambda: dra(24, -2)) + raises(ValueError, lambda: dra(24.2, 5)) + +def test_drm(): + assert drm(19, 12) == 7 + assert drm(2718, 10) == 2 + assert drm(0, 15) == 0 + assert drm(234161, 10) == 6 + raises(ValueError, lambda: drm(24, -2)) + raises(ValueError, lambda: drm(11.6, 9)) + + +def test_deprecated_ntheory_symbolic_functions(): + from sympy.testing.pytest import warns_deprecated_sympy + + with warns_deprecated_sympy(): + assert primenu(3) == 1 + with warns_deprecated_sympy(): + assert primeomega(3) == 1 + with warns_deprecated_sympy(): + assert totient(3) == 2 + with warns_deprecated_sympy(): + assert reduced_totient(3) == 2 + with warns_deprecated_sympy(): + assert divisor_sigma(3) == 4 + with warns_deprecated_sympy(): + assert udivisor_sigma(3) == 4 diff --git a/phivenv/Lib/site-packages/sympy/ntheory/tests/test_generate.py b/phivenv/Lib/site-packages/sympy/ntheory/tests/test_generate.py new file mode 100644 index 0000000000000000000000000000000000000000..b0e5918ffefede2e86f3be2b07d6c3a01c02e6e0 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/ntheory/tests/test_generate.py @@ -0,0 +1,285 @@ +from bisect import bisect, bisect_left + +from sympy.functions.combinatorial.numbers import mobius, totient +from sympy.ntheory.generate import (sieve, Sieve) + +from sympy.ntheory import isprime, randprime, nextprime, prevprime, \ + primerange, primepi, prime, primorial, composite, compositepi +from sympy.ntheory.generate import cycle_length, _primepi +from sympy.ntheory.primetest import mr +from sympy.testing.pytest import raises + +def test_prime(): + assert prime(1) == 2 + assert prime(2) == 3 + assert prime(5) == 11 + assert prime(11) == 31 + assert prime(57) == 269 + assert prime(296) == 1949 + assert prime(559) == 4051 + assert prime(3000) == 27449 + assert prime(4096) == 38873 + assert prime(9096) == 94321 + assert prime(25023) == 287341 + assert prime(10000000) == 179424673 # issue #20951 + assert prime(99999999) == 2038074739 + raises(ValueError, lambda: prime(0)) + sieve.extend(3000) + assert prime(401) == 2749 + raises(ValueError, lambda: prime(-1)) + + +def test__primepi(): + assert _primepi(-1) == 0 + assert _primepi(1) == 0 + assert _primepi(2) == 1 + assert _primepi(5) == 3 + assert _primepi(11) == 5 + assert _primepi(57) == 16 + assert _primepi(296) == 62 + assert _primepi(559) == 102 + assert _primepi(3000) == 430 + assert _primepi(4096) == 564 + assert _primepi(9096) == 1128 + assert _primepi(25023) == 2763 + assert _primepi(10**8) == 5761455 + assert _primepi(253425253) == 13856396 + assert _primepi(8769575643) == 401464322 + sieve.extend(3000) + assert _primepi(2000) == 303 + + +def test_composite(): + from sympy.ntheory.generate import sieve + sieve._reset() + assert composite(1) == 4 + assert composite(2) == 6 + assert composite(5) == 10 + assert composite(11) == 20 + assert composite(41) == 58 + assert composite(57) == 80 + assert composite(296) == 370 + assert composite(559) == 684 + assert composite(3000) == 3488 + assert composite(4096) == 4736 + assert composite(9096) == 10368 + assert composite(25023) == 28088 + sieve.extend(3000) + assert composite(1957) == 2300 + assert composite(2568) == 2998 + raises(ValueError, lambda: composite(0)) + + +def test_compositepi(): + assert compositepi(1) == 0 + assert compositepi(2) == 0 + assert compositepi(5) == 1 + assert compositepi(11) == 5 + assert compositepi(57) == 40 + assert compositepi(296) == 233 + assert compositepi(559) == 456 + assert compositepi(3000) == 2569 + assert compositepi(4096) == 3531 + assert compositepi(9096) == 7967 + assert compositepi(25023) == 22259 + assert compositepi(10**8) == 94238544 + assert compositepi(253425253) == 239568856 + assert compositepi(8769575643) == 8368111320 + sieve.extend(3000) + assert compositepi(2321) == 1976 + + +def test_generate(): + from sympy.ntheory.generate import sieve + sieve._reset() + assert nextprime(-4) == 2 + assert nextprime(2) == 3 + assert nextprime(5) == 7 + assert nextprime(12) == 13 + assert prevprime(3) == 2 + assert prevprime(7) == 5 + assert prevprime(13) == 11 + assert prevprime(19) == 17 + assert prevprime(20) == 19 + + sieve.extend_to_no(9) + assert sieve._list[-1] == 23 + + assert sieve._list[-1] < 31 + assert 31 in sieve + + assert nextprime(90) == 97 + assert nextprime(10**40) == (10**40 + 121) + primelist = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, + 37, 41, 43, 47, 53, 59, 61, 67, 71, 73, + 79, 83, 89, 97, 101, 103, 107, 109, 113, + 127, 131, 137, 139, 149, 151, 157, 163, + 167, 173, 179, 181, 191, 193, 197, 199, + 211, 223, 227, 229, 233, 239, 241, 251, + 257, 263, 269, 271, 277, 281, 283, 293] + for i in range(len(primelist) - 2): + for j in range(2, len(primelist) - i): + assert nextprime(primelist[i], j) == primelist[i + j] + if 3 < i: + assert nextprime(primelist[i] - 1, j) == primelist[i + j - 1] + raises(ValueError, lambda: nextprime(2, 0)) + raises(ValueError, lambda: nextprime(2, -1)) + assert prevprime(97) == 89 + assert prevprime(10**40) == (10**40 - 17) + + raises(ValueError, lambda: Sieve(0)) + raises(ValueError, lambda: Sieve(-1)) + for sieve_interval in [1, 10, 11, 1_000_000]: + s = Sieve(sieve_interval=sieve_interval) + for head in range(s._list[-1] + 1, (s._list[-1] + 1)**2, 2): + for tail in range(head + 1, (s._list[-1] + 1)**2): + A = list(s._primerange(head, tail)) + B = primelist[bisect(primelist, head):bisect_left(primelist, tail)] + assert A == B + for k in range(s._list[-1], primelist[-1] - 1, 2): + s = Sieve(sieve_interval=sieve_interval) + s.extend(k) + assert list(s._list) == primelist[:bisect(primelist, k)] + s.extend(primelist[-1]) + assert list(s._list) == primelist + + assert list(sieve.primerange(10, 1)) == [] + assert list(sieve.primerange(5, 9)) == [5, 7] + sieve._reset(prime=True) + assert list(sieve.primerange(2, 13)) == [2, 3, 5, 7, 11] + assert list(sieve.primerange(13)) == [2, 3, 5, 7, 11] + assert list(sieve.primerange(8)) == [2, 3, 5, 7] + assert list(sieve.primerange(-2)) == [] + assert list(sieve.primerange(29)) == [2, 3, 5, 7, 11, 13, 17, 19, 23] + assert list(sieve.primerange(34)) == [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31] + + assert list(sieve.totientrange(5, 15)) == [4, 2, 6, 4, 6, 4, 10, 4, 12, 6] + sieve._reset(totient=True) + assert list(sieve.totientrange(3, 13)) == [2, 2, 4, 2, 6, 4, 6, 4, 10, 4] + assert list(sieve.totientrange(900, 1000)) == [totient(x) for x in range(900, 1000)] + assert list(sieve.totientrange(0, 1)) == [] + assert list(sieve.totientrange(1, 2)) == [1] + + assert list(sieve.mobiusrange(5, 15)) == [-1, 1, -1, 0, 0, 1, -1, 0, -1, 1] + sieve._reset(mobius=True) + assert list(sieve.mobiusrange(3, 13)) == [-1, 0, -1, 1, -1, 0, 0, 1, -1, 0] + assert list(sieve.mobiusrange(1050, 1100)) == [mobius(x) for x in range(1050, 1100)] + assert list(sieve.mobiusrange(0, 1)) == [] + assert list(sieve.mobiusrange(1, 2)) == [1] + + assert list(primerange(10, 1)) == [] + assert list(primerange(2, 7)) == [2, 3, 5] + assert list(primerange(2, 10)) == [2, 3, 5, 7] + assert list(primerange(1050, 1100)) == [1051, 1061, + 1063, 1069, 1087, 1091, 1093, 1097] + s = Sieve() + for i in range(30, 2350, 376): + for j in range(2, 5096, 1139): + A = list(s.primerange(i, i + j)) + B = list(primerange(i, i + j)) + assert A == B + s = Sieve() + sieve._reset(prime=True) + sieve.extend(13) + for i in range(200): + for j in range(i, 200): + A = list(s.primerange(i, j)) + B = list(primerange(i, j)) + assert A == B + sieve.extend(1000) + for a, b in [(901, 1103), # a < 1000 < b < 1000**2 + (806, 1002007), # a < 1000 < 1000**2 < b + (2000, 30001), # 1000 < a < b < 1000**2 + (100005, 1010001), # 1000 < a < 1000**2 < b + (1003003, 1005000), # 1000**2 < a < b + ]: + assert list(primerange(a, b)) == list(s.primerange(a, b)) + sieve._reset(prime=True) + sieve.extend(100000) + assert len(sieve._list) == len(set(sieve._list)) + s = Sieve() + assert s[10] == 29 + + assert nextprime(2, 2) == 5 + + raises(ValueError, lambda: totient(0)) + + raises(ValueError, lambda: primorial(0)) + + assert mr(1, [2]) is False + + func = lambda i: (i**2 + 1) % 51 + assert next(cycle_length(func, 4)) == (6, 3) + assert list(cycle_length(func, 4, values=True)) == \ + [4, 17, 35, 2, 5, 26, 14, 44, 50, 2, 5, 26, 14] + assert next(cycle_length(func, 4, nmax=5)) == (5, None) + assert list(cycle_length(func, 4, nmax=5, values=True)) == \ + [4, 17, 35, 2, 5] + sieve.extend(3000) + assert nextprime(2968) == 2969 + assert prevprime(2930) == 2927 + raises(ValueError, lambda: prevprime(1)) + raises(ValueError, lambda: prevprime(-4)) + + +def test_randprime(): + assert randprime(10, 1) is None + assert randprime(3, -3) is None + assert randprime(2, 3) == 2 + assert randprime(1, 3) == 2 + assert randprime(3, 5) == 3 + raises(ValueError, lambda: randprime(-12, -2)) + raises(ValueError, lambda: randprime(-10, 0)) + raises(ValueError, lambda: randprime(20, 22)) + raises(ValueError, lambda: randprime(0, 2)) + raises(ValueError, lambda: randprime(1, 2)) + for a in [100, 300, 500, 250000]: + for b in [100, 300, 500, 250000]: + p = randprime(a, a + b) + assert a <= p < (a + b) and isprime(p) + + +def test_primorial(): + assert primorial(1) == 2 + assert primorial(1, nth=0) == 1 + assert primorial(2) == 6 + assert primorial(2, nth=0) == 2 + assert primorial(4, nth=0) == 6 + + +def test_search(): + assert 2 in sieve + assert 2.1 not in sieve + assert 1 not in sieve + assert 2**1000 not in sieve + raises(ValueError, lambda: sieve.search(1)) + + +def test_sieve_slice(): + assert sieve[5] == 11 + assert list(sieve[5:10]) == [sieve[x] for x in range(5, 10)] + assert list(sieve[5:10:2]) == [sieve[x] for x in range(5, 10, 2)] + assert list(sieve[1:5]) == [2, 3, 5, 7] + raises(IndexError, lambda: sieve[:5]) + raises(IndexError, lambda: sieve[0]) + raises(IndexError, lambda: sieve[0:5]) + +def test_sieve_iter(): + values = [] + for value in sieve: + if value > 7: + break + values.append(value) + assert values == list(sieve[1:5]) + + +def test_sieve_repr(): + assert "sieve" in repr(sieve) + assert "prime" in repr(sieve) + + +def test_deprecated_ntheory_symbolic_functions(): + from sympy.testing.pytest import warns_deprecated_sympy + + with warns_deprecated_sympy(): + assert primepi(0) == 0 diff --git a/phivenv/Lib/site-packages/sympy/ntheory/tests/test_hypothesis.py b/phivenv/Lib/site-packages/sympy/ntheory/tests/test_hypothesis.py new file mode 100644 index 0000000000000000000000000000000000000000..a8f4cbecdbb7a6b15b0e323700cda11039c968fb --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/ntheory/tests/test_hypothesis.py @@ -0,0 +1,24 @@ +from hypothesis import given +from hypothesis import strategies as st +from sympy import divisors +from sympy.functions.combinatorial.numbers import divisor_sigma, totient +from sympy.ntheory.primetest import is_square + + +@given(n=st.integers(1, 10**10)) +def test_tau_hypothesis(n): + div = divisors(n) + tau_n = len(div) + assert is_square(n) == (tau_n % 2 == 1) + sigmas = [divisor_sigma(i) for i in div] + totients = [totient(n // i) for i in div] + mul = [a * b for a, b in zip(sigmas, totients)] + assert n * tau_n == sum(mul) + + +@given(n=st.integers(1, 10**10)) +def test_totient_hypothesis(n): + assert totient(n) <= n + div = divisors(n) + totients = [totient(i) for i in div] + assert n == sum(totients) diff --git a/phivenv/Lib/site-packages/sympy/ntheory/tests/test_modular.py b/phivenv/Lib/site-packages/sympy/ntheory/tests/test_modular.py new file mode 100644 index 0000000000000000000000000000000000000000..10ebb1d3d3bdf5f736a6229579ae4c42a805745e --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/ntheory/tests/test_modular.py @@ -0,0 +1,34 @@ +from sympy.ntheory.modular import crt, crt1, crt2, solve_congruence +from sympy.testing.pytest import raises + + +def test_crt(): + def mcrt(m, v, r, symmetric=False): + assert crt(m, v, symmetric)[0] == r + mm, e, s = crt1(m) + assert crt2(m, v, mm, e, s, symmetric) == (r, mm) + + mcrt([2, 3, 5], [0, 0, 0], 0) + mcrt([2, 3, 5], [1, 1, 1], 1) + + mcrt([2, 3, 5], [-1, -1, -1], -1, True) + mcrt([2, 3, 5], [-1, -1, -1], 2*3*5 - 1, False) + + assert crt([656, 350], [811, 133], symmetric=True) == (-56917, 114800) + + +def test_modular(): + assert solve_congruence(*list(zip([3, 4, 2], [12, 35, 17]))) == (1719, 7140) + assert solve_congruence(*list(zip([3, 4, 2], [12, 6, 17]))) is None + assert solve_congruence(*list(zip([3, 4, 2], [13, 7, 17]))) == (172, 1547) + assert solve_congruence(*list(zip([-10, -3, -15], [13, 7, 17]))) == (172, 1547) + assert solve_congruence(*list(zip([-10, -3, 1, -15], [13, 7, 7, 17]))) is None + assert solve_congruence( + *list(zip([-10, -5, 2, -15], [13, 7, 7, 17]))) == (835, 1547) + assert solve_congruence( + *list(zip([-10, -5, 2, -15], [13, 7, 14, 17]))) == (2382, 3094) + assert solve_congruence( + *list(zip([-10, 2, 2, -15], [13, 7, 14, 17]))) == (2382, 3094) + assert solve_congruence(*list(zip((1, 1, 2), (3, 2, 4)))) is None + raises( + ValueError, lambda: solve_congruence(*list(zip([3, 4, 2], [12.1, 35, 17])))) diff --git a/phivenv/Lib/site-packages/sympy/ntheory/tests/test_multinomial.py b/phivenv/Lib/site-packages/sympy/ntheory/tests/test_multinomial.py new file mode 100644 index 0000000000000000000000000000000000000000..b455c5cc979b9ba9756c9da88c1694471115cd5d --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/ntheory/tests/test_multinomial.py @@ -0,0 +1,48 @@ +from sympy.ntheory.multinomial import (binomial_coefficients, binomial_coefficients_list, multinomial_coefficients) +from sympy.ntheory.multinomial import multinomial_coefficients_iterator + + +def test_binomial_coefficients_list(): + assert binomial_coefficients_list(0) == [1] + assert binomial_coefficients_list(1) == [1, 1] + assert binomial_coefficients_list(2) == [1, 2, 1] + assert binomial_coefficients_list(3) == [1, 3, 3, 1] + assert binomial_coefficients_list(4) == [1, 4, 6, 4, 1] + assert binomial_coefficients_list(5) == [1, 5, 10, 10, 5, 1] + assert binomial_coefficients_list(6) == [1, 6, 15, 20, 15, 6, 1] + + +def test_binomial_coefficients(): + for n in range(15): + c = binomial_coefficients(n) + l = [c[k] for k in sorted(c)] + assert l == binomial_coefficients_list(n) + + +def test_multinomial_coefficients(): + assert multinomial_coefficients(1, 1) == {(1,): 1} + assert multinomial_coefficients(1, 2) == {(2,): 1} + assert multinomial_coefficients(1, 3) == {(3,): 1} + assert multinomial_coefficients(2, 0) == {(0, 0): 1} + assert multinomial_coefficients(2, 1) == {(0, 1): 1, (1, 0): 1} + assert multinomial_coefficients(2, 2) == {(2, 0): 1, (0, 2): 1, (1, 1): 2} + assert multinomial_coefficients(2, 3) == {(3, 0): 1, (1, 2): 3, (0, 3): 1, + (2, 1): 3} + assert multinomial_coefficients(3, 1) == {(1, 0, 0): 1, (0, 1, 0): 1, + (0, 0, 1): 1} + assert multinomial_coefficients(3, 2) == {(0, 1, 1): 2, (0, 0, 2): 1, + (1, 1, 0): 2, (0, 2, 0): 1, (1, 0, 1): 2, (2, 0, 0): 1} + mc = multinomial_coefficients(3, 3) + assert mc == {(2, 1, 0): 3, (0, 3, 0): 1, + (1, 0, 2): 3, (0, 2, 1): 3, (0, 1, 2): 3, (3, 0, 0): 1, + (2, 0, 1): 3, (1, 2, 0): 3, (1, 1, 1): 6, (0, 0, 3): 1} + assert dict(multinomial_coefficients_iterator(2, 0)) == {(0, 0): 1} + assert dict( + multinomial_coefficients_iterator(2, 1)) == {(0, 1): 1, (1, 0): 1} + assert dict(multinomial_coefficients_iterator(2, 2)) == \ + {(2, 0): 1, (0, 2): 1, (1, 1): 2} + assert dict(multinomial_coefficients_iterator(3, 3)) == mc + it = multinomial_coefficients_iterator(7, 2) + assert [next(it) for i in range(4)] == \ + [((2, 0, 0, 0, 0, 0, 0), 1), ((1, 1, 0, 0, 0, 0, 0), 2), + ((0, 2, 0, 0, 0, 0, 0), 1), ((1, 0, 1, 0, 0, 0, 0), 2)] diff --git a/phivenv/Lib/site-packages/sympy/ntheory/tests/test_partitions.py b/phivenv/Lib/site-packages/sympy/ntheory/tests/test_partitions.py new file mode 100644 index 0000000000000000000000000000000000000000..8eb7fad3445068ae7ae4033c76c808e3c87347b6 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/ntheory/tests/test_partitions.py @@ -0,0 +1,28 @@ +from sympy.ntheory.partitions_ import npartitions, _partition_rec, _partition + + +def test__partition_rec(): + A000041 = [1, 1, 2, 3, 5, 7, 11, 15, 22, 30, 42, 56, 77, 101, 135, + 176, 231, 297, 385, 490, 627, 792, 1002, 1255, 1575] + for n, val in enumerate(A000041): + assert _partition_rec(n) == val + + +def test__partition(): + assert [_partition(k) for k in range(13)] == \ + [1, 1, 2, 3, 5, 7, 11, 15, 22, 30, 42, 56, 77] + assert _partition(100) == 190569292 + assert _partition(200) == 3972999029388 + assert _partition(1000) == 24061467864032622473692149727991 + assert _partition(1001) == 25032297938763929621013218349796 + assert _partition(2000) == 4720819175619413888601432406799959512200344166 + assert _partition(10000) % 10**10 == 6916435144 + assert _partition(100000) % 10**10 == 9421098519 + assert _partition(10000000) % 10**10 == 7677288980 + + +def test_deprecated_ntheory_symbolic_functions(): + from sympy.testing.pytest import warns_deprecated_sympy + + with warns_deprecated_sympy(): + assert npartitions(0) == 1 diff --git a/phivenv/Lib/site-packages/sympy/ntheory/tests/test_primetest.py b/phivenv/Lib/site-packages/sympy/ntheory/tests/test_primetest.py new file mode 100644 index 0000000000000000000000000000000000000000..8a56332941d9421bda4d6acc1e4b3406617cee2b --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/ntheory/tests/test_primetest.py @@ -0,0 +1,235 @@ +from math import gcd + +from sympy.ntheory.generate import Sieve, sieve +from sympy.ntheory.primetest import (mr, _lucas_extrastrong_params, is_lucas_prp, is_square, + is_strong_lucas_prp, is_extra_strong_lucas_prp, + proth_test, isprime, is_euler_pseudoprime, + is_gaussian_prime, is_fermat_pseudoprime, is_euler_jacobi_pseudoprime, + MERSENNE_PRIME_EXPONENTS, _lucas_lehmer_primality_test, + is_mersenne_prime) + +from sympy.testing.pytest import slow, raises +from sympy.core.numbers import I, Float + + +def test_is_fermat_pseudoprime(): + assert is_fermat_pseudoprime(5, 1) + assert is_fermat_pseudoprime(9, 1) + + +def test_euler_pseudoprimes(): + assert is_euler_pseudoprime(13, 1) + assert is_euler_pseudoprime(15, 1) + assert is_euler_pseudoprime(17, 6) + assert is_euler_pseudoprime(101, 7) + assert is_euler_pseudoprime(1009, 10) + assert is_euler_pseudoprime(11287, 41) + + raises(ValueError, lambda: is_euler_pseudoprime(0, 4)) + raises(ValueError, lambda: is_euler_pseudoprime(3, 0)) + raises(ValueError, lambda: is_euler_pseudoprime(15, 6)) + + # A006970 + euler_prp = [341, 561, 1105, 1729, 1905, 2047, 2465, 3277, + 4033, 4681, 5461, 6601, 8321, 8481, 10261, 10585] + for p in euler_prp: + assert is_euler_pseudoprime(p, 2) + + # A048950 + euler_prp = [121, 703, 1729, 1891, 2821, 3281, 7381, 8401, 8911, 10585, + 12403, 15457, 15841, 16531, 18721, 19345, 23521, 24661, 28009] + for p in euler_prp: + assert is_euler_pseudoprime(p, 3) + + # A033181 + absolute_euler_prp = [1729, 2465, 15841, 41041, 46657, 75361, + 162401, 172081, 399001, 449065, 488881] + for p in absolute_euler_prp: + for a in range(2, p): + if gcd(a, p) != 1: + continue + assert is_euler_pseudoprime(p, a) + + +def test_is_euler_jacobi_pseudoprime(): + assert is_euler_jacobi_pseudoprime(11, 1) + assert is_euler_jacobi_pseudoprime(15, 1) + + +def test_lucas_extrastrong_params(): + assert _lucas_extrastrong_params(3) == (5, 3, 1) + assert _lucas_extrastrong_params(5) == (12, 4, 1) + assert _lucas_extrastrong_params(7) == (5, 3, 1) + assert _lucas_extrastrong_params(9) == (0, 0, 0) + assert _lucas_extrastrong_params(11) == (21, 5, 1) + assert _lucas_extrastrong_params(59) == (32, 6, 1) + assert _lucas_extrastrong_params(479) == (117, 11, 1) + + +def test_is_extra_strong_lucas_prp(): + assert is_extra_strong_lucas_prp(4) == False + assert is_extra_strong_lucas_prp(989) == True + assert is_extra_strong_lucas_prp(10877) == True + assert is_extra_strong_lucas_prp(9) == False + assert is_extra_strong_lucas_prp(16) == False + assert is_extra_strong_lucas_prp(169) == False + +@slow +def test_prps(): + oddcomposites = [n for n in range(1, 10**5) if + n % 2 and not isprime(n)] + # A checksum would be better. + assert sum(oddcomposites) == 2045603465 + assert [n for n in oddcomposites if mr(n, [2])] == [ + 2047, 3277, 4033, 4681, 8321, 15841, 29341, 42799, 49141, + 52633, 65281, 74665, 80581, 85489, 88357, 90751] + assert [n for n in oddcomposites if mr(n, [3])] == [ + 121, 703, 1891, 3281, 8401, 8911, 10585, 12403, 16531, + 18721, 19345, 23521, 31621, 44287, 47197, 55969, 63139, + 74593, 79003, 82513, 87913, 88573, 97567] + assert [n for n in oddcomposites if mr(n, [325])] == [ + 9, 25, 27, 49, 65, 81, 325, 341, 343, 697, 1141, 2059, + 2149, 3097, 3537, 4033, 4681, 4941, 5833, 6517, 7987, 8911, + 12403, 12913, 15043, 16021, 20017, 22261, 23221, 24649, + 24929, 31841, 35371, 38503, 43213, 44173, 47197, 50041, + 55909, 56033, 58969, 59089, 61337, 65441, 68823, 72641, + 76793, 78409, 85879] + assert not any(mr(n, [9345883071009581737]) for n in oddcomposites) + assert [n for n in oddcomposites if is_lucas_prp(n)] == [ + 323, 377, 1159, 1829, 3827, 5459, 5777, 9071, 9179, 10877, + 11419, 11663, 13919, 14839, 16109, 16211, 18407, 18971, + 19043, 22499, 23407, 24569, 25199, 25877, 26069, 27323, + 32759, 34943, 35207, 39059, 39203, 39689, 40309, 44099, + 46979, 47879, 50183, 51983, 53663, 56279, 58519, 60377, + 63881, 69509, 72389, 73919, 75077, 77219, 79547, 79799, + 82983, 84419, 86063, 90287, 94667, 97019, 97439] + assert [n for n in oddcomposites if is_strong_lucas_prp(n)] == [ + 5459, 5777, 10877, 16109, 18971, 22499, 24569, 25199, 40309, + 58519, 75077, 97439] + assert [n for n in oddcomposites if is_extra_strong_lucas_prp(n) + ] == [ + 989, 3239, 5777, 10877, 27971, 29681, 30739, 31631, 39059, + 72389, 73919, 75077] + + +def test_proth_test(): + # Proth number + A080075 = [3, 5, 9, 13, 17, 25, 33, 41, 49, 57, 65, + 81, 97, 113, 129, 145, 161, 177, 193] + # Proth prime + A080076 = [3, 5, 13, 17, 41, 97, 113, 193] + + for n in range(200): + if n in A080075: + assert proth_test(n) == (n in A080076) + else: + raises(ValueError, lambda: proth_test(n)) + + +def test_lucas_lehmer_primality_test(): + for p in sieve.primerange(3, 100): + assert _lucas_lehmer_primality_test(p) == (p in MERSENNE_PRIME_EXPONENTS) + + +def test_is_mersenne_prime(): + assert is_mersenne_prime(-3) is False + assert is_mersenne_prime(3) is True + assert is_mersenne_prime(10) is False + assert is_mersenne_prime(127) is True + assert is_mersenne_prime(511) is False + assert is_mersenne_prime(131071) is True + assert is_mersenne_prime(2147483647) is True + + +def test_isprime(): + s = Sieve() + s.extend(100000) + ps = set(s.primerange(2, 100001)) + for n in range(100001): + # if (n in ps) != isprime(n): print n + assert (n in ps) == isprime(n) + assert isprime(179424673) + assert isprime(20678048681) + assert isprime(1968188556461) + assert isprime(2614941710599) + assert isprime(65635624165761929287) + assert isprime(1162566711635022452267983) + assert isprime(77123077103005189615466924501) + assert isprime(3991617775553178702574451996736229) + assert isprime(273952953553395851092382714516720001799) + assert isprime(int(''' +531137992816767098689588206552468627329593117727031923199444138200403\ +559860852242739162502265229285668889329486246501015346579337652707239\ +409519978766587351943831270835393219031728127''')) + + # Some Mersenne primes + assert isprime(2**61 - 1) + assert isprime(2**89 - 1) + assert isprime(2**607 - 1) + # (but not all Mersenne's are primes + assert not isprime(2**601 - 1) + + # pseudoprimes + #------------- + # to some small bases + assert not isprime(2152302898747) + assert not isprime(3474749660383) + assert not isprime(341550071728321) + assert not isprime(3825123056546413051) + # passes the base set [2, 3, 7, 61, 24251] + assert not isprime(9188353522314541) + # large examples + assert not isprime(877777777777777777777777) + # conjectured psi_12 given at http://mathworld.wolfram.com/StrongPseudoprime.html + assert not isprime(318665857834031151167461) + # conjectured psi_17 given at http://mathworld.wolfram.com/StrongPseudoprime.html + assert not isprime(564132928021909221014087501701) + # Arnault's 1993 number; a factor of it is + # 400958216639499605418306452084546853005188166041132508774506\ + # 204738003217070119624271622319159721973358216316508535816696\ + # 9145233813917169287527980445796800452592031836601 + assert not isprime(int(''' +803837457453639491257079614341942108138837688287558145837488917522297\ +427376533365218650233616396004545791504202360320876656996676098728404\ +396540823292873879185086916685732826776177102938969773947016708230428\ +687109997439976544144845341155872450633409279022275296229414984230688\ +1685404326457534018329786111298960644845216191652872597534901''')) + # Arnault's 1995 number; can be factored as + # p1*(313*(p1 - 1) + 1)*(353*(p1 - 1) + 1) where p1 is + # 296744956686855105501541746429053327307719917998530433509950\ + # 755312768387531717701995942385964281211880336647542183455624\ + # 93168782883 + assert not isprime(int(''' +288714823805077121267142959713039399197760945927972270092651602419743\ +230379915273311632898314463922594197780311092934965557841894944174093\ +380561511397999942154241693397290542371100275104208013496673175515285\ +922696291677532547504444585610194940420003990443211677661994962953925\ +045269871932907037356403227370127845389912612030924484149472897688540\ +6024976768122077071687938121709811322297802059565867''')) + sieve.extend(3000) + assert isprime(2819) + assert not isprime(2931) + raises(ValueError, lambda: isprime(2.0)) + raises(ValueError, lambda: isprime(Float(2))) + + +def test_is_square(): + assert [i for i in range(25) if is_square(i)] == [0, 1, 4, 9, 16] + + # issue #17044 + assert not is_square(60 ** 3) + assert not is_square(60 ** 5) + assert not is_square(84 ** 7) + assert not is_square(105 ** 9) + assert not is_square(120 ** 3) + +def test_is_gaussianprime(): + assert is_gaussian_prime(7*I) + assert is_gaussian_prime(7) + assert is_gaussian_prime(2 + 3*I) + assert not is_gaussian_prime(2 + 2*I) + + +def test_issue_27145(): + #https://github.com/sympy/sympy/issues/27145 + assert [mr(i,[2,3,5,7]) for i in (1, 2, 6)] == [False, True, False] diff --git a/phivenv/Lib/site-packages/sympy/ntheory/tests/test_qs.py b/phivenv/Lib/site-packages/sympy/ntheory/tests/test_qs.py new file mode 100644 index 0000000000000000000000000000000000000000..16932dd61badf4a467e67fa52e0f473594fd057b --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/ntheory/tests/test_qs.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +import math +from sympy.core.random import _randint +from sympy.ntheory import qs, qs_factor +from sympy.ntheory.qs import SievePolynomial, _generate_factor_base, \ + _generate_polynomial, \ + _gen_sieve_array, _check_smoothness, _trial_division_stage, _find_factor +from sympy.testing.pytest import slow + + +@slow +def test_qs_1(): + assert qs(10009202107, 100, 10000) == {100043, 100049} + assert qs(211107295182713951054568361, 1000, 10000) == \ + {13791315212531, 15307263442931} + assert qs(980835832582657*990377764891511, 2000, 10000) == \ + {980835832582657, 990377764891511} + assert qs(18640889198609*20991129234731, 1000, 50000) == \ + {18640889198609, 20991129234731} + + +def test_qs_2() -> None: + n = 10009202107 + M = 50 + sieve_poly = SievePolynomial(10, 80, n) + assert sieve_poly.eval_v(10) == sieve_poly.eval_u(10)**2 - n == -10009169707 + assert sieve_poly.eval_v(5) == sieve_poly.eval_u(5)**2 - n == -10009185207 + + idx_1000, idx_5000, factor_base = _generate_factor_base(2000, n) + assert idx_1000 == 82 + assert [factor_base[i].prime for i in range(15)] == \ + [2, 3, 7, 11, 17, 19, 29, 31, 43, 59, 61, 67, 71, 73, 79] + assert [factor_base[i].tmem_p for i in range(15)] == \ + [1, 1, 3, 5, 3, 6, 6, 14, 1, 16, 24, 22, 18, 22, 15] + assert [factor_base[i].log_p for i in range(5)] == \ + [710, 1125, 1993, 2455, 2901] + + it = _generate_polynomial( + n, M, factor_base, idx_1000, idx_5000, _randint(0)) + g = next(it) + assert g.a == 1133107 + assert g.b == 682543 + assert [factor_base[i].soln1 for i in range(15)] == \ + [0, 0, 3, 7, 13, 0, 8, 19, 9, 43, 27, 25, 63, 29, 19] + assert [factor_base[i].soln2 for i in range(15)] == \ + [0, 1, 1, 3, 12, 16, 15, 6, 15, 1, 56, 55, 61, 58, 16] + assert [factor_base[i].b_ainv for i in range(5)] == \ + [[0, 0], [0, 2], [3, 0], [3, 9], [13, 13]] + + g_1 = next(it) + assert g_1.a == 1133107 + assert g_1.b == 136765 + + sieve_array = _gen_sieve_array(M, factor_base) + assert sieve_array[0:5] == [8424, 13603, 1835, 5335, 710] + + assert _check_smoothness(9645, factor_base) == (36028797018963972, 5) + assert _check_smoothness(210313, factor_base) == (20992, 1) + + partial_relations: dict[int, tuple[int, int]] = {} + smooth_relation, proper_factor = _trial_division_stage( + n, M, factor_base, sieve_array, sieve_poly, partial_relations, + ERROR_TERM=25*2**10) + + assert partial_relations == { + 8699: (440, -10009008507, 75557863761098695507973), + 166741: (490, -10008962007, 524341), + 131449: (530, -10008921207, 664613997892457936451903530140172325), + 6653: (550, -10008899607, 19342813113834066795307021) + } + assert [smooth_relation[i][0] for i in range(5)] == [ + -250, 1064469, 72819, 231957, 44167] + assert [smooth_relation[i][1] for i in range(5)] == [ + -10009139607, 1133094251961, 5302606761, 53804049849, 1950723889] + assert smooth_relation[0][2] == 89213869829863962596973701078031812362502145 + assert proper_factor == set() + + +def test_qs_3(): + N = 1817 + smooth_relations = [ + (2455024, 637, 8), + (-27993000, 81536, 10), + (11461840, 12544, 0), + (149, 20384, 10), + (-31138074, 19208, 2) + ] + assert next(_find_factor(N, smooth_relations, 4)) == 23 + + +def test_qs_4(): + N = 10007**2 * 10009 * 10037**3 * 10039 + for factor in qs(N, 1000, 2000): + assert N % factor == 0 + N //= factor + + +def test_qs_factor(): + assert qs_factor(1009 * 100003, 2000, 10000) == {1009: 1, 100003: 1} + n = 1009**2 * 2003**2*30011*400009 + factors = qs_factor(n, 2000, 10000) + assert len(factors) > 1 + assert math.prod(p**e for p, e in factors.items()) == n + + +def test_issue_27616(): + #https://github.com/sympy/sympy/issues/27616 + N = 9804659461513846513 + 1 + assert qs(N, 5000, 20000) is not None diff --git a/phivenv/Lib/site-packages/sympy/ntheory/tests/test_residue.py b/phivenv/Lib/site-packages/sympy/ntheory/tests/test_residue.py new file mode 100644 index 0000000000000000000000000000000000000000..4d530905f39b88d8d7cc0e861ac6eadb2fa6f98a --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/ntheory/tests/test_residue.py @@ -0,0 +1,349 @@ +from collections import defaultdict +from sympy.core.containers import Tuple +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol) +from sympy.functions.combinatorial.numbers import totient +from sympy.ntheory import n_order, is_primitive_root, is_quad_residue, \ + legendre_symbol, jacobi_symbol, primerange, sqrt_mod, \ + primitive_root, quadratic_residues, is_nthpow_residue, nthroot_mod, \ + sqrt_mod_iter, mobius, discrete_log, quadratic_congruence, \ + polynomial_congruence, sieve +from sympy.ntheory.residue_ntheory import _primitive_root_prime_iter, \ + _primitive_root_prime_power_iter, _primitive_root_prime_power2_iter, \ + _nthroot_mod_prime_power, _discrete_log_trial_mul, _discrete_log_shanks_steps, \ + _discrete_log_pollard_rho, _discrete_log_index_calculus, _discrete_log_pohlig_hellman, \ + _binomial_mod_prime_power, binomial_mod +from sympy.polys.domains import ZZ +from sympy.testing.pytest import raises +from sympy.core.random import randint, choice + + +def test_residue(): + assert n_order(2, 13) == 12 + assert [n_order(a, 7) for a in range(1, 7)] == \ + [1, 3, 6, 3, 6, 2] + assert n_order(5, 17) == 16 + assert n_order(17, 11) == n_order(6, 11) + assert n_order(101, 119) == 6 + assert n_order(11, (10**50 + 151)**2) == 10000000000000000000000000000000000000000000000030100000000000000000000000000000000000000000000022650 + raises(ValueError, lambda: n_order(6, 9)) + + assert is_primitive_root(2, 7) is False + assert is_primitive_root(3, 8) is False + assert is_primitive_root(11, 14) is False + assert is_primitive_root(12, 17) == is_primitive_root(29, 17) + raises(ValueError, lambda: is_primitive_root(3, 6)) + + for p in primerange(3, 100): + li = list(_primitive_root_prime_iter(p)) + assert li[0] == min(li) + for g in li: + assert n_order(g, p) == p - 1 + assert len(li) == totient(totient(p)) + for e in range(1, 4): + li_power = list(_primitive_root_prime_power_iter(p, e)) + li_power2 = list(_primitive_root_prime_power2_iter(p, e)) + assert len(li_power) == len(li_power2) == totient(totient(p**e)) + assert primitive_root(97) == 5 + assert n_order(primitive_root(97, False), 97) == totient(97) + assert primitive_root(97**2) == 5 + assert n_order(primitive_root(97**2, False), 97**2) == totient(97**2) + assert primitive_root(40487) == 5 + assert n_order(primitive_root(40487, False), 40487) == totient(40487) + # note that primitive_root(40487) + 40487 = 40492 is a primitive root + # of 40487**2, but it is not the smallest + assert primitive_root(40487**2) == 10 + assert n_order(primitive_root(40487**2, False), 40487**2) == totient(40487**2) + assert primitive_root(82) == 7 + assert n_order(primitive_root(82, False), 82) == totient(82) + p = 10**50 + 151 + assert primitive_root(p) == 11 + assert n_order(primitive_root(p, False), p) == totient(p) + assert primitive_root(2*p) == 11 + assert n_order(primitive_root(2*p, False), 2*p) == totient(2*p) + assert primitive_root(p**2) == 11 + assert n_order(primitive_root(p**2, False), p**2) == totient(p**2) + assert primitive_root(4 * 11) is None and primitive_root(4 * 11, False) is None + assert primitive_root(15) is None and primitive_root(15, False) is None + raises(ValueError, lambda: primitive_root(-3)) + + assert is_quad_residue(3, 7) is False + assert is_quad_residue(10, 13) is True + assert is_quad_residue(12364, 139) == is_quad_residue(12364 % 139, 139) + assert is_quad_residue(207, 251) is True + assert is_quad_residue(0, 1) is True + assert is_quad_residue(1, 1) is True + assert is_quad_residue(0, 2) == is_quad_residue(1, 2) is True + assert is_quad_residue(1, 4) is True + assert is_quad_residue(2, 27) is False + assert is_quad_residue(13122380800, 13604889600) is True + assert [j for j in range(14) if is_quad_residue(j, 14)] == \ + [0, 1, 2, 4, 7, 8, 9, 11] + raises(ValueError, lambda: is_quad_residue(1.1, 2)) + raises(ValueError, lambda: is_quad_residue(2, 0)) + + assert quadratic_residues(S.One) == [0] + assert quadratic_residues(1) == [0] + assert quadratic_residues(12) == [0, 1, 4, 9] + assert quadratic_residues(13) == [0, 1, 3, 4, 9, 10, 12] + assert [len(quadratic_residues(i)) for i in range(1, 20)] == \ + [1, 2, 2, 2, 3, 4, 4, 3, 4, 6, 6, 4, 7, 8, 6, 4, 9, 8, 10] + + assert list(sqrt_mod_iter(6, 2)) == [0] + assert sqrt_mod(3, 13) == 4 + assert sqrt_mod(3, -13) == 4 + assert sqrt_mod(6, 23) == 11 + assert sqrt_mod(345, 690) == 345 + assert sqrt_mod(67, 101) == None + assert sqrt_mod(1020, 104729) == None + + for p in range(3, 100): + d = defaultdict(list) + for i in range(p): + d[pow(i, 2, p)].append(i) + for i in range(1, p): + it = sqrt_mod_iter(i, p) + v = sqrt_mod(i, p, True) + if v: + v = sorted(v) + assert d[i] == v + else: + assert not d[i] + + assert sqrt_mod(9, 27, True) == [3, 6, 12, 15, 21, 24] + assert sqrt_mod(9, 81, True) == [3, 24, 30, 51, 57, 78] + assert sqrt_mod(9, 3**5, True) == [3, 78, 84, 159, 165, 240] + assert sqrt_mod(81, 3**4, True) == [0, 9, 18, 27, 36, 45, 54, 63, 72] + assert sqrt_mod(81, 3**5, True) == [9, 18, 36, 45, 63, 72, 90, 99, 117,\ + 126, 144, 153, 171, 180, 198, 207, 225, 234] + assert sqrt_mod(81, 3**6, True) == [9, 72, 90, 153, 171, 234, 252, 315,\ + 333, 396, 414, 477, 495, 558, 576, 639, 657, 720] + assert sqrt_mod(81, 3**7, True) == [9, 234, 252, 477, 495, 720, 738, 963,\ + 981, 1206, 1224, 1449, 1467, 1692, 1710, 1935, 1953, 2178] + + for a, p in [(26214400, 32768000000), (26214400, 16384000000), + (262144, 1048576), (87169610025, 163443018796875), + (22315420166400, 167365651248000000)]: + assert pow(sqrt_mod(a, p), 2, p) == a + + n = 70 + a, p = 5**2*3**n*2**n, 5**6*3**(n+1)*2**(n+2) + it = sqrt_mod_iter(a, p) + for i in range(10): + assert pow(next(it), 2, p) == a + a, p = 5**2*3**n*2**n, 5**6*3**(n+1)*2**(n+3) + it = sqrt_mod_iter(a, p) + for i in range(2): + assert pow(next(it), 2, p) == a + n = 100 + a, p = 5**2*3**n*2**n, 5**6*3**(n+1)*2**(n+1) + it = sqrt_mod_iter(a, p) + for i in range(2): + assert pow(next(it), 2, p) == a + + assert type(next(sqrt_mod_iter(9, 27))) is int + assert type(next(sqrt_mod_iter(9, 27, ZZ))) is type(ZZ(1)) + assert type(next(sqrt_mod_iter(1, 7, ZZ))) is type(ZZ(1)) + + assert is_nthpow_residue(2, 1, 5) + + #issue 10816 + assert is_nthpow_residue(1, 0, 1) is False + assert is_nthpow_residue(1, 0, 2) is True + assert is_nthpow_residue(3, 0, 2) is True + assert is_nthpow_residue(0, 1, 8) is True + assert is_nthpow_residue(2, 3, 2) is True + assert is_nthpow_residue(2, 3, 9) is False + assert is_nthpow_residue(3, 5, 30) is True + assert is_nthpow_residue(21, 11, 20) is True + assert is_nthpow_residue(7, 10, 20) is False + assert is_nthpow_residue(5, 10, 20) is True + assert is_nthpow_residue(3, 10, 48) is False + assert is_nthpow_residue(1, 10, 40) is True + assert is_nthpow_residue(3, 10, 24) is False + assert is_nthpow_residue(1, 10, 24) is True + assert is_nthpow_residue(3, 10, 24) is False + assert is_nthpow_residue(2, 10, 48) is False + assert is_nthpow_residue(81, 3, 972) is False + assert is_nthpow_residue(243, 5, 5103) is True + assert is_nthpow_residue(243, 3, 1240029) is False + assert is_nthpow_residue(36010, 8, 87382) is True + assert is_nthpow_residue(28552, 6, 2218) is True + assert is_nthpow_residue(92712, 9, 50026) is True + x = {pow(i, 56, 1024) for i in range(1024)} + assert {a for a in range(1024) if is_nthpow_residue(a, 56, 1024)} == x + x = { pow(i, 256, 2048) for i in range(2048)} + assert {a for a in range(2048) if is_nthpow_residue(a, 256, 2048)} == x + x = { pow(i, 11, 324000) for i in range(1000)} + assert [ is_nthpow_residue(a, 11, 324000) for a in x] + x = { pow(i, 17, 22217575536) for i in range(1000)} + assert [ is_nthpow_residue(a, 17, 22217575536) for a in x] + assert is_nthpow_residue(676, 3, 5364) + assert is_nthpow_residue(9, 12, 36) + assert is_nthpow_residue(32, 10, 41) + assert is_nthpow_residue(4, 2, 64) + assert is_nthpow_residue(31, 4, 41) + assert not is_nthpow_residue(2, 2, 5) + assert is_nthpow_residue(8547, 12, 10007) + assert is_nthpow_residue(Dummy(even=True) + 3, 3, 2) == True + # _nthroot_mod_prime_power + for p in primerange(2, 10): + for a in range(3): + for n in range(3, 5): + ans = _nthroot_mod_prime_power(a, n, p, 1) + assert isinstance(ans, list) + if len(ans) == 0: + for b in range(p): + assert pow(b, n, p) != a % p + for k in range(2, 10): + assert _nthroot_mod_prime_power(a, n, p, k) == [] + else: + for b in range(p): + pred = pow(b, n, p) == a % p + assert not(pred ^ (b in ans)) + for k in range(2, 10): + ans = _nthroot_mod_prime_power(a, n, p, k) + if not ans: + break + for b in ans: + assert pow(b, n , p**k) == a + + assert nthroot_mod(Dummy(odd=True), 3, 2) == 1 + assert nthroot_mod(29, 31, 74) == 45 + assert nthroot_mod(1801, 11, 2663) == 44 + for a, q, p in [(51922, 2, 203017), (43, 3, 109), (1801, 11, 2663), + (26118163, 1303, 33333347), (1499, 7, 2663), (595, 6, 2663), + (1714, 12, 2663), (28477, 9, 33343)]: + r = nthroot_mod(a, q, p) + assert pow(r, q, p) == a + assert nthroot_mod(11, 3, 109) is None + assert nthroot_mod(16, 5, 36, True) == [4, 22] + assert nthroot_mod(9, 16, 36, True) == [3, 9, 15, 21, 27, 33] + assert nthroot_mod(4, 3, 3249000) is None + assert nthroot_mod(36010, 8, 87382, True) == [40208, 47174] + assert nthroot_mod(0, 12, 37, True) == [0] + assert nthroot_mod(0, 7, 100, True) == [0, 10, 20, 30, 40, 50, 60, 70, 80, 90] + assert nthroot_mod(4, 4, 27, True) == [5, 22] + assert nthroot_mod(4, 4, 121, True) == [19, 102] + assert nthroot_mod(2, 3, 7, True) == [] + for p in range(1, 20): + for a in range(p): + for n in range(1, p): + ans = nthroot_mod(a, n, p, True) + assert isinstance(ans, list) + for b in range(p): + pred = pow(b, n, p) == a + assert not(pred ^ (b in ans)) + ans2 = nthroot_mod(a, n, p, False) + if ans2 is None: + assert ans == [] + else: + assert ans2 in ans + + x = Symbol('x', positive=True) + i = Symbol('i', integer=True) + assert _discrete_log_trial_mul(587, 2**7, 2) == 7 + assert _discrete_log_trial_mul(941, 7**18, 7) == 18 + assert _discrete_log_trial_mul(389, 3**81, 3) == 81 + assert _discrete_log_trial_mul(191, 19**123, 19) == 123 + assert _discrete_log_shanks_steps(442879, 7**2, 7) == 2 + assert _discrete_log_shanks_steps(874323, 5**19, 5) == 19 + assert _discrete_log_shanks_steps(6876342, 7**71, 7) == 71 + assert _discrete_log_shanks_steps(2456747, 3**321, 3) == 321 + assert _discrete_log_pollard_rho(6013199, 2**6, 2, rseed=0) == 6 + assert _discrete_log_pollard_rho(6138719, 2**19, 2, rseed=0) == 19 + assert _discrete_log_pollard_rho(36721943, 2**40, 2, rseed=0) == 40 + assert _discrete_log_pollard_rho(24567899, 3**333, 3, rseed=0) == 333 + raises(ValueError, lambda: _discrete_log_pollard_rho(11, 7, 31, rseed=0)) + raises(ValueError, lambda: _discrete_log_pollard_rho(227, 3**7, 5, rseed=0)) + assert _discrete_log_index_calculus(983, 948, 2, 491) == 183 + assert _discrete_log_index_calculus(633383, 21794, 2, 316691) == 68048 + assert _discrete_log_index_calculus(941762639, 68822582, 2, 470881319) == 338029275 + assert _discrete_log_index_calculus(999231337607, 888188918786, 2, 499615668803) == 142811376514 + assert _discrete_log_index_calculus(47747730623, 19410045286, 43425105668, 645239603) == 590504662 + assert _discrete_log_pohlig_hellman(98376431, 11**9, 11) == 9 + assert _discrete_log_pohlig_hellman(78723213, 11**31, 11) == 31 + assert _discrete_log_pohlig_hellman(32942478, 11**98, 11) == 98 + assert _discrete_log_pohlig_hellman(14789363, 11**444, 11) == 444 + assert discrete_log(1, 0, 2) == 0 + raises(ValueError, lambda: discrete_log(-4, 1, 3)) + raises(ValueError, lambda: discrete_log(10, 3, 2)) + assert discrete_log(587, 2**9, 2) == 9 + assert discrete_log(2456747, 3**51, 3) == 51 + assert discrete_log(32942478, 11**127, 11) == 127 + assert discrete_log(432751500361, 7**324, 7) == 324 + assert discrete_log(265390227570863,184500076053622, 2) == 17835221372061 + assert discrete_log(22708823198678103974314518195029102158525052496759285596453269189798311427475159776411276642277139650833937, + 17463946429475485293747680247507700244427944625055089103624311227422110546803452417458985046168310373075327, + 123456) == 2068031853682195777930683306640554533145512201725884603914601918777510185469769997054750835368413389728895 + args = 5779, 3528, 6215 + assert discrete_log(*args) == 687 + assert discrete_log(*Tuple(*args)) == 687 + assert quadratic_congruence(400, 85, 125, 1600) == [295, 615, 935, 1255, 1575] + assert quadratic_congruence(3, 6, 5, 25) == [3, 20] + assert quadratic_congruence(120, 80, 175, 500) == [] + assert quadratic_congruence(15, 14, 7, 2) == [1] + assert quadratic_congruence(8, 15, 7, 29) == [10, 28] + assert quadratic_congruence(160, 200, 300, 461) == [144, 431] + assert quadratic_congruence(100000, 123456, 7415263, 48112959837082048697) == [30417843635344493501, 36001135160550533083] + assert quadratic_congruence(65, 121, 72, 277) == [249, 252] + assert quadratic_congruence(5, 10, 14, 2) == [0] + assert quadratic_congruence(10, 17, 19, 2) == [1] + assert quadratic_congruence(10, 14, 20, 2) == [0, 1] + assert quadratic_congruence(2**48-7, 2**48-1, 4, 2**48) == [8249717183797, 31960993774868] + assert polynomial_congruence(6*x**5 + 10*x**4 + 5*x**3 + x**2 + x + 1, + 972000) == [220999, 242999, 463999, 485999, 706999, 728999, 949999, 971999] + + assert polynomial_congruence(x**3 - 10*x**2 + 12*x - 82, 33075) == [30287] + assert polynomial_congruence(x**2 + x + 47, 2401) == [785, 1615] + assert polynomial_congruence(10*x**2 + 14*x + 20, 2) == [0, 1] + assert polynomial_congruence(x**3 + 3, 16) == [5] + assert polynomial_congruence(65*x**2 + 121*x + 72, 277) == [249, 252] + assert polynomial_congruence(x**4 - 4, 27) == [5, 22] + assert polynomial_congruence(35*x**3 - 6*x**2 - 567*x + 2308, 148225) == [86957, + 111157, 122531, 146731] + assert polynomial_congruence(x**16 - 9, 36) == [3, 9, 15, 21, 27, 33] + assert polynomial_congruence(x**6 - 2*x**5 - 35, 6125) == [3257] + raises(ValueError, lambda: polynomial_congruence(x**x, 6125)) + raises(ValueError, lambda: polynomial_congruence(x**i, 6125)) + raises(ValueError, lambda: polynomial_congruence(0.1*x**2 + 6, 100)) + + assert binomial_mod(-1, 1, 10) == 0 + assert binomial_mod(1, -1, 10) == 0 + raises(ValueError, lambda: binomial_mod(2, 1, -1)) + assert binomial_mod(51, 10, 10) == 0 + assert binomial_mod(10**3, 500, 3**6) == 567 + assert binomial_mod(10**18 - 1, 123456789, 4) == 0 + assert binomial_mod(10**18, 10**12, (10**5 + 3)**2) == 3744312326 + + +def test_binomial_p_pow(): + n, binomials, binomial = 1000, [1], 1 + for i in range(1, n + 1): + binomial *= n - i + 1 + binomial //= i + binomials.append(binomial) + + # Test powers of two, which the algorithm treats slightly differently + trials_2 = 100 + for _ in range(trials_2): + m, power = randint(0, n), randint(1, 20) + assert _binomial_mod_prime_power(n, m, 2, power) == binomials[m] % 2**power + + # Test against other prime powers + primes = list(sieve.primerange(2*n)) + trials = 1000 + for _ in range(trials): + m, prime, power = randint(0, n), choice(primes), randint(1, 10) + assert _binomial_mod_prime_power(n, m, prime, power) == binomials[m] % prime**power + + +def test_deprecated_ntheory_symbolic_functions(): + from sympy.testing.pytest import warns_deprecated_sympy + + with warns_deprecated_sympy(): + assert mobius(3) == -1 + with warns_deprecated_sympy(): + assert legendre_symbol(2, 3) == -1 + with warns_deprecated_sympy(): + assert jacobi_symbol(2, 3) == -1 diff --git a/phivenv/Lib/site-packages/sympy/parsing/__init__.py b/phivenv/Lib/site-packages/sympy/parsing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b39d031bca26bc599eb9eb0e12dfe48f7e6db174 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/__init__.py @@ -0,0 +1,4 @@ +"""Used for translating a string into a SymPy expression. """ +__all__ = ['parse_expr'] + +from .sympy_parser import parse_expr diff --git a/phivenv/Lib/site-packages/sympy/parsing/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9e98f15ce76a162b1ce819760fc2b3cb04cfe188 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/__pycache__/ast_parser.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/__pycache__/ast_parser.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c71ced93b3e6f3156a40f3bbe5f07e1b4196d28b Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/__pycache__/ast_parser.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/__pycache__/mathematica.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/__pycache__/mathematica.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c364e63532a5d912496c095f672a61ebfb74191e Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/__pycache__/mathematica.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/__pycache__/maxima.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/__pycache__/maxima.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e2d0c0d1951845e8f52afbfc9eb3dc67d46e099 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/__pycache__/maxima.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/__pycache__/sym_expr.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/__pycache__/sym_expr.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1878e06a5b54d20baa28ad272325f2155fc899ec Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/__pycache__/sym_expr.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/__pycache__/sympy_parser.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/__pycache__/sympy_parser.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08d9c348dfdf4fcc35538ed6274814a421a02b13 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/__pycache__/sympy_parser.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/ast_parser.py b/phivenv/Lib/site-packages/sympy/parsing/ast_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..95a773d5bec6e130810b7b7925fdff57270aec17 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/ast_parser.py @@ -0,0 +1,79 @@ +""" +This module implements the functionality to take any Python expression as a +string and fix all numbers and other things before evaluating it, +thus + +1/2 + +returns + +Integer(1)/Integer(2) + +We use the ast module for this. It is well documented at docs.python.org. + +Some tips to understand how this works: use dump() to get a nice +representation of any node. Then write a string of what you want to get, +e.g. "Integer(1)", parse it, dump it and you'll see that you need to do +"Call(Name('Integer', Load()), [node], [], None, None)". You do not need +to bother with lineno and col_offset, just call fix_missing_locations() +before returning the node. +""" + +from sympy.core.basic import Basic +from sympy.core.sympify import SympifyError + +from ast import parse, NodeTransformer, Call, Name, Load, \ + fix_missing_locations, Constant, Tuple + +class Transform(NodeTransformer): + + def __init__(self, local_dict, global_dict): + NodeTransformer.__init__(self) + self.local_dict = local_dict + self.global_dict = global_dict + + def visit_Constant(self, node): + if isinstance(node.value, int): + return fix_missing_locations(Call(func=Name('Integer', Load()), + args=[node], keywords=[])) + elif isinstance(node.value, float): + return fix_missing_locations(Call(func=Name('Float', Load()), + args=[node], keywords=[])) + return node + + def visit_Name(self, node): + if node.id in self.local_dict: + return node + elif node.id in self.global_dict: + name_obj = self.global_dict[node.id] + + if isinstance(name_obj, (Basic, type)) or callable(name_obj): + return node + elif node.id in ['True', 'False']: + return node + return fix_missing_locations(Call(func=Name('Symbol', Load()), + args=[Constant(node.id)], keywords=[])) + + def visit_Lambda(self, node): + args = [self.visit(arg) for arg in node.args.args] + body = self.visit(node.body) + n = Call(func=Name('Lambda', Load()), + args=[Tuple(args, Load()), body], keywords=[]) + return fix_missing_locations(n) + +def parse_expr(s, local_dict): + """ + Converts the string "s" to a SymPy expression, in local_dict. + + It converts all numbers to Integers before feeding it to Python and + automatically creates Symbols. + """ + global_dict = {} + exec('from sympy import *', global_dict) + try: + a = parse(s.strip(), mode="eval") + except SyntaxError: + raise SympifyError("Cannot parse %s." % repr(s)) + a = Transform(local_dict, global_dict).visit(a) + e = compile(a, "", "eval") + return eval(e, global_dict, local_dict) diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/Autolev.g4 b/phivenv/Lib/site-packages/sympy/parsing/autolev/Autolev.g4 new file mode 100644 index 0000000000000000000000000000000000000000..94feea5fa4f49e9d1054eca2cd60c996aebff7c2 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/autolev/Autolev.g4 @@ -0,0 +1,118 @@ +grammar Autolev; + +options { + language = Python3; +} + +prog: stat+; + +stat: varDecl + | functionCall + | codeCommands + | massDecl + | inertiaDecl + | assignment + | settings + ; + +assignment: vec equals expr #vecAssign + | ID '[' index ']' equals expr #indexAssign + | ID diff? equals expr #regularAssign; + +equals: ('='|'+='|'-='|':='|'*='|'/='|'^='); + +index: expr (',' expr)* ; + +diff: ('\'')+; + +functionCall: ID '(' (expr (',' expr)*)? ')' + | (Mass|Inertia) '(' (ID (',' ID)*)? ')'; + +varDecl: varType varDecl2 (',' varDecl2)*; + +varType: Newtonian|Frames|Bodies|Particles|Points|Constants + | Specifieds|Imaginary|Variables ('\'')*|MotionVariables ('\'')*; + +varDecl2: ID ('{' INT ',' INT '}')? (('{' INT ':' INT (',' INT ':' INT)* '}'))? ('{' INT '}')? ('+'|'-')? ('\'')* ('=' expr)?; + +ranges: ('{' INT ':' INT (',' INT ':' INT)* '}'); + +massDecl: Mass massDecl2 (',' massDecl2)*; + +massDecl2: ID '=' expr; + +inertiaDecl: Inertia ID ('(' ID ')')? (',' expr)+; + +matrix: '[' expr ((','|';') expr)* ']'; +matrixInOutput: (ID (ID '=' (FLOAT|INT)?))|FLOAT|INT; + +codeCommands: units + | inputs + | outputs + | codegen + | commands; + +settings: ID (EXP|ID|FLOAT|INT)?; + +units: UnitSystem ID (',' ID)*; +inputs: Input inputs2 (',' inputs2)*; +id_diff: ID diff?; +inputs2: id_diff '=' expr expr?; +outputs: Output outputs2 (',' outputs2)*; +outputs2: expr expr?; +codegen: ID functionCall ('['matrixInOutput (',' matrixInOutput)*']')? ID'.'ID; + +commands: Save ID'.'ID + | Encode ID (',' ID)*; + +vec: ID ('>')+ + | '0>' + | '1>>'; + +expr: expr '^' expr # Exponent + | expr ('*'|'/') expr # MulDiv + | expr ('+'|'-') expr # AddSub + | EXP # exp + | '-' expr # negativeOne + | FLOAT # float + | INT # int + | ID('\'')* # id + | vec # VectorOrDyadic + | ID '['expr (',' expr)* ']' # Indexing + | functionCall # function + | matrix # matrices + | '(' expr ')' # parens + | expr '=' expr # idEqualsExpr + | expr ':' expr # colon + | ID? ranges ('\'')* # rangess + ; + +// These are to take care of the case insensitivity of Autolev. +Mass: ('M'|'m')('A'|'a')('S'|'s')('S'|'s'); +Inertia: ('I'|'i')('N'|'n')('E'|'e')('R'|'r')('T'|'t')('I'|'i')('A'|'a'); +Input: ('I'|'i')('N'|'n')('P'|'p')('U'|'u')('T'|'t')('S'|'s')?; +Output: ('O'|'o')('U'|'u')('T'|'t')('P'|'p')('U'|'u')('T'|'t'); +Save: ('S'|'s')('A'|'a')('V'|'v')('E'|'e'); +UnitSystem: ('U'|'u')('N'|'n')('I'|'i')('T'|'t')('S'|'s')('Y'|'y')('S'|'s')('T'|'t')('E'|'e')('M'|'m'); +Encode: ('E'|'e')('N'|'n')('C'|'c')('O'|'o')('D'|'d')('E'|'e'); +Newtonian: ('N'|'n')('E'|'e')('W'|'w')('T'|'t')('O'|'o')('N'|'n')('I'|'i')('A'|'a')('N'|'n'); +Frames: ('F'|'f')('R'|'r')('A'|'a')('M'|'m')('E'|'e')('S'|'s')?; +Bodies: ('B'|'b')('O'|'o')('D'|'d')('I'|'i')('E'|'e')('S'|'s')?; +Particles: ('P'|'p')('A'|'a')('R'|'r')('T'|'t')('I'|'i')('C'|'c')('L'|'l')('E'|'e')('S'|'s')?; +Points: ('P'|'p')('O'|'o')('I'|'i')('N'|'n')('T'|'t')('S'|'s')?; +Constants: ('C'|'c')('O'|'o')('N'|'n')('S'|'s')('T'|'t')('A'|'a')('N'|'n')('T'|'t')('S'|'s')?; +Specifieds: ('S'|'s')('P'|'p')('E'|'e')('C'|'c')('I'|'i')('F'|'f')('I'|'i')('E'|'e')('D'|'d')('S'|'s')?; +Imaginary: ('I'|'i')('M'|'m')('A'|'a')('G'|'g')('I'|'i')('N'|'n')('A'|'a')('R'|'r')('Y'|'y'); +Variables: ('V'|'v')('A'|'a')('R'|'r')('I'|'i')('A'|'a')('B'|'b')('L'|'l')('E'|'e')('S'|'s')?; +MotionVariables: ('M'|'m')('O'|'o')('T'|'t')('I'|'i')('O'|'o')('N'|'n')('V'|'v')('A'|'a')('R'|'r')('I'|'i')('A'|'a')('B'|'b')('L'|'l')('E'|'e')('S'|'s')?; + +fragment DIFF: ('\'')*; +fragment DIGIT: [0-9]; +INT: [0-9]+ ; // match integers +FLOAT: DIGIT+ '.' DIGIT* + | '.' DIGIT+; +EXP: FLOAT 'E' INT +| FLOAT 'E' '-' INT; +LINE_COMMENT : '%' .*? '\r'? '\n' -> skip ; +ID: [a-zA-Z][a-zA-Z0-9_]*; +WS: [ \t\r\n&]+ -> skip ; // toss out whitespace diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/__init__.py b/phivenv/Lib/site-packages/sympy/parsing/autolev/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ec81bb83325d68e1c11b43a1df5ec56846367e9f --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/autolev/__init__.py @@ -0,0 +1,97 @@ +from sympy.external import import_module +from sympy.utilities.decorator import doctest_depends_on + +@doctest_depends_on(modules=('antlr4',)) +def parse_autolev(autolev_code, include_numeric=False): + """Parses Autolev code (version 4.1) to SymPy code. + + Parameters + ========= + autolev_code : Can be an str or any object with a readlines() method (such as a file handle or StringIO). + include_numeric : boolean, optional + If True NumPy, PyDy, or other numeric code is included for numeric evaluation lines in the Autolev code. + + Returns + ======= + sympy_code : str + Equivalent SymPy and/or numpy/pydy code as the input code. + + + Example (Double Pendulum) + ========================= + >>> my_al_text = ("MOTIONVARIABLES' Q{2}', U{2}'", + ... "CONSTANTS L,M,G", + ... "NEWTONIAN N", + ... "FRAMES A,B", + ... "SIMPROT(N, A, 3, Q1)", + ... "SIMPROT(N, B, 3, Q2)", + ... "W_A_N>=U1*N3>", + ... "W_B_N>=U2*N3>", + ... "POINT O", + ... "PARTICLES P,R", + ... "P_O_P> = L*A1>", + ... "P_P_R> = L*B1>", + ... "V_O_N> = 0>", + ... "V2PTS(N, A, O, P)", + ... "V2PTS(N, B, P, R)", + ... "MASS P=M, R=M", + ... "Q1' = U1", + ... "Q2' = U2", + ... "GRAVITY(G*N1>)", + ... "ZERO = FR() + FRSTAR()", + ... "KANE()", + ... "INPUT M=1,G=9.81,L=1", + ... "INPUT Q1=.1,Q2=.2,U1=0,U2=0", + ... "INPUT TFINAL=10, INTEGSTP=.01", + ... "CODE DYNAMICS() some_filename.c") + >>> my_al_text = '\\n'.join(my_al_text) + >>> from sympy.parsing.autolev import parse_autolev + >>> print(parse_autolev(my_al_text, include_numeric=True)) + import sympy.physics.mechanics as _me + import sympy as _sm + import math as m + import numpy as _np + + q1, q2, u1, u2 = _me.dynamicsymbols('q1 q2 u1 u2') + q1_d, q2_d, u1_d, u2_d = _me.dynamicsymbols('q1_ q2_ u1_ u2_', 1) + l, m, g = _sm.symbols('l m g', real=True) + frame_n = _me.ReferenceFrame('n') + frame_a = _me.ReferenceFrame('a') + frame_b = _me.ReferenceFrame('b') + frame_a.orient(frame_n, 'Axis', [q1, frame_n.z]) + frame_b.orient(frame_n, 'Axis', [q2, frame_n.z]) + frame_a.set_ang_vel(frame_n, u1*frame_n.z) + frame_b.set_ang_vel(frame_n, u2*frame_n.z) + point_o = _me.Point('o') + particle_p = _me.Particle('p', _me.Point('p_pt'), _sm.Symbol('m')) + particle_r = _me.Particle('r', _me.Point('r_pt'), _sm.Symbol('m')) + particle_p.point.set_pos(point_o, l*frame_a.x) + particle_r.point.set_pos(particle_p.point, l*frame_b.x) + point_o.set_vel(frame_n, 0) + particle_p.point.v2pt_theory(point_o,frame_n,frame_a) + particle_r.point.v2pt_theory(particle_p.point,frame_n,frame_b) + particle_p.mass = m + particle_r.mass = m + force_p = particle_p.mass*(g*frame_n.x) + force_r = particle_r.mass*(g*frame_n.x) + kd_eqs = [q1_d - u1, q2_d - u2] + forceList = [(particle_p.point,particle_p.mass*(g*frame_n.x)), (particle_r.point,particle_r.mass*(g*frame_n.x))] + kane = _me.KanesMethod(frame_n, q_ind=[q1,q2], u_ind=[u1, u2], kd_eqs = kd_eqs) + fr, frstar = kane.kanes_equations([particle_p, particle_r], forceList) + zero = fr+frstar + from pydy.system import System + sys = System(kane, constants = {l:1, m:1, g:9.81}, + specifieds={}, + initial_conditions={q1:.1, q2:.2, u1:0, u2:0}, + times = _np.linspace(0.0, 10, 10/.01)) + + y=sys.integrate() + + """ + + _autolev = import_module( + 'sympy.parsing.autolev._parse_autolev_antlr', + import_kwargs={'fromlist': ['X']}) + + if _autolev is not None: + return _autolev.parse_autolev(autolev_code, include_numeric) diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/autolev/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..72fa1eee0ed224599467ce4e45e774f4ce3c293b Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/autolev/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/__pycache__/_build_autolev_antlr.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/autolev/__pycache__/_build_autolev_antlr.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f961f5308940e62fba2053226bb693986a2e49fa Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/autolev/__pycache__/_build_autolev_antlr.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/__pycache__/_listener_autolev_antlr.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/autolev/__pycache__/_listener_autolev_antlr.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1dd5635deebf31c56d67eb1a2d0b50280637ef3f Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/autolev/__pycache__/_listener_autolev_antlr.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/__pycache__/_parse_autolev_antlr.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/autolev/__pycache__/_parse_autolev_antlr.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..963099d09387b49dad8cb0d6135a0072a2bc7ce7 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/autolev/__pycache__/_parse_autolev_antlr.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/_antlr/__init__.py b/phivenv/Lib/site-packages/sympy/parsing/autolev/_antlr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9b71e9f51fd455558a9eb42dc840604c6c96e4b3 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/autolev/_antlr/__init__.py @@ -0,0 +1,5 @@ +# *** GENERATED BY `setup.py antlr`, DO NOT EDIT BY HAND *** +# +# Generated with antlr4 +# antlr4 is licensed under the BSD-3-Clause License +# https://github.com/antlr/antlr4/blob/master/LICENSE.txt diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/_antlr/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/autolev/_antlr/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a43beba9a14d05885d80d1577aaf3c74e7ad3b0 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/autolev/_antlr/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/_antlr/__pycache__/autolevlexer.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/autolev/_antlr/__pycache__/autolevlexer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d696cabe1676286601036d3d2ff7b5c669bad7d1 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/autolev/_antlr/__pycache__/autolevlexer.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/_antlr/__pycache__/autolevlistener.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/autolev/_antlr/__pycache__/autolevlistener.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92fbbf2e2d199709c66f58fd4cd3f9090648f51a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/autolev/_antlr/__pycache__/autolevlistener.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/_antlr/autolevlexer.py b/phivenv/Lib/site-packages/sympy/parsing/autolev/_antlr/autolevlexer.py new file mode 100644 index 0000000000000000000000000000000000000000..f3b3b1d27ade809a63d9fd328a1572c17625443e --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/autolev/_antlr/autolevlexer.py @@ -0,0 +1,253 @@ +# *** GENERATED BY `setup.py antlr`, DO NOT EDIT BY HAND *** +# +# Generated with antlr4 +# antlr4 is licensed under the BSD-3-Clause License +# https://github.com/antlr/antlr4/blob/master/LICENSE.txt +from antlr4 import * +from io import StringIO +import sys +if sys.version_info[1] > 5: + from typing import TextIO +else: + from typing.io import TextIO + + +def serializedATN(): + return [ + 4,0,49,393,6,-1,2,0,7,0,2,1,7,1,2,2,7,2,2,3,7,3,2,4,7,4,2,5,7,5, + 2,6,7,6,2,7,7,7,2,8,7,8,2,9,7,9,2,10,7,10,2,11,7,11,2,12,7,12,2, + 13,7,13,2,14,7,14,2,15,7,15,2,16,7,16,2,17,7,17,2,18,7,18,2,19,7, + 19,2,20,7,20,2,21,7,21,2,22,7,22,2,23,7,23,2,24,7,24,2,25,7,25,2, + 26,7,26,2,27,7,27,2,28,7,28,2,29,7,29,2,30,7,30,2,31,7,31,2,32,7, + 32,2,33,7,33,2,34,7,34,2,35,7,35,2,36,7,36,2,37,7,37,2,38,7,38,2, + 39,7,39,2,40,7,40,2,41,7,41,2,42,7,42,2,43,7,43,2,44,7,44,2,45,7, + 45,2,46,7,46,2,47,7,47,2,48,7,48,2,49,7,49,2,50,7,50,1,0,1,0,1,1, + 1,1,1,2,1,2,1,3,1,3,1,3,1,4,1,4,1,4,1,5,1,5,1,5,1,6,1,6,1,6,1,7, + 1,7,1,7,1,8,1,8,1,8,1,9,1,9,1,10,1,10,1,11,1,11,1,12,1,12,1,13,1, + 13,1,14,1,14,1,15,1,15,1,16,1,16,1,17,1,17,1,18,1,18,1,19,1,19,1, + 20,1,20,1,21,1,21,1,21,1,22,1,22,1,22,1,22,1,23,1,23,1,24,1,24,1, + 25,1,25,1,26,1,26,1,26,1,26,1,26,1,27,1,27,1,27,1,27,1,27,1,27,1, + 27,1,27,1,28,1,28,1,28,1,28,1,28,1,28,3,28,184,8,28,1,29,1,29,1, + 29,1,29,1,29,1,29,1,29,1,30,1,30,1,30,1,30,1,30,1,31,1,31,1,31,1, + 31,1,31,1,31,1,31,1,31,1,31,1,31,1,31,1,32,1,32,1,32,1,32,1,32,1, + 32,1,32,1,33,1,33,1,33,1,33,1,33,1,33,1,33,1,33,1,33,1,33,1,34,1, + 34,1,34,1,34,1,34,1,34,3,34,232,8,34,1,35,1,35,1,35,1,35,1,35,1, + 35,3,35,240,8,35,1,36,1,36,1,36,1,36,1,36,1,36,1,36,1,36,1,36,3, + 36,251,8,36,1,37,1,37,1,37,1,37,1,37,1,37,3,37,259,8,37,1,38,1,38, + 1,38,1,38,1,38,1,38,1,38,1,38,1,38,3,38,270,8,38,1,39,1,39,1,39, + 1,39,1,39,1,39,1,39,1,39,1,39,1,39,3,39,282,8,39,1,40,1,40,1,40, + 1,40,1,40,1,40,1,40,1,40,1,40,1,40,1,41,1,41,1,41,1,41,1,41,1,41, + 1,41,1,41,1,41,3,41,303,8,41,1,42,1,42,1,42,1,42,1,42,1,42,1,42, + 1,42,1,42,1,42,1,42,1,42,1,42,1,42,1,42,3,42,320,8,42,1,43,5,43, + 323,8,43,10,43,12,43,326,9,43,1,44,1,44,1,45,4,45,331,8,45,11,45, + 12,45,332,1,46,4,46,336,8,46,11,46,12,46,337,1,46,1,46,5,46,342, + 8,46,10,46,12,46,345,9,46,1,46,1,46,4,46,349,8,46,11,46,12,46,350, + 3,46,353,8,46,1,47,1,47,1,47,1,47,1,47,1,47,1,47,1,47,1,47,3,47, + 364,8,47,1,48,1,48,5,48,368,8,48,10,48,12,48,371,9,48,1,48,3,48, + 374,8,48,1,48,1,48,1,48,1,48,1,49,1,49,5,49,382,8,49,10,49,12,49, + 385,9,49,1,50,4,50,388,8,50,11,50,12,50,389,1,50,1,50,1,369,0,51, + 1,1,3,2,5,3,7,4,9,5,11,6,13,7,15,8,17,9,19,10,21,11,23,12,25,13, + 27,14,29,15,31,16,33,17,35,18,37,19,39,20,41,21,43,22,45,23,47,24, + 49,25,51,26,53,27,55,28,57,29,59,30,61,31,63,32,65,33,67,34,69,35, + 71,36,73,37,75,38,77,39,79,40,81,41,83,42,85,43,87,0,89,0,91,44, + 93,45,95,46,97,47,99,48,101,49,1,0,24,2,0,77,77,109,109,2,0,65,65, + 97,97,2,0,83,83,115,115,2,0,73,73,105,105,2,0,78,78,110,110,2,0, + 69,69,101,101,2,0,82,82,114,114,2,0,84,84,116,116,2,0,80,80,112, + 112,2,0,85,85,117,117,2,0,79,79,111,111,2,0,86,86,118,118,2,0,89, + 89,121,121,2,0,67,67,99,99,2,0,68,68,100,100,2,0,87,87,119,119,2, + 0,70,70,102,102,2,0,66,66,98,98,2,0,76,76,108,108,2,0,71,71,103, + 103,1,0,48,57,2,0,65,90,97,122,4,0,48,57,65,90,95,95,97,122,4,0, + 9,10,13,13,32,32,38,38,410,0,1,1,0,0,0,0,3,1,0,0,0,0,5,1,0,0,0,0, + 7,1,0,0,0,0,9,1,0,0,0,0,11,1,0,0,0,0,13,1,0,0,0,0,15,1,0,0,0,0,17, + 1,0,0,0,0,19,1,0,0,0,0,21,1,0,0,0,0,23,1,0,0,0,0,25,1,0,0,0,0,27, + 1,0,0,0,0,29,1,0,0,0,0,31,1,0,0,0,0,33,1,0,0,0,0,35,1,0,0,0,0,37, + 1,0,0,0,0,39,1,0,0,0,0,41,1,0,0,0,0,43,1,0,0,0,0,45,1,0,0,0,0,47, + 1,0,0,0,0,49,1,0,0,0,0,51,1,0,0,0,0,53,1,0,0,0,0,55,1,0,0,0,0,57, + 1,0,0,0,0,59,1,0,0,0,0,61,1,0,0,0,0,63,1,0,0,0,0,65,1,0,0,0,0,67, + 1,0,0,0,0,69,1,0,0,0,0,71,1,0,0,0,0,73,1,0,0,0,0,75,1,0,0,0,0,77, + 1,0,0,0,0,79,1,0,0,0,0,81,1,0,0,0,0,83,1,0,0,0,0,85,1,0,0,0,0,91, + 1,0,0,0,0,93,1,0,0,0,0,95,1,0,0,0,0,97,1,0,0,0,0,99,1,0,0,0,0,101, + 1,0,0,0,1,103,1,0,0,0,3,105,1,0,0,0,5,107,1,0,0,0,7,109,1,0,0,0, + 9,112,1,0,0,0,11,115,1,0,0,0,13,118,1,0,0,0,15,121,1,0,0,0,17,124, + 1,0,0,0,19,127,1,0,0,0,21,129,1,0,0,0,23,131,1,0,0,0,25,133,1,0, + 0,0,27,135,1,0,0,0,29,137,1,0,0,0,31,139,1,0,0,0,33,141,1,0,0,0, + 35,143,1,0,0,0,37,145,1,0,0,0,39,147,1,0,0,0,41,149,1,0,0,0,43,151, + 1,0,0,0,45,154,1,0,0,0,47,158,1,0,0,0,49,160,1,0,0,0,51,162,1,0, + 0,0,53,164,1,0,0,0,55,169,1,0,0,0,57,177,1,0,0,0,59,185,1,0,0,0, + 61,192,1,0,0,0,63,197,1,0,0,0,65,208,1,0,0,0,67,215,1,0,0,0,69,225, + 1,0,0,0,71,233,1,0,0,0,73,241,1,0,0,0,75,252,1,0,0,0,77,260,1,0, + 0,0,79,271,1,0,0,0,81,283,1,0,0,0,83,293,1,0,0,0,85,304,1,0,0,0, + 87,324,1,0,0,0,89,327,1,0,0,0,91,330,1,0,0,0,93,352,1,0,0,0,95,363, + 1,0,0,0,97,365,1,0,0,0,99,379,1,0,0,0,101,387,1,0,0,0,103,104,5, + 91,0,0,104,2,1,0,0,0,105,106,5,93,0,0,106,4,1,0,0,0,107,108,5,61, + 0,0,108,6,1,0,0,0,109,110,5,43,0,0,110,111,5,61,0,0,111,8,1,0,0, + 0,112,113,5,45,0,0,113,114,5,61,0,0,114,10,1,0,0,0,115,116,5,58, + 0,0,116,117,5,61,0,0,117,12,1,0,0,0,118,119,5,42,0,0,119,120,5,61, + 0,0,120,14,1,0,0,0,121,122,5,47,0,0,122,123,5,61,0,0,123,16,1,0, + 0,0,124,125,5,94,0,0,125,126,5,61,0,0,126,18,1,0,0,0,127,128,5,44, + 0,0,128,20,1,0,0,0,129,130,5,39,0,0,130,22,1,0,0,0,131,132,5,40, + 0,0,132,24,1,0,0,0,133,134,5,41,0,0,134,26,1,0,0,0,135,136,5,123, + 0,0,136,28,1,0,0,0,137,138,5,125,0,0,138,30,1,0,0,0,139,140,5,58, + 0,0,140,32,1,0,0,0,141,142,5,43,0,0,142,34,1,0,0,0,143,144,5,45, + 0,0,144,36,1,0,0,0,145,146,5,59,0,0,146,38,1,0,0,0,147,148,5,46, + 0,0,148,40,1,0,0,0,149,150,5,62,0,0,150,42,1,0,0,0,151,152,5,48, + 0,0,152,153,5,62,0,0,153,44,1,0,0,0,154,155,5,49,0,0,155,156,5,62, + 0,0,156,157,5,62,0,0,157,46,1,0,0,0,158,159,5,94,0,0,159,48,1,0, + 0,0,160,161,5,42,0,0,161,50,1,0,0,0,162,163,5,47,0,0,163,52,1,0, + 0,0,164,165,7,0,0,0,165,166,7,1,0,0,166,167,7,2,0,0,167,168,7,2, + 0,0,168,54,1,0,0,0,169,170,7,3,0,0,170,171,7,4,0,0,171,172,7,5,0, + 0,172,173,7,6,0,0,173,174,7,7,0,0,174,175,7,3,0,0,175,176,7,1,0, + 0,176,56,1,0,0,0,177,178,7,3,0,0,178,179,7,4,0,0,179,180,7,8,0,0, + 180,181,7,9,0,0,181,183,7,7,0,0,182,184,7,2,0,0,183,182,1,0,0,0, + 183,184,1,0,0,0,184,58,1,0,0,0,185,186,7,10,0,0,186,187,7,9,0,0, + 187,188,7,7,0,0,188,189,7,8,0,0,189,190,7,9,0,0,190,191,7,7,0,0, + 191,60,1,0,0,0,192,193,7,2,0,0,193,194,7,1,0,0,194,195,7,11,0,0, + 195,196,7,5,0,0,196,62,1,0,0,0,197,198,7,9,0,0,198,199,7,4,0,0,199, + 200,7,3,0,0,200,201,7,7,0,0,201,202,7,2,0,0,202,203,7,12,0,0,203, + 204,7,2,0,0,204,205,7,7,0,0,205,206,7,5,0,0,206,207,7,0,0,0,207, + 64,1,0,0,0,208,209,7,5,0,0,209,210,7,4,0,0,210,211,7,13,0,0,211, + 212,7,10,0,0,212,213,7,14,0,0,213,214,7,5,0,0,214,66,1,0,0,0,215, + 216,7,4,0,0,216,217,7,5,0,0,217,218,7,15,0,0,218,219,7,7,0,0,219, + 220,7,10,0,0,220,221,7,4,0,0,221,222,7,3,0,0,222,223,7,1,0,0,223, + 224,7,4,0,0,224,68,1,0,0,0,225,226,7,16,0,0,226,227,7,6,0,0,227, + 228,7,1,0,0,228,229,7,0,0,0,229,231,7,5,0,0,230,232,7,2,0,0,231, + 230,1,0,0,0,231,232,1,0,0,0,232,70,1,0,0,0,233,234,7,17,0,0,234, + 235,7,10,0,0,235,236,7,14,0,0,236,237,7,3,0,0,237,239,7,5,0,0,238, + 240,7,2,0,0,239,238,1,0,0,0,239,240,1,0,0,0,240,72,1,0,0,0,241,242, + 7,8,0,0,242,243,7,1,0,0,243,244,7,6,0,0,244,245,7,7,0,0,245,246, + 7,3,0,0,246,247,7,13,0,0,247,248,7,18,0,0,248,250,7,5,0,0,249,251, + 7,2,0,0,250,249,1,0,0,0,250,251,1,0,0,0,251,74,1,0,0,0,252,253,7, + 8,0,0,253,254,7,10,0,0,254,255,7,3,0,0,255,256,7,4,0,0,256,258,7, + 7,0,0,257,259,7,2,0,0,258,257,1,0,0,0,258,259,1,0,0,0,259,76,1,0, + 0,0,260,261,7,13,0,0,261,262,7,10,0,0,262,263,7,4,0,0,263,264,7, + 2,0,0,264,265,7,7,0,0,265,266,7,1,0,0,266,267,7,4,0,0,267,269,7, + 7,0,0,268,270,7,2,0,0,269,268,1,0,0,0,269,270,1,0,0,0,270,78,1,0, + 0,0,271,272,7,2,0,0,272,273,7,8,0,0,273,274,7,5,0,0,274,275,7,13, + 0,0,275,276,7,3,0,0,276,277,7,16,0,0,277,278,7,3,0,0,278,279,7,5, + 0,0,279,281,7,14,0,0,280,282,7,2,0,0,281,280,1,0,0,0,281,282,1,0, + 0,0,282,80,1,0,0,0,283,284,7,3,0,0,284,285,7,0,0,0,285,286,7,1,0, + 0,286,287,7,19,0,0,287,288,7,3,0,0,288,289,7,4,0,0,289,290,7,1,0, + 0,290,291,7,6,0,0,291,292,7,12,0,0,292,82,1,0,0,0,293,294,7,11,0, + 0,294,295,7,1,0,0,295,296,7,6,0,0,296,297,7,3,0,0,297,298,7,1,0, + 0,298,299,7,17,0,0,299,300,7,18,0,0,300,302,7,5,0,0,301,303,7,2, + 0,0,302,301,1,0,0,0,302,303,1,0,0,0,303,84,1,0,0,0,304,305,7,0,0, + 0,305,306,7,10,0,0,306,307,7,7,0,0,307,308,7,3,0,0,308,309,7,10, + 0,0,309,310,7,4,0,0,310,311,7,11,0,0,311,312,7,1,0,0,312,313,7,6, + 0,0,313,314,7,3,0,0,314,315,7,1,0,0,315,316,7,17,0,0,316,317,7,18, + 0,0,317,319,7,5,0,0,318,320,7,2,0,0,319,318,1,0,0,0,319,320,1,0, + 0,0,320,86,1,0,0,0,321,323,5,39,0,0,322,321,1,0,0,0,323,326,1,0, + 0,0,324,322,1,0,0,0,324,325,1,0,0,0,325,88,1,0,0,0,326,324,1,0,0, + 0,327,328,7,20,0,0,328,90,1,0,0,0,329,331,7,20,0,0,330,329,1,0,0, + 0,331,332,1,0,0,0,332,330,1,0,0,0,332,333,1,0,0,0,333,92,1,0,0,0, + 334,336,3,89,44,0,335,334,1,0,0,0,336,337,1,0,0,0,337,335,1,0,0, + 0,337,338,1,0,0,0,338,339,1,0,0,0,339,343,5,46,0,0,340,342,3,89, + 44,0,341,340,1,0,0,0,342,345,1,0,0,0,343,341,1,0,0,0,343,344,1,0, + 0,0,344,353,1,0,0,0,345,343,1,0,0,0,346,348,5,46,0,0,347,349,3,89, + 44,0,348,347,1,0,0,0,349,350,1,0,0,0,350,348,1,0,0,0,350,351,1,0, + 0,0,351,353,1,0,0,0,352,335,1,0,0,0,352,346,1,0,0,0,353,94,1,0,0, + 0,354,355,3,93,46,0,355,356,5,69,0,0,356,357,3,91,45,0,357,364,1, + 0,0,0,358,359,3,93,46,0,359,360,5,69,0,0,360,361,5,45,0,0,361,362, + 3,91,45,0,362,364,1,0,0,0,363,354,1,0,0,0,363,358,1,0,0,0,364,96, + 1,0,0,0,365,369,5,37,0,0,366,368,9,0,0,0,367,366,1,0,0,0,368,371, + 1,0,0,0,369,370,1,0,0,0,369,367,1,0,0,0,370,373,1,0,0,0,371,369, + 1,0,0,0,372,374,5,13,0,0,373,372,1,0,0,0,373,374,1,0,0,0,374,375, + 1,0,0,0,375,376,5,10,0,0,376,377,1,0,0,0,377,378,6,48,0,0,378,98, + 1,0,0,0,379,383,7,21,0,0,380,382,7,22,0,0,381,380,1,0,0,0,382,385, + 1,0,0,0,383,381,1,0,0,0,383,384,1,0,0,0,384,100,1,0,0,0,385,383, + 1,0,0,0,386,388,7,23,0,0,387,386,1,0,0,0,388,389,1,0,0,0,389,387, + 1,0,0,0,389,390,1,0,0,0,390,391,1,0,0,0,391,392,6,50,0,0,392,102, + 1,0,0,0,21,0,183,231,239,250,258,269,281,302,319,324,332,337,343, + 350,352,363,369,373,383,389,1,6,0,0 + ] + +class AutolevLexer(Lexer): + + atn = ATNDeserializer().deserialize(serializedATN()) + + decisionsToDFA = [ DFA(ds, i) for i, ds in enumerate(atn.decisionToState) ] + + T__0 = 1 + T__1 = 2 + T__2 = 3 + T__3 = 4 + T__4 = 5 + T__5 = 6 + T__6 = 7 + T__7 = 8 + T__8 = 9 + T__9 = 10 + T__10 = 11 + T__11 = 12 + T__12 = 13 + T__13 = 14 + T__14 = 15 + T__15 = 16 + T__16 = 17 + T__17 = 18 + T__18 = 19 + T__19 = 20 + T__20 = 21 + T__21 = 22 + T__22 = 23 + T__23 = 24 + T__24 = 25 + T__25 = 26 + Mass = 27 + Inertia = 28 + Input = 29 + Output = 30 + Save = 31 + UnitSystem = 32 + Encode = 33 + Newtonian = 34 + Frames = 35 + Bodies = 36 + Particles = 37 + Points = 38 + Constants = 39 + Specifieds = 40 + Imaginary = 41 + Variables = 42 + MotionVariables = 43 + INT = 44 + FLOAT = 45 + EXP = 46 + LINE_COMMENT = 47 + ID = 48 + WS = 49 + + channelNames = [ u"DEFAULT_TOKEN_CHANNEL", u"HIDDEN" ] + + modeNames = [ "DEFAULT_MODE" ] + + literalNames = [ "", + "'['", "']'", "'='", "'+='", "'-='", "':='", "'*='", "'/='", + "'^='", "','", "'''", "'('", "')'", "'{'", "'}'", "':'", "'+'", + "'-'", "';'", "'.'", "'>'", "'0>'", "'1>>'", "'^'", "'*'", "'/'" ] + + symbolicNames = [ "", + "Mass", "Inertia", "Input", "Output", "Save", "UnitSystem", + "Encode", "Newtonian", "Frames", "Bodies", "Particles", "Points", + "Constants", "Specifieds", "Imaginary", "Variables", "MotionVariables", + "INT", "FLOAT", "EXP", "LINE_COMMENT", "ID", "WS" ] + + ruleNames = [ "T__0", "T__1", "T__2", "T__3", "T__4", "T__5", "T__6", + "T__7", "T__8", "T__9", "T__10", "T__11", "T__12", "T__13", + "T__14", "T__15", "T__16", "T__17", "T__18", "T__19", + "T__20", "T__21", "T__22", "T__23", "T__24", "T__25", + "Mass", "Inertia", "Input", "Output", "Save", "UnitSystem", + "Encode", "Newtonian", "Frames", "Bodies", "Particles", + "Points", "Constants", "Specifieds", "Imaginary", "Variables", + "MotionVariables", "DIFF", "DIGIT", "INT", "FLOAT", "EXP", + "LINE_COMMENT", "ID", "WS" ] + + grammarFileName = "Autolev.g4" + + def __init__(self, input=None, output:TextIO = sys.stdout): + super().__init__(input, output) + self.checkVersion("4.11.1") + self._interp = LexerATNSimulator(self, self.atn, self.decisionsToDFA, PredictionContextCache()) + self._actions = None + self._predicates = None + + diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/_antlr/autolevlistener.py b/phivenv/Lib/site-packages/sympy/parsing/autolev/_antlr/autolevlistener.py new file mode 100644 index 0000000000000000000000000000000000000000..6f391a298a71ecf2d04cf921a919cbb68b181fab --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/autolev/_antlr/autolevlistener.py @@ -0,0 +1,421 @@ +# *** GENERATED BY `setup.py antlr`, DO NOT EDIT BY HAND *** +# +# Generated with antlr4 +# antlr4 is licensed under the BSD-3-Clause License +# https://github.com/antlr/antlr4/blob/master/LICENSE.txt +from antlr4 import * +if __name__ is not None and "." in __name__: + from .autolevparser import AutolevParser +else: + from autolevparser import AutolevParser + +# This class defines a complete listener for a parse tree produced by AutolevParser. +class AutolevListener(ParseTreeListener): + + # Enter a parse tree produced by AutolevParser#prog. + def enterProg(self, ctx:AutolevParser.ProgContext): + pass + + # Exit a parse tree produced by AutolevParser#prog. + def exitProg(self, ctx:AutolevParser.ProgContext): + pass + + + # Enter a parse tree produced by AutolevParser#stat. + def enterStat(self, ctx:AutolevParser.StatContext): + pass + + # Exit a parse tree produced by AutolevParser#stat. + def exitStat(self, ctx:AutolevParser.StatContext): + pass + + + # Enter a parse tree produced by AutolevParser#vecAssign. + def enterVecAssign(self, ctx:AutolevParser.VecAssignContext): + pass + + # Exit a parse tree produced by AutolevParser#vecAssign. + def exitVecAssign(self, ctx:AutolevParser.VecAssignContext): + pass + + + # Enter a parse tree produced by AutolevParser#indexAssign. + def enterIndexAssign(self, ctx:AutolevParser.IndexAssignContext): + pass + + # Exit a parse tree produced by AutolevParser#indexAssign. + def exitIndexAssign(self, ctx:AutolevParser.IndexAssignContext): + pass + + + # Enter a parse tree produced by AutolevParser#regularAssign. + def enterRegularAssign(self, ctx:AutolevParser.RegularAssignContext): + pass + + # Exit a parse tree produced by AutolevParser#regularAssign. + def exitRegularAssign(self, ctx:AutolevParser.RegularAssignContext): + pass + + + # Enter a parse tree produced by AutolevParser#equals. + def enterEquals(self, ctx:AutolevParser.EqualsContext): + pass + + # Exit a parse tree produced by AutolevParser#equals. + def exitEquals(self, ctx:AutolevParser.EqualsContext): + pass + + + # Enter a parse tree produced by AutolevParser#index. + def enterIndex(self, ctx:AutolevParser.IndexContext): + pass + + # Exit a parse tree produced by AutolevParser#index. + def exitIndex(self, ctx:AutolevParser.IndexContext): + pass + + + # Enter a parse tree produced by AutolevParser#diff. + def enterDiff(self, ctx:AutolevParser.DiffContext): + pass + + # Exit a parse tree produced by AutolevParser#diff. + def exitDiff(self, ctx:AutolevParser.DiffContext): + pass + + + # Enter a parse tree produced by AutolevParser#functionCall. + def enterFunctionCall(self, ctx:AutolevParser.FunctionCallContext): + pass + + # Exit a parse tree produced by AutolevParser#functionCall. + def exitFunctionCall(self, ctx:AutolevParser.FunctionCallContext): + pass + + + # Enter a parse tree produced by AutolevParser#varDecl. + def enterVarDecl(self, ctx:AutolevParser.VarDeclContext): + pass + + # Exit a parse tree produced by AutolevParser#varDecl. + def exitVarDecl(self, ctx:AutolevParser.VarDeclContext): + pass + + + # Enter a parse tree produced by AutolevParser#varType. + def enterVarType(self, ctx:AutolevParser.VarTypeContext): + pass + + # Exit a parse tree produced by AutolevParser#varType. + def exitVarType(self, ctx:AutolevParser.VarTypeContext): + pass + + + # Enter a parse tree produced by AutolevParser#varDecl2. + def enterVarDecl2(self, ctx:AutolevParser.VarDecl2Context): + pass + + # Exit a parse tree produced by AutolevParser#varDecl2. + def exitVarDecl2(self, ctx:AutolevParser.VarDecl2Context): + pass + + + # Enter a parse tree produced by AutolevParser#ranges. + def enterRanges(self, ctx:AutolevParser.RangesContext): + pass + + # Exit a parse tree produced by AutolevParser#ranges. + def exitRanges(self, ctx:AutolevParser.RangesContext): + pass + + + # Enter a parse tree produced by AutolevParser#massDecl. + def enterMassDecl(self, ctx:AutolevParser.MassDeclContext): + pass + + # Exit a parse tree produced by AutolevParser#massDecl. + def exitMassDecl(self, ctx:AutolevParser.MassDeclContext): + pass + + + # Enter a parse tree produced by AutolevParser#massDecl2. + def enterMassDecl2(self, ctx:AutolevParser.MassDecl2Context): + pass + + # Exit a parse tree produced by AutolevParser#massDecl2. + def exitMassDecl2(self, ctx:AutolevParser.MassDecl2Context): + pass + + + # Enter a parse tree produced by AutolevParser#inertiaDecl. + def enterInertiaDecl(self, ctx:AutolevParser.InertiaDeclContext): + pass + + # Exit a parse tree produced by AutolevParser#inertiaDecl. + def exitInertiaDecl(self, ctx:AutolevParser.InertiaDeclContext): + pass + + + # Enter a parse tree produced by AutolevParser#matrix. + def enterMatrix(self, ctx:AutolevParser.MatrixContext): + pass + + # Exit a parse tree produced by AutolevParser#matrix. + def exitMatrix(self, ctx:AutolevParser.MatrixContext): + pass + + + # Enter a parse tree produced by AutolevParser#matrixInOutput. + def enterMatrixInOutput(self, ctx:AutolevParser.MatrixInOutputContext): + pass + + # Exit a parse tree produced by AutolevParser#matrixInOutput. + def exitMatrixInOutput(self, ctx:AutolevParser.MatrixInOutputContext): + pass + + + # Enter a parse tree produced by AutolevParser#codeCommands. + def enterCodeCommands(self, ctx:AutolevParser.CodeCommandsContext): + pass + + # Exit a parse tree produced by AutolevParser#codeCommands. + def exitCodeCommands(self, ctx:AutolevParser.CodeCommandsContext): + pass + + + # Enter a parse tree produced by AutolevParser#settings. + def enterSettings(self, ctx:AutolevParser.SettingsContext): + pass + + # Exit a parse tree produced by AutolevParser#settings. + def exitSettings(self, ctx:AutolevParser.SettingsContext): + pass + + + # Enter a parse tree produced by AutolevParser#units. + def enterUnits(self, ctx:AutolevParser.UnitsContext): + pass + + # Exit a parse tree produced by AutolevParser#units. + def exitUnits(self, ctx:AutolevParser.UnitsContext): + pass + + + # Enter a parse tree produced by AutolevParser#inputs. + def enterInputs(self, ctx:AutolevParser.InputsContext): + pass + + # Exit a parse tree produced by AutolevParser#inputs. + def exitInputs(self, ctx:AutolevParser.InputsContext): + pass + + + # Enter a parse tree produced by AutolevParser#id_diff. + def enterId_diff(self, ctx:AutolevParser.Id_diffContext): + pass + + # Exit a parse tree produced by AutolevParser#id_diff. + def exitId_diff(self, ctx:AutolevParser.Id_diffContext): + pass + + + # Enter a parse tree produced by AutolevParser#inputs2. + def enterInputs2(self, ctx:AutolevParser.Inputs2Context): + pass + + # Exit a parse tree produced by AutolevParser#inputs2. + def exitInputs2(self, ctx:AutolevParser.Inputs2Context): + pass + + + # Enter a parse tree produced by AutolevParser#outputs. + def enterOutputs(self, ctx:AutolevParser.OutputsContext): + pass + + # Exit a parse tree produced by AutolevParser#outputs. + def exitOutputs(self, ctx:AutolevParser.OutputsContext): + pass + + + # Enter a parse tree produced by AutolevParser#outputs2. + def enterOutputs2(self, ctx:AutolevParser.Outputs2Context): + pass + + # Exit a parse tree produced by AutolevParser#outputs2. + def exitOutputs2(self, ctx:AutolevParser.Outputs2Context): + pass + + + # Enter a parse tree produced by AutolevParser#codegen. + def enterCodegen(self, ctx:AutolevParser.CodegenContext): + pass + + # Exit a parse tree produced by AutolevParser#codegen. + def exitCodegen(self, ctx:AutolevParser.CodegenContext): + pass + + + # Enter a parse tree produced by AutolevParser#commands. + def enterCommands(self, ctx:AutolevParser.CommandsContext): + pass + + # Exit a parse tree produced by AutolevParser#commands. + def exitCommands(self, ctx:AutolevParser.CommandsContext): + pass + + + # Enter a parse tree produced by AutolevParser#vec. + def enterVec(self, ctx:AutolevParser.VecContext): + pass + + # Exit a parse tree produced by AutolevParser#vec. + def exitVec(self, ctx:AutolevParser.VecContext): + pass + + + # Enter a parse tree produced by AutolevParser#parens. + def enterParens(self, ctx:AutolevParser.ParensContext): + pass + + # Exit a parse tree produced by AutolevParser#parens. + def exitParens(self, ctx:AutolevParser.ParensContext): + pass + + + # Enter a parse tree produced by AutolevParser#VectorOrDyadic. + def enterVectorOrDyadic(self, ctx:AutolevParser.VectorOrDyadicContext): + pass + + # Exit a parse tree produced by AutolevParser#VectorOrDyadic. + def exitVectorOrDyadic(self, ctx:AutolevParser.VectorOrDyadicContext): + pass + + + # Enter a parse tree produced by AutolevParser#Exponent. + def enterExponent(self, ctx:AutolevParser.ExponentContext): + pass + + # Exit a parse tree produced by AutolevParser#Exponent. + def exitExponent(self, ctx:AutolevParser.ExponentContext): + pass + + + # Enter a parse tree produced by AutolevParser#MulDiv. + def enterMulDiv(self, ctx:AutolevParser.MulDivContext): + pass + + # Exit a parse tree produced by AutolevParser#MulDiv. + def exitMulDiv(self, ctx:AutolevParser.MulDivContext): + pass + + + # Enter a parse tree produced by AutolevParser#AddSub. + def enterAddSub(self, ctx:AutolevParser.AddSubContext): + pass + + # Exit a parse tree produced by AutolevParser#AddSub. + def exitAddSub(self, ctx:AutolevParser.AddSubContext): + pass + + + # Enter a parse tree produced by AutolevParser#float. + def enterFloat(self, ctx:AutolevParser.FloatContext): + pass + + # Exit a parse tree produced by AutolevParser#float. + def exitFloat(self, ctx:AutolevParser.FloatContext): + pass + + + # Enter a parse tree produced by AutolevParser#int. + def enterInt(self, ctx:AutolevParser.IntContext): + pass + + # Exit a parse tree produced by AutolevParser#int. + def exitInt(self, ctx:AutolevParser.IntContext): + pass + + + # Enter a parse tree produced by AutolevParser#idEqualsExpr. + def enterIdEqualsExpr(self, ctx:AutolevParser.IdEqualsExprContext): + pass + + # Exit a parse tree produced by AutolevParser#idEqualsExpr. + def exitIdEqualsExpr(self, ctx:AutolevParser.IdEqualsExprContext): + pass + + + # Enter a parse tree produced by AutolevParser#negativeOne. + def enterNegativeOne(self, ctx:AutolevParser.NegativeOneContext): + pass + + # Exit a parse tree produced by AutolevParser#negativeOne. + def exitNegativeOne(self, ctx:AutolevParser.NegativeOneContext): + pass + + + # Enter a parse tree produced by AutolevParser#function. + def enterFunction(self, ctx:AutolevParser.FunctionContext): + pass + + # Exit a parse tree produced by AutolevParser#function. + def exitFunction(self, ctx:AutolevParser.FunctionContext): + pass + + + # Enter a parse tree produced by AutolevParser#rangess. + def enterRangess(self, ctx:AutolevParser.RangessContext): + pass + + # Exit a parse tree produced by AutolevParser#rangess. + def exitRangess(self, ctx:AutolevParser.RangessContext): + pass + + + # Enter a parse tree produced by AutolevParser#colon. + def enterColon(self, ctx:AutolevParser.ColonContext): + pass + + # Exit a parse tree produced by AutolevParser#colon. + def exitColon(self, ctx:AutolevParser.ColonContext): + pass + + + # Enter a parse tree produced by AutolevParser#id. + def enterId(self, ctx:AutolevParser.IdContext): + pass + + # Exit a parse tree produced by AutolevParser#id. + def exitId(self, ctx:AutolevParser.IdContext): + pass + + + # Enter a parse tree produced by AutolevParser#exp. + def enterExp(self, ctx:AutolevParser.ExpContext): + pass + + # Exit a parse tree produced by AutolevParser#exp. + def exitExp(self, ctx:AutolevParser.ExpContext): + pass + + + # Enter a parse tree produced by AutolevParser#matrices. + def enterMatrices(self, ctx:AutolevParser.MatricesContext): + pass + + # Exit a parse tree produced by AutolevParser#matrices. + def exitMatrices(self, ctx:AutolevParser.MatricesContext): + pass + + + # Enter a parse tree produced by AutolevParser#Indexing. + def enterIndexing(self, ctx:AutolevParser.IndexingContext): + pass + + # Exit a parse tree produced by AutolevParser#Indexing. + def exitIndexing(self, ctx:AutolevParser.IndexingContext): + pass + + + +del AutolevParser diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/_antlr/autolevparser.py b/phivenv/Lib/site-packages/sympy/parsing/autolev/_antlr/autolevparser.py new file mode 100644 index 0000000000000000000000000000000000000000..e63ef1c110812580d06291ee7c7ec40b6a076cea --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/autolev/_antlr/autolevparser.py @@ -0,0 +1,3063 @@ +# *** GENERATED BY `setup.py antlr`, DO NOT EDIT BY HAND *** +# +# Generated with antlr4 +# antlr4 is licensed under the BSD-3-Clause License +# https://github.com/antlr/antlr4/blob/master/LICENSE.txt +from antlr4 import * +from io import StringIO +import sys +if sys.version_info[1] > 5: + from typing import TextIO +else: + from typing.io import TextIO + +def serializedATN(): + return [ + 4,1,49,431,2,0,7,0,2,1,7,1,2,2,7,2,2,3,7,3,2,4,7,4,2,5,7,5,2,6,7, + 6,2,7,7,7,2,8,7,8,2,9,7,9,2,10,7,10,2,11,7,11,2,12,7,12,2,13,7,13, + 2,14,7,14,2,15,7,15,2,16,7,16,2,17,7,17,2,18,7,18,2,19,7,19,2,20, + 7,20,2,21,7,21,2,22,7,22,2,23,7,23,2,24,7,24,2,25,7,25,2,26,7,26, + 2,27,7,27,1,0,4,0,58,8,0,11,0,12,0,59,1,1,1,1,1,1,1,1,1,1,1,1,1, + 1,3,1,69,8,1,1,2,1,2,1,2,1,2,1,2,1,2,1,2,1,2,1,2,1,2,1,2,1,2,1,2, + 3,2,84,8,2,1,2,1,2,1,2,3,2,89,8,2,1,3,1,3,1,4,1,4,1,4,5,4,96,8,4, + 10,4,12,4,99,9,4,1,5,4,5,102,8,5,11,5,12,5,103,1,6,1,6,1,6,1,6,1, + 6,5,6,111,8,6,10,6,12,6,114,9,6,3,6,116,8,6,1,6,1,6,1,6,1,6,1,6, + 1,6,5,6,124,8,6,10,6,12,6,127,9,6,3,6,129,8,6,1,6,3,6,132,8,6,1, + 7,1,7,1,7,1,7,5,7,138,8,7,10,7,12,7,141,9,7,1,8,1,8,1,8,1,8,1,8, + 1,8,1,8,1,8,1,8,1,8,5,8,153,8,8,10,8,12,8,156,9,8,1,8,1,8,5,8,160, + 8,8,10,8,12,8,163,9,8,3,8,165,8,8,1,9,1,9,1,9,1,9,1,9,1,9,3,9,173, + 8,9,1,9,1,9,1,9,1,9,1,9,1,9,1,9,1,9,5,9,183,8,9,10,9,12,9,186,9, + 9,1,9,3,9,189,8,9,1,9,1,9,1,9,3,9,194,8,9,1,9,3,9,197,8,9,1,9,5, + 9,200,8,9,10,9,12,9,203,9,9,1,9,1,9,3,9,207,8,9,1,10,1,10,1,10,1, + 10,1,10,1,10,1,10,1,10,5,10,217,8,10,10,10,12,10,220,9,10,1,10,1, + 10,1,11,1,11,1,11,1,11,5,11,228,8,11,10,11,12,11,231,9,11,1,12,1, + 12,1,12,1,12,1,13,1,13,1,13,1,13,1,13,3,13,242,8,13,1,13,1,13,4, + 13,246,8,13,11,13,12,13,247,1,14,1,14,1,14,1,14,5,14,254,8,14,10, + 14,12,14,257,9,14,1,14,1,14,1,15,1,15,1,15,1,15,3,15,265,8,15,1, + 15,1,15,3,15,269,8,15,1,16,1,16,1,16,1,16,1,16,3,16,276,8,16,1,17, + 1,17,3,17,280,8,17,1,18,1,18,1,18,1,18,5,18,286,8,18,10,18,12,18, + 289,9,18,1,19,1,19,1,19,1,19,5,19,295,8,19,10,19,12,19,298,9,19, + 1,20,1,20,3,20,302,8,20,1,21,1,21,1,21,1,21,3,21,308,8,21,1,22,1, + 22,1,22,1,22,5,22,314,8,22,10,22,12,22,317,9,22,1,23,1,23,3,23,321, + 8,23,1,24,1,24,1,24,1,24,1,24,1,24,5,24,329,8,24,10,24,12,24,332, + 9,24,1,24,1,24,3,24,336,8,24,1,24,1,24,1,24,1,24,1,25,1,25,1,25, + 1,25,1,25,1,25,1,25,1,25,5,25,350,8,25,10,25,12,25,353,9,25,3,25, + 355,8,25,1,26,1,26,4,26,359,8,26,11,26,12,26,360,1,26,1,26,3,26, + 365,8,26,1,27,1,27,1,27,1,27,1,27,1,27,1,27,1,27,5,27,375,8,27,10, + 27,12,27,378,9,27,1,27,1,27,1,27,1,27,1,27,1,27,5,27,386,8,27,10, + 27,12,27,389,9,27,1,27,1,27,1,27,1,27,1,27,1,27,1,27,1,27,1,27,3, + 27,400,8,27,1,27,1,27,5,27,404,8,27,10,27,12,27,407,9,27,3,27,409, + 8,27,1,27,1,27,1,27,1,27,1,27,1,27,1,27,1,27,1,27,1,27,1,27,1,27, + 1,27,1,27,1,27,5,27,426,8,27,10,27,12,27,429,9,27,1,27,0,1,54,28, + 0,2,4,6,8,10,12,14,16,18,20,22,24,26,28,30,32,34,36,38,40,42,44, + 46,48,50,52,54,0,7,1,0,3,9,1,0,27,28,1,0,17,18,2,0,10,10,19,19,1, + 0,44,45,2,0,44,46,48,48,1,0,25,26,483,0,57,1,0,0,0,2,68,1,0,0,0, + 4,88,1,0,0,0,6,90,1,0,0,0,8,92,1,0,0,0,10,101,1,0,0,0,12,131,1,0, + 0,0,14,133,1,0,0,0,16,164,1,0,0,0,18,166,1,0,0,0,20,208,1,0,0,0, + 22,223,1,0,0,0,24,232,1,0,0,0,26,236,1,0,0,0,28,249,1,0,0,0,30,268, + 1,0,0,0,32,275,1,0,0,0,34,277,1,0,0,0,36,281,1,0,0,0,38,290,1,0, + 0,0,40,299,1,0,0,0,42,303,1,0,0,0,44,309,1,0,0,0,46,318,1,0,0,0, + 48,322,1,0,0,0,50,354,1,0,0,0,52,364,1,0,0,0,54,408,1,0,0,0,56,58, + 3,2,1,0,57,56,1,0,0,0,58,59,1,0,0,0,59,57,1,0,0,0,59,60,1,0,0,0, + 60,1,1,0,0,0,61,69,3,14,7,0,62,69,3,12,6,0,63,69,3,32,16,0,64,69, + 3,22,11,0,65,69,3,26,13,0,66,69,3,4,2,0,67,69,3,34,17,0,68,61,1, + 0,0,0,68,62,1,0,0,0,68,63,1,0,0,0,68,64,1,0,0,0,68,65,1,0,0,0,68, + 66,1,0,0,0,68,67,1,0,0,0,69,3,1,0,0,0,70,71,3,52,26,0,71,72,3,6, + 3,0,72,73,3,54,27,0,73,89,1,0,0,0,74,75,5,48,0,0,75,76,5,1,0,0,76, + 77,3,8,4,0,77,78,5,2,0,0,78,79,3,6,3,0,79,80,3,54,27,0,80,89,1,0, + 0,0,81,83,5,48,0,0,82,84,3,10,5,0,83,82,1,0,0,0,83,84,1,0,0,0,84, + 85,1,0,0,0,85,86,3,6,3,0,86,87,3,54,27,0,87,89,1,0,0,0,88,70,1,0, + 0,0,88,74,1,0,0,0,88,81,1,0,0,0,89,5,1,0,0,0,90,91,7,0,0,0,91,7, + 1,0,0,0,92,97,3,54,27,0,93,94,5,10,0,0,94,96,3,54,27,0,95,93,1,0, + 0,0,96,99,1,0,0,0,97,95,1,0,0,0,97,98,1,0,0,0,98,9,1,0,0,0,99,97, + 1,0,0,0,100,102,5,11,0,0,101,100,1,0,0,0,102,103,1,0,0,0,103,101, + 1,0,0,0,103,104,1,0,0,0,104,11,1,0,0,0,105,106,5,48,0,0,106,115, + 5,12,0,0,107,112,3,54,27,0,108,109,5,10,0,0,109,111,3,54,27,0,110, + 108,1,0,0,0,111,114,1,0,0,0,112,110,1,0,0,0,112,113,1,0,0,0,113, + 116,1,0,0,0,114,112,1,0,0,0,115,107,1,0,0,0,115,116,1,0,0,0,116, + 117,1,0,0,0,117,132,5,13,0,0,118,119,7,1,0,0,119,128,5,12,0,0,120, + 125,5,48,0,0,121,122,5,10,0,0,122,124,5,48,0,0,123,121,1,0,0,0,124, + 127,1,0,0,0,125,123,1,0,0,0,125,126,1,0,0,0,126,129,1,0,0,0,127, + 125,1,0,0,0,128,120,1,0,0,0,128,129,1,0,0,0,129,130,1,0,0,0,130, + 132,5,13,0,0,131,105,1,0,0,0,131,118,1,0,0,0,132,13,1,0,0,0,133, + 134,3,16,8,0,134,139,3,18,9,0,135,136,5,10,0,0,136,138,3,18,9,0, + 137,135,1,0,0,0,138,141,1,0,0,0,139,137,1,0,0,0,139,140,1,0,0,0, + 140,15,1,0,0,0,141,139,1,0,0,0,142,165,5,34,0,0,143,165,5,35,0,0, + 144,165,5,36,0,0,145,165,5,37,0,0,146,165,5,38,0,0,147,165,5,39, + 0,0,148,165,5,40,0,0,149,165,5,41,0,0,150,154,5,42,0,0,151,153,5, + 11,0,0,152,151,1,0,0,0,153,156,1,0,0,0,154,152,1,0,0,0,154,155,1, + 0,0,0,155,165,1,0,0,0,156,154,1,0,0,0,157,161,5,43,0,0,158,160,5, + 11,0,0,159,158,1,0,0,0,160,163,1,0,0,0,161,159,1,0,0,0,161,162,1, + 0,0,0,162,165,1,0,0,0,163,161,1,0,0,0,164,142,1,0,0,0,164,143,1, + 0,0,0,164,144,1,0,0,0,164,145,1,0,0,0,164,146,1,0,0,0,164,147,1, + 0,0,0,164,148,1,0,0,0,164,149,1,0,0,0,164,150,1,0,0,0,164,157,1, + 0,0,0,165,17,1,0,0,0,166,172,5,48,0,0,167,168,5,14,0,0,168,169,5, + 44,0,0,169,170,5,10,0,0,170,171,5,44,0,0,171,173,5,15,0,0,172,167, + 1,0,0,0,172,173,1,0,0,0,173,188,1,0,0,0,174,175,5,14,0,0,175,176, + 5,44,0,0,176,177,5,16,0,0,177,184,5,44,0,0,178,179,5,10,0,0,179, + 180,5,44,0,0,180,181,5,16,0,0,181,183,5,44,0,0,182,178,1,0,0,0,183, + 186,1,0,0,0,184,182,1,0,0,0,184,185,1,0,0,0,185,187,1,0,0,0,186, + 184,1,0,0,0,187,189,5,15,0,0,188,174,1,0,0,0,188,189,1,0,0,0,189, + 193,1,0,0,0,190,191,5,14,0,0,191,192,5,44,0,0,192,194,5,15,0,0,193, + 190,1,0,0,0,193,194,1,0,0,0,194,196,1,0,0,0,195,197,7,2,0,0,196, + 195,1,0,0,0,196,197,1,0,0,0,197,201,1,0,0,0,198,200,5,11,0,0,199, + 198,1,0,0,0,200,203,1,0,0,0,201,199,1,0,0,0,201,202,1,0,0,0,202, + 206,1,0,0,0,203,201,1,0,0,0,204,205,5,3,0,0,205,207,3,54,27,0,206, + 204,1,0,0,0,206,207,1,0,0,0,207,19,1,0,0,0,208,209,5,14,0,0,209, + 210,5,44,0,0,210,211,5,16,0,0,211,218,5,44,0,0,212,213,5,10,0,0, + 213,214,5,44,0,0,214,215,5,16,0,0,215,217,5,44,0,0,216,212,1,0,0, + 0,217,220,1,0,0,0,218,216,1,0,0,0,218,219,1,0,0,0,219,221,1,0,0, + 0,220,218,1,0,0,0,221,222,5,15,0,0,222,21,1,0,0,0,223,224,5,27,0, + 0,224,229,3,24,12,0,225,226,5,10,0,0,226,228,3,24,12,0,227,225,1, + 0,0,0,228,231,1,0,0,0,229,227,1,0,0,0,229,230,1,0,0,0,230,23,1,0, + 0,0,231,229,1,0,0,0,232,233,5,48,0,0,233,234,5,3,0,0,234,235,3,54, + 27,0,235,25,1,0,0,0,236,237,5,28,0,0,237,241,5,48,0,0,238,239,5, + 12,0,0,239,240,5,48,0,0,240,242,5,13,0,0,241,238,1,0,0,0,241,242, + 1,0,0,0,242,245,1,0,0,0,243,244,5,10,0,0,244,246,3,54,27,0,245,243, + 1,0,0,0,246,247,1,0,0,0,247,245,1,0,0,0,247,248,1,0,0,0,248,27,1, + 0,0,0,249,250,5,1,0,0,250,255,3,54,27,0,251,252,7,3,0,0,252,254, + 3,54,27,0,253,251,1,0,0,0,254,257,1,0,0,0,255,253,1,0,0,0,255,256, + 1,0,0,0,256,258,1,0,0,0,257,255,1,0,0,0,258,259,5,2,0,0,259,29,1, + 0,0,0,260,261,5,48,0,0,261,262,5,48,0,0,262,264,5,3,0,0,263,265, + 7,4,0,0,264,263,1,0,0,0,264,265,1,0,0,0,265,269,1,0,0,0,266,269, + 5,45,0,0,267,269,5,44,0,0,268,260,1,0,0,0,268,266,1,0,0,0,268,267, + 1,0,0,0,269,31,1,0,0,0,270,276,3,36,18,0,271,276,3,38,19,0,272,276, + 3,44,22,0,273,276,3,48,24,0,274,276,3,50,25,0,275,270,1,0,0,0,275, + 271,1,0,0,0,275,272,1,0,0,0,275,273,1,0,0,0,275,274,1,0,0,0,276, + 33,1,0,0,0,277,279,5,48,0,0,278,280,7,5,0,0,279,278,1,0,0,0,279, + 280,1,0,0,0,280,35,1,0,0,0,281,282,5,32,0,0,282,287,5,48,0,0,283, + 284,5,10,0,0,284,286,5,48,0,0,285,283,1,0,0,0,286,289,1,0,0,0,287, + 285,1,0,0,0,287,288,1,0,0,0,288,37,1,0,0,0,289,287,1,0,0,0,290,291, + 5,29,0,0,291,296,3,42,21,0,292,293,5,10,0,0,293,295,3,42,21,0,294, + 292,1,0,0,0,295,298,1,0,0,0,296,294,1,0,0,0,296,297,1,0,0,0,297, + 39,1,0,0,0,298,296,1,0,0,0,299,301,5,48,0,0,300,302,3,10,5,0,301, + 300,1,0,0,0,301,302,1,0,0,0,302,41,1,0,0,0,303,304,3,40,20,0,304, + 305,5,3,0,0,305,307,3,54,27,0,306,308,3,54,27,0,307,306,1,0,0,0, + 307,308,1,0,0,0,308,43,1,0,0,0,309,310,5,30,0,0,310,315,3,46,23, + 0,311,312,5,10,0,0,312,314,3,46,23,0,313,311,1,0,0,0,314,317,1,0, + 0,0,315,313,1,0,0,0,315,316,1,0,0,0,316,45,1,0,0,0,317,315,1,0,0, + 0,318,320,3,54,27,0,319,321,3,54,27,0,320,319,1,0,0,0,320,321,1, + 0,0,0,321,47,1,0,0,0,322,323,5,48,0,0,323,335,3,12,6,0,324,325,5, + 1,0,0,325,330,3,30,15,0,326,327,5,10,0,0,327,329,3,30,15,0,328,326, + 1,0,0,0,329,332,1,0,0,0,330,328,1,0,0,0,330,331,1,0,0,0,331,333, + 1,0,0,0,332,330,1,0,0,0,333,334,5,2,0,0,334,336,1,0,0,0,335,324, + 1,0,0,0,335,336,1,0,0,0,336,337,1,0,0,0,337,338,5,48,0,0,338,339, + 5,20,0,0,339,340,5,48,0,0,340,49,1,0,0,0,341,342,5,31,0,0,342,343, + 5,48,0,0,343,344,5,20,0,0,344,355,5,48,0,0,345,346,5,33,0,0,346, + 351,5,48,0,0,347,348,5,10,0,0,348,350,5,48,0,0,349,347,1,0,0,0,350, + 353,1,0,0,0,351,349,1,0,0,0,351,352,1,0,0,0,352,355,1,0,0,0,353, + 351,1,0,0,0,354,341,1,0,0,0,354,345,1,0,0,0,355,51,1,0,0,0,356,358, + 5,48,0,0,357,359,5,21,0,0,358,357,1,0,0,0,359,360,1,0,0,0,360,358, + 1,0,0,0,360,361,1,0,0,0,361,365,1,0,0,0,362,365,5,22,0,0,363,365, + 5,23,0,0,364,356,1,0,0,0,364,362,1,0,0,0,364,363,1,0,0,0,365,53, + 1,0,0,0,366,367,6,27,-1,0,367,409,5,46,0,0,368,369,5,18,0,0,369, + 409,3,54,27,12,370,409,5,45,0,0,371,409,5,44,0,0,372,376,5,48,0, + 0,373,375,5,11,0,0,374,373,1,0,0,0,375,378,1,0,0,0,376,374,1,0,0, + 0,376,377,1,0,0,0,377,409,1,0,0,0,378,376,1,0,0,0,379,409,3,52,26, + 0,380,381,5,48,0,0,381,382,5,1,0,0,382,387,3,54,27,0,383,384,5,10, + 0,0,384,386,3,54,27,0,385,383,1,0,0,0,386,389,1,0,0,0,387,385,1, + 0,0,0,387,388,1,0,0,0,388,390,1,0,0,0,389,387,1,0,0,0,390,391,5, + 2,0,0,391,409,1,0,0,0,392,409,3,12,6,0,393,409,3,28,14,0,394,395, + 5,12,0,0,395,396,3,54,27,0,396,397,5,13,0,0,397,409,1,0,0,0,398, + 400,5,48,0,0,399,398,1,0,0,0,399,400,1,0,0,0,400,401,1,0,0,0,401, + 405,3,20,10,0,402,404,5,11,0,0,403,402,1,0,0,0,404,407,1,0,0,0,405, + 403,1,0,0,0,405,406,1,0,0,0,406,409,1,0,0,0,407,405,1,0,0,0,408, + 366,1,0,0,0,408,368,1,0,0,0,408,370,1,0,0,0,408,371,1,0,0,0,408, + 372,1,0,0,0,408,379,1,0,0,0,408,380,1,0,0,0,408,392,1,0,0,0,408, + 393,1,0,0,0,408,394,1,0,0,0,408,399,1,0,0,0,409,427,1,0,0,0,410, + 411,10,16,0,0,411,412,5,24,0,0,412,426,3,54,27,17,413,414,10,15, + 0,0,414,415,7,6,0,0,415,426,3,54,27,16,416,417,10,14,0,0,417,418, + 7,2,0,0,418,426,3,54,27,15,419,420,10,3,0,0,420,421,5,3,0,0,421, + 426,3,54,27,4,422,423,10,2,0,0,423,424,5,16,0,0,424,426,3,54,27, + 3,425,410,1,0,0,0,425,413,1,0,0,0,425,416,1,0,0,0,425,419,1,0,0, + 0,425,422,1,0,0,0,426,429,1,0,0,0,427,425,1,0,0,0,427,428,1,0,0, + 0,428,55,1,0,0,0,429,427,1,0,0,0,50,59,68,83,88,97,103,112,115,125, + 128,131,139,154,161,164,172,184,188,193,196,201,206,218,229,241, + 247,255,264,268,275,279,287,296,301,307,315,320,330,335,351,354, + 360,364,376,387,399,405,408,425,427 + ] + +class AutolevParser ( Parser ): + + grammarFileName = "Autolev.g4" + + atn = ATNDeserializer().deserialize(serializedATN()) + + decisionsToDFA = [ DFA(ds, i) for i, ds in enumerate(atn.decisionToState) ] + + sharedContextCache = PredictionContextCache() + + literalNames = [ "", "'['", "']'", "'='", "'+='", "'-='", "':='", + "'*='", "'/='", "'^='", "','", "'''", "'('", "')'", + "'{'", "'}'", "':'", "'+'", "'-'", "';'", "'.'", "'>'", + "'0>'", "'1>>'", "'^'", "'*'", "'/'" ] + + symbolicNames = [ "", "", "", "", + "", "", "", "", + "", "", "", "", + "", "", "", "", + "", "", "", "", + "", "", "", "", + "", "", "", "Mass", "Inertia", + "Input", "Output", "Save", "UnitSystem", "Encode", + "Newtonian", "Frames", "Bodies", "Particles", "Points", + "Constants", "Specifieds", "Imaginary", "Variables", + "MotionVariables", "INT", "FLOAT", "EXP", "LINE_COMMENT", + "ID", "WS" ] + + RULE_prog = 0 + RULE_stat = 1 + RULE_assignment = 2 + RULE_equals = 3 + RULE_index = 4 + RULE_diff = 5 + RULE_functionCall = 6 + RULE_varDecl = 7 + RULE_varType = 8 + RULE_varDecl2 = 9 + RULE_ranges = 10 + RULE_massDecl = 11 + RULE_massDecl2 = 12 + RULE_inertiaDecl = 13 + RULE_matrix = 14 + RULE_matrixInOutput = 15 + RULE_codeCommands = 16 + RULE_settings = 17 + RULE_units = 18 + RULE_inputs = 19 + RULE_id_diff = 20 + RULE_inputs2 = 21 + RULE_outputs = 22 + RULE_outputs2 = 23 + RULE_codegen = 24 + RULE_commands = 25 + RULE_vec = 26 + RULE_expr = 27 + + ruleNames = [ "prog", "stat", "assignment", "equals", "index", "diff", + "functionCall", "varDecl", "varType", "varDecl2", "ranges", + "massDecl", "massDecl2", "inertiaDecl", "matrix", "matrixInOutput", + "codeCommands", "settings", "units", "inputs", "id_diff", + "inputs2", "outputs", "outputs2", "codegen", "commands", + "vec", "expr" ] + + EOF = Token.EOF + T__0=1 + T__1=2 + T__2=3 + T__3=4 + T__4=5 + T__5=6 + T__6=7 + T__7=8 + T__8=9 + T__9=10 + T__10=11 + T__11=12 + T__12=13 + T__13=14 + T__14=15 + T__15=16 + T__16=17 + T__17=18 + T__18=19 + T__19=20 + T__20=21 + T__21=22 + T__22=23 + T__23=24 + T__24=25 + T__25=26 + Mass=27 + Inertia=28 + Input=29 + Output=30 + Save=31 + UnitSystem=32 + Encode=33 + Newtonian=34 + Frames=35 + Bodies=36 + Particles=37 + Points=38 + Constants=39 + Specifieds=40 + Imaginary=41 + Variables=42 + MotionVariables=43 + INT=44 + FLOAT=45 + EXP=46 + LINE_COMMENT=47 + ID=48 + WS=49 + + def __init__(self, input:TokenStream, output:TextIO = sys.stdout): + super().__init__(input, output) + self.checkVersion("4.11.1") + self._interp = ParserATNSimulator(self, self.atn, self.decisionsToDFA, self.sharedContextCache) + self._predicates = None + + + + + class ProgContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def stat(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.StatContext) + else: + return self.getTypedRuleContext(AutolevParser.StatContext,i) + + + def getRuleIndex(self): + return AutolevParser.RULE_prog + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterProg" ): + listener.enterProg(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitProg" ): + listener.exitProg(self) + + + + + def prog(self): + + localctx = AutolevParser.ProgContext(self, self._ctx, self.state) + self.enterRule(localctx, 0, self.RULE_prog) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 57 + self._errHandler.sync(self) + _la = self._input.LA(1) + while True: + self.state = 56 + self.stat() + self.state = 59 + self._errHandler.sync(self) + _la = self._input.LA(1) + if not (((_la) & ~0x3f) == 0 and ((1 << _la) & 299067041120256) != 0): + break + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class StatContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def varDecl(self): + return self.getTypedRuleContext(AutolevParser.VarDeclContext,0) + + + def functionCall(self): + return self.getTypedRuleContext(AutolevParser.FunctionCallContext,0) + + + def codeCommands(self): + return self.getTypedRuleContext(AutolevParser.CodeCommandsContext,0) + + + def massDecl(self): + return self.getTypedRuleContext(AutolevParser.MassDeclContext,0) + + + def inertiaDecl(self): + return self.getTypedRuleContext(AutolevParser.InertiaDeclContext,0) + + + def assignment(self): + return self.getTypedRuleContext(AutolevParser.AssignmentContext,0) + + + def settings(self): + return self.getTypedRuleContext(AutolevParser.SettingsContext,0) + + + def getRuleIndex(self): + return AutolevParser.RULE_stat + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterStat" ): + listener.enterStat(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitStat" ): + listener.exitStat(self) + + + + + def stat(self): + + localctx = AutolevParser.StatContext(self, self._ctx, self.state) + self.enterRule(localctx, 2, self.RULE_stat) + try: + self.state = 68 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,1,self._ctx) + if la_ == 1: + self.enterOuterAlt(localctx, 1) + self.state = 61 + self.varDecl() + pass + + elif la_ == 2: + self.enterOuterAlt(localctx, 2) + self.state = 62 + self.functionCall() + pass + + elif la_ == 3: + self.enterOuterAlt(localctx, 3) + self.state = 63 + self.codeCommands() + pass + + elif la_ == 4: + self.enterOuterAlt(localctx, 4) + self.state = 64 + self.massDecl() + pass + + elif la_ == 5: + self.enterOuterAlt(localctx, 5) + self.state = 65 + self.inertiaDecl() + pass + + elif la_ == 6: + self.enterOuterAlt(localctx, 6) + self.state = 66 + self.assignment() + pass + + elif la_ == 7: + self.enterOuterAlt(localctx, 7) + self.state = 67 + self.settings() + pass + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class AssignmentContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + + def getRuleIndex(self): + return AutolevParser.RULE_assignment + + + def copyFrom(self, ctx:ParserRuleContext): + super().copyFrom(ctx) + + + + class VecAssignContext(AssignmentContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.AssignmentContext + super().__init__(parser) + self.copyFrom(ctx) + + def vec(self): + return self.getTypedRuleContext(AutolevParser.VecContext,0) + + def equals(self): + return self.getTypedRuleContext(AutolevParser.EqualsContext,0) + + def expr(self): + return self.getTypedRuleContext(AutolevParser.ExprContext,0) + + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterVecAssign" ): + listener.enterVecAssign(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitVecAssign" ): + listener.exitVecAssign(self) + + + class RegularAssignContext(AssignmentContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.AssignmentContext + super().__init__(parser) + self.copyFrom(ctx) + + def ID(self): + return self.getToken(AutolevParser.ID, 0) + def equals(self): + return self.getTypedRuleContext(AutolevParser.EqualsContext,0) + + def expr(self): + return self.getTypedRuleContext(AutolevParser.ExprContext,0) + + def diff(self): + return self.getTypedRuleContext(AutolevParser.DiffContext,0) + + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterRegularAssign" ): + listener.enterRegularAssign(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitRegularAssign" ): + listener.exitRegularAssign(self) + + + class IndexAssignContext(AssignmentContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.AssignmentContext + super().__init__(parser) + self.copyFrom(ctx) + + def ID(self): + return self.getToken(AutolevParser.ID, 0) + def index(self): + return self.getTypedRuleContext(AutolevParser.IndexContext,0) + + def equals(self): + return self.getTypedRuleContext(AutolevParser.EqualsContext,0) + + def expr(self): + return self.getTypedRuleContext(AutolevParser.ExprContext,0) + + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterIndexAssign" ): + listener.enterIndexAssign(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitIndexAssign" ): + listener.exitIndexAssign(self) + + + + def assignment(self): + + localctx = AutolevParser.AssignmentContext(self, self._ctx, self.state) + self.enterRule(localctx, 4, self.RULE_assignment) + self._la = 0 # Token type + try: + self.state = 88 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,3,self._ctx) + if la_ == 1: + localctx = AutolevParser.VecAssignContext(self, localctx) + self.enterOuterAlt(localctx, 1) + self.state = 70 + self.vec() + self.state = 71 + self.equals() + self.state = 72 + self.expr(0) + pass + + elif la_ == 2: + localctx = AutolevParser.IndexAssignContext(self, localctx) + self.enterOuterAlt(localctx, 2) + self.state = 74 + self.match(AutolevParser.ID) + self.state = 75 + self.match(AutolevParser.T__0) + self.state = 76 + self.index() + self.state = 77 + self.match(AutolevParser.T__1) + self.state = 78 + self.equals() + self.state = 79 + self.expr(0) + pass + + elif la_ == 3: + localctx = AutolevParser.RegularAssignContext(self, localctx) + self.enterOuterAlt(localctx, 3) + self.state = 81 + self.match(AutolevParser.ID) + self.state = 83 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==11: + self.state = 82 + self.diff() + + + self.state = 85 + self.equals() + self.state = 86 + self.expr(0) + pass + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class EqualsContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + + def getRuleIndex(self): + return AutolevParser.RULE_equals + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterEquals" ): + listener.enterEquals(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitEquals" ): + listener.exitEquals(self) + + + + + def equals(self): + + localctx = AutolevParser.EqualsContext(self, self._ctx, self.state) + self.enterRule(localctx, 6, self.RULE_equals) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 90 + _la = self._input.LA(1) + if not(((_la) & ~0x3f) == 0 and ((1 << _la) & 1016) != 0): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class IndexContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.ExprContext) + else: + return self.getTypedRuleContext(AutolevParser.ExprContext,i) + + + def getRuleIndex(self): + return AutolevParser.RULE_index + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterIndex" ): + listener.enterIndex(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitIndex" ): + listener.exitIndex(self) + + + + + def index(self): + + localctx = AutolevParser.IndexContext(self, self._ctx, self.state) + self.enterRule(localctx, 8, self.RULE_index) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 92 + self.expr(0) + self.state = 97 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==10: + self.state = 93 + self.match(AutolevParser.T__9) + self.state = 94 + self.expr(0) + self.state = 99 + self._errHandler.sync(self) + _la = self._input.LA(1) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class DiffContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + + def getRuleIndex(self): + return AutolevParser.RULE_diff + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterDiff" ): + listener.enterDiff(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitDiff" ): + listener.exitDiff(self) + + + + + def diff(self): + + localctx = AutolevParser.DiffContext(self, self._ctx, self.state) + self.enterRule(localctx, 10, self.RULE_diff) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 101 + self._errHandler.sync(self) + _la = self._input.LA(1) + while True: + self.state = 100 + self.match(AutolevParser.T__10) + self.state = 103 + self._errHandler.sync(self) + _la = self._input.LA(1) + if not (_la==11): + break + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class FunctionCallContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def ID(self, i:int=None): + if i is None: + return self.getTokens(AutolevParser.ID) + else: + return self.getToken(AutolevParser.ID, i) + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.ExprContext) + else: + return self.getTypedRuleContext(AutolevParser.ExprContext,i) + + + def Mass(self): + return self.getToken(AutolevParser.Mass, 0) + + def Inertia(self): + return self.getToken(AutolevParser.Inertia, 0) + + def getRuleIndex(self): + return AutolevParser.RULE_functionCall + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterFunctionCall" ): + listener.enterFunctionCall(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitFunctionCall" ): + listener.exitFunctionCall(self) + + + + + def functionCall(self): + + localctx = AutolevParser.FunctionCallContext(self, self._ctx, self.state) + self.enterRule(localctx, 12, self.RULE_functionCall) + self._la = 0 # Token type + try: + self.state = 131 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [48]: + self.enterOuterAlt(localctx, 1) + self.state = 105 + self.match(AutolevParser.ID) + self.state = 106 + self.match(AutolevParser.T__11) + self.state = 115 + self._errHandler.sync(self) + _la = self._input.LA(1) + if ((_la) & ~0x3f) == 0 and ((1 << _la) & 404620694540290) != 0: + self.state = 107 + self.expr(0) + self.state = 112 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==10: + self.state = 108 + self.match(AutolevParser.T__9) + self.state = 109 + self.expr(0) + self.state = 114 + self._errHandler.sync(self) + _la = self._input.LA(1) + + + + self.state = 117 + self.match(AutolevParser.T__12) + pass + elif token in [27, 28]: + self.enterOuterAlt(localctx, 2) + self.state = 118 + _la = self._input.LA(1) + if not(_la==27 or _la==28): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 119 + self.match(AutolevParser.T__11) + self.state = 128 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==48: + self.state = 120 + self.match(AutolevParser.ID) + self.state = 125 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==10: + self.state = 121 + self.match(AutolevParser.T__9) + self.state = 122 + self.match(AutolevParser.ID) + self.state = 127 + self._errHandler.sync(self) + _la = self._input.LA(1) + + + + self.state = 130 + self.match(AutolevParser.T__12) + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class VarDeclContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def varType(self): + return self.getTypedRuleContext(AutolevParser.VarTypeContext,0) + + + def varDecl2(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.VarDecl2Context) + else: + return self.getTypedRuleContext(AutolevParser.VarDecl2Context,i) + + + def getRuleIndex(self): + return AutolevParser.RULE_varDecl + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterVarDecl" ): + listener.enterVarDecl(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitVarDecl" ): + listener.exitVarDecl(self) + + + + + def varDecl(self): + + localctx = AutolevParser.VarDeclContext(self, self._ctx, self.state) + self.enterRule(localctx, 14, self.RULE_varDecl) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 133 + self.varType() + self.state = 134 + self.varDecl2() + self.state = 139 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==10: + self.state = 135 + self.match(AutolevParser.T__9) + self.state = 136 + self.varDecl2() + self.state = 141 + self._errHandler.sync(self) + _la = self._input.LA(1) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class VarTypeContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def Newtonian(self): + return self.getToken(AutolevParser.Newtonian, 0) + + def Frames(self): + return self.getToken(AutolevParser.Frames, 0) + + def Bodies(self): + return self.getToken(AutolevParser.Bodies, 0) + + def Particles(self): + return self.getToken(AutolevParser.Particles, 0) + + def Points(self): + return self.getToken(AutolevParser.Points, 0) + + def Constants(self): + return self.getToken(AutolevParser.Constants, 0) + + def Specifieds(self): + return self.getToken(AutolevParser.Specifieds, 0) + + def Imaginary(self): + return self.getToken(AutolevParser.Imaginary, 0) + + def Variables(self): + return self.getToken(AutolevParser.Variables, 0) + + def MotionVariables(self): + return self.getToken(AutolevParser.MotionVariables, 0) + + def getRuleIndex(self): + return AutolevParser.RULE_varType + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterVarType" ): + listener.enterVarType(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitVarType" ): + listener.exitVarType(self) + + + + + def varType(self): + + localctx = AutolevParser.VarTypeContext(self, self._ctx, self.state) + self.enterRule(localctx, 16, self.RULE_varType) + self._la = 0 # Token type + try: + self.state = 164 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [34]: + self.enterOuterAlt(localctx, 1) + self.state = 142 + self.match(AutolevParser.Newtonian) + pass + elif token in [35]: + self.enterOuterAlt(localctx, 2) + self.state = 143 + self.match(AutolevParser.Frames) + pass + elif token in [36]: + self.enterOuterAlt(localctx, 3) + self.state = 144 + self.match(AutolevParser.Bodies) + pass + elif token in [37]: + self.enterOuterAlt(localctx, 4) + self.state = 145 + self.match(AutolevParser.Particles) + pass + elif token in [38]: + self.enterOuterAlt(localctx, 5) + self.state = 146 + self.match(AutolevParser.Points) + pass + elif token in [39]: + self.enterOuterAlt(localctx, 6) + self.state = 147 + self.match(AutolevParser.Constants) + pass + elif token in [40]: + self.enterOuterAlt(localctx, 7) + self.state = 148 + self.match(AutolevParser.Specifieds) + pass + elif token in [41]: + self.enterOuterAlt(localctx, 8) + self.state = 149 + self.match(AutolevParser.Imaginary) + pass + elif token in [42]: + self.enterOuterAlt(localctx, 9) + self.state = 150 + self.match(AutolevParser.Variables) + self.state = 154 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==11: + self.state = 151 + self.match(AutolevParser.T__10) + self.state = 156 + self._errHandler.sync(self) + _la = self._input.LA(1) + + pass + elif token in [43]: + self.enterOuterAlt(localctx, 10) + self.state = 157 + self.match(AutolevParser.MotionVariables) + self.state = 161 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==11: + self.state = 158 + self.match(AutolevParser.T__10) + self.state = 163 + self._errHandler.sync(self) + _la = self._input.LA(1) + + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class VarDecl2Context(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def ID(self): + return self.getToken(AutolevParser.ID, 0) + + def INT(self, i:int=None): + if i is None: + return self.getTokens(AutolevParser.INT) + else: + return self.getToken(AutolevParser.INT, i) + + def expr(self): + return self.getTypedRuleContext(AutolevParser.ExprContext,0) + + + def getRuleIndex(self): + return AutolevParser.RULE_varDecl2 + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterVarDecl2" ): + listener.enterVarDecl2(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitVarDecl2" ): + listener.exitVarDecl2(self) + + + + + def varDecl2(self): + + localctx = AutolevParser.VarDecl2Context(self, self._ctx, self.state) + self.enterRule(localctx, 18, self.RULE_varDecl2) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 166 + self.match(AutolevParser.ID) + self.state = 172 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,15,self._ctx) + if la_ == 1: + self.state = 167 + self.match(AutolevParser.T__13) + self.state = 168 + self.match(AutolevParser.INT) + self.state = 169 + self.match(AutolevParser.T__9) + self.state = 170 + self.match(AutolevParser.INT) + self.state = 171 + self.match(AutolevParser.T__14) + + + self.state = 188 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,17,self._ctx) + if la_ == 1: + self.state = 174 + self.match(AutolevParser.T__13) + self.state = 175 + self.match(AutolevParser.INT) + self.state = 176 + self.match(AutolevParser.T__15) + self.state = 177 + self.match(AutolevParser.INT) + self.state = 184 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==10: + self.state = 178 + self.match(AutolevParser.T__9) + self.state = 179 + self.match(AutolevParser.INT) + self.state = 180 + self.match(AutolevParser.T__15) + self.state = 181 + self.match(AutolevParser.INT) + self.state = 186 + self._errHandler.sync(self) + _la = self._input.LA(1) + + self.state = 187 + self.match(AutolevParser.T__14) + + + self.state = 193 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==14: + self.state = 190 + self.match(AutolevParser.T__13) + self.state = 191 + self.match(AutolevParser.INT) + self.state = 192 + self.match(AutolevParser.T__14) + + + self.state = 196 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==17 or _la==18: + self.state = 195 + _la = self._input.LA(1) + if not(_la==17 or _la==18): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + + + self.state = 201 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==11: + self.state = 198 + self.match(AutolevParser.T__10) + self.state = 203 + self._errHandler.sync(self) + _la = self._input.LA(1) + + self.state = 206 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==3: + self.state = 204 + self.match(AutolevParser.T__2) + self.state = 205 + self.expr(0) + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class RangesContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def INT(self, i:int=None): + if i is None: + return self.getTokens(AutolevParser.INT) + else: + return self.getToken(AutolevParser.INT, i) + + def getRuleIndex(self): + return AutolevParser.RULE_ranges + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterRanges" ): + listener.enterRanges(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitRanges" ): + listener.exitRanges(self) + + + + + def ranges(self): + + localctx = AutolevParser.RangesContext(self, self._ctx, self.state) + self.enterRule(localctx, 20, self.RULE_ranges) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 208 + self.match(AutolevParser.T__13) + self.state = 209 + self.match(AutolevParser.INT) + self.state = 210 + self.match(AutolevParser.T__15) + self.state = 211 + self.match(AutolevParser.INT) + self.state = 218 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==10: + self.state = 212 + self.match(AutolevParser.T__9) + self.state = 213 + self.match(AutolevParser.INT) + self.state = 214 + self.match(AutolevParser.T__15) + self.state = 215 + self.match(AutolevParser.INT) + self.state = 220 + self._errHandler.sync(self) + _la = self._input.LA(1) + + self.state = 221 + self.match(AutolevParser.T__14) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class MassDeclContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def Mass(self): + return self.getToken(AutolevParser.Mass, 0) + + def massDecl2(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.MassDecl2Context) + else: + return self.getTypedRuleContext(AutolevParser.MassDecl2Context,i) + + + def getRuleIndex(self): + return AutolevParser.RULE_massDecl + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterMassDecl" ): + listener.enterMassDecl(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitMassDecl" ): + listener.exitMassDecl(self) + + + + + def massDecl(self): + + localctx = AutolevParser.MassDeclContext(self, self._ctx, self.state) + self.enterRule(localctx, 22, self.RULE_massDecl) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 223 + self.match(AutolevParser.Mass) + self.state = 224 + self.massDecl2() + self.state = 229 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==10: + self.state = 225 + self.match(AutolevParser.T__9) + self.state = 226 + self.massDecl2() + self.state = 231 + self._errHandler.sync(self) + _la = self._input.LA(1) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class MassDecl2Context(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def ID(self): + return self.getToken(AutolevParser.ID, 0) + + def expr(self): + return self.getTypedRuleContext(AutolevParser.ExprContext,0) + + + def getRuleIndex(self): + return AutolevParser.RULE_massDecl2 + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterMassDecl2" ): + listener.enterMassDecl2(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitMassDecl2" ): + listener.exitMassDecl2(self) + + + + + def massDecl2(self): + + localctx = AutolevParser.MassDecl2Context(self, self._ctx, self.state) + self.enterRule(localctx, 24, self.RULE_massDecl2) + try: + self.enterOuterAlt(localctx, 1) + self.state = 232 + self.match(AutolevParser.ID) + self.state = 233 + self.match(AutolevParser.T__2) + self.state = 234 + self.expr(0) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class InertiaDeclContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def Inertia(self): + return self.getToken(AutolevParser.Inertia, 0) + + def ID(self, i:int=None): + if i is None: + return self.getTokens(AutolevParser.ID) + else: + return self.getToken(AutolevParser.ID, i) + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.ExprContext) + else: + return self.getTypedRuleContext(AutolevParser.ExprContext,i) + + + def getRuleIndex(self): + return AutolevParser.RULE_inertiaDecl + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterInertiaDecl" ): + listener.enterInertiaDecl(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitInertiaDecl" ): + listener.exitInertiaDecl(self) + + + + + def inertiaDecl(self): + + localctx = AutolevParser.InertiaDeclContext(self, self._ctx, self.state) + self.enterRule(localctx, 26, self.RULE_inertiaDecl) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 236 + self.match(AutolevParser.Inertia) + self.state = 237 + self.match(AutolevParser.ID) + self.state = 241 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==12: + self.state = 238 + self.match(AutolevParser.T__11) + self.state = 239 + self.match(AutolevParser.ID) + self.state = 240 + self.match(AutolevParser.T__12) + + + self.state = 245 + self._errHandler.sync(self) + _la = self._input.LA(1) + while True: + self.state = 243 + self.match(AutolevParser.T__9) + self.state = 244 + self.expr(0) + self.state = 247 + self._errHandler.sync(self) + _la = self._input.LA(1) + if not (_la==10): + break + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class MatrixContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.ExprContext) + else: + return self.getTypedRuleContext(AutolevParser.ExprContext,i) + + + def getRuleIndex(self): + return AutolevParser.RULE_matrix + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterMatrix" ): + listener.enterMatrix(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitMatrix" ): + listener.exitMatrix(self) + + + + + def matrix(self): + + localctx = AutolevParser.MatrixContext(self, self._ctx, self.state) + self.enterRule(localctx, 28, self.RULE_matrix) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 249 + self.match(AutolevParser.T__0) + self.state = 250 + self.expr(0) + self.state = 255 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==10 or _la==19: + self.state = 251 + _la = self._input.LA(1) + if not(_la==10 or _la==19): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 252 + self.expr(0) + self.state = 257 + self._errHandler.sync(self) + _la = self._input.LA(1) + + self.state = 258 + self.match(AutolevParser.T__1) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class MatrixInOutputContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def ID(self, i:int=None): + if i is None: + return self.getTokens(AutolevParser.ID) + else: + return self.getToken(AutolevParser.ID, i) + + def FLOAT(self): + return self.getToken(AutolevParser.FLOAT, 0) + + def INT(self): + return self.getToken(AutolevParser.INT, 0) + + def getRuleIndex(self): + return AutolevParser.RULE_matrixInOutput + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterMatrixInOutput" ): + listener.enterMatrixInOutput(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitMatrixInOutput" ): + listener.exitMatrixInOutput(self) + + + + + def matrixInOutput(self): + + localctx = AutolevParser.MatrixInOutputContext(self, self._ctx, self.state) + self.enterRule(localctx, 30, self.RULE_matrixInOutput) + self._la = 0 # Token type + try: + self.state = 268 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [48]: + self.enterOuterAlt(localctx, 1) + self.state = 260 + self.match(AutolevParser.ID) + + self.state = 261 + self.match(AutolevParser.ID) + self.state = 262 + self.match(AutolevParser.T__2) + self.state = 264 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==44 or _la==45: + self.state = 263 + _la = self._input.LA(1) + if not(_la==44 or _la==45): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + + + pass + elif token in [45]: + self.enterOuterAlt(localctx, 2) + self.state = 266 + self.match(AutolevParser.FLOAT) + pass + elif token in [44]: + self.enterOuterAlt(localctx, 3) + self.state = 267 + self.match(AutolevParser.INT) + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class CodeCommandsContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def units(self): + return self.getTypedRuleContext(AutolevParser.UnitsContext,0) + + + def inputs(self): + return self.getTypedRuleContext(AutolevParser.InputsContext,0) + + + def outputs(self): + return self.getTypedRuleContext(AutolevParser.OutputsContext,0) + + + def codegen(self): + return self.getTypedRuleContext(AutolevParser.CodegenContext,0) + + + def commands(self): + return self.getTypedRuleContext(AutolevParser.CommandsContext,0) + + + def getRuleIndex(self): + return AutolevParser.RULE_codeCommands + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterCodeCommands" ): + listener.enterCodeCommands(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitCodeCommands" ): + listener.exitCodeCommands(self) + + + + + def codeCommands(self): + + localctx = AutolevParser.CodeCommandsContext(self, self._ctx, self.state) + self.enterRule(localctx, 32, self.RULE_codeCommands) + try: + self.state = 275 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [32]: + self.enterOuterAlt(localctx, 1) + self.state = 270 + self.units() + pass + elif token in [29]: + self.enterOuterAlt(localctx, 2) + self.state = 271 + self.inputs() + pass + elif token in [30]: + self.enterOuterAlt(localctx, 3) + self.state = 272 + self.outputs() + pass + elif token in [48]: + self.enterOuterAlt(localctx, 4) + self.state = 273 + self.codegen() + pass + elif token in [31, 33]: + self.enterOuterAlt(localctx, 5) + self.state = 274 + self.commands() + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class SettingsContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def ID(self, i:int=None): + if i is None: + return self.getTokens(AutolevParser.ID) + else: + return self.getToken(AutolevParser.ID, i) + + def EXP(self): + return self.getToken(AutolevParser.EXP, 0) + + def FLOAT(self): + return self.getToken(AutolevParser.FLOAT, 0) + + def INT(self): + return self.getToken(AutolevParser.INT, 0) + + def getRuleIndex(self): + return AutolevParser.RULE_settings + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterSettings" ): + listener.enterSettings(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitSettings" ): + listener.exitSettings(self) + + + + + def settings(self): + + localctx = AutolevParser.SettingsContext(self, self._ctx, self.state) + self.enterRule(localctx, 34, self.RULE_settings) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 277 + self.match(AutolevParser.ID) + self.state = 279 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,30,self._ctx) + if la_ == 1: + self.state = 278 + _la = self._input.LA(1) + if not(((_la) & ~0x3f) == 0 and ((1 << _la) & 404620279021568) != 0): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class UnitsContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def UnitSystem(self): + return self.getToken(AutolevParser.UnitSystem, 0) + + def ID(self, i:int=None): + if i is None: + return self.getTokens(AutolevParser.ID) + else: + return self.getToken(AutolevParser.ID, i) + + def getRuleIndex(self): + return AutolevParser.RULE_units + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterUnits" ): + listener.enterUnits(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitUnits" ): + listener.exitUnits(self) + + + + + def units(self): + + localctx = AutolevParser.UnitsContext(self, self._ctx, self.state) + self.enterRule(localctx, 36, self.RULE_units) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 281 + self.match(AutolevParser.UnitSystem) + self.state = 282 + self.match(AutolevParser.ID) + self.state = 287 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==10: + self.state = 283 + self.match(AutolevParser.T__9) + self.state = 284 + self.match(AutolevParser.ID) + self.state = 289 + self._errHandler.sync(self) + _la = self._input.LA(1) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class InputsContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def Input(self): + return self.getToken(AutolevParser.Input, 0) + + def inputs2(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.Inputs2Context) + else: + return self.getTypedRuleContext(AutolevParser.Inputs2Context,i) + + + def getRuleIndex(self): + return AutolevParser.RULE_inputs + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterInputs" ): + listener.enterInputs(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitInputs" ): + listener.exitInputs(self) + + + + + def inputs(self): + + localctx = AutolevParser.InputsContext(self, self._ctx, self.state) + self.enterRule(localctx, 38, self.RULE_inputs) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 290 + self.match(AutolevParser.Input) + self.state = 291 + self.inputs2() + self.state = 296 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==10: + self.state = 292 + self.match(AutolevParser.T__9) + self.state = 293 + self.inputs2() + self.state = 298 + self._errHandler.sync(self) + _la = self._input.LA(1) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class Id_diffContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def ID(self): + return self.getToken(AutolevParser.ID, 0) + + def diff(self): + return self.getTypedRuleContext(AutolevParser.DiffContext,0) + + + def getRuleIndex(self): + return AutolevParser.RULE_id_diff + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterId_diff" ): + listener.enterId_diff(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitId_diff" ): + listener.exitId_diff(self) + + + + + def id_diff(self): + + localctx = AutolevParser.Id_diffContext(self, self._ctx, self.state) + self.enterRule(localctx, 40, self.RULE_id_diff) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 299 + self.match(AutolevParser.ID) + self.state = 301 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==11: + self.state = 300 + self.diff() + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class Inputs2Context(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def id_diff(self): + return self.getTypedRuleContext(AutolevParser.Id_diffContext,0) + + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.ExprContext) + else: + return self.getTypedRuleContext(AutolevParser.ExprContext,i) + + + def getRuleIndex(self): + return AutolevParser.RULE_inputs2 + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterInputs2" ): + listener.enterInputs2(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitInputs2" ): + listener.exitInputs2(self) + + + + + def inputs2(self): + + localctx = AutolevParser.Inputs2Context(self, self._ctx, self.state) + self.enterRule(localctx, 42, self.RULE_inputs2) + try: + self.enterOuterAlt(localctx, 1) + self.state = 303 + self.id_diff() + self.state = 304 + self.match(AutolevParser.T__2) + self.state = 305 + self.expr(0) + self.state = 307 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,34,self._ctx) + if la_ == 1: + self.state = 306 + self.expr(0) + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class OutputsContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def Output(self): + return self.getToken(AutolevParser.Output, 0) + + def outputs2(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.Outputs2Context) + else: + return self.getTypedRuleContext(AutolevParser.Outputs2Context,i) + + + def getRuleIndex(self): + return AutolevParser.RULE_outputs + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterOutputs" ): + listener.enterOutputs(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitOutputs" ): + listener.exitOutputs(self) + + + + + def outputs(self): + + localctx = AutolevParser.OutputsContext(self, self._ctx, self.state) + self.enterRule(localctx, 44, self.RULE_outputs) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 309 + self.match(AutolevParser.Output) + self.state = 310 + self.outputs2() + self.state = 315 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==10: + self.state = 311 + self.match(AutolevParser.T__9) + self.state = 312 + self.outputs2() + self.state = 317 + self._errHandler.sync(self) + _la = self._input.LA(1) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class Outputs2Context(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.ExprContext) + else: + return self.getTypedRuleContext(AutolevParser.ExprContext,i) + + + def getRuleIndex(self): + return AutolevParser.RULE_outputs2 + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterOutputs2" ): + listener.enterOutputs2(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitOutputs2" ): + listener.exitOutputs2(self) + + + + + def outputs2(self): + + localctx = AutolevParser.Outputs2Context(self, self._ctx, self.state) + self.enterRule(localctx, 46, self.RULE_outputs2) + try: + self.enterOuterAlt(localctx, 1) + self.state = 318 + self.expr(0) + self.state = 320 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,36,self._ctx) + if la_ == 1: + self.state = 319 + self.expr(0) + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class CodegenContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def ID(self, i:int=None): + if i is None: + return self.getTokens(AutolevParser.ID) + else: + return self.getToken(AutolevParser.ID, i) + + def functionCall(self): + return self.getTypedRuleContext(AutolevParser.FunctionCallContext,0) + + + def matrixInOutput(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.MatrixInOutputContext) + else: + return self.getTypedRuleContext(AutolevParser.MatrixInOutputContext,i) + + + def getRuleIndex(self): + return AutolevParser.RULE_codegen + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterCodegen" ): + listener.enterCodegen(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitCodegen" ): + listener.exitCodegen(self) + + + + + def codegen(self): + + localctx = AutolevParser.CodegenContext(self, self._ctx, self.state) + self.enterRule(localctx, 48, self.RULE_codegen) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 322 + self.match(AutolevParser.ID) + self.state = 323 + self.functionCall() + self.state = 335 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==1: + self.state = 324 + self.match(AutolevParser.T__0) + self.state = 325 + self.matrixInOutput() + self.state = 330 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==10: + self.state = 326 + self.match(AutolevParser.T__9) + self.state = 327 + self.matrixInOutput() + self.state = 332 + self._errHandler.sync(self) + _la = self._input.LA(1) + + self.state = 333 + self.match(AutolevParser.T__1) + + + self.state = 337 + self.match(AutolevParser.ID) + self.state = 338 + self.match(AutolevParser.T__19) + self.state = 339 + self.match(AutolevParser.ID) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class CommandsContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def Save(self): + return self.getToken(AutolevParser.Save, 0) + + def ID(self, i:int=None): + if i is None: + return self.getTokens(AutolevParser.ID) + else: + return self.getToken(AutolevParser.ID, i) + + def Encode(self): + return self.getToken(AutolevParser.Encode, 0) + + def getRuleIndex(self): + return AutolevParser.RULE_commands + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterCommands" ): + listener.enterCommands(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitCommands" ): + listener.exitCommands(self) + + + + + def commands(self): + + localctx = AutolevParser.CommandsContext(self, self._ctx, self.state) + self.enterRule(localctx, 50, self.RULE_commands) + self._la = 0 # Token type + try: + self.state = 354 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [31]: + self.enterOuterAlt(localctx, 1) + self.state = 341 + self.match(AutolevParser.Save) + self.state = 342 + self.match(AutolevParser.ID) + self.state = 343 + self.match(AutolevParser.T__19) + self.state = 344 + self.match(AutolevParser.ID) + pass + elif token in [33]: + self.enterOuterAlt(localctx, 2) + self.state = 345 + self.match(AutolevParser.Encode) + self.state = 346 + self.match(AutolevParser.ID) + self.state = 351 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==10: + self.state = 347 + self.match(AutolevParser.T__9) + self.state = 348 + self.match(AutolevParser.ID) + self.state = 353 + self._errHandler.sync(self) + _la = self._input.LA(1) + + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class VecContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def ID(self): + return self.getToken(AutolevParser.ID, 0) + + def getRuleIndex(self): + return AutolevParser.RULE_vec + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterVec" ): + listener.enterVec(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitVec" ): + listener.exitVec(self) + + + + + def vec(self): + + localctx = AutolevParser.VecContext(self, self._ctx, self.state) + self.enterRule(localctx, 52, self.RULE_vec) + try: + self.state = 364 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [48]: + self.enterOuterAlt(localctx, 1) + self.state = 356 + self.match(AutolevParser.ID) + self.state = 358 + self._errHandler.sync(self) + _alt = 1 + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt == 1: + self.state = 357 + self.match(AutolevParser.T__20) + + else: + raise NoViableAltException(self) + self.state = 360 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,41,self._ctx) + + pass + elif token in [22]: + self.enterOuterAlt(localctx, 2) + self.state = 362 + self.match(AutolevParser.T__21) + pass + elif token in [23]: + self.enterOuterAlt(localctx, 3) + self.state = 363 + self.match(AutolevParser.T__22) + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class ExprContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + + def getRuleIndex(self): + return AutolevParser.RULE_expr + + + def copyFrom(self, ctx:ParserRuleContext): + super().copyFrom(ctx) + + + class ParensContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def expr(self): + return self.getTypedRuleContext(AutolevParser.ExprContext,0) + + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterParens" ): + listener.enterParens(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitParens" ): + listener.exitParens(self) + + + class VectorOrDyadicContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def vec(self): + return self.getTypedRuleContext(AutolevParser.VecContext,0) + + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterVectorOrDyadic" ): + listener.enterVectorOrDyadic(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitVectorOrDyadic" ): + listener.exitVectorOrDyadic(self) + + + class ExponentContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.ExprContext) + else: + return self.getTypedRuleContext(AutolevParser.ExprContext,i) + + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterExponent" ): + listener.enterExponent(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitExponent" ): + listener.exitExponent(self) + + + class MulDivContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.ExprContext) + else: + return self.getTypedRuleContext(AutolevParser.ExprContext,i) + + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterMulDiv" ): + listener.enterMulDiv(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitMulDiv" ): + listener.exitMulDiv(self) + + + class AddSubContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.ExprContext) + else: + return self.getTypedRuleContext(AutolevParser.ExprContext,i) + + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterAddSub" ): + listener.enterAddSub(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitAddSub" ): + listener.exitAddSub(self) + + + class FloatContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def FLOAT(self): + return self.getToken(AutolevParser.FLOAT, 0) + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterFloat" ): + listener.enterFloat(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitFloat" ): + listener.exitFloat(self) + + + class IntContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def INT(self): + return self.getToken(AutolevParser.INT, 0) + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterInt" ): + listener.enterInt(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitInt" ): + listener.exitInt(self) + + + class IdEqualsExprContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.ExprContext) + else: + return self.getTypedRuleContext(AutolevParser.ExprContext,i) + + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterIdEqualsExpr" ): + listener.enterIdEqualsExpr(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitIdEqualsExpr" ): + listener.exitIdEqualsExpr(self) + + + class NegativeOneContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def expr(self): + return self.getTypedRuleContext(AutolevParser.ExprContext,0) + + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterNegativeOne" ): + listener.enterNegativeOne(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitNegativeOne" ): + listener.exitNegativeOne(self) + + + class FunctionContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def functionCall(self): + return self.getTypedRuleContext(AutolevParser.FunctionCallContext,0) + + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterFunction" ): + listener.enterFunction(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitFunction" ): + listener.exitFunction(self) + + + class RangessContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def ranges(self): + return self.getTypedRuleContext(AutolevParser.RangesContext,0) + + def ID(self): + return self.getToken(AutolevParser.ID, 0) + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterRangess" ): + listener.enterRangess(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitRangess" ): + listener.exitRangess(self) + + + class ColonContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.ExprContext) + else: + return self.getTypedRuleContext(AutolevParser.ExprContext,i) + + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterColon" ): + listener.enterColon(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitColon" ): + listener.exitColon(self) + + + class IdContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def ID(self): + return self.getToken(AutolevParser.ID, 0) + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterId" ): + listener.enterId(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitId" ): + listener.exitId(self) + + + class ExpContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def EXP(self): + return self.getToken(AutolevParser.EXP, 0) + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterExp" ): + listener.enterExp(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitExp" ): + listener.exitExp(self) + + + class MatricesContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def matrix(self): + return self.getTypedRuleContext(AutolevParser.MatrixContext,0) + + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterMatrices" ): + listener.enterMatrices(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitMatrices" ): + listener.exitMatrices(self) + + + class IndexingContext(ExprContext): + + def __init__(self, parser, ctx:ParserRuleContext): # actually a AutolevParser.ExprContext + super().__init__(parser) + self.copyFrom(ctx) + + def ID(self): + return self.getToken(AutolevParser.ID, 0) + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(AutolevParser.ExprContext) + else: + return self.getTypedRuleContext(AutolevParser.ExprContext,i) + + + def enterRule(self, listener:ParseTreeListener): + if hasattr( listener, "enterIndexing" ): + listener.enterIndexing(self) + + def exitRule(self, listener:ParseTreeListener): + if hasattr( listener, "exitIndexing" ): + listener.exitIndexing(self) + + + + def expr(self, _p:int=0): + _parentctx = self._ctx + _parentState = self.state + localctx = AutolevParser.ExprContext(self, self._ctx, _parentState) + _prevctx = localctx + _startState = 54 + self.enterRecursionRule(localctx, 54, self.RULE_expr, _p) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 408 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,47,self._ctx) + if la_ == 1: + localctx = AutolevParser.ExpContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + + self.state = 367 + self.match(AutolevParser.EXP) + pass + + elif la_ == 2: + localctx = AutolevParser.NegativeOneContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 368 + self.match(AutolevParser.T__17) + self.state = 369 + self.expr(12) + pass + + elif la_ == 3: + localctx = AutolevParser.FloatContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 370 + self.match(AutolevParser.FLOAT) + pass + + elif la_ == 4: + localctx = AutolevParser.IntContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 371 + self.match(AutolevParser.INT) + pass + + elif la_ == 5: + localctx = AutolevParser.IdContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 372 + self.match(AutolevParser.ID) + self.state = 376 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,43,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: + self.state = 373 + self.match(AutolevParser.T__10) + self.state = 378 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,43,self._ctx) + + pass + + elif la_ == 6: + localctx = AutolevParser.VectorOrDyadicContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 379 + self.vec() + pass + + elif la_ == 7: + localctx = AutolevParser.IndexingContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 380 + self.match(AutolevParser.ID) + self.state = 381 + self.match(AutolevParser.T__0) + self.state = 382 + self.expr(0) + self.state = 387 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==10: + self.state = 383 + self.match(AutolevParser.T__9) + self.state = 384 + self.expr(0) + self.state = 389 + self._errHandler.sync(self) + _la = self._input.LA(1) + + self.state = 390 + self.match(AutolevParser.T__1) + pass + + elif la_ == 8: + localctx = AutolevParser.FunctionContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 392 + self.functionCall() + pass + + elif la_ == 9: + localctx = AutolevParser.MatricesContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 393 + self.matrix() + pass + + elif la_ == 10: + localctx = AutolevParser.ParensContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 394 + self.match(AutolevParser.T__11) + self.state = 395 + self.expr(0) + self.state = 396 + self.match(AutolevParser.T__12) + pass + + elif la_ == 11: + localctx = AutolevParser.RangessContext(self, localctx) + self._ctx = localctx + _prevctx = localctx + self.state = 399 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==48: + self.state = 398 + self.match(AutolevParser.ID) + + + self.state = 401 + self.ranges() + self.state = 405 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,46,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: + self.state = 402 + self.match(AutolevParser.T__10) + self.state = 407 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,46,self._ctx) + + pass + + + self._ctx.stop = self._input.LT(-1) + self.state = 427 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,49,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: + if self._parseListeners is not None: + self.triggerExitRuleEvent() + _prevctx = localctx + self.state = 425 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,48,self._ctx) + if la_ == 1: + localctx = AutolevParser.ExponentContext(self, AutolevParser.ExprContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) + self.state = 410 + if not self.precpred(self._ctx, 16): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 16)") + self.state = 411 + self.match(AutolevParser.T__23) + self.state = 412 + self.expr(17) + pass + + elif la_ == 2: + localctx = AutolevParser.MulDivContext(self, AutolevParser.ExprContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) + self.state = 413 + if not self.precpred(self._ctx, 15): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 15)") + self.state = 414 + _la = self._input.LA(1) + if not(_la==25 or _la==26): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 415 + self.expr(16) + pass + + elif la_ == 3: + localctx = AutolevParser.AddSubContext(self, AutolevParser.ExprContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) + self.state = 416 + if not self.precpred(self._ctx, 14): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 14)") + self.state = 417 + _la = self._input.LA(1) + if not(_la==17 or _la==18): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 418 + self.expr(15) + pass + + elif la_ == 4: + localctx = AutolevParser.IdEqualsExprContext(self, AutolevParser.ExprContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) + self.state = 419 + if not self.precpred(self._ctx, 3): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 3)") + self.state = 420 + self.match(AutolevParser.T__2) + self.state = 421 + self.expr(4) + pass + + elif la_ == 5: + localctx = AutolevParser.ColonContext(self, AutolevParser.ExprContext(self, _parentctx, _parentState)) + self.pushNewRecursionContext(localctx, _startState, self.RULE_expr) + self.state = 422 + if not self.precpred(self._ctx, 2): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 2)") + self.state = 423 + self.match(AutolevParser.T__15) + self.state = 424 + self.expr(3) + pass + + + self.state = 429 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,49,self._ctx) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.unrollRecursionContexts(_parentctx) + return localctx + + + + def sempred(self, localctx:RuleContext, ruleIndex:int, predIndex:int): + if self._predicates == None: + self._predicates = dict() + self._predicates[27] = self.expr_sempred + pred = self._predicates.get(ruleIndex, None) + if pred is None: + raise Exception("No predicate with index:" + str(ruleIndex)) + else: + return pred(localctx, predIndex) + + def expr_sempred(self, localctx:ExprContext, predIndex:int): + if predIndex == 0: + return self.precpred(self._ctx, 16) + + + if predIndex == 1: + return self.precpred(self._ctx, 15) + + + if predIndex == 2: + return self.precpred(self._ctx, 14) + + + if predIndex == 3: + return self.precpred(self._ctx, 3) + + + if predIndex == 4: + return self.precpred(self._ctx, 2) + + + + + diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/_build_autolev_antlr.py b/phivenv/Lib/site-packages/sympy/parsing/autolev/_build_autolev_antlr.py new file mode 100644 index 0000000000000000000000000000000000000000..8314b2f546c0a18a8e281768b60d66556c852e3b --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/autolev/_build_autolev_antlr.py @@ -0,0 +1,86 @@ +import os +import subprocess +import glob + +from sympy.utilities.misc import debug + +here = os.path.dirname(__file__) +grammar_file = os.path.abspath(os.path.join(here, "Autolev.g4")) +dir_autolev_antlr = os.path.join(here, "_antlr") + +header = '''\ +# *** GENERATED BY `setup.py antlr`, DO NOT EDIT BY HAND *** +# +# Generated with antlr4 +# antlr4 is licensed under the BSD-3-Clause License +# https://github.com/antlr/antlr4/blob/master/LICENSE.txt +''' + + +def check_antlr_version(): + debug("Checking antlr4 version...") + + try: + debug(subprocess.check_output(["antlr4"]) + .decode('utf-8').split("\n")[0]) + return True + except (subprocess.CalledProcessError, FileNotFoundError): + debug("The 'antlr4' command line tool is not installed, " + "or not on your PATH.\n" + "> Please refer to the README.md file for more information.") + return False + + +def build_parser(output_dir=dir_autolev_antlr): + check_antlr_version() + + debug("Updating ANTLR-generated code in {}".format(output_dir)) + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + with open(os.path.join(output_dir, "__init__.py"), "w+") as fp: + fp.write(header) + + args = [ + "antlr4", + grammar_file, + "-o", output_dir, + "-no-visitor", + ] + + debug("Running code generation...\n\t$ {}".format(" ".join(args))) + subprocess.check_output(args, cwd=output_dir) + + debug("Applying headers, removing unnecessary files and renaming...") + # Handle case insensitive file systems. If the files are already + # generated, they will be written to autolev* but Autolev*.* won't match them. + for path in (glob.glob(os.path.join(output_dir, "Autolev*.*")) or + glob.glob(os.path.join(output_dir, "autolev*.*"))): + + # Remove files ending in .interp or .tokens as they are not needed. + if not path.endswith(".py"): + os.unlink(path) + continue + + new_path = os.path.join(output_dir, os.path.basename(path).lower()) + with open(path, 'r') as f: + lines = [line.rstrip().replace('AutolevParser import', 'autolevparser import') +'\n' + for line in f] + + os.unlink(path) + + with open(new_path, "w") as out_file: + offset = 0 + while lines[offset].startswith('#'): + offset += 1 + out_file.write(header) + out_file.writelines(lines[offset:]) + + debug("\t{}".format(new_path)) + + return True + + +if __name__ == "__main__": + build_parser() diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/_listener_autolev_antlr.py b/phivenv/Lib/site-packages/sympy/parsing/autolev/_listener_autolev_antlr.py new file mode 100644 index 0000000000000000000000000000000000000000..9ca2f8af88de18036b90788fd29d02707f098213 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/autolev/_listener_autolev_antlr.py @@ -0,0 +1,2083 @@ +import collections +import warnings + +from sympy.external import import_module + +autolevparser = import_module('sympy.parsing.autolev._antlr.autolevparser', + import_kwargs={'fromlist': ['AutolevParser']}) +autolevlexer = import_module('sympy.parsing.autolev._antlr.autolevlexer', + import_kwargs={'fromlist': ['AutolevLexer']}) +autolevlistener = import_module('sympy.parsing.autolev._antlr.autolevlistener', + import_kwargs={'fromlist': ['AutolevListener']}) + +AutolevParser = getattr(autolevparser, 'AutolevParser', None) +AutolevLexer = getattr(autolevlexer, 'AutolevLexer', None) +AutolevListener = getattr(autolevlistener, 'AutolevListener', None) + + +def strfunc(z): + if z == 0: + return "" + elif z == 1: + return "_d" + else: + return "_" + "d" * z + +def declare_phy_entities(self, ctx, phy_type, i, j=None): + if phy_type in ("frame", "newtonian"): + declare_frames(self, ctx, i, j) + elif phy_type == "particle": + declare_particles(self, ctx, i, j) + elif phy_type == "point": + declare_points(self, ctx, i, j) + elif phy_type == "bodies": + declare_bodies(self, ctx, i, j) + +def declare_frames(self, ctx, i, j=None): + if "{" in ctx.getText(): + if j: + name1 = ctx.ID().getText().lower() + str(i) + str(j) + else: + name1 = ctx.ID().getText().lower() + str(i) + else: + name1 = ctx.ID().getText().lower() + name2 = "frame_" + name1 + if self.getValue(ctx.parentCtx.varType()) == "newtonian": + self.newtonian = name2 + + self.symbol_table2.update({name1: name2}) + + self.symbol_table.update({name1 + "1>": name2 + ".x"}) + self.symbol_table.update({name1 + "2>": name2 + ".y"}) + self.symbol_table.update({name1 + "3>": name2 + ".z"}) + + self.type2.update({name1: "frame"}) + self.write(name2 + " = " + "_me.ReferenceFrame('" + name1 + "')\n") + +def declare_points(self, ctx, i, j=None): + if "{" in ctx.getText(): + if j: + name1 = ctx.ID().getText().lower() + str(i) + str(j) + else: + name1 = ctx.ID().getText().lower() + str(i) + else: + name1 = ctx.ID().getText().lower() + + name2 = "point_" + name1 + + self.symbol_table2.update({name1: name2}) + self.type2.update({name1: "point"}) + self.write(name2 + " = " + "_me.Point('" + name1 + "')\n") + +def declare_particles(self, ctx, i, j=None): + if "{" in ctx.getText(): + if j: + name1 = ctx.ID().getText().lower() + str(i) + str(j) + else: + name1 = ctx.ID().getText().lower() + str(i) + else: + name1 = ctx.ID().getText().lower() + + name2 = "particle_" + name1 + + self.symbol_table2.update({name1: name2}) + self.type2.update({name1: "particle"}) + self.bodies.update({name1: name2}) + self.write(name2 + " = " + "_me.Particle('" + name1 + "', " + "_me.Point('" + + name1 + "_pt" + "'), " + "_sm.Symbol('m'))\n") + +def declare_bodies(self, ctx, i, j=None): + if "{" in ctx.getText(): + if j: + name1 = ctx.ID().getText().lower() + str(i) + str(j) + else: + name1 = ctx.ID().getText().lower() + str(i) + else: + name1 = ctx.ID().getText().lower() + + name2 = "body_" + name1 + self.bodies.update({name1: name2}) + masscenter = name2 + "_cm" + refFrame = name2 + "_f" + + self.symbol_table2.update({name1: name2}) + self.symbol_table2.update({name1 + "o": masscenter}) + self.symbol_table.update({name1 + "1>": refFrame+".x"}) + self.symbol_table.update({name1 + "2>": refFrame+".y"}) + self.symbol_table.update({name1 + "3>": refFrame+".z"}) + + self.type2.update({name1: "bodies"}) + self.type2.update({name1+"o": "point"}) + + self.write(masscenter + " = " + "_me.Point('" + name1 + "_cm" + "')\n") + if self.newtonian: + self.write(masscenter + ".set_vel(" + self.newtonian + ", " + "0)\n") + self.write(refFrame + " = " + "_me.ReferenceFrame('" + name1 + "_f" + "')\n") + # We set a dummy mass and inertia here. + # They will be reset using the setters later in the code anyway. + self.write(name2 + " = " + "_me.RigidBody('" + name1 + "', " + masscenter + ", " + + refFrame + ", " + "_sm.symbols('m'), (_me.outer(" + refFrame + + ".x," + refFrame + ".x)," + masscenter + "))\n") + +def inertia_func(self, v1, v2, l, frame): + + if self.type2[v1] == "particle": + l.append("_me.inertia_of_point_mass(" + self.bodies[v1] + ".mass, " + self.bodies[v1] + + ".point.pos_from(" + self.symbol_table2[v2] + "), " + frame + ")") + + elif self.type2[v1] == "bodies": + # Inertia has been defined about center of mass. + if self.inertia_point[v1] == v1 + "o": + # Asking point is cm as well + if v2 == self.inertia_point[v1]: + l.append(self.symbol_table2[v1] + ".inertia[0]") + + # Asking point is not cm + else: + l.append(self.bodies[v1] + ".inertia[0]" + " + " + + "_me.inertia_of_point_mass(" + self.bodies[v1] + + ".mass, " + self.bodies[v1] + ".masscenter" + + ".pos_from(" + self.symbol_table2[v2] + + "), " + frame + ")") + + # Inertia has been defined about another point + else: + # Asking point is the defined point + if v2 == self.inertia_point[v1]: + l.append(self.symbol_table2[v1] + ".inertia[0]") + # Asking point is cm + elif v2 == v1 + "o": + l.append(self.bodies[v1] + ".inertia[0]" + " - " + + "_me.inertia_of_point_mass(" + self.bodies[v1] + + ".mass, " + self.bodies[v1] + ".masscenter" + + ".pos_from(" + self.symbol_table2[self.inertia_point[v1]] + + "), " + frame + ")") + # Asking point is some other point + else: + l.append(self.bodies[v1] + ".inertia[0]" + " - " + + "_me.inertia_of_point_mass(" + self.bodies[v1] + + ".mass, " + self.bodies[v1] + ".masscenter" + + ".pos_from(" + self.symbol_table2[self.inertia_point[v1]] + + "), " + frame + ")" + " + " + + "_me.inertia_of_point_mass(" + self.bodies[v1] + + ".mass, " + self.bodies[v1] + ".masscenter" + + ".pos_from(" + self.symbol_table2[v2] + + "), " + frame + ")") + + +def processConstants(self, ctx): + # Process constant declarations of the type: Constants F = 3, g = 9.81 + name = ctx.ID().getText().lower() + if "=" in ctx.getText(): + self.symbol_table.update({name: name}) + # self.inputs.update({self.symbol_table[name]: self.getValue(ctx.getChild(2))}) + self.write(self.symbol_table[name] + " = " + "_sm.S(" + self.getValue(ctx.getChild(2)) + ")\n") + self.type.update({name: "constants"}) + return + + # Constants declarations of the type: Constants A, B + else: + if "{" not in ctx.getText(): + self.symbol_table[name] = name + self.type[name] = "constants" + + # Process constant declarations of the type: Constants C+, D- + if ctx.getChildCount() == 2: + # This is set for declaring nonpositive=True and nonnegative=True + if ctx.getChild(1).getText() == "+": + self.sign[name] = "+" + elif ctx.getChild(1).getText() == "-": + self.sign[name] = "-" + else: + if "{" not in ctx.getText(): + self.sign[name] = "o" + + # Process constant declarations of the type: Constants K{4}, a{1:2, 1:2}, b{1:2} + if "{" in ctx.getText(): + if ":" in ctx.getText(): + num1 = int(ctx.INT(0).getText()) + num2 = int(ctx.INT(1).getText()) + 1 + else: + num1 = 1 + num2 = int(ctx.INT(0).getText()) + 1 + + if ":" in ctx.getText(): + if "," in ctx.getText(): + num3 = int(ctx.INT(2).getText()) + num4 = int(ctx.INT(3).getText()) + 1 + for i in range(num1, num2): + for j in range(num3, num4): + self.symbol_table[name + str(i) + str(j)] = name + str(i) + str(j) + self.type[name + str(i) + str(j)] = "constants" + self.var_list.append(name + str(i) + str(j)) + self.sign[name + str(i) + str(j)] = "o" + else: + for i in range(num1, num2): + self.symbol_table[name + str(i)] = name + str(i) + self.type[name + str(i)] = "constants" + self.var_list.append(name + str(i)) + self.sign[name + str(i)] = "o" + + elif "," in ctx.getText(): + for i in range(1, int(ctx.INT(0).getText()) + 1): + for j in range(1, int(ctx.INT(1).getText()) + 1): + self.symbol_table[name] = name + str(i) + str(j) + self.type[name + str(i) + str(j)] = "constants" + self.var_list.append(name + str(i) + str(j)) + self.sign[name + str(i) + str(j)] = "o" + + else: + for i in range(num1, num2): + self.symbol_table[name + str(i)] = name + str(i) + self.type[name + str(i)] = "constants" + self.var_list.append(name + str(i)) + self.sign[name + str(i)] = "o" + + if "{" not in ctx.getText(): + self.var_list.append(name) + + +def writeConstants(self, ctx): + l1 = list(filter(lambda x: self.sign[x] == "o", self.var_list)) + l2 = list(filter(lambda x: self.sign[x] == "+", self.var_list)) + l3 = list(filter(lambda x: self.sign[x] == "-", self.var_list)) + try: + if self.settings["complex"] == "on": + real = ", real=True" + elif self.settings["complex"] == "off": + real = "" + except Exception: + real = ", real=True" + + if l1: + a = ", ".join(l1) + " = " + "_sm.symbols(" + "'" +\ + " ".join(l1) + "'" + real + ")\n" + self.write(a) + if l2: + a = ", ".join(l2) + " = " + "_sm.symbols(" + "'" +\ + " ".join(l2) + "'" + real + ", nonnegative=True)\n" + self.write(a) + if l3: + a = ", ".join(l3) + " = " + "_sm.symbols(" + "'" + \ + " ".join(l3) + "'" + real + ", nonpositive=True)\n" + self.write(a) + self.var_list = [] + + +def processVariables(self, ctx): + # Specified F = x*N1> + y*N2> + name = ctx.ID().getText().lower() + if "=" in ctx.getText(): + text = name + "'"*(ctx.getChildCount()-3) + self.write(text + " = " + self.getValue(ctx.expr()) + "\n") + return + + # Process variables of the type: Variables qA, qB + if ctx.getChildCount() == 1: + self.symbol_table[name] = name + if self.getValue(ctx.parentCtx.getChild(0)) in ("variable", "specified", "motionvariable", "motionvariable'"): + self.type.update({name: self.getValue(ctx.parentCtx.getChild(0))}) + + self.var_list.append(name) + self.sign[name] = 0 + + # Process variables of the type: Variables x', y'' + elif "'" in ctx.getText() and "{" not in ctx.getText(): + if ctx.getText().count("'") > self.maxDegree: + self.maxDegree = ctx.getText().count("'") + for i in range(ctx.getChildCount()): + self.sign[name + strfunc(i)] = i + self.symbol_table[name + "'"*i] = name + strfunc(i) + if self.getValue(ctx.parentCtx.getChild(0)) in ("variable", "specified", "motionvariable", "motionvariable'"): + self.type.update({name + "'"*i: self.getValue(ctx.parentCtx.getChild(0))}) + self.var_list.append(name + strfunc(i)) + + elif "{" in ctx.getText(): + # Process variables of the type: Variables x{3}, y{2} + + if "'" in ctx.getText(): + dash_count = ctx.getText().count("'") + if dash_count > self.maxDegree: + self.maxDegree = dash_count + + if ":" in ctx.getText(): + # Variables C{1:2, 1:2} + if "," in ctx.getText(): + num1 = int(ctx.INT(0).getText()) + num2 = int(ctx.INT(1).getText()) + 1 + num3 = int(ctx.INT(2).getText()) + num4 = int(ctx.INT(3).getText()) + 1 + # Variables C{1:2} + else: + num1 = int(ctx.INT(0).getText()) + num2 = int(ctx.INT(1).getText()) + 1 + + # Variables C{1,3} + elif "," in ctx.getText(): + num1 = 1 + num2 = int(ctx.INT(0).getText()) + 1 + num3 = 1 + num4 = int(ctx.INT(1).getText()) + 1 + else: + num1 = 1 + num2 = int(ctx.INT(0).getText()) + 1 + + for i in range(num1, num2): + try: + for j in range(num3, num4): + try: + for z in range(dash_count+1): + self.symbol_table.update({name + str(i) + str(j) + "'"*z: name + str(i) + str(j) + strfunc(z)}) + if self.getValue(ctx.parentCtx.getChild(0)) in ("variable", "specified", "motionvariable", "motionvariable'"): + self.type.update({name + str(i) + str(j) + "'"*z: self.getValue(ctx.parentCtx.getChild(0))}) + self.var_list.append(name + str(i) + str(j) + strfunc(z)) + self.sign.update({name + str(i) + str(j) + strfunc(z): z}) + if dash_count > self.maxDegree: + self.maxDegree = dash_count + except Exception: + self.symbol_table.update({name + str(i) + str(j): name + str(i) + str(j)}) + if self.getValue(ctx.parentCtx.getChild(0)) in ("variable", "specified", "motionvariable", "motionvariable'"): + self.type.update({name + str(i) + str(j): self.getValue(ctx.parentCtx.getChild(0))}) + self.var_list.append(name + str(i) + str(j)) + self.sign.update({name + str(i) + str(j): 0}) + except Exception: + try: + for z in range(dash_count+1): + self.symbol_table.update({name + str(i) + "'"*z: name + str(i) + strfunc(z)}) + if self.getValue(ctx.parentCtx.getChild(0)) in ("variable", "specified", "motionvariable", "motionvariable'"): + self.type.update({name + str(i) + "'"*z: self.getValue(ctx.parentCtx.getChild(0))}) + self.var_list.append(name + str(i) + strfunc(z)) + self.sign.update({name + str(i) + strfunc(z): z}) + if dash_count > self.maxDegree: + self.maxDegree = dash_count + except Exception: + self.symbol_table.update({name + str(i): name + str(i)}) + if self.getValue(ctx.parentCtx.getChild(0)) in ("variable", "specified", "motionvariable", "motionvariable'"): + self.type.update({name + str(i): self.getValue(ctx.parentCtx.getChild(0))}) + self.var_list.append(name + str(i)) + self.sign.update({name + str(i): 0}) + +def writeVariables(self, ctx): + #print(self.sign) + #print(self.symbol_table) + if self.var_list: + for i in range(self.maxDegree+1): + if i == 0: + j = "" + t = "" + else: + j = str(i) + t = ", " + l = [] + for k in list(filter(lambda x: self.sign[x] == i, self.var_list)): + if i == 0: + l.append(k) + if i == 1: + l.append(k[:-1]) + if i > 1: + l.append(k[:-2]) + a = ", ".join(list(filter(lambda x: self.sign[x] == i, self.var_list))) + " = " +\ + "_me.dynamicsymbols(" + "'" + " ".join(l) + "'" + t + j + ")\n" + l = [] + self.write(a) + self.maxDegree = 0 + self.var_list = [] + +def processImaginary(self, ctx): + name = ctx.ID().getText().lower() + self.symbol_table[name] = name + self.type[name] = "imaginary" + self.var_list.append(name) + + +def writeImaginary(self, ctx): + a = ", ".join(self.var_list) + " = " + "_sm.symbols(" + "'" + \ + " ".join(self.var_list) + "')\n" + b = ", ".join(self.var_list) + " = " + "_sm.I\n" + self.write(a) + self.write(b) + self.var_list = [] + +if AutolevListener: + class MyListener(AutolevListener): # type: ignore + def __init__(self, include_numeric=False): + # Stores data in tree nodes(tree annotation). Especially useful for expr reconstruction. + self.tree_property = {} + + # Stores the declared variables, constants etc as they are declared in Autolev and SymPy + # {"": ""}. + self.symbol_table = collections.OrderedDict() + + # Similar to symbol_table. Used for storing Physical entities like Frames, Points, + # Particles, Bodies etc + self.symbol_table2 = collections.OrderedDict() + + # Used to store nonpositive, nonnegative etc for constants and number of "'"s (order of diff) + # in variables. + self.sign = {} + + # Simple list used as a store to pass around variables between the 'process' and 'write' + # methods. + self.var_list = [] + + # Stores the type of a declared variable (constants, variables, specifieds etc) + self.type = collections.OrderedDict() + + # Similar to self.type. Used for storing the type of Physical entities like Frames, Points, + # Particles, Bodies etc + self.type2 = collections.OrderedDict() + + # These lists are used to distinguish matrix, numeric and vector expressions. + self.matrix_expr = [] + self.numeric_expr = [] + self.vector_expr = [] + self.fr_expr = [] + + self.output_code = [] + + # Stores the variables and their rhs for substituting upon the Autolev command EXPLICIT. + self.explicit = collections.OrderedDict() + + # Write code to import common dependencies. + self.output_code.append("import sympy.physics.mechanics as _me\n") + self.output_code.append("import sympy as _sm\n") + self.output_code.append("import math as m\n") + self.output_code.append("import numpy as _np\n") + self.output_code.append("\n") + + # Just a store for the max degree variable in a line. + self.maxDegree = 0 + + # Stores the input parameters which are then used for codegen and numerical analysis. + self.inputs = collections.OrderedDict() + # Stores the variables which appear in Output Autolev commands. + self.outputs = [] + # Stores the settings specified by the user. Ex: Complex on/off, Degrees on/off + self.settings = {} + # Boolean which changes the behaviour of some expression reconstruction + # when parsing Input Autolev commands. + self.in_inputs = False + self.in_outputs = False + + # Stores for the physical entities. + self.newtonian = None + self.bodies = collections.OrderedDict() + self.constants = [] + self.forces = collections.OrderedDict() + self.q_ind = [] + self.q_dep = [] + self.u_ind = [] + self.u_dep = [] + self.kd_eqs = [] + self.dependent_variables = [] + self.kd_equivalents = collections.OrderedDict() + self.kd_equivalents2 = collections.OrderedDict() + self.kd_eqs_supplied = None + self.kane_type = "no_args" + self.inertia_point = collections.OrderedDict() + self.kane_parsed = False + self.t = False + + # PyDy ode code will be included only if this flag is set to True. + self.include_numeric = include_numeric + + def write(self, string): + self.output_code.append(string) + + def getValue(self, node): + return self.tree_property[node] + + def setValue(self, node, value): + self.tree_property[node] = value + + def getSymbolTable(self): + return self.symbol_table + + def getType(self): + return self.type + + def exitVarDecl(self, ctx): + # This event method handles variable declarations. The parse tree node varDecl contains + # one or more varDecl2 nodes. Eg varDecl for 'Constants a{1:2, 1:2}, b{1:2}' has two varDecl2 + # nodes(one for a{1:2, 1:2} and one for b{1:2}). + + # Variable declarations are processed and stored in the event method exitVarDecl2. + # This stored information is used to write the final SymPy output code in the exitVarDecl event method. + + # determine the type of declaration + if self.getValue(ctx.varType()) == "constant": + writeConstants(self, ctx) + elif self.getValue(ctx.varType()) in\ + ("variable", "motionvariable", "motionvariable'", "specified"): + writeVariables(self, ctx) + elif self.getValue(ctx.varType()) == "imaginary": + writeImaginary(self, ctx) + + def exitVarType(self, ctx): + # Annotate the varType tree node with the type of the variable declaration. + name = ctx.getChild(0).getText().lower() + if name[-1] == "s" and name != "bodies": + self.setValue(ctx, name[:-1]) + else: + self.setValue(ctx, name) + + def exitVarDecl2(self, ctx): + # Variable declarations are processed and stored in the event method exitVarDecl2. + # This stored information is used to write the final SymPy output code in the exitVarDecl event method. + # This is the case for constants, variables, specifieds etc. + + # This isn't the case for all types of declarations though. For instance + # Frames A, B, C, N cannot be defined on one line in SymPy. So we do not append A, B, C, N + # to a var_list or use exitVarDecl. exitVarDecl2 directly writes out to the file. + + # determine the type of declaration + if self.getValue(ctx.parentCtx.varType()) == "constant": + processConstants(self, ctx) + + elif self.getValue(ctx.parentCtx.varType()) in \ + ("variable", "motionvariable", "motionvariable'", "specified"): + processVariables(self, ctx) + + elif self.getValue(ctx.parentCtx.varType()) == "imaginary": + processImaginary(self, ctx) + + elif self.getValue(ctx.parentCtx.varType()) in ("frame", "newtonian", "point", "particle", "bodies"): + if "{" in ctx.getText(): + if ":" in ctx.getText() and "," not in ctx.getText(): + num1 = int(ctx.INT(0).getText()) + num2 = int(ctx.INT(1).getText()) + 1 + elif ":" not in ctx.getText() and "," in ctx.getText(): + num1 = 1 + num2 = int(ctx.INT(0).getText()) + 1 + num3 = 1 + num4 = int(ctx.INT(1).getText()) + 1 + elif ":" in ctx.getText() and "," in ctx.getText(): + num1 = int(ctx.INT(0).getText()) + num2 = int(ctx.INT(1).getText()) + 1 + num3 = int(ctx.INT(2).getText()) + num4 = int(ctx.INT(3).getText()) + 1 + else: + num1 = 1 + num2 = int(ctx.INT(0).getText()) + 1 + else: + num1 = 1 + num2 = 2 + for i in range(num1, num2): + try: + for j in range(num3, num4): + declare_phy_entities(self, ctx, self.getValue(ctx.parentCtx.varType()), i, j) + except Exception: + declare_phy_entities(self, ctx, self.getValue(ctx.parentCtx.varType()), i) + # ================== Subrules of parser rule expr (Start) ====================== # + + def exitId(self, ctx): + # Tree annotation for ID which is a labeled subrule of the parser rule expr. + # A_C + python_keywords = ["and", "as", "assert", "break", "class", "continue", "def", "del", "elif", "else", "except",\ + "exec", "finally", "for", "from", "global", "if", "import", "in", "is", "lambda", "not", "or", "pass", "print",\ + "raise", "return", "try", "while", "with", "yield"] + + if ctx.ID().getText().lower() in python_keywords: + warnings.warn("Python keywords must not be used as identifiers. Please refer to the list of keywords at https://docs.python.org/2.5/ref/keywords.html", + SyntaxWarning) + + if "_" in ctx.ID().getText() and ctx.ID().getText().count('_') == 1: + e1, e2 = ctx.ID().getText().lower().split('_') + try: + if self.type2[e1] == "frame": + e1 = self.symbol_table2[e1] + elif self.type2[e1] == "bodies": + e1 = self.symbol_table2[e1] + "_f" + if self.type2[e2] == "frame": + e2 = self.symbol_table2[e2] + elif self.type2[e2] == "bodies": + e2 = self.symbol_table2[e2] + "_f" + + self.setValue(ctx, e1 + ".dcm(" + e2 + ")") + except Exception: + self.setValue(ctx, ctx.ID().getText().lower()) + else: + # Reserved constant Pi + if ctx.ID().getText().lower() == "pi": + self.setValue(ctx, "_sm.pi") + self.numeric_expr.append(ctx) + + # Reserved variable T (for time) + elif ctx.ID().getText().lower() == "t": + self.setValue(ctx, "_me.dynamicsymbols._t") + if not self.in_inputs and not self.in_outputs: + self.t = True + + else: + idText = ctx.ID().getText().lower() + "'"*(ctx.getChildCount() - 1) + if idText in self.type.keys() and self.type[idText] == "matrix": + self.matrix_expr.append(ctx) + if self.in_inputs: + try: + self.setValue(ctx, self.symbol_table[idText]) + except Exception: + self.setValue(ctx, idText.lower()) + else: + try: + self.setValue(ctx, self.symbol_table[idText]) + except Exception: + pass + + def exitInt(self, ctx): + # Tree annotation for int which is a labeled subrule of the parser rule expr. + int_text = ctx.INT().getText() + self.setValue(ctx, int_text) + self.numeric_expr.append(ctx) + + def exitFloat(self, ctx): + # Tree annotation for float which is a labeled subrule of the parser rule expr. + floatText = ctx.FLOAT().getText() + self.setValue(ctx, floatText) + self.numeric_expr.append(ctx) + + def exitAddSub(self, ctx): + # Tree annotation for AddSub which is a labeled subrule of the parser rule expr. + # The subrule is expr = expr (+|-) expr + if ctx.expr(0) in self.matrix_expr or ctx.expr(1) in self.matrix_expr: + self.matrix_expr.append(ctx) + if ctx.expr(0) in self.vector_expr or ctx.expr(1) in self.vector_expr: + self.vector_expr.append(ctx) + if ctx.expr(0) in self.numeric_expr and ctx.expr(1) in self.numeric_expr: + self.numeric_expr.append(ctx) + self.setValue(ctx, self.getValue(ctx.expr(0)) + ctx.getChild(1).getText() + + self.getValue(ctx.expr(1))) + + def exitMulDiv(self, ctx): + # Tree annotation for MulDiv which is a labeled subrule of the parser rule expr. + # The subrule is expr = expr (*|/) expr + try: + if ctx.expr(0) in self.vector_expr and ctx.expr(1) in self.vector_expr: + self.setValue(ctx, "_me.outer(" + self.getValue(ctx.expr(0)) + ", " + + self.getValue(ctx.expr(1)) + ")") + else: + if ctx.expr(0) in self.matrix_expr or ctx.expr(1) in self.matrix_expr: + self.matrix_expr.append(ctx) + if ctx.expr(0) in self.vector_expr or ctx.expr(1) in self.vector_expr: + self.vector_expr.append(ctx) + if ctx.expr(0) in self.numeric_expr and ctx.expr(1) in self.numeric_expr: + self.numeric_expr.append(ctx) + self.setValue(ctx, self.getValue(ctx.expr(0)) + ctx.getChild(1).getText() + + self.getValue(ctx.expr(1))) + except Exception: + pass + + def exitNegativeOne(self, ctx): + # Tree annotation for negativeOne which is a labeled subrule of the parser rule expr. + self.setValue(ctx, "-1*" + self.getValue(ctx.getChild(1))) + if ctx.getChild(1) in self.matrix_expr: + self.matrix_expr.append(ctx) + if ctx.getChild(1) in self.numeric_expr: + self.numeric_expr.append(ctx) + + def exitParens(self, ctx): + # Tree annotation for parens which is a labeled subrule of the parser rule expr. + # The subrule is expr = '(' expr ')' + if ctx.expr() in self.matrix_expr: + self.matrix_expr.append(ctx) + if ctx.expr() in self.vector_expr: + self.vector_expr.append(ctx) + if ctx.expr() in self.numeric_expr: + self.numeric_expr.append(ctx) + self.setValue(ctx, "(" + self.getValue(ctx.expr()) + ")") + + def exitExponent(self, ctx): + # Tree annotation for Exponent which is a labeled subrule of the parser rule expr. + # The subrule is expr = expr ^ expr + if ctx.expr(0) in self.matrix_expr or ctx.expr(1) in self.matrix_expr: + self.matrix_expr.append(ctx) + if ctx.expr(0) in self.vector_expr or ctx.expr(1) in self.vector_expr: + self.vector_expr.append(ctx) + if ctx.expr(0) in self.numeric_expr and ctx.expr(1) in self.numeric_expr: + self.numeric_expr.append(ctx) + self.setValue(ctx, self.getValue(ctx.expr(0)) + "**" + self.getValue(ctx.expr(1))) + + def exitExp(self, ctx): + s = ctx.EXP().getText()[ctx.EXP().getText().index('E')+1:] + if "-" in s: + s = s[0] + s[1:].lstrip("0") + else: + s = s.lstrip("0") + self.setValue(ctx, ctx.EXP().getText()[:ctx.EXP().getText().index('E')] + + "*10**(" + s + ")") + + def exitFunction(self, ctx): + # Tree annotation for function which is a labeled subrule of the parser rule expr. + + # The difference between this and FunctionCall is that this is used for non standalone functions + # appearing in expressions and assignments. + # Eg: + # When we come across a standalone function say Expand(E, n:m) then it is categorized as FunctionCall + # which is a parser rule in itself under rule stat. exitFunctionCall() takes care of it and writes to the file. + # + # On the other hand, while we come across E_diff = D(E, y), we annotate the tree node + # of the function D(E, y) with the SymPy equivalent in exitFunction(). + # In this case it is the method exitAssignment() that writes the code to the file and not exitFunction(). + + ch = ctx.getChild(0) + func_name = ch.getChild(0).getText().lower() + + # Expand(y, n:m) * + if func_name == "expand": + expr = self.getValue(ch.expr(0)) + if ch.expr(0) in self.matrix_expr or (expr in self.type.keys() and self.type[expr] == "matrix"): + self.matrix_expr.append(ctx) + # _sm.Matrix([i.expand() for i in z]).reshape(z.shape[0], z.shape[1]) + self.setValue(ctx, "_sm.Matrix([i.expand() for i in " + expr + "])" + + ".reshape((" + expr + ").shape[0], " + "(" + expr + ").shape[1])") + else: + self.setValue(ctx, "(" + expr + ")" + "." + "expand()") + + # Factor(y, x) * + elif func_name == "factor": + expr = self.getValue(ch.expr(0)) + if ch.expr(0) in self.matrix_expr or (expr in self.type.keys() and self.type[expr] == "matrix"): + self.matrix_expr.append(ctx) + self.setValue(ctx, "_sm.Matrix([_sm.factor(i, " + self.getValue(ch.expr(1)) + ") for i in " + + expr + "])" + ".reshape((" + expr + ").shape[0], " + "(" + expr + ").shape[1])") + else: + self.setValue(ctx, "_sm.factor(" + "(" + expr + ")" + + ", " + self.getValue(ch.expr(1)) + ")") + + # D(y, x) + elif func_name == "d": + expr = self.getValue(ch.expr(0)) + if ch.expr(0) in self.matrix_expr or (expr in self.type.keys() and self.type[expr] == "matrix"): + self.matrix_expr.append(ctx) + self.setValue(ctx, "_sm.Matrix([i.diff(" + self.getValue(ch.expr(1)) + ") for i in " + + expr + "])" + ".reshape((" + expr + ").shape[0], " + "(" + expr + ").shape[1])") + else: + if ch.getChildCount() == 8: + frame = self.symbol_table2[ch.expr(2).getText().lower()] + self.setValue(ctx, "(" + expr + ")" + "." + "diff(" + self.getValue(ch.expr(1)) + + ", " + frame + ")") + else: + self.setValue(ctx, "(" + expr + ")" + "." + "diff(" + + self.getValue(ch.expr(1)) + ")") + + # Dt(y) + elif func_name == "dt": + expr = self.getValue(ch.expr(0)) + if ch.expr(0) in self.vector_expr: + text = "dt(" + else: + text = "diff(_sm.Symbol('t')" + if ch.expr(0) in self.matrix_expr or (expr in self.type.keys() and self.type[expr] == "matrix"): + self.matrix_expr.append(ctx) + self.setValue(ctx, "_sm.Matrix([i." + text + + ") for i in " + expr + "])" + + ".reshape((" + expr + ").shape[0], " + "(" + expr + ").shape[1])") + else: + if ch.getChildCount() == 6: + frame = self.symbol_table2[ch.expr(1).getText().lower()] + self.setValue(ctx, "(" + expr + ")" + "." + "dt(" + + frame + ")") + else: + self.setValue(ctx, "(" + expr + ")" + "." + text + ")") + + # Explicit(EXPRESS(IMPLICIT>,C)) + elif func_name == "explicit": + if ch.expr(0) in self.vector_expr: + self.vector_expr.append(ctx) + expr = self.getValue(ch.expr(0)) + if self.explicit.keys(): + explicit_list = [] + for i in self.explicit.keys(): + explicit_list.append(i + ":" + self.explicit[i]) + self.setValue(ctx, "(" + expr + ")" + ".subs({" + ", ".join(explicit_list) + "})") + else: + self.setValue(ctx, expr) + + # Taylor(y, 0:2, w=a, x=0) + # TODO: Currently only works with symbols. Make it work for dynamicsymbols. + elif func_name == "taylor": + exp = self.getValue(ch.expr(0)) + order = self.getValue(ch.expr(1).expr(1)) + x = (ch.getChildCount()-6)//2 + l = [] + for i in range(x): + index = 2 + i + child = ch.expr(index) + l.append(".series(" + self.getValue(child.getChild(0)) + + ", " + self.getValue(child.getChild(2)) + + ", " + order + ").removeO()") + self.setValue(ctx, "(" + exp + ")" + "".join(l)) + + # Evaluate(y, a=x, b=2) + elif func_name == "evaluate": + expr = self.getValue(ch.expr(0)) + l = [] + x = (ch.getChildCount()-4)//2 + for i in range(x): + index = 1 + i + child = ch.expr(index) + l.append(self.getValue(child.getChild(0)) + ":" + + self.getValue(child.getChild(2))) + + if ch.expr(0) in self.matrix_expr or (expr in self.type.keys() and self.type[expr] == "matrix"): + self.matrix_expr.append(ctx) + self.setValue(ctx, "_sm.Matrix([i.subs({" + ",".join(l) + "}) for i in " + + expr + "])" + + ".reshape((" + expr + ").shape[0], " + "(" + expr + ").shape[1])") + else: + if self.explicit: + explicit_list = [] + for i in self.explicit.keys(): + explicit_list.append(i + ":" + self.explicit[i]) + self.setValue(ctx, "(" + expr + ")" + ".subs({" + ",".join(explicit_list) + + "}).subs({" + ",".join(l) + "})") + else: + self.setValue(ctx, "(" + expr + ")" + ".subs({" + ",".join(l) + "})") + + # Polynomial([a, b, c], x) + elif func_name == "polynomial": + self.setValue(ctx, "_sm.Poly(" + self.getValue(ch.expr(0)) + ", " + + self.getValue(ch.expr(1)) + ")") + + # Roots(Poly, x, 2) + # Roots([1; 2; 3; 4]) + elif func_name == "roots": + self.matrix_expr.append(ctx) + expr = self.getValue(ch.expr(0)) + if ch.expr(0) in self.matrix_expr or (expr in self.type.keys() and self.type[expr] == "matrix"): + self.setValue(ctx, "[i.evalf() for i in " + "_sm.solve(" + + "_sm.Poly(" + expr + ", " + "x),x)]") + else: + self.setValue(ctx, "[i.evalf() for i in " + "_sm.solve(" + + expr + ", " + self.getValue(ch.expr(1)) + ")]") + + # Transpose(A), Inv(A) + elif func_name in ("transpose", "inv", "inverse"): + self.matrix_expr.append(ctx) + if func_name == "transpose": + e = ".T" + elif func_name in ("inv", "inverse"): + e = "**(-1)" + self.setValue(ctx, "(" + self.getValue(ch.expr(0)) + ")" + e) + + # Eig(A) + elif func_name == "eig": + # "_sm.Matrix([i.evalf() for i in " + + self.setValue(ctx, "_sm.Matrix([i.evalf() for i in (" + + self.getValue(ch.expr(0)) + ").eigenvals().keys()])") + + # Diagmat(n, m, x) + # Diagmat(3, 1) + elif func_name == "diagmat": + self.matrix_expr.append(ctx) + if ch.getChildCount() == 6: + l = [] + for i in range(int(self.getValue(ch.expr(0)))): + l.append(self.getValue(ch.expr(1)) + ",") + + self.setValue(ctx, "_sm.diag(" + ("".join(l))[:-1] + ")") + + elif ch.getChildCount() == 8: + # _sm.Matrix([x if i==j else 0 for i in range(n) for j in range(m)]).reshape(n, m) + n = self.getValue(ch.expr(0)) + m = self.getValue(ch.expr(1)) + x = self.getValue(ch.expr(2)) + self.setValue(ctx, "_sm.Matrix([" + x + " if i==j else 0 for i in range(" + + n + ") for j in range(" + m + ")]).reshape(" + n + ", " + m + ")") + + # Cols(A) + # Cols(A, 1) + # Cols(A, 1, 2:4, 3) + elif func_name in ("cols", "rows"): + self.matrix_expr.append(ctx) + if func_name == "cols": + e1 = ".cols" + e2 = ".T." + else: + e1 = ".rows" + e2 = "." + if ch.getChildCount() == 4: + self.setValue(ctx, "(" + self.getValue(ch.expr(0)) + ")" + e1) + elif ch.getChildCount() == 6: + self.setValue(ctx, "(" + self.getValue(ch.expr(0)) + ")" + + e1[:-1] + "(" + str(int(self.getValue(ch.expr(1))) - 1) + ")") + else: + l = [] + for i in range(4, ch.getChildCount()): + try: + if ch.getChild(i).getChildCount() > 1 and ch.getChild(i).getChild(1).getText() == ":": + for j in range(int(ch.getChild(i).getChild(0).getText()), + int(ch.getChild(i).getChild(2).getText())+1): + l.append("(" + self.getValue(ch.getChild(2)) + ")" + e2 + + "row(" + str(j-1) + ")") + else: + l.append("(" + self.getValue(ch.getChild(2)) + ")" + e2 + + "row(" + str(int(ch.getChild(i).getText())-1) + ")") + except Exception: + pass + self.setValue(ctx, "_sm.Matrix([" + ",".join(l) + "])") + + # Det(A) Trace(A) + elif func_name in ["det", "trace"]: + self.setValue(ctx, "(" + self.getValue(ch.expr(0)) + ")" + "." + + func_name + "()") + + # Element(A, 2, 3) + elif func_name == "element": + self.setValue(ctx, "(" + self.getValue(ch.expr(0)) + ")" + "[" + + str(int(self.getValue(ch.expr(1)))-1) + "," + + str(int(self.getValue(ch.expr(2)))-1) + "]") + + elif func_name in \ + ["cos", "sin", "tan", "cosh", "sinh", "tanh", "acos", "asin", "atan", + "log", "exp", "sqrt", "factorial", "floor", "sign"]: + self.setValue(ctx, "_sm." + func_name + "(" + self.getValue(ch.expr(0)) + ")") + + elif func_name == "ceil": + self.setValue(ctx, "_sm.ceiling" + "(" + self.getValue(ch.expr(0)) + ")") + + elif func_name == "sqr": + self.setValue(ctx, "(" + self.getValue(ch.expr(0)) + + ")" + "**2") + + elif func_name == "log10": + self.setValue(ctx, "_sm.log" + + "(" + self.getValue(ch.expr(0)) + ", 10)") + + elif func_name == "atan2": + self.setValue(ctx, "_sm.atan2" + "(" + self.getValue(ch.expr(0)) + ", " + + self.getValue(ch.expr(1)) + ")") + + elif func_name in ["int", "round"]: + self.setValue(ctx, func_name + + "(" + self.getValue(ch.expr(0)) + ")") + + elif func_name == "abs": + self.setValue(ctx, "_sm.Abs(" + self.getValue(ch.expr(0)) + ")") + + elif func_name in ["max", "min"]: + # max(x, y, z) + l = [] + for i in range(1, ch.getChildCount()): + if ch.getChild(i) in self.tree_property.keys(): + l.append(self.getValue(ch.getChild(i))) + elif ch.getChild(i).getText() in [",", "(", ")"]: + l.append(ch.getChild(i).getText()) + self.setValue(ctx, "_sm." + ch.getChild(0).getText().capitalize() + "".join(l)) + + # Coef(y, x) + elif func_name == "coef": + #A41_A53=COEF([RHS(U4);RHS(U5)],[U1,U2,U3]) + if ch.expr(0) in self.matrix_expr and ch.expr(1) in self.matrix_expr: + icount = jcount = 0 + for i in range(ch.expr(0).getChild(0).getChildCount()): + try: + ch.expr(0).getChild(0).getChild(i).getRuleIndex() + icount+=1 + except Exception: + pass + for j in range(ch.expr(1).getChild(0).getChildCount()): + try: + ch.expr(1).getChild(0).getChild(j).getRuleIndex() + jcount+=1 + except Exception: + pass + l = [] + for i in range(icount): + for j in range(jcount): + # a41_a53[i,j] = u4.expand().coeff(u1) + l.append(self.getValue(ch.expr(0).getChild(0).expr(i)) + ".expand().coeff(" + + self.getValue(ch.expr(1).getChild(0).expr(j)) + ")") + self.setValue(ctx, "_sm.Matrix([" + ", ".join(l) + "]).reshape(" + str(icount) + ", " + str(jcount) + ")") + else: + self.setValue(ctx, "(" + self.getValue(ch.expr(0)) + + ")" + ".expand().coeff(" + self.getValue(ch.expr(1)) + ")") + + # Exclude(y, x) Include(y, x) + elif func_name in ("exclude", "include"): + if func_name == "exclude": + e = "0" + else: + e = "1" + expr = self.getValue(ch.expr(0)) + if ch.expr(0) in self.matrix_expr or (expr in self.type.keys() and self.type[expr] == "matrix"): + self.matrix_expr.append(ctx) + self.setValue(ctx, "_sm.Matrix([i.collect(" + self.getValue(ch.expr(1)) + "])" + + ".coeff(" + self.getValue(ch.expr(1)) + "," + e + ")" + "for i in " + expr + ")" + + ".reshape((" + expr + ").shape[0], " + "(" + expr + ").shape[1])") + else: + self.setValue(ctx, "(" + expr + + ")" + ".collect(" + self.getValue(ch.expr(1)) + ")" + + ".coeff(" + self.getValue(ch.expr(1)) + "," + e + ")") + + # RHS(y) + elif func_name == "rhs": + self.setValue(ctx, self.explicit[self.getValue(ch.expr(0))]) + + # Arrange(y, n, x) * + elif func_name == "arrange": + expr = self.getValue(ch.expr(0)) + if ch.expr(0) in self.matrix_expr or (expr in self.type.keys() and self.type[expr] == "matrix"): + self.matrix_expr.append(ctx) + self.setValue(ctx, "_sm.Matrix([i.collect(" + self.getValue(ch.expr(2)) + + ")" + "for i in " + expr + "])"+ + ".reshape((" + expr + ").shape[0], " + "(" + expr + ").shape[1])") + else: + self.setValue(ctx, "(" + expr + + ")" + ".collect(" + self.getValue(ch.expr(2)) + ")") + + # Replace(y, sin(x)=3) + elif func_name == "replace": + l = [] + for i in range(1, ch.getChildCount()): + try: + if ch.getChild(i).getChild(1).getText() == "=": + l.append(self.getValue(ch.getChild(i).getChild(0)) + + ":" + self.getValue(ch.getChild(i).getChild(2))) + except Exception: + pass + expr = self.getValue(ch.expr(0)) + if ch.expr(0) in self.matrix_expr or (expr in self.type.keys() and self.type[expr] == "matrix"): + self.matrix_expr.append(ctx) + self.setValue(ctx, "_sm.Matrix([i.subs({" + ",".join(l) + "}) for i in " + + expr + "])" + + ".reshape((" + expr + ").shape[0], " + "(" + expr + ").shape[1])") + else: + self.setValue(ctx, "(" + self.getValue(ch.expr(0)) + ")" + + ".subs({" + ",".join(l) + "})") + + # Dot(Loop>, N1>) + elif func_name == "dot": + l = [] + num = (ch.expr(1).getChild(0).getChildCount()-1)//2 + if ch.expr(1) in self.matrix_expr: + for i in range(num): + l.append("_me.dot(" + self.getValue(ch.expr(0)) + ", " + self.getValue(ch.expr(1).getChild(0).expr(i)) + ")") + self.setValue(ctx, "_sm.Matrix([" + ",".join(l) + "]).reshape(" + str(num) + ", " + "1)") + else: + self.setValue(ctx, "_me.dot(" + self.getValue(ch.expr(0)) + ", " + self.getValue(ch.expr(1)) + ")") + # Cross(w_A_N>, P_NA_AB>) + elif func_name == "cross": + self.vector_expr.append(ctx) + self.setValue(ctx, "_me.cross(" + self.getValue(ch.expr(0)) + ", " + self.getValue(ch.expr(1)) + ")") + + # Mag(P_O_Q>) + elif func_name == "mag": + self.setValue(ctx, self.getValue(ch.expr(0)) + "." + "magnitude()") + + # MATRIX(A, I_R>>) + elif func_name == "matrix": + if self.type2[ch.expr(0).getText().lower()] == "frame": + text = "" + elif self.type2[ch.expr(0).getText().lower()] == "bodies": + text = "_f" + self.setValue(ctx, "(" + self.getValue(ch.expr(1)) + ")" + ".to_matrix(" + + self.symbol_table2[ch.expr(0).getText().lower()] + text + ")") + + # VECTOR(A, ROWS(EIGVECS,1)) + elif func_name == "vector": + if self.type2[ch.expr(0).getText().lower()] == "frame": + text = "" + elif self.type2[ch.expr(0).getText().lower()] == "bodies": + text = "_f" + v = self.getValue(ch.expr(1)) + f = self.symbol_table2[ch.expr(0).getText().lower()] + text + self.setValue(ctx, v + "[0]*" + f + ".x +" + v + "[1]*" + f + ".y +" + + v + "[2]*" + f + ".z") + + # Express(A2>, B) + # Here I am dealing with all the Inertia commands as I expect the users to use Inertia + # commands only with Express because SymPy needs the Reference frame to be specified unlike Autolev. + elif func_name == "express": + self.vector_expr.append(ctx) + if self.type2[ch.expr(1).getText().lower()] == "frame": + frame = self.symbol_table2[ch.expr(1).getText().lower()] + else: + frame = self.symbol_table2[ch.expr(1).getText().lower()] + "_f" + if ch.expr(0).getText().lower() == "1>>": + self.setValue(ctx, "_me.inertia(" + frame + ", 1, 1, 1)") + + elif '_' in ch.expr(0).getText().lower() and ch.expr(0).getText().lower().count('_') == 2\ + and ch.expr(0).getText().lower()[0] == "i" and ch.expr(0).getText().lower()[-2:] == ">>": + v1 = ch.expr(0).getText().lower()[:-2].split('_')[1] + v2 = ch.expr(0).getText().lower()[:-2].split('_')[2] + l = [] + inertia_func(self, v1, v2, l, frame) + self.setValue(ctx, " + ".join(l)) + + elif ch.expr(0).getChild(0).getChild(0).getText().lower() == "inertia": + if ch.expr(0).getChild(0).getChildCount() == 4: + l = [] + v2 = ch.expr(0).getChild(0).ID(0).getText().lower() + for v1 in self.bodies: + inertia_func(self, v1, v2, l, frame) + self.setValue(ctx, " + ".join(l)) + + else: + l = [] + l2 = [] + v2 = ch.expr(0).getChild(0).ID(0).getText().lower() + for i in range(1, (ch.expr(0).getChild(0).getChildCount()-2)//2): + l2.append(ch.expr(0).getChild(0).ID(i).getText().lower()) + for v1 in l2: + inertia_func(self, v1, v2, l, frame) + self.setValue(ctx, " + ".join(l)) + + else: + self.setValue(ctx, "(" + self.getValue(ch.expr(0)) + ")" + ".express(" + + self.symbol_table2[ch.expr(1).getText().lower()] + ")") + # CM(P) + elif func_name == "cm": + if self.type2[ch.expr(0).getText().lower()] == "point": + text = "" + else: + text = ".point" + if ch.getChildCount() == 4: + self.setValue(ctx, "_me.functions.center_of_mass(" + self.symbol_table2[ch.expr(0).getText().lower()] + + text + "," + ", ".join(self.bodies.values()) + ")") + else: + bodies = [] + for i in range(1, (ch.getChildCount()-1)//2): + bodies.append(self.symbol_table2[ch.expr(i).getText().lower()]) + self.setValue(ctx, "_me.functions.center_of_mass(" + self.symbol_table2[ch.expr(0).getText().lower()] + + text + "," + ", ".join(bodies) + ")") + + # PARTIALS(V_P1_E>,U1) + elif func_name == "partials": + speeds = [] + for i in range(1, (ch.getChildCount()-1)//2): + if self.kd_equivalents2: + speeds.append(self.kd_equivalents2[self.symbol_table[ch.expr(i).getText().lower()]]) + else: + speeds.append(self.symbol_table[ch.expr(i).getText().lower()]) + v1, v2, v3 = ch.expr(0).getText().lower().replace(">","").split('_') + if self.type2[v2] == "point": + point = self.symbol_table2[v2] + elif self.type2[v2] == "particle": + point = self.symbol_table2[v2] + ".point" + frame = self.symbol_table2[v3] + self.setValue(ctx, point + ".partial_velocity(" + frame + ", " + ",".join(speeds) + ")") + + # UnitVec(A1>+A2>+A3>) + elif func_name == "unitvec": + self.setValue(ctx, "(" + self.getValue(ch.expr(0)) + ")" + ".normalize()") + + # Units(deg, rad) + elif func_name == "units": + if ch.expr(0).getText().lower() == "deg" and ch.expr(1).getText().lower() == "rad": + factor = 0.0174533 + elif ch.expr(0).getText().lower() == "rad" and ch.expr(1).getText().lower() == "deg": + factor = 57.2958 + self.setValue(ctx, str(factor)) + # Mass(A) + elif func_name == "mass": + l = [] + try: + ch.ID(0).getText().lower() + for i in range((ch.getChildCount()-1)//2): + l.append(self.symbol_table2[ch.ID(i).getText().lower()] + ".mass") + self.setValue(ctx, "+".join(l)) + except Exception: + for i in self.bodies.keys(): + l.append(self.bodies[i] + ".mass") + self.setValue(ctx, "+".join(l)) + + # Fr() FrStar() + # _me.KanesMethod(n, q_ind, u_ind, kd, velocity_constraints).kanes_equations(pl, fl)[0] + elif func_name in ["fr", "frstar"]: + if not self.kane_parsed: + if self.kd_eqs: + for i in self.kd_eqs: + self.q_ind.append(self.symbol_table[i.strip().split('-')[0].replace("'","")]) + self.u_ind.append(self.symbol_table[i.strip().split('-')[1].replace("'","")]) + + for i in range(len(self.kd_eqs)): + self.kd_eqs[i] = self.symbol_table[self.kd_eqs[i].strip().split('-')[0]] + " - " +\ + self.symbol_table[self.kd_eqs[i].strip().split('-')[1]] + + # Do all of this if kd_eqs are not specified + if not self.kd_eqs: + self.kd_eqs_supplied = False + self.matrix_expr.append(ctx) + for i in self.type.keys(): + if self.type[i] == "motionvariable": + if self.sign[self.symbol_table[i.lower()]] == 0: + self.q_ind.append(self.symbol_table[i.lower()]) + elif self.sign[self.symbol_table[i.lower()]] == 1: + name = "u_" + self.symbol_table[i.lower()] + self.symbol_table.update({name: name}) + self.write(name + " = " + "_me.dynamicsymbols('" + name + "')\n") + if self.symbol_table[i.lower()] not in self.dependent_variables: + self.u_ind.append(name) + self.kd_equivalents.update({name: self.symbol_table[i.lower()]}) + else: + self.u_dep.append(name) + self.kd_equivalents.update({name: self.symbol_table[i.lower()]}) + + for i in self.kd_equivalents.keys(): + self.kd_eqs.append(self.kd_equivalents[i] + "-" + i) + + if not self.u_ind and not self.kd_eqs: + self.u_ind = self.q_ind.copy() + self.q_ind = [] + + # deal with velocity constraints + if self.dependent_variables: + for i in self.dependent_variables: + self.u_dep.append(i) + if i in self.u_ind: + self.u_ind.remove(i) + + + self.u_dep[:] = [i for i in self.u_dep if i not in self.kd_equivalents.values()] + + force_list = [] + for i in self.forces.keys(): + force_list.append("(" + i + "," + self.forces[i] + ")") + if self.u_dep: + u_dep_text = ", u_dependent=[" + ", ".join(self.u_dep) + "]" + else: + u_dep_text = "" + if self.dependent_variables: + velocity_constraints_text = ", velocity_constraints = velocity_constraints" + else: + velocity_constraints_text = "" + if ctx.parentCtx not in self.fr_expr: + self.write("kd_eqs = [" + ", ".join(self.kd_eqs) + "]\n") + self.write("forceList = " + "[" + ", ".join(force_list) + "]\n") + self.write("kane = _me.KanesMethod(" + self.newtonian + ", " + "q_ind=[" + + ",".join(self.q_ind) + "], " + "u_ind=[" + + ", ".join(self.u_ind) + "]" + u_dep_text + ", " + + "kd_eqs = kd_eqs" + velocity_constraints_text + ")\n") + self.write("fr, frstar = kane." + "kanes_equations([" + + ", ".join(self.bodies.values()) + "], forceList)\n") + self.fr_expr.append(ctx.parentCtx) + self.kane_parsed = True + self.setValue(ctx, func_name) + + def exitMatrices(self, ctx): + # Tree annotation for Matrices which is a labeled subrule of the parser rule expr. + + # MO = [a, b; c, d] + # we generate _sm.Matrix([a, b, c, d]).reshape(2, 2) + # The reshape values are determined by counting the "," and ";" in the Autolev matrix + + # Eg: + # [1, 2, 3; 4, 5, 6; 7, 8, 9; 10, 11, 12] + # semicolon_count = 3 and rows = 3+1 = 4 + # comma_count = 8 and cols = 8/rows + 1 = 8/4 + 1 = 3 + + # TODO** Parse block matrices + self.matrix_expr.append(ctx) + l = [] + semicolon_count = 0 + comma_count = 0 + for i in range(ctx.matrix().getChildCount()): + child = ctx.matrix().getChild(i) + if child == AutolevParser.ExprContext: + l.append(self.getValue(child)) + elif child.getText() == ";": + semicolon_count += 1 + l.append(",") + elif child.getText() == ",": + comma_count += 1 + l.append(",") + else: + try: + try: + l.append(self.getValue(child)) + except Exception: + l.append(self.symbol_table[child.getText().lower()]) + except Exception: + l.append(child.getText().lower()) + num_of_rows = semicolon_count + 1 + num_of_cols = (comma_count//num_of_rows) + 1 + + self.setValue(ctx, "_sm.Matrix(" + "".join(l) + ")" + ".reshape(" + + str(num_of_rows) + ", " + str(num_of_cols) + ")") + + def exitVectorOrDyadic(self, ctx): + self.vector_expr.append(ctx) + ch = ctx.vec() + + if ch.getChild(0).getText() == "0>": + self.setValue(ctx, "0") + + elif ch.getChild(0).getText() == "1>>": + self.setValue(ctx, "1>>") + + elif "_" in ch.ID().getText() and ch.ID().getText().count('_') == 2: + vec_text = ch.getText().lower() + v1, v2, v3 = ch.ID().getText().lower().split('_') + + if v1 == "p": + if self.type2[v2] == "point": + e2 = self.symbol_table2[v2] + elif self.type2[v2] == "particle": + e2 = self.symbol_table2[v2] + ".point" + if self.type2[v3] == "point": + e3 = self.symbol_table2[v3] + elif self.type2[v3] == "particle": + e3 = self.symbol_table2[v3] + ".point" + get_vec = e3 + ".pos_from(" + e2 + ")" + self.setValue(ctx, get_vec) + + elif v1 in ("w", "alf"): + if v1 == "w": + text = ".ang_vel_in(" + elif v1 == "alf": + text = ".ang_acc_in(" + if self.type2[v2] == "bodies": + e2 = self.symbol_table2[v2] + "_f" + elif self.type2[v2] == "frame": + e2 = self.symbol_table2[v2] + if self.type2[v3] == "bodies": + e3 = self.symbol_table2[v3] + "_f" + elif self.type2[v3] == "frame": + e3 = self.symbol_table2[v3] + get_vec = e2 + text + e3 + ")" + self.setValue(ctx, get_vec) + + elif v1 in ("v", "a"): + if v1 == "v": + text = ".vel(" + elif v1 == "a": + text = ".acc(" + if self.type2[v2] == "point": + e2 = self.symbol_table2[v2] + elif self.type2[v2] == "particle": + e2 = self.symbol_table2[v2] + ".point" + get_vec = e2 + text + self.symbol_table2[v3] + ")" + self.setValue(ctx, get_vec) + + else: + self.setValue(ctx, vec_text.replace(">", "")) + + else: + vec_text = ch.getText().lower() + name = self.symbol_table[vec_text] + self.setValue(ctx, name) + + def exitIndexing(self, ctx): + if ctx.getChildCount() == 4: + try: + int_text = str(int(self.getValue(ctx.getChild(2))) - 1) + except Exception: + int_text = self.getValue(ctx.getChild(2)) + " - 1" + self.setValue(ctx, ctx.ID().getText().lower() + "[" + int_text + "]") + elif ctx.getChildCount() == 6: + try: + int_text1 = str(int(self.getValue(ctx.getChild(2))) - 1) + except Exception: + int_text1 = self.getValue(ctx.getChild(2)) + " - 1" + try: + int_text2 = str(int(self.getValue(ctx.getChild(4))) - 1) + except Exception: + int_text2 = self.getValue(ctx.getChild(2)) + " - 1" + self.setValue(ctx, ctx.ID().getText().lower() + "[" + int_text1 + ", " + int_text2 + "]") + + + # ================== Subrules of parser rule expr (End) ====================== # + + def exitRegularAssign(self, ctx): + # Handle assignments of type ID = expr + if ctx.equals().getText() in ["=", "+=", "-=", "*=", "/="]: + equals = ctx.equals().getText() + elif ctx.equals().getText() == ":=": + equals = " = " + elif ctx.equals().getText() == "^=": + equals = "**=" + + try: + a = ctx.ID().getText().lower() + "'"*ctx.diff().getText().count("'") + except Exception: + a = ctx.ID().getText().lower() + + if a in self.type.keys() and self.type[a] in ("motionvariable", "motionvariable'") and\ + self.type[ctx.expr().getText().lower()] in ("motionvariable", "motionvariable'"): + b = ctx.expr().getText().lower() + if "'" in b and "'" not in a: + a, b = b, a + if not self.kane_parsed: + self.kd_eqs.append(a + "-" + b) + self.kd_equivalents.update({self.symbol_table[a]: + self.symbol_table[b]}) + self.kd_equivalents2.update({self.symbol_table[b]: + self.symbol_table[a]}) + + if a in self.symbol_table.keys() and a in self.type.keys() and self.type[a] in ("variable", "motionvariable"): + self.explicit.update({self.symbol_table[a]: self.getValue(ctx.expr())}) + + else: + if ctx.expr() in self.matrix_expr: + self.type.update({a: "matrix"}) + + try: + b = self.symbol_table[a] + except KeyError: + self.symbol_table[a] = a + + if "_" in a and a.count("_") == 1: + e1, e2 = a.split('_') + if e1 in self.type2.keys() and self.type2[e1] in ("frame", "bodies")\ + and e2 in self.type2.keys() and self.type2[e2] in ("frame", "bodies"): + if self.type2[e1] == "bodies": + t1 = "_f" + else: + t1 = "" + if self.type2[e2] == "bodies": + t2 = "_f" + else: + t2 = "" + + self.write(self.symbol_table2[e2] + t2 + ".orient(" + self.symbol_table2[e1] + + t1 + ", 'DCM', " + self.getValue(ctx.expr()) + ")\n") + else: + self.write(self.symbol_table[a] + " " + equals + " " + + self.getValue(ctx.expr()) + "\n") + else: + self.write(self.symbol_table[a] + " " + equals + " " + + self.getValue(ctx.expr()) + "\n") + + def exitIndexAssign(self, ctx): + # Handle assignments of type ID[index] = expr + if ctx.equals().getText() in ["=", "+=", "-=", "*=", "/="]: + equals = ctx.equals().getText() + elif ctx.equals().getText() == ":=": + equals = " = " + elif ctx.equals().getText() == "^=": + equals = "**=" + + text = ctx.ID().getText().lower() + self.type.update({text: "matrix"}) + # Handle assignments of type ID[2] = expr + if ctx.index().getChildCount() == 1: + if ctx.index().getChild(0).getText() == "1": + self.type.update({text: "matrix"}) + self.symbol_table.update({text: text}) + self.write(text + " = " + "_sm.Matrix([[0]])\n") + self.write(text + "[0] = " + self.getValue(ctx.expr()) + "\n") + else: + # m = m.row_insert(m.shape[0], _sm.Matrix([[0]])) + self.write(text + " = " + text + + ".row_insert(" + text + ".shape[0]" + ", " + "_sm.Matrix([[0]])" + ")\n") + self.write(text + "[" + text + ".shape[0]-1" + "] = " + self.getValue(ctx.expr()) + "\n") + + # Handle assignments of type ID[2, 2] = expr + elif ctx.index().getChildCount() == 3: + l = [] + try: + l.append(str(int(self.getValue(ctx.index().getChild(0)))-1)) + except Exception: + l.append(self.getValue(ctx.index().getChild(0)) + "-1") + l.append(",") + try: + l.append(str(int(self.getValue(ctx.index().getChild(2)))-1)) + except Exception: + l.append(self.getValue(ctx.index().getChild(2)) + "-1") + self.write(self.symbol_table[ctx.ID().getText().lower()] + + "[" + "".join(l) + "]" + " " + equals + " " + self.getValue(ctx.expr()) + "\n") + + def exitVecAssign(self, ctx): + # Handle assignments of the type vec = expr + ch = ctx.vec() + vec_text = ch.getText().lower() + + if "_" in ch.ID().getText(): + num = ch.ID().getText().count('_') + + if num == 2: + v1, v2, v3 = ch.ID().getText().lower().split('_') + + if v1 == "p": + if self.type2[v2] == "point": + e2 = self.symbol_table2[v2] + elif self.type2[v2] == "particle": + e2 = self.symbol_table2[v2] + ".point" + if self.type2[v3] == "point": + e3 = self.symbol_table2[v3] + elif self.type2[v3] == "particle": + e3 = self.symbol_table2[v3] + ".point" + # ab.set_pos(na, la*a.x) + self.write(e3 + ".set_pos(" + e2 + ", " + self.getValue(ctx.expr()) + ")\n") + + elif v1 in ("w", "alf"): + if v1 == "w": + text = ".set_ang_vel(" + elif v1 == "alf": + text = ".set_ang_acc(" + # a.set_ang_vel(n, qad*a.z) + if self.type2[v2] == "bodies": + e2 = self.symbol_table2[v2] + "_f" + else: + e2 = self.symbol_table2[v2] + if self.type2[v3] == "bodies": + e3 = self.symbol_table2[v3] + "_f" + else: + e3 = self.symbol_table2[v3] + self.write(e2 + text + e3 + ", " + self.getValue(ctx.expr()) + ")\n") + + elif v1 in ("v", "a"): + if v1 == "v": + text = ".set_vel(" + elif v1 == "a": + text = ".set_acc(" + if self.type2[v2] == "point": + e2 = self.symbol_table2[v2] + elif self.type2[v2] == "particle": + e2 = self.symbol_table2[v2] + ".point" + self.write(e2 + text + self.symbol_table2[v3] + + ", " + self.getValue(ctx.expr()) + ")\n") + elif v1 == "i": + if v2 in self.type2.keys() and self.type2[v2] == "bodies": + self.write(self.symbol_table2[v2] + ".inertia = (" + self.getValue(ctx.expr()) + + ", " + self.symbol_table2[v3] + ")\n") + self.inertia_point.update({v2: v3}) + elif v2 in self.type2.keys() and self.type2[v2] == "particle": + self.write(ch.ID().getText().lower() + " = " + self.getValue(ctx.expr()) + "\n") + else: + self.write(ch.ID().getText().lower() + " = " + self.getValue(ctx.expr()) + "\n") + else: + self.write(ch.ID().getText().lower() + " = " + self.getValue(ctx.expr()) + "\n") + + elif num == 1: + v1, v2 = ch.ID().getText().lower().split('_') + + if v1 in ("force", "torque"): + if self.type2[v2] in ("point", "frame"): + e2 = self.symbol_table2[v2] + elif self.type2[v2] == "particle": + e2 = self.symbol_table2[v2] + ".point" + self.symbol_table.update({vec_text: ch.ID().getText().lower()}) + + if e2 in self.forces.keys(): + self.forces[e2] = self.forces[e2] + " + " + self.getValue(ctx.expr()) + else: + self.forces.update({e2: self.getValue(ctx.expr())}) + self.write(ch.ID().getText().lower() + " = " + self.forces[e2] + "\n") + + else: + name = ch.ID().getText().lower() + self.symbol_table.update({vec_text: name}) + self.write(ch.ID().getText().lower() + " = " + self.getValue(ctx.expr()) + "\n") + else: + name = ch.ID().getText().lower() + self.symbol_table.update({vec_text: name}) + self.write(name + " " + ctx.getChild(1).getText() + " " + self.getValue(ctx.expr()) + "\n") + else: + name = ch.ID().getText().lower() + self.symbol_table.update({vec_text: name}) + self.write(name + " " + ctx.getChild(1).getText() + " " + self.getValue(ctx.expr()) + "\n") + + def enterInputs2(self, ctx): + self.in_inputs = True + + # Inputs + def exitInputs2(self, ctx): + # Stores numerical values given by the input command which + # are used for codegen and numerical analysis. + if ctx.getChildCount() == 3: + try: + self.inputs.update({self.symbol_table[ctx.id_diff().getText().lower()]: self.getValue(ctx.expr(0))}) + except Exception: + self.inputs.update({ctx.id_diff().getText().lower(): self.getValue(ctx.expr(0))}) + elif ctx.getChildCount() == 4: + try: + self.inputs.update({self.symbol_table[ctx.id_diff().getText().lower()]: + (self.getValue(ctx.expr(0)), self.getValue(ctx.expr(1)))}) + except Exception: + self.inputs.update({ctx.id_diff().getText().lower(): + (self.getValue(ctx.expr(0)), self.getValue(ctx.expr(1)))}) + + self.in_inputs = False + + def enterOutputs(self, ctx): + self.in_outputs = True + def exitOutputs(self, ctx): + self.in_outputs = False + + def exitOutputs2(self, ctx): + try: + if "[" in ctx.expr(1).getText(): + self.outputs.append(self.symbol_table[ctx.expr(0).getText().lower()] + + ctx.expr(1).getText().lower()) + else: + self.outputs.append(self.symbol_table[ctx.expr(0).getText().lower()]) + + except Exception: + pass + + # Code commands + def exitCodegen(self, ctx): + # Handles the CODE() command ie the solvers and the codgen part. + # Uses linsolve for the algebraic solvers and nsolve for non linear solvers. + + if ctx.functionCall().getChild(0).getText().lower() == "algebraic": + matrix_name = self.getValue(ctx.functionCall().expr(0)) + e = [] + d = [] + for i in range(1, (ctx.functionCall().getChildCount()-2)//2): + a = self.getValue(ctx.functionCall().expr(i)) + e.append(a) + + for i in self.inputs.keys(): + d.append(i + ":" + self.inputs[i]) + self.write(matrix_name + "_list" + " = " + "[]\n") + self.write("for i in " + matrix_name + ": " + matrix_name + + "_list" + ".append(i.subs({" + ", ".join(d) + "}))\n") + self.write("print(_sm.linsolve(" + matrix_name + "_list" + ", " + ",".join(e) + "))\n") + + elif ctx.functionCall().getChild(0).getText().lower() == "nonlinear": + e = [] + d = [] + guess = [] + for i in range(1, (ctx.functionCall().getChildCount()-2)//2): + a = self.getValue(ctx.functionCall().expr(i)) + e.append(a) + #print(self.inputs) + for i in self.inputs.keys(): + if i in self.symbol_table.keys(): + if type(self.inputs[i]) is tuple: + j, z = self.inputs[i] + else: + j = self.inputs[i] + z = "" + if i not in e: + if z == "deg": + d.append(i + ":" + "_np.deg2rad(" + j + ")") + else: + d.append(i + ":" + j) + else: + if z == "deg": + guess.append("_np.deg2rad(" + j + ")") + else: + guess.append(j) + + self.write("matrix_list" + " = " + "[]\n") + self.write("for i in " + self.getValue(ctx.functionCall().expr(0)) + ":") + self.write("matrix_list" + ".append(i.subs({" + ", ".join(d) + "}))\n") + self.write("print(_sm.nsolve(matrix_list," + "(" + ",".join(e) + ")" + + ",(" + ",".join(guess) + ")" + "))\n") + + elif ctx.functionCall().getChild(0).getText().lower() in ["ode", "dynamics"] and self.include_numeric: + if self.kane_type == "no_args": + for i in self.symbol_table.keys(): + try: + if self.type[i] == "constants" or self.type[self.symbol_table[i]] == "constants": + self.constants.append(self.symbol_table[i]) + except Exception: + pass + q_add_u = self.q_ind + self.q_dep + self.u_ind + self.u_dep + x0 = [] + for i in q_add_u: + try: + if i in self.inputs.keys(): + if type(self.inputs[i]) is tuple: + if self.inputs[i][1] == "deg": + x0.append(i + ":" + "_np.deg2rad(" + self.inputs[i][0] + ")") + else: + x0.append(i + ":" + self.inputs[i][0]) + else: + x0.append(i + ":" + self.inputs[i]) + elif self.kd_equivalents[i] in self.inputs.keys(): + if type(self.inputs[self.kd_equivalents[i]]) is tuple: + x0.append(i + ":" + self.inputs[self.kd_equivalents[i]][0]) + else: + x0.append(i + ":" + self.inputs[self.kd_equivalents[i]]) + except Exception: + pass + + # numerical constants + numerical_constants = [] + for i in self.constants: + if i in self.inputs.keys(): + if type(self.inputs[i]) is tuple: + numerical_constants.append(self.inputs[i][0]) + else: + numerical_constants.append(self.inputs[i]) + + # t = linspace + t_final = self.inputs["tfinal"] + integ_stp = self.inputs["integstp"] + + self.write("from pydy.system import System\n") + const_list = [] + if numerical_constants: + for i in range(len(self.constants)): + const_list.append(self.constants[i] + ":" + numerical_constants[i]) + specifieds = [] + if self.t: + specifieds.append("_me.dynamicsymbols('t')" + ":" + "lambda x, t: t") + + for i in self.inputs: + if i in self.symbol_table.keys() and self.symbol_table[i] not in\ + self.constants + self.q_ind + self.q_dep + self.u_ind + self.u_dep: + specifieds.append(self.symbol_table[i] + ":" + self.inputs[i]) + + self.write("sys = System(kane, constants = {" + ", ".join(const_list) + "},\n" + + "specifieds={" + ", ".join(specifieds) + "},\n" + + "initial_conditions={" + ", ".join(x0) + "},\n" + + "times = _np.linspace(0.0, " + str(t_final) + ", " + str(t_final) + + "/" + str(integ_stp) + "))\n\ny=sys.integrate()\n") + + # For outputs other than qs and us. + other_outputs = [] + for i in self.outputs: + if i not in q_add_u: + if "[" in i: + other_outputs.append((i[:-3] + i[-2], i[:-3] + "[" + str(int(i[-2])-1) + "]")) + else: + other_outputs.append((i, i)) + + for i in other_outputs: + self.write(i[0] + "_out" + " = " + "[]\n") + if other_outputs: + self.write("for i in y:\n") + self.write(" q_u_dict = dict(zip(sys.coordinates+sys.speeds, i))\n") + for i in other_outputs: + self.write(" "*4 + i[0] + "_out" + ".append(" + i[1] + ".subs(q_u_dict)" + + ".subs(sys.constants).evalf())\n") + + # Standalone function calls (used for dual functions) + def exitFunctionCall(self, ctx): + # Basically deals with standalone function calls ie functions which are not a part of + # expressions and assignments. Autolev Dual functions can both appear in standalone + # function calls and also on the right hand side as part of expr or assignment. + + # Dual functions are indicated by a * in the comments below + + # Checks if the function is a statement on its own + if ctx.parentCtx.getRuleIndex() == AutolevParser.RULE_stat: + func_name = ctx.getChild(0).getText().lower() + # Expand(E, n:m) * + if func_name == "expand": + # If the first argument is a pre declared variable. + expr = self.getValue(ctx.expr(0)) + symbol = self.symbol_table[ctx.expr(0).getText().lower()] + if ctx.expr(0) in self.matrix_expr or (expr in self.type.keys() and self.type[expr] == "matrix"): + self.write(symbol + " = " + "_sm.Matrix([i.expand() for i in " + expr + "])" + + ".reshape((" + expr + ").shape[0], " + "(" + expr + ").shape[1])\n") + else: + self.write(symbol + " = " + symbol + "." + "expand()\n") + + # Factor(E, x) * + elif func_name == "factor": + expr = self.getValue(ctx.expr(0)) + symbol = self.symbol_table[ctx.expr(0).getText().lower()] + if ctx.expr(0) in self.matrix_expr or (expr in self.type.keys() and self.type[expr] == "matrix"): + self.write(symbol + " = " + "_sm.Matrix([_sm.factor(i," + self.getValue(ctx.expr(1)) + + ") for i in " + expr + "])" + + ".reshape((" + expr + ").shape[0], " + "(" + expr + ").shape[1])\n") + else: + self.write(expr + " = " + "_sm.factor(" + expr + ", " + + self.getValue(ctx.expr(1)) + ")\n") + + # Solve(Zero, x, y) + elif func_name == "solve": + l = [] + l2 = [] + num = 0 + for i in range(1, ctx.getChildCount()): + if ctx.getChild(i).getText() == ",": + num+=1 + try: + l.append(self.getValue(ctx.getChild(i))) + except Exception: + l.append(ctx.getChild(i).getText()) + + if i != 2: + try: + l2.append(self.getValue(ctx.getChild(i))) + except Exception: + pass + + for i in l2: + self.explicit.update({i: "_sm.solve" + "".join(l) + "[" + i + "]"}) + + self.write("print(_sm.solve" + "".join(l) + ")\n") + + # Arrange(y, n, x) * + elif func_name == "arrange": + expr = self.getValue(ctx.expr(0)) + symbol = self.symbol_table[ctx.expr(0).getText().lower()] + + if ctx.expr(0) in self.matrix_expr or (expr in self.type.keys() and self.type[expr] == "matrix"): + self.write(symbol + " = " + "_sm.Matrix([i.collect(" + self.getValue(ctx.expr(2)) + + ")" + "for i in " + expr + "])" + + ".reshape((" + expr + ").shape[0], " + "(" + expr + ").shape[1])\n") + else: + self.write(self.getValue(ctx.expr(0)) + ".collect(" + + self.getValue(ctx.expr(2)) + ")\n") + + # Eig(M, EigenValue, EigenVec) + elif func_name == "eig": + self.symbol_table.update({ctx.expr(1).getText().lower(): ctx.expr(1).getText().lower()}) + self.symbol_table.update({ctx.expr(2).getText().lower(): ctx.expr(2).getText().lower()}) + # _sm.Matrix([i.evalf() for i in (i_s_so).eigenvals().keys()]) + self.write(ctx.expr(1).getText().lower() + " = " + + "_sm.Matrix([i.evalf() for i in " + + "(" + self.getValue(ctx.expr(0)) + ")" + ".eigenvals().keys()])\n") + # _sm.Matrix([i[2][0].evalf() for i in (i_s_o).eigenvects()]).reshape(i_s_o.shape[0], i_s_o.shape[1]) + self.write(ctx.expr(2).getText().lower() + " = " + + "_sm.Matrix([i[2][0].evalf() for i in " + "(" + self.getValue(ctx.expr(0)) + ")" + + ".eigenvects()]).reshape(" + self.getValue(ctx.expr(0)) + ".shape[0], " + + self.getValue(ctx.expr(0)) + ".shape[1])\n") + + # Simprot(N, A, 3, qA) + elif func_name == "simprot": + # A.orient(N, 'Axis', qA, N.z) + if self.type2[ctx.expr(0).getText().lower()] == "frame": + frame1 = self.symbol_table2[ctx.expr(0).getText().lower()] + elif self.type2[ctx.expr(0).getText().lower()] == "bodies": + frame1 = self.symbol_table2[ctx.expr(0).getText().lower()] + "_f" + if self.type2[ctx.expr(1).getText().lower()] == "frame": + frame2 = self.symbol_table2[ctx.expr(1).getText().lower()] + elif self.type2[ctx.expr(1).getText().lower()] == "bodies": + frame2 = self.symbol_table2[ctx.expr(1).getText().lower()] + "_f" + e2 = "" + if ctx.expr(2).getText()[0] == "-": + e2 = "-1*" + if ctx.expr(2).getText() in ("1", "-1"): + e = frame1 + ".x" + elif ctx.expr(2).getText() in ("2", "-2"): + e = frame1 + ".y" + elif ctx.expr(2).getText() in ("3", "-3"): + e = frame1 + ".z" + else: + e = self.getValue(ctx.expr(2)) + e2 = "" + + if "degrees" in self.settings.keys() and self.settings["degrees"] == "off": + value = self.getValue(ctx.expr(3)) + else: + if ctx.expr(3) in self.numeric_expr: + value = "_np.deg2rad(" + self.getValue(ctx.expr(3)) + ")" + else: + value = self.getValue(ctx.expr(3)) + self.write(frame2 + ".orient(" + frame1 + + ", " + "'Axis'" + ", " + "[" + value + + ", " + e2 + e + "]" + ")\n") + + # Express(A2>, B) * + elif func_name == "express": + if self.type2[ctx.expr(1).getText().lower()] == "bodies": + f = "_f" + else: + f = "" + + if '_' in ctx.expr(0).getText().lower() and ctx.expr(0).getText().count('_') == 2: + vec = ctx.expr(0).getText().lower().replace(">", "").split('_') + v1 = self.symbol_table2[vec[1]] + v2 = self.symbol_table2[vec[2]] + if vec[0] == "p": + self.write(v2 + ".set_pos(" + v1 + ", " + "(" + self.getValue(ctx.expr(0)) + + ")" + ".express(" + self.symbol_table2[ctx.expr(1).getText().lower()] + f + "))\n") + elif vec[0] == "v": + self.write(v1 + ".set_vel(" + v2 + ", " + "(" + self.getValue(ctx.expr(0)) + + ")" + ".express(" + self.symbol_table2[ctx.expr(1).getText().lower()] + f + "))\n") + elif vec[0] == "a": + self.write(v1 + ".set_acc(" + v2 + ", " + "(" + self.getValue(ctx.expr(0)) + + ")" + ".express(" + self.symbol_table2[ctx.expr(1).getText().lower()] + f + "))\n") + else: + self.write(self.getValue(ctx.expr(0)) + " = " + "(" + self.getValue(ctx.expr(0)) + ")" + ".express(" + + self.symbol_table2[ctx.expr(1).getText().lower()] + f + ")\n") + else: + self.write(self.getValue(ctx.expr(0)) + " = " + "(" + self.getValue(ctx.expr(0)) + ")" + ".express(" + + self.symbol_table2[ctx.expr(1).getText().lower()] + f + ")\n") + + # Angvel(A, B) + elif func_name == "angvel": + self.write("print(" + self.symbol_table2[ctx.expr(1).getText().lower()] + + ".ang_vel_in(" + self.symbol_table2[ctx.expr(0).getText().lower()] + "))\n") + + # v2pts(N, A, O, P) + elif func_name in ("v2pts", "a2pts", "v2pt", "a1pt"): + if func_name == "v2pts": + text = ".v2pt_theory(" + elif func_name == "a2pts": + text = ".a2pt_theory(" + elif func_name == "v1pt": + text = ".v1pt_theory(" + elif func_name == "a1pt": + text = ".a1pt_theory(" + if self.type2[ctx.expr(1).getText().lower()] == "frame": + frame = self.symbol_table2[ctx.expr(1).getText().lower()] + elif self.type2[ctx.expr(1).getText().lower()] == "bodies": + frame = self.symbol_table2[ctx.expr(1).getText().lower()] + "_f" + expr_list = [] + for i in range(2, 4): + if self.type2[ctx.expr(i).getText().lower()] == "point": + expr_list.append(self.symbol_table2[ctx.expr(i).getText().lower()]) + elif self.type2[ctx.expr(i).getText().lower()] == "particle": + expr_list.append(self.symbol_table2[ctx.expr(i).getText().lower()] + ".point") + + self.write(expr_list[1] + text + expr_list[0] + + "," + self.symbol_table2[ctx.expr(0).getText().lower()] + "," + + frame + ")\n") + + # Gravity(g*N1>) + elif func_name == "gravity": + for i in self.bodies.keys(): + if self.type2[i] == "bodies": + e = self.symbol_table2[i] + ".masscenter" + elif self.type2[i] == "particle": + e = self.symbol_table2[i] + ".point" + if e in self.forces.keys(): + self.forces[e] = self.forces[e] + self.symbol_table2[i] +\ + ".mass*(" + self.getValue(ctx.expr(0)) + ")" + else: + self.forces.update({e: self.symbol_table2[i] + + ".mass*(" + self.getValue(ctx.expr(0)) + ")"}) + self.write("force_" + i + " = " + self.forces[e] + "\n") + + # Explicit(EXPRESS(IMPLICIT>,C)) + elif func_name == "explicit": + if ctx.expr(0) in self.vector_expr: + self.vector_expr.append(ctx) + expr = self.getValue(ctx.expr(0)) + if self.explicit.keys(): + explicit_list = [] + for i in self.explicit.keys(): + explicit_list.append(i + ":" + self.explicit[i]) + if '_' in ctx.expr(0).getText().lower() and ctx.expr(0).getText().count('_') == 2: + vec = ctx.expr(0).getText().lower().replace(">", "").split('_') + v1 = self.symbol_table2[vec[1]] + v2 = self.symbol_table2[vec[2]] + if vec[0] == "p": + self.write(v2 + ".set_pos(" + v1 + ", " + "(" + expr + + ")" + ".subs({" + ", ".join(explicit_list) + "}))\n") + elif vec[0] == "v": + self.write(v2 + ".set_vel(" + v1 + ", " + "(" + expr + + ")" + ".subs({" + ", ".join(explicit_list) + "}))\n") + elif vec[0] == "a": + self.write(v2 + ".set_acc(" + v1 + ", " + "(" + expr + + ")" + ".subs({" + ", ".join(explicit_list) + "}))\n") + else: + self.write(expr + " = " + "(" + expr + ")" + ".subs({" + ", ".join(explicit_list) + "})\n") + else: + self.write(expr + " = " + "(" + expr + ")" + ".subs({" + ", ".join(explicit_list) + "})\n") + + # Force(O/Q, -k*Stretch*Uvec>) + elif func_name in ("force", "torque"): + + if "/" in ctx.expr(0).getText().lower(): + p1 = ctx.expr(0).getText().lower().split('/')[0] + p2 = ctx.expr(0).getText().lower().split('/')[1] + if self.type2[p1] in ("point", "frame"): + pt1 = self.symbol_table2[p1] + elif self.type2[p1] == "particle": + pt1 = self.symbol_table2[p1] + ".point" + if self.type2[p2] in ("point", "frame"): + pt2 = self.symbol_table2[p2] + elif self.type2[p2] == "particle": + pt2 = self.symbol_table2[p2] + ".point" + if pt1 in self.forces.keys(): + self.forces[pt1] = self.forces[pt1] + " + -1*("+self.getValue(ctx.expr(1)) + ")" + self.write("force_" + p1 + " = " + self.forces[pt1] + "\n") + else: + self.forces.update({pt1: "-1*("+self.getValue(ctx.expr(1)) + ")"}) + self.write("force_" + p1 + " = " + self.forces[pt1] + "\n") + if pt2 in self.forces.keys(): + self.forces[pt2] = self.forces[pt2] + "+ " + self.getValue(ctx.expr(1)) + self.write("force_" + p2 + " = " + self.forces[pt2] + "\n") + else: + self.forces.update({pt2: self.getValue(ctx.expr(1))}) + self.write("force_" + p2 + " = " + self.forces[pt2] + "\n") + + elif ctx.expr(0).getChildCount() == 1: + p1 = ctx.expr(0).getText().lower() + if self.type2[p1] in ("point", "frame"): + pt1 = self.symbol_table2[p1] + elif self.type2[p1] == "particle": + pt1 = self.symbol_table2[p1] + ".point" + if pt1 in self.forces.keys(): + self.forces[pt1] = self.forces[pt1] + "+ -1*(" + self.getValue(ctx.expr(1)) + ")" + else: + self.forces.update({pt1: "-1*(" + self.getValue(ctx.expr(1)) + ")"}) + + # Constrain(Dependent[qB]) + elif func_name == "constrain": + if ctx.getChild(2).getChild(0).getText().lower() == "dependent": + self.write("velocity_constraints = [i for i in dependent]\n") + x = (ctx.expr(0).getChildCount()-2)//2 + for i in range(x): + self.dependent_variables.append(self.getValue(ctx.expr(0).expr(i))) + + # Kane() + elif func_name == "kane": + if ctx.getChildCount() == 3: + self.kane_type = "no_args" + + # Settings + def exitSettings(self, ctx): + # Stores settings like Complex on/off, Degrees on/off etc in self.settings. + try: + self.settings.update({ctx.getChild(0).getText().lower(): + ctx.getChild(1).getText().lower()}) + except Exception: + pass + + def exitMassDecl2(self, ctx): + # Used for declaring the masses of particles and rigidbodies. + particle = self.symbol_table2[ctx.getChild(0).getText().lower()] + if ctx.getText().count("=") == 2: + if ctx.expr().expr(1) in self.numeric_expr: + e = "_sm.S(" + self.getValue(ctx.expr().expr(1)) + ")" + else: + e = self.getValue(ctx.expr().expr(1)) + self.symbol_table.update({ctx.expr().expr(0).getText().lower(): ctx.expr().expr(0).getText().lower()}) + self.write(ctx.expr().expr(0).getText().lower() + " = " + e + "\n") + mass = ctx.expr().expr(0).getText().lower() + else: + try: + if ctx.expr() in self.numeric_expr: + mass = "_sm.S(" + self.getValue(ctx.expr()) + ")" + else: + mass = self.getValue(ctx.expr()) + except Exception: + a_text = ctx.expr().getText().lower() + self.symbol_table.update({a_text: a_text}) + self.type.update({a_text: "constants"}) + self.write(a_text + " = " + "_sm.symbols('" + a_text + "')\n") + mass = a_text + + self.write(particle + ".mass = " + mass + "\n") + + def exitInertiaDecl(self, ctx): + inertia_list = [] + try: + ctx.ID(1).getText() + num = 5 + except Exception: + num = 2 + for i in range((ctx.getChildCount()-num)//2): + try: + if ctx.expr(i) in self.numeric_expr: + inertia_list.append("_sm.S(" + self.getValue(ctx.expr(i)) + ")") + else: + inertia_list.append(self.getValue(ctx.expr(i))) + except Exception: + a_text = ctx.expr(i).getText().lower() + self.symbol_table.update({a_text: a_text}) + self.type.update({a_text: "constants"}) + self.write(a_text + " = " + "_sm.symbols('" + a_text + "')\n") + inertia_list.append(a_text) + + if len(inertia_list) < 6: + for i in range(6-len(inertia_list)): + inertia_list.append("0") + # body_a.inertia = (_me.inertia(body_a, I1, I2, I3, 0, 0, 0), body_a_cm) + try: + frame = self.symbol_table2[ctx.ID(1).getText().lower()] + point = self.symbol_table2[ctx.ID(0).getText().lower().split('_')[1]] + body = self.symbol_table2[ctx.ID(0).getText().lower().split('_')[0]] + self.inertia_point.update({ctx.ID(0).getText().lower().split('_')[0] + : ctx.ID(0).getText().lower().split('_')[1]}) + self.write(body + ".inertia" + " = " + "(_me.inertia(" + frame + ", " + + ", ".join(inertia_list) + "), " + point + ")\n") + + except Exception: + body_name = self.symbol_table2[ctx.ID(0).getText().lower()] + body_name_cm = body_name + "_cm" + self.inertia_point.update({ctx.ID(0).getText().lower(): ctx.ID(0).getText().lower() + "o"}) + self.write(body_name + ".inertia" + " = " + "(_me.inertia(" + body_name + "_f" + ", " + + ", ".join(inertia_list) + "), " + body_name_cm + ")\n") diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/_parse_autolev_antlr.py b/phivenv/Lib/site-packages/sympy/parsing/autolev/_parse_autolev_antlr.py new file mode 100644 index 0000000000000000000000000000000000000000..e43924aac30903ade996b31921d3960afae90284 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/autolev/_parse_autolev_antlr.py @@ -0,0 +1,38 @@ +from importlib.metadata import version +from sympy.external import import_module + + +autolevparser = import_module('sympy.parsing.autolev._antlr.autolevparser', + import_kwargs={'fromlist': ['AutolevParser']}) +autolevlexer = import_module('sympy.parsing.autolev._antlr.autolevlexer', + import_kwargs={'fromlist': ['AutolevLexer']}) +autolevlistener = import_module('sympy.parsing.autolev._antlr.autolevlistener', + import_kwargs={'fromlist': ['AutolevListener']}) + +AutolevParser = getattr(autolevparser, 'AutolevParser', None) +AutolevLexer = getattr(autolevlexer, 'AutolevLexer', None) +AutolevListener = getattr(autolevlistener, 'AutolevListener', None) + + +def parse_autolev(autolev_code, include_numeric): + antlr4 = import_module('antlr4') + if not antlr4 or not version('antlr4-python3-runtime').startswith('4.11'): + raise ImportError("Autolev parsing requires the antlr4 Python package," + " provided by pip (antlr4-python3-runtime)" + " conda (antlr-python-runtime), version 4.11") + try: + l = autolev_code.readlines() + input_stream = antlr4.InputStream("".join(l)) + except Exception: + input_stream = antlr4.InputStream(autolev_code) + + if AutolevListener: + from ._listener_autolev_antlr import MyListener + lexer = AutolevLexer(input_stream) + token_stream = antlr4.CommonTokenStream(lexer) + parser = AutolevParser(token_stream) + tree = parser.prog() + my_listener = MyListener(include_numeric) + walker = antlr4.ParseTreeWalker() + walker.walk(my_listener, tree) + return "".join(my_listener.output_code) diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/README.txt b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/README.txt new file mode 100644 index 0000000000000000000000000000000000000000..946b006bac33544fadd2dc6d24c22240c8fbc8e4 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/README.txt @@ -0,0 +1,9 @@ +# parsing/tests/test_autolev.py uses the .al files in this directory as inputs and checks +# the equivalence of the parser generated codes and the respective .py files. + +# By default, this directory contains tests for all rules of the parser. + +# Additional tests consisting of full physics examples shall be made available soon in +# the form of another repository. One shall be able to copy the contents of that repo +# to this folder and use those tests after uncommenting the respective code in +# parsing/tests/test_autolev.py. diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest1.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest1.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ab7dd97b953ab730e1ac26c79ed345bdeca735f Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest1.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest10.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest10.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe6593e9ccc425797c1ee36f63998d4b17994def Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest10.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest11.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest11.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c0ec1dca46744255e487da38d80a9ea20dca22c Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest11.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest12.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest12.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6659b0c276e86a1f9be9bb45a9e6850279e3a04d Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest12.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest2.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest2.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f70b9dcecbd1590f750cffa6b4823839b83e8c6 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest2.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest3.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest3.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ee6315e951d6d04e8a39e1c71199d409902a484 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest3.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest4.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest4.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f64a6bd22ea7351ff94c88d6bb21ec822f585a5 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest4.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest5.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest5.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2f32143e25f3b48e083b8882a78b8ca95a72c75 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest5.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest6.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest6.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5630ed2e5be7738dface534e3274c6e18bf04a13 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest6.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest7.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest7.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67d1f09aa1300d154a2d0733d9c10d010cba055d Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest7.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest8.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest8.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe97e59522029fdecb8e63390422519e516bd8e0 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest8.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest9.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest9.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..106d0144fc5ad6457c4eebea20b1c8638387618b Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/__pycache__/ruletest9.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/__pycache__/chaos_pendulum.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/__pycache__/chaos_pendulum.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c0a946bdb7af4a76d73329d3ce4a632ed9a3d98 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/__pycache__/chaos_pendulum.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/__pycache__/double_pendulum.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/__pycache__/double_pendulum.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20a875e99219309a73e570cff93d99e77cf0dd06 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/__pycache__/double_pendulum.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/__pycache__/mass_spring_damper.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/__pycache__/mass_spring_damper.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..208c1f041dbdcb0fd22877a3efcbfe0f329ddebf Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/__pycache__/mass_spring_damper.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/__pycache__/non_min_pendulum.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/__pycache__/non_min_pendulum.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..139f5ba80767bc36088d67790b97d4bbc6ff0c2d Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/__pycache__/non_min_pendulum.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/chaos_pendulum.al b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/chaos_pendulum.al new file mode 100644 index 0000000000000000000000000000000000000000..3bbb4d51b853bfd759df38d666a42adc1cbea190 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/chaos_pendulum.al @@ -0,0 +1,33 @@ +CONSTANTS G,LB,W,H +MOTIONVARIABLES' THETA'',PHI'',OMEGA',ALPHA' +NEWTONIAN N +BODIES A,B +SIMPROT(N,A,2,THETA) +SIMPROT(A,B,3,PHI) +POINT O +LA = (LB-H/2)/2 +P_O_AO> = LA*A3> +P_O_BO> = LB*A3> +OMEGA = THETA' +ALPHA = PHI' +W_A_N> = OMEGA*N2> +W_B_A> = ALPHA*A3> +V_O_N> = 0> +V2PTS(N, A, O, AO) +V2PTS(N, A, O, BO) +MASS A=MA, B=MB +IAXX = 1/12*MA*(2*LA)^2 +IAYY = IAXX +IAZZ = 0 +IBXX = 1/12*MB*H^2 +IBYY = 1/12*MB*(W^2+H^2) +IBZZ = 1/12*MB*W^2 +INERTIA A, IAXX, IAYY, IAZZ +INERTIA B, IBXX, IBYY, IBZZ +GRAVITY(G*N3>) +ZERO = FR() + FRSTAR() +KANE() +INPUT LB=0.2,H=0.1,W=0.2,MA=0.01,MB=0.1,G=9.81 +INPUT THETA = 90 DEG, PHI = 0.5 DEG, OMEGA=0, ALPHA=0 +INPUT TFINAL=10, INTEGSTP=0.02 +CODE DYNAMICS() some_filename.c diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/chaos_pendulum.py b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/chaos_pendulum.py new file mode 100644 index 0000000000000000000000000000000000000000..4435635720bb38f40366f55bb3ace0f6f6899284 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/chaos_pendulum.py @@ -0,0 +1,55 @@ +import sympy.physics.mechanics as _me +import sympy as _sm +import math as m +import numpy as _np + +g, lb, w, h = _sm.symbols('g lb w h', real=True) +theta, phi, omega, alpha = _me.dynamicsymbols('theta phi omega alpha') +theta_d, phi_d, omega_d, alpha_d = _me.dynamicsymbols('theta_ phi_ omega_ alpha_', 1) +theta_dd, phi_dd = _me.dynamicsymbols('theta_ phi_', 2) +frame_n = _me.ReferenceFrame('n') +body_a_cm = _me.Point('a_cm') +body_a_cm.set_vel(frame_n, 0) +body_a_f = _me.ReferenceFrame('a_f') +body_a = _me.RigidBody('a', body_a_cm, body_a_f, _sm.symbols('m'), (_me.outer(body_a_f.x,body_a_f.x),body_a_cm)) +body_b_cm = _me.Point('b_cm') +body_b_cm.set_vel(frame_n, 0) +body_b_f = _me.ReferenceFrame('b_f') +body_b = _me.RigidBody('b', body_b_cm, body_b_f, _sm.symbols('m'), (_me.outer(body_b_f.x,body_b_f.x),body_b_cm)) +body_a_f.orient(frame_n, 'Axis', [theta, frame_n.y]) +body_b_f.orient(body_a_f, 'Axis', [phi, body_a_f.z]) +point_o = _me.Point('o') +la = (lb-h/2)/2 +body_a_cm.set_pos(point_o, la*body_a_f.z) +body_b_cm.set_pos(point_o, lb*body_a_f.z) +body_a_f.set_ang_vel(frame_n, omega*frame_n.y) +body_b_f.set_ang_vel(body_a_f, alpha*body_a_f.z) +point_o.set_vel(frame_n, 0) +body_a_cm.v2pt_theory(point_o,frame_n,body_a_f) +body_b_cm.v2pt_theory(point_o,frame_n,body_a_f) +ma = _sm.symbols('ma') +body_a.mass = ma +mb = _sm.symbols('mb') +body_b.mass = mb +iaxx = 1/12*ma*(2*la)**2 +iayy = iaxx +iazz = 0 +ibxx = 1/12*mb*h**2 +ibyy = 1/12*mb*(w**2+h**2) +ibzz = 1/12*mb*w**2 +body_a.inertia = (_me.inertia(body_a_f, iaxx, iayy, iazz, 0, 0, 0), body_a_cm) +body_b.inertia = (_me.inertia(body_b_f, ibxx, ibyy, ibzz, 0, 0, 0), body_b_cm) +force_a = body_a.mass*(g*frame_n.z) +force_b = body_b.mass*(g*frame_n.z) +kd_eqs = [theta_d - omega, phi_d - alpha] +forceList = [(body_a.masscenter,body_a.mass*(g*frame_n.z)), (body_b.masscenter,body_b.mass*(g*frame_n.z))] +kane = _me.KanesMethod(frame_n, q_ind=[theta,phi], u_ind=[omega, alpha], kd_eqs = kd_eqs) +fr, frstar = kane.kanes_equations([body_a, body_b], forceList) +zero = fr+frstar +from pydy.system import System +sys = System(kane, constants = {g:9.81, lb:0.2, w:0.2, h:0.1, ma:0.01, mb:0.1}, +specifieds={}, +initial_conditions={theta:_np.deg2rad(90), phi:_np.deg2rad(0.5), omega:0, alpha:0}, +times = _np.linspace(0.0, 10, 10/0.02)) + +y=sys.integrate() diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/double_pendulum.al b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/double_pendulum.al new file mode 100644 index 0000000000000000000000000000000000000000..0b6d72a072e093a6cb048a0b7976041ee9c2f4f3 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/double_pendulum.al @@ -0,0 +1,25 @@ +MOTIONVARIABLES' Q{2}', U{2}' +CONSTANTS L,M,G +NEWTONIAN N +FRAMES A,B +SIMPROT(N, A, 3, Q1) +SIMPROT(N, B, 3, Q2) +W_A_N>=U1*N3> +W_B_N>=U2*N3> +POINT O +PARTICLES P,R +P_O_P> = L*A1> +P_P_R> = L*B1> +V_O_N> = 0> +V2PTS(N, A, O, P) +V2PTS(N, B, P, R) +MASS P=M, R=M +Q1' = U1 +Q2' = U2 +GRAVITY(G*N1>) +ZERO = FR() + FRSTAR() +KANE() +INPUT M=1,G=9.81,L=1 +INPUT Q1=.1,Q2=.2,U1=0,U2=0 +INPUT TFINAL=10, INTEGSTP=.01 +CODE DYNAMICS() some_filename.c diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/double_pendulum.py b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/double_pendulum.py new file mode 100644 index 0000000000000000000000000000000000000000..12c73c3b4b198399f4c45f5e00d556c859caff74 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/double_pendulum.py @@ -0,0 +1,39 @@ +import sympy.physics.mechanics as _me +import sympy as _sm +import math as m +import numpy as _np + +q1, q2, u1, u2 = _me.dynamicsymbols('q1 q2 u1 u2') +q1_d, q2_d, u1_d, u2_d = _me.dynamicsymbols('q1_ q2_ u1_ u2_', 1) +l, m, g = _sm.symbols('l m g', real=True) +frame_n = _me.ReferenceFrame('n') +frame_a = _me.ReferenceFrame('a') +frame_b = _me.ReferenceFrame('b') +frame_a.orient(frame_n, 'Axis', [q1, frame_n.z]) +frame_b.orient(frame_n, 'Axis', [q2, frame_n.z]) +frame_a.set_ang_vel(frame_n, u1*frame_n.z) +frame_b.set_ang_vel(frame_n, u2*frame_n.z) +point_o = _me.Point('o') +particle_p = _me.Particle('p', _me.Point('p_pt'), _sm.Symbol('m')) +particle_r = _me.Particle('r', _me.Point('r_pt'), _sm.Symbol('m')) +particle_p.point.set_pos(point_o, l*frame_a.x) +particle_r.point.set_pos(particle_p.point, l*frame_b.x) +point_o.set_vel(frame_n, 0) +particle_p.point.v2pt_theory(point_o,frame_n,frame_a) +particle_r.point.v2pt_theory(particle_p.point,frame_n,frame_b) +particle_p.mass = m +particle_r.mass = m +force_p = particle_p.mass*(g*frame_n.x) +force_r = particle_r.mass*(g*frame_n.x) +kd_eqs = [q1_d - u1, q2_d - u2] +forceList = [(particle_p.point,particle_p.mass*(g*frame_n.x)), (particle_r.point,particle_r.mass*(g*frame_n.x))] +kane = _me.KanesMethod(frame_n, q_ind=[q1,q2], u_ind=[u1, u2], kd_eqs = kd_eqs) +fr, frstar = kane.kanes_equations([particle_p, particle_r], forceList) +zero = fr+frstar +from pydy.system import System +sys = System(kane, constants = {l:1, m:1, g:9.81}, +specifieds={}, +initial_conditions={q1:.1, q2:.2, u1:0, u2:0}, +times = _np.linspace(0.0, 10, 10/.01)) + +y=sys.integrate() diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/mass_spring_damper.al b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/mass_spring_damper.al new file mode 100644 index 0000000000000000000000000000000000000000..4892e5ca8cb18cad6b14a2a37cbdc1f7fb8217ac --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/mass_spring_damper.al @@ -0,0 +1,19 @@ +CONSTANTS M,K,B,G +MOTIONVARIABLES' POSITION',SPEED' +VARIABLES O +FORCE = O*SIN(T) +NEWTONIAN CEILING +POINTS ORIGIN +V_ORIGIN_CEILING> = 0> +PARTICLES BLOCK +P_ORIGIN_BLOCK> = POSITION*CEILING1> +MASS BLOCK=M +V_BLOCK_CEILING>=SPEED*CEILING1> +POSITION' = SPEED +FORCE_MAGNITUDE = M*G-K*POSITION-B*SPEED+FORCE +FORCE_BLOCK>=EXPLICIT(FORCE_MAGNITUDE*CEILING1>) +ZERO = FR() + FRSTAR() +KANE() +INPUT TFINAL=10.0, INTEGSTP=0.01 +INPUT M=1.0, K=1.0, B=0.2, G=9.8, POSITION=0.1, SPEED=-1.0, O=2 +CODE DYNAMICS() dummy_file.c diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/mass_spring_damper.py b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/mass_spring_damper.py new file mode 100644 index 0000000000000000000000000000000000000000..8a5baab9642ff140e0ee81027a1e8f9152d7050c --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/mass_spring_damper.py @@ -0,0 +1,31 @@ +import sympy.physics.mechanics as _me +import sympy as _sm +import math as m +import numpy as _np + +m, k, b, g = _sm.symbols('m k b g', real=True) +position, speed = _me.dynamicsymbols('position speed') +position_d, speed_d = _me.dynamicsymbols('position_ speed_', 1) +o = _me.dynamicsymbols('o') +force = o*_sm.sin(_me.dynamicsymbols._t) +frame_ceiling = _me.ReferenceFrame('ceiling') +point_origin = _me.Point('origin') +point_origin.set_vel(frame_ceiling, 0) +particle_block = _me.Particle('block', _me.Point('block_pt'), _sm.Symbol('m')) +particle_block.point.set_pos(point_origin, position*frame_ceiling.x) +particle_block.mass = m +particle_block.point.set_vel(frame_ceiling, speed*frame_ceiling.x) +force_magnitude = m*g-k*position-b*speed+force +force_block = (force_magnitude*frame_ceiling.x).subs({position_d:speed}) +kd_eqs = [position_d - speed] +forceList = [(particle_block.point,(force_magnitude*frame_ceiling.x).subs({position_d:speed}))] +kane = _me.KanesMethod(frame_ceiling, q_ind=[position], u_ind=[speed], kd_eqs = kd_eqs) +fr, frstar = kane.kanes_equations([particle_block], forceList) +zero = fr+frstar +from pydy.system import System +sys = System(kane, constants = {m:1.0, k:1.0, b:0.2, g:9.8}, +specifieds={_me.dynamicsymbols('t'):lambda x, t: t, o:2}, +initial_conditions={position:0.1, speed:-1*1.0}, +times = _np.linspace(0.0, 10.0, 10.0/0.01)) + +y=sys.integrate() diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/non_min_pendulum.al b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/non_min_pendulum.al new file mode 100644 index 0000000000000000000000000000000000000000..74f5062d80926db7acd634a04759abce857087e5 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/non_min_pendulum.al @@ -0,0 +1,20 @@ +MOTIONVARIABLES' Q{2}'' +CONSTANTS L,M,G +NEWTONIAN N +POINT PN +V_PN_N> = 0> +THETA1 = ATAN(Q2/Q1) +FRAMES A +SIMPROT(N, A, 3, THETA1) +PARTICLES P +P_PN_P> = Q1*N1>+Q2*N2> +MASS P=M +V_P_N>=DT(P_P_PN>, N) +F_V = DOT(EXPRESS(V_P_N>,A), A1>) +GRAVITY(G*N1>) +DEPENDENT[1] = F_V +CONSTRAIN(DEPENDENT[Q1']) +ZERO=FR()+FRSTAR() +F_C = MAG(P_P_PN>)-L +CONFIG[1]=F_C +ZERO[2]=CONFIG[1] diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/non_min_pendulum.py b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/non_min_pendulum.py new file mode 100644 index 0000000000000000000000000000000000000000..fc972ebd518e77da5e1902c149f2699979865e7f --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/pydy-example-repo/non_min_pendulum.py @@ -0,0 +1,36 @@ +import sympy.physics.mechanics as _me +import sympy as _sm +import math as m +import numpy as _np + +q1, q2 = _me.dynamicsymbols('q1 q2') +q1_d, q2_d = _me.dynamicsymbols('q1_ q2_', 1) +q1_dd, q2_dd = _me.dynamicsymbols('q1_ q2_', 2) +l, m, g = _sm.symbols('l m g', real=True) +frame_n = _me.ReferenceFrame('n') +point_pn = _me.Point('pn') +point_pn.set_vel(frame_n, 0) +theta1 = _sm.atan(q2/q1) +frame_a = _me.ReferenceFrame('a') +frame_a.orient(frame_n, 'Axis', [theta1, frame_n.z]) +particle_p = _me.Particle('p', _me.Point('p_pt'), _sm.Symbol('m')) +particle_p.point.set_pos(point_pn, q1*frame_n.x+q2*frame_n.y) +particle_p.mass = m +particle_p.point.set_vel(frame_n, (point_pn.pos_from(particle_p.point)).dt(frame_n)) +f_v = _me.dot((particle_p.point.vel(frame_n)).express(frame_a), frame_a.x) +force_p = particle_p.mass*(g*frame_n.x) +dependent = _sm.Matrix([[0]]) +dependent[0] = f_v +velocity_constraints = [i for i in dependent] +u_q1_d = _me.dynamicsymbols('u_q1_d') +u_q2_d = _me.dynamicsymbols('u_q2_d') +kd_eqs = [q1_d-u_q1_d, q2_d-u_q2_d] +forceList = [(particle_p.point,particle_p.mass*(g*frame_n.x))] +kane = _me.KanesMethod(frame_n, q_ind=[q1,q2], u_ind=[u_q2_d], u_dependent=[u_q1_d], kd_eqs = kd_eqs, velocity_constraints = velocity_constraints) +fr, frstar = kane.kanes_equations([particle_p], forceList) +zero = fr+frstar +f_c = point_pn.pos_from(particle_p.point).magnitude()-l +config = _sm.Matrix([[0]]) +config[0] = f_c +zero = zero.row_insert(zero.shape[0], _sm.Matrix([[0]])) +zero[zero.shape[0]-1] = config[0] diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest1.al b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest1.al new file mode 100644 index 0000000000000000000000000000000000000000..457e79fd646677c0decdc69f921bc05e9e0dcf51 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest1.al @@ -0,0 +1,8 @@ +% ruletest1.al +CONSTANTS F = 3, G = 9.81 +CONSTANTS A, B +CONSTANTS S, S1, S2+, S3+, S4- +CONSTANTS K{4}, L{1:3}, P{1:2,1:3} +CONSTANTS C{2,3} +E1 = A*F + S2 - G +E2 = F^2 + K3*K2*G diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest1.py b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest1.py new file mode 100644 index 0000000000000000000000000000000000000000..8466392ac930f13f2419c9c04eef9dcc2884e9bd --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest1.py @@ -0,0 +1,15 @@ +import sympy.physics.mechanics as _me +import sympy as _sm +import math as m +import numpy as _np + +f = _sm.S(3) +g = _sm.S(9.81) +a, b = _sm.symbols('a b', real=True) +s, s1 = _sm.symbols('s s1', real=True) +s2, s3 = _sm.symbols('s2 s3', real=True, nonnegative=True) +s4 = _sm.symbols('s4', real=True, nonpositive=True) +k1, k2, k3, k4, l1, l2, l3, p11, p12, p13, p21, p22, p23 = _sm.symbols('k1 k2 k3 k4 l1 l2 l3 p11 p12 p13 p21 p22 p23', real=True) +c11, c12, c13, c21, c22, c23 = _sm.symbols('c11 c12 c13 c21 c22 c23', real=True) +e1 = a*f+s2-g +e2 = f**2+k3*k2*g diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest10.al b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest10.al new file mode 100644 index 0000000000000000000000000000000000000000..9d5f76f063c43bcb5e2a8d4f29619a6952abf9e5 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest10.al @@ -0,0 +1,58 @@ +% ruletest10.al + +VARIABLES X,Y +COMPLEX ON +CONSTANTS A,B +E = A*(B*X+Y)^2 +M = [E;E] +EXPAND(E) +EXPAND(M) +FACTOR(E,X) +FACTOR(M,X) + +EQN[1] = A*X + B*Y +EQN[2] = 2*A*X - 3*B*Y +SOLVE(EQN, X, Y) +RHS_Y = RHS(Y) +E = (X+Y)^2 + 2*X^2 +ARRANGE(E, 2, X) + +CONSTANTS A,B,C +M = [A,B;C,0] +M2 = EVALUATE(M,A=1,B=2,C=3) +EIG(M2, EIGVALUE, EIGVEC) + +NEWTONIAN N +FRAMES A +SIMPROT(N, A, N1>, X) +DEGREES OFF +SIMPROT(N, A, N1>, PI/2) + +CONSTANTS C{3} +V> = C1*A1> + C2*A2> + C3*A3> +POINTS O, P +P_P_O> = C1*A1> +EXPRESS(V>,N) +EXPRESS(P_P_O>,N) +W_A_N> = C3*A3> +ANGVEL(A,N) + +V2PTS(N,A,O,P) +PARTICLES P{2} +V2PTS(N,A,P1,P2) +A2PTS(N,A,P1,P) + +BODIES B{2} +CONSTANT G +GRAVITY(G*N1>) + +VARIABLE Z +V> = X*A1> + Y*A3> +P_P_O> = X*A1> + Y*A2> +X = 2*Z +Y = Z +EXPLICIT(V>) +EXPLICIT(P_P_O>) + +FORCE(O/P1, X*Y*A1>) +FORCE(P2, X*Y*A1>) diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest10.py b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest10.py new file mode 100644 index 0000000000000000000000000000000000000000..2b9674e47d5f6132c5a79a33b9d8d55a131942d6 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest10.py @@ -0,0 +1,64 @@ +import sympy.physics.mechanics as _me +import sympy as _sm +import math as m +import numpy as _np + +x, y = _me.dynamicsymbols('x y') +a, b = _sm.symbols('a b', real=True) +e = a*(b*x+y)**2 +m = _sm.Matrix([e,e]).reshape(2, 1) +e = e.expand() +m = _sm.Matrix([i.expand() for i in m]).reshape((m).shape[0], (m).shape[1]) +e = _sm.factor(e, x) +m = _sm.Matrix([_sm.factor(i,x) for i in m]).reshape((m).shape[0], (m).shape[1]) +eqn = _sm.Matrix([[0]]) +eqn[0] = a*x+b*y +eqn = eqn.row_insert(eqn.shape[0], _sm.Matrix([[0]])) +eqn[eqn.shape[0]-1] = 2*a*x-3*b*y +print(_sm.solve(eqn,x,y)) +rhs_y = _sm.solve(eqn,x,y)[y] +e = (x+y)**2+2*x**2 +e.collect(x) +a, b, c = _sm.symbols('a b c', real=True) +m = _sm.Matrix([a,b,c,0]).reshape(2, 2) +m2 = _sm.Matrix([i.subs({a:1,b:2,c:3}) for i in m]).reshape((m).shape[0], (m).shape[1]) +eigvalue = _sm.Matrix([i.evalf() for i in (m2).eigenvals().keys()]) +eigvec = _sm.Matrix([i[2][0].evalf() for i in (m2).eigenvects()]).reshape(m2.shape[0], m2.shape[1]) +frame_n = _me.ReferenceFrame('n') +frame_a = _me.ReferenceFrame('a') +frame_a.orient(frame_n, 'Axis', [x, frame_n.x]) +frame_a.orient(frame_n, 'Axis', [_sm.pi/2, frame_n.x]) +c1, c2, c3 = _sm.symbols('c1 c2 c3', real=True) +v = c1*frame_a.x+c2*frame_a.y+c3*frame_a.z +point_o = _me.Point('o') +point_p = _me.Point('p') +point_o.set_pos(point_p, c1*frame_a.x) +v = (v).express(frame_n) +point_o.set_pos(point_p, (point_o.pos_from(point_p)).express(frame_n)) +frame_a.set_ang_vel(frame_n, c3*frame_a.z) +print(frame_n.ang_vel_in(frame_a)) +point_p.v2pt_theory(point_o,frame_n,frame_a) +particle_p1 = _me.Particle('p1', _me.Point('p1_pt'), _sm.Symbol('m')) +particle_p2 = _me.Particle('p2', _me.Point('p2_pt'), _sm.Symbol('m')) +particle_p2.point.v2pt_theory(particle_p1.point,frame_n,frame_a) +point_p.a2pt_theory(particle_p1.point,frame_n,frame_a) +body_b1_cm = _me.Point('b1_cm') +body_b1_cm.set_vel(frame_n, 0) +body_b1_f = _me.ReferenceFrame('b1_f') +body_b1 = _me.RigidBody('b1', body_b1_cm, body_b1_f, _sm.symbols('m'), (_me.outer(body_b1_f.x,body_b1_f.x),body_b1_cm)) +body_b2_cm = _me.Point('b2_cm') +body_b2_cm.set_vel(frame_n, 0) +body_b2_f = _me.ReferenceFrame('b2_f') +body_b2 = _me.RigidBody('b2', body_b2_cm, body_b2_f, _sm.symbols('m'), (_me.outer(body_b2_f.x,body_b2_f.x),body_b2_cm)) +g = _sm.symbols('g', real=True) +force_p1 = particle_p1.mass*(g*frame_n.x) +force_p2 = particle_p2.mass*(g*frame_n.x) +force_b1 = body_b1.mass*(g*frame_n.x) +force_b2 = body_b2.mass*(g*frame_n.x) +z = _me.dynamicsymbols('z') +v = x*frame_a.x+y*frame_a.z +point_o.set_pos(point_p, x*frame_a.x+y*frame_a.y) +v = (v).subs({x:2*z, y:z}) +point_o.set_pos(point_p, (point_o.pos_from(point_p)).subs({x:2*z, y:z})) +force_o = -1*(x*y*frame_a.x) +force_p1 = particle_p1.mass*(g*frame_n.x)+ x*y*frame_a.x diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest11.al b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest11.al new file mode 100644 index 0000000000000000000000000000000000000000..60934c1ca563024828110bfe984a90d5686b89e4 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest11.al @@ -0,0 +1,6 @@ +VARIABLES X, Y +CONSTANTS A{1:2, 1:2}, B{1:2} +EQN[1] = A11*x + A12*y - B1 +EQN[2] = A21*x + A22*y - B2 +INPUT A11=2, A12=5, A21=3, A22=4, B1=7, B2=6 +CODE ALGEBRAIC(EQN, X, Y) some_filename.c diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest11.py b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest11.py new file mode 100644 index 0000000000000000000000000000000000000000..4ec2397ea96261d7b582d1f699e3897caae88f20 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest11.py @@ -0,0 +1,14 @@ +import sympy.physics.mechanics as _me +import sympy as _sm +import math as m +import numpy as _np + +x, y = _me.dynamicsymbols('x y') +a11, a12, a21, a22, b1, b2 = _sm.symbols('a11 a12 a21 a22 b1 b2', real=True) +eqn = _sm.Matrix([[0]]) +eqn[0] = a11*x+a12*y-b1 +eqn = eqn.row_insert(eqn.shape[0], _sm.Matrix([[0]])) +eqn[eqn.shape[0]-1] = a21*x+a22*y-b2 +eqn_list = [] +for i in eqn: eqn_list.append(i.subs({a11:2, a12:5, a21:3, a22:4, b1:7, b2:6})) +print(_sm.linsolve(eqn_list, x,y)) diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest12.al b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest12.al new file mode 100644 index 0000000000000000000000000000000000000000..f147f55afd1438436767960e0487d5d9e7161c8f --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest12.al @@ -0,0 +1,7 @@ +VARIABLES X,Y +CONSTANTS A,B,R +EQN[1] = A*X^3+B*Y^2-R +EQN[2] = A*SIN(X)^2 + B*COS(2*Y) - R^2 +INPUT A=2.0, B=3.0, R=1.0 +INPUT X = 30 DEG, Y = 3.14 +CODE NONLINEAR(EQN,X,Y) some_filename.c diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest12.py b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest12.py new file mode 100644 index 0000000000000000000000000000000000000000..3d7d996fa649f796a536dba20c1a36554acd8046 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest12.py @@ -0,0 +1,14 @@ +import sympy.physics.mechanics as _me +import sympy as _sm +import math as m +import numpy as _np + +x, y = _me.dynamicsymbols('x y') +a, b, r = _sm.symbols('a b r', real=True) +eqn = _sm.Matrix([[0]]) +eqn[0] = a*x**3+b*y**2-r +eqn = eqn.row_insert(eqn.shape[0], _sm.Matrix([[0]])) +eqn[eqn.shape[0]-1] = a*_sm.sin(x)**2+b*_sm.cos(2*y)-r**2 +matrix_list = [] +for i in eqn:matrix_list.append(i.subs({a:2.0, b:3.0, r:1.0})) +print(_sm.nsolve(matrix_list,(x,y),(_np.deg2rad(30),3.14))) diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest2.al b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest2.al new file mode 100644 index 0000000000000000000000000000000000000000..17937e58bd20a9fb82f44ccd05f0c081a1aa6c9b --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest2.al @@ -0,0 +1,12 @@ +% ruletest2.al +VARIABLES X1,X2 +SPECIFIED F1 = X1*X2 + 3*X1^2 +SPECIFIED F2=X1*T+X2*T^2 +VARIABLE X', Y'' +MOTIONVARIABLES Q{3}, U{2} +VARIABLES P{2}' +VARIABLE W{3}', R{2}'' +VARIABLES C{1:2, 1:2} +VARIABLES D{1,3} +VARIABLES J{1:2} +IMAGINARY N diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest2.py b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest2.py new file mode 100644 index 0000000000000000000000000000000000000000..31c1d9974c2292466b805b91f8254bffaa94e2ac --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest2.py @@ -0,0 +1,22 @@ +import sympy.physics.mechanics as _me +import sympy as _sm +import math as m +import numpy as _np + +x1, x2 = _me.dynamicsymbols('x1 x2') +f1 = x1*x2+3*x1**2 +f2 = x1*_me.dynamicsymbols._t+x2*_me.dynamicsymbols._t**2 +x, y = _me.dynamicsymbols('x y') +x_d, y_d = _me.dynamicsymbols('x_ y_', 1) +y_dd = _me.dynamicsymbols('y_', 2) +q1, q2, q3, u1, u2 = _me.dynamicsymbols('q1 q2 q3 u1 u2') +p1, p2 = _me.dynamicsymbols('p1 p2') +p1_d, p2_d = _me.dynamicsymbols('p1_ p2_', 1) +w1, w2, w3, r1, r2 = _me.dynamicsymbols('w1 w2 w3 r1 r2') +w1_d, w2_d, w3_d, r1_d, r2_d = _me.dynamicsymbols('w1_ w2_ w3_ r1_ r2_', 1) +r1_dd, r2_dd = _me.dynamicsymbols('r1_ r2_', 2) +c11, c12, c21, c22 = _me.dynamicsymbols('c11 c12 c21 c22') +d11, d12, d13 = _me.dynamicsymbols('d11 d12 d13') +j1, j2 = _me.dynamicsymbols('j1 j2') +n = _sm.symbols('n') +n = _sm.I diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest3.al b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest3.al new file mode 100644 index 0000000000000000000000000000000000000000..f263f1802ebca2725481dd5fdd3540bf8e9f11bf --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest3.al @@ -0,0 +1,25 @@ +% ruletest3.al +FRAMES A, B +NEWTONIAN N + +VARIABLES X{3} +CONSTANTS L + +V1> = X1*A1> + X2*A2> + X3*A3> +V2> = X1*B1> + X2*B2> + X3*B3> +V3> = X1*N1> + X2*N2> + X3*N3> + +V> = V1> + V2> + V3> + +POINTS C, D +POINTS PO{3} + +PARTICLES L +PARTICLES P{3} + +BODIES S +BODIES R{2} + +V4> = X1*S1> + X2*S2> + X3*S3> + +P_C_SO> = L*N1> diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest3.py b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest3.py new file mode 100644 index 0000000000000000000000000000000000000000..23f79aa571337f200b3ff4d56b5747f7704985c0 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest3.py @@ -0,0 +1,37 @@ +import sympy.physics.mechanics as _me +import sympy as _sm +import math as m +import numpy as _np + +frame_a = _me.ReferenceFrame('a') +frame_b = _me.ReferenceFrame('b') +frame_n = _me.ReferenceFrame('n') +x1, x2, x3 = _me.dynamicsymbols('x1 x2 x3') +l = _sm.symbols('l', real=True) +v1 = x1*frame_a.x+x2*frame_a.y+x3*frame_a.z +v2 = x1*frame_b.x+x2*frame_b.y+x3*frame_b.z +v3 = x1*frame_n.x+x2*frame_n.y+x3*frame_n.z +v = v1+v2+v3 +point_c = _me.Point('c') +point_d = _me.Point('d') +point_po1 = _me.Point('po1') +point_po2 = _me.Point('po2') +point_po3 = _me.Point('po3') +particle_l = _me.Particle('l', _me.Point('l_pt'), _sm.Symbol('m')) +particle_p1 = _me.Particle('p1', _me.Point('p1_pt'), _sm.Symbol('m')) +particle_p2 = _me.Particle('p2', _me.Point('p2_pt'), _sm.Symbol('m')) +particle_p3 = _me.Particle('p3', _me.Point('p3_pt'), _sm.Symbol('m')) +body_s_cm = _me.Point('s_cm') +body_s_cm.set_vel(frame_n, 0) +body_s_f = _me.ReferenceFrame('s_f') +body_s = _me.RigidBody('s', body_s_cm, body_s_f, _sm.symbols('m'), (_me.outer(body_s_f.x,body_s_f.x),body_s_cm)) +body_r1_cm = _me.Point('r1_cm') +body_r1_cm.set_vel(frame_n, 0) +body_r1_f = _me.ReferenceFrame('r1_f') +body_r1 = _me.RigidBody('r1', body_r1_cm, body_r1_f, _sm.symbols('m'), (_me.outer(body_r1_f.x,body_r1_f.x),body_r1_cm)) +body_r2_cm = _me.Point('r2_cm') +body_r2_cm.set_vel(frame_n, 0) +body_r2_f = _me.ReferenceFrame('r2_f') +body_r2 = _me.RigidBody('r2', body_r2_cm, body_r2_f, _sm.symbols('m'), (_me.outer(body_r2_f.x,body_r2_f.x),body_r2_cm)) +v4 = x1*body_s_f.x+x2*body_s_f.y+x3*body_s_f.z +body_s_cm.set_pos(point_c, l*frame_n.x) diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest4.al b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest4.al new file mode 100644 index 0000000000000000000000000000000000000000..7302bd7724bad9b763c75fe4230faa42b5070408 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest4.al @@ -0,0 +1,20 @@ +% ruletest4.al + +FRAMES A, B +MOTIONVARIABLES Q{3} +SIMPROT(A, B, 1, Q3) +DCM = A_B +M = DCM*3 - A_B + +VARIABLES R +CIRCLE_AREA = PI*R^2 + +VARIABLES U, A +VARIABLES X, Y +S = U*T - 1/2*A*T^2 + +EXPR1 = 2*A*0.5 - 1.25 + 0.25 +EXPR2 = -X^2 + Y^2 + 0.25*(X+Y)^2 +EXPR3 = 0.5E-10 + +DYADIC>> = A1>*A1> + A2>*A2> + A3>*A3> diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest4.py b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest4.py new file mode 100644 index 0000000000000000000000000000000000000000..74b18543e04d6c9e42dd569d2152040c13ae0899 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest4.py @@ -0,0 +1,20 @@ +import sympy.physics.mechanics as _me +import sympy as _sm +import math as m +import numpy as _np + +frame_a = _me.ReferenceFrame('a') +frame_b = _me.ReferenceFrame('b') +q1, q2, q3 = _me.dynamicsymbols('q1 q2 q3') +frame_b.orient(frame_a, 'Axis', [q3, frame_a.x]) +dcm = frame_a.dcm(frame_b) +m = dcm*3-frame_a.dcm(frame_b) +r = _me.dynamicsymbols('r') +circle_area = _sm.pi*r**2 +u, a = _me.dynamicsymbols('u a') +x, y = _me.dynamicsymbols('x y') +s = u*_me.dynamicsymbols._t-1/2*a*_me.dynamicsymbols._t**2 +expr1 = 2*a*0.5-1.25+0.25 +expr2 = -1*x**2+y**2+0.25*(x+y)**2 +expr3 = 0.5*10**(-10) +dyadic = _me.outer(frame_a.x, frame_a.x)+_me.outer(frame_a.y, frame_a.y)+_me.outer(frame_a.z, frame_a.z) diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest5.al b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest5.al new file mode 100644 index 0000000000000000000000000000000000000000..a859dc8bb1f0251af14809681d995c59b31377ba --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest5.al @@ -0,0 +1,32 @@ +% ruletest5.al +VARIABLES X', Y' + +E1 = (X+Y)^2 + (X-Y)^3 +E2 = (X-Y)^2 +E3 = X^2 + Y^2 + 2*X*Y + +M1 = [E1;E2] +M2 = [(X+Y)^2,(X-Y)^2] +M3 = M1 + [X;Y] + +AM = EXPAND(M1) +CM = EXPAND([(X+Y)^2,(X-Y)^2]) +EM = EXPAND(M1 + [X;Y]) +F = EXPAND(E1) +G = EXPAND(E2) + +A = FACTOR(E3, X) +BM = FACTOR(M1, X) +CM = FACTOR(M1 + [X;Y], X) + +A = D(E3, X) +B = D(E3, Y) +CM = D(M2, X) +DM = D(M1 + [X;Y], X) +FRAMES A, B +A_B = [1,0,0;1,0,0;1,0,0] +V1> = X*A1> + Y*A2> + X*Y*A3> +E> = D(V1>, X, B) +FM = DT(M1) +GM = DT([(X+Y)^2,(X-Y)^2]) +H> = DT(V1>, B) diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest5.py b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest5.py new file mode 100644 index 0000000000000000000000000000000000000000..93684435b402f5b56e2f4a5c3c81500208556423 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest5.py @@ -0,0 +1,33 @@ +import sympy.physics.mechanics as _me +import sympy as _sm +import math as m +import numpy as _np + +x, y = _me.dynamicsymbols('x y') +x_d, y_d = _me.dynamicsymbols('x_ y_', 1) +e1 = (x+y)**2+(x-y)**3 +e2 = (x-y)**2 +e3 = x**2+y**2+2*x*y +m1 = _sm.Matrix([e1,e2]).reshape(2, 1) +m2 = _sm.Matrix([(x+y)**2,(x-y)**2]).reshape(1, 2) +m3 = m1+_sm.Matrix([x,y]).reshape(2, 1) +am = _sm.Matrix([i.expand() for i in m1]).reshape((m1).shape[0], (m1).shape[1]) +cm = _sm.Matrix([i.expand() for i in _sm.Matrix([(x+y)**2,(x-y)**2]).reshape(1, 2)]).reshape((_sm.Matrix([(x+y)**2,(x-y)**2]).reshape(1, 2)).shape[0], (_sm.Matrix([(x+y)**2,(x-y)**2]).reshape(1, 2)).shape[1]) +em = _sm.Matrix([i.expand() for i in m1+_sm.Matrix([x,y]).reshape(2, 1)]).reshape((m1+_sm.Matrix([x,y]).reshape(2, 1)).shape[0], (m1+_sm.Matrix([x,y]).reshape(2, 1)).shape[1]) +f = (e1).expand() +g = (e2).expand() +a = _sm.factor((e3), x) +bm = _sm.Matrix([_sm.factor(i, x) for i in m1]).reshape((m1).shape[0], (m1).shape[1]) +cm = _sm.Matrix([_sm.factor(i, x) for i in m1+_sm.Matrix([x,y]).reshape(2, 1)]).reshape((m1+_sm.Matrix([x,y]).reshape(2, 1)).shape[0], (m1+_sm.Matrix([x,y]).reshape(2, 1)).shape[1]) +a = (e3).diff(x) +b = (e3).diff(y) +cm = _sm.Matrix([i.diff(x) for i in m2]).reshape((m2).shape[0], (m2).shape[1]) +dm = _sm.Matrix([i.diff(x) for i in m1+_sm.Matrix([x,y]).reshape(2, 1)]).reshape((m1+_sm.Matrix([x,y]).reshape(2, 1)).shape[0], (m1+_sm.Matrix([x,y]).reshape(2, 1)).shape[1]) +frame_a = _me.ReferenceFrame('a') +frame_b = _me.ReferenceFrame('b') +frame_b.orient(frame_a, 'DCM', _sm.Matrix([1,0,0,1,0,0,1,0,0]).reshape(3, 3)) +v1 = x*frame_a.x+y*frame_a.y+x*y*frame_a.z +e = (v1).diff(x, frame_b) +fm = _sm.Matrix([i.diff(_sm.Symbol('t')) for i in m1]).reshape((m1).shape[0], (m1).shape[1]) +gm = _sm.Matrix([i.diff(_sm.Symbol('t')) for i in _sm.Matrix([(x+y)**2,(x-y)**2]).reshape(1, 2)]).reshape((_sm.Matrix([(x+y)**2,(x-y)**2]).reshape(1, 2)).shape[0], (_sm.Matrix([(x+y)**2,(x-y)**2]).reshape(1, 2)).shape[1]) +h = (v1).dt(frame_b) diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest6.al b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest6.al new file mode 100644 index 0000000000000000000000000000000000000000..7ec3ba61590e77772ae631237df048b932fe778c --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest6.al @@ -0,0 +1,41 @@ +% ruletest6.al +VARIABLES Q{2} +VARIABLES X,Y,Z +Q1 = X^2 + Y^2 +Q2 = X-Y +E = Q1 + Q2 +A = EXPLICIT(E) +E2 = COS(X) +E3 = COS(X*Y) +A = TAYLOR(E2, 0:2, X=0) +B = TAYLOR(E3, 0:2, X=0, Y=0) + +E = EXPAND((X+Y)^2) +A = EVALUATE(E, X=1, Y=Z) +BM = EVALUATE([E;2*E], X=1, Y=Z) + +E = Q1 + Q2 +A = EVALUATE(E, X=2, Y=Z^2) + +CONSTANTS J,K,L +P1 = POLYNOMIAL([J,K,L],X) +P2 = POLYNOMIAL(J*X+K,X,1) + +ROOT1 = ROOTS(P1, X, 2) +ROOT2 = ROOTS([1;2;3]) + +M = [1,2,3,4;5,6,7,8;9,10,11,12;13,14,15,16] + +AM = TRANSPOSE(M) + M +BM = EIG(M) +C1 = DIAGMAT(4, 1) +C2 = DIAGMAT(3, 4, 2) +DM = INV(M+C1) +E = DET(M+C1) + TRACE([1,0;0,1]) +F = ELEMENT(M, 2, 3) + +A = COLS(M) +BM = COLS(M, 1) +CM = COLS(M, 1, 2:4, 3) +DM = ROWS(M, 1) +EM = ROWS(M, 1, 2:4, 3) diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest6.py b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest6.py new file mode 100644 index 0000000000000000000000000000000000000000..85f1a0b49518bb0ae5766cbe91b9c24a1b8e9c20 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest6.py @@ -0,0 +1,36 @@ +import sympy.physics.mechanics as _me +import sympy as _sm +import math as m +import numpy as _np + +q1, q2 = _me.dynamicsymbols('q1 q2') +x, y, z = _me.dynamicsymbols('x y z') +e = q1+q2 +a = (e).subs({q1:x**2+y**2, q2:x-y}) +e2 = _sm.cos(x) +e3 = _sm.cos(x*y) +a = (e2).series(x, 0, 2).removeO() +b = (e3).series(x, 0, 2).removeO().series(y, 0, 2).removeO() +e = ((x+y)**2).expand() +a = (e).subs({q1:x**2+y**2,q2:x-y}).subs({x:1,y:z}) +bm = _sm.Matrix([i.subs({x:1,y:z}) for i in _sm.Matrix([e,2*e]).reshape(2, 1)]).reshape((_sm.Matrix([e,2*e]).reshape(2, 1)).shape[0], (_sm.Matrix([e,2*e]).reshape(2, 1)).shape[1]) +e = q1+q2 +a = (e).subs({q1:x**2+y**2,q2:x-y}).subs({x:2,y:z**2}) +j, k, l = _sm.symbols('j k l', real=True) +p1 = _sm.Poly(_sm.Matrix([j,k,l]).reshape(1, 3), x) +p2 = _sm.Poly(j*x+k, x) +root1 = [i.evalf() for i in _sm.solve(p1, x)] +root2 = [i.evalf() for i in _sm.solve(_sm.Poly(_sm.Matrix([1,2,3]).reshape(3, 1), x),x)] +m = _sm.Matrix([1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16]).reshape(4, 4) +am = (m).T+m +bm = _sm.Matrix([i.evalf() for i in (m).eigenvals().keys()]) +c1 = _sm.diag(1,1,1,1) +c2 = _sm.Matrix([2 if i==j else 0 for i in range(3) for j in range(4)]).reshape(3, 4) +dm = (m+c1)**(-1) +e = (m+c1).det()+(_sm.Matrix([1,0,0,1]).reshape(2, 2)).trace() +f = (m)[1,2] +a = (m).cols +bm = (m).col(0) +cm = _sm.Matrix([(m).T.row(0),(m).T.row(1),(m).T.row(2),(m).T.row(3),(m).T.row(2)]) +dm = (m).row(0) +em = _sm.Matrix([(m).row(0),(m).row(1),(m).row(2),(m).row(3),(m).row(2)]) diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest7.al b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest7.al new file mode 100644 index 0000000000000000000000000000000000000000..2904a602f589645d22e1d3d378d077dd6a1ec27e --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest7.al @@ -0,0 +1,39 @@ +% ruletest7.al +VARIABLES X', Y' +E = COS(X) + SIN(X) + TAN(X)& ++ COSH(X) + SINH(X) + TANH(X)& ++ ACOS(X) + ASIN(X) + ATAN(X)& ++ LOG(X) + EXP(X) + SQRT(X)& ++ FACTORIAL(X) + CEIL(X) +& +FLOOR(X) + SIGN(X) + +E = SQR(X) + LOG10(X) + +A = ABS(-1) + INT(1.5) + ROUND(1.9) + +E1 = 2*X + 3*Y +E2 = X + Y + +AM = COEF([E1;E2], [X,Y]) +B = COEF(E1, X) +C = COEF(E2, Y) +D1 = EXCLUDE(E1, X) +D2 = INCLUDE(E1, X) +FM = ARRANGE([E1,E2],2,X) +F = ARRANGE(E1, 2, Y) +G = REPLACE(E1, X=2*X) +GM = REPLACE([E1;E2], X=3) + +FRAMES A, B +VARIABLES THETA +SIMPROT(A,B,3,THETA) +V1> = 2*A1> - 3*A2> + A3> +V2> = B1> + B2> + B3> +A = DOT(V1>, V2>) +BM = DOT(V1>, [V2>;2*V2>]) +C> = CROSS(V1>,V2>) +D = MAG(2*V1>) + MAG(3*V1>) +DYADIC>> = 3*A1>*A1> + A2>*A2> + 2*A3>*A3> +AM = MATRIX(B, DYADIC>>) +M = [1;2;3] +V> = VECTOR(A, M) diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest7.py b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest7.py new file mode 100644 index 0000000000000000000000000000000000000000..19147856dc3b0d451184a6bb539c1c331f61a6d2 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest7.py @@ -0,0 +1,35 @@ +import sympy.physics.mechanics as _me +import sympy as _sm +import math as m +import numpy as _np + +x, y = _me.dynamicsymbols('x y') +x_d, y_d = _me.dynamicsymbols('x_ y_', 1) +e = _sm.cos(x)+_sm.sin(x)+_sm.tan(x)+_sm.cosh(x)+_sm.sinh(x)+_sm.tanh(x)+_sm.acos(x)+_sm.asin(x)+_sm.atan(x)+_sm.log(x)+_sm.exp(x)+_sm.sqrt(x)+_sm.factorial(x)+_sm.ceiling(x)+_sm.floor(x)+_sm.sign(x) +e = (x)**2+_sm.log(x, 10) +a = _sm.Abs(-1*1)+int(1.5)+round(1.9) +e1 = 2*x+3*y +e2 = x+y +am = _sm.Matrix([e1.expand().coeff(x), e1.expand().coeff(y), e2.expand().coeff(x), e2.expand().coeff(y)]).reshape(2, 2) +b = (e1).expand().coeff(x) +c = (e2).expand().coeff(y) +d1 = (e1).collect(x).coeff(x,0) +d2 = (e1).collect(x).coeff(x,1) +fm = _sm.Matrix([i.collect(x)for i in _sm.Matrix([e1,e2]).reshape(1, 2)]).reshape((_sm.Matrix([e1,e2]).reshape(1, 2)).shape[0], (_sm.Matrix([e1,e2]).reshape(1, 2)).shape[1]) +f = (e1).collect(y) +g = (e1).subs({x:2*x}) +gm = _sm.Matrix([i.subs({x:3}) for i in _sm.Matrix([e1,e2]).reshape(2, 1)]).reshape((_sm.Matrix([e1,e2]).reshape(2, 1)).shape[0], (_sm.Matrix([e1,e2]).reshape(2, 1)).shape[1]) +frame_a = _me.ReferenceFrame('a') +frame_b = _me.ReferenceFrame('b') +theta = _me.dynamicsymbols('theta') +frame_b.orient(frame_a, 'Axis', [theta, frame_a.z]) +v1 = 2*frame_a.x-3*frame_a.y+frame_a.z +v2 = frame_b.x+frame_b.y+frame_b.z +a = _me.dot(v1, v2) +bm = _sm.Matrix([_me.dot(v1, v2),_me.dot(v1, 2*v2)]).reshape(2, 1) +c = _me.cross(v1, v2) +d = 2*v1.magnitude()+3*v1.magnitude() +dyadic = _me.outer(3*frame_a.x, frame_a.x)+_me.outer(frame_a.y, frame_a.y)+_me.outer(2*frame_a.z, frame_a.z) +am = (dyadic).to_matrix(frame_b) +m = _sm.Matrix([1,2,3]).reshape(3, 1) +v = m[0]*frame_a.x +m[1]*frame_a.y +m[2]*frame_a.z diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest8.al b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest8.al new file mode 100644 index 0000000000000000000000000000000000000000..4b2462c51e6730f46bf60b4b21ab6cfbf1993640 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest8.al @@ -0,0 +1,38 @@ +% ruletest8.al +FRAMES A +CONSTANTS C{3} +A>> = EXPRESS(1>>,A) +PARTICLES P1, P2 +BODIES R +R_A = [1,1,1;1,1,0;0,0,1] +POINT O +MASS P1=M1, P2=M2, R=MR +INERTIA R, I1, I2, I3 +P_P1_O> = C1*A1> +P_P2_O> = C2*A2> +P_RO_O> = C3*A3> +A>> = EXPRESS(I_P1_O>>, A) +A>> = EXPRESS(I_P2_O>>, A) +A>> = EXPRESS(I_R_O>>, A) +A>> = EXPRESS(INERTIA(O), A) +A>> = EXPRESS(INERTIA(O, P1, R), A) +A>> = EXPRESS(I_R_O>>, A) +A>> = EXPRESS(I_R_RO>>, A) + +P_P1_P2> = C1*A1> + C2*A2> +P_P1_RO> = C3*A1> +P_P2_RO> = C3*A2> + +B> = CM(O) +B> = CM(O, P1, R) +B> = CM(P1) + +MOTIONVARIABLES U{3} +V> = U1*A1> + U2*A2> + U3*A3> +U> = UNITVEC(V> + C1*A1>) +V_P1_A> = U1*A1> +A> = PARTIALS(V_P1_A>, U1) + +M = MASS(P1,R) +M = MASS(P2) +M = MASS() \ No newline at end of file diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest8.py b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest8.py new file mode 100644 index 0000000000000000000000000000000000000000..6809c47138e40027c700536e807ca7cfa5f468d7 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest8.py @@ -0,0 +1,49 @@ +import sympy.physics.mechanics as _me +import sympy as _sm +import math as m +import numpy as _np + +frame_a = _me.ReferenceFrame('a') +c1, c2, c3 = _sm.symbols('c1 c2 c3', real=True) +a = _me.inertia(frame_a, 1, 1, 1) +particle_p1 = _me.Particle('p1', _me.Point('p1_pt'), _sm.Symbol('m')) +particle_p2 = _me.Particle('p2', _me.Point('p2_pt'), _sm.Symbol('m')) +body_r_cm = _me.Point('r_cm') +body_r_f = _me.ReferenceFrame('r_f') +body_r = _me.RigidBody('r', body_r_cm, body_r_f, _sm.symbols('m'), (_me.outer(body_r_f.x,body_r_f.x),body_r_cm)) +frame_a.orient(body_r_f, 'DCM', _sm.Matrix([1,1,1,1,1,0,0,0,1]).reshape(3, 3)) +point_o = _me.Point('o') +m1 = _sm.symbols('m1') +particle_p1.mass = m1 +m2 = _sm.symbols('m2') +particle_p2.mass = m2 +mr = _sm.symbols('mr') +body_r.mass = mr +i1 = _sm.symbols('i1') +i2 = _sm.symbols('i2') +i3 = _sm.symbols('i3') +body_r.inertia = (_me.inertia(body_r_f, i1, i2, i3, 0, 0, 0), body_r_cm) +point_o.set_pos(particle_p1.point, c1*frame_a.x) +point_o.set_pos(particle_p2.point, c2*frame_a.y) +point_o.set_pos(body_r_cm, c3*frame_a.z) +a = _me.inertia_of_point_mass(particle_p1.mass, particle_p1.point.pos_from(point_o), frame_a) +a = _me.inertia_of_point_mass(particle_p2.mass, particle_p2.point.pos_from(point_o), frame_a) +a = body_r.inertia[0] + _me.inertia_of_point_mass(body_r.mass, body_r.masscenter.pos_from(point_o), frame_a) +a = _me.inertia_of_point_mass(particle_p1.mass, particle_p1.point.pos_from(point_o), frame_a) + _me.inertia_of_point_mass(particle_p2.mass, particle_p2.point.pos_from(point_o), frame_a) + body_r.inertia[0] + _me.inertia_of_point_mass(body_r.mass, body_r.masscenter.pos_from(point_o), frame_a) +a = _me.inertia_of_point_mass(particle_p1.mass, particle_p1.point.pos_from(point_o), frame_a) + body_r.inertia[0] + _me.inertia_of_point_mass(body_r.mass, body_r.masscenter.pos_from(point_o), frame_a) +a = body_r.inertia[0] + _me.inertia_of_point_mass(body_r.mass, body_r.masscenter.pos_from(point_o), frame_a) +a = body_r.inertia[0] +particle_p2.point.set_pos(particle_p1.point, c1*frame_a.x+c2*frame_a.y) +body_r_cm.set_pos(particle_p1.point, c3*frame_a.x) +body_r_cm.set_pos(particle_p2.point, c3*frame_a.y) +b = _me.functions.center_of_mass(point_o,particle_p1, particle_p2, body_r) +b = _me.functions.center_of_mass(point_o,particle_p1, body_r) +b = _me.functions.center_of_mass(particle_p1.point,particle_p1, particle_p2, body_r) +u1, u2, u3 = _me.dynamicsymbols('u1 u2 u3') +v = u1*frame_a.x+u2*frame_a.y+u3*frame_a.z +u = (v+c1*frame_a.x).normalize() +particle_p1.point.set_vel(frame_a, u1*frame_a.x) +a = particle_p1.point.partial_velocity(frame_a, u1) +m = particle_p1.mass+body_r.mass +m = particle_p2.mass +m = particle_p1.mass+particle_p2.mass+body_r.mass diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest9.al b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest9.al new file mode 100644 index 0000000000000000000000000000000000000000..df5c70f05b76fc215f829672e281491b0c96c6a6 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest9.al @@ -0,0 +1,54 @@ +% ruletest9.al +NEWTONIAN N +FRAMES A +A> = 0> +D>> = EXPRESS(1>>, A) + +POINTS PO{2} +PARTICLES P{2} +MOTIONVARIABLES' C{3}' +BODIES R +P_P1_PO2> = C1*A1> +V> = 2*P_P1_PO2> + C2*A2> + +W_A_N> = C3*A3> +V> = 2*W_A_N> + C2*A2> +W_R_N> = C3*A3> +V> = 2*W_R_N> + C2*A2> + +ALF_A_N> = DT(W_A_N>, A) +V> = 2*ALF_A_N> + C2*A2> + +V_P1_A> = C1*A1> + C3*A2> +A_RO_N> = C2*A2> +V_A> = CROSS(A_RO_N>, V_P1_A>) + +X_B_C> = V_A> +X_B_D> = 2*X_B_C> +A_B_C_D_E> = X_B_D>*2 + +A_B_C = 2*C1*C2*C3 +A_B_C += 2*C1 +A_B_C := 3*C1 + +MOTIONVARIABLES' Q{2}', U{2}' +Q1' = U1 +Q2' = U2 + +VARIABLES X'', Y'' +SPECIFIED YY +Y'' = X*X'^2 + 1 +YY = X*X'^2 + 1 + +M[1] = 2*X +M[2] = 2*Y +A = 2*M[1] + +M = [1,2,3;4,5,6;7,8,9] +M[1, 2] = 5 +A = M[1, 2]*2 + +FORCE_RO> = Q1*N1> +TORQUE_A> = Q2*N3> +FORCE_RO> = Q2*N2> +F> = FORCE_RO>*2 diff --git a/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest9.py b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest9.py new file mode 100644 index 0000000000000000000000000000000000000000..09d8ae4ee8385bde5c38b946458a43c8ffdaa9b8 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/autolev/test-examples/ruletest9.py @@ -0,0 +1,55 @@ +import sympy.physics.mechanics as _me +import sympy as _sm +import math as m +import numpy as _np + +frame_n = _me.ReferenceFrame('n') +frame_a = _me.ReferenceFrame('a') +a = 0 +d = _me.inertia(frame_a, 1, 1, 1) +point_po1 = _me.Point('po1') +point_po2 = _me.Point('po2') +particle_p1 = _me.Particle('p1', _me.Point('p1_pt'), _sm.Symbol('m')) +particle_p2 = _me.Particle('p2', _me.Point('p2_pt'), _sm.Symbol('m')) +c1, c2, c3 = _me.dynamicsymbols('c1 c2 c3') +c1_d, c2_d, c3_d = _me.dynamicsymbols('c1_ c2_ c3_', 1) +body_r_cm = _me.Point('r_cm') +body_r_cm.set_vel(frame_n, 0) +body_r_f = _me.ReferenceFrame('r_f') +body_r = _me.RigidBody('r', body_r_cm, body_r_f, _sm.symbols('m'), (_me.outer(body_r_f.x,body_r_f.x),body_r_cm)) +point_po2.set_pos(particle_p1.point, c1*frame_a.x) +v = 2*point_po2.pos_from(particle_p1.point)+c2*frame_a.y +frame_a.set_ang_vel(frame_n, c3*frame_a.z) +v = 2*frame_a.ang_vel_in(frame_n)+c2*frame_a.y +body_r_f.set_ang_vel(frame_n, c3*frame_a.z) +v = 2*body_r_f.ang_vel_in(frame_n)+c2*frame_a.y +frame_a.set_ang_acc(frame_n, (frame_a.ang_vel_in(frame_n)).dt(frame_a)) +v = 2*frame_a.ang_acc_in(frame_n)+c2*frame_a.y +particle_p1.point.set_vel(frame_a, c1*frame_a.x+c3*frame_a.y) +body_r_cm.set_acc(frame_n, c2*frame_a.y) +v_a = _me.cross(body_r_cm.acc(frame_n), particle_p1.point.vel(frame_a)) +x_b_c = v_a +x_b_d = 2*x_b_c +a_b_c_d_e = x_b_d*2 +a_b_c = 2*c1*c2*c3 +a_b_c += 2*c1 +a_b_c = 3*c1 +q1, q2, u1, u2 = _me.dynamicsymbols('q1 q2 u1 u2') +q1_d, q2_d, u1_d, u2_d = _me.dynamicsymbols('q1_ q2_ u1_ u2_', 1) +x, y = _me.dynamicsymbols('x y') +x_d, y_d = _me.dynamicsymbols('x_ y_', 1) +x_dd, y_dd = _me.dynamicsymbols('x_ y_', 2) +yy = _me.dynamicsymbols('yy') +yy = x*x_d**2+1 +m = _sm.Matrix([[0]]) +m[0] = 2*x +m = m.row_insert(m.shape[0], _sm.Matrix([[0]])) +m[m.shape[0]-1] = 2*y +a = 2*m[0] +m = _sm.Matrix([1,2,3,4,5,6,7,8,9]).reshape(3, 3) +m[0,1] = 5 +a = m[0, 1]*2 +force_ro = q1*frame_n.x +torque_a = q2*frame_n.z +force_ro = q1*frame_n.x + q2*frame_n.y +f = force_ro*2 diff --git a/phivenv/Lib/site-packages/sympy/parsing/c/__init__.py b/phivenv/Lib/site-packages/sympy/parsing/c/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..18d3d5301cb001c78fc4a9bc04b25aa36f282a93 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/c/__init__.py @@ -0,0 +1 @@ +"""Used for translating C source code into a SymPy expression""" diff --git a/phivenv/Lib/site-packages/sympy/parsing/c/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/c/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..383e59d66ff389f8d017eb0e1b44cfd855dedfa9 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/c/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/c/__pycache__/c_parser.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/c/__pycache__/c_parser.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..112957495fcbc9c71d127b97c420f32022935274 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/c/__pycache__/c_parser.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/c/c_parser.py b/phivenv/Lib/site-packages/sympy/parsing/c/c_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..9e7223f8351205272e803773589649fcf1902f15 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/c/c_parser.py @@ -0,0 +1,1059 @@ +from sympy.external import import_module +import os + +cin = import_module('clang.cindex', import_kwargs = {'fromlist': ['cindex']}) + +""" +This module contains all the necessary Classes and Function used to Parse C and +C++ code into SymPy expression +The module serves as a backend for SymPyExpression to parse C code +It is also dependent on Clang's AST and SymPy's Codegen AST. +The module only supports the features currently supported by the Clang and +codegen AST which will be updated as the development of codegen AST and this +module progresses. +You might find unexpected bugs and exceptions while using the module, feel free +to report them to the SymPy Issue Tracker + +Features Supported +================== + +- Variable Declarations (integers and reals) +- Assignment (using integer & floating literal and function calls) +- Function Definitions and Declaration +- Function Calls +- Compound statements, Return statements + +Notes +===== + +The module is dependent on an external dependency which needs to be installed +to use the features of this module. + +Clang: The C and C++ compiler which is used to extract an AST from the provided +C source code. + +References +========== + +.. [1] https://github.com/sympy/sympy/issues +.. [2] https://clang.llvm.org/docs/ +.. [3] https://clang.llvm.org/docs/IntroductionToTheClangAST.html + +""" + +if cin: + from sympy.codegen.ast import (Variable, Integer, Float, + FunctionPrototype, FunctionDefinition, FunctionCall, + none, Return, Assignment, intc, int8, int16, int64, + uint8, uint16, uint32, uint64, float32, float64, float80, + aug_assign, bool_, While, CodeBlock) + from sympy.codegen.cnodes import (PreDecrement, PostDecrement, + PreIncrement, PostIncrement) + from sympy.core import Add, Mod, Mul, Pow, Rel + from sympy.logic.boolalg import And, as_Boolean, Not, Or + from sympy.core.symbol import Symbol + from sympy.core.sympify import sympify + from sympy.logic.boolalg import (false, true) + import sys + import tempfile + + class BaseParser: + """Base Class for the C parser""" + + def __init__(self): + """Initializes the Base parser creating a Clang AST index""" + self.index = cin.Index.create() + + def diagnostics(self, out): + """Diagostics function for the Clang AST""" + for diag in self.tu.diagnostics: + # tu = translation unit + print('%s %s (line %s, col %s) %s' % ( + { + 4: 'FATAL', + 3: 'ERROR', + 2: 'WARNING', + 1: 'NOTE', + 0: 'IGNORED', + }[diag.severity], + diag.location.file, + diag.location.line, + diag.location.column, + diag.spelling + ), file=out) + + class CCodeConverter(BaseParser): + """The Code Convereter for Clang AST + + The converter object takes the C source code or file as input and + converts them to SymPy Expressions. + """ + + def __init__(self): + """Initializes the code converter""" + super().__init__() + self._py_nodes = [] + self._data_types = { + "void": { + cin.TypeKind.VOID: none + }, + "bool": { + cin.TypeKind.BOOL: bool_ + }, + "int": { + cin.TypeKind.SCHAR: int8, + cin.TypeKind.SHORT: int16, + cin.TypeKind.INT: intc, + cin.TypeKind.LONG: int64, + cin.TypeKind.UCHAR: uint8, + cin.TypeKind.USHORT: uint16, + cin.TypeKind.UINT: uint32, + cin.TypeKind.ULONG: uint64 + }, + "float": { + cin.TypeKind.FLOAT: float32, + cin.TypeKind.DOUBLE: float64, + cin.TypeKind.LONGDOUBLE: float80 + } + } + + def parse(self, filename, flags): + """Function to parse a file with C source code + + It takes the filename as an attribute and creates a Clang AST + Translation Unit parsing the file. + Then the transformation function is called on the translation unit, + whose results are collected into a list which is returned by the + function. + + Parameters + ========== + + filename : string + Path to the C file to be parsed + + flags: list + Arguments to be passed to Clang while parsing the C code + + Returns + ======= + + py_nodes: list + A list of SymPy AST nodes + + """ + filepath = os.path.abspath(filename) + self.tu = self.index.parse( + filepath, + args=flags, + options=cin.TranslationUnit.PARSE_DETAILED_PROCESSING_RECORD + ) + for child in self.tu.cursor.get_children(): + if child.kind == cin.CursorKind.VAR_DECL or child.kind == cin.CursorKind.FUNCTION_DECL: + self._py_nodes.append(self.transform(child)) + return self._py_nodes + + def parse_str(self, source, flags): + """Function to parse a string with C source code + + It takes the source code as an attribute, stores it in a temporary + file and creates a Clang AST Translation Unit parsing the file. + Then the transformation function is called on the translation unit, + whose results are collected into a list which is returned by the + function. + + Parameters + ========== + + source : string + A string containing the C source code to be parsed + + flags: list + Arguments to be passed to Clang while parsing the C code + + Returns + ======= + + py_nodes: list + A list of SymPy AST nodes + + """ + file = tempfile.NamedTemporaryFile(mode = 'w+', suffix = '.cpp') + file.write(source) + file.flush() + file.seek(0) + self.tu = self.index.parse( + file.name, + args=flags, + options=cin.TranslationUnit.PARSE_DETAILED_PROCESSING_RECORD + ) + file.close() + for child in self.tu.cursor.get_children(): + if child.kind == cin.CursorKind.VAR_DECL or child.kind == cin.CursorKind.FUNCTION_DECL: + self._py_nodes.append(self.transform(child)) + return self._py_nodes + + def transform(self, node): + """Transformation Function for Clang AST nodes + + It determines the kind of node and calls the respective + transformation function for that node. + + Raises + ====== + + NotImplementedError : if the transformation for the provided node + is not implemented + + """ + handler = getattr(self, 'transform_%s' % node.kind.name.lower(), None) + + if handler is None: + print( + "Ignoring node of type %s (%s)" % ( + node.kind, + ' '.join( + t.spelling for t in node.get_tokens()) + ), + file=sys.stderr + ) + + return handler(node) + + def transform_var_decl(self, node): + """Transformation Function for Variable Declaration + + Used to create nodes for variable declarations and assignments with + values or function call for the respective nodes in the clang AST + + Returns + ======= + + A variable node as Declaration, with the initial value if given + + Raises + ====== + + NotImplementedError : if called for data types not currently + implemented + + Notes + ===== + + The function currently supports following data types: + + Boolean: + bool, _Bool + + Integer: + 8-bit: signed char and unsigned char + 16-bit: short, short int, signed short, + signed short int, unsigned short, unsigned short int + 32-bit: int, signed int, unsigned int + 64-bit: long, long int, signed long, + signed long int, unsigned long, unsigned long int + + Floating point: + Single Precision: float + Double Precision: double + Extended Precision: long double + + """ + if node.type.kind in self._data_types["int"]: + type = self._data_types["int"][node.type.kind] + elif node.type.kind in self._data_types["float"]: + type = self._data_types["float"][node.type.kind] + elif node.type.kind in self._data_types["bool"]: + type = self._data_types["bool"][node.type.kind] + else: + raise NotImplementedError("Only bool, int " + "and float are supported") + try: + children = node.get_children() + child = next(children) + + #ignoring namespace and type details for the variable + while child.kind == cin.CursorKind.NAMESPACE_REF or child.kind == cin.CursorKind.TYPE_REF: + child = next(children) + + val = self.transform(child) + + supported_rhs = [ + cin.CursorKind.INTEGER_LITERAL, + cin.CursorKind.FLOATING_LITERAL, + cin.CursorKind.UNEXPOSED_EXPR, + cin.CursorKind.BINARY_OPERATOR, + cin.CursorKind.PAREN_EXPR, + cin.CursorKind.UNARY_OPERATOR, + cin.CursorKind.CXX_BOOL_LITERAL_EXPR + ] + + if child.kind in supported_rhs: + if isinstance(val, str): + value = Symbol(val) + elif isinstance(val, bool): + if node.type.kind in self._data_types["int"]: + value = Integer(0) if val == False else Integer(1) + elif node.type.kind in self._data_types["float"]: + value = Float(0.0) if val == False else Float(1.0) + elif node.type.kind in self._data_types["bool"]: + value = sympify(val) + elif isinstance(val, (Integer, int, Float, float)): + if node.type.kind in self._data_types["int"]: + value = Integer(val) + elif node.type.kind in self._data_types["float"]: + value = Float(val) + elif node.type.kind in self._data_types["bool"]: + value = sympify(bool(val)) + else: + value = val + + return Variable( + node.spelling + ).as_Declaration( + type = type, + value = value + ) + + elif child.kind == cin.CursorKind.CALL_EXPR: + return Variable( + node.spelling + ).as_Declaration( + value = val + ) + + else: + raise NotImplementedError("Given " + "variable declaration \"{}\" " + "is not possible to parse yet!" + .format(" ".join( + t.spelling for t in node.get_tokens() + ) + )) + + except StopIteration: + return Variable( + node.spelling + ).as_Declaration( + type = type + ) + + def transform_function_decl(self, node): + """Transformation Function For Function Declaration + + Used to create nodes for function declarations and definitions for + the respective nodes in the clang AST + + Returns + ======= + + function : Codegen AST node + - FunctionPrototype node if function body is not present + - FunctionDefinition node if the function body is present + + + """ + + if node.result_type.kind in self._data_types["int"]: + ret_type = self._data_types["int"][node.result_type.kind] + elif node.result_type.kind in self._data_types["float"]: + ret_type = self._data_types["float"][node.result_type.kind] + elif node.result_type.kind in self._data_types["bool"]: + ret_type = self._data_types["bool"][node.result_type.kind] + elif node.result_type.kind in self._data_types["void"]: + ret_type = self._data_types["void"][node.result_type.kind] + else: + raise NotImplementedError("Only void, bool, int " + "and float are supported") + body = [] + param = [] + + # Subsequent nodes will be the parameters for the function. + for child in node.get_children(): + decl = self.transform(child) + if child.kind == cin.CursorKind.PARM_DECL: + param.append(decl) + elif child.kind == cin.CursorKind.COMPOUND_STMT: + for val in decl: + body.append(val) + else: + body.append(decl) + + if body == []: + function = FunctionPrototype( + return_type = ret_type, + name = node.spelling, + parameters = param + ) + else: + function = FunctionDefinition( + return_type = ret_type, + name = node.spelling, + parameters = param, + body = body + ) + return function + + def transform_parm_decl(self, node): + """Transformation function for Parameter Declaration + + Used to create parameter nodes for the required functions for the + respective nodes in the clang AST + + Returns + ======= + + param : Codegen AST Node + Variable node with the value and type of the variable + + Raises + ====== + + ValueError if multiple children encountered in the parameter node + + """ + if node.type.kind in self._data_types["int"]: + type = self._data_types["int"][node.type.kind] + elif node.type.kind in self._data_types["float"]: + type = self._data_types["float"][node.type.kind] + elif node.type.kind in self._data_types["bool"]: + type = self._data_types["bool"][node.type.kind] + else: + raise NotImplementedError("Only bool, int " + "and float are supported") + try: + children = node.get_children() + child = next(children) + + # Any namespace nodes can be ignored + while child.kind in [cin.CursorKind.NAMESPACE_REF, + cin.CursorKind.TYPE_REF, + cin.CursorKind.TEMPLATE_REF]: + child = next(children) + + # If there is a child, it is the default value of the parameter. + lit = self.transform(child) + if node.type.kind in self._data_types["int"]: + val = Integer(lit) + elif node.type.kind in self._data_types["float"]: + val = Float(lit) + elif node.type.kind in self._data_types["bool"]: + val = sympify(bool(lit)) + else: + raise NotImplementedError("Only bool, int " + "and float are supported") + + param = Variable( + node.spelling + ).as_Declaration( + type = type, + value = val + ) + except StopIteration: + param = Variable( + node.spelling + ).as_Declaration( + type = type + ) + + try: + self.transform(next(children)) + raise ValueError("Can't handle multiple children on parameter") + except StopIteration: + pass + + return param + + def transform_integer_literal(self, node): + """Transformation function for integer literal + + Used to get the value and type of the given integer literal. + + Returns + ======= + + val : list + List with two arguments type and Value + type contains the type of the integer + value contains the value stored in the variable + + Notes + ===== + + Only Base Integer type supported for now + + """ + try: + value = next(node.get_tokens()).spelling + except StopIteration: + # No tokens + value = node.literal + return int(value) + + def transform_floating_literal(self, node): + """Transformation function for floating literal + + Used to get the value and type of the given floating literal. + + Returns + ======= + + val : list + List with two arguments type and Value + type contains the type of float + value contains the value stored in the variable + + Notes + ===== + + Only Base Float type supported for now + + """ + try: + value = next(node.get_tokens()).spelling + except (StopIteration, ValueError): + # No tokens + value = node.literal + return float(value) + + def transform_string_literal(self, node): + #TODO: No string type in AST + #type = + #try: + # value = next(node.get_tokens()).spelling + #except (StopIteration, ValueError): + # No tokens + # value = node.literal + #val = [type, value] + #return val + pass + + def transform_character_literal(self, node): + """Transformation function for character literal + + Used to get the value of the given character literal. + + Returns + ======= + + val : int + val contains the ascii value of the character literal + + Notes + ===== + + Only for cases where character is assigned to a integer value, + since character literal is not in SymPy AST + + """ + try: + value = next(node.get_tokens()).spelling + except (StopIteration, ValueError): + # No tokens + value = node.literal + return ord(str(value[1])) + + def transform_cxx_bool_literal_expr(self, node): + """Transformation function for boolean literal + + Used to get the value of the given boolean literal. + + Returns + ======= + + value : bool + value contains the boolean value of the variable + + """ + try: + value = next(node.get_tokens()).spelling + except (StopIteration, ValueError): + value = node.literal + return True if value == 'true' else False + + def transform_unexposed_decl(self,node): + """Transformation function for unexposed declarations""" + pass + + def transform_unexposed_expr(self, node): + """Transformation function for unexposed expression + + Unexposed expressions are used to wrap float, double literals and + expressions + + Returns + ======= + + expr : Codegen AST Node + the result from the wrapped expression + + None : NoneType + No children are found for the node + + Raises + ====== + + ValueError if the expression contains multiple children + + """ + # Ignore unexposed nodes; pass whatever is the first + # (and should be only) child unaltered. + try: + children = node.get_children() + expr = self.transform(next(children)) + except StopIteration: + return None + + try: + next(children) + raise ValueError("Unexposed expression has > 1 children.") + except StopIteration: + pass + + return expr + + def transform_decl_ref_expr(self, node): + """Returns the name of the declaration reference""" + return node.spelling + + def transform_call_expr(self, node): + """Transformation function for a call expression + + Used to create function call nodes for the function calls present + in the C code + + Returns + ======= + + FunctionCall : Codegen AST Node + FunctionCall node with parameters if any parameters are present + + """ + param = [] + children = node.get_children() + child = next(children) + + while child.kind == cin.CursorKind.NAMESPACE_REF: + child = next(children) + while child.kind == cin.CursorKind.TYPE_REF: + child = next(children) + + first_child = self.transform(child) + try: + for child in children: + arg = self.transform(child) + if child.kind == cin.CursorKind.INTEGER_LITERAL: + param.append(Integer(arg)) + elif child.kind == cin.CursorKind.FLOATING_LITERAL: + param.append(Float(arg)) + else: + param.append(arg) + return FunctionCall(first_child, param) + + except StopIteration: + return FunctionCall(first_child) + + def transform_return_stmt(self, node): + """Returns the Return Node for a return statement""" + return Return(next(node.get_children()).spelling) + + def transform_compound_stmt(self, node): + """Transformation function for compound statements + + Returns + ======= + + expr : list + list of Nodes for the expressions present in the statement + + None : NoneType + if the compound statement is empty + + """ + expr = [] + children = node.get_children() + + for child in children: + expr.append(self.transform(child)) + return expr + + def transform_decl_stmt(self, node): + """Transformation function for declaration statements + + These statements are used to wrap different kinds of declararions + like variable or function declaration + The function calls the transformer function for the child of the + given node + + Returns + ======= + + statement : Codegen AST Node + contains the node returned by the children node for the type of + declaration + + Raises + ====== + + ValueError if multiple children present + + """ + try: + children = node.get_children() + statement = self.transform(next(children)) + except StopIteration: + pass + + try: + self.transform(next(children)) + raise ValueError("Don't know how to handle multiple statements") + except StopIteration: + pass + + return statement + + def transform_paren_expr(self, node): + """Transformation function for Parenthesized expressions + + Returns the result from its children nodes + + """ + return self.transform(next(node.get_children())) + + def transform_compound_assignment_operator(self, node): + """Transformation function for handling shorthand operators + + Returns + ======= + + augmented_assignment_expression: Codegen AST node + shorthand assignment expression represented as Codegen AST + + Raises + ====== + + NotImplementedError + If the shorthand operator for bitwise operators + (~=, ^=, &=, |=, <<=, >>=) is encountered + + """ + return self.transform_binary_operator(node) + + def transform_unary_operator(self, node): + """Transformation function for handling unary operators + + Returns + ======= + + unary_expression: Codegen AST node + simplified unary expression represented as Codegen AST + + Raises + ====== + + NotImplementedError + If dereferencing operator(*), address operator(&) or + bitwise NOT operator(~) is encountered + + """ + # supported operators list + operators_list = ['+', '-', '++', '--', '!'] + tokens = list(node.get_tokens()) + + # it can be either pre increment/decrement or any other operator from the list + if tokens[0].spelling in operators_list: + child = self.transform(next(node.get_children())) + # (decl_ref) e.g.; int a = ++b; or simply ++b; + if isinstance(child, str): + if tokens[0].spelling == '+': + return Symbol(child) + if tokens[0].spelling == '-': + return Mul(Symbol(child), -1) + if tokens[0].spelling == '++': + return PreIncrement(Symbol(child)) + if tokens[0].spelling == '--': + return PreDecrement(Symbol(child)) + if tokens[0].spelling == '!': + return Not(Symbol(child)) + # e.g.; int a = -1; or int b = -(1 + 2); + else: + if tokens[0].spelling == '+': + return child + if tokens[0].spelling == '-': + return Mul(child, -1) + if tokens[0].spelling == '!': + return Not(sympify(bool(child))) + + # it can be either post increment/decrement + # since variable name is obtained in token[0].spelling + elif tokens[1].spelling in ['++', '--']: + child = self.transform(next(node.get_children())) + if tokens[1].spelling == '++': + return PostIncrement(Symbol(child)) + if tokens[1].spelling == '--': + return PostDecrement(Symbol(child)) + else: + raise NotImplementedError("Dereferencing operator, " + "Address operator and bitwise NOT operator " + "have not been implemented yet!") + + def transform_binary_operator(self, node): + """Transformation function for handling binary operators + + Returns + ======= + + binary_expression: Codegen AST node + simplified binary expression represented as Codegen AST + + Raises + ====== + + NotImplementedError + If a bitwise operator or + unary operator(which is a child of any binary + operator in Clang AST) is encountered + + """ + # get all the tokens of assignment + # and store it in the tokens list + tokens = list(node.get_tokens()) + + # supported operators list + operators_list = ['+', '-', '*', '/', '%','=', + '>', '>=', '<', '<=', '==', '!=', '&&', '||', '+=', '-=', + '*=', '/=', '%='] + + # this stack will contain variable content + # and type of variable in the rhs + combined_variables_stack = [] + + # this stack will contain operators + # to be processed in the rhs + operators_stack = [] + + # iterate through every token + for token in tokens: + # token is either '(', ')' or + # any of the supported operators from the operator list + if token.kind == cin.TokenKind.PUNCTUATION: + + # push '(' to the operators stack + if token.spelling == '(': + operators_stack.append('(') + + elif token.spelling == ')': + # keep adding the expression to the + # combined variables stack unless + # '(' is found + while (operators_stack + and operators_stack[-1] != '('): + if len(combined_variables_stack) < 2: + raise NotImplementedError( + "Unary operators as a part of " + "binary operators is not " + "supported yet!") + rhs = combined_variables_stack.pop() + lhs = combined_variables_stack.pop() + operator = operators_stack.pop() + combined_variables_stack.append( + self.perform_operation( + lhs, rhs, operator)) + + # pop '(' + operators_stack.pop() + + # token is an operator (supported) + elif token.spelling in operators_list: + while (operators_stack + and self.priority_of(token.spelling) + <= self.priority_of( + operators_stack[-1])): + if len(combined_variables_stack) < 2: + raise NotImplementedError( + "Unary operators as a part of " + "binary operators is not " + "supported yet!") + rhs = combined_variables_stack.pop() + lhs = combined_variables_stack.pop() + operator = operators_stack.pop() + combined_variables_stack.append( + self.perform_operation( + lhs, rhs, operator)) + + # push current operator + operators_stack.append(token.spelling) + + # token is a bitwise operator + elif token.spelling in ['&', '|', '^', '<<', '>>']: + raise NotImplementedError( + "Bitwise operator has not been " + "implemented yet!") + + # token is a shorthand bitwise operator + elif token.spelling in ['&=', '|=', '^=', '<<=', + '>>=']: + raise NotImplementedError( + "Shorthand bitwise operator has not been " + "implemented yet!") + else: + raise NotImplementedError( + "Given token {} is not implemented yet!" + .format(token.spelling)) + + # token is an identifier(variable) + elif token.kind == cin.TokenKind.IDENTIFIER: + combined_variables_stack.append( + [token.spelling, 'identifier']) + + # token is a literal + elif token.kind == cin.TokenKind.LITERAL: + combined_variables_stack.append( + [token.spelling, 'literal']) + + # token is a keyword, either true or false + elif (token.kind == cin.TokenKind.KEYWORD + and token.spelling in ['true', 'false']): + combined_variables_stack.append( + [token.spelling, 'boolean']) + else: + raise NotImplementedError( + "Given token {} is not implemented yet!" + .format(token.spelling)) + + # process remaining operators + while operators_stack: + if len(combined_variables_stack) < 2: + raise NotImplementedError( + "Unary operators as a part of " + "binary operators is not " + "supported yet!") + rhs = combined_variables_stack.pop() + lhs = combined_variables_stack.pop() + operator = operators_stack.pop() + combined_variables_stack.append( + self.perform_operation(lhs, rhs, operator)) + + return combined_variables_stack[-1][0] + + def priority_of(self, op): + """To get the priority of given operator""" + if op in ['=', '+=', '-=', '*=', '/=', '%=']: + return 1 + if op in ['&&', '||']: + return 2 + if op in ['<', '<=', '>', '>=', '==', '!=']: + return 3 + if op in ['+', '-']: + return 4 + if op in ['*', '/', '%']: + return 5 + return 0 + + def perform_operation(self, lhs, rhs, op): + """Performs operation supported by the SymPy core + + Returns + ======= + + combined_variable: list + contains variable content and type of variable + + """ + lhs_value = self.get_expr_for_operand(lhs) + rhs_value = self.get_expr_for_operand(rhs) + if op == '+': + return [Add(lhs_value, rhs_value), 'expr'] + if op == '-': + return [Add(lhs_value, -rhs_value), 'expr'] + if op == '*': + return [Mul(lhs_value, rhs_value), 'expr'] + if op == '/': + return [Mul(lhs_value, Pow(rhs_value, Integer(-1))), 'expr'] + if op == '%': + return [Mod(lhs_value, rhs_value), 'expr'] + if op in ['<', '<=', '>', '>=', '==', '!=']: + return [Rel(lhs_value, rhs_value, op), 'expr'] + if op == '&&': + return [And(as_Boolean(lhs_value), as_Boolean(rhs_value)), 'expr'] + if op == '||': + return [Or(as_Boolean(lhs_value), as_Boolean(rhs_value)), 'expr'] + if op == '=': + return [Assignment(Variable(lhs_value), rhs_value), 'expr'] + if op in ['+=', '-=', '*=', '/=', '%=']: + return [aug_assign(Variable(lhs_value), op[0], rhs_value), 'expr'] + + def get_expr_for_operand(self, combined_variable): + """Gives out SymPy Codegen AST node + + AST node returned is corresponding to + combined variable passed.Combined variable contains + variable content and type of variable + + """ + if combined_variable[1] == 'identifier': + return Symbol(combined_variable[0]) + if combined_variable[1] == 'literal': + if '.' in combined_variable[0]: + return Float(float(combined_variable[0])) + else: + return Integer(int(combined_variable[0])) + if combined_variable[1] == 'expr': + return combined_variable[0] + if combined_variable[1] == 'boolean': + return true if combined_variable[0] == 'true' else false + + def transform_null_stmt(self, node): + """Handles Null Statement and returns None""" + return none + + def transform_while_stmt(self, node): + """Transformation function for handling while statement + + Returns + ======= + + while statement : Codegen AST Node + contains the while statement node having condition and + statement block + + """ + children = node.get_children() + + condition = self.transform(next(children)) + statements = self.transform(next(children)) + + if isinstance(statements, list): + statement_block = CodeBlock(*statements) + else: + statement_block = CodeBlock(statements) + + return While(condition, statement_block) + + + +else: + class CCodeConverter(): # type: ignore + def __init__(self, *args, **kwargs): + raise ImportError("Module not Installed") + + +def parse_c(source): + """Function for converting a C source code + + The function reads the source code present in the given file and parses it + to give out SymPy Expressions + + Returns + ======= + + src : list + List of Python expression strings + + """ + converter = CCodeConverter() + if os.path.exists(source): + src = converter.parse(source, flags = []) + else: + src = converter.parse_str(source, flags = []) + return src diff --git a/phivenv/Lib/site-packages/sympy/parsing/fortran/__init__.py b/phivenv/Lib/site-packages/sympy/parsing/fortran/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c65e37cf3de2dddbcee0fa5c7eeac2fdc9f685db --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/fortran/__init__.py @@ -0,0 +1 @@ +"""Used for translating Fortran source code into a SymPy expression. """ diff --git a/phivenv/Lib/site-packages/sympy/parsing/fortran/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/fortran/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f6ee948e5ce06873aa97c1c61469a2fe7207928 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/fortran/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/fortran/__pycache__/fortran_parser.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/fortran/__pycache__/fortran_parser.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d77e41b2f7127cb764e6fe471d2beacefc5a2c7 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/fortran/__pycache__/fortran_parser.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/fortran/fortran_parser.py b/phivenv/Lib/site-packages/sympy/parsing/fortran/fortran_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..504249f6119a59a90d91c5e989f893cffe20e643 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/fortran/fortran_parser.py @@ -0,0 +1,347 @@ +from sympy.external import import_module + +lfortran = import_module('lfortran') + +if lfortran: + from sympy.codegen.ast import (Variable, IntBaseType, FloatBaseType, String, + Return, FunctionDefinition, Assignment) + from sympy.core import Add, Mul, Integer, Float + from sympy.core.symbol import Symbol + + asr_mod = lfortran.asr + asr = lfortran.asr.asr + src_to_ast = lfortran.ast.src_to_ast + ast_to_asr = lfortran.semantic.ast_to_asr.ast_to_asr + + """ + This module contains all the necessary Classes and Function used to Parse + Fortran code into SymPy expression + + The module and its API are currently under development and experimental. + It is also dependent on LFortran for the ASR that is converted to SymPy syntax + which is also under development. + The module only supports the features currently supported by the LFortran ASR + which will be updated as the development of LFortran and this module progresses + + You might find unexpected bugs and exceptions while using the module, feel free + to report them to the SymPy Issue Tracker + + The API for the module might also change while in development if better and + more effective ways are discovered for the process + + Features Supported + ================== + + - Variable Declarations (integers and reals) + - Function Definitions + - Assignments and Basic Binary Operations + + + Notes + ===== + + The module depends on an external dependency + + LFortran : Required to parse Fortran source code into ASR + + + References + ========== + + .. [1] https://github.com/sympy/sympy/issues + .. [2] https://gitlab.com/lfortran/lfortran + .. [3] https://docs.lfortran.org/ + + """ + + + class ASR2PyVisitor(asr.ASTVisitor): # type: ignore + """ + Visitor Class for LFortran ASR + + It is a Visitor class derived from asr.ASRVisitor which visits all the + nodes of the LFortran ASR and creates corresponding AST node for each + ASR node + + """ + + def __init__(self): + """Initialize the Parser""" + self._py_ast = [] + + def visit_TranslationUnit(self, node): + """ + Function to visit all the elements of the Translation Unit + created by LFortran ASR + """ + for s in node.global_scope.symbols: + sym = node.global_scope.symbols[s] + self.visit(sym) + for item in node.items: + self.visit(item) + + def visit_Assignment(self, node): + """Visitor Function for Assignment + + Visits each Assignment is the LFortran ASR and creates corresponding + assignment for SymPy. + + Notes + ===== + + The function currently only supports variable assignment and binary + operation assignments of varying multitudes. Any type of numberS or + array is not supported. + + Raises + ====== + + NotImplementedError() when called for Numeric assignments or Arrays + + """ + # TODO: Arithmetic Assignment + if isinstance(node.target, asr.Variable): + target = node.target + value = node.value + if isinstance(value, asr.Variable): + new_node = Assignment( + Variable( + target.name + ), + Variable( + value.name + ) + ) + elif (type(value) == asr.BinOp): + exp_ast = call_visitor(value) + for expr in exp_ast: + new_node = Assignment( + Variable(target.name), + expr + ) + else: + raise NotImplementedError("Numeric assignments not supported") + else: + raise NotImplementedError("Arrays not supported") + self._py_ast.append(new_node) + + def visit_BinOp(self, node): + """Visitor Function for Binary Operations + + Visits each binary operation present in the LFortran ASR like addition, + subtraction, multiplication, division and creates the corresponding + operation node in SymPy's AST + + In case of more than one binary operations, the function calls the + call_visitor() function on the child nodes of the binary operations + recursively until all the operations have been processed. + + Notes + ===== + + The function currently only supports binary operations with Variables + or other binary operations. Numerics are not supported as of yet. + + Raises + ====== + + NotImplementedError() when called for Numeric assignments + + """ + # TODO: Integer Binary Operations + op = node.op + lhs = node.left + rhs = node.right + + if (type(lhs) == asr.Variable): + left_value = Symbol(lhs.name) + elif(type(lhs) == asr.BinOp): + l_exp_ast = call_visitor(lhs) + for exp in l_exp_ast: + left_value = exp + else: + raise NotImplementedError("Numbers Currently not supported") + + if (type(rhs) == asr.Variable): + right_value = Symbol(rhs.name) + elif(type(rhs) == asr.BinOp): + r_exp_ast = call_visitor(rhs) + for exp in r_exp_ast: + right_value = exp + else: + raise NotImplementedError("Numbers Currently not supported") + + if isinstance(op, asr.Add): + new_node = Add(left_value, right_value) + elif isinstance(op, asr.Sub): + new_node = Add(left_value, -right_value) + elif isinstance(op, asr.Div): + new_node = Mul(left_value, 1/right_value) + elif isinstance(op, asr.Mul): + new_node = Mul(left_value, right_value) + + self._py_ast.append(new_node) + + def visit_Variable(self, node): + """Visitor Function for Variable Declaration + + Visits each variable declaration present in the ASR and creates a + Symbol declaration for each variable + + Notes + ===== + + The functions currently only support declaration of integer and + real variables. Other data types are still under development. + + Raises + ====== + + NotImplementedError() when called for unsupported data types + + """ + if isinstance(node.type, asr.Integer): + var_type = IntBaseType(String('integer')) + value = Integer(0) + elif isinstance(node.type, asr.Real): + var_type = FloatBaseType(String('real')) + value = Float(0.0) + else: + raise NotImplementedError("Data type not supported") + + if not (node.intent == 'in'): + new_node = Variable( + node.name + ).as_Declaration( + type = var_type, + value = value + ) + self._py_ast.append(new_node) + + def visit_Sequence(self, seq): + """Visitor Function for code sequence + + Visits a code sequence/ block and calls the visitor function on all the + children of the code block to create corresponding code in python + + """ + if seq is not None: + for node in seq: + self._py_ast.append(call_visitor(node)) + + def visit_Num(self, node): + """Visitor Function for Numbers in ASR + + This function is currently under development and will be updated + with improvements in the LFortran ASR + + """ + # TODO:Numbers when the LFortran ASR is updated + # self._py_ast.append(Integer(node.n)) + pass + + def visit_Function(self, node): + """Visitor Function for function Definitions + + Visits each function definition present in the ASR and creates a + function definition node in the Python AST with all the elements of the + given function + + The functions declare all the variables required as SymPy symbols in + the function before the function definition + + This function also the call_visior_function to parse the contents of + the function body + + """ + # TODO: Return statement, variable declaration + fn_args = [Variable(arg_iter.name) for arg_iter in node.args] + fn_body = [] + fn_name = node.name + for i in node.body: + fn_ast = call_visitor(i) + try: + fn_body_expr = fn_ast + except UnboundLocalError: + fn_body_expr = [] + for sym in node.symtab.symbols: + decl = call_visitor(node.symtab.symbols[sym]) + for symbols in decl: + fn_body.append(symbols) + for elem in fn_body_expr: + fn_body.append(elem) + fn_body.append( + Return( + Variable( + node.return_var.name + ) + ) + ) + if isinstance(node.return_var.type, asr.Integer): + ret_type = IntBaseType(String('integer')) + elif isinstance(node.return_var.type, asr.Real): + ret_type = FloatBaseType(String('real')) + else: + raise NotImplementedError("Data type not supported") + new_node = FunctionDefinition( + return_type = ret_type, + name = fn_name, + parameters = fn_args, + body = fn_body + ) + self._py_ast.append(new_node) + + def ret_ast(self): + """Returns the AST nodes""" + return self._py_ast +else: + class ASR2PyVisitor(): # type: ignore + def __init__(self, *args, **kwargs): + raise ImportError('lfortran not available') + +def call_visitor(fort_node): + """Calls the AST Visitor on the Module + + This function is used to call the AST visitor for a program or module + It imports all the required modules and calls the visit() function + on the given node + + Parameters + ========== + + fort_node : LFortran ASR object + Node for the operation for which the NodeVisitor is called + + Returns + ======= + + res_ast : list + list of SymPy AST Nodes + + """ + v = ASR2PyVisitor() + v.visit(fort_node) + res_ast = v.ret_ast() + return res_ast + + +def src_to_sympy(src): + """Wrapper function to convert the given Fortran source code to SymPy Expressions + + Parameters + ========== + + src : string + A string with the Fortran source code + + Returns + ======= + + py_src : string + A string with the Python source code compatible with SymPy + + """ + a_ast = src_to_ast(src, translation_unit=False) + a = ast_to_asr(a_ast) + py_src = call_visitor(a) + return py_src diff --git a/phivenv/Lib/site-packages/sympy/parsing/latex/LICENSE.txt b/phivenv/Lib/site-packages/sympy/parsing/latex/LICENSE.txt new file mode 100644 index 0000000000000000000000000000000000000000..6bbfda911b2afada41a568218e31a6502dc68f44 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/latex/LICENSE.txt @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright 2016, latex2sympy + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/phivenv/Lib/site-packages/sympy/parsing/latex/LaTeX.g4 b/phivenv/Lib/site-packages/sympy/parsing/latex/LaTeX.g4 new file mode 100644 index 0000000000000000000000000000000000000000..fc2c30f9817931e2060b549a39f98a6a4f9cb1f7 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/latex/LaTeX.g4 @@ -0,0 +1,312 @@ +/* + ANTLR4 LaTeX Math Grammar + + Ported from latex2sympy by @augustt198 https://github.com/augustt198/latex2sympy See license in + LICENSE.txt + */ + +/* + After changing this file, it is necessary to run `python setup.py antlr` in the root directory of + the repository. This will regenerate the code in `sympy/parsing/latex/_antlr/*.py`. + */ + +grammar LaTeX; + +options { + language = Python3; +} + +WS: [ \t\r\n]+ -> skip; +THINSPACE: ('\\,' | '\\thinspace') -> skip; +MEDSPACE: ('\\:' | '\\medspace') -> skip; +THICKSPACE: ('\\;' | '\\thickspace') -> skip; +QUAD: '\\quad' -> skip; +QQUAD: '\\qquad' -> skip; +NEGTHINSPACE: ('\\!' | '\\negthinspace') -> skip; +NEGMEDSPACE: '\\negmedspace' -> skip; +NEGTHICKSPACE: '\\negthickspace' -> skip; +CMD_LEFT: '\\left' -> skip; +CMD_RIGHT: '\\right' -> skip; + +IGNORE: + ( + '\\vrule' + | '\\vcenter' + | '\\vbox' + | '\\vskip' + | '\\vspace' + | '\\hfil' + | '\\*' + | '\\-' + | '\\.' + | '\\/' + | '\\"' + | '\\(' + | '\\=' + ) -> skip; + +ADD: '+'; +SUB: '-'; +MUL: '*'; +DIV: '/'; + +L_PAREN: '('; +R_PAREN: ')'; +L_BRACE: '{'; +R_BRACE: '}'; +L_BRACE_LITERAL: '\\{'; +R_BRACE_LITERAL: '\\}'; +L_BRACKET: '['; +R_BRACKET: ']'; + +BAR: '|'; + +R_BAR: '\\right|'; +L_BAR: '\\left|'; + +L_ANGLE: '\\langle'; +R_ANGLE: '\\rangle'; +FUNC_LIM: '\\lim'; +LIM_APPROACH_SYM: + '\\to' + | '\\rightarrow' + | '\\Rightarrow' + | '\\longrightarrow' + | '\\Longrightarrow'; +FUNC_INT: + '\\int' + | '\\int\\limits'; +FUNC_SUM: '\\sum'; +FUNC_PROD: '\\prod'; + +FUNC_EXP: '\\exp'; +FUNC_LOG: '\\log'; +FUNC_LG: '\\lg'; +FUNC_LN: '\\ln'; +FUNC_SIN: '\\sin'; +FUNC_COS: '\\cos'; +FUNC_TAN: '\\tan'; +FUNC_CSC: '\\csc'; +FUNC_SEC: '\\sec'; +FUNC_COT: '\\cot'; + +FUNC_ARCSIN: '\\arcsin'; +FUNC_ARCCOS: '\\arccos'; +FUNC_ARCTAN: '\\arctan'; +FUNC_ARCCSC: '\\arccsc'; +FUNC_ARCSEC: '\\arcsec'; +FUNC_ARCCOT: '\\arccot'; + +FUNC_SINH: '\\sinh'; +FUNC_COSH: '\\cosh'; +FUNC_TANH: '\\tanh'; +FUNC_ARSINH: '\\arsinh'; +FUNC_ARCOSH: '\\arcosh'; +FUNC_ARTANH: '\\artanh'; + +L_FLOOR: '\\lfloor'; +R_FLOOR: '\\rfloor'; +L_CEIL: '\\lceil'; +R_CEIL: '\\rceil'; + +FUNC_SQRT: '\\sqrt'; +FUNC_OVERLINE: '\\overline'; + +CMD_TIMES: '\\times'; +CMD_CDOT: '\\cdot'; +CMD_DIV: '\\div'; +CMD_FRAC: + '\\frac' + | '\\dfrac' + | '\\tfrac'; +CMD_BINOM: '\\binom'; +CMD_DBINOM: '\\dbinom'; +CMD_TBINOM: '\\tbinom'; + +CMD_MATHIT: '\\mathit'; + +UNDERSCORE: '_'; +CARET: '^'; +COLON: ':'; + +fragment WS_CHAR: [ \t\r\n]; +DIFFERENTIAL: 'd' WS_CHAR*? ([a-zA-Z] | '\\' [a-zA-Z]+); + +LETTER: [a-zA-Z]; +DIGIT: [0-9]; + +EQUAL: (('&' WS_CHAR*?)? '=') | ('=' (WS_CHAR*? '&')?); +NEQ: '\\neq'; + +LT: '<'; +LTE: ('\\leq' | '\\le' | LTE_Q | LTE_S); +LTE_Q: '\\leqq'; +LTE_S: '\\leqslant'; + +GT: '>'; +GTE: ('\\geq' | '\\ge' | GTE_Q | GTE_S); +GTE_Q: '\\geqq'; +GTE_S: '\\geqslant'; + +BANG: '!'; + +SINGLE_QUOTES: '\''+; + +SYMBOL: '\\' [a-zA-Z]+; + +math: relation; + +relation: + relation (EQUAL | LT | LTE | GT | GTE | NEQ) relation + | expr; + +equality: expr EQUAL expr; + +expr: additive; + +additive: additive (ADD | SUB) additive | mp; + +// mult part +mp: + mp (MUL | CMD_TIMES | CMD_CDOT | DIV | CMD_DIV | COLON) mp + | unary; + +mp_nofunc: + mp_nofunc ( + MUL + | CMD_TIMES + | CMD_CDOT + | DIV + | CMD_DIV + | COLON + ) mp_nofunc + | unary_nofunc; + +unary: (ADD | SUB) unary | postfix+; + +unary_nofunc: + (ADD | SUB) unary_nofunc + | postfix postfix_nofunc*; + +postfix: exp postfix_op*; +postfix_nofunc: exp_nofunc postfix_op*; +postfix_op: BANG | eval_at; + +eval_at: + BAR (eval_at_sup | eval_at_sub | eval_at_sup eval_at_sub); + +eval_at_sub: UNDERSCORE L_BRACE (expr | equality) R_BRACE; + +eval_at_sup: CARET L_BRACE (expr | equality) R_BRACE; + +exp: exp CARET (atom | L_BRACE expr R_BRACE) subexpr? | comp; + +exp_nofunc: + exp_nofunc CARET (atom | L_BRACE expr R_BRACE) subexpr? + | comp_nofunc; + +comp: + group + | abs_group + | func + | atom + | floor + | ceil; + +comp_nofunc: + group + | abs_group + | atom + | floor + | ceil; + +group: + L_PAREN expr R_PAREN + | L_BRACKET expr R_BRACKET + | L_BRACE expr R_BRACE + | L_BRACE_LITERAL expr R_BRACE_LITERAL; + +abs_group: BAR expr BAR; + +number: DIGIT+ (',' DIGIT DIGIT DIGIT)* ('.' DIGIT+)?; + +atom: (LETTER | SYMBOL) (subexpr? SINGLE_QUOTES? | SINGLE_QUOTES? subexpr?) + | number + | DIFFERENTIAL + | mathit + | frac + | binom + | bra + | ket; + +bra: L_ANGLE expr (R_BAR | BAR); +ket: (L_BAR | BAR) expr R_ANGLE; + +mathit: CMD_MATHIT L_BRACE mathit_text R_BRACE; +mathit_text: LETTER*; + +frac: CMD_FRAC (upperd = DIGIT | L_BRACE upper = expr R_BRACE) + (lowerd = DIGIT | L_BRACE lower = expr R_BRACE); + +binom: + (CMD_BINOM | CMD_DBINOM | CMD_TBINOM) L_BRACE n = expr R_BRACE L_BRACE k = expr R_BRACE; + +floor: L_FLOOR val = expr R_FLOOR; +ceil: L_CEIL val = expr R_CEIL; + +func_normal: + FUNC_EXP + | FUNC_LOG + | FUNC_LG + | FUNC_LN + | FUNC_SIN + | FUNC_COS + | FUNC_TAN + | FUNC_CSC + | FUNC_SEC + | FUNC_COT + | FUNC_ARCSIN + | FUNC_ARCCOS + | FUNC_ARCTAN + | FUNC_ARCCSC + | FUNC_ARCSEC + | FUNC_ARCCOT + | FUNC_SINH + | FUNC_COSH + | FUNC_TANH + | FUNC_ARSINH + | FUNC_ARCOSH + | FUNC_ARTANH; + +func: + func_normal (subexpr? supexpr? | supexpr? subexpr?) ( + L_PAREN func_arg R_PAREN + | func_arg_noparens + ) + | (LETTER | SYMBOL) (subexpr? SINGLE_QUOTES? | SINGLE_QUOTES? subexpr?) // e.g. f(x), f_1'(x) + L_PAREN args R_PAREN + | FUNC_INT (subexpr supexpr | supexpr subexpr)? ( + additive? DIFFERENTIAL + | frac + | additive + ) + | FUNC_SQRT (L_BRACKET root = expr R_BRACKET)? L_BRACE base = expr R_BRACE + | FUNC_OVERLINE L_BRACE base = expr R_BRACE + | (FUNC_SUM | FUNC_PROD) (subeq supexpr | supexpr subeq) mp + | FUNC_LIM limit_sub mp; + +args: (expr ',' args) | expr; + +limit_sub: + UNDERSCORE L_BRACE (LETTER | SYMBOL) LIM_APPROACH_SYM expr ( + CARET ((L_BRACE (ADD | SUB) R_BRACE) | ADD | SUB) + )? R_BRACE; + +func_arg: expr | (expr ',' func_arg); +func_arg_noparens: mp_nofunc; + +subexpr: UNDERSCORE (atom | L_BRACE expr R_BRACE); +supexpr: CARET (atom | L_BRACE expr R_BRACE); + +subeq: UNDERSCORE L_BRACE equality R_BRACE; +supeq: UNDERSCORE L_BRACE equality R_BRACE; diff --git a/phivenv/Lib/site-packages/sympy/parsing/latex/__init__.py b/phivenv/Lib/site-packages/sympy/parsing/latex/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9466d37b8b06f1f292c73f975e44d21c96da10d1 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/latex/__init__.py @@ -0,0 +1,204 @@ +from sympy.external import import_module +from sympy.utilities.decorator import doctest_depends_on +from re import compile as rcompile + +from sympy.parsing.latex.lark import LarkLaTeXParser, TransformToSymPyExpr, parse_latex_lark # noqa + +from .errors import LaTeXParsingError # noqa + + +IGNORE_L = r"\s*[{]*\s*" +IGNORE_R = r"\s*[}]*\s*" +NO_LEFT = r"(? len(latex_str): + e = len(latex_str) + eellipsis = "" + + if x[3] in END_DELIM_REPR: + err = (f"Extra '{x[2]}' at index {x[0]} or " + "missing corresponding " + f"'{BEGIN_DELIM_REPR[MATRIX_DELIMS_INV[x[3]]]}' " + f"in LaTeX string: {sellipsis}{latex_str[s:e]}" + f"{eellipsis}") + raise LaTeXParsingError(err) + + if x[7] is None: + err = (f"Extra '{x[2]}' at index {x[0]} or " + "missing corresponding " + f"'{END_DELIM_REPR[MATRIX_DELIMS[x[3]]]}' " + f"in LaTeX string: {sellipsis}{latex_str[s:e]}" + f"{eellipsis}") + raise LaTeXParsingError(err) + + correct_end_regex = MATRIX_DELIMS[x[3]] + sellipsis = "..." if x[0] > 0 else "" + eellipsis = "..." if x[5] < len(latex_str) else "" + if x[7] != correct_end_regex: + err = ("Expected " + f"'{END_DELIM_REPR[correct_end_regex]}' " + f"to close the '{x[2]}' at index {x[0]} but " + f"found '{x[6]}' at index {x[4]} of LaTeX " + f"string instead: {sellipsis}{latex_str[x[0]:x[5]]}" + f"{eellipsis}") + raise LaTeXParsingError(err) + +__doctest_requires__ = {('parse_latex',): ['antlr4', 'lark']} + + +@doctest_depends_on(modules=('antlr4', 'lark')) +def parse_latex(s, strict=False, backend="antlr"): + r"""Converts the input LaTeX string ``s`` to a SymPy ``Expr``. + + Parameters + ========== + + s : str + The LaTeX string to parse. In Python source containing LaTeX, + *raw strings* (denoted with ``r"``, like this one) are preferred, + as LaTeX makes liberal use of the ``\`` character, which would + trigger escaping in normal Python strings. + backend : str, optional + Currently, there are two backends supported: ANTLR, and Lark. + The default setting is to use the ANTLR backend, which can be + changed to Lark if preferred. + + Use ``backend="antlr"`` for the ANTLR-based parser, and + ``backend="lark"`` for the Lark-based parser. + + The ``backend`` option is case-sensitive, and must be in + all lowercase. + strict : bool, optional + This option is only available with the ANTLR backend. + + If True, raise an exception if the string cannot be parsed as + valid LaTeX. If False, try to recover gracefully from common + mistakes. + + Examples + ======== + + >>> from sympy.parsing.latex import parse_latex + >>> expr = parse_latex(r"\frac {1 + \sqrt {\a}} {\b}") + >>> expr + (sqrt(a) + 1)/b + >>> expr.evalf(4, subs=dict(a=5, b=2)) + 1.618 + >>> func = parse_latex(r"\int_1^\alpha \dfrac{\mathrm{d}t}{t}", backend="lark") + >>> func.evalf(subs={"alpha": 2}) + 0.693147180559945 + """ + + check_matrix_delimiters(s) + + if backend == "antlr": + _latex = import_module( + 'sympy.parsing.latex._parse_latex_antlr', + import_kwargs={'fromlist': ['X']}) + + if _latex is not None: + return _latex.parse_latex(s, strict) + elif backend == "lark": + return parse_latex_lark(s) + else: + raise NotImplementedError(f"Using the '{backend}' backend in the LaTeX" \ + " parser is not supported.") diff --git a/phivenv/Lib/site-packages/sympy/parsing/latex/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/latex/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a2b96e2c11620c4ab791d7740ffa1cdd97fd11b Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/latex/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/latex/__pycache__/_build_latex_antlr.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/latex/__pycache__/_build_latex_antlr.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..86c6453aca37adb6d074ea7b7b03da1dbc16e9dc Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/latex/__pycache__/_build_latex_antlr.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/latex/__pycache__/_parse_latex_antlr.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/latex/__pycache__/_parse_latex_antlr.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ddde41a337c6efadc8b6c59ebfaab048350653b8 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/latex/__pycache__/_parse_latex_antlr.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/latex/__pycache__/errors.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/latex/__pycache__/errors.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a818893b359e8456f25426847d3bb3b82f2d9d7 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/latex/__pycache__/errors.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/latex/_antlr/__init__.py b/phivenv/Lib/site-packages/sympy/parsing/latex/_antlr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2d690e1eb8631ee7731fc1875769d3a4704a1743 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/latex/_antlr/__init__.py @@ -0,0 +1,9 @@ +# *** GENERATED BY `setup.py antlr`, DO NOT EDIT BY HAND *** +# +# Generated from ../LaTeX.g4, derived from latex2sympy +# latex2sympy is licensed under the MIT license +# https://github.com/augustt198/latex2sympy/blob/master/LICENSE.txt +# +# Generated with antlr4 +# antlr4 is licensed under the BSD-3-Clause License +# https://github.com/antlr/antlr4/blob/master/LICENSE.txt diff --git a/phivenv/Lib/site-packages/sympy/parsing/latex/_antlr/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/latex/_antlr/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ef9a78e35acc65a50b129556e05b4ade4f86d9a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/latex/_antlr/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/latex/_antlr/__pycache__/latexlexer.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/latex/_antlr/__pycache__/latexlexer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3cbf33d64113e5a24f32bb23219cc34df87ad69 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/latex/_antlr/__pycache__/latexlexer.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/latex/_antlr/latexlexer.py b/phivenv/Lib/site-packages/sympy/parsing/latex/_antlr/latexlexer.py new file mode 100644 index 0000000000000000000000000000000000000000..46ca959736c967782eef360b9b3268ccd0be0979 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/latex/_antlr/latexlexer.py @@ -0,0 +1,512 @@ +# *** GENERATED BY `setup.py antlr`, DO NOT EDIT BY HAND *** +# +# Generated from ../LaTeX.g4, derived from latex2sympy +# latex2sympy is licensed under the MIT license +# https://github.com/augustt198/latex2sympy/blob/master/LICENSE.txt +# +# Generated with antlr4 +# antlr4 is licensed under the BSD-3-Clause License +# https://github.com/antlr/antlr4/blob/master/LICENSE.txt +from antlr4 import * +from io import StringIO +import sys +if sys.version_info[1] > 5: + from typing import TextIO +else: + from typing.io import TextIO + + +def serializedATN(): + return [ + 4,0,91,911,6,-1,2,0,7,0,2,1,7,1,2,2,7,2,2,3,7,3,2,4,7,4,2,5,7,5, + 2,6,7,6,2,7,7,7,2,8,7,8,2,9,7,9,2,10,7,10,2,11,7,11,2,12,7,12,2, + 13,7,13,2,14,7,14,2,15,7,15,2,16,7,16,2,17,7,17,2,18,7,18,2,19,7, + 19,2,20,7,20,2,21,7,21,2,22,7,22,2,23,7,23,2,24,7,24,2,25,7,25,2, + 26,7,26,2,27,7,27,2,28,7,28,2,29,7,29,2,30,7,30,2,31,7,31,2,32,7, + 32,2,33,7,33,2,34,7,34,2,35,7,35,2,36,7,36,2,37,7,37,2,38,7,38,2, + 39,7,39,2,40,7,40,2,41,7,41,2,42,7,42,2,43,7,43,2,44,7,44,2,45,7, + 45,2,46,7,46,2,47,7,47,2,48,7,48,2,49,7,49,2,50,7,50,2,51,7,51,2, + 52,7,52,2,53,7,53,2,54,7,54,2,55,7,55,2,56,7,56,2,57,7,57,2,58,7, + 58,2,59,7,59,2,60,7,60,2,61,7,61,2,62,7,62,2,63,7,63,2,64,7,64,2, + 65,7,65,2,66,7,66,2,67,7,67,2,68,7,68,2,69,7,69,2,70,7,70,2,71,7, + 71,2,72,7,72,2,73,7,73,2,74,7,74,2,75,7,75,2,76,7,76,2,77,7,77,2, + 78,7,78,2,79,7,79,2,80,7,80,2,81,7,81,2,82,7,82,2,83,7,83,2,84,7, + 84,2,85,7,85,2,86,7,86,2,87,7,87,2,88,7,88,2,89,7,89,2,90,7,90,2, + 91,7,91,1,0,1,0,1,1,1,1,1,2,4,2,191,8,2,11,2,12,2,192,1,2,1,2,1, + 3,1,3,1,3,1,3,1,3,1,3,1,3,1,3,1,3,1,3,1,3,1,3,3,3,209,8,3,1,3,1, + 3,1,4,1,4,1,4,1,4,1,4,1,4,1,4,1,4,1,4,1,4,1,4,3,4,224,8,4,1,4,1, + 4,1,5,1,5,1,5,1,5,1,5,1,5,1,5,1,5,1,5,1,5,1,5,1,5,1,5,3,5,241,8, + 5,1,5,1,5,1,6,1,6,1,6,1,6,1,6,1,6,1,6,1,6,1,7,1,7,1,7,1,7,1,7,1, + 7,1,7,1,7,1,7,1,8,1,8,1,8,1,8,1,8,1,8,1,8,1,8,1,8,1,8,1,8,1,8,1, + 8,1,8,1,8,3,8,277,8,8,1,8,1,8,1,9,1,9,1,9,1,9,1,9,1,9,1,9,1,9,1, + 9,1,9,1,9,1,9,1,9,1,9,1,9,1,10,1,10,1,10,1,10,1,10,1,10,1,10,1,10, + 1,10,1,10,1,10,1,10,1,10,1,10,1,10,1,10,1,10,1,11,1,11,1,11,1,11, + 1,11,1,11,1,11,1,11,1,12,1,12,1,12,1,12,1,12,1,12,1,12,1,12,1,12, + 1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13, + 1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13, + 1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13, + 1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,1,13,3,13, + 381,8,13,1,13,1,13,1,14,1,14,1,15,1,15,1,16,1,16,1,17,1,17,1,18, + 1,18,1,19,1,19,1,20,1,20,1,21,1,21,1,22,1,22,1,22,1,23,1,23,1,23, + 1,24,1,24,1,25,1,25,1,26,1,26,1,27,1,27,1,27,1,27,1,27,1,27,1,27, + 1,27,1,28,1,28,1,28,1,28,1,28,1,28,1,28,1,29,1,29,1,29,1,29,1,29, + 1,29,1,29,1,29,1,30,1,30,1,30,1,30,1,30,1,30,1,30,1,30,1,31,1,31, + 1,31,1,31,1,31,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32, + 1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32, + 1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32, + 1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32, + 1,32,1,32,1,32,1,32,1,32,1,32,3,32,504,8,32,1,33,1,33,1,33,1,33, + 1,33,1,33,1,33,1,33,1,33,1,33,1,33,1,33,1,33,1,33,1,33,3,33,521, + 8,33,1,34,1,34,1,34,1,34,1,34,1,35,1,35,1,35,1,35,1,35,1,35,1,36, + 1,36,1,36,1,36,1,36,1,37,1,37,1,37,1,37,1,37,1,38,1,38,1,38,1,38, + 1,39,1,39,1,39,1,39,1,40,1,40,1,40,1,40,1,40,1,41,1,41,1,41,1,41, + 1,41,1,42,1,42,1,42,1,42,1,42,1,43,1,43,1,43,1,43,1,43,1,44,1,44, + 1,44,1,44,1,44,1,45,1,45,1,45,1,45,1,45,1,46,1,46,1,46,1,46,1,46, + 1,46,1,46,1,46,1,47,1,47,1,47,1,47,1,47,1,47,1,47,1,47,1,48,1,48, + 1,48,1,48,1,48,1,48,1,48,1,48,1,49,1,49,1,49,1,49,1,49,1,49,1,49, + 1,49,1,50,1,50,1,50,1,50,1,50,1,50,1,50,1,50,1,51,1,51,1,51,1,51, + 1,51,1,51,1,51,1,51,1,52,1,52,1,52,1,52,1,52,1,52,1,53,1,53,1,53, + 1,53,1,53,1,53,1,54,1,54,1,54,1,54,1,54,1,54,1,55,1,55,1,55,1,55, + 1,55,1,55,1,55,1,55,1,56,1,56,1,56,1,56,1,56,1,56,1,56,1,56,1,57, + 1,57,1,57,1,57,1,57,1,57,1,57,1,57,1,58,1,58,1,58,1,58,1,58,1,58, + 1,58,1,58,1,59,1,59,1,59,1,59,1,59,1,59,1,59,1,59,1,60,1,60,1,60, + 1,60,1,60,1,60,1,60,1,61,1,61,1,61,1,61,1,61,1,61,1,61,1,62,1,62, + 1,62,1,62,1,62,1,62,1,63,1,63,1,63,1,63,1,63,1,63,1,63,1,63,1,63, + 1,63,1,64,1,64,1,64,1,64,1,64,1,64,1,64,1,65,1,65,1,65,1,65,1,65, + 1,65,1,66,1,66,1,66,1,66,1,66,1,67,1,67,1,67,1,67,1,67,1,67,1,67, + 1,67,1,67,1,67,1,67,1,67,1,67,1,67,1,67,1,67,1,67,3,67,753,8,67, + 1,68,1,68,1,68,1,68,1,68,1,68,1,68,1,69,1,69,1,69,1,69,1,69,1,69, + 1,69,1,69,1,70,1,70,1,70,1,70,1,70,1,70,1,70,1,70,1,71,1,71,1,71, + 1,71,1,71,1,71,1,71,1,71,1,72,1,72,1,73,1,73,1,74,1,74,1,75,1,75, + 1,76,1,76,5,76,796,8,76,10,76,12,76,799,9,76,1,76,1,76,1,76,4,76, + 804,8,76,11,76,12,76,805,3,76,808,8,76,1,77,1,77,1,78,1,78,1,79, + 1,79,5,79,816,8,79,10,79,12,79,819,9,79,3,79,821,8,79,1,79,1,79, + 1,79,5,79,826,8,79,10,79,12,79,829,9,79,1,79,3,79,832,8,79,3,79, + 834,8,79,1,80,1,80,1,80,1,80,1,80,1,81,1,81,1,82,1,82,1,82,1,82, + 1,82,1,82,1,82,1,82,1,82,3,82,852,8,82,1,83,1,83,1,83,1,83,1,83, + 1,83,1,84,1,84,1,84,1,84,1,84,1,84,1,84,1,84,1,84,1,84,1,85,1,85, + 1,86,1,86,1,86,1,86,1,86,1,86,1,86,1,86,1,86,3,86,881,8,86,1,87, + 1,87,1,87,1,87,1,87,1,87,1,88,1,88,1,88,1,88,1,88,1,88,1,88,1,88, + 1,88,1,88,1,89,1,89,1,90,4,90,902,8,90,11,90,12,90,903,1,91,1,91, + 4,91,908,8,91,11,91,12,91,909,3,797,817,827,0,92,1,1,3,2,5,3,7,4, + 9,5,11,6,13,7,15,8,17,9,19,10,21,11,23,12,25,13,27,14,29,15,31,16, + 33,17,35,18,37,19,39,20,41,21,43,22,45,23,47,24,49,25,51,26,53,27, + 55,28,57,29,59,30,61,31,63,32,65,33,67,34,69,35,71,36,73,37,75,38, + 77,39,79,40,81,41,83,42,85,43,87,44,89,45,91,46,93,47,95,48,97,49, + 99,50,101,51,103,52,105,53,107,54,109,55,111,56,113,57,115,58,117, + 59,119,60,121,61,123,62,125,63,127,64,129,65,131,66,133,67,135,68, + 137,69,139,70,141,71,143,72,145,73,147,74,149,75,151,0,153,76,155, + 77,157,78,159,79,161,80,163,81,165,82,167,83,169,84,171,85,173,86, + 175,87,177,88,179,89,181,90,183,91,1,0,3,3,0,9,10,13,13,32,32,2, + 0,65,90,97,122,1,0,48,57,949,0,1,1,0,0,0,0,3,1,0,0,0,0,5,1,0,0,0, + 0,7,1,0,0,0,0,9,1,0,0,0,0,11,1,0,0,0,0,13,1,0,0,0,0,15,1,0,0,0,0, + 17,1,0,0,0,0,19,1,0,0,0,0,21,1,0,0,0,0,23,1,0,0,0,0,25,1,0,0,0,0, + 27,1,0,0,0,0,29,1,0,0,0,0,31,1,0,0,0,0,33,1,0,0,0,0,35,1,0,0,0,0, + 37,1,0,0,0,0,39,1,0,0,0,0,41,1,0,0,0,0,43,1,0,0,0,0,45,1,0,0,0,0, + 47,1,0,0,0,0,49,1,0,0,0,0,51,1,0,0,0,0,53,1,0,0,0,0,55,1,0,0,0,0, + 57,1,0,0,0,0,59,1,0,0,0,0,61,1,0,0,0,0,63,1,0,0,0,0,65,1,0,0,0,0, + 67,1,0,0,0,0,69,1,0,0,0,0,71,1,0,0,0,0,73,1,0,0,0,0,75,1,0,0,0,0, + 77,1,0,0,0,0,79,1,0,0,0,0,81,1,0,0,0,0,83,1,0,0,0,0,85,1,0,0,0,0, + 87,1,0,0,0,0,89,1,0,0,0,0,91,1,0,0,0,0,93,1,0,0,0,0,95,1,0,0,0,0, + 97,1,0,0,0,0,99,1,0,0,0,0,101,1,0,0,0,0,103,1,0,0,0,0,105,1,0,0, + 0,0,107,1,0,0,0,0,109,1,0,0,0,0,111,1,0,0,0,0,113,1,0,0,0,0,115, + 1,0,0,0,0,117,1,0,0,0,0,119,1,0,0,0,0,121,1,0,0,0,0,123,1,0,0,0, + 0,125,1,0,0,0,0,127,1,0,0,0,0,129,1,0,0,0,0,131,1,0,0,0,0,133,1, + 0,0,0,0,135,1,0,0,0,0,137,1,0,0,0,0,139,1,0,0,0,0,141,1,0,0,0,0, + 143,1,0,0,0,0,145,1,0,0,0,0,147,1,0,0,0,0,149,1,0,0,0,0,153,1,0, + 0,0,0,155,1,0,0,0,0,157,1,0,0,0,0,159,1,0,0,0,0,161,1,0,0,0,0,163, + 1,0,0,0,0,165,1,0,0,0,0,167,1,0,0,0,0,169,1,0,0,0,0,171,1,0,0,0, + 0,173,1,0,0,0,0,175,1,0,0,0,0,177,1,0,0,0,0,179,1,0,0,0,0,181,1, + 0,0,0,0,183,1,0,0,0,1,185,1,0,0,0,3,187,1,0,0,0,5,190,1,0,0,0,7, + 208,1,0,0,0,9,223,1,0,0,0,11,240,1,0,0,0,13,244,1,0,0,0,15,252,1, + 0,0,0,17,276,1,0,0,0,19,280,1,0,0,0,21,295,1,0,0,0,23,312,1,0,0, + 0,25,320,1,0,0,0,27,380,1,0,0,0,29,384,1,0,0,0,31,386,1,0,0,0,33, + 388,1,0,0,0,35,390,1,0,0,0,37,392,1,0,0,0,39,394,1,0,0,0,41,396, + 1,0,0,0,43,398,1,0,0,0,45,400,1,0,0,0,47,403,1,0,0,0,49,406,1,0, + 0,0,51,408,1,0,0,0,53,410,1,0,0,0,55,412,1,0,0,0,57,420,1,0,0,0, + 59,427,1,0,0,0,61,435,1,0,0,0,63,443,1,0,0,0,65,503,1,0,0,0,67,520, + 1,0,0,0,69,522,1,0,0,0,71,527,1,0,0,0,73,533,1,0,0,0,75,538,1,0, + 0,0,77,543,1,0,0,0,79,547,1,0,0,0,81,551,1,0,0,0,83,556,1,0,0,0, + 85,561,1,0,0,0,87,566,1,0,0,0,89,571,1,0,0,0,91,576,1,0,0,0,93,581, + 1,0,0,0,95,589,1,0,0,0,97,597,1,0,0,0,99,605,1,0,0,0,101,613,1,0, + 0,0,103,621,1,0,0,0,105,629,1,0,0,0,107,635,1,0,0,0,109,641,1,0, + 0,0,111,647,1,0,0,0,113,655,1,0,0,0,115,663,1,0,0,0,117,671,1,0, + 0,0,119,679,1,0,0,0,121,687,1,0,0,0,123,694,1,0,0,0,125,701,1,0, + 0,0,127,707,1,0,0,0,129,717,1,0,0,0,131,724,1,0,0,0,133,730,1,0, + 0,0,135,752,1,0,0,0,137,754,1,0,0,0,139,761,1,0,0,0,141,769,1,0, + 0,0,143,777,1,0,0,0,145,785,1,0,0,0,147,787,1,0,0,0,149,789,1,0, + 0,0,151,791,1,0,0,0,153,793,1,0,0,0,155,809,1,0,0,0,157,811,1,0, + 0,0,159,833,1,0,0,0,161,835,1,0,0,0,163,840,1,0,0,0,165,851,1,0, + 0,0,167,853,1,0,0,0,169,859,1,0,0,0,171,869,1,0,0,0,173,880,1,0, + 0,0,175,882,1,0,0,0,177,888,1,0,0,0,179,898,1,0,0,0,181,901,1,0, + 0,0,183,905,1,0,0,0,185,186,5,44,0,0,186,2,1,0,0,0,187,188,5,46, + 0,0,188,4,1,0,0,0,189,191,7,0,0,0,190,189,1,0,0,0,191,192,1,0,0, + 0,192,190,1,0,0,0,192,193,1,0,0,0,193,194,1,0,0,0,194,195,6,2,0, + 0,195,6,1,0,0,0,196,197,5,92,0,0,197,209,5,44,0,0,198,199,5,92,0, + 0,199,200,5,116,0,0,200,201,5,104,0,0,201,202,5,105,0,0,202,203, + 5,110,0,0,203,204,5,115,0,0,204,205,5,112,0,0,205,206,5,97,0,0,206, + 207,5,99,0,0,207,209,5,101,0,0,208,196,1,0,0,0,208,198,1,0,0,0,209, + 210,1,0,0,0,210,211,6,3,0,0,211,8,1,0,0,0,212,213,5,92,0,0,213,224, + 5,58,0,0,214,215,5,92,0,0,215,216,5,109,0,0,216,217,5,101,0,0,217, + 218,5,100,0,0,218,219,5,115,0,0,219,220,5,112,0,0,220,221,5,97,0, + 0,221,222,5,99,0,0,222,224,5,101,0,0,223,212,1,0,0,0,223,214,1,0, + 0,0,224,225,1,0,0,0,225,226,6,4,0,0,226,10,1,0,0,0,227,228,5,92, + 0,0,228,241,5,59,0,0,229,230,5,92,0,0,230,231,5,116,0,0,231,232, + 5,104,0,0,232,233,5,105,0,0,233,234,5,99,0,0,234,235,5,107,0,0,235, + 236,5,115,0,0,236,237,5,112,0,0,237,238,5,97,0,0,238,239,5,99,0, + 0,239,241,5,101,0,0,240,227,1,0,0,0,240,229,1,0,0,0,241,242,1,0, + 0,0,242,243,6,5,0,0,243,12,1,0,0,0,244,245,5,92,0,0,245,246,5,113, + 0,0,246,247,5,117,0,0,247,248,5,97,0,0,248,249,5,100,0,0,249,250, + 1,0,0,0,250,251,6,6,0,0,251,14,1,0,0,0,252,253,5,92,0,0,253,254, + 5,113,0,0,254,255,5,113,0,0,255,256,5,117,0,0,256,257,5,97,0,0,257, + 258,5,100,0,0,258,259,1,0,0,0,259,260,6,7,0,0,260,16,1,0,0,0,261, + 262,5,92,0,0,262,277,5,33,0,0,263,264,5,92,0,0,264,265,5,110,0,0, + 265,266,5,101,0,0,266,267,5,103,0,0,267,268,5,116,0,0,268,269,5, + 104,0,0,269,270,5,105,0,0,270,271,5,110,0,0,271,272,5,115,0,0,272, + 273,5,112,0,0,273,274,5,97,0,0,274,275,5,99,0,0,275,277,5,101,0, + 0,276,261,1,0,0,0,276,263,1,0,0,0,277,278,1,0,0,0,278,279,6,8,0, + 0,279,18,1,0,0,0,280,281,5,92,0,0,281,282,5,110,0,0,282,283,5,101, + 0,0,283,284,5,103,0,0,284,285,5,109,0,0,285,286,5,101,0,0,286,287, + 5,100,0,0,287,288,5,115,0,0,288,289,5,112,0,0,289,290,5,97,0,0,290, + 291,5,99,0,0,291,292,5,101,0,0,292,293,1,0,0,0,293,294,6,9,0,0,294, + 20,1,0,0,0,295,296,5,92,0,0,296,297,5,110,0,0,297,298,5,101,0,0, + 298,299,5,103,0,0,299,300,5,116,0,0,300,301,5,104,0,0,301,302,5, + 105,0,0,302,303,5,99,0,0,303,304,5,107,0,0,304,305,5,115,0,0,305, + 306,5,112,0,0,306,307,5,97,0,0,307,308,5,99,0,0,308,309,5,101,0, + 0,309,310,1,0,0,0,310,311,6,10,0,0,311,22,1,0,0,0,312,313,5,92,0, + 0,313,314,5,108,0,0,314,315,5,101,0,0,315,316,5,102,0,0,316,317, + 5,116,0,0,317,318,1,0,0,0,318,319,6,11,0,0,319,24,1,0,0,0,320,321, + 5,92,0,0,321,322,5,114,0,0,322,323,5,105,0,0,323,324,5,103,0,0,324, + 325,5,104,0,0,325,326,5,116,0,0,326,327,1,0,0,0,327,328,6,12,0,0, + 328,26,1,0,0,0,329,330,5,92,0,0,330,331,5,118,0,0,331,332,5,114, + 0,0,332,333,5,117,0,0,333,334,5,108,0,0,334,381,5,101,0,0,335,336, + 5,92,0,0,336,337,5,118,0,0,337,338,5,99,0,0,338,339,5,101,0,0,339, + 340,5,110,0,0,340,341,5,116,0,0,341,342,5,101,0,0,342,381,5,114, + 0,0,343,344,5,92,0,0,344,345,5,118,0,0,345,346,5,98,0,0,346,347, + 5,111,0,0,347,381,5,120,0,0,348,349,5,92,0,0,349,350,5,118,0,0,350, + 351,5,115,0,0,351,352,5,107,0,0,352,353,5,105,0,0,353,381,5,112, + 0,0,354,355,5,92,0,0,355,356,5,118,0,0,356,357,5,115,0,0,357,358, + 5,112,0,0,358,359,5,97,0,0,359,360,5,99,0,0,360,381,5,101,0,0,361, + 362,5,92,0,0,362,363,5,104,0,0,363,364,5,102,0,0,364,365,5,105,0, + 0,365,381,5,108,0,0,366,367,5,92,0,0,367,381,5,42,0,0,368,369,5, + 92,0,0,369,381,5,45,0,0,370,371,5,92,0,0,371,381,5,46,0,0,372,373, + 5,92,0,0,373,381,5,47,0,0,374,375,5,92,0,0,375,381,5,34,0,0,376, + 377,5,92,0,0,377,381,5,40,0,0,378,379,5,92,0,0,379,381,5,61,0,0, + 380,329,1,0,0,0,380,335,1,0,0,0,380,343,1,0,0,0,380,348,1,0,0,0, + 380,354,1,0,0,0,380,361,1,0,0,0,380,366,1,0,0,0,380,368,1,0,0,0, + 380,370,1,0,0,0,380,372,1,0,0,0,380,374,1,0,0,0,380,376,1,0,0,0, + 380,378,1,0,0,0,381,382,1,0,0,0,382,383,6,13,0,0,383,28,1,0,0,0, + 384,385,5,43,0,0,385,30,1,0,0,0,386,387,5,45,0,0,387,32,1,0,0,0, + 388,389,5,42,0,0,389,34,1,0,0,0,390,391,5,47,0,0,391,36,1,0,0,0, + 392,393,5,40,0,0,393,38,1,0,0,0,394,395,5,41,0,0,395,40,1,0,0,0, + 396,397,5,123,0,0,397,42,1,0,0,0,398,399,5,125,0,0,399,44,1,0,0, + 0,400,401,5,92,0,0,401,402,5,123,0,0,402,46,1,0,0,0,403,404,5,92, + 0,0,404,405,5,125,0,0,405,48,1,0,0,0,406,407,5,91,0,0,407,50,1,0, + 0,0,408,409,5,93,0,0,409,52,1,0,0,0,410,411,5,124,0,0,411,54,1,0, + 0,0,412,413,5,92,0,0,413,414,5,114,0,0,414,415,5,105,0,0,415,416, + 5,103,0,0,416,417,5,104,0,0,417,418,5,116,0,0,418,419,5,124,0,0, + 419,56,1,0,0,0,420,421,5,92,0,0,421,422,5,108,0,0,422,423,5,101, + 0,0,423,424,5,102,0,0,424,425,5,116,0,0,425,426,5,124,0,0,426,58, + 1,0,0,0,427,428,5,92,0,0,428,429,5,108,0,0,429,430,5,97,0,0,430, + 431,5,110,0,0,431,432,5,103,0,0,432,433,5,108,0,0,433,434,5,101, + 0,0,434,60,1,0,0,0,435,436,5,92,0,0,436,437,5,114,0,0,437,438,5, + 97,0,0,438,439,5,110,0,0,439,440,5,103,0,0,440,441,5,108,0,0,441, + 442,5,101,0,0,442,62,1,0,0,0,443,444,5,92,0,0,444,445,5,108,0,0, + 445,446,5,105,0,0,446,447,5,109,0,0,447,64,1,0,0,0,448,449,5,92, + 0,0,449,450,5,116,0,0,450,504,5,111,0,0,451,452,5,92,0,0,452,453, + 5,114,0,0,453,454,5,105,0,0,454,455,5,103,0,0,455,456,5,104,0,0, + 456,457,5,116,0,0,457,458,5,97,0,0,458,459,5,114,0,0,459,460,5,114, + 0,0,460,461,5,111,0,0,461,504,5,119,0,0,462,463,5,92,0,0,463,464, + 5,82,0,0,464,465,5,105,0,0,465,466,5,103,0,0,466,467,5,104,0,0,467, + 468,5,116,0,0,468,469,5,97,0,0,469,470,5,114,0,0,470,471,5,114,0, + 0,471,472,5,111,0,0,472,504,5,119,0,0,473,474,5,92,0,0,474,475,5, + 108,0,0,475,476,5,111,0,0,476,477,5,110,0,0,477,478,5,103,0,0,478, + 479,5,114,0,0,479,480,5,105,0,0,480,481,5,103,0,0,481,482,5,104, + 0,0,482,483,5,116,0,0,483,484,5,97,0,0,484,485,5,114,0,0,485,486, + 5,114,0,0,486,487,5,111,0,0,487,504,5,119,0,0,488,489,5,92,0,0,489, + 490,5,76,0,0,490,491,5,111,0,0,491,492,5,110,0,0,492,493,5,103,0, + 0,493,494,5,114,0,0,494,495,5,105,0,0,495,496,5,103,0,0,496,497, + 5,104,0,0,497,498,5,116,0,0,498,499,5,97,0,0,499,500,5,114,0,0,500, + 501,5,114,0,0,501,502,5,111,0,0,502,504,5,119,0,0,503,448,1,0,0, + 0,503,451,1,0,0,0,503,462,1,0,0,0,503,473,1,0,0,0,503,488,1,0,0, + 0,504,66,1,0,0,0,505,506,5,92,0,0,506,507,5,105,0,0,507,508,5,110, + 0,0,508,521,5,116,0,0,509,510,5,92,0,0,510,511,5,105,0,0,511,512, + 5,110,0,0,512,513,5,116,0,0,513,514,5,92,0,0,514,515,5,108,0,0,515, + 516,5,105,0,0,516,517,5,109,0,0,517,518,5,105,0,0,518,519,5,116, + 0,0,519,521,5,115,0,0,520,505,1,0,0,0,520,509,1,0,0,0,521,68,1,0, + 0,0,522,523,5,92,0,0,523,524,5,115,0,0,524,525,5,117,0,0,525,526, + 5,109,0,0,526,70,1,0,0,0,527,528,5,92,0,0,528,529,5,112,0,0,529, + 530,5,114,0,0,530,531,5,111,0,0,531,532,5,100,0,0,532,72,1,0,0,0, + 533,534,5,92,0,0,534,535,5,101,0,0,535,536,5,120,0,0,536,537,5,112, + 0,0,537,74,1,0,0,0,538,539,5,92,0,0,539,540,5,108,0,0,540,541,5, + 111,0,0,541,542,5,103,0,0,542,76,1,0,0,0,543,544,5,92,0,0,544,545, + 5,108,0,0,545,546,5,103,0,0,546,78,1,0,0,0,547,548,5,92,0,0,548, + 549,5,108,0,0,549,550,5,110,0,0,550,80,1,0,0,0,551,552,5,92,0,0, + 552,553,5,115,0,0,553,554,5,105,0,0,554,555,5,110,0,0,555,82,1,0, + 0,0,556,557,5,92,0,0,557,558,5,99,0,0,558,559,5,111,0,0,559,560, + 5,115,0,0,560,84,1,0,0,0,561,562,5,92,0,0,562,563,5,116,0,0,563, + 564,5,97,0,0,564,565,5,110,0,0,565,86,1,0,0,0,566,567,5,92,0,0,567, + 568,5,99,0,0,568,569,5,115,0,0,569,570,5,99,0,0,570,88,1,0,0,0,571, + 572,5,92,0,0,572,573,5,115,0,0,573,574,5,101,0,0,574,575,5,99,0, + 0,575,90,1,0,0,0,576,577,5,92,0,0,577,578,5,99,0,0,578,579,5,111, + 0,0,579,580,5,116,0,0,580,92,1,0,0,0,581,582,5,92,0,0,582,583,5, + 97,0,0,583,584,5,114,0,0,584,585,5,99,0,0,585,586,5,115,0,0,586, + 587,5,105,0,0,587,588,5,110,0,0,588,94,1,0,0,0,589,590,5,92,0,0, + 590,591,5,97,0,0,591,592,5,114,0,0,592,593,5,99,0,0,593,594,5,99, + 0,0,594,595,5,111,0,0,595,596,5,115,0,0,596,96,1,0,0,0,597,598,5, + 92,0,0,598,599,5,97,0,0,599,600,5,114,0,0,600,601,5,99,0,0,601,602, + 5,116,0,0,602,603,5,97,0,0,603,604,5,110,0,0,604,98,1,0,0,0,605, + 606,5,92,0,0,606,607,5,97,0,0,607,608,5,114,0,0,608,609,5,99,0,0, + 609,610,5,99,0,0,610,611,5,115,0,0,611,612,5,99,0,0,612,100,1,0, + 0,0,613,614,5,92,0,0,614,615,5,97,0,0,615,616,5,114,0,0,616,617, + 5,99,0,0,617,618,5,115,0,0,618,619,5,101,0,0,619,620,5,99,0,0,620, + 102,1,0,0,0,621,622,5,92,0,0,622,623,5,97,0,0,623,624,5,114,0,0, + 624,625,5,99,0,0,625,626,5,99,0,0,626,627,5,111,0,0,627,628,5,116, + 0,0,628,104,1,0,0,0,629,630,5,92,0,0,630,631,5,115,0,0,631,632,5, + 105,0,0,632,633,5,110,0,0,633,634,5,104,0,0,634,106,1,0,0,0,635, + 636,5,92,0,0,636,637,5,99,0,0,637,638,5,111,0,0,638,639,5,115,0, + 0,639,640,5,104,0,0,640,108,1,0,0,0,641,642,5,92,0,0,642,643,5,116, + 0,0,643,644,5,97,0,0,644,645,5,110,0,0,645,646,5,104,0,0,646,110, + 1,0,0,0,647,648,5,92,0,0,648,649,5,97,0,0,649,650,5,114,0,0,650, + 651,5,115,0,0,651,652,5,105,0,0,652,653,5,110,0,0,653,654,5,104, + 0,0,654,112,1,0,0,0,655,656,5,92,0,0,656,657,5,97,0,0,657,658,5, + 114,0,0,658,659,5,99,0,0,659,660,5,111,0,0,660,661,5,115,0,0,661, + 662,5,104,0,0,662,114,1,0,0,0,663,664,5,92,0,0,664,665,5,97,0,0, + 665,666,5,114,0,0,666,667,5,116,0,0,667,668,5,97,0,0,668,669,5,110, + 0,0,669,670,5,104,0,0,670,116,1,0,0,0,671,672,5,92,0,0,672,673,5, + 108,0,0,673,674,5,102,0,0,674,675,5,108,0,0,675,676,5,111,0,0,676, + 677,5,111,0,0,677,678,5,114,0,0,678,118,1,0,0,0,679,680,5,92,0,0, + 680,681,5,114,0,0,681,682,5,102,0,0,682,683,5,108,0,0,683,684,5, + 111,0,0,684,685,5,111,0,0,685,686,5,114,0,0,686,120,1,0,0,0,687, + 688,5,92,0,0,688,689,5,108,0,0,689,690,5,99,0,0,690,691,5,101,0, + 0,691,692,5,105,0,0,692,693,5,108,0,0,693,122,1,0,0,0,694,695,5, + 92,0,0,695,696,5,114,0,0,696,697,5,99,0,0,697,698,5,101,0,0,698, + 699,5,105,0,0,699,700,5,108,0,0,700,124,1,0,0,0,701,702,5,92,0,0, + 702,703,5,115,0,0,703,704,5,113,0,0,704,705,5,114,0,0,705,706,5, + 116,0,0,706,126,1,0,0,0,707,708,5,92,0,0,708,709,5,111,0,0,709,710, + 5,118,0,0,710,711,5,101,0,0,711,712,5,114,0,0,712,713,5,108,0,0, + 713,714,5,105,0,0,714,715,5,110,0,0,715,716,5,101,0,0,716,128,1, + 0,0,0,717,718,5,92,0,0,718,719,5,116,0,0,719,720,5,105,0,0,720,721, + 5,109,0,0,721,722,5,101,0,0,722,723,5,115,0,0,723,130,1,0,0,0,724, + 725,5,92,0,0,725,726,5,99,0,0,726,727,5,100,0,0,727,728,5,111,0, + 0,728,729,5,116,0,0,729,132,1,0,0,0,730,731,5,92,0,0,731,732,5,100, + 0,0,732,733,5,105,0,0,733,734,5,118,0,0,734,134,1,0,0,0,735,736, + 5,92,0,0,736,737,5,102,0,0,737,738,5,114,0,0,738,739,5,97,0,0,739, + 753,5,99,0,0,740,741,5,92,0,0,741,742,5,100,0,0,742,743,5,102,0, + 0,743,744,5,114,0,0,744,745,5,97,0,0,745,753,5,99,0,0,746,747,5, + 92,0,0,747,748,5,116,0,0,748,749,5,102,0,0,749,750,5,114,0,0,750, + 751,5,97,0,0,751,753,5,99,0,0,752,735,1,0,0,0,752,740,1,0,0,0,752, + 746,1,0,0,0,753,136,1,0,0,0,754,755,5,92,0,0,755,756,5,98,0,0,756, + 757,5,105,0,0,757,758,5,110,0,0,758,759,5,111,0,0,759,760,5,109, + 0,0,760,138,1,0,0,0,761,762,5,92,0,0,762,763,5,100,0,0,763,764,5, + 98,0,0,764,765,5,105,0,0,765,766,5,110,0,0,766,767,5,111,0,0,767, + 768,5,109,0,0,768,140,1,0,0,0,769,770,5,92,0,0,770,771,5,116,0,0, + 771,772,5,98,0,0,772,773,5,105,0,0,773,774,5,110,0,0,774,775,5,111, + 0,0,775,776,5,109,0,0,776,142,1,0,0,0,777,778,5,92,0,0,778,779,5, + 109,0,0,779,780,5,97,0,0,780,781,5,116,0,0,781,782,5,104,0,0,782, + 783,5,105,0,0,783,784,5,116,0,0,784,144,1,0,0,0,785,786,5,95,0,0, + 786,146,1,0,0,0,787,788,5,94,0,0,788,148,1,0,0,0,789,790,5,58,0, + 0,790,150,1,0,0,0,791,792,7,0,0,0,792,152,1,0,0,0,793,797,5,100, + 0,0,794,796,3,151,75,0,795,794,1,0,0,0,796,799,1,0,0,0,797,798,1, + 0,0,0,797,795,1,0,0,0,798,807,1,0,0,0,799,797,1,0,0,0,800,808,7, + 1,0,0,801,803,5,92,0,0,802,804,7,1,0,0,803,802,1,0,0,0,804,805,1, + 0,0,0,805,803,1,0,0,0,805,806,1,0,0,0,806,808,1,0,0,0,807,800,1, + 0,0,0,807,801,1,0,0,0,808,154,1,0,0,0,809,810,7,1,0,0,810,156,1, + 0,0,0,811,812,7,2,0,0,812,158,1,0,0,0,813,817,5,38,0,0,814,816,3, + 151,75,0,815,814,1,0,0,0,816,819,1,0,0,0,817,818,1,0,0,0,817,815, + 1,0,0,0,818,821,1,0,0,0,819,817,1,0,0,0,820,813,1,0,0,0,820,821, + 1,0,0,0,821,822,1,0,0,0,822,834,5,61,0,0,823,831,5,61,0,0,824,826, + 3,151,75,0,825,824,1,0,0,0,826,829,1,0,0,0,827,828,1,0,0,0,827,825, + 1,0,0,0,828,830,1,0,0,0,829,827,1,0,0,0,830,832,5,38,0,0,831,827, + 1,0,0,0,831,832,1,0,0,0,832,834,1,0,0,0,833,820,1,0,0,0,833,823, + 1,0,0,0,834,160,1,0,0,0,835,836,5,92,0,0,836,837,5,110,0,0,837,838, + 5,101,0,0,838,839,5,113,0,0,839,162,1,0,0,0,840,841,5,60,0,0,841, + 164,1,0,0,0,842,843,5,92,0,0,843,844,5,108,0,0,844,845,5,101,0,0, + 845,852,5,113,0,0,846,847,5,92,0,0,847,848,5,108,0,0,848,852,5,101, + 0,0,849,852,3,167,83,0,850,852,3,169,84,0,851,842,1,0,0,0,851,846, + 1,0,0,0,851,849,1,0,0,0,851,850,1,0,0,0,852,166,1,0,0,0,853,854, + 5,92,0,0,854,855,5,108,0,0,855,856,5,101,0,0,856,857,5,113,0,0,857, + 858,5,113,0,0,858,168,1,0,0,0,859,860,5,92,0,0,860,861,5,108,0,0, + 861,862,5,101,0,0,862,863,5,113,0,0,863,864,5,115,0,0,864,865,5, + 108,0,0,865,866,5,97,0,0,866,867,5,110,0,0,867,868,5,116,0,0,868, + 170,1,0,0,0,869,870,5,62,0,0,870,172,1,0,0,0,871,872,5,92,0,0,872, + 873,5,103,0,0,873,874,5,101,0,0,874,881,5,113,0,0,875,876,5,92,0, + 0,876,877,5,103,0,0,877,881,5,101,0,0,878,881,3,175,87,0,879,881, + 3,177,88,0,880,871,1,0,0,0,880,875,1,0,0,0,880,878,1,0,0,0,880,879, + 1,0,0,0,881,174,1,0,0,0,882,883,5,92,0,0,883,884,5,103,0,0,884,885, + 5,101,0,0,885,886,5,113,0,0,886,887,5,113,0,0,887,176,1,0,0,0,888, + 889,5,92,0,0,889,890,5,103,0,0,890,891,5,101,0,0,891,892,5,113,0, + 0,892,893,5,115,0,0,893,894,5,108,0,0,894,895,5,97,0,0,895,896,5, + 110,0,0,896,897,5,116,0,0,897,178,1,0,0,0,898,899,5,33,0,0,899,180, + 1,0,0,0,900,902,5,39,0,0,901,900,1,0,0,0,902,903,1,0,0,0,903,901, + 1,0,0,0,903,904,1,0,0,0,904,182,1,0,0,0,905,907,5,92,0,0,906,908, + 7,1,0,0,907,906,1,0,0,0,908,909,1,0,0,0,909,907,1,0,0,0,909,910, + 1,0,0,0,910,184,1,0,0,0,22,0,192,208,223,240,276,380,503,520,752, + 797,805,807,817,820,827,831,833,851,880,903,909,1,6,0,0 + ] + +class LaTeXLexer(Lexer): + + atn = ATNDeserializer().deserialize(serializedATN()) + + decisionsToDFA = [ DFA(ds, i) for i, ds in enumerate(atn.decisionToState) ] + + T__0 = 1 + T__1 = 2 + WS = 3 + THINSPACE = 4 + MEDSPACE = 5 + THICKSPACE = 6 + QUAD = 7 + QQUAD = 8 + NEGTHINSPACE = 9 + NEGMEDSPACE = 10 + NEGTHICKSPACE = 11 + CMD_LEFT = 12 + CMD_RIGHT = 13 + IGNORE = 14 + ADD = 15 + SUB = 16 + MUL = 17 + DIV = 18 + L_PAREN = 19 + R_PAREN = 20 + L_BRACE = 21 + R_BRACE = 22 + L_BRACE_LITERAL = 23 + R_BRACE_LITERAL = 24 + L_BRACKET = 25 + R_BRACKET = 26 + BAR = 27 + R_BAR = 28 + L_BAR = 29 + L_ANGLE = 30 + R_ANGLE = 31 + FUNC_LIM = 32 + LIM_APPROACH_SYM = 33 + FUNC_INT = 34 + FUNC_SUM = 35 + FUNC_PROD = 36 + FUNC_EXP = 37 + FUNC_LOG = 38 + FUNC_LG = 39 + FUNC_LN = 40 + FUNC_SIN = 41 + FUNC_COS = 42 + FUNC_TAN = 43 + FUNC_CSC = 44 + FUNC_SEC = 45 + FUNC_COT = 46 + FUNC_ARCSIN = 47 + FUNC_ARCCOS = 48 + FUNC_ARCTAN = 49 + FUNC_ARCCSC = 50 + FUNC_ARCSEC = 51 + FUNC_ARCCOT = 52 + FUNC_SINH = 53 + FUNC_COSH = 54 + FUNC_TANH = 55 + FUNC_ARSINH = 56 + FUNC_ARCOSH = 57 + FUNC_ARTANH = 58 + L_FLOOR = 59 + R_FLOOR = 60 + L_CEIL = 61 + R_CEIL = 62 + FUNC_SQRT = 63 + FUNC_OVERLINE = 64 + CMD_TIMES = 65 + CMD_CDOT = 66 + CMD_DIV = 67 + CMD_FRAC = 68 + CMD_BINOM = 69 + CMD_DBINOM = 70 + CMD_TBINOM = 71 + CMD_MATHIT = 72 + UNDERSCORE = 73 + CARET = 74 + COLON = 75 + DIFFERENTIAL = 76 + LETTER = 77 + DIGIT = 78 + EQUAL = 79 + NEQ = 80 + LT = 81 + LTE = 82 + LTE_Q = 83 + LTE_S = 84 + GT = 85 + GTE = 86 + GTE_Q = 87 + GTE_S = 88 + BANG = 89 + SINGLE_QUOTES = 90 + SYMBOL = 91 + + channelNames = [ u"DEFAULT_TOKEN_CHANNEL", u"HIDDEN" ] + + modeNames = [ "DEFAULT_MODE" ] + + literalNames = [ "", + "','", "'.'", "'\\quad'", "'\\qquad'", "'\\negmedspace'", "'\\negthickspace'", + "'\\left'", "'\\right'", "'+'", "'-'", "'*'", "'/'", "'('", + "')'", "'{'", "'}'", "'\\{'", "'\\}'", "'['", "']'", "'|'", + "'\\right|'", "'\\left|'", "'\\langle'", "'\\rangle'", "'\\lim'", + "'\\sum'", "'\\prod'", "'\\exp'", "'\\log'", "'\\lg'", "'\\ln'", + "'\\sin'", "'\\cos'", "'\\tan'", "'\\csc'", "'\\sec'", "'\\cot'", + "'\\arcsin'", "'\\arccos'", "'\\arctan'", "'\\arccsc'", "'\\arcsec'", + "'\\arccot'", "'\\sinh'", "'\\cosh'", "'\\tanh'", "'\\arsinh'", + "'\\arcosh'", "'\\artanh'", "'\\lfloor'", "'\\rfloor'", "'\\lceil'", + "'\\rceil'", "'\\sqrt'", "'\\overline'", "'\\times'", "'\\cdot'", + "'\\div'", "'\\binom'", "'\\dbinom'", "'\\tbinom'", "'\\mathit'", + "'_'", "'^'", "':'", "'\\neq'", "'<'", "'\\leqq'", "'\\leqslant'", + "'>'", "'\\geqq'", "'\\geqslant'", "'!'" ] + + symbolicNames = [ "", + "WS", "THINSPACE", "MEDSPACE", "THICKSPACE", "QUAD", "QQUAD", + "NEGTHINSPACE", "NEGMEDSPACE", "NEGTHICKSPACE", "CMD_LEFT", + "CMD_RIGHT", "IGNORE", "ADD", "SUB", "MUL", "DIV", "L_PAREN", + "R_PAREN", "L_BRACE", "R_BRACE", "L_BRACE_LITERAL", "R_BRACE_LITERAL", + "L_BRACKET", "R_BRACKET", "BAR", "R_BAR", "L_BAR", "L_ANGLE", + "R_ANGLE", "FUNC_LIM", "LIM_APPROACH_SYM", "FUNC_INT", "FUNC_SUM", + "FUNC_PROD", "FUNC_EXP", "FUNC_LOG", "FUNC_LG", "FUNC_LN", "FUNC_SIN", + "FUNC_COS", "FUNC_TAN", "FUNC_CSC", "FUNC_SEC", "FUNC_COT", + "FUNC_ARCSIN", "FUNC_ARCCOS", "FUNC_ARCTAN", "FUNC_ARCCSC", + "FUNC_ARCSEC", "FUNC_ARCCOT", "FUNC_SINH", "FUNC_COSH", "FUNC_TANH", + "FUNC_ARSINH", "FUNC_ARCOSH", "FUNC_ARTANH", "L_FLOOR", "R_FLOOR", + "L_CEIL", "R_CEIL", "FUNC_SQRT", "FUNC_OVERLINE", "CMD_TIMES", + "CMD_CDOT", "CMD_DIV", "CMD_FRAC", "CMD_BINOM", "CMD_DBINOM", + "CMD_TBINOM", "CMD_MATHIT", "UNDERSCORE", "CARET", "COLON", + "DIFFERENTIAL", "LETTER", "DIGIT", "EQUAL", "NEQ", "LT", "LTE", + "LTE_Q", "LTE_S", "GT", "GTE", "GTE_Q", "GTE_S", "BANG", "SINGLE_QUOTES", + "SYMBOL" ] + + ruleNames = [ "T__0", "T__1", "WS", "THINSPACE", "MEDSPACE", "THICKSPACE", + "QUAD", "QQUAD", "NEGTHINSPACE", "NEGMEDSPACE", "NEGTHICKSPACE", + "CMD_LEFT", "CMD_RIGHT", "IGNORE", "ADD", "SUB", "MUL", + "DIV", "L_PAREN", "R_PAREN", "L_BRACE", "R_BRACE", "L_BRACE_LITERAL", + "R_BRACE_LITERAL", "L_BRACKET", "R_BRACKET", "BAR", "R_BAR", + "L_BAR", "L_ANGLE", "R_ANGLE", "FUNC_LIM", "LIM_APPROACH_SYM", + "FUNC_INT", "FUNC_SUM", "FUNC_PROD", "FUNC_EXP", "FUNC_LOG", + "FUNC_LG", "FUNC_LN", "FUNC_SIN", "FUNC_COS", "FUNC_TAN", + "FUNC_CSC", "FUNC_SEC", "FUNC_COT", "FUNC_ARCSIN", "FUNC_ARCCOS", + "FUNC_ARCTAN", "FUNC_ARCCSC", "FUNC_ARCSEC", "FUNC_ARCCOT", + "FUNC_SINH", "FUNC_COSH", "FUNC_TANH", "FUNC_ARSINH", + "FUNC_ARCOSH", "FUNC_ARTANH", "L_FLOOR", "R_FLOOR", "L_CEIL", + "R_CEIL", "FUNC_SQRT", "FUNC_OVERLINE", "CMD_TIMES", "CMD_CDOT", + "CMD_DIV", "CMD_FRAC", "CMD_BINOM", "CMD_DBINOM", "CMD_TBINOM", + "CMD_MATHIT", "UNDERSCORE", "CARET", "COLON", "WS_CHAR", + "DIFFERENTIAL", "LETTER", "DIGIT", "EQUAL", "NEQ", "LT", + "LTE", "LTE_Q", "LTE_S", "GT", "GTE", "GTE_Q", "GTE_S", + "BANG", "SINGLE_QUOTES", "SYMBOL" ] + + grammarFileName = "LaTeX.g4" + + def __init__(self, input=None, output:TextIO = sys.stdout): + super().__init__(input, output) + self.checkVersion("4.11.1") + self._interp = LexerATNSimulator(self, self.atn, self.decisionsToDFA, PredictionContextCache()) + self._actions = None + self._predicates = None + + diff --git a/phivenv/Lib/site-packages/sympy/parsing/latex/_antlr/latexparser.py b/phivenv/Lib/site-packages/sympy/parsing/latex/_antlr/latexparser.py new file mode 100644 index 0000000000000000000000000000000000000000..f6f58119055ded8f77380bbef52c77ddd6a01cfe --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/latex/_antlr/latexparser.py @@ -0,0 +1,3652 @@ +# *** GENERATED BY `setup.py antlr`, DO NOT EDIT BY HAND *** +# +# Generated from ../LaTeX.g4, derived from latex2sympy +# latex2sympy is licensed under the MIT license +# https://github.com/augustt198/latex2sympy/blob/master/LICENSE.txt +# +# Generated with antlr4 +# antlr4 is licensed under the BSD-3-Clause License +# https://github.com/antlr/antlr4/blob/master/LICENSE.txt +from antlr4 import * +from io import StringIO +import sys +if sys.version_info[1] > 5: + from typing import TextIO +else: + from typing.io import TextIO + +def serializedATN(): + return [ + 4,1,91,522,2,0,7,0,2,1,7,1,2,2,7,2,2,3,7,3,2,4,7,4,2,5,7,5,2,6,7, + 6,2,7,7,7,2,8,7,8,2,9,7,9,2,10,7,10,2,11,7,11,2,12,7,12,2,13,7,13, + 2,14,7,14,2,15,7,15,2,16,7,16,2,17,7,17,2,18,7,18,2,19,7,19,2,20, + 7,20,2,21,7,21,2,22,7,22,2,23,7,23,2,24,7,24,2,25,7,25,2,26,7,26, + 2,27,7,27,2,28,7,28,2,29,7,29,2,30,7,30,2,31,7,31,2,32,7,32,2,33, + 7,33,2,34,7,34,2,35,7,35,2,36,7,36,2,37,7,37,2,38,7,38,2,39,7,39, + 2,40,7,40,1,0,1,0,1,1,1,1,1,1,1,1,1,1,1,1,5,1,91,8,1,10,1,12,1,94, + 9,1,1,2,1,2,1,2,1,2,1,3,1,3,1,4,1,4,1,4,1,4,1,4,1,4,5,4,108,8,4, + 10,4,12,4,111,9,4,1,5,1,5,1,5,1,5,1,5,1,5,5,5,119,8,5,10,5,12,5, + 122,9,5,1,6,1,6,1,6,1,6,1,6,1,6,5,6,130,8,6,10,6,12,6,133,9,6,1, + 7,1,7,1,7,4,7,138,8,7,11,7,12,7,139,3,7,142,8,7,1,8,1,8,1,8,1,8, + 5,8,148,8,8,10,8,12,8,151,9,8,3,8,153,8,8,1,9,1,9,5,9,157,8,9,10, + 9,12,9,160,9,9,1,10,1,10,5,10,164,8,10,10,10,12,10,167,9,10,1,11, + 1,11,3,11,171,8,11,1,12,1,12,1,12,1,12,1,12,1,12,3,12,179,8,12,1, + 13,1,13,1,13,1,13,3,13,185,8,13,1,13,1,13,1,14,1,14,1,14,1,14,3, + 14,193,8,14,1,14,1,14,1,15,1,15,1,15,1,15,1,15,1,15,1,15,1,15,1, + 15,1,15,3,15,207,8,15,1,15,3,15,210,8,15,5,15,212,8,15,10,15,12, + 15,215,9,15,1,16,1,16,1,16,1,16,1,16,1,16,1,16,1,16,1,16,1,16,3, + 16,227,8,16,1,16,3,16,230,8,16,5,16,232,8,16,10,16,12,16,235,9,16, + 1,17,1,17,1,17,1,17,1,17,1,17,3,17,243,8,17,1,18,1,18,1,18,1,18, + 1,18,3,18,250,8,18,1,19,1,19,1,19,1,19,1,19,1,19,1,19,1,19,1,19, + 1,19,1,19,1,19,1,19,1,19,1,19,1,19,3,19,268,8,19,1,20,1,20,1,20, + 1,20,1,21,4,21,275,8,21,11,21,12,21,276,1,21,1,21,1,21,1,21,5,21, + 283,8,21,10,21,12,21,286,9,21,1,21,1,21,4,21,290,8,21,11,21,12,21, + 291,3,21,294,8,21,1,22,1,22,3,22,298,8,22,1,22,3,22,301,8,22,1,22, + 3,22,304,8,22,1,22,3,22,307,8,22,3,22,309,8,22,1,22,1,22,1,22,1, + 22,1,22,1,22,1,22,3,22,318,8,22,1,23,1,23,1,23,1,23,1,24,1,24,1, + 24,1,24,1,25,1,25,1,25,1,25,1,25,1,26,5,26,334,8,26,10,26,12,26, + 337,9,26,1,27,1,27,1,27,1,27,1,27,1,27,3,27,345,8,27,1,27,1,27,1, + 27,1,27,1,27,3,27,352,8,27,1,28,1,28,1,28,1,28,1,28,1,28,1,28,1, + 28,1,29,1,29,1,29,1,29,1,30,1,30,1,30,1,30,1,31,1,31,1,32,1,32,3, + 32,374,8,32,1,32,3,32,377,8,32,1,32,3,32,380,8,32,1,32,3,32,383, + 8,32,3,32,385,8,32,1,32,1,32,1,32,1,32,1,32,3,32,392,8,32,1,32,1, + 32,3,32,396,8,32,1,32,3,32,399,8,32,1,32,3,32,402,8,32,1,32,3,32, + 405,8,32,3,32,407,8,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1, + 32,1,32,1,32,3,32,420,8,32,1,32,3,32,423,8,32,1,32,1,32,1,32,3,32, + 428,8,32,1,32,1,32,1,32,1,32,1,32,3,32,435,8,32,1,32,1,32,1,32,1, + 32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,1,32,3, + 32,453,8,32,1,32,1,32,1,32,1,32,1,32,1,32,3,32,461,8,32,1,33,1,33, + 1,33,1,33,1,33,3,33,468,8,33,1,34,1,34,1,34,1,34,1,34,1,34,1,34, + 1,34,1,34,1,34,1,34,3,34,481,8,34,3,34,483,8,34,1,34,1,34,1,35,1, + 35,1,35,1,35,1,35,3,35,492,8,35,1,36,1,36,1,37,1,37,1,37,1,37,1, + 37,1,37,3,37,502,8,37,1,38,1,38,1,38,1,38,1,38,1,38,3,38,510,8,38, + 1,39,1,39,1,39,1,39,1,39,1,40,1,40,1,40,1,40,1,40,1,40,0,6,2,8,10, + 12,30,32,41,0,2,4,6,8,10,12,14,16,18,20,22,24,26,28,30,32,34,36, + 38,40,42,44,46,48,50,52,54,56,58,60,62,64,66,68,70,72,74,76,78,80, + 0,9,2,0,79,82,85,86,1,0,15,16,3,0,17,18,65,67,75,75,2,0,77,77,91, + 91,1,0,27,28,2,0,27,27,29,29,1,0,69,71,1,0,37,58,1,0,35,36,563,0, + 82,1,0,0,0,2,84,1,0,0,0,4,95,1,0,0,0,6,99,1,0,0,0,8,101,1,0,0,0, + 10,112,1,0,0,0,12,123,1,0,0,0,14,141,1,0,0,0,16,152,1,0,0,0,18,154, + 1,0,0,0,20,161,1,0,0,0,22,170,1,0,0,0,24,172,1,0,0,0,26,180,1,0, + 0,0,28,188,1,0,0,0,30,196,1,0,0,0,32,216,1,0,0,0,34,242,1,0,0,0, + 36,249,1,0,0,0,38,267,1,0,0,0,40,269,1,0,0,0,42,274,1,0,0,0,44,317, + 1,0,0,0,46,319,1,0,0,0,48,323,1,0,0,0,50,327,1,0,0,0,52,335,1,0, + 0,0,54,338,1,0,0,0,56,353,1,0,0,0,58,361,1,0,0,0,60,365,1,0,0,0, + 62,369,1,0,0,0,64,460,1,0,0,0,66,467,1,0,0,0,68,469,1,0,0,0,70,491, + 1,0,0,0,72,493,1,0,0,0,74,495,1,0,0,0,76,503,1,0,0,0,78,511,1,0, + 0,0,80,516,1,0,0,0,82,83,3,2,1,0,83,1,1,0,0,0,84,85,6,1,-1,0,85, + 86,3,6,3,0,86,92,1,0,0,0,87,88,10,2,0,0,88,89,7,0,0,0,89,91,3,2, + 1,3,90,87,1,0,0,0,91,94,1,0,0,0,92,90,1,0,0,0,92,93,1,0,0,0,93,3, + 1,0,0,0,94,92,1,0,0,0,95,96,3,6,3,0,96,97,5,79,0,0,97,98,3,6,3,0, + 98,5,1,0,0,0,99,100,3,8,4,0,100,7,1,0,0,0,101,102,6,4,-1,0,102,103, + 3,10,5,0,103,109,1,0,0,0,104,105,10,2,0,0,105,106,7,1,0,0,106,108, + 3,8,4,3,107,104,1,0,0,0,108,111,1,0,0,0,109,107,1,0,0,0,109,110, + 1,0,0,0,110,9,1,0,0,0,111,109,1,0,0,0,112,113,6,5,-1,0,113,114,3, + 14,7,0,114,120,1,0,0,0,115,116,10,2,0,0,116,117,7,2,0,0,117,119, + 3,10,5,3,118,115,1,0,0,0,119,122,1,0,0,0,120,118,1,0,0,0,120,121, + 1,0,0,0,121,11,1,0,0,0,122,120,1,0,0,0,123,124,6,6,-1,0,124,125, + 3,16,8,0,125,131,1,0,0,0,126,127,10,2,0,0,127,128,7,2,0,0,128,130, + 3,12,6,3,129,126,1,0,0,0,130,133,1,0,0,0,131,129,1,0,0,0,131,132, + 1,0,0,0,132,13,1,0,0,0,133,131,1,0,0,0,134,135,7,1,0,0,135,142,3, + 14,7,0,136,138,3,18,9,0,137,136,1,0,0,0,138,139,1,0,0,0,139,137, + 1,0,0,0,139,140,1,0,0,0,140,142,1,0,0,0,141,134,1,0,0,0,141,137, + 1,0,0,0,142,15,1,0,0,0,143,144,7,1,0,0,144,153,3,16,8,0,145,149, + 3,18,9,0,146,148,3,20,10,0,147,146,1,0,0,0,148,151,1,0,0,0,149,147, + 1,0,0,0,149,150,1,0,0,0,150,153,1,0,0,0,151,149,1,0,0,0,152,143, + 1,0,0,0,152,145,1,0,0,0,153,17,1,0,0,0,154,158,3,30,15,0,155,157, + 3,22,11,0,156,155,1,0,0,0,157,160,1,0,0,0,158,156,1,0,0,0,158,159, + 1,0,0,0,159,19,1,0,0,0,160,158,1,0,0,0,161,165,3,32,16,0,162,164, + 3,22,11,0,163,162,1,0,0,0,164,167,1,0,0,0,165,163,1,0,0,0,165,166, + 1,0,0,0,166,21,1,0,0,0,167,165,1,0,0,0,168,171,5,89,0,0,169,171, + 3,24,12,0,170,168,1,0,0,0,170,169,1,0,0,0,171,23,1,0,0,0,172,178, + 5,27,0,0,173,179,3,28,14,0,174,179,3,26,13,0,175,176,3,28,14,0,176, + 177,3,26,13,0,177,179,1,0,0,0,178,173,1,0,0,0,178,174,1,0,0,0,178, + 175,1,0,0,0,179,25,1,0,0,0,180,181,5,73,0,0,181,184,5,21,0,0,182, + 185,3,6,3,0,183,185,3,4,2,0,184,182,1,0,0,0,184,183,1,0,0,0,185, + 186,1,0,0,0,186,187,5,22,0,0,187,27,1,0,0,0,188,189,5,74,0,0,189, + 192,5,21,0,0,190,193,3,6,3,0,191,193,3,4,2,0,192,190,1,0,0,0,192, + 191,1,0,0,0,193,194,1,0,0,0,194,195,5,22,0,0,195,29,1,0,0,0,196, + 197,6,15,-1,0,197,198,3,34,17,0,198,213,1,0,0,0,199,200,10,2,0,0, + 200,206,5,74,0,0,201,207,3,44,22,0,202,203,5,21,0,0,203,204,3,6, + 3,0,204,205,5,22,0,0,205,207,1,0,0,0,206,201,1,0,0,0,206,202,1,0, + 0,0,207,209,1,0,0,0,208,210,3,74,37,0,209,208,1,0,0,0,209,210,1, + 0,0,0,210,212,1,0,0,0,211,199,1,0,0,0,212,215,1,0,0,0,213,211,1, + 0,0,0,213,214,1,0,0,0,214,31,1,0,0,0,215,213,1,0,0,0,216,217,6,16, + -1,0,217,218,3,36,18,0,218,233,1,0,0,0,219,220,10,2,0,0,220,226, + 5,74,0,0,221,227,3,44,22,0,222,223,5,21,0,0,223,224,3,6,3,0,224, + 225,5,22,0,0,225,227,1,0,0,0,226,221,1,0,0,0,226,222,1,0,0,0,227, + 229,1,0,0,0,228,230,3,74,37,0,229,228,1,0,0,0,229,230,1,0,0,0,230, + 232,1,0,0,0,231,219,1,0,0,0,232,235,1,0,0,0,233,231,1,0,0,0,233, + 234,1,0,0,0,234,33,1,0,0,0,235,233,1,0,0,0,236,243,3,38,19,0,237, + 243,3,40,20,0,238,243,3,64,32,0,239,243,3,44,22,0,240,243,3,58,29, + 0,241,243,3,60,30,0,242,236,1,0,0,0,242,237,1,0,0,0,242,238,1,0, + 0,0,242,239,1,0,0,0,242,240,1,0,0,0,242,241,1,0,0,0,243,35,1,0,0, + 0,244,250,3,38,19,0,245,250,3,40,20,0,246,250,3,44,22,0,247,250, + 3,58,29,0,248,250,3,60,30,0,249,244,1,0,0,0,249,245,1,0,0,0,249, + 246,1,0,0,0,249,247,1,0,0,0,249,248,1,0,0,0,250,37,1,0,0,0,251,252, + 5,19,0,0,252,253,3,6,3,0,253,254,5,20,0,0,254,268,1,0,0,0,255,256, + 5,25,0,0,256,257,3,6,3,0,257,258,5,26,0,0,258,268,1,0,0,0,259,260, + 5,21,0,0,260,261,3,6,3,0,261,262,5,22,0,0,262,268,1,0,0,0,263,264, + 5,23,0,0,264,265,3,6,3,0,265,266,5,24,0,0,266,268,1,0,0,0,267,251, + 1,0,0,0,267,255,1,0,0,0,267,259,1,0,0,0,267,263,1,0,0,0,268,39,1, + 0,0,0,269,270,5,27,0,0,270,271,3,6,3,0,271,272,5,27,0,0,272,41,1, + 0,0,0,273,275,5,78,0,0,274,273,1,0,0,0,275,276,1,0,0,0,276,274,1, + 0,0,0,276,277,1,0,0,0,277,284,1,0,0,0,278,279,5,1,0,0,279,280,5, + 78,0,0,280,281,5,78,0,0,281,283,5,78,0,0,282,278,1,0,0,0,283,286, + 1,0,0,0,284,282,1,0,0,0,284,285,1,0,0,0,285,293,1,0,0,0,286,284, + 1,0,0,0,287,289,5,2,0,0,288,290,5,78,0,0,289,288,1,0,0,0,290,291, + 1,0,0,0,291,289,1,0,0,0,291,292,1,0,0,0,292,294,1,0,0,0,293,287, + 1,0,0,0,293,294,1,0,0,0,294,43,1,0,0,0,295,308,7,3,0,0,296,298,3, + 74,37,0,297,296,1,0,0,0,297,298,1,0,0,0,298,300,1,0,0,0,299,301, + 5,90,0,0,300,299,1,0,0,0,300,301,1,0,0,0,301,309,1,0,0,0,302,304, + 5,90,0,0,303,302,1,0,0,0,303,304,1,0,0,0,304,306,1,0,0,0,305,307, + 3,74,37,0,306,305,1,0,0,0,306,307,1,0,0,0,307,309,1,0,0,0,308,297, + 1,0,0,0,308,303,1,0,0,0,309,318,1,0,0,0,310,318,3,42,21,0,311,318, + 5,76,0,0,312,318,3,50,25,0,313,318,3,54,27,0,314,318,3,56,28,0,315, + 318,3,46,23,0,316,318,3,48,24,0,317,295,1,0,0,0,317,310,1,0,0,0, + 317,311,1,0,0,0,317,312,1,0,0,0,317,313,1,0,0,0,317,314,1,0,0,0, + 317,315,1,0,0,0,317,316,1,0,0,0,318,45,1,0,0,0,319,320,5,30,0,0, + 320,321,3,6,3,0,321,322,7,4,0,0,322,47,1,0,0,0,323,324,7,5,0,0,324, + 325,3,6,3,0,325,326,5,31,0,0,326,49,1,0,0,0,327,328,5,72,0,0,328, + 329,5,21,0,0,329,330,3,52,26,0,330,331,5,22,0,0,331,51,1,0,0,0,332, + 334,5,77,0,0,333,332,1,0,0,0,334,337,1,0,0,0,335,333,1,0,0,0,335, + 336,1,0,0,0,336,53,1,0,0,0,337,335,1,0,0,0,338,344,5,68,0,0,339, + 345,5,78,0,0,340,341,5,21,0,0,341,342,3,6,3,0,342,343,5,22,0,0,343, + 345,1,0,0,0,344,339,1,0,0,0,344,340,1,0,0,0,345,351,1,0,0,0,346, + 352,5,78,0,0,347,348,5,21,0,0,348,349,3,6,3,0,349,350,5,22,0,0,350, + 352,1,0,0,0,351,346,1,0,0,0,351,347,1,0,0,0,352,55,1,0,0,0,353,354, + 7,6,0,0,354,355,5,21,0,0,355,356,3,6,3,0,356,357,5,22,0,0,357,358, + 5,21,0,0,358,359,3,6,3,0,359,360,5,22,0,0,360,57,1,0,0,0,361,362, + 5,59,0,0,362,363,3,6,3,0,363,364,5,60,0,0,364,59,1,0,0,0,365,366, + 5,61,0,0,366,367,3,6,3,0,367,368,5,62,0,0,368,61,1,0,0,0,369,370, + 7,7,0,0,370,63,1,0,0,0,371,384,3,62,31,0,372,374,3,74,37,0,373,372, + 1,0,0,0,373,374,1,0,0,0,374,376,1,0,0,0,375,377,3,76,38,0,376,375, + 1,0,0,0,376,377,1,0,0,0,377,385,1,0,0,0,378,380,3,76,38,0,379,378, + 1,0,0,0,379,380,1,0,0,0,380,382,1,0,0,0,381,383,3,74,37,0,382,381, + 1,0,0,0,382,383,1,0,0,0,383,385,1,0,0,0,384,373,1,0,0,0,384,379, + 1,0,0,0,385,391,1,0,0,0,386,387,5,19,0,0,387,388,3,70,35,0,388,389, + 5,20,0,0,389,392,1,0,0,0,390,392,3,72,36,0,391,386,1,0,0,0,391,390, + 1,0,0,0,392,461,1,0,0,0,393,406,7,3,0,0,394,396,3,74,37,0,395,394, + 1,0,0,0,395,396,1,0,0,0,396,398,1,0,0,0,397,399,5,90,0,0,398,397, + 1,0,0,0,398,399,1,0,0,0,399,407,1,0,0,0,400,402,5,90,0,0,401,400, + 1,0,0,0,401,402,1,0,0,0,402,404,1,0,0,0,403,405,3,74,37,0,404,403, + 1,0,0,0,404,405,1,0,0,0,405,407,1,0,0,0,406,395,1,0,0,0,406,401, + 1,0,0,0,407,408,1,0,0,0,408,409,5,19,0,0,409,410,3,66,33,0,410,411, + 5,20,0,0,411,461,1,0,0,0,412,419,5,34,0,0,413,414,3,74,37,0,414, + 415,3,76,38,0,415,420,1,0,0,0,416,417,3,76,38,0,417,418,3,74,37, + 0,418,420,1,0,0,0,419,413,1,0,0,0,419,416,1,0,0,0,419,420,1,0,0, + 0,420,427,1,0,0,0,421,423,3,8,4,0,422,421,1,0,0,0,422,423,1,0,0, + 0,423,424,1,0,0,0,424,428,5,76,0,0,425,428,3,54,27,0,426,428,3,8, + 4,0,427,422,1,0,0,0,427,425,1,0,0,0,427,426,1,0,0,0,428,461,1,0, + 0,0,429,434,5,63,0,0,430,431,5,25,0,0,431,432,3,6,3,0,432,433,5, + 26,0,0,433,435,1,0,0,0,434,430,1,0,0,0,434,435,1,0,0,0,435,436,1, + 0,0,0,436,437,5,21,0,0,437,438,3,6,3,0,438,439,5,22,0,0,439,461, + 1,0,0,0,440,441,5,64,0,0,441,442,5,21,0,0,442,443,3,6,3,0,443,444, + 5,22,0,0,444,461,1,0,0,0,445,452,7,8,0,0,446,447,3,78,39,0,447,448, + 3,76,38,0,448,453,1,0,0,0,449,450,3,76,38,0,450,451,3,78,39,0,451, + 453,1,0,0,0,452,446,1,0,0,0,452,449,1,0,0,0,453,454,1,0,0,0,454, + 455,3,10,5,0,455,461,1,0,0,0,456,457,5,32,0,0,457,458,3,68,34,0, + 458,459,3,10,5,0,459,461,1,0,0,0,460,371,1,0,0,0,460,393,1,0,0,0, + 460,412,1,0,0,0,460,429,1,0,0,0,460,440,1,0,0,0,460,445,1,0,0,0, + 460,456,1,0,0,0,461,65,1,0,0,0,462,463,3,6,3,0,463,464,5,1,0,0,464, + 465,3,66,33,0,465,468,1,0,0,0,466,468,3,6,3,0,467,462,1,0,0,0,467, + 466,1,0,0,0,468,67,1,0,0,0,469,470,5,73,0,0,470,471,5,21,0,0,471, + 472,7,3,0,0,472,473,5,33,0,0,473,482,3,6,3,0,474,480,5,74,0,0,475, + 476,5,21,0,0,476,477,7,1,0,0,477,481,5,22,0,0,478,481,5,15,0,0,479, + 481,5,16,0,0,480,475,1,0,0,0,480,478,1,0,0,0,480,479,1,0,0,0,481, + 483,1,0,0,0,482,474,1,0,0,0,482,483,1,0,0,0,483,484,1,0,0,0,484, + 485,5,22,0,0,485,69,1,0,0,0,486,492,3,6,3,0,487,488,3,6,3,0,488, + 489,5,1,0,0,489,490,3,70,35,0,490,492,1,0,0,0,491,486,1,0,0,0,491, + 487,1,0,0,0,492,71,1,0,0,0,493,494,3,12,6,0,494,73,1,0,0,0,495,501, + 5,73,0,0,496,502,3,44,22,0,497,498,5,21,0,0,498,499,3,6,3,0,499, + 500,5,22,0,0,500,502,1,0,0,0,501,496,1,0,0,0,501,497,1,0,0,0,502, + 75,1,0,0,0,503,509,5,74,0,0,504,510,3,44,22,0,505,506,5,21,0,0,506, + 507,3,6,3,0,507,508,5,22,0,0,508,510,1,0,0,0,509,504,1,0,0,0,509, + 505,1,0,0,0,510,77,1,0,0,0,511,512,5,73,0,0,512,513,5,21,0,0,513, + 514,3,4,2,0,514,515,5,22,0,0,515,79,1,0,0,0,516,517,5,73,0,0,517, + 518,5,21,0,0,518,519,3,4,2,0,519,520,5,22,0,0,520,81,1,0,0,0,59, + 92,109,120,131,139,141,149,152,158,165,170,178,184,192,206,209,213, + 226,229,233,242,249,267,276,284,291,293,297,300,303,306,308,317, + 335,344,351,373,376,379,382,384,391,395,398,401,404,406,419,422, + 427,434,452,460,467,480,482,491,501,509 + ] + +class LaTeXParser ( Parser ): + + grammarFileName = "LaTeX.g4" + + atn = ATNDeserializer().deserialize(serializedATN()) + + decisionsToDFA = [ DFA(ds, i) for i, ds in enumerate(atn.decisionToState) ] + + sharedContextCache = PredictionContextCache() + + literalNames = [ "", "','", "'.'", "", "", + "", "", "'\\quad'", "'\\qquad'", + "", "'\\negmedspace'", "'\\negthickspace'", + "'\\left'", "'\\right'", "", "'+'", "'-'", + "'*'", "'/'", "'('", "')'", "'{'", "'}'", "'\\{'", + "'\\}'", "'['", "']'", "'|'", "'\\right|'", "'\\left|'", + "'\\langle'", "'\\rangle'", "'\\lim'", "", + "", "'\\sum'", "'\\prod'", "'\\exp'", "'\\log'", + "'\\lg'", "'\\ln'", "'\\sin'", "'\\cos'", "'\\tan'", + "'\\csc'", "'\\sec'", "'\\cot'", "'\\arcsin'", "'\\arccos'", + "'\\arctan'", "'\\arccsc'", "'\\arcsec'", "'\\arccot'", + "'\\sinh'", "'\\cosh'", "'\\tanh'", "'\\arsinh'", "'\\arcosh'", + "'\\artanh'", "'\\lfloor'", "'\\rfloor'", "'\\lceil'", + "'\\rceil'", "'\\sqrt'", "'\\overline'", "'\\times'", + "'\\cdot'", "'\\div'", "", "'\\binom'", "'\\dbinom'", + "'\\tbinom'", "'\\mathit'", "'_'", "'^'", "':'", "", + "", "", "", "'\\neq'", "'<'", + "", "'\\leqq'", "'\\leqslant'", "'>'", "", + "'\\geqq'", "'\\geqslant'", "'!'" ] + + symbolicNames = [ "", "", "", "WS", "THINSPACE", + "MEDSPACE", "THICKSPACE", "QUAD", "QQUAD", "NEGTHINSPACE", + "NEGMEDSPACE", "NEGTHICKSPACE", "CMD_LEFT", "CMD_RIGHT", + "IGNORE", "ADD", "SUB", "MUL", "DIV", "L_PAREN", "R_PAREN", + "L_BRACE", "R_BRACE", "L_BRACE_LITERAL", "R_BRACE_LITERAL", + "L_BRACKET", "R_BRACKET", "BAR", "R_BAR", "L_BAR", + "L_ANGLE", "R_ANGLE", "FUNC_LIM", "LIM_APPROACH_SYM", + "FUNC_INT", "FUNC_SUM", "FUNC_PROD", "FUNC_EXP", "FUNC_LOG", + "FUNC_LG", "FUNC_LN", "FUNC_SIN", "FUNC_COS", "FUNC_TAN", + "FUNC_CSC", "FUNC_SEC", "FUNC_COT", "FUNC_ARCSIN", + "FUNC_ARCCOS", "FUNC_ARCTAN", "FUNC_ARCCSC", "FUNC_ARCSEC", + "FUNC_ARCCOT", "FUNC_SINH", "FUNC_COSH", "FUNC_TANH", + "FUNC_ARSINH", "FUNC_ARCOSH", "FUNC_ARTANH", "L_FLOOR", + "R_FLOOR", "L_CEIL", "R_CEIL", "FUNC_SQRT", "FUNC_OVERLINE", + "CMD_TIMES", "CMD_CDOT", "CMD_DIV", "CMD_FRAC", "CMD_BINOM", + "CMD_DBINOM", "CMD_TBINOM", "CMD_MATHIT", "UNDERSCORE", + "CARET", "COLON", "DIFFERENTIAL", "LETTER", "DIGIT", + "EQUAL", "NEQ", "LT", "LTE", "LTE_Q", "LTE_S", "GT", + "GTE", "GTE_Q", "GTE_S", "BANG", "SINGLE_QUOTES", + "SYMBOL" ] + + RULE_math = 0 + RULE_relation = 1 + RULE_equality = 2 + RULE_expr = 3 + RULE_additive = 4 + RULE_mp = 5 + RULE_mp_nofunc = 6 + RULE_unary = 7 + RULE_unary_nofunc = 8 + RULE_postfix = 9 + RULE_postfix_nofunc = 10 + RULE_postfix_op = 11 + RULE_eval_at = 12 + RULE_eval_at_sub = 13 + RULE_eval_at_sup = 14 + RULE_exp = 15 + RULE_exp_nofunc = 16 + RULE_comp = 17 + RULE_comp_nofunc = 18 + RULE_group = 19 + RULE_abs_group = 20 + RULE_number = 21 + RULE_atom = 22 + RULE_bra = 23 + RULE_ket = 24 + RULE_mathit = 25 + RULE_mathit_text = 26 + RULE_frac = 27 + RULE_binom = 28 + RULE_floor = 29 + RULE_ceil = 30 + RULE_func_normal = 31 + RULE_func = 32 + RULE_args = 33 + RULE_limit_sub = 34 + RULE_func_arg = 35 + RULE_func_arg_noparens = 36 + RULE_subexpr = 37 + RULE_supexpr = 38 + RULE_subeq = 39 + RULE_supeq = 40 + + ruleNames = [ "math", "relation", "equality", "expr", "additive", "mp", + "mp_nofunc", "unary", "unary_nofunc", "postfix", "postfix_nofunc", + "postfix_op", "eval_at", "eval_at_sub", "eval_at_sup", + "exp", "exp_nofunc", "comp", "comp_nofunc", "group", + "abs_group", "number", "atom", "bra", "ket", "mathit", + "mathit_text", "frac", "binom", "floor", "ceil", "func_normal", + "func", "args", "limit_sub", "func_arg", "func_arg_noparens", + "subexpr", "supexpr", "subeq", "supeq" ] + + EOF = Token.EOF + T__0=1 + T__1=2 + WS=3 + THINSPACE=4 + MEDSPACE=5 + THICKSPACE=6 + QUAD=7 + QQUAD=8 + NEGTHINSPACE=9 + NEGMEDSPACE=10 + NEGTHICKSPACE=11 + CMD_LEFT=12 + CMD_RIGHT=13 + IGNORE=14 + ADD=15 + SUB=16 + MUL=17 + DIV=18 + L_PAREN=19 + R_PAREN=20 + L_BRACE=21 + R_BRACE=22 + L_BRACE_LITERAL=23 + R_BRACE_LITERAL=24 + L_BRACKET=25 + R_BRACKET=26 + BAR=27 + R_BAR=28 + L_BAR=29 + L_ANGLE=30 + R_ANGLE=31 + FUNC_LIM=32 + LIM_APPROACH_SYM=33 + FUNC_INT=34 + FUNC_SUM=35 + FUNC_PROD=36 + FUNC_EXP=37 + FUNC_LOG=38 + FUNC_LG=39 + FUNC_LN=40 + FUNC_SIN=41 + FUNC_COS=42 + FUNC_TAN=43 + FUNC_CSC=44 + FUNC_SEC=45 + FUNC_COT=46 + FUNC_ARCSIN=47 + FUNC_ARCCOS=48 + FUNC_ARCTAN=49 + FUNC_ARCCSC=50 + FUNC_ARCSEC=51 + FUNC_ARCCOT=52 + FUNC_SINH=53 + FUNC_COSH=54 + FUNC_TANH=55 + FUNC_ARSINH=56 + FUNC_ARCOSH=57 + FUNC_ARTANH=58 + L_FLOOR=59 + R_FLOOR=60 + L_CEIL=61 + R_CEIL=62 + FUNC_SQRT=63 + FUNC_OVERLINE=64 + CMD_TIMES=65 + CMD_CDOT=66 + CMD_DIV=67 + CMD_FRAC=68 + CMD_BINOM=69 + CMD_DBINOM=70 + CMD_TBINOM=71 + CMD_MATHIT=72 + UNDERSCORE=73 + CARET=74 + COLON=75 + DIFFERENTIAL=76 + LETTER=77 + DIGIT=78 + EQUAL=79 + NEQ=80 + LT=81 + LTE=82 + LTE_Q=83 + LTE_S=84 + GT=85 + GTE=86 + GTE_Q=87 + GTE_S=88 + BANG=89 + SINGLE_QUOTES=90 + SYMBOL=91 + + def __init__(self, input:TokenStream, output:TextIO = sys.stdout): + super().__init__(input, output) + self.checkVersion("4.11.1") + self._interp = ParserATNSimulator(self, self.atn, self.decisionsToDFA, self.sharedContextCache) + self._predicates = None + + + + + class MathContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def relation(self): + return self.getTypedRuleContext(LaTeXParser.RelationContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_math + + + + + def math(self): + + localctx = LaTeXParser.MathContext(self, self._ctx, self.state) + self.enterRule(localctx, 0, self.RULE_math) + try: + self.enterOuterAlt(localctx, 1) + self.state = 82 + self.relation(0) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class RelationContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def expr(self): + return self.getTypedRuleContext(LaTeXParser.ExprContext,0) + + + def relation(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(LaTeXParser.RelationContext) + else: + return self.getTypedRuleContext(LaTeXParser.RelationContext,i) + + + def EQUAL(self): + return self.getToken(LaTeXParser.EQUAL, 0) + + def LT(self): + return self.getToken(LaTeXParser.LT, 0) + + def LTE(self): + return self.getToken(LaTeXParser.LTE, 0) + + def GT(self): + return self.getToken(LaTeXParser.GT, 0) + + def GTE(self): + return self.getToken(LaTeXParser.GTE, 0) + + def NEQ(self): + return self.getToken(LaTeXParser.NEQ, 0) + + def getRuleIndex(self): + return LaTeXParser.RULE_relation + + + + def relation(self, _p:int=0): + _parentctx = self._ctx + _parentState = self.state + localctx = LaTeXParser.RelationContext(self, self._ctx, _parentState) + _prevctx = localctx + _startState = 2 + self.enterRecursionRule(localctx, 2, self.RULE_relation, _p) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 85 + self.expr() + self._ctx.stop = self._input.LT(-1) + self.state = 92 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,0,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: + if self._parseListeners is not None: + self.triggerExitRuleEvent() + _prevctx = localctx + localctx = LaTeXParser.RelationContext(self, _parentctx, _parentState) + self.pushNewRecursionContext(localctx, _startState, self.RULE_relation) + self.state = 87 + if not self.precpred(self._ctx, 2): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 2)") + self.state = 88 + _la = self._input.LA(1) + if not((((_la - 79)) & ~0x3f) == 0 and ((1 << (_la - 79)) & 207) != 0): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 89 + self.relation(3) + self.state = 94 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,0,self._ctx) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.unrollRecursionContexts(_parentctx) + return localctx + + + class EqualityContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(LaTeXParser.ExprContext) + else: + return self.getTypedRuleContext(LaTeXParser.ExprContext,i) + + + def EQUAL(self): + return self.getToken(LaTeXParser.EQUAL, 0) + + def getRuleIndex(self): + return LaTeXParser.RULE_equality + + + + + def equality(self): + + localctx = LaTeXParser.EqualityContext(self, self._ctx, self.state) + self.enterRule(localctx, 4, self.RULE_equality) + try: + self.enterOuterAlt(localctx, 1) + self.state = 95 + self.expr() + self.state = 96 + self.match(LaTeXParser.EQUAL) + self.state = 97 + self.expr() + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class ExprContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def additive(self): + return self.getTypedRuleContext(LaTeXParser.AdditiveContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_expr + + + + + def expr(self): + + localctx = LaTeXParser.ExprContext(self, self._ctx, self.state) + self.enterRule(localctx, 6, self.RULE_expr) + try: + self.enterOuterAlt(localctx, 1) + self.state = 99 + self.additive(0) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class AdditiveContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def mp(self): + return self.getTypedRuleContext(LaTeXParser.MpContext,0) + + + def additive(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(LaTeXParser.AdditiveContext) + else: + return self.getTypedRuleContext(LaTeXParser.AdditiveContext,i) + + + def ADD(self): + return self.getToken(LaTeXParser.ADD, 0) + + def SUB(self): + return self.getToken(LaTeXParser.SUB, 0) + + def getRuleIndex(self): + return LaTeXParser.RULE_additive + + + + def additive(self, _p:int=0): + _parentctx = self._ctx + _parentState = self.state + localctx = LaTeXParser.AdditiveContext(self, self._ctx, _parentState) + _prevctx = localctx + _startState = 8 + self.enterRecursionRule(localctx, 8, self.RULE_additive, _p) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 102 + self.mp(0) + self._ctx.stop = self._input.LT(-1) + self.state = 109 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,1,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: + if self._parseListeners is not None: + self.triggerExitRuleEvent() + _prevctx = localctx + localctx = LaTeXParser.AdditiveContext(self, _parentctx, _parentState) + self.pushNewRecursionContext(localctx, _startState, self.RULE_additive) + self.state = 104 + if not self.precpred(self._ctx, 2): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 2)") + self.state = 105 + _la = self._input.LA(1) + if not(_la==15 or _la==16): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 106 + self.additive(3) + self.state = 111 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,1,self._ctx) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.unrollRecursionContexts(_parentctx) + return localctx + + + class MpContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def unary(self): + return self.getTypedRuleContext(LaTeXParser.UnaryContext,0) + + + def mp(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(LaTeXParser.MpContext) + else: + return self.getTypedRuleContext(LaTeXParser.MpContext,i) + + + def MUL(self): + return self.getToken(LaTeXParser.MUL, 0) + + def CMD_TIMES(self): + return self.getToken(LaTeXParser.CMD_TIMES, 0) + + def CMD_CDOT(self): + return self.getToken(LaTeXParser.CMD_CDOT, 0) + + def DIV(self): + return self.getToken(LaTeXParser.DIV, 0) + + def CMD_DIV(self): + return self.getToken(LaTeXParser.CMD_DIV, 0) + + def COLON(self): + return self.getToken(LaTeXParser.COLON, 0) + + def getRuleIndex(self): + return LaTeXParser.RULE_mp + + + + def mp(self, _p:int=0): + _parentctx = self._ctx + _parentState = self.state + localctx = LaTeXParser.MpContext(self, self._ctx, _parentState) + _prevctx = localctx + _startState = 10 + self.enterRecursionRule(localctx, 10, self.RULE_mp, _p) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 113 + self.unary() + self._ctx.stop = self._input.LT(-1) + self.state = 120 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,2,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: + if self._parseListeners is not None: + self.triggerExitRuleEvent() + _prevctx = localctx + localctx = LaTeXParser.MpContext(self, _parentctx, _parentState) + self.pushNewRecursionContext(localctx, _startState, self.RULE_mp) + self.state = 115 + if not self.precpred(self._ctx, 2): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 2)") + self.state = 116 + _la = self._input.LA(1) + if not((((_la - 17)) & ~0x3f) == 0 and ((1 << (_la - 17)) & 290200700988686339) != 0): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 117 + self.mp(3) + self.state = 122 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,2,self._ctx) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.unrollRecursionContexts(_parentctx) + return localctx + + + class Mp_nofuncContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def unary_nofunc(self): + return self.getTypedRuleContext(LaTeXParser.Unary_nofuncContext,0) + + + def mp_nofunc(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(LaTeXParser.Mp_nofuncContext) + else: + return self.getTypedRuleContext(LaTeXParser.Mp_nofuncContext,i) + + + def MUL(self): + return self.getToken(LaTeXParser.MUL, 0) + + def CMD_TIMES(self): + return self.getToken(LaTeXParser.CMD_TIMES, 0) + + def CMD_CDOT(self): + return self.getToken(LaTeXParser.CMD_CDOT, 0) + + def DIV(self): + return self.getToken(LaTeXParser.DIV, 0) + + def CMD_DIV(self): + return self.getToken(LaTeXParser.CMD_DIV, 0) + + def COLON(self): + return self.getToken(LaTeXParser.COLON, 0) + + def getRuleIndex(self): + return LaTeXParser.RULE_mp_nofunc + + + + def mp_nofunc(self, _p:int=0): + _parentctx = self._ctx + _parentState = self.state + localctx = LaTeXParser.Mp_nofuncContext(self, self._ctx, _parentState) + _prevctx = localctx + _startState = 12 + self.enterRecursionRule(localctx, 12, self.RULE_mp_nofunc, _p) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 124 + self.unary_nofunc() + self._ctx.stop = self._input.LT(-1) + self.state = 131 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,3,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: + if self._parseListeners is not None: + self.triggerExitRuleEvent() + _prevctx = localctx + localctx = LaTeXParser.Mp_nofuncContext(self, _parentctx, _parentState) + self.pushNewRecursionContext(localctx, _startState, self.RULE_mp_nofunc) + self.state = 126 + if not self.precpred(self._ctx, 2): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 2)") + self.state = 127 + _la = self._input.LA(1) + if not((((_la - 17)) & ~0x3f) == 0 and ((1 << (_la - 17)) & 290200700988686339) != 0): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 128 + self.mp_nofunc(3) + self.state = 133 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,3,self._ctx) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.unrollRecursionContexts(_parentctx) + return localctx + + + class UnaryContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def unary(self): + return self.getTypedRuleContext(LaTeXParser.UnaryContext,0) + + + def ADD(self): + return self.getToken(LaTeXParser.ADD, 0) + + def SUB(self): + return self.getToken(LaTeXParser.SUB, 0) + + def postfix(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(LaTeXParser.PostfixContext) + else: + return self.getTypedRuleContext(LaTeXParser.PostfixContext,i) + + + def getRuleIndex(self): + return LaTeXParser.RULE_unary + + + + + def unary(self): + + localctx = LaTeXParser.UnaryContext(self, self._ctx, self.state) + self.enterRule(localctx, 14, self.RULE_unary) + self._la = 0 # Token type + try: + self.state = 141 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [15, 16]: + self.enterOuterAlt(localctx, 1) + self.state = 134 + _la = self._input.LA(1) + if not(_la==15 or _la==16): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 135 + self.unary() + pass + elif token in [19, 21, 23, 25, 27, 29, 30, 32, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 61, 63, 64, 68, 69, 70, 71, 72, 76, 77, 78, 91]: + self.enterOuterAlt(localctx, 2) + self.state = 137 + self._errHandler.sync(self) + _alt = 1 + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt == 1: + self.state = 136 + self.postfix() + + else: + raise NoViableAltException(self) + self.state = 139 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,4,self._ctx) + + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class Unary_nofuncContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def unary_nofunc(self): + return self.getTypedRuleContext(LaTeXParser.Unary_nofuncContext,0) + + + def ADD(self): + return self.getToken(LaTeXParser.ADD, 0) + + def SUB(self): + return self.getToken(LaTeXParser.SUB, 0) + + def postfix(self): + return self.getTypedRuleContext(LaTeXParser.PostfixContext,0) + + + def postfix_nofunc(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(LaTeXParser.Postfix_nofuncContext) + else: + return self.getTypedRuleContext(LaTeXParser.Postfix_nofuncContext,i) + + + def getRuleIndex(self): + return LaTeXParser.RULE_unary_nofunc + + + + + def unary_nofunc(self): + + localctx = LaTeXParser.Unary_nofuncContext(self, self._ctx, self.state) + self.enterRule(localctx, 16, self.RULE_unary_nofunc) + self._la = 0 # Token type + try: + self.state = 152 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [15, 16]: + self.enterOuterAlt(localctx, 1) + self.state = 143 + _la = self._input.LA(1) + if not(_la==15 or _la==16): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 144 + self.unary_nofunc() + pass + elif token in [19, 21, 23, 25, 27, 29, 30, 32, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 61, 63, 64, 68, 69, 70, 71, 72, 76, 77, 78, 91]: + self.enterOuterAlt(localctx, 2) + self.state = 145 + self.postfix() + self.state = 149 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,6,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: + self.state = 146 + self.postfix_nofunc() + self.state = 151 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,6,self._ctx) + + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class PostfixContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def exp(self): + return self.getTypedRuleContext(LaTeXParser.ExpContext,0) + + + def postfix_op(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(LaTeXParser.Postfix_opContext) + else: + return self.getTypedRuleContext(LaTeXParser.Postfix_opContext,i) + + + def getRuleIndex(self): + return LaTeXParser.RULE_postfix + + + + + def postfix(self): + + localctx = LaTeXParser.PostfixContext(self, self._ctx, self.state) + self.enterRule(localctx, 18, self.RULE_postfix) + try: + self.enterOuterAlt(localctx, 1) + self.state = 154 + self.exp(0) + self.state = 158 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,8,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: + self.state = 155 + self.postfix_op() + self.state = 160 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,8,self._ctx) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class Postfix_nofuncContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def exp_nofunc(self): + return self.getTypedRuleContext(LaTeXParser.Exp_nofuncContext,0) + + + def postfix_op(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(LaTeXParser.Postfix_opContext) + else: + return self.getTypedRuleContext(LaTeXParser.Postfix_opContext,i) + + + def getRuleIndex(self): + return LaTeXParser.RULE_postfix_nofunc + + + + + def postfix_nofunc(self): + + localctx = LaTeXParser.Postfix_nofuncContext(self, self._ctx, self.state) + self.enterRule(localctx, 20, self.RULE_postfix_nofunc) + try: + self.enterOuterAlt(localctx, 1) + self.state = 161 + self.exp_nofunc(0) + self.state = 165 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,9,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: + self.state = 162 + self.postfix_op() + self.state = 167 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,9,self._ctx) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class Postfix_opContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def BANG(self): + return self.getToken(LaTeXParser.BANG, 0) + + def eval_at(self): + return self.getTypedRuleContext(LaTeXParser.Eval_atContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_postfix_op + + + + + def postfix_op(self): + + localctx = LaTeXParser.Postfix_opContext(self, self._ctx, self.state) + self.enterRule(localctx, 22, self.RULE_postfix_op) + try: + self.state = 170 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [89]: + self.enterOuterAlt(localctx, 1) + self.state = 168 + self.match(LaTeXParser.BANG) + pass + elif token in [27]: + self.enterOuterAlt(localctx, 2) + self.state = 169 + self.eval_at() + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class Eval_atContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def BAR(self): + return self.getToken(LaTeXParser.BAR, 0) + + def eval_at_sup(self): + return self.getTypedRuleContext(LaTeXParser.Eval_at_supContext,0) + + + def eval_at_sub(self): + return self.getTypedRuleContext(LaTeXParser.Eval_at_subContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_eval_at + + + + + def eval_at(self): + + localctx = LaTeXParser.Eval_atContext(self, self._ctx, self.state) + self.enterRule(localctx, 24, self.RULE_eval_at) + try: + self.enterOuterAlt(localctx, 1) + self.state = 172 + self.match(LaTeXParser.BAR) + self.state = 178 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,11,self._ctx) + if la_ == 1: + self.state = 173 + self.eval_at_sup() + pass + + elif la_ == 2: + self.state = 174 + self.eval_at_sub() + pass + + elif la_ == 3: + self.state = 175 + self.eval_at_sup() + self.state = 176 + self.eval_at_sub() + pass + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class Eval_at_subContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def UNDERSCORE(self): + return self.getToken(LaTeXParser.UNDERSCORE, 0) + + def L_BRACE(self): + return self.getToken(LaTeXParser.L_BRACE, 0) + + def R_BRACE(self): + return self.getToken(LaTeXParser.R_BRACE, 0) + + def expr(self): + return self.getTypedRuleContext(LaTeXParser.ExprContext,0) + + + def equality(self): + return self.getTypedRuleContext(LaTeXParser.EqualityContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_eval_at_sub + + + + + def eval_at_sub(self): + + localctx = LaTeXParser.Eval_at_subContext(self, self._ctx, self.state) + self.enterRule(localctx, 26, self.RULE_eval_at_sub) + try: + self.enterOuterAlt(localctx, 1) + self.state = 180 + self.match(LaTeXParser.UNDERSCORE) + self.state = 181 + self.match(LaTeXParser.L_BRACE) + self.state = 184 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,12,self._ctx) + if la_ == 1: + self.state = 182 + self.expr() + pass + + elif la_ == 2: + self.state = 183 + self.equality() + pass + + + self.state = 186 + self.match(LaTeXParser.R_BRACE) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class Eval_at_supContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def CARET(self): + return self.getToken(LaTeXParser.CARET, 0) + + def L_BRACE(self): + return self.getToken(LaTeXParser.L_BRACE, 0) + + def R_BRACE(self): + return self.getToken(LaTeXParser.R_BRACE, 0) + + def expr(self): + return self.getTypedRuleContext(LaTeXParser.ExprContext,0) + + + def equality(self): + return self.getTypedRuleContext(LaTeXParser.EqualityContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_eval_at_sup + + + + + def eval_at_sup(self): + + localctx = LaTeXParser.Eval_at_supContext(self, self._ctx, self.state) + self.enterRule(localctx, 28, self.RULE_eval_at_sup) + try: + self.enterOuterAlt(localctx, 1) + self.state = 188 + self.match(LaTeXParser.CARET) + self.state = 189 + self.match(LaTeXParser.L_BRACE) + self.state = 192 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,13,self._ctx) + if la_ == 1: + self.state = 190 + self.expr() + pass + + elif la_ == 2: + self.state = 191 + self.equality() + pass + + + self.state = 194 + self.match(LaTeXParser.R_BRACE) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class ExpContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def comp(self): + return self.getTypedRuleContext(LaTeXParser.CompContext,0) + + + def exp(self): + return self.getTypedRuleContext(LaTeXParser.ExpContext,0) + + + def CARET(self): + return self.getToken(LaTeXParser.CARET, 0) + + def atom(self): + return self.getTypedRuleContext(LaTeXParser.AtomContext,0) + + + def L_BRACE(self): + return self.getToken(LaTeXParser.L_BRACE, 0) + + def expr(self): + return self.getTypedRuleContext(LaTeXParser.ExprContext,0) + + + def R_BRACE(self): + return self.getToken(LaTeXParser.R_BRACE, 0) + + def subexpr(self): + return self.getTypedRuleContext(LaTeXParser.SubexprContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_exp + + + + def exp(self, _p:int=0): + _parentctx = self._ctx + _parentState = self.state + localctx = LaTeXParser.ExpContext(self, self._ctx, _parentState) + _prevctx = localctx + _startState = 30 + self.enterRecursionRule(localctx, 30, self.RULE_exp, _p) + try: + self.enterOuterAlt(localctx, 1) + self.state = 197 + self.comp() + self._ctx.stop = self._input.LT(-1) + self.state = 213 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,16,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: + if self._parseListeners is not None: + self.triggerExitRuleEvent() + _prevctx = localctx + localctx = LaTeXParser.ExpContext(self, _parentctx, _parentState) + self.pushNewRecursionContext(localctx, _startState, self.RULE_exp) + self.state = 199 + if not self.precpred(self._ctx, 2): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 2)") + self.state = 200 + self.match(LaTeXParser.CARET) + self.state = 206 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [27, 29, 30, 68, 69, 70, 71, 72, 76, 77, 78, 91]: + self.state = 201 + self.atom() + pass + elif token in [21]: + self.state = 202 + self.match(LaTeXParser.L_BRACE) + self.state = 203 + self.expr() + self.state = 204 + self.match(LaTeXParser.R_BRACE) + pass + else: + raise NoViableAltException(self) + + self.state = 209 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,15,self._ctx) + if la_ == 1: + self.state = 208 + self.subexpr() + + + self.state = 215 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,16,self._ctx) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.unrollRecursionContexts(_parentctx) + return localctx + + + class Exp_nofuncContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def comp_nofunc(self): + return self.getTypedRuleContext(LaTeXParser.Comp_nofuncContext,0) + + + def exp_nofunc(self): + return self.getTypedRuleContext(LaTeXParser.Exp_nofuncContext,0) + + + def CARET(self): + return self.getToken(LaTeXParser.CARET, 0) + + def atom(self): + return self.getTypedRuleContext(LaTeXParser.AtomContext,0) + + + def L_BRACE(self): + return self.getToken(LaTeXParser.L_BRACE, 0) + + def expr(self): + return self.getTypedRuleContext(LaTeXParser.ExprContext,0) + + + def R_BRACE(self): + return self.getToken(LaTeXParser.R_BRACE, 0) + + def subexpr(self): + return self.getTypedRuleContext(LaTeXParser.SubexprContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_exp_nofunc + + + + def exp_nofunc(self, _p:int=0): + _parentctx = self._ctx + _parentState = self.state + localctx = LaTeXParser.Exp_nofuncContext(self, self._ctx, _parentState) + _prevctx = localctx + _startState = 32 + self.enterRecursionRule(localctx, 32, self.RULE_exp_nofunc, _p) + try: + self.enterOuterAlt(localctx, 1) + self.state = 217 + self.comp_nofunc() + self._ctx.stop = self._input.LT(-1) + self.state = 233 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,19,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: + if self._parseListeners is not None: + self.triggerExitRuleEvent() + _prevctx = localctx + localctx = LaTeXParser.Exp_nofuncContext(self, _parentctx, _parentState) + self.pushNewRecursionContext(localctx, _startState, self.RULE_exp_nofunc) + self.state = 219 + if not self.precpred(self._ctx, 2): + from antlr4.error.Errors import FailedPredicateException + raise FailedPredicateException(self, "self.precpred(self._ctx, 2)") + self.state = 220 + self.match(LaTeXParser.CARET) + self.state = 226 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [27, 29, 30, 68, 69, 70, 71, 72, 76, 77, 78, 91]: + self.state = 221 + self.atom() + pass + elif token in [21]: + self.state = 222 + self.match(LaTeXParser.L_BRACE) + self.state = 223 + self.expr() + self.state = 224 + self.match(LaTeXParser.R_BRACE) + pass + else: + raise NoViableAltException(self) + + self.state = 229 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,18,self._ctx) + if la_ == 1: + self.state = 228 + self.subexpr() + + + self.state = 235 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,19,self._ctx) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.unrollRecursionContexts(_parentctx) + return localctx + + + class CompContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def group(self): + return self.getTypedRuleContext(LaTeXParser.GroupContext,0) + + + def abs_group(self): + return self.getTypedRuleContext(LaTeXParser.Abs_groupContext,0) + + + def func(self): + return self.getTypedRuleContext(LaTeXParser.FuncContext,0) + + + def atom(self): + return self.getTypedRuleContext(LaTeXParser.AtomContext,0) + + + def floor(self): + return self.getTypedRuleContext(LaTeXParser.FloorContext,0) + + + def ceil(self): + return self.getTypedRuleContext(LaTeXParser.CeilContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_comp + + + + + def comp(self): + + localctx = LaTeXParser.CompContext(self, self._ctx, self.state) + self.enterRule(localctx, 34, self.RULE_comp) + try: + self.state = 242 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,20,self._ctx) + if la_ == 1: + self.enterOuterAlt(localctx, 1) + self.state = 236 + self.group() + pass + + elif la_ == 2: + self.enterOuterAlt(localctx, 2) + self.state = 237 + self.abs_group() + pass + + elif la_ == 3: + self.enterOuterAlt(localctx, 3) + self.state = 238 + self.func() + pass + + elif la_ == 4: + self.enterOuterAlt(localctx, 4) + self.state = 239 + self.atom() + pass + + elif la_ == 5: + self.enterOuterAlt(localctx, 5) + self.state = 240 + self.floor() + pass + + elif la_ == 6: + self.enterOuterAlt(localctx, 6) + self.state = 241 + self.ceil() + pass + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class Comp_nofuncContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def group(self): + return self.getTypedRuleContext(LaTeXParser.GroupContext,0) + + + def abs_group(self): + return self.getTypedRuleContext(LaTeXParser.Abs_groupContext,0) + + + def atom(self): + return self.getTypedRuleContext(LaTeXParser.AtomContext,0) + + + def floor(self): + return self.getTypedRuleContext(LaTeXParser.FloorContext,0) + + + def ceil(self): + return self.getTypedRuleContext(LaTeXParser.CeilContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_comp_nofunc + + + + + def comp_nofunc(self): + + localctx = LaTeXParser.Comp_nofuncContext(self, self._ctx, self.state) + self.enterRule(localctx, 36, self.RULE_comp_nofunc) + try: + self.state = 249 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,21,self._ctx) + if la_ == 1: + self.enterOuterAlt(localctx, 1) + self.state = 244 + self.group() + pass + + elif la_ == 2: + self.enterOuterAlt(localctx, 2) + self.state = 245 + self.abs_group() + pass + + elif la_ == 3: + self.enterOuterAlt(localctx, 3) + self.state = 246 + self.atom() + pass + + elif la_ == 4: + self.enterOuterAlt(localctx, 4) + self.state = 247 + self.floor() + pass + + elif la_ == 5: + self.enterOuterAlt(localctx, 5) + self.state = 248 + self.ceil() + pass + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class GroupContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def L_PAREN(self): + return self.getToken(LaTeXParser.L_PAREN, 0) + + def expr(self): + return self.getTypedRuleContext(LaTeXParser.ExprContext,0) + + + def R_PAREN(self): + return self.getToken(LaTeXParser.R_PAREN, 0) + + def L_BRACKET(self): + return self.getToken(LaTeXParser.L_BRACKET, 0) + + def R_BRACKET(self): + return self.getToken(LaTeXParser.R_BRACKET, 0) + + def L_BRACE(self): + return self.getToken(LaTeXParser.L_BRACE, 0) + + def R_BRACE(self): + return self.getToken(LaTeXParser.R_BRACE, 0) + + def L_BRACE_LITERAL(self): + return self.getToken(LaTeXParser.L_BRACE_LITERAL, 0) + + def R_BRACE_LITERAL(self): + return self.getToken(LaTeXParser.R_BRACE_LITERAL, 0) + + def getRuleIndex(self): + return LaTeXParser.RULE_group + + + + + def group(self): + + localctx = LaTeXParser.GroupContext(self, self._ctx, self.state) + self.enterRule(localctx, 38, self.RULE_group) + try: + self.state = 267 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [19]: + self.enterOuterAlt(localctx, 1) + self.state = 251 + self.match(LaTeXParser.L_PAREN) + self.state = 252 + self.expr() + self.state = 253 + self.match(LaTeXParser.R_PAREN) + pass + elif token in [25]: + self.enterOuterAlt(localctx, 2) + self.state = 255 + self.match(LaTeXParser.L_BRACKET) + self.state = 256 + self.expr() + self.state = 257 + self.match(LaTeXParser.R_BRACKET) + pass + elif token in [21]: + self.enterOuterAlt(localctx, 3) + self.state = 259 + self.match(LaTeXParser.L_BRACE) + self.state = 260 + self.expr() + self.state = 261 + self.match(LaTeXParser.R_BRACE) + pass + elif token in [23]: + self.enterOuterAlt(localctx, 4) + self.state = 263 + self.match(LaTeXParser.L_BRACE_LITERAL) + self.state = 264 + self.expr() + self.state = 265 + self.match(LaTeXParser.R_BRACE_LITERAL) + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class Abs_groupContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def BAR(self, i:int=None): + if i is None: + return self.getTokens(LaTeXParser.BAR) + else: + return self.getToken(LaTeXParser.BAR, i) + + def expr(self): + return self.getTypedRuleContext(LaTeXParser.ExprContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_abs_group + + + + + def abs_group(self): + + localctx = LaTeXParser.Abs_groupContext(self, self._ctx, self.state) + self.enterRule(localctx, 40, self.RULE_abs_group) + try: + self.enterOuterAlt(localctx, 1) + self.state = 269 + self.match(LaTeXParser.BAR) + self.state = 270 + self.expr() + self.state = 271 + self.match(LaTeXParser.BAR) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class NumberContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def DIGIT(self, i:int=None): + if i is None: + return self.getTokens(LaTeXParser.DIGIT) + else: + return self.getToken(LaTeXParser.DIGIT, i) + + def getRuleIndex(self): + return LaTeXParser.RULE_number + + + + + def number(self): + + localctx = LaTeXParser.NumberContext(self, self._ctx, self.state) + self.enterRule(localctx, 42, self.RULE_number) + try: + self.enterOuterAlt(localctx, 1) + self.state = 274 + self._errHandler.sync(self) + _alt = 1 + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt == 1: + self.state = 273 + self.match(LaTeXParser.DIGIT) + + else: + raise NoViableAltException(self) + self.state = 276 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,23,self._ctx) + + self.state = 284 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,24,self._ctx) + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt==1: + self.state = 278 + self.match(LaTeXParser.T__0) + self.state = 279 + self.match(LaTeXParser.DIGIT) + self.state = 280 + self.match(LaTeXParser.DIGIT) + self.state = 281 + self.match(LaTeXParser.DIGIT) + self.state = 286 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,24,self._ctx) + + self.state = 293 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,26,self._ctx) + if la_ == 1: + self.state = 287 + self.match(LaTeXParser.T__1) + self.state = 289 + self._errHandler.sync(self) + _alt = 1 + while _alt!=2 and _alt!=ATN.INVALID_ALT_NUMBER: + if _alt == 1: + self.state = 288 + self.match(LaTeXParser.DIGIT) + + else: + raise NoViableAltException(self) + self.state = 291 + self._errHandler.sync(self) + _alt = self._interp.adaptivePredict(self._input,25,self._ctx) + + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class AtomContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def LETTER(self): + return self.getToken(LaTeXParser.LETTER, 0) + + def SYMBOL(self): + return self.getToken(LaTeXParser.SYMBOL, 0) + + def subexpr(self): + return self.getTypedRuleContext(LaTeXParser.SubexprContext,0) + + + def SINGLE_QUOTES(self): + return self.getToken(LaTeXParser.SINGLE_QUOTES, 0) + + def number(self): + return self.getTypedRuleContext(LaTeXParser.NumberContext,0) + + + def DIFFERENTIAL(self): + return self.getToken(LaTeXParser.DIFFERENTIAL, 0) + + def mathit(self): + return self.getTypedRuleContext(LaTeXParser.MathitContext,0) + + + def frac(self): + return self.getTypedRuleContext(LaTeXParser.FracContext,0) + + + def binom(self): + return self.getTypedRuleContext(LaTeXParser.BinomContext,0) + + + def bra(self): + return self.getTypedRuleContext(LaTeXParser.BraContext,0) + + + def ket(self): + return self.getTypedRuleContext(LaTeXParser.KetContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_atom + + + + + def atom(self): + + localctx = LaTeXParser.AtomContext(self, self._ctx, self.state) + self.enterRule(localctx, 44, self.RULE_atom) + self._la = 0 # Token type + try: + self.state = 317 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [77, 91]: + self.enterOuterAlt(localctx, 1) + self.state = 295 + _la = self._input.LA(1) + if not(_la==77 or _la==91): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 308 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,31,self._ctx) + if la_ == 1: + self.state = 297 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,27,self._ctx) + if la_ == 1: + self.state = 296 + self.subexpr() + + + self.state = 300 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,28,self._ctx) + if la_ == 1: + self.state = 299 + self.match(LaTeXParser.SINGLE_QUOTES) + + + pass + + elif la_ == 2: + self.state = 303 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,29,self._ctx) + if la_ == 1: + self.state = 302 + self.match(LaTeXParser.SINGLE_QUOTES) + + + self.state = 306 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,30,self._ctx) + if la_ == 1: + self.state = 305 + self.subexpr() + + + pass + + + pass + elif token in [78]: + self.enterOuterAlt(localctx, 2) + self.state = 310 + self.number() + pass + elif token in [76]: + self.enterOuterAlt(localctx, 3) + self.state = 311 + self.match(LaTeXParser.DIFFERENTIAL) + pass + elif token in [72]: + self.enterOuterAlt(localctx, 4) + self.state = 312 + self.mathit() + pass + elif token in [68]: + self.enterOuterAlt(localctx, 5) + self.state = 313 + self.frac() + pass + elif token in [69, 70, 71]: + self.enterOuterAlt(localctx, 6) + self.state = 314 + self.binom() + pass + elif token in [30]: + self.enterOuterAlt(localctx, 7) + self.state = 315 + self.bra() + pass + elif token in [27, 29]: + self.enterOuterAlt(localctx, 8) + self.state = 316 + self.ket() + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class BraContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def L_ANGLE(self): + return self.getToken(LaTeXParser.L_ANGLE, 0) + + def expr(self): + return self.getTypedRuleContext(LaTeXParser.ExprContext,0) + + + def R_BAR(self): + return self.getToken(LaTeXParser.R_BAR, 0) + + def BAR(self): + return self.getToken(LaTeXParser.BAR, 0) + + def getRuleIndex(self): + return LaTeXParser.RULE_bra + + + + + def bra(self): + + localctx = LaTeXParser.BraContext(self, self._ctx, self.state) + self.enterRule(localctx, 46, self.RULE_bra) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 319 + self.match(LaTeXParser.L_ANGLE) + self.state = 320 + self.expr() + self.state = 321 + _la = self._input.LA(1) + if not(_la==27 or _la==28): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class KetContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def expr(self): + return self.getTypedRuleContext(LaTeXParser.ExprContext,0) + + + def R_ANGLE(self): + return self.getToken(LaTeXParser.R_ANGLE, 0) + + def L_BAR(self): + return self.getToken(LaTeXParser.L_BAR, 0) + + def BAR(self): + return self.getToken(LaTeXParser.BAR, 0) + + def getRuleIndex(self): + return LaTeXParser.RULE_ket + + + + + def ket(self): + + localctx = LaTeXParser.KetContext(self, self._ctx, self.state) + self.enterRule(localctx, 48, self.RULE_ket) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 323 + _la = self._input.LA(1) + if not(_la==27 or _la==29): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 324 + self.expr() + self.state = 325 + self.match(LaTeXParser.R_ANGLE) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class MathitContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def CMD_MATHIT(self): + return self.getToken(LaTeXParser.CMD_MATHIT, 0) + + def L_BRACE(self): + return self.getToken(LaTeXParser.L_BRACE, 0) + + def mathit_text(self): + return self.getTypedRuleContext(LaTeXParser.Mathit_textContext,0) + + + def R_BRACE(self): + return self.getToken(LaTeXParser.R_BRACE, 0) + + def getRuleIndex(self): + return LaTeXParser.RULE_mathit + + + + + def mathit(self): + + localctx = LaTeXParser.MathitContext(self, self._ctx, self.state) + self.enterRule(localctx, 50, self.RULE_mathit) + try: + self.enterOuterAlt(localctx, 1) + self.state = 327 + self.match(LaTeXParser.CMD_MATHIT) + self.state = 328 + self.match(LaTeXParser.L_BRACE) + self.state = 329 + self.mathit_text() + self.state = 330 + self.match(LaTeXParser.R_BRACE) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class Mathit_textContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def LETTER(self, i:int=None): + if i is None: + return self.getTokens(LaTeXParser.LETTER) + else: + return self.getToken(LaTeXParser.LETTER, i) + + def getRuleIndex(self): + return LaTeXParser.RULE_mathit_text + + + + + def mathit_text(self): + + localctx = LaTeXParser.Mathit_textContext(self, self._ctx, self.state) + self.enterRule(localctx, 52, self.RULE_mathit_text) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 335 + self._errHandler.sync(self) + _la = self._input.LA(1) + while _la==77: + self.state = 332 + self.match(LaTeXParser.LETTER) + self.state = 337 + self._errHandler.sync(self) + _la = self._input.LA(1) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class FracContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + self.upperd = None # Token + self.upper = None # ExprContext + self.lowerd = None # Token + self.lower = None # ExprContext + + def CMD_FRAC(self): + return self.getToken(LaTeXParser.CMD_FRAC, 0) + + def L_BRACE(self, i:int=None): + if i is None: + return self.getTokens(LaTeXParser.L_BRACE) + else: + return self.getToken(LaTeXParser.L_BRACE, i) + + def R_BRACE(self, i:int=None): + if i is None: + return self.getTokens(LaTeXParser.R_BRACE) + else: + return self.getToken(LaTeXParser.R_BRACE, i) + + def DIGIT(self, i:int=None): + if i is None: + return self.getTokens(LaTeXParser.DIGIT) + else: + return self.getToken(LaTeXParser.DIGIT, i) + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(LaTeXParser.ExprContext) + else: + return self.getTypedRuleContext(LaTeXParser.ExprContext,i) + + + def getRuleIndex(self): + return LaTeXParser.RULE_frac + + + + + def frac(self): + + localctx = LaTeXParser.FracContext(self, self._ctx, self.state) + self.enterRule(localctx, 54, self.RULE_frac) + try: + self.enterOuterAlt(localctx, 1) + self.state = 338 + self.match(LaTeXParser.CMD_FRAC) + self.state = 344 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [78]: + self.state = 339 + localctx.upperd = self.match(LaTeXParser.DIGIT) + pass + elif token in [21]: + self.state = 340 + self.match(LaTeXParser.L_BRACE) + self.state = 341 + localctx.upper = self.expr() + self.state = 342 + self.match(LaTeXParser.R_BRACE) + pass + else: + raise NoViableAltException(self) + + self.state = 351 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [78]: + self.state = 346 + localctx.lowerd = self.match(LaTeXParser.DIGIT) + pass + elif token in [21]: + self.state = 347 + self.match(LaTeXParser.L_BRACE) + self.state = 348 + localctx.lower = self.expr() + self.state = 349 + self.match(LaTeXParser.R_BRACE) + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class BinomContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + self.n = None # ExprContext + self.k = None # ExprContext + + def L_BRACE(self, i:int=None): + if i is None: + return self.getTokens(LaTeXParser.L_BRACE) + else: + return self.getToken(LaTeXParser.L_BRACE, i) + + def R_BRACE(self, i:int=None): + if i is None: + return self.getTokens(LaTeXParser.R_BRACE) + else: + return self.getToken(LaTeXParser.R_BRACE, i) + + def CMD_BINOM(self): + return self.getToken(LaTeXParser.CMD_BINOM, 0) + + def CMD_DBINOM(self): + return self.getToken(LaTeXParser.CMD_DBINOM, 0) + + def CMD_TBINOM(self): + return self.getToken(LaTeXParser.CMD_TBINOM, 0) + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(LaTeXParser.ExprContext) + else: + return self.getTypedRuleContext(LaTeXParser.ExprContext,i) + + + def getRuleIndex(self): + return LaTeXParser.RULE_binom + + + + + def binom(self): + + localctx = LaTeXParser.BinomContext(self, self._ctx, self.state) + self.enterRule(localctx, 56, self.RULE_binom) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 353 + _la = self._input.LA(1) + if not((((_la - 69)) & ~0x3f) == 0 and ((1 << (_la - 69)) & 7) != 0): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 354 + self.match(LaTeXParser.L_BRACE) + self.state = 355 + localctx.n = self.expr() + self.state = 356 + self.match(LaTeXParser.R_BRACE) + self.state = 357 + self.match(LaTeXParser.L_BRACE) + self.state = 358 + localctx.k = self.expr() + self.state = 359 + self.match(LaTeXParser.R_BRACE) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class FloorContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + self.val = None # ExprContext + + def L_FLOOR(self): + return self.getToken(LaTeXParser.L_FLOOR, 0) + + def R_FLOOR(self): + return self.getToken(LaTeXParser.R_FLOOR, 0) + + def expr(self): + return self.getTypedRuleContext(LaTeXParser.ExprContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_floor + + + + + def floor(self): + + localctx = LaTeXParser.FloorContext(self, self._ctx, self.state) + self.enterRule(localctx, 58, self.RULE_floor) + try: + self.enterOuterAlt(localctx, 1) + self.state = 361 + self.match(LaTeXParser.L_FLOOR) + self.state = 362 + localctx.val = self.expr() + self.state = 363 + self.match(LaTeXParser.R_FLOOR) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class CeilContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + self.val = None # ExprContext + + def L_CEIL(self): + return self.getToken(LaTeXParser.L_CEIL, 0) + + def R_CEIL(self): + return self.getToken(LaTeXParser.R_CEIL, 0) + + def expr(self): + return self.getTypedRuleContext(LaTeXParser.ExprContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_ceil + + + + + def ceil(self): + + localctx = LaTeXParser.CeilContext(self, self._ctx, self.state) + self.enterRule(localctx, 60, self.RULE_ceil) + try: + self.enterOuterAlt(localctx, 1) + self.state = 365 + self.match(LaTeXParser.L_CEIL) + self.state = 366 + localctx.val = self.expr() + self.state = 367 + self.match(LaTeXParser.R_CEIL) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class Func_normalContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def FUNC_EXP(self): + return self.getToken(LaTeXParser.FUNC_EXP, 0) + + def FUNC_LOG(self): + return self.getToken(LaTeXParser.FUNC_LOG, 0) + + def FUNC_LG(self): + return self.getToken(LaTeXParser.FUNC_LG, 0) + + def FUNC_LN(self): + return self.getToken(LaTeXParser.FUNC_LN, 0) + + def FUNC_SIN(self): + return self.getToken(LaTeXParser.FUNC_SIN, 0) + + def FUNC_COS(self): + return self.getToken(LaTeXParser.FUNC_COS, 0) + + def FUNC_TAN(self): + return self.getToken(LaTeXParser.FUNC_TAN, 0) + + def FUNC_CSC(self): + return self.getToken(LaTeXParser.FUNC_CSC, 0) + + def FUNC_SEC(self): + return self.getToken(LaTeXParser.FUNC_SEC, 0) + + def FUNC_COT(self): + return self.getToken(LaTeXParser.FUNC_COT, 0) + + def FUNC_ARCSIN(self): + return self.getToken(LaTeXParser.FUNC_ARCSIN, 0) + + def FUNC_ARCCOS(self): + return self.getToken(LaTeXParser.FUNC_ARCCOS, 0) + + def FUNC_ARCTAN(self): + return self.getToken(LaTeXParser.FUNC_ARCTAN, 0) + + def FUNC_ARCCSC(self): + return self.getToken(LaTeXParser.FUNC_ARCCSC, 0) + + def FUNC_ARCSEC(self): + return self.getToken(LaTeXParser.FUNC_ARCSEC, 0) + + def FUNC_ARCCOT(self): + return self.getToken(LaTeXParser.FUNC_ARCCOT, 0) + + def FUNC_SINH(self): + return self.getToken(LaTeXParser.FUNC_SINH, 0) + + def FUNC_COSH(self): + return self.getToken(LaTeXParser.FUNC_COSH, 0) + + def FUNC_TANH(self): + return self.getToken(LaTeXParser.FUNC_TANH, 0) + + def FUNC_ARSINH(self): + return self.getToken(LaTeXParser.FUNC_ARSINH, 0) + + def FUNC_ARCOSH(self): + return self.getToken(LaTeXParser.FUNC_ARCOSH, 0) + + def FUNC_ARTANH(self): + return self.getToken(LaTeXParser.FUNC_ARTANH, 0) + + def getRuleIndex(self): + return LaTeXParser.RULE_func_normal + + + + + def func_normal(self): + + localctx = LaTeXParser.Func_normalContext(self, self._ctx, self.state) + self.enterRule(localctx, 62, self.RULE_func_normal) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 369 + _la = self._input.LA(1) + if not(((_la) & ~0x3f) == 0 and ((1 << _la) & 576460614864470016) != 0): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class FuncContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + self.root = None # ExprContext + self.base = None # ExprContext + + def func_normal(self): + return self.getTypedRuleContext(LaTeXParser.Func_normalContext,0) + + + def L_PAREN(self): + return self.getToken(LaTeXParser.L_PAREN, 0) + + def func_arg(self): + return self.getTypedRuleContext(LaTeXParser.Func_argContext,0) + + + def R_PAREN(self): + return self.getToken(LaTeXParser.R_PAREN, 0) + + def func_arg_noparens(self): + return self.getTypedRuleContext(LaTeXParser.Func_arg_noparensContext,0) + + + def subexpr(self): + return self.getTypedRuleContext(LaTeXParser.SubexprContext,0) + + + def supexpr(self): + return self.getTypedRuleContext(LaTeXParser.SupexprContext,0) + + + def args(self): + return self.getTypedRuleContext(LaTeXParser.ArgsContext,0) + + + def LETTER(self): + return self.getToken(LaTeXParser.LETTER, 0) + + def SYMBOL(self): + return self.getToken(LaTeXParser.SYMBOL, 0) + + def SINGLE_QUOTES(self): + return self.getToken(LaTeXParser.SINGLE_QUOTES, 0) + + def FUNC_INT(self): + return self.getToken(LaTeXParser.FUNC_INT, 0) + + def DIFFERENTIAL(self): + return self.getToken(LaTeXParser.DIFFERENTIAL, 0) + + def frac(self): + return self.getTypedRuleContext(LaTeXParser.FracContext,0) + + + def additive(self): + return self.getTypedRuleContext(LaTeXParser.AdditiveContext,0) + + + def FUNC_SQRT(self): + return self.getToken(LaTeXParser.FUNC_SQRT, 0) + + def L_BRACE(self): + return self.getToken(LaTeXParser.L_BRACE, 0) + + def R_BRACE(self): + return self.getToken(LaTeXParser.R_BRACE, 0) + + def expr(self, i:int=None): + if i is None: + return self.getTypedRuleContexts(LaTeXParser.ExprContext) + else: + return self.getTypedRuleContext(LaTeXParser.ExprContext,i) + + + def L_BRACKET(self): + return self.getToken(LaTeXParser.L_BRACKET, 0) + + def R_BRACKET(self): + return self.getToken(LaTeXParser.R_BRACKET, 0) + + def FUNC_OVERLINE(self): + return self.getToken(LaTeXParser.FUNC_OVERLINE, 0) + + def mp(self): + return self.getTypedRuleContext(LaTeXParser.MpContext,0) + + + def FUNC_SUM(self): + return self.getToken(LaTeXParser.FUNC_SUM, 0) + + def FUNC_PROD(self): + return self.getToken(LaTeXParser.FUNC_PROD, 0) + + def subeq(self): + return self.getTypedRuleContext(LaTeXParser.SubeqContext,0) + + + def FUNC_LIM(self): + return self.getToken(LaTeXParser.FUNC_LIM, 0) + + def limit_sub(self): + return self.getTypedRuleContext(LaTeXParser.Limit_subContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_func + + + + + def func(self): + + localctx = LaTeXParser.FuncContext(self, self._ctx, self.state) + self.enterRule(localctx, 64, self.RULE_func) + self._la = 0 # Token type + try: + self.state = 460 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58]: + self.enterOuterAlt(localctx, 1) + self.state = 371 + self.func_normal() + self.state = 384 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,40,self._ctx) + if la_ == 1: + self.state = 373 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==73: + self.state = 372 + self.subexpr() + + + self.state = 376 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==74: + self.state = 375 + self.supexpr() + + + pass + + elif la_ == 2: + self.state = 379 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==74: + self.state = 378 + self.supexpr() + + + self.state = 382 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==73: + self.state = 381 + self.subexpr() + + + pass + + + self.state = 391 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,41,self._ctx) + if la_ == 1: + self.state = 386 + self.match(LaTeXParser.L_PAREN) + self.state = 387 + self.func_arg() + self.state = 388 + self.match(LaTeXParser.R_PAREN) + pass + + elif la_ == 2: + self.state = 390 + self.func_arg_noparens() + pass + + + pass + elif token in [77, 91]: + self.enterOuterAlt(localctx, 2) + self.state = 393 + _la = self._input.LA(1) + if not(_la==77 or _la==91): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 406 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,46,self._ctx) + if la_ == 1: + self.state = 395 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==73: + self.state = 394 + self.subexpr() + + + self.state = 398 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==90: + self.state = 397 + self.match(LaTeXParser.SINGLE_QUOTES) + + + pass + + elif la_ == 2: + self.state = 401 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==90: + self.state = 400 + self.match(LaTeXParser.SINGLE_QUOTES) + + + self.state = 404 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==73: + self.state = 403 + self.subexpr() + + + pass + + + self.state = 408 + self.match(LaTeXParser.L_PAREN) + self.state = 409 + self.args() + self.state = 410 + self.match(LaTeXParser.R_PAREN) + pass + elif token in [34]: + self.enterOuterAlt(localctx, 3) + self.state = 412 + self.match(LaTeXParser.FUNC_INT) + self.state = 419 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [73]: + self.state = 413 + self.subexpr() + self.state = 414 + self.supexpr() + pass + elif token in [74]: + self.state = 416 + self.supexpr() + self.state = 417 + self.subexpr() + pass + elif token in [15, 16, 19, 21, 23, 25, 27, 29, 30, 32, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 61, 63, 64, 68, 69, 70, 71, 72, 76, 77, 78, 91]: + pass + else: + pass + self.state = 427 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,49,self._ctx) + if la_ == 1: + self.state = 422 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,48,self._ctx) + if la_ == 1: + self.state = 421 + self.additive(0) + + + self.state = 424 + self.match(LaTeXParser.DIFFERENTIAL) + pass + + elif la_ == 2: + self.state = 425 + self.frac() + pass + + elif la_ == 3: + self.state = 426 + self.additive(0) + pass + + + pass + elif token in [63]: + self.enterOuterAlt(localctx, 4) + self.state = 429 + self.match(LaTeXParser.FUNC_SQRT) + self.state = 434 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==25: + self.state = 430 + self.match(LaTeXParser.L_BRACKET) + self.state = 431 + localctx.root = self.expr() + self.state = 432 + self.match(LaTeXParser.R_BRACKET) + + + self.state = 436 + self.match(LaTeXParser.L_BRACE) + self.state = 437 + localctx.base = self.expr() + self.state = 438 + self.match(LaTeXParser.R_BRACE) + pass + elif token in [64]: + self.enterOuterAlt(localctx, 5) + self.state = 440 + self.match(LaTeXParser.FUNC_OVERLINE) + self.state = 441 + self.match(LaTeXParser.L_BRACE) + self.state = 442 + localctx.base = self.expr() + self.state = 443 + self.match(LaTeXParser.R_BRACE) + pass + elif token in [35, 36]: + self.enterOuterAlt(localctx, 6) + self.state = 445 + _la = self._input.LA(1) + if not(_la==35 or _la==36): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 452 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [73]: + self.state = 446 + self.subeq() + self.state = 447 + self.supexpr() + pass + elif token in [74]: + self.state = 449 + self.supexpr() + self.state = 450 + self.subeq() + pass + else: + raise NoViableAltException(self) + + self.state = 454 + self.mp(0) + pass + elif token in [32]: + self.enterOuterAlt(localctx, 7) + self.state = 456 + self.match(LaTeXParser.FUNC_LIM) + self.state = 457 + self.limit_sub() + self.state = 458 + self.mp(0) + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class ArgsContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def expr(self): + return self.getTypedRuleContext(LaTeXParser.ExprContext,0) + + + def args(self): + return self.getTypedRuleContext(LaTeXParser.ArgsContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_args + + + + + def args(self): + + localctx = LaTeXParser.ArgsContext(self, self._ctx, self.state) + self.enterRule(localctx, 66, self.RULE_args) + try: + self.state = 467 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,53,self._ctx) + if la_ == 1: + self.enterOuterAlt(localctx, 1) + self.state = 462 + self.expr() + self.state = 463 + self.match(LaTeXParser.T__0) + self.state = 464 + self.args() + pass + + elif la_ == 2: + self.enterOuterAlt(localctx, 2) + self.state = 466 + self.expr() + pass + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class Limit_subContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def UNDERSCORE(self): + return self.getToken(LaTeXParser.UNDERSCORE, 0) + + def L_BRACE(self, i:int=None): + if i is None: + return self.getTokens(LaTeXParser.L_BRACE) + else: + return self.getToken(LaTeXParser.L_BRACE, i) + + def LIM_APPROACH_SYM(self): + return self.getToken(LaTeXParser.LIM_APPROACH_SYM, 0) + + def expr(self): + return self.getTypedRuleContext(LaTeXParser.ExprContext,0) + + + def R_BRACE(self, i:int=None): + if i is None: + return self.getTokens(LaTeXParser.R_BRACE) + else: + return self.getToken(LaTeXParser.R_BRACE, i) + + def LETTER(self): + return self.getToken(LaTeXParser.LETTER, 0) + + def SYMBOL(self): + return self.getToken(LaTeXParser.SYMBOL, 0) + + def CARET(self): + return self.getToken(LaTeXParser.CARET, 0) + + def ADD(self): + return self.getToken(LaTeXParser.ADD, 0) + + def SUB(self): + return self.getToken(LaTeXParser.SUB, 0) + + def getRuleIndex(self): + return LaTeXParser.RULE_limit_sub + + + + + def limit_sub(self): + + localctx = LaTeXParser.Limit_subContext(self, self._ctx, self.state) + self.enterRule(localctx, 68, self.RULE_limit_sub) + self._la = 0 # Token type + try: + self.enterOuterAlt(localctx, 1) + self.state = 469 + self.match(LaTeXParser.UNDERSCORE) + self.state = 470 + self.match(LaTeXParser.L_BRACE) + self.state = 471 + _la = self._input.LA(1) + if not(_la==77 or _la==91): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 472 + self.match(LaTeXParser.LIM_APPROACH_SYM) + self.state = 473 + self.expr() + self.state = 482 + self._errHandler.sync(self) + _la = self._input.LA(1) + if _la==74: + self.state = 474 + self.match(LaTeXParser.CARET) + self.state = 480 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [21]: + self.state = 475 + self.match(LaTeXParser.L_BRACE) + self.state = 476 + _la = self._input.LA(1) + if not(_la==15 or _la==16): + self._errHandler.recoverInline(self) + else: + self._errHandler.reportMatch(self) + self.consume() + self.state = 477 + self.match(LaTeXParser.R_BRACE) + pass + elif token in [15]: + self.state = 478 + self.match(LaTeXParser.ADD) + pass + elif token in [16]: + self.state = 479 + self.match(LaTeXParser.SUB) + pass + else: + raise NoViableAltException(self) + + + + self.state = 484 + self.match(LaTeXParser.R_BRACE) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class Func_argContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def expr(self): + return self.getTypedRuleContext(LaTeXParser.ExprContext,0) + + + def func_arg(self): + return self.getTypedRuleContext(LaTeXParser.Func_argContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_func_arg + + + + + def func_arg(self): + + localctx = LaTeXParser.Func_argContext(self, self._ctx, self.state) + self.enterRule(localctx, 70, self.RULE_func_arg) + try: + self.state = 491 + self._errHandler.sync(self) + la_ = self._interp.adaptivePredict(self._input,56,self._ctx) + if la_ == 1: + self.enterOuterAlt(localctx, 1) + self.state = 486 + self.expr() + pass + + elif la_ == 2: + self.enterOuterAlt(localctx, 2) + self.state = 487 + self.expr() + self.state = 488 + self.match(LaTeXParser.T__0) + self.state = 489 + self.func_arg() + pass + + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class Func_arg_noparensContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def mp_nofunc(self): + return self.getTypedRuleContext(LaTeXParser.Mp_nofuncContext,0) + + + def getRuleIndex(self): + return LaTeXParser.RULE_func_arg_noparens + + + + + def func_arg_noparens(self): + + localctx = LaTeXParser.Func_arg_noparensContext(self, self._ctx, self.state) + self.enterRule(localctx, 72, self.RULE_func_arg_noparens) + try: + self.enterOuterAlt(localctx, 1) + self.state = 493 + self.mp_nofunc(0) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class SubexprContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def UNDERSCORE(self): + return self.getToken(LaTeXParser.UNDERSCORE, 0) + + def atom(self): + return self.getTypedRuleContext(LaTeXParser.AtomContext,0) + + + def L_BRACE(self): + return self.getToken(LaTeXParser.L_BRACE, 0) + + def expr(self): + return self.getTypedRuleContext(LaTeXParser.ExprContext,0) + + + def R_BRACE(self): + return self.getToken(LaTeXParser.R_BRACE, 0) + + def getRuleIndex(self): + return LaTeXParser.RULE_subexpr + + + + + def subexpr(self): + + localctx = LaTeXParser.SubexprContext(self, self._ctx, self.state) + self.enterRule(localctx, 74, self.RULE_subexpr) + try: + self.enterOuterAlt(localctx, 1) + self.state = 495 + self.match(LaTeXParser.UNDERSCORE) + self.state = 501 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [27, 29, 30, 68, 69, 70, 71, 72, 76, 77, 78, 91]: + self.state = 496 + self.atom() + pass + elif token in [21]: + self.state = 497 + self.match(LaTeXParser.L_BRACE) + self.state = 498 + self.expr() + self.state = 499 + self.match(LaTeXParser.R_BRACE) + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class SupexprContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def CARET(self): + return self.getToken(LaTeXParser.CARET, 0) + + def atom(self): + return self.getTypedRuleContext(LaTeXParser.AtomContext,0) + + + def L_BRACE(self): + return self.getToken(LaTeXParser.L_BRACE, 0) + + def expr(self): + return self.getTypedRuleContext(LaTeXParser.ExprContext,0) + + + def R_BRACE(self): + return self.getToken(LaTeXParser.R_BRACE, 0) + + def getRuleIndex(self): + return LaTeXParser.RULE_supexpr + + + + + def supexpr(self): + + localctx = LaTeXParser.SupexprContext(self, self._ctx, self.state) + self.enterRule(localctx, 76, self.RULE_supexpr) + try: + self.enterOuterAlt(localctx, 1) + self.state = 503 + self.match(LaTeXParser.CARET) + self.state = 509 + self._errHandler.sync(self) + token = self._input.LA(1) + if token in [27, 29, 30, 68, 69, 70, 71, 72, 76, 77, 78, 91]: + self.state = 504 + self.atom() + pass + elif token in [21]: + self.state = 505 + self.match(LaTeXParser.L_BRACE) + self.state = 506 + self.expr() + self.state = 507 + self.match(LaTeXParser.R_BRACE) + pass + else: + raise NoViableAltException(self) + + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class SubeqContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def UNDERSCORE(self): + return self.getToken(LaTeXParser.UNDERSCORE, 0) + + def L_BRACE(self): + return self.getToken(LaTeXParser.L_BRACE, 0) + + def equality(self): + return self.getTypedRuleContext(LaTeXParser.EqualityContext,0) + + + def R_BRACE(self): + return self.getToken(LaTeXParser.R_BRACE, 0) + + def getRuleIndex(self): + return LaTeXParser.RULE_subeq + + + + + def subeq(self): + + localctx = LaTeXParser.SubeqContext(self, self._ctx, self.state) + self.enterRule(localctx, 78, self.RULE_subeq) + try: + self.enterOuterAlt(localctx, 1) + self.state = 511 + self.match(LaTeXParser.UNDERSCORE) + self.state = 512 + self.match(LaTeXParser.L_BRACE) + self.state = 513 + self.equality() + self.state = 514 + self.match(LaTeXParser.R_BRACE) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + class SupeqContext(ParserRuleContext): + __slots__ = 'parser' + + def __init__(self, parser, parent:ParserRuleContext=None, invokingState:int=-1): + super().__init__(parent, invokingState) + self.parser = parser + + def UNDERSCORE(self): + return self.getToken(LaTeXParser.UNDERSCORE, 0) + + def L_BRACE(self): + return self.getToken(LaTeXParser.L_BRACE, 0) + + def equality(self): + return self.getTypedRuleContext(LaTeXParser.EqualityContext,0) + + + def R_BRACE(self): + return self.getToken(LaTeXParser.R_BRACE, 0) + + def getRuleIndex(self): + return LaTeXParser.RULE_supeq + + + + + def supeq(self): + + localctx = LaTeXParser.SupeqContext(self, self._ctx, self.state) + self.enterRule(localctx, 80, self.RULE_supeq) + try: + self.enterOuterAlt(localctx, 1) + self.state = 516 + self.match(LaTeXParser.UNDERSCORE) + self.state = 517 + self.match(LaTeXParser.L_BRACE) + self.state = 518 + self.equality() + self.state = 519 + self.match(LaTeXParser.R_BRACE) + except RecognitionException as re: + localctx.exception = re + self._errHandler.reportError(self, re) + self._errHandler.recover(self, re) + finally: + self.exitRule() + return localctx + + + + def sempred(self, localctx:RuleContext, ruleIndex:int, predIndex:int): + if self._predicates == None: + self._predicates = dict() + self._predicates[1] = self.relation_sempred + self._predicates[4] = self.additive_sempred + self._predicates[5] = self.mp_sempred + self._predicates[6] = self.mp_nofunc_sempred + self._predicates[15] = self.exp_sempred + self._predicates[16] = self.exp_nofunc_sempred + pred = self._predicates.get(ruleIndex, None) + if pred is None: + raise Exception("No predicate with index:" + str(ruleIndex)) + else: + return pred(localctx, predIndex) + + def relation_sempred(self, localctx:RelationContext, predIndex:int): + if predIndex == 0: + return self.precpred(self._ctx, 2) + + + def additive_sempred(self, localctx:AdditiveContext, predIndex:int): + if predIndex == 1: + return self.precpred(self._ctx, 2) + + + def mp_sempred(self, localctx:MpContext, predIndex:int): + if predIndex == 2: + return self.precpred(self._ctx, 2) + + + def mp_nofunc_sempred(self, localctx:Mp_nofuncContext, predIndex:int): + if predIndex == 3: + return self.precpred(self._ctx, 2) + + + def exp_sempred(self, localctx:ExpContext, predIndex:int): + if predIndex == 4: + return self.precpred(self._ctx, 2) + + + def exp_nofunc_sempred(self, localctx:Exp_nofuncContext, predIndex:int): + if predIndex == 5: + return self.precpred(self._ctx, 2) + + + + + diff --git a/phivenv/Lib/site-packages/sympy/parsing/latex/_build_latex_antlr.py b/phivenv/Lib/site-packages/sympy/parsing/latex/_build_latex_antlr.py new file mode 100644 index 0000000000000000000000000000000000000000..ee50da5b7861154823812c7773360b53dfd29ff6 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/latex/_build_latex_antlr.py @@ -0,0 +1,91 @@ +import os +import subprocess +import glob + +from sympy.utilities.misc import debug + +here = os.path.dirname(__file__) +grammar_file = os.path.abspath(os.path.join(here, "LaTeX.g4")) +dir_latex_antlr = os.path.join(here, "_antlr") + +header = '''\ +# *** GENERATED BY `setup.py antlr`, DO NOT EDIT BY HAND *** +# +# Generated from ../LaTeX.g4, derived from latex2sympy +# latex2sympy is licensed under the MIT license +# https://github.com/augustt198/latex2sympy/blob/master/LICENSE.txt +# +# Generated with antlr4 +# antlr4 is licensed under the BSD-3-Clause License +# https://github.com/antlr/antlr4/blob/master/LICENSE.txt +''' + + +def check_antlr_version(): + debug("Checking antlr4 version...") + + try: + debug(subprocess.check_output(["antlr4"]) + .decode('utf-8').split("\n")[0]) + return True + except (subprocess.CalledProcessError, FileNotFoundError): + debug("The 'antlr4' command line tool is not installed, " + "or not on your PATH.\n" + "> Please refer to the README.md file for more information.") + return False + + +def build_parser(output_dir=dir_latex_antlr): + check_antlr_version() + + debug("Updating ANTLR-generated code in {}".format(output_dir)) + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + with open(os.path.join(output_dir, "__init__.py"), "w+") as fp: + fp.write(header) + + args = [ + "antlr4", + grammar_file, + "-o", output_dir, + # for now, not generating these as latex2sympy did not use them + "-no-visitor", + "-no-listener", + ] + + debug("Running code generation...\n\t$ {}".format(" ".join(args))) + subprocess.check_output(args, cwd=output_dir) + + debug("Applying headers, removing unnecessary files and renaming...") + # Handle case insensitive file systems. If the files are already + # generated, they will be written to latex* but LaTeX*.* won't match them. + for path in (glob.glob(os.path.join(output_dir, "LaTeX*.*")) or + glob.glob(os.path.join(output_dir, "latex*.*"))): + + # Remove files ending in .interp or .tokens as they are not needed. + if not path.endswith(".py"): + os.unlink(path) + continue + + new_path = os.path.join(output_dir, os.path.basename(path).lower()) + with open(path, 'r') as f: + lines = [line.rstrip() + '\n' for line in f] + + os.unlink(path) + + with open(new_path, "w") as out_file: + offset = 0 + while lines[offset].startswith('#'): + offset += 1 + out_file.write(header) + out_file.writelines(lines[offset:]) + + debug("\t{}".format(new_path)) + + return True + + +if __name__ == "__main__": + build_parser() diff --git a/phivenv/Lib/site-packages/sympy/parsing/latex/_parse_latex_antlr.py b/phivenv/Lib/site-packages/sympy/parsing/latex/_parse_latex_antlr.py new file mode 100644 index 0000000000000000000000000000000000000000..26604375b3a9622f8c1dacdb1d678d09c2c3ad41 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/latex/_parse_latex_antlr.py @@ -0,0 +1,607 @@ +# Ported from latex2sympy by @augustt198 +# https://github.com/augustt198/latex2sympy +# See license in LICENSE.txt +from importlib.metadata import version +import sympy +from sympy.external import import_module +from sympy.printing.str import StrPrinter +from sympy.physics.quantum.state import Bra, Ket + +from .errors import LaTeXParsingError + + +LaTeXParser = LaTeXLexer = MathErrorListener = None + +try: + LaTeXParser = import_module('sympy.parsing.latex._antlr.latexparser', + import_kwargs={'fromlist': ['LaTeXParser']}).LaTeXParser + LaTeXLexer = import_module('sympy.parsing.latex._antlr.latexlexer', + import_kwargs={'fromlist': ['LaTeXLexer']}).LaTeXLexer +except Exception: + pass + +ErrorListener = import_module('antlr4.error.ErrorListener', + warn_not_installed=True, + import_kwargs={'fromlist': ['ErrorListener']} + ) + + + +if ErrorListener: + class MathErrorListener(ErrorListener.ErrorListener): # type:ignore # noqa:F811 + def __init__(self, src): + super(ErrorListener.ErrorListener, self).__init__() + self.src = src + + def syntaxError(self, recog, symbol, line, col, msg, e): + fmt = "%s\n%s\n%s" + marker = "~" * col + "^" + + if msg.startswith("missing"): + err = fmt % (msg, self.src, marker) + elif msg.startswith("no viable"): + err = fmt % ("I expected something else here", self.src, marker) + elif msg.startswith("mismatched"): + names = LaTeXParser.literalNames + expected = [ + names[i] for i in e.getExpectedTokens() if i < len(names) + ] + if len(expected) < 10: + expected = " ".join(expected) + err = (fmt % ("I expected one of these: " + expected, self.src, + marker)) + else: + err = (fmt % ("I expected something else here", self.src, + marker)) + else: + err = fmt % ("I don't understand this", self.src, marker) + raise LaTeXParsingError(err) + + +def parse_latex(sympy, strict=False): + antlr4 = import_module('antlr4') + + if None in [antlr4, MathErrorListener] or \ + not version('antlr4-python3-runtime').startswith('4.11'): + raise ImportError("LaTeX parsing requires the antlr4 Python package," + " provided by pip (antlr4-python3-runtime) or" + " conda (antlr-python-runtime), version 4.11") + + sympy = sympy.strip() + matherror = MathErrorListener(sympy) + + stream = antlr4.InputStream(sympy) + lex = LaTeXLexer(stream) + lex.removeErrorListeners() + lex.addErrorListener(matherror) + + tokens = antlr4.CommonTokenStream(lex) + parser = LaTeXParser(tokens) + + # remove default console error listener + parser.removeErrorListeners() + parser.addErrorListener(matherror) + + relation = parser.math().relation() + if strict and (relation.start.start != 0 or relation.stop.stop != len(sympy) - 1): + raise LaTeXParsingError("Invalid LaTeX") + expr = convert_relation(relation) + + return expr + + +def convert_relation(rel): + if rel.expr(): + return convert_expr(rel.expr()) + + lh = convert_relation(rel.relation(0)) + rh = convert_relation(rel.relation(1)) + if rel.LT(): + return sympy.StrictLessThan(lh, rh) + elif rel.LTE(): + return sympy.LessThan(lh, rh) + elif rel.GT(): + return sympy.StrictGreaterThan(lh, rh) + elif rel.GTE(): + return sympy.GreaterThan(lh, rh) + elif rel.EQUAL(): + return sympy.Eq(lh, rh) + elif rel.NEQ(): + return sympy.Ne(lh, rh) + + +def convert_expr(expr): + return convert_add(expr.additive()) + + +def convert_add(add): + if add.ADD(): + lh = convert_add(add.additive(0)) + rh = convert_add(add.additive(1)) + return sympy.Add(lh, rh, evaluate=False) + elif add.SUB(): + lh = convert_add(add.additive(0)) + rh = convert_add(add.additive(1)) + if hasattr(rh, "is_Atom") and rh.is_Atom: + return sympy.Add(lh, -1 * rh, evaluate=False) + return sympy.Add(lh, sympy.Mul(-1, rh, evaluate=False), evaluate=False) + else: + return convert_mp(add.mp()) + + +def convert_mp(mp): + if hasattr(mp, 'mp'): + mp_left = mp.mp(0) + mp_right = mp.mp(1) + else: + mp_left = mp.mp_nofunc(0) + mp_right = mp.mp_nofunc(1) + + if mp.MUL() or mp.CMD_TIMES() or mp.CMD_CDOT(): + lh = convert_mp(mp_left) + rh = convert_mp(mp_right) + return sympy.Mul(lh, rh, evaluate=False) + elif mp.DIV() or mp.CMD_DIV() or mp.COLON(): + lh = convert_mp(mp_left) + rh = convert_mp(mp_right) + return sympy.Mul(lh, sympy.Pow(rh, -1, evaluate=False), evaluate=False) + else: + if hasattr(mp, 'unary'): + return convert_unary(mp.unary()) + else: + return convert_unary(mp.unary_nofunc()) + + +def convert_unary(unary): + if hasattr(unary, 'unary'): + nested_unary = unary.unary() + else: + nested_unary = unary.unary_nofunc() + if hasattr(unary, 'postfix_nofunc'): + first = unary.postfix() + tail = unary.postfix_nofunc() + postfix = [first] + tail + else: + postfix = unary.postfix() + + if unary.ADD(): + return convert_unary(nested_unary) + elif unary.SUB(): + numabs = convert_unary(nested_unary) + # Use Integer(-n) instead of Mul(-1, n) + return -numabs + elif postfix: + return convert_postfix_list(postfix) + + +def convert_postfix_list(arr, i=0): + if i >= len(arr): + raise LaTeXParsingError("Index out of bounds") + + res = convert_postfix(arr[i]) + if isinstance(res, sympy.Expr): + if i == len(arr) - 1: + return res # nothing to multiply by + else: + if i > 0: + left = convert_postfix(arr[i - 1]) + right = convert_postfix(arr[i + 1]) + if isinstance(left, sympy.Expr) and isinstance( + right, sympy.Expr): + left_syms = convert_postfix(arr[i - 1]).atoms(sympy.Symbol) + right_syms = convert_postfix(arr[i + 1]).atoms( + sympy.Symbol) + # if the left and right sides contain no variables and the + # symbol in between is 'x', treat as multiplication. + if not (left_syms or right_syms) and str(res) == 'x': + return convert_postfix_list(arr, i + 1) + # multiply by next + return sympy.Mul( + res, convert_postfix_list(arr, i + 1), evaluate=False) + else: # must be derivative + wrt = res[0] + if i == len(arr) - 1: + raise LaTeXParsingError("Expected expression for derivative") + else: + expr = convert_postfix_list(arr, i + 1) + return sympy.Derivative(expr, wrt) + + +def do_subs(expr, at): + if at.expr(): + at_expr = convert_expr(at.expr()) + syms = at_expr.atoms(sympy.Symbol) + if len(syms) == 0: + return expr + elif len(syms) > 0: + sym = next(iter(syms)) + return expr.subs(sym, at_expr) + elif at.equality(): + lh = convert_expr(at.equality().expr(0)) + rh = convert_expr(at.equality().expr(1)) + return expr.subs(lh, rh) + + +def convert_postfix(postfix): + if hasattr(postfix, 'exp'): + exp_nested = postfix.exp() + else: + exp_nested = postfix.exp_nofunc() + + exp = convert_exp(exp_nested) + for op in postfix.postfix_op(): + if op.BANG(): + if isinstance(exp, list): + raise LaTeXParsingError("Cannot apply postfix to derivative") + exp = sympy.factorial(exp, evaluate=False) + elif op.eval_at(): + ev = op.eval_at() + at_b = None + at_a = None + if ev.eval_at_sup(): + at_b = do_subs(exp, ev.eval_at_sup()) + if ev.eval_at_sub(): + at_a = do_subs(exp, ev.eval_at_sub()) + if at_b is not None and at_a is not None: + exp = sympy.Add(at_b, -1 * at_a, evaluate=False) + elif at_b is not None: + exp = at_b + elif at_a is not None: + exp = at_a + + return exp + + +def convert_exp(exp): + if hasattr(exp, 'exp'): + exp_nested = exp.exp() + else: + exp_nested = exp.exp_nofunc() + + if exp_nested: + base = convert_exp(exp_nested) + if isinstance(base, list): + raise LaTeXParsingError("Cannot raise derivative to power") + if exp.atom(): + exponent = convert_atom(exp.atom()) + elif exp.expr(): + exponent = convert_expr(exp.expr()) + return sympy.Pow(base, exponent, evaluate=False) + else: + if hasattr(exp, 'comp'): + return convert_comp(exp.comp()) + else: + return convert_comp(exp.comp_nofunc()) + + +def convert_comp(comp): + if comp.group(): + return convert_expr(comp.group().expr()) + elif comp.abs_group(): + return sympy.Abs(convert_expr(comp.abs_group().expr()), evaluate=False) + elif comp.atom(): + return convert_atom(comp.atom()) + elif comp.floor(): + return convert_floor(comp.floor()) + elif comp.ceil(): + return convert_ceil(comp.ceil()) + elif comp.func(): + return convert_func(comp.func()) + + +def convert_atom(atom): + if atom.LETTER(): + sname = atom.LETTER().getText() + if atom.subexpr(): + if atom.subexpr().expr(): # subscript is expr + subscript = convert_expr(atom.subexpr().expr()) + else: # subscript is atom + subscript = convert_atom(atom.subexpr().atom()) + sname += '_{' + StrPrinter().doprint(subscript) + '}' + if atom.SINGLE_QUOTES(): + sname += atom.SINGLE_QUOTES().getText() # put after subscript for easy identify + return sympy.Symbol(sname) + elif atom.SYMBOL(): + s = atom.SYMBOL().getText()[1:] + if s == "infty": + return sympy.oo + else: + if atom.subexpr(): + subscript = None + if atom.subexpr().expr(): # subscript is expr + subscript = convert_expr(atom.subexpr().expr()) + else: # subscript is atom + subscript = convert_atom(atom.subexpr().atom()) + subscriptName = StrPrinter().doprint(subscript) + s += '_{' + subscriptName + '}' + return sympy.Symbol(s) + elif atom.number(): + s = atom.number().getText().replace(",", "") + return sympy.Number(s) + elif atom.DIFFERENTIAL(): + var = get_differential_var(atom.DIFFERENTIAL()) + return sympy.Symbol('d' + var.name) + elif atom.mathit(): + text = rule2text(atom.mathit().mathit_text()) + return sympy.Symbol(text) + elif atom.frac(): + return convert_frac(atom.frac()) + elif atom.binom(): + return convert_binom(atom.binom()) + elif atom.bra(): + val = convert_expr(atom.bra().expr()) + return Bra(val) + elif atom.ket(): + val = convert_expr(atom.ket().expr()) + return Ket(val) + + +def rule2text(ctx): + stream = ctx.start.getInputStream() + # starting index of starting token + startIdx = ctx.start.start + # stopping index of stopping token + stopIdx = ctx.stop.stop + + return stream.getText(startIdx, stopIdx) + + +def convert_frac(frac): + diff_op = False + partial_op = False + if frac.lower and frac.upper: + lower_itv = frac.lower.getSourceInterval() + lower_itv_len = lower_itv[1] - lower_itv[0] + 1 + if (frac.lower.start == frac.lower.stop + and frac.lower.start.type == LaTeXLexer.DIFFERENTIAL): + wrt = get_differential_var_str(frac.lower.start.text) + diff_op = True + elif (lower_itv_len == 2 and frac.lower.start.type == LaTeXLexer.SYMBOL + and frac.lower.start.text == '\\partial' + and (frac.lower.stop.type == LaTeXLexer.LETTER + or frac.lower.stop.type == LaTeXLexer.SYMBOL)): + partial_op = True + wrt = frac.lower.stop.text + if frac.lower.stop.type == LaTeXLexer.SYMBOL: + wrt = wrt[1:] + + if diff_op or partial_op: + wrt = sympy.Symbol(wrt) + if (diff_op and frac.upper.start == frac.upper.stop + and frac.upper.start.type == LaTeXLexer.LETTER + and frac.upper.start.text == 'd'): + return [wrt] + elif (partial_op and frac.upper.start == frac.upper.stop + and frac.upper.start.type == LaTeXLexer.SYMBOL + and frac.upper.start.text == '\\partial'): + return [wrt] + upper_text = rule2text(frac.upper) + + expr_top = None + if diff_op and upper_text.startswith('d'): + expr_top = parse_latex(upper_text[1:]) + elif partial_op and frac.upper.start.text == '\\partial': + expr_top = parse_latex(upper_text[len('\\partial'):]) + if expr_top: + return sympy.Derivative(expr_top, wrt) + if frac.upper: + expr_top = convert_expr(frac.upper) + else: + expr_top = sympy.Number(frac.upperd.text) + if frac.lower: + expr_bot = convert_expr(frac.lower) + else: + expr_bot = sympy.Number(frac.lowerd.text) + inverse_denom = sympy.Pow(expr_bot, -1, evaluate=False) + if expr_top == 1: + return inverse_denom + else: + return sympy.Mul(expr_top, inverse_denom, evaluate=False) + +def convert_binom(binom): + expr_n = convert_expr(binom.n) + expr_k = convert_expr(binom.k) + return sympy.binomial(expr_n, expr_k, evaluate=False) + +def convert_floor(floor): + val = convert_expr(floor.val) + return sympy.floor(val, evaluate=False) + +def convert_ceil(ceil): + val = convert_expr(ceil.val) + return sympy.ceiling(val, evaluate=False) + +def convert_func(func): + if func.func_normal(): + if func.L_PAREN(): # function called with parenthesis + arg = convert_func_arg(func.func_arg()) + else: + arg = convert_func_arg(func.func_arg_noparens()) + + name = func.func_normal().start.text[1:] + + # change arc -> a + if name in [ + "arcsin", "arccos", "arctan", "arccsc", "arcsec", "arccot" + ]: + name = "a" + name[3:] + expr = getattr(sympy.functions, name)(arg, evaluate=False) + if name in ["arsinh", "arcosh", "artanh"]: + name = "a" + name[2:] + expr = getattr(sympy.functions, name)(arg, evaluate=False) + + if name == "exp": + expr = sympy.exp(arg, evaluate=False) + + if name in ("log", "lg", "ln"): + if func.subexpr(): + if func.subexpr().expr(): + base = convert_expr(func.subexpr().expr()) + else: + base = convert_atom(func.subexpr().atom()) + elif name == "lg": # ISO 80000-2:2019 + base = 10 + elif name in ("ln", "log"): # SymPy's latex printer prints ln as log by default + base = sympy.E + expr = sympy.log(arg, base, evaluate=False) + + func_pow = None + should_pow = True + if func.supexpr(): + if func.supexpr().expr(): + func_pow = convert_expr(func.supexpr().expr()) + else: + func_pow = convert_atom(func.supexpr().atom()) + + if name in [ + "sin", "cos", "tan", "csc", "sec", "cot", "sinh", "cosh", + "tanh" + ]: + if func_pow == -1: + name = "a" + name + should_pow = False + expr = getattr(sympy.functions, name)(arg, evaluate=False) + + if func_pow and should_pow: + expr = sympy.Pow(expr, func_pow, evaluate=False) + + return expr + elif func.LETTER() or func.SYMBOL(): + if func.LETTER(): + fname = func.LETTER().getText() + elif func.SYMBOL(): + fname = func.SYMBOL().getText()[1:] + fname = str(fname) # can't be unicode + if func.subexpr(): + if func.subexpr().expr(): # subscript is expr + subscript = convert_expr(func.subexpr().expr()) + else: # subscript is atom + subscript = convert_atom(func.subexpr().atom()) + subscriptName = StrPrinter().doprint(subscript) + fname += '_{' + subscriptName + '}' + if func.SINGLE_QUOTES(): + fname += func.SINGLE_QUOTES().getText() + input_args = func.args() + output_args = [] + while input_args.args(): # handle multiple arguments to function + output_args.append(convert_expr(input_args.expr())) + input_args = input_args.args() + output_args.append(convert_expr(input_args.expr())) + return sympy.Function(fname)(*output_args) + elif func.FUNC_INT(): + return handle_integral(func) + elif func.FUNC_SQRT(): + expr = convert_expr(func.base) + if func.root: + r = convert_expr(func.root) + return sympy.root(expr, r, evaluate=False) + else: + return sympy.sqrt(expr, evaluate=False) + elif func.FUNC_OVERLINE(): + expr = convert_expr(func.base) + return sympy.conjugate(expr, evaluate=False) + elif func.FUNC_SUM(): + return handle_sum_or_prod(func, "summation") + elif func.FUNC_PROD(): + return handle_sum_or_prod(func, "product") + elif func.FUNC_LIM(): + return handle_limit(func) + + +def convert_func_arg(arg): + if hasattr(arg, 'expr'): + return convert_expr(arg.expr()) + else: + return convert_mp(arg.mp_nofunc()) + + +def handle_integral(func): + if func.additive(): + integrand = convert_add(func.additive()) + elif func.frac(): + integrand = convert_frac(func.frac()) + else: + integrand = 1 + + int_var = None + if func.DIFFERENTIAL(): + int_var = get_differential_var(func.DIFFERENTIAL()) + else: + for sym in integrand.atoms(sympy.Symbol): + s = str(sym) + if len(s) > 1 and s[0] == 'd': + if s[1] == '\\': + int_var = sympy.Symbol(s[2:]) + else: + int_var = sympy.Symbol(s[1:]) + int_sym = sym + if int_var: + integrand = integrand.subs(int_sym, 1) + else: + # Assume dx by default + int_var = sympy.Symbol('x') + + if func.subexpr(): + if func.subexpr().atom(): + lower = convert_atom(func.subexpr().atom()) + else: + lower = convert_expr(func.subexpr().expr()) + if func.supexpr().atom(): + upper = convert_atom(func.supexpr().atom()) + else: + upper = convert_expr(func.supexpr().expr()) + return sympy.Integral(integrand, (int_var, lower, upper)) + else: + return sympy.Integral(integrand, int_var) + + +def handle_sum_or_prod(func, name): + val = convert_mp(func.mp()) + iter_var = convert_expr(func.subeq().equality().expr(0)) + start = convert_expr(func.subeq().equality().expr(1)) + if func.supexpr().expr(): # ^{expr} + end = convert_expr(func.supexpr().expr()) + else: # ^atom + end = convert_atom(func.supexpr().atom()) + + if name == "summation": + return sympy.Sum(val, (iter_var, start, end)) + elif name == "product": + return sympy.Product(val, (iter_var, start, end)) + + +def handle_limit(func): + sub = func.limit_sub() + if sub.LETTER(): + var = sympy.Symbol(sub.LETTER().getText()) + elif sub.SYMBOL(): + var = sympy.Symbol(sub.SYMBOL().getText()[1:]) + else: + var = sympy.Symbol('x') + if sub.SUB(): + direction = "-" + elif sub.ADD(): + direction = "+" + else: + direction = "+-" + approaching = convert_expr(sub.expr()) + content = convert_mp(func.mp()) + + return sympy.Limit(content, var, approaching, direction) + + +def get_differential_var(d): + text = get_differential_var_str(d.getText()) + return sympy.Symbol(text) + + +def get_differential_var_str(text): + for i in range(1, len(text)): + c = text[i] + if not (c == " " or c == "\r" or c == "\n" or c == "\t"): + idx = i + break + text = text[idx:] + if text[0] == "\\": + text = text[1:] + return text diff --git a/phivenv/Lib/site-packages/sympy/parsing/latex/errors.py b/phivenv/Lib/site-packages/sympy/parsing/latex/errors.py new file mode 100644 index 0000000000000000000000000000000000000000..d8c3ef9f06279df42d4b2054acc4cfe39b6682a5 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/latex/errors.py @@ -0,0 +1,2 @@ +class LaTeXParsingError(Exception): + pass diff --git a/phivenv/Lib/site-packages/sympy/parsing/latex/lark/__init__.py b/phivenv/Lib/site-packages/sympy/parsing/latex/lark/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..92e58d3172e100cc376d0b416b3835d164bd5647 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/latex/lark/__init__.py @@ -0,0 +1,2 @@ +from .latex_parser import parse_latex_lark, LarkLaTeXParser # noqa +from .transformer import TransformToSymPyExpr # noqa diff --git a/phivenv/Lib/site-packages/sympy/parsing/latex/lark/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/latex/lark/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..acd7e8dcbd8c104ed07addb45b93738f142b8ad0 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/latex/lark/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/latex/lark/__pycache__/latex_parser.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/latex/lark/__pycache__/latex_parser.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e883af6ac53f70c572c715de41063b39a01b52fa Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/latex/lark/__pycache__/latex_parser.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/latex/lark/__pycache__/transformer.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/latex/lark/__pycache__/transformer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cad64843da571f2c5edfaa302eda7d55437676f0 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/latex/lark/__pycache__/transformer.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/latex/lark/grammar/greek_symbols.lark b/phivenv/Lib/site-packages/sympy/parsing/latex/lark/grammar/greek_symbols.lark new file mode 100644 index 0000000000000000000000000000000000000000..7439fab9dcac284dc3c9b5fbfa4fc6db8b29dfd2 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/latex/lark/grammar/greek_symbols.lark @@ -0,0 +1,28 @@ +// Greek symbols +// TODO: Shouold we include the uppercase variants for the symbols where the uppercase variant doesn't have a separate meaning? +ALPHA: "\\alpha" +BETA: "\\beta" +GAMMA: "\\gamma" +DELTA: "\\delta" // TODO: Should this be included? Delta usually denotes other things. +EPSILON: "\\epsilon" | "\\varepsilon" +ZETA: "\\zeta" +ETA: "\\eta" +THETA: "\\theta" | "\\vartheta" +// TODO: Should I add iota to the list? +KAPPA: "\\kappa" +LAMBDA: "\\lambda" // TODO: What about the uppercase variant? +MU: "\\mu" +NU: "\\nu" +XI: "\\xi" +// TODO: Should there be a separate note for transforming \pi into sympy.pi? +RHO: "\\rho" | "\\varrho" +// TODO: What should we do about sigma? +TAU: "\\tau" +UPSILON: "\\upsilon" +PHI: "\\phi" | "\\varphi" +CHI: "\\chi" +PSI: "\\psi" +OMEGA: "\\omega" + +GREEK_SYMBOL: ALPHA | BETA | GAMMA | DELTA | EPSILON | ZETA | ETA | THETA | KAPPA + | LAMBDA | MU | NU | XI | RHO | TAU | UPSILON | PHI | CHI | PSI | OMEGA diff --git a/phivenv/Lib/site-packages/sympy/parsing/latex/lark/grammar/latex.lark b/phivenv/Lib/site-packages/sympy/parsing/latex/lark/grammar/latex.lark new file mode 100644 index 0000000000000000000000000000000000000000..43e8d0e9105fa4da9bcdd2c0fa6111f6d523c9a9 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/latex/lark/grammar/latex.lark @@ -0,0 +1,403 @@ +%ignore /[ \t\n\r]+/ + +%ignore "\\," | "\\thinspace" | "\\:" | "\\medspace" | "\\;" | "\\thickspace" +%ignore "\\quad" | "\\qquad" +%ignore "\\!" | "\\negthinspace" | "\\negmedspace" | "\\negthickspace" +%ignore "\\vrule" | "\\vcenter" | "\\vbox" | "\\vskip" | "\\vspace" | "\\hfill" +%ignore "\\*" | "\\-" | "\\." | "\\/" | "\\(" | "\\=" + +%ignore "\\left" | "\\right" +%ignore "\\limits" | "\\nolimits" +%ignore "\\displaystyle" + +///////////////////// tokens /////////////////////// + +// basic binary operators +ADD: "+" +SUB: "-" +MUL: "*" +DIV: "/" + +// tokens with distinct left and right symbols +L_BRACE: "{" +R_BRACE: "}" +L_BRACE_LITERAL: "\\{" +R_BRACE_LITERAL: "\\}" +L_BRACKET: "[" +R_BRACKET: "]" +L_CEIL: "\\lceil" +R_CEIL: "\\rceil" +L_FLOOR: "\\lfloor" +R_FLOOR: "\\rfloor" +L_PAREN: "(" +R_PAREN: ")" + +// limit, integral, sum, and product symbols +FUNC_LIM: "\\lim" +LIM_APPROACH_SYM: "\\to" | "\\rightarrow" | "\\Rightarrow" | "\\longrightarrow" | "\\Longrightarrow" +FUNC_INT: "\\int" | "\\intop" +FUNC_SUM: "\\sum" +FUNC_PROD: "\\prod" + +// common functions +FUNC_EXP: "\\exp" +FUNC_LOG: "\\log" +FUNC_LN: "\\ln" +FUNC_LG: "\\lg" +FUNC_MIN: "\\min" +FUNC_MAX: "\\max" + +// trigonometric functions +FUNC_SIN: "\\sin" +FUNC_COS: "\\cos" +FUNC_TAN: "\\tan" +FUNC_CSC: "\\csc" +FUNC_SEC: "\\sec" +FUNC_COT: "\\cot" + +// inverse trigonometric functions +FUNC_ARCSIN: "\\arcsin" +FUNC_ARCCOS: "\\arccos" +FUNC_ARCTAN: "\\arctan" +FUNC_ARCCSC: "\\arccsc" +FUNC_ARCSEC: "\\arcsec" +FUNC_ARCCOT: "\\arccot" + +// hyperbolic trigonometric functions +FUNC_SINH: "\\sinh" +FUNC_COSH: "\\cosh" +FUNC_TANH: "\\tanh" +FUNC_ARSINH: "\\arsinh" +FUNC_ARCOSH: "\\arcosh" +FUNC_ARTANH: "\\artanh" + +FUNC_SQRT: "\\sqrt" + +// miscellaneous symbols +CMD_TIMES: "\\times" +CMD_CDOT: "\\cdot" +CMD_DIV: "\\div" +CMD_FRAC: "\\frac" | "\\dfrac" | "\\tfrac" | "\\nicefrac" +CMD_BINOM: "\\binom" | "\\dbinom" | "\\tbinom" +CMD_OVERLINE: "\\overline" +CMD_LANGLE: "\\langle" +CMD_RANGLE: "\\rangle" + +CMD_MATHIT: "\\mathit" + +CMD_INFTY: "\\infty" + +BANG: "!" +BAR: "|" +CARET: "^" +COLON: ":" +UNDERSCORE: "_" + +// relational symbols +EQUAL: "=" +NOT_EQUAL: "\\neq" | "\\ne" +LT: "<" +LTE: "\\leq" | "\\le" | "\\leqslant" +GT: ">" +GTE: "\\geq" | "\\ge" | "\\geqslant" + +DIV_SYMBOL: CMD_DIV | DIV +MUL_SYMBOL: MUL | CMD_TIMES | CMD_CDOT + +%import .greek_symbols.GREEK_SYMBOL + +UPRIGHT_DIFFERENTIAL_SYMBOL: "\\text{d}" | "\\mathrm{d}" +DIFFERENTIAL_SYMBOL: "d" | UPRIGHT_DIFFERENTIAL_SYMBOL + +// disallow "d" as a variable name because we want to parse "d" as a differential symbol. +SYMBOL: /[a-zA-Z]'*/ +GREEK_SYMBOL_WITH_PRIMES: GREEK_SYMBOL "'"* +LATIN_SYMBOL_WITH_LATIN_SUBSCRIPT: /([a-zA-Z]'*)_(([A-Za-z0-9]|[a-zA-Z]+)|\{([A-Za-z0-9]|[a-zA-Z]+'*)\})/ +LATIN_SYMBOL_WITH_GREEK_SUBSCRIPT: /([a-zA-Z]'*)_/ GREEK_SYMBOL | /([a-zA-Z]'*)_/ L_BRACE GREEK_SYMBOL_WITH_PRIMES R_BRACE +// best to define the variant with braces like that instead of shoving it all into one case like in +// /([a-zA-Z])_/ L_BRACE? GREEK_SYMBOL R_BRACE? because then we can easily error out on input like +// r"h_{\theta" +GREEK_SYMBOL_WITH_LATIN_SUBSCRIPT: GREEK_SYMBOL_WITH_PRIMES /_(([A-Za-z0-9]|[a-zA-Z]+)|\{([A-Za-z0-9]|[a-zA-Z]+'*)\})/ +GREEK_SYMBOL_WITH_GREEK_SUBSCRIPT: GREEK_SYMBOL_WITH_PRIMES /_/ (GREEK_SYMBOL | L_BRACE GREEK_SYMBOL_WITH_PRIMES R_BRACE) +MULTI_LETTER_SYMBOL: /[a-zA-Z]+(\s+[a-zA-Z]+)*'*/ + +%import common.DIGIT -> DIGIT + +CMD_PRIME: "\\prime" +CMD_ASTERISK: "\\ast" + +PRIMES: "'"+ +STARS: "*"+ +PRIMES_VIA_CMD: CMD_PRIME+ +STARS_VIA_CMD: CMD_ASTERISK+ + +CMD_IMAGINARY_UNIT: "\\imaginaryunit" + +CMD_BEGIN: "\\begin" +CMD_END: "\\end" + +// matrices +IGNORE_L: /[ \t\n\r]*/ L_BRACE* /[ \t\n\r]*/ +IGNORE_R: /[ \t\n\r]*/ R_BRACE* /[ \t\n\r]*/ +ARRAY_MATRIX_BEGIN: L_BRACE "array" R_BRACE L_BRACE /[^}]*/ R_BRACE +ARRAY_MATRIX_END: L_BRACE "array" R_BRACE +AMSMATH_MATRIX: L_BRACE "matrix" R_BRACE +AMSMATH_PMATRIX: L_BRACE "pmatrix" R_BRACE +AMSMATH_BMATRIX: L_BRACE "bmatrix" R_BRACE +// Without the (L|R)_PARENs and (L|R)_BRACKETs, a matrix defined using +// \begin{array}...\end{array} or \begin{matrix}...\end{matrix} must +// not qualify as a complete matrix expression; this is done so that +// if we have \begin{array}...\end{array} or \begin{matrix}...\end{matrix} +// between BAR pairs, then they should be interpreted as determinants as +// opposed to sympy.Abs (absolute value) applied to a matrix. +CMD_BEGIN_AMSPMATRIX_AMSBMATRIX: CMD_BEGIN (AMSMATH_PMATRIX | AMSMATH_BMATRIX) +CMD_BEGIN_ARRAY_AMSMATRIX: (L_PAREN | L_BRACKET) IGNORE_L CMD_BEGIN (ARRAY_MATRIX_BEGIN | AMSMATH_MATRIX) +CMD_MATRIX_BEGIN: CMD_BEGIN_AMSPMATRIX_AMSBMATRIX | CMD_BEGIN_ARRAY_AMSMATRIX +CMD_END_AMSPMATRIX_AMSBMATRIX: CMD_END (AMSMATH_PMATRIX | AMSMATH_BMATRIX) +CMD_END_ARRAY_AMSMATRIX: CMD_END (ARRAY_MATRIX_END | AMSMATH_MATRIX) IGNORE_R "\\right"? (R_PAREN | R_BRACKET) +CMD_MATRIX_END: CMD_END_AMSPMATRIX_AMSBMATRIX | CMD_END_ARRAY_AMSMATRIX +MATRIX_COL_DELIM: "&" +MATRIX_ROW_DELIM: "\\\\" +FUNC_MATRIX_TRACE: "\\trace" +FUNC_MATRIX_ADJUGATE: "\\adjugate" + +// determinants +AMSMATH_VMATRIX: L_BRACE "vmatrix" R_BRACE +CMD_DETERMINANT_BEGIN_SIMPLE: CMD_BEGIN AMSMATH_VMATRIX +CMD_DETERMINANT_BEGIN_VARIANT: BAR IGNORE_L CMD_BEGIN (ARRAY_MATRIX_BEGIN | AMSMATH_MATRIX) +CMD_DETERMINANT_BEGIN: CMD_DETERMINANT_BEGIN_SIMPLE | CMD_DETERMINANT_BEGIN_VARIANT +CMD_DETERMINANT_END_SIMPLE: CMD_END AMSMATH_VMATRIX +CMD_DETERMINANT_END_VARIANT: CMD_END (ARRAY_MATRIX_END | AMSMATH_MATRIX) IGNORE_R "\\right"? BAR +CMD_DETERMINANT_END: CMD_DETERMINANT_END_SIMPLE | CMD_DETERMINANT_END_VARIANT +FUNC_DETERMINANT: "\\det" + +//////////////////// grammar ////////////////////// + +latex_string: _relation | _expression + +_one_letter_symbol: SYMBOL + | LATIN_SYMBOL_WITH_LATIN_SUBSCRIPT + | LATIN_SYMBOL_WITH_GREEK_SUBSCRIPT + | GREEK_SYMBOL_WITH_LATIN_SUBSCRIPT + | GREEK_SYMBOL_WITH_GREEK_SUBSCRIPT + | GREEK_SYMBOL_WITH_PRIMES +// LuaTeX-generated outputs of \mathit{foo'} and \mathit{foo}' +// seem to be the same on the surface. We allow both styles. +multi_letter_symbol: CMD_MATHIT L_BRACE MULTI_LETTER_SYMBOL R_BRACE + | CMD_MATHIT L_BRACE MULTI_LETTER_SYMBOL R_BRACE /'+/ +number: /\d+(\.\d*)?/ | CMD_IMAGINARY_UNIT + +_atomic_expr: _one_letter_symbol + | multi_letter_symbol + | number + | CMD_INFTY + +group_round_parentheses: L_PAREN _expression R_PAREN +group_square_brackets: L_BRACKET _expression R_BRACKET +group_curly_parentheses: L_BRACE _expression R_BRACE + +_relation: eq | ne | lt | lte | gt | gte + +eq: _expression EQUAL _expression +ne: _expression NOT_EQUAL _expression +lt: _expression LT _expression +lte: _expression LTE _expression +gt: _expression GT _expression +gte: _expression GTE _expression + +_expression_core: _atomic_expr | group_curly_parentheses + +add: _expression ADD _expression_mul + | ADD _expression_mul +sub: _expression SUB _expression_mul + | SUB _expression_mul +mul: _expression_mul MUL_SYMBOL _expression_power +div: _expression_mul DIV_SYMBOL _expression_power + +adjacent_expressions: (_one_letter_symbol | number) _expression_mul + | group_round_parentheses (group_round_parentheses | _one_letter_symbol) + | _function _function + | fraction _expression_mul + +_expression_func: _expression_core + | group_round_parentheses + | fraction + | binomial + | _function + | _integral// | derivative + | limit + | matrix + +_expression_power: _expression_func | superscript | matrix_prime | symbol_prime + +_expression_mul: _expression_power + | mul | div | adjacent_expressions + | summation | product + +_expression: _expression_mul | add | sub + +_limit_dir: "+" | "-" | L_BRACE ("+" | "-") R_BRACE + +limit_dir_expr: _expression CARET _limit_dir + +group_curly_parentheses_lim: L_BRACE _expression LIM_APPROACH_SYM (limit_dir_expr | _expression) R_BRACE + +limit: FUNC_LIM UNDERSCORE group_curly_parentheses_lim _expression + +differential: DIFFERENTIAL_SYMBOL _one_letter_symbol + +//_derivative_operator: CMD_FRAC L_BRACE DIFFERENTIAL_SYMBOL R_BRACE L_BRACE differential R_BRACE + +//derivative: _derivative_operator _expression + +_integral: normal_integral | integral_with_special_fraction + +normal_integral: FUNC_INT _expression DIFFERENTIAL_SYMBOL _one_letter_symbol + | FUNC_INT (CARET _expression_core UNDERSCORE _expression_core)? _expression? DIFFERENTIAL_SYMBOL _one_letter_symbol + | FUNC_INT (UNDERSCORE _expression_core CARET _expression_core)? _expression? DIFFERENTIAL_SYMBOL _one_letter_symbol + +group_curly_parentheses_int: L_BRACE _expression? differential R_BRACE + +special_fraction: CMD_FRAC group_curly_parentheses_int group_curly_parentheses + +integral_with_special_fraction: FUNC_INT special_fraction + | FUNC_INT (CARET _expression_core UNDERSCORE _expression_core)? special_fraction + | FUNC_INT (UNDERSCORE _expression_core CARET _expression_core)? special_fraction + +group_curly_parentheses_special: UNDERSCORE L_BRACE _atomic_expr EQUAL _atomic_expr R_BRACE CARET _expression_core + | CARET _expression_core UNDERSCORE L_BRACE _atomic_expr EQUAL _atomic_expr R_BRACE + +summation: FUNC_SUM group_curly_parentheses_special _expression + | FUNC_SUM group_curly_parentheses_special _expression + +product: FUNC_PROD group_curly_parentheses_special _expression + | FUNC_PROD group_curly_parentheses_special _expression + +superscript: _expression_func CARET (_expression_power | CMD_PRIME | CMD_ASTERISK) + | _expression_func CARET L_BRACE (PRIMES | STARS | PRIMES_VIA_CMD | STARS_VIA_CMD) R_BRACE + +matrix_prime: (matrix | group_round_parentheses) PRIMES + +symbol_prime: (LATIN_SYMBOL_WITH_LATIN_SUBSCRIPT + | LATIN_SYMBOL_WITH_GREEK_SUBSCRIPT + | GREEK_SYMBOL_WITH_LATIN_SUBSCRIPT + | GREEK_SYMBOL_WITH_GREEK_SUBSCRIPT) PRIMES + +fraction: _basic_fraction + | _simple_fraction + | _general_fraction + +_basic_fraction: CMD_FRAC DIGIT (DIGIT | SYMBOL | GREEK_SYMBOL_WITH_PRIMES) + +_simple_fraction: CMD_FRAC DIGIT group_curly_parentheses + | CMD_FRAC group_curly_parentheses (DIGIT | SYMBOL | GREEK_SYMBOL_WITH_PRIMES) + +_general_fraction: CMD_FRAC group_curly_parentheses group_curly_parentheses + +binomial: _basic_binomial + | _simple_binomial + | _general_binomial + +_basic_binomial: CMD_BINOM DIGIT (DIGIT | SYMBOL | GREEK_SYMBOL_WITH_PRIMES) + +_simple_binomial: CMD_BINOM DIGIT group_curly_parentheses + | CMD_BINOM group_curly_parentheses (DIGIT | SYMBOL | GREEK_SYMBOL_WITH_PRIMES) + +_general_binomial: CMD_BINOM group_curly_parentheses group_curly_parentheses + +list_of_expressions: _expression ("," _expression)* + +function_applied: _one_letter_symbol L_PAREN list_of_expressions R_PAREN + +min: FUNC_MIN L_PAREN list_of_expressions R_PAREN + +max: FUNC_MAX L_PAREN list_of_expressions R_PAREN + +bra: CMD_LANGLE _expression BAR + +ket: BAR _expression CMD_RANGLE + +inner_product: CMD_LANGLE _expression BAR _expression CMD_RANGLE + +_function: function_applied + | abs | floor | ceil + | _trigonometric_function | _inverse_trigonometric_function + | _trigonometric_function_power + | _hyperbolic_trigonometric_function | _inverse_hyperbolic_trigonometric_function + | exponential + | log + | square_root + | factorial + | conjugate + | max | min + | bra | ket | inner_product + | determinant + | trace + | adjugate + +exponential: FUNC_EXP _expression + +log: FUNC_LOG _expression + | FUNC_LN _expression + | FUNC_LG _expression + | FUNC_LOG UNDERSCORE (DIGIT | _one_letter_symbol) _expression + | FUNC_LOG UNDERSCORE group_curly_parentheses _expression + +square_root: FUNC_SQRT group_curly_parentheses + | FUNC_SQRT group_square_brackets group_curly_parentheses + +factorial: _expression_func BANG + +conjugate: CMD_OVERLINE group_curly_parentheses + | CMD_OVERLINE DIGIT + +_trigonometric_function: sin | cos | tan | csc | sec | cot + +sin: FUNC_SIN _expression +cos: FUNC_COS _expression +tan: FUNC_TAN _expression +csc: FUNC_CSC _expression +sec: FUNC_SEC _expression +cot: FUNC_COT _expression + +_trigonometric_function_power: sin_power | cos_power | tan_power | csc_power | sec_power | cot_power + +sin_power: FUNC_SIN CARET _expression_core _expression +cos_power: FUNC_COS CARET _expression_core _expression +tan_power: FUNC_TAN CARET _expression_core _expression +csc_power: FUNC_CSC CARET _expression_core _expression +sec_power: FUNC_SEC CARET _expression_core _expression +cot_power: FUNC_COT CARET _expression_core _expression + +_hyperbolic_trigonometric_function: sinh | cosh | tanh + +sinh: FUNC_SINH _expression +cosh: FUNC_COSH _expression +tanh: FUNC_TANH _expression + +_inverse_trigonometric_function: arcsin | arccos | arctan | arccsc | arcsec | arccot + +arcsin: FUNC_ARCSIN _expression +arccos: FUNC_ARCCOS _expression +arctan: FUNC_ARCTAN _expression +arccsc: FUNC_ARCCSC _expression +arcsec: FUNC_ARCSEC _expression +arccot: FUNC_ARCCOT _expression + +_inverse_hyperbolic_trigonometric_function: asinh | acosh | atanh + +asinh: FUNC_ARSINH _expression +acosh: FUNC_ARCOSH _expression +atanh: FUNC_ARTANH _expression + +abs: BAR _expression BAR +floor: L_FLOOR _expression R_FLOOR +ceil: L_CEIL _expression R_CEIL + +matrix: CMD_MATRIX_BEGIN matrix_body CMD_MATRIX_END +matrix_body: matrix_row (MATRIX_ROW_DELIM matrix_row)* (MATRIX_ROW_DELIM)? +matrix_row: _expression (MATRIX_COL_DELIM _expression)* +determinant: (CMD_DETERMINANT_BEGIN matrix_body CMD_DETERMINANT_END) + | FUNC_DETERMINANT _expression +trace: FUNC_MATRIX_TRACE _expression +adjugate: FUNC_MATRIX_ADJUGATE _expression diff --git a/phivenv/Lib/site-packages/sympy/parsing/latex/lark/latex_parser.py b/phivenv/Lib/site-packages/sympy/parsing/latex/lark/latex_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..29f594b0de4bfd4648df1554d5863a37afff035f --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/latex/lark/latex_parser.py @@ -0,0 +1,145 @@ +import os +import logging +import re +from pathlib import Path + +from sympy.external import import_module +from sympy.parsing.latex.lark.transformer import TransformToSymPyExpr + +_lark = import_module("lark") + + +class LarkLaTeXParser: + r"""Class for converting input `\mathrm{\LaTeX}` strings into SymPy Expressions. + It holds all the necessary internal data for doing so, and exposes hooks for + customizing its behavior. + + Parameters + ========== + + print_debug_output : bool, optional + + If set to ``True``, prints debug output to the logger. Defaults to ``False``. + + transform : bool, optional + + If set to ``True``, the class runs the Transformer class on the parse tree + generated by running ``Lark.parse`` on the input string. Defaults to ``True``. + + Setting it to ``False`` can help with debugging the `\mathrm{\LaTeX}` grammar. + + grammar_file : str, optional + + The path to the grammar file that the parser should use. If set to ``None``, + it uses the default grammar, which is in ``grammar/latex.lark``, relative to + the ``sympy/parsing/latex/lark/`` directory. + + transformer : str, optional + + The name of the Transformer class to use. If set to ``None``, it uses the + default transformer class, which is :py:func:`TransformToSymPyExpr`. + + """ + def __init__(self, print_debug_output=False, transform=True, grammar_file=None, transformer=None): + grammar_dir_path = os.path.join(os.path.dirname(__file__), "grammar/") + + if grammar_file is None: + latex_grammar = Path(os.path.join(grammar_dir_path, "latex.lark")).read_text(encoding="utf-8") + else: + latex_grammar = Path(grammar_file).read_text(encoding="utf-8") + + self.parser = _lark.Lark( + latex_grammar, + source_path=grammar_dir_path, + parser="earley", + start="latex_string", + lexer="auto", + ambiguity="explicit", + propagate_positions=False, + maybe_placeholders=False, + keep_all_tokens=True) + + self.print_debug_output = print_debug_output + self.transform_expr = transform + + if transformer is None: + self.transformer = TransformToSymPyExpr() + else: + self.transformer = transformer() + + def doparse(self, s: str): + if self.print_debug_output: + _lark.logger.setLevel(logging.DEBUG) + + parse_tree = self.parser.parse(s) + + if not self.transform_expr: + # exit early and return the parse tree + _lark.logger.debug("expression = %s", s) + _lark.logger.debug(parse_tree) + _lark.logger.debug(parse_tree.pretty()) + return parse_tree + + if self.print_debug_output: + # print this stuff before attempting to run the transformer + _lark.logger.debug("expression = %s", s) + # print the `parse_tree` variable + _lark.logger.debug(parse_tree.pretty()) + + sympy_expression = self.transformer.transform(parse_tree) + + if self.print_debug_output: + _lark.logger.debug("SymPy expression = %s", sympy_expression) + + return sympy_expression + + +if _lark is not None: + _lark_latex_parser = LarkLaTeXParser() + + +def parse_latex_lark(s: str): + """ + Experimental LaTeX parser using Lark. + + This function is still under development and its API may change with the + next releases of SymPy. + """ + if _lark is None: + raise ImportError("Lark is probably not installed") + return _lark_latex_parser.doparse(s) + + +def _pretty_print_lark_trees(tree, indent=0, show_expr=True): + if isinstance(tree, _lark.Token): + return tree.value + + data = str(tree.data) + + is_expr = data.startswith("expression") + + if is_expr: + data = re.sub(r"^expression", "E", data) + + is_ambig = (data == "_ambig") + + if is_ambig: + new_indent = indent + 2 + else: + new_indent = indent + + output = "" + show_node = not is_expr or show_expr + + if show_node: + output += str(data) + "(" + + if is_ambig: + output += "\n" + "\n".join([" " * new_indent + _pretty_print_lark_trees(i, new_indent, show_expr) for i in tree.children]) + else: + output += ",".join([_pretty_print_lark_trees(i, new_indent, show_expr) for i in tree.children]) + + if show_node: + output += ")" + + return output diff --git a/phivenv/Lib/site-packages/sympy/parsing/latex/lark/transformer.py b/phivenv/Lib/site-packages/sympy/parsing/latex/lark/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..cbd514b6517336207a57de6d28bcce25858071dc --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/latex/lark/transformer.py @@ -0,0 +1,730 @@ +import re + +import sympy +from sympy.external import import_module +from sympy.parsing.latex.errors import LaTeXParsingError + +lark = import_module("lark") + +if lark: + from lark import Transformer, Token, Tree # type: ignore +else: + class Transformer: # type: ignore + def transform(self, *args): + pass + + + class Token: # type: ignore + pass + + + class Tree: # type: ignore + pass + + +# noinspection PyPep8Naming,PyMethodMayBeStatic +class TransformToSymPyExpr(Transformer): + """Returns a SymPy expression that is generated by traversing the ``lark.Tree`` + passed to the ``.transform()`` function. + + Notes + ===== + + **This class is never supposed to be used directly.** + + In order to tweak the behavior of this class, it has to be subclassed and then after + the required modifications are made, the name of the new class should be passed to + the :py:class:`LarkLaTeXParser` class by using the ``transformer`` argument in the + constructor. + + Parameters + ========== + + visit_tokens : bool, optional + For information about what this option does, see `here + `_. + + Note that the option must be set to ``True`` for the default parser to work. + """ + + SYMBOL = sympy.Symbol + DIGIT = sympy.core.numbers.Integer + + def CMD_INFTY(self, tokens): + return sympy.oo + + def GREEK_SYMBOL_WITH_PRIMES(self, tokens): + # we omit the first character because it is a backslash. Also, if the variable name has "var" in it, + # like "varphi" or "varepsilon", we remove that too + variable_name = re.sub("var", "", tokens[1:]) + + return sympy.Symbol(variable_name) + + def LATIN_SYMBOL_WITH_LATIN_SUBSCRIPT(self, tokens): + base, sub = tokens.value.split("_") + if sub.startswith("{"): + return sympy.Symbol("%s_{%s}" % (base, sub[1:-1])) + else: + return sympy.Symbol("%s_{%s}" % (base, sub)) + + def GREEK_SYMBOL_WITH_LATIN_SUBSCRIPT(self, tokens): + base, sub = tokens.value.split("_") + greek_letter = re.sub("var", "", base[1:]) + + if sub.startswith("{"): + return sympy.Symbol("%s_{%s}" % (greek_letter, sub[1:-1])) + else: + return sympy.Symbol("%s_{%s}" % (greek_letter, sub)) + + def LATIN_SYMBOL_WITH_GREEK_SUBSCRIPT(self, tokens): + base, sub = tokens.value.split("_") + if sub.startswith("{"): + greek_letter = sub[2:-1] + else: + greek_letter = sub[1:] + + greek_letter = re.sub("var", "", greek_letter) + return sympy.Symbol("%s_{%s}" % (base, greek_letter)) + + + def GREEK_SYMBOL_WITH_GREEK_SUBSCRIPT(self, tokens): + base, sub = tokens.value.split("_") + greek_base = re.sub("var", "", base[1:]) + + if sub.startswith("{"): + greek_sub = sub[2:-1] + else: + greek_sub = sub[1:] + + greek_sub = re.sub("var", "", greek_sub) + return sympy.Symbol("%s_{%s}" % (greek_base, greek_sub)) + + def multi_letter_symbol(self, tokens): + if len(tokens) == 4: # no primes (single quotes) on symbol + return sympy.Symbol(tokens[2]) + if len(tokens) == 5: # there are primes on the symbol + return sympy.Symbol(tokens[2] + tokens[4]) + + def number(self, tokens): + if tokens[0].type == "CMD_IMAGINARY_UNIT": + return sympy.I + + if "." in tokens[0]: + return sympy.core.numbers.Float(tokens[0]) + else: + return sympy.core.numbers.Integer(tokens[0]) + + def latex_string(self, tokens): + return tokens[0] + + def group_round_parentheses(self, tokens): + return tokens[1] + + def group_square_brackets(self, tokens): + return tokens[1] + + def group_curly_parentheses(self, tokens): + return tokens[1] + + def eq(self, tokens): + return sympy.Eq(tokens[0], tokens[2]) + + def ne(self, tokens): + return sympy.Ne(tokens[0], tokens[2]) + + def lt(self, tokens): + return sympy.Lt(tokens[0], tokens[2]) + + def lte(self, tokens): + return sympy.Le(tokens[0], tokens[2]) + + def gt(self, tokens): + return sympy.Gt(tokens[0], tokens[2]) + + def gte(self, tokens): + return sympy.Ge(tokens[0], tokens[2]) + + def add(self, tokens): + if len(tokens) == 2: # +a + return tokens[1] + if len(tokens) == 3: # a + b + lh = tokens[0] + rh = tokens[2] + + if self._obj_is_sympy_Matrix(lh) or self._obj_is_sympy_Matrix(rh): + return sympy.MatAdd(lh, rh) + + return sympy.Add(lh, rh) + + def sub(self, tokens): + if len(tokens) == 2: # -a + x = tokens[1] + + if self._obj_is_sympy_Matrix(x): + return sympy.MatMul(-1, x) + + return -x + if len(tokens) == 3: # a - b + lh = tokens[0] + rh = tokens[2] + + if self._obj_is_sympy_Matrix(lh) or self._obj_is_sympy_Matrix(rh): + return sympy.MatAdd(lh, sympy.MatMul(-1, rh)) + + return sympy.Add(lh, -rh) + + def mul(self, tokens): + lh = tokens[0] + rh = tokens[2] + + if self._obj_is_sympy_Matrix(lh) or self._obj_is_sympy_Matrix(rh): + return sympy.MatMul(lh, rh) + + return sympy.Mul(lh, rh) + + def div(self, tokens): + return self._handle_division(tokens[0], tokens[2]) + + def adjacent_expressions(self, tokens): + # Most of the time, if two expressions are next to each other, it means implicit multiplication, + # but not always + from sympy.physics.quantum import Bra, Ket + if isinstance(tokens[0], Ket) and isinstance(tokens[1], Bra): + from sympy.physics.quantum import OuterProduct + return OuterProduct(tokens[0], tokens[1]) + elif tokens[0] == sympy.Symbol("d"): + # If the leftmost token is a "d", then it is highly likely that this is a differential + return tokens[0], tokens[1] + elif isinstance(tokens[0], tuple): + # then we have a derivative + return sympy.Derivative(tokens[1], tokens[0][1]) + else: + return sympy.Mul(tokens[0], tokens[1]) + + def superscript(self, tokens): + def isprime(x): + return isinstance(x, Token) and x.type == "PRIMES" + + def iscmdprime(x): + return isinstance(x, Token) and (x.type == "PRIMES_VIA_CMD" + or x.type == "CMD_PRIME") + + def isstar(x): + return isinstance(x, Token) and x.type == "STARS" + + def iscmdstar(x): + return isinstance(x, Token) and (x.type == "STARS_VIA_CMD" + or x.type == "CMD_ASTERISK") + + base = tokens[0] + if len(tokens) == 3: # a^b OR a^\prime OR a^\ast + sup = tokens[2] + if len(tokens) == 5: + # a^{'}, a^{''}, ... OR + # a^{*}, a^{**}, ... OR + # a^{\prime}, a^{\prime\prime}, ... OR + # a^{\ast}, a^{\ast\ast}, ... + sup = tokens[3] + + if self._obj_is_sympy_Matrix(base): + if sup == sympy.Symbol("T"): + return sympy.Transpose(base) + if sup == sympy.Symbol("H"): + return sympy.adjoint(base) + if isprime(sup): + sup = sup.value + if len(sup) % 2 == 0: + return base + return sympy.Transpose(base) + if iscmdprime(sup): + sup = sup.value + if (len(sup)/len(r"\prime")) % 2 == 0: + return base + return sympy.Transpose(base) + if isstar(sup): + sup = sup.value + # need .doit() in order to be consistent with + # sympy.adjoint() which returns the evaluated adjoint + # of a matrix + if len(sup) % 2 == 0: + return base.doit() + return sympy.adjoint(base) + if iscmdstar(sup): + sup = sup.value + # need .doit() for same reason as above + if (len(sup)/len(r"\ast")) % 2 == 0: + return base.doit() + return sympy.adjoint(base) + + if isprime(sup) or iscmdprime(sup) or isstar(sup) or iscmdstar(sup): + raise LaTeXParsingError(f"{base} with superscript {sup} is not understood.") + + return sympy.Pow(base, sup) + + def matrix_prime(self, tokens): + base = tokens[0] + primes = tokens[1].value + + if not self._obj_is_sympy_Matrix(base): + raise LaTeXParsingError(f"({base}){primes} is not understood.") + + if len(primes) % 2 == 0: + return base + + return sympy.Transpose(base) + + def symbol_prime(self, tokens): + base = tokens[0] + primes = tokens[1].value + + return sympy.Symbol(f"{base.name}{primes}") + + def fraction(self, tokens): + numerator = tokens[1] + if isinstance(tokens[2], tuple): + # we only need the variable w.r.t. which we are differentiating + _, variable = tokens[2] + + # we will pass this information upwards + return "derivative", variable + else: + denominator = tokens[2] + return self._handle_division(numerator, denominator) + + def binomial(self, tokens): + return sympy.binomial(tokens[1], tokens[2]) + + def normal_integral(self, tokens): + underscore_index = None + caret_index = None + + if "_" in tokens: + # we need to know the index because the next item in the list is the + # arguments for the lower bound of the integral + underscore_index = tokens.index("_") + + if "^" in tokens: + # we need to know the index because the next item in the list is the + # arguments for the upper bound of the integral + caret_index = tokens.index("^") + + lower_bound = tokens[underscore_index + 1] if underscore_index else None + upper_bound = tokens[caret_index + 1] if caret_index else None + + differential_symbol = self._extract_differential_symbol(tokens) + + if differential_symbol is None: + raise LaTeXParsingError("Differential symbol was not found in the expression." + "Valid differential symbols are \"d\", \"\\text{d}, and \"\\mathrm{d}\".") + + # else we can assume that a differential symbol was found + differential_variable_index = tokens.index(differential_symbol) + 1 + differential_variable = tokens[differential_variable_index] + + # we can't simply do something like `if (lower_bound and not upper_bound) ...` because this would + # evaluate to `True` if the `lower_bound` is 0 and upper bound is non-zero + if lower_bound is not None and upper_bound is None: + # then one was given and the other wasn't + raise LaTeXParsingError("Lower bound for the integral was found, but upper bound was not found.") + + if upper_bound is not None and lower_bound is None: + # then one was given and the other wasn't + raise LaTeXParsingError("Upper bound for the integral was found, but lower bound was not found.") + + # check if any expression was given or not. If it wasn't, then set the integrand to 1. + if underscore_index is not None and underscore_index == differential_variable_index - 3: + # The Token at differential_variable_index - 2 should be the integrand. However, if going one more step + # backwards after that gives us the underscore, then that means that there _was_ no integrand. + # Example: \int^7_0 dx + integrand = 1 + elif caret_index is not None and caret_index == differential_variable_index - 3: + # The Token at differential_variable_index - 2 should be the integrand. However, if going one more step + # backwards after that gives us the caret, then that means that there _was_ no integrand. + # Example: \int_0^7 dx + integrand = 1 + elif differential_variable_index == 2: + # this means we have something like "\int dx", because the "\int" symbol will always be + # at index 0 in `tokens` + integrand = 1 + else: + # The Token at differential_variable_index - 1 is the differential symbol itself, so we need to go one + # more step before that. + integrand = tokens[differential_variable_index - 2] + + if lower_bound is not None: + # then we have a definite integral + + # we can assume that either both the lower and upper bounds are given, or + # neither of them are + return sympy.Integral(integrand, (differential_variable, lower_bound, upper_bound)) + else: + # we have an indefinite integral + return sympy.Integral(integrand, differential_variable) + + def group_curly_parentheses_int(self, tokens): + # return signature is a tuple consisting of the expression in the numerator, along with the variable of + # integration + if len(tokens) == 3: + return 1, tokens[1] + elif len(tokens) == 4: + return tokens[1], tokens[2] + # there are no other possibilities + + def special_fraction(self, tokens): + numerator, variable = tokens[1] + denominator = tokens[2] + + # We pass the integrand, along with information about the variable of integration, upw + return sympy.Mul(numerator, sympy.Pow(denominator, -1)), variable + + def integral_with_special_fraction(self, tokens): + underscore_index = None + caret_index = None + + if "_" in tokens: + # we need to know the index because the next item in the list is the + # arguments for the lower bound of the integral + underscore_index = tokens.index("_") + + if "^" in tokens: + # we need to know the index because the next item in the list is the + # arguments for the upper bound of the integral + caret_index = tokens.index("^") + + lower_bound = tokens[underscore_index + 1] if underscore_index else None + upper_bound = tokens[caret_index + 1] if caret_index else None + + # we can't simply do something like `if (lower_bound and not upper_bound) ...` because this would + # evaluate to `True` if the `lower_bound` is 0 and upper bound is non-zero + if lower_bound is not None and upper_bound is None: + # then one was given and the other wasn't + raise LaTeXParsingError("Lower bound for the integral was found, but upper bound was not found.") + + if upper_bound is not None and lower_bound is None: + # then one was given and the other wasn't + raise LaTeXParsingError("Upper bound for the integral was found, but lower bound was not found.") + + integrand, differential_variable = tokens[-1] + + if lower_bound is not None: + # then we have a definite integral + + # we can assume that either both the lower and upper bounds are given, or + # neither of them are + return sympy.Integral(integrand, (differential_variable, lower_bound, upper_bound)) + else: + # we have an indefinite integral + return sympy.Integral(integrand, differential_variable) + + def group_curly_parentheses_special(self, tokens): + underscore_index = tokens.index("_") + caret_index = tokens.index("^") + + # given the type of expressions we are parsing, we can assume that the lower limit + # will always use braces around its arguments. This is because we don't support + # converting unconstrained sums into SymPy expressions. + + # first we isolate the bottom limit + left_brace_index = tokens.index("{", underscore_index) + right_brace_index = tokens.index("}", underscore_index) + + bottom_limit = tokens[left_brace_index + 1: right_brace_index] + + # next, we isolate the upper limit + top_limit = tokens[caret_index + 1:] + + # the code below will be useful for supporting things like `\sum_{n = 0}^{n = 5} n^2` + # if "{" in top_limit: + # left_brace_index = tokens.index("{", caret_index) + # if left_brace_index != -1: + # # then there's a left brace in the string, and we need to find the closing right brace + # right_brace_index = tokens.index("}", caret_index) + # top_limit = tokens[left_brace_index + 1: right_brace_index] + + # print(f"top limit = {top_limit}") + + index_variable = bottom_limit[0] + lower_limit = bottom_limit[-1] + upper_limit = top_limit[0] # for now, the index will always be 0 + + # print(f"return value = ({index_variable}, {lower_limit}, {upper_limit})") + + return index_variable, lower_limit, upper_limit + + def summation(self, tokens): + return sympy.Sum(tokens[2], tokens[1]) + + def product(self, tokens): + return sympy.Product(tokens[2], tokens[1]) + + def limit_dir_expr(self, tokens): + caret_index = tokens.index("^") + + if "{" in tokens: + left_curly_brace_index = tokens.index("{", caret_index) + direction = tokens[left_curly_brace_index + 1] + else: + direction = tokens[caret_index + 1] + + if direction == "+": + return tokens[0], "+" + elif direction == "-": + return tokens[0], "-" + else: + return tokens[0], "+-" + + def group_curly_parentheses_lim(self, tokens): + limit_variable = tokens[1] + if isinstance(tokens[3], tuple): + destination, direction = tokens[3] + else: + destination = tokens[3] + direction = "+-" + + return limit_variable, destination, direction + + def limit(self, tokens): + limit_variable, destination, direction = tokens[2] + + return sympy.Limit(tokens[-1], limit_variable, destination, direction) + + def differential(self, tokens): + return tokens[1] + + def derivative(self, tokens): + return sympy.Derivative(tokens[-1], tokens[5]) + + def list_of_expressions(self, tokens): + if len(tokens) == 1: + # we return it verbatim because the function_applied node expects + # a list + return tokens + else: + def remove_tokens(args): + if isinstance(args, Token): + if args.type != "COMMA": + # An unexpected token was encountered + raise LaTeXParsingError("A comma token was expected, but some other token was encountered.") + return False + return True + + return filter(remove_tokens, tokens) + + def function_applied(self, tokens): + return sympy.Function(tokens[0])(*tokens[2]) + + def min(self, tokens): + return sympy.Min(*tokens[2]) + + def max(self, tokens): + return sympy.Max(*tokens[2]) + + def bra(self, tokens): + from sympy.physics.quantum import Bra + return Bra(tokens[1]) + + def ket(self, tokens): + from sympy.physics.quantum import Ket + return Ket(tokens[1]) + + def inner_product(self, tokens): + from sympy.physics.quantum import Bra, Ket, InnerProduct + return InnerProduct(Bra(tokens[1]), Ket(tokens[3])) + + def sin(self, tokens): + return sympy.sin(tokens[1]) + + def cos(self, tokens): + return sympy.cos(tokens[1]) + + def tan(self, tokens): + return sympy.tan(tokens[1]) + + def csc(self, tokens): + return sympy.csc(tokens[1]) + + def sec(self, tokens): + return sympy.sec(tokens[1]) + + def cot(self, tokens): + return sympy.cot(tokens[1]) + + def sin_power(self, tokens): + exponent = tokens[2] + if exponent == -1: + return sympy.asin(tokens[-1]) + else: + return sympy.Pow(sympy.sin(tokens[-1]), exponent) + + def cos_power(self, tokens): + exponent = tokens[2] + if exponent == -1: + return sympy.acos(tokens[-1]) + else: + return sympy.Pow(sympy.cos(tokens[-1]), exponent) + + def tan_power(self, tokens): + exponent = tokens[2] + if exponent == -1: + return sympy.atan(tokens[-1]) + else: + return sympy.Pow(sympy.tan(tokens[-1]), exponent) + + def csc_power(self, tokens): + exponent = tokens[2] + if exponent == -1: + return sympy.acsc(tokens[-1]) + else: + return sympy.Pow(sympy.csc(tokens[-1]), exponent) + + def sec_power(self, tokens): + exponent = tokens[2] + if exponent == -1: + return sympy.asec(tokens[-1]) + else: + return sympy.Pow(sympy.sec(tokens[-1]), exponent) + + def cot_power(self, tokens): + exponent = tokens[2] + if exponent == -1: + return sympy.acot(tokens[-1]) + else: + return sympy.Pow(sympy.cot(tokens[-1]), exponent) + + def arcsin(self, tokens): + return sympy.asin(tokens[1]) + + def arccos(self, tokens): + return sympy.acos(tokens[1]) + + def arctan(self, tokens): + return sympy.atan(tokens[1]) + + def arccsc(self, tokens): + return sympy.acsc(tokens[1]) + + def arcsec(self, tokens): + return sympy.asec(tokens[1]) + + def arccot(self, tokens): + return sympy.acot(tokens[1]) + + def sinh(self, tokens): + return sympy.sinh(tokens[1]) + + def cosh(self, tokens): + return sympy.cosh(tokens[1]) + + def tanh(self, tokens): + return sympy.tanh(tokens[1]) + + def asinh(self, tokens): + return sympy.asinh(tokens[1]) + + def acosh(self, tokens): + return sympy.acosh(tokens[1]) + + def atanh(self, tokens): + return sympy.atanh(tokens[1]) + + def abs(self, tokens): + return sympy.Abs(tokens[1]) + + def floor(self, tokens): + return sympy.floor(tokens[1]) + + def ceil(self, tokens): + return sympy.ceiling(tokens[1]) + + def factorial(self, tokens): + return sympy.factorial(tokens[0]) + + def conjugate(self, tokens): + return sympy.conjugate(tokens[1]) + + def square_root(self, tokens): + if len(tokens) == 2: + # then there was no square bracket argument + return sympy.sqrt(tokens[1]) + elif len(tokens) == 3: + # then there _was_ a square bracket argument + return sympy.root(tokens[2], tokens[1]) + + def exponential(self, tokens): + return sympy.exp(tokens[1]) + + def log(self, tokens): + if tokens[0].type == "FUNC_LG": + # we don't need to check if there's an underscore or not because having one + # in this case would be meaningless + # TODO: ANTLR refers to ISO 80000-2:2019. should we keep base 10 or base 2? + return sympy.log(tokens[1], 10) + elif tokens[0].type == "FUNC_LN": + return sympy.log(tokens[1]) + elif tokens[0].type == "FUNC_LOG": + # we check if a base was specified or not + if "_" in tokens: + # then a base was specified + return sympy.log(tokens[3], tokens[2]) + else: + # a base was not specified + return sympy.log(tokens[1]) + + def _extract_differential_symbol(self, s: str): + differential_symbols = {"d", r"\text{d}", r"\mathrm{d}"} + + differential_symbol = next((symbol for symbol in differential_symbols if symbol in s), None) + + return differential_symbol + + def matrix(self, tokens): + def is_matrix_row(x): + return (isinstance(x, Tree) and x.data == "matrix_row") + + def is_not_col_delim(y): + return (not isinstance(y, Token) or y.type != "MATRIX_COL_DELIM") + + matrix_body = tokens[1].children + return sympy.Matrix([[y for y in x.children if is_not_col_delim(y)] + for x in matrix_body if is_matrix_row(x)]) + + def determinant(self, tokens): + if len(tokens) == 2: # \det A + if not self._obj_is_sympy_Matrix(tokens[1]): + raise LaTeXParsingError("Cannot take determinant of non-matrix.") + + return tokens[1].det() + + if len(tokens) == 3: # | A | + return self.matrix(tokens).det() + + def trace(self, tokens): + if not self._obj_is_sympy_Matrix(tokens[1]): + raise LaTeXParsingError("Cannot take trace of non-matrix.") + + return sympy.Trace(tokens[1]) + + def adjugate(self, tokens): + if not self._obj_is_sympy_Matrix(tokens[1]): + raise LaTeXParsingError("Cannot take adjugate of non-matrix.") + + # need .doit() since MatAdd does not support .adjugate() method + return tokens[1].doit().adjugate() + + def _obj_is_sympy_Matrix(self, obj): + if hasattr(obj, "is_Matrix"): + return obj.is_Matrix + + return isinstance(obj, sympy.Matrix) + + def _handle_division(self, numerator, denominator): + if self._obj_is_sympy_Matrix(denominator): + raise LaTeXParsingError("Cannot divide by matrices like this since " + "it is not clear if left or right multiplication " + "by the inverse is intended. Try explicitly " + "multiplying by the inverse instead.") + + if self._obj_is_sympy_Matrix(numerator): + return sympy.MatMul(numerator, sympy.Pow(denominator, -1)) + + return sympy.Mul(numerator, sympy.Pow(denominator, -1)) diff --git a/phivenv/Lib/site-packages/sympy/parsing/mathematica.py b/phivenv/Lib/site-packages/sympy/parsing/mathematica.py new file mode 100644 index 0000000000000000000000000000000000000000..b5824a8c33ee402d03e6c5617eeeea21d4a457d1 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/mathematica.py @@ -0,0 +1,1085 @@ +from __future__ import annotations +import re +import typing +from itertools import product +from typing import Any, Callable + +import sympy +from sympy import Mul, Add, Pow, Rational, log, exp, sqrt, cos, sin, tan, asin, acos, acot, asec, acsc, sinh, cosh, tanh, asinh, \ + acosh, atanh, acoth, asech, acsch, expand, im, flatten, polylog, cancel, expand_trig, sign, simplify, \ + UnevaluatedExpr, S, atan, atan2, Mod, Max, Min, rf, Ei, Si, Ci, airyai, airyaiprime, airybi, primepi, prime, \ + isprime, cot, sec, csc, csch, sech, coth, Function, I, pi, Tuple, GreaterThan, StrictGreaterThan, StrictLessThan, \ + LessThan, Equality, Or, And, Lambda, Integer, Dummy, symbols +from sympy.core.sympify import sympify, _sympify +from sympy.functions.special.bessel import airybiprime +from sympy.functions.special.error_functions import li +from sympy.utilities.exceptions import sympy_deprecation_warning + + +def mathematica(s, additional_translations=None): + sympy_deprecation_warning( + """The ``mathematica`` function for the Mathematica parser is now +deprecated. Use ``parse_mathematica`` instead. +The parameter ``additional_translation`` can be replaced by SymPy's +.replace( ) or .subs( ) methods on the output expression instead.""", + deprecated_since_version="1.11", + active_deprecations_target="mathematica-parser-new", + ) + parser = MathematicaParser(additional_translations) + return sympify(parser._parse_old(s)) + + +def parse_mathematica(s): + """ + Translate a string containing a Wolfram Mathematica expression to a SymPy + expression. + + If the translator is unable to find a suitable SymPy expression, the + ``FullForm`` of the Mathematica expression will be output, using SymPy + ``Function`` objects as nodes of the syntax tree. + + Examples + ======== + + >>> from sympy.parsing.mathematica import parse_mathematica + >>> parse_mathematica("Sin[x]^2 Tan[y]") + sin(x)**2*tan(y) + >>> e = parse_mathematica("F[7,5,3]") + >>> e + F(7, 5, 3) + >>> from sympy import Function, Max, Min + >>> e.replace(Function("F"), lambda *x: Max(*x)*Min(*x)) + 21 + + Both standard input form and Mathematica full form are supported: + + >>> parse_mathematica("x*(a + b)") + x*(a + b) + >>> parse_mathematica("Times[x, Plus[a, b]]") + x*(a + b) + + To get a matrix from Wolfram's code: + + >>> m = parse_mathematica("{{a, b}, {c, d}}") + >>> m + ((a, b), (c, d)) + >>> from sympy import Matrix + >>> Matrix(m) + Matrix([ + [a, b], + [c, d]]) + + If the translation into equivalent SymPy expressions fails, an SymPy + expression equivalent to Wolfram Mathematica's "FullForm" will be created: + + >>> parse_mathematica("x_.") + Optional(Pattern(x, Blank())) + >>> parse_mathematica("Plus @@ {x, y, z}") + Apply(Plus, (x, y, z)) + >>> parse_mathematica("f[x_, 3] := x^3 /; x > 0") + SetDelayed(f(Pattern(x, Blank()), 3), Condition(x**3, x > 0)) + """ + parser = MathematicaParser() + return parser.parse(s) + + +def _parse_Function(*args): + if len(args) == 1: + arg = args[0] + Slot = Function("Slot") + slots = arg.atoms(Slot) + numbers = [a.args[0] for a in slots] + number_of_arguments = max(numbers) + if isinstance(number_of_arguments, Integer): + variables = symbols(f"dummy0:{number_of_arguments}", cls=Dummy) + return Lambda(variables, arg.xreplace({Slot(i+1): v for i, v in enumerate(variables)})) + return Lambda((), arg) + elif len(args) == 2: + variables = args[0] + body = args[1] + return Lambda(variables, body) + else: + raise SyntaxError("Function node expects 1 or 2 arguments") + + +def _deco(cls): + cls._initialize_class() + return cls + + +@_deco +class MathematicaParser: + """ + An instance of this class converts a string of a Wolfram Mathematica + expression to a SymPy expression. + + The main parser acts internally in three stages: + + 1. tokenizer: tokenizes the Mathematica expression and adds the missing * + operators. Handled by ``_from_mathematica_to_tokens(...)`` + 2. full form list: sort the list of strings output by the tokenizer into a + syntax tree of nested lists and strings, equivalent to Mathematica's + ``FullForm`` expression output. This is handled by the function + ``_from_tokens_to_fullformlist(...)``. + 3. SymPy expression: the syntax tree expressed as full form list is visited + and the nodes with equivalent classes in SymPy are replaced. Unknown + syntax tree nodes are cast to SymPy ``Function`` objects. This is + handled by ``_from_fullformlist_to_sympy(...)``. + + """ + + # left: Mathematica, right: SymPy + CORRESPONDENCES = { + 'Sqrt[x]': 'sqrt(x)', + 'Rational[x,y]': 'Rational(x,y)', + 'Exp[x]': 'exp(x)', + 'Log[x]': 'log(x)', + 'Log[x,y]': 'log(y,x)', + 'Log2[x]': 'log(x,2)', + 'Log10[x]': 'log(x,10)', + 'Mod[x,y]': 'Mod(x,y)', + 'Max[*x]': 'Max(*x)', + 'Min[*x]': 'Min(*x)', + 'Pochhammer[x,y]':'rf(x,y)', + 'ArcTan[x,y]':'atan2(y,x)', + 'ExpIntegralEi[x]': 'Ei(x)', + 'SinIntegral[x]': 'Si(x)', + 'CosIntegral[x]': 'Ci(x)', + 'AiryAi[x]': 'airyai(x)', + 'AiryAiPrime[x]': 'airyaiprime(x)', + 'AiryBi[x]' :'airybi(x)', + 'AiryBiPrime[x]' :'airybiprime(x)', + 'LogIntegral[x]':' li(x)', + 'PrimePi[x]': 'primepi(x)', + 'Prime[x]': 'prime(x)', + 'PrimeQ[x]': 'isprime(x)' + } + + # trigonometric, e.t.c. + for arc, tri, h in product(('', 'Arc'), ( + 'Sin', 'Cos', 'Tan', 'Cot', 'Sec', 'Csc'), ('', 'h')): + fm = arc + tri + h + '[x]' + if arc: # arc func + fs = 'a' + tri.lower() + h + '(x)' + else: # non-arc func + fs = tri.lower() + h + '(x)' + CORRESPONDENCES.update({fm: fs}) + + REPLACEMENTS = { + ' ': '', + '^': '**', + '{': '[', + '}': ']', + } + + RULES = { + # a single whitespace to '*' + 'whitespace': ( + re.compile(r''' + (?:(?<=[a-zA-Z\d])|(?<=\d\.)) # a letter or a number + \s+ # any number of whitespaces + (?:(?=[a-zA-Z\d])|(?=\.\d)) # a letter or a number + ''', re.VERBOSE), + '*'), + + # add omitted '*' character + 'add*_1': ( + re.compile(r''' + (?:(?<=[])\d])|(?<=\d\.)) # ], ) or a number + # '' + (?=[(a-zA-Z]) # ( or a single letter + ''', re.VERBOSE), + '*'), + + # add omitted '*' character (variable letter preceding) + 'add*_2': ( + re.compile(r''' + (?<=[a-zA-Z]) # a letter + \( # ( as a character + (?=.) # any characters + ''', re.VERBOSE), + '*('), + + # convert 'Pi' to 'pi' + 'Pi': ( + re.compile(r''' + (?: + \A|(?<=[^a-zA-Z]) + ) + Pi # 'Pi' is 3.14159... in Mathematica + (?=[^a-zA-Z]) + ''', re.VERBOSE), + 'pi'), + } + + # Mathematica function name pattern + FM_PATTERN = re.compile(r''' + (?: + \A|(?<=[^a-zA-Z]) # at the top or a non-letter + ) + [A-Z][a-zA-Z\d]* # Function + (?=\[) # [ as a character + ''', re.VERBOSE) + + # list or matrix pattern (for future usage) + ARG_MTRX_PATTERN = re.compile(r''' + \{.*\} + ''', re.VERBOSE) + + # regex string for function argument pattern + ARGS_PATTERN_TEMPLATE = r''' + (?: + \A|(?<=[^a-zA-Z]) + ) + {arguments} # model argument like x, y,... + (?=[^a-zA-Z]) + ''' + + # will contain transformed CORRESPONDENCES dictionary + TRANSLATIONS: dict[tuple[str, int], dict[str, Any]] = {} + + # cache for a raw users' translation dictionary + cache_original: dict[tuple[str, int], dict[str, Any]] = {} + + # cache for a compiled users' translation dictionary + cache_compiled: dict[tuple[str, int], dict[str, Any]] = {} + + @classmethod + def _initialize_class(cls): + # get a transformed CORRESPONDENCES dictionary + d = cls._compile_dictionary(cls.CORRESPONDENCES) + cls.TRANSLATIONS.update(d) + + def __init__(self, additional_translations=None): + self.translations = {} + + # update with TRANSLATIONS (class constant) + self.translations.update(self.TRANSLATIONS) + + if additional_translations is None: + additional_translations = {} + + # check the latest added translations + if self.__class__.cache_original != additional_translations: + if not isinstance(additional_translations, dict): + raise ValueError('The argument must be dict type') + + # get a transformed additional_translations dictionary + d = self._compile_dictionary(additional_translations) + + # update cache + self.__class__.cache_original = additional_translations + self.__class__.cache_compiled = d + + # merge user's own translations + self.translations.update(self.__class__.cache_compiled) + + @classmethod + def _compile_dictionary(cls, dic): + # for return + d = {} + + for fm, fs in dic.items(): + # check function form + cls._check_input(fm) + cls._check_input(fs) + + # uncover '*' hiding behind a whitespace + fm = cls._apply_rules(fm, 'whitespace') + fs = cls._apply_rules(fs, 'whitespace') + + # remove whitespace(s) + fm = cls._replace(fm, ' ') + fs = cls._replace(fs, ' ') + + # search Mathematica function name + m = cls.FM_PATTERN.search(fm) + + # if no-hit + if m is None: + err = "'{f}' function form is invalid.".format(f=fm) + raise ValueError(err) + + # get Mathematica function name like 'Log' + fm_name = m.group() + + # get arguments of Mathematica function + args, end = cls._get_args(m) + + # function side check. (e.g.) '2*Func[x]' is invalid. + if m.start() != 0 or end != len(fm): + err = "'{f}' function form is invalid.".format(f=fm) + raise ValueError(err) + + # check the last argument's 1st character + if args[-1][0] == '*': + key_arg = '*' + else: + key_arg = len(args) + + key = (fm_name, key_arg) + + # convert '*x' to '\\*x' for regex + re_args = [x if x[0] != '*' else '\\' + x for x in args] + + # for regex. Example: (?:(x|y|z)) + xyz = '(?:(' + '|'.join(re_args) + '))' + + # string for regex compile + patStr = cls.ARGS_PATTERN_TEMPLATE.format(arguments=xyz) + + pat = re.compile(patStr, re.VERBOSE) + + # update dictionary + d[key] = {} + d[key]['fs'] = fs # SymPy function template + d[key]['args'] = args # args are ['x', 'y'] for example + d[key]['pat'] = pat + + return d + + def _convert_function(self, s): + '''Parse Mathematica function to SymPy one''' + + # compiled regex object + pat = self.FM_PATTERN + + scanned = '' # converted string + cur = 0 # position cursor + while True: + m = pat.search(s) + + if m is None: + # append the rest of string + scanned += s + break + + # get Mathematica function name + fm = m.group() + + # get arguments, and the end position of fm function + args, end = self._get_args(m) + + # the start position of fm function + bgn = m.start() + + # convert Mathematica function to SymPy one + s = self._convert_one_function(s, fm, args, bgn, end) + + # update cursor + cur = bgn + + # append converted part + scanned += s[:cur] + + # shrink s + s = s[cur:] + + return scanned + + def _convert_one_function(self, s, fm, args, bgn, end): + # no variable-length argument + if (fm, len(args)) in self.translations: + key = (fm, len(args)) + + # x, y,... model arguments + x_args = self.translations[key]['args'] + + # make CORRESPONDENCES between model arguments and actual ones + d = dict(zip(x_args, args)) + + # with variable-length argument + elif (fm, '*') in self.translations: + key = (fm, '*') + + # x, y,..*args (model arguments) + x_args = self.translations[key]['args'] + + # make CORRESPONDENCES between model arguments and actual ones + d = {} + for i, x in enumerate(x_args): + if x[0] == '*': + d[x] = ','.join(args[i:]) + break + d[x] = args[i] + + # out of self.translations + else: + err = "'{f}' is out of the whitelist.".format(f=fm) + raise ValueError(err) + + # template string of converted function + template = self.translations[key]['fs'] + + # regex pattern for x_args + pat = self.translations[key]['pat'] + + scanned = '' + cur = 0 + while True: + m = pat.search(template) + + if m is None: + scanned += template + break + + # get model argument + x = m.group() + + # get a start position of the model argument + xbgn = m.start() + + # add the corresponding actual argument + scanned += template[:xbgn] + d[x] + + # update cursor to the end of the model argument + cur = m.end() + + # shrink template + template = template[cur:] + + # update to swapped string + s = s[:bgn] + scanned + s[end:] + + return s + + @classmethod + def _get_args(cls, m): + '''Get arguments of a Mathematica function''' + + s = m.string # whole string + anc = m.end() + 1 # pointing the first letter of arguments + square, curly = [], [] # stack for brackets + args = [] + + # current cursor + cur = anc + for i, c in enumerate(s[anc:], anc): + # extract one argument + if c == ',' and (not square) and (not curly): + args.append(s[cur:i]) # add an argument + cur = i + 1 # move cursor + + # handle list or matrix (for future usage) + if c == '{': + curly.append(c) + elif c == '}': + curly.pop() + + # seek corresponding ']' with skipping irrevant ones + if c == '[': + square.append(c) + elif c == ']': + if square: + square.pop() + else: # empty stack + args.append(s[cur:i]) + break + + # the next position to ']' bracket (the function end) + func_end = i + 1 + + return args, func_end + + @classmethod + def _replace(cls, s, bef): + aft = cls.REPLACEMENTS[bef] + s = s.replace(bef, aft) + return s + + @classmethod + def _apply_rules(cls, s, bef): + pat, aft = cls.RULES[bef] + return pat.sub(aft, s) + + @classmethod + def _check_input(cls, s): + for bracket in (('[', ']'), ('{', '}'), ('(', ')')): + if s.count(bracket[0]) != s.count(bracket[1]): + err = "'{f}' function form is invalid.".format(f=s) + raise ValueError(err) + + if '{' in s: + err = "Currently list is not supported." + raise ValueError(err) + + def _parse_old(self, s): + # input check + self._check_input(s) + + # uncover '*' hiding behind a whitespace + s = self._apply_rules(s, 'whitespace') + + # remove whitespace(s) + s = self._replace(s, ' ') + + # add omitted '*' character + s = self._apply_rules(s, 'add*_1') + s = self._apply_rules(s, 'add*_2') + + # translate function + s = self._convert_function(s) + + # '^' to '**' + s = self._replace(s, '^') + + # 'Pi' to 'pi' + s = self._apply_rules(s, 'Pi') + + # '{', '}' to '[', ']', respectively +# s = cls._replace(s, '{') # currently list is not taken into account +# s = cls._replace(s, '}') + + return s + + def parse(self, s): + s2 = self._from_mathematica_to_tokens(s) + s3 = self._from_tokens_to_fullformlist(s2) + s4 = self._from_fullformlist_to_sympy(s3) + return s4 + + INFIX = "Infix" + PREFIX = "Prefix" + POSTFIX = "Postfix" + FLAT = "Flat" + RIGHT = "Right" + LEFT = "Left" + + _mathematica_op_precedence: list[tuple[str, str | None, dict[str, str | Callable]]] = [ + (POSTFIX, None, {";": lambda x: x + ["Null"] if isinstance(x, list) and x and x[0] == "CompoundExpression" else ["CompoundExpression", x, "Null"]}), + (INFIX, FLAT, {";": "CompoundExpression"}), + (INFIX, RIGHT, {"=": "Set", ":=": "SetDelayed", "+=": "AddTo", "-=": "SubtractFrom", "*=": "TimesBy", "/=": "DivideBy"}), + (INFIX, LEFT, {"//": lambda x, y: [x, y]}), + (POSTFIX, None, {"&": "Function"}), + (INFIX, LEFT, {"/.": "ReplaceAll"}), + (INFIX, RIGHT, {"->": "Rule", ":>": "RuleDelayed"}), + (INFIX, LEFT, {"/;": "Condition"}), + (INFIX, FLAT, {"|": "Alternatives"}), + (POSTFIX, None, {"..": "Repeated", "...": "RepeatedNull"}), + (INFIX, FLAT, {"||": "Or"}), + (INFIX, FLAT, {"&&": "And"}), + (PREFIX, None, {"!": "Not"}), + (INFIX, FLAT, {"===": "SameQ", "=!=": "UnsameQ"}), + (INFIX, FLAT, {"==": "Equal", "!=": "Unequal", "<=": "LessEqual", "<": "Less", ">=": "GreaterEqual", ">": "Greater"}), + (INFIX, None, {";;": "Span"}), + (INFIX, FLAT, {"+": "Plus", "-": "Plus"}), + (INFIX, FLAT, {"*": "Times", "/": "Times"}), + (INFIX, FLAT, {".": "Dot"}), + (PREFIX, None, {"-": lambda x: MathematicaParser._get_neg(x), + "+": lambda x: x}), + (INFIX, RIGHT, {"^": "Power"}), + (INFIX, RIGHT, {"@@": "Apply", "/@": "Map", "//@": "MapAll", "@@@": lambda x, y: ["Apply", x, y, ["List", "1"]]}), + (POSTFIX, None, {"'": "Derivative", "!": "Factorial", "!!": "Factorial2", "--": "Decrement"}), + (INFIX, None, {"[": lambda x, y: [x, *y], "[[": lambda x, y: ["Part", x, *y]}), + (PREFIX, None, {"{": lambda x: ["List", *x], "(": lambda x: x[0]}), + (INFIX, None, {"?": "PatternTest"}), + (POSTFIX, None, { + "_": lambda x: ["Pattern", x, ["Blank"]], + "_.": lambda x: ["Optional", ["Pattern", x, ["Blank"]]], + "__": lambda x: ["Pattern", x, ["BlankSequence"]], + "___": lambda x: ["Pattern", x, ["BlankNullSequence"]], + }), + (INFIX, None, {"_": lambda x, y: ["Pattern", x, ["Blank", y]]}), + (PREFIX, None, {"#": "Slot", "##": "SlotSequence"}), + ] + + _missing_arguments_default = { + "#": lambda: ["Slot", "1"], + "##": lambda: ["SlotSequence", "1"], + } + + _literal = r"[A-Za-z][A-Za-z0-9]*" + _number = r"(?:[0-9]+(?:\.[0-9]*)?|\.[0-9]+)" + + _enclosure_open = ["(", "[", "[[", "{"] + _enclosure_close = [")", "]", "]]", "}"] + + @classmethod + def _get_neg(cls, x): + return f"-{x}" if isinstance(x, str) and re.match(MathematicaParser._number, x) else ["Times", "-1", x] + + @classmethod + def _get_inv(cls, x): + return ["Power", x, "-1"] + + _regex_tokenizer = None + + def _get_tokenizer(self): + if self._regex_tokenizer is not None: + # Check if the regular expression has already been compiled: + return self._regex_tokenizer + tokens = [self._literal, self._number] + tokens_escape = self._enclosure_open[:] + self._enclosure_close[:] + for typ, strat, symdict in self._mathematica_op_precedence: + for k in symdict: + tokens_escape.append(k) + tokens_escape.sort(key=lambda x: -len(x)) + tokens.extend(map(re.escape, tokens_escape)) + tokens.append(",") + tokens.append("\n") + tokenizer = re.compile("(" + "|".join(tokens) + ")") + self._regex_tokenizer = tokenizer + return self._regex_tokenizer + + def _from_mathematica_to_tokens(self, code: str): + tokenizer = self._get_tokenizer() + + # Find strings: + code_splits: list[str | list] = [] + while True: + string_start = code.find("\"") + if string_start == -1: + if len(code) > 0: + code_splits.append(code) + break + match_end = re.search(r'(? 0: + code_splits.append(code[:string_start]) + code_splits.append(["_Str", code[string_start+1:string_end].replace('\\"', '"')]) + code = code[string_end+1:] + + # Remove comments: + for i, code_split in enumerate(code_splits): + if isinstance(code_split, list): + continue + while True: + pos_comment_start = code_split.find("(*") + if pos_comment_start == -1: + break + pos_comment_end = code_split.find("*)") + if pos_comment_end == -1 or pos_comment_end < pos_comment_start: + raise SyntaxError("mismatch in comment (* *) code") + code_split = code_split[:pos_comment_start] + code_split[pos_comment_end+2:] + code_splits[i] = code_split + + # Tokenize the input strings with a regular expression: + token_lists = [tokenizer.findall(i) if isinstance(i, str) and i.isascii() else [i] for i in code_splits] + tokens = [j for i in token_lists for j in i] + + # Remove newlines at the beginning + while tokens and tokens[0] == "\n": + tokens.pop(0) + # Remove newlines at the end + while tokens and tokens[-1] == "\n": + tokens.pop(-1) + + return tokens + + def _is_op(self, token: str | list) -> bool: + if isinstance(token, list): + return False + if re.match(self._literal, token): + return False + if re.match("-?" + self._number, token): + return False + return True + + def _is_valid_star1(self, token: str | list) -> bool: + if token in (")", "}"): + return True + return not self._is_op(token) + + def _is_valid_star2(self, token: str | list) -> bool: + if token in ("(", "{"): + return True + return not self._is_op(token) + + def _from_tokens_to_fullformlist(self, tokens: list): + stack: list[list] = [[]] + open_seq = [] + pointer: int = 0 + while pointer < len(tokens): + token = tokens[pointer] + if token in self._enclosure_open: + stack[-1].append(token) + open_seq.append(token) + stack.append([]) + elif token == ",": + if len(stack[-1]) == 0 and stack[-2][-1] == open_seq[-1]: + raise SyntaxError("%s cannot be followed by comma ," % open_seq[-1]) + stack[-1] = self._parse_after_braces(stack[-1]) + stack.append([]) + elif token in self._enclosure_close: + ind = self._enclosure_close.index(token) + if self._enclosure_open[ind] != open_seq[-1]: + unmatched_enclosure = SyntaxError("unmatched enclosure") + if token == "]]" and open_seq[-1] == "[": + if open_seq[-2] == "[": + # These two lines would be logically correct, but are + # unnecessary: + # token = "]" + # tokens[pointer] = "]" + tokens.insert(pointer+1, "]") + elif open_seq[-2] == "[[": + if tokens[pointer+1] == "]": + tokens[pointer+1] = "]]" + elif tokens[pointer+1] == "]]": + tokens[pointer+1] = "]]" + tokens.insert(pointer+2, "]") + else: + raise unmatched_enclosure + else: + raise unmatched_enclosure + if len(stack[-1]) == 0 and stack[-2][-1] == "(": + raise SyntaxError("( ) not valid syntax") + last_stack = self._parse_after_braces(stack[-1], True) + stack[-1] = last_stack + new_stack_element = [] + while stack[-1][-1] != open_seq[-1]: + new_stack_element.append(stack.pop()) + new_stack_element.reverse() + if open_seq[-1] == "(" and len(new_stack_element) != 1: + raise SyntaxError("( must be followed by one expression, %i detected" % len(new_stack_element)) + stack[-1].append(new_stack_element) + open_seq.pop(-1) + else: + stack[-1].append(token) + pointer += 1 + if len(stack) != 1: + raise RuntimeError("Stack should have only one element") + return self._parse_after_braces(stack[0]) + + def _util_remove_newlines(self, lines: list, tokens: list, inside_enclosure: bool): + pointer = 0 + size = len(tokens) + while pointer < size: + token = tokens[pointer] + if token == "\n": + if inside_enclosure: + # Ignore newlines inside enclosures + tokens.pop(pointer) + size -= 1 + continue + if pointer == 0: + tokens.pop(0) + size -= 1 + continue + if pointer > 1: + try: + prev_expr = self._parse_after_braces(tokens[:pointer], inside_enclosure) + except SyntaxError: + tokens.pop(pointer) + size -= 1 + continue + else: + prev_expr = tokens[0] + if len(prev_expr) > 0 and prev_expr[0] == "CompoundExpression": + lines.extend(prev_expr[1:]) + else: + lines.append(prev_expr) + for i in range(pointer): + tokens.pop(0) + size -= pointer + pointer = 0 + continue + pointer += 1 + + def _util_add_missing_asterisks(self, tokens: list): + size: int = len(tokens) + pointer: int = 0 + while pointer < size: + if (pointer > 0 and + self._is_valid_star1(tokens[pointer - 1]) and + self._is_valid_star2(tokens[pointer])): + # This is a trick to add missing * operators in the expression, + # `"*" in op_dict` makes sure the precedence level is the same as "*", + # while `not self._is_op( ... )` makes sure this and the previous + # expression are not operators. + if tokens[pointer] == "(": + # ( has already been processed by now, replace: + tokens[pointer] = "*" + tokens[pointer + 1] = tokens[pointer + 1][0] + else: + tokens.insert(pointer, "*") + pointer += 1 + size += 1 + pointer += 1 + + def _parse_after_braces(self, tokens: list, inside_enclosure: bool = False): + op_dict: dict + changed: bool = False + lines: list = [] + + self._util_remove_newlines(lines, tokens, inside_enclosure) + + for op_type, grouping_strat, op_dict in reversed(self._mathematica_op_precedence): + if "*" in op_dict: + self._util_add_missing_asterisks(tokens) + size: int = len(tokens) + pointer: int = 0 + while pointer < size: + token = tokens[pointer] + if isinstance(token, str) and token in op_dict: + op_name: str | Callable = op_dict[token] + node: list + first_index: int + if isinstance(op_name, str): + node = [op_name] + first_index = 1 + else: + node = [] + first_index = 0 + if token in ("+", "-") and op_type == self.PREFIX and pointer > 0 and not self._is_op(tokens[pointer - 1]): + # Make sure that PREFIX + - don't match expressions like a + b or a - b, + # the INFIX + - are supposed to match that expression: + pointer += 1 + continue + if op_type == self.INFIX: + if pointer == 0 or pointer == size - 1 or self._is_op(tokens[pointer - 1]) or self._is_op(tokens[pointer + 1]): + pointer += 1 + continue + changed = True + tokens[pointer] = node + if op_type == self.INFIX: + arg1 = tokens.pop(pointer-1) + arg2 = tokens.pop(pointer) + if token == "/": + arg2 = self._get_inv(arg2) + elif token == "-": + arg2 = self._get_neg(arg2) + pointer -= 1 + size -= 2 + node.append(arg1) + node_p = node + if grouping_strat == self.FLAT: + while pointer + 2 < size and self._check_op_compatible(tokens[pointer+1], token): + node_p.append(arg2) + other_op = tokens.pop(pointer+1) + arg2 = tokens.pop(pointer+1) + if other_op == "/": + arg2 = self._get_inv(arg2) + elif other_op == "-": + arg2 = self._get_neg(arg2) + size -= 2 + node_p.append(arg2) + elif grouping_strat == self.RIGHT: + while pointer + 2 < size and tokens[pointer+1] == token: + node_p.append([op_name, arg2]) + node_p = node_p[-1] + tokens.pop(pointer+1) + arg2 = tokens.pop(pointer+1) + size -= 2 + node_p.append(arg2) + elif grouping_strat == self.LEFT: + while pointer + 1 < size and tokens[pointer+1] == token: + if isinstance(op_name, str): + node_p[first_index] = [op_name, node_p[first_index], arg2] + else: + node_p[first_index] = op_name(node_p[first_index], arg2) + tokens.pop(pointer+1) + arg2 = tokens.pop(pointer+1) + size -= 2 + node_p.append(arg2) + else: + node.append(arg2) + elif op_type == self.PREFIX: + if grouping_strat is not None: + raise TypeError("'Prefix' op_type should not have a grouping strat") + if pointer == size - 1 or self._is_op(tokens[pointer + 1]): + tokens[pointer] = self._missing_arguments_default[token]() + else: + node.append(tokens.pop(pointer+1)) + size -= 1 + elif op_type == self.POSTFIX: + if grouping_strat is not None: + raise TypeError("'Prefix' op_type should not have a grouping strat") + if pointer == 0 or self._is_op(tokens[pointer - 1]): + tokens[pointer] = self._missing_arguments_default[token]() + else: + node.append(tokens.pop(pointer-1)) + pointer -= 1 + size -= 1 + if isinstance(op_name, Callable): # type: ignore + op_call: Callable = typing.cast(Callable, op_name) + new_node = op_call(*node) + node.clear() + if isinstance(new_node, list): + node.extend(new_node) + else: + tokens[pointer] = new_node + pointer += 1 + if len(tokens) > 1 or (len(lines) == 0 and len(tokens) == 0): + if changed: + # Trick to deal with cases in which an operator with lower + # precedence should be transformed before an operator of higher + # precedence. Such as in the case of `#&[x]` (that is + # equivalent to `Lambda(d_, d_)(x)` in SymPy). In this case the + # operator `&` has lower precedence than `[`, but needs to be + # evaluated first because otherwise `# (&[x])` is not a valid + # expression: + return self._parse_after_braces(tokens, inside_enclosure) + raise SyntaxError("unable to create a single AST for the expression") + if len(lines) > 0: + if tokens[0] and tokens[0][0] == "CompoundExpression": + tokens = tokens[0][1:] + compound_expression = ["CompoundExpression", *lines, *tokens] + return compound_expression + return tokens[0] + + def _check_op_compatible(self, op1: str, op2: str): + if op1 == op2: + return True + muldiv = {"*", "/"} + addsub = {"+", "-"} + if op1 in muldiv and op2 in muldiv: + return True + if op1 in addsub and op2 in addsub: + return True + return False + + def _from_fullform_to_fullformlist(self, wmexpr: str): + """ + Parses FullForm[Downvalues[]] generated by Mathematica + """ + out: list = [] + stack = [out] + generator = re.finditer(r'[\[\],]', wmexpr) + last_pos = 0 + for match in generator: + if match is None: + break + position = match.start() + last_expr = wmexpr[last_pos:position].replace(',', '').replace(']', '').replace('[', '').strip() + + if match.group() == ',': + if last_expr != '': + stack[-1].append(last_expr) + elif match.group() == ']': + if last_expr != '': + stack[-1].append(last_expr) + stack.pop() + elif match.group() == '[': + stack[-1].append([last_expr]) + stack.append(stack[-1][-1]) + last_pos = match.end() + return out[0] + + def _from_fullformlist_to_fullformsympy(self, pylist: list): + from sympy import Function, Symbol + + def converter(expr): + if isinstance(expr, list): + if len(expr) > 0: + head = expr[0] + args = [converter(arg) for arg in expr[1:]] + return Function(head)(*args) + else: + raise ValueError("Empty list of expressions") + elif isinstance(expr, str): + return Symbol(expr) + else: + return _sympify(expr) + + return converter(pylist) + + _node_conversions = { + "Times": Mul, + "Plus": Add, + "Power": Pow, + "Rational": Rational, + "Log": lambda *a: log(*reversed(a)), + "Log2": lambda x: log(x, 2), + "Log10": lambda x: log(x, 10), + "Exp": exp, + "Sqrt": sqrt, + + "Sin": sin, + "Cos": cos, + "Tan": tan, + "Cot": cot, + "Sec": sec, + "Csc": csc, + + "ArcSin": asin, + "ArcCos": acos, + "ArcTan": lambda *a: atan2(*reversed(a)) if len(a) == 2 else atan(*a), + "ArcCot": acot, + "ArcSec": asec, + "ArcCsc": acsc, + + "Sinh": sinh, + "Cosh": cosh, + "Tanh": tanh, + "Coth": coth, + "Sech": sech, + "Csch": csch, + + "ArcSinh": asinh, + "ArcCosh": acosh, + "ArcTanh": atanh, + "ArcCoth": acoth, + "ArcSech": asech, + "ArcCsch": acsch, + + "Expand": expand, + "Im": im, + "Re": sympy.re, + "Flatten": flatten, + "Polylog": polylog, + "Cancel": cancel, + # Gamma=gamma, + "TrigExpand": expand_trig, + "Sign": sign, + "Simplify": simplify, + "Defer": UnevaluatedExpr, + "Identity": S, + # Sum=Sum_doit, + # Module=With, + # Block=With, + "Null": lambda *a: S.Zero, + "Mod": Mod, + "Max": Max, + "Min": Min, + "Pochhammer": rf, + "ExpIntegralEi": Ei, + "SinIntegral": Si, + "CosIntegral": Ci, + "AiryAi": airyai, + "AiryAiPrime": airyaiprime, + "AiryBi": airybi, + "AiryBiPrime": airybiprime, + "LogIntegral": li, + "PrimePi": primepi, + "Prime": prime, + "PrimeQ": isprime, + + "List": Tuple, + "Greater": StrictGreaterThan, + "GreaterEqual": GreaterThan, + "Less": StrictLessThan, + "LessEqual": LessThan, + "Equal": Equality, + "Or": Or, + "And": And, + + "Function": _parse_Function, + } + + _atom_conversions = { + "I": I, + "Pi": pi, + } + + def _from_fullformlist_to_sympy(self, full_form_list): + + def recurse(expr): + if isinstance(expr, list): + if isinstance(expr[0], list): + head = recurse(expr[0]) + else: + head = self._node_conversions.get(expr[0], Function(expr[0])) + return head(*[recurse(arg) for arg in expr[1:]]) + else: + return self._atom_conversions.get(expr, sympify(expr)) + + return recurse(full_form_list) + + def _from_fullformsympy_to_sympy(self, mform): + + expr = mform + for mma_form, sympy_node in self._node_conversions.items(): + expr = expr.replace(Function(mma_form), sympy_node) + return expr diff --git a/phivenv/Lib/site-packages/sympy/parsing/maxima.py b/phivenv/Lib/site-packages/sympy/parsing/maxima.py new file mode 100644 index 0000000000000000000000000000000000000000..7a8ee5b17bb03a36e338803cb10f9ebf22763c2c --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/maxima.py @@ -0,0 +1,71 @@ +import re +from sympy.concrete.products import product +from sympy.concrete.summations import Sum +from sympy.core.sympify import sympify +from sympy.functions.elementary.trigonometric import (cos, sin) + + +class MaximaHelpers: + def maxima_expand(expr): + return expr.expand() + + def maxima_float(expr): + return expr.evalf() + + def maxima_trigexpand(expr): + return expr.expand(trig=True) + + def maxima_sum(a1, a2, a3, a4): + return Sum(a1, (a2, a3, a4)).doit() + + def maxima_product(a1, a2, a3, a4): + return product(a1, (a2, a3, a4)) + + def maxima_csc(expr): + return 1/sin(expr) + + def maxima_sec(expr): + return 1/cos(expr) + +sub_dict = { + 'pi': re.compile(r'%pi'), + 'E': re.compile(r'%e'), + 'I': re.compile(r'%i'), + '**': re.compile(r'\^'), + 'oo': re.compile(r'\binf\b'), + '-oo': re.compile(r'\bminf\b'), + "'-'": re.compile(r'\bminus\b'), + 'maxima_expand': re.compile(r'\bexpand\b'), + 'maxima_float': re.compile(r'\bfloat\b'), + 'maxima_trigexpand': re.compile(r'\btrigexpand'), + 'maxima_sum': re.compile(r'\bsum\b'), + 'maxima_product': re.compile(r'\bproduct\b'), + 'cancel': re.compile(r'\bratsimp\b'), + 'maxima_csc': re.compile(r'\bcsc\b'), + 'maxima_sec': re.compile(r'\bsec\b') +} + +var_name = re.compile(r'^\s*(\w+)\s*:') + + +def parse_maxima(str, globals=None, name_dict={}): + str = str.strip() + str = str.rstrip('; ') + + for k, v in sub_dict.items(): + str = v.sub(k, str) + + assign_var = None + var_match = var_name.search(str) + if var_match: + assign_var = var_match.group(1) + str = str[var_match.end():].strip() + + dct = MaximaHelpers.__dict__.copy() + dct.update(name_dict) + obj = sympify(str, locals=dct) + + if assign_var and globals: + globals[assign_var] = obj + + return obj diff --git a/phivenv/Lib/site-packages/sympy/parsing/sym_expr.py b/phivenv/Lib/site-packages/sympy/parsing/sym_expr.py new file mode 100644 index 0000000000000000000000000000000000000000..9dbd0e94eb51147b51825fcf15cbec5ae18bb1b6 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/sym_expr.py @@ -0,0 +1,279 @@ +from sympy.printing import pycode, ccode, fcode +from sympy.external import import_module +from sympy.utilities.decorator import doctest_depends_on + +lfortran = import_module('lfortran') +cin = import_module('clang.cindex', import_kwargs = {'fromlist': ['cindex']}) + +if lfortran: + from sympy.parsing.fortran.fortran_parser import src_to_sympy +if cin: + from sympy.parsing.c.c_parser import parse_c + +@doctest_depends_on(modules=['lfortran', 'clang.cindex']) +class SymPyExpression: # type: ignore + """Class to store and handle SymPy expressions + + This class will hold SymPy Expressions and handle the API for the + conversion to and from different languages. + + It works with the C and the Fortran Parser to generate SymPy expressions + which are stored here and which can be converted to multiple language's + source code. + + Notes + ===== + + The module and its API are currently under development and experimental + and can be changed during development. + + The Fortran parser does not support numeric assignments, so all the + variables have been Initialized to zero. + + The module also depends on external dependencies: + + - LFortran which is required to use the Fortran parser + - Clang which is required for the C parser + + Examples + ======== + + Example of parsing C code: + + >>> from sympy.parsing.sym_expr import SymPyExpression + >>> src = ''' + ... int a,b; + ... float c = 2, d =4; + ... ''' + >>> a = SymPyExpression(src, 'c') + >>> a.return_expr() + [Declaration(Variable(a, type=intc)), + Declaration(Variable(b, type=intc)), + Declaration(Variable(c, type=float32, value=2.0)), + Declaration(Variable(d, type=float32, value=4.0))] + + An example of variable definition: + + >>> from sympy.parsing.sym_expr import SymPyExpression + >>> src2 = ''' + ... integer :: a, b, c, d + ... real :: p, q, r, s + ... ''' + >>> p = SymPyExpression() + >>> p.convert_to_expr(src2, 'f') + >>> p.convert_to_c() + ['int a = 0', 'int b = 0', 'int c = 0', 'int d = 0', 'double p = 0.0', 'double q = 0.0', 'double r = 0.0', 'double s = 0.0'] + + An example of Assignment: + + >>> from sympy.parsing.sym_expr import SymPyExpression + >>> src3 = ''' + ... integer :: a, b, c, d, e + ... d = a + b - c + ... e = b * d + c * e / a + ... ''' + >>> p = SymPyExpression(src3, 'f') + >>> p.convert_to_python() + ['a = 0', 'b = 0', 'c = 0', 'd = 0', 'e = 0', 'd = a + b - c', 'e = b*d + c*e/a'] + + An example of function definition: + + >>> from sympy.parsing.sym_expr import SymPyExpression + >>> src = ''' + ... integer function f(a,b) + ... integer, intent(in) :: a, b + ... integer :: r + ... end function + ... ''' + >>> a = SymPyExpression(src, 'f') + >>> a.convert_to_python() + ['def f(a, b):\\n f = 0\\n r = 0\\n return f'] + + """ + + def __init__(self, source_code = None, mode = None): + """Constructor for SymPyExpression class""" + super().__init__() + if not(mode or source_code): + self._expr = [] + elif mode: + if source_code: + if mode.lower() == 'f': + if lfortran: + self._expr = src_to_sympy(source_code) + else: + raise ImportError("LFortran is not installed, cannot parse Fortran code") + elif mode.lower() == 'c': + if cin: + self._expr = parse_c(source_code) + else: + raise ImportError("Clang is not installed, cannot parse C code") + else: + raise NotImplementedError( + 'Parser for specified language is not implemented' + ) + else: + raise ValueError('Source code not present') + else: + raise ValueError('Please specify a mode for conversion') + + def convert_to_expr(self, src_code, mode): + """Converts the given source code to SymPy Expressions + + Attributes + ========== + + src_code : String + the source code or filename of the source code that is to be + converted + + mode: String + the mode to determine which parser is to be used according to + the language of the source code + f or F for Fortran + c or C for C/C++ + + Examples + ======== + + >>> from sympy.parsing.sym_expr import SymPyExpression + >>> src3 = ''' + ... integer function f(a,b) result(r) + ... integer, intent(in) :: a, b + ... integer :: x + ... r = a + b -x + ... end function + ... ''' + >>> p = SymPyExpression() + >>> p.convert_to_expr(src3, 'f') + >>> p.return_expr() + [FunctionDefinition(integer, name=f, parameters=(Variable(a), Variable(b)), body=CodeBlock( + Declaration(Variable(r, type=integer, value=0)), + Declaration(Variable(x, type=integer, value=0)), + Assignment(Variable(r), a + b - x), + Return(Variable(r)) + ))] + + + + + """ + if mode.lower() == 'f': + if lfortran: + self._expr = src_to_sympy(src_code) + else: + raise ImportError("LFortran is not installed, cannot parse Fortran code") + elif mode.lower() == 'c': + if cin: + self._expr = parse_c(src_code) + else: + raise ImportError("Clang is not installed, cannot parse C code") + else: + raise NotImplementedError( + "Parser for specified language has not been implemented" + ) + + def convert_to_python(self): + """Returns a list with Python code for the SymPy expressions + + Examples + ======== + + >>> from sympy.parsing.sym_expr import SymPyExpression + >>> src2 = ''' + ... integer :: a, b, c, d + ... real :: p, q, r, s + ... c = a/b + ... d = c/a + ... s = p/q + ... r = q/p + ... ''' + >>> p = SymPyExpression(src2, 'f') + >>> p.convert_to_python() + ['a = 0', 'b = 0', 'c = 0', 'd = 0', 'p = 0.0', 'q = 0.0', 'r = 0.0', 's = 0.0', 'c = a/b', 'd = c/a', 's = p/q', 'r = q/p'] + + """ + self._pycode = [] + for iter in self._expr: + self._pycode.append(pycode(iter)) + return self._pycode + + def convert_to_c(self): + """Returns a list with the c source code for the SymPy expressions + + + Examples + ======== + + >>> from sympy.parsing.sym_expr import SymPyExpression + >>> src2 = ''' + ... integer :: a, b, c, d + ... real :: p, q, r, s + ... c = a/b + ... d = c/a + ... s = p/q + ... r = q/p + ... ''' + >>> p = SymPyExpression() + >>> p.convert_to_expr(src2, 'f') + >>> p.convert_to_c() + ['int a = 0', 'int b = 0', 'int c = 0', 'int d = 0', 'double p = 0.0', 'double q = 0.0', 'double r = 0.0', 'double s = 0.0', 'c = a/b;', 'd = c/a;', 's = p/q;', 'r = q/p;'] + + """ + self._ccode = [] + for iter in self._expr: + self._ccode.append(ccode(iter)) + return self._ccode + + def convert_to_fortran(self): + """Returns a list with the fortran source code for the SymPy expressions + + Examples + ======== + + >>> from sympy.parsing.sym_expr import SymPyExpression + >>> src2 = ''' + ... integer :: a, b, c, d + ... real :: p, q, r, s + ... c = a/b + ... d = c/a + ... s = p/q + ... r = q/p + ... ''' + >>> p = SymPyExpression(src2, 'f') + >>> p.convert_to_fortran() + [' integer*4 a', ' integer*4 b', ' integer*4 c', ' integer*4 d', ' real*8 p', ' real*8 q', ' real*8 r', ' real*8 s', ' c = a/b', ' d = c/a', ' s = p/q', ' r = q/p'] + + """ + self._fcode = [] + for iter in self._expr: + self._fcode.append(fcode(iter)) + return self._fcode + + def return_expr(self): + """Returns the expression list + + Examples + ======== + + >>> from sympy.parsing.sym_expr import SymPyExpression + >>> src3 = ''' + ... integer function f(a,b) + ... integer, intent(in) :: a, b + ... integer :: r + ... r = a+b + ... f = r + ... end function + ... ''' + >>> p = SymPyExpression() + >>> p.convert_to_expr(src3, 'f') + >>> p.return_expr() + [FunctionDefinition(integer, name=f, parameters=(Variable(a), Variable(b)), body=CodeBlock( + Declaration(Variable(f, type=integer, value=0)), + Declaration(Variable(r, type=integer, value=0)), + Assignment(Variable(f), Variable(r)), + Return(Variable(f)) + ))] + + """ + return self._expr diff --git a/phivenv/Lib/site-packages/sympy/parsing/sympy_parser.py b/phivenv/Lib/site-packages/sympy/parsing/sympy_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..9cfda9ce0f73ffa3773031c48b9e9c245f69fe0b --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/sympy_parser.py @@ -0,0 +1,1270 @@ +"""Transform a string with Python-like source code into SymPy expression. """ +from __future__ import annotations +from tokenize import (generate_tokens, untokenize, TokenError, + NUMBER, STRING, NAME, OP, ENDMARKER, ERRORTOKEN, NEWLINE) + +from keyword import iskeyword + +import ast +import unicodedata +from io import StringIO +import builtins +import types +from typing import Any, Callable +from functools import reduce +from sympy.assumptions.ask import AssumptionKeys +from sympy.core.basic import Basic +from sympy.core import Symbol +from sympy.core.function import Function +from sympy.utilities.misc import func_name +from sympy.functions.elementary.miscellaneous import Max, Min + + +null = '' + +TOKEN = tuple[int, str] +DICT = dict[str, Any] +TRANS = Callable[[list[TOKEN], DICT, DICT], list[TOKEN]] + +def _token_splittable(token_name: str) -> bool: + """ + Predicate for whether a token name can be split into multiple tokens. + + A token is splittable if it does not contain an underscore character and + it is not the name of a Greek letter. This is used to implicitly convert + expressions like 'xyz' into 'x*y*z'. + """ + if '_' in token_name: + return False + try: + return not unicodedata.lookup('GREEK SMALL LETTER ' + token_name) + except KeyError: + return len(token_name) > 1 + + +def _token_callable(token: TOKEN, local_dict: DICT, global_dict: DICT, nextToken=None): + """ + Predicate for whether a token name represents a callable function. + + Essentially wraps ``callable``, but looks up the token name in the + locals and globals. + """ + func = local_dict.get(token[1]) + if not func: + func = global_dict.get(token[1]) + return callable(func) and not isinstance(func, Symbol) + + +def _add_factorial_tokens(name: str, result: list[TOKEN]) -> list[TOKEN]: + if result == [] or result[-1][1] == '(': + raise TokenError() + + beginning = [(NAME, name), (OP, '(')] + end = [(OP, ')')] + + diff = 0 + length = len(result) + + for index, token in enumerate(result[::-1]): + toknum, tokval = token + i = length - index - 1 + + if tokval == ')': + diff += 1 + elif tokval == '(': + diff -= 1 + + if diff == 0: + if i - 1 >= 0 and result[i - 1][0] == NAME: + return result[:i - 1] + beginning + result[i - 1:] + end + else: + return result[:i] + beginning + result[i:] + end + + return result + + +class ParenthesisGroup(list[TOKEN]): + """List of tokens representing an expression in parentheses.""" + pass + + +class AppliedFunction: + """ + A group of tokens representing a function and its arguments. + + `exponent` is for handling the shorthand sin^2, ln^2, etc. + """ + def __init__(self, function: TOKEN, args: ParenthesisGroup, exponent=None): + if exponent is None: + exponent = [] + self.function = function + self.args = args + self.exponent = exponent + self.items = ['function', 'args', 'exponent'] + + def expand(self) -> list[TOKEN]: + """Return a list of tokens representing the function""" + return [self.function, *self.args] + + def __getitem__(self, index): + return getattr(self, self.items[index]) + + def __repr__(self): + return "AppliedFunction(%s, %s, %s)" % (self.function, self.args, + self.exponent) + + +def _flatten(result: list[TOKEN | AppliedFunction]): + result2: list[TOKEN] = [] + for tok in result: + if isinstance(tok, AppliedFunction): + result2.extend(tok.expand()) + else: + result2.append(tok) + return result2 + + +def _group_parentheses(recursor: TRANS): + def _inner(tokens: list[TOKEN], local_dict: DICT, global_dict: DICT): + """Group tokens between parentheses with ParenthesisGroup. + + Also processes those tokens recursively. + + """ + result: list[TOKEN | ParenthesisGroup] = [] + stacks: list[ParenthesisGroup] = [] + stacklevel = 0 + for token in tokens: + if token[0] == OP: + if token[1] == '(': + stacks.append(ParenthesisGroup([])) + stacklevel += 1 + elif token[1] == ')': + stacks[-1].append(token) + stack = stacks.pop() + + if len(stacks) > 0: + # We don't recurse here since the upper-level stack + # would reprocess these tokens + stacks[-1].extend(stack) + else: + # Recurse here to handle nested parentheses + # Strip off the outer parentheses to avoid an infinite loop + inner = stack[1:-1] + inner = recursor(inner, + local_dict, + global_dict) + parenGroup = [stack[0]] + inner + [stack[-1]] + result.append(ParenthesisGroup(parenGroup)) + stacklevel -= 1 + continue + if stacklevel: + stacks[-1].append(token) + else: + result.append(token) + if stacklevel: + raise TokenError("Mismatched parentheses") + return result + return _inner + + +def _apply_functions(tokens: list[TOKEN | ParenthesisGroup], local_dict: DICT, global_dict: DICT): + """Convert a NAME token + ParenthesisGroup into an AppliedFunction. + + Note that ParenthesisGroups, if not applied to any function, are + converted back into lists of tokens. + + """ + result: list[TOKEN | AppliedFunction] = [] + symbol = None + for tok in tokens: + if isinstance(tok, ParenthesisGroup): + if symbol and _token_callable(symbol, local_dict, global_dict): + result[-1] = AppliedFunction(symbol, tok) + symbol = None + else: + result.extend(tok) + elif tok[0] == NAME: + symbol = tok + result.append(tok) + else: + symbol = None + result.append(tok) + return result + + +def _implicit_multiplication(tokens: list[TOKEN | AppliedFunction], local_dict: DICT, global_dict: DICT): + """Implicitly adds '*' tokens. + + Cases: + + - Two AppliedFunctions next to each other ("sin(x)cos(x)") + + - AppliedFunction next to an open parenthesis ("sin x (cos x + 1)") + + - A close parenthesis next to an AppliedFunction ("(x+2)sin x")\ + + - A close parenthesis next to an open parenthesis ("(x+2)(x+3)") + + - AppliedFunction next to an implicitly applied function ("sin(x)cos x") + + """ + result: list[TOKEN | AppliedFunction] = [] + skip = False + for tok, nextTok in zip(tokens, tokens[1:]): + result.append(tok) + if skip: + skip = False + continue + if tok[0] == OP and tok[1] == '.' and nextTok[0] == NAME: + # Dotted name. Do not do implicit multiplication + skip = True + continue + if isinstance(tok, AppliedFunction): + if isinstance(nextTok, AppliedFunction): + result.append((OP, '*')) + elif nextTok == (OP, '('): + # Applied function followed by an open parenthesis + if tok.function[1] == "Function": + tok.function = (tok.function[0], 'Symbol') + result.append((OP, '*')) + elif nextTok[0] == NAME: + # Applied function followed by implicitly applied function + result.append((OP, '*')) + else: + if tok == (OP, ')'): + if isinstance(nextTok, AppliedFunction): + # Close parenthesis followed by an applied function + result.append((OP, '*')) + elif nextTok[0] == NAME: + # Close parenthesis followed by an implicitly applied function + result.append((OP, '*')) + elif nextTok == (OP, '('): + # Close parenthesis followed by an open parenthesis + result.append((OP, '*')) + elif tok[0] == NAME and not _token_callable(tok, local_dict, global_dict): + if isinstance(nextTok, AppliedFunction) or \ + (nextTok[0] == NAME and _token_callable(nextTok, local_dict, global_dict)): + # Constant followed by (implicitly applied) function + result.append((OP, '*')) + elif nextTok == (OP, '('): + # Constant followed by parenthesis + result.append((OP, '*')) + elif nextTok[0] == NAME: + # Constant followed by constant + result.append((OP, '*')) + if tokens: + result.append(tokens[-1]) + return result + + +def _implicit_application(tokens: list[TOKEN | AppliedFunction], local_dict: DICT, global_dict: DICT): + """Adds parentheses as needed after functions.""" + result: list[TOKEN | AppliedFunction] = [] + appendParen = 0 # number of closing parentheses to add + skip = 0 # number of tokens to delay before adding a ')' (to + # capture **, ^, etc.) + exponentSkip = False # skipping tokens before inserting parentheses to + # work with function exponentiation + for tok, nextTok in zip(tokens, tokens[1:]): + result.append(tok) + if (tok[0] == NAME and nextTok[0] not in [OP, ENDMARKER, NEWLINE]): + if _token_callable(tok, local_dict, global_dict, nextTok): # type: ignore + result.append((OP, '(')) + appendParen += 1 + # name followed by exponent - function exponentiation + elif (tok[0] == NAME and nextTok[0] == OP and nextTok[1] == '**'): + if _token_callable(tok, local_dict, global_dict): # type: ignore + exponentSkip = True + elif exponentSkip: + # if the last token added was an applied function (i.e. the + # power of the function exponent) OR a multiplication (as + # implicit multiplication would have added an extraneous + # multiplication) + if (isinstance(tok, AppliedFunction) + or (tok[0] == OP and tok[1] == '*')): + # don't add anything if the next token is a multiplication + # or if there's already a parenthesis (if parenthesis, still + # stop skipping tokens) + if not (nextTok[0] == OP and nextTok[1] == '*'): + if not(nextTok[0] == OP and nextTok[1] == '('): + result.append((OP, '(')) + appendParen += 1 + exponentSkip = False + elif appendParen: + if nextTok[0] == OP and nextTok[1] in ('^', '**', '*'): + skip = 1 + continue + if skip: + skip -= 1 + continue + result.append((OP, ')')) + appendParen -= 1 + + if tokens: + result.append(tokens[-1]) + + if appendParen: + result.extend([(OP, ')')] * appendParen) + return result + + +def function_exponentiation(tokens: list[TOKEN], local_dict: DICT, global_dict: DICT): + """Allows functions to be exponentiated, e.g. ``cos**2(x)``. + + Examples + ======== + + >>> from sympy.parsing.sympy_parser import (parse_expr, + ... standard_transformations, function_exponentiation) + >>> transformations = standard_transformations + (function_exponentiation,) + >>> parse_expr('sin**4(x)', transformations=transformations) + sin(x)**4 + """ + result: list[TOKEN] = [] + exponent: list[TOKEN] = [] + consuming_exponent = False + level = 0 + for tok, nextTok in zip(tokens, tokens[1:]): + if tok[0] == NAME and nextTok[0] == OP and nextTok[1] == '**': + if _token_callable(tok, local_dict, global_dict): + consuming_exponent = True + elif consuming_exponent: + if tok[0] == NAME and tok[1] == 'Function': + tok = (NAME, 'Symbol') + exponent.append(tok) + + # only want to stop after hitting ) + if tok[0] == nextTok[0] == OP and tok[1] == ')' and nextTok[1] == '(': + consuming_exponent = False + # if implicit multiplication was used, we may have )*( instead + if tok[0] == nextTok[0] == OP and tok[1] == '*' and nextTok[1] == '(': + consuming_exponent = False + del exponent[-1] + continue + elif exponent and not consuming_exponent: + if tok[0] == OP: + if tok[1] == '(': + level += 1 + elif tok[1] == ')': + level -= 1 + if level == 0: + result.append(tok) + result.extend(exponent) + exponent = [] + continue + result.append(tok) + if tokens: + result.append(tokens[-1]) + if exponent: + result.extend(exponent) + return result + + +def split_symbols_custom(predicate: Callable[[str], bool]): + """Creates a transformation that splits symbol names. + + ``predicate`` should return True if the symbol name is to be split. + + For instance, to retain the default behavior but avoid splitting certain + symbol names, a predicate like this would work: + + + >>> from sympy.parsing.sympy_parser import (parse_expr, _token_splittable, + ... standard_transformations, implicit_multiplication, + ... split_symbols_custom) + >>> def can_split(symbol): + ... if symbol not in ('list', 'of', 'unsplittable', 'names'): + ... return _token_splittable(symbol) + ... return False + ... + >>> transformation = split_symbols_custom(can_split) + >>> parse_expr('unsplittable', transformations=standard_transformations + + ... (transformation, implicit_multiplication)) + unsplittable + """ + def _split_symbols(tokens: list[TOKEN], local_dict: DICT, global_dict: DICT): + result: list[TOKEN] = [] + split = False + split_previous=False + + for tok in tokens: + if split_previous: + # throw out closing parenthesis of Symbol that was split + split_previous=False + continue + split_previous=False + + if tok[0] == NAME and tok[1] in ['Symbol', 'Function']: + split = True + + elif split and tok[0] == NAME: + symbol = tok[1][1:-1] + + if predicate(symbol): + tok_type = result[-2][1] # Symbol or Function + del result[-2:] # Get rid of the call to Symbol + + i = 0 + while i < len(symbol): + char = symbol[i] + if char in local_dict or char in global_dict: + result.append((NAME, "%s" % char)) + elif char.isdigit(): + chars = [char] + for i in range(i + 1, len(symbol)): + if not symbol[i].isdigit(): + i -= 1 + break + chars.append(symbol[i]) + char = ''.join(chars) + result.extend([(NAME, 'Number'), (OP, '('), + (NAME, "'%s'" % char), (OP, ')')]) + else: + use = tok_type if i == len(symbol) else 'Symbol' + result.extend([(NAME, use), (OP, '('), + (NAME, "'%s'" % char), (OP, ')')]) + i += 1 + + # Set split_previous=True so will skip + # the closing parenthesis of the original Symbol + split = False + split_previous = True + continue + + else: + split = False + + result.append(tok) + + return result + + return _split_symbols + + +#: Splits symbol names for implicit multiplication. +#: +#: Intended to let expressions like ``xyz`` be parsed as ``x*y*z``. Does not +#: split Greek character names, so ``theta`` will *not* become +#: ``t*h*e*t*a``. Generally this should be used with +#: ``implicit_multiplication``. +split_symbols = split_symbols_custom(_token_splittable) + + +def implicit_multiplication(tokens: list[TOKEN], local_dict: DICT, + global_dict: DICT) -> list[TOKEN]: + """Makes the multiplication operator optional in most cases. + + Use this before :func:`implicit_application`, otherwise expressions like + ``sin 2x`` will be parsed as ``x * sin(2)`` rather than ``sin(2*x)``. + + Examples + ======== + + >>> from sympy.parsing.sympy_parser import (parse_expr, + ... standard_transformations, implicit_multiplication) + >>> transformations = standard_transformations + (implicit_multiplication,) + >>> parse_expr('3 x y', transformations=transformations) + 3*x*y + """ + # These are interdependent steps, so we don't expose them separately + res1 = _group_parentheses(implicit_multiplication)(tokens, local_dict, global_dict) + res2 = _apply_functions(res1, local_dict, global_dict) + res3 = _implicit_multiplication(res2, local_dict, global_dict) + result = _flatten(res3) + return result + + +def implicit_application(tokens: list[TOKEN], local_dict: DICT, + global_dict: DICT) -> list[TOKEN]: + """Makes parentheses optional in some cases for function calls. + + Use this after :func:`implicit_multiplication`, otherwise expressions + like ``sin 2x`` will be parsed as ``x * sin(2)`` rather than + ``sin(2*x)``. + + Examples + ======== + + >>> from sympy.parsing.sympy_parser import (parse_expr, + ... standard_transformations, implicit_application) + >>> transformations = standard_transformations + (implicit_application,) + >>> parse_expr('cot z + csc z', transformations=transformations) + cot(z) + csc(z) + """ + res1 = _group_parentheses(implicit_application)(tokens, local_dict, global_dict) + res2 = _apply_functions(res1, local_dict, global_dict) + res3 = _implicit_application(res2, local_dict, global_dict) + result = _flatten(res3) + return result + + +def implicit_multiplication_application(result: list[TOKEN], local_dict: DICT, + global_dict: DICT) -> list[TOKEN]: + """Allows a slightly relaxed syntax. + + - Parentheses for single-argument method calls are optional. + + - Multiplication is implicit. + + - Symbol names can be split (i.e. spaces are not needed between + symbols). + + - Functions can be exponentiated. + + Examples + ======== + + >>> from sympy.parsing.sympy_parser import (parse_expr, + ... standard_transformations, implicit_multiplication_application) + >>> parse_expr("10sin**2 x**2 + 3xyz + tan theta", + ... transformations=(standard_transformations + + ... (implicit_multiplication_application,))) + 3*x*y*z + 10*sin(x**2)**2 + tan(theta) + + """ + for step in (split_symbols, implicit_multiplication, + implicit_application, function_exponentiation): + result = step(result, local_dict, global_dict) + + return result + + +def auto_symbol(tokens: list[TOKEN], local_dict: DICT, global_dict: DICT): + """Inserts calls to ``Symbol``/``Function`` for undefined variables.""" + result: list[TOKEN] = [] + prevTok = (-1, '') + + tokens.append((-1, '')) # so zip traverses all tokens + for tok, nextTok in zip(tokens, tokens[1:]): + tokNum, tokVal = tok + nextTokNum, nextTokVal = nextTok + if tokNum == NAME: + name = tokVal + + if (name in ['True', 'False', 'None'] + or iskeyword(name) + # Don't convert attribute access + or (prevTok[0] == OP and prevTok[1] == '.') + # Don't convert keyword arguments + or (prevTok[0] == OP and prevTok[1] in ('(', ',') + and nextTokNum == OP and nextTokVal == '=') + # the name has already been defined + or name in local_dict and local_dict[name] is not null): + result.append((NAME, name)) + continue + elif name in local_dict: + local_dict.setdefault(null, set()).add(name) + if nextTokVal == '(': + local_dict[name] = Function(name) + else: + local_dict[name] = Symbol(name) + result.append((NAME, name)) + continue + elif name in global_dict: + obj = global_dict[name] + if isinstance(obj, (AssumptionKeys, Basic, type)) or callable(obj): + result.append((NAME, name)) + continue + + result.extend([ + (NAME, 'Symbol' if nextTokVal != '(' else 'Function'), + (OP, '('), + (NAME, repr(str(name))), + (OP, ')'), + ]) + else: + result.append((tokNum, tokVal)) + + prevTok = (tokNum, tokVal) + + return result + + +def lambda_notation(tokens: list[TOKEN], local_dict: DICT, global_dict: DICT): + """Substitutes "lambda" with its SymPy equivalent Lambda(). + However, the conversion does not take place if only "lambda" + is passed because that is a syntax error. + + """ + result: list[TOKEN] = [] + flag = False + toknum, tokval = tokens[0] + tokLen = len(tokens) + + if toknum == NAME and tokval == 'lambda': + if tokLen == 2 or tokLen == 3 and tokens[1][0] == NEWLINE: + # In Python 3.6.7+, inputs without a newline get NEWLINE added to + # the tokens + result.extend(tokens) + elif tokLen > 2: + result.extend([ + (NAME, 'Lambda'), + (OP, '('), + (OP, '('), + (OP, ')'), + (OP, ')'), + ]) + for tokNum, tokVal in tokens[1:]: + if tokNum == OP and tokVal == ':': + tokVal = ',' + flag = True + if not flag and tokNum == OP and tokVal in ('*', '**'): + raise TokenError("Starred arguments in lambda not supported") + if flag: + result.insert(-1, (tokNum, tokVal)) + else: + result.insert(-2, (tokNum, tokVal)) + else: + result.extend(tokens) + + return result + + +def factorial_notation(tokens: list[TOKEN], local_dict: DICT, global_dict: DICT): + """Allows standard notation for factorial.""" + result: list[TOKEN] = [] + nfactorial = 0 + for toknum, tokval in tokens: + if toknum == OP and tokval == "!": + # In Python 3.12 "!" are OP instead of ERRORTOKEN + nfactorial += 1 + elif toknum == ERRORTOKEN: + op = tokval + if op == '!': + nfactorial += 1 + else: + nfactorial = 0 + result.append((OP, op)) + else: + if nfactorial == 1: + result = _add_factorial_tokens('factorial', result) + elif nfactorial == 2: + result = _add_factorial_tokens('factorial2', result) + elif nfactorial > 2: + raise TokenError + nfactorial = 0 + result.append((toknum, tokval)) + return result + + +def convert_xor(tokens: list[TOKEN], local_dict: DICT, global_dict: DICT): + """Treats XOR, ``^``, as exponentiation, ``**``.""" + result: list[TOKEN] = [] + for toknum, tokval in tokens: + if toknum == OP: + if tokval == '^': + result.append((OP, '**')) + else: + result.append((toknum, tokval)) + else: + result.append((toknum, tokval)) + + return result + + +def repeated_decimals(tokens: list[TOKEN], local_dict: DICT, global_dict: DICT): + """ + Allows 0.2[1] notation to represent the repeated decimal 0.2111... (19/90) + + Run this before auto_number. + + """ + result: list[TOKEN] = [] + + def is_digit(s): + return all(i in '0123456789_' for i in s) + + # num will running match any DECIMAL [ INTEGER ] + num: list[TOKEN] = [] + for toknum, tokval in tokens: + if toknum == NUMBER: + if (not num and '.' in tokval and 'e' not in tokval.lower() and + 'j' not in tokval.lower()): + num.append((toknum, tokval)) + elif is_digit(tokval) and (len(num) == 2 or + len(num) == 3 and is_digit(num[-1][1])): + num.append((toknum, tokval)) + else: + num = [] + elif toknum == OP: + if tokval == '[' and len(num) == 1: + num.append((OP, tokval)) + elif tokval == ']' and len(num) >= 3: + num.append((OP, tokval)) + elif tokval == '.' and not num: + # handle .[1] + num.append((NUMBER, '0.')) + else: + num = [] + else: + num = [] + + result.append((toknum, tokval)) + + if num and num[-1][1] == ']': + # pre.post[repetend] = a + b/c + d/e where a = pre, b/c = post, + # and d/e = repetend + result = result[:-len(num)] + pre, post = num[0][1].split('.') + repetend = num[2][1] + if len(num) == 5: + repetend += num[3][1] + + pre = pre.replace('_', '') + post = post.replace('_', '') + repetend = repetend.replace('_', '') + + zeros = '0'*len(post) + post, repetends = [w.lstrip('0') for w in [post, repetend]] + # or else interpreted as octal + + a = pre or '0' + b, c = post or '0', '1' + zeros + d, e = repetends, ('9'*len(repetend)) + zeros + + seq = [ + (OP, '('), + (NAME, 'Integer'), + (OP, '('), + (NUMBER, a), + (OP, ')'), + (OP, '+'), + (NAME, 'Rational'), + (OP, '('), + (NUMBER, b), + (OP, ','), + (NUMBER, c), + (OP, ')'), + (OP, '+'), + (NAME, 'Rational'), + (OP, '('), + (NUMBER, d), + (OP, ','), + (NUMBER, e), + (OP, ')'), + (OP, ')'), + ] + result.extend(seq) + num = [] + + return result + + +def auto_number(tokens: list[TOKEN], local_dict: DICT, global_dict: DICT): + """ + Converts numeric literals to use SymPy equivalents. + + Complex numbers use ``I``, integer literals use ``Integer``, and float + literals use ``Float``. + + """ + result: list[TOKEN] = [] + + for toknum, tokval in tokens: + if toknum == NUMBER: + number = tokval + postfix = [] + + if number.endswith(('j', 'J')): + number = number[:-1] + postfix = [(OP, '*'), (NAME, 'I')] + + if '.' in number or (('e' in number or 'E' in number) and + not (number.startswith(('0x', '0X')))): + seq = [(NAME, 'Float'), (OP, '('), + (NUMBER, repr(str(number))), (OP, ')')] + else: + seq = [(NAME, 'Integer'), (OP, '('), ( + NUMBER, number), (OP, ')')] + + result.extend(seq + postfix) + else: + result.append((toknum, tokval)) + + return result + + +def rationalize(tokens: list[TOKEN], local_dict: DICT, global_dict: DICT): + """Converts floats into ``Rational``. Run AFTER ``auto_number``.""" + result: list[TOKEN] = [] + passed_float = False + for toknum, tokval in tokens: + if toknum == NAME: + if tokval == 'Float': + passed_float = True + tokval = 'Rational' + result.append((toknum, tokval)) + elif passed_float == True and toknum == NUMBER: + passed_float = False + result.append((STRING, tokval)) + else: + result.append((toknum, tokval)) + + return result + + +def _transform_equals_sign(tokens: list[TOKEN], local_dict: DICT, global_dict: DICT): + """Transforms the equals sign ``=`` to instances of Eq. + + This is a helper function for ``convert_equals_signs``. + Works with expressions containing one equals sign and no + nesting. Expressions like ``(1=2)=False`` will not work with this + and should be used with ``convert_equals_signs``. + + Examples: 1=2 to Eq(1,2) + 1*2=x to Eq(1*2, x) + + This does not deal with function arguments yet. + + """ + result: list[TOKEN] = [] + if (OP, "=") in tokens: + result.append((NAME, "Eq")) + result.append((OP, "(")) + for token in tokens: + if token == (OP, "="): + result.append((OP, ",")) + continue + result.append(token) + result.append((OP, ")")) + else: + result = tokens + return result + + +def convert_equals_signs(tokens: list[TOKEN], local_dict: DICT, + global_dict: DICT) -> list[TOKEN]: + """ Transforms all the equals signs ``=`` to instances of Eq. + + Parses the equals signs in the expression and replaces them with + appropriate Eq instances. Also works with nested equals signs. + + Does not yet play well with function arguments. + For example, the expression ``(x=y)`` is ambiguous and can be interpreted + as x being an argument to a function and ``convert_equals_signs`` will not + work for this. + + See also + ======== + convert_equality_operators + + Examples + ======== + + >>> from sympy.parsing.sympy_parser import (parse_expr, + ... standard_transformations, convert_equals_signs) + >>> parse_expr("1*2=x", transformations=( + ... standard_transformations + (convert_equals_signs,))) + Eq(2, x) + >>> parse_expr("(1*2=x)=False", transformations=( + ... standard_transformations + (convert_equals_signs,))) + Eq(Eq(2, x), False) + + """ + res1 = _group_parentheses(convert_equals_signs)(tokens, local_dict, global_dict) + res2 = _apply_functions(res1, local_dict, global_dict) + res3 = _transform_equals_sign(res2, local_dict, global_dict) + result = _flatten(res3) + return result + + +#: Standard transformations for :func:`parse_expr`. +#: Inserts calls to :class:`~.Symbol`, :class:`~.Integer`, and other SymPy +#: datatypes and allows the use of standard factorial notation (e.g. ``x!``). +standard_transformations: tuple[TRANS, ...] \ + = (lambda_notation, auto_symbol, repeated_decimals, auto_number, + factorial_notation) + + +def stringify_expr(s: str, local_dict: DICT, global_dict: DICT, + transformations: tuple[TRANS, ...]) -> str: + """ + Converts the string ``s`` to Python code, in ``local_dict`` + + Generally, ``parse_expr`` should be used. + """ + + tokens = [] + input_code = StringIO(s.strip()) + for toknum, tokval, _, _, _ in generate_tokens(input_code.readline): + tokens.append((toknum, tokval)) + + for transform in transformations: + tokens = transform(tokens, local_dict, global_dict) + + return untokenize(tokens) + + +def eval_expr(code, local_dict: DICT, global_dict: DICT): + """ + Evaluate Python code generated by ``stringify_expr``. + + Generally, ``parse_expr`` should be used. + """ + expr = eval( + code, global_dict, local_dict) # take local objects in preference + return expr + + +def parse_expr(s: str, local_dict: DICT | None = None, + transformations: tuple[TRANS, ...] | str \ + = standard_transformations, + global_dict: DICT | None = None, evaluate=True): + """Converts the string ``s`` to a SymPy expression, in ``local_dict``. + + .. warning:: + Note that this function uses ``eval``, and thus shouldn't be used on + unsanitized input. + + Parameters + ========== + + s : str + The string to parse. + + local_dict : dict, optional + A dictionary of local variables to use when parsing. + + global_dict : dict, optional + A dictionary of global variables. By default, this is initialized + with ``from sympy import *``; provide this parameter to override + this behavior (for instance, to parse ``"Q & S"``). + + transformations : tuple or str + A tuple of transformation functions used to modify the tokens of the + parsed expression before evaluation. The default transformations + convert numeric literals into their SymPy equivalents, convert + undefined variables into SymPy symbols, and allow the use of standard + mathematical factorial notation (e.g. ``x!``). Selection via + string is available (see below). + + evaluate : bool, optional + When False, the order of the arguments will remain as they were in the + string and automatic simplification that would normally occur is + suppressed. (see examples) + + Examples + ======== + + >>> from sympy.parsing.sympy_parser import parse_expr + >>> parse_expr("1/2") + 1/2 + >>> type(_) + + >>> from sympy.parsing.sympy_parser import standard_transformations,\\ + ... implicit_multiplication_application + >>> transformations = (standard_transformations + + ... (implicit_multiplication_application,)) + >>> parse_expr("2x", transformations=transformations) + 2*x + + When evaluate=False, some automatic simplifications will not occur: + + >>> parse_expr("2**3"), parse_expr("2**3", evaluate=False) + (8, 2**3) + + In addition the order of the arguments will not be made canonical. + This feature allows one to tell exactly how the expression was entered: + + >>> a = parse_expr('1 + x', evaluate=False) + >>> b = parse_expr('x + 1', evaluate=False) + >>> a == b + False + >>> a.args + (1, x) + >>> b.args + (x, 1) + + Note, however, that when these expressions are printed they will + appear the same: + + >>> assert str(a) == str(b) + + As a convenience, transformations can be seen by printing ``transformations``: + + >>> from sympy.parsing.sympy_parser import transformations + + >>> print(transformations) + 0: lambda_notation + 1: auto_symbol + 2: repeated_decimals + 3: auto_number + 4: factorial_notation + 5: implicit_multiplication_application + 6: convert_xor + 7: implicit_application + 8: implicit_multiplication + 9: convert_equals_signs + 10: function_exponentiation + 11: rationalize + + The ``T`` object provides a way to select these transformations: + + >>> from sympy.parsing.sympy_parser import T + + If you print it, you will see the same list as shown above. + + >>> str(T) == str(transformations) + True + + Standard slicing will return a tuple of transformations: + + >>> T[:5] == standard_transformations + True + + So ``T`` can be used to specify the parsing transformations: + + >>> parse_expr("2x", transformations=T[:5]) + Traceback (most recent call last): + ... + SyntaxError: invalid syntax + >>> parse_expr("2x", transformations=T[:6]) + 2*x + >>> parse_expr('.3', transformations=T[3, 11]) + 3/10 + >>> parse_expr('.3x', transformations=T[:]) + 3*x/10 + + As a further convenience, strings 'implicit' and 'all' can be used + to select 0-5 and all the transformations, respectively. + + >>> parse_expr('.3x', transformations='all') + 3*x/10 + + See Also + ======== + + stringify_expr, eval_expr, standard_transformations, + implicit_multiplication_application + + """ + + if local_dict is None: + local_dict = {} + elif not isinstance(local_dict, dict): + raise TypeError('expecting local_dict to be a dict') + elif null in local_dict: + raise ValueError('cannot use "" in local_dict') + + if global_dict is None: + global_dict = {} + exec('from sympy import *', global_dict) + + builtins_dict = vars(builtins) + for name, obj in builtins_dict.items(): + if isinstance(obj, types.BuiltinFunctionType): + global_dict[name] = obj + global_dict['max'] = Max + global_dict['min'] = Min + + elif not isinstance(global_dict, dict): + raise TypeError('expecting global_dict to be a dict') + + transformations = transformations or () + if isinstance(transformations, str): + if transformations == 'all': + _transformations = T[:] + elif transformations == 'implicit': + _transformations = T[:6] + else: + raise ValueError('unknown transformation group name') + else: + _transformations = transformations + + code = stringify_expr(s, local_dict, global_dict, _transformations) + + if not evaluate: + code = compile(evaluateFalse(code), '', 'eval') # type: ignore + + try: + rv = eval_expr(code, local_dict, global_dict) + # restore neutral definitions for names + for i in local_dict.pop(null, ()): + local_dict[i] = null + return rv + except Exception as e: + # restore neutral definitions for names + for i in local_dict.pop(null, ()): + local_dict[i] = null + raise e from ValueError(f"Error from parse_expr with transformed code: {code!r}") + + +def evaluateFalse(s: str): + """ + Replaces operators with the SymPy equivalent and sets evaluate=False. + """ + node = ast.parse(s) + transformed_node = EvaluateFalseTransformer().visit(node) + # node is a Module, we want an Expression + transformed_node = ast.Expression(transformed_node.body[0].value) + + return ast.fix_missing_locations(transformed_node) + + +class EvaluateFalseTransformer(ast.NodeTransformer): + operators = { + ast.Add: 'Add', + ast.Mult: 'Mul', + ast.Pow: 'Pow', + ast.Sub: 'Add', + ast.Div: 'Mul', + ast.BitOr: 'Or', + ast.BitAnd: 'And', + ast.BitXor: 'Not', + } + functions = ( + 'Abs', 'im', 're', 'sign', 'arg', 'conjugate', + 'acos', 'acot', 'acsc', 'asec', 'asin', 'atan', + 'acosh', 'acoth', 'acsch', 'asech', 'asinh', 'atanh', + 'cos', 'cot', 'csc', 'sec', 'sin', 'tan', + 'cosh', 'coth', 'csch', 'sech', 'sinh', 'tanh', + 'exp', 'ln', 'log', 'sqrt', 'cbrt', + ) + + relational_operators = { + ast.NotEq: 'Ne', + ast.Lt: 'Lt', + ast.LtE: 'Le', + ast.Gt: 'Gt', + ast.GtE: 'Ge', + ast.Eq: 'Eq' + } + def visit_Compare(self, node): + def reducer(acc, op_right): + result, left = acc + op, right = op_right + if op.__class__ not in self.relational_operators: + raise ValueError("Only equation or inequality operators are supported") + new = ast.Call( + func=ast.Name( + id=self.relational_operators[op.__class__], ctx=ast.Load() + ), + args=[self.visit(left), self.visit(right)], + keywords=[ast.keyword(arg="evaluate", value=ast.Constant(value=False))], + ) + return result + [new], right + + args, _ = reduce( + reducer, zip(node.ops, node.comparators), ([], node.left) + ) + if len(args) == 1: + return args[0] + return ast.Call( + func=ast.Name(id=self.operators[ast.BitAnd], ctx=ast.Load()), + args=args, + keywords=[ast.keyword(arg="evaluate", value=ast.Constant(value=False))], + ) + + def flatten(self, args, func): + result = [] + for arg in args: + if isinstance(arg, ast.Call): + arg_func = arg.func + if isinstance(arg_func, ast.Call): + arg_func = arg_func.func + if arg_func.id == func: + result.extend(self.flatten(arg.args, func)) + else: + result.append(arg) + else: + result.append(arg) + return result + + def visit_BinOp(self, node): + if node.op.__class__ in self.operators: + sympy_class = self.operators[node.op.__class__] + right = self.visit(node.right) + left = self.visit(node.left) + + rev = False + if isinstance(node.op, ast.Sub): + right = ast.Call( + func=ast.Name(id='Mul', ctx=ast.Load()), + args=[ast.UnaryOp(op=ast.USub(), operand=ast.Constant(1)), right], + keywords=[ast.keyword(arg='evaluate', value=ast.Constant(value=False))] + ) + elif isinstance(node.op, ast.Div): + if isinstance(node.left, ast.UnaryOp): + left, right = right, left + rev = True + left = ast.Call( + func=ast.Name(id='Pow', ctx=ast.Load()), + args=[left, ast.UnaryOp(op=ast.USub(), operand=ast.Constant(1))], + keywords=[ast.keyword(arg='evaluate', value=ast.Constant(value=False))] + ) + else: + right = ast.Call( + func=ast.Name(id='Pow', ctx=ast.Load()), + args=[right, ast.UnaryOp(op=ast.USub(), operand=ast.Constant(1))], + keywords=[ast.keyword(arg='evaluate', value=ast.Constant(value=False))] + ) + + if rev: # undo reversal + left, right = right, left + new_node = ast.Call( + func=ast.Name(id=sympy_class, ctx=ast.Load()), + args=[left, right], + keywords=[ast.keyword(arg='evaluate', value=ast.Constant(value=False))] + ) + + if sympy_class in ('Add', 'Mul'): + # Denest Add or Mul as appropriate + new_node.args = self.flatten(new_node.args, sympy_class) + + return new_node + return node + + def visit_Call(self, node): + new_node = self.generic_visit(node) + if isinstance(node.func, ast.Name) and node.func.id in self.functions: + new_node.keywords.append(ast.keyword(arg='evaluate', value=ast.Constant(value=False))) + return new_node + + +_transformation = { # items can be added but never re-ordered +0: lambda_notation, +1: auto_symbol, +2: repeated_decimals, +3: auto_number, +4: factorial_notation, +5: implicit_multiplication_application, +6: convert_xor, +7: implicit_application, +8: implicit_multiplication, +9: convert_equals_signs, +10: function_exponentiation, +11: rationalize} + +transformations = '\n'.join('%s: %s' % (i, func_name(f)) for i, f in _transformation.items()) + + +class _T(): + """class to retrieve transformations from a given slice + + EXAMPLES + ======== + + >>> from sympy.parsing.sympy_parser import T, standard_transformations + >>> assert T[:5] == standard_transformations + """ + def __init__(self): + self.N = len(_transformation) + + def __str__(self): + return transformations + + def __getitem__(self, t): + if not type(t) is tuple: + t = (t,) + i = [] + for ti in t: + if type(ti) is int: + i.append(range(self.N)[ti]) + elif type(ti) is slice: + i.extend(range(*ti.indices(self.N))) + else: + raise TypeError('unexpected slice arg') + return tuple([_transformation[_] for _ in i]) + +T = _T() diff --git a/phivenv/Lib/site-packages/sympy/parsing/tests/__init__.py b/phivenv/Lib/site-packages/sympy/parsing/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/parsing/tests/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/tests/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5985f72d72bc6c4ac3959e6ab3475c9dc377710a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/tests/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/tests/__pycache__/test_ast_parser.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/tests/__pycache__/test_ast_parser.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..427788d458aeefe411326e2ebd5179b63191f871 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/tests/__pycache__/test_ast_parser.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/tests/__pycache__/test_autolev.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/tests/__pycache__/test_autolev.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64a993f803853ef95250b2d37295dfe417448228 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/tests/__pycache__/test_autolev.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/tests/__pycache__/test_c_parser.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/tests/__pycache__/test_c_parser.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10e63cd3f1e2130bcb4ec26685a142fac042ff91 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/tests/__pycache__/test_c_parser.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/tests/__pycache__/test_custom_latex.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/tests/__pycache__/test_custom_latex.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d7e8057b2ec6aa9f79d2bb6db2c217e9dda893e Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/tests/__pycache__/test_custom_latex.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/tests/__pycache__/test_fortran_parser.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/tests/__pycache__/test_fortran_parser.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5775ab264259a5fed85c84a9d324e9000aadb488 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/tests/__pycache__/test_fortran_parser.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/tests/__pycache__/test_implicit_multiplication_application.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/tests/__pycache__/test_implicit_multiplication_application.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96898c32ea443a3cf165b0e6ee74e9e841241a6b Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/tests/__pycache__/test_implicit_multiplication_application.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/tests/__pycache__/test_latex.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/tests/__pycache__/test_latex.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1884ecb0280bd8526140166d281d10405e16c096 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/tests/__pycache__/test_latex.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/tests/__pycache__/test_latex_deps.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/tests/__pycache__/test_latex_deps.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a3fe5aaa8fa961fa7fb9670348134512b76997fb Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/tests/__pycache__/test_latex_deps.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/tests/__pycache__/test_latex_lark.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/tests/__pycache__/test_latex_lark.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f3aca7468612fc58b67b7ac1603dfa2524ebb3c7 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/tests/__pycache__/test_latex_lark.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/tests/__pycache__/test_mathematica.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/tests/__pycache__/test_mathematica.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..026426ec14cc918613653b5e86eeed217dbdb3a6 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/tests/__pycache__/test_mathematica.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/tests/__pycache__/test_maxima.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/tests/__pycache__/test_maxima.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c01bcac1570cfdf1a9fd05c27092470337c744c7 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/tests/__pycache__/test_maxima.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/tests/__pycache__/test_sym_expr.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/tests/__pycache__/test_sym_expr.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bbdeb54a3ea4297777a7035d4699f571691253da Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/tests/__pycache__/test_sym_expr.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/tests/__pycache__/test_sympy_parser.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/parsing/tests/__pycache__/test_sympy_parser.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..322496e0efa77033e27396cd6417f12e8c86469c Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/parsing/tests/__pycache__/test_sympy_parser.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/parsing/tests/test_ast_parser.py b/phivenv/Lib/site-packages/sympy/parsing/tests/test_ast_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..24572190df72f9be11b5830355b0d6b9e3bb53ad --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/tests/test_ast_parser.py @@ -0,0 +1,25 @@ +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.parsing.ast_parser import parse_expr +from sympy.testing.pytest import raises +from sympy.core.sympify import SympifyError +import warnings + +def test_parse_expr(): + a, b = symbols('a, b') + # tests issue_16393 + assert parse_expr('a + b', {}) == a + b + raises(SympifyError, lambda: parse_expr('a + ', {})) + + # tests Transform.visit_Constant + assert parse_expr('1 + 2', {}) == S(3) + assert parse_expr('1 + 2.0', {}) == S(3.0) + + # tests Transform.visit_Name + assert parse_expr('Rational(1, 2)', {}) == S(1)/2 + assert parse_expr('a', {'a': a}) == a + + # tests issue_23092 + with warnings.catch_warnings(): + warnings.simplefilter('error') + assert parse_expr('6 * 7', {}) == S(42) diff --git a/phivenv/Lib/site-packages/sympy/parsing/tests/test_autolev.py b/phivenv/Lib/site-packages/sympy/parsing/tests/test_autolev.py new file mode 100644 index 0000000000000000000000000000000000000000..dfcaef13565c5e2187dc6e90113b407a7967c331 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/tests/test_autolev.py @@ -0,0 +1,178 @@ +import os + +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.external import import_module +from sympy.testing.pytest import skip +from sympy.parsing.autolev import parse_autolev + +antlr4 = import_module("antlr4") + +if not antlr4: + disabled = True + +FILE_DIR = os.path.dirname( + os.path.dirname(os.path.abspath(os.path.realpath(__file__)))) + + +def _test_examples(in_filename, out_filename, test_name=""): + + in_file_path = os.path.join(FILE_DIR, 'autolev', 'test-examples', + in_filename) + correct_file_path = os.path.join(FILE_DIR, 'autolev', 'test-examples', + out_filename) + with open(in_file_path) as f: + generated_code = parse_autolev(f, include_numeric=True) + + with open(correct_file_path) as f: + for idx, line1 in enumerate(f): + if line1.startswith("#"): + break + try: + line2 = generated_code.split('\n')[idx] + assert line1.rstrip() == line2.rstrip() + except Exception: + msg = 'mismatch in ' + test_name + ' in line no: {0}' + raise AssertionError(msg.format(idx+1)) + + +def test_rule_tests(): + + l = ["ruletest1", "ruletest2", "ruletest3", "ruletest4", "ruletest5", + "ruletest6", "ruletest7", "ruletest8", "ruletest9", "ruletest10", + "ruletest11", "ruletest12"] + + for i in l: + in_filepath = i + ".al" + out_filepath = i + ".py" + _test_examples(in_filepath, out_filepath, i) + + +def test_pydy_examples(): + + l = ["mass_spring_damper", "chaos_pendulum", "double_pendulum", + "non_min_pendulum"] + + for i in l: + in_filepath = os.path.join("pydy-example-repo", i + ".al") + out_filepath = os.path.join("pydy-example-repo", i + ".py") + _test_examples(in_filepath, out_filepath, i) + + +def test_autolev_tutorial(): + + dir_path = os.path.join(FILE_DIR, 'autolev', 'test-examples', + 'autolev-tutorial') + + if os.path.isdir(dir_path): + l = ["tutor1", "tutor2", "tutor3", "tutor4", "tutor5", "tutor6", + "tutor7"] + for i in l: + in_filepath = os.path.join("autolev-tutorial", i + ".al") + out_filepath = os.path.join("autolev-tutorial", i + ".py") + _test_examples(in_filepath, out_filepath, i) + + +def test_dynamics_online(): + + dir_path = os.path.join(FILE_DIR, 'autolev', 'test-examples', + 'dynamics-online') + + if os.path.isdir(dir_path): + ch1 = ["1-4", "1-5", "1-6", "1-7", "1-8", "1-9_1", "1-9_2", "1-9_3"] + ch2 = ["2-1", "2-2", "2-3", "2-4", "2-5", "2-6", "2-7", "2-8", "2-9", + "circular"] + ch3 = ["3-1_1", "3-1_2", "3-2_1", "3-2_2", "3-2_3", "3-2_4", "3-2_5", + "3-3"] + ch4 = ["4-1_1", "4-2_1", "4-4_1", "4-4_2", "4-5_1", "4-5_2"] + chapters = [(ch1, "ch1"), (ch2, "ch2"), (ch3, "ch3"), (ch4, "ch4")] + for ch, name in chapters: + for i in ch: + in_filepath = os.path.join("dynamics-online", name, i + ".al") + out_filepath = os.path.join("dynamics-online", name, i + ".py") + _test_examples(in_filepath, out_filepath, i) + + +def test_output_01(): + """Autolev example calculates the position, velocity, and acceleration of a + point and expresses in a single reference frame:: + + (1) FRAMES C,D,F + (2) VARIABLES FD'',DC'' + (3) CONSTANTS R,L + (4) POINTS O,E + (5) SIMPROT(F,D,1,FD) + -> (6) F_D = [1, 0, 0; 0, COS(FD), -SIN(FD); 0, SIN(FD), COS(FD)] + (7) SIMPROT(D,C,2,DC) + -> (8) D_C = [COS(DC), 0, SIN(DC); 0, 1, 0; -SIN(DC), 0, COS(DC)] + (9) W_C_F> = EXPRESS(W_C_F>, F) + -> (10) W_C_F> = FD'*F1> + COS(FD)*DC'*F2> + SIN(FD)*DC'*F3> + (11) P_O_E>=R*D2>-L*C1> + (12) P_O_E>=EXPRESS(P_O_E>, D) + -> (13) P_O_E> = -L*COS(DC)*D1> + R*D2> + L*SIN(DC)*D3> + (14) V_E_F>=EXPRESS(DT(P_O_E>,F),D) + -> (15) V_E_F> = L*SIN(DC)*DC'*D1> - L*SIN(DC)*FD'*D2> + (R*FD'+L*COS(DC)*DC')*D3> + (16) A_E_F>=EXPRESS(DT(V_E_F>,F),D) + -> (17) A_E_F> = L*(COS(DC)*DC'^2+SIN(DC)*DC'')*D1> + (-R*FD'^2-2*L*COS(DC)*DC'*FD'-L*SIN(DC)*FD'')*D2> + (R*FD''+L*COS(DC)*DC''-L*SIN(DC)*DC'^2-L*SIN(DC)*FD'^2)*D3> + + """ + + if not antlr4: + skip('Test skipped: antlr4 is not installed.') + + autolev_input = """\ +FRAMES C,D,F +VARIABLES FD'',DC'' +CONSTANTS R,L +POINTS O,E +SIMPROT(F,D,1,FD) +SIMPROT(D,C,2,DC) +W_C_F>=EXPRESS(W_C_F>,F) +P_O_E>=R*D2>-L*C1> +P_O_E>=EXPRESS(P_O_E>,D) +V_E_F>=EXPRESS(DT(P_O_E>,F),D) +A_E_F>=EXPRESS(DT(V_E_F>,F),D)\ +""" + + sympy_input = parse_autolev(autolev_input) + + g = {} + l = {} + exec(sympy_input, g, l) + + w_c_f = l['frame_c'].ang_vel_in(l['frame_f']) + # P_O_E> means "the position of point E wrt to point O" + p_o_e = l['point_e'].pos_from(l['point_o']) + v_e_f = l['point_e'].vel(l['frame_f']) + a_e_f = l['point_e'].acc(l['frame_f']) + + # NOTE : The Autolev outputs above were manually transformed into + # equivalent SymPy physics vector expressions. Would be nice to automate + # this transformation. + expected_w_c_f = (l['fd'].diff()*l['frame_f'].x + + cos(l['fd'])*l['dc'].diff()*l['frame_f'].y + + sin(l['fd'])*l['dc'].diff()*l['frame_f'].z) + + assert (w_c_f - expected_w_c_f).simplify() == 0 + + expected_p_o_e = (-l['l']*cos(l['dc'])*l['frame_d'].x + + l['r']*l['frame_d'].y + + l['l']*sin(l['dc'])*l['frame_d'].z) + + assert (p_o_e - expected_p_o_e).simplify() == 0 + + expected_v_e_f = (l['l']*sin(l['dc'])*l['dc'].diff()*l['frame_d'].x - + l['l']*sin(l['dc'])*l['fd'].diff()*l['frame_d'].y + + (l['r']*l['fd'].diff() + + l['l']*cos(l['dc'])*l['dc'].diff())*l['frame_d'].z) + assert (v_e_f - expected_v_e_f).simplify() == 0 + + expected_a_e_f = (l['l']*(cos(l['dc'])*l['dc'].diff()**2 + + sin(l['dc'])*l['dc'].diff().diff())*l['frame_d'].x + + (-l['r']*l['fd'].diff()**2 - + 2*l['l']*cos(l['dc'])*l['dc'].diff()*l['fd'].diff() - + l['l']*sin(l['dc'])*l['fd'].diff().diff())*l['frame_d'].y + + (l['r']*l['fd'].diff().diff() + + l['l']*cos(l['dc'])*l['dc'].diff().diff() - + l['l']*sin(l['dc'])*l['dc'].diff()**2 - + l['l']*sin(l['dc'])*l['fd'].diff()**2)*l['frame_d'].z) + assert (a_e_f - expected_a_e_f).simplify() == 0 diff --git a/phivenv/Lib/site-packages/sympy/parsing/tests/test_c_parser.py b/phivenv/Lib/site-packages/sympy/parsing/tests/test_c_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..b74622e40030cba180cb4fc354216ccca119baec --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/tests/test_c_parser.py @@ -0,0 +1,5248 @@ +from sympy.parsing.sym_expr import SymPyExpression +from sympy.testing.pytest import raises, XFAIL +from sympy.external import import_module + +cin = import_module('clang.cindex', import_kwargs = {'fromlist': ['cindex']}) + +if cin: + from sympy.codegen.ast import (Variable, String, Return, + FunctionDefinition, Integer, Float, Declaration, CodeBlock, + FunctionPrototype, FunctionCall, NoneToken, Assignment, Type, + IntBaseType, SignedIntType, UnsignedIntType, FloatType, + AddAugmentedAssignment, SubAugmentedAssignment, + MulAugmentedAssignment, DivAugmentedAssignment, + ModAugmentedAssignment, While) + from sympy.codegen.cnodes import (PreDecrement, PostDecrement, + PreIncrement, PostIncrement) + from sympy.core import (Add, Mul, Mod, Pow, Rational, + StrictLessThan, LessThan, StrictGreaterThan, GreaterThan, + Equality, Unequality) + from sympy.logic.boolalg import And, Not, Or + from sympy.core.symbol import Symbol + from sympy.logic.boolalg import (false, true) + import os + + def test_variable(): + c_src1 = ( + 'int a;' + '\n' + + 'int b;' + '\n' + ) + c_src2 = ( + 'float a;' + '\n' + + 'float b;' + '\n' + ) + c_src3 = ( + 'int a;' + '\n' + + 'float b;' + '\n' + + 'int c;' + ) + c_src4 = ( + 'int x = 1, y = 6.78;' + '\n' + + 'float p = 2, q = 9.67;' + ) + + res1 = SymPyExpression(c_src1, 'c').return_expr() + res2 = SymPyExpression(c_src2, 'c').return_expr() + res3 = SymPyExpression(c_src3, 'c').return_expr() + res4 = SymPyExpression(c_src4, 'c').return_expr() + + assert res1[0] == Declaration( + Variable( + Symbol('a'), + type=IntBaseType(String('intc')) + ) + ) + + assert res1[1] == Declaration( + Variable( + Symbol('b'), + type=IntBaseType(String('intc')) + ) + ) + + assert res2[0] == Declaration( + Variable( + Symbol('a'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ) + ) + assert res2[1] == Declaration( + Variable( + Symbol('b'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ) + ) + + assert res3[0] == Declaration( + Variable( + Symbol('a'), + type=IntBaseType(String('intc')) + ) + ) + + assert res3[1] == Declaration( + Variable( + Symbol('b'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ) + ) + + assert res3[2] == Declaration( + Variable( + Symbol('c'), + type=IntBaseType(String('intc')) + ) + ) + + assert res4[0] == Declaration( + Variable( + Symbol('x'), + type=IntBaseType(String('intc')), + value=Integer(1) + ) + ) + + assert res4[1] == Declaration( + Variable( + Symbol('y'), + type=IntBaseType(String('intc')), + value=Integer(6) + ) + ) + + assert res4[2] == Declaration( + Variable( + Symbol('p'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('2.0', precision=53) + ) + ) + + assert res4[3] == Declaration( + Variable( + Symbol('q'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('9.67', precision=53) + ) + ) + + + def test_int(): + c_src1 = 'int a = 1;' + c_src2 = ( + 'int a = 1;' + '\n' + + 'int b = 2;' + '\n' + ) + c_src3 = 'int a = 2.345, b = 5.67;' + c_src4 = 'int p = 6, q = 23.45;' + c_src5 = "int x = '0', y = 'a';" + c_src6 = "int r = true, s = false;" + + # cin.TypeKind.UCHAR + c_src_type1 = ( + "signed char a = 1, b = 5.1;" + ) + + # cin.TypeKind.SHORT + c_src_type2 = ( + "short a = 1, b = 5.1;" + "signed short c = 1, d = 5.1;" + "short int e = 1, f = 5.1;" + "signed short int g = 1, h = 5.1;" + ) + + # cin.TypeKind.INT + c_src_type3 = ( + "signed int a = 1, b = 5.1;" + "int c = 1, d = 5.1;" + ) + + # cin.TypeKind.LONG + c_src_type4 = ( + "long a = 1, b = 5.1;" + "long int c = 1, d = 5.1;" + ) + + # cin.TypeKind.UCHAR + c_src_type5 = "unsigned char a = 1, b = 5.1;" + + # cin.TypeKind.USHORT + c_src_type6 = ( + "unsigned short a = 1, b = 5.1;" + "unsigned short int c = 1, d = 5.1;" + ) + + # cin.TypeKind.UINT + c_src_type7 = "unsigned int a = 1, b = 5.1;" + + # cin.TypeKind.ULONG + c_src_type8 = ( + "unsigned long a = 1, b = 5.1;" + "unsigned long int c = 1, d = 5.1;" + ) + + res1 = SymPyExpression(c_src1, 'c').return_expr() + res2 = SymPyExpression(c_src2, 'c').return_expr() + res3 = SymPyExpression(c_src3, 'c').return_expr() + res4 = SymPyExpression(c_src4, 'c').return_expr() + res5 = SymPyExpression(c_src5, 'c').return_expr() + res6 = SymPyExpression(c_src6, 'c').return_expr() + + res_type1 = SymPyExpression(c_src_type1, 'c').return_expr() + res_type2 = SymPyExpression(c_src_type2, 'c').return_expr() + res_type3 = SymPyExpression(c_src_type3, 'c').return_expr() + res_type4 = SymPyExpression(c_src_type4, 'c').return_expr() + res_type5 = SymPyExpression(c_src_type5, 'c').return_expr() + res_type6 = SymPyExpression(c_src_type6, 'c').return_expr() + res_type7 = SymPyExpression(c_src_type7, 'c').return_expr() + res_type8 = SymPyExpression(c_src_type8, 'c').return_expr() + + assert res1[0] == Declaration( + Variable( + Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(1) + ) + ) + + assert res2[0] == Declaration( + Variable( + Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(1) + ) + ) + + assert res2[1] == Declaration( + Variable( + Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(2) + ) + ) + + assert res3[0] == Declaration( + Variable( + Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(2) + ) + ) + + assert res3[1] == Declaration( + Variable( + Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(5) + ) + ) + + assert res4[0] == Declaration( + Variable( + Symbol('p'), + type=IntBaseType(String('intc')), + value=Integer(6) + ) + ) + + assert res4[1] == Declaration( + Variable( + Symbol('q'), + type=IntBaseType(String('intc')), + value=Integer(23) + ) + ) + + assert res5[0] == Declaration( + Variable( + Symbol('x'), + type=IntBaseType(String('intc')), + value=Integer(48) + ) + ) + + assert res5[1] == Declaration( + Variable( + Symbol('y'), + type=IntBaseType(String('intc')), + value=Integer(97) + ) + ) + + assert res6[0] == Declaration( + Variable( + Symbol('r'), + type=IntBaseType(String('intc')), + value=Integer(1) + ) + ) + + assert res6[1] == Declaration( + Variable( + Symbol('s'), + type=IntBaseType(String('intc')), + value=Integer(0) + ) + ) + + assert res_type1[0] == Declaration( + Variable( + Symbol('a'), + type=SignedIntType( + String('int8'), + nbits=Integer(8) + ), + value=Integer(1) + ) + ) + + assert res_type1[1] == Declaration( + Variable( + Symbol('b'), + type=SignedIntType( + String('int8'), + nbits=Integer(8) + ), + value=Integer(5) + ) + ) + + assert res_type2[0] == Declaration( + Variable( + Symbol('a'), + type=SignedIntType( + String('int16'), + nbits=Integer(16) + ), + value=Integer(1) + ) + ) + + assert res_type2[1] == Declaration( + Variable( + Symbol('b'), + type=SignedIntType( + String('int16'), + nbits=Integer(16) + ), + value=Integer(5) + ) + ) + + assert res_type2[2] == Declaration( + Variable(Symbol('c'), + type=SignedIntType( + String('int16'), + nbits=Integer(16) + ), + value=Integer(1) + ) + ) + + assert res_type2[3] == Declaration( + Variable( + Symbol('d'), + type=SignedIntType( + String('int16'), + nbits=Integer(16) + ), + value=Integer(5) + ) + ) + + assert res_type2[4] == Declaration( + Variable( + Symbol('e'), + type=SignedIntType( + String('int16'), + nbits=Integer(16) + ), + value=Integer(1) + ) + ) + + assert res_type2[5] == Declaration( + Variable( + Symbol('f'), + type=SignedIntType( + String('int16'), + nbits=Integer(16) + ), + value=Integer(5) + ) + ) + + assert res_type2[6] == Declaration( + Variable( + Symbol('g'), + type=SignedIntType( + String('int16'), + nbits=Integer(16) + ), + value=Integer(1) + ) + ) + + assert res_type2[7] == Declaration( + Variable( + Symbol('h'), + type=SignedIntType( + String('int16'), + nbits=Integer(16) + ), + value=Integer(5) + ) + ) + + assert res_type3[0] == Declaration( + Variable( + Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(1) + ) + ) + + assert res_type3[1] == Declaration( + Variable( + Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(5) + ) + ) + + assert res_type3[2] == Declaration( + Variable( + Symbol('c'), + type=IntBaseType(String('intc')), + value=Integer(1) + ) + ) + + assert res_type3[3] == Declaration( + Variable( + Symbol('d'), + type=IntBaseType(String('intc')), + value=Integer(5) + ) + ) + + assert res_type4[0] == Declaration( + Variable( + Symbol('a'), + type=SignedIntType( + String('int64'), + nbits=Integer(64) + ), + value=Integer(1) + ) + ) + + assert res_type4[1] == Declaration( + Variable( + Symbol('b'), + type=SignedIntType( + String('int64'), + nbits=Integer(64) + ), + value=Integer(5) + ) + ) + + assert res_type4[2] == Declaration( + Variable( + Symbol('c'), + type=SignedIntType( + String('int64'), + nbits=Integer(64) + ), + value=Integer(1) + ) + ) + + assert res_type4[3] == Declaration( + Variable( + Symbol('d'), + type=SignedIntType( + String('int64'), + nbits=Integer(64) + ), + value=Integer(5) + ) + ) + + assert res_type5[0] == Declaration( + Variable( + Symbol('a'), + type=UnsignedIntType( + String('uint8'), + nbits=Integer(8) + ), + value=Integer(1) + ) + ) + + assert res_type5[1] == Declaration( + Variable( + Symbol('b'), + type=UnsignedIntType( + String('uint8'), + nbits=Integer(8) + ), + value=Integer(5) + ) + ) + + assert res_type6[0] == Declaration( + Variable( + Symbol('a'), + type=UnsignedIntType( + String('uint16'), + nbits=Integer(16) + ), + value=Integer(1) + ) + ) + + assert res_type6[1] == Declaration( + Variable( + Symbol('b'), + type=UnsignedIntType( + String('uint16'), + nbits=Integer(16) + ), + value=Integer(5) + ) + ) + + assert res_type6[2] == Declaration( + Variable( + Symbol('c'), + type=UnsignedIntType( + String('uint16'), + nbits=Integer(16) + ), + value=Integer(1) + ) + ) + + assert res_type6[3] == Declaration( + Variable( + Symbol('d'), + type=UnsignedIntType( + String('uint16'), + nbits=Integer(16) + ), + value=Integer(5) + ) + ) + + assert res_type7[0] == Declaration( + Variable( + Symbol('a'), + type=UnsignedIntType( + String('uint32'), + nbits=Integer(32) + ), + value=Integer(1) + ) + ) + + assert res_type7[1] == Declaration( + Variable( + Symbol('b'), + type=UnsignedIntType( + String('uint32'), + nbits=Integer(32) + ), + value=Integer(5) + ) + ) + + assert res_type8[0] == Declaration( + Variable( + Symbol('a'), + type=UnsignedIntType( + String('uint64'), + nbits=Integer(64) + ), + value=Integer(1) + ) + ) + + assert res_type8[1] == Declaration( + Variable( + Symbol('b'), + type=UnsignedIntType( + String('uint64'), + nbits=Integer(64) + ), + value=Integer(5) + ) + ) + + assert res_type8[2] == Declaration( + Variable( + Symbol('c'), + type=UnsignedIntType( + String('uint64'), + nbits=Integer(64) + ), + value=Integer(1) + ) + ) + + assert res_type8[3] == Declaration( + Variable( + Symbol('d'), + type=UnsignedIntType( + String('uint64'), + nbits=Integer(64) + ), + value=Integer(5) + ) + ) + + + def test_float(): + c_src1 = 'float a = 1.0;' + c_src2 = ( + 'float a = 1.25;' + '\n' + + 'float b = 2.39;' + '\n' + ) + c_src3 = 'float x = 1, y = 2;' + c_src4 = 'float p = 5, e = 7.89;' + c_src5 = 'float r = true, s = false;' + + # cin.TypeKind.FLOAT + c_src_type1 = 'float x = 1, y = 2.5;' + + # cin.TypeKind.DOUBLE + c_src_type2 = 'double x = 1, y = 2.5;' + + # cin.TypeKind.LONGDOUBLE + c_src_type3 = 'long double x = 1, y = 2.5;' + + res1 = SymPyExpression(c_src1, 'c').return_expr() + res2 = SymPyExpression(c_src2, 'c').return_expr() + res3 = SymPyExpression(c_src3, 'c').return_expr() + res4 = SymPyExpression(c_src4, 'c').return_expr() + res5 = SymPyExpression(c_src5, 'c').return_expr() + + res_type1 = SymPyExpression(c_src_type1, 'c').return_expr() + res_type2 = SymPyExpression(c_src_type2, 'c').return_expr() + res_type3 = SymPyExpression(c_src_type3, 'c').return_expr() + + assert res1[0] == Declaration( + Variable( + Symbol('a'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('1.0', precision=53) + ) + ) + + assert res2[0] == Declaration( + Variable( + Symbol('a'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('1.25', precision=53) + ) + ) + + assert res2[1] == Declaration( + Variable( + Symbol('b'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('2.3900000000000001', precision=53) + ) + ) + + assert res3[0] == Declaration( + Variable( + Symbol('x'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('1.0', precision=53) + ) + ) + + assert res3[1] == Declaration( + Variable( + Symbol('y'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('2.0', precision=53) + ) + ) + + assert res4[0] == Declaration( + Variable( + Symbol('p'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('5.0', precision=53) + ) + ) + + assert res4[1] == Declaration( + Variable( + Symbol('e'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('7.89', precision=53) + ) + ) + + assert res5[0] == Declaration( + Variable( + Symbol('r'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('1.0', precision=53) + ) + ) + + assert res5[1] == Declaration( + Variable( + Symbol('s'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('0.0', precision=53) + ) + ) + + assert res_type1[0] == Declaration( + Variable( + Symbol('x'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('1.0', precision=53) + ) + ) + + assert res_type1[1] == Declaration( + Variable( + Symbol('y'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('2.5', precision=53) + ) + ) + assert res_type2[0] == Declaration( + Variable( + Symbol('x'), + type=FloatType( + String('float64'), + nbits=Integer(64), + nmant=Integer(52), + nexp=Integer(11) + ), + value=Float('1.0', precision=53) + ) + ) + + assert res_type2[1] == Declaration( + Variable( + Symbol('y'), + type=FloatType( + String('float64'), + nbits=Integer(64), + nmant=Integer(52), + nexp=Integer(11) + ), + value=Float('2.5', precision=53) + ) + ) + + assert res_type3[0] == Declaration( + Variable( + Symbol('x'), + type=FloatType( + String('float80'), + nbits=Integer(80), + nmant=Integer(63), + nexp=Integer(15) + ), + value=Float('1.0', precision=53) + ) + ) + + assert res_type3[1] == Declaration( + Variable( + Symbol('y'), + type=FloatType( + String('float80'), + nbits=Integer(80), + nmant=Integer(63), + nexp=Integer(15) + ), + value=Float('2.5', precision=53) + ) + ) + + + def test_bool(): + c_src1 = ( + 'bool a = true, b = false;' + ) + + c_src2 = ( + 'bool a = 1, b = 0;' + ) + + c_src3 = ( + 'bool a = 10, b = 20;' + ) + + c_src4 = ( + 'bool a = 19.1, b = 9.0, c = 0.0;' + ) + + res1 = SymPyExpression(c_src1, 'c').return_expr() + res2 = SymPyExpression(c_src2, 'c').return_expr() + res3 = SymPyExpression(c_src3, 'c').return_expr() + res4 = SymPyExpression(c_src4, 'c').return_expr() + + assert res1[0] == Declaration( + Variable(Symbol('a'), + type=Type(String('bool')), + value=true + ) + ) + + assert res1[1] == Declaration( + Variable(Symbol('b'), + type=Type(String('bool')), + value=false + ) + ) + + assert res2[0] == Declaration( + Variable(Symbol('a'), + type=Type(String('bool')), + value=true) + ) + + assert res2[1] == Declaration( + Variable(Symbol('b'), + type=Type(String('bool')), + value=false + ) + ) + + assert res3[0] == Declaration( + Variable(Symbol('a'), + type=Type(String('bool')), + value=true + ) + ) + + assert res3[1] == Declaration( + Variable(Symbol('b'), + type=Type(String('bool')), + value=true + ) + ) + + assert res4[0] == Declaration( + Variable(Symbol('a'), + type=Type(String('bool')), + value=true) + ) + + assert res4[1] == Declaration( + Variable(Symbol('b'), + type=Type(String('bool')), + value=true + ) + ) + + assert res4[2] == Declaration( + Variable(Symbol('c'), + type=Type(String('bool')), + value=false + ) + ) + + @XFAIL # this is expected to fail because of a bug in the C parser. + def test_function(): + c_src1 = ( + 'void fun1()' + '\n' + + '{' + '\n' + + 'int a;' + '\n' + + '}' + ) + c_src2 = ( + 'int fun2()' + '\n' + + '{'+ '\n' + + 'int a;' + '\n' + + 'return a;' + '\n' + + '}' + ) + c_src3 = ( + 'float fun3()' + '\n' + + '{' + '\n' + + 'float b;' + '\n' + + 'return b;' + '\n' + + '}' + ) + c_src4 = ( + 'float fun4()' + '\n' + + '{}' + ) + + res1 = SymPyExpression(c_src1, 'c').return_expr() + res2 = SymPyExpression(c_src2, 'c').return_expr() + res3 = SymPyExpression(c_src3, 'c').return_expr() + res4 = SymPyExpression(c_src4, 'c').return_expr() + + assert res1[0] == FunctionDefinition( + NoneToken(), + name=String('fun1'), + parameters=(), + body=CodeBlock( + Declaration( + Variable( + Symbol('a'), + type=IntBaseType(String('intc')) + ) + ) + ) + ) + + assert res2[0] == FunctionDefinition( + IntBaseType(String('intc')), + name=String('fun2'), + parameters=(), + body=CodeBlock( + Declaration( + Variable( + Symbol('a'), + type=IntBaseType(String('intc')) + ) + ), + Return('a') + ) + ) + + assert res3[0] == FunctionDefinition( + FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + name=String('fun3'), + parameters=(), + body=CodeBlock( + Declaration( + Variable( + Symbol('b'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ) + ), + Return('b') + ) + ) + + assert res4[0] == FunctionPrototype( + FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + name=String('fun4'), + parameters=() + ) + + @XFAIL # this is expected to fail because of a bug in the C parser. + def test_parameters(): + c_src1 = ( + 'void fun1( int a)' + '\n' + + '{' + '\n' + + 'int i;' + '\n' + + '}' + ) + c_src2 = ( + 'int fun2(float x, float y)' + '\n' + + '{'+ '\n' + + 'int a;' + '\n' + + 'return a;' + '\n' + + '}' + ) + c_src3 = ( + 'float fun3(int p, float q, int r)' + '\n' + + '{' + '\n' + + 'float b;' + '\n' + + 'return b;' + '\n' + + '}' + ) + + res1 = SymPyExpression(c_src1, 'c').return_expr() + res2 = SymPyExpression(c_src2, 'c').return_expr() + res3 = SymPyExpression(c_src3, 'c').return_expr() + + assert res1[0] == FunctionDefinition( + NoneToken(), + name=String('fun1'), + parameters=( + Variable( + Symbol('a'), + type=IntBaseType(String('intc')) + ), + ), + body=CodeBlock( + Declaration( + Variable( + Symbol('i'), + type=IntBaseType(String('intc')) + ) + ) + ) + ) + + assert res2[0] == FunctionDefinition( + IntBaseType(String('intc')), + name=String('fun2'), + parameters=( + Variable( + Symbol('x'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ), + Variable( + Symbol('y'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ) + ), + body=CodeBlock( + Declaration( + Variable( + Symbol('a'), + type=IntBaseType(String('intc')) + ) + ), + Return('a') + ) + ) + + assert res3[0] == FunctionDefinition( + FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + name=String('fun3'), + parameters=( + Variable( + Symbol('p'), + type=IntBaseType(String('intc')) + ), + Variable( + Symbol('q'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ), + Variable( + Symbol('r'), + type=IntBaseType(String('intc')) + ) + ), + body=CodeBlock( + Declaration( + Variable( + Symbol('b'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ) + ), + Return('b') + ) + ) + + @XFAIL # this is expected to fail because of a bug in the C parser. + def test_function_call(): + c_src1 = ( + 'int fun1(int x)' + '\n' + + '{' + '\n' + + 'return x;' + '\n' + + '}' + '\n' + + 'void caller()' + '\n' + + '{' + '\n' + + 'int x = fun1(2);' + '\n' + + '}' + ) + + c_src2 = ( + 'int fun2(int a, int b, int c)' + '\n' + + '{' + '\n' + + 'return a;' + '\n' + + '}' + '\n' + + 'void caller()' + '\n' + + '{' + '\n' + + 'int y = fun2(2, 3, 4);' + '\n' + + '}' + ) + + c_src3 = ( + 'int fun3(int a, int b, int c)' + '\n' + + '{' + '\n' + + 'return b;' + '\n' + + '}' + '\n' + + 'void caller()' + '\n' + + '{' + '\n' + + 'int p;' + '\n' + + 'int q;' + '\n' + + 'int r;' + '\n' + + 'int z = fun3(p, q, r);' + '\n' + + '}' + ) + + c_src4 = ( + 'int fun4(float a, float b, int c)' + '\n' + + '{' + '\n' + + 'return c;' + '\n' + + '}' + '\n' + + 'void caller()' + '\n' + + '{' + '\n' + + 'float x;' + '\n' + + 'float y;' + '\n' + + 'int z;' + '\n' + + 'int i = fun4(x, y, z)' + '\n' + + '}' + ) + + c_src5 = ( + 'int fun()' + '\n' + + '{' + '\n' + + 'return 1;' + '\n' + + '}' + '\n' + + 'void caller()' + '\n' + + '{' + '\n' + + 'int a = fun()' + '\n' + + '}' + ) + + res1 = SymPyExpression(c_src1, 'c').return_expr() + res2 = SymPyExpression(c_src2, 'c').return_expr() + res3 = SymPyExpression(c_src3, 'c').return_expr() + res4 = SymPyExpression(c_src4, 'c').return_expr() + res5 = SymPyExpression(c_src5, 'c').return_expr() + + + assert res1[0] == FunctionDefinition( + IntBaseType(String('intc')), + name=String('fun1'), + parameters=(Variable(Symbol('x'), + type=IntBaseType(String('intc')) + ), + ), + body=CodeBlock( + Return('x') + ) + ) + + assert res1[1] == FunctionDefinition( + NoneToken(), + name=String('caller'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('x'), + value=FunctionCall(String('fun1'), + function_args=( + Integer(2), + ) + ) + ) + ) + ) + ) + + assert res2[0] == FunctionDefinition( + IntBaseType(String('intc')), + name=String('fun2'), + parameters=(Variable(Symbol('a'), + type=IntBaseType(String('intc')) + ), + Variable(Symbol('b'), + type=IntBaseType(String('intc')) + ), + Variable(Symbol('c'), + type=IntBaseType(String('intc')) + ) + ), + body=CodeBlock( + Return('a') + ) + ) + + assert res2[1] == FunctionDefinition( + NoneToken(), + name=String('caller'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('y'), + value=FunctionCall( + String('fun2'), + function_args=( + Integer(2), + Integer(3), + Integer(4) + ) + ) + ) + ) + ) + ) + + assert res3[0] == FunctionDefinition( + IntBaseType(String('intc')), + name=String('fun3'), + parameters=( + Variable(Symbol('a'), + type=IntBaseType(String('intc')) + ), + Variable(Symbol('b'), + type=IntBaseType(String('intc')) + ), + Variable(Symbol('c'), + type=IntBaseType(String('intc')) + ) + ), + body=CodeBlock( + Return('b') + ) + ) + + assert res3[1] == FunctionDefinition( + NoneToken(), + name=String('caller'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('p'), + type=IntBaseType(String('intc')) + ) + ), + Declaration( + Variable(Symbol('q'), + type=IntBaseType(String('intc')) + ) + ), + Declaration( + Variable(Symbol('r'), + type=IntBaseType(String('intc')) + ) + ), + Declaration( + Variable(Symbol('z'), + value=FunctionCall( + String('fun3'), + function_args=( + Symbol('p'), + Symbol('q'), + Symbol('r') + ) + ) + ) + ) + ) + ) + + assert res4[0] == FunctionDefinition( + IntBaseType(String('intc')), + name=String('fun4'), + parameters=(Variable(Symbol('a'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ), + Variable(Symbol('b'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ), + Variable(Symbol('c'), + type=IntBaseType(String('intc')) + ) + ), + body=CodeBlock( + Return('c') + ) + ) + + assert res4[1] == FunctionDefinition( + NoneToken(), + name=String('caller'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('x'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ) + ), + Declaration( + Variable(Symbol('y'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ) + ), + Declaration( + Variable(Symbol('z'), + type=IntBaseType(String('intc')) + ) + ), + Declaration( + Variable(Symbol('i'), + value=FunctionCall(String('fun4'), + function_args=( + Symbol('x'), + Symbol('y'), + Symbol('z') + ) + ) + ) + ) + ) + ) + + assert res5[0] == FunctionDefinition( + IntBaseType(String('intc')), + name=String('fun'), + parameters=(), + body=CodeBlock( + Return('') + ) + ) + + assert res5[1] == FunctionDefinition( + NoneToken(), + name=String('caller'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + value=FunctionCall(String('fun'), + function_args=() + ) + ) + ) + ) + ) + + + def test_parse(): + c_src1 = ( + 'int a;' + '\n' + + 'int b;' + '\n' + ) + c_src2 = ( + 'void fun1()' + '\n' + + '{' + '\n' + + 'int a;' + '\n' + + '}' + ) + + f1 = open('..a.h', 'w') + f2 = open('..b.h', 'w') + + f1.write(c_src1) + f2. write(c_src2) + + f1.close() + f2.close() + + res1 = SymPyExpression('..a.h', 'c').return_expr() + res2 = SymPyExpression('..b.h', 'c').return_expr() + + os.remove('..a.h') + os.remove('..b.h') + + assert res1[0] == Declaration( + Variable( + Symbol('a'), + type=IntBaseType(String('intc')) + ) + ) + assert res1[1] == Declaration( + Variable( + Symbol('b'), + type=IntBaseType(String('intc')) + ) + ) + assert res2[0] == FunctionDefinition( + NoneToken(), + name=String('fun1'), + parameters=(), + body=CodeBlock( + Declaration( + Variable( + Symbol('a'), + type=IntBaseType(String('intc')) + ) + ) + ) + ) + + + def test_binary_operators(): + c_src1 = ( + 'void func()'+ + '{' + '\n' + + 'int a;' + '\n' + + 'a = 1;' + '\n' + + '}' + ) + c_src2 = ( + 'void func()'+ + '{' + '\n' + + 'int a = 0;' + '\n' + + 'a = a + 1;' + '\n' + + 'a = 3*a - 10;' + '\n' + + '}' + ) + c_src3 = ( + 'void func()'+ + '{' + '\n' + + 'int a = 10;' + '\n' + + 'a = 1 + a - 3 * 6;' + '\n' + + '}' + ) + c_src4 = ( + 'void func()'+ + '{' + '\n' + + 'int a;' + '\n' + + 'int b;' + '\n' + + 'a = 100;' + '\n' + + 'b = a*a + a*a + a + 19*a + 1 + 24;' + '\n' + + '}' + ) + c_src5 = ( + 'void func()'+ + '{' + '\n' + + 'int a;' + '\n' + + 'int b;' + '\n' + + 'int c;' + '\n' + + 'int d;' + '\n' + + 'a = 1;' + '\n' + + 'b = 2;' + '\n' + + 'c = b;' + '\n' + + 'd = ((a+b)*(a+c))*((c-d)*(a+c));' + '\n' + + '}' + ) + c_src6 = ( + 'void func()'+ + '{' + '\n' + + 'int a;' + '\n' + + 'int b;' + '\n' + + 'int c;' + '\n' + + 'int d;' + '\n' + + 'a = 1;' + '\n' + + 'b = 2;' + '\n' + + 'c = 3;' + '\n' + + 'd = (a*a*a*a + 3*b*b + b + b + c*d);' + '\n' + + '}' + ) + c_src7 = ( + 'void func()'+ + '{' + '\n' + + 'float a;' + '\n' + + 'a = 1.01;' + '\n' + + '}' + ) + + c_src8 = ( + 'void func()'+ + '{' + '\n' + + 'float a;' + '\n' + + 'a = 10.0 + 2.5;' + '\n' + + '}' + ) + + c_src9 = ( + 'void func()'+ + '{' + '\n' + + 'float a;' + '\n' + + 'a = 10.0 / 2.5;' + '\n' + + '}' + ) + + c_src10 = ( + 'void func()'+ + '{' + '\n' + + 'int a;' + '\n' + + 'a = 100 / 4;' + '\n' + + '}' + ) + + c_src11 = ( + 'void func()'+ + '{' + '\n' + + 'int a;' + '\n' + + 'a = 20 - 100 / 4 * 5 + 10;' + '\n' + + '}' + ) + + c_src12 = ( + 'void func()'+ + '{' + '\n' + + 'int a;' + '\n' + + 'a = (20 - 100) / 4 * (5 + 10);' + '\n' + + '}' + ) + + c_src13 = ( + 'void func()'+ + '{' + '\n' + + 'int a;' + '\n' + + 'int b;' + '\n' + + 'float c;' + '\n' + + 'c = b/a;' + '\n' + + '}' + ) + + c_src14 = ( + 'void func()'+ + '{' + '\n' + + 'int a = 2;' + '\n' + + 'int d = 5;' + '\n' + + 'int n = 10;' + '\n' + + 'int s;' + '\n' + + 's = (a/2)*(2*a + (n-1)*d);' + '\n' + + '}' + ) + + c_src15 = ( + 'void func()'+ + '{' + '\n' + + 'int a;' + '\n' + + 'a = 1 % 2;' + '\n' + + '}' + ) + + c_src16 = ( + 'void func()'+ + '{' + '\n' + + 'int a = 2;' + '\n' + + 'int b;' + '\n' + + 'b = a % 3;' + '\n' + + '}' + ) + + c_src17 = ( + 'void func()'+ + '{' + '\n' + + 'int a = 100;' + '\n' + + 'int b = 3;' + '\n' + + 'int c;' + '\n' + + 'c = a % b;' + '\n' + + '}' + ) + + c_src18 = ( + 'void func()'+ + '{' + '\n' + + 'int a = 100;' + '\n' + + 'int b = 3;' + '\n' + + 'int mod = 1000000007;' + '\n' + + 'int c;' + '\n' + + 'c = (a + b * (100/a)) % mod;' + '\n' + + '}' + ) + + c_src19 = ( + 'void func()'+ + '{' + '\n' + + 'int a = 100;' + '\n' + + 'int b = 3;' + '\n' + + 'int mod = 1000000007;' + '\n' + + 'int c;' + '\n' + + 'c = ((a % mod + b % mod) % mod' \ + '* (a % mod - b % mod) % mod) % mod;' + '\n' + + '}' + ) + + c_src20 = ( + 'void func()'+ + '{' + '\n' + + 'bool a' + '\n' + + 'bool b;' + '\n' + + 'a = 1 == 2;' + '\n' + + 'b = 1 != 2;' + '\n' + + '}' + ) + + c_src21 = ( + 'void func()'+ + '{' + '\n' + + 'bool a;' + '\n' + + 'bool b;' + '\n' + + 'bool c;' + '\n' + + 'bool d;' + '\n' + + 'a = 1 == 2;' + '\n' + + 'b = 1 <= 2;' + '\n' + + 'c = 1 > 2;' + '\n' + + 'd = 1 >= 2;' + '\n' + + '}' + ) + + c_src22 = ( + 'void func()'+ + '{' + '\n' + + 'int a = 1;' + '\n' + + 'int b = 2;' + '\n' + + + 'bool c1;' + '\n' + + 'bool c2;' + '\n' + + 'bool c3;' + '\n' + + 'bool c4;' + '\n' + + 'bool c5;' + '\n' + + 'bool c6;' + '\n' + + 'bool c7;' + '\n' + + 'bool c8;' + '\n' + + + 'c1 = a == 1;' + '\n' + + 'c2 = b == 2;' + '\n' + + + 'c3 = 1 != a;' + '\n' + + 'c4 = 1 != b;' + '\n' + + + 'c5 = a < 0;' + '\n' + + 'c6 = b <= 10;' + '\n' + + 'c7 = a > 0;' + '\n' + + 'c8 = b >= 11;' + '\n' + + '}' + ) + + c_src23 = ( + 'void func()'+ + '{' + '\n' + + 'int a = 3;' + '\n' + + 'int b = 4;' + '\n' + + + 'bool c1;' + '\n' + + 'bool c2;' + '\n' + + 'bool c3;' + '\n' + + 'bool c4;' + '\n' + + 'bool c5;' + '\n' + + 'bool c6;' + '\n' + + + 'c1 = a == b;' + '\n' + + 'c2 = a != b;' + '\n' + + 'c3 = a < b;' + '\n' + + 'c4 = a <= b;' + '\n' + + 'c5 = a > b;' + '\n' + + 'c6 = a >= b;' + '\n' + + '}' + ) + + c_src24 = ( + 'void func()'+ + '{' + '\n' + + 'float a = 1.25' + 'float b = 2.5;' + '\n' + + + 'bool c1;' + '\n' + + 'bool c2;' + '\n' + + 'bool c3;' + '\n' + + 'bool c4;' + '\n' + + + 'c1 = a == 1.25;' + '\n' + + 'c2 = b == 2.54;' + '\n' + + + 'c3 = 1.2 != a;' + '\n' + + 'c4 = 1.5 != b;' + '\n' + + '}' + ) + + c_src25 = ( + 'void func()'+ + '{' + '\n' + + 'float a = 1.25' + '\n' + + 'float b = 2.5;' + '\n' + + + 'bool c1;' + '\n' + + 'bool c2;' + '\n' + + 'bool c3;' + '\n' + + 'bool c4;' + '\n' + + 'bool c5;' + '\n' + + 'bool c6;' + '\n' + + + 'c1 = a == b;' + '\n' + + 'c2 = a != b;' + '\n' + + 'c3 = a < b;' + '\n' + + 'c4 = a <= b;' + '\n' + + 'c5 = a > b;' + '\n' + + 'c6 = a >= b;' + '\n' + + '}' + ) + + c_src26 = ( + 'void func()'+ + '{' + '\n' + + 'bool c1;' + '\n' + + 'bool c2;' + '\n' + + 'bool c3;' + '\n' + + 'bool c4;' + '\n' + + 'bool c5;' + '\n' + + 'bool c6;' + '\n' + + + 'c1 = true == true;' + '\n' + + 'c2 = true == false;' + '\n' + + 'c3 = false == false;' + '\n' + + + 'c4 = true != true;' + '\n' + + 'c5 = true != false;' + '\n' + + 'c6 = false != false;' + '\n' + + '}' + ) + + c_src27 = ( + 'void func()'+ + '{' + '\n' + + 'bool c1;' + '\n' + + 'bool c2;' + '\n' + + 'bool c3;' + '\n' + + 'bool c4;' + '\n' + + 'bool c5;' + '\n' + + 'bool c6;' + '\n' + + + 'c1 = true && true;' + '\n' + + 'c2 = true && false;' + '\n' + + 'c3 = false && false;' + '\n' + + + 'c4 = true || true;' + '\n' + + 'c5 = true || false;' + '\n' + + 'c6 = false || false;' + '\n' + + '}' + ) + + c_src28 = ( + 'void func()'+ + '{' + '\n' + + 'bool a;' + '\n' + + 'bool c1;' + '\n' + + 'bool c2;' + '\n' + + 'bool c3;' + '\n' + + 'bool c4;' + '\n' + + + 'c1 = a && true;' + '\n' + + 'c2 = false && a;' + '\n' + + + 'c3 = true || a;' + '\n' + + 'c4 = a || false;' + '\n' + + '}' + ) + + c_src29 = ( + 'void func()'+ + '{' + '\n' + + 'int a;' + '\n' + + 'bool c1;' + '\n' + + 'bool c2;' + '\n' + + 'bool c3;' + '\n' + + 'bool c4;' + '\n' + + + 'c1 = a && 1;' + '\n' + + 'c2 = a && 0;' + '\n' + + + 'c3 = a || 1;' + '\n' + + 'c4 = 0 || a;' + '\n' + + '}' + ) + + c_src30 = ( + 'void func()'+ + '{' + '\n' + + 'int a;' + '\n' + + 'int b;' + '\n' + + 'bool c;'+ '\n' + + 'bool d;'+ '\n' + + + 'bool c1;' + '\n' + + 'bool c2;' + '\n' + + 'bool c3;' + '\n' + + 'bool c4;' + '\n' + + 'bool c5;' + '\n' + + 'bool c6;' + '\n' + + + 'c1 = a && b;' + '\n' + + 'c2 = a && c;' + '\n' + + 'c3 = c && d;' + '\n' + + + 'c4 = a || b;' + '\n' + + 'c5 = a || c;' + '\n' + + 'c6 = c || d;' + '\n' + + '}' + ) + + c_src_raise1 = ( + 'void func()'+ + '{' + '\n' + + 'int a;' + '\n' + + 'a = -1;' + '\n' + + '}' + ) + + c_src_raise2 = ( + 'void func()'+ + '{' + '\n' + + 'int a;' + '\n' + + 'a = -+1;' + '\n' + + '}' + ) + + c_src_raise3 = ( + 'void func()'+ + '{' + '\n' + + 'int a;' + '\n' + + 'a = 2*-2;' + '\n' + + '}' + ) + + c_src_raise4 = ( + 'void func()'+ + '{' + '\n' + + 'int a;' + '\n' + + 'a = (int)2.0;' + '\n' + + '}' + ) + + c_src_raise5 = ( + 'void func()'+ + '{' + '\n' + + 'int a=100;' + '\n' + + 'a = (a==100)?(1):(0);' + '\n' + + '}' + ) + + res1 = SymPyExpression(c_src1, 'c').return_expr() + res2 = SymPyExpression(c_src2, 'c').return_expr() + res3 = SymPyExpression(c_src3, 'c').return_expr() + res4 = SymPyExpression(c_src4, 'c').return_expr() + res5 = SymPyExpression(c_src5, 'c').return_expr() + res6 = SymPyExpression(c_src6, 'c').return_expr() + res7 = SymPyExpression(c_src7, 'c').return_expr() + res8 = SymPyExpression(c_src8, 'c').return_expr() + res9 = SymPyExpression(c_src9, 'c').return_expr() + res10 = SymPyExpression(c_src10, 'c').return_expr() + res11 = SymPyExpression(c_src11, 'c').return_expr() + res12 = SymPyExpression(c_src12, 'c').return_expr() + res13 = SymPyExpression(c_src13, 'c').return_expr() + res14 = SymPyExpression(c_src14, 'c').return_expr() + res15 = SymPyExpression(c_src15, 'c').return_expr() + res16 = SymPyExpression(c_src16, 'c').return_expr() + res17 = SymPyExpression(c_src17, 'c').return_expr() + res18 = SymPyExpression(c_src18, 'c').return_expr() + res19 = SymPyExpression(c_src19, 'c').return_expr() + res20 = SymPyExpression(c_src20, 'c').return_expr() + res21 = SymPyExpression(c_src21, 'c').return_expr() + res22 = SymPyExpression(c_src22, 'c').return_expr() + res23 = SymPyExpression(c_src23, 'c').return_expr() + res24 = SymPyExpression(c_src24, 'c').return_expr() + res25 = SymPyExpression(c_src25, 'c').return_expr() + res26 = SymPyExpression(c_src26, 'c').return_expr() + res27 = SymPyExpression(c_src27, 'c').return_expr() + res28 = SymPyExpression(c_src28, 'c').return_expr() + res29 = SymPyExpression(c_src29, 'c').return_expr() + res30 = SymPyExpression(c_src30, 'c').return_expr() + + assert res1[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')) + ) + ), + Assignment(Variable(Symbol('a')), Integer(1)) + ) + ) + + assert res2[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(0))), + Assignment( + Variable(Symbol('a')), + Add(Symbol('a'), + Integer(1)) + ), + Assignment(Variable(Symbol('a')), + Add( + Mul( + Integer(3), + Symbol('a')), + Integer(-10) + ) + ) + ) + ) + + assert res3[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(10) + ) + ), + Assignment( + Variable(Symbol('a')), + Add( + Symbol('a'), + Integer(-17) + ) + ) + ) + ) + + assert res4[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')) + ) + ), + Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')) + ) + ), + Assignment( + Variable(Symbol('a')), + Integer(100)), + Assignment( + Variable(Symbol('b')), + Add( + Mul( + Integer(2), + Pow( + Symbol('a'), + Integer(2)) + ), + Mul( + Integer(20), + Symbol('a')), + Integer(25) + ) + ) + ) + ) + + assert res5[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')) + ) + ), + Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')) + ) + ), + Declaration( + Variable(Symbol('c'), + type=IntBaseType(String('intc')) + ) + ), + Declaration( + Variable(Symbol('d'), + type=IntBaseType(String('intc')) + ) + ), + Assignment( + Variable(Symbol('a')), + Integer(1)), + Assignment( + Variable(Symbol('b')), + Integer(2) + ), + Assignment( + Variable(Symbol('c')), + Symbol('b')), + Assignment( + Variable(Symbol('d')), + Mul( + Add( + Symbol('a'), + Symbol('b')), + Pow( + Add( + Symbol('a'), + Symbol('c') + ), + Integer(2) + ), + Add( + Symbol('c'), + Mul( + Integer(-1), + Symbol('d') + ) + ) + ) + ) + ) + ) + + assert res6[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')) + ) + ), + Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')) + ) + ), + Declaration( + Variable(Symbol('c'), + type=IntBaseType(String('intc')) + ) + ), + Declaration( + Variable(Symbol('d'), + type=IntBaseType(String('intc')) + ) + ), + Assignment( + Variable(Symbol('a')), + Integer(1) + ), + Assignment( + Variable(Symbol('b')), + Integer(2) + ), + Assignment( + Variable(Symbol('c')), + Integer(3) + ), + Assignment( + Variable(Symbol('d')), + Add( + Pow( + Symbol('a'), + Integer(4) + ), + Mul( + Integer(3), + Pow( + Symbol('b'), + Integer(2) + ) + ), + Mul( + Integer(2), + Symbol('b') + ), + Mul( + Symbol('c'), + Symbol('d') + ) + ) + ) + ) + ) + + assert res7[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ) + ), + Assignment( + Variable(Symbol('a')), + Float('1.01', precision=53) + ) + ) + ) + + assert res8[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ) + ), + Assignment( + Variable(Symbol('a')), + Float('12.5', precision=53) + ) + ) + ) + + assert res9[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ) + ), + Assignment( + Variable(Symbol('a')), + Float('4.0', precision=53) + ) + ) + ) + + assert res10[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')) + ) + ), + Assignment( + Variable(Symbol('a')), + Integer(25) + ) + ) + ) + + assert res11[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')) + ) + ), + Assignment( + Variable(Symbol('a')), + Integer(-95) + ) + ) + ) + + assert res12[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')) + ) + ), + Assignment( + Variable(Symbol('a')), + Integer(-300) + ) + ) + ) + + assert res13[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')) + ) + ), + Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')) + ) + ), + Declaration( + Variable(Symbol('c'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ) + ), + Assignment( + Variable(Symbol('c')), + Mul( + Pow( + Symbol('a'), + Integer(-1) + ), + Symbol('b') + ) + ) + ) + ) + + assert res14[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(2) + ) + ), + Declaration( + Variable(Symbol('d'), + type=IntBaseType(String('intc')), + value=Integer(5) + ) + ), + Declaration( + Variable(Symbol('n'), + type=IntBaseType(String('intc')), + value=Integer(10) + ) + ), + Declaration( + Variable(Symbol('s'), + type=IntBaseType(String('intc')) + ) + ), + Assignment( + Variable(Symbol('s')), + Mul( + Rational(1, 2), + Symbol('a'), + Add( + Mul( + Integer(2), + Symbol('a') + ), + Mul( + Symbol('d'), + Add( + Symbol('n'), + Integer(-1) + ) + ) + ) + ) + ) + ) + ) + + assert res15[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')) + ) + ), + Assignment( + Variable(Symbol('a')), + Integer(1) + ) + ) + ) + + assert res16[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(2) + ) + ), + Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')) + ) + ), + Assignment( + Variable(Symbol('b')), + Mod( + Symbol('a'), + Integer(3) + ) + ) + ) + ) + + assert res17[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(100) + ) + ), + Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(3) + ) + ), + Declaration( + Variable(Symbol('c'), + type=IntBaseType(String('intc')) + ) + ), + Assignment( + Variable(Symbol('c')), + Mod( + Symbol('a'), + Symbol('b') + ) + ) + ) + ) + + assert res18[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(100) + ) + ), + Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(3) + ) + ), + Declaration( + Variable(Symbol('mod'), + type=IntBaseType(String('intc')), + value=Integer(1000000007) + ) + ), + Declaration( + Variable(Symbol('c'), + type=IntBaseType(String('intc')) + ) + ), + Assignment( + Variable(Symbol('c')), + Mod( + Add( + Symbol('a'), + Mul( + Integer(100), + Pow( + Symbol('a'), + Integer(-1) + ), + Symbol('b') + ) + ), + Symbol('mod') + ) + ) + ) + ) + + assert res19[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(100) + ) + ), + Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(3) + ) + ), + Declaration( + Variable(Symbol('mod'), + type=IntBaseType(String('intc')), + value=Integer(1000000007) + ) + ), + Declaration( + Variable(Symbol('c'), + type=IntBaseType(String('intc')) + ) + ), + Assignment( + Variable(Symbol('c')), + Mod( + Mul( + Add( + Mod( + Symbol('a'), + Symbol('mod') + ), + Mul( + Integer(-1), + Mod( + Symbol('b'), + Symbol('mod') + ) + ) + ), + Mod( + Add( + Symbol('a'), + Symbol('b') + ), + Symbol('mod') + ) + ), + Symbol('mod') + ) + ) + ) + ) + + assert res20[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('b'), + type=Type(String('bool')) + ) + ), + Assignment( + Variable(Symbol('a')), + false + ), + Assignment( + Variable(Symbol('b')), + true + ) + ) + ) + + assert res21[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('b'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('d'), + type=Type(String('bool')) + ) + ), + Assignment( + Variable(Symbol('a')), + false + ), + Assignment( + Variable(Symbol('b')), + true + ), + Assignment( + Variable(Symbol('c')), + false + ), + Assignment( + Variable(Symbol('d')), + false + ) + ) + ) + + assert res22[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(1) + ) + ), + Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(2) + ) + ), + Declaration( + Variable(Symbol('c1'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c2'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c5'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c6'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c7'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c8'), + type=Type(String('bool')) + ) + ), + Assignment( + Variable(Symbol('c1')), + Equality( + Symbol('a'), + Integer(1) + ) + ), + Assignment( + Variable(Symbol('c2')), + Equality( + Symbol('b'), + Integer(2) + ) + ), + Assignment( + Variable(Symbol('c3')), + Unequality( + Integer(1), + Symbol('a') + ) + ), + Assignment( + Variable(Symbol('c4')), + Unequality( + Integer(1), + Symbol('b') + ) + ), + Assignment( + Variable(Symbol('c5')), + StrictLessThan( + Symbol('a'), + Integer(0) + ) + ), + Assignment( + Variable(Symbol('c6')), + LessThan( + Symbol('b'), + Integer(10) + ) + ), + Assignment( + Variable(Symbol('c7')), + StrictGreaterThan( + Symbol('a'), + Integer(0) + ) + ), + Assignment( + Variable(Symbol('c8')), + GreaterThan( + Symbol('b'), + Integer(11) + ) + ) + ) + ) + + assert res23[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(3) + ) + ), + Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(4) + ) + ), + Declaration( + Variable(Symbol('c1'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c2'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c5'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c6'), + type=Type(String('bool')) + ) + ), + Assignment( + Variable(Symbol('c1')), + Equality( + Symbol('a'), + Symbol('b') + ) + ), + Assignment( + Variable(Symbol('c2')), + Unequality( + Symbol('a'), + Symbol('b') + ) + ), + Assignment( + Variable(Symbol('c3')), + StrictLessThan( + Symbol('a'), + Symbol('b') + ) + ), + Assignment( + Variable(Symbol('c4')), + LessThan( + Symbol('a'), + Symbol('b') + ) + ), + Assignment( + Variable(Symbol('c5')), + StrictGreaterThan( + Symbol('a'), + Symbol('b') + ) + ), + Assignment( + Variable(Symbol('c6')), + GreaterThan( + Symbol('a'), + Symbol('b') + ) + ) + ) + ) + + assert res24[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ) + ), + Declaration( + Variable(Symbol('c1'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c2'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')) + ) + ), + Assignment( + Variable(Symbol('c1')), + Equality( + Symbol('a'), + Float('1.25', precision=53) + ) + ), + Assignment( + Variable(Symbol('c3')), + Unequality( + Float('1.2', precision=53), + Symbol('a') + ) + ) + ) + ) + + + assert res25[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('1.25', precision=53) + ) + ), + Declaration( + Variable(Symbol('b'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('2.5', precision=53) + ) + ), + Declaration( + Variable(Symbol('c1'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c2'), + type=Type(String('bool') + ) + ) + ), + Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c5'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c6'), + type=Type(String('bool')) + ) + ), + Assignment( + Variable(Symbol('c1')), + Equality( + Symbol('a'), + Symbol('b') + ) + ), + Assignment( + Variable(Symbol('c2')), + Unequality( + Symbol('a'), + Symbol('b') + ) + ), + Assignment( + Variable(Symbol('c3')), + StrictLessThan( + Symbol('a'), + Symbol('b') + ) + ), + Assignment( + Variable(Symbol('c4')), + LessThan( + Symbol('a'), + Symbol('b') + ) + ), + Assignment( + Variable(Symbol('c5')), + StrictGreaterThan( + Symbol('a'), + Symbol('b') + ) + ), + Assignment( + Variable(Symbol('c6')), + GreaterThan( + Symbol('a'), + Symbol('b') + ) + ) + ) + ) + + assert res26[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), body=CodeBlock( + Declaration( + Variable(Symbol('c1'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c2'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c5'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c6'), + type=Type(String('bool')) + ) + ), + Assignment( + Variable(Symbol('c1')), + true + ), + Assignment( + Variable(Symbol('c2')), + false + ), + Assignment( + Variable(Symbol('c3')), + true + ), + Assignment( + Variable(Symbol('c4')), + false + ), + Assignment( + Variable(Symbol('c5')), + true + ), + Assignment( + Variable(Symbol('c6')), + false + ) + ) + ) + + assert res27[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('c1'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c2'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c5'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c6'), + type=Type(String('bool')) + ) + ), + Assignment( + Variable(Symbol('c1')), + true + ), + Assignment( + Variable(Symbol('c2')), + false + ), + Assignment( + Variable(Symbol('c3')), + false + ), + Assignment( + Variable(Symbol('c4')), + true + ), + Assignment( + Variable(Symbol('c5')), + true + ), + Assignment( + Variable(Symbol('c6')), + false) + ) + ) + + assert res28[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c1'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c2'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')) + ) + ), + Assignment( + Variable(Symbol('c1')), + Symbol('a') + ), + Assignment( + Variable(Symbol('c2')), + false + ), + Assignment( + Variable(Symbol('c3')), + true + ), + Assignment( + Variable(Symbol('c4')), + Symbol('a') + ) + ) + ) + + assert res29[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')) + ) + ), + Declaration( + Variable(Symbol('c1'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c2'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')) + ) + ), + Assignment( + Variable(Symbol('c1')), + Symbol('a') + ), + Assignment( + Variable(Symbol('c2')), + false + ), + Assignment( + Variable(Symbol('c3')), + true + ), + Assignment( + Variable(Symbol('c4')), + Symbol('a') + ) + ) + ) + + assert res30[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')) + ) + ), + Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')) + ) + ), + Declaration( + Variable(Symbol('c'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('d'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c1'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c2'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c5'), + type=Type(String('bool')) + ) + ), + Declaration( + Variable(Symbol('c6'), + type=Type(String('bool')) + ) + ), + Assignment( + Variable(Symbol('c1')), + And( + Symbol('a'), + Symbol('b') + ) + ), + Assignment( + Variable(Symbol('c2')), + And( + Symbol('a'), + Symbol('c') + ) + ), + Assignment( + Variable(Symbol('c3')), + And( + Symbol('c'), + Symbol('d') + ) + ), + Assignment( + Variable(Symbol('c4')), + Or( + Symbol('a'), + Symbol('b') + ) + ), + Assignment( + Variable(Symbol('c5')), + Or( + Symbol('a'), + Symbol('c') + ) + ), + Assignment( + Variable(Symbol('c6')), + Or( + Symbol('c'), + Symbol('d') + ) + ) + ) + ) + + raises(NotImplementedError, lambda: SymPyExpression(c_src_raise1, 'c')) + raises(NotImplementedError, lambda: SymPyExpression(c_src_raise2, 'c')) + raises(NotImplementedError, lambda: SymPyExpression(c_src_raise3, 'c')) + raises(NotImplementedError, lambda: SymPyExpression(c_src_raise4, 'c')) + raises(NotImplementedError, lambda: SymPyExpression(c_src_raise5, 'c')) + + + @XFAIL + def test_var_decl(): + c_src1 = ( + 'int b = 100;' + '\n' + + 'int a = b;' + '\n' + ) + + c_src2 = ( + 'int a = 1;' + '\n' + + 'int b = a + 1;' + '\n' + ) + + c_src3 = ( + 'float a = 10.0 + 2.5;' + '\n' + + 'float b = a * 20.0;' + '\n' + ) + + c_src4 = ( + 'int a = 1 + 100 - 3 * 6;' + '\n' + ) + + c_src5 = ( + 'int a = (((1 + 100) * 12) - 3) * (6 - 10);' + '\n' + ) + + c_src6 = ( + 'int b = 2;' + '\n' + + 'int c = 3;' + '\n' + + 'int a = b + c * 4;' + '\n' + ) + + c_src7 = ( + 'int b = 1;' + '\n' + + 'int c = b + 2;' + '\n' + + 'int a = 10 * b * b * c;' + '\n' + ) + + c_src8 = ( + 'void func()'+ + '{' + '\n' + + 'int a = 1;' + '\n' + + 'int b = 2;' + '\n' + + 'int temp = a;' + '\n' + + 'a = b;' + '\n' + + 'b = temp;' + '\n' + + '}' + ) + + c_src9 = ( + 'int a = 1;' + '\n' + + 'int b = 2;' + '\n' + + 'int c = a;' + '\n' + + 'int d = a + b + c;' + '\n' + + 'int e = a*a*a + 3*a*a*b + 3*a*b*b + b*b*b;' + '\n' + 'int f = (a + b + c) * (a + b - c);' + '\n' + + 'int g = (a + b + c + d)*(a + b + c + d)*(a * (b - c));' + + '\n' + ) + + c_src10 = ( + 'float a = 10.0;' + '\n' + + 'float b = 2.5;' + '\n' + + 'float c = a*a + 2*a*b + b*b;' + '\n' + ) + + c_src11 = ( + 'float a = 10.0 / 2.5;' + '\n' + ) + + c_src12 = ( + 'int a = 100 / 4;' + '\n' + ) + + c_src13 = ( + 'int a = 20 - 100 / 4 * 5 + 10;' + '\n' + ) + + c_src14 = ( + 'int a = (20 - 100) / 4 * (5 + 10);' + '\n' + ) + + c_src15 = ( + 'int a = 4;' + '\n' + + 'int b = 2;' + '\n' + + 'float c = b/a;' + '\n' + ) + + c_src16 = ( + 'int a = 2;' + '\n' + + 'int d = 5;' + '\n' + + 'int n = 10;' + '\n' + + 'int s = (a/2)*(2*a + (n-1)*d);' + '\n' + ) + + c_src17 = ( + 'int a = 1 % 2;' + '\n' + ) + + c_src18 = ( + 'int a = 2;' + '\n' + + 'int b = a % 3;' + '\n' + ) + + c_src19 = ( + 'int a = 100;' + '\n' + + 'int b = 3;' + '\n' + + 'int c = a % b;' + '\n' + ) + + c_src20 = ( + 'int a = 100;' + '\n' + + 'int b = 3;' + '\n' + + 'int mod = 1000000007;' + '\n' + + 'int c = (a + b * (100/a)) % mod;' + '\n' + ) + + c_src21 = ( + 'int a = 100;' + '\n' + + 'int b = 3;' + '\n' + + 'int mod = 1000000007;' + '\n' + + 'int c = ((a % mod + b % mod) % mod *' \ + '(a % mod - b % mod) % mod) % mod;' + '\n' + ) + + c_src22 = ( + 'bool a = 1 == 2, b = 1 != 2;' + ) + + c_src23 = ( + 'bool a = 1 < 2, b = 1 <= 2, c = 1 > 2, d = 1 >= 2;' + ) + + c_src24 = ( + 'int a = 1, b = 2;' + '\n' + + + 'bool c1 = a == 1;' + '\n' + + 'bool c2 = b == 2;' + '\n' + + + 'bool c3 = 1 != a;' + '\n' + + 'bool c4 = 1 != b;' + '\n' + + + 'bool c5 = a < 0;' + '\n' + + 'bool c6 = b <= 10;' + '\n' + + 'bool c7 = a > 0;' + '\n' + + 'bool c8 = b >= 11;' + + ) + + c_src25 = ( + 'int a = 3, b = 4;' + '\n' + + + 'bool c1 = a == b;' + '\n' + + 'bool c2 = a != b;' + '\n' + + 'bool c3 = a < b;' + '\n' + + 'bool c4 = a <= b;' + '\n' + + 'bool c5 = a > b;' + '\n' + + 'bool c6 = a >= b;' + ) + + c_src26 = ( + 'float a = 1.25, b = 2.5;' + '\n' + + + 'bool c1 = a == 1.25;' + '\n' + + 'bool c2 = b == 2.54;' + '\n' + + + 'bool c3 = 1.2 != a;' + '\n' + + 'bool c4 = 1.5 != b;' + ) + + c_src27 = ( + 'float a = 1.25, b = 2.5;' + '\n' + + + 'bool c1 = a == b;' + '\n' + + 'bool c2 = a != b;' + '\n' + + 'bool c3 = a < b;' + '\n' + + 'bool c4 = a <= b;' + '\n' + + 'bool c5 = a > b;' + '\n' + + 'bool c6 = a >= b;' + ) + + c_src28 = ( + 'bool c1 = true == true;' + '\n' + + 'bool c2 = true == false;' + '\n' + + 'bool c3 = false == false;' + '\n' + + + 'bool c4 = true != true;' + '\n' + + 'bool c5 = true != false;' + '\n' + + 'bool c6 = false != false;' + ) + + c_src29 = ( + 'bool c1 = true && true;' + '\n' + + 'bool c2 = true && false;' + '\n' + + 'bool c3 = false && false;' + '\n' + + + 'bool c4 = true || true;' + '\n' + + 'bool c5 = true || false;' + '\n' + + 'bool c6 = false || false;' + ) + + c_src30 = ( + 'bool a = false;' + '\n' + + + 'bool c1 = a && true;' + '\n' + + 'bool c2 = false && a;' + '\n' + + + 'bool c3 = true || a;' + '\n' + + 'bool c4 = a || false;' + ) + + c_src31 = ( + 'int a = 1;' + '\n' + + + 'bool c1 = a && 1;' + '\n' + + 'bool c2 = a && 0;' + '\n' + + + 'bool c3 = a || 1;' + '\n' + + 'bool c4 = 0 || a;' + ) + + c_src32 = ( + 'int a = 1, b = 0;' + '\n' + + 'bool c = false, d = true;'+ '\n' + + + 'bool c1 = a && b;' + '\n' + + 'bool c2 = a && c;' + '\n' + + 'bool c3 = c && d;' + '\n' + + + 'bool c4 = a || b;' + '\n' + + 'bool c5 = a || c;' + '\n' + + 'bool c6 = c || d;' + ) + + c_src_raise1 = ( + "char a = 'b';" + ) + + c_src_raise2 = ( + 'int a[] = {10, 20};' + ) + + res1 = SymPyExpression(c_src1, 'c').return_expr() + res2 = SymPyExpression(c_src2, 'c').return_expr() + res3 = SymPyExpression(c_src3, 'c').return_expr() + res4 = SymPyExpression(c_src4, 'c').return_expr() + res5 = SymPyExpression(c_src5, 'c').return_expr() + res6 = SymPyExpression(c_src6, 'c').return_expr() + res7 = SymPyExpression(c_src7, 'c').return_expr() + res8 = SymPyExpression(c_src8, 'c').return_expr() + res9 = SymPyExpression(c_src9, 'c').return_expr() + res10 = SymPyExpression(c_src10, 'c').return_expr() + res11 = SymPyExpression(c_src11, 'c').return_expr() + res12 = SymPyExpression(c_src12, 'c').return_expr() + res13 = SymPyExpression(c_src13, 'c').return_expr() + res14 = SymPyExpression(c_src14, 'c').return_expr() + res15 = SymPyExpression(c_src15, 'c').return_expr() + res16 = SymPyExpression(c_src16, 'c').return_expr() + res17 = SymPyExpression(c_src17, 'c').return_expr() + res18 = SymPyExpression(c_src18, 'c').return_expr() + res19 = SymPyExpression(c_src19, 'c').return_expr() + res20 = SymPyExpression(c_src20, 'c').return_expr() + res21 = SymPyExpression(c_src21, 'c').return_expr() + res22 = SymPyExpression(c_src22, 'c').return_expr() + res23 = SymPyExpression(c_src23, 'c').return_expr() + res24 = SymPyExpression(c_src24, 'c').return_expr() + res25 = SymPyExpression(c_src25, 'c').return_expr() + res26 = SymPyExpression(c_src26, 'c').return_expr() + res27 = SymPyExpression(c_src27, 'c').return_expr() + res28 = SymPyExpression(c_src28, 'c').return_expr() + res29 = SymPyExpression(c_src29, 'c').return_expr() + res30 = SymPyExpression(c_src30, 'c').return_expr() + res31 = SymPyExpression(c_src31, 'c').return_expr() + res32 = SymPyExpression(c_src32, 'c').return_expr() + + assert res1[0] == Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(100) + ) + ) + + assert res1[1] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Symbol('b') + ) + ) + + assert res2[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(1) + ) + ) + + assert res2[1] == Declaration(Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Add( + Symbol('a'), + Integer(1) + ) + ) + ) + + assert res3[0] == Declaration( + Variable(Symbol('a'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('12.5', precision=53) + ) + ) + + assert res3[1] == Declaration( + Variable(Symbol('b'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Mul( + Float('20.0', precision=53), + Symbol('a') + ) + ) + ) + + assert res4[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(83) + ) + ) + + assert res5[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(-4836) + ) + ) + + assert res6[0] == Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(2) + ) + ) + + assert res6[1] == Declaration( + Variable(Symbol('c'), + type=IntBaseType(String('intc')), + value=Integer(3) + ) + ) + + assert res6[2] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Add( + Symbol('b'), + Mul( + Integer(4), + Symbol('c') + ) + ) + ) + ) + + assert res7[0] == Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(1) + ) + ) + + assert res7[1] == Declaration( + Variable(Symbol('c'), + type=IntBaseType(String('intc')), + value=Add( + Symbol('b'), + Integer(2) + ) + ) + ) + + assert res7[2] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Mul( + Integer(10), + Pow( + Symbol('b'), + Integer(2) + ), + Symbol('c') + ) + ) + ) + + assert res8[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(1) + ) + ), + Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(2) + ) + ), + Declaration( + Variable(Symbol('temp'), + type=IntBaseType(String('intc')), + value=Symbol('a') + ) + ), + Assignment( + Variable(Symbol('a')), + Symbol('b') + ), + Assignment( + Variable(Symbol('b')), + Symbol('temp') + ) + ) + ) + + assert res9[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(1) + ) + ) + + assert res9[1] == Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(2) + ) + ) + + assert res9[2] == Declaration( + Variable(Symbol('c'), + type=IntBaseType(String('intc')), + value=Symbol('a') + ) + ) + + assert res9[3] == Declaration( + Variable(Symbol('d'), + type=IntBaseType(String('intc')), + value=Add( + Symbol('a'), + Symbol('b'), + Symbol('c') + ) + ) + ) + + assert res9[4] == Declaration( + Variable(Symbol('e'), + type=IntBaseType(String('intc')), + value=Add( + Pow( + Symbol('a'), + Integer(3) + ), + Mul( + Integer(3), + Pow( + Symbol('a'), + Integer(2) + ), + Symbol('b') + ), + Mul( + Integer(3), + Symbol('a'), + Pow( + Symbol('b'), + Integer(2) + ) + ), + Pow( + Symbol('b'), + Integer(3) + ) + ) + ) + ) + + assert res9[5] == Declaration( + Variable(Symbol('f'), + type=IntBaseType(String('intc')), + value=Mul( + Add( + Symbol('a'), + Symbol('b'), + Mul( + Integer(-1), + Symbol('c') + ) + ), + Add( + Symbol('a'), + Symbol('b'), + Symbol('c') + ) + ) + ) + ) + + assert res9[6] == Declaration( + Variable(Symbol('g'), + type=IntBaseType(String('intc')), + value=Mul( + Symbol('a'), + Add( + Symbol('b'), + Mul( + Integer(-1), + Symbol('c') + ) + ), + Pow( + Add( + Symbol('a'), + Symbol('b'), + Symbol('c'), + Symbol('d') + ), + Integer(2) + ) + ) + ) + ) + + assert res10[0] == Declaration( + Variable(Symbol('a'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('10.0', precision=53) + ) + ) + + assert res10[1] == Declaration( + Variable(Symbol('b'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('2.5', precision=53) + ) + ) + + assert res10[2] == Declaration( + Variable(Symbol('c'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Add( + Pow( + Symbol('a'), + Integer(2) + ), + Mul( + Integer(2), + Symbol('a'), + Symbol('b') + ), + Pow( + Symbol('b'), + Integer(2) + ) + ) + ) + ) + + assert res11[0] == Declaration( + Variable(Symbol('a'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('4.0', precision=53) + ) + ) + + assert res12[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(25) + ) + ) + + assert res13[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(-95) + ) + ) + + assert res14[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(-300) + ) + ) + + assert res15[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(4) + ) + ) + + assert res15[1] == Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(2) + ) + ) + + assert res15[2] == Declaration( + Variable(Symbol('c'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Mul( + Pow( + Symbol('a'), + Integer(-1) + ), + Symbol('b') + ) + ) + ) + + assert res16[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(2) + ) + ) + + assert res16[1] == Declaration( + Variable(Symbol('d'), + type=IntBaseType(String('intc')), + value=Integer(5) + ) + ) + + assert res16[2] == Declaration( + Variable(Symbol('n'), + type=IntBaseType(String('intc')), + value=Integer(10) + ) + ) + + assert res16[3] == Declaration( + Variable(Symbol('s'), + type=IntBaseType(String('intc')), + value=Mul( + Rational(1, 2), + Symbol('a'), + Add( + Mul( + Integer(2), + Symbol('a') + ), + Mul( + Symbol('d'), + Add( + Symbol('n'), + Integer(-1) + ) + ) + ) + ) + ) + ) + + assert res17[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(1) + ) + ) + + assert res18[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(2) + ) + ) + + assert res18[1] == Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Mod( + Symbol('a'), + Integer(3) + ) + ) + ) + + assert res19[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(100) + ) + ) + assert res19[1] == Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(3) + ) + ) + + assert res19[2] == Declaration( + Variable(Symbol('c'), + type=IntBaseType(String('intc')), + value=Mod( + Symbol('a'), + Symbol('b') + ) + ) + ) + + assert res20[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(100) + ) + ) + + assert res20[1] == Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(3) + ) + ) + + assert res20[2] == Declaration( + Variable(Symbol('mod'), + type=IntBaseType(String('intc')), + value=Integer(1000000007) + ) + ) + + assert res20[3] == Declaration( + Variable(Symbol('c'), + type=IntBaseType(String('intc')), + value=Mod( + Add( + Symbol('a'), + Mul( + Integer(100), + Pow( + Symbol('a'), + Integer(-1) + ), + Symbol('b') + ) + ), + Symbol('mod') + ) + ) + ) + + assert res21[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(100) + ) + ) + + assert res21[1] == Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(3) + ) + ) + + assert res21[2] == Declaration( + Variable(Symbol('mod'), + type=IntBaseType(String('intc')), + value=Integer(1000000007) + ) + ) + + assert res21[3] == Declaration( + Variable(Symbol('c'), + type=IntBaseType(String('intc')), + value=Mod( + Mul( + Add( + Symbol('a'), + Mul( + Integer(-1), + Symbol('b') + ) + ), + Add( + Symbol('a'), + Symbol('b') + ) + ), + Symbol('mod') + ) + ) + ) + + assert res22[0] == Declaration( + Variable(Symbol('a'), + type=Type(String('bool')), + value=false + ) + ) + + assert res22[1] == Declaration( + Variable(Symbol('b'), + type=Type(String('bool')), + value=true + ) + ) + + assert res23[0] == Declaration( + Variable(Symbol('a'), + type=Type(String('bool')), + value=true + ) + ) + + assert res23[1] == Declaration( + Variable(Symbol('b'), + type=Type(String('bool')), + value=true + ) + ) + + assert res23[2] == Declaration( + Variable(Symbol('c'), + type=Type(String('bool')), + value=false + ) + ) + + assert res23[3] == Declaration( + Variable(Symbol('d'), + type=Type(String('bool')), + value=false + ) + ) + + assert res24[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(1) + ) + ) + + assert res24[1] == Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(2) + ) + ) + + assert res24[2] == Declaration( + Variable(Symbol('c1'), + type=Type(String('bool')), + value=Equality( + Symbol('a'), + Integer(1) + ) + ) + ) + + assert res24[3] == Declaration( + Variable(Symbol('c2'), + type=Type(String('bool')), + value=Equality( + Symbol('b'), + Integer(2) + ) + ) + ) + + assert res24[4] == Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')), + value=Unequality( + Integer(1), + Symbol('a') + ) + ) + ) + + assert res24[5] == Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')), + value=Unequality( + Integer(1), + Symbol('b') + ) + ) + ) + + assert res24[6] == Declaration( + Variable(Symbol('c5'), + type=Type(String('bool')), + value=StrictLessThan(Symbol('a'), + Integer(0) + ) + ) + ) + + assert res24[7] == Declaration( + Variable(Symbol('c6'), + type=Type(String('bool')), + value=LessThan( + Symbol('b'), + Integer(10) + ) + ) + ) + + assert res24[8] == Declaration( + Variable(Symbol('c7'), + type=Type(String('bool')), + value=StrictGreaterThan( + Symbol('a'), + Integer(0) + ) + ) + ) + + assert res24[9] == Declaration( + Variable(Symbol('c8'), + type=Type(String('bool')), + value=GreaterThan( + Symbol('b'), + Integer(11) + ) + ) + ) + + assert res25[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(3) + ) + ) + + assert res25[1] == Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(4) + ) + ) + + assert res25[2] == Declaration(Variable(Symbol('c1'), + type=Type(String('bool')), + value=Equality( + Symbol('a'), + Symbol('b') + ) + ) + ) + + assert res25[3] == Declaration( + Variable(Symbol('c2'), + type=Type(String('bool')), + value=Unequality( + Symbol('a'), + Symbol('b') + ) + ) + ) + + assert res25[4] == Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')), + value=StrictLessThan( + Symbol('a'), + Symbol('b') + ) + ) + ) + + assert res25[5] == Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')), + value=LessThan( + Symbol('a'), + Symbol('b') + ) + ) + ) + + assert res25[6] == Declaration( + Variable(Symbol('c5'), + type=Type(String('bool')), + value=StrictGreaterThan( + Symbol('a'), + Symbol('b') + ) + ) + ) + + assert res25[7] == Declaration( + Variable(Symbol('c6'), + type=Type(String('bool')), + value=GreaterThan( + Symbol('a'), + Symbol('b') + ) + ) + ) + + assert res26[0] == Declaration( + Variable(Symbol('a'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('1.25', precision=53) + ) + ) + + assert res26[1] == Declaration( + Variable(Symbol('b'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('2.5', precision=53) + ) + ) + + assert res26[2] == Declaration( + Variable(Symbol('c1'), + type=Type(String('bool')), + value=Equality( + Symbol('a'), + Float('1.25', precision=53) + ) + ) + ) + + assert res26[3] == Declaration( + Variable(Symbol('c2'), + type=Type(String('bool')), + value=Equality( + Symbol('b'), + Float('2.54', precision=53) + ) + ) + ) + + assert res26[4] == Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')), + value=Unequality( + Float('1.2', precision=53), + Symbol('a') + ) + ) + ) + + assert res26[5] == Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')), + value=Unequality( + Float('1.5', precision=53), + Symbol('b') + ) + ) + ) + + assert res27[0] == Declaration( + Variable(Symbol('a'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('1.25', precision=53) + ) + ) + + assert res27[1] == Declaration( + Variable(Symbol('b'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('2.5', precision=53) + ) + ) + + assert res27[2] == Declaration( + Variable(Symbol('c1'), + type=Type(String('bool')), + value=Equality( + Symbol('a'), + Symbol('b') + ) + ) + ) + + assert res27[3] == Declaration( + Variable(Symbol('c2'), + type=Type(String('bool')), + value=Unequality( + Symbol('a'), + Symbol('b') + ) + ) + ) + + assert res27[4] == Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')), + value=StrictLessThan( + Symbol('a'), + Symbol('b') + ) + ) + ) + + assert res27[5] == Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')), + value=LessThan( + Symbol('a'), + Symbol('b') + ) + ) + ) + + assert res27[6] == Declaration( + Variable(Symbol('c5'), + type=Type(String('bool')), + value=StrictGreaterThan( + Symbol('a'), + Symbol('b') + ) + ) + ) + + assert res27[7] == Declaration( + Variable(Symbol('c6'), + type=Type(String('bool')), + value=GreaterThan( + Symbol('a'), + Symbol('b') + ) + ) + ) + + assert res28[0] == Declaration( + Variable(Symbol('c1'), + type=Type(String('bool')), + value=true + ) + ) + + assert res28[1] == Declaration( + Variable(Symbol('c2'), + type=Type(String('bool')), + value=false + ) + ) + + assert res28[2] == Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')), + value=true + ) + ) + + assert res28[3] == Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')), + value=false + ) + ) + + assert res28[4] == Declaration( + Variable(Symbol('c5'), + type=Type(String('bool')), + value=true + ) + ) + + assert res28[5] == Declaration( + Variable(Symbol('c6'), + type=Type(String('bool')), + value=false + ) + ) + + assert res29[0] == Declaration( + Variable(Symbol('c1'), + type=Type(String('bool')), + value=true + ) + ) + + assert res29[1] == Declaration( + Variable(Symbol('c2'), + type=Type(String('bool')), + value=false + ) + ) + + assert res29[2] == Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')), + value=false + ) + ) + + assert res29[3] == Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')), + value=true + ) + ) + + assert res29[4] == Declaration( + Variable(Symbol('c5'), + type=Type(String('bool')), + value=true + ) + ) + + assert res29[5] == Declaration( + Variable(Symbol('c6'), + type=Type(String('bool')), + value=false + ) + ) + + assert res30[0] == Declaration( + Variable(Symbol('a'), + type=Type(String('bool')), + value=false + ) + ) + + assert res30[1] == Declaration( + Variable(Symbol('c1'), + type=Type(String('bool')), + value=Symbol('a') + ) + ) + + assert res30[2] == Declaration( + Variable(Symbol('c2'), + type=Type(String('bool')), + value=false + ) + ) + + assert res30[3] == Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')), + value=true + ) + ) + + assert res30[4] == Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')), + value=Symbol('a') + ) + ) + + assert res31[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(1) + ) + ) + + assert res31[1] == Declaration( + Variable(Symbol('c1'), + type=Type(String('bool')), + value=Symbol('a') + ) + ) + + assert res31[2] == Declaration( + Variable(Symbol('c2'), + type=Type(String('bool')), + value=false + ) + ) + + assert res31[3] == Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')), + value=true + ) + ) + + assert res31[4] == Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')), + value=Symbol('a') + ) + ) + + assert res32[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(1) + ) + ) + + assert res32[1] == Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(0) + ) + ) + + assert res32[2] == Declaration( + Variable(Symbol('c'), + type=Type(String('bool')), + value=false + ) + ) + + assert res32[3] == Declaration( + Variable(Symbol('d'), + type=Type(String('bool')), + value=true + ) + ) + + assert res32[4] == Declaration( + Variable(Symbol('c1'), + type=Type(String('bool')), + value=And( + Symbol('a'), + Symbol('b') + ) + ) + ) + + assert res32[5] == Declaration( + Variable(Symbol('c2'), + type=Type(String('bool')), + value=And( + Symbol('a'), + Symbol('c') + ) + ) + ) + + assert res32[6] == Declaration( + Variable(Symbol('c3'), + type=Type(String('bool')), + value=And( + Symbol('c'), + Symbol('d') + ) + ) + ) + + assert res32[7] == Declaration( + Variable(Symbol('c4'), + type=Type(String('bool')), + value=Or( + Symbol('a'), + Symbol('b') + ) + ) + ) + + assert res32[8] == Declaration( + Variable(Symbol('c5'), + type=Type(String('bool')), + value=Or( + Symbol('a'), + Symbol('c') + ) + ) + ) + + assert res32[9] == Declaration( + Variable(Symbol('c6'), + type=Type(String('bool')), + value=Or( + Symbol('c'), + Symbol('d') + ) + ) + ) + + raises(NotImplementedError, lambda: SymPyExpression(c_src_raise1, 'c')) + raises(NotImplementedError, lambda: SymPyExpression(c_src_raise2, 'c')) + + + def test_paren_expr(): + c_src1 = ( + 'int a = (1);' + 'int b = (1 + 2 * 3);' + ) + + c_src2 = ( + 'int a = 1, b = 2, c = 3;' + 'int d = (a);' + 'int e = (a + 1);' + 'int f = (a + b * c - d / e);' + ) + + res1 = SymPyExpression(c_src1, 'c').return_expr() + res2 = SymPyExpression(c_src2, 'c').return_expr() + + assert res1[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(1) + ) + ) + + assert res1[1] == Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(7) + ) + ) + + assert res2[0] == Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(1) + ) + ) + + assert res2[1] == Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(2) + ) + ) + + assert res2[2] == Declaration( + Variable(Symbol('c'), + type=IntBaseType(String('intc')), + value=Integer(3) + ) + ) + + assert res2[3] == Declaration( + Variable(Symbol('d'), + type=IntBaseType(String('intc')), + value=Symbol('a') + ) + ) + + assert res2[4] == Declaration( + Variable(Symbol('e'), + type=IntBaseType(String('intc')), + value=Add( + Symbol('a'), + Integer(1) + ) + ) + ) + + assert res2[5] == Declaration( + Variable(Symbol('f'), + type=IntBaseType(String('intc')), + value=Add( + Symbol('a'), + Mul( + Symbol('b'), + Symbol('c') + ), + Mul( + Integer(-1), + Symbol('d'), + Pow( + Symbol('e'), + Integer(-1) + ) + ) + ) + ) + ) + + + def test_unary_operators(): + c_src1 = ( + 'void func()'+ + '{' + '\n' + + 'int a = 10;' + '\n' + + 'int b = 20;' + '\n' + + '++a;' + '\n' + + '--b;' + '\n' + + 'a++;' + '\n' + + 'b--;' + '\n' + + '}' + ) + + c_src2 = ( + 'void func()'+ + '{' + '\n' + + 'int a = 10;' + '\n' + + 'int b = -100;' + '\n' + + 'int c = +19;' + '\n' + + 'int d = ++a;' + '\n' + + 'int e = --b;' + '\n' + + 'int f = a++;' + '\n' + + 'int g = b--;' + '\n' + + 'bool h = !false;' + '\n' + + 'bool i = !d;' + '\n' + + 'bool j = !0;' + '\n' + + 'bool k = !10.0;' + '\n' + + '}' + ) + + c_src_raise1 = ( + 'void func()'+ + '{' + '\n' + + 'int a = 10;' + '\n' + + 'int b = ~a;' + '\n' + + '}' + ) + + c_src_raise2 = ( + 'void func()'+ + '{' + '\n' + + 'int a = 10;' + '\n' + + 'int b = *&a;' + '\n' + + '}' + ) + + res1 = SymPyExpression(c_src1, 'c').return_expr() + res2 = SymPyExpression(c_src2, 'c').return_expr() + + assert res1[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(10) + ) + ), + Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(20) + ) + ), + PreIncrement(Symbol('a')), + PreDecrement(Symbol('b')), + PostIncrement(Symbol('a')), + PostDecrement(Symbol('b')) + ) + ) + + assert res2[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(10) + ) + ), + Declaration( + Variable(Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(-100) + ) + ), + Declaration( + Variable(Symbol('c'), + type=IntBaseType(String('intc')), + value=Integer(19) + ) + ), + Declaration( + Variable(Symbol('d'), + type=IntBaseType(String('intc')), + value=PreIncrement(Symbol('a')) + ) + ), + Declaration( + Variable(Symbol('e'), + type=IntBaseType(String('intc')), + value=PreDecrement(Symbol('b')) + ) + ), + Declaration( + Variable(Symbol('f'), + type=IntBaseType(String('intc')), + value=PostIncrement(Symbol('a')) + ) + ), + Declaration( + Variable(Symbol('g'), + type=IntBaseType(String('intc')), + value=PostDecrement(Symbol('b')) + ) + ), + Declaration( + Variable(Symbol('h'), + type=Type(String('bool')), + value=true + ) + ), + Declaration( + Variable(Symbol('i'), + type=Type(String('bool')), + value=Not(Symbol('d')) + ) + ), + Declaration( + Variable(Symbol('j'), + type=Type(String('bool')), + value=true + ) + ), + Declaration( + Variable(Symbol('k'), + type=Type(String('bool')), + value=false + ) + ) + ) + ) + + raises(NotImplementedError, lambda: SymPyExpression(c_src_raise1, 'c')) + raises(NotImplementedError, lambda: SymPyExpression(c_src_raise2, 'c')) + + + def test_compound_assignment_operator(): + c_src = ( + 'void func()'+ + '{' + '\n' + + 'int a = 100;' + '\n' + + 'a += 10;' + '\n' + + 'a -= 10;' + '\n' + + 'a *= 10;' + '\n' + + 'a /= 10;' + '\n' + + 'a %= 10;' + '\n' + + '}' + ) + + res = SymPyExpression(c_src, 'c').return_expr() + + assert res[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable( + Symbol('a'), + type=IntBaseType(String('intc')), + value=Integer(100) + ) + ), + AddAugmentedAssignment( + Variable(Symbol('a')), + Integer(10) + ), + SubAugmentedAssignment( + Variable(Symbol('a')), + Integer(10) + ), + MulAugmentedAssignment( + Variable(Symbol('a')), + Integer(10) + ), + DivAugmentedAssignment( + Variable(Symbol('a')), + Integer(10) + ), + ModAugmentedAssignment( + Variable(Symbol('a')), + Integer(10) + ) + ) + ) + + @XFAIL # this is expected to fail because of a bug in the C parser. + def test_while_stmt(): + c_src1 = ( + 'void func()'+ + '{' + '\n' + + 'int i = 0;' + '\n' + + 'while(i < 10)' + '\n' + + '{' + '\n' + + 'i++;' + '\n' + + '}' + '}' + ) + + c_src2 = ( + 'void func()'+ + '{' + '\n' + + 'int i = 0;' + '\n' + + 'while(i < 10)' + '\n' + + 'i++;' + '\n' + + '}' + ) + + c_src3 = ( + 'void func()'+ + '{' + '\n' + + 'int i = 10;' + '\n' + + 'int cnt = 0;' + '\n' + + 'while(i > 0)' + '\n' + + '{' + '\n' + + 'i--;' + '\n' + + 'cnt++;' + '\n' + + '}' + '\n' + + '}' + ) + + c_src4 = ( + 'int digit_sum(int n)'+ + '{' + '\n' + + 'int sum = 0;' + '\n' + + 'while(n > 0)' + '\n' + + '{' + '\n' + + 'sum += (n % 10);' + '\n' + + 'n /= 10;' + '\n' + + '}' + '\n' + + 'return sum;' + '\n' + + '}' + ) + + c_src5 = ( + 'void func()'+ + '{' + '\n' + + 'while(1);' + '\n' + + '}' + ) + + res1 = SymPyExpression(c_src1, 'c').return_expr() + res2 = SymPyExpression(c_src2, 'c').return_expr() + res3 = SymPyExpression(c_src3, 'c').return_expr() + res4 = SymPyExpression(c_src4, 'c').return_expr() + res5 = SymPyExpression(c_src5, 'c').return_expr() + + assert res1[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable(Symbol('i'), + type=IntBaseType(String('intc')), + value=Integer(0) + ) + ), + While( + StrictLessThan( + Symbol('i'), + Integer(10) + ), + body=CodeBlock( + PostIncrement( + Symbol('i') + ) + ) + ) + ) + ) + + assert res2[0] == res1[0] + + assert res3[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + Declaration( + Variable( + Symbol('i'), + type=IntBaseType(String('intc')), + value=Integer(10) + ) + ), + Declaration( + Variable( + Symbol('cnt'), + type=IntBaseType(String('intc')), + value=Integer(0) + ) + ), + While( + StrictGreaterThan( + Symbol('i'), + Integer(0) + ), + body=CodeBlock( + PostDecrement( + Symbol('i') + ), + PostIncrement( + Symbol('cnt') + ) + ) + ) + ) + ) + + assert res4[0] == FunctionDefinition( + IntBaseType(String('intc')), + name=String('digit_sum'), + parameters=( + Variable( + Symbol('n'), + type=IntBaseType(String('intc')) + ), + ), + body=CodeBlock( + Declaration( + Variable( + Symbol('sum'), + type=IntBaseType(String('intc')), + value=Integer(0) + ) + ), + While( + StrictGreaterThan( + Symbol('n'), + Integer(0) + ), + body=CodeBlock( + AddAugmentedAssignment( + Variable( + Symbol('sum') + ), + Mod( + Symbol('n'), + Integer(10) + ) + ), + DivAugmentedAssignment( + Variable( + Symbol('n') + ), + Integer(10) + ) + ) + ), + Return('sum') + ) + ) + + assert res5[0] == FunctionDefinition( + NoneToken(), + name=String('func'), + parameters=(), + body=CodeBlock( + While( + Integer(1), + body=CodeBlock( + NoneToken() + ) + ) + ) + ) + + +else: + def test_raise(): + from sympy.parsing.c.c_parser import CCodeConverter + raises(ImportError, lambda: CCodeConverter()) + raises(ImportError, lambda: SymPyExpression(' ', mode = 'c')) diff --git a/phivenv/Lib/site-packages/sympy/parsing/tests/test_custom_latex.py b/phivenv/Lib/site-packages/sympy/parsing/tests/test_custom_latex.py new file mode 100644 index 0000000000000000000000000000000000000000..f5eff1c9ec79528c7f9e3a06cf9e2f84c86091ee --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/tests/test_custom_latex.py @@ -0,0 +1,69 @@ +import os +import tempfile +from pathlib import Path + +import sympy +from sympy.testing.pytest import raises +from sympy.parsing.latex.lark import LarkLaTeXParser, TransformToSymPyExpr, parse_latex_lark +from sympy.external import import_module + +lark = import_module("lark") + +# disable tests if lark is not present +disabled = lark is None + +grammar_file = os.path.join(os.path.dirname(__file__), "../latex/lark/grammar/latex.lark") + +modification1 = """ +%override DIV_SYMBOL: DIV +%override MUL_SYMBOL: MUL | CMD_TIMES +""" + +modification2 = r""" +%override number: /\d+(,\d*)?/ +""" + +def init_custom_parser(modification, transformer=None): + latex_grammar = Path(grammar_file).read_text(encoding="utf-8") + latex_grammar += modification + + with tempfile.NamedTemporaryFile() as f: + f.write(bytes(latex_grammar, encoding="utf8")) + f.flush() + + parser = LarkLaTeXParser(grammar_file=f.name, transformer=transformer) + + return parser + +def test_custom1(): + # Removes the parser's ability to understand \cdot and \div. + + parser = init_custom_parser(modification1) + + with raises(lark.exceptions.UnexpectedCharacters): + parser.doparse(r"a \cdot b") + parser.doparse(r"x \div y") + +class CustomTransformer(TransformToSymPyExpr): + def number(self, tokens): + if "," in tokens[0]: + # The Float constructor expects a dot as the decimal separator + return sympy.core.numbers.Float(tokens[0].replace(",", ".")) + else: + return sympy.core.numbers.Integer(tokens[0]) + +def test_custom2(): + # Makes the parser parse commas as the decimal separator instead of dots + + parser = init_custom_parser(modification2, CustomTransformer) + + with raises(lark.exceptions.UnexpectedCharacters): + # Asserting that the default parser cannot parse numbers which have commas as + # the decimal separator + parse_latex_lark("100,1") + parse_latex_lark("0,009") + + parser.doparse("100,1") + parser.doparse("0,009") + parser.doparse("2,71828") + parser.doparse("3,14159") diff --git a/phivenv/Lib/site-packages/sympy/parsing/tests/test_fortran_parser.py b/phivenv/Lib/site-packages/sympy/parsing/tests/test_fortran_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..9bcd54533ef231dd0a116910453dff0e993bc727 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/tests/test_fortran_parser.py @@ -0,0 +1,406 @@ +from sympy.testing.pytest import raises +from sympy.parsing.sym_expr import SymPyExpression +from sympy.external import import_module + +lfortran = import_module('lfortran') + +if lfortran: + from sympy.codegen.ast import (Variable, IntBaseType, FloatBaseType, String, + Return, FunctionDefinition, Assignment, + Declaration, CodeBlock) + from sympy.core import Integer, Float, Add + from sympy.core.symbol import Symbol + + + expr1 = SymPyExpression() + expr2 = SymPyExpression() + src = """\ + integer :: a, b, c, d + real :: p, q, r, s + """ + + + def test_sym_expr(): + src1 = ( + src + + """\ + d = a + b -c + """ + ) + expr3 = SymPyExpression(src,'f') + expr4 = SymPyExpression(src1,'f') + ls1 = expr3.return_expr() + ls2 = expr4.return_expr() + for i in range(0, 7): + assert isinstance(ls1[i], Declaration) + assert isinstance(ls2[i], Declaration) + assert isinstance(ls2[8], Assignment) + assert ls1[0] == Declaration( + Variable( + Symbol('a'), + type = IntBaseType(String('integer')), + value = Integer(0) + ) + ) + assert ls1[1] == Declaration( + Variable( + Symbol('b'), + type = IntBaseType(String('integer')), + value = Integer(0) + ) + ) + assert ls1[2] == Declaration( + Variable( + Symbol('c'), + type = IntBaseType(String('integer')), + value = Integer(0) + ) + ) + assert ls1[3] == Declaration( + Variable( + Symbol('d'), + type = IntBaseType(String('integer')), + value = Integer(0) + ) + ) + assert ls1[4] == Declaration( + Variable( + Symbol('p'), + type = FloatBaseType(String('real')), + value = Float(0.0) + ) + ) + assert ls1[5] == Declaration( + Variable( + Symbol('q'), + type = FloatBaseType(String('real')), + value = Float(0.0) + ) + ) + assert ls1[6] == Declaration( + Variable( + Symbol('r'), + type = FloatBaseType(String('real')), + value = Float(0.0) + ) + ) + assert ls1[7] == Declaration( + Variable( + Symbol('s'), + type = FloatBaseType(String('real')), + value = Float(0.0) + ) + ) + assert ls2[8] == Assignment( + Variable(Symbol('d')), + Symbol('a') + Symbol('b') - Symbol('c') + ) + + def test_assignment(): + src1 = ( + src + + """\ + a = b + c = d + p = q + r = s + """ + ) + expr1.convert_to_expr(src1, 'f') + ls1 = expr1.return_expr() + for iter in range(0, 12): + if iter < 8: + assert isinstance(ls1[iter], Declaration) + else: + assert isinstance(ls1[iter], Assignment) + assert ls1[8] == Assignment( + Variable(Symbol('a')), + Variable(Symbol('b')) + ) + assert ls1[9] == Assignment( + Variable(Symbol('c')), + Variable(Symbol('d')) + ) + assert ls1[10] == Assignment( + Variable(Symbol('p')), + Variable(Symbol('q')) + ) + assert ls1[11] == Assignment( + Variable(Symbol('r')), + Variable(Symbol('s')) + ) + + + def test_binop_add(): + src1 = ( + src + + """\ + c = a + b + d = a + c + s = p + q + r + """ + ) + expr1.convert_to_expr(src1, 'f') + ls1 = expr1.return_expr() + for iter in range(8, 11): + assert isinstance(ls1[iter], Assignment) + assert ls1[8] == Assignment( + Variable(Symbol('c')), + Symbol('a') + Symbol('b') + ) + assert ls1[9] == Assignment( + Variable(Symbol('d')), + Symbol('a') + Symbol('c') + ) + assert ls1[10] == Assignment( + Variable(Symbol('s')), + Symbol('p') + Symbol('q') + Symbol('r') + ) + + + def test_binop_sub(): + src1 = ( + src + + """\ + c = a - b + d = a - c + s = p - q - r + """ + ) + expr1.convert_to_expr(src1, 'f') + ls1 = expr1.return_expr() + for iter in range(8, 11): + assert isinstance(ls1[iter], Assignment) + assert ls1[8] == Assignment( + Variable(Symbol('c')), + Symbol('a') - Symbol('b') + ) + assert ls1[9] == Assignment( + Variable(Symbol('d')), + Symbol('a') - Symbol('c') + ) + assert ls1[10] == Assignment( + Variable(Symbol('s')), + Symbol('p') - Symbol('q') - Symbol('r') + ) + + + def test_binop_mul(): + src1 = ( + src + + """\ + c = a * b + d = a * c + s = p * q * r + """ + ) + expr1.convert_to_expr(src1, 'f') + ls1 = expr1.return_expr() + for iter in range(8, 11): + assert isinstance(ls1[iter], Assignment) + assert ls1[8] == Assignment( + Variable(Symbol('c')), + Symbol('a') * Symbol('b') + ) + assert ls1[9] == Assignment( + Variable(Symbol('d')), + Symbol('a') * Symbol('c') + ) + assert ls1[10] == Assignment( + Variable(Symbol('s')), + Symbol('p') * Symbol('q') * Symbol('r') + ) + + + def test_binop_div(): + src1 = ( + src + + """\ + c = a / b + d = a / c + s = p / q + r = q / p + """ + ) + expr1.convert_to_expr(src1, 'f') + ls1 = expr1.return_expr() + for iter in range(8, 12): + assert isinstance(ls1[iter], Assignment) + assert ls1[8] == Assignment( + Variable(Symbol('c')), + Symbol('a') / Symbol('b') + ) + assert ls1[9] == Assignment( + Variable(Symbol('d')), + Symbol('a') / Symbol('c') + ) + assert ls1[10] == Assignment( + Variable(Symbol('s')), + Symbol('p') / Symbol('q') + ) + assert ls1[11] == Assignment( + Variable(Symbol('r')), + Symbol('q') / Symbol('p') + ) + + def test_mul_binop(): + src1 = ( + src + + """\ + d = a + b - c + c = a * b + d + s = p * q / r + r = p * s + q / p + """ + ) + expr1.convert_to_expr(src1, 'f') + ls1 = expr1.return_expr() + for iter in range(8, 12): + assert isinstance(ls1[iter], Assignment) + assert ls1[8] == Assignment( + Variable(Symbol('d')), + Symbol('a') + Symbol('b') - Symbol('c') + ) + assert ls1[9] == Assignment( + Variable(Symbol('c')), + Symbol('a') * Symbol('b') + Symbol('d') + ) + assert ls1[10] == Assignment( + Variable(Symbol('s')), + Symbol('p') * Symbol('q') / Symbol('r') + ) + assert ls1[11] == Assignment( + Variable(Symbol('r')), + Symbol('p') * Symbol('s') + Symbol('q') / Symbol('p') + ) + + + def test_function(): + src1 = """\ + integer function f(a,b) + integer :: x, y + f = x + y + end function + """ + expr1.convert_to_expr(src1, 'f') + for iter in expr1.return_expr(): + assert isinstance(iter, FunctionDefinition) + assert iter == FunctionDefinition( + IntBaseType(String('integer')), + name=String('f'), + parameters=( + Variable(Symbol('a')), + Variable(Symbol('b')) + ), + body=CodeBlock( + Declaration( + Variable( + Symbol('a'), + type=IntBaseType(String('integer')), + value=Integer(0) + ) + ), + Declaration( + Variable( + Symbol('b'), + type=IntBaseType(String('integer')), + value=Integer(0) + ) + ), + Declaration( + Variable( + Symbol('f'), + type=IntBaseType(String('integer')), + value=Integer(0) + ) + ), + Declaration( + Variable( + Symbol('x'), + type=IntBaseType(String('integer')), + value=Integer(0) + ) + ), + Declaration( + Variable( + Symbol('y'), + type=IntBaseType(String('integer')), + value=Integer(0) + ) + ), + Assignment( + Variable(Symbol('f')), + Add(Symbol('x'), Symbol('y')) + ), + Return(Variable(Symbol('f'))) + ) + ) + + + def test_var(): + expr1.convert_to_expr(src, 'f') + ls = expr1.return_expr() + for iter in expr1.return_expr(): + assert isinstance(iter, Declaration) + assert ls[0] == Declaration( + Variable( + Symbol('a'), + type = IntBaseType(String('integer')), + value = Integer(0) + ) + ) + assert ls[1] == Declaration( + Variable( + Symbol('b'), + type = IntBaseType(String('integer')), + value = Integer(0) + ) + ) + assert ls[2] == Declaration( + Variable( + Symbol('c'), + type = IntBaseType(String('integer')), + value = Integer(0) + ) + ) + assert ls[3] == Declaration( + Variable( + Symbol('d'), + type = IntBaseType(String('integer')), + value = Integer(0) + ) + ) + assert ls[4] == Declaration( + Variable( + Symbol('p'), + type = FloatBaseType(String('real')), + value = Float(0.0) + ) + ) + assert ls[5] == Declaration( + Variable( + Symbol('q'), + type = FloatBaseType(String('real')), + value = Float(0.0) + ) + ) + assert ls[6] == Declaration( + Variable( + Symbol('r'), + type = FloatBaseType(String('real')), + value = Float(0.0) + ) + ) + assert ls[7] == Declaration( + Variable( + Symbol('s'), + type = FloatBaseType(String('real')), + value = Float(0.0) + ) + ) + +else: + def test_raise(): + from sympy.parsing.fortran.fortran_parser import ASR2PyVisitor + raises(ImportError, lambda: ASR2PyVisitor()) + raises(ImportError, lambda: SymPyExpression(' ', mode = 'f')) diff --git a/phivenv/Lib/site-packages/sympy/parsing/tests/test_implicit_multiplication_application.py b/phivenv/Lib/site-packages/sympy/parsing/tests/test_implicit_multiplication_application.py new file mode 100644 index 0000000000000000000000000000000000000000..56df361e77b0c0f94bdb53b03e0dc30a8a10899f --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/tests/test_implicit_multiplication_application.py @@ -0,0 +1,195 @@ +import sympy +from sympy.parsing.sympy_parser import ( + parse_expr, + standard_transformations, + convert_xor, + implicit_multiplication_application, + implicit_multiplication, + implicit_application, + function_exponentiation, + split_symbols, + split_symbols_custom, + _token_splittable +) +from sympy.testing.pytest import raises + + +def test_implicit_multiplication(): + cases = { + '5x': '5*x', + 'abc': 'a*b*c', + '3sin(x)': '3*sin(x)', + '(x+1)(x+2)': '(x+1)*(x+2)', + '(5 x**2)sin(x)': '(5*x**2)*sin(x)', + '2 sin(x) cos(x)': '2*sin(x)*cos(x)', + 'pi x': 'pi*x', + 'x pi': 'x*pi', + 'E x': 'E*x', + 'EulerGamma y': 'EulerGamma*y', + 'E pi': 'E*pi', + 'pi (x + 2)': 'pi*(x+2)', + '(x + 2) pi': '(x+2)*pi', + 'pi sin(x)': 'pi*sin(x)', + } + transformations = standard_transformations + (convert_xor,) + transformations2 = transformations + (split_symbols, + implicit_multiplication) + for case in cases: + implicit = parse_expr(case, transformations=transformations2) + normal = parse_expr(cases[case], transformations=transformations) + assert(implicit == normal) + + application = ['sin x', 'cos 2*x', 'sin cos x'] + for case in application: + raises(SyntaxError, + lambda: parse_expr(case, transformations=transformations2)) + raises(TypeError, + lambda: parse_expr('sin**2(x)', transformations=transformations2)) + + +def test_implicit_application(): + cases = { + 'factorial': 'factorial', + 'sin x': 'sin(x)', + 'tan y**3': 'tan(y**3)', + 'cos 2*x': 'cos(2*x)', + '(cot)': 'cot', + 'sin cos tan x': 'sin(cos(tan(x)))' + } + transformations = standard_transformations + (convert_xor,) + transformations2 = transformations + (implicit_application,) + for case in cases: + implicit = parse_expr(case, transformations=transformations2) + normal = parse_expr(cases[case], transformations=transformations) + assert(implicit == normal), (implicit, normal) + + multiplication = ['x y', 'x sin x', '2x'] + for case in multiplication: + raises(SyntaxError, + lambda: parse_expr(case, transformations=transformations2)) + raises(TypeError, + lambda: parse_expr('sin**2(x)', transformations=transformations2)) + + +def test_function_exponentiation(): + cases = { + 'sin**2(x)': 'sin(x)**2', + 'exp^y(z)': 'exp(z)^y', + 'sin**2(E^(x))': 'sin(E^(x))**2' + } + transformations = standard_transformations + (convert_xor,) + transformations2 = transformations + (function_exponentiation,) + for case in cases: + implicit = parse_expr(case, transformations=transformations2) + normal = parse_expr(cases[case], transformations=transformations) + assert(implicit == normal) + + other_implicit = ['x y', 'x sin x', '2x', 'sin x', + 'cos 2*x', 'sin cos x'] + for case in other_implicit: + raises(SyntaxError, + lambda: parse_expr(case, transformations=transformations2)) + + assert parse_expr('x**2', local_dict={ 'x': sympy.Symbol('x') }, + transformations=transformations2) == parse_expr('x**2') + + +def test_symbol_splitting(): + # By default Greek letter names should not be split (lambda is a keyword + # so skip it) + transformations = standard_transformations + (split_symbols,) + greek_letters = ('alpha', 'beta', 'gamma', 'delta', 'epsilon', 'zeta', + 'eta', 'theta', 'iota', 'kappa', 'mu', 'nu', 'xi', + 'omicron', 'pi', 'rho', 'sigma', 'tau', 'upsilon', + 'phi', 'chi', 'psi', 'omega') + + for letter in greek_letters: + assert(parse_expr(letter, transformations=transformations) == + parse_expr(letter)) + + # Make sure symbol splitting resolves names + transformations += (implicit_multiplication,) + local_dict = { 'e': sympy.E } + cases = { + 'xe': 'E*x', + 'Iy': 'I*y', + 'ee': 'E*E', + } + for case, expected in cases.items(): + assert(parse_expr(case, local_dict=local_dict, + transformations=transformations) == + parse_expr(expected)) + + # Make sure custom splitting works + def can_split(symbol): + if symbol not in ('unsplittable', 'names'): + return _token_splittable(symbol) + return False + transformations = standard_transformations + transformations += (split_symbols_custom(can_split), + implicit_multiplication) + + assert(parse_expr('unsplittable', transformations=transformations) == + parse_expr('unsplittable')) + assert(parse_expr('names', transformations=transformations) == + parse_expr('names')) + assert(parse_expr('xy', transformations=transformations) == + parse_expr('x*y')) + for letter in greek_letters: + assert(parse_expr(letter, transformations=transformations) == + parse_expr(letter)) + + +def test_all_implicit_steps(): + cases = { + '2x': '2*x', # implicit multiplication + 'x y': 'x*y', + 'xy': 'x*y', + 'sin x': 'sin(x)', # add parentheses + '2sin x': '2*sin(x)', + 'x y z': 'x*y*z', + 'sin(2 * 3x)': 'sin(2 * 3 * x)', + 'sin(x) (1 + cos(x))': 'sin(x) * (1 + cos(x))', + '(x + 2) sin(x)': '(x + 2) * sin(x)', + '(x + 2) sin x': '(x + 2) * sin(x)', + 'sin(sin x)': 'sin(sin(x))', + 'sin x!': 'sin(factorial(x))', + 'sin x!!': 'sin(factorial2(x))', + 'factorial': 'factorial', # don't apply a bare function + 'x sin x': 'x * sin(x)', # both application and multiplication + 'xy sin x': 'x * y * sin(x)', + '(x+2)(x+3)': '(x + 2) * (x+3)', + 'x**2 + 2xy + y**2': 'x**2 + 2 * x * y + y**2', # split the xy + 'pi': 'pi', # don't mess with constants + 'None': 'None', + 'ln sin x': 'ln(sin(x))', # multiple implicit function applications + 'sin x**2': 'sin(x**2)', # implicit application to an exponential + 'alpha': 'Symbol("alpha")', # don't split Greek letters/subscripts + 'x_2': 'Symbol("x_2")', + 'sin^2 x**2': 'sin(x**2)**2', # function raised to a power + 'sin**3(x)': 'sin(x)**3', + '(factorial)': 'factorial', + 'tan 3x': 'tan(3*x)', + 'sin^2(3*E^(x))': 'sin(3*E**(x))**2', + 'sin**2(E^(3x))': 'sin(E**(3*x))**2', + 'sin^2 (3x*E^(x))': 'sin(3*x*E^x)**2', + 'pi sin x': 'pi*sin(x)', + } + transformations = standard_transformations + (convert_xor,) + transformations2 = transformations + (implicit_multiplication_application,) + for case in cases: + implicit = parse_expr(case, transformations=transformations2) + normal = parse_expr(cases[case], transformations=transformations) + assert(implicit == normal) + + +def test_no_methods_implicit_multiplication(): + # Issue 21020 + u = sympy.Symbol('u') + transformations = standard_transformations + \ + (implicit_multiplication,) + expr = parse_expr('x.is_polynomial(x)', transformations=transformations) + assert expr == True + expr = parse_expr('(exp(x) / (1 + exp(2x))).subs(exp(x), u)', + transformations=transformations) + assert expr == u/(u**2 + 1) diff --git a/phivenv/Lib/site-packages/sympy/parsing/tests/test_latex.py b/phivenv/Lib/site-packages/sympy/parsing/tests/test_latex.py new file mode 100644 index 0000000000000000000000000000000000000000..49a48966eacaa1cd7a242dcd0e7699c992bb1268 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/tests/test_latex.py @@ -0,0 +1,358 @@ +from sympy.testing.pytest import raises, XFAIL +from sympy.external import import_module + +from sympy.concrete.products import Product +from sympy.concrete.summations import Sum +from sympy.core.add import Add +from sympy.core.function import (Derivative, Function) +from sympy.core.mul import Mul +from sympy.core.numbers import (E, oo) +from sympy.core.power import Pow +from sympy.core.relational import (GreaterThan, LessThan, StrictGreaterThan, StrictLessThan, Unequality) +from sympy.core.symbol import Symbol +from sympy.functions.combinatorial.factorials import (binomial, factorial) +from sympy.functions.elementary.complexes import (Abs, conjugate) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.integers import (ceiling, floor) +from sympy.functions.elementary.miscellaneous import (root, sqrt) +from sympy.functions.elementary.trigonometric import (asin, cos, csc, sec, sin, tan) +from sympy.integrals.integrals import Integral +from sympy.series.limits import Limit + +from sympy.core.relational import Eq, Ne, Lt, Le, Gt, Ge +from sympy.physics.quantum.state import Bra, Ket +from sympy.abc import x, y, z, a, b, c, t, k, n +antlr4 = import_module("antlr4") + +# disable tests if antlr4-python3-runtime is not present +disabled = antlr4 is None + +theta = Symbol('theta') +f = Function('f') + + +# shorthand definitions +def _Add(a, b): + return Add(a, b, evaluate=False) + + +def _Mul(a, b): + return Mul(a, b, evaluate=False) + + +def _Pow(a, b): + return Pow(a, b, evaluate=False) + + +def _Sqrt(a): + return sqrt(a, evaluate=False) + + +def _Conjugate(a): + return conjugate(a, evaluate=False) + + +def _Abs(a): + return Abs(a, evaluate=False) + + +def _factorial(a): + return factorial(a, evaluate=False) + + +def _exp(a): + return exp(a, evaluate=False) + + +def _log(a, b): + return log(a, b, evaluate=False) + + +def _binomial(n, k): + return binomial(n, k, evaluate=False) + + +def test_import(): + from sympy.parsing.latex._build_latex_antlr import ( + build_parser, + check_antlr_version, + dir_latex_antlr + ) + # XXX: It would be better to come up with a test for these... + del build_parser, check_antlr_version, dir_latex_antlr + + +# These LaTeX strings should parse to the corresponding SymPy expression +GOOD_PAIRS = [ + (r"0", 0), + (r"1", 1), + (r"-3.14", -3.14), + (r"(-7.13)(1.5)", _Mul(-7.13, 1.5)), + (r"x", x), + (r"2x", 2*x), + (r"x^2", x**2), + (r"x^\frac{1}{2}", _Pow(x, _Pow(2, -1))), + (r"x^{3 + 1}", x**_Add(3, 1)), + (r"-c", -c), + (r"a \cdot b", a * b), + (r"a / b", a / b), + (r"a \div b", a / b), + (r"a + b", a + b), + (r"a + b - a", _Add(a+b, -a)), + (r"a^2 + b^2 = c^2", Eq(a**2 + b**2, c**2)), + (r"(x + y) z", _Mul(_Add(x, y), z)), + (r"a'b+ab'", _Add(_Mul(Symbol("a'"), b), _Mul(a, Symbol("b'")))), + (r"y''_1", Symbol("y_{1}''")), + (r"y_1''", Symbol("y_{1}''")), + (r"\left(x + y\right) z", _Mul(_Add(x, y), z)), + (r"\left( x + y\right ) z", _Mul(_Add(x, y), z)), + (r"\left( x + y\right ) z", _Mul(_Add(x, y), z)), + (r"\left[x + y\right] z", _Mul(_Add(x, y), z)), + (r"\left\{x + y\right\} z", _Mul(_Add(x, y), z)), + (r"1+1", _Add(1, 1)), + (r"0+1", _Add(0, 1)), + (r"1*2", _Mul(1, 2)), + (r"0*1", _Mul(0, 1)), + (r"1 \times 2 ", _Mul(1, 2)), + (r"x = y", Eq(x, y)), + (r"x \neq y", Ne(x, y)), + (r"x < y", Lt(x, y)), + (r"x > y", Gt(x, y)), + (r"x \leq y", Le(x, y)), + (r"x \geq y", Ge(x, y)), + (r"x \le y", Le(x, y)), + (r"x \ge y", Ge(x, y)), + (r"\lfloor x \rfloor", floor(x)), + (r"\lceil x \rceil", ceiling(x)), + (r"\langle x |", Bra('x')), + (r"| x \rangle", Ket('x')), + (r"\sin \theta", sin(theta)), + (r"\sin(\theta)", sin(theta)), + (r"\sin^{-1} a", asin(a)), + (r"\sin a \cos b", _Mul(sin(a), cos(b))), + (r"\sin \cos \theta", sin(cos(theta))), + (r"\sin(\cos \theta)", sin(cos(theta))), + (r"\frac{a}{b}", a / b), + (r"\dfrac{a}{b}", a / b), + (r"\tfrac{a}{b}", a / b), + (r"\frac12", _Pow(2, -1)), + (r"\frac12y", _Mul(_Pow(2, -1), y)), + (r"\frac1234", _Mul(_Pow(2, -1), 34)), + (r"\frac2{3}", _Mul(2, _Pow(3, -1))), + (r"\frac{\sin{x}}2", _Mul(sin(x), _Pow(2, -1))), + (r"\frac{a + b}{c}", _Mul(a + b, _Pow(c, -1))), + (r"\frac{7}{3}", _Mul(7, _Pow(3, -1))), + (r"(\csc x)(\sec y)", csc(x)*sec(y)), + (r"\lim_{x \to 3} a", Limit(a, x, 3, dir='+-')), + (r"\lim_{x \rightarrow 3} a", Limit(a, x, 3, dir='+-')), + (r"\lim_{x \Rightarrow 3} a", Limit(a, x, 3, dir='+-')), + (r"\lim_{x \longrightarrow 3} a", Limit(a, x, 3, dir='+-')), + (r"\lim_{x \Longrightarrow 3} a", Limit(a, x, 3, dir='+-')), + (r"\lim_{x \to 3^{+}} a", Limit(a, x, 3, dir='+')), + (r"\lim_{x \to 3^{-}} a", Limit(a, x, 3, dir='-')), + (r"\lim_{x \to 3^+} a", Limit(a, x, 3, dir='+')), + (r"\lim_{x \to 3^-} a", Limit(a, x, 3, dir='-')), + (r"\infty", oo), + (r"\lim_{x \to \infty} \frac{1}{x}", Limit(_Pow(x, -1), x, oo)), + (r"\frac{d}{dx} x", Derivative(x, x)), + (r"\frac{d}{dt} x", Derivative(x, t)), + (r"f(x)", f(x)), + (r"f(x, y)", f(x, y)), + (r"f(x, y, z)", f(x, y, z)), + (r"f'_1(x)", Function("f_{1}'")(x)), + (r"f_{1}''(x+y)", Function("f_{1}''")(x+y)), + (r"\frac{d f(x)}{dx}", Derivative(f(x), x)), + (r"\frac{d\theta(x)}{dx}", Derivative(Function('theta')(x), x)), + (r"x \neq y", Unequality(x, y)), + (r"|x|", _Abs(x)), + (r"||x||", _Abs(Abs(x))), + (r"|x||y|", _Abs(x)*_Abs(y)), + (r"||x||y||", _Abs(_Abs(x)*_Abs(y))), + (r"\pi^{|xy|}", Symbol('pi')**_Abs(x*y)), + (r"\int x dx", Integral(x, x)), + (r"\int x d\theta", Integral(x, theta)), + (r"\int (x^2 - y)dx", Integral(x**2 - y, x)), + (r"\int x + a dx", Integral(_Add(x, a), x)), + (r"\int da", Integral(1, a)), + (r"\int_0^7 dx", Integral(1, (x, 0, 7))), + (r"\int\limits_{0}^{1} x dx", Integral(x, (x, 0, 1))), + (r"\int_a^b x dx", Integral(x, (x, a, b))), + (r"\int^b_a x dx", Integral(x, (x, a, b))), + (r"\int_{a}^b x dx", Integral(x, (x, a, b))), + (r"\int^{b}_a x dx", Integral(x, (x, a, b))), + (r"\int_{a}^{b} x dx", Integral(x, (x, a, b))), + (r"\int^{b}_{a} x dx", Integral(x, (x, a, b))), + (r"\int_{f(a)}^{f(b)} f(z) dz", Integral(f(z), (z, f(a), f(b)))), + (r"\int (x+a)", Integral(_Add(x, a), x)), + (r"\int a + b + c dx", Integral(_Add(_Add(a, b), c), x)), + (r"\int \frac{dz}{z}", Integral(Pow(z, -1), z)), + (r"\int \frac{3 dz}{z}", Integral(3*Pow(z, -1), z)), + (r"\int \frac{1}{x} dx", Integral(Pow(x, -1), x)), + (r"\int \frac{1}{a} + \frac{1}{b} dx", + Integral(_Add(_Pow(a, -1), Pow(b, -1)), x)), + (r"\int \frac{3 \cdot d\theta}{\theta}", + Integral(3*_Pow(theta, -1), theta)), + (r"\int \frac{1}{x} + 1 dx", Integral(_Add(_Pow(x, -1), 1), x)), + (r"x_0", Symbol('x_{0}')), + (r"x_{1}", Symbol('x_{1}')), + (r"x_a", Symbol('x_{a}')), + (r"x_{b}", Symbol('x_{b}')), + (r"h_\theta", Symbol('h_{theta}')), + (r"h_{\theta}", Symbol('h_{theta}')), + (r"h_{\theta}(x_0, x_1)", + Function('h_{theta}')(Symbol('x_{0}'), Symbol('x_{1}'))), + (r"x!", _factorial(x)), + (r"100!", _factorial(100)), + (r"\theta!", _factorial(theta)), + (r"(x + 1)!", _factorial(_Add(x, 1))), + (r"(x!)!", _factorial(_factorial(x))), + (r"x!!!", _factorial(_factorial(_factorial(x)))), + (r"5!7!", _Mul(_factorial(5), _factorial(7))), + (r"\sqrt{x}", sqrt(x)), + (r"\sqrt{x + b}", sqrt(_Add(x, b))), + (r"\sqrt[3]{\sin x}", root(sin(x), 3)), + (r"\sqrt[y]{\sin x}", root(sin(x), y)), + (r"\sqrt[\theta]{\sin x}", root(sin(x), theta)), + (r"\sqrt{\frac{12}{6}}", _Sqrt(_Mul(12, _Pow(6, -1)))), + (r"\overline{z}", _Conjugate(z)), + (r"\overline{\overline{z}}", _Conjugate(_Conjugate(z))), + (r"\overline{x + y}", _Conjugate(_Add(x, y))), + (r"\overline{x} + \overline{y}", _Conjugate(x) + _Conjugate(y)), + (r"x < y", StrictLessThan(x, y)), + (r"x \leq y", LessThan(x, y)), + (r"x > y", StrictGreaterThan(x, y)), + (r"x \geq y", GreaterThan(x, y)), + (r"\mathit{x}", Symbol('x')), + (r"\mathit{test}", Symbol('test')), + (r"\mathit{TEST}", Symbol('TEST')), + (r"\mathit{HELLO world}", Symbol('HELLO world')), + (r"\sum_{k = 1}^{3} c", Sum(c, (k, 1, 3))), + (r"\sum_{k = 1}^3 c", Sum(c, (k, 1, 3))), + (r"\sum^{3}_{k = 1} c", Sum(c, (k, 1, 3))), + (r"\sum^3_{k = 1} c", Sum(c, (k, 1, 3))), + (r"\sum_{k = 1}^{10} k^2", Sum(k**2, (k, 1, 10))), + (r"\sum_{n = 0}^{\infty} \frac{1}{n!}", + Sum(_Pow(_factorial(n), -1), (n, 0, oo))), + (r"\prod_{a = b}^{c} x", Product(x, (a, b, c))), + (r"\prod_{a = b}^c x", Product(x, (a, b, c))), + (r"\prod^{c}_{a = b} x", Product(x, (a, b, c))), + (r"\prod^c_{a = b} x", Product(x, (a, b, c))), + (r"\exp x", _exp(x)), + (r"\exp(x)", _exp(x)), + (r"\lg x", _log(x, 10)), + (r"\ln x", _log(x, E)), + (r"\ln xy", _log(x*y, E)), + (r"\log x", _log(x, E)), + (r"\log xy", _log(x*y, E)), + (r"\log_{2} x", _log(x, 2)), + (r"\log_{a} x", _log(x, a)), + (r"\log_{11} x", _log(x, 11)), + (r"\log_{a^2} x", _log(x, _Pow(a, 2))), + (r"[x]", x), + (r"[a + b]", _Add(a, b)), + (r"\frac{d}{dx} [ \tan x ]", Derivative(tan(x), x)), + (r"\binom{n}{k}", _binomial(n, k)), + (r"\tbinom{n}{k}", _binomial(n, k)), + (r"\dbinom{n}{k}", _binomial(n, k)), + (r"\binom{n}{0}", _binomial(n, 0)), + (r"x^\binom{n}{k}", _Pow(x, _binomial(n, k))), + (r"a \, b", _Mul(a, b)), + (r"a \thinspace b", _Mul(a, b)), + (r"a \: b", _Mul(a, b)), + (r"a \medspace b", _Mul(a, b)), + (r"a \; b", _Mul(a, b)), + (r"a \thickspace b", _Mul(a, b)), + (r"a \quad b", _Mul(a, b)), + (r"a \qquad b", _Mul(a, b)), + (r"a \! b", _Mul(a, b)), + (r"a \negthinspace b", _Mul(a, b)), + (r"a \negmedspace b", _Mul(a, b)), + (r"a \negthickspace b", _Mul(a, b)), + (r"\int x \, dx", Integral(x, x)), + (r"\log_2 x", _log(x, 2)), + (r"\log_a x", _log(x, a)), + (r"5^0 - 4^0", _Add(_Pow(5, 0), _Mul(-1, _Pow(4, 0)))), + (r"3x - 1", _Add(_Mul(3, x), -1)) +] + + +def test_parseable(): + from sympy.parsing.latex import parse_latex + for latex_str, sympy_expr in GOOD_PAIRS: + assert parse_latex(latex_str) == sympy_expr, latex_str + +# These bad LaTeX strings should raise a LaTeXParsingError when parsed +BAD_STRINGS = [ + r"(", + r")", + r"\frac{d}{dx}", + r"(\frac{d}{dx})", + r"\sqrt{}", + r"\sqrt", + r"\overline{}", + r"\overline", + r"{", + r"}", + r"\mathit{x + y}", + r"\mathit{21}", + r"\frac{2}{}", + r"\frac{}{2}", + r"\int", + r"!", + r"!0", + r"_", + r"^", + r"|", + r"||x|", + r"()", + r"((((((((((((((((()))))))))))))))))", + r"-", + r"\frac{d}{dx} + \frac{d}{dt}", + r"f(x,,y)", + r"f(x,y,", + r"\sin^x", + r"\cos^2", + r"@", + r"#", + r"$", + r"%", + r"&", + r"*", + r"" "\\", + r"~", + r"\frac{(2 + x}{1 - x)}", +] + +def test_not_parseable(): + from sympy.parsing.latex import parse_latex, LaTeXParsingError + for latex_str in BAD_STRINGS: + with raises(LaTeXParsingError): + parse_latex(latex_str) + +# At time of migration from latex2sympy, should fail but doesn't +FAILING_BAD_STRINGS = [ + r"\cos 1 \cos", + r"f(,", + r"f()", + r"a \div \div b", + r"a \cdot \cdot b", + r"a // b", + r"a +", + r"1.1.1", + r"1 +", + r"a / b /", +] + +@XFAIL +def test_failing_not_parseable(): + from sympy.parsing.latex import parse_latex, LaTeXParsingError + for latex_str in FAILING_BAD_STRINGS: + with raises(LaTeXParsingError): + parse_latex(latex_str) + +# In strict mode, FAILING_BAD_STRINGS would fail +def test_strict_mode(): + from sympy.parsing.latex import parse_latex, LaTeXParsingError + for latex_str in FAILING_BAD_STRINGS: + with raises(LaTeXParsingError): + parse_latex(latex_str, strict=True) diff --git a/phivenv/Lib/site-packages/sympy/parsing/tests/test_latex_deps.py b/phivenv/Lib/site-packages/sympy/parsing/tests/test_latex_deps.py new file mode 100644 index 0000000000000000000000000000000000000000..7df44c2b19e34024db6e898f7c4eac962dcaa1c9 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/tests/test_latex_deps.py @@ -0,0 +1,16 @@ +from sympy.external import import_module +from sympy.testing.pytest import ignore_warnings, raises + +antlr4 = import_module("antlr4", warn_not_installed=False) + +# disable tests if antlr4-python3-runtime is not present +if antlr4: + disabled = True + + +def test_no_import(): + from sympy.parsing.latex import parse_latex + + with ignore_warnings(UserWarning): + with raises(ImportError): + parse_latex('1 + 1') diff --git a/phivenv/Lib/site-packages/sympy/parsing/tests/test_latex_lark.py b/phivenv/Lib/site-packages/sympy/parsing/tests/test_latex_lark.py new file mode 100644 index 0000000000000000000000000000000000000000..dd1f72a66c788ac41d923005ea988664d05a16c1 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/tests/test_latex_lark.py @@ -0,0 +1,872 @@ +from sympy.testing.pytest import XFAIL +from sympy.parsing.latex.lark import parse_latex_lark +from sympy.external import import_module + +from sympy.concrete.products import Product +from sympy.concrete.summations import Sum +from sympy.core.function import Derivative, Function +from sympy.core.numbers import E, oo, Rational +from sympy.core.power import Pow +from sympy.core.parameters import evaluate +from sympy.core.relational import GreaterThan, LessThan, StrictGreaterThan, StrictLessThan, Unequality +from sympy.core.symbol import Symbol +from sympy.functions.combinatorial.factorials import binomial, factorial +from sympy.functions.elementary.complexes import Abs, conjugate +from sympy.functions.elementary.exponential import exp, log +from sympy.functions.elementary.integers import ceiling, floor +from sympy.functions.elementary.miscellaneous import root, sqrt, Min, Max +from sympy.functions.elementary.trigonometric import asin, cos, csc, sec, sin, tan +from sympy.integrals.integrals import Integral +from sympy.series.limits import Limit +from sympy import Matrix, MatAdd, MatMul, Transpose, Trace +from sympy import I + +from sympy.core.relational import Eq, Ne, Lt, Le, Gt, Ge +from sympy.physics.quantum import Bra, Ket, InnerProduct +from sympy.abc import x, y, z, a, b, c, d, t, k, n + +from .test_latex import theta, f, _Add, _Mul, _Pow, _Sqrt, _Conjugate, _Abs, _factorial, _exp, _binomial + +lark = import_module("lark") + +# disable tests if lark is not present +disabled = lark is None + +# shorthand definitions that are only needed for the Lark LaTeX parser +def _Min(*args): + return Min(*args, evaluate=False) + + +def _Max(*args): + return Max(*args, evaluate=False) + + +def _log(a, b=E): + if b == E: + return log(a, evaluate=False) + else: + return log(a, b, evaluate=False) + + +def _MatAdd(a, b): + return MatAdd(a, b, evaluate=False) + + +def _MatMul(a, b): + return MatMul(a, b, evaluate=False) + + +# These LaTeX strings should parse to the corresponding SymPy expression +SYMBOL_EXPRESSION_PAIRS = [ + (r"x_0", Symbol('x_{0}')), + (r"x_{1}", Symbol('x_{1}')), + (r"x_a", Symbol('x_{a}')), + (r"x_{b}", Symbol('x_{b}')), + (r"h_\theta", Symbol('h_{theta}')), + (r"h_{\theta}", Symbol('h_{theta}')), + (r"y''_1", Symbol("y''_{1}")), + (r"y_1''", Symbol("y_{1}''")), + (r"\mathit{x}", Symbol('x')), + (r"\mathit{test}", Symbol('test')), + (r"\mathit{TEST}", Symbol('TEST')), + (r"\mathit{HELLO world}", Symbol('HELLO world')), + (r"a'", Symbol("a'")), + (r"a''", Symbol("a''")), + (r"\alpha'", Symbol("alpha'")), + (r"\alpha''", Symbol("alpha''")), + (r"a_b", Symbol("a_{b}")), + (r"a_b'", Symbol("a_{b}'")), + (r"a'_b", Symbol("a'_{b}")), + (r"a'_b'", Symbol("a'_{b}'")), + (r"a_{b'}", Symbol("a_{b'}")), + (r"a_{b'}'", Symbol("a_{b'}'")), + (r"a'_{b'}", Symbol("a'_{b'}")), + (r"a'_{b'}'", Symbol("a'_{b'}'")), + (r"\mathit{foo}'", Symbol("foo'")), + (r"\mathit{foo'}", Symbol("foo'")), + (r"\mathit{foo'}'", Symbol("foo''")), + (r"a_b''", Symbol("a_{b}''")), + (r"a''_b", Symbol("a''_{b}")), + (r"a''_b'''", Symbol("a''_{b}'''")), + (r"a_{b''}", Symbol("a_{b''}")), + (r"a_{b''}''", Symbol("a_{b''}''")), + (r"a''_{b''}", Symbol("a''_{b''}")), + (r"a''_{b''}'''", Symbol("a''_{b''}'''")), + (r"\mathit{foo}''", Symbol("foo''")), + (r"\mathit{foo''}", Symbol("foo''")), + (r"\mathit{foo''}'''", Symbol("foo'''''")), + (r"a_\alpha", Symbol("a_{alpha}")), + (r"a_\alpha'", Symbol("a_{alpha}'")), + (r"a'_\alpha", Symbol("a'_{alpha}")), + (r"a'_\alpha'", Symbol("a'_{alpha}'")), + (r"a_{\alpha'}", Symbol("a_{alpha'}")), + (r"a_{\alpha'}'", Symbol("a_{alpha'}'")), + (r"a'_{\alpha'}", Symbol("a'_{alpha'}")), + (r"a'_{\alpha'}'", Symbol("a'_{alpha'}'")), + (r"a_\alpha''", Symbol("a_{alpha}''")), + (r"a''_\alpha", Symbol("a''_{alpha}")), + (r"a''_\alpha'''", Symbol("a''_{alpha}'''")), + (r"a_{\alpha''}", Symbol("a_{alpha''}")), + (r"a_{\alpha''}''", Symbol("a_{alpha''}''")), + (r"a''_{\alpha''}", Symbol("a''_{alpha''}")), + (r"a''_{\alpha''}'''", Symbol("a''_{alpha''}'''")), + (r"\alpha_b", Symbol("alpha_{b}")), + (r"\alpha_b'", Symbol("alpha_{b}'")), + (r"\alpha'_b", Symbol("alpha'_{b}")), + (r"\alpha'_b'", Symbol("alpha'_{b}'")), + (r"\alpha_{b'}", Symbol("alpha_{b'}")), + (r"\alpha_{b'}'", Symbol("alpha_{b'}'")), + (r"\alpha'_{b'}", Symbol("alpha'_{b'}")), + (r"\alpha'_{b'}'", Symbol("alpha'_{b'}'")), + (r"\alpha_b''", Symbol("alpha_{b}''")), + (r"\alpha''_b", Symbol("alpha''_{b}")), + (r"\alpha''_b'''", Symbol("alpha''_{b}'''")), + (r"\alpha_{b''}", Symbol("alpha_{b''}")), + (r"\alpha_{b''}''", Symbol("alpha_{b''}''")), + (r"\alpha''_{b''}", Symbol("alpha''_{b''}")), + (r"\alpha''_{b''}'''", Symbol("alpha''_{b''}'''")), + (r"\alpha_\beta", Symbol("alpha_{beta}")), + (r"\alpha_{\beta}", Symbol("alpha_{beta}")), + (r"\alpha_{\beta'}", Symbol("alpha_{beta'}")), + (r"\alpha_{\beta''}", Symbol("alpha_{beta''}")), + (r"\alpha'_\beta", Symbol("alpha'_{beta}")), + (r"\alpha'_{\beta}", Symbol("alpha'_{beta}")), + (r"\alpha'_{\beta'}", Symbol("alpha'_{beta'}")), + (r"\alpha'_{\beta''}", Symbol("alpha'_{beta''}")), + (r"\alpha''_\beta", Symbol("alpha''_{beta}")), + (r"\alpha''_{\beta}", Symbol("alpha''_{beta}")), + (r"\alpha''_{\beta'}", Symbol("alpha''_{beta'}")), + (r"\alpha''_{\beta''}", Symbol("alpha''_{beta''}")), + (r"\alpha_\beta'", Symbol("alpha_{beta}'")), + (r"\alpha_{\beta}'", Symbol("alpha_{beta}'")), + (r"\alpha_{\beta'}'", Symbol("alpha_{beta'}'")), + (r"\alpha_{\beta''}'", Symbol("alpha_{beta''}'")), + (r"\alpha'_\beta'", Symbol("alpha'_{beta}'")), + (r"\alpha'_{\beta}'", Symbol("alpha'_{beta}'")), + (r"\alpha'_{\beta'}'", Symbol("alpha'_{beta'}'")), + (r"\alpha'_{\beta''}'", Symbol("alpha'_{beta''}'")), + (r"\alpha''_\beta'", Symbol("alpha''_{beta}'")), + (r"\alpha''_{\beta}'", Symbol("alpha''_{beta}'")), + (r"\alpha''_{\beta'}'", Symbol("alpha''_{beta'}'")), + (r"\alpha''_{\beta''}'", Symbol("alpha''_{beta''}'")), + (r"\alpha_\beta''", Symbol("alpha_{beta}''")), + (r"\alpha_{\beta}''", Symbol("alpha_{beta}''")), + (r"\alpha_{\beta'}''", Symbol("alpha_{beta'}''")), + (r"\alpha_{\beta''}''", Symbol("alpha_{beta''}''")), + (r"\alpha'_\beta''", Symbol("alpha'_{beta}''")), + (r"\alpha'_{\beta}''", Symbol("alpha'_{beta}''")), + (r"\alpha'_{\beta'}''", Symbol("alpha'_{beta'}''")), + (r"\alpha'_{\beta''}''", Symbol("alpha'_{beta''}''")), + (r"\alpha''_\beta''", Symbol("alpha''_{beta}''")), + (r"\alpha''_{\beta}''", Symbol("alpha''_{beta}''")), + (r"\alpha''_{\beta'}''", Symbol("alpha''_{beta'}''")), + (r"\alpha''_{\beta''}''", Symbol("alpha''_{beta''}''")) + +] + +UNEVALUATED_SIMPLE_EXPRESSION_PAIRS = [ + (r"0", 0), + (r"1", 1), + (r"-3.14", -3.14), + (r"(-7.13)(1.5)", _Mul(-7.13, 1.5)), + (r"1+1", _Add(1, 1)), + (r"0+1", _Add(0, 1)), + (r"1*2", _Mul(1, 2)), + (r"0*1", _Mul(0, 1)), + (r"x", x), + (r"2x", 2 * x), + (r"3x - 1", _Add(_Mul(3, x), -1)), + (r"-c", -c), + (r"\infty", oo), + (r"a \cdot b", a * b), + (r"1 \times 2 ", _Mul(1, 2)), + (r"a / b", a / b), + (r"a \div b", a / b), + (r"a + b", a + b), + (r"a + b - a", _Add(a + b, -a)), + (r"(x + y) z", _Mul(_Add(x, y), z)), + (r"a'b+ab'", _Add(_Mul(Symbol("a'"), b), _Mul(a, Symbol("b'")))) +] + +EVALUATED_SIMPLE_EXPRESSION_PAIRS = [ + (r"(-7.13)(1.5)", -10.695), + (r"1+1", 2), + (r"0+1", 1), + (r"1*2", 2), + (r"0*1", 0), + (r"2x", 2 * x), + (r"3x - 1", 3 * x - 1), + (r"-c", -c), + (r"a \cdot b", a * b), + (r"1 \times 2 ", 2), + (r"a / b", a / b), + (r"a \div b", a / b), + (r"a + b", a + b), + (r"a + b - a", b), + (r"(x + y) z", (x + y) * z), +] + +UNEVALUATED_FRACTION_EXPRESSION_PAIRS = [ + (r"\frac{a}{b}", a / b), + (r"\dfrac{a}{b}", a / b), + (r"\tfrac{a}{b}", a / b), + (r"\frac12", _Mul(1, _Pow(2, -1))), + (r"\frac12y", _Mul(_Mul(1, _Pow(2, -1)), y)), + (r"\frac1234", _Mul(_Mul(1, _Pow(2, -1)), 34)), + (r"\frac2{3}", _Mul(2, _Pow(3, -1))), + (r"\frac{a + b}{c}", _Mul(a + b, _Pow(c, -1))), + (r"\frac{7}{3}", _Mul(7, _Pow(3, -1))) +] + +EVALUATED_FRACTION_EXPRESSION_PAIRS = [ + (r"\frac{a}{b}", a / b), + (r"\dfrac{a}{b}", a / b), + (r"\tfrac{a}{b}", a / b), + (r"\frac12", Rational(1, 2)), + (r"\frac12y", y / 2), + (r"\frac1234", 17), + (r"\frac2{3}", Rational(2, 3)), + (r"\frac{a + b}{c}", (a + b) / c), + (r"\frac{7}{3}", Rational(7, 3)) +] + +RELATION_EXPRESSION_PAIRS = [ + (r"x = y", Eq(x, y)), + (r"x \neq y", Ne(x, y)), + (r"x < y", Lt(x, y)), + (r"x > y", Gt(x, y)), + (r"x \leq y", Le(x, y)), + (r"x \geq y", Ge(x, y)), + (r"x \le y", Le(x, y)), + (r"x \ge y", Ge(x, y)), + (r"x < y", StrictLessThan(x, y)), + (r"x \leq y", LessThan(x, y)), + (r"x > y", StrictGreaterThan(x, y)), + (r"x \geq y", GreaterThan(x, y)), + (r"x \neq y", Unequality(x, y)), # same as 2nd one in the list + (r"a^2 + b^2 = c^2", Eq(a**2 + b**2, c**2)) +] + +UNEVALUATED_POWER_EXPRESSION_PAIRS = [ + (r"x^2", x ** 2), + (r"x^\frac{1}{2}", _Pow(x, _Mul(1, _Pow(2, -1)))), + (r"x^{3 + 1}", x ** _Add(3, 1)), + (r"\pi^{|xy|}", Symbol('pi') ** _Abs(x * y)), + (r"5^0 - 4^0", _Add(_Pow(5, 0), _Mul(-1, _Pow(4, 0)))) +] + +EVALUATED_POWER_EXPRESSION_PAIRS = [ + (r"x^2", x ** 2), + (r"x^\frac{1}{2}", sqrt(x)), + (r"x^{3 + 1}", x ** 4), + (r"\pi^{|xy|}", Symbol('pi') ** _Abs(x * y)), + (r"5^0 - 4^0", 0) +] + +UNEVALUATED_INTEGRAL_EXPRESSION_PAIRS = [ + (r"\int x dx", Integral(_Mul(1, x), x)), + (r"\int x \, dx", Integral(_Mul(1, x), x)), + (r"\int x d\theta", Integral(_Mul(1, x), theta)), + (r"\int (x^2 - y)dx", Integral(_Mul(1, x ** 2 - y), x)), + (r"\int x + a dx", Integral(_Mul(1, _Add(x, a)), x)), + (r"\int da", Integral(_Mul(1, 1), a)), + (r"\int_0^7 dx", Integral(_Mul(1, 1), (x, 0, 7))), + (r"\int\limits_{0}^{1} x dx", Integral(_Mul(1, x), (x, 0, 1))), + (r"\int_a^b x dx", Integral(_Mul(1, x), (x, a, b))), + (r"\int^b_a x dx", Integral(_Mul(1, x), (x, a, b))), + (r"\int_{a}^b x dx", Integral(_Mul(1, x), (x, a, b))), + (r"\int^{b}_a x dx", Integral(_Mul(1, x), (x, a, b))), + (r"\int_{a}^{b} x dx", Integral(_Mul(1, x), (x, a, b))), + (r"\int^{b}_{a} x dx", Integral(_Mul(1, x), (x, a, b))), + (r"\int_{f(a)}^{f(b)} f(z) dz", Integral(f(z), (z, f(a), f(b)))), + (r"\int a + b + c dx", Integral(_Mul(1, _Add(_Add(a, b), c)), x)), + (r"\int \frac{dz}{z}", Integral(_Mul(1, _Mul(1, Pow(z, -1))), z)), + (r"\int \frac{3 dz}{z}", Integral(_Mul(1, _Mul(3, _Pow(z, -1))), z)), + (r"\int \frac{1}{x} dx", Integral(_Mul(1, _Mul(1, Pow(x, -1))), x)), + (r"\int \frac{1}{a} + \frac{1}{b} dx", + Integral(_Mul(1, _Add(_Mul(1, _Pow(a, -1)), _Mul(1, Pow(b, -1)))), x)), + (r"\int \frac{1}{x} + 1 dx", Integral(_Mul(1, _Add(_Mul(1, _Pow(x, -1)), 1)), x)) +] + +EVALUATED_INTEGRAL_EXPRESSION_PAIRS = [ + (r"\int x dx", Integral(x, x)), + (r"\int x \, dx", Integral(x, x)), + (r"\int x d\theta", Integral(x, theta)), + (r"\int (x^2 - y)dx", Integral(x ** 2 - y, x)), + (r"\int x + a dx", Integral(x + a, x)), + (r"\int da", Integral(1, a)), + (r"\int_0^7 dx", Integral(1, (x, 0, 7))), + (r"\int\limits_{0}^{1} x dx", Integral(x, (x, 0, 1))), + (r"\int_a^b x dx", Integral(x, (x, a, b))), + (r"\int^b_a x dx", Integral(x, (x, a, b))), + (r"\int_{a}^b x dx", Integral(x, (x, a, b))), + (r"\int^{b}_a x dx", Integral(x, (x, a, b))), + (r"\int_{a}^{b} x dx", Integral(x, (x, a, b))), + (r"\int^{b}_{a} x dx", Integral(x, (x, a, b))), + (r"\int_{f(a)}^{f(b)} f(z) dz", Integral(f(z), (z, f(a), f(b)))), + (r"\int a + b + c dx", Integral(a + b + c, x)), + (r"\int \frac{dz}{z}", Integral(Pow(z, -1), z)), + (r"\int \frac{3 dz}{z}", Integral(3 * Pow(z, -1), z)), + (r"\int \frac{1}{x} dx", Integral(1 / x, x)), + (r"\int \frac{1}{a} + \frac{1}{b} dx", Integral(1 / a + 1 / b, x)), + (r"\int \frac{1}{a} - \frac{1}{b} dx", Integral(1 / a - 1 / b, x)), + (r"\int \frac{1}{x} + 1 dx", Integral(1 / x + 1, x)) +] + +DERIVATIVE_EXPRESSION_PAIRS = [ + (r"\frac{d}{dx} x", Derivative(x, x)), + (r"\frac{d}{dt} x", Derivative(x, t)), + (r"\frac{d}{dx} ( \tan x )", Derivative(tan(x), x)), + (r"\frac{d f(x)}{dx}", Derivative(f(x), x)), + (r"\frac{d\theta(x)}{dx}", Derivative(Function('theta')(x), x)) +] + +TRIGONOMETRIC_EXPRESSION_PAIRS = [ + (r"\sin \theta", sin(theta)), + (r"\sin(\theta)", sin(theta)), + (r"\sin^{-1} a", asin(a)), + (r"\sin a \cos b", _Mul(sin(a), cos(b))), + (r"\sin \cos \theta", sin(cos(theta))), + (r"\sin(\cos \theta)", sin(cos(theta))), + (r"(\csc x)(\sec y)", csc(x) * sec(y)), + (r"\frac{\sin{x}}2", _Mul(sin(x), _Pow(2, -1))) +] + +UNEVALUATED_LIMIT_EXPRESSION_PAIRS = [ + (r"\lim_{x \to 3} a", Limit(a, x, 3, dir="+-")), + (r"\lim_{x \rightarrow 3} a", Limit(a, x, 3, dir="+-")), + (r"\lim_{x \Rightarrow 3} a", Limit(a, x, 3, dir="+-")), + (r"\lim_{x \longrightarrow 3} a", Limit(a, x, 3, dir="+-")), + (r"\lim_{x \Longrightarrow 3} a", Limit(a, x, 3, dir="+-")), + (r"\lim_{x \to 3^{+}} a", Limit(a, x, 3, dir="+")), + (r"\lim_{x \to 3^{-}} a", Limit(a, x, 3, dir="-")), + (r"\lim_{x \to 3^+} a", Limit(a, x, 3, dir="+")), + (r"\lim_{x \to 3^-} a", Limit(a, x, 3, dir="-")), + (r"\lim_{x \to \infty} \frac{1}{x}", Limit(_Mul(1, _Pow(x, -1)), x, oo)) +] + +EVALUATED_LIMIT_EXPRESSION_PAIRS = [ + (r"\lim_{x \to \infty} \frac{1}{x}", Limit(1 / x, x, oo)) +] + +UNEVALUATED_SQRT_EXPRESSION_PAIRS = [ + (r"\sqrt{x}", sqrt(x)), + (r"\sqrt{x + b}", sqrt(_Add(x, b))), + (r"\sqrt[3]{\sin x}", _Pow(sin(x), _Pow(3, -1))), + # the above test needed to be handled differently than the ones below because root + # acts differently if its second argument is a number + (r"\sqrt[y]{\sin x}", root(sin(x), y)), + (r"\sqrt[\theta]{\sin x}", root(sin(x), theta)), + (r"\sqrt{\frac{12}{6}}", _Sqrt(_Mul(12, _Pow(6, -1)))) +] + +EVALUATED_SQRT_EXPRESSION_PAIRS = [ + (r"\sqrt{x}", sqrt(x)), + (r"\sqrt{x + b}", sqrt(x + b)), + (r"\sqrt[3]{\sin x}", root(sin(x), 3)), + (r"\sqrt[y]{\sin x}", root(sin(x), y)), + (r"\sqrt[\theta]{\sin x}", root(sin(x), theta)), + (r"\sqrt{\frac{12}{6}}", sqrt(2)) +] + +UNEVALUATED_FACTORIAL_EXPRESSION_PAIRS = [ + (r"x!", _factorial(x)), + (r"100!", _factorial(100)), + (r"\theta!", _factorial(theta)), + (r"(x + 1)!", _factorial(_Add(x, 1))), + (r"(x!)!", _factorial(_factorial(x))), + (r"x!!!", _factorial(_factorial(_factorial(x)))), + (r"5!7!", _Mul(_factorial(5), _factorial(7))) +] + +EVALUATED_FACTORIAL_EXPRESSION_PAIRS = [ + (r"x!", factorial(x)), + (r"100!", factorial(100)), + (r"\theta!", factorial(theta)), + (r"(x + 1)!", factorial(x + 1)), + (r"(x!)!", factorial(factorial(x))), + (r"x!!!", factorial(factorial(factorial(x)))), + (r"5!7!", factorial(5) * factorial(7)), + (r"24! \times 24!", factorial(24) * factorial(24)) +] + +UNEVALUATED_SUM_EXPRESSION_PAIRS = [ + (r"\sum_{k = 1}^{3} c", Sum(_Mul(1, c), (k, 1, 3))), + (r"\sum_{k = 1}^3 c", Sum(_Mul(1, c), (k, 1, 3))), + (r"\sum^{3}_{k = 1} c", Sum(_Mul(1, c), (k, 1, 3))), + (r"\sum^3_{k = 1} c", Sum(_Mul(1, c), (k, 1, 3))), + (r"\sum_{k = 1}^{10} k^2", Sum(_Mul(1, k ** 2), (k, 1, 10))), + (r"\sum_{n = 0}^{\infty} \frac{1}{n!}", + Sum(_Mul(1, _Mul(1, _Pow(_factorial(n), -1))), (n, 0, oo))) +] + +EVALUATED_SUM_EXPRESSION_PAIRS = [ + (r"\sum_{k = 1}^{3} c", Sum(c, (k, 1, 3))), + (r"\sum_{k = 1}^3 c", Sum(c, (k, 1, 3))), + (r"\sum^{3}_{k = 1} c", Sum(c, (k, 1, 3))), + (r"\sum^3_{k = 1} c", Sum(c, (k, 1, 3))), + (r"\sum_{k = 1}^{10} k^2", Sum(k ** 2, (k, 1, 10))), + (r"\sum_{n = 0}^{\infty} \frac{1}{n!}", Sum(1 / factorial(n), (n, 0, oo))) +] + +UNEVALUATED_PRODUCT_EXPRESSION_PAIRS = [ + (r"\prod_{a = b}^{c} x", Product(x, (a, b, c))), + (r"\prod_{a = b}^c x", Product(x, (a, b, c))), + (r"\prod^{c}_{a = b} x", Product(x, (a, b, c))), + (r"\prod^c_{a = b} x", Product(x, (a, b, c))) +] + +APPLIED_FUNCTION_EXPRESSION_PAIRS = [ + (r"f(x)", f(x)), + (r"f(x, y)", f(x, y)), + (r"f(x, y, z)", f(x, y, z)), + (r"f'_1(x)", Function("f_{1}'")(x)), + (r"f_{1}''(x+y)", Function("f_{1}''")(x + y)), + (r"h_{\theta}(x_0, x_1)", + Function('h_{theta}')(Symbol('x_{0}'), Symbol('x_{1}'))) +] + +UNEVALUATED_COMMON_FUNCTION_EXPRESSION_PAIRS = [ + (r"|x|", _Abs(x)), + (r"||x||", _Abs(Abs(x))), + (r"|x||y|", _Abs(x) * _Abs(y)), + (r"||x||y||", _Abs(_Abs(x) * _Abs(y))), + (r"\lfloor x \rfloor", floor(x)), + (r"\lceil x \rceil", ceiling(x)), + (r"\exp x", _exp(x)), + (r"\exp(x)", _exp(x)), + (r"\lg x", _log(x, 10)), + (r"\ln x", _log(x)), + (r"\ln xy", _log(x * y)), + (r"\log x", _log(x)), + (r"\log xy", _log(x * y)), + (r"\log_{2} x", _log(x, 2)), + (r"\log_{a} x", _log(x, a)), + (r"\log_{11} x", _log(x, 11)), + (r"\log_{a^2} x", _log(x, _Pow(a, 2))), + (r"\log_2 x", _log(x, 2)), + (r"\log_a x", _log(x, a)), + (r"\overline{z}", _Conjugate(z)), + (r"\overline{\overline{z}}", _Conjugate(_Conjugate(z))), + (r"\overline{x + y}", _Conjugate(_Add(x, y))), + (r"\overline{x} + \overline{y}", _Conjugate(x) + _Conjugate(y)), + (r"\min(a, b)", _Min(a, b)), + (r"\min(a, b, c - d, xy)", _Min(a, b, c - d, x * y)), + (r"\max(a, b)", _Max(a, b)), + (r"\max(a, b, c - d, xy)", _Max(a, b, c - d, x * y)), + # physics things don't have an `evaluate=False` variant + (r"\langle x |", Bra('x')), + (r"| x \rangle", Ket('x')), + (r"\langle x | y \rangle", InnerProduct(Bra('x'), Ket('y'))), +] + +EVALUATED_COMMON_FUNCTION_EXPRESSION_PAIRS = [ + (r"|x|", Abs(x)), + (r"||x||", Abs(Abs(x))), + (r"|x||y|", Abs(x) * Abs(y)), + (r"||x||y||", Abs(Abs(x) * Abs(y))), + (r"\lfloor x \rfloor", floor(x)), + (r"\lceil x \rceil", ceiling(x)), + (r"\exp x", exp(x)), + (r"\exp(x)", exp(x)), + (r"\lg x", log(x, 10)), + (r"\ln x", log(x)), + (r"\ln xy", log(x * y)), + (r"\log x", log(x)), + (r"\log xy", log(x * y)), + (r"\log_{2} x", log(x, 2)), + (r"\log_{a} x", log(x, a)), + (r"\log_{11} x", log(x, 11)), + (r"\log_{a^2} x", log(x, _Pow(a, 2))), + (r"\log_2 x", log(x, 2)), + (r"\log_a x", log(x, a)), + (r"\overline{z}", conjugate(z)), + (r"\overline{\overline{z}}", conjugate(conjugate(z))), + (r"\overline{x + y}", conjugate(x + y)), + (r"\overline{x} + \overline{y}", conjugate(x) + conjugate(y)), + (r"\min(a, b)", Min(a, b)), + (r"\min(a, b, c - d, xy)", Min(a, b, c - d, x * y)), + (r"\max(a, b)", Max(a, b)), + (r"\max(a, b, c - d, xy)", Max(a, b, c - d, x * y)), + (r"\langle x |", Bra('x')), + (r"| x \rangle", Ket('x')), + (r"\langle x | y \rangle", InnerProduct(Bra('x'), Ket('y'))), +] + +SPACING_RELATED_EXPRESSION_PAIRS = [ + (r"a \, b", _Mul(a, b)), + (r"a \thinspace b", _Mul(a, b)), + (r"a \: b", _Mul(a, b)), + (r"a \medspace b", _Mul(a, b)), + (r"a \; b", _Mul(a, b)), + (r"a \thickspace b", _Mul(a, b)), + (r"a \quad b", _Mul(a, b)), + (r"a \qquad b", _Mul(a, b)), + (r"a \! b", _Mul(a, b)), + (r"a \negthinspace b", _Mul(a, b)), + (r"a \negmedspace b", _Mul(a, b)), + (r"a \negthickspace b", _Mul(a, b)) +] + +UNEVALUATED_BINOMIAL_EXPRESSION_PAIRS = [ + (r"\binom{n}{k}", _binomial(n, k)), + (r"\tbinom{n}{k}", _binomial(n, k)), + (r"\dbinom{n}{k}", _binomial(n, k)), + (r"\binom{n}{0}", _binomial(n, 0)), + (r"x^\binom{n}{k}", _Pow(x, _binomial(n, k))) +] + +EVALUATED_BINOMIAL_EXPRESSION_PAIRS = [ + (r"\binom{n}{k}", binomial(n, k)), + (r"\tbinom{n}{k}", binomial(n, k)), + (r"\dbinom{n}{k}", binomial(n, k)), + (r"\binom{n}{0}", binomial(n, 0)), + (r"x^\binom{n}{k}", x ** binomial(n, k)) +] + +MISCELLANEOUS_EXPRESSION_PAIRS = [ + (r"\left(x + y\right) z", _Mul(_Add(x, y), z)), + (r"\left( x + y\right ) z", _Mul(_Add(x, y), z)), + (r"\left( x + y\right ) z", _Mul(_Add(x, y), z)), +] + +UNEVALUATED_LITERAL_COMPLEX_NUMBER_EXPRESSION_PAIRS = [ + (r"\imaginaryunit^2", _Pow(I, 2)), + (r"|\imaginaryunit|", _Abs(I)), + (r"\overline{\imaginaryunit}", _Conjugate(I)), + (r"\imaginaryunit+\imaginaryunit", _Add(I, I)), + (r"\imaginaryunit-\imaginaryunit", _Add(I, -I)), + (r"\imaginaryunit*\imaginaryunit", _Mul(I, I)), + (r"\imaginaryunit/\imaginaryunit", _Mul(I, _Pow(I, -1))), + (r"(1+\imaginaryunit)/|1+\imaginaryunit|", _Mul(_Add(1, I), _Pow(_Abs(_Add(1, I)), -1))) +] + +UNEVALUATED_MATRIX_EXPRESSION_PAIRS = [ + (r"\begin{pmatrix}a & b \\x & y\end{pmatrix}", + Matrix([[a, b], [x, y]])), + (r"\begin{pmatrix}a & b \\x & y\\\end{pmatrix}", + Matrix([[a, b], [x, y]])), + (r"\begin{bmatrix}a & b \\x & y\end{bmatrix}", + Matrix([[a, b], [x, y]])), + (r"\left(\begin{matrix}a & b \\x & y\end{matrix}\right)", + Matrix([[a, b], [x, y]])), + (r"\left[\begin{matrix}a & b \\x & y\end{matrix}\right]", + Matrix([[a, b], [x, y]])), + (r"\left[\begin{array}{cc}a & b \\x & y\end{array}\right]", + Matrix([[a, b], [x, y]])), + (r"\left(\begin{array}{cc}a & b \\x & y\end{array}\right)", + Matrix([[a, b], [x, y]])), + (r"\left( { \begin{array}{cc}a & b \\x & y\end{array} } \right)", + Matrix([[a, b], [x, y]])), + (r"+\begin{pmatrix}a & b \\x & y\end{pmatrix}", + Matrix([[a, b], [x, y]])), + ((r"\begin{pmatrix}x & y \\a & b\end{pmatrix}+" + r"\begin{pmatrix}a & b \\x & y\end{pmatrix}"), + _MatAdd(Matrix([[x, y], [a, b]]), Matrix([[a, b], [x, y]]))), + (r"-\begin{pmatrix}a & b \\x & y\end{pmatrix}", + _MatMul(-1, Matrix([[a, b], [x, y]]))), + ((r"\begin{pmatrix}x & y \\a & b\end{pmatrix}-" + r"\begin{pmatrix}a & b \\x & y\end{pmatrix}"), + _MatAdd(Matrix([[x, y], [a, b]]), _MatMul(-1, Matrix([[a, b], [x, y]])))), + ((r"\begin{pmatrix}a & b & c \\x & y & z \\a & b & c \end{pmatrix}*" + r"\begin{pmatrix}x & y & z \\a & b & c \\a & b & c \end{pmatrix}*" + r"\begin{pmatrix}a & b & c \\x & y & z \\x & y & z \end{pmatrix}"), + _MatMul(_MatMul(Matrix([[a, b, c], [x, y, z], [a, b, c]]), + Matrix([[x, y, z], [a, b, c], [a, b, c]])), + Matrix([[a, b, c], [x, y, z], [x, y, z]]))), + (r"\begin{pmatrix}a & b \\x & y\end{pmatrix}/2", + _MatMul(Matrix([[a, b], [x, y]]), _Pow(2, -1))), + (r"\begin{pmatrix}a & b \\x & y\end{pmatrix}^2", + _Pow(Matrix([[a, b], [x, y]]), 2)), + (r"\begin{pmatrix}a & b \\x & y\end{pmatrix}^{-1}", + _Pow(Matrix([[a, b], [x, y]]), -1)), + (r"\begin{pmatrix}a & b \\x & y\end{pmatrix}^T", + Transpose(Matrix([[a, b], [x, y]]))), + (r"\begin{pmatrix}a & b \\x & y\end{pmatrix}^{T}", + Transpose(Matrix([[a, b], [x, y]]))), + (r"\begin{pmatrix}a & b \\x & y\end{pmatrix}^\mathit{T}", + Transpose(Matrix([[a, b], [x, y]]))), + (r"\begin{pmatrix}1 & 2 \\3 & 4\end{pmatrix}^T", + Transpose(Matrix([[1, 2], [3, 4]]))), + ((r"(\begin{pmatrix}1 & 2 \\3 & 4\end{pmatrix}+" + r"\begin{pmatrix}1 & 2 \\3 & 4\end{pmatrix}^T)*" + r"\begin{bmatrix}1\\0\end{bmatrix}"), + _MatMul(_MatAdd(Matrix([[1, 2], [3, 4]]), + Transpose(Matrix([[1, 2], [3, 4]]))), + Matrix([[1], [0]]))), + ((r"(\begin{pmatrix}a & b \\x & y\end{pmatrix}+" + r"\begin{pmatrix}x & y \\a & b\end{pmatrix})^2"), + _Pow(_MatAdd(Matrix([[a, b], [x, y]]), + Matrix([[x, y], [a, b]])), 2)), + ((r"(\begin{pmatrix}a & b \\x & y\end{pmatrix}+" + r"\begin{pmatrix}x & y \\a & b\end{pmatrix})^T"), + Transpose(_MatAdd(Matrix([[a, b], [x, y]]), + Matrix([[x, y], [a, b]])))), + (r"\overline{\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}}", + _Conjugate(_MatAdd(Matrix([[I, 2], [3, 4]]), + Matrix([[I, 2], [3, 4]])))) +] + +EVALUATED_MATRIX_EXPRESSION_PAIRS = [ + (r"\det\left(\left[ { \begin{array}{cc}a&b\\x&y\end{array} } \right]\right)", + Matrix([[a, b], [x, y]]).det()), + (r"\det \begin{pmatrix}1&2\\3&4\end{pmatrix}", -2), + (r"\det{\begin{pmatrix}1&2\\3&4\end{pmatrix}}", -2), + (r"\det(\begin{pmatrix}1&2\\3&4\end{pmatrix})", -2), + (r"\det\left(\begin{pmatrix}1&2\\3&4\end{pmatrix}\right)", -2), + (r"\begin{pmatrix}a & b \\x & y\end{pmatrix}/\begin{vmatrix}a & b \\x & y\end{vmatrix}", + _MatMul(Matrix([[a, b], [x, y]]), _Pow(Matrix([[a, b], [x, y]]).det(), -1))), + (r"\begin{pmatrix}a & b \\x & y\end{pmatrix}/|\begin{matrix}a & b \\x & y\end{matrix}|", + _MatMul(Matrix([[a, b], [x, y]]), _Pow(Matrix([[a, b], [x, y]]).det(), -1))), + (r"\frac{\begin{pmatrix}a & b \\x & y\end{pmatrix}}{| { \begin{matrix}a & b \\x & y\end{matrix} } |}", + _MatMul(Matrix([[a, b], [x, y]]), _Pow(Matrix([[a, b], [x, y]]).det(), -1))), + (r"\overline{\begin{pmatrix}\imaginaryunit & 1+\imaginaryunit \\-\imaginaryunit & 4\end{pmatrix}}", + Matrix([[-I, 1-I], [I, 4]])), + (r"\begin{pmatrix}\imaginaryunit & 1+\imaginaryunit \\-\imaginaryunit & 4\end{pmatrix}^H", + Matrix([[-I, I], [1-I, 4]])), + (r"\trace(\begin{pmatrix}\imaginaryunit & 1+\imaginaryunit \\-\imaginaryunit & 4\end{pmatrix})", + Trace(Matrix([[I, 1+I], [-I, 4]]))), + (r"\adjugate(\begin{pmatrix}1 & 2 \\3 & 4\end{pmatrix})", + Matrix([[4, -2], [-3, 1]])), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})^\ast", + Matrix([[-2*I, 6], [4, 8]])), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})^{\ast}", + Matrix([[-2*I, 6], [4, 8]])), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})^{\ast\ast}", + Matrix([[2*I, 4], [6, 8]])), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})^{\ast\ast\ast}", + Matrix([[-2*I, 6], [4, 8]])), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})^{*}", + Matrix([[-2*I, 6], [4, 8]])), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})^{**}", + Matrix([[2*I, 4], [6, 8]])), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})^{***}", + Matrix([[-2*I, 6], [4, 8]])), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})^\prime", + Transpose(_MatAdd(Matrix([[I, 2], [3, 4]]), + Matrix([[I, 2], [3, 4]])))), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})^{\prime}", + Transpose(_MatAdd(Matrix([[I, 2], [3, 4]]), + Matrix([[I, 2], [3, 4]])))), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})^{\prime\prime}", + _MatAdd(Matrix([[I, 2], [3, 4]]), + Matrix([[I, 2], [3, 4]]))), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})^{\prime\prime\prime}", + Transpose(_MatAdd(Matrix([[I, 2], [3, 4]]), + Matrix([[I, 2], [3, 4]])))), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})^{'}", + Transpose(_MatAdd(Matrix([[I, 2], [3, 4]]), + Matrix([[I, 2], [3, 4]])))), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})^{''}", + _MatAdd(Matrix([[I, 2], [3, 4]]), + Matrix([[I, 2], [3, 4]]))), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})^{'''}", + Transpose(_MatAdd(Matrix([[I, 2], [3, 4]]), + Matrix([[I, 2], [3, 4]])))), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})'", + Transpose(_MatAdd(Matrix([[I, 2], [3, 4]]), + Matrix([[I, 2], [3, 4]])))), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})''", + _MatAdd(Matrix([[I, 2], [3, 4]]), + Matrix([[I, 2], [3, 4]]))), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})'''", + Transpose(_MatAdd(Matrix([[I, 2], [3, 4]]), + Matrix([[I, 2], [3, 4]])))), + (r"\det(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})", + (_MatAdd(Matrix([[I, 2], [3, 4]]), + Matrix([[I, 2], [3, 4]]))).det()), + (r"\trace(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})", + Trace(_MatAdd(Matrix([[I, 2], [3, 4]]), + Matrix([[I, 2], [3, 4]])))), + (r"\adjugate(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})", + (Matrix([[8, -4], [-6, 2*I]]))), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})^T", + Transpose(_MatAdd(Matrix([[I, 2], [3, 4]]), + Matrix([[I, 2], [3, 4]])))), + (r"(\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix}+\begin{pmatrix}\imaginaryunit&2\\3&4\end{pmatrix})^H", + (Matrix([[-2*I, 6], [4, 8]]))) +] + + +def test_symbol_expressions(): + expected_failures = {6, 7} + for i, (latex_str, sympy_expr) in enumerate(SYMBOL_EXPRESSION_PAIRS): + if i in expected_failures: + continue + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + +def test_simple_expressions(): + expected_failures = {20} + for i, (latex_str, sympy_expr) in enumerate(UNEVALUATED_SIMPLE_EXPRESSION_PAIRS): + if i in expected_failures: + continue + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + for i, (latex_str, sympy_expr) in enumerate(EVALUATED_SIMPLE_EXPRESSION_PAIRS): + if i in expected_failures: + continue + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + +def test_fraction_expressions(): + for latex_str, sympy_expr in UNEVALUATED_FRACTION_EXPRESSION_PAIRS: + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + for latex_str, sympy_expr in EVALUATED_FRACTION_EXPRESSION_PAIRS: + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + +def test_relation_expressions(): + for latex_str, sympy_expr in RELATION_EXPRESSION_PAIRS: + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + +def test_power_expressions(): + expected_failures = {3} + for i, (latex_str, sympy_expr) in enumerate(UNEVALUATED_POWER_EXPRESSION_PAIRS): + if i in expected_failures: + continue + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + for i, (latex_str, sympy_expr) in enumerate(EVALUATED_POWER_EXPRESSION_PAIRS): + if i in expected_failures: + continue + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + +def test_integral_expressions(): + expected_failures = {14} + for i, (latex_str, sympy_expr) in enumerate(UNEVALUATED_INTEGRAL_EXPRESSION_PAIRS): + if i in expected_failures: + continue + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, i + + for i, (latex_str, sympy_expr) in enumerate(EVALUATED_INTEGRAL_EXPRESSION_PAIRS): + if i in expected_failures: + continue + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + +def test_derivative_expressions(): + expected_failures = {3, 4} + for i, (latex_str, sympy_expr) in enumerate(DERIVATIVE_EXPRESSION_PAIRS): + if i in expected_failures: + continue + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + for i, (latex_str, sympy_expr) in enumerate(DERIVATIVE_EXPRESSION_PAIRS): + if i in expected_failures: + continue + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + +def test_trigonometric_expressions(): + expected_failures = {3} + for i, (latex_str, sympy_expr) in enumerate(TRIGONOMETRIC_EXPRESSION_PAIRS): + if i in expected_failures: + continue + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + +def test_limit_expressions(): + for latex_str, sympy_expr in UNEVALUATED_LIMIT_EXPRESSION_PAIRS: + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + +def test_square_root_expressions(): + for latex_str, sympy_expr in UNEVALUATED_SQRT_EXPRESSION_PAIRS: + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + for latex_str, sympy_expr in EVALUATED_SQRT_EXPRESSION_PAIRS: + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + +def test_factorial_expressions(): + for latex_str, sympy_expr in UNEVALUATED_FACTORIAL_EXPRESSION_PAIRS: + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + for latex_str, sympy_expr in EVALUATED_FACTORIAL_EXPRESSION_PAIRS: + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + +def test_sum_expressions(): + for latex_str, sympy_expr in UNEVALUATED_SUM_EXPRESSION_PAIRS: + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + for latex_str, sympy_expr in EVALUATED_SUM_EXPRESSION_PAIRS: + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + +def test_product_expressions(): + for latex_str, sympy_expr in UNEVALUATED_PRODUCT_EXPRESSION_PAIRS: + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + +@XFAIL +def test_applied_function_expressions(): + expected_failures = {0, 3, 4} # 0 is ambiguous, and the others require not-yet-added features + # not sure why 1, and 2 are failing + for i, (latex_str, sympy_expr) in enumerate(APPLIED_FUNCTION_EXPRESSION_PAIRS): + if i in expected_failures: + continue + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + +def test_common_function_expressions(): + for latex_str, sympy_expr in UNEVALUATED_COMMON_FUNCTION_EXPRESSION_PAIRS: + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + for latex_str, sympy_expr in EVALUATED_COMMON_FUNCTION_EXPRESSION_PAIRS: + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + +# unhandled bug causing these to fail +@XFAIL +def test_spacing(): + for latex_str, sympy_expr in SPACING_RELATED_EXPRESSION_PAIRS: + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + +def test_binomial_expressions(): + for latex_str, sympy_expr in UNEVALUATED_BINOMIAL_EXPRESSION_PAIRS: + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + for latex_str, sympy_expr in EVALUATED_BINOMIAL_EXPRESSION_PAIRS: + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + +def test_miscellaneous_expressions(): + for latex_str, sympy_expr in MISCELLANEOUS_EXPRESSION_PAIRS: + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + +def test_literal_complex_number_expressions(): + for latex_str, sympy_expr in UNEVALUATED_LITERAL_COMPLEX_NUMBER_EXPRESSION_PAIRS: + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + +def test_matrix_expressions(): + for latex_str, sympy_expr in UNEVALUATED_MATRIX_EXPRESSION_PAIRS: + with evaluate(False): + assert parse_latex_lark(latex_str) == sympy_expr, latex_str + + for latex_str, sympy_expr in EVALUATED_MATRIX_EXPRESSION_PAIRS: + assert parse_latex_lark(latex_str) == sympy_expr, latex_str diff --git a/phivenv/Lib/site-packages/sympy/parsing/tests/test_mathematica.py b/phivenv/Lib/site-packages/sympy/parsing/tests/test_mathematica.py new file mode 100644 index 0000000000000000000000000000000000000000..df193b6d61f9c82778d8e0a40b893cbe6cb8f06a --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/tests/test_mathematica.py @@ -0,0 +1,280 @@ +from sympy import sin, Function, symbols, Dummy, Lambda, cos +from sympy.parsing.mathematica import parse_mathematica, MathematicaParser +from sympy.core.sympify import sympify +from sympy.abc import n, w, x, y, z +from sympy.testing.pytest import raises + + +def test_mathematica(): + d = { + '- 6x': '-6*x', + 'Sin[x]^2': 'sin(x)**2', + '2(x-1)': '2*(x-1)', + '3y+8': '3*y+8', + 'ArcSin[2x+9(4-x)^2]/x': 'asin(2*x+9*(4-x)**2)/x', + 'x+y': 'x+y', + '355/113': '355/113', + '2.718281828': '2.718281828', + 'Cos(1/2 * π)': 'Cos(π/2)', + 'Sin[12]': 'sin(12)', + 'Exp[Log[4]]': 'exp(log(4))', + '(x+1)(x+3)': '(x+1)*(x+3)', + 'Cos[ArcCos[3.6]]': 'cos(acos(3.6))', + 'Cos[x]==Sin[y]': 'Eq(cos(x), sin(y))', + '2*Sin[x+y]': '2*sin(x+y)', + 'Sin[x]+Cos[y]': 'sin(x)+cos(y)', + 'Sin[Cos[x]]': 'sin(cos(x))', + '2*Sqrt[x+y]': '2*sqrt(x+y)', # Test case from the issue 4259 + '+Sqrt[2]': 'sqrt(2)', + '-Sqrt[2]': '-sqrt(2)', + '-1/Sqrt[2]': '-1/sqrt(2)', + '-(1/Sqrt[3])': '-(1/sqrt(3))', + '1/(2*Sqrt[5])': '1/(2*sqrt(5))', + 'Mod[5,3]': 'Mod(5,3)', + '-Mod[5,3]': '-Mod(5,3)', + '(x+1)y': '(x+1)*y', + 'x(y+1)': 'x*(y+1)', + 'Sin[x]Cos[y]': 'sin(x)*cos(y)', + 'Sin[x]^2Cos[y]^2': 'sin(x)**2*cos(y)**2', + 'Cos[x]^2(1 - Cos[y]^2)': 'cos(x)**2*(1-cos(y)**2)', + 'x y': 'x*y', + 'x y': 'x*y', + '2 x': '2*x', + 'x 8': 'x*8', + '2 8': '2*8', + '4.x': '4.*x', + '4. 3': '4.*3', + '4. 3.': '4.*3.', + '1 2 3': '1*2*3', + ' - 2 * Sqrt[ 2 3 * ( 1 + 5 ) ] ': '-2*sqrt(2*3*(1+5))', + 'Log[2,4]': 'log(4,2)', + 'Log[Log[2,4],4]': 'log(4,log(4,2))', + 'Exp[Sqrt[2]^2Log[2, 8]]': 'exp(sqrt(2)**2*log(8,2))', + 'ArcSin[Cos[0]]': 'asin(cos(0))', + 'Log2[16]': 'log(16,2)', + 'Max[1,-2,3,-4]': 'Max(1,-2,3,-4)', + 'Min[1,-2,3]': 'Min(1,-2,3)', + 'Exp[I Pi/2]': 'exp(I*pi/2)', + 'ArcTan[x,y]': 'atan2(y,x)', + 'Pochhammer[x,y]': 'rf(x,y)', + 'ExpIntegralEi[x]': 'Ei(x)', + 'SinIntegral[x]': 'Si(x)', + 'CosIntegral[x]': 'Ci(x)', + 'AiryAi[x]': 'airyai(x)', + 'AiryAiPrime[5]': 'airyaiprime(5)', + 'AiryBi[x]': 'airybi(x)', + 'AiryBiPrime[7]': 'airybiprime(7)', + 'LogIntegral[4]': ' li(4)', + 'PrimePi[7]': 'primepi(7)', + 'Prime[5]': 'prime(5)', + 'PrimeQ[5]': 'isprime(5)', + 'Rational[2,19]': 'Rational(2,19)', # test case for issue 25716 + } + + for e in d: + assert parse_mathematica(e) == sympify(d[e]) + + # The parsed form of this expression should not evaluate the Lambda object: + assert parse_mathematica("Sin[#]^2 + Cos[#]^2 &[x]") == sin(x)**2 + cos(x)**2 + + d1, d2, d3 = symbols("d1:4", cls=Dummy) + assert parse_mathematica("Sin[#] + Cos[#3] &").dummy_eq(Lambda((d1, d2, d3), sin(d1) + cos(d3))) + assert parse_mathematica("Sin[#^2] &").dummy_eq(Lambda(d1, sin(d1**2))) + assert parse_mathematica("Function[x, x^3]") == Lambda(x, x**3) + assert parse_mathematica("Function[{x, y}, x^2 + y^2]") == Lambda((x, y), x**2 + y**2) + + +def test_parser_mathematica_tokenizer(): + parser = MathematicaParser() + + chain = lambda expr: parser._from_tokens_to_fullformlist(parser._from_mathematica_to_tokens(expr)) + + # Basic patterns + assert chain("x") == "x" + assert chain("42") == "42" + assert chain(".2") == ".2" + assert chain("+x") == "x" + assert chain("-1") == "-1" + assert chain("- 3") == "-3" + assert chain("α") == "α" + assert chain("+Sin[x]") == ["Sin", "x"] + assert chain("-Sin[x]") == ["Times", "-1", ["Sin", "x"]] + assert chain("x(a+1)") == ["Times", "x", ["Plus", "a", "1"]] + assert chain("(x)") == "x" + assert chain("(+x)") == "x" + assert chain("-a") == ["Times", "-1", "a"] + assert chain("(-x)") == ["Times", "-1", "x"] + assert chain("(x + y)") == ["Plus", "x", "y"] + assert chain("3 + 4") == ["Plus", "3", "4"] + assert chain("a - 3") == ["Plus", "a", "-3"] + assert chain("a - b") == ["Plus", "a", ["Times", "-1", "b"]] + assert chain("7 * 8") == ["Times", "7", "8"] + assert chain("a + b*c") == ["Plus", "a", ["Times", "b", "c"]] + assert chain("a + b* c* d + 2 * e") == ["Plus", "a", ["Times", "b", "c", "d"], ["Times", "2", "e"]] + assert chain("a / b") == ["Times", "a", ["Power", "b", "-1"]] + + # Missing asterisk (*) patterns: + assert chain("x y") == ["Times", "x", "y"] + assert chain("3 4") == ["Times", "3", "4"] + assert chain("a[b] c") == ["Times", ["a", "b"], "c"] + assert chain("(x) (y)") == ["Times", "x", "y"] + assert chain("3 (a)") == ["Times", "3", "a"] + assert chain("(a) b") == ["Times", "a", "b"] + assert chain("4.2") == "4.2" + assert chain("4 2") == ["Times", "4", "2"] + assert chain("4 2") == ["Times", "4", "2"] + assert chain("3 . 4") == ["Dot", "3", "4"] + assert chain("4. 2") == ["Times", "4.", "2"] + assert chain("x.y") == ["Dot", "x", "y"] + assert chain("4.y") == ["Times", "4.", "y"] + assert chain("4 .y") == ["Dot", "4", "y"] + assert chain("x.4") == ["Times", "x", ".4"] + assert chain("x0.3") == ["Times", "x0", ".3"] + assert chain("x. 4") == ["Dot", "x", "4"] + + # Comments + assert chain("a (* +b *) + c") == ["Plus", "a", "c"] + assert chain("a (* + b *) + (**)c (* +d *) + e") == ["Plus", "a", "c", "e"] + assert chain("""a + (* + + b + *) c + (* d + *) e + """) == ["Plus", "a", "c", "e"] + + # Operators couples + and -, * and / are mutually associative: + # (i.e. expression gets flattened when mixing these operators) + assert chain("a*b/c") == ["Times", "a", "b", ["Power", "c", "-1"]] + assert chain("a/b*c") == ["Times", "a", ["Power", "b", "-1"], "c"] + assert chain("a+b-c") == ["Plus", "a", "b", ["Times", "-1", "c"]] + assert chain("a-b+c") == ["Plus", "a", ["Times", "-1", "b"], "c"] + assert chain("-a + b -c ") == ["Plus", ["Times", "-1", "a"], "b", ["Times", "-1", "c"]] + assert chain("a/b/c*d") == ["Times", "a", ["Power", "b", "-1"], ["Power", "c", "-1"], "d"] + assert chain("a/b/c") == ["Times", "a", ["Power", "b", "-1"], ["Power", "c", "-1"]] + assert chain("a-b-c") == ["Plus", "a", ["Times", "-1", "b"], ["Times", "-1", "c"]] + assert chain("1/a") == ["Times", "1", ["Power", "a", "-1"]] + assert chain("1/a/b") == ["Times", "1", ["Power", "a", "-1"], ["Power", "b", "-1"]] + assert chain("-1/a*b") == ["Times", "-1", ["Power", "a", "-1"], "b"] + + # Enclosures of various kinds, i.e. ( ) [ ] [[ ]] { } + assert chain("(a + b) + c") == ["Plus", ["Plus", "a", "b"], "c"] + assert chain(" a + (b + c) + d ") == ["Plus", "a", ["Plus", "b", "c"], "d"] + assert chain("a * (b + c)") == ["Times", "a", ["Plus", "b", "c"]] + assert chain("a b (c d)") == ["Times", "a", "b", ["Times", "c", "d"]] + assert chain("{a, b, 2, c}") == ["List", "a", "b", "2", "c"] + assert chain("{a, {b, c}}") == ["List", "a", ["List", "b", "c"]] + assert chain("{{a}}") == ["List", ["List", "a"]] + assert chain("a[b, c]") == ["a", "b", "c"] + assert chain("a[[b, c]]") == ["Part", "a", "b", "c"] + assert chain("a[b[c]]") == ["a", ["b", "c"]] + assert chain("a[[b, c[[d, {e,f}]]]]") == ["Part", "a", "b", ["Part", "c", "d", ["List", "e", "f"]]] + assert chain("a[b[[c,d]]]") == ["a", ["Part", "b", "c", "d"]] + assert chain("a[[b[c]]]") == ["Part", "a", ["b", "c"]] + assert chain("a[[b[[c]]]]") == ["Part", "a", ["Part", "b", "c"]] + assert chain("a[[b[c[[d]]]]]") == ["Part", "a", ["b", ["Part", "c", "d"]]] + assert chain("a[b[[c[d]]]]") == ["a", ["Part", "b", ["c", "d"]]] + assert chain("x[[a+1, b+2, c+3]]") == ["Part", "x", ["Plus", "a", "1"], ["Plus", "b", "2"], ["Plus", "c", "3"]] + assert chain("x[a+1, b+2, c+3]") == ["x", ["Plus", "a", "1"], ["Plus", "b", "2"], ["Plus", "c", "3"]] + assert chain("{a+1, b+2, c+3}") == ["List", ["Plus", "a", "1"], ["Plus", "b", "2"], ["Plus", "c", "3"]] + + # Flat operator: + assert chain("a*b*c*d*e") == ["Times", "a", "b", "c", "d", "e"] + assert chain("a +b + c+ d+e") == ["Plus", "a", "b", "c", "d", "e"] + + # Right priority operator: + assert chain("a^b") == ["Power", "a", "b"] + assert chain("a^b^c") == ["Power", "a", ["Power", "b", "c"]] + assert chain("a^b^c^d") == ["Power", "a", ["Power", "b", ["Power", "c", "d"]]] + + # Left priority operator: + assert chain("a/.b") == ["ReplaceAll", "a", "b"] + assert chain("a/.b/.c/.d") == ["ReplaceAll", ["ReplaceAll", ["ReplaceAll", "a", "b"], "c"], "d"] + + assert chain("a//b") == ["a", "b"] + assert chain("a//b//c") == [["a", "b"], "c"] + assert chain("a//b//c//d") == [[["a", "b"], "c"], "d"] + + # Compound expressions + assert chain("a;b") == ["CompoundExpression", "a", "b"] + assert chain("a;") == ["CompoundExpression", "a", "Null"] + assert chain("a;b;") == ["CompoundExpression", "a", "b", "Null"] + assert chain("a[b;c]") == ["a", ["CompoundExpression", "b", "c"]] + assert chain("a[b,c;d,e]") == ["a", "b", ["CompoundExpression", "c", "d"], "e"] + assert chain("a[b,c;,d]") == ["a", "b", ["CompoundExpression", "c", "Null"], "d"] + + # New lines + assert chain("a\nb\n") == ["CompoundExpression", "a", "b"] + assert chain("a\n\nb\n (c \nd) \n") == ["CompoundExpression", "a", "b", ["Times", "c", "d"]] + assert chain("\na; b\nc") == ["CompoundExpression", "a", "b", "c"] + assert chain("a + \nb\n") == ["Plus", "a", "b"] + assert chain("a\nb; c; d\n e; (f \n g); h + \n i") == ["CompoundExpression", "a", "b", "c", "d", "e", ["Times", "f", "g"], ["Plus", "h", "i"]] + assert chain("\n{\na\nb; c; d\n e (f \n g); h + \n i\n\n}\n") == ["List", ["CompoundExpression", ["Times", "a", "b"], "c", ["Times", "d", "e", ["Times", "f", "g"]], ["Plus", "h", "i"]]] + + # Patterns + assert chain("y_") == ["Pattern", "y", ["Blank"]] + assert chain("y_.") == ["Optional", ["Pattern", "y", ["Blank"]]] + assert chain("y__") == ["Pattern", "y", ["BlankSequence"]] + assert chain("y___") == ["Pattern", "y", ["BlankNullSequence"]] + assert chain("a[b_.,c_]") == ["a", ["Optional", ["Pattern", "b", ["Blank"]]], ["Pattern", "c", ["Blank"]]] + assert chain("b_. c") == ["Times", ["Optional", ["Pattern", "b", ["Blank"]]], "c"] + + # Slots for lambda functions + assert chain("#") == ["Slot", "1"] + assert chain("#3") == ["Slot", "3"] + assert chain("#n") == ["Slot", "n"] + assert chain("##") == ["SlotSequence", "1"] + assert chain("##a") == ["SlotSequence", "a"] + + # Lambda functions + assert chain("x&") == ["Function", "x"] + assert chain("#&") == ["Function", ["Slot", "1"]] + assert chain("#+3&") == ["Function", ["Plus", ["Slot", "1"], "3"]] + assert chain("#1 + #2&") == ["Function", ["Plus", ["Slot", "1"], ["Slot", "2"]]] + assert chain("# + #&") == ["Function", ["Plus", ["Slot", "1"], ["Slot", "1"]]] + assert chain("#&[x]") == [["Function", ["Slot", "1"]], "x"] + assert chain("#1 + #2 & [x, y]") == [["Function", ["Plus", ["Slot", "1"], ["Slot", "2"]]], "x", "y"] + assert chain("#1^2#2^3&") == ["Function", ["Times", ["Power", ["Slot", "1"], "2"], ["Power", ["Slot", "2"], "3"]]] + + # Strings inside Mathematica expressions: + assert chain('"abc"') == ["_Str", "abc"] + assert chain('"a\\"b"') == ["_Str", 'a"b'] + # This expression does not make sense mathematically, it's just testing the parser: + assert chain('x + "abc" ^ 3') == ["Plus", "x", ["Power", ["_Str", "abc"], "3"]] + assert chain('"a (* b *) c"') == ["_Str", "a (* b *) c"] + assert chain('"a" (* b *) ') == ["_Str", "a"] + assert chain('"a [ b] "') == ["_Str", "a [ b] "] + raises(SyntaxError, lambda: chain('"')) + raises(SyntaxError, lambda: chain('"\\"')) + raises(SyntaxError, lambda: chain('"abc')) + raises(SyntaxError, lambda: chain('"abc\\"def')) + + # Invalid expressions: + raises(SyntaxError, lambda: chain("(,")) + raises(SyntaxError, lambda: chain("()")) + raises(SyntaxError, lambda: chain("a (* b")) + + +def test_parser_mathematica_exp_alt(): + parser = MathematicaParser() + + convert_chain2 = lambda expr: parser._from_fullformlist_to_fullformsympy(parser._from_fullform_to_fullformlist(expr)) + convert_chain3 = lambda expr: parser._from_fullformsympy_to_sympy(convert_chain2(expr)) + + Sin, Times, Plus, Power = symbols("Sin Times Plus Power", cls=Function) + + full_form1 = "Sin[Times[x, y]]" + full_form2 = "Plus[Times[x, y], z]" + full_form3 = "Sin[Times[x, Plus[y, z], Power[w, n]]]]" + full_form4 = "Rational[Rational[x, y], z]" + + assert parser._from_fullform_to_fullformlist(full_form1) == ["Sin", ["Times", "x", "y"]] + assert parser._from_fullform_to_fullformlist(full_form2) == ["Plus", ["Times", "x", "y"], "z"] + assert parser._from_fullform_to_fullformlist(full_form3) == ["Sin", ["Times", "x", ["Plus", "y", "z"], ["Power", "w", "n"]]] + assert parser._from_fullform_to_fullformlist(full_form4) == ["Rational", ["Rational", "x", "y"], "z"] + + assert convert_chain2(full_form1) == Sin(Times(x, y)) + assert convert_chain2(full_form2) == Plus(Times(x, y), z) + assert convert_chain2(full_form3) == Sin(Times(x, Plus(y, z), Power(w, n))) + + assert convert_chain3(full_form1) == sin(x*y) + assert convert_chain3(full_form2) == x*y + z + assert convert_chain3(full_form3) == sin(x*(y + z)*w**n) diff --git a/phivenv/Lib/site-packages/sympy/parsing/tests/test_maxima.py b/phivenv/Lib/site-packages/sympy/parsing/tests/test_maxima.py new file mode 100644 index 0000000000000000000000000000000000000000..c0bc1db8f1385ed52e8c677a1bcc759f5118d01e --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/tests/test_maxima.py @@ -0,0 +1,50 @@ +from sympy.parsing.maxima import parse_maxima +from sympy.core.numbers import (E, Rational, oo) +from sympy.core.symbol import Symbol +from sympy.functions.combinatorial.factorials import factorial +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.exponential import log +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.abc import x + +n = Symbol('n', integer=True) + + +def test_parser(): + assert Abs(parse_maxima('float(1/3)') - 0.333333333) < 10**(-5) + assert parse_maxima('13^26') == 91733330193268616658399616009 + assert parse_maxima('sin(%pi/2) + cos(%pi/3)') == Rational(3, 2) + assert parse_maxima('log(%e)') == 1 + + +def test_injection(): + parse_maxima('c: x+1', globals=globals()) + # c created by parse_maxima + assert c == x + 1 # noqa:F821 + + parse_maxima('g: sqrt(81)', globals=globals()) + # g created by parse_maxima + assert g == 9 # noqa:F821 + + +def test_maxima_functions(): + assert parse_maxima('expand( (x+1)^2)') == x**2 + 2*x + 1 + assert parse_maxima('factor( x**2 + 2*x + 1)') == (x + 1)**2 + assert parse_maxima('2*cos(x)^2 + sin(x)^2') == 2*cos(x)**2 + sin(x)**2 + assert parse_maxima('trigexpand(sin(2*x)+cos(2*x))') == \ + -1 + 2*cos(x)**2 + 2*cos(x)*sin(x) + assert parse_maxima('solve(x^2-4,x)') == [-2, 2] + assert parse_maxima('limit((1+1/x)^x,x,inf)') == E + assert parse_maxima('limit(sqrt(-x)/x,x,0,minus)') is -oo + assert parse_maxima('diff(x^x, x)') == x**x*(1 + log(x)) + assert parse_maxima('sum(k, k, 1, n)', name_dict={ + "n": Symbol('n', integer=True), + "k": Symbol('k', integer=True) + }) == (n**2 + n)/2 + assert parse_maxima('product(k, k, 1, n)', name_dict={ + "n": Symbol('n', integer=True), + "k": Symbol('k', integer=True) + }) == factorial(n) + assert parse_maxima('ratsimp((x^2-1)/(x+1))') == x - 1 + assert Abs( parse_maxima( + 'float(sec(%pi/3) + csc(%pi/3))') - 3.154700538379252) < 10**(-5) diff --git a/phivenv/Lib/site-packages/sympy/parsing/tests/test_sym_expr.py b/phivenv/Lib/site-packages/sympy/parsing/tests/test_sym_expr.py new file mode 100644 index 0000000000000000000000000000000000000000..99912805db381b96e7f41a348fe6f90d71adf781 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/tests/test_sym_expr.py @@ -0,0 +1,209 @@ +from sympy.parsing.sym_expr import SymPyExpression +from sympy.testing.pytest import raises +from sympy.external import import_module + +lfortran = import_module('lfortran') +cin = import_module('clang.cindex', import_kwargs = {'fromlist': ['cindex']}) + +if lfortran and cin: + from sympy.codegen.ast import (Variable, IntBaseType, FloatBaseType, String, + Declaration, FloatType) + from sympy.core import Integer, Float + from sympy.core.symbol import Symbol + + expr1 = SymPyExpression() + src = """\ + integer :: a, b, c, d + real :: p, q, r, s + """ + + def test_c_parse(): + src1 = """\ + int a, b = 4; + float c, d = 2.4; + """ + expr1.convert_to_expr(src1, 'c') + ls = expr1.return_expr() + + assert ls[0] == Declaration( + Variable( + Symbol('a'), + type=IntBaseType(String('intc')) + ) + ) + assert ls[1] == Declaration( + Variable( + Symbol('b'), + type=IntBaseType(String('intc')), + value=Integer(4) + ) + ) + assert ls[2] == Declaration( + Variable( + Symbol('c'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ) + ) + ) + assert ls[3] == Declaration( + Variable( + Symbol('d'), + type=FloatType( + String('float32'), + nbits=Integer(32), + nmant=Integer(23), + nexp=Integer(8) + ), + value=Float('2.3999999999999999', precision=53) + ) + ) + + + def test_fortran_parse(): + expr = SymPyExpression(src, 'f') + ls = expr.return_expr() + + assert ls[0] == Declaration( + Variable( + Symbol('a'), + type=IntBaseType(String('integer')), + value=Integer(0) + ) + ) + assert ls[1] == Declaration( + Variable( + Symbol('b'), + type=IntBaseType(String('integer')), + value=Integer(0) + ) + ) + assert ls[2] == Declaration( + Variable( + Symbol('c'), + type=IntBaseType(String('integer')), + value=Integer(0) + ) + ) + assert ls[3] == Declaration( + Variable( + Symbol('d'), + type=IntBaseType(String('integer')), + value=Integer(0) + ) + ) + assert ls[4] == Declaration( + Variable( + Symbol('p'), + type=FloatBaseType(String('real')), + value=Float('0.0', precision=53) + ) + ) + assert ls[5] == Declaration( + Variable( + Symbol('q'), + type=FloatBaseType(String('real')), + value=Float('0.0', precision=53) + ) + ) + assert ls[6] == Declaration( + Variable( + Symbol('r'), + type=FloatBaseType(String('real')), + value=Float('0.0', precision=53) + ) + ) + assert ls[7] == Declaration( + Variable( + Symbol('s'), + type=FloatBaseType(String('real')), + value=Float('0.0', precision=53) + ) + ) + + + def test_convert_py(): + src1 = ( + src + + """\ + a = b + c + s = p * q / r + """ + ) + expr1.convert_to_expr(src1, 'f') + exp_py = expr1.convert_to_python() + assert exp_py == [ + 'a = 0', + 'b = 0', + 'c = 0', + 'd = 0', + 'p = 0.0', + 'q = 0.0', + 'r = 0.0', + 's = 0.0', + 'a = b + c', + 's = p*q/r' + ] + + + def test_convert_fort(): + src1 = ( + src + + """\ + a = b + c + s = p * q / r + """ + ) + expr1.convert_to_expr(src1, 'f') + exp_fort = expr1.convert_to_fortran() + assert exp_fort == [ + ' integer*4 a', + ' integer*4 b', + ' integer*4 c', + ' integer*4 d', + ' real*8 p', + ' real*8 q', + ' real*8 r', + ' real*8 s', + ' a = b + c', + ' s = p*q/r' + ] + + + def test_convert_c(): + src1 = ( + src + + """\ + a = b + c + s = p * q / r + """ + ) + expr1.convert_to_expr(src1, 'f') + exp_c = expr1.convert_to_c() + assert exp_c == [ + 'int a = 0', + 'int b = 0', + 'int c = 0', + 'int d = 0', + 'double p = 0.0', + 'double q = 0.0', + 'double r = 0.0', + 'double s = 0.0', + 'a = b + c;', + 's = p*q/r;' + ] + + + def test_exceptions(): + src = 'int a;' + raises(ValueError, lambda: SymPyExpression(src)) + raises(ValueError, lambda: SymPyExpression(mode = 'c')) + raises(NotImplementedError, lambda: SymPyExpression(src, mode = 'd')) + +elif not lfortran and not cin: + def test_raise(): + raises(ImportError, lambda: SymPyExpression('int a;', 'c')) + raises(ImportError, lambda: SymPyExpression('integer :: a', 'f')) diff --git a/phivenv/Lib/site-packages/sympy/parsing/tests/test_sympy_parser.py b/phivenv/Lib/site-packages/sympy/parsing/tests/test_sympy_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..43ecccbe262ffb4093248d891aa7423c8f62c628 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/parsing/tests/test_sympy_parser.py @@ -0,0 +1,371 @@ +# -*- coding: utf-8 -*- + + +import builtins +import types + +from sympy.assumptions import Q +from sympy.core import Symbol, Function, Float, Rational, Integer, I, Mul, Pow, Eq, Lt, Le, Gt, Ge, Ne +from sympy.functions import exp, factorial, factorial2, sin, Min, Max +from sympy.logic import And +from sympy.series import Limit +from sympy.testing.pytest import raises + +from sympy.parsing.sympy_parser import ( + parse_expr, standard_transformations, rationalize, TokenError, + split_symbols, implicit_multiplication, convert_equals_signs, + convert_xor, function_exponentiation, lambda_notation, auto_symbol, + repeated_decimals, implicit_multiplication_application, + auto_number, factorial_notation, implicit_application, + _transformation, T + ) + + +def test_sympy_parser(): + x = Symbol('x') + inputs = { + '2*x': 2 * x, + '3.00': Float(3), + '22/7': Rational(22, 7), + '2+3j': 2 + 3*I, + 'exp(x)': exp(x), + 'x!': factorial(x), + 'x!!': factorial2(x), + '(x + 1)! - 1': factorial(x + 1) - 1, + '3.[3]': Rational(10, 3), + '.0[3]': Rational(1, 30), + '3.2[3]': Rational(97, 30), + '1.3[12]': Rational(433, 330), + '1 + 3.[3]': Rational(13, 3), + '1 + .0[3]': Rational(31, 30), + '1 + 3.2[3]': Rational(127, 30), + '.[0011]': Rational(1, 909), + '0.1[00102] + 1': Rational(366697, 333330), + '1.[0191]': Rational(10190, 9999), + '10!': 3628800, + '-(2)': -Integer(2), + '[-1, -2, 3]': [Integer(-1), Integer(-2), Integer(3)], + 'Symbol("x").free_symbols': x.free_symbols, + "S('S(3).n(n=3)')": Float(3, 3), + 'factorint(12, visual=True)': Mul( + Pow(2, 2, evaluate=False), + Pow(3, 1, evaluate=False), + evaluate=False), + 'Limit(sin(x), x, 0, dir="-")': Limit(sin(x), x, 0, dir='-'), + 'Q.even(x)': Q.even(x), + + + } + for text, result in inputs.items(): + assert parse_expr(text) == result + + raises(TypeError, lambda: + parse_expr('x', standard_transformations)) + raises(TypeError, lambda: + parse_expr('x', transformations=lambda x,y: 1)) + raises(TypeError, lambda: + parse_expr('x', transformations=(lambda x,y: 1,))) + raises(TypeError, lambda: parse_expr('x', transformations=((),))) + raises(TypeError, lambda: parse_expr('x', {}, [], [])) + raises(TypeError, lambda: parse_expr('x', [], [], {})) + raises(TypeError, lambda: parse_expr('x', [], [], {})) + + +def test_rationalize(): + inputs = { + '0.123': Rational(123, 1000) + } + transformations = standard_transformations + (rationalize,) + for text, result in inputs.items(): + assert parse_expr(text, transformations=transformations) == result + + +def test_factorial_fail(): + inputs = ['x!!!', 'x!!!!', '(!)'] + + + for text in inputs: + try: + parse_expr(text) + assert False + except TokenError: + assert True + + +def test_repeated_fail(): + inputs = ['1[1]', '.1e1[1]', '0x1[1]', '1.1j[1]', '1.1[1 + 1]', + '0.1[[1]]', '0x1.1[1]'] + + + # All are valid Python, so only raise TypeError for invalid indexing + for text in inputs: + raises(TypeError, lambda: parse_expr(text)) + + + inputs = ['0.1[', '0.1[1', '0.1[]'] + for text in inputs: + raises((TokenError, SyntaxError), lambda: parse_expr(text)) + + +def test_repeated_dot_only(): + assert parse_expr('.[1]') == Rational(1, 9) + assert parse_expr('1 + .[1]') == Rational(10, 9) + + +def test_local_dict(): + local_dict = { + 'my_function': lambda x: x + 2 + } + inputs = { + 'my_function(2)': Integer(4) + } + for text, result in inputs.items(): + assert parse_expr(text, local_dict=local_dict) == result + + +def test_local_dict_split_implmult(): + t = standard_transformations + (split_symbols, implicit_multiplication,) + w = Symbol('w', real=True) + y = Symbol('y') + assert parse_expr('yx', local_dict={'x':w}, transformations=t) == y*w + + +def test_local_dict_symbol_to_fcn(): + x = Symbol('x') + d = {'foo': Function('bar')} + assert parse_expr('foo(x)', local_dict=d) == d['foo'](x) + d = {'foo': Symbol('baz')} + raises(TypeError, lambda: parse_expr('foo(x)', local_dict=d)) + + +def test_global_dict(): + global_dict = { + 'Symbol': Symbol + } + inputs = { + 'Q & S': And(Symbol('Q'), Symbol('S')) + } + for text, result in inputs.items(): + assert parse_expr(text, global_dict=global_dict) == result + + +def test_no_globals(): + + # Replicate creating the default global_dict: + default_globals = {} + exec('from sympy import *', default_globals) + builtins_dict = vars(builtins) + for name, obj in builtins_dict.items(): + if isinstance(obj, types.BuiltinFunctionType): + default_globals[name] = obj + default_globals['max'] = Max + default_globals['min'] = Min + + # Need to include Symbol or parse_expr will not work: + default_globals.pop('Symbol') + global_dict = {'Symbol':Symbol} + + for name in default_globals: + obj = parse_expr(name, global_dict=global_dict) + assert obj == Symbol(name) + + +def test_issue_2515(): + raises(TokenError, lambda: parse_expr('(()')) + raises(TokenError, lambda: parse_expr('"""')) + + +def test_issue_7663(): + x = Symbol('x') + e = '2*(x+1)' + assert parse_expr(e, evaluate=False) == parse_expr(e, evaluate=False) + assert parse_expr(e, evaluate=False).equals(2*(x+1)) + +def test_recursive_evaluate_false_10560(): + inputs = { + '4*-3' : '4*-3', + '-4*3' : '(-4)*3', + "-2*x*y": '(-2)*x*y', + "x*-4*x": "x*(-4)*x" + } + for text, result in inputs.items(): + assert parse_expr(text, evaluate=False) == parse_expr(result, evaluate=False) + + +def test_function_evaluate_false(): + inputs = [ + 'Abs(0)', 'im(0)', 're(0)', 'sign(0)', 'arg(0)', 'conjugate(0)', + 'acos(0)', 'acot(0)', 'acsc(0)', 'asec(0)', 'asin(0)', 'atan(0)', + 'acosh(0)', 'acoth(0)', 'acsch(0)', 'asech(0)', 'asinh(0)', 'atanh(0)', + 'cos(0)', 'cot(0)', 'csc(0)', 'sec(0)', 'sin(0)', 'tan(0)', + 'cosh(0)', 'coth(0)', 'csch(0)', 'sech(0)', 'sinh(0)', 'tanh(0)', + 'exp(0)', 'log(0)', 'sqrt(0)', + ] + for case in inputs: + expr = parse_expr(case, evaluate=False) + assert case == str(expr) != str(expr.doit()) + assert str(parse_expr('ln(0)', evaluate=False)) == 'log(0)' + assert str(parse_expr('cbrt(0)', evaluate=False)) == '0**(1/3)' + + +def test_issue_10773(): + inputs = { + '-10/5': '(-10)/5', + '-10/-5' : '(-10)/(-5)', + } + for text, result in inputs.items(): + assert parse_expr(text, evaluate=False) == parse_expr(result, evaluate=False) + + +def test_split_symbols(): + transformations = standard_transformations + \ + (split_symbols, implicit_multiplication,) + x = Symbol('x') + y = Symbol('y') + xy = Symbol('xy') + + + assert parse_expr("xy") == xy + assert parse_expr("xy", transformations=transformations) == x*y + + +def test_split_symbols_function(): + transformations = standard_transformations + \ + (split_symbols, implicit_multiplication,) + x = Symbol('x') + y = Symbol('y') + a = Symbol('a') + f = Function('f') + + + assert parse_expr("ay(x+1)", transformations=transformations) == a*y*(x+1) + assert parse_expr("af(x+1)", transformations=transformations, + local_dict={'f':f}) == a*f(x+1) + + +def test_functional_exponent(): + t = standard_transformations + (convert_xor, function_exponentiation) + x = Symbol('x') + y = Symbol('y') + a = Symbol('a') + yfcn = Function('y') + assert parse_expr("sin^2(x)", transformations=t) == (sin(x))**2 + assert parse_expr("sin^y(x)", transformations=t) == (sin(x))**y + assert parse_expr("exp^y(x)", transformations=t) == (exp(x))**y + assert parse_expr("E^y(x)", transformations=t) == exp(yfcn(x)) + assert parse_expr("a^y(x)", transformations=t) == a**(yfcn(x)) + + +def test_match_parentheses_implicit_multiplication(): + transformations = standard_transformations + \ + (implicit_multiplication,) + raises(TokenError, lambda: parse_expr('(1,2),(3,4]',transformations=transformations)) + + +def test_convert_equals_signs(): + transformations = standard_transformations + \ + (convert_equals_signs, ) + x = Symbol('x') + y = Symbol('y') + assert parse_expr("1*2=x", transformations=transformations) == Eq(2, x) + assert parse_expr("y = x", transformations=transformations) == Eq(y, x) + assert parse_expr("(2*y = x) = False", + transformations=transformations) == Eq(Eq(2*y, x), False) + + +def test_parse_function_issue_3539(): + x = Symbol('x') + f = Function('f') + assert parse_expr('f(x)') == f(x) + +def test_issue_24288(): + assert parse_expr("1 < 2", evaluate=False) == Lt(1, 2, evaluate=False) + assert parse_expr("1 <= 2", evaluate=False) == Le(1, 2, evaluate=False) + assert parse_expr("1 > 2", evaluate=False) == Gt(1, 2, evaluate=False) + assert parse_expr("1 >= 2", evaluate=False) == Ge(1, 2, evaluate=False) + assert parse_expr("1 != 2", evaluate=False) == Ne(1, 2, evaluate=False) + assert parse_expr("1 == 2", evaluate=False) == Eq(1, 2, evaluate=False) + assert parse_expr("1 < 2 < 3", evaluate=False) == And(Lt(1, 2, evaluate=False), Lt(2, 3, evaluate=False), evaluate=False) + assert parse_expr("1 <= 2 <= 3", evaluate=False) == And(Le(1, 2, evaluate=False), Le(2, 3, evaluate=False), evaluate=False) + assert parse_expr("1 < 2 <= 3 < 4", evaluate=False) == \ + And(Lt(1, 2, evaluate=False), Le(2, 3, evaluate=False), Lt(3, 4, evaluate=False), evaluate=False) + # Valid Python relational operators that SymPy does not decide how to handle them yet + raises(ValueError, lambda: parse_expr("1 in 2", evaluate=False)) + raises(ValueError, lambda: parse_expr("1 is 2", evaluate=False)) + raises(ValueError, lambda: parse_expr("1 not in 2", evaluate=False)) + raises(ValueError, lambda: parse_expr("1 is not 2", evaluate=False)) + +def test_split_symbols_numeric(): + transformations = ( + standard_transformations + + (implicit_multiplication_application,)) + + n = Symbol('n') + expr1 = parse_expr('2**n * 3**n') + expr2 = parse_expr('2**n3**n', transformations=transformations) + assert expr1 == expr2 == 2**n*3**n + + expr1 = parse_expr('n12n34', transformations=transformations) + assert expr1 == n*12*n*34 + + +def test_unicode_names(): + assert parse_expr('α') == Symbol('α') + + +def test_python3_features(): + assert parse_expr("123_456") == 123456 + assert parse_expr("1.2[3_4]") == parse_expr("1.2[34]") == Rational(611, 495) + assert parse_expr("1.2[012_012]") == parse_expr("1.2[012012]") == Rational(400, 333) + assert parse_expr('.[3_4]') == parse_expr('.[34]') == Rational(34, 99) + assert parse_expr('.1[3_4]') == parse_expr('.1[34]') == Rational(133, 990) + assert parse_expr('123_123.123_123[3_4]') == parse_expr('123123.123123[34]') == Rational(12189189189211, 99000000) + + +def test_issue_19501(): + x = Symbol('x') + eq = parse_expr('E**x(1+x)', local_dict={'x': x}, transformations=( + standard_transformations + + (implicit_multiplication_application,))) + assert eq.free_symbols == {x} + + +def test_parsing_definitions(): + from sympy.abc import x + assert len(_transformation) == 12 # if this changes, extend below + assert _transformation[0] == lambda_notation + assert _transformation[1] == auto_symbol + assert _transformation[2] == repeated_decimals + assert _transformation[3] == auto_number + assert _transformation[4] == factorial_notation + assert _transformation[5] == implicit_multiplication_application + assert _transformation[6] == convert_xor + assert _transformation[7] == implicit_application + assert _transformation[8] == implicit_multiplication + assert _transformation[9] == convert_equals_signs + assert _transformation[10] == function_exponentiation + assert _transformation[11] == rationalize + assert T[:5] == T[0,1,2,3,4] == standard_transformations + t = _transformation + assert T[-1, 0] == (t[len(t) - 1], t[0]) + assert T[:5, 8] == standard_transformations + (t[8],) + assert parse_expr('0.3x^2', transformations='all') == 3*x**2/10 + assert parse_expr('sin 3x', transformations='implicit') == sin(3*x) + + +def test_builtins(): + cases = [ + ('abs(x)', 'Abs(x)'), + ('max(x, y)', 'Max(x, y)'), + ('min(x, y)', 'Min(x, y)'), + ('pow(x, y)', 'Pow(x, y)'), + ] + for built_in_func_call, sympy_func_call in cases: + assert parse_expr(built_in_func_call) == parse_expr(sympy_func_call) + assert str(parse_expr('pow(38, -1, 97)')) == '23' + + +def test_issue_22822(): + raises(ValueError, lambda: parse_expr('x', {'': 1})) + data = {'some_parameter': None} + assert parse_expr('some_parameter is None', data) is True diff --git a/phivenv/Lib/site-packages/sympy/physics/__init__.py b/phivenv/Lib/site-packages/sympy/physics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..60989896ae8b3f69efc7d2350add8f6f19d85669 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/__init__.py @@ -0,0 +1,12 @@ +""" +A module that helps solving problems in physics. +""" + +from . import units +from .matrices import mgamma, msigma, minkowski_tensor, mdft + +__all__ = [ + 'units', + + 'mgamma', 'msigma', 'minkowski_tensor', 'mdft', +] diff --git a/phivenv/Lib/site-packages/sympy/physics/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..34630913f7913df139eae9374cf9f84aaf5c249a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/__pycache__/hydrogen.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/__pycache__/hydrogen.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4fdf29667457ea821b002cdfa4cabdeb168b5c21 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/__pycache__/hydrogen.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/__pycache__/matrices.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/__pycache__/matrices.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b5950efde853f1964a056b2d014fc70b9d21aad Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/__pycache__/matrices.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/__pycache__/paulialgebra.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/__pycache__/paulialgebra.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab7ee5701d15bd4a31fc3d165704646d4ab2a5d7 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/__pycache__/paulialgebra.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/__pycache__/pring.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/__pycache__/pring.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a9fb30200b8b520f6d8d8096a5b71e0263c4603b Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/__pycache__/pring.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/__pycache__/qho_1d.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/__pycache__/qho_1d.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b878088a49254eb9d14094443f828cee2823eabe Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/__pycache__/qho_1d.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/__pycache__/secondquant.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/__pycache__/secondquant.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03c72f13ca32a2008cb6360989a5aaa1d9343a44 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/__pycache__/secondquant.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/__pycache__/sho.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/__pycache__/sho.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce8e093bf5ad85555ffce718296dbb77ec89fbb2 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/__pycache__/sho.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/__pycache__/wigner.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/__pycache__/wigner.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f69792d749c6a038eb8b1c8c3ed6efe7708b592 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/__pycache__/wigner.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/biomechanics/__init__.py b/phivenv/Lib/site-packages/sympy/physics/biomechanics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3e0f687cc23c1862b65e55117841cfd7d2b8e3f0 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/biomechanics/__init__.py @@ -0,0 +1,53 @@ +"""Biomechanics extension for SymPy. + +Includes biomechanics-related constructs which allows users to extend multibody +models created using `sympy.physics.mechanics` into biomechanical or +musculoskeletal models involding musculotendons and activation dynamics. + +""" + +from .activation import ( + ActivationBase, + FirstOrderActivationDeGroote2016, + ZerothOrderActivation, +) +from .curve import ( + CharacteristicCurveCollection, + CharacteristicCurveFunction, + FiberForceLengthActiveDeGroote2016, + FiberForceLengthPassiveDeGroote2016, + FiberForceLengthPassiveInverseDeGroote2016, + FiberForceVelocityDeGroote2016, + FiberForceVelocityInverseDeGroote2016, + TendonForceLengthDeGroote2016, + TendonForceLengthInverseDeGroote2016, +) +from .musculotendon import ( + MusculotendonBase, + MusculotendonDeGroote2016, + MusculotendonFormulation, +) + + +__all__ = [ + # Musculotendon characteristic curve functions + 'CharacteristicCurveCollection', + 'CharacteristicCurveFunction', + 'FiberForceLengthActiveDeGroote2016', + 'FiberForceLengthPassiveDeGroote2016', + 'FiberForceLengthPassiveInverseDeGroote2016', + 'FiberForceVelocityDeGroote2016', + 'FiberForceVelocityInverseDeGroote2016', + 'TendonForceLengthDeGroote2016', + 'TendonForceLengthInverseDeGroote2016', + + # Activation dynamics classes + 'ActivationBase', + 'FirstOrderActivationDeGroote2016', + 'ZerothOrderActivation', + + # Musculotendon classes + 'MusculotendonBase', + 'MusculotendonDeGroote2016', + 'MusculotendonFormulation', +] diff --git a/phivenv/Lib/site-packages/sympy/physics/biomechanics/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/biomechanics/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..615a7434655b48cec9ddcd3bfc45cd8edf3251de Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/biomechanics/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/biomechanics/__pycache__/_mixin.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/biomechanics/__pycache__/_mixin.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..074be01114a16877563b1953c7a5ac5b13d4881a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/biomechanics/__pycache__/_mixin.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/biomechanics/__pycache__/activation.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/biomechanics/__pycache__/activation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..baa9d16228a7305457b992d376ef83f2dadd4650 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/biomechanics/__pycache__/activation.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/biomechanics/__pycache__/curve.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/biomechanics/__pycache__/curve.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5fa0954ed17f045dfbc566b4225e42a4388f8f73 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/biomechanics/__pycache__/curve.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/biomechanics/__pycache__/musculotendon.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/biomechanics/__pycache__/musculotendon.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20766e88a514580cdff2bb96feb8355a534d0fad Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/biomechanics/__pycache__/musculotendon.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/biomechanics/_mixin.py b/phivenv/Lib/site-packages/sympy/physics/biomechanics/_mixin.py new file mode 100644 index 0000000000000000000000000000000000000000..f6ff905100fb4d6f346aaf717cfe9a66b4c2cc9a --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/biomechanics/_mixin.py @@ -0,0 +1,53 @@ +"""Mixin classes for sharing functionality between unrelated classes. + +This module is named with a leading underscore to signify to users that it's +"private" and only intended for internal use by the biomechanics module. + +""" + + +__all__ = ['_NamedMixin'] + + +class _NamedMixin: + """Mixin class for adding `name` properties. + + Valid names, as will typically be used by subclasses as a suffix when + naming automatically-instantiated symbol attributes, must be nonzero length + strings. + + Attributes + ========== + + name : str + The name identifier associated with the instance. Must be a string of + length at least 1. + + """ + + @property + def name(self) -> str: + """The name associated with the class instance.""" + return self._name + + @name.setter + def name(self, name: str) -> None: + if hasattr(self, '_name'): + msg = ( + f'Can\'t set attribute `name` to {repr(name)} as it is ' + f'immutable.' + ) + raise AttributeError(msg) + if not isinstance(name, str): + msg = ( + f'Name {repr(name)} passed to `name` was of type ' + f'{type(name)}, must be {str}.' + ) + raise TypeError(msg) + if name in {''}: + msg = ( + f'Name {repr(name)} is invalid, must be a nonzero length ' + f'{type(str)}.' + ) + raise ValueError(msg) + self._name = name diff --git a/phivenv/Lib/site-packages/sympy/physics/biomechanics/activation.py b/phivenv/Lib/site-packages/sympy/physics/biomechanics/activation.py new file mode 100644 index 0000000000000000000000000000000000000000..908d9bd2e7b433f91ef6678426c2e4896ab82f27 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/biomechanics/activation.py @@ -0,0 +1,869 @@ +r"""Activation dynamics for musclotendon models. + +Musculotendon models are able to produce active force when they are activated, +which is when a chemical process has taken place within the muscle fibers +causing them to voluntarily contract. Biologically this chemical process (the +diffusion of :math:`\textrm{Ca}^{2+}` ions) is not the input in the system, +electrical signals from the nervous system are. These are termed excitations. +Activation dynamics, which relates the normalized excitation level to the +normalized activation level, can be modeled by the models present in this +module. + +""" + +from abc import ABC, abstractmethod +from functools import cached_property + +from sympy.core.symbol import Symbol +from sympy.core.numbers import Float, Integer, Rational +from sympy.functions.elementary.hyperbolic import tanh +from sympy.matrices.dense import MutableDenseMatrix as Matrix, zeros +from sympy.physics.biomechanics._mixin import _NamedMixin +from sympy.physics.mechanics import dynamicsymbols + + +__all__ = [ + 'ActivationBase', + 'FirstOrderActivationDeGroote2016', + 'ZerothOrderActivation', +] + + +class ActivationBase(ABC, _NamedMixin): + """Abstract base class for all activation dynamics classes to inherit from. + + Notes + ===== + + Instances of this class cannot be directly instantiated by users. However, + it can be used to created custom activation dynamics types through + subclassing. + + """ + + def __init__(self, name): + """Initializer for ``ActivationBase``.""" + self.name = str(name) + + # Symbols + self._e = dynamicsymbols(f"e_{name}") + self._a = dynamicsymbols(f"a_{name}") + + @classmethod + @abstractmethod + def with_defaults(cls, name): + """Alternate constructor that provides recommended defaults for + constants.""" + pass + + @property + def excitation(self): + """Dynamic symbol representing excitation. + + Explanation + =========== + + The alias ``e`` can also be used to access the same attribute. + + """ + return self._e + + @property + def e(self): + """Dynamic symbol representing excitation. + + Explanation + =========== + + The alias ``excitation`` can also be used to access the same attribute. + + """ + return self._e + + @property + def activation(self): + """Dynamic symbol representing activation. + + Explanation + =========== + + The alias ``a`` can also be used to access the same attribute. + + """ + return self._a + + @property + def a(self): + """Dynamic symbol representing activation. + + Explanation + =========== + + The alias ``activation`` can also be used to access the same attribute. + + """ + return self._a + + @property + @abstractmethod + def order(self): + """Order of the (differential) equation governing activation.""" + pass + + @property + @abstractmethod + def state_vars(self): + """Ordered column matrix of functions of time that represent the state + variables. + + Explanation + =========== + + The alias ``x`` can also be used to access the same attribute. + + """ + pass + + @property + @abstractmethod + def x(self): + """Ordered column matrix of functions of time that represent the state + variables. + + Explanation + =========== + + The alias ``state_vars`` can also be used to access the same attribute. + + """ + pass + + @property + @abstractmethod + def input_vars(self): + """Ordered column matrix of functions of time that represent the input + variables. + + Explanation + =========== + + The alias ``r`` can also be used to access the same attribute. + + """ + pass + + @property + @abstractmethod + def r(self): + """Ordered column matrix of functions of time that represent the input + variables. + + Explanation + =========== + + The alias ``input_vars`` can also be used to access the same attribute. + + """ + pass + + @property + @abstractmethod + def constants(self): + """Ordered column matrix of non-time varying symbols present in ``M`` + and ``F``. + + Only symbolic constants are returned. If a numeric type (e.g. ``Float``) + has been used instead of ``Symbol`` for a constant then that attribute + will not be included in the matrix returned by this property. This is + because the primary use of this property attribute is to provide an + ordered sequence of the still-free symbols that require numeric values + during code generation. + + Explanation + =========== + + The alias ``p`` can also be used to access the same attribute. + + """ + pass + + @property + @abstractmethod + def p(self): + """Ordered column matrix of non-time varying symbols present in ``M`` + and ``F``. + + Only symbolic constants are returned. If a numeric type (e.g. ``Float``) + has been used instead of ``Symbol`` for a constant then that attribute + will not be included in the matrix returned by this property. This is + because the primary use of this property attribute is to provide an + ordered sequence of the still-free symbols that require numeric values + during code generation. + + Explanation + =========== + + The alias ``constants`` can also be used to access the same attribute. + + """ + pass + + @property + @abstractmethod + def M(self): + """Ordered square matrix of coefficients on the LHS of ``M x' = F``. + + Explanation + =========== + + The square matrix that forms part of the LHS of the linear system of + ordinary differential equations governing the activation dynamics: + + ``M(x, r, t, p) x' = F(x, r, t, p)``. + + """ + pass + + @property + @abstractmethod + def F(self): + """Ordered column matrix of equations on the RHS of ``M x' = F``. + + Explanation + =========== + + The column matrix that forms the RHS of the linear system of ordinary + differential equations governing the activation dynamics: + + ``M(x, r, t, p) x' = F(x, r, t, p)``. + + """ + pass + + @abstractmethod + def rhs(self): + """ + + Explanation + =========== + + The solution to the linear system of ordinary differential equations + governing the activation dynamics: + + ``M(x, r, t, p) x' = F(x, r, t, p)``. + + """ + pass + + def __eq__(self, other): + """Equality check for activation dynamics.""" + if type(self) != type(other): + return False + if self.name != other.name: + return False + return True + + def __repr__(self): + """Default representation of activation dynamics.""" + return f'{self.__class__.__name__}({self.name!r})' + + +class ZerothOrderActivation(ActivationBase): + """Simple zeroth-order activation dynamics mapping excitation to + activation. + + Explanation + =========== + + Zeroth-order activation dynamics are useful in instances where you want to + reduce the complexity of your musculotendon dynamics as they simple map + exictation to activation. As a result, no additional state equations are + introduced to your system. They also remove a potential source of delay + between the input and dynamics of your system as no (ordinary) differential + equations are involved. + + """ + + def __init__(self, name): + """Initializer for ``ZerothOrderActivation``. + + Parameters + ========== + + name : str + The name identifier associated with the instance. Must be a string + of length at least 1. + + """ + super().__init__(name) + + # Zeroth-order activation dynamics has activation equal excitation so + # overwrite the symbol for activation with the excitation symbol. + self._a = self._e + + @classmethod + def with_defaults(cls, name): + """Alternate constructor that provides recommended defaults for + constants. + + Explanation + =========== + + As this concrete class doesn't implement any constants associated with + its dynamics, this ``classmethod`` simply creates a standard instance + of ``ZerothOrderActivation``. An implementation is provided to ensure + a consistent interface between all ``ActivationBase`` concrete classes. + + """ + return cls(name) + + @property + def order(self): + """Order of the (differential) equation governing activation.""" + return 0 + + @property + def state_vars(self): + """Ordered column matrix of functions of time that represent the state + variables. + + Explanation + =========== + + As zeroth-order activation dynamics simply maps excitation to + activation, this class has no associated state variables and so this + property return an empty column ``Matrix`` with shape (0, 1). + + The alias ``x`` can also be used to access the same attribute. + + """ + return zeros(0, 1) + + @property + def x(self): + """Ordered column matrix of functions of time that represent the state + variables. + + Explanation + =========== + + As zeroth-order activation dynamics simply maps excitation to + activation, this class has no associated state variables and so this + property return an empty column ``Matrix`` with shape (0, 1). + + The alias ``state_vars`` can also be used to access the same attribute. + + """ + return zeros(0, 1) + + @property + def input_vars(self): + """Ordered column matrix of functions of time that represent the input + variables. + + Explanation + =========== + + Excitation is the only input in zeroth-order activation dynamics and so + this property returns a column ``Matrix`` with one entry, ``e``, and + shape (1, 1). + + The alias ``r`` can also be used to access the same attribute. + + """ + return Matrix([self._e]) + + @property + def r(self): + """Ordered column matrix of functions of time that represent the input + variables. + + Explanation + =========== + + Excitation is the only input in zeroth-order activation dynamics and so + this property returns a column ``Matrix`` with one entry, ``e``, and + shape (1, 1). + + The alias ``input_vars`` can also be used to access the same attribute. + + """ + return Matrix([self._e]) + + @property + def constants(self): + """Ordered column matrix of non-time varying symbols present in ``M`` + and ``F``. + + Only symbolic constants are returned. If a numeric type (e.g. ``Float``) + has been used instead of ``Symbol`` for a constant then that attribute + will not be included in the matrix returned by this property. This is + because the primary use of this property attribute is to provide an + ordered sequence of the still-free symbols that require numeric values + during code generation. + + Explanation + =========== + + As zeroth-order activation dynamics simply maps excitation to + activation, this class has no associated constants and so this property + return an empty column ``Matrix`` with shape (0, 1). + + The alias ``p`` can also be used to access the same attribute. + + """ + return zeros(0, 1) + + @property + def p(self): + """Ordered column matrix of non-time varying symbols present in ``M`` + and ``F``. + + Only symbolic constants are returned. If a numeric type (e.g. ``Float``) + has been used instead of ``Symbol`` for a constant then that attribute + will not be included in the matrix returned by this property. This is + because the primary use of this property attribute is to provide an + ordered sequence of the still-free symbols that require numeric values + during code generation. + + Explanation + =========== + + As zeroth-order activation dynamics simply maps excitation to + activation, this class has no associated constants and so this property + return an empty column ``Matrix`` with shape (0, 1). + + The alias ``constants`` can also be used to access the same attribute. + + """ + return zeros(0, 1) + + @property + def M(self): + """Ordered square matrix of coefficients on the LHS of ``M x' = F``. + + Explanation + =========== + + The square matrix that forms part of the LHS of the linear system of + ordinary differential equations governing the activation dynamics: + + ``M(x, r, t, p) x' = F(x, r, t, p)``. + + As zeroth-order activation dynamics have no state variables, this + linear system has dimension 0 and therefore ``M`` is an empty square + ``Matrix`` with shape (0, 0). + + """ + return Matrix([]) + + @property + def F(self): + """Ordered column matrix of equations on the RHS of ``M x' = F``. + + Explanation + =========== + + The column matrix that forms the RHS of the linear system of ordinary + differential equations governing the activation dynamics: + + ``M(x, r, t, p) x' = F(x, r, t, p)``. + + As zeroth-order activation dynamics have no state variables, this + linear system has dimension 0 and therefore ``F`` is an empty column + ``Matrix`` with shape (0, 1). + + """ + return zeros(0, 1) + + def rhs(self): + """Ordered column matrix of equations for the solution of ``M x' = F``. + + Explanation + =========== + + The solution to the linear system of ordinary differential equations + governing the activation dynamics: + + ``M(x, r, t, p) x' = F(x, r, t, p)``. + + As zeroth-order activation dynamics have no state variables, this + linear has dimension 0 and therefore this method returns an empty + column ``Matrix`` with shape (0, 1). + + """ + return zeros(0, 1) + + +class FirstOrderActivationDeGroote2016(ActivationBase): + r"""First-order activation dynamics based on De Groote et al., 2016 [1]_. + + Explanation + =========== + + Gives the first-order activation dynamics equation for the rate of change + of activation with respect to time as a function of excitation and + activation. + + The function is defined by the equation: + + .. math:: + + \frac{da}{dt} = \left(\frac{\frac{1}{2} + a0}{\tau_a \left(\frac{1}{2} + + \frac{3a}{2}\right)} + \frac{\left(\frac{1}{2} + + \frac{3a}{2}\right) \left(\frac{1}{2} - a0\right)}{\tau_d}\right) + \left(e - a\right) + + where + + .. math:: + + a0 = \frac{\tanh{\left(b \left(e - a\right) \right)}}{2} + + with constant values of :math:`tau_a = 0.015`, :math:`tau_d = 0.060`, and + :math:`b = 10`. + + References + ========== + + .. [1] De Groote, F., Kinney, A. L., Rao, A. V., & Fregly, B. J., Evaluation + of direct collocation optimal control problem formulations for + solving the muscle redundancy problem, Annals of biomedical + engineering, 44(10), (2016) pp. 2922-2936 + + """ + + def __init__(self, + name, + activation_time_constant=None, + deactivation_time_constant=None, + smoothing_rate=None, + ): + """Initializer for ``FirstOrderActivationDeGroote2016``. + + Parameters + ========== + activation time constant : Symbol | Number | None + The value of the activation time constant governing the delay + between excitation and activation when excitation exceeds + activation. + deactivation time constant : Symbol | Number | None + The value of the deactivation time constant governing the delay + between excitation and activation when activation exceeds + excitation. + smoothing_rate : Symbol | Number | None + The slope of the hyperbolic tangent function used to smooth between + the switching of the equations where excitation exceed activation + and where activation exceeds excitation. The recommended value to + use is ``10``, but values between ``0.1`` and ``100`` can be used. + + """ + super().__init__(name) + + # Symbols + self.activation_time_constant = activation_time_constant + self.deactivation_time_constant = deactivation_time_constant + self.smoothing_rate = smoothing_rate + + @classmethod + def with_defaults(cls, name): + r"""Alternate constructor that will use the published constants. + + Explanation + =========== + + Returns an instance of ``FirstOrderActivationDeGroote2016`` using the + three constant values specified in the original publication. + + These have the values: + + :math:`tau_a = 0.015` + :math:`tau_d = 0.060` + :math:`b = 10` + + """ + tau_a = Float('0.015') + tau_d = Float('0.060') + b = Float('10.0') + return cls(name, tau_a, tau_d, b) + + @property + def activation_time_constant(self): + """Delay constant for activation. + + Explanation + =========== + + The alias ```tau_a`` can also be used to access the same attribute. + + """ + return self._tau_a + + @activation_time_constant.setter + def activation_time_constant(self, tau_a): + if hasattr(self, '_tau_a'): + msg = ( + f'Can\'t set attribute `activation_time_constant` to ' + f'{repr(tau_a)} as it is immutable and already has value ' + f'{self._tau_a}.' + ) + raise AttributeError(msg) + self._tau_a = Symbol(f'tau_a_{self.name}') if tau_a is None else tau_a + + @property + def tau_a(self): + """Delay constant for activation. + + Explanation + =========== + + The alias ``activation_time_constant`` can also be used to access the + same attribute. + + """ + return self._tau_a + + @property + def deactivation_time_constant(self): + """Delay constant for deactivation. + + Explanation + =========== + + The alias ``tau_d`` can also be used to access the same attribute. + + """ + return self._tau_d + + @deactivation_time_constant.setter + def deactivation_time_constant(self, tau_d): + if hasattr(self, '_tau_d'): + msg = ( + f'Can\'t set attribute `deactivation_time_constant` to ' + f'{repr(tau_d)} as it is immutable and already has value ' + f'{self._tau_d}.' + ) + raise AttributeError(msg) + self._tau_d = Symbol(f'tau_d_{self.name}') if tau_d is None else tau_d + + @property + def tau_d(self): + """Delay constant for deactivation. + + Explanation + =========== + + The alias ``deactivation_time_constant`` can also be used to access the + same attribute. + + """ + return self._tau_d + + @property + def smoothing_rate(self): + """Smoothing constant for the hyperbolic tangent term. + + Explanation + =========== + + The alias ``b`` can also be used to access the same attribute. + + """ + return self._b + + @smoothing_rate.setter + def smoothing_rate(self, b): + if hasattr(self, '_b'): + msg = ( + f'Can\'t set attribute `smoothing_rate` to {b!r} as it is ' + f'immutable and already has value {self._b!r}.' + ) + raise AttributeError(msg) + self._b = Symbol(f'b_{self.name}') if b is None else b + + @property + def b(self): + """Smoothing constant for the hyperbolic tangent term. + + Explanation + =========== + + The alias ``smoothing_rate`` can also be used to access the same + attribute. + + """ + return self._b + + @property + def order(self): + """Order of the (differential) equation governing activation.""" + return 1 + + @property + def state_vars(self): + """Ordered column matrix of functions of time that represent the state + variables. + + Explanation + =========== + + The alias ``x`` can also be used to access the same attribute. + + """ + return Matrix([self._a]) + + @property + def x(self): + """Ordered column matrix of functions of time that represent the state + variables. + + Explanation + =========== + + The alias ``state_vars`` can also be used to access the same attribute. + + """ + return Matrix([self._a]) + + @property + def input_vars(self): + """Ordered column matrix of functions of time that represent the input + variables. + + Explanation + =========== + + The alias ``r`` can also be used to access the same attribute. + + """ + return Matrix([self._e]) + + @property + def r(self): + """Ordered column matrix of functions of time that represent the input + variables. + + Explanation + =========== + + The alias ``input_vars`` can also be used to access the same attribute. + + """ + return Matrix([self._e]) + + @property + def constants(self): + """Ordered column matrix of non-time varying symbols present in ``M`` + and ``F``. + + Only symbolic constants are returned. If a numeric type (e.g. ``Float``) + has been used instead of ``Symbol`` for a constant then that attribute + will not be included in the matrix returned by this property. This is + because the primary use of this property attribute is to provide an + ordered sequence of the still-free symbols that require numeric values + during code generation. + + Explanation + =========== + + The alias ``p`` can also be used to access the same attribute. + + """ + constants = [self._tau_a, self._tau_d, self._b] + symbolic_constants = [c for c in constants if not c.is_number] + return Matrix(symbolic_constants) if symbolic_constants else zeros(0, 1) + + @property + def p(self): + """Ordered column matrix of non-time varying symbols present in ``M`` + and ``F``. + + Explanation + =========== + + Only symbolic constants are returned. If a numeric type (e.g. ``Float``) + has been used instead of ``Symbol`` for a constant then that attribute + will not be included in the matrix returned by this property. This is + because the primary use of this property attribute is to provide an + ordered sequence of the still-free symbols that require numeric values + during code generation. + + The alias ``constants`` can also be used to access the same attribute. + + """ + constants = [self._tau_a, self._tau_d, self._b] + symbolic_constants = [c for c in constants if not c.is_number] + return Matrix(symbolic_constants) if symbolic_constants else zeros(0, 1) + + @property + def M(self): + """Ordered square matrix of coefficients on the LHS of ``M x' = F``. + + Explanation + =========== + + The square matrix that forms part of the LHS of the linear system of + ordinary differential equations governing the activation dynamics: + + ``M(x, r, t, p) x' = F(x, r, t, p)``. + + """ + return Matrix([Integer(1)]) + + @property + def F(self): + """Ordered column matrix of equations on the RHS of ``M x' = F``. + + Explanation + =========== + + The column matrix that forms the RHS of the linear system of ordinary + differential equations governing the activation dynamics: + + ``M(x, r, t, p) x' = F(x, r, t, p)``. + + """ + return Matrix([self._da_eqn]) + + def rhs(self): + """Ordered column matrix of equations for the solution of ``M x' = F``. + + Explanation + =========== + + The solution to the linear system of ordinary differential equations + governing the activation dynamics: + + ``M(x, r, t, p) x' = F(x, r, t, p)``. + + """ + return Matrix([self._da_eqn]) + + @cached_property + def _da_eqn(self): + HALF = Rational(1, 2) + a0 = HALF * tanh(self._b * (self._e - self._a)) + a1 = (HALF + Rational(3, 2) * self._a) + a2 = (HALF + a0) / (self._tau_a * a1) + a3 = a1 * (HALF - a0) / self._tau_d + activation_dynamics_equation = (a2 + a3) * (self._e - self._a) + return activation_dynamics_equation + + def __eq__(self, other): + """Equality check for ``FirstOrderActivationDeGroote2016``.""" + if type(self) != type(other): + return False + self_attrs = (self.name, self.tau_a, self.tau_d, self.b) + other_attrs = (other.name, other.tau_a, other.tau_d, other.b) + if self_attrs == other_attrs: + return True + return False + + def __repr__(self): + """Representation of ``FirstOrderActivationDeGroote2016``.""" + return ( + f'{self.__class__.__name__}({self.name!r}, ' + f'activation_time_constant={self.tau_a!r}, ' + f'deactivation_time_constant={self.tau_d!r}, ' + f'smoothing_rate={self.b!r})' + ) diff --git a/phivenv/Lib/site-packages/sympy/physics/biomechanics/curve.py b/phivenv/Lib/site-packages/sympy/physics/biomechanics/curve.py new file mode 100644 index 0000000000000000000000000000000000000000..50535271f51493acc2183d257ce89ff0da4dde5e --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/biomechanics/curve.py @@ -0,0 +1,1763 @@ +"""Implementations of characteristic curves for musculotendon models.""" + +from dataclasses import dataclass + +from sympy.core.expr import UnevaluatedExpr +from sympy.core.function import ArgumentIndexError, Function +from sympy.core.numbers import Float, Integer +from sympy.functions.elementary.exponential import exp, log +from sympy.functions.elementary.hyperbolic import cosh, sinh +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.printing.precedence import PRECEDENCE + + +__all__ = [ + 'CharacteristicCurveCollection', + 'CharacteristicCurveFunction', + 'FiberForceLengthActiveDeGroote2016', + 'FiberForceLengthPassiveDeGroote2016', + 'FiberForceLengthPassiveInverseDeGroote2016', + 'FiberForceVelocityDeGroote2016', + 'FiberForceVelocityInverseDeGroote2016', + 'TendonForceLengthDeGroote2016', + 'TendonForceLengthInverseDeGroote2016', +] + + +class CharacteristicCurveFunction(Function): + """Base class for all musculotendon characteristic curve functions.""" + + @classmethod + def eval(cls): + msg = ( + f'Cannot directly instantiate {cls.__name__!r}, instances of ' + f'characteristic curves must be of a concrete subclass.' + + ) + raise TypeError(msg) + + def _print_code(self, printer): + """Print code for the function defining the curve using a printer. + + Explanation + =========== + + The order of operations may need to be controlled as constant folding + the numeric terms within the equations of a musculotendon + characteristic curve can sometimes results in a numerically-unstable + expression. + + Parameters + ========== + + printer : Printer + The printer to be used to print a string representation of the + characteristic curve as valid code in the target language. + + """ + return printer._print(printer.parenthesize( + self.doit(deep=False, evaluate=False), PRECEDENCE['Atom'], + )) + + _ccode = _print_code + _cupycode = _print_code + _cxxcode = _print_code + _fcode = _print_code + _jaxcode = _print_code + _lambdacode = _print_code + _mpmathcode = _print_code + _octave = _print_code + _pythoncode = _print_code + _numpycode = _print_code + _scipycode = _print_code + + +class TendonForceLengthDeGroote2016(CharacteristicCurveFunction): + r"""Tendon force-length curve based on De Groote et al., 2016 [1]_. + + Explanation + =========== + + Gives the normalized tendon force produced as a function of normalized + tendon length. + + The function is defined by the equation: + + $fl^T = c_0 \exp{c_3 \left( \tilde{l}^T - c_1 \right)} - c_2$ + + with constant values of $c_0 = 0.2$, $c_1 = 0.995$, $c_2 = 0.25$, and + $c_3 = 33.93669377311689$. + + While it is possible to change the constant values, these were carefully + selected in the original publication to give the characteristic curve + specific and required properties. For example, the function produces no + force when the tendon is in an unstrained state. It also produces a force + of 1 normalized unit when the tendon is under a 5% strain. + + Examples + ======== + + The preferred way to instantiate :class:`TendonForceLengthDeGroote2016` is using + the :meth:`~.with_defaults` constructor because this will automatically + populate the constants within the characteristic curve equation with the + floating point values from the original publication. This constructor takes + a single argument corresponding to normalized tendon length. We'll create a + :class:`~.Symbol` called ``l_T_tilde`` to represent this. + + >>> from sympy import Symbol + >>> from sympy.physics.biomechanics import TendonForceLengthDeGroote2016 + >>> l_T_tilde = Symbol('l_T_tilde') + >>> fl_T = TendonForceLengthDeGroote2016.with_defaults(l_T_tilde) + >>> fl_T + TendonForceLengthDeGroote2016(l_T_tilde, 0.2, 0.995, 0.25, + 33.93669377311689) + + It's also possible to populate the four constants with your own values too. + + >>> from sympy import symbols + >>> c0, c1, c2, c3 = symbols('c0 c1 c2 c3') + >>> fl_T = TendonForceLengthDeGroote2016(l_T_tilde, c0, c1, c2, c3) + >>> fl_T + TendonForceLengthDeGroote2016(l_T_tilde, c0, c1, c2, c3) + + You don't just have to use symbols as the arguments, it's also possible to + use expressions. Let's create a new pair of symbols, ``l_T`` and + ``l_T_slack``, representing tendon length and tendon slack length + respectively. We can then represent ``l_T_tilde`` as an expression, the + ratio of these. + + >>> l_T, l_T_slack = symbols('l_T l_T_slack') + >>> l_T_tilde = l_T/l_T_slack + >>> fl_T = TendonForceLengthDeGroote2016.with_defaults(l_T_tilde) + >>> fl_T + TendonForceLengthDeGroote2016(l_T/l_T_slack, 0.2, 0.995, 0.25, + 33.93669377311689) + + To inspect the actual symbolic expression that this function represents, + we can call the :meth:`~.doit` method on an instance. We'll use the keyword + argument ``evaluate=False`` as this will keep the expression in its + canonical form and won't simplify any constants. + + >>> fl_T.doit(evaluate=False) + -0.25 + 0.2*exp(33.93669377311689*(l_T/l_T_slack - 0.995)) + + The function can also be differentiated. We'll differentiate with respect + to l_T using the ``diff`` method on an instance with the single positional + argument ``l_T``. + + >>> fl_T.diff(l_T) + 6.787338754623378*exp(33.93669377311689*(l_T/l_T_slack - 0.995))/l_T_slack + + References + ========== + + .. [1] De Groote, F., Kinney, A. L., Rao, A. V., & Fregly, B. J., Evaluation + of direct collocation optimal control problem formulations for + solving the muscle redundancy problem, Annals of biomedical + engineering, 44(10), (2016) pp. 2922-2936 + + """ + + @classmethod + def with_defaults(cls, l_T_tilde): + r"""Recommended constructor that will use the published constants. + + Explanation + =========== + + Returns a new instance of the tendon force-length function using the + four constant values specified in the original publication. + + These have the values: + + $c_0 = 0.2$ + $c_1 = 0.995$ + $c_2 = 0.25$ + $c_3 = 33.93669377311689$ + + Parameters + ========== + + l_T_tilde : Any (sympifiable) + Normalized tendon length. + + """ + c0 = Float('0.2') + c1 = Float('0.995') + c2 = Float('0.25') + c3 = Float('33.93669377311689') + return cls(l_T_tilde, c0, c1, c2, c3) + + @classmethod + def eval(cls, l_T_tilde, c0, c1, c2, c3): + """Evaluation of basic inputs. + + Parameters + ========== + + l_T_tilde : Any (sympifiable) + Normalized tendon length. + c0 : Any (sympifiable) + The first constant in the characteristic equation. The published + value is ``0.2``. + c1 : Any (sympifiable) + The second constant in the characteristic equation. The published + value is ``0.995``. + c2 : Any (sympifiable) + The third constant in the characteristic equation. The published + value is ``0.25``. + c3 : Any (sympifiable) + The fourth constant in the characteristic equation. The published + value is ``33.93669377311689``. + + """ + pass + + def _eval_evalf(self, prec): + """Evaluate the expression numerically using ``evalf``.""" + return self.doit(deep=False, evaluate=False)._eval_evalf(prec) + + def doit(self, deep=True, evaluate=True, **hints): + """Evaluate the expression defining the function. + + Parameters + ========== + + deep : bool + Whether ``doit`` should be recursively called. Default is ``True``. + evaluate : bool. + Whether the SymPy expression should be evaluated as it is + constructed. If ``False``, then no constant folding will be + conducted which will leave the expression in a more numerically- + stable for values of ``l_T_tilde`` that correspond to a sensible + operating range for a musculotendon. Default is ``True``. + **kwargs : dict[str, Any] + Additional keyword argument pairs to be recursively passed to + ``doit``. + + """ + l_T_tilde, *constants = self.args + if deep: + hints['evaluate'] = evaluate + l_T_tilde = l_T_tilde.doit(deep=deep, **hints) + c0, c1, c2, c3 = [c.doit(deep=deep, **hints) for c in constants] + else: + c0, c1, c2, c3 = constants + + if evaluate: + return c0*exp(c3*(l_T_tilde - c1)) - c2 + + return c0*exp(c3*UnevaluatedExpr(l_T_tilde - c1)) - c2 + + def fdiff(self, argindex=1): + """Derivative of the function with respect to a single argument. + + Parameters + ========== + + argindex : int + The index of the function's arguments with respect to which the + derivative should be taken. Argument indexes start at ``1``. + Default is ``1``. + + """ + l_T_tilde, c0, c1, c2, c3 = self.args + if argindex == 1: + return c0*c3*exp(c3*UnevaluatedExpr(l_T_tilde - c1)) + elif argindex == 2: + return exp(c3*UnevaluatedExpr(l_T_tilde - c1)) + elif argindex == 3: + return -c0*c3*exp(c3*UnevaluatedExpr(l_T_tilde - c1)) + elif argindex == 4: + return Integer(-1) + elif argindex == 5: + return c0*(l_T_tilde - c1)*exp(c3*UnevaluatedExpr(l_T_tilde - c1)) + + raise ArgumentIndexError(self, argindex) + + def inverse(self, argindex=1): + """Inverse function. + + Parameters + ========== + + argindex : int + Value to start indexing the arguments at. Default is ``1``. + + """ + return TendonForceLengthInverseDeGroote2016 + + def _latex(self, printer): + """Print a LaTeX representation of the function defining the curve. + + Parameters + ========== + + printer : Printer + The printer to be used to print the LaTeX string representation. + + """ + l_T_tilde = self.args[0] + _l_T_tilde = printer._print(l_T_tilde) + return r'\operatorname{fl}^T \left( %s \right)' % _l_T_tilde + + +class TendonForceLengthInverseDeGroote2016(CharacteristicCurveFunction): + r"""Inverse tendon force-length curve based on De Groote et al., 2016 [1]_. + + Explanation + =========== + + Gives the normalized tendon length that produces a specific normalized + tendon force. + + The function is defined by the equation: + + ${fl^T}^{-1} = frac{\log{\frac{fl^T + c_2}{c_0}}}{c_3} + c_1$ + + with constant values of $c_0 = 0.2$, $c_1 = 0.995$, $c_2 = 0.25$, and + $c_3 = 33.93669377311689$. This function is the exact analytical inverse + of the related tendon force-length curve ``TendonForceLengthDeGroote2016``. + + While it is possible to change the constant values, these were carefully + selected in the original publication to give the characteristic curve + specific and required properties. For example, the function produces no + force when the tendon is in an unstrained state. It also produces a force + of 1 normalized unit when the tendon is under a 5% strain. + + Examples + ======== + + The preferred way to instantiate :class:`TendonForceLengthInverseDeGroote2016` is + using the :meth:`~.with_defaults` constructor because this will automatically + populate the constants within the characteristic curve equation with the + floating point values from the original publication. This constructor takes + a single argument corresponding to normalized tendon force-length, which is + equal to the tendon force. We'll create a :class:`~.Symbol` called ``fl_T`` to + represent this. + + >>> from sympy import Symbol + >>> from sympy.physics.biomechanics import TendonForceLengthInverseDeGroote2016 + >>> fl_T = Symbol('fl_T') + >>> l_T_tilde = TendonForceLengthInverseDeGroote2016.with_defaults(fl_T) + >>> l_T_tilde + TendonForceLengthInverseDeGroote2016(fl_T, 0.2, 0.995, 0.25, + 33.93669377311689) + + It's also possible to populate the four constants with your own values too. + + >>> from sympy import symbols + >>> c0, c1, c2, c3 = symbols('c0 c1 c2 c3') + >>> l_T_tilde = TendonForceLengthInverseDeGroote2016(fl_T, c0, c1, c2, c3) + >>> l_T_tilde + TendonForceLengthInverseDeGroote2016(fl_T, c0, c1, c2, c3) + + To inspect the actual symbolic expression that this function represents, + we can call the :meth:`~.doit` method on an instance. We'll use the keyword + argument ``evaluate=False`` as this will keep the expression in its + canonical form and won't simplify any constants. + + >>> l_T_tilde.doit(evaluate=False) + c1 + log((c2 + fl_T)/c0)/c3 + + The function can also be differentiated. We'll differentiate with respect + to l_T using the ``diff`` method on an instance with the single positional + argument ``l_T``. + + >>> l_T_tilde.diff(fl_T) + 1/(c3*(c2 + fl_T)) + + References + ========== + + .. [1] De Groote, F., Kinney, A. L., Rao, A. V., & Fregly, B. J., Evaluation + of direct collocation optimal control problem formulations for + solving the muscle redundancy problem, Annals of biomedical + engineering, 44(10), (2016) pp. 2922-2936 + + """ + + @classmethod + def with_defaults(cls, fl_T): + r"""Recommended constructor that will use the published constants. + + Explanation + =========== + + Returns a new instance of the inverse tendon force-length function + using the four constant values specified in the original publication. + + These have the values: + + $c_0 = 0.2$ + $c_1 = 0.995$ + $c_2 = 0.25$ + $c_3 = 33.93669377311689$ + + Parameters + ========== + + fl_T : Any (sympifiable) + Normalized tendon force as a function of tendon length. + + """ + c0 = Float('0.2') + c1 = Float('0.995') + c2 = Float('0.25') + c3 = Float('33.93669377311689') + return cls(fl_T, c0, c1, c2, c3) + + @classmethod + def eval(cls, fl_T, c0, c1, c2, c3): + """Evaluation of basic inputs. + + Parameters + ========== + + fl_T : Any (sympifiable) + Normalized tendon force as a function of tendon length. + c0 : Any (sympifiable) + The first constant in the characteristic equation. The published + value is ``0.2``. + c1 : Any (sympifiable) + The second constant in the characteristic equation. The published + value is ``0.995``. + c2 : Any (sympifiable) + The third constant in the characteristic equation. The published + value is ``0.25``. + c3 : Any (sympifiable) + The fourth constant in the characteristic equation. The published + value is ``33.93669377311689``. + + """ + pass + + def _eval_evalf(self, prec): + """Evaluate the expression numerically using ``evalf``.""" + return self.doit(deep=False, evaluate=False)._eval_evalf(prec) + + def doit(self, deep=True, evaluate=True, **hints): + """Evaluate the expression defining the function. + + Parameters + ========== + + deep : bool + Whether ``doit`` should be recursively called. Default is ``True``. + evaluate : bool. + Whether the SymPy expression should be evaluated as it is + constructed. If ``False``, then no constant folding will be + conducted which will leave the expression in a more numerically- + stable for values of ``l_T_tilde`` that correspond to a sensible + operating range for a musculotendon. Default is ``True``. + **kwargs : dict[str, Any] + Additional keyword argument pairs to be recursively passed to + ``doit``. + + """ + fl_T, *constants = self.args + if deep: + hints['evaluate'] = evaluate + fl_T = fl_T.doit(deep=deep, **hints) + c0, c1, c2, c3 = [c.doit(deep=deep, **hints) for c in constants] + else: + c0, c1, c2, c3 = constants + + if evaluate: + return log((fl_T + c2)/c0)/c3 + c1 + + return log(UnevaluatedExpr((fl_T + c2)/c0))/c3 + c1 + + def fdiff(self, argindex=1): + """Derivative of the function with respect to a single argument. + + Parameters + ========== + + argindex : int + The index of the function's arguments with respect to which the + derivative should be taken. Argument indexes start at ``1``. + Default is ``1``. + + """ + fl_T, c0, c1, c2, c3 = self.args + if argindex == 1: + return 1/(c3*(fl_T + c2)) + elif argindex == 2: + return -1/(c0*c3) + elif argindex == 3: + return Integer(1) + elif argindex == 4: + return 1/(c3*(fl_T + c2)) + elif argindex == 5: + return -log(UnevaluatedExpr((fl_T + c2)/c0))/c3**2 + + raise ArgumentIndexError(self, argindex) + + def inverse(self, argindex=1): + """Inverse function. + + Parameters + ========== + + argindex : int + Value to start indexing the arguments at. Default is ``1``. + + """ + return TendonForceLengthDeGroote2016 + + def _latex(self, printer): + """Print a LaTeX representation of the function defining the curve. + + Parameters + ========== + + printer : Printer + The printer to be used to print the LaTeX string representation. + + """ + fl_T = self.args[0] + _fl_T = printer._print(fl_T) + return r'\left( \operatorname{fl}^T \right)^{-1} \left( %s \right)' % _fl_T + + +class FiberForceLengthPassiveDeGroote2016(CharacteristicCurveFunction): + r"""Passive muscle fiber force-length curve based on De Groote et al., 2016 + [1]_. + + Explanation + =========== + + The function is defined by the equation: + + $fl^M_{pas} = \frac{\frac{\exp{c_1 \left(\tilde{l^M} - 1\right)}}{c_0} - 1}{\exp{c_1} - 1}$ + + with constant values of $c_0 = 0.6$ and $c_1 = 4.0$. + + While it is possible to change the constant values, these were carefully + selected in the original publication to give the characteristic curve + specific and required properties. For example, the function produces a + passive fiber force very close to 0 for all normalized fiber lengths + between 0 and 1. + + Examples + ======== + + The preferred way to instantiate :class:`FiberForceLengthPassiveDeGroote2016` is + using the :meth:`~.with_defaults` constructor because this will automatically + populate the constants within the characteristic curve equation with the + floating point values from the original publication. This constructor takes + a single argument corresponding to normalized muscle fiber length. We'll + create a :class:`~.Symbol` called ``l_M_tilde`` to represent this. + + >>> from sympy import Symbol + >>> from sympy.physics.biomechanics import FiberForceLengthPassiveDeGroote2016 + >>> l_M_tilde = Symbol('l_M_tilde') + >>> fl_M = FiberForceLengthPassiveDeGroote2016.with_defaults(l_M_tilde) + >>> fl_M + FiberForceLengthPassiveDeGroote2016(l_M_tilde, 0.6, 4.0) + + It's also possible to populate the two constants with your own values too. + + >>> from sympy import symbols + >>> c0, c1 = symbols('c0 c1') + >>> fl_M = FiberForceLengthPassiveDeGroote2016(l_M_tilde, c0, c1) + >>> fl_M + FiberForceLengthPassiveDeGroote2016(l_M_tilde, c0, c1) + + You don't just have to use symbols as the arguments, it's also possible to + use expressions. Let's create a new pair of symbols, ``l_M`` and + ``l_M_opt``, representing muscle fiber length and optimal muscle fiber + length respectively. We can then represent ``l_M_tilde`` as an expression, + the ratio of these. + + >>> l_M, l_M_opt = symbols('l_M l_M_opt') + >>> l_M_tilde = l_M/l_M_opt + >>> fl_M = FiberForceLengthPassiveDeGroote2016.with_defaults(l_M_tilde) + >>> fl_M + FiberForceLengthPassiveDeGroote2016(l_M/l_M_opt, 0.6, 4.0) + + To inspect the actual symbolic expression that this function represents, + we can call the :meth:`~.doit` method on an instance. We'll use the keyword + argument ``evaluate=False`` as this will keep the expression in its + canonical form and won't simplify any constants. + + >>> fl_M.doit(evaluate=False) + 0.0186573603637741*(-1 + exp(6.66666666666667*(l_M/l_M_opt - 1))) + + The function can also be differentiated. We'll differentiate with respect + to l_M using the ``diff`` method on an instance with the single positional + argument ``l_M``. + + >>> fl_M.diff(l_M) + 0.12438240242516*exp(6.66666666666667*(l_M/l_M_opt - 1))/l_M_opt + + References + ========== + + .. [1] De Groote, F., Kinney, A. L., Rao, A. V., & Fregly, B. J., Evaluation + of direct collocation optimal control problem formulations for + solving the muscle redundancy problem, Annals of biomedical + engineering, 44(10), (2016) pp. 2922-2936 + + """ + + @classmethod + def with_defaults(cls, l_M_tilde): + r"""Recommended constructor that will use the published constants. + + Explanation + =========== + + Returns a new instance of the muscle fiber passive force-length + function using the four constant values specified in the original + publication. + + These have the values: + + $c_0 = 0.6$ + $c_1 = 4.0$ + + Parameters + ========== + + l_M_tilde : Any (sympifiable) + Normalized muscle fiber length. + + """ + c0 = Float('0.6') + c1 = Float('4.0') + return cls(l_M_tilde, c0, c1) + + @classmethod + def eval(cls, l_M_tilde, c0, c1): + """Evaluation of basic inputs. + + Parameters + ========== + + l_M_tilde : Any (sympifiable) + Normalized muscle fiber length. + c0 : Any (sympifiable) + The first constant in the characteristic equation. The published + value is ``0.6``. + c1 : Any (sympifiable) + The second constant in the characteristic equation. The published + value is ``4.0``. + + """ + pass + + def _eval_evalf(self, prec): + """Evaluate the expression numerically using ``evalf``.""" + return self.doit(deep=False, evaluate=False)._eval_evalf(prec) + + def doit(self, deep=True, evaluate=True, **hints): + """Evaluate the expression defining the function. + + Parameters + ========== + + deep : bool + Whether ``doit`` should be recursively called. Default is ``True``. + evaluate : bool. + Whether the SymPy expression should be evaluated as it is + constructed. If ``False``, then no constant folding will be + conducted which will leave the expression in a more numerically- + stable for values of ``l_T_tilde`` that correspond to a sensible + operating range for a musculotendon. Default is ``True``. + **kwargs : dict[str, Any] + Additional keyword argument pairs to be recursively passed to + ``doit``. + + """ + l_M_tilde, *constants = self.args + if deep: + hints['evaluate'] = evaluate + l_M_tilde = l_M_tilde.doit(deep=deep, **hints) + c0, c1 = [c.doit(deep=deep, **hints) for c in constants] + else: + c0, c1 = constants + + if evaluate: + return (exp((c1*(l_M_tilde - 1))/c0) - 1)/(exp(c1) - 1) + + return (exp((c1*UnevaluatedExpr(l_M_tilde - 1))/c0) - 1)/(exp(c1) - 1) + + def fdiff(self, argindex=1): + """Derivative of the function with respect to a single argument. + + Parameters + ========== + + argindex : int + The index of the function's arguments with respect to which the + derivative should be taken. Argument indexes start at ``1``. + Default is ``1``. + + """ + l_M_tilde, c0, c1 = self.args + if argindex == 1: + return c1*exp(c1*UnevaluatedExpr(l_M_tilde - 1)/c0)/(c0*(exp(c1) - 1)) + elif argindex == 2: + return ( + -c1*exp(c1*UnevaluatedExpr(l_M_tilde - 1)/c0) + *UnevaluatedExpr(l_M_tilde - 1)/(c0**2*(exp(c1) - 1)) + ) + elif argindex == 3: + return ( + -exp(c1)*(-1 + exp(c1*UnevaluatedExpr(l_M_tilde - 1)/c0))/(exp(c1) - 1)**2 + + exp(c1*UnevaluatedExpr(l_M_tilde - 1)/c0)*(l_M_tilde - 1)/(c0*(exp(c1) - 1)) + ) + + raise ArgumentIndexError(self, argindex) + + def inverse(self, argindex=1): + """Inverse function. + + Parameters + ========== + + argindex : int + Value to start indexing the arguments at. Default is ``1``. + + """ + return FiberForceLengthPassiveInverseDeGroote2016 + + def _latex(self, printer): + """Print a LaTeX representation of the function defining the curve. + + Parameters + ========== + + printer : Printer + The printer to be used to print the LaTeX string representation. + + """ + l_M_tilde = self.args[0] + _l_M_tilde = printer._print(l_M_tilde) + return r'\operatorname{fl}^M_{pas} \left( %s \right)' % _l_M_tilde + + +class FiberForceLengthPassiveInverseDeGroote2016(CharacteristicCurveFunction): + r"""Inverse passive muscle fiber force-length curve based on De Groote et + al., 2016 [1]_. + + Explanation + =========== + + Gives the normalized muscle fiber length that produces a specific normalized + passive muscle fiber force. + + The function is defined by the equation: + + ${fl^M_{pas}}^{-1} = \frac{c_0 \log{\left(\exp{c_1} - 1\right)fl^M_pas + 1}}{c_1} + 1$ + + with constant values of $c_0 = 0.6$ and $c_1 = 4.0$. This function is the + exact analytical inverse of the related tendon force-length curve + ``FiberForceLengthPassiveDeGroote2016``. + + While it is possible to change the constant values, these were carefully + selected in the original publication to give the characteristic curve + specific and required properties. For example, the function produces a + passive fiber force very close to 0 for all normalized fiber lengths + between 0 and 1. + + Examples + ======== + + The preferred way to instantiate + :class:`FiberForceLengthPassiveInverseDeGroote2016` is using the + :meth:`~.with_defaults` constructor because this will automatically populate the + constants within the characteristic curve equation with the floating point + values from the original publication. This constructor takes a single + argument corresponding to the normalized passive muscle fiber length-force + component of the muscle fiber force. We'll create a :class:`~.Symbol` called + ``fl_M_pas`` to represent this. + + >>> from sympy import Symbol + >>> from sympy.physics.biomechanics import FiberForceLengthPassiveInverseDeGroote2016 + >>> fl_M_pas = Symbol('fl_M_pas') + >>> l_M_tilde = FiberForceLengthPassiveInverseDeGroote2016.with_defaults(fl_M_pas) + >>> l_M_tilde + FiberForceLengthPassiveInverseDeGroote2016(fl_M_pas, 0.6, 4.0) + + It's also possible to populate the two constants with your own values too. + + >>> from sympy import symbols + >>> c0, c1 = symbols('c0 c1') + >>> l_M_tilde = FiberForceLengthPassiveInverseDeGroote2016(fl_M_pas, c0, c1) + >>> l_M_tilde + FiberForceLengthPassiveInverseDeGroote2016(fl_M_pas, c0, c1) + + To inspect the actual symbolic expression that this function represents, + we can call the :meth:`~.doit` method on an instance. We'll use the keyword + argument ``evaluate=False`` as this will keep the expression in its + canonical form and won't simplify any constants. + + >>> l_M_tilde.doit(evaluate=False) + c0*log(1 + fl_M_pas*(exp(c1) - 1))/c1 + 1 + + The function can also be differentiated. We'll differentiate with respect + to fl_M_pas using the ``diff`` method on an instance with the single positional + argument ``fl_M_pas``. + + >>> l_M_tilde.diff(fl_M_pas) + c0*(exp(c1) - 1)/(c1*(fl_M_pas*(exp(c1) - 1) + 1)) + + References + ========== + + .. [1] De Groote, F., Kinney, A. L., Rao, A. V., & Fregly, B. J., Evaluation + of direct collocation optimal control problem formulations for + solving the muscle redundancy problem, Annals of biomedical + engineering, 44(10), (2016) pp. 2922-2936 + + """ + + @classmethod + def with_defaults(cls, fl_M_pas): + r"""Recommended constructor that will use the published constants. + + Explanation + =========== + + Returns a new instance of the inverse muscle fiber passive force-length + function using the four constant values specified in the original + publication. + + These have the values: + + $c_0 = 0.6$ + $c_1 = 4.0$ + + Parameters + ========== + + fl_M_pas : Any (sympifiable) + Normalized passive muscle fiber force as a function of muscle fiber + length. + + """ + c0 = Float('0.6') + c1 = Float('4.0') + return cls(fl_M_pas, c0, c1) + + @classmethod + def eval(cls, fl_M_pas, c0, c1): + """Evaluation of basic inputs. + + Parameters + ========== + + fl_M_pas : Any (sympifiable) + Normalized passive muscle fiber force. + c0 : Any (sympifiable) + The first constant in the characteristic equation. The published + value is ``0.6``. + c1 : Any (sympifiable) + The second constant in the characteristic equation. The published + value is ``4.0``. + + """ + pass + + def _eval_evalf(self, prec): + """Evaluate the expression numerically using ``evalf``.""" + return self.doit(deep=False, evaluate=False)._eval_evalf(prec) + + def doit(self, deep=True, evaluate=True, **hints): + """Evaluate the expression defining the function. + + Parameters + ========== + + deep : bool + Whether ``doit`` should be recursively called. Default is ``True``. + evaluate : bool. + Whether the SymPy expression should be evaluated as it is + constructed. If ``False``, then no constant folding will be + conducted which will leave the expression in a more numerically- + stable for values of ``l_T_tilde`` that correspond to a sensible + operating range for a musculotendon. Default is ``True``. + **kwargs : dict[str, Any] + Additional keyword argument pairs to be recursively passed to + ``doit``. + + """ + fl_M_pas, *constants = self.args + if deep: + hints['evaluate'] = evaluate + fl_M_pas = fl_M_pas.doit(deep=deep, **hints) + c0, c1 = [c.doit(deep=deep, **hints) for c in constants] + else: + c0, c1 = constants + + if evaluate: + return c0*log(fl_M_pas*(exp(c1) - 1) + 1)/c1 + 1 + + return c0*log(UnevaluatedExpr(fl_M_pas*(exp(c1) - 1)) + 1)/c1 + 1 + + def fdiff(self, argindex=1): + """Derivative of the function with respect to a single argument. + + Parameters + ========== + + argindex : int + The index of the function's arguments with respect to which the + derivative should be taken. Argument indexes start at ``1``. + Default is ``1``. + + """ + fl_M_pas, c0, c1 = self.args + if argindex == 1: + return c0*(exp(c1) - 1)/(c1*(fl_M_pas*(exp(c1) - 1) + 1)) + elif argindex == 2: + return log(fl_M_pas*(exp(c1) - 1) + 1)/c1 + elif argindex == 3: + return ( + c0*fl_M_pas*exp(c1)/(c1*(fl_M_pas*(exp(c1) - 1) + 1)) + - c0*log(fl_M_pas*(exp(c1) - 1) + 1)/c1**2 + ) + + raise ArgumentIndexError(self, argindex) + + def inverse(self, argindex=1): + """Inverse function. + + Parameters + ========== + + argindex : int + Value to start indexing the arguments at. Default is ``1``. + + """ + return FiberForceLengthPassiveDeGroote2016 + + def _latex(self, printer): + """Print a LaTeX representation of the function defining the curve. + + Parameters + ========== + + printer : Printer + The printer to be used to print the LaTeX string representation. + + """ + fl_M_pas = self.args[0] + _fl_M_pas = printer._print(fl_M_pas) + return r'\left( \operatorname{fl}^M_{pas} \right)^{-1} \left( %s \right)' % _fl_M_pas + + +class FiberForceLengthActiveDeGroote2016(CharacteristicCurveFunction): + r"""Active muscle fiber force-length curve based on De Groote et al., 2016 + [1]_. + + Explanation + =========== + + The function is defined by the equation: + + $fl_{\text{act}}^M = c_0 \exp\left(-\frac{1}{2}\left(\frac{\tilde{l}^M - c_1}{c_2 + c_3 \tilde{l}^M}\right)^2\right) + + c_4 \exp\left(-\frac{1}{2}\left(\frac{\tilde{l}^M - c_5}{c_6 + c_7 \tilde{l}^M}\right)^2\right) + + c_8 \exp\left(-\frac{1}{2}\left(\frac{\tilde{l}^M - c_9}{c_{10} + c_{11} \tilde{l}^M}\right)^2\right)$ + + with constant values of $c0 = 0.814$, $c1 = 1.06$, $c2 = 0.162$, + $c3 = 0.0633$, $c4 = 0.433$, $c5 = 0.717$, $c6 = -0.0299$, $c7 = 0.2$, + $c8 = 0.1$, $c9 = 1.0$, $c10 = 0.354$, and $c11 = 0.0$. + + While it is possible to change the constant values, these were carefully + selected in the original publication to give the characteristic curve + specific and required properties. For example, the function produces a + active fiber force of 1 at a normalized fiber length of 1, and an active + fiber force of 0 at normalized fiber lengths of 0 and 2. + + Examples + ======== + + The preferred way to instantiate :class:`FiberForceLengthActiveDeGroote2016` is + using the :meth:`~.with_defaults` constructor because this will automatically + populate the constants within the characteristic curve equation with the + floating point values from the original publication. This constructor takes + a single argument corresponding to normalized muscle fiber length. We'll + create a :class:`~.Symbol` called ``l_M_tilde`` to represent this. + + >>> from sympy import Symbol + >>> from sympy.physics.biomechanics import FiberForceLengthActiveDeGroote2016 + >>> l_M_tilde = Symbol('l_M_tilde') + >>> fl_M = FiberForceLengthActiveDeGroote2016.with_defaults(l_M_tilde) + >>> fl_M + FiberForceLengthActiveDeGroote2016(l_M_tilde, 0.814, 1.06, 0.162, 0.0633, + 0.433, 0.717, -0.0299, 0.2, 0.1, 1.0, 0.354, 0.0) + + It's also possible to populate the two constants with your own values too. + + >>> from sympy import symbols + >>> c0, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11 = symbols('c0:12') + >>> fl_M = FiberForceLengthActiveDeGroote2016(l_M_tilde, c0, c1, c2, c3, + ... c4, c5, c6, c7, c8, c9, c10, c11) + >>> fl_M + FiberForceLengthActiveDeGroote2016(l_M_tilde, c0, c1, c2, c3, c4, c5, c6, + c7, c8, c9, c10, c11) + + You don't just have to use symbols as the arguments, it's also possible to + use expressions. Let's create a new pair of symbols, ``l_M`` and + ``l_M_opt``, representing muscle fiber length and optimal muscle fiber + length respectively. We can then represent ``l_M_tilde`` as an expression, + the ratio of these. + + >>> l_M, l_M_opt = symbols('l_M l_M_opt') + >>> l_M_tilde = l_M/l_M_opt + >>> fl_M = FiberForceLengthActiveDeGroote2016.with_defaults(l_M_tilde) + >>> fl_M + FiberForceLengthActiveDeGroote2016(l_M/l_M_opt, 0.814, 1.06, 0.162, 0.0633, + 0.433, 0.717, -0.0299, 0.2, 0.1, 1.0, 0.354, 0.0) + + To inspect the actual symbolic expression that this function represents, + we can call the :meth:`~.doit` method on an instance. We'll use the keyword + argument ``evaluate=False`` as this will keep the expression in its + canonical form and won't simplify any constants. + + >>> fl_M.doit(evaluate=False) + 0.814*exp(-(l_M/l_M_opt + - 1.06)**2/(2*(0.0633*l_M/l_M_opt + 0.162)**2)) + + 0.433*exp(-(l_M/l_M_opt - 0.717)**2/(2*(0.2*l_M/l_M_opt - 0.0299)**2)) + + 0.1*exp(-3.98991349867535*(l_M/l_M_opt - 1.0)**2) + + The function can also be differentiated. We'll differentiate with respect + to l_M using the ``diff`` method on an instance with the single positional + argument ``l_M``. + + >>> fl_M.diff(l_M) + ((-0.79798269973507*l_M/l_M_opt + + 0.79798269973507)*exp(-3.98991349867535*(l_M/l_M_opt - 1.0)**2) + + (0.433*(-l_M/l_M_opt + 0.717)/(0.2*l_M/l_M_opt - 0.0299)**2 + + 0.0866*(l_M/l_M_opt - 0.717)**2/(0.2*l_M/l_M_opt + - 0.0299)**3)*exp(-(l_M/l_M_opt - 0.717)**2/(2*(0.2*l_M/l_M_opt - 0.0299)**2)) + + (0.814*(-l_M/l_M_opt + 1.06)/(0.0633*l_M/l_M_opt + + 0.162)**2 + 0.0515262*(l_M/l_M_opt + - 1.06)**2/(0.0633*l_M/l_M_opt + + 0.162)**3)*exp(-(l_M/l_M_opt + - 1.06)**2/(2*(0.0633*l_M/l_M_opt + 0.162)**2)))/l_M_opt + + References + ========== + + .. [1] De Groote, F., Kinney, A. L., Rao, A. V., & Fregly, B. J., Evaluation + of direct collocation optimal control problem formulations for + solving the muscle redundancy problem, Annals of biomedical + engineering, 44(10), (2016) pp. 2922-2936 + + """ + + @classmethod + def with_defaults(cls, l_M_tilde): + r"""Recommended constructor that will use the published constants. + + Explanation + =========== + + Returns a new instance of the inverse muscle fiber act force-length + function using the four constant values specified in the original + publication. + + These have the values: + + $c0 = 0.814$ + $c1 = 1.06$ + $c2 = 0.162$ + $c3 = 0.0633$ + $c4 = 0.433$ + $c5 = 0.717$ + $c6 = -0.0299$ + $c7 = 0.2$ + $c8 = 0.1$ + $c9 = 1.0$ + $c10 = 0.354$ + $c11 = 0.0$ + + Parameters + ========== + + fl_M_act : Any (sympifiable) + Normalized passive muscle fiber force as a function of muscle fiber + length. + + """ + c0 = Float('0.814') + c1 = Float('1.06') + c2 = Float('0.162') + c3 = Float('0.0633') + c4 = Float('0.433') + c5 = Float('0.717') + c6 = Float('-0.0299') + c7 = Float('0.2') + c8 = Float('0.1') + c9 = Float('1.0') + c10 = Float('0.354') + c11 = Float('0.0') + return cls(l_M_tilde, c0, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11) + + @classmethod + def eval(cls, l_M_tilde, c0, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11): + """Evaluation of basic inputs. + + Parameters + ========== + + l_M_tilde : Any (sympifiable) + Normalized muscle fiber length. + c0 : Any (sympifiable) + The first constant in the characteristic equation. The published + value is ``0.814``. + c1 : Any (sympifiable) + The second constant in the characteristic equation. The published + value is ``1.06``. + c2 : Any (sympifiable) + The third constant in the characteristic equation. The published + value is ``0.162``. + c3 : Any (sympifiable) + The fourth constant in the characteristic equation. The published + value is ``0.0633``. + c4 : Any (sympifiable) + The fifth constant in the characteristic equation. The published + value is ``0.433``. + c5 : Any (sympifiable) + The sixth constant in the characteristic equation. The published + value is ``0.717``. + c6 : Any (sympifiable) + The seventh constant in the characteristic equation. The published + value is ``-0.0299``. + c7 : Any (sympifiable) + The eighth constant in the characteristic equation. The published + value is ``0.2``. + c8 : Any (sympifiable) + The ninth constant in the characteristic equation. The published + value is ``0.1``. + c9 : Any (sympifiable) + The tenth constant in the characteristic equation. The published + value is ``1.0``. + c10 : Any (sympifiable) + The eleventh constant in the characteristic equation. The published + value is ``0.354``. + c11 : Any (sympifiable) + The tweflth constant in the characteristic equation. The published + value is ``0.0``. + + """ + pass + + def _eval_evalf(self, prec): + """Evaluate the expression numerically using ``evalf``.""" + return self.doit(deep=False, evaluate=False)._eval_evalf(prec) + + def doit(self, deep=True, evaluate=True, **hints): + """Evaluate the expression defining the function. + + Parameters + ========== + + deep : bool + Whether ``doit`` should be recursively called. Default is ``True``. + evaluate : bool. + Whether the SymPy expression should be evaluated as it is + constructed. If ``False``, then no constant folding will be + conducted which will leave the expression in a more numerically- + stable for values of ``l_M_tilde`` that correspond to a sensible + operating range for a musculotendon. Default is ``True``. + **kwargs : dict[str, Any] + Additional keyword argument pairs to be recursively passed to + ``doit``. + + """ + l_M_tilde, *constants = self.args + if deep: + hints['evaluate'] = evaluate + l_M_tilde = l_M_tilde.doit(deep=deep, **hints) + constants = [c.doit(deep=deep, **hints) for c in constants] + c0, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11 = constants + + if evaluate: + return ( + c0*exp(-(((l_M_tilde - c1)/(c2 + c3*l_M_tilde))**2)/2) + + c4*exp(-(((l_M_tilde - c5)/(c6 + c7*l_M_tilde))**2)/2) + + c8*exp(-(((l_M_tilde - c9)/(c10 + c11*l_M_tilde))**2)/2) + ) + + return ( + c0*exp(-((UnevaluatedExpr(l_M_tilde - c1)/(c2 + c3*l_M_tilde))**2)/2) + + c4*exp(-((UnevaluatedExpr(l_M_tilde - c5)/(c6 + c7*l_M_tilde))**2)/2) + + c8*exp(-((UnevaluatedExpr(l_M_tilde - c9)/(c10 + c11*l_M_tilde))**2)/2) + ) + + def fdiff(self, argindex=1): + """Derivative of the function with respect to a single argument. + + Parameters + ========== + + argindex : int + The index of the function's arguments with respect to which the + derivative should be taken. Argument indexes start at ``1``. + Default is ``1``. + + """ + l_M_tilde, c0, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11 = self.args + if argindex == 1: + return ( + c0*( + c3*(l_M_tilde - c1)**2/(c2 + c3*l_M_tilde)**3 + + (c1 - l_M_tilde)/((c2 + c3*l_M_tilde)**2) + )*exp(-(l_M_tilde - c1)**2/(2*(c2 + c3*l_M_tilde)**2)) + + c4*( + c7*(l_M_tilde - c5)**2/(c6 + c7*l_M_tilde)**3 + + (c5 - l_M_tilde)/((c6 + c7*l_M_tilde)**2) + )*exp(-(l_M_tilde - c5)**2/(2*(c6 + c7*l_M_tilde)**2)) + + c8*( + c11*(l_M_tilde - c9)**2/(c10 + c11*l_M_tilde)**3 + + (c9 - l_M_tilde)/((c10 + c11*l_M_tilde)**2) + )*exp(-(l_M_tilde - c9)**2/(2*(c10 + c11*l_M_tilde)**2)) + ) + elif argindex == 2: + return exp(-(l_M_tilde - c1)**2/(2*(c2 + c3*l_M_tilde)**2)) + elif argindex == 3: + return ( + c0*(l_M_tilde - c1)/(c2 + c3*l_M_tilde)**2 + *exp(-(l_M_tilde - c1)**2 /(2*(c2 + c3*l_M_tilde)**2)) + ) + elif argindex == 4: + return ( + c0*(l_M_tilde - c1)**2/(c2 + c3*l_M_tilde)**3 + *exp(-(l_M_tilde - c1)**2/(2*(c2 + c3*l_M_tilde)**2)) + ) + elif argindex == 5: + return ( + c0*l_M_tilde*(l_M_tilde - c1)**2/(c2 + c3*l_M_tilde)**3 + *exp(-(l_M_tilde - c1)**2/(2*(c2 + c3*l_M_tilde)**2)) + ) + elif argindex == 6: + return exp(-(l_M_tilde - c5)**2/(2*(c6 + c7*l_M_tilde)**2)) + elif argindex == 7: + return ( + c4*(l_M_tilde - c5)/(c6 + c7*l_M_tilde)**2 + *exp(-(l_M_tilde - c5)**2 /(2*(c6 + c7*l_M_tilde)**2)) + ) + elif argindex == 8: + return ( + c4*(l_M_tilde - c5)**2/(c6 + c7*l_M_tilde)**3 + *exp(-(l_M_tilde - c5)**2/(2*(c6 + c7*l_M_tilde)**2)) + ) + elif argindex == 9: + return ( + c4*l_M_tilde*(l_M_tilde - c5)**2/(c6 + c7*l_M_tilde)**3 + *exp(-(l_M_tilde - c5)**2/(2*(c6 + c7*l_M_tilde)**2)) + ) + elif argindex == 10: + return exp(-(l_M_tilde - c9)**2/(2*(c10 + c11*l_M_tilde)**2)) + elif argindex == 11: + return ( + c8*(l_M_tilde - c9)/(c10 + c11*l_M_tilde)**2 + *exp(-(l_M_tilde - c9)**2 /(2*(c10 + c11*l_M_tilde)**2)) + ) + elif argindex == 12: + return ( + c8*(l_M_tilde - c9)**2/(c10 + c11*l_M_tilde)**3 + *exp(-(l_M_tilde - c9)**2/(2*(c10 + c11*l_M_tilde)**2)) + ) + elif argindex == 13: + return ( + c8*l_M_tilde*(l_M_tilde - c9)**2/(c10 + c11*l_M_tilde)**3 + *exp(-(l_M_tilde - c9)**2/(2*(c10 + c11*l_M_tilde)**2)) + ) + + raise ArgumentIndexError(self, argindex) + + def _latex(self, printer): + """Print a LaTeX representation of the function defining the curve. + + Parameters + ========== + + printer : Printer + The printer to be used to print the LaTeX string representation. + + """ + l_M_tilde = self.args[0] + _l_M_tilde = printer._print(l_M_tilde) + return r'\operatorname{fl}^M_{act} \left( %s \right)' % _l_M_tilde + + +class FiberForceVelocityDeGroote2016(CharacteristicCurveFunction): + r"""Muscle fiber force-velocity curve based on De Groote et al., 2016 [1]_. + + Explanation + =========== + + Gives the normalized muscle fiber force produced as a function of + normalized tendon velocity. + + The function is defined by the equation: + + $fv^M = c_0 \log{\left(c_1 \tilde{v}_m + c_2\right) + \sqrt{\left(c_1 \tilde{v}_m + c_2\right)^2 + 1}} + c_3$ + + with constant values of $c_0 = -0.318$, $c_1 = -8.149$, $c_2 = -0.374$, and + $c_3 = 0.886$. + + While it is possible to change the constant values, these were carefully + selected in the original publication to give the characteristic curve + specific and required properties. For example, the function produces a + normalized muscle fiber force of 1 when the muscle fibers are contracting + isometrically (they have an extension rate of 0). + + Examples + ======== + + The preferred way to instantiate :class:`FiberForceVelocityDeGroote2016` is using + the :meth:`~.with_defaults` constructor because this will automatically populate + the constants within the characteristic curve equation with the floating + point values from the original publication. This constructor takes a single + argument corresponding to normalized muscle fiber extension velocity. We'll + create a :class:`~.Symbol` called ``v_M_tilde`` to represent this. + + >>> from sympy import Symbol + >>> from sympy.physics.biomechanics import FiberForceVelocityDeGroote2016 + >>> v_M_tilde = Symbol('v_M_tilde') + >>> fv_M = FiberForceVelocityDeGroote2016.with_defaults(v_M_tilde) + >>> fv_M + FiberForceVelocityDeGroote2016(v_M_tilde, -0.318, -8.149, -0.374, 0.886) + + It's also possible to populate the four constants with your own values too. + + >>> from sympy import symbols + >>> c0, c1, c2, c3 = symbols('c0 c1 c2 c3') + >>> fv_M = FiberForceVelocityDeGroote2016(v_M_tilde, c0, c1, c2, c3) + >>> fv_M + FiberForceVelocityDeGroote2016(v_M_tilde, c0, c1, c2, c3) + + You don't just have to use symbols as the arguments, it's also possible to + use expressions. Let's create a new pair of symbols, ``v_M`` and + ``v_M_max``, representing muscle fiber extension velocity and maximum + muscle fiber extension velocity respectively. We can then represent + ``v_M_tilde`` as an expression, the ratio of these. + + >>> v_M, v_M_max = symbols('v_M v_M_max') + >>> v_M_tilde = v_M/v_M_max + >>> fv_M = FiberForceVelocityDeGroote2016.with_defaults(v_M_tilde) + >>> fv_M + FiberForceVelocityDeGroote2016(v_M/v_M_max, -0.318, -8.149, -0.374, 0.886) + + To inspect the actual symbolic expression that this function represents, + we can call the :meth:`~.doit` method on an instance. We'll use the keyword + argument ``evaluate=False`` as this will keep the expression in its + canonical form and won't simplify any constants. + + >>> fv_M.doit(evaluate=False) + 0.886 - 0.318*log(-8.149*v_M/v_M_max - 0.374 + sqrt(1 + (-8.149*v_M/v_M_max + - 0.374)**2)) + + The function can also be differentiated. We'll differentiate with respect + to v_M using the ``diff`` method on an instance with the single positional + argument ``v_M``. + + >>> fv_M.diff(v_M) + 2.591382*(1 + (-8.149*v_M/v_M_max - 0.374)**2)**(-1/2)/v_M_max + + References + ========== + + .. [1] De Groote, F., Kinney, A. L., Rao, A. V., & Fregly, B. J., Evaluation + of direct collocation optimal control problem formulations for + solving the muscle redundancy problem, Annals of biomedical + engineering, 44(10), (2016) pp. 2922-2936 + + """ + + @classmethod + def with_defaults(cls, v_M_tilde): + r"""Recommended constructor that will use the published constants. + + Explanation + =========== + + Returns a new instance of the muscle fiber force-velocity function + using the four constant values specified in the original publication. + + These have the values: + + $c_0 = -0.318$ + $c_1 = -8.149$ + $c_2 = -0.374$ + $c_3 = 0.886$ + + Parameters + ========== + + v_M_tilde : Any (sympifiable) + Normalized muscle fiber extension velocity. + + """ + c0 = Float('-0.318') + c1 = Float('-8.149') + c2 = Float('-0.374') + c3 = Float('0.886') + return cls(v_M_tilde, c0, c1, c2, c3) + + @classmethod + def eval(cls, v_M_tilde, c0, c1, c2, c3): + """Evaluation of basic inputs. + + Parameters + ========== + + v_M_tilde : Any (sympifiable) + Normalized muscle fiber extension velocity. + c0 : Any (sympifiable) + The first constant in the characteristic equation. The published + value is ``-0.318``. + c1 : Any (sympifiable) + The second constant in the characteristic equation. The published + value is ``-8.149``. + c2 : Any (sympifiable) + The third constant in the characteristic equation. The published + value is ``-0.374``. + c3 : Any (sympifiable) + The fourth constant in the characteristic equation. The published + value is ``0.886``. + + """ + pass + + def _eval_evalf(self, prec): + """Evaluate the expression numerically using ``evalf``.""" + return self.doit(deep=False, evaluate=False)._eval_evalf(prec) + + def doit(self, deep=True, evaluate=True, **hints): + """Evaluate the expression defining the function. + + Parameters + ========== + + deep : bool + Whether ``doit`` should be recursively called. Default is ``True``. + evaluate : bool. + Whether the SymPy expression should be evaluated as it is + constructed. If ``False``, then no constant folding will be + conducted which will leave the expression in a more numerically- + stable for values of ``v_M_tilde`` that correspond to a sensible + operating range for a musculotendon. Default is ``True``. + **kwargs : dict[str, Any] + Additional keyword argument pairs to be recursively passed to + ``doit``. + + """ + v_M_tilde, *constants = self.args + if deep: + hints['evaluate'] = evaluate + v_M_tilde = v_M_tilde.doit(deep=deep, **hints) + c0, c1, c2, c3 = [c.doit(deep=deep, **hints) for c in constants] + else: + c0, c1, c2, c3 = constants + + if evaluate: + return c0*log(c1*v_M_tilde + c2 + sqrt((c1*v_M_tilde + c2)**2 + 1)) + c3 + + return c0*log(c1*v_M_tilde + c2 + sqrt(UnevaluatedExpr(c1*v_M_tilde + c2)**2 + 1)) + c3 + + def fdiff(self, argindex=1): + """Derivative of the function with respect to a single argument. + + Parameters + ========== + + argindex : int + The index of the function's arguments with respect to which the + derivative should be taken. Argument indexes start at ``1``. + Default is ``1``. + + """ + v_M_tilde, c0, c1, c2, c3 = self.args + if argindex == 1: + return c0*c1/sqrt(UnevaluatedExpr(c1*v_M_tilde + c2)**2 + 1) + elif argindex == 2: + return log( + c1*v_M_tilde + c2 + + sqrt(UnevaluatedExpr(c1*v_M_tilde + c2)**2 + 1) + ) + elif argindex == 3: + return c0*v_M_tilde/sqrt(UnevaluatedExpr(c1*v_M_tilde + c2)**2 + 1) + elif argindex == 4: + return c0/sqrt(UnevaluatedExpr(c1*v_M_tilde + c2)**2 + 1) + elif argindex == 5: + return Integer(1) + + raise ArgumentIndexError(self, argindex) + + def inverse(self, argindex=1): + """Inverse function. + + Parameters + ========== + + argindex : int + Value to start indexing the arguments at. Default is ``1``. + + """ + return FiberForceVelocityInverseDeGroote2016 + + def _latex(self, printer): + """Print a LaTeX representation of the function defining the curve. + + Parameters + ========== + + printer : Printer + The printer to be used to print the LaTeX string representation. + + """ + v_M_tilde = self.args[0] + _v_M_tilde = printer._print(v_M_tilde) + return r'\operatorname{fv}^M \left( %s \right)' % _v_M_tilde + + +class FiberForceVelocityInverseDeGroote2016(CharacteristicCurveFunction): + r"""Inverse muscle fiber force-velocity curve based on De Groote et al., + 2016 [1]_. + + Explanation + =========== + + Gives the normalized muscle fiber velocity that produces a specific + normalized muscle fiber force. + + The function is defined by the equation: + + ${fv^M}^{-1} = \frac{\sinh{\frac{fv^M - c_3}{c_0}} - c_2}{c_1}$ + + with constant values of $c_0 = -0.318$, $c_1 = -8.149$, $c_2 = -0.374$, and + $c_3 = 0.886$. This function is the exact analytical inverse of the related + muscle fiber force-velocity curve ``FiberForceVelocityDeGroote2016``. + + While it is possible to change the constant values, these were carefully + selected in the original publication to give the characteristic curve + specific and required properties. For example, the function produces a + normalized muscle fiber force of 1 when the muscle fibers are contracting + isometrically (they have an extension rate of 0). + + Examples + ======== + + The preferred way to instantiate :class:`FiberForceVelocityInverseDeGroote2016` + is using the :meth:`~.with_defaults` constructor because this will automatically + populate the constants within the characteristic curve equation with the + floating point values from the original publication. This constructor takes + a single argument corresponding to normalized muscle fiber force-velocity + component of the muscle fiber force. We'll create a :class:`~.Symbol` called + ``fv_M`` to represent this. + + >>> from sympy import Symbol + >>> from sympy.physics.biomechanics import FiberForceVelocityInverseDeGroote2016 + >>> fv_M = Symbol('fv_M') + >>> v_M_tilde = FiberForceVelocityInverseDeGroote2016.with_defaults(fv_M) + >>> v_M_tilde + FiberForceVelocityInverseDeGroote2016(fv_M, -0.318, -8.149, -0.374, 0.886) + + It's also possible to populate the four constants with your own values too. + + >>> from sympy import symbols + >>> c0, c1, c2, c3 = symbols('c0 c1 c2 c3') + >>> v_M_tilde = FiberForceVelocityInverseDeGroote2016(fv_M, c0, c1, c2, c3) + >>> v_M_tilde + FiberForceVelocityInverseDeGroote2016(fv_M, c0, c1, c2, c3) + + To inspect the actual symbolic expression that this function represents, + we can call the :meth:`~.doit` method on an instance. We'll use the keyword + argument ``evaluate=False`` as this will keep the expression in its + canonical form and won't simplify any constants. + + >>> v_M_tilde.doit(evaluate=False) + (-c2 + sinh((-c3 + fv_M)/c0))/c1 + + The function can also be differentiated. We'll differentiate with respect + to fv_M using the ``diff`` method on an instance with the single positional + argument ``fv_M``. + + >>> v_M_tilde.diff(fv_M) + cosh((-c3 + fv_M)/c0)/(c0*c1) + + References + ========== + + .. [1] De Groote, F., Kinney, A. L., Rao, A. V., & Fregly, B. J., Evaluation + of direct collocation optimal control problem formulations for + solving the muscle redundancy problem, Annals of biomedical + engineering, 44(10), (2016) pp. 2922-2936 + + """ + + @classmethod + def with_defaults(cls, fv_M): + r"""Recommended constructor that will use the published constants. + + Explanation + =========== + + Returns a new instance of the inverse muscle fiber force-velocity + function using the four constant values specified in the original + publication. + + These have the values: + + $c_0 = -0.318$ + $c_1 = -8.149$ + $c_2 = -0.374$ + $c_3 = 0.886$ + + Parameters + ========== + + fv_M : Any (sympifiable) + Normalized muscle fiber extension velocity. + + """ + c0 = Float('-0.318') + c1 = Float('-8.149') + c2 = Float('-0.374') + c3 = Float('0.886') + return cls(fv_M, c0, c1, c2, c3) + + @classmethod + def eval(cls, fv_M, c0, c1, c2, c3): + """Evaluation of basic inputs. + + Parameters + ========== + + fv_M : Any (sympifiable) + Normalized muscle fiber force as a function of muscle fiber + extension velocity. + c0 : Any (sympifiable) + The first constant in the characteristic equation. The published + value is ``-0.318``. + c1 : Any (sympifiable) + The second constant in the characteristic equation. The published + value is ``-8.149``. + c2 : Any (sympifiable) + The third constant in the characteristic equation. The published + value is ``-0.374``. + c3 : Any (sympifiable) + The fourth constant in the characteristic equation. The published + value is ``0.886``. + + """ + pass + + def _eval_evalf(self, prec): + """Evaluate the expression numerically using ``evalf``.""" + return self.doit(deep=False, evaluate=False)._eval_evalf(prec) + + def doit(self, deep=True, evaluate=True, **hints): + """Evaluate the expression defining the function. + + Parameters + ========== + + deep : bool + Whether ``doit`` should be recursively called. Default is ``True``. + evaluate : bool. + Whether the SymPy expression should be evaluated as it is + constructed. If ``False``, then no constant folding will be + conducted which will leave the expression in a more numerically- + stable for values of ``fv_M`` that correspond to a sensible + operating range for a musculotendon. Default is ``True``. + **kwargs : dict[str, Any] + Additional keyword argument pairs to be recursively passed to + ``doit``. + + """ + fv_M, *constants = self.args + if deep: + hints['evaluate'] = evaluate + fv_M = fv_M.doit(deep=deep, **hints) + c0, c1, c2, c3 = [c.doit(deep=deep, **hints) for c in constants] + else: + c0, c1, c2, c3 = constants + + if evaluate: + return (sinh((fv_M - c3)/c0) - c2)/c1 + + return (sinh(UnevaluatedExpr(fv_M - c3)/c0) - c2)/c1 + + def fdiff(self, argindex=1): + """Derivative of the function with respect to a single argument. + + Parameters + ========== + + argindex : int + The index of the function's arguments with respect to which the + derivative should be taken. Argument indexes start at ``1``. + Default is ``1``. + + """ + fv_M, c0, c1, c2, c3 = self.args + if argindex == 1: + return cosh((fv_M - c3)/c0)/(c0*c1) + elif argindex == 2: + return (c3 - fv_M)*cosh((fv_M - c3)/c0)/(c0**2*c1) + elif argindex == 3: + return (c2 - sinh((fv_M - c3)/c0))/c1**2 + elif argindex == 4: + return -1/c1 + elif argindex == 5: + return -cosh((fv_M - c3)/c0)/(c0*c1) + + raise ArgumentIndexError(self, argindex) + + def inverse(self, argindex=1): + """Inverse function. + + Parameters + ========== + + argindex : int + Value to start indexing the arguments at. Default is ``1``. + + """ + return FiberForceVelocityDeGroote2016 + + def _latex(self, printer): + """Print a LaTeX representation of the function defining the curve. + + Parameters + ========== + + printer : Printer + The printer to be used to print the LaTeX string representation. + + """ + fv_M = self.args[0] + _fv_M = printer._print(fv_M) + return r'\left( \operatorname{fv}^M \right)^{-1} \left( %s \right)' % _fv_M + + +@dataclass(frozen=True) +class CharacteristicCurveCollection: + """Simple data container to group together related characteristic curves.""" + tendon_force_length: CharacteristicCurveFunction + tendon_force_length_inverse: CharacteristicCurveFunction + fiber_force_length_passive: CharacteristicCurveFunction + fiber_force_length_passive_inverse: CharacteristicCurveFunction + fiber_force_length_active: CharacteristicCurveFunction + fiber_force_velocity: CharacteristicCurveFunction + fiber_force_velocity_inverse: CharacteristicCurveFunction + + def __iter__(self): + """Iterator support for ``CharacteristicCurveCollection``.""" + yield self.tendon_force_length + yield self.tendon_force_length_inverse + yield self.fiber_force_length_passive + yield self.fiber_force_length_passive_inverse + yield self.fiber_force_length_active + yield self.fiber_force_velocity + yield self.fiber_force_velocity_inverse diff --git a/phivenv/Lib/site-packages/sympy/physics/biomechanics/musculotendon.py b/phivenv/Lib/site-packages/sympy/physics/biomechanics/musculotendon.py new file mode 100644 index 0000000000000000000000000000000000000000..e16d66373da9107adee2e3b8418f657ee5879298 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/biomechanics/musculotendon.py @@ -0,0 +1,1424 @@ +"""Implementations of musculotendon models. + +Musculotendon models are a critical component of biomechanical models, one that +differentiates them from pure multibody systems. Musculotendon models produce a +force dependent on their level of activation, their length, and their +extension velocity. Length- and extension velocity-dependent force production +are governed by force-length and force-velocity characteristics. +These are normalized functions that are dependent on the musculotendon's state +and are specific to a given musculotendon model. + +""" + +from abc import abstractmethod +from enum import IntEnum, unique + +from sympy.core.numbers import Float, Integer +from sympy.core.symbol import Symbol, symbols +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import cos, sin +from sympy.matrices.dense import MutableDenseMatrix as Matrix, diag, eye, zeros +from sympy.physics.biomechanics.activation import ActivationBase +from sympy.physics.biomechanics.curve import ( + CharacteristicCurveCollection, + FiberForceLengthActiveDeGroote2016, + FiberForceLengthPassiveDeGroote2016, + FiberForceLengthPassiveInverseDeGroote2016, + FiberForceVelocityDeGroote2016, + FiberForceVelocityInverseDeGroote2016, + TendonForceLengthDeGroote2016, + TendonForceLengthInverseDeGroote2016, +) +from sympy.physics.biomechanics._mixin import _NamedMixin +from sympy.physics.mechanics.actuator import ForceActuator +from sympy.physics.vector.functions import dynamicsymbols + + +__all__ = [ + 'MusculotendonBase', + 'MusculotendonDeGroote2016', + 'MusculotendonFormulation', +] + + +@unique +class MusculotendonFormulation(IntEnum): + """Enumeration of types of musculotendon dynamics formulations. + + Explanation + =========== + + An (integer) enumeration is used as it allows for clearer selection of the + different formulations of musculotendon dynamics. + + Members + ======= + + RIGID_TENDON : 0 + A rigid tendon model. + FIBER_LENGTH_EXPLICIT : 1 + An explicit elastic tendon model with the muscle fiber length (l_M) as + the state variable. + TENDON_FORCE_EXPLICIT : 2 + An explicit elastic tendon model with the tendon force (F_T) as the + state variable. + FIBER_LENGTH_IMPLICIT : 3 + An implicit elastic tendon model with the muscle fiber length (l_M) as + the state variable and the muscle fiber velocity as an additional input + variable. + TENDON_FORCE_IMPLICIT : 4 + An implicit elastic tendon model with the tendon force (F_T) as the + state variable as the muscle fiber velocity as an additional input + variable. + + """ + + RIGID_TENDON = 0 + FIBER_LENGTH_EXPLICIT = 1 + TENDON_FORCE_EXPLICIT = 2 + FIBER_LENGTH_IMPLICIT = 3 + TENDON_FORCE_IMPLICIT = 4 + + def __str__(self): + """Returns a string representation of the enumeration value. + + Notes + ===== + + This hard coding is required due to an incompatibility between the + ``IntEnum`` implementations in Python 3.10 and Python 3.11 + (https://github.com/python/cpython/issues/84247). From Python 3.11 + onwards, the ``__str__`` method uses ``int.__str__``, whereas prior it + used ``Enum.__str__``. Once Python 3.11 becomes the minimum version + supported by SymPy, this method override can be removed. + + """ + return str(self.value) + + +_DEFAULT_MUSCULOTENDON_FORMULATION = MusculotendonFormulation.RIGID_TENDON + + +class MusculotendonBase(ForceActuator, _NamedMixin): + r"""Abstract base class for all musculotendon classes to inherit from. + + Explanation + =========== + + A musculotendon generates a contractile force based on its activation, + length, and shortening velocity. This abstract base class is to be inherited + by all musculotendon subclasses that implement different characteristic + musculotendon curves. Characteristic musculotendon curves are required for + the tendon force-length, passive fiber force-length, active fiber force- + length, and fiber force-velocity relationships. + + Parameters + ========== + + name : str + The name identifier associated with the musculotendon. This name is used + as a suffix when automatically generated symbols are instantiated. It + must be a string of nonzero length. + pathway : PathwayBase + The pathway that the actuator follows. This must be an instance of a + concrete subclass of ``PathwayBase``, e.g. ``LinearPathway``. + activation_dynamics : ActivationBase + The activation dynamics that will be modeled within the musculotendon. + This must be an instance of a concrete subclass of ``ActivationBase``, + e.g. ``FirstOrderActivationDeGroote2016``. + musculotendon_dynamics : MusculotendonFormulation | int + The formulation of musculotendon dynamics that should be used + internally, i.e. rigid or elastic tendon model, the choice of + musculotendon state etc. This must be a member of the integer + enumeration ``MusculotendonFormulation`` or an integer that can be cast + to a member. To use a rigid tendon formulation, set this to + ``MusculotendonFormulation.RIGID_TENDON`` (or the integer value ``0``, + which will be cast to the enumeration member). There are four possible + formulations for an elastic tendon model. To use an explicit formulation + with the fiber length as the state, set this to + ``MusculotendonFormulation.FIBER_LENGTH_EXPLICIT`` (or the integer value + ``1``). To use an explicit formulation with the tendon force as the + state, set this to ``MusculotendonFormulation.TENDON_FORCE_EXPLICIT`` + (or the integer value ``2``). To use an implicit formulation with the + fiber length as the state, set this to + ``MusculotendonFormulation.FIBER_LENGTH_IMPLICIT`` (or the integer value + ``3``). To use an implicit formulation with the tendon force as the + state, set this to ``MusculotendonFormulation.TENDON_FORCE_IMPLICIT`` + (or the integer value ``4``). The default is + ``MusculotendonFormulation.RIGID_TENDON``, which corresponds to a rigid + tendon formulation. + tendon_slack_length : Expr | None + The length of the tendon when the musculotendon is in its unloaded + state. In a rigid tendon model the tendon length is the tendon slack + length. In all musculotendon models, tendon slack length is used to + normalize tendon length to give + :math:`\tilde{l}^T = \frac{l^T}{l^T_{slack}}`. + peak_isometric_force : Expr | None + The maximum force that the muscle fiber can produce when it is + undergoing an isometric contraction (no lengthening velocity). In all + musculotendon models, peak isometric force is used to normalized tendon + and muscle fiber force to give + :math:`\tilde{F}^T = \frac{F^T}{F^M_{max}}`. + optimal_fiber_length : Expr | None + The muscle fiber length at which the muscle fibers produce no passive + force and their maximum active force. In all musculotendon models, + optimal fiber length is used to normalize muscle fiber length to give + :math:`\tilde{l}^M = \frac{l^M}{l^M_{opt}}`. + maximal_fiber_velocity : Expr | None + The fiber velocity at which, during muscle fiber shortening, the muscle + fibers are unable to produce any active force. In all musculotendon + models, maximal fiber velocity is used to normalize muscle fiber + extension velocity to give :math:`\tilde{v}^M = \frac{v^M}{v^M_{max}}`. + optimal_pennation_angle : Expr | None + The pennation angle when muscle fiber length equals the optimal fiber + length. + fiber_damping_coefficient : Expr | None + The coefficient of damping to be used in the damping element in the + muscle fiber model. + with_defaults : bool + Whether ``with_defaults`` alternate constructors should be used when + automatically constructing child classes. Default is ``False``. + + """ + + def __init__( + self, + name, + pathway, + activation_dynamics, + *, + musculotendon_dynamics=_DEFAULT_MUSCULOTENDON_FORMULATION, + tendon_slack_length=None, + peak_isometric_force=None, + optimal_fiber_length=None, + maximal_fiber_velocity=None, + optimal_pennation_angle=None, + fiber_damping_coefficient=None, + with_defaults=False, + ): + self.name = name + + # Supply a placeholder force to the super initializer, this will be + # replaced later + super().__init__(Symbol('F'), pathway) + + # Activation dynamics + if not isinstance(activation_dynamics, ActivationBase): + msg = ( + f'Can\'t set attribute `activation_dynamics` to ' + f'{activation_dynamics} as it must be of type ' + f'`ActivationBase`, not {type(activation_dynamics)}.' + ) + raise TypeError(msg) + self._activation_dynamics = activation_dynamics + self._child_objects = (self._activation_dynamics, ) + + # Constants + if tendon_slack_length is not None: + self._l_T_slack = tendon_slack_length + else: + self._l_T_slack = Symbol(f'l_T_slack_{self.name}') + if peak_isometric_force is not None: + self._F_M_max = peak_isometric_force + else: + self._F_M_max = Symbol(f'F_M_max_{self.name}') + if optimal_fiber_length is not None: + self._l_M_opt = optimal_fiber_length + else: + self._l_M_opt = Symbol(f'l_M_opt_{self.name}') + if maximal_fiber_velocity is not None: + self._v_M_max = maximal_fiber_velocity + else: + self._v_M_max = Symbol(f'v_M_max_{self.name}') + if optimal_pennation_angle is not None: + self._alpha_opt = optimal_pennation_angle + else: + self._alpha_opt = Symbol(f'alpha_opt_{self.name}') + if fiber_damping_coefficient is not None: + self._beta = fiber_damping_coefficient + else: + self._beta = Symbol(f'beta_{self.name}') + + # Musculotendon dynamics + self._with_defaults = with_defaults + if musculotendon_dynamics == MusculotendonFormulation.RIGID_TENDON: + self._rigid_tendon_musculotendon_dynamics() + elif musculotendon_dynamics == MusculotendonFormulation.FIBER_LENGTH_EXPLICIT: + self._fiber_length_explicit_musculotendon_dynamics() + elif musculotendon_dynamics == MusculotendonFormulation.TENDON_FORCE_EXPLICIT: + self._tendon_force_explicit_musculotendon_dynamics() + elif musculotendon_dynamics == MusculotendonFormulation.FIBER_LENGTH_IMPLICIT: + self._fiber_length_implicit_musculotendon_dynamics() + elif musculotendon_dynamics == MusculotendonFormulation.TENDON_FORCE_IMPLICIT: + self._tendon_force_implicit_musculotendon_dynamics() + else: + msg = ( + f'Musculotendon dynamics {repr(musculotendon_dynamics)} ' + f'passed to `musculotendon_dynamics` was of type ' + f'{type(musculotendon_dynamics)}, must be ' + f'{MusculotendonFormulation}.' + ) + raise TypeError(msg) + self._musculotendon_dynamics = musculotendon_dynamics + + # Must override the placeholder value in `self._force` now that the + # actual force has been calculated by + # `self.__musculotendon_dynamics`. + # Note that `self._force` assumes forces are expansile, musculotendon + # forces are contractile hence the minus sign preceding `self._F_T` + # (the tendon force). + self._force = -self._F_T + + @classmethod + def with_defaults( + cls, + name, + pathway, + activation_dynamics, + *, + musculotendon_dynamics=_DEFAULT_MUSCULOTENDON_FORMULATION, + tendon_slack_length=None, + peak_isometric_force=None, + optimal_fiber_length=None, + maximal_fiber_velocity=Float('10.0'), + optimal_pennation_angle=Float('0.0'), + fiber_damping_coefficient=Float('0.1'), + ): + r"""Recommended constructor that will use the published constants. + + Explanation + =========== + + Returns a new instance of the musculotendon class using recommended + values for ``v_M_max``, ``alpha_opt``, and ``beta``. The values are: + + :math:`v^M_{max} = 10` + :math:`\alpha_{opt} = 0` + :math:`\beta = \frac{1}{10}` + + The musculotendon curves are also instantiated using the constants from + the original publication. + + Parameters + ========== + + name : str + The name identifier associated with the musculotendon. This name is + used as a suffix when automatically generated symbols are + instantiated. It must be a string of nonzero length. + pathway : PathwayBase + The pathway that the actuator follows. This must be an instance of a + concrete subclass of ``PathwayBase``, e.g. ``LinearPathway``. + activation_dynamics : ActivationBase + The activation dynamics that will be modeled within the + musculotendon. This must be an instance of a concrete subclass of + ``ActivationBase``, e.g. ``FirstOrderActivationDeGroote2016``. + musculotendon_dynamics : MusculotendonFormulation | int + The formulation of musculotendon dynamics that should be used + internally, i.e. rigid or elastic tendon model, the choice of + musculotendon state etc. This must be a member of the integer + enumeration ``MusculotendonFormulation`` or an integer that can be + cast to a member. To use a rigid tendon formulation, set this to + ``MusculotendonFormulation.RIGID_TENDON`` (or the integer value + ``0``, which will be cast to the enumeration member). There are four + possible formulations for an elastic tendon model. To use an + explicit formulation with the fiber length as the state, set this to + ``MusculotendonFormulation.FIBER_LENGTH_EXPLICIT`` (or the integer + value ``1``). To use an explicit formulation with the tendon force + as the state, set this to + ``MusculotendonFormulation.TENDON_FORCE_EXPLICIT`` (or the integer + value ``2``). To use an implicit formulation with the fiber length + as the state, set this to + ``MusculotendonFormulation.FIBER_LENGTH_IMPLICIT`` (or the integer + value ``3``). To use an implicit formulation with the tendon force + as the state, set this to + ``MusculotendonFormulation.TENDON_FORCE_IMPLICIT`` (or the integer + value ``4``). The default is + ``MusculotendonFormulation.RIGID_TENDON``, which corresponds to a + rigid tendon formulation. + tendon_slack_length : Expr | None + The length of the tendon when the musculotendon is in its unloaded + state. In a rigid tendon model the tendon length is the tendon slack + length. In all musculotendon models, tendon slack length is used to + normalize tendon length to give + :math:`\tilde{l}^T = \frac{l^T}{l^T_{slack}}`. + peak_isometric_force : Expr | None + The maximum force that the muscle fiber can produce when it is + undergoing an isometric contraction (no lengthening velocity). In + all musculotendon models, peak isometric force is used to normalized + tendon and muscle fiber force to give + :math:`\tilde{F}^T = \frac{F^T}{F^M_{max}}`. + optimal_fiber_length : Expr | None + The muscle fiber length at which the muscle fibers produce no + passive force and their maximum active force. In all musculotendon + models, optimal fiber length is used to normalize muscle fiber + length to give :math:`\tilde{l}^M = \frac{l^M}{l^M_{opt}}`. + maximal_fiber_velocity : Expr | None + The fiber velocity at which, during muscle fiber shortening, the + muscle fibers are unable to produce any active force. In all + musculotendon models, maximal fiber velocity is used to normalize + muscle fiber extension velocity to give + :math:`\tilde{v}^M = \frac{v^M}{v^M_{max}}`. + optimal_pennation_angle : Expr | None + The pennation angle when muscle fiber length equals the optimal + fiber length. + fiber_damping_coefficient : Expr | None + The coefficient of damping to be used in the damping element in the + muscle fiber model. + + """ + return cls( + name, + pathway, + activation_dynamics=activation_dynamics, + musculotendon_dynamics=musculotendon_dynamics, + tendon_slack_length=tendon_slack_length, + peak_isometric_force=peak_isometric_force, + optimal_fiber_length=optimal_fiber_length, + maximal_fiber_velocity=maximal_fiber_velocity, + optimal_pennation_angle=optimal_pennation_angle, + fiber_damping_coefficient=fiber_damping_coefficient, + with_defaults=True, + ) + + @abstractmethod + def curves(cls): + """Return a ``CharacteristicCurveCollection`` of the curves related to + the specific model.""" + pass + + @property + def tendon_slack_length(self): + r"""Symbol or value corresponding to the tendon slack length constant. + + Explanation + =========== + + The length of the tendon when the musculotendon is in its unloaded + state. In a rigid tendon model the tendon length is the tendon slack + length. In all musculotendon models, tendon slack length is used to + normalize tendon length to give + :math:`\tilde{l}^T = \frac{l^T}{l^T_{slack}}`. + + The alias ``l_T_slack`` can also be used to access the same attribute. + + """ + return self._l_T_slack + + @property + def l_T_slack(self): + r"""Symbol or value corresponding to the tendon slack length constant. + + Explanation + =========== + + The length of the tendon when the musculotendon is in its unloaded + state. In a rigid tendon model the tendon length is the tendon slack + length. In all musculotendon models, tendon slack length is used to + normalize tendon length to give + :math:`\tilde{l}^T = \frac{l^T}{l^T_{slack}}`. + + The alias ``tendon_slack_length`` can also be used to access the same + attribute. + + """ + return self._l_T_slack + + @property + def peak_isometric_force(self): + r"""Symbol or value corresponding to the peak isometric force constant. + + Explanation + =========== + + The maximum force that the muscle fiber can produce when it is + undergoing an isometric contraction (no lengthening velocity). In all + musculotendon models, peak isometric force is used to normalized tendon + and muscle fiber force to give + :math:`\tilde{F}^T = \frac{F^T}{F^M_{max}}`. + + The alias ``F_M_max`` can also be used to access the same attribute. + + """ + return self._F_M_max + + @property + def F_M_max(self): + r"""Symbol or value corresponding to the peak isometric force constant. + + Explanation + =========== + + The maximum force that the muscle fiber can produce when it is + undergoing an isometric contraction (no lengthening velocity). In all + musculotendon models, peak isometric force is used to normalized tendon + and muscle fiber force to give + :math:`\tilde{F}^T = \frac{F^T}{F^M_{max}}`. + + The alias ``peak_isometric_force`` can also be used to access the same + attribute. + + """ + return self._F_M_max + + @property + def optimal_fiber_length(self): + r"""Symbol or value corresponding to the optimal fiber length constant. + + Explanation + =========== + + The muscle fiber length at which the muscle fibers produce no passive + force and their maximum active force. In all musculotendon models, + optimal fiber length is used to normalize muscle fiber length to give + :math:`\tilde{l}^M = \frac{l^M}{l^M_{opt}}`. + + The alias ``l_M_opt`` can also be used to access the same attribute. + + """ + return self._l_M_opt + + @property + def l_M_opt(self): + r"""Symbol or value corresponding to the optimal fiber length constant. + + Explanation + =========== + + The muscle fiber length at which the muscle fibers produce no passive + force and their maximum active force. In all musculotendon models, + optimal fiber length is used to normalize muscle fiber length to give + :math:`\tilde{l}^M = \frac{l^M}{l^M_{opt}}`. + + The alias ``optimal_fiber_length`` can also be used to access the same + attribute. + + """ + return self._l_M_opt + + @property + def maximal_fiber_velocity(self): + r"""Symbol or value corresponding to the maximal fiber velocity constant. + + Explanation + =========== + + The fiber velocity at which, during muscle fiber shortening, the muscle + fibers are unable to produce any active force. In all musculotendon + models, maximal fiber velocity is used to normalize muscle fiber + extension velocity to give :math:`\tilde{v}^M = \frac{v^M}{v^M_{max}}`. + + The alias ``v_M_max`` can also be used to access the same attribute. + + """ + return self._v_M_max + + @property + def v_M_max(self): + r"""Symbol or value corresponding to the maximal fiber velocity constant. + + Explanation + =========== + + The fiber velocity at which, during muscle fiber shortening, the muscle + fibers are unable to produce any active force. In all musculotendon + models, maximal fiber velocity is used to normalize muscle fiber + extension velocity to give :math:`\tilde{v}^M = \frac{v^M}{v^M_{max}}`. + + The alias ``maximal_fiber_velocity`` can also be used to access the same + attribute. + + """ + return self._v_M_max + + @property + def optimal_pennation_angle(self): + """Symbol or value corresponding to the optimal pennation angle + constant. + + Explanation + =========== + + The pennation angle when muscle fiber length equals the optimal fiber + length. + + The alias ``alpha_opt`` can also be used to access the same attribute. + + """ + return self._alpha_opt + + @property + def alpha_opt(self): + """Symbol or value corresponding to the optimal pennation angle + constant. + + Explanation + =========== + + The pennation angle when muscle fiber length equals the optimal fiber + length. + + The alias ``optimal_pennation_angle`` can also be used to access the + same attribute. + + """ + return self._alpha_opt + + @property + def fiber_damping_coefficient(self): + """Symbol or value corresponding to the fiber damping coefficient + constant. + + Explanation + =========== + + The coefficient of damping to be used in the damping element in the + muscle fiber model. + + The alias ``beta`` can also be used to access the same attribute. + + """ + return self._beta + + @property + def beta(self): + """Symbol or value corresponding to the fiber damping coefficient + constant. + + Explanation + =========== + + The coefficient of damping to be used in the damping element in the + muscle fiber model. + + The alias ``fiber_damping_coefficient`` can also be used to access the + same attribute. + + """ + return self._beta + + @property + def activation_dynamics(self): + """Activation dynamics model governing this musculotendon's activation. + + Explanation + =========== + + Returns the instance of a subclass of ``ActivationBase`` that governs + the relationship between excitation and activation that is used to + represent the activation dynamics of this musculotendon. + + """ + return self._activation_dynamics + + @property + def excitation(self): + """Dynamic symbol representing excitation. + + Explanation + =========== + + The alias ``e`` can also be used to access the same attribute. + + """ + return self._activation_dynamics._e + + @property + def e(self): + """Dynamic symbol representing excitation. + + Explanation + =========== + + The alias ``excitation`` can also be used to access the same attribute. + + """ + return self._activation_dynamics._e + + @property + def activation(self): + """Dynamic symbol representing activation. + + Explanation + =========== + + The alias ``a`` can also be used to access the same attribute. + + """ + return self._activation_dynamics._a + + @property + def a(self): + """Dynamic symbol representing activation. + + Explanation + =========== + + The alias ``activation`` can also be used to access the same attribute. + + """ + return self._activation_dynamics._a + + @property + def musculotendon_dynamics(self): + """The choice of rigid or type of elastic tendon musculotendon dynamics. + + Explanation + =========== + + The formulation of musculotendon dynamics that should be used + internally, i.e. rigid or elastic tendon model, the choice of + musculotendon state etc. This must be a member of the integer + enumeration ``MusculotendonFormulation`` or an integer that can be cast + to a member. To use a rigid tendon formulation, set this to + ``MusculotendonFormulation.RIGID_TENDON`` (or the integer value ``0``, + which will be cast to the enumeration member). There are four possible + formulations for an elastic tendon model. To use an explicit formulation + with the fiber length as the state, set this to + ``MusculotendonFormulation.FIBER_LENGTH_EXPLICIT`` (or the integer value + ``1``). To use an explicit formulation with the tendon force as the + state, set this to ``MusculotendonFormulation.TENDON_FORCE_EXPLICIT`` + (or the integer value ``2``). To use an implicit formulation with the + fiber length as the state, set this to + ``MusculotendonFormulation.FIBER_LENGTH_IMPLICIT`` (or the integer value + ``3``). To use an implicit formulation with the tendon force as the + state, set this to ``MusculotendonFormulation.TENDON_FORCE_IMPLICIT`` + (or the integer value ``4``). The default is + ``MusculotendonFormulation.RIGID_TENDON``, which corresponds to a rigid + tendon formulation. + + """ + return self._musculotendon_dynamics + + def _rigid_tendon_musculotendon_dynamics(self): + """Rigid tendon musculotendon.""" + self._l_MT = self.pathway.length + self._v_MT = self.pathway.extension_velocity + self._l_T = self._l_T_slack + self._l_T_tilde = Integer(1) + self._l_M = sqrt((self._l_MT - self._l_T)**2 + (self._l_M_opt*sin(self._alpha_opt))**2) + self._l_M_tilde = self._l_M/self._l_M_opt + self._v_M = self._v_MT*(self._l_MT - self._l_T_slack)/self._l_M + self._v_M_tilde = self._v_M/self._v_M_max + if self._with_defaults: + self._fl_T = self.curves.tendon_force_length.with_defaults(self._l_T_tilde) + self._fl_M_pas = self.curves.fiber_force_length_passive.with_defaults(self._l_M_tilde) + self._fl_M_act = self.curves.fiber_force_length_active.with_defaults(self._l_M_tilde) + self._fv_M = self.curves.fiber_force_velocity.with_defaults(self._v_M_tilde) + else: + fl_T_constants = symbols(f'c_0:4_fl_T_{self.name}') + self._fl_T = self.curves.tendon_force_length(self._l_T_tilde, *fl_T_constants) + fl_M_pas_constants = symbols(f'c_0:2_fl_M_pas_{self.name}') + self._fl_M_pas = self.curves.fiber_force_length_passive(self._l_M_tilde, *fl_M_pas_constants) + fl_M_act_constants = symbols(f'c_0:12_fl_M_act_{self.name}') + self._fl_M_act = self.curves.fiber_force_length_active(self._l_M_tilde, *fl_M_act_constants) + fv_M_constants = symbols(f'c_0:4_fv_M_{self.name}') + self._fv_M = self.curves.fiber_force_velocity(self._v_M_tilde, *fv_M_constants) + self._F_M_tilde = self.a*self._fl_M_act*self._fv_M + self._fl_M_pas + self._beta*self._v_M_tilde + self._F_T_tilde = self._F_M_tilde + self._F_M = self._F_M_tilde*self._F_M_max + self._cos_alpha = cos(self._alpha_opt) + self._F_T = self._F_M*self._cos_alpha + + # Containers + self._state_vars = zeros(0, 1) + self._input_vars = zeros(0, 1) + self._state_eqns = zeros(0, 1) + self._curve_constants = Matrix( + fl_T_constants + + fl_M_pas_constants + + fl_M_act_constants + + fv_M_constants + ) if not self._with_defaults else zeros(0, 1) + + def _fiber_length_explicit_musculotendon_dynamics(self): + """Elastic tendon musculotendon using `l_M_tilde` as a state.""" + self._l_M_tilde = dynamicsymbols(f'l_M_tilde_{self.name}') + self._l_MT = self.pathway.length + self._v_MT = self.pathway.extension_velocity + self._l_M = self._l_M_tilde*self._l_M_opt + self._l_T = self._l_MT - sqrt(self._l_M**2 - (self._l_M_opt*sin(self._alpha_opt))**2) + self._l_T_tilde = self._l_T/self._l_T_slack + self._cos_alpha = (self._l_MT - self._l_T)/self._l_M + if self._with_defaults: + self._fl_T = self.curves.tendon_force_length.with_defaults(self._l_T_tilde) + self._fl_M_pas = self.curves.fiber_force_length_passive.with_defaults(self._l_M_tilde) + self._fl_M_act = self.curves.fiber_force_length_active.with_defaults(self._l_M_tilde) + else: + fl_T_constants = symbols(f'c_0:4_fl_T_{self.name}') + self._fl_T = self.curves.tendon_force_length(self._l_T_tilde, *fl_T_constants) + fl_M_pas_constants = symbols(f'c_0:2_fl_M_pas_{self.name}') + self._fl_M_pas = self.curves.fiber_force_length_passive(self._l_M_tilde, *fl_M_pas_constants) + fl_M_act_constants = symbols(f'c_0:12_fl_M_act_{self.name}') + self._fl_M_act = self.curves.fiber_force_length_active(self._l_M_tilde, *fl_M_act_constants) + self._F_T_tilde = self._fl_T + self._F_T = self._F_T_tilde*self._F_M_max + self._F_M = self._F_T/self._cos_alpha + self._F_M_tilde = self._F_M/self._F_M_max + self._fv_M = (self._F_M_tilde - self._fl_M_pas)/(self.a*self._fl_M_act) + if self._with_defaults: + self._v_M_tilde = self.curves.fiber_force_velocity_inverse.with_defaults(self._fv_M) + else: + fv_M_constants = symbols(f'c_0:4_fv_M_{self.name}') + self._v_M_tilde = self.curves.fiber_force_velocity_inverse(self._fv_M, *fv_M_constants) + self._dl_M_tilde_dt = (self._v_M_max/self._l_M_opt)*self._v_M_tilde + + self._state_vars = Matrix([self._l_M_tilde]) + self._input_vars = zeros(0, 1) + self._state_eqns = Matrix([self._dl_M_tilde_dt]) + self._curve_constants = Matrix( + fl_T_constants + + fl_M_pas_constants + + fl_M_act_constants + + fv_M_constants + ) if not self._with_defaults else zeros(0, 1) + + def _tendon_force_explicit_musculotendon_dynamics(self): + """Elastic tendon musculotendon using `F_T_tilde` as a state.""" + self._F_T_tilde = dynamicsymbols(f'F_T_tilde_{self.name}') + self._l_MT = self.pathway.length + self._v_MT = self.pathway.extension_velocity + self._fl_T = self._F_T_tilde + if self._with_defaults: + self._fl_T_inv = self.curves.tendon_force_length_inverse.with_defaults(self._fl_T) + else: + fl_T_constants = symbols(f'c_0:4_fl_T_{self.name}') + self._fl_T_inv = self.curves.tendon_force_length_inverse(self._fl_T, *fl_T_constants) + self._l_T_tilde = self._fl_T_inv + self._l_T = self._l_T_tilde*self._l_T_slack + self._l_M = sqrt((self._l_MT - self._l_T)**2 + (self._l_M_opt*sin(self._alpha_opt))**2) + self._l_M_tilde = self._l_M/self._l_M_opt + if self._with_defaults: + self._fl_M_pas = self.curves.fiber_force_length_passive.with_defaults(self._l_M_tilde) + self._fl_M_act = self.curves.fiber_force_length_active.with_defaults(self._l_M_tilde) + else: + fl_M_pas_constants = symbols(f'c_0:2_fl_M_pas_{self.name}') + self._fl_M_pas = self.curves.fiber_force_length_passive(self._l_M_tilde, *fl_M_pas_constants) + fl_M_act_constants = symbols(f'c_0:12_fl_M_act_{self.name}') + self._fl_M_act = self.curves.fiber_force_length_active(self._l_M_tilde, *fl_M_act_constants) + self._cos_alpha = (self._l_MT - self._l_T)/self._l_M + self._F_T = self._F_T_tilde*self._F_M_max + self._F_M = self._F_T/self._cos_alpha + self._F_M_tilde = self._F_M/self._F_M_max + self._fv_M = (self._F_M_tilde - self._fl_M_pas)/(self.a*self._fl_M_act) + if self._with_defaults: + self._fv_M_inv = self.curves.fiber_force_velocity_inverse.with_defaults(self._fv_M) + else: + fv_M_constants = symbols(f'c_0:4_fv_M_{self.name}') + self._fv_M_inv = self.curves.fiber_force_velocity_inverse(self._fv_M, *fv_M_constants) + self._v_M_tilde = self._fv_M_inv + self._v_M = self._v_M_tilde*self._v_M_max + self._v_T = self._v_MT - (self._v_M/self._cos_alpha) + self._v_T_tilde = self._v_T/self._l_T_slack + if self._with_defaults: + self._fl_T = self.curves.tendon_force_length.with_defaults(self._l_T_tilde) + else: + self._fl_T = self.curves.tendon_force_length(self._l_T_tilde, *fl_T_constants) + self._dF_T_tilde_dt = self._fl_T.diff(dynamicsymbols._t).subs({self._l_T_tilde.diff(dynamicsymbols._t): self._v_T_tilde}) + + self._state_vars = Matrix([self._F_T_tilde]) + self._input_vars = zeros(0, 1) + self._state_eqns = Matrix([self._dF_T_tilde_dt]) + self._curve_constants = Matrix( + fl_T_constants + + fl_M_pas_constants + + fl_M_act_constants + + fv_M_constants + ) if not self._with_defaults else zeros(0, 1) + + def _fiber_length_implicit_musculotendon_dynamics(self): + raise NotImplementedError + + def _tendon_force_implicit_musculotendon_dynamics(self): + raise NotImplementedError + + @property + def state_vars(self): + """Ordered column matrix of functions of time that represent the state + variables. + + Explanation + =========== + + The alias ``x`` can also be used to access the same attribute. + + """ + state_vars = [self._state_vars] + for child in self._child_objects: + state_vars.append(child.state_vars) + return Matrix.vstack(*state_vars) + + @property + def x(self): + """Ordered column matrix of functions of time that represent the state + variables. + + Explanation + =========== + + The alias ``state_vars`` can also be used to access the same attribute. + + """ + state_vars = [self._state_vars] + for child in self._child_objects: + state_vars.append(child.state_vars) + return Matrix.vstack(*state_vars) + + @property + def input_vars(self): + """Ordered column matrix of functions of time that represent the input + variables. + + Explanation + =========== + + The alias ``r`` can also be used to access the same attribute. + + """ + input_vars = [self._input_vars] + for child in self._child_objects: + input_vars.append(child.input_vars) + return Matrix.vstack(*input_vars) + + @property + def r(self): + """Ordered column matrix of functions of time that represent the input + variables. + + Explanation + =========== + + The alias ``input_vars`` can also be used to access the same attribute. + + """ + input_vars = [self._input_vars] + for child in self._child_objects: + input_vars.append(child.input_vars) + return Matrix.vstack(*input_vars) + + @property + def constants(self): + """Ordered column matrix of non-time varying symbols present in ``M`` + and ``F``. + + Explanation + =========== + + Only symbolic constants are returned. If a numeric type (e.g. ``Float``) + has been used instead of ``Symbol`` for a constant then that attribute + will not be included in the matrix returned by this property. This is + because the primary use of this property attribute is to provide an + ordered sequence of the still-free symbols that require numeric values + during code generation. + + The alias ``p`` can also be used to access the same attribute. + + """ + musculotendon_constants = [ + self._l_T_slack, + self._F_M_max, + self._l_M_opt, + self._v_M_max, + self._alpha_opt, + self._beta, + ] + musculotendon_constants = [ + c for c in musculotendon_constants if not c.is_number + ] + constants = [ + Matrix(musculotendon_constants) + if musculotendon_constants + else zeros(0, 1) + ] + for child in self._child_objects: + constants.append(child.constants) + constants.append(self._curve_constants) + return Matrix.vstack(*constants) + + @property + def p(self): + """Ordered column matrix of non-time varying symbols present in ``M`` + and ``F``. + + Explanation + =========== + + Only symbolic constants are returned. If a numeric type (e.g. ``Float``) + has been used instead of ``Symbol`` for a constant then that attribute + will not be included in the matrix returned by this property. This is + because the primary use of this property attribute is to provide an + ordered sequence of the still-free symbols that require numeric values + during code generation. + + The alias ``constants`` can also be used to access the same attribute. + + """ + musculotendon_constants = [ + self._l_T_slack, + self._F_M_max, + self._l_M_opt, + self._v_M_max, + self._alpha_opt, + self._beta, + ] + musculotendon_constants = [ + c for c in musculotendon_constants if not c.is_number + ] + constants = [ + Matrix(musculotendon_constants) + if musculotendon_constants + else zeros(0, 1) + ] + for child in self._child_objects: + constants.append(child.constants) + constants.append(self._curve_constants) + return Matrix.vstack(*constants) + + @property + def M(self): + """Ordered square matrix of coefficients on the LHS of ``M x' = F``. + + Explanation + =========== + + The square matrix that forms part of the LHS of the linear system of + ordinary differential equations governing the activation dynamics: + + ``M(x, r, t, p) x' = F(x, r, t, p)``. + + As zeroth-order activation dynamics have no state variables, this + linear system has dimension 0 and therefore ``M`` is an empty square + ``Matrix`` with shape (0, 0). + + """ + M = [eye(len(self._state_vars))] + for child in self._child_objects: + M.append(child.M) + return diag(*M) + + @property + def F(self): + """Ordered column matrix of equations on the RHS of ``M x' = F``. + + Explanation + =========== + + The column matrix that forms the RHS of the linear system of ordinary + differential equations governing the activation dynamics: + + ``M(x, r, t, p) x' = F(x, r, t, p)``. + + As zeroth-order activation dynamics have no state variables, this + linear system has dimension 0 and therefore ``F`` is an empty column + ``Matrix`` with shape (0, 1). + + """ + F = [self._state_eqns] + for child in self._child_objects: + F.append(child.F) + return Matrix.vstack(*F) + + def rhs(self): + """Ordered column matrix of equations for the solution of ``M x' = F``. + + Explanation + =========== + + The solution to the linear system of ordinary differential equations + governing the activation dynamics: + + ``M(x, r, t, p) x' = F(x, r, t, p)``. + + As zeroth-order activation dynamics have no state variables, this + linear has dimension 0 and therefore this method returns an empty + column ``Matrix`` with shape (0, 1). + + """ + is_explicit = ( + MusculotendonFormulation.FIBER_LENGTH_EXPLICIT, + MusculotendonFormulation.TENDON_FORCE_EXPLICIT, + ) + if self.musculotendon_dynamics is MusculotendonFormulation.RIGID_TENDON: + child_rhs = [child.rhs() for child in self._child_objects] + return Matrix.vstack(*child_rhs) + elif self.musculotendon_dynamics in is_explicit: + rhs = self._state_eqns + child_rhs = [child.rhs() for child in self._child_objects] + return Matrix.vstack(rhs, *child_rhs) + return self.M.solve(self.F) + + def __repr__(self): + """Returns a string representation to reinstantiate the model.""" + return ( + f'{self.__class__.__name__}({self.name!r}, ' + f'pathway={self.pathway!r}, ' + f'activation_dynamics={self.activation_dynamics!r}, ' + f'musculotendon_dynamics={self.musculotendon_dynamics}, ' + f'tendon_slack_length={self._l_T_slack!r}, ' + f'peak_isometric_force={self._F_M_max!r}, ' + f'optimal_fiber_length={self._l_M_opt!r}, ' + f'maximal_fiber_velocity={self._v_M_max!r}, ' + f'optimal_pennation_angle={self._alpha_opt!r}, ' + f'fiber_damping_coefficient={self._beta!r})' + ) + + def __str__(self): + """Returns a string representation of the expression for musculotendon + force.""" + return str(self.force) + + +class MusculotendonDeGroote2016(MusculotendonBase): + r"""Musculotendon model using the curves of De Groote et al., 2016 [1]_. + + Examples + ======== + + This class models the musculotendon actuator parametrized by the + characteristic curves described in De Groote et al., 2016 [1]_. Like all + musculotendon models in SymPy's biomechanics module, it requires a pathway + to define its line of action. We'll begin by creating a simple + ``LinearPathway`` between two points that our musculotendon will follow. + We'll create a point ``O`` to represent the musculotendon's origin and + another ``I`` to represent its insertion. + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import (LinearPathway, Point, + ... ReferenceFrame, dynamicsymbols) + + >>> N = ReferenceFrame('N') + >>> O, I = O, P = symbols('O, I', cls=Point) + >>> q, u = dynamicsymbols('q, u', real=True) + >>> I.set_pos(O, q*N.x) + >>> O.set_vel(N, 0) + >>> I.set_vel(N, u*N.x) + >>> pathway = LinearPathway(O, I) + >>> pathway.attachments + (O, I) + >>> pathway.length + Abs(q(t)) + >>> pathway.extension_velocity + sign(q(t))*Derivative(q(t), t) + + A musculotendon also takes an instance of an activation dynamics model as + this will be used to provide symbols for the activation in the formulation + of the musculotendon dynamics. We'll use an instance of + ``FirstOrderActivationDeGroote2016`` to represent first-order activation + dynamics. Note that a single name argument needs to be provided as SymPy + will use this as a suffix. + + >>> from sympy.physics.biomechanics import FirstOrderActivationDeGroote2016 + + >>> activation = FirstOrderActivationDeGroote2016('muscle') + >>> activation.x + Matrix([[a_muscle(t)]]) + >>> activation.r + Matrix([[e_muscle(t)]]) + >>> activation.p + Matrix([ + [tau_a_muscle], + [tau_d_muscle], + [ b_muscle]]) + >>> activation.rhs() + Matrix([[((1/2 - tanh(b_muscle*(-a_muscle(t) + e_muscle(t)))/2)*(3*...]]) + + The musculotendon class requires symbols or values to be passed to represent + the constants in the musculotendon dynamics. We'll use SymPy's ``symbols`` + function to create symbols for the maximum isometric force ``F_M_max``, + optimal fiber length ``l_M_opt``, tendon slack length ``l_T_slack``, maximum + fiber velocity ``v_M_max``, optimal pennation angle ``alpha_opt, and fiber + damping coefficient ``beta``. + + >>> F_M_max = symbols('F_M_max', real=True) + >>> l_M_opt = symbols('l_M_opt', real=True) + >>> l_T_slack = symbols('l_T_slack', real=True) + >>> v_M_max = symbols('v_M_max', real=True) + >>> alpha_opt = symbols('alpha_opt', real=True) + >>> beta = symbols('beta', real=True) + + We can then import the class ``MusculotendonDeGroote2016`` from the + biomechanics module and create an instance by passing in the various objects + we have previously instantiated. By default, a musculotendon model with + rigid tendon musculotendon dynamics will be created. + + >>> from sympy.physics.biomechanics import MusculotendonDeGroote2016 + + >>> rigid_tendon_muscle = MusculotendonDeGroote2016( + ... 'muscle', + ... pathway, + ... activation, + ... tendon_slack_length=l_T_slack, + ... peak_isometric_force=F_M_max, + ... optimal_fiber_length=l_M_opt, + ... maximal_fiber_velocity=v_M_max, + ... optimal_pennation_angle=alpha_opt, + ... fiber_damping_coefficient=beta, + ... ) + + We can inspect the various properties of the musculotendon, including + getting the symbolic expression describing the force it produces using its + ``force`` attribute. + + >>> rigid_tendon_muscle.force + -F_M_max*(beta*(-l_T_slack + Abs(q(t)))*sign(q(t))*Derivative(q(t), t)... + + When we created the musculotendon object, we passed in an instance of an + activation dynamics object that governs the activation within the + musculotendon. SymPy makes a design choice here that the activation dynamics + instance will be treated as a child object of the musculotendon dynamics. + Therefore, if we want to inspect the state and input variables associated + with the musculotendon model, we will also be returned the state and input + variables associated with the child object, or the activation dynamics in + this case. As the musculotendon model that we created here uses rigid tendon + dynamics, no additional states or inputs relating to the musculotendon are + introduces. Consequently, the model has a single state associated with it, + the activation, and a single input associated with it, the excitation. The + states and inputs can be inspected using the ``x`` and ``r`` attributes + respectively. Note that both ``x`` and ``r`` have the alias attributes of + ``state_vars`` and ``input_vars``. + + >>> rigid_tendon_muscle.x + Matrix([[a_muscle(t)]]) + >>> rigid_tendon_muscle.r + Matrix([[e_muscle(t)]]) + + To see which constants are symbolic in the musculotendon model, we can use + the ``p`` or ``constants`` attribute. This returns a ``Matrix`` populated + by the constants that are represented by a ``Symbol`` rather than a numeric + value. + + >>> rigid_tendon_muscle.p + Matrix([ + [ l_T_slack], + [ F_M_max], + [ l_M_opt], + [ v_M_max], + [ alpha_opt], + [ beta], + [ tau_a_muscle], + [ tau_d_muscle], + [ b_muscle], + [ c_0_fl_T_muscle], + [ c_1_fl_T_muscle], + [ c_2_fl_T_muscle], + [ c_3_fl_T_muscle], + [ c_0_fl_M_pas_muscle], + [ c_1_fl_M_pas_muscle], + [ c_0_fl_M_act_muscle], + [ c_1_fl_M_act_muscle], + [ c_2_fl_M_act_muscle], + [ c_3_fl_M_act_muscle], + [ c_4_fl_M_act_muscle], + [ c_5_fl_M_act_muscle], + [ c_6_fl_M_act_muscle], + [ c_7_fl_M_act_muscle], + [ c_8_fl_M_act_muscle], + [ c_9_fl_M_act_muscle], + [c_10_fl_M_act_muscle], + [c_11_fl_M_act_muscle], + [ c_0_fv_M_muscle], + [ c_1_fv_M_muscle], + [ c_2_fv_M_muscle], + [ c_3_fv_M_muscle]]) + + Finally, we can call the ``rhs`` method to return a ``Matrix`` that + contains as its elements the righthand side of the ordinary differential + equations corresponding to each of the musculotendon's states. Like the + method with the same name on the ``Method`` classes in SymPy's mechanics + module, this returns a column vector where the number of rows corresponds to + the number of states. For our example here, we have a single state, the + dynamic symbol ``a_muscle(t)``, so the returned value is a 1-by-1 + ``Matrix``. + + >>> rigid_tendon_muscle.rhs() + Matrix([[((1/2 - tanh(b_muscle*(-a_muscle(t) + e_muscle(t)))/2)*(3*...]]) + + The musculotendon class supports elastic tendon musculotendon models in + addition to rigid tendon ones. You can choose to either use the fiber length + or tendon force as an additional state. You can also specify whether an + explicit or implicit formulation should be used. To select a formulation, + pass a member of the ``MusculotendonFormulation`` enumeration to the + ``musculotendon_dynamics`` parameter when calling the constructor. This + enumeration is an ``IntEnum``, so you can also pass an integer, however it + is recommended to use the enumeration as it is clearer which formulation you + are actually selecting. Below, we'll use the ``FIBER_LENGTH_EXPLICIT`` + member to create a musculotendon with an elastic tendon that will use the + (normalized) muscle fiber length as an additional state and will produce + the governing ordinary differential equation in explicit form. + + >>> from sympy.physics.biomechanics import MusculotendonFormulation + + >>> elastic_tendon_muscle = MusculotendonDeGroote2016( + ... 'muscle', + ... pathway, + ... activation, + ... musculotendon_dynamics=MusculotendonFormulation.FIBER_LENGTH_EXPLICIT, + ... tendon_slack_length=l_T_slack, + ... peak_isometric_force=F_M_max, + ... optimal_fiber_length=l_M_opt, + ... maximal_fiber_velocity=v_M_max, + ... optimal_pennation_angle=alpha_opt, + ... fiber_damping_coefficient=beta, + ... ) + + >>> elastic_tendon_muscle.force + -F_M_max*TendonForceLengthDeGroote2016((-sqrt(l_M_opt**2*... + >>> elastic_tendon_muscle.x + Matrix([ + [l_M_tilde_muscle(t)], + [ a_muscle(t)]]) + >>> elastic_tendon_muscle.r + Matrix([[e_muscle(t)]]) + >>> elastic_tendon_muscle.p + Matrix([ + [ l_T_slack], + [ F_M_max], + [ l_M_opt], + [ v_M_max], + [ alpha_opt], + [ beta], + [ tau_a_muscle], + [ tau_d_muscle], + [ b_muscle], + [ c_0_fl_T_muscle], + [ c_1_fl_T_muscle], + [ c_2_fl_T_muscle], + [ c_3_fl_T_muscle], + [ c_0_fl_M_pas_muscle], + [ c_1_fl_M_pas_muscle], + [ c_0_fl_M_act_muscle], + [ c_1_fl_M_act_muscle], + [ c_2_fl_M_act_muscle], + [ c_3_fl_M_act_muscle], + [ c_4_fl_M_act_muscle], + [ c_5_fl_M_act_muscle], + [ c_6_fl_M_act_muscle], + [ c_7_fl_M_act_muscle], + [ c_8_fl_M_act_muscle], + [ c_9_fl_M_act_muscle], + [c_10_fl_M_act_muscle], + [c_11_fl_M_act_muscle], + [ c_0_fv_M_muscle], + [ c_1_fv_M_muscle], + [ c_2_fv_M_muscle], + [ c_3_fv_M_muscle]]) + >>> elastic_tendon_muscle.rhs() + Matrix([ + [v_M_max*FiberForceVelocityInverseDeGroote2016((l_M_opt*...], + [ ((1/2 - tanh(b_muscle*(-a_muscle(t) + e_muscle(t)))/2)*(3*...]]) + + It is strongly recommended to use the alternate ``with_defaults`` + constructor when creating an instance because this will ensure that the + published constants are used in the musculotendon characteristic curves. + + >>> elastic_tendon_muscle = MusculotendonDeGroote2016.with_defaults( + ... 'muscle', + ... pathway, + ... activation, + ... musculotendon_dynamics=MusculotendonFormulation.FIBER_LENGTH_EXPLICIT, + ... tendon_slack_length=l_T_slack, + ... peak_isometric_force=F_M_max, + ... optimal_fiber_length=l_M_opt, + ... ) + + >>> elastic_tendon_muscle.x + Matrix([ + [l_M_tilde_muscle(t)], + [ a_muscle(t)]]) + >>> elastic_tendon_muscle.r + Matrix([[e_muscle(t)]]) + >>> elastic_tendon_muscle.p + Matrix([ + [ l_T_slack], + [ F_M_max], + [ l_M_opt], + [tau_a_muscle], + [tau_d_muscle], + [ b_muscle]]) + + Parameters + ========== + + name : str + The name identifier associated with the musculotendon. This name is used + as a suffix when automatically generated symbols are instantiated. It + must be a string of nonzero length. + pathway : PathwayBase + The pathway that the actuator follows. This must be an instance of a + concrete subclass of ``PathwayBase``, e.g. ``LinearPathway``. + activation_dynamics : ActivationBase + The activation dynamics that will be modeled within the musculotendon. + This must be an instance of a concrete subclass of ``ActivationBase``, + e.g. ``FirstOrderActivationDeGroote2016``. + musculotendon_dynamics : MusculotendonFormulation | int + The formulation of musculotendon dynamics that should be used + internally, i.e. rigid or elastic tendon model, the choice of + musculotendon state etc. This must be a member of the integer + enumeration ``MusculotendonFormulation`` or an integer that can be cast + to a member. To use a rigid tendon formulation, set this to + ``MusculotendonFormulation.RIGID_TENDON`` (or the integer value ``0``, + which will be cast to the enumeration member). There are four possible + formulations for an elastic tendon model. To use an explicit formulation + with the fiber length as the state, set this to + ``MusculotendonFormulation.FIBER_LENGTH_EXPLICIT`` (or the integer value + ``1``). To use an explicit formulation with the tendon force as the + state, set this to ``MusculotendonFormulation.TENDON_FORCE_EXPLICIT`` + (or the integer value ``2``). To use an implicit formulation with the + fiber length as the state, set this to + ``MusculotendonFormulation.FIBER_LENGTH_IMPLICIT`` (or the integer value + ``3``). To use an implicit formulation with the tendon force as the + state, set this to ``MusculotendonFormulation.TENDON_FORCE_IMPLICIT`` + (or the integer value ``4``). The default is + ``MusculotendonFormulation.RIGID_TENDON``, which corresponds to a rigid + tendon formulation. + tendon_slack_length : Expr | None + The length of the tendon when the musculotendon is in its unloaded + state. In a rigid tendon model the tendon length is the tendon slack + length. In all musculotendon models, tendon slack length is used to + normalize tendon length to give + :math:`\tilde{l}^T = \frac{l^T}{l^T_{slack}}`. + peak_isometric_force : Expr | None + The maximum force that the muscle fiber can produce when it is + undergoing an isometric contraction (no lengthening velocity). In all + musculotendon models, peak isometric force is used to normalized tendon + and muscle fiber force to give + :math:`\tilde{F}^T = \frac{F^T}{F^M_{max}}`. + optimal_fiber_length : Expr | None + The muscle fiber length at which the muscle fibers produce no passive + force and their maximum active force. In all musculotendon models, + optimal fiber length is used to normalize muscle fiber length to give + :math:`\tilde{l}^M = \frac{l^M}{l^M_{opt}}`. + maximal_fiber_velocity : Expr | None + The fiber velocity at which, during muscle fiber shortening, the muscle + fibers are unable to produce any active force. In all musculotendon + models, maximal fiber velocity is used to normalize muscle fiber + extension velocity to give :math:`\tilde{v}^M = \frac{v^M}{v^M_{max}}`. + optimal_pennation_angle : Expr | None + The pennation angle when muscle fiber length equals the optimal fiber + length. + fiber_damping_coefficient : Expr | None + The coefficient of damping to be used in the damping element in the + muscle fiber model. + with_defaults : bool + Whether ``with_defaults`` alternate constructors should be used when + automatically constructing child classes. Default is ``False``. + + References + ========== + + .. [1] De Groote, F., Kinney, A. L., Rao, A. V., & Fregly, B. J., Evaluation + of direct collocation optimal control problem formulations for + solving the muscle redundancy problem, Annals of biomedical + engineering, 44(10), (2016) pp. 2922-2936 + + """ + + curves = CharacteristicCurveCollection( + tendon_force_length=TendonForceLengthDeGroote2016, + tendon_force_length_inverse=TendonForceLengthInverseDeGroote2016, + fiber_force_length_passive=FiberForceLengthPassiveDeGroote2016, + fiber_force_length_passive_inverse=FiberForceLengthPassiveInverseDeGroote2016, + fiber_force_length_active=FiberForceLengthActiveDeGroote2016, + fiber_force_velocity=FiberForceVelocityDeGroote2016, + fiber_force_velocity_inverse=FiberForceVelocityInverseDeGroote2016, + ) diff --git a/phivenv/Lib/site-packages/sympy/physics/biomechanics/tests/__init__.py b/phivenv/Lib/site-packages/sympy/physics/biomechanics/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/physics/biomechanics/tests/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/biomechanics/tests/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d38cb1cd2a2bd7a4024e203800105e64c2a59db Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/biomechanics/tests/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/biomechanics/tests/__pycache__/test_activation.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/biomechanics/tests/__pycache__/test_activation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9a7b36949700dc32dec083ed2f8df04fb62e7e7 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/biomechanics/tests/__pycache__/test_activation.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/biomechanics/tests/__pycache__/test_curve.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/biomechanics/tests/__pycache__/test_curve.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df76cf20acbdfa9105c5eae9328774edd367ece2 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/biomechanics/tests/__pycache__/test_curve.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/biomechanics/tests/__pycache__/test_mixin.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/biomechanics/tests/__pycache__/test_mixin.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28ace7b7ae0a6d501ee8fce58e6e8628737dd04e Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/biomechanics/tests/__pycache__/test_mixin.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/biomechanics/tests/__pycache__/test_musculotendon.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/biomechanics/tests/__pycache__/test_musculotendon.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..61be900fa229bcbbdba30309efd26a1332f6899f Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/biomechanics/tests/__pycache__/test_musculotendon.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/biomechanics/tests/test_activation.py b/phivenv/Lib/site-packages/sympy/physics/biomechanics/tests/test_activation.py new file mode 100644 index 0000000000000000000000000000000000000000..a38742f0d42af48dff95295eae869b2c5ef269de --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/biomechanics/tests/test_activation.py @@ -0,0 +1,348 @@ +"""Tests for the ``sympy.physics.biomechanics.activation.py`` module.""" + +import pytest + +from sympy import Symbol +from sympy.core.numbers import Float, Integer, Rational +from sympy.functions.elementary.hyperbolic import tanh +from sympy.matrices import Matrix +from sympy.matrices.dense import zeros +from sympy.physics.mechanics import dynamicsymbols +from sympy.physics.biomechanics import ( + ActivationBase, + FirstOrderActivationDeGroote2016, + ZerothOrderActivation, +) +from sympy.physics.biomechanics._mixin import _NamedMixin +from sympy.simplify.simplify import simplify + + +class TestZerothOrderActivation: + + @staticmethod + def test_class(): + assert issubclass(ZerothOrderActivation, ActivationBase) + assert issubclass(ZerothOrderActivation, _NamedMixin) + assert ZerothOrderActivation.__name__ == 'ZerothOrderActivation' + + @pytest.fixture(autouse=True) + def _zeroth_order_activation_fixture(self): + self.name = 'name' + self.e = dynamicsymbols('e_name') + self.instance = ZerothOrderActivation(self.name) + + def test_instance(self): + instance = ZerothOrderActivation(self.name) + assert isinstance(instance, ZerothOrderActivation) + + def test_with_defaults(self): + instance = ZerothOrderActivation.with_defaults(self.name) + assert isinstance(instance, ZerothOrderActivation) + assert instance == ZerothOrderActivation(self.name) + + def test_name(self): + assert hasattr(self.instance, 'name') + assert self.instance.name == self.name + + def test_order(self): + assert hasattr(self.instance, 'order') + assert self.instance.order == 0 + + def test_excitation_attribute(self): + assert hasattr(self.instance, 'e') + assert hasattr(self.instance, 'excitation') + e_expected = dynamicsymbols('e_name') + assert self.instance.e == e_expected + assert self.instance.excitation == e_expected + assert self.instance.e is self.instance.excitation + + def test_activation_attribute(self): + assert hasattr(self.instance, 'a') + assert hasattr(self.instance, 'activation') + a_expected = dynamicsymbols('e_name') + assert self.instance.a == a_expected + assert self.instance.activation == a_expected + assert self.instance.a is self.instance.activation is self.instance.e + + def test_state_vars_attribute(self): + assert hasattr(self.instance, 'x') + assert hasattr(self.instance, 'state_vars') + assert self.instance.x == self.instance.state_vars + x_expected = zeros(0, 1) + assert self.instance.x == x_expected + assert self.instance.state_vars == x_expected + assert isinstance(self.instance.x, Matrix) + assert isinstance(self.instance.state_vars, Matrix) + assert self.instance.x.shape == (0, 1) + assert self.instance.state_vars.shape == (0, 1) + + def test_input_vars_attribute(self): + assert hasattr(self.instance, 'r') + assert hasattr(self.instance, 'input_vars') + assert self.instance.r == self.instance.input_vars + r_expected = Matrix([self.e]) + assert self.instance.r == r_expected + assert self.instance.input_vars == r_expected + assert isinstance(self.instance.r, Matrix) + assert isinstance(self.instance.input_vars, Matrix) + assert self.instance.r.shape == (1, 1) + assert self.instance.input_vars.shape == (1, 1) + + def test_constants_attribute(self): + assert hasattr(self.instance, 'p') + assert hasattr(self.instance, 'constants') + assert self.instance.p == self.instance.constants + p_expected = zeros(0, 1) + assert self.instance.p == p_expected + assert self.instance.constants == p_expected + assert isinstance(self.instance.p, Matrix) + assert isinstance(self.instance.constants, Matrix) + assert self.instance.p.shape == (0, 1) + assert self.instance.constants.shape == (0, 1) + + def test_M_attribute(self): + assert hasattr(self.instance, 'M') + M_expected = Matrix([]) + assert self.instance.M == M_expected + assert isinstance(self.instance.M, Matrix) + assert self.instance.M.shape == (0, 0) + + def test_F(self): + assert hasattr(self.instance, 'F') + F_expected = zeros(0, 1) + assert self.instance.F == F_expected + assert isinstance(self.instance.F, Matrix) + assert self.instance.F.shape == (0, 1) + + def test_rhs(self): + assert hasattr(self.instance, 'rhs') + rhs_expected = zeros(0, 1) + rhs = self.instance.rhs() + assert rhs == rhs_expected + assert isinstance(rhs, Matrix) + assert rhs.shape == (0, 1) + + def test_repr(self): + expected = 'ZerothOrderActivation(\'name\')' + assert repr(self.instance) == expected + + +class TestFirstOrderActivationDeGroote2016: + + @staticmethod + def test_class(): + assert issubclass(FirstOrderActivationDeGroote2016, ActivationBase) + assert issubclass(FirstOrderActivationDeGroote2016, _NamedMixin) + assert FirstOrderActivationDeGroote2016.__name__ == 'FirstOrderActivationDeGroote2016' + + @pytest.fixture(autouse=True) + def _first_order_activation_de_groote_2016_fixture(self): + self.name = 'name' + self.e = dynamicsymbols('e_name') + self.a = dynamicsymbols('a_name') + self.tau_a = Symbol('tau_a') + self.tau_d = Symbol('tau_d') + self.b = Symbol('b') + self.instance = FirstOrderActivationDeGroote2016( + self.name, + self.tau_a, + self.tau_d, + self.b, + ) + + def test_instance(self): + instance = FirstOrderActivationDeGroote2016(self.name) + assert isinstance(instance, FirstOrderActivationDeGroote2016) + + def test_with_defaults(self): + instance = FirstOrderActivationDeGroote2016.with_defaults(self.name) + assert isinstance(instance, FirstOrderActivationDeGroote2016) + assert instance.tau_a == Float('0.015') + assert instance.activation_time_constant == Float('0.015') + assert instance.tau_d == Float('0.060') + assert instance.deactivation_time_constant == Float('0.060') + assert instance.b == Float('10.0') + assert instance.smoothing_rate == Float('10.0') + + def test_name(self): + assert hasattr(self.instance, 'name') + assert self.instance.name == self.name + + def test_order(self): + assert hasattr(self.instance, 'order') + assert self.instance.order == 1 + + def test_excitation(self): + assert hasattr(self.instance, 'e') + assert hasattr(self.instance, 'excitation') + e_expected = dynamicsymbols('e_name') + assert self.instance.e == e_expected + assert self.instance.excitation == e_expected + assert self.instance.e is self.instance.excitation + + def test_excitation_is_immutable(self): + with pytest.raises(AttributeError): + self.instance.e = None + with pytest.raises(AttributeError): + self.instance.excitation = None + + def test_activation(self): + assert hasattr(self.instance, 'a') + assert hasattr(self.instance, 'activation') + a_expected = dynamicsymbols('a_name') + assert self.instance.a == a_expected + assert self.instance.activation == a_expected + + def test_activation_is_immutable(self): + with pytest.raises(AttributeError): + self.instance.a = None + with pytest.raises(AttributeError): + self.instance.activation = None + + @pytest.mark.parametrize( + 'tau_a, expected', + [ + (None, Symbol('tau_a_name')), + (Symbol('tau_a'), Symbol('tau_a')), + (Float('0.015'), Float('0.015')), + ] + ) + def test_activation_time_constant(self, tau_a, expected): + instance = FirstOrderActivationDeGroote2016( + 'name', activation_time_constant=tau_a, + ) + assert instance.tau_a == expected + assert instance.activation_time_constant == expected + assert instance.tau_a is instance.activation_time_constant + + def test_activation_time_constant_is_immutable(self): + with pytest.raises(AttributeError): + self.instance.tau_a = None + with pytest.raises(AttributeError): + self.instance.activation_time_constant = None + + @pytest.mark.parametrize( + 'tau_d, expected', + [ + (None, Symbol('tau_d_name')), + (Symbol('tau_d'), Symbol('tau_d')), + (Float('0.060'), Float('0.060')), + ] + ) + def test_deactivation_time_constant(self, tau_d, expected): + instance = FirstOrderActivationDeGroote2016( + 'name', deactivation_time_constant=tau_d, + ) + assert instance.tau_d == expected + assert instance.deactivation_time_constant == expected + assert instance.tau_d is instance.deactivation_time_constant + + def test_deactivation_time_constant_is_immutable(self): + with pytest.raises(AttributeError): + self.instance.tau_d = None + with pytest.raises(AttributeError): + self.instance.deactivation_time_constant = None + + @pytest.mark.parametrize( + 'b, expected', + [ + (None, Symbol('b_name')), + (Symbol('b'), Symbol('b')), + (Integer('10'), Integer('10')), + ] + ) + def test_smoothing_rate(self, b, expected): + instance = FirstOrderActivationDeGroote2016( + 'name', smoothing_rate=b, + ) + assert instance.b == expected + assert instance.smoothing_rate == expected + assert instance.b is instance.smoothing_rate + + def test_smoothing_rate_is_immutable(self): + with pytest.raises(AttributeError): + self.instance.b = None + with pytest.raises(AttributeError): + self.instance.smoothing_rate = None + + def test_state_vars(self): + assert hasattr(self.instance, 'x') + assert hasattr(self.instance, 'state_vars') + assert self.instance.x == self.instance.state_vars + x_expected = Matrix([self.a]) + assert self.instance.x == x_expected + assert self.instance.state_vars == x_expected + assert isinstance(self.instance.x, Matrix) + assert isinstance(self.instance.state_vars, Matrix) + assert self.instance.x.shape == (1, 1) + assert self.instance.state_vars.shape == (1, 1) + + def test_input_vars(self): + assert hasattr(self.instance, 'r') + assert hasattr(self.instance, 'input_vars') + assert self.instance.r == self.instance.input_vars + r_expected = Matrix([self.e]) + assert self.instance.r == r_expected + assert self.instance.input_vars == r_expected + assert isinstance(self.instance.r, Matrix) + assert isinstance(self.instance.input_vars, Matrix) + assert self.instance.r.shape == (1, 1) + assert self.instance.input_vars.shape == (1, 1) + + def test_constants(self): + assert hasattr(self.instance, 'p') + assert hasattr(self.instance, 'constants') + assert self.instance.p == self.instance.constants + p_expected = Matrix([self.tau_a, self.tau_d, self.b]) + assert self.instance.p == p_expected + assert self.instance.constants == p_expected + assert isinstance(self.instance.p, Matrix) + assert isinstance(self.instance.constants, Matrix) + assert self.instance.p.shape == (3, 1) + assert self.instance.constants.shape == (3, 1) + + def test_M(self): + assert hasattr(self.instance, 'M') + M_expected = Matrix([1]) + assert self.instance.M == M_expected + assert isinstance(self.instance.M, Matrix) + assert self.instance.M.shape == (1, 1) + + def test_F(self): + assert hasattr(self.instance, 'F') + da_expr = ( + ((1/(self.tau_a*(Rational(1, 2) + Rational(3, 2)*self.a))) + *(Rational(1, 2) + Rational(1, 2)*tanh(self.b*(self.e - self.a))) + + ((Rational(1, 2) + Rational(3, 2)*self.a)/self.tau_d) + *(Rational(1, 2) - Rational(1, 2)*tanh(self.b*(self.e - self.a)))) + *(self.e - self.a) + ) + F_expected = Matrix([da_expr]) + assert self.instance.F == F_expected + assert isinstance(self.instance.F, Matrix) + assert self.instance.F.shape == (1, 1) + + def test_rhs(self): + assert hasattr(self.instance, 'rhs') + da_expr = ( + ((1/(self.tau_a*(Rational(1, 2) + Rational(3, 2)*self.a))) + *(Rational(1, 2) + Rational(1, 2)*tanh(self.b*(self.e - self.a))) + + ((Rational(1, 2) + Rational(3, 2)*self.a)/self.tau_d) + *(Rational(1, 2) - Rational(1, 2)*tanh(self.b*(self.e - self.a)))) + *(self.e - self.a) + ) + rhs_expected = Matrix([da_expr]) + rhs = self.instance.rhs() + assert rhs == rhs_expected + assert isinstance(rhs, Matrix) + assert rhs.shape == (1, 1) + assert simplify(self.instance.M.solve(self.instance.F) - rhs) == zeros(1) + + def test_repr(self): + expected = ( + 'FirstOrderActivationDeGroote2016(\'name\', ' + 'activation_time_constant=tau_a, ' + 'deactivation_time_constant=tau_d, ' + 'smoothing_rate=b)' + ) + assert repr(self.instance) == expected diff --git a/phivenv/Lib/site-packages/sympy/physics/biomechanics/tests/test_curve.py b/phivenv/Lib/site-packages/sympy/physics/biomechanics/tests/test_curve.py new file mode 100644 index 0000000000000000000000000000000000000000..6a8fcbccdb8b4190376b051093b376e936d9d5d3 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/biomechanics/tests/test_curve.py @@ -0,0 +1,1695 @@ +"""Tests for the ``sympy.physics.biomechanics.characteristic.py`` module.""" + +import pytest + +from sympy.core.expr import UnevaluatedExpr +from sympy.core.function import Function +from sympy.core.numbers import Float, Integer +from sympy.core.symbol import Symbol, symbols +from sympy.external.importtools import import_module +from sympy.functions.elementary.exponential import exp, log +from sympy.functions.elementary.hyperbolic import cosh, sinh +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.physics.biomechanics.curve import ( + CharacteristicCurveCollection, + CharacteristicCurveFunction, + FiberForceLengthActiveDeGroote2016, + FiberForceLengthPassiveDeGroote2016, + FiberForceLengthPassiveInverseDeGroote2016, + FiberForceVelocityDeGroote2016, + FiberForceVelocityInverseDeGroote2016, + TendonForceLengthDeGroote2016, + TendonForceLengthInverseDeGroote2016, +) +from sympy.printing.c import C89CodePrinter, C99CodePrinter, C11CodePrinter +from sympy.printing.cxx import ( + CXX98CodePrinter, + CXX11CodePrinter, + CXX17CodePrinter, +) +from sympy.printing.fortran import FCodePrinter +from sympy.printing.lambdarepr import LambdaPrinter +from sympy.printing.latex import LatexPrinter +from sympy.printing.octave import OctaveCodePrinter +from sympy.printing.numpy import ( + CuPyPrinter, + JaxPrinter, + NumPyPrinter, + SciPyPrinter, +) +from sympy.printing.pycode import MpmathPrinter, PythonCodePrinter +from sympy.utilities.lambdify import lambdify + +jax = import_module('jax') +numpy = import_module('numpy') + +if jax: + jax.config.update('jax_enable_x64', True) + + +class TestCharacteristicCurveFunction: + + @staticmethod + @pytest.mark.parametrize( + 'code_printer, expected', + [ + (C89CodePrinter, '(a + b)*(c + d)*(e + f)'), + (C99CodePrinter, '(a + b)*(c + d)*(e + f)'), + (C11CodePrinter, '(a + b)*(c + d)*(e + f)'), + (CXX98CodePrinter, '(a + b)*(c + d)*(e + f)'), + (CXX11CodePrinter, '(a + b)*(c + d)*(e + f)'), + (CXX17CodePrinter, '(a + b)*(c + d)*(e + f)'), + (FCodePrinter, ' (a + b)*(c + d)*(e + f)'), + (OctaveCodePrinter, '(a + b).*(c + d).*(e + f)'), + (PythonCodePrinter, '(a + b)*(c + d)*(e + f)'), + (NumPyPrinter, '(a + b)*(c + d)*(e + f)'), + (SciPyPrinter, '(a + b)*(c + d)*(e + f)'), + (CuPyPrinter, '(a + b)*(c + d)*(e + f)'), + (JaxPrinter, '(a + b)*(c + d)*(e + f)'), + (MpmathPrinter, '(a + b)*(c + d)*(e + f)'), + (LambdaPrinter, '(a + b)*(c + d)*(e + f)'), + ] + ) + def test_print_code_parenthesize(code_printer, expected): + + class ExampleFunction(CharacteristicCurveFunction): + + @classmethod + def eval(cls, a, b): + pass + + def doit(self, **kwargs): + a, b = self.args + return a + b + + a, b, c, d, e, f = symbols('a, b, c, d, e, f') + f1 = ExampleFunction(a, b) + f2 = ExampleFunction(c, d) + f3 = ExampleFunction(e, f) + assert code_printer().doprint(f1*f2*f3) == expected + + +class TestTendonForceLengthDeGroote2016: + + @pytest.fixture(autouse=True) + def _tendon_force_length_arguments_fixture(self): + self.l_T_tilde = Symbol('l_T_tilde') + self.c0 = Symbol('c_0') + self.c1 = Symbol('c_1') + self.c2 = Symbol('c_2') + self.c3 = Symbol('c_3') + self.constants = (self.c0, self.c1, self.c2, self.c3) + + @staticmethod + def test_class(): + assert issubclass(TendonForceLengthDeGroote2016, Function) + assert issubclass(TendonForceLengthDeGroote2016, CharacteristicCurveFunction) + assert TendonForceLengthDeGroote2016.__name__ == 'TendonForceLengthDeGroote2016' + + def test_instance(self): + fl_T = TendonForceLengthDeGroote2016(self.l_T_tilde, *self.constants) + assert isinstance(fl_T, TendonForceLengthDeGroote2016) + assert str(fl_T) == 'TendonForceLengthDeGroote2016(l_T_tilde, c_0, c_1, c_2, c_3)' + + def test_doit(self): + fl_T = TendonForceLengthDeGroote2016(self.l_T_tilde, *self.constants).doit() + assert fl_T == self.c0*exp(self.c3*(self.l_T_tilde - self.c1)) - self.c2 + + def test_doit_evaluate_false(self): + fl_T = TendonForceLengthDeGroote2016(self.l_T_tilde, *self.constants).doit(evaluate=False) + assert fl_T == self.c0*exp(self.c3*UnevaluatedExpr(self.l_T_tilde - self.c1)) - self.c2 + + def test_with_defaults(self): + constants = ( + Float('0.2'), + Float('0.995'), + Float('0.25'), + Float('33.93669377311689'), + ) + fl_T_manual = TendonForceLengthDeGroote2016(self.l_T_tilde, *constants) + fl_T_constants = TendonForceLengthDeGroote2016.with_defaults(self.l_T_tilde) + assert fl_T_manual == fl_T_constants + + def test_differentiate_wrt_l_T_tilde(self): + fl_T = TendonForceLengthDeGroote2016(self.l_T_tilde, *self.constants) + expected = self.c0*self.c3*exp(self.c3*UnevaluatedExpr(-self.c1 + self.l_T_tilde)) + assert fl_T.diff(self.l_T_tilde) == expected + + def test_differentiate_wrt_c0(self): + fl_T = TendonForceLengthDeGroote2016(self.l_T_tilde, *self.constants) + expected = exp(self.c3*UnevaluatedExpr(-self.c1 + self.l_T_tilde)) + assert fl_T.diff(self.c0) == expected + + def test_differentiate_wrt_c1(self): + fl_T = TendonForceLengthDeGroote2016(self.l_T_tilde, *self.constants) + expected = -self.c0*self.c3*exp(self.c3*UnevaluatedExpr(self.l_T_tilde - self.c1)) + assert fl_T.diff(self.c1) == expected + + def test_differentiate_wrt_c2(self): + fl_T = TendonForceLengthDeGroote2016(self.l_T_tilde, *self.constants) + expected = Integer(-1) + assert fl_T.diff(self.c2) == expected + + def test_differentiate_wrt_c3(self): + fl_T = TendonForceLengthDeGroote2016(self.l_T_tilde, *self.constants) + expected = self.c0*(self.l_T_tilde - self.c1)*exp(self.c3*UnevaluatedExpr(self.l_T_tilde - self.c1)) + assert fl_T.diff(self.c3) == expected + + def test_inverse(self): + fl_T = TendonForceLengthDeGroote2016(self.l_T_tilde, *self.constants) + assert fl_T.inverse() is TendonForceLengthInverseDeGroote2016 + + def test_function_print_latex(self): + fl_T = TendonForceLengthDeGroote2016(self.l_T_tilde, *self.constants) + expected = r'\operatorname{fl}^T \left( l_{T tilde} \right)' + assert LatexPrinter().doprint(fl_T) == expected + + def test_expression_print_latex(self): + fl_T = TendonForceLengthDeGroote2016(self.l_T_tilde, *self.constants) + expected = r'c_{0} e^{c_{3} \left(- c_{1} + l_{T tilde}\right)} - c_{2}' + assert LatexPrinter().doprint(fl_T.doit()) == expected + + @pytest.mark.parametrize( + 'code_printer, expected', + [ + ( + C89CodePrinter, + '(-0.25 + 0.20000000000000001*exp(33.93669377311689*(l_T_tilde - 0.995)))', + ), + ( + C99CodePrinter, + '(-0.25 + 0.20000000000000001*exp(33.93669377311689*(l_T_tilde - 0.995)))', + ), + ( + C11CodePrinter, + '(-0.25 + 0.20000000000000001*exp(33.93669377311689*(l_T_tilde - 0.995)))', + ), + ( + CXX98CodePrinter, + '(-0.25 + 0.20000000000000001*exp(33.93669377311689*(l_T_tilde - 0.995)))', + ), + ( + CXX11CodePrinter, + '(-0.25 + 0.20000000000000001*std::exp(33.93669377311689*(l_T_tilde - 0.995)))', + ), + ( + CXX17CodePrinter, + '(-0.25 + 0.20000000000000001*std::exp(33.93669377311689*(l_T_tilde - 0.995)))', + ), + ( + FCodePrinter, + ' (-0.25d0 + 0.2d0*exp(33.93669377311689d0*(l_T_tilde - 0.995d0)))', + ), + ( + OctaveCodePrinter, + '(-0.25 + 0.2*exp(33.93669377311689*(l_T_tilde - 0.995)))', + ), + ( + PythonCodePrinter, + '(-0.25 + 0.2*math.exp(33.93669377311689*(l_T_tilde - 0.995)))', + ), + ( + NumPyPrinter, + '(-0.25 + 0.2*numpy.exp(33.93669377311689*(l_T_tilde - 0.995)))', + ), + ( + SciPyPrinter, + '(-0.25 + 0.2*numpy.exp(33.93669377311689*(l_T_tilde - 0.995)))', + ), + ( + CuPyPrinter, + '(-0.25 + 0.2*cupy.exp(33.93669377311689*(l_T_tilde - 0.995)))', + ), + ( + JaxPrinter, + '(-0.25 + 0.2*jax.numpy.exp(33.93669377311689*(l_T_tilde - 0.995)))', + ), + ( + MpmathPrinter, + '(mpmath.mpf((1, 1, -2, 1)) + mpmath.mpf((0, 3602879701896397, -54, 52))' + '*mpmath.exp(mpmath.mpf((0, 9552330089424741, -48, 54))*(l_T_tilde + ' + 'mpmath.mpf((1, 8962163258467287, -53, 53)))))', + ), + ( + LambdaPrinter, + '(-0.25 + 0.2*math.exp(33.93669377311689*(l_T_tilde - 0.995)))', + ), + ] + ) + def test_print_code(self, code_printer, expected): + fl_T = TendonForceLengthDeGroote2016.with_defaults(self.l_T_tilde) + assert code_printer().doprint(fl_T) == expected + + def test_derivative_print_code(self): + fl_T = TendonForceLengthDeGroote2016.with_defaults(self.l_T_tilde) + dfl_T_dl_T_tilde = fl_T.diff(self.l_T_tilde) + expected = '6.787338754623378*math.exp(33.93669377311689*(l_T_tilde - 0.995))' + assert PythonCodePrinter().doprint(dfl_T_dl_T_tilde) == expected + + def test_lambdify(self): + fl_T = TendonForceLengthDeGroote2016.with_defaults(self.l_T_tilde) + fl_T_callable = lambdify(self.l_T_tilde, fl_T) + assert fl_T_callable(1.0) == pytest.approx(-0.013014055039221595) + + @pytest.mark.skipif(numpy is None, reason='NumPy not installed') + def test_lambdify_numpy(self): + fl_T = TendonForceLengthDeGroote2016.with_defaults(self.l_T_tilde) + fl_T_callable = lambdify(self.l_T_tilde, fl_T, 'numpy') + l_T_tilde = numpy.array([0.95, 1.0, 1.01, 1.05]) + expected = numpy.array([ + -0.2065693181344816, + -0.0130140550392216, + 0.0827421191989246, + 1.04314889144172, + ]) + numpy.testing.assert_allclose(fl_T_callable(l_T_tilde), expected) + + @pytest.mark.skipif(jax is None, reason='JAX not installed') + def test_lambdify_jax(self): + fl_T = TendonForceLengthDeGroote2016.with_defaults(self.l_T_tilde) + fl_T_callable = jax.jit(lambdify(self.l_T_tilde, fl_T, 'jax')) + l_T_tilde = jax.numpy.array([0.95, 1.0, 1.01, 1.05]) + expected = jax.numpy.array([ + -0.2065693181344816, + -0.0130140550392216, + 0.0827421191989246, + 1.04314889144172, + ]) + numpy.testing.assert_allclose(fl_T_callable(l_T_tilde), expected) + + +class TestTendonForceLengthInverseDeGroote2016: + + @pytest.fixture(autouse=True) + def _tendon_force_length_inverse_arguments_fixture(self): + self.fl_T = Symbol('fl_T') + self.c0 = Symbol('c_0') + self.c1 = Symbol('c_1') + self.c2 = Symbol('c_2') + self.c3 = Symbol('c_3') + self.constants = (self.c0, self.c1, self.c2, self.c3) + + @staticmethod + def test_class(): + assert issubclass(TendonForceLengthInverseDeGroote2016, Function) + assert issubclass(TendonForceLengthInverseDeGroote2016, CharacteristicCurveFunction) + assert TendonForceLengthInverseDeGroote2016.__name__ == 'TendonForceLengthInverseDeGroote2016' + + def test_instance(self): + fl_T_inv = TendonForceLengthInverseDeGroote2016(self.fl_T, *self.constants) + assert isinstance(fl_T_inv, TendonForceLengthInverseDeGroote2016) + assert str(fl_T_inv) == 'TendonForceLengthInverseDeGroote2016(fl_T, c_0, c_1, c_2, c_3)' + + def test_doit(self): + fl_T_inv = TendonForceLengthInverseDeGroote2016(self.fl_T, *self.constants).doit() + assert fl_T_inv == log((self.fl_T + self.c2)/self.c0)/self.c3 + self.c1 + + def test_doit_evaluate_false(self): + fl_T_inv = TendonForceLengthInverseDeGroote2016(self.fl_T, *self.constants).doit(evaluate=False) + assert fl_T_inv == log(UnevaluatedExpr((self.fl_T + self.c2)/self.c0))/self.c3 + self.c1 + + def test_with_defaults(self): + constants = ( + Float('0.2'), + Float('0.995'), + Float('0.25'), + Float('33.93669377311689'), + ) + fl_T_inv_manual = TendonForceLengthInverseDeGroote2016(self.fl_T, *constants) + fl_T_inv_constants = TendonForceLengthInverseDeGroote2016.with_defaults(self.fl_T) + assert fl_T_inv_manual == fl_T_inv_constants + + def test_differentiate_wrt_fl_T(self): + fl_T_inv = TendonForceLengthInverseDeGroote2016(self.fl_T, *self.constants) + expected = 1/(self.c3*(self.fl_T + self.c2)) + assert fl_T_inv.diff(self.fl_T) == expected + + def test_differentiate_wrt_c0(self): + fl_T_inv = TendonForceLengthInverseDeGroote2016(self.fl_T, *self.constants) + expected = -1/(self.c0*self.c3) + assert fl_T_inv.diff(self.c0) == expected + + def test_differentiate_wrt_c1(self): + fl_T_inv = TendonForceLengthInverseDeGroote2016(self.fl_T, *self.constants) + expected = Integer(1) + assert fl_T_inv.diff(self.c1) == expected + + def test_differentiate_wrt_c2(self): + fl_T_inv = TendonForceLengthInverseDeGroote2016(self.fl_T, *self.constants) + expected = 1/(self.c3*(self.fl_T + self.c2)) + assert fl_T_inv.diff(self.c2) == expected + + def test_differentiate_wrt_c3(self): + fl_T_inv = TendonForceLengthInverseDeGroote2016(self.fl_T, *self.constants) + expected = -log(UnevaluatedExpr((self.fl_T + self.c2)/self.c0))/self.c3**2 + assert fl_T_inv.diff(self.c3) == expected + + def test_inverse(self): + fl_T_inv = TendonForceLengthInverseDeGroote2016(self.fl_T, *self.constants) + assert fl_T_inv.inverse() is TendonForceLengthDeGroote2016 + + def test_function_print_latex(self): + fl_T_inv = TendonForceLengthInverseDeGroote2016(self.fl_T, *self.constants) + expected = r'\left( \operatorname{fl}^T \right)^{-1} \left( fl_{T} \right)' + assert LatexPrinter().doprint(fl_T_inv) == expected + + def test_expression_print_latex(self): + fl_T = TendonForceLengthInverseDeGroote2016(self.fl_T, *self.constants) + expected = r'c_{1} + \frac{\log{\left(\frac{c_{2} + fl_{T}}{c_{0}} \right)}}{c_{3}}' + assert LatexPrinter().doprint(fl_T.doit()) == expected + + @pytest.mark.parametrize( + 'code_printer, expected', + [ + ( + C89CodePrinter, + '(0.995 + 0.029466630034306838*log(5.0*fl_T + 1.25))', + ), + ( + C99CodePrinter, + '(0.995 + 0.029466630034306838*log(5.0*fl_T + 1.25))', + ), + ( + C11CodePrinter, + '(0.995 + 0.029466630034306838*log(5.0*fl_T + 1.25))', + ), + ( + CXX98CodePrinter, + '(0.995 + 0.029466630034306838*log(5.0*fl_T + 1.25))', + ), + ( + CXX11CodePrinter, + '(0.995 + 0.029466630034306838*std::log(5.0*fl_T + 1.25))', + ), + ( + CXX17CodePrinter, + '(0.995 + 0.029466630034306838*std::log(5.0*fl_T + 1.25))', + ), + ( + FCodePrinter, + ' (0.995d0 + 0.02946663003430684d0*log(5.0d0*fl_T + 1.25d0))', + ), + ( + OctaveCodePrinter, + '(0.995 + 0.02946663003430684*log(5.0*fl_T + 1.25))', + ), + ( + PythonCodePrinter, + '(0.995 + 0.02946663003430684*math.log(5.0*fl_T + 1.25))', + ), + ( + NumPyPrinter, + '(0.995 + 0.02946663003430684*numpy.log(5.0*fl_T + 1.25))', + ), + ( + SciPyPrinter, + '(0.995 + 0.02946663003430684*numpy.log(5.0*fl_T + 1.25))', + ), + ( + CuPyPrinter, + '(0.995 + 0.02946663003430684*cupy.log(5.0*fl_T + 1.25))', + ), + ( + JaxPrinter, + '(0.995 + 0.02946663003430684*jax.numpy.log(5.0*fl_T + 1.25))', + ), + ( + MpmathPrinter, + '(mpmath.mpf((0, 8962163258467287, -53, 53))' + ' + mpmath.mpf((0, 33972711434846347, -60, 55))' + '*mpmath.log(mpmath.mpf((0, 5, 0, 3))*fl_T + mpmath.mpf((0, 5, -2, 3))))', + ), + ( + LambdaPrinter, + '(0.995 + 0.02946663003430684*math.log(5.0*fl_T + 1.25))', + ), + ] + ) + def test_print_code(self, code_printer, expected): + fl_T_inv = TendonForceLengthInverseDeGroote2016.with_defaults(self.fl_T) + assert code_printer().doprint(fl_T_inv) == expected + + def test_derivative_print_code(self): + fl_T_inv = TendonForceLengthInverseDeGroote2016.with_defaults(self.fl_T) + dfl_T_inv_dfl_T = fl_T_inv.diff(self.fl_T) + expected = '1/(33.93669377311689*fl_T + 8.484173443279222)' + assert PythonCodePrinter().doprint(dfl_T_inv_dfl_T) == expected + + def test_lambdify(self): + fl_T_inv = TendonForceLengthInverseDeGroote2016.with_defaults(self.fl_T) + fl_T_inv_callable = lambdify(self.fl_T, fl_T_inv) + assert fl_T_inv_callable(0.0) == pytest.approx(1.0015752885) + + @pytest.mark.skipif(numpy is None, reason='NumPy not installed') + def test_lambdify_numpy(self): + fl_T_inv = TendonForceLengthInverseDeGroote2016.with_defaults(self.fl_T) + fl_T_inv_callable = lambdify(self.fl_T, fl_T_inv, 'numpy') + fl_T = numpy.array([-0.2, -0.01, 0.0, 1.01, 1.02, 1.05]) + expected = numpy.array([ + 0.9541505769, + 1.0003724019, + 1.0015752885, + 1.0492347951, + 1.0494677341, + 1.0501557022, + ]) + numpy.testing.assert_allclose(fl_T_inv_callable(fl_T), expected) + + @pytest.mark.skipif(jax is None, reason='JAX not installed') + def test_lambdify_jax(self): + fl_T_inv = TendonForceLengthInverseDeGroote2016.with_defaults(self.fl_T) + fl_T_inv_callable = jax.jit(lambdify(self.fl_T, fl_T_inv, 'jax')) + fl_T = jax.numpy.array([-0.2, -0.01, 0.0, 1.01, 1.02, 1.05]) + expected = jax.numpy.array([ + 0.9541505769, + 1.0003724019, + 1.0015752885, + 1.0492347951, + 1.0494677341, + 1.0501557022, + ]) + numpy.testing.assert_allclose(fl_T_inv_callable(fl_T), expected) + + +class TestFiberForceLengthPassiveDeGroote2016: + + @pytest.fixture(autouse=True) + def _fiber_force_length_passive_arguments_fixture(self): + self.l_M_tilde = Symbol('l_M_tilde') + self.c0 = Symbol('c_0') + self.c1 = Symbol('c_1') + self.constants = (self.c0, self.c1) + + @staticmethod + def test_class(): + assert issubclass(FiberForceLengthPassiveDeGroote2016, Function) + assert issubclass(FiberForceLengthPassiveDeGroote2016, CharacteristicCurveFunction) + assert FiberForceLengthPassiveDeGroote2016.__name__ == 'FiberForceLengthPassiveDeGroote2016' + + def test_instance(self): + fl_M_pas = FiberForceLengthPassiveDeGroote2016(self.l_M_tilde, *self.constants) + assert isinstance(fl_M_pas, FiberForceLengthPassiveDeGroote2016) + assert str(fl_M_pas) == 'FiberForceLengthPassiveDeGroote2016(l_M_tilde, c_0, c_1)' + + def test_doit(self): + fl_M_pas = FiberForceLengthPassiveDeGroote2016(self.l_M_tilde, *self.constants).doit() + assert fl_M_pas == (exp((self.c1*(self.l_M_tilde - 1))/self.c0) - 1)/(exp(self.c1) - 1) + + def test_doit_evaluate_false(self): + fl_M_pas = FiberForceLengthPassiveDeGroote2016(self.l_M_tilde, *self.constants).doit(evaluate=False) + assert fl_M_pas == (exp((self.c1*UnevaluatedExpr(self.l_M_tilde - 1))/self.c0) - 1)/(exp(self.c1) - 1) + + def test_with_defaults(self): + constants = ( + Float('0.6'), + Float('4.0'), + ) + fl_M_pas_manual = FiberForceLengthPassiveDeGroote2016(self.l_M_tilde, *constants) + fl_M_pas_constants = FiberForceLengthPassiveDeGroote2016.with_defaults(self.l_M_tilde) + assert fl_M_pas_manual == fl_M_pas_constants + + def test_differentiate_wrt_l_M_tilde(self): + fl_M_pas = FiberForceLengthPassiveDeGroote2016(self.l_M_tilde, *self.constants) + expected = self.c1*exp(self.c1*UnevaluatedExpr(self.l_M_tilde - 1)/self.c0)/(self.c0*(exp(self.c1) - 1)) + assert fl_M_pas.diff(self.l_M_tilde) == expected + + def test_differentiate_wrt_c0(self): + fl_M_pas = FiberForceLengthPassiveDeGroote2016(self.l_M_tilde, *self.constants) + expected = ( + -self.c1*exp(self.c1*UnevaluatedExpr(self.l_M_tilde - 1)/self.c0) + *UnevaluatedExpr(self.l_M_tilde - 1)/(self.c0**2*(exp(self.c1) - 1)) + ) + assert fl_M_pas.diff(self.c0) == expected + + def test_differentiate_wrt_c1(self): + fl_M_pas = FiberForceLengthPassiveDeGroote2016(self.l_M_tilde, *self.constants) + expected = ( + -exp(self.c1)*(-1 + exp(self.c1*UnevaluatedExpr(self.l_M_tilde - 1)/self.c0))/(exp(self.c1) - 1)**2 + + exp(self.c1*UnevaluatedExpr(self.l_M_tilde - 1)/self.c0)*(self.l_M_tilde - 1)/(self.c0*(exp(self.c1) - 1)) + ) + assert fl_M_pas.diff(self.c1) == expected + + def test_inverse(self): + fl_M_pas = FiberForceLengthPassiveDeGroote2016(self.l_M_tilde, *self.constants) + assert fl_M_pas.inverse() is FiberForceLengthPassiveInverseDeGroote2016 + + def test_function_print_latex(self): + fl_M_pas = FiberForceLengthPassiveDeGroote2016(self.l_M_tilde, *self.constants) + expected = r'\operatorname{fl}^M_{pas} \left( l_{M tilde} \right)' + assert LatexPrinter().doprint(fl_M_pas) == expected + + def test_expression_print_latex(self): + fl_M_pas = FiberForceLengthPassiveDeGroote2016(self.l_M_tilde, *self.constants) + expected = r'\frac{e^{\frac{c_{1} \left(l_{M tilde} - 1\right)}{c_{0}}} - 1}{e^{c_{1}} - 1}' + assert LatexPrinter().doprint(fl_M_pas.doit()) == expected + + @pytest.mark.parametrize( + 'code_printer, expected', + [ + ( + C89CodePrinter, + '(0.01865736036377405*(-1 + exp(6.666666666666667*(l_M_tilde - 1))))', + ), + ( + C99CodePrinter, + '(0.01865736036377405*(-1 + exp(6.666666666666667*(l_M_tilde - 1))))', + ), + ( + C11CodePrinter, + '(0.01865736036377405*(-1 + exp(6.666666666666667*(l_M_tilde - 1))))', + ), + ( + CXX98CodePrinter, + '(0.01865736036377405*(-1 + exp(6.666666666666667*(l_M_tilde - 1))))', + ), + ( + CXX11CodePrinter, + '(0.01865736036377405*(-1 + std::exp(6.666666666666667*(l_M_tilde - 1))))', + ), + ( + CXX17CodePrinter, + '(0.01865736036377405*(-1 + std::exp(6.666666666666667*(l_M_tilde - 1))))', + ), + ( + FCodePrinter, + ' (0.0186573603637741d0*(-1 + exp(6.666666666666667d0*(l_M_tilde - 1\n' + ' @ ))))', + ), + ( + OctaveCodePrinter, + '(0.0186573603637741*(-1 + exp(6.66666666666667*(l_M_tilde - 1))))', + ), + ( + PythonCodePrinter, + '(0.0186573603637741*(-1 + math.exp(6.66666666666667*(l_M_tilde - 1))))', + ), + ( + NumPyPrinter, + '(0.0186573603637741*(-1 + numpy.exp(6.66666666666667*(l_M_tilde - 1))))', + ), + ( + SciPyPrinter, + '(0.0186573603637741*(-1 + numpy.exp(6.66666666666667*(l_M_tilde - 1))))', + ), + ( + CuPyPrinter, + '(0.0186573603637741*(-1 + cupy.exp(6.66666666666667*(l_M_tilde - 1))))', + ), + ( + JaxPrinter, + '(0.0186573603637741*(-1 + jax.numpy.exp(6.66666666666667*(l_M_tilde - 1))))', + ), + ( + MpmathPrinter, + '(mpmath.mpf((0, 672202249456079, -55, 50))*(-1 + mpmath.exp(' + 'mpmath.mpf((0, 7505999378950827, -50, 53))*(l_M_tilde - 1))))', + ), + ( + LambdaPrinter, + '(0.0186573603637741*(-1 + math.exp(6.66666666666667*(l_M_tilde - 1))))', + ), + ] + ) + def test_print_code(self, code_printer, expected): + fl_M_pas = FiberForceLengthPassiveDeGroote2016.with_defaults(self.l_M_tilde) + assert code_printer().doprint(fl_M_pas) == expected + + def test_derivative_print_code(self): + fl_M_pas = FiberForceLengthPassiveDeGroote2016.with_defaults(self.l_M_tilde) + fl_M_pas_dl_M_tilde = fl_M_pas.diff(self.l_M_tilde) + expected = '0.12438240242516*math.exp(6.66666666666667*(l_M_tilde - 1))' + assert PythonCodePrinter().doprint(fl_M_pas_dl_M_tilde) == expected + + def test_lambdify(self): + fl_M_pas = FiberForceLengthPassiveDeGroote2016.with_defaults(self.l_M_tilde) + fl_M_pas_callable = lambdify(self.l_M_tilde, fl_M_pas) + assert fl_M_pas_callable(1.0) == pytest.approx(0.0) + + @pytest.mark.skipif(numpy is None, reason='NumPy not installed') + def test_lambdify_numpy(self): + fl_M_pas = FiberForceLengthPassiveDeGroote2016.with_defaults(self.l_M_tilde) + fl_M_pas_callable = lambdify(self.l_M_tilde, fl_M_pas, 'numpy') + l_M_tilde = numpy.array([0.5, 0.8, 0.9, 1.0, 1.1, 1.2, 1.5]) + expected = numpy.array([ + -0.0179917778, + -0.0137393336, + -0.0090783522, + 0.0, + 0.0176822155, + 0.0521224686, + 0.5043387669, + ]) + numpy.testing.assert_allclose(fl_M_pas_callable(l_M_tilde), expected) + + @pytest.mark.skipif(jax is None, reason='JAX not installed') + def test_lambdify_jax(self): + fl_M_pas = FiberForceLengthPassiveDeGroote2016.with_defaults(self.l_M_tilde) + fl_M_pas_callable = jax.jit(lambdify(self.l_M_tilde, fl_M_pas, 'jax')) + l_M_tilde = jax.numpy.array([0.5, 0.8, 0.9, 1.0, 1.1, 1.2, 1.5]) + expected = jax.numpy.array([ + -0.0179917778, + -0.0137393336, + -0.0090783522, + 0.0, + 0.0176822155, + 0.0521224686, + 0.5043387669, + ]) + numpy.testing.assert_allclose(fl_M_pas_callable(l_M_tilde), expected) + + +class TestFiberForceLengthPassiveInverseDeGroote2016: + + @pytest.fixture(autouse=True) + def _fiber_force_length_passive_arguments_fixture(self): + self.fl_M_pas = Symbol('fl_M_pas') + self.c0 = Symbol('c_0') + self.c1 = Symbol('c_1') + self.constants = (self.c0, self.c1) + + @staticmethod + def test_class(): + assert issubclass(FiberForceLengthPassiveInverseDeGroote2016, Function) + assert issubclass(FiberForceLengthPassiveInverseDeGroote2016, CharacteristicCurveFunction) + assert FiberForceLengthPassiveInverseDeGroote2016.__name__ == 'FiberForceLengthPassiveInverseDeGroote2016' + + def test_instance(self): + fl_M_pas_inv = FiberForceLengthPassiveInverseDeGroote2016(self.fl_M_pas, *self.constants) + assert isinstance(fl_M_pas_inv, FiberForceLengthPassiveInverseDeGroote2016) + assert str(fl_M_pas_inv) == 'FiberForceLengthPassiveInverseDeGroote2016(fl_M_pas, c_0, c_1)' + + def test_doit(self): + fl_M_pas_inv = FiberForceLengthPassiveInverseDeGroote2016(self.fl_M_pas, *self.constants).doit() + assert fl_M_pas_inv == self.c0*log(self.fl_M_pas*(exp(self.c1) - 1) + 1)/self.c1 + 1 + + def test_doit_evaluate_false(self): + fl_M_pas_inv = FiberForceLengthPassiveInverseDeGroote2016(self.fl_M_pas, *self.constants).doit(evaluate=False) + assert fl_M_pas_inv == self.c0*log(UnevaluatedExpr(self.fl_M_pas*(exp(self.c1) - 1)) + 1)/self.c1 + 1 + + def test_with_defaults(self): + constants = ( + Float('0.6'), + Float('4.0'), + ) + fl_M_pas_inv_manual = FiberForceLengthPassiveInverseDeGroote2016(self.fl_M_pas, *constants) + fl_M_pas_inv_constants = FiberForceLengthPassiveInverseDeGroote2016.with_defaults(self.fl_M_pas) + assert fl_M_pas_inv_manual == fl_M_pas_inv_constants + + def test_differentiate_wrt_fl_T(self): + fl_M_pas_inv = FiberForceLengthPassiveInverseDeGroote2016(self.fl_M_pas, *self.constants) + expected = self.c0*(exp(self.c1) - 1)/(self.c1*(self.fl_M_pas*(exp(self.c1) - 1) + 1)) + assert fl_M_pas_inv.diff(self.fl_M_pas) == expected + + def test_differentiate_wrt_c0(self): + fl_M_pas_inv = FiberForceLengthPassiveInverseDeGroote2016(self.fl_M_pas, *self.constants) + expected = log(self.fl_M_pas*(exp(self.c1) - 1) + 1)/self.c1 + assert fl_M_pas_inv.diff(self.c0) == expected + + def test_differentiate_wrt_c1(self): + fl_M_pas_inv = FiberForceLengthPassiveInverseDeGroote2016(self.fl_M_pas, *self.constants) + expected = ( + self.c0*self.fl_M_pas*exp(self.c1)/(self.c1*(self.fl_M_pas*(exp(self.c1) - 1) + 1)) + - self.c0*log(self.fl_M_pas*(exp(self.c1) - 1) + 1)/self.c1**2 + ) + assert fl_M_pas_inv.diff(self.c1) == expected + + def test_inverse(self): + fl_M_pas_inv = FiberForceLengthPassiveInverseDeGroote2016(self.fl_M_pas, *self.constants) + assert fl_M_pas_inv.inverse() is FiberForceLengthPassiveDeGroote2016 + + def test_function_print_latex(self): + fl_M_pas_inv = FiberForceLengthPassiveInverseDeGroote2016(self.fl_M_pas, *self.constants) + expected = r'\left( \operatorname{fl}^M_{pas} \right)^{-1} \left( fl_{M pas} \right)' + assert LatexPrinter().doprint(fl_M_pas_inv) == expected + + def test_expression_print_latex(self): + fl_T = FiberForceLengthPassiveInverseDeGroote2016(self.fl_M_pas, *self.constants) + expected = r'\frac{c_{0} \log{\left(fl_{M pas} \left(e^{c_{1}} - 1\right) + 1 \right)}}{c_{1}} + 1' + assert LatexPrinter().doprint(fl_T.doit()) == expected + + @pytest.mark.parametrize( + 'code_printer, expected', + [ + ( + C89CodePrinter, + '(1 + 0.14999999999999999*log(1 + 53.598150033144236*fl_M_pas))', + ), + ( + C99CodePrinter, + '(1 + 0.14999999999999999*log(1 + 53.598150033144236*fl_M_pas))', + ), + ( + C11CodePrinter, + '(1 + 0.14999999999999999*log(1 + 53.598150033144236*fl_M_pas))', + ), + ( + CXX98CodePrinter, + '(1 + 0.14999999999999999*log(1 + 53.598150033144236*fl_M_pas))', + ), + ( + CXX11CodePrinter, + '(1 + 0.14999999999999999*std::log(1 + 53.598150033144236*fl_M_pas))', + ), + ( + CXX17CodePrinter, + '(1 + 0.14999999999999999*std::log(1 + 53.598150033144236*fl_M_pas))', + ), + ( + FCodePrinter, + ' (1 + 0.15d0*log(1.0d0 + 53.5981500331442d0*fl_M_pas))', + ), + ( + OctaveCodePrinter, + '(1 + 0.15*log(1 + 53.5981500331442*fl_M_pas))', + ), + ( + PythonCodePrinter, + '(1 + 0.15*math.log(1 + 53.5981500331442*fl_M_pas))', + ), + ( + NumPyPrinter, + '(1 + 0.15*numpy.log(1 + 53.5981500331442*fl_M_pas))', + ), + ( + SciPyPrinter, + '(1 + 0.15*numpy.log(1 + 53.5981500331442*fl_M_pas))', + ), + ( + CuPyPrinter, + '(1 + 0.15*cupy.log(1 + 53.5981500331442*fl_M_pas))', + ), + ( + JaxPrinter, + '(1 + 0.15*jax.numpy.log(1 + 53.5981500331442*fl_M_pas))', + ), + ( + MpmathPrinter, + '(1 + mpmath.mpf((0, 5404319552844595, -55, 53))*mpmath.log(1 ' + '+ mpmath.mpf((0, 942908627019595, -44, 50))*fl_M_pas))', + ), + ( + LambdaPrinter, + '(1 + 0.15*math.log(1 + 53.5981500331442*fl_M_pas))', + ), + ] + ) + def test_print_code(self, code_printer, expected): + fl_M_pas_inv = FiberForceLengthPassiveInverseDeGroote2016.with_defaults(self.fl_M_pas) + assert code_printer().doprint(fl_M_pas_inv) == expected + + def test_derivative_print_code(self): + fl_M_pas_inv = FiberForceLengthPassiveInverseDeGroote2016.with_defaults(self.fl_M_pas) + dfl_M_pas_inv_dfl_T = fl_M_pas_inv.diff(self.fl_M_pas) + expected = '32.1588900198865/(214.392600132577*fl_M_pas + 4.0)' + assert PythonCodePrinter().doprint(dfl_M_pas_inv_dfl_T) == expected + + def test_lambdify(self): + fl_M_pas_inv = FiberForceLengthPassiveInverseDeGroote2016.with_defaults(self.fl_M_pas) + fl_M_pas_inv_callable = lambdify(self.fl_M_pas, fl_M_pas_inv) + assert fl_M_pas_inv_callable(0.0) == pytest.approx(1.0) + + @pytest.mark.skipif(numpy is None, reason='NumPy not installed') + def test_lambdify_numpy(self): + fl_M_pas_inv = FiberForceLengthPassiveInverseDeGroote2016.with_defaults(self.fl_M_pas) + fl_M_pas_inv_callable = lambdify(self.fl_M_pas, fl_M_pas_inv, 'numpy') + fl_M_pas = numpy.array([-0.01, 0.0, 0.01, 0.02, 0.05, 0.1]) + expected = numpy.array([ + 0.8848253714, + 1.0, + 1.0643754386, + 1.1092744701, + 1.1954331425, + 1.2774998934, + ]) + numpy.testing.assert_allclose(fl_M_pas_inv_callable(fl_M_pas), expected) + + @pytest.mark.skipif(jax is None, reason='JAX not installed') + def test_lambdify_jax(self): + fl_M_pas_inv = FiberForceLengthPassiveInverseDeGroote2016.with_defaults(self.fl_M_pas) + fl_M_pas_inv_callable = jax.jit(lambdify(self.fl_M_pas, fl_M_pas_inv, 'jax')) + fl_M_pas = jax.numpy.array([-0.01, 0.0, 0.01, 0.02, 0.05, 0.1]) + expected = jax.numpy.array([ + 0.8848253714, + 1.0, + 1.0643754386, + 1.1092744701, + 1.1954331425, + 1.2774998934, + ]) + numpy.testing.assert_allclose(fl_M_pas_inv_callable(fl_M_pas), expected) + + +class TestFiberForceLengthActiveDeGroote2016: + + @pytest.fixture(autouse=True) + def _fiber_force_length_active_arguments_fixture(self): + self.l_M_tilde = Symbol('l_M_tilde') + self.c0 = Symbol('c_0') + self.c1 = Symbol('c_1') + self.c2 = Symbol('c_2') + self.c3 = Symbol('c_3') + self.c4 = Symbol('c_4') + self.c5 = Symbol('c_5') + self.c6 = Symbol('c_6') + self.c7 = Symbol('c_7') + self.c8 = Symbol('c_8') + self.c9 = Symbol('c_9') + self.c10 = Symbol('c_10') + self.c11 = Symbol('c_11') + self.constants = ( + self.c0, self.c1, self.c2, self.c3, self.c4, self.c5, + self.c6, self.c7, self.c8, self.c9, self.c10, self.c11, + ) + + @staticmethod + def test_class(): + assert issubclass(FiberForceLengthActiveDeGroote2016, Function) + assert issubclass(FiberForceLengthActiveDeGroote2016, CharacteristicCurveFunction) + assert FiberForceLengthActiveDeGroote2016.__name__ == 'FiberForceLengthActiveDeGroote2016' + + def test_instance(self): + fl_M_act = FiberForceLengthActiveDeGroote2016(self.l_M_tilde, *self.constants) + assert isinstance(fl_M_act, FiberForceLengthActiveDeGroote2016) + assert str(fl_M_act) == ( + 'FiberForceLengthActiveDeGroote2016(l_M_tilde, c_0, c_1, c_2, c_3, ' + 'c_4, c_5, c_6, c_7, c_8, c_9, c_10, c_11)' + ) + + def test_doit(self): + fl_M_act = FiberForceLengthActiveDeGroote2016(self.l_M_tilde, *self.constants).doit() + assert fl_M_act == ( + self.c0*exp(-(((self.l_M_tilde - self.c1)/(self.c2 + self.c3*self.l_M_tilde))**2)/2) + + self.c4*exp(-(((self.l_M_tilde - self.c5)/(self.c6 + self.c7*self.l_M_tilde))**2)/2) + + self.c8*exp(-(((self.l_M_tilde - self.c9)/(self.c10 + self.c11*self.l_M_tilde))**2)/2) + ) + + def test_doit_evaluate_false(self): + fl_M_act = FiberForceLengthActiveDeGroote2016(self.l_M_tilde, *self.constants).doit(evaluate=False) + assert fl_M_act == ( + self.c0*exp(-((UnevaluatedExpr(self.l_M_tilde - self.c1)/(self.c2 + self.c3*self.l_M_tilde))**2)/2) + + self.c4*exp(-((UnevaluatedExpr(self.l_M_tilde - self.c5)/(self.c6 + self.c7*self.l_M_tilde))**2)/2) + + self.c8*exp(-((UnevaluatedExpr(self.l_M_tilde - self.c9)/(self.c10 + self.c11*self.l_M_tilde))**2)/2) + ) + + def test_with_defaults(self): + constants = ( + Float('0.814'), + Float('1.06'), + Float('0.162'), + Float('0.0633'), + Float('0.433'), + Float('0.717'), + Float('-0.0299'), + Float('0.2'), + Float('0.1'), + Float('1.0'), + Float('0.354'), + Float('0.0'), + ) + fl_M_act_manual = FiberForceLengthActiveDeGroote2016(self.l_M_tilde, *constants) + fl_M_act_constants = FiberForceLengthActiveDeGroote2016.with_defaults(self.l_M_tilde) + assert fl_M_act_manual == fl_M_act_constants + + def test_differentiate_wrt_l_M_tilde(self): + fl_M_act = FiberForceLengthActiveDeGroote2016(self.l_M_tilde, *self.constants) + expected = ( + self.c0*( + self.c3*(self.l_M_tilde - self.c1)**2/(self.c2 + self.c3*self.l_M_tilde)**3 + + (self.c1 - self.l_M_tilde)/((self.c2 + self.c3*self.l_M_tilde)**2) + )*exp(-(self.l_M_tilde - self.c1)**2/(2*(self.c2 + self.c3*self.l_M_tilde)**2)) + + self.c4*( + self.c7*(self.l_M_tilde - self.c5)**2/(self.c6 + self.c7*self.l_M_tilde)**3 + + (self.c5 - self.l_M_tilde)/((self.c6 + self.c7*self.l_M_tilde)**2) + )*exp(-(self.l_M_tilde - self.c5)**2/(2*(self.c6 + self.c7*self.l_M_tilde)**2)) + + self.c8*( + self.c11*(self.l_M_tilde - self.c9)**2/(self.c10 + self.c11*self.l_M_tilde)**3 + + (self.c9 - self.l_M_tilde)/((self.c10 + self.c11*self.l_M_tilde)**2) + )*exp(-(self.l_M_tilde - self.c9)**2/(2*(self.c10 + self.c11*self.l_M_tilde)**2)) + ) + assert fl_M_act.diff(self.l_M_tilde) == expected + + def test_differentiate_wrt_c0(self): + fl_M_act = FiberForceLengthActiveDeGroote2016(self.l_M_tilde, *self.constants) + expected = exp(-(self.l_M_tilde - self.c1)**2/(2*(self.c2 + self.c3*self.l_M_tilde)**2)) + assert fl_M_act.doit().diff(self.c0) == expected + + def test_differentiate_wrt_c1(self): + fl_M_act = FiberForceLengthActiveDeGroote2016(self.l_M_tilde, *self.constants) + expected = ( + self.c0*(self.l_M_tilde - self.c1)/(self.c2 + self.c3*self.l_M_tilde)**2 + *exp(-(self.l_M_tilde - self.c1)**2/(2*(self.c2 + self.c3*self.l_M_tilde)**2)) + ) + assert fl_M_act.diff(self.c1) == expected + + def test_differentiate_wrt_c2(self): + fl_M_act = FiberForceLengthActiveDeGroote2016(self.l_M_tilde, *self.constants) + expected = ( + self.c0*(self.l_M_tilde - self.c1)**2/(self.c2 + self.c3*self.l_M_tilde)**3 + *exp(-(self.l_M_tilde - self.c1)**2/(2*(self.c2 + self.c3*self.l_M_tilde)**2)) + ) + assert fl_M_act.diff(self.c2) == expected + + def test_differentiate_wrt_c3(self): + fl_M_act = FiberForceLengthActiveDeGroote2016(self.l_M_tilde, *self.constants) + expected = ( + self.c0*self.l_M_tilde*(self.l_M_tilde - self.c1)**2/(self.c2 + self.c3*self.l_M_tilde)**3 + *exp(-(self.l_M_tilde - self.c1)**2/(2*(self.c2 + self.c3*self.l_M_tilde)**2)) + ) + assert fl_M_act.diff(self.c3) == expected + + def test_differentiate_wrt_c4(self): + fl_M_act = FiberForceLengthActiveDeGroote2016(self.l_M_tilde, *self.constants) + expected = exp(-(self.l_M_tilde - self.c5)**2/(2*(self.c6 + self.c7*self.l_M_tilde)**2)) + assert fl_M_act.diff(self.c4) == expected + + def test_differentiate_wrt_c5(self): + fl_M_act = FiberForceLengthActiveDeGroote2016(self.l_M_tilde, *self.constants) + expected = ( + self.c4*(self.l_M_tilde - self.c5)/(self.c6 + self.c7*self.l_M_tilde)**2 + *exp(-(self.l_M_tilde - self.c5)**2/(2*(self.c6 + self.c7*self.l_M_tilde)**2)) + ) + assert fl_M_act.diff(self.c5) == expected + + def test_differentiate_wrt_c6(self): + fl_M_act = FiberForceLengthActiveDeGroote2016(self.l_M_tilde, *self.constants) + expected = ( + self.c4*(self.l_M_tilde - self.c5)**2/(self.c6 + self.c7*self.l_M_tilde)**3 + *exp(-(self.l_M_tilde - self.c5)**2/(2*(self.c6 + self.c7*self.l_M_tilde)**2)) + ) + assert fl_M_act.diff(self.c6) == expected + + def test_differentiate_wrt_c7(self): + fl_M_act = FiberForceLengthActiveDeGroote2016(self.l_M_tilde, *self.constants) + expected = ( + self.c4*self.l_M_tilde*(self.l_M_tilde - self.c5)**2/(self.c6 + self.c7*self.l_M_tilde)**3 + *exp(-(self.l_M_tilde - self.c5)**2/(2*(self.c6 + self.c7*self.l_M_tilde)**2)) + ) + assert fl_M_act.diff(self.c7) == expected + + def test_differentiate_wrt_c8(self): + fl_M_act = FiberForceLengthActiveDeGroote2016(self.l_M_tilde, *self.constants) + expected = exp(-(self.l_M_tilde - self.c9)**2/(2*(self.c10 + self.c11*self.l_M_tilde)**2)) + assert fl_M_act.diff(self.c8) == expected + + def test_differentiate_wrt_c9(self): + fl_M_act = FiberForceLengthActiveDeGroote2016(self.l_M_tilde, *self.constants) + expected = ( + self.c8*(self.l_M_tilde - self.c9)/(self.c10 + self.c11*self.l_M_tilde)**2 + *exp(-(self.l_M_tilde - self.c9)**2/(2*(self.c10 + self.c11*self.l_M_tilde)**2)) + ) + assert fl_M_act.diff(self.c9) == expected + + def test_differentiate_wrt_c10(self): + fl_M_act = FiberForceLengthActiveDeGroote2016(self.l_M_tilde, *self.constants) + expected = ( + self.c8*(self.l_M_tilde - self.c9)**2/(self.c10 + self.c11*self.l_M_tilde)**3 + *exp(-(self.l_M_tilde - self.c9)**2/(2*(self.c10 + self.c11*self.l_M_tilde)**2)) + ) + assert fl_M_act.diff(self.c10) == expected + + def test_differentiate_wrt_c11(self): + fl_M_act = FiberForceLengthActiveDeGroote2016(self.l_M_tilde, *self.constants) + expected = ( + self.c8*self.l_M_tilde*(self.l_M_tilde - self.c9)**2/(self.c10 + self.c11*self.l_M_tilde)**3 + *exp(-(self.l_M_tilde - self.c9)**2/(2*(self.c10 + self.c11*self.l_M_tilde)**2)) + ) + assert fl_M_act.diff(self.c11) == expected + + def test_function_print_latex(self): + fl_M_act = FiberForceLengthActiveDeGroote2016(self.l_M_tilde, *self.constants) + expected = r'\operatorname{fl}^M_{act} \left( l_{M tilde} \right)' + assert LatexPrinter().doprint(fl_M_act) == expected + + def test_expression_print_latex(self): + fl_M_act = FiberForceLengthActiveDeGroote2016(self.l_M_tilde, *self.constants) + expected = ( + r'c_{0} e^{- \frac{\left(- c_{1} + l_{M tilde}\right)^{2}}{2 \left(c_{2} + c_{3} l_{M tilde}\right)^{2}}} ' + r'+ c_{4} e^{- \frac{\left(- c_{5} + l_{M tilde}\right)^{2}}{2 \left(c_{6} + c_{7} l_{M tilde}\right)^{2}}} ' + r'+ c_{8} e^{- \frac{\left(- c_{9} + l_{M tilde}\right)^{2}}{2 \left(c_{10} + c_{11} l_{M tilde}\right)^{2}}}' + ) + assert LatexPrinter().doprint(fl_M_act.doit()) == expected + + @pytest.mark.parametrize( + 'code_printer, expected', + [ + ( + C89CodePrinter, + ( + '(0.81399999999999995*exp(-1.0/2.0*pow(l_M_tilde - 1.0600000000000001, 2)/pow(0.063299999999999995*l_M_tilde + 0.16200000000000001, 2)) + 0.433*exp(-1.0/2.0*pow(l_M_tilde - 0.71699999999999997, 2)/pow(0.20000000000000001*l_M_tilde - 0.029899999999999999, 2)) + 0.10000000000000001*exp(-3.9899134986753491*pow(l_M_tilde - 1.0, 2)))' + ), + ), + ( + C99CodePrinter, + ( + '(0.81399999999999995*exp(-1.0/2.0*pow(l_M_tilde - 1.0600000000000001, 2)/pow(0.063299999999999995*l_M_tilde + 0.16200000000000001, 2)) + 0.433*exp(-1.0/2.0*pow(l_M_tilde - 0.71699999999999997, 2)/pow(0.20000000000000001*l_M_tilde - 0.029899999999999999, 2)) + 0.10000000000000001*exp(-3.9899134986753491*pow(l_M_tilde - 1.0, 2)))' + ), + ), + ( + C11CodePrinter, + ( + '(0.81399999999999995*exp(-1.0/2.0*pow(l_M_tilde - 1.0600000000000001, 2)/pow(0.063299999999999995*l_M_tilde + 0.16200000000000001, 2)) + 0.433*exp(-1.0/2.0*pow(l_M_tilde - 0.71699999999999997, 2)/pow(0.20000000000000001*l_M_tilde - 0.029899999999999999, 2)) + 0.10000000000000001*exp(-3.9899134986753491*pow(l_M_tilde - 1.0, 2)))' + ), + ), + ( + CXX98CodePrinter, + ( + '(0.81399999999999995*exp(-1.0/2.0*std::pow(l_M_tilde - 1.0600000000000001, 2)/std::pow(0.063299999999999995*l_M_tilde + 0.16200000000000001, 2)) + 0.433*exp(-1.0/2.0*std::pow(l_M_tilde - 0.71699999999999997, 2)/std::pow(0.20000000000000001*l_M_tilde - 0.029899999999999999, 2)) + 0.10000000000000001*exp(-3.9899134986753491*std::pow(l_M_tilde - 1.0, 2)))' + ), + ), + ( + CXX11CodePrinter, + ( + '(0.81399999999999995*std::exp(-1.0/2.0*std::pow(l_M_tilde - 1.0600000000000001, 2)/std::pow(0.063299999999999995*l_M_tilde + 0.16200000000000001, 2)) + 0.433*std::exp(-1.0/2.0*std::pow(l_M_tilde - 0.71699999999999997, 2)/std::pow(0.20000000000000001*l_M_tilde - 0.029899999999999999, 2)) + 0.10000000000000001*std::exp(-3.9899134986753491*std::pow(l_M_tilde - 1.0, 2)))' + ), + ), + ( + CXX17CodePrinter, + ( + '(0.81399999999999995*std::exp(-1.0/2.0*std::pow(l_M_tilde - 1.0600000000000001, 2)/std::pow(0.063299999999999995*l_M_tilde + 0.16200000000000001, 2)) + 0.433*std::exp(-1.0/2.0*std::pow(l_M_tilde - 0.71699999999999997, 2)/std::pow(0.20000000000000001*l_M_tilde - 0.029899999999999999, 2)) + 0.10000000000000001*std::exp(-3.9899134986753491*std::pow(l_M_tilde - 1.0, 2)))' + ), + ), + ( + FCodePrinter, + ( + ' (0.814d0*exp(-0.5d0*(l_M_tilde - 1.06d0)**2/(\n' + ' @ 0.063299999999999995d0*l_M_tilde + 0.16200000000000001d0)**2) +\n' + ' @ 0.433d0*exp(-0.5d0*(l_M_tilde - 0.717d0)**2/(\n' + ' @ 0.20000000000000001d0*l_M_tilde - 0.029899999999999999d0)**2) +\n' + ' @ 0.1d0*exp(-3.9899134986753491d0*(l_M_tilde - 1.0d0)**2))' + ), + ), + ( + OctaveCodePrinter, + ( + '(0.814*exp(-(l_M_tilde - 1.06).^2./(2*(0.0633*l_M_tilde + 0.162).^2)) + 0.433*exp(-(l_M_tilde - 0.717).^2./(2*(0.2*l_M_tilde - 0.0299).^2)) + 0.1*exp(-3.98991349867535*(l_M_tilde - 1.0).^2))' + ), + ), + ( + PythonCodePrinter, + ( + '(0.814*math.exp(-1/2*(l_M_tilde - 1.06)**2/(0.0633*l_M_tilde + 0.162)**2) + 0.433*math.exp(-1/2*(l_M_tilde - 0.717)**2/(0.2*l_M_tilde - 0.0299)**2) + 0.1*math.exp(-3.98991349867535*(l_M_tilde - 1.0)**2))' + ), + ), + ( + NumPyPrinter, + ( + '(0.814*numpy.exp(-1/2*(l_M_tilde - 1.06)**2/(0.0633*l_M_tilde + 0.162)**2) + 0.433*numpy.exp(-1/2*(l_M_tilde - 0.717)**2/(0.2*l_M_tilde - 0.0299)**2) + 0.1*numpy.exp(-3.98991349867535*(l_M_tilde - 1.0)**2))' + ), + ), + ( + SciPyPrinter, + ( + '(0.814*numpy.exp(-1/2*(l_M_tilde - 1.06)**2/(0.0633*l_M_tilde + 0.162)**2) + 0.433*numpy.exp(-1/2*(l_M_tilde - 0.717)**2/(0.2*l_M_tilde - 0.0299)**2) + 0.1*numpy.exp(-3.98991349867535*(l_M_tilde - 1.0)**2))' + ), + ), + ( + CuPyPrinter, + ( + '(0.814*cupy.exp(-1/2*(l_M_tilde - 1.06)**2/(0.0633*l_M_tilde + 0.162)**2) + 0.433*cupy.exp(-1/2*(l_M_tilde - 0.717)**2/(0.2*l_M_tilde - 0.0299)**2) + 0.1*cupy.exp(-3.98991349867535*(l_M_tilde - 1.0)**2))' + ), + ), + ( + JaxPrinter, + ( + '(0.814*jax.numpy.exp(-1/2*(l_M_tilde - 1.06)**2/(0.0633*l_M_tilde + 0.162)**2) + 0.433*jax.numpy.exp(-1/2*(l_M_tilde - 0.717)**2/(0.2*l_M_tilde - 0.0299)**2) + 0.1*jax.numpy.exp(-3.98991349867535*(l_M_tilde - 1.0)**2))' + ), + ), + ( + MpmathPrinter, + ( + '(mpmath.mpf((0, 7331860193359167, -53, 53))*mpmath.exp(-mpmath.mpf(1)/mpmath.mpf(2)*(l_M_tilde + mpmath.mpf((1, 2386907802506363, -51, 52)))**2/(mpmath.mpf((0, 2280622851300419, -55, 52))*l_M_tilde + mpmath.mpf((0, 5836665117072163, -55, 53)))**2) + mpmath.mpf((0, 7800234554605699, -54, 53))*mpmath.exp(-mpmath.mpf(1)/mpmath.mpf(2)*(l_M_tilde + mpmath.mpf((1, 6458161865649291, -53, 53)))**2/(mpmath.mpf((0, 3602879701896397, -54, 52))*l_M_tilde + mpmath.mpf((1, 8618088246936181, -58, 53)))**2) + mpmath.mpf((0, 3602879701896397, -55, 52))*mpmath.exp(-mpmath.mpf((0, 8984486472937407, -51, 53))*(l_M_tilde + mpmath.mpf((1, 1, 0, 1)))**2))' + ), + ), + ( + LambdaPrinter, + ( + '(0.814*math.exp(-1/2*(l_M_tilde - 1.06)**2/(0.0633*l_M_tilde + 0.162)**2) + 0.433*math.exp(-1/2*(l_M_tilde - 0.717)**2/(0.2*l_M_tilde - 0.0299)**2) + 0.1*math.exp(-3.98991349867535*(l_M_tilde - 1.0)**2))' + ), + ), + ] + ) + def test_print_code(self, code_printer, expected): + fl_M_act = FiberForceLengthActiveDeGroote2016.with_defaults(self.l_M_tilde) + assert code_printer().doprint(fl_M_act) == expected + + def test_derivative_print_code(self): + fl_M_act = FiberForceLengthActiveDeGroote2016.with_defaults(self.l_M_tilde) + fl_M_act_dl_M_tilde = fl_M_act.diff(self.l_M_tilde) + expected = ( + '(0.79798269973507 - 0.79798269973507*l_M_tilde)*math.exp(-3.98991349867535*(l_M_tilde - 1.0)**2) + (0.433*(0.717 - l_M_tilde)/(0.2*l_M_tilde - 0.0299)**2 + 0.0866*(l_M_tilde - 0.717)**2/(0.2*l_M_tilde - 0.0299)**3)*math.exp(-1/2*(l_M_tilde - 0.717)**2/(0.2*l_M_tilde - 0.0299)**2) + (0.814*(1.06 - l_M_tilde)/(0.0633*l_M_tilde + 0.162)**2 + 0.0515262*(l_M_tilde - 1.06)**2/(0.0633*l_M_tilde + 0.162)**3)*math.exp(-1/2*(l_M_tilde - 1.06)**2/(0.0633*l_M_tilde + 0.162)**2)' + ) + assert PythonCodePrinter().doprint(fl_M_act_dl_M_tilde) == expected + + def test_lambdify(self): + fl_M_act = FiberForceLengthActiveDeGroote2016.with_defaults(self.l_M_tilde) + fl_M_act_callable = lambdify(self.l_M_tilde, fl_M_act) + assert fl_M_act_callable(1.0) == pytest.approx(0.9941398866) + + @pytest.mark.skipif(numpy is None, reason='NumPy not installed') + def test_lambdify_numpy(self): + fl_M_act = FiberForceLengthActiveDeGroote2016.with_defaults(self.l_M_tilde) + fl_M_act_callable = lambdify(self.l_M_tilde, fl_M_act, 'numpy') + l_M_tilde = numpy.array([0.0, 0.5, 1.0, 1.5, 2.0]) + expected = numpy.array([ + 0.0018501319, + 0.0529122812, + 0.9941398866, + 0.2312431531, + 0.0069595432, + ]) + numpy.testing.assert_allclose(fl_M_act_callable(l_M_tilde), expected) + + @pytest.mark.skipif(jax is None, reason='JAX not installed') + def test_lambdify_jax(self): + fl_M_act = FiberForceLengthActiveDeGroote2016.with_defaults(self.l_M_tilde) + fl_M_act_callable = jax.jit(lambdify(self.l_M_tilde, fl_M_act, 'jax')) + l_M_tilde = jax.numpy.array([0.0, 0.5, 1.0, 1.5, 2.0]) + expected = jax.numpy.array([ + 0.0018501319, + 0.0529122812, + 0.9941398866, + 0.2312431531, + 0.0069595432, + ]) + numpy.testing.assert_allclose(fl_M_act_callable(l_M_tilde), expected) + + +class TestFiberForceVelocityDeGroote2016: + + @pytest.fixture(autouse=True) + def _muscle_fiber_force_velocity_arguments_fixture(self): + self.v_M_tilde = Symbol('v_M_tilde') + self.c0 = Symbol('c_0') + self.c1 = Symbol('c_1') + self.c2 = Symbol('c_2') + self.c3 = Symbol('c_3') + self.constants = (self.c0, self.c1, self.c2, self.c3) + + @staticmethod + def test_class(): + assert issubclass(FiberForceVelocityDeGroote2016, Function) + assert issubclass(FiberForceVelocityDeGroote2016, CharacteristicCurveFunction) + assert FiberForceVelocityDeGroote2016.__name__ == 'FiberForceVelocityDeGroote2016' + + def test_instance(self): + fv_M = FiberForceVelocityDeGroote2016(self.v_M_tilde, *self.constants) + assert isinstance(fv_M, FiberForceVelocityDeGroote2016) + assert str(fv_M) == 'FiberForceVelocityDeGroote2016(v_M_tilde, c_0, c_1, c_2, c_3)' + + def test_doit(self): + fv_M = FiberForceVelocityDeGroote2016(self.v_M_tilde, *self.constants).doit() + expected = ( + self.c0 * log((self.c1 * self.v_M_tilde + self.c2) + + sqrt((self.c1 * self.v_M_tilde + self.c2)**2 + 1)) + self.c3 + ) + assert fv_M == expected + + def test_doit_evaluate_false(self): + fv_M = FiberForceVelocityDeGroote2016(self.v_M_tilde, *self.constants).doit(evaluate=False) + expected = ( + self.c0 * log((self.c1 * self.v_M_tilde + self.c2) + + sqrt(UnevaluatedExpr(self.c1 * self.v_M_tilde + self.c2)**2 + 1)) + self.c3 + ) + assert fv_M == expected + + def test_with_defaults(self): + constants = ( + Float('-0.318'), + Float('-8.149'), + Float('-0.374'), + Float('0.886'), + ) + fv_M_manual = FiberForceVelocityDeGroote2016(self.v_M_tilde, *constants) + fv_M_constants = FiberForceVelocityDeGroote2016.with_defaults(self.v_M_tilde) + assert fv_M_manual == fv_M_constants + + def test_differentiate_wrt_v_M_tilde(self): + fv_M = FiberForceVelocityDeGroote2016(self.v_M_tilde, *self.constants) + expected = ( + self.c0*self.c1 + /sqrt(UnevaluatedExpr(self.c1*self.v_M_tilde + self.c2)**2 + 1) + ) + assert fv_M.diff(self.v_M_tilde) == expected + + def test_differentiate_wrt_c0(self): + fv_M = FiberForceVelocityDeGroote2016(self.v_M_tilde, *self.constants) + expected = log( + self.c1*self.v_M_tilde + self.c2 + + sqrt(UnevaluatedExpr(self.c1*self.v_M_tilde + self.c2)**2 + 1) + ) + assert fv_M.diff(self.c0) == expected + + def test_differentiate_wrt_c1(self): + fv_M = FiberForceVelocityDeGroote2016(self.v_M_tilde, *self.constants) + expected = ( + self.c0*self.v_M_tilde + /sqrt(UnevaluatedExpr(self.c1*self.v_M_tilde + self.c2)**2 + 1) + ) + assert fv_M.diff(self.c1) == expected + + def test_differentiate_wrt_c2(self): + fv_M = FiberForceVelocityDeGroote2016(self.v_M_tilde, *self.constants) + expected = ( + self.c0 + /sqrt(UnevaluatedExpr(self.c1*self.v_M_tilde + self.c2)**2 + 1) + ) + assert fv_M.diff(self.c2) == expected + + def test_differentiate_wrt_c3(self): + fv_M = FiberForceVelocityDeGroote2016(self.v_M_tilde, *self.constants) + expected = Integer(1) + assert fv_M.diff(self.c3) == expected + + def test_inverse(self): + fv_M = FiberForceVelocityDeGroote2016(self.v_M_tilde, *self.constants) + assert fv_M.inverse() is FiberForceVelocityInverseDeGroote2016 + + def test_function_print_latex(self): + fv_M = FiberForceVelocityDeGroote2016(self.v_M_tilde, *self.constants) + expected = r'\operatorname{fv}^M \left( v_{M tilde} \right)' + assert LatexPrinter().doprint(fv_M) == expected + + def test_expression_print_latex(self): + fv_M = FiberForceVelocityDeGroote2016(self.v_M_tilde, *self.constants) + expected = ( + r'c_{0} \log{\left(c_{1} v_{M tilde} + c_{2} + \sqrt{\left(c_{1} ' + r'v_{M tilde} + c_{2}\right)^{2} + 1} \right)} + c_{3}' + ) + assert LatexPrinter().doprint(fv_M.doit()) == expected + + @pytest.mark.parametrize( + 'code_printer, expected', + [ + ( + C89CodePrinter, + '(0.88600000000000001 - 0.318*log(-8.1489999999999991*v_M_tilde ' + '- 0.374 + sqrt(1 + pow(-8.1489999999999991*v_M_tilde - 0.374, 2))))', + ), + ( + C99CodePrinter, + '(0.88600000000000001 - 0.318*log(-8.1489999999999991*v_M_tilde ' + '- 0.374 + sqrt(1 + pow(-8.1489999999999991*v_M_tilde - 0.374, 2))))', + ), + ( + C11CodePrinter, + '(0.88600000000000001 - 0.318*log(-8.1489999999999991*v_M_tilde ' + '- 0.374 + sqrt(1 + pow(-8.1489999999999991*v_M_tilde - 0.374, 2))))', + ), + ( + CXX98CodePrinter, + '(0.88600000000000001 - 0.318*log(-8.1489999999999991*v_M_tilde ' + '- 0.374 + std::sqrt(1 + std::pow(-8.1489999999999991*v_M_tilde - 0.374, 2))))', + ), + ( + CXX11CodePrinter, + '(0.88600000000000001 - 0.318*std::log(-8.1489999999999991*v_M_tilde ' + '- 0.374 + std::sqrt(1 + std::pow(-8.1489999999999991*v_M_tilde - 0.374, 2))))', + ), + ( + CXX17CodePrinter, + '(0.88600000000000001 - 0.318*std::log(-8.1489999999999991*v_M_tilde ' + '- 0.374 + std::sqrt(1 + std::pow(-8.1489999999999991*v_M_tilde - 0.374, 2))))', + ), + ( + FCodePrinter, + ' (0.886d0 - 0.318d0*log(-8.1489999999999991d0*v_M_tilde - 0.374d0 +\n' + ' @ sqrt(1.0d0 + (-8.149d0*v_M_tilde - 0.374d0)**2)))', + ), + ( + OctaveCodePrinter, + '(0.886 - 0.318*log(-8.149*v_M_tilde - 0.374 ' + '+ sqrt(1 + (-8.149*v_M_tilde - 0.374).^2)))', + ), + ( + PythonCodePrinter, + '(0.886 - 0.318*math.log(-8.149*v_M_tilde - 0.374 ' + '+ math.sqrt(1 + (-8.149*v_M_tilde - 0.374)**2)))', + ), + ( + NumPyPrinter, + '(0.886 - 0.318*numpy.log(-8.149*v_M_tilde - 0.374 ' + '+ numpy.sqrt(1 + (-8.149*v_M_tilde - 0.374)**2)))', + ), + ( + SciPyPrinter, + '(0.886 - 0.318*numpy.log(-8.149*v_M_tilde - 0.374 ' + '+ numpy.sqrt(1 + (-8.149*v_M_tilde - 0.374)**2)))', + ), + ( + CuPyPrinter, + '(0.886 - 0.318*cupy.log(-8.149*v_M_tilde - 0.374 ' + '+ cupy.sqrt(1 + (-8.149*v_M_tilde - 0.374)**2)))', + ), + ( + JaxPrinter, + '(0.886 - 0.318*jax.numpy.log(-8.149*v_M_tilde - 0.374 ' + '+ jax.numpy.sqrt(1 + (-8.149*v_M_tilde - 0.374)**2)))', + ), + ( + MpmathPrinter, + '(mpmath.mpf((0, 7980378539700519, -53, 53)) ' + '- mpmath.mpf((0, 5728578726015271, -54, 53))' + '*mpmath.log(-mpmath.mpf((0, 4587479170430271, -49, 53))*v_M_tilde ' + '+ mpmath.mpf((1, 3368692521273131, -53, 52)) ' + '+ mpmath.sqrt(1 + (-mpmath.mpf((0, 4587479170430271, -49, 53))*v_M_tilde ' + '+ mpmath.mpf((1, 3368692521273131, -53, 52)))**2)))', + ), + ( + LambdaPrinter, + '(0.886 - 0.318*math.log(-8.149*v_M_tilde - 0.374 ' + '+ sqrt(1 + (-8.149*v_M_tilde - 0.374)**2)))', + ), + ] + ) + def test_print_code(self, code_printer, expected): + fv_M = FiberForceVelocityDeGroote2016.with_defaults(self.v_M_tilde) + assert code_printer().doprint(fv_M) == expected + + def test_derivative_print_code(self): + fv_M = FiberForceVelocityDeGroote2016.with_defaults(self.v_M_tilde) + dfv_M_dv_M_tilde = fv_M.diff(self.v_M_tilde) + expected = '2.591382*(1 + (-8.149*v_M_tilde - 0.374)**2)**(-1/2)' + assert PythonCodePrinter().doprint(dfv_M_dv_M_tilde) == expected + + def test_lambdify(self): + fv_M = FiberForceVelocityDeGroote2016.with_defaults(self.v_M_tilde) + fv_M_callable = lambdify(self.v_M_tilde, fv_M) + assert fv_M_callable(0.0) == pytest.approx(1.002320622548512) + + @pytest.mark.skipif(numpy is None, reason='NumPy not installed') + def test_lambdify_numpy(self): + fv_M = FiberForceVelocityDeGroote2016.with_defaults(self.v_M_tilde) + fv_M_callable = lambdify(self.v_M_tilde, fv_M, 'numpy') + v_M_tilde = numpy.array([-1.0, -0.5, 0.0, 0.5]) + expected = numpy.array([ + 0.0120816781, + 0.2438336294, + 1.0023206225, + 1.5850003903, + ]) + numpy.testing.assert_allclose(fv_M_callable(v_M_tilde), expected) + + @pytest.mark.skipif(jax is None, reason='JAX not installed') + def test_lambdify_jax(self): + fv_M = FiberForceVelocityDeGroote2016.with_defaults(self.v_M_tilde) + fv_M_callable = jax.jit(lambdify(self.v_M_tilde, fv_M, 'jax')) + v_M_tilde = jax.numpy.array([-1.0, -0.5, 0.0, 0.5]) + expected = jax.numpy.array([ + 0.0120816781, + 0.2438336294, + 1.0023206225, + 1.5850003903, + ]) + numpy.testing.assert_allclose(fv_M_callable(v_M_tilde), expected) + + +class TestFiberForceVelocityInverseDeGroote2016: + + @pytest.fixture(autouse=True) + def _tendon_force_length_inverse_arguments_fixture(self): + self.fv_M = Symbol('fv_M') + self.c0 = Symbol('c_0') + self.c1 = Symbol('c_1') + self.c2 = Symbol('c_2') + self.c3 = Symbol('c_3') + self.constants = (self.c0, self.c1, self.c2, self.c3) + + @staticmethod + def test_class(): + assert issubclass(FiberForceVelocityInverseDeGroote2016, Function) + assert issubclass(FiberForceVelocityInverseDeGroote2016, CharacteristicCurveFunction) + assert FiberForceVelocityInverseDeGroote2016.__name__ == 'FiberForceVelocityInverseDeGroote2016' + + def test_instance(self): + fv_M_inv = FiberForceVelocityInverseDeGroote2016(self.fv_M, *self.constants) + assert isinstance(fv_M_inv, FiberForceVelocityInverseDeGroote2016) + assert str(fv_M_inv) == 'FiberForceVelocityInverseDeGroote2016(fv_M, c_0, c_1, c_2, c_3)' + + def test_doit(self): + fv_M_inv = FiberForceVelocityInverseDeGroote2016(self.fv_M, *self.constants).doit() + assert fv_M_inv == (sinh((self.fv_M - self.c3)/self.c0) - self.c2)/self.c1 + + def test_doit_evaluate_false(self): + fv_M_inv = FiberForceVelocityInverseDeGroote2016(self.fv_M, *self.constants).doit(evaluate=False) + assert fv_M_inv == (sinh(UnevaluatedExpr(self.fv_M - self.c3)/self.c0) - self.c2)/self.c1 + + def test_with_defaults(self): + constants = ( + Float('-0.318'), + Float('-8.149'), + Float('-0.374'), + Float('0.886'), + ) + fv_M_inv_manual = FiberForceVelocityInverseDeGroote2016(self.fv_M, *constants) + fv_M_inv_constants = FiberForceVelocityInverseDeGroote2016.with_defaults(self.fv_M) + assert fv_M_inv_manual == fv_M_inv_constants + + def test_differentiate_wrt_fv_M(self): + fv_M_inv = FiberForceVelocityInverseDeGroote2016(self.fv_M, *self.constants) + expected = cosh((self.fv_M - self.c3)/self.c0)/(self.c0*self.c1) + assert fv_M_inv.diff(self.fv_M) == expected + + def test_differentiate_wrt_c0(self): + fv_M_inv = FiberForceVelocityInverseDeGroote2016(self.fv_M, *self.constants) + expected = (self.c3 - self.fv_M)*cosh((self.fv_M - self.c3)/self.c0)/(self.c0**2*self.c1) + assert fv_M_inv.diff(self.c0) == expected + + def test_differentiate_wrt_c1(self): + fv_M_inv = FiberForceVelocityInverseDeGroote2016(self.fv_M, *self.constants) + expected = (self.c2 - sinh((self.fv_M - self.c3)/self.c0))/self.c1**2 + assert fv_M_inv.diff(self.c1) == expected + + def test_differentiate_wrt_c2(self): + fv_M_inv = FiberForceVelocityInverseDeGroote2016(self.fv_M, *self.constants) + expected = -1/self.c1 + assert fv_M_inv.diff(self.c2) == expected + + def test_differentiate_wrt_c3(self): + fv_M_inv = FiberForceVelocityInverseDeGroote2016(self.fv_M, *self.constants) + expected = -cosh((self.fv_M - self.c3)/self.c0)/(self.c0*self.c1) + assert fv_M_inv.diff(self.c3) == expected + + def test_inverse(self): + fv_M_inv = FiberForceVelocityInverseDeGroote2016(self.fv_M, *self.constants) + assert fv_M_inv.inverse() is FiberForceVelocityDeGroote2016 + + def test_function_print_latex(self): + fv_M_inv = FiberForceVelocityInverseDeGroote2016(self.fv_M, *self.constants) + expected = r'\left( \operatorname{fv}^M \right)^{-1} \left( fv_{M} \right)' + assert LatexPrinter().doprint(fv_M_inv) == expected + + def test_expression_print_latex(self): + fv_M = FiberForceVelocityInverseDeGroote2016(self.fv_M, *self.constants) + expected = r'\frac{- c_{2} + \sinh{\left(\frac{- c_{3} + fv_{M}}{c_{0}} \right)}}{c_{1}}' + assert LatexPrinter().doprint(fv_M.doit()) == expected + + @pytest.mark.parametrize( + 'code_printer, expected', + [ + ( + C89CodePrinter, + '(-0.12271444348999878*(0.374 - sinh(3.1446540880503142*(fv_M ' + '- 0.88600000000000001))))', + ), + ( + C99CodePrinter, + '(-0.12271444348999878*(0.374 - sinh(3.1446540880503142*(fv_M ' + '- 0.88600000000000001))))', + ), + ( + C11CodePrinter, + '(-0.12271444348999878*(0.374 - sinh(3.1446540880503142*(fv_M ' + '- 0.88600000000000001))))', + ), + ( + CXX98CodePrinter, + '(-0.12271444348999878*(0.374 - sinh(3.1446540880503142*(fv_M ' + '- 0.88600000000000001))))', + ), + ( + CXX11CodePrinter, + '(-0.12271444348999878*(0.374 - std::sinh(3.1446540880503142' + '*(fv_M - 0.88600000000000001))))', + ), + ( + CXX17CodePrinter, + '(-0.12271444348999878*(0.374 - std::sinh(3.1446540880503142' + '*(fv_M - 0.88600000000000001))))', + ), + ( + FCodePrinter, + ' (-0.122714443489999d0*(0.374d0 - sinh(3.1446540880503142d0*(fv_M -\n' + ' @ 0.886d0))))', + ), + ( + OctaveCodePrinter, + '(-0.122714443489999*(0.374 - sinh(3.14465408805031*(fv_M ' + '- 0.886))))', + ), + ( + PythonCodePrinter, + '(-0.122714443489999*(0.374 - math.sinh(3.14465408805031*(fv_M ' + '- 0.886))))', + ), + ( + NumPyPrinter, + '(-0.122714443489999*(0.374 - numpy.sinh(3.14465408805031' + '*(fv_M - 0.886))))', + ), + ( + SciPyPrinter, + '(-0.122714443489999*(0.374 - numpy.sinh(3.14465408805031' + '*(fv_M - 0.886))))', + ), + ( + CuPyPrinter, + '(-0.122714443489999*(0.374 - cupy.sinh(3.14465408805031*(fv_M ' + '- 0.886))))', + ), + ( + JaxPrinter, + '(-0.122714443489999*(0.374 - jax.numpy.sinh(3.14465408805031' + '*(fv_M - 0.886))))', + ), + ( + MpmathPrinter, + '(-mpmath.mpf((0, 8842507551592581, -56, 53))*(mpmath.mpf((0, ' + '3368692521273131, -53, 52)) - mpmath.sinh(mpmath.mpf((0, ' + '7081131489576251, -51, 53))*(fv_M + mpmath.mpf((1, ' + '7980378539700519, -53, 53))))))', + ), + ( + LambdaPrinter, + '(-0.122714443489999*(0.374 - math.sinh(3.14465408805031*(fv_M ' + '- 0.886))))', + ), + ] + ) + def test_print_code(self, code_printer, expected): + fv_M_inv = FiberForceVelocityInverseDeGroote2016.with_defaults(self.fv_M) + assert code_printer().doprint(fv_M_inv) == expected + + def test_derivative_print_code(self): + fv_M_inv = FiberForceVelocityInverseDeGroote2016.with_defaults(self.fv_M) + dfv_M_inv_dfv_M = fv_M_inv.diff(self.fv_M) + expected = ( + '0.385894476383644*math.cosh(3.14465408805031*fv_M ' + '- 2.78616352201258)' + ) + assert PythonCodePrinter().doprint(dfv_M_inv_dfv_M) == expected + + def test_lambdify(self): + fv_M_inv = FiberForceVelocityInverseDeGroote2016.with_defaults(self.fv_M) + fv_M_inv_callable = lambdify(self.fv_M, fv_M_inv) + assert fv_M_inv_callable(1.0) == pytest.approx(-0.0009548832444487479) + + @pytest.mark.skipif(numpy is None, reason='NumPy not installed') + def test_lambdify_numpy(self): + fv_M_inv = FiberForceVelocityInverseDeGroote2016.with_defaults(self.fv_M) + fv_M_inv_callable = lambdify(self.fv_M, fv_M_inv, 'numpy') + fv_M = numpy.array([0.8, 0.9, 1.0, 1.1, 1.2]) + expected = numpy.array([ + -0.0794881459, + -0.0404909338, + -0.0009548832, + 0.043061991, + 0.0959484397, + ]) + numpy.testing.assert_allclose(fv_M_inv_callable(fv_M), expected) + + @pytest.mark.skipif(jax is None, reason='JAX not installed') + def test_lambdify_jax(self): + fv_M_inv = FiberForceVelocityInverseDeGroote2016.with_defaults(self.fv_M) + fv_M_inv_callable = jax.jit(lambdify(self.fv_M, fv_M_inv, 'jax')) + fv_M = jax.numpy.array([0.8, 0.9, 1.0, 1.1, 1.2]) + expected = jax.numpy.array([ + -0.0794881459, + -0.0404909338, + -0.0009548832, + 0.043061991, + 0.0959484397, + ]) + numpy.testing.assert_allclose(fv_M_inv_callable(fv_M), expected) + + +class TestCharacteristicCurveCollection: + + @staticmethod + def test_valid_constructor(): + curves = CharacteristicCurveCollection( + tendon_force_length=TendonForceLengthDeGroote2016, + tendon_force_length_inverse=TendonForceLengthInverseDeGroote2016, + fiber_force_length_passive=FiberForceLengthPassiveDeGroote2016, + fiber_force_length_passive_inverse=FiberForceLengthPassiveInverseDeGroote2016, + fiber_force_length_active=FiberForceLengthActiveDeGroote2016, + fiber_force_velocity=FiberForceVelocityDeGroote2016, + fiber_force_velocity_inverse=FiberForceVelocityInverseDeGroote2016, + ) + assert curves.tendon_force_length is TendonForceLengthDeGroote2016 + assert curves.tendon_force_length_inverse is TendonForceLengthInverseDeGroote2016 + assert curves.fiber_force_length_passive is FiberForceLengthPassiveDeGroote2016 + assert curves.fiber_force_length_passive_inverse is FiberForceLengthPassiveInverseDeGroote2016 + assert curves.fiber_force_length_active is FiberForceLengthActiveDeGroote2016 + assert curves.fiber_force_velocity is FiberForceVelocityDeGroote2016 + assert curves.fiber_force_velocity_inverse is FiberForceVelocityInverseDeGroote2016 + + @staticmethod + @pytest.mark.skip(reason='kw_only dataclasses only valid in Python >3.10') + def test_invalid_constructor_keyword_only(): + with pytest.raises(TypeError): + _ = CharacteristicCurveCollection( + TendonForceLengthDeGroote2016, + TendonForceLengthInverseDeGroote2016, + FiberForceLengthPassiveDeGroote2016, + FiberForceLengthPassiveInverseDeGroote2016, + FiberForceLengthActiveDeGroote2016, + FiberForceVelocityDeGroote2016, + FiberForceVelocityInverseDeGroote2016, + ) + + @staticmethod + @pytest.mark.parametrize( + 'kwargs', + [ + {'tendon_force_length': TendonForceLengthDeGroote2016}, + { + 'tendon_force_length': TendonForceLengthDeGroote2016, + 'tendon_force_length_inverse': TendonForceLengthInverseDeGroote2016, + 'fiber_force_length_passive': FiberForceLengthPassiveDeGroote2016, + 'fiber_force_length_passive_inverse': FiberForceLengthPassiveInverseDeGroote2016, + 'fiber_force_length_active': FiberForceLengthActiveDeGroote2016, + 'fiber_force_velocity': FiberForceVelocityDeGroote2016, + 'fiber_force_velocity_inverse': FiberForceVelocityInverseDeGroote2016, + 'extra_kwarg': None, + }, + ] + ) + def test_invalid_constructor_wrong_number_args(kwargs): + with pytest.raises(TypeError): + _ = CharacteristicCurveCollection(**kwargs) + + @staticmethod + def test_instance_is_immutable(): + curves = CharacteristicCurveCollection( + tendon_force_length=TendonForceLengthDeGroote2016, + tendon_force_length_inverse=TendonForceLengthInverseDeGroote2016, + fiber_force_length_passive=FiberForceLengthPassiveDeGroote2016, + fiber_force_length_passive_inverse=FiberForceLengthPassiveInverseDeGroote2016, + fiber_force_length_active=FiberForceLengthActiveDeGroote2016, + fiber_force_velocity=FiberForceVelocityDeGroote2016, + fiber_force_velocity_inverse=FiberForceVelocityInverseDeGroote2016, + ) + with pytest.raises(AttributeError): + curves.tendon_force_length = None + with pytest.raises(AttributeError): + curves.tendon_force_length_inverse = None + with pytest.raises(AttributeError): + curves.fiber_force_length_passive = None + with pytest.raises(AttributeError): + curves.fiber_force_length_passive_inverse = None + with pytest.raises(AttributeError): + curves.fiber_force_length_active = None + with pytest.raises(AttributeError): + curves.fiber_force_velocity = None + with pytest.raises(AttributeError): + curves.fiber_force_velocity_inverse = None diff --git a/phivenv/Lib/site-packages/sympy/physics/biomechanics/tests/test_mixin.py b/phivenv/Lib/site-packages/sympy/physics/biomechanics/tests/test_mixin.py new file mode 100644 index 0000000000000000000000000000000000000000..be079c195f3d961a88f52c94b695666f2a4f2bb5 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/biomechanics/tests/test_mixin.py @@ -0,0 +1,48 @@ +"""Tests for the ``sympy.physics.biomechanics._mixin.py`` module.""" + +import pytest + +from sympy.physics.biomechanics._mixin import _NamedMixin + + +class TestNamedMixin: + + @staticmethod + def test_subclass(): + + class Subclass(_NamedMixin): + + def __init__(self, name): + self.name = name + + instance = Subclass('name') + assert instance.name == 'name' + + @pytest.fixture(autouse=True) + def _named_mixin_fixture(self): + + class Subclass(_NamedMixin): + + def __init__(self, name): + self.name = name + + self.Subclass = Subclass + + @pytest.mark.parametrize('name', ['a', 'name', 'long_name']) + def test_valid_name_argument(self, name): + instance = self.Subclass(name) + assert instance.name == name + + @pytest.mark.parametrize('invalid_name', [0, 0.0, None, False]) + def test_invalid_name_argument_not_str(self, invalid_name): + with pytest.raises(TypeError): + _ = self.Subclass(invalid_name) + + def test_invalid_name_argument_zero_length_str(self): + with pytest.raises(ValueError): + _ = self.Subclass('') + + def test_name_attribute_is_immutable(self): + instance = self.Subclass('name') + with pytest.raises(AttributeError): + instance.name = 'new_name' diff --git a/phivenv/Lib/site-packages/sympy/physics/biomechanics/tests/test_musculotendon.py b/phivenv/Lib/site-packages/sympy/physics/biomechanics/tests/test_musculotendon.py new file mode 100644 index 0000000000000000000000000000000000000000..d0c5a1088214049aaaaa3666854e232d26f77786 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/biomechanics/tests/test_musculotendon.py @@ -0,0 +1,837 @@ +"""Tests for the ``sympy.physics.biomechanics.musculotendon.py`` module.""" + +import abc + +import pytest + +from sympy.core.expr import UnevaluatedExpr +from sympy.core.numbers import Float, Integer, Rational +from sympy.core.symbol import Symbol +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.hyperbolic import tanh +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import sin +from sympy.matrices.dense import MutableDenseMatrix as Matrix, eye, zeros +from sympy.physics.biomechanics.activation import ( + FirstOrderActivationDeGroote2016 +) +from sympy.physics.biomechanics.curve import ( + CharacteristicCurveCollection, + FiberForceLengthActiveDeGroote2016, + FiberForceLengthPassiveDeGroote2016, + FiberForceLengthPassiveInverseDeGroote2016, + FiberForceVelocityDeGroote2016, + FiberForceVelocityInverseDeGroote2016, + TendonForceLengthDeGroote2016, + TendonForceLengthInverseDeGroote2016, +) +from sympy.physics.biomechanics.musculotendon import ( + MusculotendonBase, + MusculotendonDeGroote2016, + MusculotendonFormulation, +) +from sympy.physics.biomechanics._mixin import _NamedMixin +from sympy.physics.mechanics.actuator import ForceActuator +from sympy.physics.mechanics.pathway import LinearPathway +from sympy.physics.vector.frame import ReferenceFrame +from sympy.physics.vector.functions import dynamicsymbols +from sympy.physics.vector.point import Point +from sympy.simplify.simplify import simplify + + +class TestMusculotendonFormulation: + @staticmethod + def test_rigid_tendon_member(): + assert MusculotendonFormulation(0) == 0 + assert MusculotendonFormulation.RIGID_TENDON == 0 + + @staticmethod + def test_fiber_length_explicit_member(): + assert MusculotendonFormulation(1) == 1 + assert MusculotendonFormulation.FIBER_LENGTH_EXPLICIT == 1 + + @staticmethod + def test_tendon_force_explicit_member(): + assert MusculotendonFormulation(2) == 2 + assert MusculotendonFormulation.TENDON_FORCE_EXPLICIT == 2 + + @staticmethod + def test_fiber_length_implicit_member(): + assert MusculotendonFormulation(3) == 3 + assert MusculotendonFormulation.FIBER_LENGTH_IMPLICIT == 3 + + @staticmethod + def test_tendon_force_implicit_member(): + assert MusculotendonFormulation(4) == 4 + assert MusculotendonFormulation.TENDON_FORCE_IMPLICIT == 4 + + +class TestMusculotendonBase: + + @staticmethod + def test_is_abstract_base_class(): + assert issubclass(MusculotendonBase, abc.ABC) + + @staticmethod + def test_class(): + assert issubclass(MusculotendonBase, ForceActuator) + assert issubclass(MusculotendonBase, _NamedMixin) + assert MusculotendonBase.__name__ == 'MusculotendonBase' + + @staticmethod + def test_cannot_instantiate_directly(): + with pytest.raises(TypeError): + _ = MusculotendonBase() + + +@pytest.mark.parametrize('musculotendon_concrete', [MusculotendonDeGroote2016]) +class TestMusculotendonRigidTendon: + + @pytest.fixture(autouse=True) + def _musculotendon_rigid_tendon_fixture(self, musculotendon_concrete): + self.name = 'name' + self.N = ReferenceFrame('N') + self.q = dynamicsymbols('q') + self.origin = Point('pO') + self.insertion = Point('pI') + self.insertion.set_pos(self.origin, self.q*self.N.x) + self.pathway = LinearPathway(self.origin, self.insertion) + self.activation = FirstOrderActivationDeGroote2016(self.name) + self.e = self.activation.excitation + self.a = self.activation.activation + self.tau_a = self.activation.activation_time_constant + self.tau_d = self.activation.deactivation_time_constant + self.b = self.activation.smoothing_rate + self.formulation = MusculotendonFormulation.RIGID_TENDON + self.l_T_slack = Symbol('l_T_slack') + self.F_M_max = Symbol('F_M_max') + self.l_M_opt = Symbol('l_M_opt') + self.v_M_max = Symbol('v_M_max') + self.alpha_opt = Symbol('alpha_opt') + self.beta = Symbol('beta') + self.instance = musculotendon_concrete( + self.name, + self.pathway, + self.activation, + musculotendon_dynamics=self.formulation, + tendon_slack_length=self.l_T_slack, + peak_isometric_force=self.F_M_max, + optimal_fiber_length=self.l_M_opt, + maximal_fiber_velocity=self.v_M_max, + optimal_pennation_angle=self.alpha_opt, + fiber_damping_coefficient=self.beta, + ) + self.da_expr = ( + (1/(self.tau_a*(Rational(1, 2) + Rational(3, 2)*self.a))) + *(Rational(1, 2) + Rational(1, 2)*tanh(self.b*(self.e - self.a))) + + ((Rational(1, 2) + Rational(3, 2)*self.a)/self.tau_d) + *(Rational(1, 2) - Rational(1, 2)*tanh(self.b*(self.e - self.a))) + )*(self.e - self.a) + + def test_state_vars(self): + assert hasattr(self.instance, 'x') + assert hasattr(self.instance, 'state_vars') + assert self.instance.x == self.instance.state_vars + x_expected = Matrix([self.a]) + assert self.instance.x == x_expected + assert self.instance.state_vars == x_expected + assert isinstance(self.instance.x, Matrix) + assert isinstance(self.instance.state_vars, Matrix) + assert self.instance.x.shape == (1, 1) + assert self.instance.state_vars.shape == (1, 1) + + def test_input_vars(self): + assert hasattr(self.instance, 'r') + assert hasattr(self.instance, 'input_vars') + assert self.instance.r == self.instance.input_vars + r_expected = Matrix([self.e]) + assert self.instance.r == r_expected + assert self.instance.input_vars == r_expected + assert isinstance(self.instance.r, Matrix) + assert isinstance(self.instance.input_vars, Matrix) + assert self.instance.r.shape == (1, 1) + assert self.instance.input_vars.shape == (1, 1) + + def test_constants(self): + assert hasattr(self.instance, 'p') + assert hasattr(self.instance, 'constants') + assert self.instance.p == self.instance.constants + p_expected = Matrix( + [ + self.l_T_slack, + self.F_M_max, + self.l_M_opt, + self.v_M_max, + self.alpha_opt, + self.beta, + self.tau_a, + self.tau_d, + self.b, + Symbol('c_0_fl_T_name'), + Symbol('c_1_fl_T_name'), + Symbol('c_2_fl_T_name'), + Symbol('c_3_fl_T_name'), + Symbol('c_0_fl_M_pas_name'), + Symbol('c_1_fl_M_pas_name'), + Symbol('c_0_fl_M_act_name'), + Symbol('c_1_fl_M_act_name'), + Symbol('c_2_fl_M_act_name'), + Symbol('c_3_fl_M_act_name'), + Symbol('c_4_fl_M_act_name'), + Symbol('c_5_fl_M_act_name'), + Symbol('c_6_fl_M_act_name'), + Symbol('c_7_fl_M_act_name'), + Symbol('c_8_fl_M_act_name'), + Symbol('c_9_fl_M_act_name'), + Symbol('c_10_fl_M_act_name'), + Symbol('c_11_fl_M_act_name'), + Symbol('c_0_fv_M_name'), + Symbol('c_1_fv_M_name'), + Symbol('c_2_fv_M_name'), + Symbol('c_3_fv_M_name'), + ] + ) + assert self.instance.p == p_expected + assert self.instance.constants == p_expected + assert isinstance(self.instance.p, Matrix) + assert isinstance(self.instance.constants, Matrix) + assert self.instance.p.shape == (31, 1) + assert self.instance.constants.shape == (31, 1) + + def test_M(self): + assert hasattr(self.instance, 'M') + M_expected = Matrix([1]) + assert self.instance.M == M_expected + assert isinstance(self.instance.M, Matrix) + assert self.instance.M.shape == (1, 1) + + def test_F(self): + assert hasattr(self.instance, 'F') + F_expected = Matrix([self.da_expr]) + assert self.instance.F == F_expected + assert isinstance(self.instance.F, Matrix) + assert self.instance.F.shape == (1, 1) + + def test_rhs(self): + assert hasattr(self.instance, 'rhs') + rhs_expected = Matrix([self.da_expr]) + rhs = self.instance.rhs() + assert isinstance(rhs, Matrix) + assert rhs.shape == (1, 1) + assert simplify(rhs - rhs_expected) == zeros(1) + + +@pytest.mark.parametrize( + 'musculotendon_concrete, curve', + [ + ( + MusculotendonDeGroote2016, + CharacteristicCurveCollection( + tendon_force_length=TendonForceLengthDeGroote2016, + tendon_force_length_inverse=TendonForceLengthInverseDeGroote2016, + fiber_force_length_passive=FiberForceLengthPassiveDeGroote2016, + fiber_force_length_passive_inverse=FiberForceLengthPassiveInverseDeGroote2016, + fiber_force_length_active=FiberForceLengthActiveDeGroote2016, + fiber_force_velocity=FiberForceVelocityDeGroote2016, + fiber_force_velocity_inverse=FiberForceVelocityInverseDeGroote2016, + ), + ) + ], +) +class TestFiberLengthExplicit: + + @pytest.fixture(autouse=True) + def _musculotendon_fiber_length_explicit_fixture( + self, + musculotendon_concrete, + curve, + ): + self.name = 'name' + self.N = ReferenceFrame('N') + self.q = dynamicsymbols('q') + self.origin = Point('pO') + self.insertion = Point('pI') + self.insertion.set_pos(self.origin, self.q*self.N.x) + self.pathway = LinearPathway(self.origin, self.insertion) + self.activation = FirstOrderActivationDeGroote2016(self.name) + self.e = self.activation.excitation + self.a = self.activation.activation + self.tau_a = self.activation.activation_time_constant + self.tau_d = self.activation.deactivation_time_constant + self.b = self.activation.smoothing_rate + self.formulation = MusculotendonFormulation.FIBER_LENGTH_EXPLICIT + self.l_T_slack = Symbol('l_T_slack') + self.F_M_max = Symbol('F_M_max') + self.l_M_opt = Symbol('l_M_opt') + self.v_M_max = Symbol('v_M_max') + self.alpha_opt = Symbol('alpha_opt') + self.beta = Symbol('beta') + self.instance = musculotendon_concrete( + self.name, + self.pathway, + self.activation, + musculotendon_dynamics=self.formulation, + tendon_slack_length=self.l_T_slack, + peak_isometric_force=self.F_M_max, + optimal_fiber_length=self.l_M_opt, + maximal_fiber_velocity=self.v_M_max, + optimal_pennation_angle=self.alpha_opt, + fiber_damping_coefficient=self.beta, + with_defaults=True, + ) + self.l_M_tilde = dynamicsymbols('l_M_tilde_name') + l_MT = self.pathway.length + l_M = self.l_M_tilde*self.l_M_opt + l_T = l_MT - sqrt(l_M**2 - (self.l_M_opt*sin(self.alpha_opt))**2) + fl_T = curve.tendon_force_length.with_defaults(l_T/self.l_T_slack) + fl_M_pas = curve.fiber_force_length_passive.with_defaults(self.l_M_tilde) + fl_M_act = curve.fiber_force_length_active.with_defaults(self.l_M_tilde) + v_M_tilde = curve.fiber_force_velocity_inverse.with_defaults( + ((((fl_T*self.F_M_max)/((l_MT - l_T)/l_M))/self.F_M_max) - fl_M_pas) + /(self.a*fl_M_act) + ) + self.dl_M_tilde_expr = (self.v_M_max/self.l_M_opt)*v_M_tilde + self.da_expr = ( + (1/(self.tau_a*(Rational(1, 2) + Rational(3, 2)*self.a))) + *(Rational(1, 2) + Rational(1, 2)*tanh(self.b*(self.e - self.a))) + + ((Rational(1, 2) + Rational(3, 2)*self.a)/self.tau_d) + *(Rational(1, 2) - Rational(1, 2)*tanh(self.b*(self.e - self.a))) + )*(self.e - self.a) + + def test_state_vars(self): + assert hasattr(self.instance, 'x') + assert hasattr(self.instance, 'state_vars') + assert self.instance.x == self.instance.state_vars + x_expected = Matrix([self.l_M_tilde, self.a]) + assert self.instance.x == x_expected + assert self.instance.state_vars == x_expected + assert isinstance(self.instance.x, Matrix) + assert isinstance(self.instance.state_vars, Matrix) + assert self.instance.x.shape == (2, 1) + assert self.instance.state_vars.shape == (2, 1) + + def test_input_vars(self): + assert hasattr(self.instance, 'r') + assert hasattr(self.instance, 'input_vars') + assert self.instance.r == self.instance.input_vars + r_expected = Matrix([self.e]) + assert self.instance.r == r_expected + assert self.instance.input_vars == r_expected + assert isinstance(self.instance.r, Matrix) + assert isinstance(self.instance.input_vars, Matrix) + assert self.instance.r.shape == (1, 1) + assert self.instance.input_vars.shape == (1, 1) + + def test_constants(self): + assert hasattr(self.instance, 'p') + assert hasattr(self.instance, 'constants') + assert self.instance.p == self.instance.constants + p_expected = Matrix( + [ + self.l_T_slack, + self.F_M_max, + self.l_M_opt, + self.v_M_max, + self.alpha_opt, + self.beta, + self.tau_a, + self.tau_d, + self.b, + ] + ) + assert self.instance.p == p_expected + assert self.instance.constants == p_expected + assert isinstance(self.instance.p, Matrix) + assert isinstance(self.instance.constants, Matrix) + assert self.instance.p.shape == (9, 1) + assert self.instance.constants.shape == (9, 1) + + def test_M(self): + assert hasattr(self.instance, 'M') + M_expected = eye(2) + assert self.instance.M == M_expected + assert isinstance(self.instance.M, Matrix) + assert self.instance.M.shape == (2, 2) + + def test_F(self): + assert hasattr(self.instance, 'F') + F_expected = Matrix([self.dl_M_tilde_expr, self.da_expr]) + assert self.instance.F == F_expected + assert isinstance(self.instance.F, Matrix) + assert self.instance.F.shape == (2, 1) + + def test_rhs(self): + assert hasattr(self.instance, 'rhs') + rhs_expected = Matrix([self.dl_M_tilde_expr, self.da_expr]) + rhs = self.instance.rhs() + assert isinstance(rhs, Matrix) + assert rhs.shape == (2, 1) + assert simplify(rhs - rhs_expected) == zeros(2, 1) + + +@pytest.mark.parametrize( + 'musculotendon_concrete, curve', + [ + ( + MusculotendonDeGroote2016, + CharacteristicCurveCollection( + tendon_force_length=TendonForceLengthDeGroote2016, + tendon_force_length_inverse=TendonForceLengthInverseDeGroote2016, + fiber_force_length_passive=FiberForceLengthPassiveDeGroote2016, + fiber_force_length_passive_inverse=FiberForceLengthPassiveInverseDeGroote2016, + fiber_force_length_active=FiberForceLengthActiveDeGroote2016, + fiber_force_velocity=FiberForceVelocityDeGroote2016, + fiber_force_velocity_inverse=FiberForceVelocityInverseDeGroote2016, + ), + ) + ], +) +class TestTendonForceExplicit: + + @pytest.fixture(autouse=True) + def _musculotendon_tendon_force_explicit_fixture( + self, + musculotendon_concrete, + curve, + ): + self.name = 'name' + self.N = ReferenceFrame('N') + self.q = dynamicsymbols('q') + self.origin = Point('pO') + self.insertion = Point('pI') + self.insertion.set_pos(self.origin, self.q*self.N.x) + self.pathway = LinearPathway(self.origin, self.insertion) + self.activation = FirstOrderActivationDeGroote2016(self.name) + self.e = self.activation.excitation + self.a = self.activation.activation + self.tau_a = self.activation.activation_time_constant + self.tau_d = self.activation.deactivation_time_constant + self.b = self.activation.smoothing_rate + self.formulation = MusculotendonFormulation.TENDON_FORCE_EXPLICIT + self.l_T_slack = Symbol('l_T_slack') + self.F_M_max = Symbol('F_M_max') + self.l_M_opt = Symbol('l_M_opt') + self.v_M_max = Symbol('v_M_max') + self.alpha_opt = Symbol('alpha_opt') + self.beta = Symbol('beta') + self.instance = musculotendon_concrete( + self.name, + self.pathway, + self.activation, + musculotendon_dynamics=self.formulation, + tendon_slack_length=self.l_T_slack, + peak_isometric_force=self.F_M_max, + optimal_fiber_length=self.l_M_opt, + maximal_fiber_velocity=self.v_M_max, + optimal_pennation_angle=self.alpha_opt, + fiber_damping_coefficient=self.beta, + with_defaults=True, + ) + self.F_T_tilde = dynamicsymbols('F_T_tilde_name') + l_T_tilde = curve.tendon_force_length_inverse.with_defaults(self.F_T_tilde) + l_MT = self.pathway.length + v_MT = self.pathway.extension_velocity + l_T = l_T_tilde*self.l_T_slack + l_M = sqrt((l_MT - l_T)**2 + (self.l_M_opt*sin(self.alpha_opt))**2) + l_M_tilde = l_M/self.l_M_opt + cos_alpha = (l_MT - l_T)/l_M + F_T = self.F_T_tilde*self.F_M_max + F_M = F_T/cos_alpha + F_M_tilde = F_M/self.F_M_max + fl_M_pas = curve.fiber_force_length_passive.with_defaults(l_M_tilde) + fl_M_act = curve.fiber_force_length_active.with_defaults(l_M_tilde) + fv_M = (F_M_tilde - fl_M_pas)/(self.a*fl_M_act) + v_M_tilde = curve.fiber_force_velocity_inverse.with_defaults(fv_M) + v_M = v_M_tilde*self.v_M_max + v_T = v_MT - v_M/cos_alpha + v_T_tilde = v_T/self.l_T_slack + self.dF_T_tilde_expr = ( + Float('0.2')*Float('33.93669377311689')*exp( + Float('33.93669377311689')*UnevaluatedExpr(l_T_tilde - Float('0.995')) + )*v_T_tilde + ) + self.da_expr = ( + (1/(self.tau_a*(Rational(1, 2) + Rational(3, 2)*self.a))) + *(Rational(1, 2) + Rational(1, 2)*tanh(self.b*(self.e - self.a))) + + ((Rational(1, 2) + Rational(3, 2)*self.a)/self.tau_d) + *(Rational(1, 2) - Rational(1, 2)*tanh(self.b*(self.e - self.a))) + )*(self.e - self.a) + + def test_state_vars(self): + assert hasattr(self.instance, 'x') + assert hasattr(self.instance, 'state_vars') + assert self.instance.x == self.instance.state_vars + x_expected = Matrix([self.F_T_tilde, self.a]) + assert self.instance.x == x_expected + assert self.instance.state_vars == x_expected + assert isinstance(self.instance.x, Matrix) + assert isinstance(self.instance.state_vars, Matrix) + assert self.instance.x.shape == (2, 1) + assert self.instance.state_vars.shape == (2, 1) + + def test_input_vars(self): + assert hasattr(self.instance, 'r') + assert hasattr(self.instance, 'input_vars') + assert self.instance.r == self.instance.input_vars + r_expected = Matrix([self.e]) + assert self.instance.r == r_expected + assert self.instance.input_vars == r_expected + assert isinstance(self.instance.r, Matrix) + assert isinstance(self.instance.input_vars, Matrix) + assert self.instance.r.shape == (1, 1) + assert self.instance.input_vars.shape == (1, 1) + + def test_constants(self): + assert hasattr(self.instance, 'p') + assert hasattr(self.instance, 'constants') + assert self.instance.p == self.instance.constants + p_expected = Matrix( + [ + self.l_T_slack, + self.F_M_max, + self.l_M_opt, + self.v_M_max, + self.alpha_opt, + self.beta, + self.tau_a, + self.tau_d, + self.b, + ] + ) + assert self.instance.p == p_expected + assert self.instance.constants == p_expected + assert isinstance(self.instance.p, Matrix) + assert isinstance(self.instance.constants, Matrix) + assert self.instance.p.shape == (9, 1) + assert self.instance.constants.shape == (9, 1) + + def test_M(self): + assert hasattr(self.instance, 'M') + M_expected = eye(2) + assert self.instance.M == M_expected + assert isinstance(self.instance.M, Matrix) + assert self.instance.M.shape == (2, 2) + + def test_F(self): + assert hasattr(self.instance, 'F') + F_expected = Matrix([self.dF_T_tilde_expr, self.da_expr]) + assert self.instance.F == F_expected + assert isinstance(self.instance.F, Matrix) + assert self.instance.F.shape == (2, 1) + + def test_rhs(self): + assert hasattr(self.instance, 'rhs') + rhs_expected = Matrix([self.dF_T_tilde_expr, self.da_expr]) + rhs = self.instance.rhs() + assert isinstance(rhs, Matrix) + assert rhs.shape == (2, 1) + assert simplify(rhs - rhs_expected) == zeros(2, 1) + + +class TestMusculotendonDeGroote2016: + + @staticmethod + def test_class(): + assert issubclass(MusculotendonDeGroote2016, ForceActuator) + assert issubclass(MusculotendonDeGroote2016, _NamedMixin) + assert MusculotendonDeGroote2016.__name__ == 'MusculotendonDeGroote2016' + + @staticmethod + def test_instance(): + origin = Point('pO') + insertion = Point('pI') + insertion.set_pos(origin, dynamicsymbols('q')*ReferenceFrame('N').x) + pathway = LinearPathway(origin, insertion) + activation = FirstOrderActivationDeGroote2016('name') + l_T_slack = Symbol('l_T_slack') + F_M_max = Symbol('F_M_max') + l_M_opt = Symbol('l_M_opt') + v_M_max = Symbol('v_M_max') + alpha_opt = Symbol('alpha_opt') + beta = Symbol('beta') + instance = MusculotendonDeGroote2016( + 'name', + pathway, + activation, + musculotendon_dynamics=MusculotendonFormulation.RIGID_TENDON, + tendon_slack_length=l_T_slack, + peak_isometric_force=F_M_max, + optimal_fiber_length=l_M_opt, + maximal_fiber_velocity=v_M_max, + optimal_pennation_angle=alpha_opt, + fiber_damping_coefficient=beta, + ) + assert isinstance(instance, MusculotendonDeGroote2016) + + @pytest.fixture(autouse=True) + def _musculotendon_fixture(self): + self.name = 'name' + self.N = ReferenceFrame('N') + self.q = dynamicsymbols('q') + self.origin = Point('pO') + self.insertion = Point('pI') + self.insertion.set_pos(self.origin, self.q*self.N.x) + self.pathway = LinearPathway(self.origin, self.insertion) + self.activation = FirstOrderActivationDeGroote2016(self.name) + self.l_T_slack = Symbol('l_T_slack') + self.F_M_max = Symbol('F_M_max') + self.l_M_opt = Symbol('l_M_opt') + self.v_M_max = Symbol('v_M_max') + self.alpha_opt = Symbol('alpha_opt') + self.beta = Symbol('beta') + + def test_with_defaults(self): + origin = Point('pO') + insertion = Point('pI') + insertion.set_pos(origin, dynamicsymbols('q')*ReferenceFrame('N').x) + pathway = LinearPathway(origin, insertion) + activation = FirstOrderActivationDeGroote2016('name') + l_T_slack = Symbol('l_T_slack') + F_M_max = Symbol('F_M_max') + l_M_opt = Symbol('l_M_opt') + v_M_max = Float('10.0') + alpha_opt = Float('0.0') + beta = Float('0.1') + instance = MusculotendonDeGroote2016.with_defaults( + 'name', + pathway, + activation, + musculotendon_dynamics=MusculotendonFormulation.RIGID_TENDON, + tendon_slack_length=l_T_slack, + peak_isometric_force=F_M_max, + optimal_fiber_length=l_M_opt, + ) + assert instance.tendon_slack_length == l_T_slack + assert instance.peak_isometric_force == F_M_max + assert instance.optimal_fiber_length == l_M_opt + assert instance.maximal_fiber_velocity == v_M_max + assert instance.optimal_pennation_angle == alpha_opt + assert instance.fiber_damping_coefficient == beta + + @pytest.mark.parametrize( + 'l_T_slack, expected', + [ + (None, Symbol('l_T_slack_name')), + (Symbol('l_T_slack'), Symbol('l_T_slack')), + (Rational(1, 2), Rational(1, 2)), + (Float('0.5'), Float('0.5')), + ], + ) + def test_tendon_slack_length(self, l_T_slack, expected): + instance = MusculotendonDeGroote2016( + self.name, + self.pathway, + self.activation, + musculotendon_dynamics=MusculotendonFormulation.RIGID_TENDON, + tendon_slack_length=l_T_slack, + peak_isometric_force=self.F_M_max, + optimal_fiber_length=self.l_M_opt, + maximal_fiber_velocity=self.v_M_max, + optimal_pennation_angle=self.alpha_opt, + fiber_damping_coefficient=self.beta, + ) + assert instance.l_T_slack == expected + assert instance.tendon_slack_length == expected + + @pytest.mark.parametrize( + 'F_M_max, expected', + [ + (None, Symbol('F_M_max_name')), + (Symbol('F_M_max'), Symbol('F_M_max')), + (Integer(1000), Integer(1000)), + (Float('1000.0'), Float('1000.0')), + ], + ) + def test_peak_isometric_force(self, F_M_max, expected): + instance = MusculotendonDeGroote2016( + self.name, + self.pathway, + self.activation, + musculotendon_dynamics=MusculotendonFormulation.RIGID_TENDON, + tendon_slack_length=self.l_T_slack, + peak_isometric_force=F_M_max, + optimal_fiber_length=self.l_M_opt, + maximal_fiber_velocity=self.v_M_max, + optimal_pennation_angle=self.alpha_opt, + fiber_damping_coefficient=self.beta, + ) + assert instance.F_M_max == expected + assert instance.peak_isometric_force == expected + + @pytest.mark.parametrize( + 'l_M_opt, expected', + [ + (None, Symbol('l_M_opt_name')), + (Symbol('l_M_opt'), Symbol('l_M_opt')), + (Rational(1, 2), Rational(1, 2)), + (Float('0.5'), Float('0.5')), + ], + ) + def test_optimal_fiber_length(self, l_M_opt, expected): + instance = MusculotendonDeGroote2016( + self.name, + self.pathway, + self.activation, + musculotendon_dynamics=MusculotendonFormulation.RIGID_TENDON, + tendon_slack_length=self.l_T_slack, + peak_isometric_force=self.F_M_max, + optimal_fiber_length=l_M_opt, + maximal_fiber_velocity=self.v_M_max, + optimal_pennation_angle=self.alpha_opt, + fiber_damping_coefficient=self.beta, + ) + assert instance.l_M_opt == expected + assert instance.optimal_fiber_length == expected + + @pytest.mark.parametrize( + 'v_M_max, expected', + [ + (None, Symbol('v_M_max_name')), + (Symbol('v_M_max'), Symbol('v_M_max')), + (Integer(10), Integer(10)), + (Float('10.0'), Float('10.0')), + ], + ) + def test_maximal_fiber_velocity(self, v_M_max, expected): + instance = MusculotendonDeGroote2016( + self.name, + self.pathway, + self.activation, + musculotendon_dynamics=MusculotendonFormulation.RIGID_TENDON, + tendon_slack_length=self.l_T_slack, + peak_isometric_force=self.F_M_max, + optimal_fiber_length=self.l_M_opt, + maximal_fiber_velocity=v_M_max, + optimal_pennation_angle=self.alpha_opt, + fiber_damping_coefficient=self.beta, + ) + assert instance.v_M_max == expected + assert instance.maximal_fiber_velocity == expected + + @pytest.mark.parametrize( + 'alpha_opt, expected', + [ + (None, Symbol('alpha_opt_name')), + (Symbol('alpha_opt'), Symbol('alpha_opt')), + (Integer(0), Integer(0)), + (Float('0.1'), Float('0.1')), + ], + ) + def test_optimal_pennation_angle(self, alpha_opt, expected): + instance = MusculotendonDeGroote2016( + self.name, + self.pathway, + self.activation, + musculotendon_dynamics=MusculotendonFormulation.RIGID_TENDON, + tendon_slack_length=self.l_T_slack, + peak_isometric_force=self.F_M_max, + optimal_fiber_length=self.l_M_opt, + maximal_fiber_velocity=self.v_M_max, + optimal_pennation_angle=alpha_opt, + fiber_damping_coefficient=self.beta, + ) + assert instance.alpha_opt == expected + assert instance.optimal_pennation_angle == expected + + @pytest.mark.parametrize( + 'beta, expected', + [ + (None, Symbol('beta_name')), + (Symbol('beta'), Symbol('beta')), + (Integer(0), Integer(0)), + (Rational(1, 10), Rational(1, 10)), + (Float('0.1'), Float('0.1')), + ], + ) + def test_fiber_damping_coefficient(self, beta, expected): + instance = MusculotendonDeGroote2016( + self.name, + self.pathway, + self.activation, + musculotendon_dynamics=MusculotendonFormulation.RIGID_TENDON, + tendon_slack_length=self.l_T_slack, + peak_isometric_force=self.F_M_max, + optimal_fiber_length=self.l_M_opt, + maximal_fiber_velocity=self.v_M_max, + optimal_pennation_angle=self.alpha_opt, + fiber_damping_coefficient=beta, + ) + assert instance.beta == expected + assert instance.fiber_damping_coefficient == expected + + def test_excitation(self): + instance = MusculotendonDeGroote2016( + self.name, + self.pathway, + self.activation, + ) + assert hasattr(instance, 'e') + assert hasattr(instance, 'excitation') + e_expected = dynamicsymbols('e_name') + assert instance.e == e_expected + assert instance.excitation == e_expected + assert instance.e is instance.excitation + + def test_excitation_is_immutable(self): + instance = MusculotendonDeGroote2016( + self.name, + self.pathway, + self.activation, + ) + with pytest.raises(AttributeError): + instance.e = None + with pytest.raises(AttributeError): + instance.excitation = None + + def test_activation(self): + instance = MusculotendonDeGroote2016( + self.name, + self.pathway, + self.activation, + ) + assert hasattr(instance, 'a') + assert hasattr(instance, 'activation') + a_expected = dynamicsymbols('a_name') + assert instance.a == a_expected + assert instance.activation == a_expected + + def test_activation_is_immutable(self): + instance = MusculotendonDeGroote2016( + self.name, + self.pathway, + self.activation, + ) + with pytest.raises(AttributeError): + instance.a = None + with pytest.raises(AttributeError): + instance.activation = None + + def test_repr(self): + instance = MusculotendonDeGroote2016( + self.name, + self.pathway, + self.activation, + musculotendon_dynamics=MusculotendonFormulation.RIGID_TENDON, + tendon_slack_length=self.l_T_slack, + peak_isometric_force=self.F_M_max, + optimal_fiber_length=self.l_M_opt, + maximal_fiber_velocity=self.v_M_max, + optimal_pennation_angle=self.alpha_opt, + fiber_damping_coefficient=self.beta, + ) + expected = ( + 'MusculotendonDeGroote2016(\'name\', ' + 'pathway=LinearPathway(pO, pI), ' + 'activation_dynamics=FirstOrderActivationDeGroote2016(\'name\', ' + 'activation_time_constant=tau_a_name, ' + 'deactivation_time_constant=tau_d_name, ' + 'smoothing_rate=b_name), ' + 'musculotendon_dynamics=0, ' + 'tendon_slack_length=l_T_slack, ' + 'peak_isometric_force=F_M_max, ' + 'optimal_fiber_length=l_M_opt, ' + 'maximal_fiber_velocity=v_M_max, ' + 'optimal_pennation_angle=alpha_opt, ' + 'fiber_damping_coefficient=beta)' + ) + assert repr(instance) == expected diff --git a/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/__init__.py b/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..781429110cab760f8990961c6536e7267a2a371a --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/__init__.py @@ -0,0 +1,10 @@ +__all__ = ['Beam', + 'Truss', + 'Cable', + 'Arch' + ] + +from .beam import Beam +from .truss import Truss +from .cable import Cable +from .arch import Arch diff --git a/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41af314d499482be2d3d0aad276580e3acdef94b Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/__pycache__/arch.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/__pycache__/arch.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f48d9b25b995186986165ab51705e2a5a44f592 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/__pycache__/arch.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/__pycache__/cable.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/__pycache__/cable.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc73a8e948565117b34bb576f973eea078a33cb2 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/__pycache__/cable.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/__pycache__/truss.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/__pycache__/truss.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..709375332c211e0229b46c10c1c3170efc17e3df Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/__pycache__/truss.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/arch.py b/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/arch.py new file mode 100644 index 0000000000000000000000000000000000000000..31e2b41e841638f6a8002da1a7c843a9f5b35555 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/arch.py @@ -0,0 +1,1025 @@ +""" +This module can be used to solve probelsm related to 2D parabolic arches +""" +from sympy.core.sympify import sympify +from sympy.core.symbol import Symbol,symbols +from sympy import diff, sqrt, cos , sin, atan, rad, Min +from sympy.core.relational import Eq +from sympy.solvers.solvers import solve +from sympy.functions import Piecewise +from sympy.plotting import plot +from sympy import limit +from sympy.utilities.decorator import doctest_depends_on +from sympy.external.importtools import import_module + +numpy = import_module('numpy', import_kwargs={'fromlist':['arange']}) + +class Arch: + """ + This class is used to solve problems related to a three hinged arch(determinate) structure.\n + An arch is a curved vertical structure spanning an open space underneath it.\n + Arches can be used to reduce the bending moments in long-span structures.\n + + Arches are used in structural engineering(over windows, door and even bridges)\n + because they can support a very large mass placed on top of them. + + Example + ======== + >>> from sympy.physics.continuum_mechanics.arch import Arch + >>> a = Arch((0,0),(10,0),crown_x=5,crown_y=5) + >>> a.get_shape_eqn + 5 - (x - 5)**2/5 + + >>> from sympy.physics.continuum_mechanics.arch import Arch + >>> a = Arch((0,0),(10,1),crown_x=6) + >>> a.get_shape_eqn + 9/5 - (x - 6)**2/20 + """ + def __init__(self,left_support,right_support,**kwargs): + self._shape_eqn = None + self._left_support = (sympify(left_support[0]),sympify(left_support[1])) + self._right_support = (sympify(right_support[0]),sympify(right_support[1])) + self._crown_x = None + self._crown_y = None + if 'crown_x' in kwargs: + self._crown_x = sympify(kwargs['crown_x']) + if 'crown_y' in kwargs: + self._crown_y = sympify(kwargs['crown_y']) + self._shape_eqn = self.get_shape_eqn + self._conc_loads = {} + self._distributed_loads = {} + self._loads = {'concentrated': self._conc_loads, 'distributed':self._distributed_loads} + self._loads_applied = {} + self._supports = {'left':'hinge', 'right':'hinge'} + self._member = None + self._member_force = None + self._reaction_force = {Symbol('R_A_x'):0, Symbol('R_A_y'):0, Symbol('R_B_x'):0, Symbol('R_B_y'):0} + self._points_disc_x = set() + self._points_disc_y = set() + self._moment_x = {} + self._moment_y = {} + self._load_x = {} + self._load_y = {} + self._moment_x_func = Piecewise((0,True)) + self._moment_y_func = Piecewise((0,True)) + self._load_x_func = Piecewise((0,True)) + self._load_y_func = Piecewise((0,True)) + self._bending_moment = None + self._shear_force = None + self._axial_force = None + # self._crown = (sympify(crown[0]),sympify(crown[1])) + + @property + def get_shape_eqn(self): + "returns the equation of the shape of arch developed" + if self._shape_eqn: + return self._shape_eqn + + x,y,c = symbols('x y c') + a = Symbol('a',positive=False) + if self._crown_x and self._crown_y: + x0 = self._crown_x + y0 = self._crown_y + parabola_eqn = a*(x-x0)**2 + y0 - y + eq1 = parabola_eqn.subs({x:self._left_support[0], y:self._left_support[1]}) + solution = solve((eq1),(a)) + parabola_eqn = solution[0]*(x-x0)**2 + y0 + if(parabola_eqn.subs({x:self._right_support[0]}) != self._right_support[1]): + raise ValueError("provided coordinates of crown and supports are not consistent with parabolic arch") + + elif self._crown_x: + x0 = self._crown_x + parabola_eqn = a*(x-x0)**2 + c - y + eq1 = parabola_eqn.subs({x:self._left_support[0], y:self._left_support[1]}) + eq2 = parabola_eqn.subs({x:self._right_support[0], y:self._right_support[1]}) + solution = solve((eq1,eq2),(a,c)) + if len(solution) <2 or solution[a] == 0: + raise ValueError("parabolic arch cannot be constructed with the provided coordinates, try providing crown_y") + parabola_eqn = solution[a]*(x-x0)**2+ solution[c] + self._crown_y = solution[c] + + else: + raise KeyError("please provide crown_x to construct arch") + + return parabola_eqn + + @property + def get_loads(self): + """ + return the position of the applied load and angle (for concentrated loads) + """ + return self._loads + + @property + def supports(self): + """ + Returns the type of support + """ + return self._supports + + @property + def left_support(self): + """ + Returns the position of the left support. + """ + return self._left_support + + @property + def right_support(self): + """ + Returns the position of the right support. + """ + return self._right_support + + @property + def reaction_force(self): + """ + return the reaction forces generated + """ + return self._reaction_force + + def apply_load(self,order,label,start,mag,end=None,angle=None): + """ + This method adds load to the Arch. + + Parameters + ========== + + order : Integer + Order of the applied load. + + - For point/concentrated loads, order = -1 + - For distributed load, order = 0 + + label : String or Symbol + The label of the load + - should not use 'A' or 'B' as it is used for supports. + + start : Float + + - For concentrated/point loads, start is the x coordinate + - For distributed loads, start is the starting position of distributed load + + mag : Sympifyable + Magnitude of the applied load. Must be positive + + end : Float + Required for distributed loads + + - For concentrated/point load , end is None(may not be given) + - For distributed loads, end is the end position of distributed load + + angle: Sympifyable + The angle in degrees, the load vector makes with the horizontal + in the counter-clockwise direction. + + Examples + ======== + For applying distributed load + + >>> from sympy.physics.continuum_mechanics.arch import Arch + >>> a = Arch((0,0),(10,0),crown_x=5,crown_y=5) + >>> a.apply_load(0,'C',start=3,end=5,mag=-10) + + For applying point/concentrated_loads + + >>> from sympy.physics.continuum_mechanics.arch import Arch + >>> a = Arch((0,0),(10,0),crown_x=5,crown_y=5) + >>> a.apply_load(-1,'C',start=2,mag=15,angle=45) + + """ + y = Symbol('y') + x = Symbol('x') + x0 = Symbol('x0') + # y0 = Symbol('y0') + order= sympify(order) + mag = sympify(mag) + angle = sympify(angle) + + if label in self._loads_applied: + raise ValueError("load with the given label already exists") + + if label in ['A','B']: + raise ValueError("cannot use the given label, reserved for supports") + + if order == 0: + if end is None or end>> from sympy.physics.continuum_mechanics.arch import Arch + >>> a = Arch((0,0),(10,0),crown_x=5,crown_y=5) + >>> a.apply_load(0,'C',start=3,end=5,mag=-10) + >>> a.remove_load('C') + removed load C: {'start': 3, 'end': 5, 'f_y': -10} + """ + y = Symbol('y') + x = Symbol('x') + x0 = Symbol('x0') + + if label in self._distributed_loads : + + self._loads_applied.pop(label) + start = self._distributed_loads[label]['start'] + end = self._distributed_loads[label]['end'] + mag = self._distributed_loads[label]['f_y'] + self._points_disc_y.remove(start) + self._load_y[start] -= mag*(Min(x,end)-start) + self._moment_y[start] += mag*(Min(x,end)-start)*(x0-(start+(Min(x,end)))/2) + val = self._distributed_loads.pop(label) + print(f"removed load {label}: {val}") + + elif label in self._conc_loads : + + self._loads_applied.pop(label) + start = self._conc_loads[label]['x'] + self._points_disc_x.remove(start) + self._points_disc_y.remove(start) + self._moment_y[start] += self._conc_loads[label]['f_y']*(x0-start) + self._moment_x[start] -= self._conc_loads[label]['f_x']*(y-self._conc_loads[label]['y']) + self._load_x[start] -= self._conc_loads[label]['f_x'] + self._load_y[start] -= self._conc_loads[label]['f_y'] + val = self._conc_loads.pop(label) + print(f"removed load {label}: {val}") + + else : + raise ValueError("label not found") + + def change_support_position(self, left_support=None, right_support=None): + """ + Change position of supports. + If not provided , defaults to the old value. + Parameters + ========== + + left_support: tuple (x, y) + x: float + x-coordinate value of the left_support + + y: float + y-coordinate value of the left_support + + right_support: tuple (x, y) + x: float + x-coordinate value of the right_support + + y: float + y-coordinate value of the right_support + """ + if left_support is not None: + self._left_support = (left_support[0],left_support[1]) + + if right_support is not None: + self._right_support = (right_support[0],right_support[1]) + + self._shape_eqn = None + self._shape_eqn = self.get_shape_eqn + + def change_crown_position(self,crown_x=None,crown_y=None): + """ + Change the position of the crown/hinge of the arch + + Parameters + ========== + + crown_x: Float + The x coordinate of the position of the hinge + - if not provided, defaults to old value + + crown_y: Float + The y coordinate of the position of the hinge + - if not provided defaults to None + """ + self._crown_x = crown_x + self._crown_y = crown_y + self._shape_eqn = None + self._shape_eqn = self.get_shape_eqn + + def change_support_type(self,left_support=None,right_support=None): + """ + Add the type for support at each end. + Can use roller or hinge support at each end. + + Parameters + ========== + + left_support, right_support : string + Type of support at respective end + + - For roller support , left_support/right_support = "roller" + - For hinged support, left_support/right_support = "hinge" + - defaults to hinge if value not provided + + Examples + ======== + + For applying roller support at right end + + >>> from sympy.physics.continuum_mechanics.arch import Arch + >>> a = Arch((0,0),(10,0),crown_x=5,crown_y=5) + >>> a.change_support_type(right_support="roller") + + """ + support_types = ['roller','hinge'] + if left_support: + if left_support not in support_types: + raise ValueError("supports must only be roller or hinge") + + self._supports['left'] = left_support + + if right_support: + if right_support not in support_types: + raise ValueError("supports must only be roller or hinge") + + self._supports['right'] = right_support + + def add_member(self,y): + """ + This method adds a member/rod at a particular height y. + A rod is used for stability of the structure in case of a roller support. + """ + if y>self._crown_y or y>> from sympy.physics.continuum_mechanics.arch import Arch + >>> a = Arch((0,0),(10,0),crown_x=5,crown_y=5) + >>> a.apply_load(0,'C',start=3,end=5,mag=-10) + >>> a.solve() + >>> a.reaction_force + {R_A_x: 8, R_A_y: 12, R_B_x: -8, R_B_y: 8} + + >>> from sympy import Symbol + >>> t = Symbol('t') + >>> from sympy.physics.continuum_mechanics.arch import Arch + >>> a = Arch((0,0),(16,0),crown_x=8,crown_y=5) + >>> a.apply_load(0,'C',start=3,end=5,mag=t) + >>> a.solve() + >>> a.reaction_force + {R_A_x: -4*t/5, R_A_y: -3*t/2, R_B_x: 4*t/5, R_B_y: -t/2} + + >>> a.bending_moment_at(4) + -5*t/2 + """ + y = Symbol('y') + x = Symbol('x') + x0 = Symbol('x0') + + discontinuity_points_x = sorted(self._points_disc_x) + discontinuity_points_y = sorted(self._points_disc_y) + + self._moment_x_func = Piecewise((0,True)) + self._moment_y_func = Piecewise((0,True)) + + self._load_x_func = Piecewise((0,True)) + self._load_y_func = Piecewise((0,True)) + + accumulated_x_moment = 0 + accumulated_y_moment = 0 + + accumulated_x_load = 0 + accumulated_y_load = 0 + + for point in discontinuity_points_x: + cond = (x >= point) + accumulated_x_load += self._load_x[point] + accumulated_x_moment += self._moment_x[point] + self._load_x_func = Piecewise((accumulated_x_load,cond),(self._load_x_func,True)) + self._moment_x_func = Piecewise((accumulated_x_moment,cond),(self._moment_x_func,True)) + + for point in discontinuity_points_y: + cond = (x >= point) + accumulated_y_moment += self._moment_y[point] + accumulated_y_load += self._load_y[point] + self._load_y_func = Piecewise((accumulated_y_load,cond),(self._load_y_func,True)) + self._moment_y_func = Piecewise((accumulated_y_moment,cond),(self._moment_y_func,True)) + + moment_A = self._moment_y_func.subs(x,self._right_support[0]).subs(x0,self._left_support[0]) +\ + self._moment_x_func.subs(x,self._right_support[0]).subs(y,self._left_support[1]) + + moment_hinge_left = self._moment_y_func.subs(x,self._crown_x).subs(x0,self._crown_x) +\ + self._moment_x_func.subs(x,self._crown_x).subs(y,self._crown_y) + + moment_hinge_right = self._moment_y_func.subs(x,self._right_support[0]).subs(x0,self._crown_x)- \ + self._moment_y_func.subs(x,self._crown_x).subs(x0,self._crown_x) +\ + self._moment_x_func.subs(x,self._right_support[0]).subs(y,self._crown_y) -\ + self._moment_x_func.subs(x,self._crown_x).subs(y,self._crown_y) + + net_x = self._load_x_func.subs(x,self._right_support[0]) + net_y = self._load_y_func.subs(x,self._right_support[0]) + + if (self._supports['left']=='roller' or self._supports['right']=='roller') and not self._member: + print("member must be added if any of the supports is roller") + return + + R_A_x, R_A_y, R_B_x, R_B_y, T = symbols('R_A_x R_A_y R_B_x R_B_y T') + + if self._supports['left'] == 'roller' and self._supports['right'] == 'roller': + + if self._member[2]>=max(self._left_support[1],self._right_support[1]): + + if net_x!=0: + raise ValueError("net force in x direction not possible under the specified conditions") + + else: + eq1 = Eq(R_A_x ,0) + eq2 = Eq(R_B_x, 0) + eq3 = Eq(R_A_y + R_B_y + net_y,0) + + eq4 = Eq(R_B_y*(self._right_support[0]-self._left_support[0])-\ + R_B_x*(self._right_support[1]-self._left_support[1])+moment_A,0) + + eq5 = Eq(moment_hinge_right + R_B_y*(self._right_support[0]-self._crown_x) +\ + T*(self._member[2]-self._crown_y),0) + solution = solve((eq1,eq2,eq3,eq4,eq5),(R_A_x,R_A_y,R_B_x,R_B_y,T)) + + elif self._member[2]>=self._left_support[1]: + eq1 = Eq(R_A_x ,0) + eq2 = Eq(R_B_x, 0) + eq3 = Eq(R_A_y + R_B_y + net_y,0) + eq4 = Eq(R_B_y*(self._right_support[0]-self._left_support[0])-\ + T*(self._member[2]-self._left_support[1])+moment_A,0) + eq5 = Eq(T+net_x,0) + solution = solve((eq1,eq2,eq3,eq4,eq5),(R_A_x,R_A_y,R_B_x,R_B_y,T)) + + elif self._member[2]>=self._right_support[1]: + eq1 = Eq(R_A_x ,0) + eq2 = Eq(R_B_x, 0) + eq3 = Eq(R_A_y + R_B_y + net_y,0) + eq4 = Eq(R_B_y*(self._right_support[0]-self._left_support[0])+\ + T*(self._member[2]-self._left_support[1])+moment_A,0) + eq5 = Eq(T-net_x,0) + solution = solve((eq1,eq2,eq3,eq4,eq5),(R_A_x,R_A_y,R_B_x,R_B_y,T)) + + elif self._supports['left'] == 'roller': + if self._member[2]>=max(self._left_support[1], self._right_support[1]): + eq1 = Eq(R_A_x ,0) + eq2 = Eq(R_B_x+net_x,0) + eq3 = Eq(R_A_y + R_B_y + net_y,0) + eq4 = Eq(R_B_y*(self._right_support[0]-self._left_support[0])-\ + R_B_x*(self._right_support[1]-self._left_support[1])+moment_A,0) + eq5 = Eq(moment_hinge_left + R_A_y*(self._left_support[0]-self._crown_x) -\ + T*(self._member[2]-self._crown_y),0) + solution = solve((eq1,eq2,eq3,eq4,eq5),(R_A_x,R_A_y,R_B_x,R_B_y,T)) + + elif self._member[2]>=self._left_support[1]: + eq1 = Eq(R_A_x ,0) + eq2 = Eq(R_B_x+ T +net_x,0) + eq3 = Eq(R_A_y + R_B_y + net_y,0) + eq4 = Eq(R_B_y*(self._right_support[0]-self._left_support[0])-\ + R_B_x*(self._right_support[1]-self._left_support[1])-\ + T*(self._member[2]-self._left_support[0])+moment_A,0) + eq5 = Eq(moment_hinge_left + R_A_y*(self._left_support[0]-self._crown_x)-\ + T*(self._member[2]-self._crown_y),0) + solution = solve((eq1,eq2,eq3,eq4,eq5),(R_A_x,R_A_y,R_B_x,R_B_y,T)) + + elif self._member[2]>=self._right_support[0]: + eq1 = Eq(R_A_x,0) + eq2 = Eq(R_B_x- T +net_x,0) + eq3 = Eq(R_A_y + R_B_y + net_y,0) + eq4 = Eq(moment_hinge_left+R_A_y*(self._left_support[0]-self._crown_x),0) + eq5 = Eq(moment_A+R_B_y*(self._right_support[0]-self._left_support[0])-\ + R_B_x*(self._right_support[1]-self._left_support[1])+\ + T*(self._member[2]-self._left_support[1]),0) + solution = solve((eq1,eq2,eq3,eq4,eq5),(R_A_x,R_A_y,R_B_x,R_B_y,T)) + + elif self._supports['right'] == 'roller': + if self._member[2]>=max(self._left_support[1], self._right_support[1]): + eq1 = Eq(R_B_x,0) + eq2 = Eq(R_A_x+net_x,0) + eq3 = Eq(R_A_y+R_B_y+net_y,0) + eq4 = Eq(moment_hinge_right+R_B_y*(self._right_support[0]-self._crown_x)+\ + T*(self._member[2]-self._crown_y),0) + eq5 = Eq(moment_A+R_B_y*(self._right_support[0]-self._left_support[0]),0) + solution = solve((eq1,eq2,eq3,eq4,eq5),(R_A_x,R_A_y,R_B_x,R_B_y,T)) + + elif self._member[2]>=self._left_support[1]: + eq1 = Eq(R_B_x,0) + eq2 = Eq(R_A_x+T+net_x,0) + eq3 = Eq(R_A_y+R_B_y+net_y,0) + eq4 = Eq(moment_hinge_right+R_B_y*(self._right_support[0]-self._crown_x),0) + eq5 = Eq(moment_A-T*(self._member[2]-self._left_support[1])+\ + R_B_y*(self._right_support[0]-self._left_support[0]),0) + solution = solve((eq1,eq2,eq3,eq4,eq5),(R_A_x,R_A_y,R_B_x,R_B_y,T)) + + elif self._member[2]>=self._right_support[1]: + eq1 = Eq(R_B_x,0) + eq2 = Eq(R_A_x-T+net_x,0) + eq3 = Eq(R_A_y+R_B_y+net_y,0) + eq4 = Eq(moment_hinge_right+R_B_y*(self._right_support[0]-self._crown_x)+\ + T*(self._member[2]-self._crown_y),0) + eq5 = Eq(moment_A+T*(self._member[2]-self._left_support[1])+\ + R_B_y*(self._right_support[0]-self._left_support[0])) + solution = solve((eq1,eq2,eq3,eq4,eq5),(R_A_x,R_A_y,R_B_x,R_B_y,T)) + else: + eq1 = Eq(R_A_x + R_B_x + net_x,0) + eq2 = Eq(R_A_y + R_B_y + net_y,0) + eq3 = Eq(R_B_y*(self._right_support[0]-self._left_support[0])-\ + R_B_x*(self._right_support[1]-self._left_support[1])+moment_A,0) + eq4 = Eq(moment_hinge_right + R_B_y*(self._right_support[0]-self._crown_x) -\ + R_B_x*(self._right_support[1]-self._crown_y),0) + solution = solve((eq1,eq2,eq3,eq4),(R_A_x,R_A_y,R_B_x,R_B_y)) + + for symb in self._reaction_force: + self._reaction_force[symb] = solution[symb] + + self._bending_moment = - (self._moment_x_func.subs(x,x0) + self._moment_y_func.subs(x,x0) -\ + solution[R_A_y]*(x0-self._left_support[0]) +\ + solution[R_A_x]*(self._shape_eqn.subs({x:x0})-self._left_support[1])) + + angle = atan(diff(self._shape_eqn,x)) + + fx = (self._load_x_func+solution[R_A_x]) + fy = (self._load_y_func+solution[R_A_y]) + + axial_force = fx*cos(angle) + fy*sin(angle) + shear_force = -fx*sin(angle) + fy*cos(angle) + + self._axial_force = axial_force + self._shear_force = shear_force + + @doctest_depends_on(modules=('numpy',)) + def draw(self): + """ + This method returns a plot object containing the diagram of the specified arch along with the supports + and forces applied to the structure. + + Examples + ======== + + >>> from sympy import Symbol + >>> t = Symbol('t') + >>> from sympy.physics.continuum_mechanics.arch import Arch + >>> a = Arch((0,0),(40,0),crown_x=20,crown_y=12) + >>> a.apply_load(-1,'C',8,150,angle=270) + >>> a.apply_load(0,'D',start=20,end=40,mag=-4) + >>> a.apply_load(-1,'E',10,t,angle=300) + >>> p = a.draw() + >>> p # doctest: +ELLIPSIS + Plot object containing: + [0]: cartesian line: 11.325 - 3*(x - 20)**2/100 for x over (0.0, 40.0) + [1]: cartesian line: 12 - 3*(x - 20)**2/100 for x over (0.0, 40.0) + ... + >>> p.show() + + """ + x = Symbol('x') + markers = [] + annotations = self._draw_loads() + rectangles = [] + supports = self._draw_supports() + markers+=supports + + xmax = self._right_support[0] + xmin = self._left_support[0] + ymin = min(self._left_support[1],self._right_support[1]) + ymax = self._crown_y + + lim = max(xmax*1.1-xmin*0.8+1, ymax*1.1-ymin*0.8+1) + + rectangles = self._draw_rectangles() + + filler = self._draw_filler() + rectangles+=filler + + if self._member is not None: + if(self._member[2]>=self._right_support[1]): + markers.append( + { + 'args':[[self._member[1]+0.005*lim],[self._member[2]]], + 'marker':'o', + 'markersize': 4, + 'color': 'white', + 'markerfacecolor':'none' + } + ) + + if(self._member[2]>=self._left_support[1]): + markers.append( + { + 'args':[[self._member[0]-0.005*lim],[self._member[2]]], + 'marker':'o', + 'markersize': 4, + 'color': 'white', + 'markerfacecolor':'none' + } + ) + + + + markers.append({ + 'args':[[self._crown_x],[self._crown_y-0.005*lim]], + 'marker':'o', + 'markersize': 5, + 'color':'white', + 'markerfacecolor':'none', + }) + + if lim==xmax*1.1-xmin*0.8+1: + + sing_plot = plot(self._shape_eqn-0.015*lim, + self._shape_eqn, + (x, self._left_support[0], self._right_support[0]), + markers=markers, + show=False, + annotations=annotations, + rectangles = rectangles, + xlim=(xmin-0.05*lim, xmax*1.1), + ylim=(xmin-0.05*lim, xmax*1.1), + axis=False, + line_color='brown') + + else: + sing_plot = plot(self._shape_eqn-0.015*lim, + self._shape_eqn, + (x, self._left_support[0], self._right_support[0]), + markers=markers, + show=False, + annotations=annotations, + rectangles = rectangles, + xlim=(ymin-0.05*lim, ymax*1.1), + ylim=(ymin-0.05*lim, ymax*1.1), + axis=False, + line_color='brown') + + return sing_plot + + + def _draw_supports(self): + support_markers = [] + + xmax = self._right_support[0] + xmin = self._left_support[0] + ymin = min(self._left_support[1],self._right_support[1]) + ymax = self._crown_y + + if abs(1.1*xmax-0.8*xmin)>abs(1.1*ymax-0.8*ymin): + max_diff = 1.1*xmax-0.8*xmin + else: + max_diff = 1.1*ymax-0.8*ymin + + if self._supports['left']=='roller': + support_markers.append( + { + 'args':[ + [self._left_support[0]], + [self._left_support[1]-0.02*max_diff] + ], + 'marker':'o', + 'markersize':11, + 'color':'black', + 'markerfacecolor':'none' + } + ) + else: + support_markers.append( + { + 'args':[ + [self._left_support[0]], + [self._left_support[1]-0.007*max_diff] + ], + 'marker':6, + 'markersize':15, + 'color':'black', + 'markerfacecolor':'none' + } + ) + + if self._supports['right']=='roller': + support_markers.append( + { + 'args':[ + [self._right_support[0]], + [self._right_support[1]-0.02*max_diff] + ], + 'marker':'o', + 'markersize':11, + 'color':'black', + 'markerfacecolor':'none' + } + ) + else: + support_markers.append( + { + 'args':[ + [self._right_support[0]], + [self._right_support[1]-0.007*max_diff] + ], + 'marker':6, + 'markersize':15, + 'color':'black', + 'markerfacecolor':'none' + } + ) + + support_markers.append( + { + 'args':[ + [self._right_support[0]], + [self._right_support[1]-0.036*max_diff] + ], + 'marker':'_', + 'markersize':15, + 'color':'black', + 'markerfacecolor':'none' + } + ) + + support_markers.append( + { + 'args':[ + [self._left_support[0]], + [self._left_support[1]-0.036*max_diff] + ], + 'marker':'_', + 'markersize':15, + 'color':'black', + 'markerfacecolor':'none' + } + ) + + return support_markers + + def _draw_rectangles(self): + member = [] + + xmax = self._right_support[0] + xmin = self._left_support[0] + ymin = min(self._left_support[1],self._right_support[1]) + ymax = self._crown_y + + if abs(1.1*xmax-0.8*xmin)>abs(1.1*ymax-0.8*ymin): + max_diff = 1.1*xmax-0.8*xmin + else: + max_diff = 1.1*ymax-0.8*ymin + + if self._member is not None: + if self._member[2]>= max(self._left_support[1],self._right_support[1]): + member.append( + { + 'xy':(self._member[0],self._member[2]-0.005*max_diff), + 'width':self._member[1]-self._member[0], + 'height': 0.01*max_diff, + 'angle': 0, + 'color':'brown', + } + ) + + elif self._member[2]>=self._left_support[1]: + member.append( + { + 'xy':(self._member[0],self._member[2]-0.005*max_diff), + 'width':self._right_support[0]-self._member[0], + 'height': 0.01*max_diff, + 'angle': 0, + 'color':'brown', + } + ) + + else: + member.append( + { + 'xy':(self._member[1],self._member[2]-0.005*max_diff), + 'width':abs(self._left_support[0]-self._member[1]), + 'height': 0.01*max_diff, + 'angle': 180, + 'color':'brown', + } + ) + + if self._distributed_loads: + for loads in self._distributed_loads: + + start = self._distributed_loads[loads]['start'] + end = self._distributed_loads[loads]['end'] + + member.append( + { + 'xy':(start,self._crown_y+max_diff*0.15), + 'width': (end-start), + 'height': max_diff*0.01, + 'color': 'orange' + } + ) + + + return member + + def _draw_loads(self): + load_annotations = [] + + xmax = self._right_support[0] + xmin = self._left_support[0] + ymin = min(self._left_support[1],self._right_support[1]) + ymax = self._crown_y + + if abs(1.1*xmax-0.8*xmin)>abs(1.1*ymax-0.8*ymin): + max_diff = 1.1*xmax-0.8*xmin + else: + max_diff = 1.1*ymax-0.8*ymin + + for load in self._conc_loads: + x = self._conc_loads[load]['x'] + y = self._conc_loads[load]['y'] + angle = self._conc_loads[load]['angle'] + mag = self._conc_loads[load]['mag'] + load_annotations.append( + { + 'text':'', + 'xy':( + x+cos(rad(angle))*max_diff*0.08, + y+sin(rad(angle))*max_diff*0.08 + ), + 'xytext':(x,y), + 'fontsize':10, + 'fontweight': 'bold', + 'arrowprops':{'width':1.5, 'headlength':5, 'headwidth':5, 'facecolor':'blue','edgecolor':'blue'} + } + ) + load_annotations.append( + { + 'text':f'{load}: {mag} N', + 'fontsize':10, + 'fontweight': 'bold', + 'xy': (x+cos(rad(angle))*max_diff*0.12,y+sin(rad(angle))*max_diff*0.12) + } + ) + + for load in self._distributed_loads: + start = self._distributed_loads[load]['start'] + end = self._distributed_loads[load]['end'] + mag = self._distributed_loads[load]['f_y'] + x_points = numpy.arange(start,end,(end-start)/(max_diff*0.25)) + x_points = numpy.append(x_points,end) + for point in x_points: + if(mag<0): + load_annotations.append( + { + 'text':'', + 'xy':(point,self._crown_y+max_diff*0.05), + 'xytext': (point,self._crown_y+max_diff*0.15), + 'arrowprops':{'width':1.5, 'headlength':5, 'headwidth':5, 'facecolor':'orange','edgecolor':'orange'} + } + ) + else: + load_annotations.append( + { + 'text':'', + 'xy':(point,self._crown_y+max_diff*0.2), + 'xytext': (point,self._crown_y+max_diff*0.15), + 'arrowprops':{'width':1.5, 'headlength':5, 'headwidth':5, 'facecolor':'orange','edgecolor':'orange'} + } + ) + if(mag<0): + load_annotations.append( + { + 'text':f'{load}: {abs(mag)} N/m', + 'fontsize':10, + 'fontweight': 'bold', + 'xy':((start+end)/2,self._crown_y+max_diff*0.175) + } + ) + else: + load_annotations.append( + { + 'text':f'{load}: {abs(mag)} N/m', + 'fontsize':10, + 'fontweight': 'bold', + 'xy':((start+end)/2,self._crown_y+max_diff*0.125) + } + ) + return load_annotations + + def _draw_filler(self): + x = Symbol('x') + filler = [] + xmax = self._right_support[0] + xmin = self._left_support[0] + ymin = min(self._left_support[1],self._right_support[1]) + ymax = self._crown_y + + if abs(1.1*xmax-0.8*xmin)>abs(1.1*ymax-0.8*ymin): + max_diff = 1.1*xmax-0.8*xmin + else: + max_diff = 1.1*ymax-0.8*ymin + + x_points = numpy.arange(self._left_support[0],self._right_support[0],(self._right_support[0]-self._left_support[0])/(max_diff*max_diff)) + + for point in x_points: + filler.append( + { + 'xy':(point,self._shape_eqn.subs(x,point)-max_diff*0.015), + 'width': (self._right_support[0]-self._left_support[0])/(max_diff*max_diff), + 'height': max_diff*0.015, + 'color': 'brown' + } + ) + + return filler diff --git a/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/beam.py b/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/beam.py new file mode 100644 index 0000000000000000000000000000000000000000..dfdfc6d3594da6de44c7c42def3e3f5539cb988e --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/beam.py @@ -0,0 +1,3903 @@ +""" +This module can be used to solve 2D beam bending problems with +singularity functions in mechanics. +""" + +from sympy.core import S, Symbol, diff, symbols +from sympy.core.add import Add +from sympy.core.expr import Expr +from sympy.core.function import (Derivative, Function) +from sympy.core.mul import Mul +from sympy.core.relational import Eq +from sympy.core.sympify import sympify +from sympy.solvers import linsolve +from sympy.solvers.ode.ode import dsolve +from sympy.solvers.solvers import solve +from sympy.printing import sstr +from sympy.functions import SingularityFunction, Piecewise, factorial +from sympy.integrals import integrate +from sympy.series import limit +from sympy.plotting import plot, PlotGrid +from sympy.geometry.entity import GeometryEntity +from sympy.external import import_module +from sympy.sets.sets import Interval +from sympy.utilities.lambdify import lambdify +from sympy.utilities.decorator import doctest_depends_on +from sympy.utilities.iterables import iterable +import warnings + + +__doctest_requires__ = { + ('Beam.draw', + 'Beam.plot_bending_moment', + 'Beam.plot_deflection', + 'Beam.plot_ild_moment', + 'Beam.plot_ild_shear', + 'Beam.plot_shear_force', + 'Beam.plot_shear_stress', + 'Beam.plot_slope'): ['matplotlib'], +} + + +numpy = import_module('numpy', import_kwargs={'fromlist':['arange']}) + + +class Beam: + """ + A Beam is a structural element that is capable of withstanding load + primarily by resisting against bending. Beams are characterized by + their cross sectional profile(Second moment of area), their length + and their material. + + .. note:: + A consistent sign convention must be used while solving a beam + bending problem; the results will + automatically follow the chosen sign convention. However, the + chosen sign convention must respect the rule that, on the positive + side of beam's axis (in respect to current section), a loading force + giving positive shear yields a negative moment, as below (the + curved arrow shows the positive moment and rotation): + + .. image:: allowed-sign-conventions.png + + Examples + ======== + There is a beam of length 4 meters. A constant distributed load of 6 N/m + is applied from half of the beam till the end. There are two simple supports + below the beam, one at the starting point and another at the ending point + of the beam. The deflection of the beam at the end is restricted. + + Using the sign convention of downwards forces being positive. + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols, Piecewise + >>> E, I = symbols('E, I') + >>> R1, R2 = symbols('R1, R2') + >>> b = Beam(4, E, I) + >>> b.apply_load(R1, 0, -1) + >>> b.apply_load(6, 2, 0) + >>> b.apply_load(R2, 4, -1) + >>> b.bc_deflection = [(0, 0), (4, 0)] + >>> b.boundary_conditions + {'bending_moment': [], 'deflection': [(0, 0), (4, 0)], 'shear_force': [], 'slope': []} + >>> b.load + R1*SingularityFunction(x, 0, -1) + R2*SingularityFunction(x, 4, -1) + 6*SingularityFunction(x, 2, 0) + >>> b.solve_for_reaction_loads(R1, R2) + >>> b.load + -3*SingularityFunction(x, 0, -1) + 6*SingularityFunction(x, 2, 0) - 9*SingularityFunction(x, 4, -1) + >>> b.shear_force() + 3*SingularityFunction(x, 0, 0) - 6*SingularityFunction(x, 2, 1) + 9*SingularityFunction(x, 4, 0) + >>> b.bending_moment() + 3*SingularityFunction(x, 0, 1) - 3*SingularityFunction(x, 2, 2) + 9*SingularityFunction(x, 4, 1) + >>> b.slope() + (-3*SingularityFunction(x, 0, 2)/2 + SingularityFunction(x, 2, 3) - 9*SingularityFunction(x, 4, 2)/2 + 7)/(E*I) + >>> b.deflection() + (7*x - SingularityFunction(x, 0, 3)/2 + SingularityFunction(x, 2, 4)/4 - 3*SingularityFunction(x, 4, 3)/2)/(E*I) + >>> b.deflection().rewrite(Piecewise) + (7*x - Piecewise((x**3, x >= 0), (0, True))/2 + - 3*Piecewise(((x - 4)**3, x >= 4), (0, True))/2 + + Piecewise(((x - 2)**4, x >= 2), (0, True))/4)/(E*I) + + Calculate the support reactions for a fully symbolic beam of length L. + There are two simple supports below the beam, one at the starting point + and another at the ending point of the beam. The deflection of the beam + at the end is restricted. The beam is loaded with: + + * a downward point load P1 applied at L/4 + * an upward point load P2 applied at L/8 + * a counterclockwise moment M1 applied at L/2 + * a clockwise moment M2 applied at 3*L/4 + * a distributed constant load q1, applied downward, starting from L/2 + up to 3*L/4 + * a distributed constant load q2, applied upward, starting from 3*L/4 + up to L + + No assumptions are needed for symbolic loads. However, defining a positive + length will help the algorithm to compute the solution. + + >>> E, I = symbols('E, I') + >>> L = symbols("L", positive=True) + >>> P1, P2, M1, M2, q1, q2 = symbols("P1, P2, M1, M2, q1, q2") + >>> R1, R2 = symbols('R1, R2') + >>> b = Beam(L, E, I) + >>> b.apply_load(R1, 0, -1) + >>> b.apply_load(R2, L, -1) + >>> b.apply_load(P1, L/4, -1) + >>> b.apply_load(-P2, L/8, -1) + >>> b.apply_load(M1, L/2, -2) + >>> b.apply_load(-M2, 3*L/4, -2) + >>> b.apply_load(q1, L/2, 0, 3*L/4) + >>> b.apply_load(-q2, 3*L/4, 0, L) + >>> b.bc_deflection = [(0, 0), (L, 0)] + >>> b.solve_for_reaction_loads(R1, R2) + >>> print(b.reaction_loads[R1]) + (-3*L**2*q1 + L**2*q2 - 24*L*P1 + 28*L*P2 - 32*M1 + 32*M2)/(32*L) + >>> print(b.reaction_loads[R2]) + (-5*L**2*q1 + 7*L**2*q2 - 8*L*P1 + 4*L*P2 + 32*M1 - 32*M2)/(32*L) + """ + + def __init__(self, length, elastic_modulus, second_moment, area=Symbol('A'), variable=Symbol('x'), base_char='C', ild_variable=Symbol('a')): + """Initializes the class. + + Parameters + ========== + + length : Sympifyable + A Symbol or value representing the Beam's length. + + elastic_modulus : Sympifyable + A SymPy expression representing the Beam's Modulus of Elasticity. + It is a measure of the stiffness of the Beam material. It can + also be a continuous function of position along the beam. + + second_moment : Sympifyable or Geometry object + Describes the cross-section of the beam via a SymPy expression + representing the Beam's second moment of area. It is a geometrical + property of an area which reflects how its points are distributed + with respect to its neutral axis. It can also be a continuous + function of position along the beam. Alternatively ``second_moment`` + can be a shape object such as a ``Polygon`` from the geometry module + representing the shape of the cross-section of the beam. In such cases, + it is assumed that the x-axis of the shape object is aligned with the + bending axis of the beam. The second moment of area will be computed + from the shape object internally. + + area : Symbol/float + Represents the cross-section area of beam + + variable : Symbol, optional + A Symbol object that will be used as the variable along the beam + while representing the load, shear, moment, slope and deflection + curve. By default, it is set to ``Symbol('x')``. + + base_char : String, optional + A String that will be used as base character to generate sequential + symbols for integration constants in cases where boundary conditions + are not sufficient to solve them. + + ild_variable : Symbol, optional + A Symbol object that will be used as the variable specifying the + location of the moving load in ILD calculations. By default, it + is set to ``Symbol('a')``. + """ + self.length = length + self.elastic_modulus = elastic_modulus + if isinstance(second_moment, GeometryEntity): + self.cross_section = second_moment + else: + self.cross_section = None + self.second_moment = second_moment + self.variable = variable + self.ild_variable = ild_variable + self._base_char = base_char + self._boundary_conditions = {'deflection': [], 'slope': [], 'bending_moment': [], 'shear_force': []} + self._load = 0 + self.area = area + self._applied_supports = [] + self._applied_rotation_hinges = [] + self._applied_sliding_hinges = [] + self._rotation_hinge_symbols = [] + self._sliding_hinge_symbols = [] + self._support_as_loads = [] + self._applied_loads = [] + self._reaction_loads = {} + self._ild_reactions = {} + self._ild_shear = 0 + self._ild_moment = 0 + # _original_load is a copy of _load equations with unsubstituted reaction + # forces. It is used for calculating reaction forces in case of I.L.D. + self._original_load = 0 + self._joined_beam = False + + def __str__(self): + shape_description = self._cross_section if self._cross_section else self._second_moment + str_sol = 'Beam({}, {}, {})'.format(sstr(self._length), sstr(self._elastic_modulus), sstr(shape_description)) + return str_sol + + @property + def reaction_loads(self): + """ Returns the reaction forces in a dictionary.""" + return self._reaction_loads + + @property + def rotation_jumps(self): + """ + Returns the value for the rotation jumps in rotation hinges in a dictionary. + The rotation jump is the rotation (in radian) in a rotation hinge. This can + be seen as a jump in the slope plot. + """ + return self._rotation_jumps + + @property + def deflection_jumps(self): + """ + Returns the deflection jumps in sliding hinges in a dictionary. + The deflection jump is the deflection (in meters) in a sliding hinge. + This can be seen as a jump in the deflection plot. + """ + return self._deflection_jumps + + @property + def ild_shear(self): + """ Returns the I.L.D. shear equation.""" + return self._ild_shear + + @property + def ild_reactions(self): + """ Returns the I.L.D. reaction forces in a dictionary.""" + return self._ild_reactions + + @property + def ild_rotation_jumps(self): + """ + Returns the I.L.D. rotation jumps in rotation hinges in a dictionary. + The rotation jump is the rotation (in radian) in a rotation hinge. This can + be seen as a jump in the slope plot. + """ + return self._ild_rotations_jumps + + @property + def ild_deflection_jumps(self): + """ + Returns the I.L.D. deflection jumps in sliding hinges in a dictionary. + The deflection jump is the deflection (in meters) in a sliding hinge. + This can be seen as a jump in the deflection plot. + """ + return self._ild_deflection_jumps + + @property + def ild_moment(self): + """ Returns the I.L.D. moment equation.""" + return self._ild_moment + + @property + def length(self): + """Length of the Beam.""" + return self._length + + @length.setter + def length(self, l): + self._length = sympify(l) + + @property + def area(self): + """Cross-sectional area of the Beam. """ + return self._area + + @area.setter + def area(self, a): + self._area = sympify(a) + + @property + def variable(self): + """ + A symbol that can be used as a variable along the length of the beam + while representing load distribution, shear force curve, bending + moment, slope curve and the deflection curve. By default, it is set + to ``Symbol('x')``, but this property is mutable. + + Examples + ======== + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> E, I, A = symbols('E, I, A') + >>> x, y, z = symbols('x, y, z') + >>> b = Beam(4, E, I) + >>> b.variable + x + >>> b.variable = y + >>> b.variable + y + >>> b = Beam(4, E, I, A, z) + >>> b.variable + z + """ + return self._variable + + @variable.setter + def variable(self, v): + if isinstance(v, Symbol): + self._variable = v + else: + raise TypeError("""The variable should be a Symbol object.""") + + @property + def elastic_modulus(self): + """Young's Modulus of the Beam. """ + return self._elastic_modulus + + @elastic_modulus.setter + def elastic_modulus(self, e): + self._elastic_modulus = sympify(e) + + @property + def second_moment(self): + """Second moment of area of the Beam. """ + return self._second_moment + + @second_moment.setter + def second_moment(self, i): + self._cross_section = None + if isinstance(i, GeometryEntity): + raise ValueError("To update cross-section geometry use `cross_section` attribute") + else: + self._second_moment = sympify(i) + + @property + def cross_section(self): + """Cross-section of the beam""" + return self._cross_section + + @cross_section.setter + def cross_section(self, s): + if s: + self._second_moment = s.second_moment_of_area()[0] + self._cross_section = s + + @property + def boundary_conditions(self): + """ + Returns a dictionary of boundary conditions applied on the beam. + The dictionary has three keywords namely moment, slope and deflection. + The value of each keyword is a list of tuple, where each tuple + contains location and value of a boundary condition in the format + (location, value). + + Examples + ======== + There is a beam of length 4 meters. The bending moment at 0 should be 4 + and at 4 it should be 0. The slope of the beam should be 1 at 0. The + deflection should be 2 at 0. + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> E, I = symbols('E, I') + >>> b = Beam(4, E, I) + >>> b.bc_deflection = [(0, 2)] + >>> b.bc_slope = [(0, 1)] + >>> b.boundary_conditions + {'bending_moment': [], 'deflection': [(0, 2)], 'shear_force': [], 'slope': [(0, 1)]} + + Here the deflection of the beam should be ``2`` at ``0``. + Similarly, the slope of the beam should be ``1`` at ``0``. + """ + return self._boundary_conditions + + @property + def bc_shear_force(self): + return self._boundary_conditions['shear_force'] + + @bc_shear_force.setter + def bc_shear_force(self, sf_bcs): + self._boundary_conditions['shear_force'] = sf_bcs + + @property + def bc_bending_moment(self): + return self._boundary_conditions['bending_moment'] + + @bc_bending_moment.setter + def bc_bending_moment(self, bm_bcs): + self._boundary_conditions['bending_moment'] = bm_bcs + + @property + def bc_slope(self): + return self._boundary_conditions['slope'] + + @bc_slope.setter + def bc_slope(self, s_bcs): + self._boundary_conditions['slope'] = s_bcs + + @property + def bc_deflection(self): + return self._boundary_conditions['deflection'] + + @bc_deflection.setter + def bc_deflection(self, d_bcs): + self._boundary_conditions['deflection'] = d_bcs + + def join(self, beam, via="fixed"): + """ + This method joins two beams to make a new composite beam system. + Passed Beam class instance is attached to the right end of calling + object. This method can be used to form beams having Discontinuous + values of Elastic modulus or Second moment. + + Parameters + ========== + beam : Beam class object + The Beam object which would be connected to the right of calling + object. + via : String + States the way two Beam object would get connected + - For axially fixed Beams, via="fixed" + - For Beams connected via rotation hinge, via="hinge" + + Examples + ======== + There is a cantilever beam of length 4 meters. For first 2 meters + its moment of inertia is `1.5*I` and `I` for the other end. + A pointload of magnitude 4 N is applied from the top at its free end. + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> E, I = symbols('E, I') + >>> R1, R2 = symbols('R1, R2') + >>> b1 = Beam(2, E, 1.5*I) + >>> b2 = Beam(2, E, I) + >>> b = b1.join(b2, "fixed") + >>> b.apply_load(20, 4, -1) + >>> b.apply_load(R1, 0, -1) + >>> b.apply_load(R2, 0, -2) + >>> b.bc_slope = [(0, 0)] + >>> b.bc_deflection = [(0, 0)] + >>> b.solve_for_reaction_loads(R1, R2) + >>> b.load + 80*SingularityFunction(x, 0, -2) - 20*SingularityFunction(x, 0, -1) + 20*SingularityFunction(x, 4, -1) + >>> b.slope() + (-((-80*SingularityFunction(x, 0, 1) + 10*SingularityFunction(x, 0, 2) - 10*SingularityFunction(x, 4, 2))/I + 120/I)/E + 80.0/(E*I))*SingularityFunction(x, 2, 0) + - 0.666666666666667*(-80*SingularityFunction(x, 0, 1) + 10*SingularityFunction(x, 0, 2) - 10*SingularityFunction(x, 4, 2))*SingularityFunction(x, 0, 0)/(E*I) + + 0.666666666666667*(-80*SingularityFunction(x, 0, 1) + 10*SingularityFunction(x, 0, 2) - 10*SingularityFunction(x, 4, 2))*SingularityFunction(x, 2, 0)/(E*I) + """ + x = self.variable + E = self.elastic_modulus + new_length = self.length + beam.length + if self.elastic_modulus != beam.elastic_modulus: + raise NotImplementedError('Joining beams with different Elastic modulus is not implemented.') + + if self.second_moment != beam.second_moment: + new_second_moment = Piecewise((self.second_moment, x<=self.length), + (beam.second_moment, x<=new_length)) + else: + new_second_moment = self.second_moment + + if via == "fixed": + new_beam = Beam(new_length, E, new_second_moment, x) + new_beam._joined_beam = True + return new_beam + + if via == "hinge": + new_beam = Beam(new_length, E, new_second_moment, x) + new_beam._joined_beam = True + new_beam.apply_rotation_hinge(self.length) + return new_beam + + def apply_support(self, loc, type="fixed"): + """ + This method applies support to a particular beam object and returns + the symbol of the unknown reaction load(s). + + Parameters + ========== + loc : Sympifyable + Location of point at which support is applied. + type : String + Determines type of Beam support applied. To apply support structure + with + - zero degree of freedom, type = "fixed" + - one degree of freedom, type = "pin" + - two degrees of freedom, type = "roller" + + Returns + ======= + Symbol or tuple of Symbol + The unknown reaction load as a symbol. + - Symbol(reaction_force) if type = "pin" or "roller" + - Symbol(reaction_force), Symbol(reaction_moment) if type = "fixed" + + Examples + ======== + There is a beam of length 20 meters. A moment of magnitude 100 Nm is + applied in the clockwise direction at the end of the beam. A pointload + of magnitude 8 N is applied from the top of the beam at a distance of 10 meters. + There is one fixed support at the start of the beam and a roller at the end. + + Using the sign convention of upward forces and clockwise moment + being positive. + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> E, I = symbols('E, I') + >>> b = Beam(20, E, I) + >>> p0, m0 = b.apply_support(0, 'fixed') + >>> p1 = b.apply_support(20, 'roller') + >>> b.apply_load(-8, 10, -1) + >>> b.apply_load(100, 20, -2) + >>> b.solve_for_reaction_loads(p0, m0, p1) + >>> b.reaction_loads + {M_0: 20, R_0: -2, R_20: 10} + >>> b.reaction_loads[p0] + -2 + >>> b.load + 20*SingularityFunction(x, 0, -2) - 2*SingularityFunction(x, 0, -1) + - 8*SingularityFunction(x, 10, -1) + 100*SingularityFunction(x, 20, -2) + + 10*SingularityFunction(x, 20, -1) + """ + loc = sympify(loc) + + self._applied_supports.append((loc, type)) + if type in ("pin", "roller"): + reaction_load = Symbol('R_'+str(loc)) + self.apply_load(reaction_load, loc, -1) + self.bc_deflection.append((loc, 0)) + else: + reaction_load = Symbol('R_'+str(loc)) + reaction_moment = Symbol('M_'+str(loc)) + self.apply_load(reaction_load, loc, -1) + self.apply_load(reaction_moment, loc, -2) + self.bc_deflection.append((loc, 0)) + self.bc_slope.append((loc, 0)) + self._support_as_loads.append((reaction_moment, loc, -2, None)) + + self._support_as_loads.append((reaction_load, loc, -1, None)) + + if type in ("pin", "roller"): + return reaction_load + else: + return reaction_load, reaction_moment + + def _get_I(self, loc): + """ + Helper function that returns the Second moment (I) at a location in the beam. + """ + I = self.second_moment + if not isinstance(I, Piecewise): + return I + else: + for i in range(len(I.args)): + if loc <= I.args[i][1].args[1]: + return I.args[i][0] + + def apply_rotation_hinge(self, loc): + """ + This method applies a rotation hinge at a single location on the beam. + + Parameters + ---------- + loc : Sympifyable + Location of point at which hinge is applied. + + Returns + ======= + Symbol + The unknown rotation jump multiplied by the elastic modulus and second moment as a symbol. + + Examples + ======== + There is a beam of length 15 meters. Pin supports are placed at distances + of 0 and 10 meters. There is a fixed support at the end. There are two rotation hinges + in the structure, one at 5 meters and one at 10 meters. A pointload of magnitude + 10 kN is applied on the hinge at 5 meters. A distributed load of 5 kN works on + the structure from 10 meters to the end. + + Using the sign convention of upward forces and clockwise moment + being positive. + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import Symbol + >>> E = Symbol('E') + >>> I = Symbol('I') + >>> b = Beam(15, E, I) + >>> r0 = b.apply_support(0, type='pin') + >>> r10 = b.apply_support(10, type='pin') + >>> r15, m15 = b.apply_support(15, type='fixed') + >>> p5 = b.apply_rotation_hinge(5) + >>> p12 = b.apply_rotation_hinge(12) + >>> b.apply_load(-10, 5, -1) + >>> b.apply_load(-5, 10, 0, 15) + >>> b.solve_for_reaction_loads(r0, r10, r15, m15) + >>> b.reaction_loads + {M_15: -75/2, R_0: 0, R_10: 40, R_15: -5} + >>> b.rotation_jumps + {P_12: -1875/(16*E*I), P_5: 9625/(24*E*I)} + >>> b.rotation_jumps[p12] + -1875/(16*E*I) + >>> b.bending_moment() + -9625*SingularityFunction(x, 5, -1)/24 + 10*SingularityFunction(x, 5, 1) + - 40*SingularityFunction(x, 10, 1) + 5*SingularityFunction(x, 10, 2)/2 + + 1875*SingularityFunction(x, 12, -1)/16 + 75*SingularityFunction(x, 15, 0)/2 + + 5*SingularityFunction(x, 15, 1) - 5*SingularityFunction(x, 15, 2)/2 + """ + loc = sympify(loc) + E = self.elastic_modulus + I = self._get_I(loc) + + rotation_jump = Symbol('P_'+str(loc)) + self._applied_rotation_hinges.append(loc) + self._rotation_hinge_symbols.append(rotation_jump) + self.apply_load(E * I * rotation_jump, loc, -3) + self.bc_bending_moment.append((loc, 0)) + return rotation_jump + + def apply_sliding_hinge(self, loc): + """ + This method applies a sliding hinge at a single location on the beam. + + Parameters + ---------- + loc : Sympifyable + Location of point at which hinge is applied. + + Returns + ======= + Symbol + The unknown deflection jump multiplied by the elastic modulus and second moment as a symbol. + + Examples + ======== + There is a beam of length 13 meters. A fixed support is placed at the beginning. + There is a pin support at the end. There is a sliding hinge at a location of 8 meters. + A pointload of magnitude 10 kN is applied on the hinge at 5 meters. + + Using the sign convention of upward forces and clockwise moment + being positive. + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> b = Beam(13, 20, 20) + >>> r0, m0 = b.apply_support(0, type="fixed") + >>> s8 = b.apply_sliding_hinge(8) + >>> r13 = b.apply_support(13, type="pin") + >>> b.apply_load(-10, 5, -1) + >>> b.solve_for_reaction_loads(r0, m0, r13) + >>> b.reaction_loads + {M_0: -50, R_0: 10, R_13: 0} + >>> b.deflection_jumps + {W_8: 85/24} + >>> b.deflection_jumps[s8] + 85/24 + >>> b.bending_moment() + 50*SingularityFunction(x, 0, 0) - 10*SingularityFunction(x, 0, 1) + + 10*SingularityFunction(x, 5, 1) - 4250*SingularityFunction(x, 8, -2)/3 + >>> b.deflection() + -SingularityFunction(x, 0, 2)/16 + SingularityFunction(x, 0, 3)/240 + - SingularityFunction(x, 5, 3)/240 + 85*SingularityFunction(x, 8, 0)/24 + """ + loc = sympify(loc) + E = self.elastic_modulus + I = self._get_I(loc) + + deflection_jump = Symbol('W_' + str(loc)) + self._applied_sliding_hinges.append(loc) + self._sliding_hinge_symbols.append(deflection_jump) + self.apply_load(E * I * deflection_jump, loc, -4) + self.bc_shear_force.append((loc, 0)) + return deflection_jump + + def apply_load(self, value, start, order, end=None): + """ + This method adds up the loads given to a particular beam object. + + Parameters + ========== + value : Sympifyable + The value inserted should have the units [Force/(Distance**(n+1)] + where n is the order of applied load. + Units for applied loads: + + - For moments, unit = kN*m + - For point loads, unit = kN + - For constant distributed load, unit = kN/m + - For ramp loads, unit = kN/m/m + - For parabolic ramp loads, unit = kN/m/m/m + - ... so on. + + start : Sympifyable + The starting point of the applied load. For point moments and + point forces this is the location of application. + order : Integer + The order of the applied load. + + - For moments, order = -2 + - For point loads, order =-1 + - For constant distributed load, order = 0 + - For ramp loads, order = 1 + - For parabolic ramp loads, order = 2 + - ... so on. + + end : Sympifyable, optional + An optional argument that can be used if the load has an end point + within the length of the beam. + + Examples + ======== + There is a beam of length 4 meters. A moment of magnitude 3 Nm is + applied in the clockwise direction at the starting point of the beam. + A point load of magnitude 4 N is applied from the top of the beam at + 2 meters from the starting point and a parabolic ramp load of magnitude + 2 N/m is applied below the beam starting from 2 meters to 3 meters + away from the starting point of the beam. + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> E, I = symbols('E, I') + >>> b = Beam(4, E, I) + >>> b.apply_load(-3, 0, -2) + >>> b.apply_load(4, 2, -1) + >>> b.apply_load(-2, 2, 2, end=3) + >>> b.load + -3*SingularityFunction(x, 0, -2) + 4*SingularityFunction(x, 2, -1) - 2*SingularityFunction(x, 2, 2) + 2*SingularityFunction(x, 3, 0) + 4*SingularityFunction(x, 3, 1) + 2*SingularityFunction(x, 3, 2) + + """ + x = self.variable + value = sympify(value) + start = sympify(start) + order = sympify(order) + + self._applied_loads.append((value, start, order, end)) + self._load += value*SingularityFunction(x, start, order) + self._original_load += value*SingularityFunction(x, start, order) + + if end: + # load has an end point within the length of the beam. + self._handle_end(x, value, start, order, end, type="apply") + + def remove_load(self, value, start, order, end=None): + """ + This method removes a particular load present on the beam object. + Returns a ValueError if the load passed as an argument is not + present on the beam. + + Parameters + ========== + value : Sympifyable + The magnitude of an applied load. + start : Sympifyable + The starting point of the applied load. For point moments and + point forces this is the location of application. + order : Integer + The order of the applied load. + - For moments, order= -2 + - For point loads, order=-1 + - For constant distributed load, order=0 + - For ramp loads, order=1 + - For parabolic ramp loads, order=2 + - ... so on. + end : Sympifyable, optional + An optional argument that can be used if the load has an end point + within the length of the beam. + + Examples + ======== + There is a beam of length 4 meters. A moment of magnitude 3 Nm is + applied in the clockwise direction at the starting point of the beam. + A pointload of magnitude 4 N is applied from the top of the beam at + 2 meters from the starting point and a parabolic ramp load of magnitude + 2 N/m is applied below the beam starting from 2 meters to 3 meters + away from the starting point of the beam. + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> E, I = symbols('E, I') + >>> b = Beam(4, E, I) + >>> b.apply_load(-3, 0, -2) + >>> b.apply_load(4, 2, -1) + >>> b.apply_load(-2, 2, 2, end=3) + >>> b.load + -3*SingularityFunction(x, 0, -2) + 4*SingularityFunction(x, 2, -1) - 2*SingularityFunction(x, 2, 2) + 2*SingularityFunction(x, 3, 0) + 4*SingularityFunction(x, 3, 1) + 2*SingularityFunction(x, 3, 2) + >>> b.remove_load(-2, 2, 2, end = 3) + >>> b.load + -3*SingularityFunction(x, 0, -2) + 4*SingularityFunction(x, 2, -1) + """ + x = self.variable + value = sympify(value) + start = sympify(start) + order = sympify(order) + + if (value, start, order, end) in self._applied_loads: + self._load -= value*SingularityFunction(x, start, order) + self._original_load -= value*SingularityFunction(x, start, order) + self._applied_loads.remove((value, start, order, end)) + else: + msg = "No such load distribution exists on the beam object." + raise ValueError(msg) + + if end: + # load has an end point within the length of the beam. + self._handle_end(x, value, start, order, end, type="remove") + + def _handle_end(self, x, value, start, order, end, type): + """ + This functions handles the optional `end` value in the + `apply_load` and `remove_load` functions. When the value + of end is not NULL, this function will be executed. + """ + if order.is_negative: + msg = ("If 'end' is provided the 'order' of the load cannot " + "be negative, i.e. 'end' is only valid for distributed " + "loads.") + raise ValueError(msg) + # NOTE : A Taylor series can be used to define the summation of + # singularity functions that subtract from the load past the end + # point such that it evaluates to zero past 'end'. + f = value*x**order + + if type == "apply": + # iterating for "apply_load" method + for i in range(0, order + 1): + self._load -= (f.diff(x, i).subs(x, end - start) * + SingularityFunction(x, end, i)/factorial(i)) + self._original_load -= (f.diff(x, i).subs(x, end - start) * + SingularityFunction(x, end, i)/factorial(i)) + elif type == "remove": + # iterating for "remove_load" method + for i in range(0, order + 1): + self._load += (f.diff(x, i).subs(x, end - start) * + SingularityFunction(x, end, i)/factorial(i)) + self._original_load += (f.diff(x, i).subs(x, end - start) * + SingularityFunction(x, end, i)/factorial(i)) + + + @property + def load(self): + """ + Returns a Singularity Function expression which represents + the load distribution curve of the Beam object. + + Examples + ======== + There is a beam of length 4 meters. A moment of magnitude 3 Nm is + applied in the clockwise direction at the starting point of the beam. + A point load of magnitude 4 N is applied from the top of the beam at + 2 meters from the starting point and a parabolic ramp load of magnitude + 2 N/m is applied below the beam starting from 3 meters away from the + starting point of the beam. + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> E, I = symbols('E, I') + >>> b = Beam(4, E, I) + >>> b.apply_load(-3, 0, -2) + >>> b.apply_load(4, 2, -1) + >>> b.apply_load(-2, 3, 2) + >>> b.load + -3*SingularityFunction(x, 0, -2) + 4*SingularityFunction(x, 2, -1) - 2*SingularityFunction(x, 3, 2) + """ + return self._load + + @property + def applied_loads(self): + """ + Returns a list of all loads applied on the beam object. + Each load in the list is a tuple of form (value, start, order, end). + + Examples + ======== + There is a beam of length 4 meters. A moment of magnitude 3 Nm is + applied in the clockwise direction at the starting point of the beam. + A pointload of magnitude 4 N is applied from the top of the beam at + 2 meters from the starting point. Another pointload of magnitude 5 N + is applied at same position. + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> E, I = symbols('E, I') + >>> b = Beam(4, E, I) + >>> b.apply_load(-3, 0, -2) + >>> b.apply_load(4, 2, -1) + >>> b.apply_load(5, 2, -1) + >>> b.load + -3*SingularityFunction(x, 0, -2) + 9*SingularityFunction(x, 2, -1) + >>> b.applied_loads + [(-3, 0, -2, None), (4, 2, -1, None), (5, 2, -1, None)] + """ + return self._applied_loads + + def solve_for_reaction_loads(self, *reactions): + """ + Solves for the reaction forces. + + Examples + ======== + There is a beam of length 30 meters. A moment of magnitude 120 Nm is + applied in the clockwise direction at the end of the beam. A pointload + of magnitude 8 N is applied from the top of the beam at the starting + point. There are two simple supports below the beam. One at the end + and another one at a distance of 10 meters from the start. The + deflection is restricted at both the supports. + + Using the sign convention of upward forces and clockwise moment + being positive. + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> E, I = symbols('E, I') + >>> R1, R2 = symbols('R1, R2') + >>> b = Beam(30, E, I) + >>> b.apply_load(-8, 0, -1) + >>> b.apply_load(R1, 10, -1) # Reaction force at x = 10 + >>> b.apply_load(R2, 30, -1) # Reaction force at x = 30 + >>> b.apply_load(120, 30, -2) + >>> b.bc_deflection = [(10, 0), (30, 0)] + >>> b.load + R1*SingularityFunction(x, 10, -1) + R2*SingularityFunction(x, 30, -1) + - 8*SingularityFunction(x, 0, -1) + 120*SingularityFunction(x, 30, -2) + >>> b.solve_for_reaction_loads(R1, R2) + >>> b.reaction_loads + {R1: 6, R2: 2} + >>> b.load + -8*SingularityFunction(x, 0, -1) + 6*SingularityFunction(x, 10, -1) + + 120*SingularityFunction(x, 30, -2) + 2*SingularityFunction(x, 30, -1) + """ + + x = self.variable + l = self.length + C3 = Symbol('C3') + C4 = Symbol('C4') + rotation_jumps = tuple(self._rotation_hinge_symbols) + deflection_jumps = tuple(self._sliding_hinge_symbols) + + shear_curve = limit(self.shear_force(), x, l) + moment_curve = limit(self.bending_moment(), x, l) + + shear_force_eqs = [] + bending_moment_eqs = [] + slope_eqs = [] + deflection_eqs = [] + + for position, value in self._boundary_conditions['shear_force']: + eqs = self.shear_force().subs(x, position) - value + new_eqs = sum(arg for arg in eqs.args if not any(num.is_infinite for num in arg.args)) + shear_force_eqs.append(new_eqs) + + for position, value in self._boundary_conditions['bending_moment']: + eqs = self.bending_moment().subs(x, position) - value + new_eqs = sum(arg for arg in eqs.args if not any(num.is_infinite for num in arg.args)) + bending_moment_eqs.append(new_eqs) + + slope_curve = integrate(self.bending_moment(), x) + C3 + for position, value in self._boundary_conditions['slope']: + eqs = slope_curve.subs(x, position) - value + slope_eqs.append(eqs) + + deflection_curve = integrate(slope_curve, x) + C4 + for position, value in self._boundary_conditions['deflection']: + eqs = deflection_curve.subs(x, position) - value + deflection_eqs.append(eqs) + + solution = list((linsolve([shear_curve, moment_curve] + shear_force_eqs + bending_moment_eqs + slope_eqs + + deflection_eqs, (C3, C4) + reactions + rotation_jumps + deflection_jumps).args)[0]) + reaction_index = 2+len(reactions) + rotation_index = reaction_index + len(rotation_jumps) + reaction_solution = solution[2:reaction_index] + rotation_solution = solution[reaction_index:rotation_index] + deflection_solution = solution[rotation_index:] + + self._reaction_loads = dict(zip(reactions, reaction_solution)) + self._rotation_jumps = dict(zip(rotation_jumps, rotation_solution)) + self._deflection_jumps = dict(zip(deflection_jumps, deflection_solution)) + self._load = self._load.subs(self._reaction_loads) + self._load = self._load.subs(self._rotation_jumps) + self._load = self._load.subs(self._deflection_jumps) + + def shear_force(self): + """ + Returns a Singularity Function expression which represents + the shear force curve of the Beam object. + + Examples + ======== + There is a beam of length 30 meters. A moment of magnitude 120 Nm is + applied in the clockwise direction at the end of the beam. A pointload + of magnitude 8 N is applied from the top of the beam at the starting + point. There are two simple supports below the beam. One at the end + and another one at a distance of 10 meters from the start. The + deflection is restricted at both the supports. + + Using the sign convention of upward forces and clockwise moment + being positive. + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> E, I = symbols('E, I') + >>> R1, R2 = symbols('R1, R2') + >>> b = Beam(30, E, I) + >>> b.apply_load(-8, 0, -1) + >>> b.apply_load(R1, 10, -1) + >>> b.apply_load(R2, 30, -1) + >>> b.apply_load(120, 30, -2) + >>> b.bc_deflection = [(10, 0), (30, 0)] + >>> b.solve_for_reaction_loads(R1, R2) + >>> b.shear_force() + 8*SingularityFunction(x, 0, 0) - 6*SingularityFunction(x, 10, 0) - 120*SingularityFunction(x, 30, -1) - 2*SingularityFunction(x, 30, 0) + """ + x = self.variable + return -integrate(self.load, x) + + def max_shear_force(self): + """Returns maximum Shear force and its coordinate + in the Beam object.""" + shear_curve = self.shear_force() + x = self.variable + + terms = shear_curve.args + singularity = [] # Points at which shear function changes + for term in terms: + if isinstance(term, Mul): + term = term.args[-1] # SingularityFunction in the term + singularity.append(term.args[1]) + singularity = list(set(singularity)) + singularity.sort() + + intervals = [] # List of Intervals with discrete value of shear force + shear_values = [] # List of values of shear force in each interval + for i, s in enumerate(singularity): + if s == 0: + continue + try: + shear_slope = Piecewise((float("nan"), x<=singularity[i-1]),(self._load.rewrite(Piecewise), x>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> E, I = symbols('E, I') + >>> R1, R2 = symbols('R1, R2') + >>> b = Beam(30, E, I) + >>> b.apply_load(-8, 0, -1) + >>> b.apply_load(R1, 10, -1) + >>> b.apply_load(R2, 30, -1) + >>> b.apply_load(120, 30, -2) + >>> b.bc_deflection = [(10, 0), (30, 0)] + >>> b.solve_for_reaction_loads(R1, R2) + >>> b.bending_moment() + 8*SingularityFunction(x, 0, 1) - 6*SingularityFunction(x, 10, 1) - 120*SingularityFunction(x, 30, 0) - 2*SingularityFunction(x, 30, 1) + """ + x = self.variable + return integrate(self.shear_force(), x) + + def max_bmoment(self): + """Returns maximum Shear force and its coordinate + in the Beam object.""" + bending_curve = self.bending_moment() + x = self.variable + + terms = bending_curve.args + singularity = [] # Points at which bending moment changes + for term in terms: + if isinstance(term, Mul): + term = term.args[-1] # SingularityFunction in the term + singularity.append(term.args[1]) + singularity = list(set(singularity)) + singularity.sort() + + intervals = [] # List of Intervals with discrete value of bending moment + moment_values = [] # List of values of bending moment in each interval + for i, s in enumerate(singularity): + if s == 0: + continue + try: + moment_slope = Piecewise( + (float("nan"), x <= singularity[i - 1]), + (self.shear_force().rewrite(Piecewise), x < s), + (float("nan"), True)) + points = solve(moment_slope, x) + val = [] + for point in points: + val.append(abs(bending_curve.subs(x, point))) + points.extend([singularity[i-1], s]) + val += [abs(limit(bending_curve, x, singularity[i-1], '+')), abs(limit(bending_curve, x, s, '-'))] + max_moment = max(val) + moment_values.append(max_moment) + intervals.append(points[val.index(max_moment)]) + + # If bending moment in a particular Interval has zero or constant + # slope, then above block gives NotImplementedError as solve + # can't represent Interval solutions. + except NotImplementedError: + initial_moment = limit(bending_curve, x, singularity[i-1], '+') + final_moment = limit(bending_curve, x, s, '-') + # If bending_curve has a constant slope(it is a line). + if bending_curve.subs(x, (singularity[i-1] + s)/2) == (initial_moment + final_moment)/2 and initial_moment != final_moment: + moment_values.extend([initial_moment, final_moment]) + intervals.extend([singularity[i-1], s]) + else: # bending_curve has same value in whole Interval + moment_values.append(final_moment) + intervals.append(Interval(singularity[i-1], s)) + + moment_values = list(map(abs, moment_values)) + maximum_moment = max(moment_values) + point = intervals[moment_values.index(maximum_moment)] + return (point, maximum_moment) + + def point_cflexure(self): + """ + Returns a Set of point(s) with zero bending moment and + where bending moment curve of the beam object changes + its sign from negative to positive or vice versa. + + Examples + ======== + There is is 10 meter long overhanging beam. There are + two simple supports below the beam. One at the start + and another one at a distance of 6 meters from the start. + Point loads of magnitude 10KN and 20KN are applied at + 2 meters and 4 meters from start respectively. A Uniformly + distribute load of magnitude of magnitude 3KN/m is also + applied on top starting from 6 meters away from starting + point till end. + Using the sign convention of upward forces and clockwise moment + being positive. + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> E, I = symbols('E, I') + >>> b = Beam(10, E, I) + >>> b.apply_load(-4, 0, -1) + >>> b.apply_load(-46, 6, -1) + >>> b.apply_load(10, 2, -1) + >>> b.apply_load(20, 4, -1) + >>> b.apply_load(3, 6, 0) + >>> b.point_cflexure() + [10/3] + """ + #Removes the singularity functions of order < 0 from the bending moment equation used in this method + non_singular_bending_moment = sum(arg for arg in self.bending_moment().args if not arg.args[1].args[2] < 0) + + # To restrict the range within length of the Beam + moment_curve = Piecewise((float("nan"), self.variable<=0), + (non_singular_bending_moment, self.variable>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> E, I = symbols('E, I') + >>> R1, R2 = symbols('R1, R2') + >>> b = Beam(30, E, I) + >>> b.apply_load(-8, 0, -1) + >>> b.apply_load(R1, 10, -1) + >>> b.apply_load(R2, 30, -1) + >>> b.apply_load(120, 30, -2) + >>> b.bc_deflection = [(10, 0), (30, 0)] + >>> b.solve_for_reaction_loads(R1, R2) + >>> b.slope() + (-4*SingularityFunction(x, 0, 2) + 3*SingularityFunction(x, 10, 2) + + 120*SingularityFunction(x, 30, 1) + SingularityFunction(x, 30, 2) + 4000/3)/(E*I) + """ + x = self.variable + E = self.elastic_modulus + I = self.second_moment + + if not self._boundary_conditions['slope']: + return diff(self.deflection(), x) + if isinstance(I, Piecewise) and self._joined_beam: + args = I.args + slope = 0 + prev_slope = 0 + prev_end = 0 + for i in range(len(args)): + if i != 0: + prev_end = args[i-1][1].args[1] + slope_value = -S.One/E*integrate(self.bending_moment()/args[i][0], (x, prev_end, x)) + if i != len(args) - 1: + slope += (prev_slope + slope_value)*SingularityFunction(x, prev_end, 0) - \ + (prev_slope + slope_value)*SingularityFunction(x, args[i][1].args[1], 0) + else: + slope += (prev_slope + slope_value)*SingularityFunction(x, prev_end, 0) + prev_slope = slope_value.subs(x, args[i][1].args[1]) + return slope + + C3 = Symbol('C3') + slope_curve = -integrate(S.One/(E*I)*self.bending_moment(), x) + C3 + + bc_eqs = [] + for position, value in self._boundary_conditions['slope']: + eqs = slope_curve.subs(x, position) - value + bc_eqs.append(eqs) + constants = list(linsolve(bc_eqs, C3)) + slope_curve = slope_curve.subs({C3: constants[0][0]}) + return slope_curve + + def deflection(self): + """ + Returns a Singularity Function expression which represents + the elastic curve or deflection of the Beam object. + + Examples + ======== + There is a beam of length 30 meters. A moment of magnitude 120 Nm is + applied in the clockwise direction at the end of the beam. A pointload + of magnitude 8 N is applied from the top of the beam at the starting + point. There are two simple supports below the beam. One at the end + and another one at a distance of 10 meters from the start. The + deflection is restricted at both the supports. + + Using the sign convention of upward forces and clockwise moment + being positive. + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> E, I = symbols('E, I') + >>> R1, R2 = symbols('R1, R2') + >>> b = Beam(30, E, I) + >>> b.apply_load(-8, 0, -1) + >>> b.apply_load(R1, 10, -1) + >>> b.apply_load(R2, 30, -1) + >>> b.apply_load(120, 30, -2) + >>> b.bc_deflection = [(10, 0), (30, 0)] + >>> b.solve_for_reaction_loads(R1, R2) + >>> b.deflection() + (4000*x/3 - 4*SingularityFunction(x, 0, 3)/3 + SingularityFunction(x, 10, 3) + + 60*SingularityFunction(x, 30, 2) + SingularityFunction(x, 30, 3)/3 - 12000)/(E*I) + """ + x = self.variable + E = self.elastic_modulus + I = self.second_moment + if not self._boundary_conditions['deflection'] and not self._boundary_conditions['slope']: + if isinstance(I, Piecewise) and self._joined_beam: + args = I.args + prev_slope = 0 + prev_def = 0 + prev_end = 0 + deflection = 0 + for i in range(len(args)): + if i != 0: + prev_end = args[i-1][1].args[1] + slope_value = -S.One/E*integrate(self.bending_moment()/args[i][0], (x, prev_end, x)) + recent_segment_slope = prev_slope + slope_value + deflection_value = integrate(recent_segment_slope, (x, prev_end, x)) + if i != len(args) - 1: + deflection += (prev_def + deflection_value)*SingularityFunction(x, prev_end, 0) \ + - (prev_def + deflection_value)*SingularityFunction(x, args[i][1].args[1], 0) + else: + deflection += (prev_def + deflection_value)*SingularityFunction(x, prev_end, 0) + prev_slope = slope_value.subs(x, args[i][1].args[1]) + prev_def = deflection_value.subs(x, args[i][1].args[1]) + return deflection + base_char = self._base_char + constants = symbols(base_char + '3:5') + return S.One/(E*I)*integrate(-integrate(self.bending_moment(), x), x) + constants[0]*x + constants[1] + elif not self._boundary_conditions['deflection']: + base_char = self._base_char + constant = symbols(base_char + '4') + return integrate(self.slope(), x) + constant + elif not self._boundary_conditions['slope'] and self._boundary_conditions['deflection']: + if isinstance(I, Piecewise) and self._joined_beam: + args = I.args + prev_slope = 0 + prev_def = 0 + prev_end = 0 + deflection = 0 + for i in range(len(args)): + if i != 0: + prev_end = args[i-1][1].args[1] + slope_value = -S.One/E*integrate(self.bending_moment()/args[i][0], (x, prev_end, x)) + recent_segment_slope = prev_slope + slope_value + deflection_value = integrate(recent_segment_slope, (x, prev_end, x)) + if i != len(args) - 1: + deflection += (prev_def + deflection_value)*SingularityFunction(x, prev_end, 0) \ + - (prev_def + deflection_value)*SingularityFunction(x, args[i][1].args[1], 0) + else: + deflection += (prev_def + deflection_value)*SingularityFunction(x, prev_end, 0) + prev_slope = slope_value.subs(x, args[i][1].args[1]) + prev_def = deflection_value.subs(x, args[i][1].args[1]) + return deflection + base_char = self._base_char + C3, C4 = symbols(base_char + '3:5') # Integration constants + slope_curve = -integrate(self.bending_moment(), x) + C3 + deflection_curve = integrate(slope_curve, x) + C4 + bc_eqs = [] + for position, value in self._boundary_conditions['deflection']: + eqs = deflection_curve.subs(x, position) - value + bc_eqs.append(eqs) + constants = list(linsolve(bc_eqs, (C3, C4))) + deflection_curve = deflection_curve.subs({C3: constants[0][0], C4: constants[0][1]}) + return S.One/(E*I)*deflection_curve + + if isinstance(I, Piecewise) and self._joined_beam: + args = I.args + prev_slope = 0 + prev_def = 0 + prev_end = 0 + deflection = 0 + for i in range(len(args)): + if i != 0: + prev_end = args[i-1][1].args[1] + slope_value = S.One/E*integrate(self.bending_moment()/args[i][0], (x, prev_end, x)) + recent_segment_slope = prev_slope + slope_value + deflection_value = integrate(recent_segment_slope, (x, prev_end, x)) + if i != len(args) - 1: + deflection += (prev_def + deflection_value)*SingularityFunction(x, prev_end, 0) \ + - (prev_def + deflection_value)*SingularityFunction(x, args[i][1].args[1], 0) + else: + deflection += (prev_def + deflection_value)*SingularityFunction(x, prev_end, 0) + prev_slope = slope_value.subs(x, args[i][1].args[1]) + prev_def = deflection_value.subs(x, args[i][1].args[1]) + return deflection + + C4 = Symbol('C4') + deflection_curve = integrate(self.slope(), x) + C4 + + bc_eqs = [] + for position, value in self._boundary_conditions['deflection']: + eqs = deflection_curve.subs(x, position) - value + bc_eqs.append(eqs) + + constants = list(linsolve(bc_eqs, C4)) + deflection_curve = deflection_curve.subs({C4: constants[0][0]}) + return deflection_curve + + def max_deflection(self): + """ + Returns point of max deflection and its corresponding deflection value + in a Beam object. + """ + + # To restrict the range within length of the Beam + slope_curve = Piecewise((float("nan"), self.variable<=0), + (self.slope(), self.variable>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> R1, R2 = symbols('R1, R2') + >>> b = Beam(8, 200*(10**9), 400*(10**-6), 2) + >>> b.apply_load(5000, 2, -1) + >>> b.apply_load(R1, 0, -1) + >>> b.apply_load(R2, 8, -1) + >>> b.apply_load(10000, 4, 0, end=8) + >>> b.bc_deflection = [(0, 0), (8, 0)] + >>> b.solve_for_reaction_loads(R1, R2) + >>> b.plot_shear_stress() + Plot object containing: + [0]: cartesian line: 6875*SingularityFunction(x, 0, 0) - 2500*SingularityFunction(x, 2, 0) + - 5000*SingularityFunction(x, 4, 1) + 15625*SingularityFunction(x, 8, 0) + + 5000*SingularityFunction(x, 8, 1) for x over (0.0, 8.0) + """ + + shear_stress = self.shear_stress() + x = self.variable + length = self.length + + if subs is None: + subs = {} + for sym in shear_stress.atoms(Symbol): + if sym != x and sym not in subs: + raise ValueError('value of %s was not passed.' %sym) + + if length in subs: + length = subs[length] + + # Returns Plot of Shear Stress + return plot (shear_stress.subs(subs), (x, 0, length), + title='Shear Stress', xlabel=r'$\mathrm{x}$', ylabel=r'$\tau$', + line_color='r') + + + def plot_shear_force(self, subs=None): + """ + + Returns a plot for Shear force present in the Beam object. + + Parameters + ========== + subs : dictionary + Python dictionary containing Symbols as key and their + corresponding values. + + Examples + ======== + There is a beam of length 8 meters. A constant distributed load of 10 KN/m + is applied from half of the beam till the end. There are two simple supports + below the beam, one at the starting point and another at the ending point + of the beam. A pointload of magnitude 5 KN is also applied from top of the + beam, at a distance of 4 meters from the starting point. + Take E = 200 GPa and I = 400*(10**-6) meter**4. + + Using the sign convention of downwards forces being positive. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> R1, R2 = symbols('R1, R2') + >>> b = Beam(8, 200*(10**9), 400*(10**-6)) + >>> b.apply_load(5000, 2, -1) + >>> b.apply_load(R1, 0, -1) + >>> b.apply_load(R2, 8, -1) + >>> b.apply_load(10000, 4, 0, end=8) + >>> b.bc_deflection = [(0, 0), (8, 0)] + >>> b.solve_for_reaction_loads(R1, R2) + >>> b.plot_shear_force() + Plot object containing: + [0]: cartesian line: 13750*SingularityFunction(x, 0, 0) - 5000*SingularityFunction(x, 2, 0) + - 10000*SingularityFunction(x, 4, 1) + 31250*SingularityFunction(x, 8, 0) + + 10000*SingularityFunction(x, 8, 1) for x over (0.0, 8.0) + """ + shear_force = self.shear_force() + if subs is None: + subs = {} + for sym in shear_force.atoms(Symbol): + if sym == self.variable: + continue + if sym not in subs: + raise ValueError('Value of %s was not passed.' %sym) + if self.length in subs: + length = subs[self.length] + else: + length = self.length + return plot(shear_force.subs(subs), (self.variable, 0, length), title='Shear Force', + xlabel=r'$\mathrm{x}$', ylabel=r'$\mathrm{V}$', line_color='g') + + def plot_bending_moment(self, subs=None): + """ + + Returns a plot for Bending moment present in the Beam object. + + Parameters + ========== + subs : dictionary + Python dictionary containing Symbols as key and their + corresponding values. + + Examples + ======== + There is a beam of length 8 meters. A constant distributed load of 10 KN/m + is applied from half of the beam till the end. There are two simple supports + below the beam, one at the starting point and another at the ending point + of the beam. A pointload of magnitude 5 KN is also applied from top of the + beam, at a distance of 4 meters from the starting point. + Take E = 200 GPa and I = 400*(10**-6) meter**4. + + Using the sign convention of downwards forces being positive. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> R1, R2 = symbols('R1, R2') + >>> b = Beam(8, 200*(10**9), 400*(10**-6)) + >>> b.apply_load(5000, 2, -1) + >>> b.apply_load(R1, 0, -1) + >>> b.apply_load(R2, 8, -1) + >>> b.apply_load(10000, 4, 0, end=8) + >>> b.bc_deflection = [(0, 0), (8, 0)] + >>> b.solve_for_reaction_loads(R1, R2) + >>> b.plot_bending_moment() + Plot object containing: + [0]: cartesian line: 13750*SingularityFunction(x, 0, 1) - 5000*SingularityFunction(x, 2, 1) + - 5000*SingularityFunction(x, 4, 2) + 31250*SingularityFunction(x, 8, 1) + + 5000*SingularityFunction(x, 8, 2) for x over (0.0, 8.0) + """ + bending_moment = self.bending_moment() + if subs is None: + subs = {} + for sym in bending_moment.atoms(Symbol): + if sym == self.variable: + continue + if sym not in subs: + raise ValueError('Value of %s was not passed.' %sym) + if self.length in subs: + length = subs[self.length] + else: + length = self.length + return plot(bending_moment.subs(subs), (self.variable, 0, length), title='Bending Moment', + xlabel=r'$\mathrm{x}$', ylabel=r'$\mathrm{M}$', line_color='b') + + def plot_slope(self, subs=None): + """ + + Returns a plot for slope of deflection curve of the Beam object. + + Parameters + ========== + subs : dictionary + Python dictionary containing Symbols as key and their + corresponding values. + + Examples + ======== + There is a beam of length 8 meters. A constant distributed load of 10 KN/m + is applied from half of the beam till the end. There are two simple supports + below the beam, one at the starting point and another at the ending point + of the beam. A pointload of magnitude 5 KN is also applied from top of the + beam, at a distance of 4 meters from the starting point. + Take E = 200 GPa and I = 400*(10**-6) meter**4. + + Using the sign convention of downwards forces being positive. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> R1, R2 = symbols('R1, R2') + >>> b = Beam(8, 200*(10**9), 400*(10**-6)) + >>> b.apply_load(5000, 2, -1) + >>> b.apply_load(R1, 0, -1) + >>> b.apply_load(R2, 8, -1) + >>> b.apply_load(10000, 4, 0, end=8) + >>> b.bc_deflection = [(0, 0), (8, 0)] + >>> b.solve_for_reaction_loads(R1, R2) + >>> b.plot_slope() + Plot object containing: + [0]: cartesian line: -8.59375e-5*SingularityFunction(x, 0, 2) + 3.125e-5*SingularityFunction(x, 2, 2) + + 2.08333333333333e-5*SingularityFunction(x, 4, 3) - 0.0001953125*SingularityFunction(x, 8, 2) + - 2.08333333333333e-5*SingularityFunction(x, 8, 3) + 0.00138541666666667 for x over (0.0, 8.0) + """ + slope = self.slope() + if subs is None: + subs = {} + for sym in slope.atoms(Symbol): + if sym == self.variable: + continue + if sym not in subs: + raise ValueError('Value of %s was not passed.' %sym) + if self.length in subs: + length = subs[self.length] + else: + length = self.length + return plot(slope.subs(subs), (self.variable, 0, length), title='Slope', + xlabel=r'$\mathrm{x}$', ylabel=r'$\theta$', line_color='m') + + def plot_deflection(self, subs=None): + """ + + Returns a plot for deflection curve of the Beam object. + + Parameters + ========== + subs : dictionary + Python dictionary containing Symbols as key and their + corresponding values. + + Examples + ======== + There is a beam of length 8 meters. A constant distributed load of 10 KN/m + is applied from half of the beam till the end. There are two simple supports + below the beam, one at the starting point and another at the ending point + of the beam. A pointload of magnitude 5 KN is also applied from top of the + beam, at a distance of 4 meters from the starting point. + Take E = 200 GPa and I = 400*(10**-6) meter**4. + + Using the sign convention of downwards forces being positive. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> R1, R2 = symbols('R1, R2') + >>> b = Beam(8, 200*(10**9), 400*(10**-6)) + >>> b.apply_load(5000, 2, -1) + >>> b.apply_load(R1, 0, -1) + >>> b.apply_load(R2, 8, -1) + >>> b.apply_load(10000, 4, 0, end=8) + >>> b.bc_deflection = [(0, 0), (8, 0)] + >>> b.solve_for_reaction_loads(R1, R2) + >>> b.plot_deflection() + Plot object containing: + [0]: cartesian line: 0.00138541666666667*x - 2.86458333333333e-5*SingularityFunction(x, 0, 3) + + 1.04166666666667e-5*SingularityFunction(x, 2, 3) + 5.20833333333333e-6*SingularityFunction(x, 4, 4) + - 6.51041666666667e-5*SingularityFunction(x, 8, 3) - 5.20833333333333e-6*SingularityFunction(x, 8, 4) + for x over (0.0, 8.0) + """ + deflection = self.deflection() + if subs is None: + subs = {} + for sym in deflection.atoms(Symbol): + if sym == self.variable: + continue + if sym not in subs: + raise ValueError('Value of %s was not passed.' %sym) + if self.length in subs: + length = subs[self.length] + else: + length = self.length + return plot(deflection.subs(subs), (self.variable, 0, length), + title='Deflection', xlabel=r'$\mathrm{x}$', ylabel=r'$\delta$', + line_color='r') + + + def plot_loading_results(self, subs=None): + """ + Returns a subplot of Shear Force, Bending Moment, + Slope and Deflection of the Beam object. + + Parameters + ========== + + subs : dictionary + Python dictionary containing Symbols as key and their + corresponding values. + + Examples + ======== + + There is a beam of length 8 meters. A constant distributed load of 10 KN/m + is applied from half of the beam till the end. There are two simple supports + below the beam, one at the starting point and another at the ending point + of the beam. A pointload of magnitude 5 KN is also applied from top of the + beam, at a distance of 4 meters from the starting point. + Take E = 200 GPa and I = 400*(10**-6) meter**4. + + Using the sign convention of downwards forces being positive. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> R1, R2 = symbols('R1, R2') + >>> b = Beam(8, 200*(10**9), 400*(10**-6)) + >>> b.apply_load(5000, 2, -1) + >>> b.apply_load(R1, 0, -1) + >>> b.apply_load(R2, 8, -1) + >>> b.apply_load(10000, 4, 0, end=8) + >>> b.bc_deflection = [(0, 0), (8, 0)] + >>> b.solve_for_reaction_loads(R1, R2) + >>> axes = b.plot_loading_results() + """ + length = self.length + variable = self.variable + if subs is None: + subs = {} + for sym in self.deflection().atoms(Symbol): + if sym == self.variable: + continue + if sym not in subs: + raise ValueError('Value of %s was not passed.' %sym) + if length in subs: + length = subs[length] + ax1 = plot(self.shear_force().subs(subs), (variable, 0, length), + title="Shear Force", xlabel=r'$\mathrm{x}$', ylabel=r'$\mathrm{V}$', + line_color='g', show=False) + ax2 = plot(self.bending_moment().subs(subs), (variable, 0, length), + title="Bending Moment", xlabel=r'$\mathrm{x}$', ylabel=r'$\mathrm{M}$', + line_color='b', show=False) + ax3 = plot(self.slope().subs(subs), (variable, 0, length), + title="Slope", xlabel=r'$\mathrm{x}$', ylabel=r'$\theta$', + line_color='m', show=False) + ax4 = plot(self.deflection().subs(subs), (variable, 0, length), + title="Deflection", xlabel=r'$\mathrm{x}$', ylabel=r'$\delta$', + line_color='r', show=False) + + return PlotGrid(4, 1, ax1, ax2, ax3, ax4) + + def _solve_for_ild_equations(self, value): + """ + + Helper function for I.L.D. It takes the unsubstituted + copy of the load equation and uses it to calculate shear force and bending + moment equations. + """ + x = self.variable + a = self.ild_variable + load = self._load + value * SingularityFunction(x, a, -1) + shear_force = -integrate(load, x) + bending_moment = integrate(shear_force, x) + + return shear_force, bending_moment + + def solve_for_ild_reactions(self, value, *reactions): + """ + + Determines the Influence Line Diagram equations for reaction + forces under the effect of a moving load. + + Parameters + ========== + value : Integer + Magnitude of moving load + reactions : + The reaction forces applied on the beam. + + Warning + ======= + This method creates equations that can give incorrect results when + substituting a = 0 or a = l, with l the length of the beam. + + Examples + ======== + + There is a beam of length 10 meters. There are two simple supports + below the beam, one at the starting point and another at the ending + point of the beam. Calculate the I.L.D. equations for reaction forces + under the effect of a moving load of magnitude 1kN. + + Using the sign convention of downwards forces being positive. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy import symbols + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> E, I = symbols('E, I') + >>> R_0, R_10 = symbols('R_0, R_10') + >>> b = Beam(10, E, I) + >>> p0 = b.apply_support(0, 'pin') + >>> p10 = b.apply_support(10, 'roller') + >>> b.solve_for_ild_reactions(1,R_0,R_10) + >>> b.ild_reactions + {R_0: -SingularityFunction(a, 0, 0) + SingularityFunction(a, 0, 1)/10 - SingularityFunction(a, 10, 1)/10, + R_10: -SingularityFunction(a, 0, 1)/10 + SingularityFunction(a, 10, 0) + SingularityFunction(a, 10, 1)/10} + + """ + shear_force, bending_moment = self._solve_for_ild_equations(value) + x = self.variable + l = self.length + a = self.ild_variable + + rotation_jumps = tuple(self._rotation_hinge_symbols) + deflection_jumps = tuple(self._sliding_hinge_symbols) + + C3 = Symbol('C3') + C4 = Symbol('C4') + + shear_curve = limit(shear_force, x, l) - value*(SingularityFunction(a, 0, 0) - SingularityFunction(a, l, 0)) + moment_curve = (limit(bending_moment, x, l) - value * (l * SingularityFunction(a, 0, 0) + - SingularityFunction(a, 0, 1) + + SingularityFunction(a, l, 1))) + + shear_force_eqs = [] + bending_moment_eqs = [] + slope_eqs = [] + deflection_eqs = [] + + for position, val in self._boundary_conditions['shear_force']: + eqs = self.shear_force().subs(x, position) - val + eqs_without_inf = sum(arg for arg in eqs.args if not any(num.is_infinite for num in arg.args)) + shear_sinc = value * (SingularityFunction(- a, - position, 0) - SingularityFunction(-a, 0, 0)) + eqs_with_shear_sinc = eqs_without_inf - shear_sinc + shear_force_eqs.append(eqs_with_shear_sinc) + + for position, val in self._boundary_conditions['bending_moment']: + eqs = self.bending_moment().subs(x, position) - val + eqs_without_inf = sum(arg for arg in eqs.args if not any(num.is_infinite for num in arg.args)) + moment_sinc = value * (position * SingularityFunction(a, 0, 0) + - SingularityFunction(a, 0, 1) + SingularityFunction(a, position, 1)) + eqs_with_moment_sinc = eqs_without_inf - moment_sinc + bending_moment_eqs.append(eqs_with_moment_sinc) + + slope_curve = integrate(bending_moment, x) + C3 + for position, val in self._boundary_conditions['slope']: + eqs = slope_curve.subs(x, position) - val + value * (SingularityFunction(-a, 0, 1) + position * SingularityFunction(-a, 0, 0))**2 / 2 + slope_eqs.append(eqs) + + deflection_curve = integrate(slope_curve, x) + C4 + for position, val in self._boundary_conditions['deflection']: + eqs = deflection_curve.subs(x, position) - val + value * (SingularityFunction(-a, 0, 1) + position * SingularityFunction(-a, 0, 0)) ** 3 / 6 + deflection_eqs.append(eqs) + + solution = list((linsolve([shear_curve, moment_curve] + shear_force_eqs + bending_moment_eqs + slope_eqs + + deflection_eqs, (C3, C4) + reactions + rotation_jumps + deflection_jumps).args)[0]) + + reaction_index = 2 + len(reactions) + rotation_index = reaction_index + len(rotation_jumps) + reaction_solution = solution[2:reaction_index] + rotation_solution = solution[reaction_index:rotation_index] + deflection_solution = solution[rotation_index:] + + self._ild_reactions = dict(zip(reactions, reaction_solution)) + self._ild_rotations_jumps = dict(zip(rotation_jumps, rotation_solution)) + self._ild_deflection_jumps = dict(zip(deflection_jumps, deflection_solution)) + + def plot_ild_reactions(self, subs=None): + """ + + Plots the Influence Line Diagram of Reaction Forces + under the effect of a moving load. This function + should be called after calling solve_for_ild_reactions(). + + Parameters + ========== + + subs : dictionary + Python dictionary containing Symbols as key and their + corresponding values. + + Warning + ======= + The values for a = 0 and a = l, with l the length of the beam, in + the plot can be incorrect. + + Examples + ======== + + There is a beam of length 10 meters. A point load of magnitude 5KN + is also applied from top of the beam, at a distance of 4 meters + from the starting point. There are two simple supports below the + beam, located at the starting point and at a distance of 7 meters + from the starting point. Plot the I.L.D. equations for reactions + at both support points under the effect of a moving load + of magnitude 1kN. + + Using the sign convention of downwards forces being positive. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy import symbols + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> E, I = symbols('E, I') + >>> R_0, R_7 = symbols('R_0, R_7') + >>> b = Beam(10, E, I) + >>> p0 = b.apply_support(0, 'roller') + >>> p7 = b.apply_support(7, 'roller') + >>> b.apply_load(5,4,-1) + >>> b.solve_for_ild_reactions(1,R_0,R_7) + >>> b.ild_reactions + {R_0: -SingularityFunction(a, 0, 0) + SingularityFunction(a, 0, 1)/7 + - 3*SingularityFunction(a, 10, 0)/7 - SingularityFunction(a, 10, 1)/7 - 15/7, + R_7: -SingularityFunction(a, 0, 1)/7 + 10*SingularityFunction(a, 10, 0)/7 + SingularityFunction(a, 10, 1)/7 - 20/7} + >>> b.plot_ild_reactions() + PlotGrid object containing: + Plot[0]:Plot object containing: + [0]: cartesian line: -SingularityFunction(a, 0, 0) + SingularityFunction(a, 0, 1)/7 + - 3*SingularityFunction(a, 10, 0)/7 - SingularityFunction(a, 10, 1)/7 - 15/7 for a over (0.0, 10.0) + Plot[1]:Plot object containing: + [0]: cartesian line: -SingularityFunction(a, 0, 1)/7 + 10*SingularityFunction(a, 10, 0)/7 + + SingularityFunction(a, 10, 1)/7 - 20/7 for a over (0.0, 10.0) + + """ + if not self._ild_reactions: + raise ValueError("I.L.D. reaction equations not found. Please use solve_for_ild_reactions() to generate the I.L.D. reaction equations.") + + a = self.ild_variable + ildplots = [] + + if subs is None: + subs = {} + + for reaction in self._ild_reactions: + for sym in self._ild_reactions[reaction].atoms(Symbol): + if sym != a and sym not in subs: + raise ValueError('Value of %s was not passed.' %sym) + + for sym in self._length.atoms(Symbol): + if sym != a and sym not in subs: + raise ValueError('Value of %s was not passed.' %sym) + + for reaction in self._ild_reactions: + ildplots.append(plot(self._ild_reactions[reaction].subs(subs), + (a, 0, self._length.subs(subs)), title='I.L.D. for Reactions', + xlabel=a, ylabel=reaction, line_color='blue', show=False)) + + return PlotGrid(len(ildplots), 1, *ildplots) + + def solve_for_ild_shear(self, distance, value, *reactions): + """ + + Determines the Influence Line Diagram equations for shear at a + specified point under the effect of a moving load. + + Parameters + ========== + distance : Integer + Distance of the point from the start of the beam + for which equations are to be determined + value : Integer + Magnitude of moving load + reactions : + The reaction forces applied on the beam. + + Warning + ======= + This method creates equations that can give incorrect results when + substituting a = 0 or a = l, with l the length of the beam. + + Examples + ======== + + There is a beam of length 12 meters. There are two simple supports + below the beam, one at the starting point and another at a distance + of 8 meters. Calculate the I.L.D. equations for Shear at a distance + of 4 meters under the effect of a moving load of magnitude 1kN. + + Using the sign convention of downwards forces being positive. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy import symbols + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> E, I = symbols('E, I') + >>> R_0, R_8 = symbols('R_0, R_8') + >>> b = Beam(12, E, I) + >>> p0 = b.apply_support(0, 'roller') + >>> p8 = b.apply_support(8, 'roller') + >>> b.solve_for_ild_reactions(1, R_0, R_8) + >>> b.solve_for_ild_shear(4, 1, R_0, R_8) + >>> b.ild_shear + -(-SingularityFunction(a, 0, 0) + SingularityFunction(a, 12, 0) + 2)*SingularityFunction(a, 4, 0) + - SingularityFunction(-a, 0, 0) - SingularityFunction(a, 0, 0) + SingularityFunction(a, 0, 1)/8 + + SingularityFunction(a, 12, 0)/2 - SingularityFunction(a, 12, 1)/8 + 1 + + """ + + x = self.variable + l = self.length + a = self.ild_variable + + shear_force, _ = self._solve_for_ild_equations(value) + + shear_curve1 = value - limit(shear_force, x, distance) + shear_curve2 = (limit(shear_force, x, l) - limit(shear_force, x, distance)) - value + + for reaction in reactions: + shear_curve1 = shear_curve1.subs(reaction,self._ild_reactions[reaction]) + shear_curve2 = shear_curve2.subs(reaction,self._ild_reactions[reaction]) + + shear_eq = (shear_curve1 - (shear_curve1 - shear_curve2) * SingularityFunction(a, distance, 0) + - value * SingularityFunction(-a, 0, 0) + value * SingularityFunction(a, l, 0)) + + self._ild_shear = shear_eq + + def plot_ild_shear(self,subs=None): + """ + + Plots the Influence Line Diagram for Shear under the effect + of a moving load. This function should be called after + calling solve_for_ild_shear(). + + Parameters + ========== + + subs : dictionary + Python dictionary containing Symbols as key and their + corresponding values. + + Warning + ======= + The values for a = 0 and a = l, with l the length of the beam, in + the plot can be incorrect. + + Examples + ======== + + There is a beam of length 12 meters. There are two simple supports + below the beam, one at the starting point and another at a distance + of 8 meters. Plot the I.L.D. for Shear at a distance + of 4 meters under the effect of a moving load of magnitude 1kN. + + Using the sign convention of downwards forces being positive. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy import symbols + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> E, I = symbols('E, I') + >>> R_0, R_8 = symbols('R_0, R_8') + >>> b = Beam(12, E, I) + >>> p0 = b.apply_support(0, 'roller') + >>> p8 = b.apply_support(8, 'roller') + >>> b.solve_for_ild_reactions(1, R_0, R_8) + >>> b.solve_for_ild_shear(4, 1, R_0, R_8) + >>> b.ild_shear + -(-SingularityFunction(a, 0, 0) + SingularityFunction(a, 12, 0) + 2)*SingularityFunction(a, 4, 0) + - SingularityFunction(-a, 0, 0) - SingularityFunction(a, 0, 0) + SingularityFunction(a, 0, 1)/8 + + SingularityFunction(a, 12, 0)/2 - SingularityFunction(a, 12, 1)/8 + 1 + >>> b.plot_ild_shear() + Plot object containing: + [0]: cartesian line: -(-SingularityFunction(a, 0, 0) + SingularityFunction(a, 12, 0) + 2)*SingularityFunction(a, 4, 0) + - SingularityFunction(-a, 0, 0) - SingularityFunction(a, 0, 0) + SingularityFunction(a, 0, 1)/8 + + SingularityFunction(a, 12, 0)/2 - SingularityFunction(a, 12, 1)/8 + 1 for a over (0.0, 12.0) + + """ + + if not self._ild_shear: + raise ValueError("I.L.D. shear equation not found. Please use solve_for_ild_shear() to generate the I.L.D. shear equations.") + + l = self._length + a = self.ild_variable + + if subs is None: + subs = {} + + for sym in self._ild_shear.atoms(Symbol): + if sym != a and sym not in subs: + raise ValueError('Value of %s was not passed.' %sym) + + for sym in self._length.atoms(Symbol): + if sym != a and sym not in subs: + raise ValueError('Value of %s was not passed.' %sym) + + return plot(self._ild_shear.subs(subs), (a, 0, l), title='I.L.D. for Shear', + xlabel=r'$\mathrm{a}$', ylabel=r'$\mathrm{V}$', line_color='blue',show=True) + + def solve_for_ild_moment(self, distance, value, *reactions): + """ + + Determines the Influence Line Diagram equations for moment at a + specified point under the effect of a moving load. + + Parameters + ========== + distance : Integer + Distance of the point from the start of the beam + for which equations are to be determined + value : Integer + Magnitude of moving load + reactions : + The reaction forces applied on the beam. + + Warning + ======= + This method creates equations that can give incorrect results when + substituting a = 0 or a = l, with l the length of the beam. + + Examples + ======== + + There is a beam of length 12 meters. There are two simple supports + below the beam, one at the starting point and another at a distance + of 8 meters. Calculate the I.L.D. equations for Moment at a distance + of 4 meters under the effect of a moving load of magnitude 1kN. + + Using the sign convention of downwards forces being positive. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy import symbols + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> E, I = symbols('E, I') + >>> R_0, R_8 = symbols('R_0, R_8') + >>> b = Beam(12, E, I) + >>> p0 = b.apply_support(0, 'roller') + >>> p8 = b.apply_support(8, 'roller') + >>> b.solve_for_ild_reactions(1, R_0, R_8) + >>> b.solve_for_ild_moment(4, 1, R_0, R_8) + >>> b.ild_moment + -(4*SingularityFunction(a, 0, 0) - SingularityFunction(a, 0, 1) + SingularityFunction(a, 4, 1))*SingularityFunction(a, 4, 0) + - SingularityFunction(a, 0, 1)/2 + SingularityFunction(a, 4, 1) - 2*SingularityFunction(a, 12, 0) + - SingularityFunction(a, 12, 1)/2 + + """ + + x = self.variable + l = self.length + a = self.ild_variable + + _, moment = self._solve_for_ild_equations(value) + + moment_curve1 = value*(distance * SingularityFunction(a, 0, 0) - SingularityFunction(a, 0, 1) + + SingularityFunction(a, distance, 1)) - limit(moment, x, distance) + moment_curve2 = (limit(moment, x, l)-limit(moment, x, distance) + - value * (l * SingularityFunction(a, 0, 0) - SingularityFunction(a, 0, 1) + + SingularityFunction(a, l, 1))) + + for reaction in reactions: + moment_curve1 = moment_curve1.subs(reaction, self._ild_reactions[reaction]) + moment_curve2 = moment_curve2.subs(reaction, self._ild_reactions[reaction]) + + moment_eq = moment_curve1 - (moment_curve1 - moment_curve2) * SingularityFunction(a, distance, 0) + + self._ild_moment = moment_eq + + def plot_ild_moment(self,subs=None): + """ + + Plots the Influence Line Diagram for Moment under the effect + of a moving load. This function should be called after + calling solve_for_ild_moment(). + + Parameters + ========== + + subs : dictionary + Python dictionary containing Symbols as key and their + corresponding values. + + Warning + ======= + The values for a = 0 and a = l, with l the length of the beam, in + the plot can be incorrect. + + Examples + ======== + + There is a beam of length 12 meters. There are two simple supports + below the beam, one at the starting point and another at a distance + of 8 meters. Plot the I.L.D. for Moment at a distance + of 4 meters under the effect of a moving load of magnitude 1kN. + + Using the sign convention of downwards forces being positive. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy import symbols + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> E, I = symbols('E, I') + >>> R_0, R_8 = symbols('R_0, R_8') + >>> b = Beam(12, E, I) + >>> p0 = b.apply_support(0, 'roller') + >>> p8 = b.apply_support(8, 'roller') + >>> b.solve_for_ild_reactions(1, R_0, R_8) + >>> b.solve_for_ild_moment(4, 1, R_0, R_8) + >>> b.ild_moment + -(4*SingularityFunction(a, 0, 0) - SingularityFunction(a, 0, 1) + SingularityFunction(a, 4, 1))*SingularityFunction(a, 4, 0) + - SingularityFunction(a, 0, 1)/2 + SingularityFunction(a, 4, 1) - 2*SingularityFunction(a, 12, 0) + - SingularityFunction(a, 12, 1)/2 + >>> b.plot_ild_moment() + Plot object containing: + [0]: cartesian line: -(4*SingularityFunction(a, 0, 0) - SingularityFunction(a, 0, 1) + + SingularityFunction(a, 4, 1))*SingularityFunction(a, 4, 0) - SingularityFunction(a, 0, 1)/2 + + SingularityFunction(a, 4, 1) - 2*SingularityFunction(a, 12, 0) - SingularityFunction(a, 12, 1)/2 for a over (0.0, 12.0) + + """ + + if not self._ild_moment: + raise ValueError("I.L.D. moment equation not found. Please use solve_for_ild_moment() to generate the I.L.D. moment equations.") + + a = self.ild_variable + + if subs is None: + subs = {} + + for sym in self._ild_moment.atoms(Symbol): + if sym != a and sym not in subs: + raise ValueError('Value of %s was not passed.' %sym) + + for sym in self._length.atoms(Symbol): + if sym != a and sym not in subs: + raise ValueError('Value of %s was not passed.' %sym) + return plot(self._ild_moment.subs(subs), (a, 0, self._length), title='I.L.D. for Moment', + xlabel=r'$\mathrm{a}$', ylabel=r'$\mathrm{M}$', line_color='blue', show=True) + + @doctest_depends_on(modules=('numpy',)) + def draw(self, pictorial=True): + """ + Returns a plot object representing the beam diagram of the beam. + In particular, the diagram might include: + + * the beam. + * vertical black arrows represent point loads and support reaction + forces (the latter if they have been added with the ``apply_load`` + method). + * circular arrows represent moments. + * shaded areas represent distributed loads. + * the support, if ``apply_support`` has been executed. + * if a composite beam has been created with the ``join`` method and + a hinge has been specified, it will be shown with a white disc. + + The diagram shows positive loads on the upper side of the beam, + and negative loads on the lower side. If two or more distributed + loads acts along the same direction over the same region, the + function will add them up together. + + .. note:: + The user must be careful while entering load values. + The draw function assumes a sign convention which is used + for plotting loads. + Given a right handed coordinate system with XYZ coordinates, + the beam's length is assumed to be along the positive X axis. + The draw function recognizes positive loads(with n>-2) as loads + acting along negative Y direction and positive moments acting + along positive Z direction. + + Parameters + ========== + + pictorial: Boolean (default=True) + Setting ``pictorial=True`` would simply create a pictorial (scaled) + view of the beam diagram. On the other hand, ``pictorial=False`` + would create a beam diagram with the exact dimensions on the plot. + + Examples + ======== + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> P1, P2, M = symbols('P1, P2, M') + >>> E, I = symbols('E, I') + >>> b = Beam(50, 20, 30) + >>> b.apply_load(-10, 2, -1) + >>> b.apply_load(15, 26, -1) + >>> b.apply_load(P1, 10, -1) + >>> b.apply_load(-P2, 40, -1) + >>> b.apply_load(90, 5, 0, 23) + >>> b.apply_load(10, 30, 1, 50) + >>> b.apply_load(M, 15, -2) + >>> b.apply_load(-M, 30, -2) + >>> p50 = b.apply_support(50, "pin") + >>> p0, m0 = b.apply_support(0, "fixed") + >>> p20 = b.apply_support(20, "roller") + >>> p = b.draw() # doctest: +SKIP + >>> p # doctest: +ELLIPSIS,+SKIP + Plot object containing: + [0]: cartesian line: 25*SingularityFunction(x, 5, 0) - 25*SingularityFunction(x, 23, 0) + + SingularityFunction(x, 30, 1) - 20*SingularityFunction(x, 50, 0) + - SingularityFunction(x, 50, 1) + 5 for x over (0.0, 50.0) + [1]: cartesian line: 5 for x over (0.0, 50.0) + ... + >>> p.show() # doctest: +SKIP + + """ + if not numpy: + raise ImportError("To use this function numpy module is required") + + loads = list(set(self.applied_loads) - set(self._support_as_loads)) + if (not pictorial) and any((len(l[0].free_symbols) > 0) and (l[2] >= 0) for l in loads): + raise ValueError("`pictorial=False` requires numerical " + "distributed loads. Instead, symbolic loads were found. " + "Cannot continue.") + + x = self.variable + + # checking whether length is an expression in terms of any Symbol. + if isinstance(self.length, Expr): + l = list(self.length.atoms(Symbol)) + # assigning every Symbol a default value of 10 + l = dict.fromkeys(l, 10) + length = self.length.subs(l) + else: + l = {} + length = self.length + height = length/10 + + rectangles = [] + rectangles.append({'xy':(0, 0), 'width':length, 'height': height, 'facecolor':"brown"}) + annotations, markers, load_eq,load_eq1, fill = self._draw_load(pictorial, length, l) + support_markers, support_rectangles = self._draw_supports(length, l) + + rectangles += support_rectangles + markers += support_markers + + for loc in self._applied_rotation_hinges: + ratio = loc / self.length + x_pos = float(ratio) * length + markers += [{'args':[[x_pos], [height / 2]], 'marker':'o', 'markersize':6, 'color':"white"}] + + for loc in self._applied_sliding_hinges: + ratio = loc / self.length + x_pos = float(ratio) * length + markers += [{'args': [[x_pos], [height / 2]], 'marker':'|', 'markersize':12, 'color':"white"}] + + ylim = (-length, 1.25*length) + if fill: + # when distributed loads are presents, they might get clipped out + # in the figure by the ylim settings. + # It might be necessary to compute new limits. + _min = min(min(fill["y2"]), min(r["xy"][1] for r in rectangles)) + _max = max(max(fill["y1"]), max(r["xy"][1] for r in rectangles)) + if (_min < ylim[0]) or (_max > ylim[1]): + offset = abs(_max - _min) * 0.1 + ylim = (_min - offset, _max + offset) + + sing_plot = plot(height + load_eq, height + load_eq1, (x, 0, length), + xlim=(-height, length + height), ylim=ylim, + annotations=annotations, markers=markers, rectangles=rectangles, + line_color='brown', fill=fill, axis=False, show=False) + + return sing_plot + + + def _is_load_negative(self, load): + """Try to determine if a load is negative or positive, using + expansion and doit if necessary. + + Returns + ======= + True: if the load is negative + False: if the load is positive + None: if it is indeterminate + + """ + rv = load.is_negative + if load.is_Atom or rv is not None: + return rv + return load.doit().expand().is_negative + + def _draw_load(self, pictorial, length, l): + loads = list(set(self.applied_loads) - set(self._support_as_loads)) + height = length/10 + x = self.variable + + annotations = [] + markers = [] + load_args = [] + scaled_load = 0 + load_args1 = [] + scaled_load1 = 0 + load_eq = S.Zero # For positive valued higher order loads + load_eq1 = S.Zero # For negative valued higher order loads + fill = None + + # schematic view should use the class convention as much as possible. + # However, users can add expressions as symbolic loads, for example + # P1 - P2: is this load positive or negative? We can't say. + # On these occasions it is better to inform users about the + # indeterminate state of those loads. + warning_head = "Please, note that this schematic view might not be " \ + "in agreement with the sign convention used by the Beam class " \ + "for load-related computations, because it was not possible " \ + "to determine the sign (hence, the direction) of the " \ + "following loads:\n" + warning_body = "" + + for load in loads: + # check if the position of load is in terms of the beam length. + if l: + pos = load[1].subs(l) + else: + pos = load[1] + + # point loads + if load[2] == -1: + iln = self._is_load_negative(load[0]) + if iln is None: + warning_body += "* Point load %s located at %s\n" % (load[0], load[1]) + if iln: + annotations.append({'text':'', 'xy':(pos, 0), 'xytext':(pos, height - 4*height), 'arrowprops':{'width': 1.5, 'headlength': 5, 'headwidth': 5, 'facecolor': 'black'}}) + else: + annotations.append({'text':'', 'xy':(pos, height), 'xytext':(pos, height*4), 'arrowprops':{"width": 1.5, "headlength": 4, "headwidth": 4, "facecolor": 'black'}}) + # moment loads + elif load[2] == -2: + iln = self._is_load_negative(load[0]) + if iln is None: + warning_body += "* Moment %s located at %s\n" % (load[0], load[1]) + if self._is_load_negative(load[0]): + markers.append({'args':[[pos], [height/2]], 'marker': r'$\circlearrowright$', 'markersize':15}) + else: + markers.append({'args':[[pos], [height/2]], 'marker': r'$\circlearrowleft$', 'markersize':15}) + # higher order loads + elif load[2] >= 0: + # `fill` will be assigned only when higher order loads are present + value, start, order, end = load + + iln = self._is_load_negative(value) + if iln is None: + warning_body += "* Distributed load %s from %s to %s\n" % (value, start, end) + + # Positive loads have their separate equations + if not iln: + # if pictorial is True we remake the load equation again with + # some constant magnitude values. + if pictorial: + # remake the load equation again with some constant + # magnitude values. + value = 10**(1-order) if order > 0 else length/2 + scaled_load += value*SingularityFunction(x, start, order) + if end: + f2 = value*x**order if order >= 0 else length/2*x**order + for i in range(0, order + 1): + scaled_load -= (f2.diff(x, i).subs(x, end - start)* + SingularityFunction(x, end, i)/factorial(i)) + + if isinstance(scaled_load, Add): + load_args = scaled_load.args + else: + # when the load equation consists of only a single term + load_args = (scaled_load,) + load_eq = Add(*[i.subs(l) for i in load_args]) + + # For loads with negative value + else: + if pictorial: + # remake the load equation again with some constant + # magnitude values. + value = 10**(1-order) if order > 0 else length/2 + scaled_load1 += abs(value)*SingularityFunction(x, start, order) + if end: + f2 = abs(value)*x**order if order >= 0 else length/2*x**order + for i in range(0, order + 1): + scaled_load1 -= (f2.diff(x, i).subs(x, end - start)* + SingularityFunction(x, end, i)/factorial(i)) + + if isinstance(scaled_load1, Add): + load_args1 = scaled_load1.args + else: + # when the load equation consists of only a single term + load_args1 = (scaled_load1,) + load_eq1 = [i.subs(l) for i in load_args1] + load_eq1 = -Add(*load_eq1) - height + + if len(warning_body) > 0: + warnings.warn(warning_head + warning_body) + + xx = numpy.arange(0, float(length), 0.001) + yy1 = lambdify([x], height + load_eq.rewrite(Piecewise))(xx) + yy2 = lambdify([x], height + load_eq1.rewrite(Piecewise))(xx) + if not isinstance(yy1, numpy.ndarray): + yy1 *= numpy.ones_like(xx) + if not isinstance(yy2, numpy.ndarray): + yy2 *= numpy.ones_like(xx) + fill = {'x': xx, 'y1': yy1, 'y2': yy2, + 'color':'darkkhaki', "zorder": -1} + return annotations, markers, load_eq, load_eq1, fill + + + def _draw_supports(self, length, l): + height = float(length/10) + + support_markers = [] + support_rectangles = [] + for support in self._applied_supports: + if l: + pos = support[0].subs(l) + else: + pos = support[0] + + if support[1] == "pin": + support_markers.append({'args':[pos, [0]], 'marker':6, 'markersize':13, 'color':"black"}) + + elif support[1] == "roller": + support_markers.append({'args':[pos, [-height/2.5]], 'marker':'o', 'markersize':11, 'color':"black"}) + + elif support[1] == "fixed": + if pos == 0: + support_rectangles.append({'xy':(0, -3*height), 'width':-length/20, 'height':6*height + height, 'fill':False, 'hatch':'/////'}) + else: + support_rectangles.append({'xy':(length, -3*height), 'width':length/20, 'height': 6*height + height, 'fill':False, 'hatch':'/////'}) + + return support_markers, support_rectangles + + +class Beam3D(Beam): + """ + This class handles loads applied in any direction of a 3D space along + with unequal values of Second moment along different axes. + + .. note:: + A consistent sign convention must be used while solving a beam + bending problem; the results will + automatically follow the chosen sign convention. + This class assumes that any kind of distributed load/moment is + applied through out the span of a beam. + + Examples + ======== + There is a beam of l meters long. A constant distributed load of magnitude q + is applied along y-axis from start till the end of beam. A constant distributed + moment of magnitude m is also applied along z-axis from start till the end of beam. + Beam is fixed at both of its end. So, deflection of the beam at the both ends + is restricted. + + >>> from sympy.physics.continuum_mechanics.beam import Beam3D + >>> from sympy import symbols, simplify, collect, factor + >>> l, E, G, I, A = symbols('l, E, G, I, A') + >>> b = Beam3D(l, E, G, I, A) + >>> x, q, m = symbols('x, q, m') + >>> b.apply_load(q, 0, 0, dir="y") + >>> b.apply_moment_load(m, 0, -1, dir="z") + >>> b.shear_force() + [0, -q*x, 0] + >>> b.bending_moment() + [0, 0, -m*x + q*x**2/2] + >>> b.bc_slope = [(0, [0, 0, 0]), (l, [0, 0, 0])] + >>> b.bc_deflection = [(0, [0, 0, 0]), (l, [0, 0, 0])] + >>> b.solve_slope_deflection() + >>> factor(b.slope()) + [0, 0, x*(-l + x)*(-A*G*l**3*q + 2*A*G*l**2*q*x - 12*E*I*l*q + - 72*E*I*m + 24*E*I*q*x)/(12*E*I*(A*G*l**2 + 12*E*I))] + >>> dx, dy, dz = b.deflection() + >>> dy = collect(simplify(dy), x) + >>> dx == dz == 0 + True + >>> dy == (x*(12*E*I*l*(A*G*l**2*q - 2*A*G*l*m + 12*E*I*q) + ... + x*(A*G*l*(3*l*(A*G*l**2*q - 2*A*G*l*m + 12*E*I*q) + x*(-2*A*G*l**2*q + 4*A*G*l*m - 24*E*I*q)) + ... + A*G*(A*G*l**2 + 12*E*I)*(-2*l**2*q + 6*l*m - 4*m*x + q*x**2) + ... - 12*E*I*q*(A*G*l**2 + 12*E*I)))/(24*A*E*G*I*(A*G*l**2 + 12*E*I))) + True + + References + ========== + + .. [1] https://homes.civil.aau.dk/jc/FemteSemester/Beams3D.pdf + + """ + + def __init__(self, length, elastic_modulus, shear_modulus, second_moment, + area, variable=Symbol('x')): + """Initializes the class. + + Parameters + ========== + length : Sympifyable + A Symbol or value representing the Beam's length. + elastic_modulus : Sympifyable + A SymPy expression representing the Beam's Modulus of Elasticity. + It is a measure of the stiffness of the Beam material. + shear_modulus : Sympifyable + A SymPy expression representing the Beam's Modulus of rigidity. + It is a measure of rigidity of the Beam material. + second_moment : Sympifyable or list + A list of two elements having SymPy expression representing the + Beam's Second moment of area. First value represent Second moment + across y-axis and second across z-axis. + Single SymPy expression can be passed if both values are same + area : Sympifyable + A SymPy expression representing the Beam's cross-sectional area + in a plane perpendicular to length of the Beam. + variable : Symbol, optional + A Symbol object that will be used as the variable along the beam + while representing the load, shear, moment, slope and deflection + curve. By default, it is set to ``Symbol('x')``. + """ + super().__init__(length, elastic_modulus, second_moment, variable) + self.shear_modulus = shear_modulus + self.area = area + self._load_vector = [0, 0, 0] + self._moment_load_vector = [0, 0, 0] + self._torsion_moment = {} + self._load_Singularity = [0, 0, 0] + self._slope = [0, 0, 0] + self._deflection = [0, 0, 0] + self._angular_deflection = 0 + + @property + def shear_modulus(self): + """Young's Modulus of the Beam. """ + return self._shear_modulus + + @shear_modulus.setter + def shear_modulus(self, e): + self._shear_modulus = sympify(e) + + @property + def second_moment(self): + """Second moment of area of the Beam. """ + return self._second_moment + + @second_moment.setter + def second_moment(self, i): + if isinstance(i, list): + i = [sympify(x) for x in i] + self._second_moment = i + else: + self._second_moment = sympify(i) + + @property + def area(self): + """Cross-sectional area of the Beam. """ + return self._area + + @area.setter + def area(self, a): + self._area = sympify(a) + + @property + def load_vector(self): + """ + Returns a three element list representing the load vector. + """ + return self._load_vector + + @property + def moment_load_vector(self): + """ + Returns a three element list representing moment loads on Beam. + """ + return self._moment_load_vector + + @property + def boundary_conditions(self): + """ + Returns a dictionary of boundary conditions applied on the beam. + The dictionary has two keywords namely slope and deflection. + The value of each keyword is a list of tuple, where each tuple + contains location and value of a boundary condition in the format + (location, value). Further each value is a list corresponding to + slope or deflection(s) values along three axes at that location. + + Examples + ======== + There is a beam of length 4 meters. The slope at 0 should be 4 along + the x-axis and 0 along others. At the other end of beam, deflection + along all the three axes should be zero. + + >>> from sympy.physics.continuum_mechanics.beam import Beam3D + >>> from sympy import symbols + >>> l, E, G, I, A, x = symbols('l, E, G, I, A, x') + >>> b = Beam3D(30, E, G, I, A, x) + >>> b.bc_slope = [(0, (4, 0, 0))] + >>> b.bc_deflection = [(4, [0, 0, 0])] + >>> b.boundary_conditions + {'bending_moment': [], 'deflection': [(4, [0, 0, 0])], 'shear_force': [], 'slope': [(0, (4, 0, 0))]} + + Here the deflection of the beam should be ``0`` along all the three axes at ``4``. + Similarly, the slope of the beam should be ``4`` along x-axis and ``0`` + along y and z axis at ``0``. + """ + return self._boundary_conditions + + def polar_moment(self): + """ + Returns the polar moment of area of the beam + about the X axis with respect to the centroid. + + Examples + ======== + + >>> from sympy.physics.continuum_mechanics.beam import Beam3D + >>> from sympy import symbols + >>> l, E, G, I, A = symbols('l, E, G, I, A') + >>> b = Beam3D(l, E, G, I, A) + >>> b.polar_moment() + 2*I + >>> I1 = [9, 15] + >>> b = Beam3D(l, E, G, I1, A) + >>> b.polar_moment() + 24 + """ + if not iterable(self.second_moment): + return 2*self.second_moment + return sum(self.second_moment) + + def apply_load(self, value, start, order, dir="y"): + """ + This method adds up the force load to a particular beam object. + + Parameters + ========== + value : Sympifyable + The magnitude of an applied load. + dir : String + Axis along which load is applied. + order : Integer + The order of the applied load. + - For point loads, order=-1 + - For constant distributed load, order=0 + - For ramp loads, order=1 + - For parabolic ramp loads, order=2 + - ... so on. + """ + x = self.variable + value = sympify(value) + start = sympify(start) + order = sympify(order) + + if dir == "x": + if not order == -1: + self._load_vector[0] += value + self._load_Singularity[0] += value*SingularityFunction(x, start, order) + + elif dir == "y": + if not order == -1: + self._load_vector[1] += value + self._load_Singularity[1] += value*SingularityFunction(x, start, order) + + else: + if not order == -1: + self._load_vector[2] += value + self._load_Singularity[2] += value*SingularityFunction(x, start, order) + + def apply_moment_load(self, value, start, order, dir="y"): + """ + This method adds up the moment loads to a particular beam object. + + Parameters + ========== + value : Sympifyable + The magnitude of an applied moment. + dir : String + Axis along which moment is applied. + order : Integer + The order of the applied load. + - For point moments, order=-2 + - For constant distributed moment, order=-1 + - For ramp moments, order=0 + - For parabolic ramp moments, order=1 + - ... so on. + """ + x = self.variable + value = sympify(value) + start = sympify(start) + order = sympify(order) + + if dir == "x": + if not order == -2: + self._moment_load_vector[0] += value + else: + if start in list(self._torsion_moment): + self._torsion_moment[start] += value + else: + self._torsion_moment[start] = value + self._load_Singularity[0] += value*SingularityFunction(x, start, order) + elif dir == "y": + if not order == -2: + self._moment_load_vector[1] += value + self._load_Singularity[0] += value*SingularityFunction(x, start, order) + else: + if not order == -2: + self._moment_load_vector[2] += value + self._load_Singularity[0] += value*SingularityFunction(x, start, order) + + def apply_support(self, loc, type="fixed"): + if type in ("pin", "roller"): + reaction_load = Symbol('R_'+str(loc)) + self._reaction_loads[reaction_load] = reaction_load + self.bc_deflection.append((loc, [0, 0, 0])) + else: + reaction_load = Symbol('R_'+str(loc)) + reaction_moment = Symbol('M_'+str(loc)) + self._reaction_loads[reaction_load] = [reaction_load, reaction_moment] + self.bc_deflection.append((loc, [0, 0, 0])) + self.bc_slope.append((loc, [0, 0, 0])) + + def solve_for_reaction_loads(self, *reaction): + """ + Solves for the reaction forces. + + Examples + ======== + There is a beam of length 30 meters. It it supported by rollers at + of its end. A constant distributed load of magnitude 8 N is applied + from start till its end along y-axis. Another linear load having + slope equal to 9 is applied along z-axis. + + >>> from sympy.physics.continuum_mechanics.beam import Beam3D + >>> from sympy import symbols + >>> l, E, G, I, A, x = symbols('l, E, G, I, A, x') + >>> b = Beam3D(30, E, G, I, A, x) + >>> b.apply_load(8, start=0, order=0, dir="y") + >>> b.apply_load(9*x, start=0, order=0, dir="z") + >>> b.bc_deflection = [(0, [0, 0, 0]), (30, [0, 0, 0])] + >>> R1, R2, R3, R4 = symbols('R1, R2, R3, R4') + >>> b.apply_load(R1, start=0, order=-1, dir="y") + >>> b.apply_load(R2, start=30, order=-1, dir="y") + >>> b.apply_load(R3, start=0, order=-1, dir="z") + >>> b.apply_load(R4, start=30, order=-1, dir="z") + >>> b.solve_for_reaction_loads(R1, R2, R3, R4) + >>> b.reaction_loads + {R1: -120, R2: -120, R3: -1350, R4: -2700} + """ + x = self.variable + l = self.length + q = self._load_Singularity + shear_curves = [integrate(load, x) for load in q] + moment_curves = [integrate(shear, x) for shear in shear_curves] + for i in range(3): + react = [r for r in reaction if (shear_curves[i].has(r) or moment_curves[i].has(r))] + if len(react) == 0: + continue + shear_curve = limit(shear_curves[i], x, l) + moment_curve = limit(moment_curves[i], x, l) + sol = list((linsolve([shear_curve, moment_curve], react).args)[0]) + sol_dict = dict(zip(react, sol)) + reaction_loads = self._reaction_loads + # Check if any of the evaluated reaction exists in another direction + # and if it exists then it should have same value. + for key in sol_dict: + if key in reaction_loads and sol_dict[key] != reaction_loads[key]: + raise ValueError("Ambiguous solution for %s in different directions." % key) + self._reaction_loads.update(sol_dict) + + def shear_force(self): + """ + Returns a list of three expressions which represents the shear force + curve of the Beam object along all three axes. + """ + x = self.variable + q = self._load_vector + return [integrate(-q[0], x), integrate(-q[1], x), integrate(-q[2], x)] + + def axial_force(self): + """ + Returns expression of Axial shear force present inside the Beam object. + """ + return self.shear_force()[0] + + def shear_stress(self): + """ + Returns a list of three expressions which represents the shear stress + curve of the Beam object along all three axes. + """ + return [self.shear_force()[0]/self._area, self.shear_force()[1]/self._area, self.shear_force()[2]/self._area] + + def axial_stress(self): + """ + Returns expression of Axial stress present inside the Beam object. + """ + return self.axial_force()/self._area + + def bending_moment(self): + """ + Returns a list of three expressions which represents the bending moment + curve of the Beam object along all three axes. + """ + x = self.variable + m = self._moment_load_vector + shear = self.shear_force() + + return [integrate(-m[0], x), integrate(-m[1] + shear[2], x), + integrate(-m[2] - shear[1], x) ] + + def torsional_moment(self): + """ + Returns expression of Torsional moment present inside the Beam object. + """ + return self.bending_moment()[0] + + def solve_for_torsion(self): + """ + Solves for the angular deflection due to the torsional effects of + moments being applied in the x-direction i.e. out of or into the beam. + + Here, a positive torque means the direction of the torque is positive + i.e. out of the beam along the beam-axis. Likewise, a negative torque + signifies a torque into the beam cross-section. + + Examples + ======== + + >>> from sympy.physics.continuum_mechanics.beam import Beam3D + >>> from sympy import symbols + >>> l, E, G, I, A, x = symbols('l, E, G, I, A, x') + >>> b = Beam3D(20, E, G, I, A, x) + >>> b.apply_moment_load(4, 4, -2, dir='x') + >>> b.apply_moment_load(4, 8, -2, dir='x') + >>> b.apply_moment_load(4, 8, -2, dir='x') + >>> b.solve_for_torsion() + >>> b.angular_deflection().subs(x, 3) + 18/(G*I) + """ + x = self.variable + sum_moments = 0 + for point in list(self._torsion_moment): + sum_moments += self._torsion_moment[point] + list(self._torsion_moment).sort() + pointsList = list(self._torsion_moment) + torque_diagram = Piecewise((sum_moments, x<=pointsList[0]), (0, x>=pointsList[0])) + for i in range(len(pointsList))[1:]: + sum_moments -= self._torsion_moment[pointsList[i-1]] + torque_diagram += Piecewise((0, x<=pointsList[i-1]), (sum_moments, x<=pointsList[i]), (0, x>=pointsList[i])) + integrated_torque_diagram = integrate(torque_diagram) + self._angular_deflection = integrated_torque_diagram/(self.shear_modulus*self.polar_moment()) + + def solve_slope_deflection(self): + x = self.variable + l = self.length + E = self.elastic_modulus + G = self.shear_modulus + I = self.second_moment + if isinstance(I, list): + I_y, I_z = I[0], I[1] + else: + I_y = I_z = I + A = self._area + load = self._load_vector + moment = self._moment_load_vector + defl = Function('defl') + theta = Function('theta') + + # Finding deflection along x-axis(and corresponding slope value by differentiating it) + # Equation used: Derivative(E*A*Derivative(def_x(x), x), x) + load_x = 0 + eq = Derivative(E*A*Derivative(defl(x), x), x) + load[0] + def_x = dsolve(Eq(eq, 0), defl(x)).args[1] + # Solving constants originated from dsolve + C1 = Symbol('C1') + C2 = Symbol('C2') + constants = list((linsolve([def_x.subs(x, 0), def_x.subs(x, l)], C1, C2).args)[0]) + def_x = def_x.subs({C1:constants[0], C2:constants[1]}) + slope_x = def_x.diff(x) + self._deflection[0] = def_x + self._slope[0] = slope_x + + # Finding deflection along y-axis and slope across z-axis. System of equation involved: + # 1: Derivative(E*I_z*Derivative(theta_z(x), x), x) + G*A*(Derivative(defl_y(x), x) - theta_z(x)) + moment_z = 0 + # 2: Derivative(G*A*(Derivative(defl_y(x), x) - theta_z(x)), x) + load_y = 0 + C_i = Symbol('C_i') + # Substitute value of `G*A*(Derivative(defl_y(x), x) - theta_z(x))` from (2) in (1) + eq1 = Derivative(E*I_z*Derivative(theta(x), x), x) + (integrate(-load[1], x) + C_i) + moment[2] + slope_z = dsolve(Eq(eq1, 0)).args[1] + + # Solve for constants originated from using dsolve on eq1 + constants = list((linsolve([slope_z.subs(x, 0), slope_z.subs(x, l)], C1, C2).args)[0]) + slope_z = slope_z.subs({C1:constants[0], C2:constants[1]}) + + # Put value of slope obtained back in (2) to solve for `C_i` and find deflection across y-axis + eq2 = G*A*(Derivative(defl(x), x)) + load[1]*x - C_i - G*A*slope_z + def_y = dsolve(Eq(eq2, 0), defl(x)).args[1] + # Solve for constants originated from using dsolve on eq2 + constants = list((linsolve([def_y.subs(x, 0), def_y.subs(x, l)], C1, C_i).args)[0]) + self._deflection[1] = def_y.subs({C1:constants[0], C_i:constants[1]}) + self._slope[2] = slope_z.subs(C_i, constants[1]) + + # Finding deflection along z-axis and slope across y-axis. System of equation involved: + # 1: Derivative(E*I_y*Derivative(theta_y(x), x), x) - G*A*(Derivative(defl_z(x), x) + theta_y(x)) + moment_y = 0 + # 2: Derivative(G*A*(Derivative(defl_z(x), x) + theta_y(x)), x) + load_z = 0 + + # Substitute value of `G*A*(Derivative(defl_y(x), x) + theta_z(x))` from (2) in (1) + eq1 = Derivative(E*I_y*Derivative(theta(x), x), x) + (integrate(load[2], x) - C_i) + moment[1] + slope_y = dsolve(Eq(eq1, 0)).args[1] + # Solve for constants originated from using dsolve on eq1 + constants = list((linsolve([slope_y.subs(x, 0), slope_y.subs(x, l)], C1, C2).args)[0]) + slope_y = slope_y.subs({C1:constants[0], C2:constants[1]}) + + # Put value of slope obtained back in (2) to solve for `C_i` and find deflection across z-axis + eq2 = G*A*(Derivative(defl(x), x)) + load[2]*x - C_i + G*A*slope_y + def_z = dsolve(Eq(eq2,0)).args[1] + # Solve for constants originated from using dsolve on eq2 + constants = list((linsolve([def_z.subs(x, 0), def_z.subs(x, l)], C1, C_i).args)[0]) + self._deflection[2] = def_z.subs({C1:constants[0], C_i:constants[1]}) + self._slope[1] = slope_y.subs(C_i, constants[1]) + + def slope(self): + """ + Returns a three element list representing slope of deflection curve + along all the three axes. + """ + return self._slope + + def deflection(self): + """ + Returns a three element list representing deflection curve along all + the three axes. + """ + return self._deflection + + def angular_deflection(self): + """ + Returns a function in x depicting how the angular deflection, due to moments + in the x-axis on the beam, varies with x. + """ + return self._angular_deflection + + def _plot_shear_force(self, dir, subs=None): + + shear_force = self.shear_force() + + if dir == 'x': + dir_num = 0 + color = 'r' + + elif dir == 'y': + dir_num = 1 + color = 'g' + + elif dir == 'z': + dir_num = 2 + color = 'b' + + if subs is None: + subs = {} + + for sym in shear_force[dir_num].atoms(Symbol): + if sym != self.variable and sym not in subs: + raise ValueError('Value of %s was not passed.' %sym) + if self.length in subs: + length = subs[self.length] + else: + length = self.length + + return plot(shear_force[dir_num].subs(subs), (self.variable, 0, length), show = False, title='Shear Force along %c direction'%dir, + xlabel=r'$\mathrm{X}$', ylabel=r'$\mathrm{V(%c)}$'%dir, line_color=color) + + def plot_shear_force(self, dir="all", subs=None): + + """ + + Returns a plot for Shear force along all three directions + present in the Beam object. + + Parameters + ========== + dir : string (default : "all") + Direction along which shear force plot is required. + If no direction is specified, all plots are displayed. + subs : dictionary + Python dictionary containing Symbols as key and their + corresponding values. + + Examples + ======== + There is a beam of length 20 meters. It is supported by rollers + at both of its ends. A linear load having slope equal to 12 is applied + along y-axis. A constant distributed load of magnitude 15 N is + applied from start till its end along z-axis. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.physics.continuum_mechanics.beam import Beam3D + >>> from sympy import symbols + >>> l, E, G, I, A, x = symbols('l, E, G, I, A, x') + >>> b = Beam3D(20, E, G, I, A, x) + >>> b.apply_load(15, start=0, order=0, dir="z") + >>> b.apply_load(12*x, start=0, order=0, dir="y") + >>> b.bc_deflection = [(0, [0, 0, 0]), (20, [0, 0, 0])] + >>> R1, R2, R3, R4 = symbols('R1, R2, R3, R4') + >>> b.apply_load(R1, start=0, order=-1, dir="z") + >>> b.apply_load(R2, start=20, order=-1, dir="z") + >>> b.apply_load(R3, start=0, order=-1, dir="y") + >>> b.apply_load(R4, start=20, order=-1, dir="y") + >>> b.solve_for_reaction_loads(R1, R2, R3, R4) + >>> b.plot_shear_force() + PlotGrid object containing: + Plot[0]:Plot object containing: + [0]: cartesian line: 0 for x over (0.0, 20.0) + Plot[1]:Plot object containing: + [0]: cartesian line: -6*x**2 for x over (0.0, 20.0) + Plot[2]:Plot object containing: + [0]: cartesian line: -15*x for x over (0.0, 20.0) + + """ + + dir = dir.lower() + # For shear force along x direction + if dir == "x": + Px = self._plot_shear_force('x', subs) + return Px.show() + # For shear force along y direction + elif dir == "y": + Py = self._plot_shear_force('y', subs) + return Py.show() + # For shear force along z direction + elif dir == "z": + Pz = self._plot_shear_force('z', subs) + return Pz.show() + # For shear force along all direction + else: + Px = self._plot_shear_force('x', subs) + Py = self._plot_shear_force('y', subs) + Pz = self._plot_shear_force('z', subs) + return PlotGrid(3, 1, Px, Py, Pz) + + def _plot_bending_moment(self, dir, subs=None): + + bending_moment = self.bending_moment() + + if dir == 'x': + dir_num = 0 + color = 'g' + + elif dir == 'y': + dir_num = 1 + color = 'c' + + elif dir == 'z': + dir_num = 2 + color = 'm' + + if subs is None: + subs = {} + + for sym in bending_moment[dir_num].atoms(Symbol): + if sym != self.variable and sym not in subs: + raise ValueError('Value of %s was not passed.' %sym) + if self.length in subs: + length = subs[self.length] + else: + length = self.length + + return plot(bending_moment[dir_num].subs(subs), (self.variable, 0, length), show = False, title='Bending Moment along %c direction'%dir, + xlabel=r'$\mathrm{X}$', ylabel=r'$\mathrm{M(%c)}$'%dir, line_color=color) + + def plot_bending_moment(self, dir="all", subs=None): + + """ + + Returns a plot for bending moment along all three directions + present in the Beam object. + + Parameters + ========== + dir : string (default : "all") + Direction along which bending moment plot is required. + If no direction is specified, all plots are displayed. + subs : dictionary + Python dictionary containing Symbols as key and their + corresponding values. + + Examples + ======== + There is a beam of length 20 meters. It is supported by rollers + at both of its ends. A linear load having slope equal to 12 is applied + along y-axis. A constant distributed load of magnitude 15 N is + applied from start till its end along z-axis. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.physics.continuum_mechanics.beam import Beam3D + >>> from sympy import symbols + >>> l, E, G, I, A, x = symbols('l, E, G, I, A, x') + >>> b = Beam3D(20, E, G, I, A, x) + >>> b.apply_load(15, start=0, order=0, dir="z") + >>> b.apply_load(12*x, start=0, order=0, dir="y") + >>> b.bc_deflection = [(0, [0, 0, 0]), (20, [0, 0, 0])] + >>> R1, R2, R3, R4 = symbols('R1, R2, R3, R4') + >>> b.apply_load(R1, start=0, order=-1, dir="z") + >>> b.apply_load(R2, start=20, order=-1, dir="z") + >>> b.apply_load(R3, start=0, order=-1, dir="y") + >>> b.apply_load(R4, start=20, order=-1, dir="y") + >>> b.solve_for_reaction_loads(R1, R2, R3, R4) + >>> b.plot_bending_moment() + PlotGrid object containing: + Plot[0]:Plot object containing: + [0]: cartesian line: 0 for x over (0.0, 20.0) + Plot[1]:Plot object containing: + [0]: cartesian line: -15*x**2/2 for x over (0.0, 20.0) + Plot[2]:Plot object containing: + [0]: cartesian line: 2*x**3 for x over (0.0, 20.0) + + """ + + dir = dir.lower() + # For bending moment along x direction + if dir == "x": + Px = self._plot_bending_moment('x', subs) + return Px.show() + # For bending moment along y direction + elif dir == "y": + Py = self._plot_bending_moment('y', subs) + return Py.show() + # For bending moment along z direction + elif dir == "z": + Pz = self._plot_bending_moment('z', subs) + return Pz.show() + # For bending moment along all direction + else: + Px = self._plot_bending_moment('x', subs) + Py = self._plot_bending_moment('y', subs) + Pz = self._plot_bending_moment('z', subs) + return PlotGrid(3, 1, Px, Py, Pz) + + def _plot_slope(self, dir, subs=None): + + slope = self.slope() + + if dir == 'x': + dir_num = 0 + color = 'b' + + elif dir == 'y': + dir_num = 1 + color = 'm' + + elif dir == 'z': + dir_num = 2 + color = 'g' + + if subs is None: + subs = {} + + for sym in slope[dir_num].atoms(Symbol): + if sym != self.variable and sym not in subs: + raise ValueError('Value of %s was not passed.' %sym) + if self.length in subs: + length = subs[self.length] + else: + length = self.length + + + return plot(slope[dir_num].subs(subs), (self.variable, 0, length), show = False, title='Slope along %c direction'%dir, + xlabel=r'$\mathrm{X}$', ylabel=r'$\mathrm{\theta(%c)}$'%dir, line_color=color) + + def plot_slope(self, dir="all", subs=None): + + """ + + Returns a plot for Slope along all three directions + present in the Beam object. + + Parameters + ========== + dir : string (default : "all") + Direction along which Slope plot is required. + If no direction is specified, all plots are displayed. + subs : dictionary + Python dictionary containing Symbols as keys and their + corresponding values. + + Examples + ======== + There is a beam of length 20 meters. It is supported by rollers + at both of its ends. A linear load having slope equal to 12 is applied + along y-axis. A constant distributed load of magnitude 15 N is + applied from start till its end along z-axis. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.physics.continuum_mechanics.beam import Beam3D + >>> from sympy import symbols + >>> l, E, G, I, A, x = symbols('l, E, G, I, A, x') + >>> b = Beam3D(20, 40, 21, 100, 25, x) + >>> b.apply_load(15, start=0, order=0, dir="z") + >>> b.apply_load(12*x, start=0, order=0, dir="y") + >>> b.bc_deflection = [(0, [0, 0, 0]), (20, [0, 0, 0])] + >>> R1, R2, R3, R4 = symbols('R1, R2, R3, R4') + >>> b.apply_load(R1, start=0, order=-1, dir="z") + >>> b.apply_load(R2, start=20, order=-1, dir="z") + >>> b.apply_load(R3, start=0, order=-1, dir="y") + >>> b.apply_load(R4, start=20, order=-1, dir="y") + >>> b.solve_for_reaction_loads(R1, R2, R3, R4) + >>> b.solve_slope_deflection() + >>> b.plot_slope() + PlotGrid object containing: + Plot[0]:Plot object containing: + [0]: cartesian line: 0 for x over (0.0, 20.0) + Plot[1]:Plot object containing: + [0]: cartesian line: -x**3/1600 + 3*x**2/160 - x/8 for x over (0.0, 20.0) + Plot[2]:Plot object containing: + [0]: cartesian line: x**4/8000 - 19*x**2/172 + 52*x/43 for x over (0.0, 20.0) + + """ + + dir = dir.lower() + # For Slope along x direction + if dir == "x": + Px = self._plot_slope('x', subs) + return Px.show() + # For Slope along y direction + elif dir == "y": + Py = self._plot_slope('y', subs) + return Py.show() + # For Slope along z direction + elif dir == "z": + Pz = self._plot_slope('z', subs) + return Pz.show() + # For Slope along all direction + else: + Px = self._plot_slope('x', subs) + Py = self._plot_slope('y', subs) + Pz = self._plot_slope('z', subs) + return PlotGrid(3, 1, Px, Py, Pz) + + def _plot_deflection(self, dir, subs=None): + + deflection = self.deflection() + + if dir == 'x': + dir_num = 0 + color = 'm' + + elif dir == 'y': + dir_num = 1 + color = 'r' + + elif dir == 'z': + dir_num = 2 + color = 'c' + + if subs is None: + subs = {} + + for sym in deflection[dir_num].atoms(Symbol): + if sym != self.variable and sym not in subs: + raise ValueError('Value of %s was not passed.' %sym) + if self.length in subs: + length = subs[self.length] + else: + length = self.length + + return plot(deflection[dir_num].subs(subs), (self.variable, 0, length), show = False, title='Deflection along %c direction'%dir, + xlabel=r'$\mathrm{X}$', ylabel=r'$\mathrm{\delta(%c)}$'%dir, line_color=color) + + def plot_deflection(self, dir="all", subs=None): + + """ + + Returns a plot for Deflection along all three directions + present in the Beam object. + + Parameters + ========== + dir : string (default : "all") + Direction along which deflection plot is required. + If no direction is specified, all plots are displayed. + subs : dictionary + Python dictionary containing Symbols as keys and their + corresponding values. + + Examples + ======== + There is a beam of length 20 meters. It is supported by rollers + at both of its ends. A linear load having slope equal to 12 is applied + along y-axis. A constant distributed load of magnitude 15 N is + applied from start till its end along z-axis. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.physics.continuum_mechanics.beam import Beam3D + >>> from sympy import symbols + >>> l, E, G, I, A, x = symbols('l, E, G, I, A, x') + >>> b = Beam3D(20, 40, 21, 100, 25, x) + >>> b.apply_load(15, start=0, order=0, dir="z") + >>> b.apply_load(12*x, start=0, order=0, dir="y") + >>> b.bc_deflection = [(0, [0, 0, 0]), (20, [0, 0, 0])] + >>> R1, R2, R3, R4 = symbols('R1, R2, R3, R4') + >>> b.apply_load(R1, start=0, order=-1, dir="z") + >>> b.apply_load(R2, start=20, order=-1, dir="z") + >>> b.apply_load(R3, start=0, order=-1, dir="y") + >>> b.apply_load(R4, start=20, order=-1, dir="y") + >>> b.solve_for_reaction_loads(R1, R2, R3, R4) + >>> b.solve_slope_deflection() + >>> b.plot_deflection() + PlotGrid object containing: + Plot[0]:Plot object containing: + [0]: cartesian line: 0 for x over (0.0, 20.0) + Plot[1]:Plot object containing: + [0]: cartesian line: x**5/40000 - 4013*x**3/90300 + 26*x**2/43 + 1520*x/903 for x over (0.0, 20.0) + Plot[2]:Plot object containing: + [0]: cartesian line: x**4/6400 - x**3/160 + 27*x**2/560 + 2*x/7 for x over (0.0, 20.0) + + + """ + + dir = dir.lower() + # For deflection along x direction + if dir == "x": + Px = self._plot_deflection('x', subs) + return Px.show() + # For deflection along y direction + elif dir == "y": + Py = self._plot_deflection('y', subs) + return Py.show() + # For deflection along z direction + elif dir == "z": + Pz = self._plot_deflection('z', subs) + return Pz.show() + # For deflection along all direction + else: + Px = self._plot_deflection('x', subs) + Py = self._plot_deflection('y', subs) + Pz = self._plot_deflection('z', subs) + return PlotGrid(3, 1, Px, Py, Pz) + + def plot_loading_results(self, dir='x', subs=None): + + """ + + Returns a subplot of Shear Force, Bending Moment, + Slope and Deflection of the Beam object along the direction specified. + + Parameters + ========== + + dir : string (default : "x") + Direction along which plots are required. + If no direction is specified, plots along x-axis are displayed. + subs : dictionary + Python dictionary containing Symbols as key and their + corresponding values. + + Examples + ======== + There is a beam of length 20 meters. It is supported by rollers + at both of its ends. A linear load having slope equal to 12 is applied + along y-axis. A constant distributed load of magnitude 15 N is + applied from start till its end along z-axis. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.physics.continuum_mechanics.beam import Beam3D + >>> from sympy import symbols + >>> l, E, G, I, A, x = symbols('l, E, G, I, A, x') + >>> b = Beam3D(20, E, G, I, A, x) + >>> subs = {E:40, G:21, I:100, A:25} + >>> b.apply_load(15, start=0, order=0, dir="z") + >>> b.apply_load(12*x, start=0, order=0, dir="y") + >>> b.bc_deflection = [(0, [0, 0, 0]), (20, [0, 0, 0])] + >>> R1, R2, R3, R4 = symbols('R1, R2, R3, R4') + >>> b.apply_load(R1, start=0, order=-1, dir="z") + >>> b.apply_load(R2, start=20, order=-1, dir="z") + >>> b.apply_load(R3, start=0, order=-1, dir="y") + >>> b.apply_load(R4, start=20, order=-1, dir="y") + >>> b.solve_for_reaction_loads(R1, R2, R3, R4) + >>> b.solve_slope_deflection() + >>> b.plot_loading_results('y',subs) + PlotGrid object containing: + Plot[0]:Plot object containing: + [0]: cartesian line: -6*x**2 for x over (0.0, 20.0) + Plot[1]:Plot object containing: + [0]: cartesian line: -15*x**2/2 for x over (0.0, 20.0) + Plot[2]:Plot object containing: + [0]: cartesian line: -x**3/1600 + 3*x**2/160 - x/8 for x over (0.0, 20.0) + Plot[3]:Plot object containing: + [0]: cartesian line: x**5/40000 - 4013*x**3/90300 + 26*x**2/43 + 1520*x/903 for x over (0.0, 20.0) + + """ + + dir = dir.lower() + if subs is None: + subs = {} + + ax1 = self._plot_shear_force(dir, subs) + ax2 = self._plot_bending_moment(dir, subs) + ax3 = self._plot_slope(dir, subs) + ax4 = self._plot_deflection(dir, subs) + + return PlotGrid(4, 1, ax1, ax2, ax3, ax4) + + def _plot_shear_stress(self, dir, subs=None): + + shear_stress = self.shear_stress() + + if dir == 'x': + dir_num = 0 + color = 'r' + + elif dir == 'y': + dir_num = 1 + color = 'g' + + elif dir == 'z': + dir_num = 2 + color = 'b' + + if subs is None: + subs = {} + + for sym in shear_stress[dir_num].atoms(Symbol): + if sym != self.variable and sym not in subs: + raise ValueError('Value of %s was not passed.' %sym) + if self.length in subs: + length = subs[self.length] + else: + length = self.length + + return plot(shear_stress[dir_num].subs(subs), (self.variable, 0, length), show = False, title='Shear stress along %c direction'%dir, + xlabel=r'$\mathrm{X}$', ylabel=r'$\tau(%c)$'%dir, line_color=color) + + def plot_shear_stress(self, dir="all", subs=None): + + """ + + Returns a plot for Shear Stress along all three directions + present in the Beam object. + + Parameters + ========== + dir : string (default : "all") + Direction along which shear stress plot is required. + If no direction is specified, all plots are displayed. + subs : dictionary + Python dictionary containing Symbols as key and their + corresponding values. + + Examples + ======== + There is a beam of length 20 meters and area of cross section 2 square + meters. It is supported by rollers at both of its ends. A linear load having + slope equal to 12 is applied along y-axis. A constant distributed load + of magnitude 15 N is applied from start till its end along z-axis. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.physics.continuum_mechanics.beam import Beam3D + >>> from sympy import symbols + >>> l, E, G, I, A, x = symbols('l, E, G, I, A, x') + >>> b = Beam3D(20, E, G, I, 2, x) + >>> b.apply_load(15, start=0, order=0, dir="z") + >>> b.apply_load(12*x, start=0, order=0, dir="y") + >>> b.bc_deflection = [(0, [0, 0, 0]), (20, [0, 0, 0])] + >>> R1, R2, R3, R4 = symbols('R1, R2, R3, R4') + >>> b.apply_load(R1, start=0, order=-1, dir="z") + >>> b.apply_load(R2, start=20, order=-1, dir="z") + >>> b.apply_load(R3, start=0, order=-1, dir="y") + >>> b.apply_load(R4, start=20, order=-1, dir="y") + >>> b.solve_for_reaction_loads(R1, R2, R3, R4) + >>> b.plot_shear_stress() + PlotGrid object containing: + Plot[0]:Plot object containing: + [0]: cartesian line: 0 for x over (0.0, 20.0) + Plot[1]:Plot object containing: + [0]: cartesian line: -3*x**2 for x over (0.0, 20.0) + Plot[2]:Plot object containing: + [0]: cartesian line: -15*x/2 for x over (0.0, 20.0) + + """ + + dir = dir.lower() + # For shear stress along x direction + if dir == "x": + Px = self._plot_shear_stress('x', subs) + return Px.show() + # For shear stress along y direction + elif dir == "y": + Py = self._plot_shear_stress('y', subs) + return Py.show() + # For shear stress along z direction + elif dir == "z": + Pz = self._plot_shear_stress('z', subs) + return Pz.show() + # For shear stress along all direction + else: + Px = self._plot_shear_stress('x', subs) + Py = self._plot_shear_stress('y', subs) + Pz = self._plot_shear_stress('z', subs) + return PlotGrid(3, 1, Px, Py, Pz) + + def _max_shear_force(self, dir): + """ + Helper function for max_shear_force(). + """ + + dir = dir.lower() + + if dir == 'x': + dir_num = 0 + + elif dir == 'y': + dir_num = 1 + + elif dir == 'z': + dir_num = 2 + + if not self.shear_force()[dir_num]: + return (0,0) + # To restrict the range within length of the Beam + load_curve = Piecewise((float("nan"), self.variable<=0), + (self._load_vector[dir_num], self.variable>> from sympy.physics.continuum_mechanics.beam import Beam3D + >>> from sympy import symbols + >>> l, E, G, I, A, x = symbols('l, E, G, I, A, x') + >>> b = Beam3D(20, 40, 21, 100, 25, x) + >>> b.apply_load(15, start=0, order=0, dir="z") + >>> b.apply_load(12*x, start=0, order=0, dir="y") + >>> b.bc_deflection = [(0, [0, 0, 0]), (20, [0, 0, 0])] + >>> R1, R2, R3, R4 = symbols('R1, R2, R3, R4') + >>> b.apply_load(R1, start=0, order=-1, dir="z") + >>> b.apply_load(R2, start=20, order=-1, dir="z") + >>> b.apply_load(R3, start=0, order=-1, dir="y") + >>> b.apply_load(R4, start=20, order=-1, dir="y") + >>> b.solve_for_reaction_loads(R1, R2, R3, R4) + >>> b.max_shear_force() + [(0, 0), (20, 2400), (20, 300)] + """ + + max_shear = [] + max_shear.append(self._max_shear_force('x')) + max_shear.append(self._max_shear_force('y')) + max_shear.append(self._max_shear_force('z')) + return max_shear + + def _max_bending_moment(self, dir): + """ + Helper function for max_bending_moment(). + """ + + dir = dir.lower() + + if dir == 'x': + dir_num = 0 + + elif dir == 'y': + dir_num = 1 + + elif dir == 'z': + dir_num = 2 + + if not self.bending_moment()[dir_num]: + return (0,0) + # To restrict the range within length of the Beam + shear_curve = Piecewise((float("nan"), self.variable<=0), + (self.shear_force()[dir_num], self.variable>> from sympy.physics.continuum_mechanics.beam import Beam3D + >>> from sympy import symbols + >>> l, E, G, I, A, x = symbols('l, E, G, I, A, x') + >>> b = Beam3D(20, 40, 21, 100, 25, x) + >>> b.apply_load(15, start=0, order=0, dir="z") + >>> b.apply_load(12*x, start=0, order=0, dir="y") + >>> b.bc_deflection = [(0, [0, 0, 0]), (20, [0, 0, 0])] + >>> R1, R2, R3, R4 = symbols('R1, R2, R3, R4') + >>> b.apply_load(R1, start=0, order=-1, dir="z") + >>> b.apply_load(R2, start=20, order=-1, dir="z") + >>> b.apply_load(R3, start=0, order=-1, dir="y") + >>> b.apply_load(R4, start=20, order=-1, dir="y") + >>> b.solve_for_reaction_loads(R1, R2, R3, R4) + >>> b.max_bending_moment() + [(0, 0), (20, 3000), (20, 16000)] + """ + + max_bmoment = [] + max_bmoment.append(self._max_bending_moment('x')) + max_bmoment.append(self._max_bending_moment('y')) + max_bmoment.append(self._max_bending_moment('z')) + return max_bmoment + + max_bmoment = max_bending_moment + + def _max_deflection(self, dir): + """ + Helper function for max_Deflection() + """ + + dir = dir.lower() + + if dir == 'x': + dir_num = 0 + + elif dir == 'y': + dir_num = 1 + + elif dir == 'z': + dir_num = 2 + + if not self.deflection()[dir_num]: + return (0,0) + # To restrict the range within length of the Beam + slope_curve = Piecewise((float("nan"), self.variable<=0), + (self.slope()[dir_num], self.variable>> from sympy.physics.continuum_mechanics.beam import Beam3D + >>> from sympy import symbols + >>> l, E, G, I, A, x = symbols('l, E, G, I, A, x') + >>> b = Beam3D(20, 40, 21, 100, 25, x) + >>> b.apply_load(15, start=0, order=0, dir="z") + >>> b.apply_load(12*x, start=0, order=0, dir="y") + >>> b.bc_deflection = [(0, [0, 0, 0]), (20, [0, 0, 0])] + >>> R1, R2, R3, R4 = symbols('R1, R2, R3, R4') + >>> b.apply_load(R1, start=0, order=-1, dir="z") + >>> b.apply_load(R2, start=20, order=-1, dir="z") + >>> b.apply_load(R3, start=0, order=-1, dir="y") + >>> b.apply_load(R4, start=20, order=-1, dir="y") + >>> b.solve_for_reaction_loads(R1, R2, R3, R4) + >>> b.solve_slope_deflection() + >>> b.max_deflection() + [(0, 0), (10, 495/14), (-10 + 10*sqrt(10793)/43, (10 - 10*sqrt(10793)/43)**3/160 - 20/7 + (10 - 10*sqrt(10793)/43)**4/6400 + 20*sqrt(10793)/301 + 27*(10 - 10*sqrt(10793)/43)**2/560)] + """ + + max_def = [] + max_def.append(self._max_deflection('x')) + max_def.append(self._max_deflection('y')) + max_def.append(self._max_deflection('z')) + return max_def diff --git a/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/cable.py b/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/cable.py new file mode 100644 index 0000000000000000000000000000000000000000..e38c6601b0a12cad83bc7e87597e79937f4667a4 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/cable.py @@ -0,0 +1,815 @@ +""" +This module can be used to solve problems related +to 2D Cables. +""" + +from sympy.core.sympify import sympify +from sympy.core.symbol import Symbol,symbols +from sympy import sin, cos, pi, atan, diff, Piecewise, solve, rad +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.solvers.solveset import linsolve +from sympy.matrices import Matrix +from sympy.plotting import plot + +class Cable: + """ + Cables are structures in engineering that support + the applied transverse loads through the tensile + resistance developed in its members. + + Cables are widely used in suspension bridges, tension + leg offshore platforms, transmission lines, and find + use in several other engineering applications. + + Examples + ======== + A cable is supported at (0, 10) and (10, 10). Two point loads + acting vertically downwards act on the cable, one with magnitude 3 kN + and acting 2 meters from the left support and 3 meters below it, while + the other with magnitude 2 kN is 6 meters from the left support and + 6 meters below it. + + >>> from sympy.physics.continuum_mechanics.cable import Cable + >>> c = Cable(('A', 0, 10), ('B', 10, 10)) + >>> c.apply_load(-1, ('P', 2, 7, 3, 270)) + >>> c.apply_load(-1, ('Q', 6, 4, 2, 270)) + >>> c.loads + {'distributed': {}, 'point_load': {'P': [3, 270], 'Q': [2, 270]}} + >>> c.loads_position + {'P': [2, 7], 'Q': [6, 4]} + """ + def __init__(self, support_1, support_2): + """ + Initializes the class. + + Parameters + ========== + + support_1 and support_2 are tuples of the form + (label, x, y), where + + label : String or symbol + The label of the support + + x : Sympifyable + The x coordinate of the position of the support + + y : Sympifyable + The y coordinate of the position of the support + """ + self._left_support = [] + self._right_support = [] + self._supports = {} + self._support_labels = [] + self._loads = {"distributed": {}, "point_load": {}} + self._loads_position = {} + self._length = 0 + self._reaction_loads = {} + self._tension = {} + self._lowest_x_global = sympify(0) + self._lowest_y_global = sympify(0) + self._cable_eqn = None + self._tension_func = None + if support_1[0] == support_2[0]: + raise ValueError("Supports can not have the same label") + + elif support_1[1] == support_2[1]: + raise ValueError("Supports can not be at the same location") + + x1 = sympify(support_1[1]) + y1 = sympify(support_1[2]) + self._supports[support_1[0]] = [x1, y1] + + x2 = sympify(support_2[1]) + y2 = sympify(support_2[2]) + self._supports[support_2[0]] = [x2, y2] + + if support_1[1] < support_2[1]: + self._left_support.append(x1) + self._left_support.append(y1) + self._right_support.append(x2) + self._right_support.append(y2) + self._support_labels.append(support_1[0]) + self._support_labels.append(support_2[0]) + + else: + self._left_support.append(x2) + self._left_support.append(y2) + self._right_support.append(x1) + self._right_support.append(y1) + self._support_labels.append(support_2[0]) + self._support_labels.append(support_1[0]) + + for i in self._support_labels: + self._reaction_loads[Symbol("R_"+ i +"_x")] = 0 + self._reaction_loads[Symbol("R_"+ i +"_y")] = 0 + + @property + def supports(self): + """ + Returns the supports of the cable along with their + positions. + """ + return self._supports + + @property + def left_support(self): + """ + Returns the position of the left support. + """ + return self._left_support + + @property + def right_support(self): + """ + Returns the position of the right support. + """ + return self._right_support + + @property + def loads(self): + """ + Returns the magnitude and direction of the loads + acting on the cable. + """ + return self._loads + + @property + def loads_position(self): + """ + Returns the position of the point loads acting on the + cable. + """ + return self._loads_position + + @property + def length(self): + """ + Returns the length of the cable. + """ + return self._length + + @property + def reaction_loads(self): + """ + Returns the reaction forces at the supports, which are + initialized to 0. + """ + return self._reaction_loads + + @property + def tension(self): + """ + Returns the tension developed in the cable due to the loads + applied. + """ + return self._tension + + def tension_at(self, x): + """ + Returns the tension at a given value of x developed due to + distributed load. + """ + if 'distributed' not in self._tension.keys(): + raise ValueError("No distributed load added or solve method not called") + + if x > self._right_support[0] or x < self._left_support[0]: + raise ValueError("The value of x should be between the two supports") + + A = self._tension['distributed'] + X = Symbol('X') + + return A.subs({X:(x-self._lowest_x_global)}) + + def apply_length(self, length): + """ + This method specifies the length of the cable + + Parameters + ========== + + length : Sympifyable + The length of the cable + + Examples + ======== + + >>> from sympy.physics.continuum_mechanics.cable import Cable + >>> c = Cable(('A', 0, 10), ('B', 10, 10)) + >>> c.apply_length(20) + >>> c.length + 20 + """ + dist = ((self._left_support[0] - self._right_support[0])**2 + - (self._left_support[1] - self._right_support[1])**2)**(1/2) + + if length < dist: + raise ValueError("length should not be less than the distance between the supports") + + self._length = length + + def change_support(self, label, new_support): + """ + This method changes the mentioned support with a new support. + + Parameters + ========== + label: String or symbol + The label of the support to be changed + + new_support: Tuple of the form (new_label, x, y) + new_label: String or symbol + The label of the new support + + x: Sympifyable + The x-coordinate of the position of the new support. + + y: Sympifyable + The y-coordinate of the position of the new support. + + Examples + ======== + + >>> from sympy.physics.continuum_mechanics.cable import Cable + >>> c = Cable(('A', 0, 10), ('B', 10, 10)) + >>> c.supports + {'A': [0, 10], 'B': [10, 10]} + >>> c.change_support('B', ('C', 5, 6)) + >>> c.supports + {'A': [0, 10], 'C': [5, 6]} + """ + if label not in self._supports: + raise ValueError("No support exists with the given label") + + i = self._support_labels.index(label) + rem_label = self._support_labels[(i+1)%2] + x1 = self._supports[rem_label][0] + y1 = self._supports[rem_label][1] + + x = sympify(new_support[1]) + y = sympify(new_support[2]) + + for l in self._loads_position: + if l[0] >= max(x, x1) or l[0] <= min(x, x1): + raise ValueError("The change in support will throw an existing load out of range") + + self._supports.pop(label) + self._left_support.clear() + self._right_support.clear() + self._reaction_loads.clear() + self._support_labels.remove(label) + + self._supports[new_support[0]] = [x, y] + + if x1 < x: + self._left_support.append(x1) + self._left_support.append(y1) + self._right_support.append(x) + self._right_support.append(y) + self._support_labels.append(new_support[0]) + + else: + self._left_support.append(x) + self._left_support.append(y) + self._right_support.append(x1) + self._right_support.append(y1) + self._support_labels.insert(0, new_support[0]) + + for i in self._support_labels: + self._reaction_loads[Symbol("R_"+ i +"_x")] = 0 + self._reaction_loads[Symbol("R_"+ i +"_y")] = 0 + + def apply_load(self, order, load): + """ + This method adds load to the cable. + + Parameters + ========== + + order : Integer + The order of the applied load. + + - For point loads, order = -1 + - For distributed load, order = 0 + + load : tuple + + * For point loads, load is of the form (label, x, y, magnitude, direction), where: + + label : String or symbol + The label of the load + + x : Sympifyable + The x coordinate of the position of the load + + y : Sympifyable + The y coordinate of the position of the load + + magnitude : Sympifyable + The magnitude of the load. It must always be positive + + direction : Sympifyable + The angle, in degrees, that the load vector makes with the horizontal + in the counter-clockwise direction. It takes the values 0 to 360, + inclusive. + + + * For uniformly distributed load, load is of the form (label, magnitude) + + label : String or symbol + The label of the load + + magnitude : Sympifyable + The magnitude of the load. It must always be positive + + Examples + ======== + + For a point load of magnitude 12 units inclined at 30 degrees with the horizontal: + + >>> from sympy.physics.continuum_mechanics.cable import Cable + >>> c = Cable(('A', 0, 10), ('B', 10, 10)) + >>> c.apply_load(-1, ('Z', 5, 5, 12, 30)) + >>> c.loads + {'distributed': {}, 'point_load': {'Z': [12, 30]}} + >>> c.loads_position + {'Z': [5, 5]} + + + For a uniformly distributed load of magnitude 9 units: + + >>> from sympy.physics.continuum_mechanics.cable import Cable + >>> c = Cable(('A', 0, 10), ('B', 10, 10)) + >>> c.apply_load(0, ('X', 9)) + >>> c.loads + {'distributed': {'X': 9}, 'point_load': {}} + """ + if order == -1: + if len(self._loads["distributed"]) != 0: + raise ValueError("Distributed load already exists") + + label = load[0] + if label in self._loads["point_load"]: + raise ValueError("Label already exists") + + x = sympify(load[1]) + y = sympify(load[2]) + + if x > self._right_support[0] or x < self._left_support[0]: + raise ValueError("The load should be positioned between the supports") + + magnitude = sympify(load[3]) + direction = sympify(load[4]) + + self._loads["point_load"][label] = [magnitude, direction] + self._loads_position[label] = [x, y] + + elif order == 0: + if len(self._loads_position) != 0: + raise ValueError("Point load(s) already exist") + + label = load[0] + if label in self._loads["distributed"]: + raise ValueError("Label already exists") + + magnitude = sympify(load[1]) + + self._loads["distributed"][label] = magnitude + + else: + raise ValueError("Order should be either -1 or 0") + + def remove_loads(self, *args): + """ + This methods removes the specified loads. + + Parameters + ========== + This input takes multiple label(s) as input + label(s): String or symbol + The label(s) of the loads to be removed. + + Examples + ======== + + >>> from sympy.physics.continuum_mechanics.cable import Cable + >>> c = Cable(('A', 0, 10), ('B', 10, 10)) + >>> c.apply_load(-1, ('Z', 5, 5, 12, 30)) + >>> c.loads + {'distributed': {}, 'point_load': {'Z': [12, 30]}} + >>> c.remove_loads('Z') + >>> c.loads + {'distributed': {}, 'point_load': {}} + """ + for i in args: + if len(self._loads_position) == 0: + if i not in self._loads['distributed']: + raise ValueError("Error removing load " + i + ": no such load exists") + + else: + self._loads['disrtibuted'].pop(i) + + else: + if i not in self._loads['point_load']: + raise ValueError("Error removing load " + i + ": no such load exists") + + else: + self._loads['point_load'].pop(i) + self._loads_position.pop(i) + + def solve(self, *args): + """ + This method solves for the reaction forces at the supports, the tension developed in + the cable, and updates the length of the cable. + + Parameters + ========== + This method requires no input when solving for point loads + For distributed load, the x and y coordinates of the lowest point of the cable are + required as + + x: Sympifyable + The x coordinate of the lowest point + + y: Sympifyable + The y coordinate of the lowest point + + Examples + ======== + For point loads, + + >>> from sympy.physics.continuum_mechanics.cable import Cable + >>> c = Cable(("A", 0, 10), ("B", 10, 10)) + >>> c.apply_load(-1, ('Z', 2, 7.26, 3, 270)) + >>> c.apply_load(-1, ('X', 4, 6, 8, 270)) + >>> c.solve() + >>> c.tension + {A_Z: 8.91403453669861, X_B: 19*sqrt(13)/10, Z_X: 4.79150773600774} + >>> c.reaction_loads + {R_A_x: -5.25547445255474, R_A_y: 7.2, R_B_x: 5.25547445255474, R_B_y: 3.8} + >>> c.length + 5.7560958484519 + 2*sqrt(13) + + For distributed load, + + >>> from sympy.physics.continuum_mechanics.cable import Cable + >>> c=Cable(("A", 0, 40),("B", 100, 20)) + >>> c.apply_load(0, ("X", 850)) + >>> c.solve(58.58) + >>> c.tension + {'distributed': 36465.0*sqrt(0.00054335718671383*X**2 + 1)} + >>> c.tension_at(0) + 61717.4130533677 + >>> c.reaction_loads + {R_A_x: 36465.0, R_A_y: -49793.0, R_B_x: 44399.9537590861, R_B_y: 42868.2071025955} + """ + + if len(self._loads_position) != 0: + sorted_position = sorted(self._loads_position.items(), key = lambda item : item[1][0]) + + sorted_position.append(self._support_labels[1]) + sorted_position.insert(0, self._support_labels[0]) + + self._tension.clear() + moment_sum_from_left_support = 0 + moment_sum_from_right_support = 0 + F_x = 0 + F_y = 0 + self._length = 0 + tension_func = [] + x = symbols('x') + for i in range(1, len(sorted_position)-1): + if i == 1: + self._length+=sqrt((self._left_support[0] - self._loads_position[sorted_position[i][0]][0])**2 + (self._left_support[1] - self._loads_position[sorted_position[i][0]][1])**2) + + else: + self._length+=sqrt((self._loads_position[sorted_position[i-1][0]][0] - self._loads_position[sorted_position[i][0]][0])**2 + (self._loads_position[sorted_position[i-1][0]][1] - self._loads_position[sorted_position[i][0]][1])**2) + + if i == len(sorted_position)-2: + self._length+=sqrt((self._right_support[0] - self._loads_position[sorted_position[i][0]][0])**2 + (self._right_support[1] - self._loads_position[sorted_position[i][0]][1])**2) + + moment_sum_from_left_support += self._loads['point_load'][sorted_position[i][0]][0] * cos(pi * self._loads['point_load'][sorted_position[i][0]][1] / 180) * abs(self._left_support[1] - self._loads_position[sorted_position[i][0]][1]) + moment_sum_from_left_support += self._loads['point_load'][sorted_position[i][0]][0] * sin(pi * self._loads['point_load'][sorted_position[i][0]][1] / 180) * abs(self._left_support[0] - self._loads_position[sorted_position[i][0]][0]) + + F_x += self._loads['point_load'][sorted_position[i][0]][0] * cos(pi * self._loads['point_load'][sorted_position[i][0]][1] / 180) + F_y += self._loads['point_load'][sorted_position[i][0]][0] * sin(pi * self._loads['point_load'][sorted_position[i][0]][1] / 180) + + label = Symbol(sorted_position[i][0]+"_"+sorted_position[i+1][0]) + y2 = self._loads_position[sorted_position[i][0]][1] + x2 = self._loads_position[sorted_position[i][0]][0] + y1 = 0 + x1 = 0 + + if i == len(sorted_position)-2: + x1 = self._right_support[0] + y1 = self._right_support[1] + + else: + x1 = self._loads_position[sorted_position[i+1][0]][0] + y1 = self._loads_position[sorted_position[i+1][0]][1] + + angle_with_horizontal = atan((y1 - y2)/(x1 - x2)) + + tension = -(moment_sum_from_left_support)/(abs(self._left_support[1] - self._loads_position[sorted_position[i][0]][1])*cos(angle_with_horizontal) + abs(self._left_support[0] - self._loads_position[sorted_position[i][0]][0])*sin(angle_with_horizontal)) + self._tension[label] = tension + tension_func.append((tension, x<=x1)) + moment_sum_from_right_support += self._loads['point_load'][sorted_position[i][0]][0] * cos(pi * self._loads['point_load'][sorted_position[i][0]][1] / 180) * abs(self._right_support[1] - self._loads_position[sorted_position[i][0]][1]) + moment_sum_from_right_support += self._loads['point_load'][sorted_position[i][0]][0] * sin(pi * self._loads['point_load'][sorted_position[i][0]][1] / 180) * abs(self._right_support[0] - self._loads_position[sorted_position[i][0]][0]) + + label = Symbol(sorted_position[0][0]+"_"+sorted_position[1][0]) + y2 = self._loads_position[sorted_position[1][0]][1] + x2 = self._loads_position[sorted_position[1][0]][0] + x1 = self._left_support[0] + y1 = self._left_support[1] + + angle_with_horizontal = -atan((y2 - y1)/(x2 - x1)) + tension = -(moment_sum_from_right_support)/(abs(self._right_support[1] - self._loads_position[sorted_position[1][0]][1])*cos(angle_with_horizontal) + abs(self._right_support[0] - self._loads_position[sorted_position[1][0]][0])*sin(angle_with_horizontal)) + self._tension[label] = tension + + tension_func.insert(0,(tension, x<=x2)) + self._tension_func = Piecewise(*tension_func) + angle_with_horizontal = pi/2 - angle_with_horizontal + label = self._support_labels[0] + self._reaction_loads[Symbol("R_"+label+"_x")] = -sin(angle_with_horizontal) * tension + F_x += -sin(angle_with_horizontal) * tension + self._reaction_loads[Symbol("R_"+label+"_y")] = cos(angle_with_horizontal) * tension + F_y += cos(angle_with_horizontal) * tension + + label = self._support_labels[1] + self._reaction_loads[Symbol("R_"+label+"_x")] = -F_x + self._reaction_loads[Symbol("R_"+label+"_y")] = -F_y + + elif len(self._loads['distributed']) != 0 : + + if len(args) == 0: + raise ValueError("Provide the lowest point of the cable") + + lowest_x = sympify(args[0]) + self._lowest_x_global = lowest_x + + a = Symbol('a', positive=True) + c = Symbol('c') + # augmented matrix form of linsolve + + M = Matrix( + [[(self._left_support[0]-lowest_x)**2, 1, self._left_support[1]], + [(self._right_support[0]-lowest_x)**2, 1, self._right_support[1]], + ]) + + coefficient_solution = list(linsolve(M, (a, c))) + if len(coefficient_solution) ==0 or coefficient_solution[0][0]== 0: + raise ValueError("The lowest point is inconsistent with the supports") + + A = coefficient_solution[0][0] + C = coefficient_solution[0][1] + coefficient_solution[0][0]*lowest_x**2 + B = -2*coefficient_solution[0][0]*lowest_x + self._lowest_y_global = coefficient_solution[0][1] + lowest_y = self._lowest_y_global + + # y = A*x**2 + B*x + C + # shifting origin to lowest point + X = Symbol('X') + Y = Symbol('Y') + Y = A*(X + lowest_x)**2 + B*(X + lowest_x) + C - lowest_y + + temp_list = list(self._loads['distributed'].values()) + applied_force = temp_list[0] + + horizontal_force_constant = (applied_force * (self._right_support[0] - lowest_x)**2) / (2 * (self._right_support[1] - lowest_y)) + + self._tension.clear() + tangent_slope_to_curve = diff(Y, X) + self._tension['distributed'] = horizontal_force_constant / (cos(atan(tangent_slope_to_curve))) + + label = self._support_labels[0] + self._reaction_loads[Symbol("R_"+label+"_x")] = self.tension_at(self._left_support[0]) * cos(atan(tangent_slope_to_curve.subs(X, self._left_support[0] - lowest_x))) + self._reaction_loads[Symbol("R_"+label+"_y")] = self.tension_at(self._left_support[0]) * sin(atan(tangent_slope_to_curve.subs(X, self._left_support[0] - lowest_x))) + + label = self._support_labels[1] + self._reaction_loads[Symbol("R_"+label+"_x")] = self.tension_at(self._left_support[0]) * cos(atan(tangent_slope_to_curve.subs(X, self._right_support[0] - lowest_x))) + self._reaction_loads[Symbol("R_"+label+"_y")] = self.tension_at(self._left_support[0]) * sin(atan(tangent_slope_to_curve.subs(X, self._right_support[0] - lowest_x))) + + def draw(self): + """ + This method is used to obtain a plot for the specified cable with its supports, + shape and loads. + + Examples + ======== + + For point loads, + + >>> from sympy.physics.continuum_mechanics.cable import Cable + >>> c = Cable(("A", 0, 10), ("B", 10, 10)) + >>> c.apply_load(-1, ('Z', 2, 7.26, 3, 270)) + >>> c.apply_load(-1, ('X', 4, 6, 8, 270)) + >>> c.solve() + >>> p = c.draw() + >>> p # doctest: +ELLIPSIS + Plot object containing: + [0]: cartesian line: Piecewise((10 - 1.37*x, x <= 2), (8.52 - 0.63*x, x <= 4), (2*x/3 + 10/3, x <= 10)) for x over (0.0, 10.0) + ... + >>> p.show() + + For uniformly distributed loads, + + >>> from sympy.physics.continuum_mechanics.cable import Cable + >>> c=Cable(("A", 0, 40),("B", 100, 20)) + >>> c.apply_load(0, ("X", 850)) + >>> c.solve(58.58) + >>> p = c.draw() + >>> p # doctest: +ELLIPSIS + Plot object containing: + [0]: cartesian line: 0.0116550116550117*(x - 58.58)**2 + 0.00447086247086247 for x over (0.0, 100.0) + [1]: cartesian line: -7.49552913752915 for x over (0.0, 100.0) + ... + >>> p.show() + """ + x = Symbol("x") + annotations = [] + support_rectangles = self._draw_supports() + + xy_min = min(self._left_support[0],self._lowest_y_global) + xy_max = max(self._right_support[0], max(self._right_support[1],self._left_support[1])) + max_diff = xy_max - xy_min + if len(self._loads_position) != 0: + self._cable_eqn = self._draw_cable(-1) + annotations += self._draw_loads(-1) + + elif len(self._loads['distributed']) != 0 : + self._cable_eqn = self._draw_cable(0) + annotations += self._draw_loads(0) + + if not self._cable_eqn: + raise ValueError("solve method not called and/or values provided for loads and supports not adequate") + + cab_plot = plot(*self._cable_eqn,(x,self._left_support[0],self._right_support[0]), + xlim=(xy_min-0.5*max_diff,xy_max+0.5*max_diff), + ylim=(xy_min-0.5*max_diff,xy_max+0.5*max_diff), + rectangles=support_rectangles,show= False,annotations=annotations, axis=False) + + return cab_plot + + def _draw_supports(self): + member_rectangles = [] + xy_min = min(self._left_support[0],self._lowest_y_global) + xy_max = max(self._right_support[0], max(self._right_support[1],self._left_support[1])) + max_diff = xy_max - xy_min + + supp_width = 0.075*max_diff + + member_rectangles.append( + { + 'xy': (self._left_support[0]-supp_width,self._left_support[1]), + 'width': supp_width, + 'height':supp_width, + 'color':'brown', + 'fill': False + } + ) + + member_rectangles.append( + { + 'xy': (self._right_support[0],self._right_support[1]), + 'width': supp_width, + 'height':supp_width, + 'color':'brown', + 'fill': False + } + ) + + return member_rectangles + + def _draw_cable(self,order): + xy_min = min(self._left_support[0],self._lowest_y_global) + xy_max = max(self._right_support[0], max(self._right_support[1],self._left_support[1])) + max_diff = xy_max - xy_min + if order == -1 : + x,y = symbols('x y') + line_func = [] + sorted_position = sorted(self._loads_position.items(), key = lambda item : item[1][0]) + + for i in range(len(sorted_position)): + if(i==0): + y = ((sorted_position[i][1][1] - self._left_support[1])*(x-self._left_support[0]))/(sorted_position[i][1][0]- self._left_support[0]) + self._left_support[1] + else: + y = ((sorted_position[i][1][1] - sorted_position[i-1][1][1] )*(x-sorted_position[i-1][1][0]))/(sorted_position[i][1][0]- sorted_position[i-1][1][0]) + sorted_position[i-1][1][1] + line_func.append((y,x<=sorted_position[i][1][0])) + + y = ((sorted_position[len(sorted_position)-1][1][1] - self._right_support[1])*(x-self._right_support[0]))/(sorted_position[i][1][0]- self._right_support[0]) + self._right_support[1] + line_func.append((y,x<=self._right_support[0])) + return [Piecewise(*line_func)] + + elif order == 0: + x0 = self._lowest_x_global + diff_force_height = max_diff*0.075 + + a,c,x,y = symbols('a c x y') + parabola_eqn = a*(x-x0)**2 + c - y + + points = [(self._left_support[0],self._left_support[1]),(self._right_support[0],self._right_support[1])] + equations = [] + for px, py in points: + equations.append(parabola_eqn.subs({x: px, y: py})) + solution = solve(equations, (a, c)) + parabola_eqn = solution[a]*(x-x0)**2 + solution[c] + return [parabola_eqn, self._lowest_y_global - diff_force_height] + + def _draw_loads(self,order): + xy_min = min(self._left_support[0],self._lowest_y_global) + xy_max = max(self._right_support[0], max(self._right_support[1],self._left_support[1])) + max_diff = xy_max - xy_min + if(order==-1): + arrow_length = max_diff*0.1 + force_arrows = [] + for key in self._loads['point_load']: + force_arrows.append( + { + 'text': '', + 'xy':(self._loads_position[key][0]+arrow_length*cos(rad(self._loads['point_load'][key][1])),\ + self._loads_position[key][1] + arrow_length*sin(rad(self._loads['point_load'][key][1]))), + 'xytext': (self._loads_position[key][0],self._loads_position[key][1]), + 'arrowprops': {'width': 1, 'headlength':3, 'headwidth':3 , 'facecolor': 'black', } + } + ) + mag = self._loads['point_load'][key][0] + force_arrows.append( + { + 'text':f'{mag}N', + 'xy': (self._loads_position[key][0]+arrow_length*1.6*cos(rad(self._loads['point_load'][key][1])),\ + self._loads_position[key][1] + arrow_length*1.6*sin(rad(self._loads['point_load'][key][1]))), + } + ) + return force_arrows + + elif (order == 0): + x = symbols('x') + force_arrows = [] + x_val = [self._left_support[0] + ((self._right_support[0]-self._left_support[0])/10)*i for i in range(1,10)] + for i in x_val: + force_arrows.append( + { + 'text':'', + 'xytext':( + i, + self._cable_eqn[0].subs(x,i) + ), + 'xy':( + i, + self._cable_eqn[1].subs(x,i) + ), + 'arrowprops':{'width':1, 'headlength':3.5, 'headwidth':3.5, 'facecolor':'black'} + } + ) + mag = 0 + for key in self._loads['distributed']: + mag += self._loads['distributed'][key] + + force_arrows.append( + { + 'text':f'{mag} N/m', + 'xy':((self._left_support[0]+self._right_support[0])/2,self._lowest_y_global - max_diff*0.15) + } + ) + return force_arrows + + def plot_tension(self): + """ + Returns the diagram/plot of the tension generated in the cable at various points. + + Examples + ======== + + For point loads, + + >>> from sympy.physics.continuum_mechanics.cable import Cable + >>> c = Cable(("A", 0, 10), ("B", 10, 10)) + >>> c.apply_load(-1, ('Z', 2, 7.26, 3, 270)) + >>> c.apply_load(-1, ('X', 4, 6, 8, 270)) + >>> c.solve() + >>> p = c.plot_tension() + >>> p + Plot object containing: + [0]: cartesian line: Piecewise((8.91403453669861, x <= 2), (4.79150773600774, x <= 4), (19*sqrt(13)/10, x <= 10)) for x over (0.0, 10.0) + >>> p.show() + + For uniformly distributed loads, + + >>> from sympy.physics.continuum_mechanics.cable import Cable + >>> c=Cable(("A", 0, 40),("B", 100, 20)) + >>> c.apply_load(0, ("X", 850)) + >>> c.solve(58.58) + >>> p = c.plot_tension() + >>> p + Plot object containing: + [0]: cartesian line: 36465.0*sqrt(0.00054335718671383*X**2 + 1) for X over (0.0, 100.0) + >>> p.show() + + """ + if len(self._loads_position) != 0: + x = symbols('x') + tension_plot = plot(self._tension_func, (x,self._left_support[0],self._right_support[0]), show=False) + else: + X = symbols('X') + tension_plot = plot(self._tension['distributed'], (X,self._left_support[0],self._right_support[0]), show=False) + return tension_plot diff --git a/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/tests/__init__.py b/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/tests/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/tests/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca7b9f41dcaa678710787632952ad149f06e56c9 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/tests/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/tests/__pycache__/test_arch.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/tests/__pycache__/test_arch.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..feb3feea13ae9bc5a0358362d9b12d05b3c16619 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/tests/__pycache__/test_arch.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/tests/__pycache__/test_beam.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/tests/__pycache__/test_beam.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e2be3db8628e8febdc29014a399682a09fb3d0e Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/tests/__pycache__/test_beam.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/tests/__pycache__/test_cable.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/tests/__pycache__/test_cable.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b4df6b22cff78261c6416b2a529340221dfe7eb Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/tests/__pycache__/test_cable.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/tests/__pycache__/test_truss.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/tests/__pycache__/test_truss.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97f79cf4a04a806adf684435a2f02d0a55d4e05a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/tests/__pycache__/test_truss.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/tests/test_arch.py b/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/tests/test_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..3d77062702222d7381a89450e8230b52bac4028c --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/tests/test_arch.py @@ -0,0 +1,61 @@ +from sympy.physics.continuum_mechanics.arch import Arch +from sympy import Symbol, simplify + +x = Symbol('x') +t = Symbol('t') + +def test_arch_init(): + a = Arch((0,0),(10,0),crown_x=5,crown_y=5) + assert a.get_loads == {'distributed': {}, 'concentrated': {}} + assert a.reaction_force == {Symbol('R_A_x'):0, Symbol('R_A_y'):0, Symbol('R_B_x'):0, Symbol('R_B_y'):0} + assert a.supports == {'left':'hinge', 'right':'hinge'} + assert a.left_support == (0,0) + assert a.right_support == (10,0) + assert a.get_shape_eqn == 5 - ((x-5)**2)/5 + + a = Arch((0,0),(10,1),crown_x=6) + a.change_support_type(left_support='roller') + a.add_member(0.5) + assert a.supports == {'left':'roller', 'right':'hinge'} + assert simplify(a.get_shape_eqn) == simplify(9/5 - (x - 6)**2/20) + +def test_arch_support(): + a = Arch((0,0),(40,0),crown_x=20,crown_y=12) + a.apply_load(-1,'C',8,150,angle=270) + a.apply_load(0,'D',start=20,end=40,mag=-4) + a.solve() + assert abs(a.reaction_force[Symbol("R_A_x")] - 83.33333333333333) < 10e-12 + assert abs(a.reaction_force[Symbol("R_B_y")] - 90.00000000000000) < 10e-12 + assert abs(a.reaction_force[Symbol("R_B_x")] + 83.33333333333333) < 10e-12 + assert abs(a.reaction_force[Symbol("R_A_y")] - 140.00000000000000) < 10e-12 + +def test_arch_member(): + a = Arch((0,0),(40,0),crown_x=20,crown_y=15) + a.change_support_type(right_support='roller') + a.add_member(0) + a.apply_load(-1,'D',start=12,mag=3,angle=270) + a.apply_load(-1,'E',start=6,mag=4,angle=270) + a.apply_load(-1,'C',start=30,mag=5,angle=270) + a.solve() + assert a.reaction_force[Symbol("R_A_x")] == 0 + assert abs(a.reaction_force[Symbol("R_A_y")] - 6.750000000000000) < 10e-12 + assert a.reaction_force[Symbol("R_B_x")] == 0 + assert abs(a.reaction_force[Symbol("R_B_y")] - 5.250000000000000) < 10e-12 + +def test_symbol_magnitude(): + a = Arch((0,0),(16,0),crown_x=8,crown_y=5) + a.apply_load(0,'C',start=3,end=5,mag=t) + a.solve() + assert a.reaction_force[Symbol("R_A_x")] == -(4*t)/5 + assert a.reaction_force[Symbol("R_A_y")] == -(3*t)/2 + assert a.reaction_force[Symbol("R_B_x")] == (4*t)/5 + assert a.reaction_force[Symbol("R_B_y")] == -t/2 + assert a.bending_moment_at(4) == -5*t/2 + +def test_forces(): + a = Arch((0,0),(40,0),crown_x=20,crown_y=12) + a.apply_load(-1,'C',8,150,angle=270) + a.apply_load(0,'D',start=20,end=40,mag=-4) + a.solve() + assert abs(a.axial_force_at(7.999999999999999)-149.430523405935) < 1e-12 + assert abs(a.shear_force_at(7.999999999999999)-64.9227473161196) < 1e-12 diff --git a/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/tests/test_beam.py b/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/tests/test_beam.py new file mode 100644 index 0000000000000000000000000000000000000000..a6a36fb030f99d9d384e52d4a239351688c7626b --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/tests/test_beam.py @@ -0,0 +1,1118 @@ +from sympy.core.function import expand +from sympy.core.numbers import (Rational, pi) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.sets.sets import Interval +from sympy.simplify.simplify import simplify +from sympy.physics.continuum_mechanics.beam import Beam +from sympy.functions import SingularityFunction, Piecewise, meijerg, Abs, log +from sympy.testing.pytest import raises +from sympy.physics.units import meter, newton, kilo, giga, milli +from sympy.physics.continuum_mechanics.beam import Beam3D +from sympy.geometry import Circle, Polygon, Point2D, Triangle +from sympy.core.sympify import sympify + +x = Symbol('x') +y = Symbol('y') +R1, R2 = symbols('R1, R2') + + +def test_Beam(): + E = Symbol('E') + E_1 = Symbol('E_1') + I = Symbol('I') + I_1 = Symbol('I_1') + A = Symbol('A') + + b = Beam(1, E, I) + assert b.length == 1 + assert b.elastic_modulus == E + assert b.second_moment == I + assert b.variable == x + + # Test the length setter + b.length = 4 + assert b.length == 4 + + # Test the E setter + b.elastic_modulus = E_1 + assert b.elastic_modulus == E_1 + + # Test the I setter + b.second_moment = I_1 + assert b.second_moment is I_1 + + # Test the variable setter + b.variable = y + assert b.variable is y + + # Test for all boundary conditions. + b.bc_deflection = [(0, 2)] + b.bc_slope = [(0, 1)] + b.bc_bending_moment = [(0, 5)] + b.bc_shear_force = [(2, 1)] + assert b.boundary_conditions == {'deflection': [(0, 2)], 'slope': [(0, 1)], + 'bending_moment': [(0, 5)], 'shear_force': [(2, 1)]} + + # Test for shear force boundary condition method + b.bc_shear_force.extend([(1, 1), (2, 3)]) + sf_bcs = b.bc_shear_force + assert sf_bcs == [(2, 1), (1, 1), (2, 3)] + + # Test for slope boundary condition method + b.bc_bending_moment.extend([(1, 3), (5, 3)]) + bm_bcs = b.bc_bending_moment + assert bm_bcs == [(0, 5), (1, 3), (5, 3)] + + # Test for slope boundary condition method + b.bc_slope.extend([(4, 3), (5, 0)]) + s_bcs = b.bc_slope + assert s_bcs == [(0, 1), (4, 3), (5, 0)] + + # Test for deflection boundary condition method + b.bc_deflection.extend([(4, 3), (5, 0)]) + d_bcs = b.bc_deflection + assert d_bcs == [(0, 2), (4, 3), (5, 0)] + + # Test for updated boundary conditions + bcs_new = b.boundary_conditions + assert bcs_new == { + 'deflection': [(0, 2), (4, 3), (5, 0)], + 'slope': [(0, 1), (4, 3), (5, 0)], + 'bending_moment': [(0, 5), (1, 3), (5, 3)], + 'shear_force': [(2, 1), (1, 1), (2, 3)]} + + b1 = Beam(30, E, I) + b1.apply_load(-8, 0, -1) + b1.apply_load(R1, 10, -1) + b1.apply_load(R2, 30, -1) + b1.apply_load(120, 30, -2) + b1.bc_deflection = [(10, 0), (30, 0)] + b1.solve_for_reaction_loads(R1, R2) + + # Test for finding reaction forces + p = b1.reaction_loads + q = {R1: 6, R2: 2} + assert p == q + + # Test for load distribution function. + p = b1.load + q = -8*SingularityFunction(x, 0, -1) + 6*SingularityFunction(x, 10, -1) \ + + 120*SingularityFunction(x, 30, -2) + 2*SingularityFunction(x, 30, -1) + assert p == q + + # Test for shear force distribution function + p = b1.shear_force() + q = 8*SingularityFunction(x, 0, 0) - 6*SingularityFunction(x, 10, 0) \ + - 120*SingularityFunction(x, 30, -1) - 2*SingularityFunction(x, 30, 0) + assert p == q + + # Test for shear stress distribution function + p = b1.shear_stress() + q = (8*SingularityFunction(x, 0, 0) - 6*SingularityFunction(x, 10, 0) \ + - 120*SingularityFunction(x, 30, -1) \ + - 2*SingularityFunction(x, 30, 0))/A + assert p==q + + # Test for bending moment distribution function + p = b1.bending_moment() + q = 8*SingularityFunction(x, 0, 1) - 6*SingularityFunction(x, 10, 1) \ + - 120*SingularityFunction(x, 30, 0) - 2*SingularityFunction(x, 30, 1) + assert p == q + + # Test for slope distribution function + p = b1.slope() + q = -4*SingularityFunction(x, 0, 2) + 3*SingularityFunction(x, 10, 2) \ + + 120*SingularityFunction(x, 30, 1) + SingularityFunction(x, 30, 2) \ + + Rational(4000, 3) + assert p == q/(E*I) + + # Test for deflection distribution function + p = b1.deflection() + q = x*Rational(4000, 3) - 4*SingularityFunction(x, 0, 3)/3 \ + + SingularityFunction(x, 10, 3) + 60*SingularityFunction(x, 30, 2) \ + + SingularityFunction(x, 30, 3)/3 - 12000 + assert p == q/(E*I) + + # Test using symbols + l = Symbol('l') + w0 = Symbol('w0') + w2 = Symbol('w2') + a1 = Symbol('a1') + c = Symbol('c') + c1 = Symbol('c1') + d = Symbol('d') + e = Symbol('e') + f = Symbol('f') + + b2 = Beam(l, E, I) + + b2.apply_load(w0, a1, 1) + b2.apply_load(w2, c1, -1) + + b2.bc_deflection = [(c, d)] + b2.bc_slope = [(e, f)] + + # Test for load distribution function. + p = b2.load + q = w0*SingularityFunction(x, a1, 1) + w2*SingularityFunction(x, c1, -1) + assert p == q + + # Test for shear force distribution function + p = b2.shear_force() + q = -w0*SingularityFunction(x, a1, 2)/2 \ + - w2*SingularityFunction(x, c1, 0) + assert p == q + + # Test for shear stress distribution function + p = b2.shear_stress() + q = (-w0*SingularityFunction(x, a1, 2)/2 \ + - w2*SingularityFunction(x, c1, 0))/A + assert p == q + + # Test for bending moment distribution function + p = b2.bending_moment() + q = -w0*SingularityFunction(x, a1, 3)/6 - w2*SingularityFunction(x, c1, 1) + assert p == q + + # Test for slope distribution function + p = b2.slope() + q = (w0*SingularityFunction(x, a1, 4)/24 + w2*SingularityFunction(x, c1, 2)/2)/(E*I) + (E*I*f - w0*SingularityFunction(e, a1, 4)/24 - w2*SingularityFunction(e, c1, 2)/2)/(E*I) + assert expand(p) == expand(q) + + # Test for deflection distribution function + p = b2.deflection() + q = x*(E*I*f - w0*SingularityFunction(e, a1, 4)/24 \ + - w2*SingularityFunction(e, c1, 2)/2)/(E*I) \ + + (w0*SingularityFunction(x, a1, 5)/120 \ + + w2*SingularityFunction(x, c1, 3)/6)/(E*I) \ + + (E*I*(-c*f + d) + c*w0*SingularityFunction(e, a1, 4)/24 \ + + c*w2*SingularityFunction(e, c1, 2)/2 \ + - w0*SingularityFunction(c, a1, 5)/120 \ + - w2*SingularityFunction(c, c1, 3)/6)/(E*I) + assert simplify(p - q) == 0 + + b3 = Beam(9, E, I, 2) + b3.apply_load(value=-2, start=2, order=2, end=3) + b3.bc_slope.append((0, 2)) + C3 = symbols('C3') + C4 = symbols('C4') + + p = b3.load + q = -2*SingularityFunction(x, 2, 2) + 2*SingularityFunction(x, 3, 0) \ + + 4*SingularityFunction(x, 3, 1) + 2*SingularityFunction(x, 3, 2) + assert p == q + + p = b3.shear_force() + q = 2*SingularityFunction(x, 2, 3)/3 - 2*SingularityFunction(x, 3, 1) \ + - 2*SingularityFunction(x, 3, 2) - 2*SingularityFunction(x, 3, 3)/3 + assert p == q + + p = b3.shear_stress() + q = SingularityFunction(x, 2, 3)/3 - 1*SingularityFunction(x, 3, 1) \ + - 1*SingularityFunction(x, 3, 2) - 1*SingularityFunction(x, 3, 3)/3 + assert p == q + + p = b3.slope() + q = 2 - (SingularityFunction(x, 2, 5)/30 - SingularityFunction(x, 3, 3)/3 \ + - SingularityFunction(x, 3, 4)/6 - SingularityFunction(x, 3, 5)/30)/(E*I) + assert p == q + + p = b3.deflection() + q = 2*x - (SingularityFunction(x, 2, 6)/180 \ + - SingularityFunction(x, 3, 4)/12 - SingularityFunction(x, 3, 5)/30 \ + - SingularityFunction(x, 3, 6)/180)/(E*I) + assert p == q + C4 + + b4 = Beam(4, E, I, 3) + b4.apply_load(-3, 0, 0, end=3) + + p = b4.load + q = -3*SingularityFunction(x, 0, 0) + 3*SingularityFunction(x, 3, 0) + assert p == q + + p = b4.shear_force() + q = 3*SingularityFunction(x, 0, 1) \ + - 3*SingularityFunction(x, 3, 1) + assert p == q + + p = b4.shear_stress() + q = SingularityFunction(x, 0, 1) - SingularityFunction(x, 3, 1) + assert p == q + + p = b4.slope() + q = -3*SingularityFunction(x, 0, 3)/6 + 3*SingularityFunction(x, 3, 3)/6 + assert p == q/(E*I) + C3 + + p = b4.deflection() + q = -3*SingularityFunction(x, 0, 4)/24 + 3*SingularityFunction(x, 3, 4)/24 + assert p == q/(E*I) + C3*x + C4 + + # can't use end with point loads + raises(ValueError, lambda: b4.apply_load(-3, 0, -1, end=3)) + with raises(TypeError): + b4.variable = 1 + + +def test_insufficient_bconditions(): + # Test cases when required number of boundary conditions + # are not provided to solve the integration constants. + L = symbols('L', positive=True) + E, I, P, a3, a4 = symbols('E I P a3 a4') + + b = Beam(L, E, I, base_char='a') + b.apply_load(R2, L, -1) + b.apply_load(R1, 0, -1) + b.apply_load(-P, L/2, -1) + b.solve_for_reaction_loads(R1, R2) + + p = b.slope() + q = P*SingularityFunction(x, 0, 2)/4 - P*SingularityFunction(x, L/2, 2)/2 + P*SingularityFunction(x, L, 2)/4 + assert p == q/(E*I) + a3 + + p = b.deflection() + q = P*SingularityFunction(x, 0, 3)/12 - P*SingularityFunction(x, L/2, 3)/6 + P*SingularityFunction(x, L, 3)/12 + assert p == q/(E*I) + a3*x + a4 + + b.bc_deflection = [(0, 0)] + p = b.deflection() + q = a3*x + P*SingularityFunction(x, 0, 3)/12 - P*SingularityFunction(x, L/2, 3)/6 + P*SingularityFunction(x, L, 3)/12 + assert p == q/(E*I) + + b.bc_deflection = [(0, 0), (L, 0)] + p = b.deflection() + q = -L**2*P*x/16 + P*SingularityFunction(x, 0, 3)/12 - P*SingularityFunction(x, L/2, 3)/6 + P*SingularityFunction(x, L, 3)/12 + assert p == q/(E*I) + + +def test_statically_indeterminate(): + E = Symbol('E') + I = Symbol('I') + M1, M2 = symbols('M1, M2') + F = Symbol('F') + l = Symbol('l', positive=True) + + b5 = Beam(l, E, I) + b5.bc_deflection = [(0, 0),(l, 0)] + b5.bc_slope = [(0, 0),(l, 0)] + + b5.apply_load(R1, 0, -1) + b5.apply_load(M1, 0, -2) + b5.apply_load(R2, l, -1) + b5.apply_load(M2, l, -2) + b5.apply_load(-F, l/2, -1) + + b5.solve_for_reaction_loads(R1, R2, M1, M2) + p = b5.reaction_loads + q = {R1: F/2, R2: F/2, M1: -F*l/8, M2: F*l/8} + assert p == q + + +def test_beam_units(): + E = Symbol('E') + I = Symbol('I') + R1, R2 = symbols('R1, R2') + + kN = kilo*newton + gN = giga*newton + + b = Beam(8*meter, 200*gN/meter**2, 400*1000000*(milli*meter)**4) + b.apply_load(5*kN, 2*meter, -1) + b.apply_load(R1, 0*meter, -1) + b.apply_load(R2, 8*meter, -1) + b.apply_load(10*kN/meter, 4*meter, 0, end=8*meter) + b.bc_deflection = [(0*meter, 0*meter), (8*meter, 0*meter)] + b.solve_for_reaction_loads(R1, R2) + assert b.reaction_loads == {R1: -13750*newton, R2: -31250*newton} + + b = Beam(3*meter, E*newton/meter**2, I*meter**4) + b.apply_load(8*kN, 1*meter, -1) + b.apply_load(R1, 0*meter, -1) + b.apply_load(R2, 3*meter, -1) + b.apply_load(12*kN*meter, 2*meter, -2) + b.bc_deflection = [(0*meter, 0*meter), (3*meter, 0*meter)] + b.solve_for_reaction_loads(R1, R2) + assert b.reaction_loads == {R1: newton*Rational(-28000, 3), R2: newton*Rational(4000, 3)} + assert b.deflection().subs(x, 1*meter) == 62000*meter/(9*E*I) + + +def test_variable_moment(): + E = Symbol('E') + I = Symbol('I') + + b = Beam(4, E, 2*(4 - x)) + b.apply_load(20, 4, -1) + R, M = symbols('R, M') + b.apply_load(R, 0, -1) + b.apply_load(M, 0, -2) + b.bc_deflection = [(0, 0)] + b.bc_slope = [(0, 0)] + b.solve_for_reaction_loads(R, M) + assert b.slope().expand() == ((10*x*SingularityFunction(x, 0, 0) + - 10*(x - 4)*SingularityFunction(x, 4, 0))/E).expand() + assert b.deflection().expand() == ((5*x**2*SingularityFunction(x, 0, 0) + - 10*Piecewise((0, Abs(x)/4 < 1), (x**2*meijerg(((-1, 1), ()), ((), (-2, 0)), x/4), True)) + + 40*SingularityFunction(x, 4, 1))/E).expand() + + b = Beam(4, E - x, I) + b.apply_load(20, 4, -1) + R, M = symbols('R, M') + b.apply_load(R, 0, -1) + b.apply_load(M, 0, -2) + b.bc_deflection = [(0, 0)] + b.bc_slope = [(0, 0)] + b.solve_for_reaction_loads(R, M) + assert b.slope().expand() == ((-80*(-log(-E) + log(-E + x))*SingularityFunction(x, 0, 0) + + 80*(-log(-E + 4) + log(-E + x))*SingularityFunction(x, 4, 0) + 20*(-E*log(-E) + + E*log(-E + x) + x)*SingularityFunction(x, 0, 0) - 20*(-E*log(-E + 4) + E*log(-E + x) + + x - 4)*SingularityFunction(x, 4, 0))/I).expand() + + +def test_composite_beam(): + E = Symbol('E') + I = Symbol('I') + b1 = Beam(2, E, 1.5*I) + b2 = Beam(2, E, I) + b = b1.join(b2, "fixed") + b.apply_load(-20, 0, -1) + b.apply_load(80, 0, -2) + b.apply_load(20, 4, -1) + b.bc_slope = [(0, 0)] + b.bc_deflection = [(0, 0)] + assert b.length == 4 + assert b.second_moment == Piecewise((1.5*I, x <= 2), (I, x <= 4)) + assert b.slope().subs(x, 4) == 120.0/(E*I) + assert b.slope().subs(x, 2) == 80.0/(E*I) + assert int(b.deflection().subs(x, 4).args[0]) == -302 # Coefficient of 1/(E*I) + + l = symbols('l', positive=True) + R1, M1, R2, R3, P = symbols('R1 M1 R2 R3 P') + b1 = Beam(2*l, E, I) + b2 = Beam(2*l, E, I) + b = b1.join(b2,"hinge") + b.apply_load(M1, 0, -2) + b.apply_load(R1, 0, -1) + b.apply_load(R2, l, -1) + b.apply_load(R3, 4*l, -1) + b.apply_load(P, 3*l, -1) + b.bc_slope = [(0, 0)] + b.bc_deflection = [(0, 0), (l, 0), (4*l, 0)] + b.solve_for_reaction_loads(M1, R1, R2, R3) + assert b.reaction_loads == {R3: -P/2, R2: P*Rational(-5, 4), M1: -P*l/4, R1: P*Rational(3, 4)} + assert b.slope().subs(x, 3*l) == -7*P*l**2/(48*E*I) + assert b.deflection().subs(x, 2*l) == 7*P*l**3/(24*E*I) + assert b.deflection().subs(x, 3*l) == 5*P*l**3/(16*E*I) + + # When beams having same second moment are joined. + b1 = Beam(2, 500, 10) + b2 = Beam(2, 500, 10) + b = b1.join(b2, "fixed") + b.apply_load(M1, 0, -2) + b.apply_load(R1, 0, -1) + b.apply_load(R2, 1, -1) + b.apply_load(R3, 4, -1) + b.apply_load(10, 3, -1) + b.bc_slope = [(0, 0)] + b.bc_deflection = [(0, 0), (1, 0), (4, 0)] + b.solve_for_reaction_loads(M1, R1, R2, R3) + assert b.slope() == -2*SingularityFunction(x, 0, 1)/5625 + SingularityFunction(x, 0, 2)/1875\ + - 133*SingularityFunction(x, 1, 2)/135000 + SingularityFunction(x, 3, 2)/1000\ + - 37*SingularityFunction(x, 4, 2)/67500 + assert b.deflection() == -SingularityFunction(x, 0, 2)/5625 + SingularityFunction(x, 0, 3)/5625\ + - 133*SingularityFunction(x, 1, 3)/405000 + SingularityFunction(x, 3, 3)/3000\ + - 37*SingularityFunction(x, 4, 3)/202500 + + +def test_point_cflexure(): + E = Symbol('E') + I = Symbol('I') + b = Beam(10, E, I) + b.apply_load(-4, 0, -1) + b.apply_load(-46, 6, -1) + b.apply_load(10, 2, -1) + b.apply_load(20, 4, -1) + b.apply_load(3, 6, 0) + assert b.point_cflexure() == [Rational(10, 3)] + + E = Symbol('E') + I = Symbol('I') + b = Beam(15, E, I) + r0 = b.apply_support(0, type='pin') + r10 = b.apply_support(10, type='pin') + r15, m15 = b.apply_support(15, type='fixed') + b.apply_rotation_hinge(12) + b.apply_load(-10, 5, -1) + b.apply_load(-5, 10, 0, 15) + b.solve_for_reaction_loads(r0, r10, r15, m15) + assert b.point_cflexure() == [Rational(1200, 163), 12, Rational(163, 12)] + + E = Symbol('E') + I = Symbol('I') + b = Beam(15, E, I) + r0 = b.apply_support(0, type='pin') + r10 = b.apply_support(10, type='pin') + r15, m15 = b.apply_support(15, type='fixed') + b.apply_rotation_hinge(5) + b.apply_rotation_hinge(12) + b.apply_load(-10, 5, -1) + b.apply_load(-5, 10, 0, 15) + b.solve_for_reaction_loads(r0, r10, r15, m15) + with raises(NotImplementedError): + b.point_cflexure() + +def test_remove_load(): + E = Symbol('E') + I = Symbol('I') + b = Beam(4, E, I) + + try: + b.remove_load(2, 1, -1) + # As no load is applied on beam, ValueError should be returned. + except ValueError: + assert True + else: + assert False + + b.apply_load(-3, 0, -2) + b.apply_load(4, 2, -1) + b.apply_load(-2, 2, 2, end = 3) + b.remove_load(-2, 2, 2, end = 3) + assert b.load == -3*SingularityFunction(x, 0, -2) + 4*SingularityFunction(x, 2, -1) + assert b.applied_loads == [(-3, 0, -2, None), (4, 2, -1, None)] + + try: + b.remove_load(1, 2, -1) + # As load of this magnitude was never applied at + # this position, method should return a ValueError. + except ValueError: + assert True + else: + assert False + + b.remove_load(-3, 0, -2) + b.remove_load(4, 2, -1) + assert b.load == 0 + assert b.applied_loads == [] + + +def test_apply_support(): + E = Symbol('E') + I = Symbol('I') + + b = Beam(4, E, I) + b.apply_support(0, "cantilever") + b.apply_load(20, 4, -1) + M_0, R_0 = symbols('M_0, R_0') + b.solve_for_reaction_loads(R_0, M_0) + assert simplify(b.slope()) == simplify((80*SingularityFunction(x, 0, 1) - 10*SingularityFunction(x, 0, 2) + + 10*SingularityFunction(x, 4, 2))/(E*I)) + assert simplify(b.deflection()) == simplify((40*SingularityFunction(x, 0, 2) - 10*SingularityFunction(x, 0, 3)/3 + + 10*SingularityFunction(x, 4, 3)/3)/(E*I)) + + b = Beam(30, E, I) + p0 = b.apply_support(10, "pin") + p1 = b.apply_support(30, "roller") + b.apply_load(-8, 0, -1) + b.apply_load(120, 30, -2) + b.solve_for_reaction_loads(p0, p1) + assert b.slope() == (-4*SingularityFunction(x, 0, 2) + 3*SingularityFunction(x, 10, 2) + + 120*SingularityFunction(x, 30, 1) + SingularityFunction(x, 30, 2) + Rational(4000, 3))/(E*I) + assert b.deflection() == (x*Rational(4000, 3) - 4*SingularityFunction(x, 0, 3)/3 + SingularityFunction(x, 10, 3) + + 60*SingularityFunction(x, 30, 2) + SingularityFunction(x, 30, 3)/3 - 12000)/(E*I) + R_10 = Symbol('R_10') + R_30 = Symbol('R_30') + assert p0 == R_10 + assert b.reaction_loads == {R_10: 6, R_30: 2} + assert b.reaction_loads[p0] == 6 + + b = Beam(8, E, I) + p0, m0 = b.apply_support(0, "fixed") + p1 = b.apply_support(8, "roller") + b.apply_load(-5, 0, 0, 8) + b.solve_for_reaction_loads(p0, m0, p1) + R_0 = Symbol('R_0') + M_0 = Symbol('M_0') + R_8 = Symbol('R_8') + assert p0 == R_0 + assert m0 == M_0 + assert p1 == R_8 + assert b.reaction_loads == {R_0: 25, M_0: -40, R_8: 15} + assert b.reaction_loads[m0] == -40 + + P = Symbol('P', positive=True) + L = Symbol('L', positive=True) + b = Beam(L, E, I) + b.apply_support(0, type='fixed') + b.apply_support(L, type='fixed') + b.apply_load(-P, L/2, -1) + R_0, R_L, M_0, M_L = symbols('R_0, R_L, M_0, M_L') + b.solve_for_reaction_loads(R_0, R_L, M_0, M_L) + assert b.reaction_loads == {R_0: P/2, R_L: P/2, M_0: -L*P/8, M_L: L*P/8} + +def test_apply_rotation_hinge(): + b = Beam(15, 20, 20) + r0, m0 = b.apply_support(0, type='fixed') + r10 = b.apply_support(10, type='pin') + r15 = b.apply_support(15, type='pin') + p7 = b.apply_rotation_hinge(7) + p12 = b.apply_rotation_hinge(12) + b.apply_load(-10, 7, -1) + b.apply_load(-2, 10, 0, 15) + b.solve_for_reaction_loads(r0, m0, r10, r15) + R_0, M_0, R_10, R_15, P_7, P_12 = symbols('R_0, M_0, R_10, R_15, P_7, P_12') + expected_reactions = {R_0: 20/3, M_0: -140/3, R_10: 31/3, R_15: 3} + expected_rotations = {P_7: 2281/2160, P_12: -5137/5184} + reaction_symbols = [r0, m0, r10, r15] + rotation_symbols = [p7, p12] + tolerance = 1e-6 + assert all(abs(b.reaction_loads[r] - expected_reactions[r]) < tolerance for r in reaction_symbols) + assert all(abs(b.rotation_jumps[r] - expected_rotations[r]) < tolerance for r in rotation_symbols) + expected_bending_moment = (140 * SingularityFunction(x, 0, 0) / 3 - 20 * SingularityFunction(x, 0, 1) / 3 + - 11405 * SingularityFunction(x, 7, -1) / 27 + 10 * SingularityFunction(x, 7, 1) + - 31 * SingularityFunction(x, 10, 1) / 3 + SingularityFunction(x, 10, 2) + + 128425 * SingularityFunction(x, 12, -1) / 324 - 3 * SingularityFunction(x, 15, 1) + - SingularityFunction(x, 15, 2)) + assert b.bending_moment().expand() == expected_bending_moment.expand() + expected_slope = (-7*SingularityFunction(x, 0, 1)/60 + SingularityFunction(x, 0, 2)/120 + + 2281*SingularityFunction(x, 7, 0)/2160 - SingularityFunction(x, 7, 2)/80 + + 31*SingularityFunction(x, 10, 2)/2400 - SingularityFunction(x, 10, 3)/1200 + - 5137*SingularityFunction(x, 12, 0)/5184 + 3*SingularityFunction(x, 15, 2)/800 + + SingularityFunction(x, 15, 3)/1200) + assert b.slope().expand() == expected_slope.expand() + expected_deflection = (-7 * SingularityFunction(x, 0, 2) / 120 + SingularityFunction(x, 0, 3) / 360 + + 2281 * SingularityFunction(x, 7, 1) / 2160 - SingularityFunction(x, 7, 3) / 240 + + 31 * SingularityFunction(x, 10, 3) / 7200 - SingularityFunction(x, 10, 4) / 4800 + - 5137 * SingularityFunction(x, 12, 1) / 5184 + SingularityFunction(x, 15, 3) / 800 + + SingularityFunction(x, 15, 4) / 4800) + assert b.deflection().expand() == expected_deflection.expand() + + E = Symbol('E') + I = Symbol('I') + F = Symbol('F') + b = Beam(10, E, I) + r0, m0 = b.apply_support(0, type="fixed") + r10 = b.apply_support(10, type="pin") + b.apply_rotation_hinge(6) + b.apply_load(F, 8, -1) + b.solve_for_reaction_loads(r0, m0, r10) + assert b.reaction_loads == {R_0: -F/2, M_0: 3*F, R_10: -F/2} + assert (b.bending_moment() == -3*F*SingularityFunction(x, 0, 0) + F*SingularityFunction(x, 0, 1)/2 + + 17*F*SingularityFunction(x, 6, -1) - F*SingularityFunction(x, 8, 1) + + F*SingularityFunction(x, 10, 1)/2) + expected_deflection = -(-3*F*SingularityFunction(x, 0, 2)/2 + F*SingularityFunction(x, 0, 3)/12 + + 17*F*SingularityFunction(x, 6, 1) - F*SingularityFunction(x, 8, 3)/6 + + F*SingularityFunction(x, 10, 3)/12)/(E*I) + assert b.deflection().expand() == expected_deflection.expand() + + E = Symbol('E') + I = Symbol('I') + F = Symbol('F') + l1 = Symbol('l1', positive=True) + l2 = Symbol('l2', positive=True) + l3 = Symbol('l3', positive=True) + L = l1 + l2 + l3 + b = Beam(L, E, I) + r0, m0 = b.apply_support(0, type="fixed") + r1 = b.apply_support(L, type="pin") + b.apply_rotation_hinge(l1) + b.apply_load(F, l1+l2, -1) + b.solve_for_reaction_loads(r0, m0, r1) + assert b.reaction_loads[r0] == -F*l3/(l2 + l3) + assert b.reaction_loads[m0] == F*l1*l3/(l2 + l3) + assert b.reaction_loads[r1] == -F*l2/(l2 + l3) + expected_bending_moment = (-F*l1*l3*SingularityFunction(x, 0, 0)/(l2 + l3) + + F*l2*SingularityFunction(x, l1 + l2 + l3, 1)/(l2 + l3) + + F*l3*SingularityFunction(x, 0, 1)/(l2 + l3) - F*SingularityFunction(x, l1 + l2, 1) + - (-2*F*l1**3*l3 - 3*F*l1**2*l2*l3 - 3*F*l1**2*l3**2 + F*l2**3*l3 + 3*F*l2**2*l3**2 + 2*F*l2*l3**3) + *SingularityFunction(x, l1, -1)/(6*l2**2 + 12*l2*l3 + 6*l3**2)) + assert simplify(b.bending_moment().expand()) == simplify(expected_bending_moment.expand()) + +def test_apply_sliding_hinge(): + b = Beam(13, 20, 20) + r0, m0 = b.apply_support(0, type="fixed") + w8 = b.apply_sliding_hinge(8) + r13 = b.apply_support(13, type="pin") + b.apply_load(-10, 5, -1) + b.solve_for_reaction_loads(r0, m0, r13) + R_0, M_0, R_13, W_8 = symbols('R_0, M_0, R_13 W_8') + assert b.reaction_loads == {R_0: 10, M_0: -50, R_13: 0} + tolerance = 1e-6 + assert abs(b.deflection_jumps[w8] - 85/24) < tolerance + assert (b.bending_moment() == 50*SingularityFunction(x, 0, 0) - 10*SingularityFunction(x, 0, 1) + + 10*SingularityFunction(x, 5, 1) - 4250*SingularityFunction(x, 8, -2)/3) + assert (b.deflection() == -SingularityFunction(x, 0, 2)/16 + SingularityFunction(x, 0, 3)/240 + - SingularityFunction(x, 5, 3)/240 + 85*SingularityFunction(x, 8, 0)/24) + + E = Symbol('E') + I = Symbol('I') + I2 = Symbol('I2') + b1 = Beam(5, E, I) + b2 = Beam(8, E, I2) + b = b1.join(b2) + r0, m0 = b.apply_support(0, type="fixed") + b.apply_sliding_hinge(8) + r13 = b.apply_support(13, type="pin") + b.apply_load(-10, 5, -1) + b.solve_for_reaction_loads(r0, m0, r13) + W_8 = Symbol('W_8') + assert b.deflection_jumps == {W_8: 4250/(3*E*I2)} + + E = Symbol('E') + I = Symbol('I') + q = Symbol('q') + l1 = Symbol('l1', positive=True) + l2 = Symbol('l2', positive=True) + l3 = Symbol('l3', positive=True) + L = l1 + l2 + l3 + b = Beam(L, E, I) + r0 = b.apply_support(0, type="pin") + r3 = b.apply_support(l1, type="pin") + b.apply_sliding_hinge(l1 + l2) + r10 = b.apply_support(L, type="pin") + b.apply_load(q, 0, 0, l1) + b.solve_for_reaction_loads(r0, r3, r10) + assert (b.bending_moment() == l1*q*SingularityFunction(x, 0, 1)/2 + l1*q*SingularityFunction(x, l1, 1)/2 + - q*SingularityFunction(x, 0, 2)/2 + q*SingularityFunction(x, l1, 2)/2 + + (-l1**3*l2*q/24 - l1**3*l3*q/24)*SingularityFunction(x, l1 + l2, -2)) + assert b.deflection() ==(l1**3*q*x/24 - l1*q*SingularityFunction(x, 0, 3)/12 + - l1*q*SingularityFunction(x, l1, 3)/12 + q*SingularityFunction(x, 0, 4)/24 + - q*SingularityFunction(x, l1, 4)/24 + + (l1**3*l2*q/24 + l1**3*l3*q/24)*SingularityFunction(x, l1 + l2, 0))/(E*I) + +def test_max_shear_force(): + E = Symbol('E') + I = Symbol('I') + + b = Beam(3, E, I) + R, M = symbols('R, M') + b.apply_load(R, 0, -1) + b.apply_load(M, 0, -2) + b.apply_load(2, 3, -1) + b.apply_load(4, 2, -1) + b.apply_load(2, 2, 0, end=3) + b.solve_for_reaction_loads(R, M) + assert b.max_shear_force() == (Interval(0, 2), 8) + + l = symbols('l', positive=True) + P = Symbol('P') + b = Beam(l, E, I) + R1, R2 = symbols('R1, R2') + b.apply_load(R1, 0, -1) + b.apply_load(R2, l, -1) + b.apply_load(P, 0, 0, end=l) + b.solve_for_reaction_loads(R1, R2) + max_shear = b.max_shear_force() + assert max_shear[0] == 0 + assert simplify(max_shear[1] - (l*Abs(P)/2)) == 0 + + +def test_max_bmoment(): + E = Symbol('E') + I = Symbol('I') + l, P = symbols('l, P', positive=True) + + b = Beam(l, E, I) + R1, R2 = symbols('R1, R2') + b.apply_load(R1, 0, -1) + b.apply_load(R2, l, -1) + b.apply_load(P, l/2, -1) + b.solve_for_reaction_loads(R1, R2) + b.reaction_loads + assert b.max_bmoment() == (l/2, P*l/4) + + b = Beam(l, E, I) + R1, R2 = symbols('R1, R2') + b.apply_load(R1, 0, -1) + b.apply_load(R2, l, -1) + b.apply_load(P, 0, 0, end=l) + b.solve_for_reaction_loads(R1, R2) + assert b.max_bmoment() == (l/2, P*l**2/8) + + +def test_max_deflection(): + E, I, l, F = symbols('E, I, l, F', positive=True) + b = Beam(l, E, I) + b.bc_deflection = [(0, 0),(l, 0)] + b.bc_slope = [(0, 0),(l, 0)] + b.apply_load(F/2, 0, -1) + b.apply_load(-F*l/8, 0, -2) + b.apply_load(F/2, l, -1) + b.apply_load(F*l/8, l, -2) + b.apply_load(-F, l/2, -1) + assert b.max_deflection() == (l/2, F*l**3/(192*E*I)) + +def test_solve_for_ild_reactions(): + E = Symbol('E') + I = Symbol('I') + b = Beam(10, E, I) + b.apply_support(0, type="pin") + b.apply_support(10, type="pin") + R_0, R_10 = symbols('R_0, R_10') + b.solve_for_ild_reactions(1, R_0, R_10) + a = b.ild_variable + assert b.ild_reactions == {R_0: -SingularityFunction(a, 0, 0) + SingularityFunction(a, 0, 1)/10 + - SingularityFunction(a, 10, 1)/10, + R_10: -SingularityFunction(a, 0, 1)/10 + SingularityFunction(a, 10, 0) + + SingularityFunction(a, 10, 1)/10} + + E = Symbol('E') + I = Symbol('I') + F = Symbol('F') + L = Symbol('L', positive=True) + b = Beam(L, E, I) + b.apply_support(L, type="fixed") + b.apply_load(F, 0, -1) + R_L, M_L = symbols('R_L, M_L') + b.solve_for_ild_reactions(F, R_L, M_L) + a = b.ild_variable + assert b.ild_reactions == {R_L: -F*SingularityFunction(a, 0, 0) + F*SingularityFunction(a, L, 0) - F, + M_L: -F*L*SingularityFunction(a, 0, 0) - F*L + F*SingularityFunction(a, 0, 1) + - F*SingularityFunction(a, L, 1)} + + E = Symbol('E') + I = Symbol('I') + b = Beam(20, E, I) + r0 = b.apply_support(0, type="pin") + r5 = b.apply_support(5, type="pin") + r10 = b.apply_support(10, type="pin") + r20, m20 = b.apply_support(20, type="fixed") + b.solve_for_ild_reactions(1, r0, r5, r10, r20, m20) + a = b.ild_variable + assert b.ild_reactions[r0].subs(a, 4) == -Rational(59, 475) + assert b.ild_reactions[r5].subs(a, 4) == -Rational(2296, 2375) + assert b.ild_reactions[r10].subs(a, 4) == Rational(243, 2375) + assert b.ild_reactions[r20].subs(a, 12) == -Rational(83, 475) + assert b.ild_reactions[m20].subs(a, 12) == -Rational(264, 475) + +def test_solve_for_ild_shear(): + E = Symbol('E') + I = Symbol('I') + F = Symbol('F') + L1 = Symbol('L1', positive=True) + L2 = Symbol('L2', positive=True) + b = Beam(L1 + L2, E, I) + r0 = b.apply_support(0, type="pin") + rL = b.apply_support(L1 + L2, type="pin") + b.solve_for_ild_reactions(F, r0, rL) + b.solve_for_ild_shear(L1, F, r0, rL) + a = b.ild_variable + expected_shear = (-F*L1*SingularityFunction(a, 0, 0)/(L1 + L2) - F*L2*SingularityFunction(a, 0, 0)/(L1 + L2) + - F*SingularityFunction(-a, 0, 0) + F*SingularityFunction(a, L1 + L2, 0) + F + + F*SingularityFunction(a, 0, 1)/(L1 + L2) - F*SingularityFunction(a, L1 + L2, 1)/(L1 + L2) + - (-F*L1*SingularityFunction(a, 0, 0)/(L1 + L2) + F*L1*SingularityFunction(a, L1 + L2, 0)/(L1 + L2) + - F*L2*SingularityFunction(a, 0, 0)/(L1 + L2) + F*L2*SingularityFunction(a, L1 + L2, 0)/(L1 + L2) + + 2*F)*SingularityFunction(a, L1, 0)) + assert b.ild_shear.expand() == expected_shear.expand() + + E = Symbol('E') + I = Symbol('I') + b = Beam(20, E, I) + r0 = b.apply_support(0, type="pin") + r5 = b.apply_support(5, type="pin") + r10 = b.apply_support(10, type="pin") + r20, m20 = b.apply_support(20, type="fixed") + b.solve_for_ild_reactions(1, r0, r5, r10, r20, m20) + b.solve_for_ild_shear(6, 1, r0, r5, r10, r20, m20) + a = b.ild_variable + assert b.ild_shear.subs(a, 12) == Rational(96, 475) + assert b.ild_shear.subs(a, 4) == -Rational(216, 2375) + +def test_solve_for_ild_moment(): + E = Symbol('E') + I = Symbol('I') + F = Symbol('F') + L1 = Symbol('L1', positive=True) + L2 = Symbol('L2', positive=True) + b = Beam(L1 + L2, E, I) + r0 = b.apply_support(0, type="pin") + rL = b.apply_support(L1 + L2, type="pin") + a = b.ild_variable + b.solve_for_ild_reactions(F, r0, rL) + b.solve_for_ild_moment(L1, F, r0, rL) + assert b.ild_moment.subs(a, 3).subs(L1, 5).subs(L2, 5) == -3*F/2 + + E = Symbol('E') + I = Symbol('I') + b = Beam(20, E, I) + r0 = b.apply_support(0, type="pin") + r5 = b.apply_support(5, type="pin") + r10 = b.apply_support(10, type="pin") + r20, m20 = b.apply_support(20, type="fixed") + b.solve_for_ild_reactions(1, r0, r5, r10, r20, m20) + b.solve_for_ild_moment(5, 1, r0, r5, r10, r20, m20) + assert b.ild_moment.subs(a, 12) == -Rational(96, 475) + assert b.ild_moment.subs(a, 4) == Rational(36, 95) + +def test_ild_with_rotation_hinge(): + E = Symbol('E') + I = Symbol('I') + F = Symbol('F') + L1 = Symbol('L1', positive=True) + L2 = Symbol('L2', positive=True) + L3 = Symbol('L3', positive=True) + b = Beam(L1 + L2 + L3, E, I) + r0 = b.apply_support(0, type="pin") + r1 = b.apply_support(L1 + L2, type="pin") + r2 = b.apply_support(L1 + L2 + L3, type="pin") + b.apply_rotation_hinge(L1 + L2) + b.solve_for_ild_reactions(F, r0, r1, r2) + a = b.ild_variable + assert b.ild_reactions[r0].subs(a, 4).subs(L1, 5).subs(L2, 5).subs(L3, 10) == -3*F/5 + assert b.ild_reactions[r0].subs(a, -10).subs(L1, 5).subs(L2, 5).subs(L3, 10) == 0 + assert b.ild_reactions[r0].subs(a, 25).subs(L1, 5).subs(L2, 5).subs(L3, 10) == 0 + assert b.ild_reactions[r1].subs(a, 4).subs(L1, 5).subs(L2, 5).subs(L3, 10) == -2*F/5 + assert b.ild_reactions[r2].subs(a, 18).subs(L1, 5).subs(L2, 5).subs(L3, 10) == -4*F/5 + b.solve_for_ild_shear(L1, F, r0, r1, r2) + assert b.ild_shear.subs(a, 7).subs(L1, 5).subs(L2, 5).subs(L3, 10) == -3*F/10 + assert b.ild_shear.subs(a, 70).subs(L1, 5).subs(L2, 5).subs(L3, 10) == 0 + b.solve_for_ild_moment(L1, F, r0, r1, r2) + assert b.ild_moment.subs(a, 1).subs(L1, 5).subs(L2, 5).subs(L3, 10) == -F/2 + assert b.ild_moment.subs(a, 8).subs(L1, 5).subs(L2, 5).subs(L3, 10) == -F + +def test_ild_with_sliding_hinge(): + b = Beam(13, 200, 200) + r0 = b.apply_support(0, type="pin") + r6 = b.apply_support(6, type="pin") + r13, m13 = b.apply_support(13, type="fixed") + w3 = b.apply_sliding_hinge(3) + b.solve_for_ild_reactions(1, r0, r6, r13, m13) + a = b.ild_variable + assert b.ild_reactions[r0].subs(a, 3) == -1 + assert b.ild_reactions[r6].subs(a, 3) == Rational(9, 14) + assert b.ild_reactions[r13].subs(a, 9) == -Rational(207, 343) + assert b.ild_reactions[m13].subs(a, 9) == -Rational(60, 49) + assert b.ild_reactions[m13].subs(a, 15) == 0 + assert b.ild_reactions[m13].subs(a, -3) == 0 + assert b.ild_deflection_jumps[w3].subs(a, 9) == -Rational(9, 35000) + b.solve_for_ild_shear(7, 1, r0, r6, r13, m13) + assert b.ild_shear.subs(a, 8) == -Rational(200, 343) + b.solve_for_ild_moment(8, 1, r0, r6, r13, m13) + assert b.ild_moment.subs(a, 3) == -Rational(12, 7) + +def test_Beam3D(): + l, E, G, I, A = symbols('l, E, G, I, A') + R1, R2, R3, R4 = symbols('R1, R2, R3, R4') + + b = Beam3D(l, E, G, I, A) + m, q = symbols('m, q') + b.apply_load(q, 0, 0, dir="y") + b.apply_moment_load(m, 0, 0, dir="z") + b.bc_slope = [(0, [0, 0, 0]), (l, [0, 0, 0])] + b.bc_deflection = [(0, [0, 0, 0]), (l, [0, 0, 0])] + b.solve_slope_deflection() + + assert b.polar_moment() == 2*I + assert b.shear_force() == [0, -q*x, 0] + assert b.shear_stress() == [0, -q*x/A, 0] + assert b.axial_stress() == 0 + assert b.bending_moment() == [0, 0, -m*x + q*x**2/2] + expected_deflection = (x*(A*G*q*x**3/4 + A*G*x**2*(-l*(A*G*l*(l*q - 2*m) + + 12*E*I*q)/(A*G*l**2 + 12*E*I)/2 - m) + 3*E*I*l*(A*G*l*(l*q - 2*m) + + 12*E*I*q)/(A*G*l**2 + 12*E*I) + x*(-A*G*l**2*q/2 + + 3*A*G*l**2*(A*G*l*(l*q - 2*m) + 12*E*I*q)/(A*G*l**2 + 12*E*I)/4 + + A*G*l*m*Rational(3, 2) - 3*E*I*q))/(6*A*E*G*I)) + dx, dy, dz = b.deflection() + assert dx == dz == 0 + assert simplify(dy - expected_deflection) == 0 + + b2 = Beam3D(30, E, G, I, A, x) + b2.apply_load(50, start=0, order=0, dir="y") + b2.bc_deflection = [(0, [0, 0, 0]), (30, [0, 0, 0])] + b2.apply_load(R1, start=0, order=-1, dir="y") + b2.apply_load(R2, start=30, order=-1, dir="y") + b2.solve_for_reaction_loads(R1, R2) + assert b2.reaction_loads == {R1: -750, R2: -750} + + b2.solve_slope_deflection() + assert b2.slope() == [0, 0, 25*x**3/(3*E*I) - 375*x**2/(E*I) + 3750*x/(E*I)] + expected_deflection = 25*x**4/(12*E*I) - 125*x**3/(E*I) + 1875*x**2/(E*I) - \ + 25*x**2/(A*G) + 750*x/(A*G) + dx, dy, dz = b2.deflection() + assert dx == dz == 0 + assert dy == expected_deflection + + # Test for solve_for_reaction_loads + b3 = Beam3D(30, E, G, I, A, x) + b3.apply_load(8, start=0, order=0, dir="y") + b3.apply_load(9*x, start=0, order=0, dir="z") + b3.apply_load(R1, start=0, order=-1, dir="y") + b3.apply_load(R2, start=30, order=-1, dir="y") + b3.apply_load(R3, start=0, order=-1, dir="z") + b3.apply_load(R4, start=30, order=-1, dir="z") + b3.solve_for_reaction_loads(R1, R2, R3, R4) + assert b3.reaction_loads == {R1: -120, R2: -120, R3: -1350, R4: -2700} + + +def test_polar_moment_Beam3D(): + l, E, G, A, I1, I2 = symbols('l, E, G, A, I1, I2') + I = [I1, I2] + + b = Beam3D(l, E, G, I, A) + assert b.polar_moment() == I1 + I2 + + +def test_parabolic_loads(): + + E, I, L = symbols('E, I, L', positive=True, real=True) + R, M, P = symbols('R, M, P', real=True) + + # cantilever beam fixed at x=0 and parabolic distributed loading across + # length of beam + beam = Beam(L, E, I) + + beam.bc_deflection.append((0, 0)) + beam.bc_slope.append((0, 0)) + beam.apply_load(R, 0, -1) + beam.apply_load(M, 0, -2) + + # parabolic load + beam.apply_load(1, 0, 2) + + beam.solve_for_reaction_loads(R, M) + + assert beam.reaction_loads[R] == -L**3/3 + + # cantilever beam fixed at x=0 and parabolic distributed loading across + # first half of beam + beam = Beam(2*L, E, I) + + beam.bc_deflection.append((0, 0)) + beam.bc_slope.append((0, 0)) + beam.apply_load(R, 0, -1) + beam.apply_load(M, 0, -2) + + # parabolic load from x=0 to x=L + beam.apply_load(1, 0, 2, end=L) + + beam.solve_for_reaction_loads(R, M) + + # result should be the same as the prior example + assert beam.reaction_loads[R] == -L**3/3 + + # check constant load + beam = Beam(2*L, E, I) + beam.apply_load(P, 0, 0, end=L) + loading = beam.load.xreplace({L: 10, E: 20, I: 30, P: 40}) + assert loading.xreplace({x: 5}) == 40 + assert loading.xreplace({x: 15}) == 0 + + # check ramp load + beam = Beam(2*L, E, I) + beam.apply_load(P, 0, 1, end=L) + assert beam.load == (P*SingularityFunction(x, 0, 1) - + P*SingularityFunction(x, L, 1) - + P*L*SingularityFunction(x, L, 0)) + + # check higher order load: x**8 load from x=0 to x=L + beam = Beam(2*L, E, I) + beam.apply_load(P, 0, 8, end=L) + loading = beam.load.xreplace({L: 10, E: 20, I: 30, P: 40}) + assert loading.xreplace({x: 5}) == 40*5**8 + assert loading.xreplace({x: 15}) == 0 + + +def test_cross_section(): + I = Symbol('I') + l = Symbol('l') + E = Symbol('E') + C3, C4 = symbols('C3, C4') + a, c, g, h, r, n = symbols('a, c, g, h, r, n') + + # test for second_moment and cross_section setter + b0 = Beam(l, E, I) + assert b0.second_moment == I + assert b0.cross_section == None + b0.cross_section = Circle((0, 0), 5) + assert b0.second_moment == pi*Rational(625, 4) + assert b0.cross_section == Circle((0, 0), 5) + b0.second_moment = 2*n - 6 + assert b0.second_moment == 2*n-6 + assert b0.cross_section == None + with raises(ValueError): + b0.second_moment = Circle((0, 0), 5) + + # beam with a circular cross-section + b1 = Beam(50, E, Circle((0, 0), r)) + assert b1.cross_section == Circle((0, 0), r) + assert b1.second_moment == pi*r*Abs(r)**3/4 + + b1.apply_load(-10, 0, -1) + b1.apply_load(R1, 5, -1) + b1.apply_load(R2, 50, -1) + b1.apply_load(90, 45, -2) + b1.solve_for_reaction_loads(R1, R2) + assert b1.load == (-10*SingularityFunction(x, 0, -1) + 82*SingularityFunction(x, 5, -1)/S(9) + + 90*SingularityFunction(x, 45, -2) + 8*SingularityFunction(x, 50, -1)/9) + assert b1.bending_moment() == (10*SingularityFunction(x, 0, 1) - 82*SingularityFunction(x, 5, 1)/9 + - 90*SingularityFunction(x, 45, 0) - 8*SingularityFunction(x, 50, 1)/9) + q = (-5*SingularityFunction(x, 0, 2) + 41*SingularityFunction(x, 5, 2)/S(9) + + 90*SingularityFunction(x, 45, 1) + 4*SingularityFunction(x, 50, 2)/S(9))/(pi*E*r*Abs(r)**3) + assert b1.slope() == C3 + 4*q + q = (-5*SingularityFunction(x, 0, 3)/3 + 41*SingularityFunction(x, 5, 3)/27 + 45*SingularityFunction(x, 45, 2) + + 4*SingularityFunction(x, 50, 3)/27)/(pi*E*r*Abs(r)**3) + assert b1.deflection() == C3*x + C4 + 4*q + + # beam with a recatangular cross-section + b2 = Beam(20, E, Polygon((0, 0), (a, 0), (a, c), (0, c))) + assert b2.cross_section == Polygon((0, 0), (a, 0), (a, c), (0, c)) + assert b2.second_moment == a*c**3/12 + # beam with a triangular cross-section + b3 = Beam(15, E, Triangle((0, 0), (g, 0), (g/2, h))) + assert b3.cross_section == Triangle(Point2D(0, 0), Point2D(g, 0), Point2D(g/2, h)) + assert b3.second_moment == g*h**3/36 + + # composite beam + b = b2.join(b3, "fixed") + b.apply_load(-30, 0, -1) + b.apply_load(65, 0, -2) + b.apply_load(40, 0, -1) + b.bc_slope = [(0, 0)] + b.bc_deflection = [(0, 0)] + + assert b.second_moment == Piecewise((a*c**3/12, x <= 20), (g*h**3/36, x <= 35)) + assert b.cross_section == None + assert b.length == 35 + assert b.slope().subs(x, 7) == 8400/(E*a*c**3) + assert b.slope().subs(x, 25) == 52200/(E*g*h**3) + 39600/(E*a*c**3) + assert b.deflection().subs(x, 30) == -537000/(E*g*h**3) - 712000/(E*a*c**3) + +def test_max_shear_force_Beam3D(): + x = symbols('x') + b = Beam3D(20, 40, 21, 100, 25) + b.apply_load(15, start=0, order=0, dir="z") + b.apply_load(12*x, start=0, order=0, dir="y") + b.bc_deflection = [(0, [0, 0, 0]), (20, [0, 0, 0])] + assert b.max_shear_force() == [(0, 0), (20, 2400), (20, 300)] + +def test_max_bending_moment_Beam3D(): + x = symbols('x') + b = Beam3D(20, 40, 21, 100, 25) + b.apply_load(15, start=0, order=0, dir="z") + b.apply_load(12*x, start=0, order=0, dir="y") + b.bc_deflection = [(0, [0, 0, 0]), (20, [0, 0, 0])] + assert b.max_bmoment() == [(0, 0), (20, 3000), (20, 16000)] + +def test_max_deflection_Beam3D(): + x = symbols('x') + b = Beam3D(20, 40, 21, 100, 25) + b.apply_load(15, start=0, order=0, dir="z") + b.apply_load(12*x, start=0, order=0, dir="y") + b.bc_deflection = [(0, [0, 0, 0]), (20, [0, 0, 0])] + b.solve_slope_deflection() + c = sympify("495/14") + p = sympify("-10 + 10*sqrt(10793)/43") + q = sympify("(10 - 10*sqrt(10793)/43)**3/160 - 20/7 + (10 - 10*sqrt(10793)/43)**4/6400 + 20*sqrt(10793)/301 + 27*(10 - 10*sqrt(10793)/43)**2/560") + assert b.max_deflection() == [(0, 0), (10, c), (p, q)] + +def test_torsion_Beam3D(): + x = symbols('x') + b = Beam3D(20, 40, 21, 100, 25) + b.apply_moment_load(15, 5, -2, dir='x') + b.apply_moment_load(25, 10, -2, dir='x') + b.apply_moment_load(-5, 20, -2, dir='x') + b.solve_for_torsion() + assert b.angular_deflection().subs(x, 3) == sympify("1/40") + assert b.angular_deflection().subs(x, 9) == sympify("17/280") + assert b.angular_deflection().subs(x, 12) == sympify("53/840") + assert b.angular_deflection().subs(x, 17) == sympify("2/35") + assert b.angular_deflection().subs(x, 20) == sympify("3/56") diff --git a/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/tests/test_cable.py b/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/tests/test_cable.py new file mode 100644 index 0000000000000000000000000000000000000000..95ae7997af20f31cbd1b36df4a494f66968ecf53 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/tests/test_cable.py @@ -0,0 +1,83 @@ +from sympy.physics.continuum_mechanics.cable import Cable +from sympy.core.symbol import Symbol + + +def test_cable(): + c = Cable(('A', 0, 10), ('B', 10, 10)) + assert c.supports == {'A': [0, 10], 'B': [10, 10]} + assert c.left_support == [0, 10] + assert c.right_support == [10, 10] + assert c.loads == {'distributed': {}, 'point_load': {}} + assert c.loads_position == {} + assert c.length == 0 + assert c.reaction_loads == {Symbol("R_A_x"): 0, Symbol("R_A_y"): 0, Symbol("R_B_x"): 0, Symbol("R_B_y"): 0} + + # tests for change_support method + c.change_support('A', ('C', 12, 3)) + assert c.supports == {'B': [10, 10], 'C': [12, 3]} + assert c.left_support == [10, 10] + assert c.right_support == [12, 3] + assert c.reaction_loads == {Symbol("R_B_x"): 0, Symbol("R_B_y"): 0, Symbol("R_C_x"): 0, Symbol("R_C_y"): 0} + + c.change_support('C', ('A', 0, 10)) + + # tests for apply_load method for point loads + c.apply_load(-1, ('X', 2, 5, 3, 30)) + c.apply_load(-1, ('Y', 5, 8, 5, 60)) + assert c.loads == {'distributed': {}, 'point_load': {'X': [3, 30], 'Y': [5, 60]}} + assert c.loads_position == {'X': [2, 5], 'Y': [5, 8]} + assert c.length == 0 + assert c.reaction_loads == {Symbol("R_A_x"): 0, Symbol("R_A_y"): 0, Symbol("R_B_x"): 0, Symbol("R_B_y"): 0} + + # tests for remove_loads method + c.remove_loads('X') + assert c.loads == {'distributed': {}, 'point_load': {'Y': [5, 60]}} + assert c.loads_position == {'Y': [5, 8]} + assert c.length == 0 + assert c.reaction_loads == {Symbol("R_A_x"): 0, Symbol("R_A_y"): 0, Symbol("R_B_x"): 0, Symbol("R_B_y"): 0} + + c.remove_loads('Y') + + #tests for apply_load method for distributed load + c.apply_load(0, ('Z', 9)) + assert c.loads == {'distributed': {'Z': 9}, 'point_load': {}} + assert c.loads_position == {} + assert c.length == 0 + assert c.reaction_loads == {Symbol("R_A_x"): 0, Symbol("R_A_y"): 0, Symbol("R_B_x"): 0, Symbol("R_B_y"): 0} + + # tests for apply_length method + c.apply_length(20) + assert c.length == 20 + + del c + # tests for solve method + # for point loads + c = Cable(("A", 0, 10), ("B", 5.5, 8)) + c.apply_load(-1, ('Z', 2, 7.26, 3, 270)) + c.apply_load(-1, ('X', 4, 6, 8, 270)) + c.solve() + #assert c.tension == {Symbol("Z_X"): 4.79150773600774, Symbol("X_B"): 6.78571428571429, Symbol("A_Z"): 6.89488895397307} + assert abs(c.tension[Symbol("A_Z")] - 6.89488895397307) < 10e-12 + assert abs(c.tension[Symbol("Z_X")] - 4.79150773600774) < 10e-12 + assert abs(c.tension[Symbol("X_B")] - 6.78571428571429) < 10e-12 + #assert c.reaction_loads == {Symbol("R_A_x"): -4.06504065040650, Symbol("R_A_y"): 5.56910569105691, Symbol("R_B_x"): 4.06504065040650, Symbol("R_B_y"): 5.43089430894309} + assert abs(c.reaction_loads[Symbol("R_A_x")] + 4.06504065040650) < 10e-12 + assert abs(c.reaction_loads[Symbol("R_A_y")] - 5.56910569105691) < 10e-12 + assert abs(c.reaction_loads[Symbol("R_B_x")] - 4.06504065040650) < 10e-12 + assert abs(c.reaction_loads[Symbol("R_B_y")] - 5.43089430894309) < 10e-12 + assert abs(c.length - 8.25609584845190) < 10e-12 + + del c + # tests for solve method + # for distributed loads + c=Cable(("A", 0, 40),("B", 100, 20)) + c.apply_load(0, ("X", 850)) + c.solve(58.58, 0) + + # assert c.tension['distributed'] == 36456.8485*sqrt(0.000543529004799705*(X + 0.00135624381275735)**2 + 1) + assert abs(c.tension_at(0) - 61717.4130533677) < 10e-11 + assert abs(c.tension_at(40) - 39738.0809048449) < 10e-11 + assert abs(c.reaction_loads[Symbol("R_A_x")] - 36465.0000000000) < 10e-11 + assert abs(c.reaction_loads[Symbol("R_A_y")] + 49793.0000000000) < 10e-11 + assert abs(c.reaction_loads[Symbol("R_B_x")] - 44399.9537590861) < 10e-11 + assert abs(c.reaction_loads[Symbol("R_B_y")] - 42868.2071025955 ) < 10e-11 diff --git a/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/tests/test_truss.py b/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/tests/test_truss.py new file mode 100644 index 0000000000000000000000000000000000000000..61c89c9e09386257c7c69909dfdb0f37cda8627d --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/tests/test_truss.py @@ -0,0 +1,100 @@ +from sympy.core.symbol import Symbol, symbols +from sympy.physics.continuum_mechanics.truss import Truss +from sympy import sqrt + + +def test_truss(): + A = Symbol('A') + B = Symbol('B') + C = Symbol('C') + AB, BC, AC = symbols('AB, BC, AC') + P = Symbol('P') + + t = Truss() + assert t.nodes == [] + assert t.node_labels == [] + assert t.node_positions == [] + assert t.members == {} + assert t.loads == {} + assert t.supports == {} + assert t.reaction_loads == {} + assert t.internal_forces == {} + + # testing the add_node method + t.add_node((A, 0, 0), (B, 2, 2), (C, 3, 0)) + assert t.nodes == [(A, 0, 0), (B, 2, 2), (C, 3, 0)] + assert t.node_labels == [A, B, C] + assert t.node_positions == [(0, 0), (2, 2), (3, 0)] + assert t.loads == {} + assert t.supports == {} + assert t.reaction_loads == {} + + # testing the remove_node method + t.remove_node(C) + assert t.nodes == [(A, 0, 0), (B, 2, 2)] + assert t.node_labels == [A, B] + assert t.node_positions == [(0, 0), (2, 2)] + assert t.loads == {} + assert t.supports == {} + + t.add_node((C, 3, 0)) + + # testing the add_member method + t.add_member((AB, A, B), (BC, B, C), (AC, A, C)) + assert t.members == {AB: [A, B], BC: [B, C], AC: [A, C]} + assert t.internal_forces == {AB: 0, BC: 0, AC: 0} + + # testing the remove_member method + t.remove_member(BC) + assert t.members == {AB: [A, B], AC: [A, C]} + assert t.internal_forces == {AB: 0, AC: 0} + + t.add_member((BC, B, C)) + + D, CD = symbols('D, CD') + + # testing the change_label methods + t.change_node_label((B, D)) + assert t.nodes == [(A, 0, 0), (D, 2, 2), (C, 3, 0)] + assert t.node_labels == [A, D, C] + assert t.loads == {} + assert t.supports == {} + assert t.members == {AB: [A, D], BC: [D, C], AC: [A, C]} + + t.change_member_label((BC, CD)) + assert t.members == {AB: [A, D], CD: [D, C], AC: [A, C]} + assert t.internal_forces == {AB: 0, CD: 0, AC: 0} + + + # testing the apply_load method + t.apply_load((A, P, 90), (A, P/4, 90), (A, 2*P,45), (D, P/2, 90)) + assert t.loads == {A: [[P, 90], [P/4, 90], [2*P, 45]], D: [[P/2, 90]]} + assert t.loads[A] == [[P, 90], [P/4, 90], [2*P, 45]] + + # testing the remove_load method + t.remove_load((A, P/4, 90)) + assert t.loads == {A: [[P, 90], [2*P, 45]], D: [[P/2, 90]]} + assert t.loads[A] == [[P, 90], [2*P, 45]] + + # testing the apply_support method + t.apply_support((A, "pinned"), (D, "roller")) + assert t.supports == {A: 'pinned', D: 'roller'} + assert t.reaction_loads == {} + assert t.loads == {A: [[P, 90], [2*P, 45], [Symbol('R_A_x'), 0], [Symbol('R_A_y'), 90]], D: [[P/2, 90], [Symbol('R_D_y'), 90]]} + + # testing the remove_support method + t.remove_support(A) + assert t.supports == {D: 'roller'} + assert t.reaction_loads == {} + assert t.loads == {A: [[P, 90], [2*P, 45]], D: [[P/2, 90], [Symbol('R_D_y'), 90]]} + + t.apply_support((A, "pinned")) + + # testing the solve method + t.solve() + assert t.reaction_loads['R_A_x'] == -sqrt(2)*P + assert t.reaction_loads['R_A_y'] == -sqrt(2)*P - P + assert t.reaction_loads['R_D_y'] == -P/2 + assert t.internal_forces[AB]/P == 0 + assert t.internal_forces[CD] == 0 + assert t.internal_forces[AC] == 0 diff --git a/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/truss.py b/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/truss.py new file mode 100644 index 0000000000000000000000000000000000000000..f7fd0ea3f5e18574f21e2f656477c7af987d8eb6 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/continuum_mechanics/truss.py @@ -0,0 +1,1108 @@ +""" +This module can be used to solve problems related +to 2D Trusses. +""" + + +from cmath import atan, inf +from sympy.core.add import Add +from sympy.core.evalf import INF +from sympy.core.mul import Mul +from sympy.core.symbol import Symbol +from sympy.core.sympify import sympify +from sympy import Matrix, pi +from sympy.external.importtools import import_module +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.matrices.dense import zeros +import math +from sympy.physics.units.quantities import Quantity +from sympy.plotting import plot +from sympy.utilities.decorator import doctest_depends_on +from sympy import sin, cos + + +__doctest_requires__ = {('Truss.draw'): ['matplotlib']} + + +numpy = import_module('numpy', import_kwargs={'fromlist':['arange']}) + + +class Truss: + """ + A Truss is an assembly of members such as beams, + connected by nodes, that create a rigid structure. + In engineering, a truss is a structure that + consists of two-force members only. + + Trusses are extremely important in engineering applications + and can be seen in numerous real-world applications like bridges. + + Examples + ======== + + There is a Truss consisting of four nodes and five + members connecting the nodes. A force P acts + downward on the node D and there also exist pinned + and roller joints on the nodes A and B respectively. + + .. image:: truss_example.png + + >>> from sympy.physics.continuum_mechanics.truss import Truss + >>> t = Truss() + >>> t.add_node(("node_1", 0, 0), ("node_2", 6, 0), ("node_3", 2, 2), ("node_4", 2, 0)) + >>> t.add_member(("member_1", "node_1", "node_4"), ("member_2", "node_2", "node_4"), ("member_3", "node_1", "node_3")) + >>> t.add_member(("member_4", "node_2", "node_3"), ("member_5", "node_3", "node_4")) + >>> t.apply_load(("node_4", 10, 270)) + >>> t.apply_support(("node_1", "pinned"), ("node_2", "roller")) + """ + + def __init__(self): + """ + Initializes the class + """ + self._nodes = [] + self._members = {} + self._loads = {} + self._supports = {} + self._node_labels = [] + self._node_positions = [] + self._node_position_x = [] + self._node_position_y = [] + self._nodes_occupied = {} + self._member_lengths = {} + self._reaction_loads = {} + self._internal_forces = {} + self._node_coordinates = {} + + @property + def nodes(self): + """ + Returns the nodes of the truss along with their positions. + """ + return self._nodes + + @property + def node_labels(self): + """ + Returns the node labels of the truss. + """ + return self._node_labels + + @property + def node_positions(self): + """ + Returns the positions of the nodes of the truss. + """ + return self._node_positions + + @property + def members(self): + """ + Returns the members of the truss along with the start and end points. + """ + return self._members + + @property + def member_lengths(self): + """ + Returns the length of each member of the truss. + """ + return self._member_lengths + + @property + def supports(self): + """ + Returns the nodes with provided supports along with the kind of support provided i.e. + pinned or roller. + """ + return self._supports + + @property + def loads(self): + """ + Returns the loads acting on the truss. + """ + return self._loads + + @property + def reaction_loads(self): + """ + Returns the reaction forces for all supports which are all initialized to 0. + """ + return self._reaction_loads + + @property + def internal_forces(self): + """ + Returns the internal forces for all members which are all initialized to 0. + """ + return self._internal_forces + + def add_node(self, *args): + """ + This method adds a node to the truss along with its name/label and its location. + Multiple nodes can be added at the same time. + + Parameters + ========== + The input(s) for this method are tuples of the form (label, x, y). + + label: String or a Symbol + The label for a node. It is the only way to identify a particular node. + + x: Sympifyable + The x-coordinate of the position of the node. + + y: Sympifyable + The y-coordinate of the position of the node. + + Examples + ======== + + >>> from sympy.physics.continuum_mechanics.truss import Truss + >>> t = Truss() + >>> t.add_node(('A', 0, 0)) + >>> t.nodes + [('A', 0, 0)] + >>> t.add_node(('B', 3, 0), ('C', 4, 1)) + >>> t.nodes + [('A', 0, 0), ('B', 3, 0), ('C', 4, 1)] + """ + + for i in args: + label = i[0] + x = i[1] + x = sympify(x) + y=i[2] + y = sympify(y) + if label in self._node_coordinates: + raise ValueError("Node needs to have a unique label") + + elif [x, y] in self._node_coordinates.values(): + raise ValueError("A node already exists at the given position") + + else : + self._nodes.append((label, x, y)) + self._node_labels.append(label) + self._node_positions.append((x, y)) + self._node_position_x.append(x) + self._node_position_y.append(y) + self._node_coordinates[label] = [x, y] + + + + def remove_node(self, *args): + """ + This method removes a node from the truss. + Multiple nodes can be removed at the same time. + + Parameters + ========== + The input(s) for this method are the labels of the nodes to be removed. + + label: String or Symbol + The label of the node to be removed. + + Examples + ======== + + >>> from sympy.physics.continuum_mechanics.truss import Truss + >>> t = Truss() + >>> t.add_node(('A', 0, 0), ('B', 3, 0), ('C', 5, 0)) + >>> t.nodes + [('A', 0, 0), ('B', 3, 0), ('C', 5, 0)] + >>> t.remove_node('A', 'C') + >>> t.nodes + [('B', 3, 0)] + """ + for label in args: + for i in range(len(self.nodes)): + if self._node_labels[i] == label: + x = self._node_position_x[i] + y = self._node_position_y[i] + + if label not in self._node_coordinates: + raise ValueError("No such node exists in the truss") + + else: + members_duplicate = self._members.copy() + for member in members_duplicate: + if label == self._members[member][0] or label == self._members[member][1]: + raise ValueError("The given node already has member attached to it") + self._nodes.remove((label, x, y)) + self._node_labels.remove(label) + self._node_positions.remove((x, y)) + self._node_position_x.remove(x) + self._node_position_y.remove(y) + if label in self._loads: + self._loads.pop(label) + if label in self._supports: + self._supports.pop(label) + self._node_coordinates.pop(label) + + + + def add_member(self, *args): + """ + This method adds a member between any two nodes in the given truss. + + Parameters + ========== + The input(s) of the method are tuple(s) of the form (label, start, end). + + label: String or Symbol + The label for a member. It is the only way to identify a particular member. + + start: String or Symbol + The label of the starting point/node of the member. + + end: String or Symbol + The label of the ending point/node of the member. + + Examples + ======== + + >>> from sympy.physics.continuum_mechanics.truss import Truss + >>> t = Truss() + >>> t.add_node(('A', 0, 0), ('B', 3, 0), ('C', 2, 2)) + >>> t.add_member(('AB', 'A', 'B'), ('BC', 'B', 'C')) + >>> t.members + {'AB': ['A', 'B'], 'BC': ['B', 'C']} + """ + for i in args: + label = i[0] + start = i[1] + end = i[2] + + if start not in self._node_coordinates or end not in self._node_coordinates or start==end: + raise ValueError("The start and end points of the member must be unique nodes") + + elif label in self._members: + raise ValueError("A member with the same label already exists for the truss") + + elif self._nodes_occupied.get((start, end)): + raise ValueError("A member already exists between the two nodes") + + else: + self._members[label] = [start, end] + self._member_lengths[label] = sqrt((self._node_coordinates[end][0]-self._node_coordinates[start][0])**2 + (self._node_coordinates[end][1]-self._node_coordinates[start][1])**2) + self._nodes_occupied[start, end] = True + self._nodes_occupied[end, start] = True + self._internal_forces[label] = 0 + + def remove_member(self, *args): + """ + This method removes members from the given truss. + + Parameters + ========== + labels: String or Symbol + The label for the member to be removed. + + Examples + ======== + + >>> from sympy.physics.continuum_mechanics.truss import Truss + >>> t = Truss() + >>> t.add_node(('A', 0, 0), ('B', 3, 0), ('C', 2, 2)) + >>> t.add_member(('AB', 'A', 'B'), ('AC', 'A', 'C'), ('BC', 'B', 'C')) + >>> t.members + {'AB': ['A', 'B'], 'AC': ['A', 'C'], 'BC': ['B', 'C']} + >>> t.remove_member('AC', 'BC') + >>> t.members + {'AB': ['A', 'B']} + """ + for label in args: + if label not in self._members: + raise ValueError("No such member exists in the Truss") + + else: + self._nodes_occupied.pop((self._members[label][0], self._members[label][1])) + self._nodes_occupied.pop((self._members[label][1], self._members[label][0])) + self._members.pop(label) + self._member_lengths.pop(label) + self._internal_forces.pop(label) + + def change_node_label(self, *args): + """ + This method changes the label(s) of the specified node(s). + + Parameters + ========== + The input(s) of this method are tuple(s) of the form (label, new_label). + + label: String or Symbol + The label of the node for which the label has + to be changed. + + new_label: String or Symbol + The new label of the node. + + Examples + ======== + + >>> from sympy.physics.continuum_mechanics.truss import Truss + >>> t = Truss() + >>> t.add_node(('A', 0, 0), ('B', 3, 0)) + >>> t.nodes + [('A', 0, 0), ('B', 3, 0)] + >>> t.change_node_label(('A', 'C'), ('B', 'D')) + >>> t.nodes + [('C', 0, 0), ('D', 3, 0)] + """ + for i in args: + label = i[0] + new_label = i[1] + if label not in self._node_coordinates: + raise ValueError("No such node exists for the Truss") + elif new_label in self._node_coordinates: + raise ValueError("A node with the given label already exists") + else: + for node in self._nodes: + if node[0] == label: + self._nodes[self._nodes.index((label, node[1], node[2]))] = (new_label, node[1], node[2]) + self._node_labels[self._node_labels.index(node[0])] = new_label + self._node_coordinates[new_label] = self._node_coordinates[label] + self._node_coordinates.pop(label) + if node[0] in self._supports: + self._supports[new_label] = self._supports[node[0]] + self._supports.pop(node[0]) + if new_label in self._supports: + if self._supports[new_label] == 'pinned': + if 'R_'+str(label)+'_x' in self._reaction_loads and 'R_'+str(label)+'_y' in self._reaction_loads: + self._reaction_loads['R_'+str(new_label)+'_x'] = self._reaction_loads['R_'+str(label)+'_x'] + self._reaction_loads['R_'+str(new_label)+'_y'] = self._reaction_loads['R_'+str(label)+'_y'] + self._reaction_loads.pop('R_'+str(label)+'_x') + self._reaction_loads.pop('R_'+str(label)+'_y') + self._loads[new_label] = self._loads[label] + for load in self._loads[new_label]: + if load[1] == 90: + load[0] -= Symbol('R_'+str(label)+'_y') + if load[0] == 0: + self._loads[label].remove(load) + break + for load in self._loads[new_label]: + if load[1] == 0: + load[0] -= Symbol('R_'+str(label)+'_x') + if load[0] == 0: + self._loads[label].remove(load) + break + self.apply_load(new_label, Symbol('R_'+str(new_label)+'_x'), 0) + self.apply_load(new_label, Symbol('R_'+str(new_label)+'_y'), 90) + self._loads.pop(label) + elif self._supports[new_label] == 'roller': + self._loads[new_label] = self._loads[label] + for load in self._loads[label]: + if load[1] == 90: + load[0] -= Symbol('R_'+str(label)+'_y') + if load[0] == 0: + self._loads[label].remove(load) + break + self.apply_load(new_label, Symbol('R_'+str(new_label)+'_y'), 90) + self._loads.pop(label) + else: + if label in self._loads: + self._loads[new_label] = self._loads[label] + self._loads.pop(label) + for member in self._members: + if self._members[member][0] == node[0]: + self._members[member][0] = new_label + self._nodes_occupied[(new_label, self._members[member][1])] = True + self._nodes_occupied[(self._members[member][1], new_label)] = True + self._nodes_occupied.pop((label, self._members[member][1])) + self._nodes_occupied.pop((self._members[member][1], label)) + elif self._members[member][1] == node[0]: + self._members[member][1] = new_label + self._nodes_occupied[(self._members[member][0], new_label)] = True + self._nodes_occupied[(new_label, self._members[member][0])] = True + self._nodes_occupied.pop((self._members[member][0], label)) + self._nodes_occupied.pop((label, self._members[member][0])) + + def change_member_label(self, *args): + """ + This method changes the label(s) of the specified member(s). + + Parameters + ========== + The input(s) of this method are tuple(s) of the form (label, new_label) + + label: String or Symbol + The label of the member for which the label has + to be changed. + + new_label: String or Symbol + The new label of the member. + + Examples + ======== + + >>> from sympy.physics.continuum_mechanics.truss import Truss + >>> t = Truss() + >>> t.add_node(('A', 0, 0), ('B', 3, 0), ('D', 5, 0)) + >>> t.nodes + [('A', 0, 0), ('B', 3, 0), ('D', 5, 0)] + >>> t.change_node_label(('A', 'C')) + >>> t.nodes + [('C', 0, 0), ('B', 3, 0), ('D', 5, 0)] + >>> t.add_member(('BC', 'B', 'C'), ('BD', 'B', 'D')) + >>> t.members + {'BC': ['B', 'C'], 'BD': ['B', 'D']} + >>> t.change_member_label(('BC', 'BC_new'), ('BD', 'BD_new')) + >>> t.members + {'BC_new': ['B', 'C'], 'BD_new': ['B', 'D']} + """ + for i in args: + label = i[0] + new_label = i[1] + if label not in self._members: + raise ValueError("No such member exists for the Truss") + else: + members_duplicate = list(self._members).copy() + for member in members_duplicate: + if member == label: + self._members[new_label] = [self._members[member][0], self._members[member][1]] + self._members.pop(label) + self._member_lengths[new_label] = self._member_lengths[label] + self._member_lengths.pop(label) + self._internal_forces[new_label] = self._internal_forces[label] + self._internal_forces.pop(label) + + def apply_load(self, *args): + """ + This method applies external load(s) at the specified node(s). + + Parameters + ========== + The input(s) of the method are tuple(s) of the form (location, magnitude, direction). + + location: String or Symbol + Label of the Node at which load is applied. + + magnitude: Sympifyable + Magnitude of the load applied. It must always be positive and any changes in + the direction of the load are not reflected here. + + direction: Sympifyable + The angle, in degrees, that the load vector makes with the horizontal + in the counter-clockwise direction. It takes the values 0 to 360, + inclusive. + + Examples + ======== + + >>> from sympy.physics.continuum_mechanics.truss import Truss + >>> from sympy import symbols + >>> t = Truss() + >>> t.add_node(('A', 0, 0), ('B', 3, 0)) + >>> P = symbols('P') + >>> t.apply_load(('A', P, 90), ('A', P/2, 45), ('A', P/4, 90)) + >>> t.loads + {'A': [[P, 90], [P/2, 45], [P/4, 90]]} + """ + for i in args: + location = i[0] + magnitude = i[1] + direction = i[2] + magnitude = sympify(magnitude) + direction = sympify(direction) + + if location not in self._node_coordinates: + raise ValueError("Load must be applied at a known node") + + else: + if location in self._loads: + self._loads[location].append([magnitude, direction]) + else: + self._loads[location] = [[magnitude, direction]] + + def remove_load(self, *args): + """ + This method removes already + present external load(s) at specified node(s). + + Parameters + ========== + The input(s) of this method are tuple(s) of the form (location, magnitude, direction). + + location: String or Symbol + Label of the Node at which load is applied and is to be removed. + + magnitude: Sympifyable + Magnitude of the load applied. + + direction: Sympifyable + The angle, in degrees, that the load vector makes with the horizontal + in the counter-clockwise direction. It takes the values 0 to 360, + inclusive. + + Examples + ======== + + >>> from sympy.physics.continuum_mechanics.truss import Truss + >>> from sympy import symbols + >>> t = Truss() + >>> t.add_node(('A', 0, 0), ('B', 3, 0)) + >>> P = symbols('P') + >>> t.apply_load(('A', P, 90), ('A', P/2, 45), ('A', P/4, 90)) + >>> t.loads + {'A': [[P, 90], [P/2, 45], [P/4, 90]]} + >>> t.remove_load(('A', P/4, 90), ('A', P/2, 45)) + >>> t.loads + {'A': [[P, 90]]} + """ + for i in args: + location = i[0] + magnitude = i[1] + direction = i[2] + magnitude = sympify(magnitude) + direction = sympify(direction) + + if location not in self._node_coordinates: + raise ValueError("Load must be removed from a known node") + + else: + if [magnitude, direction] not in self._loads[location]: + raise ValueError("No load of this magnitude and direction has been applied at this node") + else: + self._loads[location].remove([magnitude, direction]) + if self._loads[location] == []: + self._loads.pop(location) + + def apply_support(self, *args): + """ + This method adds a pinned or roller support at specified node(s). + + Parameters + ========== + The input(s) of this method are of the form (location, type). + + location: String or Symbol + Label of the Node at which support is added. + + type: String + Type of the support being provided at the node. + + Examples + ======== + + >>> from sympy.physics.continuum_mechanics.truss import Truss + >>> t = Truss() + >>> t.add_node(('A', 0, 0), ('B', 3, 0)) + >>> t.apply_support(('A', 'pinned'), ('B', 'roller')) + >>> t.supports + {'A': 'pinned', 'B': 'roller'} + """ + for i in args: + location = i[0] + type = i[1] + if location not in self._node_coordinates: + raise ValueError("Support must be added on a known node") + + else: + if location not in self._supports: + if type == 'pinned': + self.apply_load((location, Symbol('R_'+str(location)+'_x'), 0)) + self.apply_load((location, Symbol('R_'+str(location)+'_y'), 90)) + elif type == 'roller': + self.apply_load((location, Symbol('R_'+str(location)+'_y'), 90)) + elif self._supports[location] == 'pinned': + if type == 'roller': + self.remove_load((location, Symbol('R_'+str(location)+'_x'), 0)) + elif self._supports[location] == 'roller': + if type == 'pinned': + self.apply_load((location, Symbol('R_'+str(location)+'_x'), 0)) + self._supports[location] = type + + def remove_support(self, *args): + """ + This method removes support from specified node(s.) + + Parameters + ========== + + locations: String or Symbol + Label of the Node(s) at which support is to be removed. + + Examples + ======== + + >>> from sympy.physics.continuum_mechanics.truss import Truss + >>> t = Truss() + >>> t.add_node(('A', 0, 0), ('B', 3, 0)) + >>> t.apply_support(('A', 'pinned'), ('B', 'roller')) + >>> t.supports + {'A': 'pinned', 'B': 'roller'} + >>> t.remove_support('A','B') + >>> t.supports + {} + """ + for location in args: + + if location not in self._node_coordinates: + raise ValueError("No such node exists in the Truss") + + elif location not in self._supports: + raise ValueError("No support has been added to the given node") + + else: + if self._supports[location] == 'pinned': + self.remove_load((location, Symbol('R_'+str(location)+'_x'), 0)) + self.remove_load((location, Symbol('R_'+str(location)+'_y'), 90)) + elif self._supports[location] == 'roller': + self.remove_load((location, Symbol('R_'+str(location)+'_y'), 90)) + self._supports.pop(location) + + def solve(self): + """ + This method solves for all reaction forces of all supports and all internal forces + of all the members in the truss, provided the Truss is solvable. + + A Truss is solvable if the following condition is met, + + 2n >= r + m + + Where n is the number of nodes, r is the number of reaction forces, where each pinned + support has 2 reaction forces and each roller has 1, and m is the number of members. + + The given condition is derived from the fact that a system of equations is solvable + only when the number of variables is lesser than or equal to the number of equations. + Equilibrium Equations in x and y directions give two equations per node giving 2n number + equations. However, the truss needs to be stable as well and may be unstable if 2n > r + m. + The number of variables is simply the sum of the number of reaction forces and member + forces. + + .. note:: + The sign convention for the internal forces present in a member revolves around whether each + force is compressive or tensile. While forming equations for each node, internal force due + to a member on the node is assumed to be away from the node i.e. each force is assumed to + be compressive by default. Hence, a positive value for an internal force implies the + presence of compressive force in the member and a negative value implies a tensile force. + + Examples + ======== + + >>> from sympy.physics.continuum_mechanics.truss import Truss + >>> t = Truss() + >>> t.add_node(("node_1", 0, 0), ("node_2", 6, 0), ("node_3", 2, 2), ("node_4", 2, 0)) + >>> t.add_member(("member_1", "node_1", "node_4"), ("member_2", "node_2", "node_4"), ("member_3", "node_1", "node_3")) + >>> t.add_member(("member_4", "node_2", "node_3"), ("member_5", "node_3", "node_4")) + >>> t.apply_load(("node_4", 10, 270)) + >>> t.apply_support(("node_1", "pinned"), ("node_2", "roller")) + >>> t.solve() + >>> t.reaction_loads + {'R_node_1_x': 0, 'R_node_1_y': 20/3, 'R_node_2_y': 10/3} + >>> t.internal_forces + {'member_1': 20/3, 'member_2': 20/3, 'member_3': -20*sqrt(2)/3, 'member_4': -10*sqrt(5)/3, 'member_5': 10} + """ + count_reaction_loads = 0 + for node in self._nodes: + if node[0] in self._supports: + if self._supports[node[0]]=='pinned': + count_reaction_loads += 2 + elif self._supports[node[0]]=='roller': + count_reaction_loads += 1 + if 2*len(self._nodes) != len(self._members) + count_reaction_loads: + raise ValueError("The given truss cannot be solved") + coefficients_matrix = [[0 for i in range(2*len(self._nodes))] for j in range(2*len(self._nodes))] + load_matrix = zeros(2*len(self.nodes), 1) + load_matrix_row = 0 + for node in self._nodes: + if node[0] in self._loads: + for load in self._loads[node[0]]: + if load[0]!=Symbol('R_'+str(node[0])+'_x') and load[0]!=Symbol('R_'+str(node[0])+'_y'): + load_matrix[load_matrix_row] -= load[0]*cos(pi*load[1]/180) + load_matrix[load_matrix_row + 1] -= load[0]*sin(pi*load[1]/180) + load_matrix_row += 2 + cols = 0 + row = 0 + for node in self._nodes: + if node[0] in self._supports: + if self._supports[node[0]]=='pinned': + coefficients_matrix[row][cols] += 1 + coefficients_matrix[row+1][cols+1] += 1 + cols += 2 + elif self._supports[node[0]]=='roller': + coefficients_matrix[row+1][cols] += 1 + cols += 1 + row += 2 + for member in self._members: + start = self._members[member][0] + end = self._members[member][1] + length = sqrt((self._node_coordinates[start][0]-self._node_coordinates[end][0])**2 + (self._node_coordinates[start][1]-self._node_coordinates[end][1])**2) + start_index = self._node_labels.index(start) + end_index = self._node_labels.index(end) + horizontal_component_start = (self._node_coordinates[end][0]-self._node_coordinates[start][0])/length + vertical_component_start = (self._node_coordinates[end][1]-self._node_coordinates[start][1])/length + horizontal_component_end = (self._node_coordinates[start][0]-self._node_coordinates[end][0])/length + vertical_component_end = (self._node_coordinates[start][1]-self._node_coordinates[end][1])/length + coefficients_matrix[start_index*2][cols] += horizontal_component_start + coefficients_matrix[start_index*2+1][cols] += vertical_component_start + coefficients_matrix[end_index*2][cols] += horizontal_component_end + coefficients_matrix[end_index*2+1][cols] += vertical_component_end + cols += 1 + forces_matrix = (Matrix(coefficients_matrix)**-1)*load_matrix + self._reaction_loads = {} + i = 0 + min_load = inf + for node in self._nodes: + if node[0] in self._loads: + for load in self._loads[node[0]]: + if type(load[0]) not in [Symbol, Mul, Add]: + min_load = min(min_load, load[0]) + for j in range(len(forces_matrix)): + if type(forces_matrix[j]) not in [Symbol, Mul, Add]: + if abs(forces_matrix[j]/min_load) <1E-10: + forces_matrix[j] = 0 + for node in self._nodes: + if node[0] in self._supports: + if self._supports[node[0]]=='pinned': + self._reaction_loads['R_'+str(node[0])+'_x'] = forces_matrix[i] + self._reaction_loads['R_'+str(node[0])+'_y'] = forces_matrix[i+1] + i += 2 + elif self._supports[node[0]]=='roller': + self._reaction_loads['R_'+str(node[0])+'_y'] = forces_matrix[i] + i += 1 + for member in self._members: + self._internal_forces[member] = forces_matrix[i] + i += 1 + return + + @doctest_depends_on(modules=('numpy',)) + def draw(self, subs_dict=None): + """ + Returns a plot object of the Truss with all its nodes, members, + supports and loads. + + .. note:: + The user must be careful while entering load values in their + directions. The draw function assumes a sign convention that + is used for plotting loads. + + Given a right-handed coordinate system with XYZ coordinates, + the supports are assumed to be such that the reaction forces of a + pinned support is in the +X and +Y direction while those of a + roller support is in the +Y direction. For the load, the range + of angles, one can input goes all the way to 360 degrees which, in the + the plot is the angle that the load vector makes with the positive x-axis in the anticlockwise direction. + + For example, for a 90-degree angle, the load will be a vertically + directed along +Y while a 270-degree angle denotes a vertical + load as well but along -Y. + + Examples + ======== + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.physics.continuum_mechanics.truss import Truss + >>> import math + >>> t = Truss() + >>> t.add_node(("A", -4, 0), ("B", 0, 0), ("C", 4, 0), ("D", 8, 0)) + >>> t.add_node(("E", 6, 2/math.sqrt(3))) + >>> t.add_node(("F", 2, 2*math.sqrt(3))) + >>> t.add_node(("G", -2, 2/math.sqrt(3))) + >>> t.add_member(("AB","A","B"), ("BC","B","C"), ("CD","C","D")) + >>> t.add_member(("AG","A","G"), ("GB","G","B"), ("GF","G","F")) + >>> t.add_member(("BF","B","F"), ("FC","F","C"), ("CE","C","E")) + >>> t.add_member(("FE","F","E"), ("DE","D","E")) + >>> t.apply_support(("A","pinned"), ("D","roller")) + >>> t.apply_load(("G", 3, 90), ("E", 3, 90), ("F", 2, 90)) + >>> p = t.draw() + >>> p # doctest: +ELLIPSIS + Plot object containing: + [0]: cartesian line: 1 for x over (1.0, 1.0) + ... + >>> p.show() + """ + if not numpy: + raise ImportError("To use this function numpy module is required") + + x = Symbol('x') + + markers = [] + annotations = [] + rectangles = [] + + node_markers = self._draw_nodes(subs_dict) + markers += node_markers + + member_rectangles = self._draw_members() + rectangles += member_rectangles + + support_markers = self._draw_supports() + markers += support_markers + + load_annotations = self._draw_loads() + annotations += load_annotations + + xmax = -INF + xmin = INF + ymax = -INF + ymin = INF + + for node in self._node_coordinates: + xmax = max(xmax, self._node_coordinates[node][0]) + xmin = min(xmin, self._node_coordinates[node][0]) + ymax = max(ymax, self._node_coordinates[node][1]) + ymin = min(ymin, self._node_coordinates[node][1]) + + lim = max(xmax*1.1-xmin*0.8+1, ymax*1.1-ymin*0.8+1) + + if lim==xmax*1.1-xmin*0.8+1: + sing_plot = plot(1, (x, 1, 1), markers=markers, show=False, annotations=annotations, xlim=(xmin-0.05*lim, xmax*1.1), ylim=(xmin-0.05*lim, xmax*1.1), axis=False, rectangles=rectangles) + + else: + sing_plot = plot(1, (x, 1, 1), markers=markers, show=False, annotations=annotations, xlim=(ymin-0.05*lim, ymax*1.1), ylim=(ymin-0.05*lim, ymax*1.1), axis=False, rectangles=rectangles) + + return sing_plot + + + def _draw_nodes(self, subs_dict): + node_markers = [] + + for node in self._node_coordinates: + if (type(self._node_coordinates[node][0]) in (Symbol, Quantity)): + if self._node_coordinates[node][0] in subs_dict: + self._node_coordinates[node][0] = subs_dict[self._node_coordinates[node][0]] + else: + raise ValueError("provided substituted dictionary is not adequate") + elif (type(self._node_coordinates[node][0]) == Mul): + objects = self._node_coordinates[node][0].as_coeff_Mul() + for object in objects: + if type(object) in (Symbol, Quantity): + if subs_dict==None or object not in subs_dict: + raise ValueError("provided substituted dictionary is not adequate") + else: + self._node_coordinates[node][0] /= object + self._node_coordinates[node][0] *= subs_dict[object] + + if (type(self._node_coordinates[node][1]) in (Symbol, Quantity)): + if self._node_coordinates[node][1] in subs_dict: + self._node_coordinates[node][1] = subs_dict[self._node_coordinates[node][1]] + else: + raise ValueError("provided substituted dictionary is not adequate") + elif (type(self._node_coordinates[node][1]) == Mul): + objects = self._node_coordinates[node][1].as_coeff_Mul() + for object in objects: + if type(object) in (Symbol, Quantity): + if subs_dict==None or object not in subs_dict: + raise ValueError("provided substituted dictionary is not adequate") + else: + self._node_coordinates[node][1] /= object + self._node_coordinates[node][1] *= subs_dict[object] + + for node in self._node_coordinates: + node_markers.append( + { + 'args':[[self._node_coordinates[node][0]], [self._node_coordinates[node][1]]], + 'marker':'o', + 'markersize':5, + 'color':'black' + } + ) + return node_markers + + def _draw_members(self): + + member_rectangles = [] + + xmax = -INF + xmin = INF + ymax = -INF + ymin = INF + + for node in self._node_coordinates: + xmax = max(xmax, self._node_coordinates[node][0]) + xmin = min(xmin, self._node_coordinates[node][0]) + ymax = max(ymax, self._node_coordinates[node][1]) + ymin = min(ymin, self._node_coordinates[node][1]) + + if abs(1.1*xmax-0.8*xmin)>abs(1.1*ymax-0.8*ymin): + max_diff = 1.1*xmax-0.8*xmin + else: + max_diff = 1.1*ymax-0.8*ymin + + for member in self._members: + x1 = self._node_coordinates[self._members[member][0]][0] + y1 = self._node_coordinates[self._members[member][0]][1] + x2 = self._node_coordinates[self._members[member][1]][0] + y2 = self._node_coordinates[self._members[member][1]][1] + if x2!=x1 and y2!=y1: + if x2>x1: + member_rectangles.append( + { + 'xy':(x1-0.005*max_diff*cos(pi/4+atan((y2-y1)/(x2-x1)))/2, y1-0.005*max_diff*sin(pi/4+atan((y2-y1)/(x2-x1)))/2), + 'width':sqrt((x1-x2)**2+(y1-y2)**2)+0.005*max_diff/math.sqrt(2), + 'height':0.005*max_diff, + 'angle':180*atan((y2-y1)/(x2-x1))/pi, + 'color':'brown' + } + ) + else: + member_rectangles.append( + { + 'xy':(x2-0.005*max_diff*cos(pi/4+atan((y2-y1)/(x2-x1)))/2, y2-0.005*max_diff*sin(pi/4+atan((y2-y1)/(x2-x1)))/2), + 'width':sqrt((x1-x2)**2+(y1-y2)**2)+0.005*max_diff/math.sqrt(2), + 'height':0.005*max_diff, + 'angle':180*atan((y2-y1)/(x2-x1))/pi, + 'color':'brown' + } + ) + elif y2==y1: + if x2>x1: + member_rectangles.append( + { + 'xy':(x1-0.005*max_diff/2, y1-0.005*max_diff/2), + 'width':sqrt((x1-x2)**2+(y1-y2)**2), + 'height':0.005*max_diff, + 'angle':90*(1-math.copysign(1, x2-x1)), + 'color':'brown' + } + ) + else: + member_rectangles.append( + { + 'xy':(x1-0.005*max_diff/2, y1-0.005*max_diff/2), + 'width':sqrt((x1-x2)**2+(y1-y2)**2), + 'height':-0.005*max_diff, + 'angle':90*(1-math.copysign(1, x2-x1)), + 'color':'brown' + } + ) + else: + if y1abs(1.1*ymax-0.8*ymin): + max_diff = 1.1*xmax-0.8*xmin + else: + max_diff = 1.1*ymax-0.8*ymin + + for node in self._supports: + if self._supports[node]=='pinned': + support_markers.append( + { + 'args':[ + [self._node_coordinates[node][0]], + [self._node_coordinates[node][1]] + ], + 'marker':6, + 'markersize':15, + 'color':'black', + 'markerfacecolor':'none' + } + ) + support_markers.append( + { + 'args':[ + [self._node_coordinates[node][0]], + [self._node_coordinates[node][1]-0.035*max_diff] + ], + 'marker':'_', + 'markersize':14, + 'color':'black' + } + ) + + elif self._supports[node]=='roller': + support_markers.append( + { + 'args':[ + [self._node_coordinates[node][0]], + [self._node_coordinates[node][1]-0.02*max_diff] + ], + 'marker':'o', + 'markersize':11, + 'color':'black', + 'markerfacecolor':'none' + } + ) + support_markers.append( + { + 'args':[ + [self._node_coordinates[node][0]], + [self._node_coordinates[node][1]-0.0375*max_diff] + ], + 'marker':'_', + 'markersize':14, + 'color':'black' + } + ) + return support_markers + + def _draw_loads(self): + load_annotations = [] + + xmax = -INF + xmin = INF + ymax = -INF + ymin = INF + + for node in self._node_coordinates: + xmax = max(xmax, self._node_coordinates[node][0]) + xmin = min(xmin, self._node_coordinates[node][0]) + ymax = max(ymax, self._node_coordinates[node][1]) + ymin = min(ymin, self._node_coordinates[node][1]) + + if abs(1.1*xmax-0.8*xmin)>abs(1.1*ymax-0.8*ymin): + max_diff = 1.1*xmax-0.8*xmin+5 + else: + max_diff = 1.1*ymax-0.8*ymin+5 + + for node in self._loads: + for load in self._loads[node]: + if load[0] in [Symbol('R_'+str(node)+'_x'), Symbol('R_'+str(node)+'_y')]: + continue + x = self._node_coordinates[node][0] + y = self._node_coordinates[node][1] + load_annotations.append( + { + 'text':'', + 'xy':( + x-math.cos(pi*load[1]/180)*(max_diff/100), + y-math.sin(pi*load[1]/180)*(max_diff/100) + ), + 'xytext':( + x-(max_diff/100+abs(xmax-xmin)+abs(ymax-ymin))*math.cos(pi*load[1]/180)/20, + y-(max_diff/100+abs(xmax-xmin)+abs(ymax-ymin))*math.sin(pi*load[1]/180)/20 + ), + 'arrowprops':{'width':1.5, 'headlength':5, 'headwidth':5, 'facecolor':'black'} + } + ) + return load_annotations diff --git a/phivenv/Lib/site-packages/sympy/physics/control/__init__.py b/phivenv/Lib/site-packages/sympy/physics/control/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c4d74895f2e68cb918f00fd7065ca048b32ef06d --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/control/__init__.py @@ -0,0 +1,17 @@ +from .lti import (TransferFunction, PIDController, Series, MIMOSeries, Parallel, MIMOParallel, + Feedback, MIMOFeedback, TransferFunctionMatrix, StateSpace, gbt, bilinear, forward_diff, + backward_diff, phase_margin, gain_margin) +from .control_plots import (pole_zero_numerical_data, pole_zero_plot, step_response_numerical_data, + step_response_plot, impulse_response_numerical_data, impulse_response_plot, ramp_response_numerical_data, + ramp_response_plot, bode_magnitude_numerical_data, bode_phase_numerical_data, bode_magnitude_plot, + bode_phase_plot, bode_plot, nyquist_plot_expr, nyquist_plot, nichols_plot_expr, nichols_plot) + +__all__ = ['TransferFunction', 'PIDController', 'Series', 'MIMOSeries', 'Parallel', + 'MIMOParallel', 'Feedback', 'MIMOFeedback', 'TransferFunctionMatrix', 'StateSpace', + 'gbt', 'bilinear', 'forward_diff', 'backward_diff', 'phase_margin', 'gain_margin', + 'pole_zero_numerical_data', 'pole_zero_plot', 'step_response_numerical_data', + 'step_response_plot', 'impulse_response_numerical_data', 'impulse_response_plot', + 'ramp_response_numerical_data', 'ramp_response_plot', + 'bode_magnitude_numerical_data', 'bode_phase_numerical_data', + 'bode_magnitude_plot', 'bode_phase_plot', 'bode_plot', 'nyquist_plot_expr', 'nyquist_plot', + 'nichols_plot_expr', 'nichols_plot'] diff --git a/phivenv/Lib/site-packages/sympy/physics/control/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/control/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..588e7497442204193db3f7237a564b9543e48a24 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/control/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/control/__pycache__/control_plots.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/control/__pycache__/control_plots.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bfa6ce1a9be5415a6576fb04531259b4a4cbdda8 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/control/__pycache__/control_plots.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/control/control_plots.py b/phivenv/Lib/site-packages/sympy/physics/control/control_plots.py new file mode 100644 index 0000000000000000000000000000000000000000..1a83d3b833a064905619a4d6ba2a74e52ef72afa --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/control/control_plots.py @@ -0,0 +1,1135 @@ +from sympy.core.numbers import I, pi +from sympy.functions.elementary.exponential import (exp, log) +from sympy.polys.partfrac import apart +from sympy.core.symbol import Dummy +from sympy.external import import_module +from sympy.functions import arg, Abs +from sympy.integrals.laplace import _fast_inverse_laplace +from sympy.physics.control.lti import SISOLinearTimeInvariant +from sympy.plotting.series import LineOver1DRangeSeries +from sympy.plotting.plot import plot_parametric +from sympy.polys.domains import ZZ, QQ +from sympy.polys.polytools import Poly +from sympy.printing.latex import latex +from sympy.geometry.polygon import deg + +__all__ = ['pole_zero_numerical_data', 'pole_zero_plot', + 'step_response_numerical_data', 'step_response_plot', + 'impulse_response_numerical_data', 'impulse_response_plot', + 'ramp_response_numerical_data', 'ramp_response_plot', + 'bode_magnitude_numerical_data', 'bode_phase_numerical_data', + 'bode_magnitude_plot', 'bode_phase_plot', 'bode_plot', + 'nyquist_plot_expr', 'nyquist_plot', 'nichols_plot_expr', + 'nichols_plot'] + + +matplotlib = import_module( + 'matplotlib', import_kwargs={'fromlist': ['pyplot']}, + catch=(RuntimeError,)) + +if matplotlib: + plt = matplotlib.pyplot + + +def _check_system(system): + """Function to check whether the dynamical system passed for plots is + compatible or not.""" + if not isinstance(system, SISOLinearTimeInvariant): + raise NotImplementedError("Only SISO LTI systems are currently supported.") + sys = system.to_expr() + len_free_symbols = len(sys.free_symbols) + if len_free_symbols > 1: + raise ValueError("Extra degree of freedom found. Make sure" + " that there are no free symbols in the dynamical system other" + " than the variable of Laplace transform.") + if sys.has(exp): + # Should test that exp is not part of a constant, in which case + # no exception is required, compare exp(s) with s*exp(1) + raise NotImplementedError("Time delay terms are not supported.") + + +def _poly_roots(poly): + """Function to get the roots of a polynomial.""" + def _eval(l): + return [float(i) if i.is_real else complex(i) for i in l] + if poly.domain in (QQ, ZZ): + return _eval(poly.all_roots()) + # XXX: Use all_roots() for irrational coefficients when possible + # See https://github.com/sympy/sympy/issues/22943 + return _eval(poly.nroots()) + + +def pole_zero_numerical_data(system): + """ + Returns the numerical data of poles and zeros of the system. + It is internally used by ``pole_zero_plot`` to get the data + for plotting poles and zeros. Users can use this data to further + analyse the dynamics of the system or plot using a different + backend/plotting-module. + + Parameters + ========== + + system : SISOLinearTimeInvariant + The system for which the pole-zero data is to be computed. + + Returns + ======= + + tuple : (zeros, poles) + zeros = Zeros of the system as a list of Python float/complex. + poles = Poles of the system as a list of Python float/complex. + + Raises + ====== + + NotImplementedError + When a SISO LTI system is not passed. + + When time delay terms are present in the system. + + ValueError + When more than one free symbol is present in the system. + The only variable in the transfer function should be + the variable of the Laplace transform. + + Examples + ======== + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction + >>> from sympy.physics.control.control_plots import pole_zero_numerical_data + >>> tf1 = TransferFunction(s**2 + 1, s**4 + 4*s**3 + 6*s**2 + 5*s + 2, s) + >>> pole_zero_numerical_data(tf1) + ([-1j, 1j], [-2.0, -1.0, (-0.5-0.8660254037844386j), (-0.5+0.8660254037844386j)]) + + See Also + ======== + + pole_zero_plot + + """ + _check_system(system) + system = system.doit() # Get the equivalent TransferFunction object. + + num_poly = Poly(system.num, system.var) + den_poly = Poly(system.den, system.var) + + return _poly_roots(num_poly), _poly_roots(den_poly) + + +def pole_zero_plot(system, pole_color='blue', pole_markersize=10, + zero_color='orange', zero_markersize=7, grid=True, show_axes=True, + show=True, **kwargs): + r""" + Returns the Pole-Zero plot (also known as PZ Plot or PZ Map) of a system. + + A Pole-Zero plot is a graphical representation of a system's poles and + zeros. It is plotted on a complex plane, with circular markers representing + the system's zeros and 'x' shaped markers representing the system's poles. + + Parameters + ========== + + system : SISOLinearTimeInvariant type systems + The system for which the pole-zero plot is to be computed. + pole_color : str, tuple, optional + The color of the pole points on the plot. Default color + is blue. The color can be provided as a matplotlib color string, + or a 3-tuple of floats each in the 0-1 range. + pole_markersize : Number, optional + The size of the markers used to mark the poles in the plot. + Default pole markersize is 10. + zero_color : str, tuple, optional + The color of the zero points on the plot. Default color + is orange. The color can be provided as a matplotlib color string, + or a 3-tuple of floats each in the 0-1 range. + zero_markersize : Number, optional + The size of the markers used to mark the zeros in the plot. + Default zero markersize is 7. + grid : boolean, optional + If ``True``, the plot will have a grid. Defaults to True. + show_axes : boolean, optional + If ``True``, the coordinate axes will be shown. Defaults to False. + show : boolean, optional + If ``True``, the plot will be displayed otherwise + the equivalent matplotlib ``plot`` object will be returned. + Defaults to True. + + Examples + ======== + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction + >>> from sympy.physics.control.control_plots import pole_zero_plot + >>> tf1 = TransferFunction(s**2 + 1, s**4 + 4*s**3 + 6*s**2 + 5*s + 2, s) + >>> pole_zero_plot(tf1) # doctest: +SKIP + + See Also + ======== + + pole_zero_numerical_data + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Pole%E2%80%93zero_plot + + """ + zeros, poles = pole_zero_numerical_data(system) + + zero_real = [i.real for i in zeros] + zero_imag = [i.imag for i in zeros] + + pole_real = [i.real for i in poles] + pole_imag = [i.imag for i in poles] + + plt.plot(pole_real, pole_imag, 'x', mfc='none', + markersize=pole_markersize, color=pole_color) + plt.plot(zero_real, zero_imag, 'o', markersize=zero_markersize, + color=zero_color) + plt.xlabel('Real Axis') + plt.ylabel('Imaginary Axis') + plt.title(f'Poles and Zeros of ${latex(system)}$', pad=20) + + if grid: + plt.grid() + if show_axes: + plt.axhline(0, color='black') + plt.axvline(0, color='black') + if show: + plt.show() + return + + return plt + + +def step_response_numerical_data(system, prec=8, lower_limit=0, + upper_limit=10, **kwargs): + """ + Returns the numerical values of the points in the step response plot + of a SISO continuous-time system. By default, adaptive sampling + is used. If the user wants to instead get an uniformly + sampled response, then ``adaptive`` kwarg should be passed ``False`` + and ``n`` must be passed as additional kwargs. + Refer to the parameters of class :class:`sympy.plotting.series.LineOver1DRangeSeries` + for more details. + + Parameters + ========== + + system : SISOLinearTimeInvariant + The system for which the unit step response data is to be computed. + prec : int, optional + The decimal point precision for the point coordinate values. + Defaults to 8. + lower_limit : Number, optional + The lower limit of the plot range. Defaults to 0. + upper_limit : Number, optional + The upper limit of the plot range. Defaults to 10. + kwargs : + Additional keyword arguments are passed to the underlying + :class:`sympy.plotting.series.LineOver1DRangeSeries` class. + + Returns + ======= + + tuple : (x, y) + x = Time-axis values of the points in the step response. NumPy array. + y = Amplitude-axis values of the points in the step response. NumPy array. + + Raises + ====== + + NotImplementedError + When a SISO LTI system is not passed. + + When time delay terms are present in the system. + + ValueError + When more than one free symbol is present in the system. + The only variable in the transfer function should be + the variable of the Laplace transform. + + When ``lower_limit`` parameter is less than 0. + + Examples + ======== + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction + >>> from sympy.physics.control.control_plots import step_response_numerical_data + >>> tf1 = TransferFunction(s, s**2 + 5*s + 8, s) + >>> step_response_numerical_data(tf1) # doctest: +SKIP + ([0.0, 0.025413462339411542, 0.0484508722725343, ... , 9.670250533855183, 9.844291913708725, 10.0], + [0.0, 0.023844582399907256, 0.042894276802320226, ..., 6.828770759094287e-12, 6.456457160755703e-12]) + + See Also + ======== + + step_response_plot + + """ + if lower_limit < 0: + raise ValueError("Lower limit of time must be greater " + "than or equal to zero.") + _check_system(system) + _x = Dummy("x") + expr = system.to_expr()/(system.var) + expr = apart(expr, system.var, full=True) + _y = _fast_inverse_laplace(expr, system.var, _x).evalf(prec) + return LineOver1DRangeSeries(_y, (_x, lower_limit, upper_limit), + **kwargs).get_points() + + +def step_response_plot(system, color='b', prec=8, lower_limit=0, + upper_limit=10, show_axes=False, grid=True, show=True, **kwargs): + r""" + Returns the unit step response of a continuous-time system. It is + the response of the system when the input signal is a step function. + + Parameters + ========== + + system : SISOLinearTimeInvariant type + The LTI SISO system for which the Step Response is to be computed. + color : str, tuple, optional + The color of the line. Default is Blue. + show : boolean, optional + If ``True``, the plot will be displayed otherwise + the equivalent matplotlib ``plot`` object will be returned. + Defaults to True. + lower_limit : Number, optional + The lower limit of the plot range. Defaults to 0. + upper_limit : Number, optional + The upper limit of the plot range. Defaults to 10. + prec : int, optional + The decimal point precision for the point coordinate values. + Defaults to 8. + show_axes : boolean, optional + If ``True``, the coordinate axes will be shown. Defaults to False. + grid : boolean, optional + If ``True``, the plot will have a grid. Defaults to True. + + Examples + ======== + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction + >>> from sympy.physics.control.control_plots import step_response_plot + >>> tf1 = TransferFunction(8*s**2 + 18*s + 32, s**3 + 6*s**2 + 14*s + 24, s) + >>> step_response_plot(tf1) # doctest: +SKIP + + See Also + ======== + + impulse_response_plot, ramp_response_plot + + References + ========== + + .. [1] https://www.mathworks.com/help/control/ref/lti.step.html + + """ + x, y = step_response_numerical_data(system, prec=prec, + lower_limit=lower_limit, upper_limit=upper_limit, **kwargs) + plt.plot(x, y, color=color) + plt.xlabel('Time (s)') + plt.ylabel('Amplitude') + plt.title(f'Unit Step Response of ${latex(system)}$', pad=20) + + if grid: + plt.grid() + if show_axes: + plt.axhline(0, color='black') + plt.axvline(0, color='black') + if show: + plt.show() + return + + return plt + + +def impulse_response_numerical_data(system, prec=8, lower_limit=0, + upper_limit=10, **kwargs): + """ + Returns the numerical values of the points in the impulse response plot + of a SISO continuous-time system. By default, adaptive sampling + is used. If the user wants to instead get an uniformly + sampled response, then ``adaptive`` kwarg should be passed ``False`` + and ``n`` must be passed as additional kwargs. + Refer to the parameters of class :class:`sympy.plotting.series.LineOver1DRangeSeries` + for more details. + + Parameters + ========== + + system : SISOLinearTimeInvariant + The system for which the impulse response data is to be computed. + prec : int, optional + The decimal point precision for the point coordinate values. + Defaults to 8. + lower_limit : Number, optional + The lower limit of the plot range. Defaults to 0. + upper_limit : Number, optional + The upper limit of the plot range. Defaults to 10. + kwargs : + Additional keyword arguments are passed to the underlying + :class:`sympy.plotting.series.LineOver1DRangeSeries` class. + + Returns + ======= + + tuple : (x, y) + x = Time-axis values of the points in the impulse response. NumPy array. + y = Amplitude-axis values of the points in the impulse response. NumPy array. + + Raises + ====== + + NotImplementedError + When a SISO LTI system is not passed. + + When time delay terms are present in the system. + + ValueError + When more than one free symbol is present in the system. + The only variable in the transfer function should be + the variable of the Laplace transform. + + When ``lower_limit`` parameter is less than 0. + + Examples + ======== + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction + >>> from sympy.physics.control.control_plots import impulse_response_numerical_data + >>> tf1 = TransferFunction(s, s**2 + 5*s + 8, s) + >>> impulse_response_numerical_data(tf1) # doctest: +SKIP + ([0.0, 0.06616480200395854,... , 9.854500743565858, 10.0], + [0.9999999799999999, 0.7042848373025861,...,7.170748906965121e-13, -5.1901263495547205e-12]) + + See Also + ======== + + impulse_response_plot + + """ + if lower_limit < 0: + raise ValueError("Lower limit of time must be greater " + "than or equal to zero.") + _check_system(system) + _x = Dummy("x") + expr = system.to_expr() + expr = apart(expr, system.var, full=True) + _y = _fast_inverse_laplace(expr, system.var, _x).evalf(prec) + return LineOver1DRangeSeries(_y, (_x, lower_limit, upper_limit), + **kwargs).get_points() + + +def impulse_response_plot(system, color='b', prec=8, lower_limit=0, + upper_limit=10, show_axes=False, grid=True, show=True, **kwargs): + r""" + Returns the unit impulse response (Input is the Dirac-Delta Function) of a + continuous-time system. + + Parameters + ========== + + system : SISOLinearTimeInvariant type + The LTI SISO system for which the Impulse Response is to be computed. + color : str, tuple, optional + The color of the line. Default is Blue. + show : boolean, optional + If ``True``, the plot will be displayed otherwise + the equivalent matplotlib ``plot`` object will be returned. + Defaults to True. + lower_limit : Number, optional + The lower limit of the plot range. Defaults to 0. + upper_limit : Number, optional + The upper limit of the plot range. Defaults to 10. + prec : int, optional + The decimal point precision for the point coordinate values. + Defaults to 8. + show_axes : boolean, optional + If ``True``, the coordinate axes will be shown. Defaults to False. + grid : boolean, optional + If ``True``, the plot will have a grid. Defaults to True. + + Examples + ======== + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction + >>> from sympy.physics.control.control_plots import impulse_response_plot + >>> tf1 = TransferFunction(8*s**2 + 18*s + 32, s**3 + 6*s**2 + 14*s + 24, s) + >>> impulse_response_plot(tf1) # doctest: +SKIP + + See Also + ======== + + step_response_plot, ramp_response_plot + + References + ========== + + .. [1] https://www.mathworks.com/help/control/ref/dynamicsystem.impulse.html + + """ + x, y = impulse_response_numerical_data(system, prec=prec, + lower_limit=lower_limit, upper_limit=upper_limit, **kwargs) + plt.plot(x, y, color=color) + plt.xlabel('Time (s)') + plt.ylabel('Amplitude') + plt.title(f'Impulse Response of ${latex(system)}$', pad=20) + + if grid: + plt.grid() + if show_axes: + plt.axhline(0, color='black') + plt.axvline(0, color='black') + if show: + plt.show() + return + + return plt + + +def ramp_response_numerical_data(system, slope=1, prec=8, + lower_limit=0, upper_limit=10, **kwargs): + """ + Returns the numerical values of the points in the ramp response plot + of a SISO continuous-time system. By default, adaptive sampling + is used. If the user wants to instead get an uniformly + sampled response, then ``adaptive`` kwarg should be passed ``False`` + and ``n`` must be passed as additional kwargs. + Refer to the parameters of class :class:`sympy.plotting.series.LineOver1DRangeSeries` + for more details. + + Parameters + ========== + + system : SISOLinearTimeInvariant + The system for which the ramp response data is to be computed. + slope : Number, optional + The slope of the input ramp function. Defaults to 1. + prec : int, optional + The decimal point precision for the point coordinate values. + Defaults to 8. + lower_limit : Number, optional + The lower limit of the plot range. Defaults to 0. + upper_limit : Number, optional + The upper limit of the plot range. Defaults to 10. + kwargs : + Additional keyword arguments are passed to the underlying + :class:`sympy.plotting.series.LineOver1DRangeSeries` class. + + Returns + ======= + + tuple : (x, y) + x = Time-axis values of the points in the ramp response plot. NumPy array. + y = Amplitude-axis values of the points in the ramp response plot. NumPy array. + + Raises + ====== + + NotImplementedError + When a SISO LTI system is not passed. + + When time delay terms are present in the system. + + ValueError + When more than one free symbol is present in the system. + The only variable in the transfer function should be + the variable of the Laplace transform. + + When ``lower_limit`` parameter is less than 0. + + When ``slope`` is negative. + + Examples + ======== + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction + >>> from sympy.physics.control.control_plots import ramp_response_numerical_data + >>> tf1 = TransferFunction(s, s**2 + 5*s + 8, s) + >>> ramp_response_numerical_data(tf1) # doctest: +SKIP + (([0.0, 0.12166980856813935,..., 9.861246379582118, 10.0], + [1.4504508011325967e-09, 0.006046440489058766,..., 0.12499999999568202, 0.12499999999661349])) + + See Also + ======== + + ramp_response_plot + + """ + if slope < 0: + raise ValueError("Slope must be greater than or equal" + " to zero.") + if lower_limit < 0: + raise ValueError("Lower limit of time must be greater " + "than or equal to zero.") + _check_system(system) + _x = Dummy("x") + expr = (slope*system.to_expr())/((system.var)**2) + expr = apart(expr, system.var, full=True) + _y = _fast_inverse_laplace(expr, system.var, _x).evalf(prec) + return LineOver1DRangeSeries(_y, (_x, lower_limit, upper_limit), + **kwargs).get_points() + + +def ramp_response_plot(system, slope=1, color='b', prec=8, lower_limit=0, + upper_limit=10, show_axes=False, grid=True, show=True, **kwargs): + r""" + Returns the ramp response of a continuous-time system. + + Ramp function is defined as the straight line + passing through origin ($f(x) = mx$). The slope of + the ramp function can be varied by the user and + the default value is 1. + + Parameters + ========== + + system : SISOLinearTimeInvariant type + The LTI SISO system for which the Ramp Response is to be computed. + slope : Number, optional + The slope of the input ramp function. Defaults to 1. + color : str, tuple, optional + The color of the line. Default is Blue. + show : boolean, optional + If ``True``, the plot will be displayed otherwise + the equivalent matplotlib ``plot`` object will be returned. + Defaults to True. + lower_limit : Number, optional + The lower limit of the plot range. Defaults to 0. + upper_limit : Number, optional + The upper limit of the plot range. Defaults to 10. + prec : int, optional + The decimal point precision for the point coordinate values. + Defaults to 8. + show_axes : boolean, optional + If ``True``, the coordinate axes will be shown. Defaults to False. + grid : boolean, optional + If ``True``, the plot will have a grid. Defaults to True. + + Examples + ======== + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction + >>> from sympy.physics.control.control_plots import ramp_response_plot + >>> tf1 = TransferFunction(s, (s+4)*(s+8), s) + >>> ramp_response_plot(tf1, upper_limit=2) # doctest: +SKIP + + See Also + ======== + + step_response_plot, impulse_response_plot + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Ramp_function + + """ + x, y = ramp_response_numerical_data(system, slope=slope, prec=prec, + lower_limit=lower_limit, upper_limit=upper_limit, **kwargs) + plt.plot(x, y, color=color) + plt.xlabel('Time (s)') + plt.ylabel('Amplitude') + plt.title(f'Ramp Response of ${latex(system)}$ [Slope = {slope}]', pad=20) + + if grid: + plt.grid() + if show_axes: + plt.axhline(0, color='black') + plt.axvline(0, color='black') + if show: + plt.show() + return + + return plt + + +def bode_magnitude_numerical_data(system, initial_exp=-5, final_exp=5, freq_unit='rad/sec', **kwargs): + """ + Returns the numerical data of the Bode magnitude plot of the system. + It is internally used by ``bode_magnitude_plot`` to get the data + for plotting Bode magnitude plot. Users can use this data to further + analyse the dynamics of the system or plot using a different + backend/plotting-module. + + Parameters + ========== + + system : SISOLinearTimeInvariant + The system for which the data is to be computed. + initial_exp : Number, optional + The initial exponent of 10 of the semilog plot. Defaults to -5. + final_exp : Number, optional + The final exponent of 10 of the semilog plot. Defaults to 5. + freq_unit : string, optional + User can choose between ``'rad/sec'`` (radians/second) and ``'Hz'`` (Hertz) as frequency units. + + Returns + ======= + + tuple : (x, y) + x = x-axis values of the Bode magnitude plot. + y = y-axis values of the Bode magnitude plot. + + Raises + ====== + + NotImplementedError + When a SISO LTI system is not passed. + + When time delay terms are present in the system. + + ValueError + When more than one free symbol is present in the system. + The only variable in the transfer function should be + the variable of the Laplace transform. + + When incorrect frequency units are given as input. + + Examples + ======== + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction + >>> from sympy.physics.control.control_plots import bode_magnitude_numerical_data + >>> tf1 = TransferFunction(s**2 + 1, s**4 + 4*s**3 + 6*s**2 + 5*s + 2, s) + >>> bode_magnitude_numerical_data(tf1) # doctest: +SKIP + ([1e-05, 1.5148378120533502e-05,..., 68437.36188804005, 100000.0], + [-6.020599914256786, -6.0205999155219505,..., -193.4117304087953, -200.00000000260573]) + + See Also + ======== + + bode_magnitude_plot, bode_phase_numerical_data + + """ + _check_system(system) + expr = system.to_expr() + freq_units = ('rad/sec', 'Hz') + if freq_unit not in freq_units: + raise ValueError('Only "rad/sec" and "Hz" are accepted frequency units.') + + _w = Dummy("w", real=True) + if freq_unit == 'Hz': + repl = I*_w*2*pi + else: + repl = I*_w + w_expr = expr.subs({system.var: repl}) + + mag = 20*log(Abs(w_expr), 10) + + x, y = LineOver1DRangeSeries(mag, + (_w, 10**initial_exp, 10**final_exp), xscale='log', **kwargs).get_points() + + return x, y + + +def bode_magnitude_plot(system, initial_exp=-5, final_exp=5, + color='b', show_axes=False, grid=True, show=True, freq_unit='rad/sec', **kwargs): + r""" + Returns the Bode magnitude plot of a continuous-time system. + + See ``bode_plot`` for all the parameters. + """ + x, y = bode_magnitude_numerical_data(system, initial_exp=initial_exp, + final_exp=final_exp, freq_unit=freq_unit) + plt.plot(x, y, color=color, **kwargs) + plt.xscale('log') + + + plt.xlabel('Frequency (%s) [Log Scale]' % freq_unit) + plt.ylabel('Magnitude (dB)') + plt.title(f'Bode Plot (Magnitude) of ${latex(system)}$', pad=20) + + if grid: + plt.grid(True) + if show_axes: + plt.axhline(0, color='black') + plt.axvline(0, color='black') + if show: + plt.show() + return + + return plt + + +def bode_phase_numerical_data(system, initial_exp=-5, final_exp=5, freq_unit='rad/sec', phase_unit='rad', phase_unwrap = True, **kwargs): + """ + Returns the numerical data of the Bode phase plot of the system. + It is internally used by ``bode_phase_plot`` to get the data + for plotting Bode phase plot. Users can use this data to further + analyse the dynamics of the system or plot using a different + backend/plotting-module. + + Parameters + ========== + + system : SISOLinearTimeInvariant + The system for which the Bode phase plot data is to be computed. + initial_exp : Number, optional + The initial exponent of 10 of the semilog plot. Defaults to -5. + final_exp : Number, optional + The final exponent of 10 of the semilog plot. Defaults to 5. + freq_unit : string, optional + User can choose between ``'rad/sec'`` (radians/second) and '``'Hz'`` (Hertz) as frequency units. + phase_unit : string, optional + User can choose between ``'rad'`` (radians) and ``'deg'`` (degree) as phase units. + phase_unwrap : bool, optional + Set to ``True`` by default. + + Returns + ======= + + tuple : (x, y) + x = x-axis values of the Bode phase plot. + y = y-axis values of the Bode phase plot. + + Raises + ====== + + NotImplementedError + When a SISO LTI system is not passed. + + When time delay terms are present in the system. + + ValueError + When more than one free symbol is present in the system. + The only variable in the transfer function should be + the variable of the Laplace transform. + + When incorrect frequency or phase units are given as input. + + Examples + ======== + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction + >>> from sympy.physics.control.control_plots import bode_phase_numerical_data + >>> tf1 = TransferFunction(s**2 + 1, s**4 + 4*s**3 + 6*s**2 + 5*s + 2, s) + >>> bode_phase_numerical_data(tf1) # doctest: +SKIP + ([1e-05, 1.4472354033813751e-05, 2.035581932165858e-05,..., 47577.3248186011, 67884.09326036123, 100000.0], + [-2.5000000000291665e-05, -3.6180885085e-05, -5.08895483066e-05,...,-3.1415085799262523, -3.14155265358979]) + + See Also + ======== + + bode_magnitude_plot, bode_phase_numerical_data + + """ + _check_system(system) + expr = system.to_expr() + freq_units = ('rad/sec', 'Hz') + phase_units = ('rad', 'deg') + if freq_unit not in freq_units: + raise ValueError('Only "rad/sec" and "Hz" are accepted frequency units.') + if phase_unit not in phase_units: + raise ValueError('Only "rad" and "deg" are accepted phase units.') + + _w = Dummy("w", real=True) + if freq_unit == 'Hz': + repl = I*_w*2*pi + else: + repl = I*_w + w_expr = expr.subs({system.var: repl}) + + if phase_unit == 'deg': + phase = arg(w_expr)*180/pi + else: + phase = arg(w_expr) + + x, y = LineOver1DRangeSeries(phase, + (_w, 10**initial_exp, 10**final_exp), xscale='log', **kwargs).get_points() + + half = None + if phase_unwrap: + if(phase_unit == 'rad'): + half = pi + elif(phase_unit == 'deg'): + half = 180 + if half: + unit = 2*half + for i in range(1, len(y)): + diff = y[i] - y[i - 1] + if diff > half: # Jump from -half to half + y[i] = (y[i] - unit) + elif diff < -half: # Jump from half to -half + y[i] = (y[i] + unit) + + return x, y + + +def bode_phase_plot(system, initial_exp=-5, final_exp=5, + color='b', show_axes=False, grid=True, show=True, freq_unit='rad/sec', phase_unit='rad', phase_unwrap=True, **kwargs): + r""" + Returns the Bode phase plot of a continuous-time system. + + See ``bode_plot`` for all the parameters. + """ + x, y = bode_phase_numerical_data(system, initial_exp=initial_exp, + final_exp=final_exp, freq_unit=freq_unit, phase_unit=phase_unit, phase_unwrap=phase_unwrap) + plt.plot(x, y, color=color, **kwargs) + plt.xscale('log') + + plt.xlabel('Frequency (%s) [Log Scale]' % freq_unit) + plt.ylabel('Phase (%s)' % phase_unit) + plt.title(f'Bode Plot (Phase) of ${latex(system)}$', pad=20) + + if grid: + plt.grid(True) + if show_axes: + plt.axhline(0, color='black') + plt.axvline(0, color='black') + if show: + plt.show() + return + + return plt + + +def bode_plot(system, initial_exp=-5, final_exp=5, + grid=True, show_axes=False, show=True, freq_unit='rad/sec', phase_unit='rad', phase_unwrap=True, **kwargs): + r""" + Returns the Bode phase and magnitude plots of a continuous-time system. + + Parameters + ========== + + system : SISOLinearTimeInvariant type + The LTI SISO system for which the Bode Plot is to be computed. + initial_exp : Number, optional + The initial exponent of 10 of the semilog plot. Defaults to -5. + final_exp : Number, optional + The final exponent of 10 of the semilog plot. Defaults to 5. + show : boolean, optional + If ``True``, the plot will be displayed otherwise + the equivalent matplotlib ``plot`` object will be returned. + Defaults to True. + prec : int, optional + The decimal point precision for the point coordinate values. + Defaults to 8. + grid : boolean, optional + If ``True``, the plot will have a grid. Defaults to True. + show_axes : boolean, optional + If ``True``, the coordinate axes will be shown. Defaults to False. + freq_unit : string, optional + User can choose between ``'rad/sec'`` (radians/second) and ``'Hz'`` (Hertz) as frequency units. + phase_unit : string, optional + User can choose between ``'rad'`` (radians) and ``'deg'`` (degree) as phase units. + + Examples + ======== + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction + >>> from sympy.physics.control.control_plots import bode_plot + >>> tf1 = TransferFunction(1*s**2 + 0.1*s + 7.5, 1*s**4 + 0.12*s**3 + 9*s**2, s) + >>> bode_plot(tf1, initial_exp=0.2, final_exp=0.7) # doctest: +SKIP + + See Also + ======== + + bode_magnitude_plot, bode_phase_plot + + """ + plt.subplot(211) + mag = bode_magnitude_plot(system, initial_exp=initial_exp, final_exp=final_exp, + show=False, grid=grid, show_axes=show_axes, + freq_unit=freq_unit, **kwargs) + mag.title(f'Bode Plot of ${latex(system)}$', pad=20) + mag.xlabel(None) + plt.subplot(212) + bode_phase_plot(system, initial_exp=initial_exp, final_exp=final_exp, + show=False, grid=grid, show_axes=show_axes, freq_unit=freq_unit, phase_unit=phase_unit, phase_unwrap=phase_unwrap, **kwargs).title(None) + + if show: + plt.show() + return + + return plt + + +def nyquist_plot_expr(system): + """Function to get the expression for Nyquist plot.""" + s = system.var + w = Dummy('w', real=True) + repl = I * w + expr = system.to_expr() + w_expr = expr.subs({s: repl}) + w_expr = w_expr.as_real_imag() + real_expr = w_expr[0] + imag_expr = w_expr[1] + return real_expr, imag_expr, w + + +def nichols_plot_expr(system): + """Function to get the expression for Nichols plot.""" + s = system.var + w = Dummy('w', real=True) + sys_expr = system.to_expr() + H_jw = sys_expr.subs(s, I*w) + mag_expr = Abs(H_jw) + mag_dB_expr = 20*log(mag_expr, 10) + phase_expr = arg(H_jw) + phase_deg_expr = deg(phase_expr) + return mag_dB_expr, phase_deg_expr, w + + +def nyquist_plot(system, initial_omega=0.01, final_omega=100, show=True, + color='b', **kwargs): + r""" + Generates the Nyquist plot for a continuous-time system. + + Parameters + ========== + + system : SISOLinearTimeInvariant + The LTI SISO system for which the Nyquist plot is to be generated. + initial_omega : float, optional + The starting frequency value. Defaults to 0.01. + final_omega : float, optional + The ending frequency value. Defaults to 100. + show : bool, optional + If True, the plot is displayed. Default is True. + color : str, optional + The color of the Nyquist plot. Default is 'b' (blue). + grid : bool, optional + If True, grid lines are displayed. Default is False. + **kwargs + Additional keyword arguments for customization. + + Examples + ======== + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction + >>> from sympy.physics.control.control_plots import nyquist_plot + >>> tf1 = TransferFunction(2*s**2 + 5*s + 1, s**2 + 2*s + 3, s) + >>> nyquist_plot(tf1) # doctest: +SKIP + + See Also + ======== + + nichols_plot, bode_plot + + """ + _check_system(system) + real_expr, imag_expr, w = nyquist_plot_expr(system) + w_values = [(w, initial_omega, final_omega)] + p = plot_parametric( + (real_expr, imag_expr), # The curve + (real_expr, -imag_expr), # Its mirror image + *w_values, + show=False, + line_color=color, + adaptive=True, + title=f'Nyquist Plot of ${latex(system)}$', + xlabel='Real Axis', + ylabel='Imaginary Axis', + size=(6, 5), + kwargs=kwargs) + if show: + p.show() + return + return p + + +def nichols_plot(system, initial_omega=0.01, final_omega=100, show=True, color='b', **kwargs): + r""" + Generates the Nichols plot for a LTI system. + + Parameters + ========== + + system : SISOLinearTimeInvariant + The LTI SISO system for which the Nyquist plot is to be generated. + initial_omega : float, optional + The starting frequency value. Defaults to 0.01. + final_omega : float, optional + The ending frequency value. Defaults to 100. + show : bool, optional + If True, the plot is displayed. Default is True. + color : str, optional + The color of the Nyquist plot. Default is 'b' (blue). + grid : bool, optional + If True, grid lines are displayed. Default is False. + **kwargs + Additional keyword arguments for customization. + + Examples + ======== + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction + >>> from sympy.physics.control.control_plots import nichols_plot + >>> tf1 = TransferFunction(1.5, s**2+14*s+40.02, s) + >>> nichols_plot(tf1) # doctest: +SKIP + + See Also + ======== + + nyquist_plot, bode_plot + + """ + _check_system(system) + magnitude_dB_expr, phase_deg_expr, w = nichols_plot_expr(system) + w_values = [(w, initial_omega, final_omega)] + p = plot_parametric( + (phase_deg_expr, magnitude_dB_expr), + *w_values, + show=False, + line_color=color, + title=f'Nichols Plot of ${latex(system)}$', + xlabel='Phase [deg]', + ylabel='Magnitude [dB]', + size=(6,5), + kwargs=kwargs) + if show: + p.show() + return + return p diff --git a/phivenv/Lib/site-packages/sympy/physics/control/lti.py b/phivenv/Lib/site-packages/sympy/physics/control/lti.py new file mode 100644 index 0000000000000000000000000000000000000000..480a1ec71d8c4dd07a51d67304a0b6e20a90691e --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/control/lti.py @@ -0,0 +1,5001 @@ +from typing import Type +from sympy import Interval, numer, Rational, solveset +from sympy.core.add import Add +from sympy.core.basic import Basic +from sympy.core.containers import Tuple +from sympy.core.evalf import EvalfMixin +from sympy.core.expr import Expr +from sympy.core.function import expand +from sympy.core.logic import fuzzy_and +from sympy.core.mul import Mul +from sympy.core.numbers import I, pi, oo +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.core.symbol import Dummy, Symbol +from sympy.functions import Abs +from sympy.core.sympify import sympify, _sympify +from sympy.matrices import Matrix, ImmutableMatrix, ImmutableDenseMatrix, eye, ShapeError, zeros +from sympy.functions.elementary.exponential import (exp, log) +from sympy.matrices.expressions import MatMul, MatAdd +from sympy.polys import Poly, rootof +from sympy.polys.polyroots import roots +from sympy.polys.polytools import (cancel, degree) +from sympy.series import limit +from sympy.utilities.misc import filldedent +from sympy.solvers.ode.systems import linodesolve +from sympy.solvers.solveset import linsolve, linear_eq_to_matrix + +from mpmath.libmp.libmpf import prec_to_dps + +__all__ = ['TransferFunction', 'PIDController', 'Series', 'MIMOSeries', 'Parallel', 'MIMOParallel', + 'Feedback', 'MIMOFeedback', 'TransferFunctionMatrix', 'StateSpace', 'gbt', 'bilinear', 'forward_diff', 'backward_diff', + 'phase_margin', 'gain_margin'] + +def _roots(poly, var): + """ like roots, but works on higher-order polynomials. """ + r = roots(poly, var, multiple=True) + n = degree(poly) + if len(r) != n: + r = [rootof(poly, var, k) for k in range(n)] + return r + +def gbt(tf, sample_per, alpha): + r""" + Returns falling coefficients of H(z) from numerator and denominator. + + Explanation + =========== + + Where H(z) is the corresponding discretized transfer function, + discretized with the generalised bilinear transformation method. + H(z) is obtained from the continuous transfer function H(s) + by substituting $s(z) = \frac{z-1}{T(\alpha z + (1-\alpha))}$ into H(s), where T is the + sample period. + Coefficients are falling, i.e. $H(z) = \frac{az+b}{cz+d}$ is returned + as [a, b], [c, d]. + + Examples + ======== + + >>> from sympy.physics.control.lti import TransferFunction, gbt + >>> from sympy.abc import s, L, R, T + + >>> tf = TransferFunction(1, s*L + R, s) + >>> numZ, denZ = gbt(tf, T, 0.5) + >>> numZ + [T/(2*(L + R*T/2)), T/(2*(L + R*T/2))] + >>> denZ + [1, (-L + R*T/2)/(L + R*T/2)] + + >>> numZ, denZ = gbt(tf, T, 0) + >>> numZ + [T/L] + >>> denZ + [1, (-L + R*T)/L] + + >>> numZ, denZ = gbt(tf, T, 1) + >>> numZ + [T/(L + R*T), 0] + >>> denZ + [1, -L/(L + R*T)] + + >>> numZ, denZ = gbt(tf, T, 0.3) + >>> numZ + [3*T/(10*(L + 3*R*T/10)), 7*T/(10*(L + 3*R*T/10))] + >>> denZ + [1, (-L + 7*R*T/10)/(L + 3*R*T/10)] + + References + ========== + + .. [1] https://www.polyu.edu.hk/ama/profile/gfzhang/Research/ZCC09_IJC.pdf + """ + if not tf.is_SISO: + raise NotImplementedError("Not implemented for MIMO systems.") + + T = sample_per # and sample period T + s = tf.var + z = s # dummy discrete variable z + + np = tf.num.as_poly(s).all_coeffs() + dp = tf.den.as_poly(s).all_coeffs() + alpha = Rational(alpha).limit_denominator(1000) + + # The next line results from multiplying H(z) with z^N/z^N + N = max(len(np), len(dp)) - 1 + num = Add(*[ T**(N-i) * c * (z-1)**i * (alpha * z + 1 - alpha)**(N-i) for c, i in zip(np[::-1], range(len(np))) ]) + den = Add(*[ T**(N-i) * c * (z-1)**i * (alpha * z + 1 - alpha)**(N-i) for c, i in zip(dp[::-1], range(len(dp))) ]) + + num_coefs = num.as_poly(z).all_coeffs() + den_coefs = den.as_poly(z).all_coeffs() + + para = den_coefs[0] + num_coefs = [coef/para for coef in num_coefs] + den_coefs = [coef/para for coef in den_coefs] + + return num_coefs, den_coefs + +def bilinear(tf, sample_per): + r""" + Returns falling coefficients of H(z) from numerator and denominator. + + Explanation + =========== + + Where H(z) is the corresponding discretized transfer function, + discretized with the bilinear transform method. + H(z) is obtained from the continuous transfer function H(s) + by substituting $s(z) = \frac{2}{T}\frac{z-1}{z+1}$ into H(s), where T is the + sample period. + Coefficients are falling, i.e. $H(z) = \frac{az+b}{cz+d}$ is returned + as [a, b], [c, d]. + + Examples + ======== + + >>> from sympy.physics.control.lti import TransferFunction, bilinear + >>> from sympy.abc import s, L, R, T + + >>> tf = TransferFunction(1, s*L + R, s) + >>> numZ, denZ = bilinear(tf, T) + >>> numZ + [T/(2*(L + R*T/2)), T/(2*(L + R*T/2))] + >>> denZ + [1, (-L + R*T/2)/(L + R*T/2)] + """ + return gbt(tf, sample_per, S.Half) + +def forward_diff(tf, sample_per): + r""" + Returns falling coefficients of H(z) from numerator and denominator. + + Explanation + =========== + + Where H(z) is the corresponding discretized transfer function, + discretized with the forward difference transform method. + H(z) is obtained from the continuous transfer function H(s) + by substituting $s(z) = \frac{z-1}{T}$ into H(s), where T is the + sample period. + Coefficients are falling, i.e. $H(z) = \frac{az+b}{cz+d}$ is returned + as [a, b], [c, d]. + + Examples + ======== + + >>> from sympy.physics.control.lti import TransferFunction, forward_diff + >>> from sympy.abc import s, L, R, T + + >>> tf = TransferFunction(1, s*L + R, s) + >>> numZ, denZ = forward_diff(tf, T) + >>> numZ + [T/L] + >>> denZ + [1, (-L + R*T)/L] + """ + return gbt(tf, sample_per, S.Zero) + +def backward_diff(tf, sample_per): + r""" + Returns falling coefficients of H(z) from numerator and denominator. + + Explanation + =========== + + Where H(z) is the corresponding discretized transfer function, + discretized with the backward difference transform method. + H(z) is obtained from the continuous transfer function H(s) + by substituting $s(z) = \frac{z-1}{Tz}$ into H(s), where T is the + sample period. + Coefficients are falling, i.e. $H(z) = \frac{az+b}{cz+d}$ is returned + as [a, b], [c, d]. + + Examples + ======== + + >>> from sympy.physics.control.lti import TransferFunction, backward_diff + >>> from sympy.abc import s, L, R, T + + >>> tf = TransferFunction(1, s*L + R, s) + >>> numZ, denZ = backward_diff(tf, T) + >>> numZ + [T/(L + R*T), 0] + >>> denZ + [1, -L/(L + R*T)] + """ + return gbt(tf, sample_per, S.One) + +def phase_margin(system): + r""" + Returns the phase margin of a continuous time system. + Only applicable to Transfer Functions which can generate valid bode plots. + + Raises + ====== + + NotImplementedError + When time delay terms are present in the system. + + ValueError + When a SISO LTI system is not passed. + + When more than one free symbol is present in the system. + The only variable in the transfer function should be + the variable of the Laplace transform. + + Examples + ======== + + >>> from sympy.physics.control import TransferFunction, phase_margin + >>> from sympy.abc import s + + >>> tf = TransferFunction(1, s**3 + 2*s**2 + s, s) + >>> phase_margin(tf) + 180*(-pi + atan((-1 + (-2*18**(1/3)/(9 + sqrt(93))**(1/3) + 12**(1/3)*(9 + sqrt(93))**(1/3))**2/36)/(-12**(1/3)*(9 + sqrt(93))**(1/3)/3 + 2*18**(1/3)/(3*(9 + sqrt(93))**(1/3)))))/pi + 180 + >>> phase_margin(tf).n() + 21.3863897518751 + + >>> tf1 = TransferFunction(s**3, s**2 + 5*s, s) + >>> phase_margin(tf1) + -180 + 180*(atan(sqrt(2)*(-51/10 - sqrt(101)/10)*sqrt(1 + sqrt(101))/(2*(sqrt(101)/2 + 51/2))) + pi)/pi + >>> phase_margin(tf1).n() + -25.1783920627277 + + >>> tf2 = TransferFunction(1, s + 1, s) + >>> phase_margin(tf2) + -180 + + See Also + ======== + + gain_margin + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Phase_margin + + """ + from sympy.functions import arg + + if not isinstance(system, SISOLinearTimeInvariant): + raise ValueError("Margins are only applicable for SISO LTI systems.") + + _w = Dummy("w", real=True) + repl = I*_w + expr = system.to_expr() + len_free_symbols = len(expr.free_symbols) + if expr.has(exp): + raise NotImplementedError("Margins for systems with Time delay terms are not supported.") + elif len_free_symbols > 1: + raise ValueError("Extra degree of freedom found. Make sure" + " that there are no free symbols in the dynamical system other" + " than the variable of Laplace transform.") + + w_expr = expr.subs({system.var: repl}) + + mag = 20*log(Abs(w_expr), 10) + mag_sol = list(solveset(mag, _w, Interval(0, oo, left_open=True))) + + if (len(mag_sol) == 0): + pm = S(-180) + else: + wcp = mag_sol[0] + pm = ((arg(w_expr)*S(180)/pi).subs({_w:wcp}) + S(180)) % 360 + + if(pm >= 180): + pm = pm - 360 + + return pm + +def gain_margin(system): + r""" + Returns the gain margin of a continuous time system. + Only applicable to Transfer Functions which can generate valid bode plots. + + Raises + ====== + + NotImplementedError + When time delay terms are present in the system. + + ValueError + When a SISO LTI system is not passed. + + When more than one free symbol is present in the system. + The only variable in the transfer function should be + the variable of the Laplace transform. + + Examples + ======== + + >>> from sympy.physics.control import TransferFunction, gain_margin + >>> from sympy.abc import s + + >>> tf = TransferFunction(1, s**3 + 2*s**2 + s, s) + >>> gain_margin(tf) + 20*log(2)/log(10) + >>> gain_margin(tf).n() + 6.02059991327962 + + >>> tf1 = TransferFunction(s**3, s**2 + 5*s, s) + >>> gain_margin(tf1) + oo + + See Also + ======== + + phase_margin + + References + ========== + + https://en.wikipedia.org/wiki/Bode_plot + + """ + if not isinstance(system, SISOLinearTimeInvariant): + raise ValueError("Margins are only applicable for SISO LTI systems.") + + _w = Dummy("w", real=True) + repl = I*_w + expr = system.to_expr() + len_free_symbols = len(expr.free_symbols) + if expr.has(exp): + raise NotImplementedError("Margins for systems with Time delay terms are not supported.") + elif len_free_symbols > 1: + raise ValueError("Extra degree of freedom found. Make sure" + " that there are no free symbols in the dynamical system other" + " than the variable of Laplace transform.") + + w_expr = expr.subs({system.var: repl}) + + mag = 20*log(Abs(w_expr), 10) + phase = w_expr + phase_sol = list(solveset(numer(phase.as_real_imag()[1].cancel()),_w, Interval(0, oo, left_open = True))) + + if (len(phase_sol) == 0): + gm = oo + else: + wcg = phase_sol[0] + gm = -mag.subs({_w:wcg}) + + return gm + +class LinearTimeInvariant(Basic, EvalfMixin): + """A common class for all the Linear Time-Invariant Dynamical Systems.""" + + _clstype: Type + + # Users should not directly interact with this class. + def __new__(cls, *system, **kwargs): + if cls is LinearTimeInvariant: + raise NotImplementedError('The LTICommon class is not meant to be used directly.') + return super(LinearTimeInvariant, cls).__new__(cls, *system, **kwargs) + + @classmethod + def _check_args(cls, args): + if not args: + raise ValueError("At least 1 argument must be passed.") + if not all(isinstance(arg, cls._clstype) for arg in args): + raise TypeError(f"All arguments must be of type {cls._clstype}.") + var_set = {arg.var for arg in args} + if len(var_set) != 1: + raise ValueError(filldedent(f""" + All transfer functions should use the same complex variable + of the Laplace transform. {len(var_set)} different + values found.""")) + + @property + def is_SISO(self): + """Returns `True` if the passed LTI system is SISO else returns False.""" + return self._is_SISO + + +class SISOLinearTimeInvariant(LinearTimeInvariant): + """A common class for all the SISO Linear Time-Invariant Dynamical Systems.""" + # Users should not directly interact with this class. + + @property + def num_inputs(self): + """Return the number of inputs for SISOLinearTimeInvariant.""" + return 1 + + @property + def num_outputs(self): + """Return the number of outputs for SISOLinearTimeInvariant.""" + return 1 + + _is_SISO = True + + +class MIMOLinearTimeInvariant(LinearTimeInvariant): + """A common class for all the MIMO Linear Time-Invariant Dynamical Systems.""" + # Users should not directly interact with this class. + _is_SISO = False + + +SISOLinearTimeInvariant._clstype = SISOLinearTimeInvariant +MIMOLinearTimeInvariant._clstype = MIMOLinearTimeInvariant + + +def _check_other_SISO(func): + def wrapper(*args, **kwargs): + if not isinstance(args[-1], SISOLinearTimeInvariant): + return NotImplemented + else: + return func(*args, **kwargs) + return wrapper + + +def _check_other_MIMO(func): + def wrapper(*args, **kwargs): + if not isinstance(args[-1], MIMOLinearTimeInvariant): + return NotImplemented + else: + return func(*args, **kwargs) + return wrapper + + +class TransferFunction(SISOLinearTimeInvariant): + r""" + A class for representing LTI (Linear, time-invariant) systems that can be strictly described + by ratio of polynomials in the Laplace transform complex variable. The arguments + are ``num``, ``den``, and ``var``, where ``num`` and ``den`` are numerator and + denominator polynomials of the ``TransferFunction`` respectively, and the third argument is + a complex variable of the Laplace transform used by these polynomials of the transfer function. + ``num`` and ``den`` can be either polynomials or numbers, whereas ``var`` + has to be a :py:class:`~.Symbol`. + + Explanation + =========== + + Generally, a dynamical system representing a physical model can be described in terms of Linear + Ordinary Differential Equations like - + + $b_{m}y^{\left(m\right)}+b_{m-1}y^{\left(m-1\right)}+\dots+b_{1}y^{\left(1\right)}+b_{0}y= + a_{n}x^{\left(n\right)}+a_{n-1}x^{\left(n-1\right)}+\dots+a_{1}x^{\left(1\right)}+a_{0}x$ + + Here, $x$ is the input signal and $y$ is the output signal and superscript on both is the order of derivative + (not exponent). Derivative is taken with respect to the independent variable, $t$. Also, generally $m$ is greater + than $n$. + + It is not feasible to analyse the properties of such systems in their native form therefore, we use + mathematical tools like Laplace transform to get a better perspective. Taking the Laplace transform + of both the sides in the equation (at zero initial conditions), we get - + + $\mathcal{L}[b_{m}y^{\left(m\right)}+b_{m-1}y^{\left(m-1\right)}+\dots+b_{1}y^{\left(1\right)}+b_{0}y]= + \mathcal{L}[a_{n}x^{\left(n\right)}+a_{n-1}x^{\left(n-1\right)}+\dots+a_{1}x^{\left(1\right)}+a_{0}x]$ + + Using the linearity property of Laplace transform and also considering zero initial conditions + (i.e. $y(0^{-}) = 0$, $y'(0^{-}) = 0$ and so on), the equation + above gets translated to - + + $b_{m}\mathcal{L}[y^{\left(m\right)}]+\dots+b_{1}\mathcal{L}[y^{\left(1\right)}]+b_{0}\mathcal{L}[y]= + a_{n}\mathcal{L}[x^{\left(n\right)}]+\dots+a_{1}\mathcal{L}[x^{\left(1\right)}]+a_{0}\mathcal{L}[x]$ + + Now, applying Derivative property of Laplace transform, + + $b_{m}s^{m}\mathcal{L}[y]+\dots+b_{1}s\mathcal{L}[y]+b_{0}\mathcal{L}[y]= + a_{n}s^{n}\mathcal{L}[x]+\dots+a_{1}s\mathcal{L}[x]+a_{0}\mathcal{L}[x]$ + + Here, the superscript on $s$ is **exponent**. Note that the zero initial conditions assumption, mentioned above, is very important + and cannot be ignored otherwise the dynamical system cannot be considered time-independent and the simplified equation above + cannot be reached. + + Collecting $\mathcal{L}[y]$ and $\mathcal{L}[x]$ terms from both the sides and taking the ratio + $\frac{ \mathcal{L}\left\{y\right\} }{ \mathcal{L}\left\{x\right\} }$, we get the typical rational form of transfer + function. + + The numerator of the transfer function is, therefore, the Laplace transform of the output signal + (The signals are represented as functions of time) and similarly, the denominator + of the transfer function is the Laplace transform of the input signal. It is also a convention + to denote the input and output signal's Laplace transform with capital alphabets like shown below. + + $H(s) = \frac{Y(s)}{X(s)} = \frac{ \mathcal{L}\left\{y(t)\right\} }{ \mathcal{L}\left\{x(t)\right\} }$ + + $s$, also known as complex frequency, is a complex variable in the Laplace domain. It corresponds to the + equivalent variable $t$, in the time domain. Transfer functions are sometimes also referred to as the Laplace + transform of the system's impulse response. Transfer function, $H$, is represented as a rational + function in $s$ like, + + $H(s) =\ \frac{a_{n}s^{n}+a_{n-1}s^{n-1}+\dots+a_{1}s+a_{0}}{b_{m}s^{m}+b_{m-1}s^{m-1}+\dots+b_{1}s+b_{0}}$ + + Parameters + ========== + + num : Expr, Number + The numerator polynomial of the transfer function. + den : Expr, Number + The denominator polynomial of the transfer function. + var : Symbol + Complex variable of the Laplace transform used by the + polynomials of the transfer function. + + Raises + ====== + + TypeError + When ``var`` is not a Symbol or when ``num`` or ``den`` is not a + number or a polynomial. + ValueError + When ``den`` is zero. + + Examples + ======== + + >>> from sympy.abc import s, p, a + >>> from sympy.physics.control.lti import TransferFunction + >>> tf1 = TransferFunction(s + a, s**2 + s + 1, s) + >>> tf1 + TransferFunction(a + s, s**2 + s + 1, s) + >>> tf1.num + a + s + >>> tf1.den + s**2 + s + 1 + >>> tf1.var + s + >>> tf1.args + (a + s, s**2 + s + 1, s) + + Any complex variable can be used for ``var``. + + >>> tf2 = TransferFunction(a*p**3 - a*p**2 + s*p, p + a**2, p) + >>> tf2 + TransferFunction(a*p**3 - a*p**2 + p*s, a**2 + p, p) + >>> tf3 = TransferFunction((p + 3)*(p - 1), (p - 1)*(p + 5), p) + >>> tf3 + TransferFunction((p - 1)*(p + 3), (p - 1)*(p + 5), p) + + To negate a transfer function the ``-`` operator can be prepended: + + >>> tf4 = TransferFunction(-a + s, p**2 + s, p) + >>> -tf4 + TransferFunction(a - s, p**2 + s, p) + >>> tf5 = TransferFunction(s**4 - 2*s**3 + 5*s + 4, s + 4, s) + >>> -tf5 + TransferFunction(-s**4 + 2*s**3 - 5*s - 4, s + 4, s) + + You can use a float or an integer (or other constants) as numerator and denominator: + + >>> tf6 = TransferFunction(1/2, 4, s) + >>> tf6.num + 0.500000000000000 + >>> tf6.den + 4 + >>> tf6.var + s + >>> tf6.args + (0.5, 4, s) + + You can take the integer power of a transfer function using the ``**`` operator: + + >>> tf7 = TransferFunction(s + a, s - a, s) + >>> tf7**3 + TransferFunction((a + s)**3, (-a + s)**3, s) + >>> tf7**0 + TransferFunction(1, 1, s) + >>> tf8 = TransferFunction(p + 4, p - 3, p) + >>> tf8**-1 + TransferFunction(p - 3, p + 4, p) + + Addition, subtraction, and multiplication of transfer functions can form + unevaluated ``Series`` or ``Parallel`` objects. + + >>> tf9 = TransferFunction(s + 1, s**2 + s + 1, s) + >>> tf10 = TransferFunction(s - p, s + 3, s) + >>> tf11 = TransferFunction(4*s**2 + 2*s - 4, s - 1, s) + >>> tf12 = TransferFunction(1 - s, s**2 + 4, s) + >>> tf9 + tf10 + Parallel(TransferFunction(s + 1, s**2 + s + 1, s), TransferFunction(-p + s, s + 3, s)) + >>> tf10 - tf11 + Parallel(TransferFunction(-p + s, s + 3, s), TransferFunction(-4*s**2 - 2*s + 4, s - 1, s)) + >>> tf9 * tf10 + Series(TransferFunction(s + 1, s**2 + s + 1, s), TransferFunction(-p + s, s + 3, s)) + >>> tf10 - (tf9 + tf12) + Parallel(TransferFunction(-p + s, s + 3, s), TransferFunction(-s - 1, s**2 + s + 1, s), TransferFunction(s - 1, s**2 + 4, s)) + >>> tf10 - (tf9 * tf12) + Parallel(TransferFunction(-p + s, s + 3, s), Series(TransferFunction(-1, 1, s), TransferFunction(s + 1, s**2 + s + 1, s), TransferFunction(1 - s, s**2 + 4, s))) + >>> tf11 * tf10 * tf9 + Series(TransferFunction(4*s**2 + 2*s - 4, s - 1, s), TransferFunction(-p + s, s + 3, s), TransferFunction(s + 1, s**2 + s + 1, s)) + >>> tf9 * tf11 + tf10 * tf12 + Parallel(Series(TransferFunction(s + 1, s**2 + s + 1, s), TransferFunction(4*s**2 + 2*s - 4, s - 1, s)), Series(TransferFunction(-p + s, s + 3, s), TransferFunction(1 - s, s**2 + 4, s))) + >>> (tf9 + tf12) * (tf10 + tf11) + Series(Parallel(TransferFunction(s + 1, s**2 + s + 1, s), TransferFunction(1 - s, s**2 + 4, s)), Parallel(TransferFunction(-p + s, s + 3, s), TransferFunction(4*s**2 + 2*s - 4, s - 1, s))) + + These unevaluated ``Series`` or ``Parallel`` objects can convert into the + resultant transfer function using ``.doit()`` method or by ``.rewrite(TransferFunction)``. + + >>> ((tf9 + tf10) * tf12).doit() + TransferFunction((1 - s)*((-p + s)*(s**2 + s + 1) + (s + 1)*(s + 3)), (s + 3)*(s**2 + 4)*(s**2 + s + 1), s) + >>> (tf9 * tf10 - tf11 * tf12).rewrite(TransferFunction) + TransferFunction(-(1 - s)*(s + 3)*(s**2 + s + 1)*(4*s**2 + 2*s - 4) + (-p + s)*(s - 1)*(s + 1)*(s**2 + 4), (s - 1)*(s + 3)*(s**2 + 4)*(s**2 + s + 1), s) + + See Also + ======== + + Feedback, Series, Parallel + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Transfer_function + .. [2] https://en.wikipedia.org/wiki/Laplace_transform + + """ + def __new__(cls, num, den, var): + num, den = _sympify(num), _sympify(den) + + if not isinstance(var, Symbol): + raise TypeError("Variable input must be a Symbol.") + + if den == 0: + raise ValueError("TransferFunction cannot have a zero denominator.") + + if (((isinstance(num, (Expr, TransferFunction, Series, Parallel)) and num.has(Symbol)) or num.is_number) and + ((isinstance(den, (Expr, TransferFunction, Series, Parallel)) and den.has(Symbol)) or den.is_number)): + cls.is_StateSpace_object = False + return super(TransferFunction, cls).__new__(cls, num, den, var) + + else: + raise TypeError("Unsupported type for numerator or denominator of TransferFunction.") + + @classmethod + def from_rational_expression(cls, expr, var=None): + r""" + Creates a new ``TransferFunction`` efficiently from a rational expression. + + Parameters + ========== + + expr : Expr, Number + The rational expression representing the ``TransferFunction``. + var : Symbol, optional + Complex variable of the Laplace transform used by the + polynomials of the transfer function. + + Raises + ====== + + ValueError + When ``expr`` is of type ``Number`` and optional parameter ``var`` + is not passed. + + When ``expr`` has more than one variables and an optional parameter + ``var`` is not passed. + ZeroDivisionError + When denominator of ``expr`` is zero or it has ``ComplexInfinity`` + in its numerator. + + Examples + ======== + + >>> from sympy.abc import s, p, a + >>> from sympy.physics.control.lti import TransferFunction + >>> expr1 = (s + 5)/(3*s**2 + 2*s + 1) + >>> tf1 = TransferFunction.from_rational_expression(expr1) + >>> tf1 + TransferFunction(s + 5, 3*s**2 + 2*s + 1, s) + >>> expr2 = (a*p**3 - a*p**2 + s*p)/(p + a**2) # Expr with more than one variables + >>> tf2 = TransferFunction.from_rational_expression(expr2, p) + >>> tf2 + TransferFunction(a*p**3 - a*p**2 + p*s, a**2 + p, p) + + In case of conflict between two or more variables in a expression, SymPy will + raise a ``ValueError``, if ``var`` is not passed by the user. + + >>> tf = TransferFunction.from_rational_expression((a + a*s)/(s**2 + s + 1)) + Traceback (most recent call last): + ... + ValueError: Conflicting values found for positional argument `var` ({a, s}). Specify it manually. + + This can be corrected by specifying the ``var`` parameter manually. + + >>> tf = TransferFunction.from_rational_expression((a + a*s)/(s**2 + s + 1), s) + >>> tf + TransferFunction(a*s + a, s**2 + s + 1, s) + + ``var`` also need to be specified when ``expr`` is a ``Number`` + + >>> tf3 = TransferFunction.from_rational_expression(10, s) + >>> tf3 + TransferFunction(10, 1, s) + + """ + expr = _sympify(expr) + if var is None: + _free_symbols = expr.free_symbols + _len_free_symbols = len(_free_symbols) + if _len_free_symbols == 1: + var = list(_free_symbols)[0] + elif _len_free_symbols == 0: + raise ValueError(filldedent(""" + Positional argument `var` not found in the + TransferFunction defined. Specify it manually.""")) + else: + raise ValueError(filldedent(""" + Conflicting values found for positional argument `var` ({}). + Specify it manually.""".format(_free_symbols))) + + _num, _den = expr.as_numer_denom() + if _den == 0 or _num.has(S.ComplexInfinity): + raise ZeroDivisionError("TransferFunction cannot have a zero denominator.") + return cls(_num, _den, var) + + @classmethod + def from_coeff_lists(cls, num_list, den_list, var): + r""" + Creates a new ``TransferFunction`` efficiently from a list of coefficients. + + Parameters + ========== + + num_list : Sequence + Sequence comprising of numerator coefficients. + den_list : Sequence + Sequence comprising of denominator coefficients. + var : Symbol + Complex variable of the Laplace transform used by the + polynomials of the transfer function. + + Raises + ====== + + ZeroDivisionError + When the constructed denominator is zero. + + Examples + ======== + + >>> from sympy.abc import s, p + >>> from sympy.physics.control.lti import TransferFunction + >>> num = [1, 0, 2] + >>> den = [3, 2, 2, 1] + >>> tf = TransferFunction.from_coeff_lists(num, den, s) + >>> tf + TransferFunction(s**2 + 2, 3*s**3 + 2*s**2 + 2*s + 1, s) + >>> #Create a Transfer Function with more than one variable + >>> tf1 = TransferFunction.from_coeff_lists([p, 1], [2*p, 0, 4], s) + >>> tf1 + TransferFunction(p*s + 1, 2*p*s**2 + 4, s) + + """ + num_list = num_list[::-1] + den_list = den_list[::-1] + num_var_powers = [var**i for i in range(len(num_list))] + den_var_powers = [var**i for i in range(len(den_list))] + + _num = sum(coeff * var_power for coeff, var_power in zip(num_list, num_var_powers)) + _den = sum(coeff * var_power for coeff, var_power in zip(den_list, den_var_powers)) + + if _den == 0: + raise ZeroDivisionError("TransferFunction cannot have a zero denominator.") + + return cls(_num, _den, var) + + @classmethod + def from_zpk(cls, zeros, poles, gain, var): + r""" + Creates a new ``TransferFunction`` from given zeros, poles and gain. + + Parameters + ========== + + zeros : Sequence + Sequence comprising of zeros of transfer function. + poles : Sequence + Sequence comprising of poles of transfer function. + gain : Number, Symbol, Expression + A scalar value specifying gain of the model. + var : Symbol + Complex variable of the Laplace transform used by the + polynomials of the transfer function. + + Examples + ======== + + >>> from sympy.abc import s, p, k + >>> from sympy.physics.control.lti import TransferFunction + >>> zeros = [1, 2, 3] + >>> poles = [6, 5, 4] + >>> gain = 7 + >>> tf = TransferFunction.from_zpk(zeros, poles, gain, s) + >>> tf + TransferFunction(7*(s - 3)*(s - 2)*(s - 1), (s - 6)*(s - 5)*(s - 4), s) + >>> #Create a Transfer Function with variable poles and zeros + >>> tf1 = TransferFunction.from_zpk([p, k], [p + k, p - k], 2, s) + >>> tf1 + TransferFunction(2*(-k + s)*(-p + s), (-k - p + s)*(k - p + s), s) + >>> #Complex poles or zeros are acceptable + >>> tf2 = TransferFunction.from_zpk([0], [1-1j, 1+1j, 2], -2, s) + >>> tf2 + TransferFunction(-2*s, (s - 2)*(s - 1.0 - 1.0*I)*(s - 1.0 + 1.0*I), s) + + """ + num_poly = 1 + den_poly = 1 + for zero in zeros: + num_poly *= var - zero + for pole in poles: + den_poly *= var - pole + + return cls(gain*num_poly, den_poly, var) + + @property + def num(self): + """ + Returns the numerator polynomial of the transfer function. + + Examples + ======== + + >>> from sympy.abc import s, p + >>> from sympy.physics.control.lti import TransferFunction + >>> G1 = TransferFunction(s**2 + p*s + 3, s - 4, s) + >>> G1.num + p*s + s**2 + 3 + >>> G2 = TransferFunction((p + 5)*(p - 3), (p - 3)*(p + 1), p) + >>> G2.num + (p - 3)*(p + 5) + + """ + return self.args[0] + + @property + def den(self): + """ + Returns the denominator polynomial of the transfer function. + + Examples + ======== + + >>> from sympy.abc import s, p + >>> from sympy.physics.control.lti import TransferFunction + >>> G1 = TransferFunction(s + 4, p**3 - 2*p + 4, s) + >>> G1.den + p**3 - 2*p + 4 + >>> G2 = TransferFunction(3, 4, s) + >>> G2.den + 4 + + """ + return self.args[1] + + @property + def var(self): + """ + Returns the complex variable of the Laplace transform used by the polynomials of + the transfer function. + + Examples + ======== + + >>> from sympy.abc import s, p + >>> from sympy.physics.control.lti import TransferFunction + >>> G1 = TransferFunction(p**2 + 2*p + 4, p - 6, p) + >>> G1.var + p + >>> G2 = TransferFunction(0, s - 5, s) + >>> G2.var + s + + """ + return self.args[2] + + def _eval_subs(self, old, new): + arg_num = self.num.subs(old, new) + arg_den = self.den.subs(old, new) + argnew = TransferFunction(arg_num, arg_den, self.var) + return self if old == self.var else argnew + + def _eval_evalf(self, prec): + return TransferFunction( + self.num._eval_evalf(prec), + self.den._eval_evalf(prec), + self.var) + + def _eval_simplify(self, **kwargs): + tf = cancel(Mul(self.num, 1/self.den, evaluate=False), expand=False).as_numer_denom() + num_, den_ = tf[0], tf[1] + return TransferFunction(num_, den_, self.var) + + def _eval_rewrite_as_StateSpace(self, *args): + """ + Returns the equivalent space model of the transfer function model. + The state space model will be returned in the controllable canonical form. + + Unlike the space state to transfer function model conversion, the transfer function + to state space model conversion is not unique. There can be multiple state space + representations of a given transfer function model. + + Examples + ======== + + >>> from sympy.abc import s + >>> from sympy.physics.control import TransferFunction, StateSpace + >>> tf = TransferFunction(s**2 + 1, s**3 + 2*s + 10, s) + >>> tf.rewrite(StateSpace) + StateSpace(Matrix([ + [ 0, 1, 0], + [ 0, 0, 1], + [-10, -2, 0]]), Matrix([ + [0], + [0], + [1]]), Matrix([[1, 0, 1]]), Matrix([[0]])) + + """ + if not self.is_proper: + raise ValueError("Transfer Function must be proper.") + + num_poly = Poly(self.num, self.var) + den_poly = Poly(self.den, self.var) + n = den_poly.degree() + + num_coeffs = num_poly.all_coeffs() + den_coeffs = den_poly.all_coeffs() + diff = n - num_poly.degree() + num_coeffs = [0]*diff + num_coeffs + + a = den_coeffs[1:] + a_mat = Matrix([[(-1)*coefficient/den_coeffs[0] for coefficient in reversed(a)]]) + vert = zeros(n-1, 1) + mat = eye(n-1) + A = vert.row_join(mat) + A = A.col_join(a_mat) + + B = zeros(n, 1) + B[n-1] = 1 + + i = n + C = [] + while(i > 0): + C.append(num_coeffs[i] - den_coeffs[i]*num_coeffs[0]) + i -= 1 + C = Matrix([C]) + + D = Matrix([num_coeffs[0]]) + + return StateSpace(A, B, C, D) + + def expand(self): + """ + Returns the transfer function with numerator and denominator + in expanded form. + + Examples + ======== + + >>> from sympy.abc import s, p, a, b + >>> from sympy.physics.control.lti import TransferFunction + >>> G1 = TransferFunction((a - s)**2, (s**2 + a)**2, s) + >>> G1.expand() + TransferFunction(a**2 - 2*a*s + s**2, a**2 + 2*a*s**2 + s**4, s) + >>> G2 = TransferFunction((p + 3*b)*(p - b), (p - b)*(p + 2*b), p) + >>> G2.expand() + TransferFunction(-3*b**2 + 2*b*p + p**2, -2*b**2 + b*p + p**2, p) + + """ + return TransferFunction(expand(self.num), expand(self.den), self.var) + + def dc_gain(self): + """ + Computes the gain of the response as the frequency approaches zero. + + The DC gain is infinite for systems with pure integrators. + + Examples + ======== + + >>> from sympy.abc import s, p, a, b + >>> from sympy.physics.control.lti import TransferFunction + >>> tf1 = TransferFunction(s + 3, s**2 - 9, s) + >>> tf1.dc_gain() + -1/3 + >>> tf2 = TransferFunction(p**2, p - 3 + p**3, p) + >>> tf2.dc_gain() + 0 + >>> tf3 = TransferFunction(a*p**2 - b, s + b, s) + >>> tf3.dc_gain() + (a*p**2 - b)/b + >>> tf4 = TransferFunction(1, s, s) + >>> tf4.dc_gain() + oo + + """ + m = Mul(self.num, Pow(self.den, -1, evaluate=False), evaluate=False) + return limit(m, self.var, 0) + + def poles(self): + """ + Returns the poles of a transfer function. + + Examples + ======== + + >>> from sympy.abc import s, p, a + >>> from sympy.physics.control.lti import TransferFunction + >>> tf1 = TransferFunction((p + 3)*(p - 1), (p - 1)*(p + 5), p) + >>> tf1.poles() + [-5, 1] + >>> tf2 = TransferFunction((1 - s)**2, (s**2 + 1)**2, s) + >>> tf2.poles() + [I, I, -I, -I] + >>> tf3 = TransferFunction(s**2, a*s + p, s) + >>> tf3.poles() + [-p/a] + + """ + return _roots(Poly(self.den, self.var), self.var) + + def zeros(self): + """ + Returns the zeros of a transfer function. + + Examples + ======== + + >>> from sympy.abc import s, p, a + >>> from sympy.physics.control.lti import TransferFunction + >>> tf1 = TransferFunction((p + 3)*(p - 1), (p - 1)*(p + 5), p) + >>> tf1.zeros() + [-3, 1] + >>> tf2 = TransferFunction((1 - s)**2, (s**2 + 1)**2, s) + >>> tf2.zeros() + [1, 1] + >>> tf3 = TransferFunction(s**2, a*s + p, s) + >>> tf3.zeros() + [0, 0] + + """ + return _roots(Poly(self.num, self.var), self.var) + + def eval_frequency(self, other): + """ + Returns the system response at any point in the real or complex plane. + + Examples + ======== + + >>> from sympy.abc import s, p, a + >>> from sympy.physics.control.lti import TransferFunction + >>> from sympy import I + >>> tf1 = TransferFunction(1, s**2 + 2*s + 1, s) + >>> omega = 0.1 + >>> tf1.eval_frequency(I*omega) + 1/(0.99 + 0.2*I) + >>> tf2 = TransferFunction(s**2, a*s + p, s) + >>> tf2.eval_frequency(2) + 4/(2*a + p) + >>> tf2.eval_frequency(I*2) + -4/(2*I*a + p) + """ + arg_num = self.num.subs(self.var, other) + arg_den = self.den.subs(self.var, other) + argnew = TransferFunction(arg_num, arg_den, self.var).to_expr() + return argnew.expand() + + def is_stable(self): + """ + Returns True if the transfer function is asymptotically stable; else False. + + This would not check the marginal or conditional stability of the system. + + Examples + ======== + + >>> from sympy.abc import s, p, a + >>> from sympy import symbols + >>> from sympy.physics.control.lti import TransferFunction + >>> q, r = symbols('q, r', negative=True) + >>> tf1 = TransferFunction((1 - s)**2, (s + 1)**2, s) + >>> tf1.is_stable() + True + >>> tf2 = TransferFunction((1 - p)**2, (s**2 + 1)**2, s) + >>> tf2.is_stable() + False + >>> tf3 = TransferFunction(4, q*s - r, s) + >>> tf3.is_stable() + False + >>> tf4 = TransferFunction(p + 1, a*p - s**2, p) + >>> tf4.is_stable() is None # Not enough info about the symbols to determine stability + True + + """ + return fuzzy_and(pole.as_real_imag()[0].is_negative for pole in self.poles()) + + def __add__(self, other): + if hasattr(other, "is_StateSpace_object") and other.is_StateSpace_object: + return Parallel(self, other) + elif isinstance(other, (TransferFunction, Series, Feedback)): + if not self.var == other.var: + raise ValueError(filldedent(""" + All the transfer functions should use the same complex variable + of the Laplace transform.""")) + return Parallel(self, other) + elif isinstance(other, Parallel): + if not self.var == other.var: + raise ValueError(filldedent(""" + All the transfer functions should use the same complex variable + of the Laplace transform.""")) + arg_list = list(other.args) + return Parallel(self, *arg_list) + else: + raise ValueError("TransferFunction cannot be added with {}.". + format(type(other))) + + def __radd__(self, other): + return self + other + + def __sub__(self, other): + if hasattr(other, "is_StateSpace_object") and other.is_StateSpace_object: + return Parallel(self, -other) + elif isinstance(other, (TransferFunction, Series)): + if not self.var == other.var: + raise ValueError(filldedent(""" + All the transfer functions should use the same complex variable + of the Laplace transform.""")) + return Parallel(self, -other) + elif isinstance(other, Parallel): + if not self.var == other.var: + raise ValueError(filldedent(""" + All the transfer functions should use the same complex variable + of the Laplace transform.""")) + arg_list = [-i for i in list(other.args)] + return Parallel(self, *arg_list) + else: + raise ValueError("{} cannot be subtracted from a TransferFunction." + .format(type(other))) + + def __rsub__(self, other): + return -self + other + + def __mul__(self, other): + if hasattr(other, "is_StateSpace_object") and other.is_StateSpace_object: + return Series(self, other) + elif isinstance(other, (TransferFunction, Parallel, Feedback)): + if not self.var == other.var: + raise ValueError(filldedent(""" + All the transfer functions should use the same complex variable + of the Laplace transform.""")) + return Series(self, other) + elif isinstance(other, Series): + if not self.var == other.var: + raise ValueError(filldedent(""" + All the transfer functions should use the same complex variable + of the Laplace transform.""")) + arg_list = list(other.args) + return Series(self, *arg_list) + else: + raise ValueError("TransferFunction cannot be multiplied with {}." + .format(type(other))) + + __rmul__ = __mul__ + + def __truediv__(self, other): + if isinstance(other, TransferFunction): + if not self.var == other.var: + raise ValueError(filldedent(""" + All the transfer functions should use the same complex variable + of the Laplace transform.""")) + return Series(self, TransferFunction(other.den, other.num, self.var)) + elif (isinstance(other, Parallel) and len(other.args + ) == 2 and isinstance(other.args[0], TransferFunction) + and isinstance(other.args[1], (Series, TransferFunction))): + + if not self.var == other.var: + raise ValueError(filldedent(""" + Both TransferFunction and Parallel should use the + same complex variable of the Laplace transform.""")) + if other.args[1] == self: + # plant and controller with unit feedback. + return Feedback(self, other.args[0]) + other_arg_list = list(other.args[1].args) if isinstance( + other.args[1], Series) else other.args[1] + if other_arg_list == other.args[1]: + return Feedback(self, other_arg_list) + elif self in other_arg_list: + other_arg_list.remove(self) + else: + return Feedback(self, Series(*other_arg_list)) + + if len(other_arg_list) == 1: + return Feedback(self, *other_arg_list) + else: + return Feedback(self, Series(*other_arg_list)) + else: + raise ValueError("TransferFunction cannot be divided by {}.". + format(type(other))) + + __rtruediv__ = __truediv__ + + def __pow__(self, p): + p = sympify(p) + if not p.is_Integer: + raise ValueError("Exponent must be an integer.") + if p is S.Zero: + return TransferFunction(1, 1, self.var) + elif p > 0: + num_, den_ = self.num**p, self.den**p + else: + p = abs(p) + num_, den_ = self.den**p, self.num**p + + return TransferFunction(num_, den_, self.var) + + def __neg__(self): + return TransferFunction(-self.num, self.den, self.var) + + @property + def is_proper(self): + """ + Returns True if degree of the numerator polynomial is less than + or equal to degree of the denominator polynomial, else False. + + Examples + ======== + + >>> from sympy.abc import s, p, a, b + >>> from sympy.physics.control.lti import TransferFunction + >>> tf1 = TransferFunction(b*s**2 + p**2 - a*p + s, b - p**2, s) + >>> tf1.is_proper + False + >>> tf2 = TransferFunction(p**2 - 4*p, p**3 + 3*p + 2, p) + >>> tf2.is_proper + True + + """ + return degree(self.num, self.var) <= degree(self.den, self.var) + + @property + def is_strictly_proper(self): + """ + Returns True if degree of the numerator polynomial is strictly less + than degree of the denominator polynomial, else False. + + Examples + ======== + + >>> from sympy.abc import s, p, a, b + >>> from sympy.physics.control.lti import TransferFunction + >>> tf1 = TransferFunction(a*p**2 + b*s, s - p, s) + >>> tf1.is_strictly_proper + False + >>> tf2 = TransferFunction(s**3 - 2, s**4 + 5*s + 6, s) + >>> tf2.is_strictly_proper + True + + """ + return degree(self.num, self.var) < degree(self.den, self.var) + + @property + def is_biproper(self): + """ + Returns True if degree of the numerator polynomial is equal to + degree of the denominator polynomial, else False. + + Examples + ======== + + >>> from sympy.abc import s, p, a, b + >>> from sympy.physics.control.lti import TransferFunction + >>> tf1 = TransferFunction(a*p**2 + b*s, s - p, s) + >>> tf1.is_biproper + True + >>> tf2 = TransferFunction(p**2, p + a, p) + >>> tf2.is_biproper + False + + """ + return degree(self.num, self.var) == degree(self.den, self.var) + + def to_expr(self): + """ + Converts a ``TransferFunction`` object to SymPy Expr. + + Examples + ======== + + >>> from sympy.abc import s, p, a, b + >>> from sympy.physics.control.lti import TransferFunction + >>> from sympy import Expr + >>> tf1 = TransferFunction(s, a*s**2 + 1, s) + >>> tf1.to_expr() + s/(a*s**2 + 1) + >>> isinstance(_, Expr) + True + >>> tf2 = TransferFunction(1, (p + 3*b)*(b - p), p) + >>> tf2.to_expr() + 1/((b - p)*(3*b + p)) + >>> tf3 = TransferFunction((s - 2)*(s - 3), (s - 1)*(s - 2)*(s - 3), s) + >>> tf3.to_expr() + ((s - 3)*(s - 2))/(((s - 3)*(s - 2)*(s - 1))) + + """ + + if self.num != 1: + return Mul(self.num, Pow(self.den, -1, evaluate=False), evaluate=False) + else: + return Pow(self.den, -1, evaluate=False) + + +class PIDController(TransferFunction): + r""" + A class for representing PID (Proportional-Integral-Derivative) + controllers in control systems. The PIDController class is a subclass + of TransferFunction, representing the controller's transfer function + in the Laplace domain. The arguments are ``kp``, ``ki``, ``kd``, + ``tf``, and ``var``, where ``kp``, ``ki``, and ``kd`` are the + proportional, integral, and derivative gains respectively.``tf`` + is the derivative filter time constant, which can be used to + filter out the noise and ``var`` is the complex variable used in + the transfer function. + + Parameters + ========== + + kp : Expr, Number + Proportional gain. Defaults to ``Symbol('kp')`` if not specified. + ki : Expr, Number + Integral gain. Defaults to ``Symbol('ki')`` if not specified. + kd : Expr, Number + Derivative gain. Defaults to ``Symbol('kd')`` if not specified. + tf : Expr, Number + Derivative filter time constant. Defaults to ``0`` if not specified. + var : Symbol + The complex frequency variable. Defaults to ``s`` if not specified. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.control.lti import PIDController + >>> kp, ki, kd = symbols('kp ki kd') + >>> p1 = PIDController(kp, ki, kd) + >>> print(p1) + PIDController(kp, ki, kd, 0, s) + >>> p1.doit() + TransferFunction(kd*s**2 + ki + kp*s, s, s) + >>> p1.kp + kp + >>> p1.ki + ki + >>> p1.kd + kd + >>> p1.tf + 0 + >>> p1.var + s + >>> p1.to_expr() + (kd*s**2 + ki + kp*s)/s + + See Also + ======== + + TransferFunction + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/PID_controller + .. [2] https://in.mathworks.com/help/control/ug/proportional-integral-derivative-pid-controllers.html + + """ + def __new__(cls, kp=Symbol('kp'), ki=Symbol('ki'), kd=Symbol('kd'), tf=0, var=Symbol('s')): + kp, ki, kd, tf = _sympify(kp), _sympify(ki), _sympify(kd), _sympify(tf) + num = kp*tf*var**2 + kp*var + ki*tf*var + ki + kd*var**2 + den = tf*var**2 + var + obj = TransferFunction.__new__(cls, num, den, var) + obj._kp, obj._ki, obj._kd, obj._tf = kp, ki, kd, tf + return obj + + def __repr__(self): + return f"PIDController({self.kp}, {self.ki}, {self.kd}, {self.tf}, {self.var})" + + __str__ = __repr__ + + @property + def kp(self): + """ + Returns the Proportional gain (kp) of the PIDController. + """ + return self._kp + + @property + def ki(self): + """ + Returns the Integral gain (ki) of the PIDController. + """ + return self._ki + + @property + def kd(self): + """ + Returns the Derivative gain (kd) of the PIDController. + """ + return self._kd + + @property + def tf(self): + """ + Returns the Derivative filter time constant (tf) of the PIDController. + """ + return self._tf + + def doit(self): + """ + Convert the PIDController into TransferFunction. + """ + return TransferFunction(self.num, self.den, self.var) + + +def _flatten_args(args, _cls): + temp_args = [] + for arg in args: + if isinstance(arg, _cls): + temp_args.extend(arg.args) + else: + temp_args.append(arg) + return tuple(temp_args) + + +def _dummify_args(_arg, var): + dummy_dict = {} + dummy_arg_list = [] + + for arg in _arg: + _s = Dummy() + dummy_dict[_s] = var + dummy_arg = arg.subs({var: _s}) + dummy_arg_list.append(dummy_arg) + + return dummy_arg_list, dummy_dict + + +class Series(SISOLinearTimeInvariant): + r""" + A class for representing a series configuration of SISO systems. + + Parameters + ========== + + args : SISOLinearTimeInvariant + SISO systems in a series configuration. + evaluate : Boolean, Keyword + When passed ``True``, returns the equivalent + ``Series(*args).doit()``. Set to ``False`` by default. + + Raises + ====== + + ValueError + When no argument is passed. + + ``var`` attribute is not same for every system. + TypeError + Any of the passed ``*args`` has unsupported type + + A combination of SISO and MIMO systems is + passed. There should be homogeneity in the + type of systems passed, SISO in this case. + + Examples + ======== + + >>> from sympy.abc import s, p, a, b + >>> from sympy import Matrix + >>> from sympy.physics.control.lti import TransferFunction, Series, Parallel, StateSpace + >>> tf1 = TransferFunction(a*p**2 + b*s, s - p, s) + >>> tf2 = TransferFunction(s**3 - 2, s**4 + 5*s + 6, s) + >>> tf3 = TransferFunction(p**2, p + s, s) + >>> S1 = Series(tf1, tf2) + >>> S1 + Series(TransferFunction(a*p**2 + b*s, -p + s, s), TransferFunction(s**3 - 2, s**4 + 5*s + 6, s)) + >>> S1.var + s + >>> S2 = Series(tf2, Parallel(tf3, -tf1)) + >>> S2 + Series(TransferFunction(s**3 - 2, s**4 + 5*s + 6, s), Parallel(TransferFunction(p**2, p + s, s), TransferFunction(-a*p**2 - b*s, -p + s, s))) + >>> S2.var + s + >>> S3 = Series(Parallel(tf1, tf2), Parallel(tf2, tf3)) + >>> S3 + Series(Parallel(TransferFunction(a*p**2 + b*s, -p + s, s), TransferFunction(s**3 - 2, s**4 + 5*s + 6, s)), Parallel(TransferFunction(s**3 - 2, s**4 + 5*s + 6, s), TransferFunction(p**2, p + s, s))) + >>> S3.var + s + + You can get the resultant transfer function by using ``.doit()`` method: + + >>> S3 = Series(tf1, tf2, -tf3) + >>> S3.doit() + TransferFunction(-p**2*(s**3 - 2)*(a*p**2 + b*s), (-p + s)*(p + s)*(s**4 + 5*s + 6), s) + >>> S4 = Series(tf2, Parallel(tf1, -tf3)) + >>> S4.doit() + TransferFunction((s**3 - 2)*(-p**2*(-p + s) + (p + s)*(a*p**2 + b*s)), (-p + s)*(p + s)*(s**4 + 5*s + 6), s) + + You can also connect StateSpace which results in SISO + + >>> A1 = Matrix([[-1]]) + >>> B1 = Matrix([[1]]) + >>> C1 = Matrix([[-1]]) + >>> D1 = Matrix([1]) + >>> A2 = Matrix([[0]]) + >>> B2 = Matrix([[1]]) + >>> C2 = Matrix([[1]]) + >>> D2 = Matrix([[0]]) + >>> ss1 = StateSpace(A1, B1, C1, D1) + >>> ss2 = StateSpace(A2, B2, C2, D2) + >>> S5 = Series(ss1, ss2) + >>> S5 + Series(StateSpace(Matrix([[-1]]), Matrix([[1]]), Matrix([[-1]]), Matrix([[1]])), StateSpace(Matrix([[0]]), Matrix([[1]]), Matrix([[1]]), Matrix([[0]]))) + >>> S5.doit() + StateSpace(Matrix([ + [-1, 0], + [-1, 0]]), Matrix([ + [1], + [1]]), Matrix([[0, 1]]), Matrix([[0]])) + + Notes + ===== + + All the transfer functions should use the same complex variable + ``var`` of the Laplace transform. + + See Also + ======== + + MIMOSeries, Parallel, TransferFunction, Feedback + + """ + def __new__(cls, *args, evaluate=False): + + args = _flatten_args(args, Series) + # For StateSpace series connection + if args and any(isinstance(arg, StateSpace) or (hasattr(arg, 'is_StateSpace_object') + and arg.is_StateSpace_object)for arg in args): + # Check for SISO + if (args[0].num_inputs == 1) and (args[-1].num_outputs == 1): + # Check the interconnection + for i in range(1, len(args)): + if args[i].num_inputs != args[i-1].num_outputs: + raise ValueError(filldedent("""Systems with incompatible inputs and outputs + cannot be connected in Series.""")) + cls._is_series_StateSpace = True + else: + raise ValueError("To use Series connection for MIMO systems use MIMOSeries instead.") + else: + cls._is_series_StateSpace = False + cls._check_args(args) + + obj = super().__new__(cls, *args) + + return obj.doit() if evaluate else obj + + def __repr__(self): + systems_repr = ', '.join(repr(system) for system in self.args) + return f"Series({systems_repr})" + + __str__ = __repr__ + + @property + def var(self): + """ + Returns the complex variable used by all the transfer functions. + + Examples + ======== + + >>> from sympy.abc import p + >>> from sympy.physics.control.lti import TransferFunction, Series, Parallel + >>> G1 = TransferFunction(p**2 + 2*p + 4, p - 6, p) + >>> G2 = TransferFunction(p, 4 - p, p) + >>> G3 = TransferFunction(0, p**4 - 1, p) + >>> Series(G1, G2).var + p + >>> Series(-G3, Parallel(G1, G2)).var + p + + """ + return self.args[0].var + + def doit(self, **hints): + """ + Returns the resultant transfer function or StateSpace obtained after evaluating + the series interconnection. + + Examples + ======== + + >>> from sympy.abc import s, p, a, b + >>> from sympy.physics.control.lti import TransferFunction, Series + >>> tf1 = TransferFunction(a*p**2 + b*s, s - p, s) + >>> tf2 = TransferFunction(s**3 - 2, s**4 + 5*s + 6, s) + >>> Series(tf2, tf1).doit() + TransferFunction((s**3 - 2)*(a*p**2 + b*s), (-p + s)*(s**4 + 5*s + 6), s) + >>> Series(-tf1, -tf2).doit() + TransferFunction((2 - s**3)*(-a*p**2 - b*s), (-p + s)*(s**4 + 5*s + 6), s) + + Notes + ===== + + If a series connection contains only TransferFunction components, the equivalent system returned + will be a TransferFunction. However, if a StateSpace object is used in any of the arguments, + the output will be a StateSpace object. + + """ + # Check if the system is a StateSpace + if self._is_series_StateSpace: + # Return the equivalent StateSpace model + res = self.args[0] + if not isinstance(res, StateSpace): + res = res.doit().rewrite(StateSpace) + for arg in self.args[1:]: + if not isinstance(arg, StateSpace): + arg = arg.doit().rewrite(StateSpace) + else: + arg = arg.doit() + arg = arg.doit() + res = arg * res + return res + + _num_arg = (arg.doit().num for arg in self.args) + _den_arg = (arg.doit().den for arg in self.args) + res_num = Mul(*_num_arg, evaluate=True) + res_den = Mul(*_den_arg, evaluate=True) + return TransferFunction(res_num, res_den, self.var) + + def _eval_rewrite_as_TransferFunction(self, *args, **kwargs): + if self._is_series_StateSpace: + return self.doit().rewrite(TransferFunction)[0][0] + return self.doit() + + @_check_other_SISO + def __add__(self, other): + + if isinstance(other, Parallel): + arg_list = list(other.args) + return Parallel(self, *arg_list) + + return Parallel(self, other) + + __radd__ = __add__ + + @_check_other_SISO + def __sub__(self, other): + return self + (-other) + + def __rsub__(self, other): + return -self + other + + @_check_other_SISO + def __mul__(self, other): + + arg_list = list(self.args) + return Series(*arg_list, other) + + def __truediv__(self, other): + if isinstance(other, TransferFunction): + return Series(*self.args, TransferFunction(other.den, other.num, other.var)) + elif isinstance(other, Series): + tf_self = self.rewrite(TransferFunction) + tf_other = other.rewrite(TransferFunction) + return tf_self / tf_other + elif (isinstance(other, Parallel) and len(other.args) == 2 + and isinstance(other.args[0], TransferFunction) and isinstance(other.args[1], Series)): + + if not self.var == other.var: + raise ValueError(filldedent(""" + All the transfer functions should use the same complex variable + of the Laplace transform.""")) + self_arg_list = set(self.args) + other_arg_list = set(other.args[1].args) + res = list(self_arg_list ^ other_arg_list) + if len(res) == 0: + return Feedback(self, other.args[0]) + elif len(res) == 1: + return Feedback(self, *res) + else: + return Feedback(self, Series(*res)) + else: + raise ValueError("This transfer function expression is invalid.") + + def __neg__(self): + return Series(TransferFunction(-1, 1, self.var), self) + + def to_expr(self): + """Returns the equivalent ``Expr`` object.""" + return Mul(*(arg.to_expr() for arg in self.args), evaluate=False) + + @property + def is_proper(self): + """ + Returns True if degree of the numerator polynomial of the resultant transfer + function is less than or equal to degree of the denominator polynomial of + the same, else False. + + Examples + ======== + + >>> from sympy.abc import s, p, a, b + >>> from sympy.physics.control.lti import TransferFunction, Series + >>> tf1 = TransferFunction(b*s**2 + p**2 - a*p + s, b - p**2, s) + >>> tf2 = TransferFunction(p**2 - 4*p, p**3 + 3*s + 2, s) + >>> tf3 = TransferFunction(s, s**2 + s + 1, s) + >>> S1 = Series(-tf2, tf1) + >>> S1.is_proper + False + >>> S2 = Series(tf1, tf2, tf3) + >>> S2.is_proper + True + + """ + return self.doit().is_proper + + @property + def is_strictly_proper(self): + """ + Returns True if degree of the numerator polynomial of the resultant transfer + function is strictly less than degree of the denominator polynomial of + the same, else False. + + Examples + ======== + + >>> from sympy.abc import s, p, a, b + >>> from sympy.physics.control.lti import TransferFunction, Series + >>> tf1 = TransferFunction(a*p**2 + b*s, s - p, s) + >>> tf2 = TransferFunction(s**3 - 2, s**2 + 5*s + 6, s) + >>> tf3 = TransferFunction(1, s**2 + s + 1, s) + >>> S1 = Series(tf1, tf2) + >>> S1.is_strictly_proper + False + >>> S2 = Series(tf1, tf2, tf3) + >>> S2.is_strictly_proper + True + + """ + return self.doit().is_strictly_proper + + @property + def is_biproper(self): + r""" + Returns True if degree of the numerator polynomial of the resultant transfer + function is equal to degree of the denominator polynomial of + the same, else False. + + Examples + ======== + + >>> from sympy.abc import s, p, a, b + >>> from sympy.physics.control.lti import TransferFunction, Series + >>> tf1 = TransferFunction(a*p**2 + b*s, s - p, s) + >>> tf2 = TransferFunction(p, s**2, s) + >>> tf3 = TransferFunction(s**2, 1, s) + >>> S1 = Series(tf1, -tf2) + >>> S1.is_biproper + False + >>> S2 = Series(tf2, tf3) + >>> S2.is_biproper + True + + """ + return self.doit().is_biproper + + @property + def is_StateSpace_object(self): + return self._is_series_StateSpace + +def _mat_mul_compatible(*args): + """To check whether shapes are compatible for matrix mul.""" + return all(args[i].num_outputs == args[i+1].num_inputs for i in range(len(args)-1)) + + +class MIMOSeries(MIMOLinearTimeInvariant): + r""" + A class for representing a series configuration of MIMO systems. + + Parameters + ========== + + args : MIMOLinearTimeInvariant + MIMO systems in a series configuration. + evaluate : Boolean, Keyword + When passed ``True``, returns the equivalent + ``MIMOSeries(*args).doit()``. Set to ``False`` by default. + + Raises + ====== + + ValueError + When no argument is passed. + + ``var`` attribute is not same for every system. + + ``num_outputs`` of the MIMO system is not equal to the + ``num_inputs`` of its adjacent MIMO system. (Matrix + multiplication constraint, basically) + TypeError + Any of the passed ``*args`` has unsupported type + + A combination of SISO and MIMO systems is + passed. There should be homogeneity in the + type of systems passed, MIMO in this case. + + Examples + ======== + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import MIMOSeries, TransferFunctionMatrix, StateSpace + >>> from sympy import Matrix, pprint + >>> mat_a = Matrix([[5*s], [5]]) # 2 Outputs 1 Input + >>> mat_b = Matrix([[5, 1/(6*s**2)]]) # 1 Output 2 Inputs + >>> mat_c = Matrix([[1, s], [5/s, 1]]) # 2 Outputs 2 Inputs + >>> tfm_a = TransferFunctionMatrix.from_Matrix(mat_a, s) + >>> tfm_b = TransferFunctionMatrix.from_Matrix(mat_b, s) + >>> tfm_c = TransferFunctionMatrix.from_Matrix(mat_c, s) + >>> MIMOSeries(tfm_c, tfm_b, tfm_a) + MIMOSeries(TransferFunctionMatrix(((TransferFunction(1, 1, s), TransferFunction(s, 1, s)), (TransferFunction(5, s, s), TransferFunction(1, 1, s)))), TransferFunctionMatrix(((TransferFunction(5, 1, s), TransferFunction(1, 6*s**2, s)),)), TransferFunctionMatrix(((TransferFunction(5*s, 1, s),), (TransferFunction(5, 1, s),)))) + >>> pprint(_, use_unicode=False) # For Better Visualization + [5*s] [1 s] + [---] [5 1 ] [- -] + [ 1 ] [- ----] [1 1] + [ ] *[1 2] *[ ] + [ 5 ] [ 6*s ]{t} [5 1] + [ - ] [- -] + [ 1 ]{t} [s 1]{t} + >>> MIMOSeries(tfm_c, tfm_b, tfm_a).doit() + TransferFunctionMatrix(((TransferFunction(150*s**4 + 25*s, 6*s**3, s), TransferFunction(150*s**4 + 5*s, 6*s**2, s)), (TransferFunction(150*s**3 + 25, 6*s**3, s), TransferFunction(150*s**3 + 5, 6*s**2, s)))) + >>> pprint(_, use_unicode=False) # (2 Inputs -A-> 2 Outputs) -> (2 Inputs -B-> 1 Output) -> (1 Input -C-> 2 Outputs) is equivalent to (2 Inputs -Series Equivalent-> 2 Outputs). + [ 4 4 ] + [150*s + 25*s 150*s + 5*s] + [------------- ------------] + [ 3 2 ] + [ 6*s 6*s ] + [ ] + [ 3 3 ] + [ 150*s + 25 150*s + 5 ] + [ ----------- ---------- ] + [ 3 2 ] + [ 6*s 6*s ]{t} + >>> a1 = Matrix([[4, 1], [2, -3]]) + >>> b1 = Matrix([[5, 2], [-3, -3]]) + >>> c1 = Matrix([[2, -4], [0, 1]]) + >>> d1 = Matrix([[3, 2], [1, -1]]) + >>> a2 = Matrix([[-3, 4, 2], [-1, -3, 0], [2, 5, 3]]) + >>> b2 = Matrix([[1, 4], [-3, -3], [-2, 1]]) + >>> c2 = Matrix([[4, 2, -3], [1, 4, 3]]) + >>> d2 = Matrix([[-2, 4], [0, 1]]) + >>> ss1 = StateSpace(a1, b1, c1, d1) #2 inputs, 2 outputs + >>> ss2 = StateSpace(a2, b2, c2, d2) #2 inputs, 2 outputs + >>> S1 = MIMOSeries(ss1, ss2) #(2 inputs, 2 outputs) -> (2 inputs, 2 outputs) + >>> S1 + MIMOSeries(StateSpace(Matrix([ + [4, 1], + [2, -3]]), Matrix([ + [ 5, 2], + [-3, -3]]), Matrix([ + [2, -4], + [0, 1]]), Matrix([ + [3, 2], + [1, -1]])), StateSpace(Matrix([ + [-3, 4, 2], + [-1, -3, 0], + [ 2, 5, 3]]), Matrix([ + [ 1, 4], + [-3, -3], + [-2, 1]]), Matrix([ + [4, 2, -3], + [1, 4, 3]]), Matrix([ + [-2, 4], + [ 0, 1]]))) + >>> S1.doit() + StateSpace(Matrix([ + [ 4, 1, 0, 0, 0], + [ 2, -3, 0, 0, 0], + [ 2, 0, -3, 4, 2], + [-6, 9, -1, -3, 0], + [-4, 9, 2, 5, 3]]), Matrix([ + [ 5, 2], + [ -3, -3], + [ 7, -2], + [-12, -3], + [ -5, -5]]), Matrix([ + [-4, 12, 4, 2, -3], + [ 0, 1, 1, 4, 3]]), Matrix([ + [-2, -8], + [ 1, -1]])) + + Notes + ===== + + All the transfer function matrices should use the same complex variable ``var`` of the Laplace transform. + + ``MIMOSeries(A, B)`` is not equivalent to ``A*B``. It is always in the reverse order, that is ``B*A``. + + See Also + ======== + + Series, MIMOParallel + + """ + def __new__(cls, *args, evaluate=False): + + if args and any(isinstance(arg, StateSpace) or (hasattr(arg, 'is_StateSpace_object') + and arg.is_StateSpace_object) for arg in args): + # Check compatibility + for i in range(1, len(args)): + if args[i].num_inputs != args[i - 1].num_outputs: + raise ValueError(filldedent("""Systems with incompatible inputs and outputs + cannot be connected in MIMOSeries.""")) + obj = super().__new__(cls, *args) + cls._is_series_StateSpace = True + else: + cls._check_args(args) + cls._is_series_StateSpace = False + + if _mat_mul_compatible(*args): + obj = super().__new__(cls, *args) + + else: + raise ValueError(filldedent(""" + Number of input signals do not match the number + of output signals of adjacent systems for some args.""")) + + return obj.doit() if evaluate else obj + + @property + def var(self): + """ + Returns the complex variable used by all the transfer functions. + + Examples + ======== + + >>> from sympy.abc import p + >>> from sympy.physics.control.lti import TransferFunction, MIMOSeries, TransferFunctionMatrix + >>> G1 = TransferFunction(p**2 + 2*p + 4, p - 6, p) + >>> G2 = TransferFunction(p, 4 - p, p) + >>> G3 = TransferFunction(0, p**4 - 1, p) + >>> tfm_1 = TransferFunctionMatrix([[G1, G2, G3]]) + >>> tfm_2 = TransferFunctionMatrix([[G1], [G2], [G3]]) + >>> MIMOSeries(tfm_2, tfm_1).var + p + + """ + return self.args[0].var + + @property + def num_inputs(self): + """Returns the number of input signals of the series system.""" + return self.args[0].num_inputs + + @property + def num_outputs(self): + """Returns the number of output signals of the series system.""" + return self.args[-1].num_outputs + + @property + def shape(self): + """Returns the shape of the equivalent MIMO system.""" + return self.num_outputs, self.num_inputs + + @property + def is_StateSpace_object(self): + return self._is_series_StateSpace + + def doit(self, cancel=False, **kwargs): + """ + Returns the resultant obtained after evaluating the MIMO systems arranged + in a series configuration. For TransferFunction systems it returns a TransferFunctionMatrix + and for StateSpace systems it returns the resultant StateSpace system. + + Examples + ======== + + >>> from sympy.abc import s, p, a, b + >>> from sympy.physics.control.lti import TransferFunction, MIMOSeries, TransferFunctionMatrix + >>> tf1 = TransferFunction(a*p**2 + b*s, s - p, s) + >>> tf2 = TransferFunction(s**3 - 2, s**4 + 5*s + 6, s) + >>> tfm1 = TransferFunctionMatrix([[tf1, tf2], [tf2, tf2]]) + >>> tfm2 = TransferFunctionMatrix([[tf2, tf1], [tf1, tf1]]) + >>> MIMOSeries(tfm2, tfm1).doit() + TransferFunctionMatrix(((TransferFunction(2*(-p + s)*(s**3 - 2)*(a*p**2 + b*s)*(s**4 + 5*s + 6), (-p + s)**2*(s**4 + 5*s + 6)**2, s), TransferFunction((-p + s)**2*(s**3 - 2)*(a*p**2 + b*s) + (-p + s)*(a*p**2 + b*s)**2*(s**4 + 5*s + 6), (-p + s)**3*(s**4 + 5*s + 6), s)), (TransferFunction((-p + s)*(s**3 - 2)**2*(s**4 + 5*s + 6) + (s**3 - 2)*(a*p**2 + b*s)*(s**4 + 5*s + 6)**2, (-p + s)*(s**4 + 5*s + 6)**3, s), TransferFunction(2*(s**3 - 2)*(a*p**2 + b*s), (-p + s)*(s**4 + 5*s + 6), s)))) + + """ + if self._is_series_StateSpace: + # Return the equivalent StateSpace model + res = self.args[0] + if not isinstance(res, StateSpace): + res = res.doit().rewrite(StateSpace) + for arg in self.args[1:]: + if not isinstance(arg, StateSpace): + arg = arg.doit().rewrite(StateSpace) + else: + arg = arg.doit() + res = arg * res + return res + + _arg = (arg.doit()._expr_mat for arg in reversed(self.args)) + + if cancel: + res = MatMul(*_arg, evaluate=True) + return TransferFunctionMatrix.from_Matrix(res, self.var) + + _dummy_args, _dummy_dict = _dummify_args(_arg, self.var) + res = MatMul(*_dummy_args, evaluate=True) + temp_tfm = TransferFunctionMatrix.from_Matrix(res, self.var) + return temp_tfm.subs(_dummy_dict) + + def _eval_rewrite_as_TransferFunctionMatrix(self, *args, **kwargs): + if self._is_series_StateSpace: + return self.doit().rewrite(TransferFunction) + return self.doit() + + @_check_other_MIMO + def __add__(self, other): + + if isinstance(other, MIMOParallel): + arg_list = list(other.args) + return MIMOParallel(self, *arg_list) + + return MIMOParallel(self, other) + + __radd__ = __add__ + + @_check_other_MIMO + def __sub__(self, other): + return self + (-other) + + def __rsub__(self, other): + return -self + other + + @_check_other_MIMO + def __mul__(self, other): + + if isinstance(other, MIMOSeries): + self_arg_list = list(self.args) + other_arg_list = list(other.args) + return MIMOSeries(*other_arg_list, *self_arg_list) # A*B = MIMOSeries(B, A) + + arg_list = list(self.args) + return MIMOSeries(other, *arg_list) + + def __neg__(self): + arg_list = list(self.args) + arg_list[0] = -arg_list[0] + return MIMOSeries(*arg_list) + + +class Parallel(SISOLinearTimeInvariant): + r""" + A class for representing a parallel configuration of SISO systems. + + Parameters + ========== + + args : SISOLinearTimeInvariant + SISO systems in a parallel arrangement. + evaluate : Boolean, Keyword + When passed ``True``, returns the equivalent + ``Parallel(*args).doit()``. Set to ``False`` by default. + + Raises + ====== + + ValueError + When no argument is passed. + + ``var`` attribute is not same for every system. + TypeError + Any of the passed ``*args`` has unsupported type + + A combination of SISO and MIMO systems is + passed. There should be homogeneity in the + type of systems passed. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.abc import s, p, a, b + >>> from sympy.physics.control.lti import TransferFunction, Parallel, Series, StateSpace + >>> tf1 = TransferFunction(a*p**2 + b*s, s - p, s) + >>> tf2 = TransferFunction(s**3 - 2, s**4 + 5*s + 6, s) + >>> tf3 = TransferFunction(p**2, p + s, s) + >>> P1 = Parallel(tf1, tf2) + >>> P1 + Parallel(TransferFunction(a*p**2 + b*s, -p + s, s), TransferFunction(s**3 - 2, s**4 + 5*s + 6, s)) + >>> P1.var + s + >>> P2 = Parallel(tf2, Series(tf3, -tf1)) + >>> P2 + Parallel(TransferFunction(s**3 - 2, s**4 + 5*s + 6, s), Series(TransferFunction(p**2, p + s, s), TransferFunction(-a*p**2 - b*s, -p + s, s))) + >>> P2.var + s + >>> P3 = Parallel(Series(tf1, tf2), Series(tf2, tf3)) + >>> P3 + Parallel(Series(TransferFunction(a*p**2 + b*s, -p + s, s), TransferFunction(s**3 - 2, s**4 + 5*s + 6, s)), Series(TransferFunction(s**3 - 2, s**4 + 5*s + 6, s), TransferFunction(p**2, p + s, s))) + >>> P3.var + s + + You can get the resultant transfer function by using ``.doit()`` method: + + >>> Parallel(tf1, tf2, -tf3).doit() + TransferFunction(-p**2*(-p + s)*(s**4 + 5*s + 6) + (-p + s)*(p + s)*(s**3 - 2) + (p + s)*(a*p**2 + b*s)*(s**4 + 5*s + 6), (-p + s)*(p + s)*(s**4 + 5*s + 6), s) + >>> Parallel(tf2, Series(tf1, -tf3)).doit() + TransferFunction(-p**2*(a*p**2 + b*s)*(s**4 + 5*s + 6) + (-p + s)*(p + s)*(s**3 - 2), (-p + s)*(p + s)*(s**4 + 5*s + 6), s) + + Parallel can be used to connect SISO ``StateSpace`` systems together. + + >>> A1 = Matrix([[-1]]) + >>> B1 = Matrix([[1]]) + >>> C1 = Matrix([[-1]]) + >>> D1 = Matrix([1]) + >>> A2 = Matrix([[0]]) + >>> B2 = Matrix([[1]]) + >>> C2 = Matrix([[1]]) + >>> D2 = Matrix([[0]]) + >>> ss1 = StateSpace(A1, B1, C1, D1) + >>> ss2 = StateSpace(A2, B2, C2, D2) + >>> P4 = Parallel(ss1, ss2) + >>> P4 + Parallel(StateSpace(Matrix([[-1]]), Matrix([[1]]), Matrix([[-1]]), Matrix([[1]])), StateSpace(Matrix([[0]]), Matrix([[1]]), Matrix([[1]]), Matrix([[0]]))) + + ``doit()`` can be used to find ``StateSpace`` equivalent for the system containing ``StateSpace`` objects. + + >>> P4.doit() + StateSpace(Matrix([ + [-1, 0], + [ 0, 0]]), Matrix([ + [1], + [1]]), Matrix([[-1, 1]]), Matrix([[1]])) + >>> P4.rewrite(TransferFunction) + TransferFunction(s*(s + 1) + 1, s*(s + 1), s) + + Notes + ===== + + All the transfer functions should use the same complex variable + ``var`` of the Laplace transform. + + See Also + ======== + + Series, TransferFunction, Feedback + + """ + + def __new__(cls, *args, evaluate=False): + + args = _flatten_args(args, Parallel) + # For StateSpace parallel connection + if args and any(isinstance(arg, StateSpace) or (hasattr(arg, 'is_StateSpace_object') + and arg.is_StateSpace_object) for arg in args): + # Check for SISO + if all(arg.is_SISO for arg in args): + cls._is_parallel_StateSpace = True + else: + raise ValueError("To use Parallel connection for MIMO systems use MIMOParallel instead.") + else: + cls._is_parallel_StateSpace = False + cls._check_args(args) + obj = super().__new__(cls, *args) + + return obj.doit() if evaluate else obj + + def __repr__(self): + systems_repr = ', '.join(repr(system) for system in self.args) + return f"Parallel({systems_repr})" + + __str__ = __repr__ + + @property + def var(self): + """ + Returns the complex variable used by all the transfer functions. + + Examples + ======== + + >>> from sympy.abc import p + >>> from sympy.physics.control.lti import TransferFunction, Parallel, Series + >>> G1 = TransferFunction(p**2 + 2*p + 4, p - 6, p) + >>> G2 = TransferFunction(p, 4 - p, p) + >>> G3 = TransferFunction(0, p**4 - 1, p) + >>> Parallel(G1, G2).var + p + >>> Parallel(-G3, Series(G1, G2)).var + p + + """ + return self.args[0].var + + def doit(self, **hints): + """ + Returns the resultant transfer function or state space obtained by + parallel connection of transfer functions or state space objects. + + Examples + ======== + + >>> from sympy.abc import s, p, a, b + >>> from sympy.physics.control.lti import TransferFunction, Parallel + >>> tf1 = TransferFunction(a*p**2 + b*s, s - p, s) + >>> tf2 = TransferFunction(s**3 - 2, s**4 + 5*s + 6, s) + >>> Parallel(tf2, tf1).doit() + TransferFunction((-p + s)*(s**3 - 2) + (a*p**2 + b*s)*(s**4 + 5*s + 6), (-p + s)*(s**4 + 5*s + 6), s) + >>> Parallel(-tf1, -tf2).doit() + TransferFunction((2 - s**3)*(-p + s) + (-a*p**2 - b*s)*(s**4 + 5*s + 6), (-p + s)*(s**4 + 5*s + 6), s) + + """ + if self._is_parallel_StateSpace: + # Return the equivalent StateSpace model + res = self.args[0].doit() + if not isinstance(res, StateSpace): + res = res.rewrite(StateSpace) + for arg in self.args[1:]: + if not isinstance(arg, StateSpace): + arg = arg.doit().rewrite(StateSpace) + res += arg + return res + + _arg = (arg.doit().to_expr() for arg in self.args) + res = Add(*_arg).as_numer_denom() + return TransferFunction(*res, self.var) + + def _eval_rewrite_as_TransferFunction(self, *args, **kwargs): + if self._is_parallel_StateSpace: + return self.doit().rewrite(TransferFunction)[0][0] + return self.doit() + + @_check_other_SISO + def __add__(self, other): + + self_arg_list = list(self.args) + return Parallel(*self_arg_list, other) + + __radd__ = __add__ + + @_check_other_SISO + def __sub__(self, other): + return self + (-other) + + def __rsub__(self, other): + return -self + other + + @_check_other_SISO + def __mul__(self, other): + + if isinstance(other, Series): + arg_list = list(other.args) + return Series(self, *arg_list) + + return Series(self, other) + + def __neg__(self): + return Series(TransferFunction(-1, 1, self.var), self) + + def to_expr(self): + """Returns the equivalent ``Expr`` object.""" + return Add(*(arg.to_expr() for arg in self.args), evaluate=False) + + @property + def is_proper(self): + """ + Returns True if degree of the numerator polynomial of the resultant transfer + function is less than or equal to degree of the denominator polynomial of + the same, else False. + + Examples + ======== + + >>> from sympy.abc import s, p, a, b + >>> from sympy.physics.control.lti import TransferFunction, Parallel + >>> tf1 = TransferFunction(b*s**2 + p**2 - a*p + s, b - p**2, s) + >>> tf2 = TransferFunction(p**2 - 4*p, p**3 + 3*s + 2, s) + >>> tf3 = TransferFunction(s, s**2 + s + 1, s) + >>> P1 = Parallel(-tf2, tf1) + >>> P1.is_proper + False + >>> P2 = Parallel(tf2, tf3) + >>> P2.is_proper + True + + """ + return self.doit().is_proper + + @property + def is_strictly_proper(self): + """ + Returns True if degree of the numerator polynomial of the resultant transfer + function is strictly less than degree of the denominator polynomial of + the same, else False. + + Examples + ======== + + >>> from sympy.abc import s, p, a, b + >>> from sympy.physics.control.lti import TransferFunction, Parallel + >>> tf1 = TransferFunction(a*p**2 + b*s, s - p, s) + >>> tf2 = TransferFunction(s**3 - 2, s**4 + 5*s + 6, s) + >>> tf3 = TransferFunction(s, s**2 + s + 1, s) + >>> P1 = Parallel(tf1, tf2) + >>> P1.is_strictly_proper + False + >>> P2 = Parallel(tf2, tf3) + >>> P2.is_strictly_proper + True + + """ + return self.doit().is_strictly_proper + + @property + def is_biproper(self): + """ + Returns True if degree of the numerator polynomial of the resultant transfer + function is equal to degree of the denominator polynomial of + the same, else False. + + Examples + ======== + + >>> from sympy.abc import s, p, a, b + >>> from sympy.physics.control.lti import TransferFunction, Parallel + >>> tf1 = TransferFunction(a*p**2 + b*s, s - p, s) + >>> tf2 = TransferFunction(p**2, p + s, s) + >>> tf3 = TransferFunction(s, s**2 + s + 1, s) + >>> P1 = Parallel(tf1, -tf2) + >>> P1.is_biproper + True + >>> P2 = Parallel(tf2, tf3) + >>> P2.is_biproper + False + + """ + return self.doit().is_biproper + + @property + def is_StateSpace_object(self): + return self._is_parallel_StateSpace + + +class MIMOParallel(MIMOLinearTimeInvariant): + r""" + A class for representing a parallel configuration of MIMO systems. + + Parameters + ========== + + args : MIMOLinearTimeInvariant + MIMO Systems in a parallel arrangement. + evaluate : Boolean, Keyword + When passed ``True``, returns the equivalent + ``MIMOParallel(*args).doit()``. Set to ``False`` by default. + + Raises + ====== + + ValueError + When no argument is passed. + + ``var`` attribute is not same for every system. + + All MIMO systems passed do not have same shape. + TypeError + Any of the passed ``*args`` has unsupported type + + A combination of SISO and MIMO systems is + passed. There should be homogeneity in the + type of systems passed, MIMO in this case. + + Examples + ======== + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunctionMatrix, MIMOParallel, StateSpace + >>> from sympy import Matrix, pprint + >>> expr_1 = 1/s + >>> expr_2 = s/(s**2-1) + >>> expr_3 = (2 + s)/(s**2 - 1) + >>> expr_4 = 5 + >>> tfm_a = TransferFunctionMatrix.from_Matrix(Matrix([[expr_1, expr_2], [expr_3, expr_4]]), s) + >>> tfm_b = TransferFunctionMatrix.from_Matrix(Matrix([[expr_2, expr_1], [expr_4, expr_3]]), s) + >>> tfm_c = TransferFunctionMatrix.from_Matrix(Matrix([[expr_3, expr_4], [expr_1, expr_2]]), s) + >>> MIMOParallel(tfm_a, tfm_b, tfm_c) + MIMOParallel(TransferFunctionMatrix(((TransferFunction(1, s, s), TransferFunction(s, s**2 - 1, s)), (TransferFunction(s + 2, s**2 - 1, s), TransferFunction(5, 1, s)))), TransferFunctionMatrix(((TransferFunction(s, s**2 - 1, s), TransferFunction(1, s, s)), (TransferFunction(5, 1, s), TransferFunction(s + 2, s**2 - 1, s)))), TransferFunctionMatrix(((TransferFunction(s + 2, s**2 - 1, s), TransferFunction(5, 1, s)), (TransferFunction(1, s, s), TransferFunction(s, s**2 - 1, s))))) + >>> pprint(_, use_unicode=False) # For Better Visualization + [ 1 s ] [ s 1 ] [s + 2 5 ] + [ - ------] [------ - ] [------ - ] + [ s 2 ] [ 2 s ] [ 2 1 ] + [ s - 1] [s - 1 ] [s - 1 ] + [ ] + [ ] + [ ] + [s + 2 5 ] [ 5 s + 2 ] [ 1 s ] + [------ - ] [ - ------] [ - ------] + [ 2 1 ] [ 1 2 ] [ s 2 ] + [s - 1 ]{t} [ s - 1]{t} [ s - 1]{t} + >>> MIMOParallel(tfm_a, tfm_b, tfm_c).doit() + TransferFunctionMatrix(((TransferFunction(s**2 + s*(2*s + 2) - 1, s*(s**2 - 1), s), TransferFunction(2*s**2 + 5*s*(s**2 - 1) - 1, s*(s**2 - 1), s)), (TransferFunction(s**2 + s*(s + 2) + 5*s*(s**2 - 1) - 1, s*(s**2 - 1), s), TransferFunction(5*s**2 + 2*s - 3, s**2 - 1, s)))) + >>> pprint(_, use_unicode=False) + [ 2 2 / 2 \ ] + [ s + s*(2*s + 2) - 1 2*s + 5*s*\s - 1/ - 1] + [ -------------------- -----------------------] + [ / 2 \ / 2 \ ] + [ s*\s - 1/ s*\s - 1/ ] + [ ] + [ 2 / 2 \ 2 ] + [s + s*(s + 2) + 5*s*\s - 1/ - 1 5*s + 2*s - 3 ] + [--------------------------------- -------------- ] + [ / 2 \ 2 ] + [ s*\s - 1/ s - 1 ]{t} + + ``MIMOParallel`` can also be used to connect MIMO ``StateSpace`` systems. + + >>> A1 = Matrix([[4, 1], [2, -3]]) + >>> B1 = Matrix([[5, 2], [-3, -3]]) + >>> C1 = Matrix([[2, -4], [0, 1]]) + >>> D1 = Matrix([[3, 2], [1, -1]]) + >>> A2 = Matrix([[-3, 4, 2], [-1, -3, 0], [2, 5, 3]]) + >>> B2 = Matrix([[1, 4], [-3, -3], [-2, 1]]) + >>> C2 = Matrix([[4, 2, -3], [1, 4, 3]]) + >>> D2 = Matrix([[-2, 4], [0, 1]]) + >>> ss1 = StateSpace(A1, B1, C1, D1) + >>> ss2 = StateSpace(A2, B2, C2, D2) + >>> p1 = MIMOParallel(ss1, ss2) + >>> p1 + MIMOParallel(StateSpace(Matrix([ + [4, 1], + [2, -3]]), Matrix([ + [ 5, 2], + [-3, -3]]), Matrix([ + [2, -4], + [0, 1]]), Matrix([ + [3, 2], + [1, -1]])), StateSpace(Matrix([ + [-3, 4, 2], + [-1, -3, 0], + [ 2, 5, 3]]), Matrix([ + [ 1, 4], + [-3, -3], + [-2, 1]]), Matrix([ + [4, 2, -3], + [1, 4, 3]]), Matrix([ + [-2, 4], + [ 0, 1]]))) + + ``doit()`` can be used to find ``StateSpace`` equivalent for the system containing ``StateSpace`` objects. + + >>> p1.doit() + StateSpace(Matrix([ + [4, 1, 0, 0, 0], + [2, -3, 0, 0, 0], + [0, 0, -3, 4, 2], + [0, 0, -1, -3, 0], + [0, 0, 2, 5, 3]]), Matrix([ + [ 5, 2], + [-3, -3], + [ 1, 4], + [-3, -3], + [-2, 1]]), Matrix([ + [2, -4, 4, 2, -3], + [0, 1, 1, 4, 3]]), Matrix([ + [1, 6], + [1, 0]])) + + Notes + ===== + + All the transfer function matrices should use the same complex variable + ``var`` of the Laplace transform. + + See Also + ======== + + Parallel, MIMOSeries + + """ + + def __new__(cls, *args, evaluate=False): + + args = _flatten_args(args, MIMOParallel) + + # For StateSpace Parallel connection + if args and any(isinstance(arg, StateSpace) or (hasattr(arg, 'is_StateSpace_object') + and arg.is_StateSpace_object) for arg in args): + if any(arg.num_inputs != args[0].num_inputs or arg.num_outputs != args[0].num_outputs + for arg in args[1:]): + raise ShapeError("Systems with incompatible inputs and outputs cannot be " + "connected in MIMOParallel.") + cls._is_parallel_StateSpace = True + else: + cls._check_args(args) + if any(arg.shape != args[0].shape for arg in args): + raise TypeError("Shape of all the args is not equal.") + cls._is_parallel_StateSpace = False + obj = super().__new__(cls, *args) + + return obj.doit() if evaluate else obj + + @property + def var(self): + """ + Returns the complex variable used by all the systems. + + Examples + ======== + + >>> from sympy.abc import p + >>> from sympy.physics.control.lti import TransferFunction, TransferFunctionMatrix, MIMOParallel + >>> G1 = TransferFunction(p**2 + 2*p + 4, p - 6, p) + >>> G2 = TransferFunction(p, 4 - p, p) + >>> G3 = TransferFunction(0, p**4 - 1, p) + >>> G4 = TransferFunction(p**2, p**2 - 1, p) + >>> tfm_a = TransferFunctionMatrix([[G1, G2], [G3, G4]]) + >>> tfm_b = TransferFunctionMatrix([[G2, G1], [G4, G3]]) + >>> MIMOParallel(tfm_a, tfm_b).var + p + + """ + return self.args[0].var + + @property + def num_inputs(self): + """Returns the number of input signals of the parallel system.""" + return self.args[0].num_inputs + + @property + def num_outputs(self): + """Returns the number of output signals of the parallel system.""" + return self.args[0].num_outputs + + @property + def shape(self): + """Returns the shape of the equivalent MIMO system.""" + return self.num_outputs, self.num_inputs + + @property + def is_StateSpace_object(self): + return self._is_parallel_StateSpace + + def doit(self, **hints): + """ + Returns the resultant transfer function matrix or StateSpace obtained after evaluating + the MIMO systems arranged in a parallel configuration. + + Examples + ======== + + >>> from sympy.abc import s, p, a, b + >>> from sympy.physics.control.lti import TransferFunction, MIMOParallel, TransferFunctionMatrix + >>> tf1 = TransferFunction(a*p**2 + b*s, s - p, s) + >>> tf2 = TransferFunction(s**3 - 2, s**4 + 5*s + 6, s) + >>> tfm_1 = TransferFunctionMatrix([[tf1, tf2], [tf2, tf1]]) + >>> tfm_2 = TransferFunctionMatrix([[tf2, tf1], [tf1, tf2]]) + >>> MIMOParallel(tfm_1, tfm_2).doit() + TransferFunctionMatrix(((TransferFunction((-p + s)*(s**3 - 2) + (a*p**2 + b*s)*(s**4 + 5*s + 6), (-p + s)*(s**4 + 5*s + 6), s), TransferFunction((-p + s)*(s**3 - 2) + (a*p**2 + b*s)*(s**4 + 5*s + 6), (-p + s)*(s**4 + 5*s + 6), s)), (TransferFunction((-p + s)*(s**3 - 2) + (a*p**2 + b*s)*(s**4 + 5*s + 6), (-p + s)*(s**4 + 5*s + 6), s), TransferFunction((-p + s)*(s**3 - 2) + (a*p**2 + b*s)*(s**4 + 5*s + 6), (-p + s)*(s**4 + 5*s + 6), s)))) + + """ + if self._is_parallel_StateSpace: + # Return the equivalent StateSpace model. + res = self.args[0] + if not isinstance(res, StateSpace): + res = res.doit().rewrite(StateSpace) + for arg in self.args[1:]: + if not isinstance(arg, StateSpace): + arg = arg.doit().rewrite(StateSpace) + else: + arg = arg.doit() + res += arg + return res + _arg = (arg.doit()._expr_mat for arg in self.args) + res = MatAdd(*_arg, evaluate=True) + return TransferFunctionMatrix.from_Matrix(res, self.var) + + def _eval_rewrite_as_TransferFunctionMatrix(self, *args, **kwargs): + if self._is_parallel_StateSpace: + return self.doit().rewrite(TransferFunction) + return self.doit() + + @_check_other_MIMO + def __add__(self, other): + + self_arg_list = list(self.args) + return MIMOParallel(*self_arg_list, other) + + __radd__ = __add__ + + @_check_other_MIMO + def __sub__(self, other): + return self + (-other) + + def __rsub__(self, other): + return -self + other + + @_check_other_MIMO + def __mul__(self, other): + + if isinstance(other, MIMOSeries): + arg_list = list(other.args) + return MIMOSeries(*arg_list, self) + + return MIMOSeries(other, self) + + def __neg__(self): + arg_list = [-arg for arg in list(self.args)] + return MIMOParallel(*arg_list) + + +class Feedback(SISOLinearTimeInvariant): + r""" + A class for representing closed-loop feedback interconnection between two + SISO input/output systems. + + The first argument, ``sys1``, is the feedforward part of the closed-loop + system or in simple words, the dynamical model representing the process + to be controlled. The second argument, ``sys2``, is the feedback system + and controls the fed back signal to ``sys1``. Both ``sys1`` and ``sys2`` + can either be ``Series``, ``StateSpace`` or ``TransferFunction`` objects. + + Parameters + ========== + + sys1 : Series, StateSpace, TransferFunction + The feedforward path system. + sys2 : Series, StateSpace, TransferFunction, optional + The feedback path system (often a feedback controller). + It is the model sitting on the feedback path. + + If not specified explicitly, the sys2 is + assumed to be unit (1.0) transfer function. + sign : int, optional + The sign of feedback. Can either be ``1`` + (for positive feedback) or ``-1`` (for negative feedback). + Default value is `-1`. + + Raises + ====== + + ValueError + When ``sys1`` and ``sys2`` are not using the + same complex variable of the Laplace transform. + + When a combination of ``sys1`` and ``sys2`` yields + zero denominator. + + TypeError + When either ``sys1`` or ``sys2`` is not a ``Series``, ``StateSpace`` or + ``TransferFunction`` object. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import StateSpace, TransferFunction, Feedback + >>> plant = TransferFunction(3*s**2 + 7*s - 3, s**2 - 4*s + 2, s) + >>> controller = TransferFunction(5*s - 10, s + 7, s) + >>> F1 = Feedback(plant, controller) + >>> F1 + Feedback(TransferFunction(3*s**2 + 7*s - 3, s**2 - 4*s + 2, s), TransferFunction(5*s - 10, s + 7, s), -1) + >>> F1.var + s + >>> F1.args + (TransferFunction(3*s**2 + 7*s - 3, s**2 - 4*s + 2, s), TransferFunction(5*s - 10, s + 7, s), -1) + + You can get the feedforward and feedback path systems by using ``.sys1`` and ``.sys2`` respectively. + + >>> F1.sys1 + TransferFunction(3*s**2 + 7*s - 3, s**2 - 4*s + 2, s) + >>> F1.sys2 + TransferFunction(5*s - 10, s + 7, s) + + You can get the resultant closed loop transfer function obtained by negative feedback + interconnection using ``.doit()`` method. + + >>> F1.doit() + TransferFunction((s + 7)*(s**2 - 4*s + 2)*(3*s**2 + 7*s - 3), ((s + 7)*(s**2 - 4*s + 2) + (5*s - 10)*(3*s**2 + 7*s - 3))*(s**2 - 4*s + 2), s) + >>> G = TransferFunction(2*s**2 + 5*s + 1, s**2 + 2*s + 3, s) + >>> C = TransferFunction(5*s + 10, s + 10, s) + >>> F2 = Feedback(G*C, TransferFunction(1, 1, s)) + >>> F2.doit() + TransferFunction((s + 10)*(5*s + 10)*(s**2 + 2*s + 3)*(2*s**2 + 5*s + 1), (s + 10)*((s + 10)*(s**2 + 2*s + 3) + (5*s + 10)*(2*s**2 + 5*s + 1))*(s**2 + 2*s + 3), s) + + To negate a ``Feedback`` object, the ``-`` operator can be prepended: + + >>> -F1 + Feedback(TransferFunction(-3*s**2 - 7*s + 3, s**2 - 4*s + 2, s), TransferFunction(10 - 5*s, s + 7, s), -1) + >>> -F2 + Feedback(Series(TransferFunction(-1, 1, s), TransferFunction(2*s**2 + 5*s + 1, s**2 + 2*s + 3, s), TransferFunction(5*s + 10, s + 10, s)), TransferFunction(-1, 1, s), -1) + + ``Feedback`` can also be used to connect SISO ``StateSpace`` systems together. + + >>> A1 = Matrix([[-1]]) + >>> B1 = Matrix([[1]]) + >>> C1 = Matrix([[-1]]) + >>> D1 = Matrix([1]) + >>> A2 = Matrix([[0]]) + >>> B2 = Matrix([[1]]) + >>> C2 = Matrix([[1]]) + >>> D2 = Matrix([[0]]) + >>> ss1 = StateSpace(A1, B1, C1, D1) + >>> ss2 = StateSpace(A2, B2, C2, D2) + >>> F3 = Feedback(ss1, ss2) + >>> F3 + Feedback(StateSpace(Matrix([[-1]]), Matrix([[1]]), Matrix([[-1]]), Matrix([[1]])), StateSpace(Matrix([[0]]), Matrix([[1]]), Matrix([[1]]), Matrix([[0]])), -1) + + ``doit()`` can be used to find ``StateSpace`` equivalent for the system containing ``StateSpace`` objects. + + >>> F3.doit() + StateSpace(Matrix([ + [-1, -1], + [-1, -1]]), Matrix([ + [1], + [1]]), Matrix([[-1, -1]]), Matrix([[1]])) + + We can also find the equivalent ``TransferFunction`` by using ``rewrite(TransferFunction)`` method. + + >>> F3.rewrite(TransferFunction) + TransferFunction(s, s + 2, s) + + See Also + ======== + + MIMOFeedback, Series, Parallel + + """ + def __new__(cls, sys1, sys2=None, sign=-1): + if not sys2: + sys2 = TransferFunction(1, 1, sys1.var) + + if not isinstance(sys1, (TransferFunction, Series, StateSpace, Feedback)): + raise TypeError("Unsupported type for `sys1` in Feedback.") + + if not isinstance(sys2, (TransferFunction, Series, StateSpace, Feedback)): + raise TypeError("Unsupported type for `sys2` in Feedback.") + + if not (sys1.num_inputs == sys1.num_outputs == sys2.num_inputs == + sys2.num_outputs == 1): + raise ValueError("""To use Feedback connection for MIMO systems + use MIMOFeedback instead.""") + + if sign not in [-1, 1]: + raise ValueError(filldedent(""" + Unsupported type for feedback. `sign` arg should + either be 1 (positive feedback loop) or -1 + (negative feedback loop).""")) + + if sys1.is_StateSpace_object or sys2.is_StateSpace_object: + cls.is_StateSpace_object = True + else: + if Mul(sys1.to_expr(), sys2.to_expr()).simplify() == sign: + raise ValueError("The equivalent system will have zero denominator.") + if sys1.var != sys2.var: + raise ValueError(filldedent("""Both `sys1` and `sys2` should be using the + same complex variable.""")) + cls.is_StateSpace_object = False + + return super(SISOLinearTimeInvariant, cls).__new__(cls, sys1, sys2, _sympify(sign)) + + def __repr__(self): + return f"Feedback({self.sys1}, {self.sys2}, {self.sign})" + + __str__ = __repr__ + + @property + def sys1(self): + """ + Returns the feedforward system of the feedback interconnection. + + Examples + ======== + + >>> from sympy.abc import s, p + >>> from sympy.physics.control.lti import TransferFunction, Feedback + >>> plant = TransferFunction(3*s**2 + 7*s - 3, s**2 - 4*s + 2, s) + >>> controller = TransferFunction(5*s - 10, s + 7, s) + >>> F1 = Feedback(plant, controller) + >>> F1.sys1 + TransferFunction(3*s**2 + 7*s - 3, s**2 - 4*s + 2, s) + >>> G = TransferFunction(2*s**2 + 5*s + 1, p**2 + 2*p + 3, p) + >>> C = TransferFunction(5*p + 10, p + 10, p) + >>> P = TransferFunction(1 - s, p + 2, p) + >>> F2 = Feedback(TransferFunction(1, 1, p), G*C*P) + >>> F2.sys1 + TransferFunction(1, 1, p) + + """ + return self.args[0] + + @property + def sys2(self): + """ + Returns the feedback controller of the feedback interconnection. + + Examples + ======== + + >>> from sympy.abc import s, p + >>> from sympy.physics.control.lti import TransferFunction, Feedback + >>> plant = TransferFunction(3*s**2 + 7*s - 3, s**2 - 4*s + 2, s) + >>> controller = TransferFunction(5*s - 10, s + 7, s) + >>> F1 = Feedback(plant, controller) + >>> F1.sys2 + TransferFunction(5*s - 10, s + 7, s) + >>> G = TransferFunction(2*s**2 + 5*s + 1, p**2 + 2*p + 3, p) + >>> C = TransferFunction(5*p + 10, p + 10, p) + >>> P = TransferFunction(1 - s, p + 2, p) + >>> F2 = Feedback(TransferFunction(1, 1, p), G*C*P) + >>> F2.sys2 + Series(TransferFunction(2*s**2 + 5*s + 1, p**2 + 2*p + 3, p), TransferFunction(5*p + 10, p + 10, p), TransferFunction(1 - s, p + 2, p)) + + """ + return self.args[1] + + @property + def var(self): + """ + Returns the complex variable of the Laplace transform used by all + the transfer functions involved in the feedback interconnection. + + Examples + ======== + + >>> from sympy.abc import s, p + >>> from sympy.physics.control.lti import TransferFunction, Feedback + >>> plant = TransferFunction(3*s**2 + 7*s - 3, s**2 - 4*s + 2, s) + >>> controller = TransferFunction(5*s - 10, s + 7, s) + >>> F1 = Feedback(plant, controller) + >>> F1.var + s + >>> G = TransferFunction(2*s**2 + 5*s + 1, p**2 + 2*p + 3, p) + >>> C = TransferFunction(5*p + 10, p + 10, p) + >>> P = TransferFunction(1 - s, p + 2, p) + >>> F2 = Feedback(TransferFunction(1, 1, p), G*C*P) + >>> F2.var + p + + """ + return self.sys1.var + + @property + def sign(self): + """ + Returns the type of MIMO Feedback model. ``1`` + for Positive and ``-1`` for Negative. + """ + return self.args[2] + + @property + def num(self): + """ + Returns the numerator of the closed loop feedback system. + """ + return self.sys1 + + @property + def den(self): + """ + Returns the denominator of the closed loop feedback model. + """ + unit = TransferFunction(1, 1, self.var) + arg_list = list(self.sys1.args) if isinstance(self.sys1, Series) else [self.sys1] + if self.sign == 1: + return Parallel(unit, -Series(self.sys2, *arg_list)) + return Parallel(unit, Series(self.sys2, *arg_list)) + + @property + def sensitivity(self): + """ + Returns the sensitivity function of the feedback loop. + + Sensitivity of a Feedback system is the ratio + of change in the open loop gain to the change in + the closed loop gain. + + .. note:: + This method would not return the complementary + sensitivity function. + + Examples + ======== + + >>> from sympy.abc import p + >>> from sympy.physics.control.lti import TransferFunction, Feedback + >>> C = TransferFunction(5*p + 10, p + 10, p) + >>> P = TransferFunction(1 - p, p + 2, p) + >>> F_1 = Feedback(P, C) + >>> F_1.sensitivity + 1/((1 - p)*(5*p + 10)/((p + 2)*(p + 10)) + 1) + + """ + + return 1/(1 - self.sign*self.sys1.to_expr()*self.sys2.to_expr()) + + def doit(self, cancel=False, expand=False, **hints): + """ + Returns the resultant transfer function or state space obtained by + feedback connection of transfer functions or state space objects. + + Examples + ======== + + >>> from sympy.abc import s + >>> from sympy import Matrix + >>> from sympy.physics.control.lti import TransferFunction, Feedback, StateSpace + >>> plant = TransferFunction(3*s**2 + 7*s - 3, s**2 - 4*s + 2, s) + >>> controller = TransferFunction(5*s - 10, s + 7, s) + >>> F1 = Feedback(plant, controller) + >>> F1.doit() + TransferFunction((s + 7)*(s**2 - 4*s + 2)*(3*s**2 + 7*s - 3), ((s + 7)*(s**2 - 4*s + 2) + (5*s - 10)*(3*s**2 + 7*s - 3))*(s**2 - 4*s + 2), s) + >>> G = TransferFunction(2*s**2 + 5*s + 1, s**2 + 2*s + 3, s) + >>> F2 = Feedback(G, TransferFunction(1, 1, s)) + >>> F2.doit() + TransferFunction((s**2 + 2*s + 3)*(2*s**2 + 5*s + 1), (s**2 + 2*s + 3)*(3*s**2 + 7*s + 4), s) + + Use kwarg ``expand=True`` to expand the resultant transfer function. + Use ``cancel=True`` to cancel out the common terms in numerator and + denominator. + + >>> F2.doit(cancel=True, expand=True) + TransferFunction(2*s**2 + 5*s + 1, 3*s**2 + 7*s + 4, s) + >>> F2.doit(expand=True) + TransferFunction(2*s**4 + 9*s**3 + 17*s**2 + 17*s + 3, 3*s**4 + 13*s**3 + 27*s**2 + 29*s + 12, s) + + If the connection contain any ``StateSpace`` object then ``doit()`` + will return the equivalent ``StateSpace`` object. + + >>> A1 = Matrix([[-1.5, -2], [1, 0]]) + >>> B1 = Matrix([0.5, 0]) + >>> C1 = Matrix([[0, 1]]) + >>> A2 = Matrix([[0, 1], [-5, -2]]) + >>> B2 = Matrix([0, 3]) + >>> C2 = Matrix([[0, 1]]) + >>> ss1 = StateSpace(A1, B1, C1) + >>> ss2 = StateSpace(A2, B2, C2) + >>> F3 = Feedback(ss1, ss2) + >>> F3.doit() + StateSpace(Matrix([ + [-1.5, -2, 0, -0.5], + [ 1, 0, 0, 0], + [ 0, 0, 0, 1], + [ 0, 3, -5, -2]]), Matrix([ + [0.5], + [ 0], + [ 0], + [ 0]]), Matrix([[0, 1, 0, 0]]), Matrix([[0]])) + + """ + if self.is_StateSpace_object: + sys1_ss = self.sys1.doit().rewrite(StateSpace) + sys2_ss = self.sys2.doit().rewrite(StateSpace) + A1, B1, C1, D1 = sys1_ss.A, sys1_ss.B, sys1_ss.C, sys1_ss.D + A2, B2, C2, D2 = sys2_ss.A, sys2_ss.B, sys2_ss.C, sys2_ss.D + + # Create identity matrices + I_inputs = eye(self.num_inputs) + I_outputs = eye(self.num_outputs) + + # Compute F and its inverse + F = I_inputs - self.sign * D2 * D1 + E = F.inv() + + # Compute intermediate matrices + E_D2 = E * D2 + E_C2 = E * C2 + T1 = I_outputs + self.sign * D1 * E_D2 + T2 = I_inputs + self.sign * E_D2 * D1 + A = Matrix.vstack( + Matrix.hstack(A1 + self.sign * B1 * E_D2 * C1, self.sign * B1 * E_C2), + Matrix.hstack(B2 * T1 * C1, A2 + self.sign * B2 * D1 * E_C2) + ) + B = Matrix.vstack(B1 * T2, B2 * D1 * T2) + C = Matrix.hstack(T1 * C1, self.sign * D1 * E_C2) + D = D1 * T2 + return StateSpace(A, B, C, D) + + arg_list = list(self.sys1.args) if isinstance(self.sys1, Series) else [self.sys1] + # F_n and F_d are resultant TFs of num and den of Feedback. + F_n, unit = self.sys1.doit(), TransferFunction(1, 1, self.sys1.var) + if self.sign == -1: + F_d = Parallel(unit, Series(self.sys2, *arg_list)).doit() + else: + F_d = Parallel(unit, -Series(self.sys2, *arg_list)).doit() + + _resultant_tf = TransferFunction(F_n.num * F_d.den, F_n.den * F_d.num, F_n.var) + + if cancel: + _resultant_tf = _resultant_tf.simplify() + + if expand: + _resultant_tf = _resultant_tf.expand() + + return _resultant_tf + + def _eval_rewrite_as_TransferFunction(self, num, den, sign, **kwargs): + if self.is_StateSpace_object: + return self.doit().rewrite(TransferFunction)[0][0] + return self.doit() + + def to_expr(self): + """ + Converts a ``Feedback`` object to SymPy Expr. + + Examples + ======== + + >>> from sympy.abc import s, a, b + >>> from sympy.physics.control.lti import TransferFunction, Feedback + >>> from sympy import Expr + >>> tf1 = TransferFunction(a+s, 1, s) + >>> tf2 = TransferFunction(b+s, 1, s) + >>> fd1 = Feedback(tf1, tf2) + >>> fd1.to_expr() + (a + s)/((a + s)*(b + s) + 1) + >>> isinstance(_, Expr) + True + """ + + return self.doit().to_expr() + + def __neg__(self): + return Feedback(-self.sys1, -self.sys2, self.sign) + + +def _is_invertible(a, b, sign): + """ + Checks whether a given pair of MIMO + systems passed is invertible or not. + """ + _mat = eye(a.num_outputs) - sign*(a.doit()._expr_mat)*(b.doit()._expr_mat) + _det = _mat.det() + + return _det != 0 + + +class MIMOFeedback(MIMOLinearTimeInvariant): + r""" + A class for representing closed-loop feedback interconnection between two + MIMO input/output systems. + + Parameters + ========== + + sys1 : MIMOSeries, TransferFunctionMatrix, StateSpace + The MIMO system placed on the feedforward path. + sys2 : MIMOSeries, TransferFunctionMatrix, StateSpace + The system placed on the feedback path + (often a feedback controller). + sign : int, optional + The sign of feedback. Can either be ``1`` + (for positive feedback) or ``-1`` (for negative feedback). + Default value is `-1`. + + Raises + ====== + + ValueError + When ``sys1`` and ``sys2`` are not using the + same complex variable of the Laplace transform. + + Forward path model should have an equal number of inputs/outputs + to the feedback path outputs/inputs. + + When product of ``sys1`` and ``sys2`` is not a square matrix. + + When the equivalent MIMO system is not invertible. + + TypeError + When either ``sys1`` or ``sys2`` is not a ``MIMOSeries``, + ``TransferFunctionMatrix`` or a ``StateSpace`` object. + + Examples + ======== + + >>> from sympy import Matrix, pprint + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import StateSpace, TransferFunctionMatrix, MIMOFeedback + >>> plant_mat = Matrix([[1, 1/s], [0, 1]]) + >>> controller_mat = Matrix([[10, 0], [0, 10]]) # Constant Gain + >>> plant = TransferFunctionMatrix.from_Matrix(plant_mat, s) + >>> controller = TransferFunctionMatrix.from_Matrix(controller_mat, s) + >>> feedback = MIMOFeedback(plant, controller) # Negative Feedback (default) + >>> pprint(feedback, use_unicode=False) + / [1 1] [10 0 ] \-1 [1 1] + | [- -] [-- - ] | [- -] + | [1 s] [1 1 ] | [1 s] + |I + [ ] *[ ] | * [ ] + | [0 1] [0 10] | [0 1] + | [- -] [- --] | [- -] + \ [1 1]{t} [1 1 ]{t}/ [1 1]{t} + + To get the equivalent system matrix, use either ``doit`` or ``rewrite`` method. + + >>> pprint(feedback.doit(), use_unicode=False) + [1 1 ] + [-- -----] + [11 121*s] + [ ] + [0 1 ] + [- -- ] + [1 11 ]{t} + + To negate the ``MIMOFeedback`` object, use ``-`` operator. + + >>> neg_feedback = -feedback + >>> pprint(neg_feedback.doit(), use_unicode=False) + [-1 -1 ] + [--- -----] + [11 121*s] + [ ] + [ 0 -1 ] + [ - --- ] + [ 1 11 ]{t} + + ``MIMOFeedback`` can also be used to connect MIMO ``StateSpace`` systems. + + >>> A1 = Matrix([[4, 1], [2, -3]]) + >>> B1 = Matrix([[5, 2], [-3, -3]]) + >>> C1 = Matrix([[2, -4], [0, 1]]) + >>> D1 = Matrix([[3, 2], [1, -1]]) + >>> A2 = Matrix([[-3, 4, 2], [-1, -3, 0], [2, 5, 3]]) + >>> B2 = Matrix([[1, 4], [-3, -3], [-2, 1]]) + >>> C2 = Matrix([[4, 2, -3], [1, 4, 3]]) + >>> D2 = Matrix([[-2, 4], [0, 1]]) + >>> ss1 = StateSpace(A1, B1, C1, D1) + >>> ss2 = StateSpace(A2, B2, C2, D2) + >>> F1 = MIMOFeedback(ss1, ss2) + >>> F1 + MIMOFeedback(StateSpace(Matrix([ + [4, 1], + [2, -3]]), Matrix([ + [ 5, 2], + [-3, -3]]), Matrix([ + [2, -4], + [0, 1]]), Matrix([ + [3, 2], + [1, -1]])), StateSpace(Matrix([ + [-3, 4, 2], + [-1, -3, 0], + [ 2, 5, 3]]), Matrix([ + [ 1, 4], + [-3, -3], + [-2, 1]]), Matrix([ + [4, 2, -3], + [1, 4, 3]]), Matrix([ + [-2, 4], + [ 0, 1]])), -1) + + ``doit()`` can be used to find ``StateSpace`` equivalent for the system containing ``StateSpace`` objects. + + >>> F1.doit() + StateSpace(Matrix([ + [ 3, -3/4, -15/4, -37/2, -15], + [ 7/2, -39/8, 9/8, 39/4, 9], + [ 3, -41/4, -45/4, -51/2, -19], + [-9/2, 129/8, 73/8, 171/4, 36], + [-3/2, 47/8, 31/8, 85/4, 18]]), Matrix([ + [-1/4, 19/4], + [ 3/8, -21/8], + [ 1/4, 29/4], + [ 3/8, -93/8], + [ 5/8, -35/8]]), Matrix([ + [ 1, -15/4, -7/4, -21/2, -9], + [1/2, -13/8, -13/8, -19/4, -3]]), Matrix([ + [-1/4, 11/4], + [ 1/8, 9/8]])) + + See Also + ======== + + Feedback, MIMOSeries, MIMOParallel + + """ + def __new__(cls, sys1, sys2, sign=-1): + if not isinstance(sys1, (TransferFunctionMatrix, MIMOSeries, StateSpace)): + raise TypeError("Unsupported type for `sys1` in MIMO Feedback.") + + if not isinstance(sys2, (TransferFunctionMatrix, MIMOSeries, StateSpace)): + raise TypeError("Unsupported type for `sys2` in MIMO Feedback.") + + if sys1.num_inputs != sys2.num_outputs or \ + sys1.num_outputs != sys2.num_inputs: + raise ValueError(filldedent(""" + Product of `sys1` and `sys2` must + yield a square matrix.""")) + + if sign not in (-1, 1): + raise ValueError(filldedent(""" + Unsupported type for feedback. `sign` arg should + either be 1 (positive feedback loop) or -1 + (negative feedback loop).""")) + + if sys1.is_StateSpace_object or sys2.is_StateSpace_object: + cls.is_StateSpace_object = True + else: + if not _is_invertible(sys1, sys2, sign): + raise ValueError("Non-Invertible system inputted.") + cls.is_StateSpace_object = False + + if not cls.is_StateSpace_object and sys1.var != sys2.var: + raise ValueError(filldedent(""" + Both `sys1` and `sys2` should be using the + same complex variable.""")) + + return super().__new__(cls, sys1, sys2, _sympify(sign)) + + @property + def sys1(self): + r""" + Returns the system placed on the feedforward path of the MIMO feedback interconnection. + + Examples + ======== + + >>> from sympy import pprint + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction, TransferFunctionMatrix, MIMOFeedback + >>> tf1 = TransferFunction(s**2 + s + 1, s**2 - s + 1, s) + >>> tf2 = TransferFunction(1, s, s) + >>> tf3 = TransferFunction(1, 1, s) + >>> sys1 = TransferFunctionMatrix([[tf1, tf2], [tf2, tf1]]) + >>> sys2 = TransferFunctionMatrix([[tf3, tf3], [tf3, tf2]]) + >>> F_1 = MIMOFeedback(sys1, sys2, 1) + >>> F_1.sys1 + TransferFunctionMatrix(((TransferFunction(s**2 + s + 1, s**2 - s + 1, s), TransferFunction(1, s, s)), (TransferFunction(1, s, s), TransferFunction(s**2 + s + 1, s**2 - s + 1, s)))) + >>> pprint(_, use_unicode=False) + [ 2 ] + [s + s + 1 1 ] + [---------- - ] + [ 2 s ] + [s - s + 1 ] + [ ] + [ 2 ] + [ 1 s + s + 1] + [ - ----------] + [ s 2 ] + [ s - s + 1]{t} + + """ + return self.args[0] + + @property + def sys2(self): + r""" + Returns the feedback controller of the MIMO feedback interconnection. + + Examples + ======== + + >>> from sympy import pprint + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction, TransferFunctionMatrix, MIMOFeedback + >>> tf1 = TransferFunction(s**2, s**3 - s + 1, s) + >>> tf2 = TransferFunction(1, s, s) + >>> tf3 = TransferFunction(1, 1, s) + >>> sys1 = TransferFunctionMatrix([[tf1, tf2], [tf2, tf1]]) + >>> sys2 = TransferFunctionMatrix([[tf1, tf3], [tf3, tf2]]) + >>> F_1 = MIMOFeedback(sys1, sys2) + >>> F_1.sys2 + TransferFunctionMatrix(((TransferFunction(s**2, s**3 - s + 1, s), TransferFunction(1, 1, s)), (TransferFunction(1, 1, s), TransferFunction(1, s, s)))) + >>> pprint(_, use_unicode=False) + [ 2 ] + [ s 1] + [---------- -] + [ 3 1] + [s - s + 1 ] + [ ] + [ 1 1] + [ - -] + [ 1 s]{t} + + """ + return self.args[1] + + @property + def var(self): + r""" + Returns the complex variable of the Laplace transform used by all + the transfer functions involved in the MIMO feedback loop. + + Examples + ======== + + >>> from sympy.abc import p + >>> from sympy.physics.control.lti import TransferFunction, TransferFunctionMatrix, MIMOFeedback + >>> tf1 = TransferFunction(p, 1 - p, p) + >>> tf2 = TransferFunction(1, p, p) + >>> tf3 = TransferFunction(1, 1, p) + >>> sys1 = TransferFunctionMatrix([[tf1, tf2], [tf2, tf1]]) + >>> sys2 = TransferFunctionMatrix([[tf1, tf3], [tf3, tf2]]) + >>> F_1 = MIMOFeedback(sys1, sys2, 1) # Positive feedback + >>> F_1.var + p + + """ + return self.sys1.var + + @property + def sign(self): + r""" + Returns the type of feedback interconnection of two models. ``1`` + for Positive and ``-1`` for Negative. + """ + return self.args[2] + + @property + def sensitivity(self): + r""" + Returns the sensitivity function matrix of the feedback loop. + + Sensitivity of a closed-loop system is the ratio of change + in the open loop gain to the change in the closed loop gain. + + .. note:: + This method would not return the complementary + sensitivity function. + + Examples + ======== + + >>> from sympy import pprint + >>> from sympy.abc import p + >>> from sympy.physics.control.lti import TransferFunction, TransferFunctionMatrix, MIMOFeedback + >>> tf1 = TransferFunction(p, 1 - p, p) + >>> tf2 = TransferFunction(1, p, p) + >>> tf3 = TransferFunction(1, 1, p) + >>> sys1 = TransferFunctionMatrix([[tf1, tf2], [tf2, tf1]]) + >>> sys2 = TransferFunctionMatrix([[tf1, tf3], [tf3, tf2]]) + >>> F_1 = MIMOFeedback(sys1, sys2, 1) # Positive feedback + >>> F_2 = MIMOFeedback(sys1, sys2) # Negative feedback + >>> pprint(F_1.sensitivity, use_unicode=False) + [ 4 3 2 5 4 2 ] + [- p + 3*p - 4*p + 3*p - 1 p - 2*p + 3*p - 3*p + 1 ] + [---------------------------- -----------------------------] + [ 4 3 2 5 4 3 2 ] + [ p + 3*p - 8*p + 8*p - 3 p + 3*p - 8*p + 8*p - 3*p] + [ ] + [ 4 3 2 3 2 ] + [ p - p - p + p 3*p - 6*p + 4*p - 1 ] + [ -------------------------- -------------------------- ] + [ 4 3 2 4 3 2 ] + [ p + 3*p - 8*p + 8*p - 3 p + 3*p - 8*p + 8*p - 3 ] + >>> pprint(F_2.sensitivity, use_unicode=False) + [ 4 3 2 5 4 2 ] + [p - 3*p + 2*p + p - 1 p - 2*p + 3*p - 3*p + 1] + [------------------------ --------------------------] + [ 4 3 5 4 2 ] + [ p - 3*p + 2*p - 1 p - 3*p + 2*p - p ] + [ ] + [ 4 3 2 4 3 ] + [ p - p - p + p 2*p - 3*p + 2*p - 1 ] + [ ------------------- --------------------- ] + [ 4 3 4 3 ] + [ p - 3*p + 2*p - 1 p - 3*p + 2*p - 1 ] + + """ + _sys1_mat = self.sys1.doit()._expr_mat + _sys2_mat = self.sys2.doit()._expr_mat + + return (eye(self.sys1.num_inputs) - \ + self.sign*_sys1_mat*_sys2_mat).inv() + + @property + def num_inputs(self): + """Returns the number of inputs of the system.""" + return self.sys1.num_inputs + + @property + def num_outputs(self): + """Returns the number of outputs of the system.""" + return self.sys1.num_outputs + + def doit(self, cancel=True, expand=False, **hints): + r""" + Returns the resultant transfer function matrix obtained by the + feedback interconnection. + + Examples + ======== + + >>> from sympy import pprint + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction, TransferFunctionMatrix, MIMOFeedback + >>> tf1 = TransferFunction(s, 1 - s, s) + >>> tf2 = TransferFunction(1, s, s) + >>> tf3 = TransferFunction(5, 1, s) + >>> tf4 = TransferFunction(s - 1, s, s) + >>> tf5 = TransferFunction(0, 1, s) + >>> sys1 = TransferFunctionMatrix([[tf1, tf2], [tf3, tf4]]) + >>> sys2 = TransferFunctionMatrix([[tf3, tf5], [tf5, tf5]]) + >>> F_1 = MIMOFeedback(sys1, sys2, 1) + >>> pprint(F_1, use_unicode=False) + / [ s 1 ] [5 0] \-1 [ s 1 ] + | [----- - ] [- -] | [----- - ] + | [1 - s s ] [1 1] | [1 - s s ] + |I - [ ] *[ ] | * [ ] + | [ 5 s - 1] [0 0] | [ 5 s - 1] + | [ - -----] [- -] | [ - -----] + \ [ 1 s ]{t} [1 1]{t}/ [ 1 s ]{t} + >>> pprint(F_1.doit(), use_unicode=False) + [ -s s - 1 ] + [------- ----------- ] + [6*s - 1 s*(6*s - 1) ] + [ ] + [5*s - 5 (s - 1)*(6*s + 24)] + [------- ------------------] + [6*s - 1 s*(6*s - 1) ]{t} + + If the user wants the resultant ``TransferFunctionMatrix`` object without + canceling the common factors then the ``cancel`` kwarg should be passed ``False``. + + >>> pprint(F_1.doit(cancel=False), use_unicode=False) + [ s*(s - 1) s - 1 ] + [ ----------------- ----------- ] + [ (1 - s)*(6*s - 1) s*(6*s - 1) ] + [ ] + [s*(25*s - 25) + 5*(1 - s)*(6*s - 1) s*(s - 1)*(6*s - 1) + s*(25*s - 25)] + [----------------------------------- -----------------------------------] + [ (1 - s)*(6*s - 1) 2 ] + [ s *(6*s - 1) ]{t} + + If the user wants the expanded form of the resultant transfer function matrix, + the ``expand`` kwarg should be passed as ``True``. + + >>> pprint(F_1.doit(expand=True), use_unicode=False) + [ -s s - 1 ] + [------- -------- ] + [6*s - 1 2 ] + [ 6*s - s ] + [ ] + [ 2 ] + [5*s - 5 6*s + 18*s - 24] + [------- ----------------] + [6*s - 1 2 ] + [ 6*s - s ]{t} + + """ + if self.is_StateSpace_object: + sys1_ss = self.sys1.doit().rewrite(StateSpace) + sys2_ss = self.sys2.doit().rewrite(StateSpace) + A1, B1, C1, D1 = sys1_ss.A, sys1_ss.B, sys1_ss.C, sys1_ss.D + A2, B2, C2, D2 = sys2_ss.A, sys2_ss.B, sys2_ss.C, sys2_ss.D + + # Create identity matrices + I_inputs = eye(self.num_inputs) + I_outputs = eye(self.num_outputs) + + # Compute F and its inverse + F = I_inputs - self.sign * D2 * D1 + E = F.inv() + + # Compute intermediate matrices + E_D2 = E * D2 + E_C2 = E * C2 + T1 = I_outputs + self.sign * D1 * E_D2 + T2 = I_inputs + self.sign * E_D2 * D1 + A = Matrix.vstack( + Matrix.hstack(A1 + self.sign * B1 * E_D2 * C1, self.sign * B1 * E_C2), + Matrix.hstack(B2 * T1 * C1, A2 + self.sign * B2 * D1 * E_C2) + ) + B = Matrix.vstack(B1 * T2, B2 * D1 * T2) + C = Matrix.hstack(T1 * C1, self.sign * D1 * E_C2) + D = D1 * T2 + return StateSpace(A, B, C, D) + + _mat = self.sensitivity * self.sys1.doit()._expr_mat + + _resultant_tfm = _to_TFM(_mat, self.var) + + if cancel: + _resultant_tfm = _resultant_tfm.simplify() + + if expand: + _resultant_tfm = _resultant_tfm.expand() + + return _resultant_tfm + + def _eval_rewrite_as_TransferFunctionMatrix(self, sys1, sys2, sign, **kwargs): + return self.doit() + + def __neg__(self): + return MIMOFeedback(-self.sys1, -self.sys2, self.sign) + + +def _to_TFM(mat, var): + """Private method to convert ImmutableMatrix to TransferFunctionMatrix efficiently""" + to_tf = lambda expr: TransferFunction.from_rational_expression(expr, var) + arg = [[to_tf(expr) for expr in row] for row in mat.tolist()] + return TransferFunctionMatrix(arg) + + +class TransferFunctionMatrix(MIMOLinearTimeInvariant): + r""" + A class for representing the MIMO (multiple-input and multiple-output) + generalization of the SISO (single-input and single-output) transfer function. + + It is a matrix of transfer functions (``TransferFunction``, SISO-``Series`` or SISO-``Parallel``). + There is only one argument, ``arg`` which is also the compulsory argument. + ``arg`` is expected to be strictly of the type list of lists + which holds the transfer functions or reducible to transfer functions. + + Parameters + ========== + + arg : Nested ``List`` (strictly). + Users are expected to input a nested list of ``TransferFunction``, ``Series`` + and/or ``Parallel`` objects. + + Examples + ======== + + .. note:: + ``pprint()`` can be used for better visualization of ``TransferFunctionMatrix`` objects. + + >>> from sympy.abc import s, p, a + >>> from sympy import pprint + >>> from sympy.physics.control.lti import TransferFunction, TransferFunctionMatrix, Series, Parallel + >>> tf_1 = TransferFunction(s + a, s**2 + s + 1, s) + >>> tf_2 = TransferFunction(p**4 - 3*p + 2, s + p, s) + >>> tf_3 = TransferFunction(3, s + 2, s) + >>> tf_4 = TransferFunction(-a + p, 9*s - 9, s) + >>> tfm_1 = TransferFunctionMatrix([[tf_1], [tf_2], [tf_3]]) + >>> tfm_1 + TransferFunctionMatrix(((TransferFunction(a + s, s**2 + s + 1, s),), (TransferFunction(p**4 - 3*p + 2, p + s, s),), (TransferFunction(3, s + 2, s),))) + >>> tfm_1.var + s + >>> tfm_1.num_inputs + 1 + >>> tfm_1.num_outputs + 3 + >>> tfm_1.shape + (3, 1) + >>> tfm_1.args + (((TransferFunction(a + s, s**2 + s + 1, s),), (TransferFunction(p**4 - 3*p + 2, p + s, s),), (TransferFunction(3, s + 2, s),)),) + >>> tfm_2 = TransferFunctionMatrix([[tf_1, -tf_3], [tf_2, -tf_1], [tf_3, -tf_2]]) + >>> tfm_2 + TransferFunctionMatrix(((TransferFunction(a + s, s**2 + s + 1, s), TransferFunction(-3, s + 2, s)), (TransferFunction(p**4 - 3*p + 2, p + s, s), TransferFunction(-a - s, s**2 + s + 1, s)), (TransferFunction(3, s + 2, s), TransferFunction(-p**4 + 3*p - 2, p + s, s)))) + >>> pprint(tfm_2, use_unicode=False) # pretty-printing for better visualization + [ a + s -3 ] + [ ---------- ----- ] + [ 2 s + 2 ] + [ s + s + 1 ] + [ ] + [ 4 ] + [p - 3*p + 2 -a - s ] + [------------ ---------- ] + [ p + s 2 ] + [ s + s + 1 ] + [ ] + [ 4 ] + [ 3 - p + 3*p - 2] + [ ----- --------------] + [ s + 2 p + s ]{t} + + TransferFunctionMatrix can be transposed, if user wants to switch the input and output transfer functions + + >>> tfm_2.transpose() + TransferFunctionMatrix(((TransferFunction(a + s, s**2 + s + 1, s), TransferFunction(p**4 - 3*p + 2, p + s, s), TransferFunction(3, s + 2, s)), (TransferFunction(-3, s + 2, s), TransferFunction(-a - s, s**2 + s + 1, s), TransferFunction(-p**4 + 3*p - 2, p + s, s)))) + >>> pprint(_, use_unicode=False) + [ 4 ] + [ a + s p - 3*p + 2 3 ] + [---------- ------------ ----- ] + [ 2 p + s s + 2 ] + [s + s + 1 ] + [ ] + [ 4 ] + [ -3 -a - s - p + 3*p - 2] + [ ----- ---------- --------------] + [ s + 2 2 p + s ] + [ s + s + 1 ]{t} + + >>> tf_5 = TransferFunction(5, s, s) + >>> tf_6 = TransferFunction(5*s, (2 + s**2), s) + >>> tf_7 = TransferFunction(5, (s*(2 + s**2)), s) + >>> tf_8 = TransferFunction(5, 1, s) + >>> tfm_3 = TransferFunctionMatrix([[tf_5, tf_6], [tf_7, tf_8]]) + >>> tfm_3 + TransferFunctionMatrix(((TransferFunction(5, s, s), TransferFunction(5*s, s**2 + 2, s)), (TransferFunction(5, s*(s**2 + 2), s), TransferFunction(5, 1, s)))) + >>> pprint(tfm_3, use_unicode=False) + [ 5 5*s ] + [ - ------] + [ s 2 ] + [ s + 2] + [ ] + [ 5 5 ] + [---------- - ] + [ / 2 \ 1 ] + [s*\s + 2/ ]{t} + >>> tfm_3.var + s + >>> tfm_3.shape + (2, 2) + >>> tfm_3.num_outputs + 2 + >>> tfm_3.num_inputs + 2 + >>> tfm_3.args + (((TransferFunction(5, s, s), TransferFunction(5*s, s**2 + 2, s)), (TransferFunction(5, s*(s**2 + 2), s), TransferFunction(5, 1, s))),) + + To access the ``TransferFunction`` at any index in the ``TransferFunctionMatrix``, use the index notation. + + >>> tfm_3[1, 0] # gives the TransferFunction present at 2nd Row and 1st Col. Similar to that in Matrix classes + TransferFunction(5, s*(s**2 + 2), s) + >>> tfm_3[0, 0] # gives the TransferFunction present at 1st Row and 1st Col. + TransferFunction(5, s, s) + >>> tfm_3[:, 0] # gives the first column + TransferFunctionMatrix(((TransferFunction(5, s, s),), (TransferFunction(5, s*(s**2 + 2), s),))) + >>> pprint(_, use_unicode=False) + [ 5 ] + [ - ] + [ s ] + [ ] + [ 5 ] + [----------] + [ / 2 \] + [s*\s + 2/]{t} + >>> tfm_3[0, :] # gives the first row + TransferFunctionMatrix(((TransferFunction(5, s, s), TransferFunction(5*s, s**2 + 2, s)),)) + >>> pprint(_, use_unicode=False) + [5 5*s ] + [- ------] + [s 2 ] + [ s + 2]{t} + + To negate a transfer function matrix, ``-`` operator can be prepended: + + >>> tfm_4 = TransferFunctionMatrix([[tf_2], [-tf_1], [tf_3]]) + >>> -tfm_4 + TransferFunctionMatrix(((TransferFunction(-p**4 + 3*p - 2, p + s, s),), (TransferFunction(a + s, s**2 + s + 1, s),), (TransferFunction(-3, s + 2, s),))) + >>> tfm_5 = TransferFunctionMatrix([[tf_1, tf_2], [tf_3, -tf_1]]) + >>> -tfm_5 + TransferFunctionMatrix(((TransferFunction(-a - s, s**2 + s + 1, s), TransferFunction(-p**4 + 3*p - 2, p + s, s)), (TransferFunction(-3, s + 2, s), TransferFunction(a + s, s**2 + s + 1, s)))) + + ``subs()`` returns the ``TransferFunctionMatrix`` object with the value substituted in the expression. This will not + mutate your original ``TransferFunctionMatrix``. + + >>> tfm_2.subs(p, 2) # substituting p everywhere in tfm_2 with 2. + TransferFunctionMatrix(((TransferFunction(a + s, s**2 + s + 1, s), TransferFunction(-3, s + 2, s)), (TransferFunction(12, s + 2, s), TransferFunction(-a - s, s**2 + s + 1, s)), (TransferFunction(3, s + 2, s), TransferFunction(-12, s + 2, s)))) + >>> pprint(_, use_unicode=False) + [ a + s -3 ] + [---------- ----- ] + [ 2 s + 2 ] + [s + s + 1 ] + [ ] + [ 12 -a - s ] + [ ----- ----------] + [ s + 2 2 ] + [ s + s + 1] + [ ] + [ 3 -12 ] + [ ----- ----- ] + [ s + 2 s + 2 ]{t} + >>> pprint(tfm_2, use_unicode=False) # State of tfm_2 is unchanged after substitution + [ a + s -3 ] + [ ---------- ----- ] + [ 2 s + 2 ] + [ s + s + 1 ] + [ ] + [ 4 ] + [p - 3*p + 2 -a - s ] + [------------ ---------- ] + [ p + s 2 ] + [ s + s + 1 ] + [ ] + [ 4 ] + [ 3 - p + 3*p - 2] + [ ----- --------------] + [ s + 2 p + s ]{t} + + ``subs()`` also supports multiple substitutions. + + >>> tfm_2.subs({p: 2, a: 1}) # substituting p with 2 and a with 1 + TransferFunctionMatrix(((TransferFunction(s + 1, s**2 + s + 1, s), TransferFunction(-3, s + 2, s)), (TransferFunction(12, s + 2, s), TransferFunction(-s - 1, s**2 + s + 1, s)), (TransferFunction(3, s + 2, s), TransferFunction(-12, s + 2, s)))) + >>> pprint(_, use_unicode=False) + [ s + 1 -3 ] + [---------- ----- ] + [ 2 s + 2 ] + [s + s + 1 ] + [ ] + [ 12 -s - 1 ] + [ ----- ----------] + [ s + 2 2 ] + [ s + s + 1] + [ ] + [ 3 -12 ] + [ ----- ----- ] + [ s + 2 s + 2 ]{t} + + Users can reduce the ``Series`` and ``Parallel`` elements of the matrix to ``TransferFunction`` by using + ``doit()``. + + >>> tfm_6 = TransferFunctionMatrix([[Series(tf_3, tf_4), Parallel(tf_3, tf_4)]]) + >>> tfm_6 + TransferFunctionMatrix(((Series(TransferFunction(3, s + 2, s), TransferFunction(-a + p, 9*s - 9, s)), Parallel(TransferFunction(3, s + 2, s), TransferFunction(-a + p, 9*s - 9, s))),)) + >>> pprint(tfm_6, use_unicode=False) + [-a + p 3 -a + p 3 ] + [-------*----- ------- + -----] + [9*s - 9 s + 2 9*s - 9 s + 2]{t} + >>> tfm_6.doit() + TransferFunctionMatrix(((TransferFunction(-3*a + 3*p, (s + 2)*(9*s - 9), s), TransferFunction(27*s + (-a + p)*(s + 2) - 27, (s + 2)*(9*s - 9), s)),)) + >>> pprint(_, use_unicode=False) + [ -3*a + 3*p 27*s + (-a + p)*(s + 2) - 27] + [----------------- ----------------------------] + [(s + 2)*(9*s - 9) (s + 2)*(9*s - 9) ]{t} + >>> tf_9 = TransferFunction(1, s, s) + >>> tf_10 = TransferFunction(1, s**2, s) + >>> tfm_7 = TransferFunctionMatrix([[Series(tf_9, tf_10), tf_9], [tf_10, Parallel(tf_9, tf_10)]]) + >>> tfm_7 + TransferFunctionMatrix(((Series(TransferFunction(1, s, s), TransferFunction(1, s**2, s)), TransferFunction(1, s, s)), (TransferFunction(1, s**2, s), Parallel(TransferFunction(1, s, s), TransferFunction(1, s**2, s))))) + >>> pprint(tfm_7, use_unicode=False) + [ 1 1 ] + [---- - ] + [ 2 s ] + [s*s ] + [ ] + [ 1 1 1] + [ -- -- + -] + [ 2 2 s] + [ s s ]{t} + >>> tfm_7.doit() + TransferFunctionMatrix(((TransferFunction(1, s**3, s), TransferFunction(1, s, s)), (TransferFunction(1, s**2, s), TransferFunction(s**2 + s, s**3, s)))) + >>> pprint(_, use_unicode=False) + [1 1 ] + [-- - ] + [ 3 s ] + [s ] + [ ] + [ 2 ] + [1 s + s] + [-- ------] + [ 2 3 ] + [s s ]{t} + + Addition, subtraction, and multiplication of transfer function matrices can form + unevaluated ``Series`` or ``Parallel`` objects. + + - For addition and subtraction: + All the transfer function matrices must have the same shape. + + - For multiplication (C = A * B): + The number of inputs of the first transfer function matrix (A) must be equal to the + number of outputs of the second transfer function matrix (B). + + Also, use pretty-printing (``pprint``) to analyse better. + + >>> tfm_8 = TransferFunctionMatrix([[tf_3], [tf_2], [-tf_1]]) + >>> tfm_9 = TransferFunctionMatrix([[-tf_3]]) + >>> tfm_10 = TransferFunctionMatrix([[tf_1], [tf_2], [tf_4]]) + >>> tfm_11 = TransferFunctionMatrix([[tf_4], [-tf_1]]) + >>> tfm_12 = TransferFunctionMatrix([[tf_4, -tf_1, tf_3], [-tf_2, -tf_4, -tf_3]]) + >>> tfm_8 + tfm_10 + MIMOParallel(TransferFunctionMatrix(((TransferFunction(3, s + 2, s),), (TransferFunction(p**4 - 3*p + 2, p + s, s),), (TransferFunction(-a - s, s**2 + s + 1, s),))), TransferFunctionMatrix(((TransferFunction(a + s, s**2 + s + 1, s),), (TransferFunction(p**4 - 3*p + 2, p + s, s),), (TransferFunction(-a + p, 9*s - 9, s),)))) + >>> pprint(_, use_unicode=False) + [ 3 ] [ a + s ] + [ ----- ] [ ---------- ] + [ s + 2 ] [ 2 ] + [ ] [ s + s + 1 ] + [ 4 ] [ ] + [p - 3*p + 2] [ 4 ] + [------------] + [p - 3*p + 2] + [ p + s ] [------------] + [ ] [ p + s ] + [ -a - s ] [ ] + [ ---------- ] [ -a + p ] + [ 2 ] [ ------- ] + [ s + s + 1 ]{t} [ 9*s - 9 ]{t} + >>> -tfm_10 - tfm_8 + MIMOParallel(TransferFunctionMatrix(((TransferFunction(-a - s, s**2 + s + 1, s),), (TransferFunction(-p**4 + 3*p - 2, p + s, s),), (TransferFunction(a - p, 9*s - 9, s),))), TransferFunctionMatrix(((TransferFunction(-3, s + 2, s),), (TransferFunction(-p**4 + 3*p - 2, p + s, s),), (TransferFunction(a + s, s**2 + s + 1, s),)))) + >>> pprint(_, use_unicode=False) + [ -a - s ] [ -3 ] + [ ---------- ] [ ----- ] + [ 2 ] [ s + 2 ] + [ s + s + 1 ] [ ] + [ ] [ 4 ] + [ 4 ] [- p + 3*p - 2] + [- p + 3*p - 2] + [--------------] + [--------------] [ p + s ] + [ p + s ] [ ] + [ ] [ a + s ] + [ a - p ] [ ---------- ] + [ ------- ] [ 2 ] + [ 9*s - 9 ]{t} [ s + s + 1 ]{t} + >>> tfm_12 * tfm_8 + MIMOSeries(TransferFunctionMatrix(((TransferFunction(3, s + 2, s),), (TransferFunction(p**4 - 3*p + 2, p + s, s),), (TransferFunction(-a - s, s**2 + s + 1, s),))), TransferFunctionMatrix(((TransferFunction(-a + p, 9*s - 9, s), TransferFunction(-a - s, s**2 + s + 1, s), TransferFunction(3, s + 2, s)), (TransferFunction(-p**4 + 3*p - 2, p + s, s), TransferFunction(a - p, 9*s - 9, s), TransferFunction(-3, s + 2, s))))) + >>> pprint(_, use_unicode=False) + [ 3 ] + [ ----- ] + [ -a + p -a - s 3 ] [ s + 2 ] + [ ------- ---------- -----] [ ] + [ 9*s - 9 2 s + 2] [ 4 ] + [ s + s + 1 ] [p - 3*p + 2] + [ ] *[------------] + [ 4 ] [ p + s ] + [- p + 3*p - 2 a - p -3 ] [ ] + [-------------- ------- -----] [ -a - s ] + [ p + s 9*s - 9 s + 2]{t} [ ---------- ] + [ 2 ] + [ s + s + 1 ]{t} + >>> tfm_12 * tfm_8 * tfm_9 + MIMOSeries(TransferFunctionMatrix(((TransferFunction(-3, s + 2, s),),)), TransferFunctionMatrix(((TransferFunction(3, s + 2, s),), (TransferFunction(p**4 - 3*p + 2, p + s, s),), (TransferFunction(-a - s, s**2 + s + 1, s),))), TransferFunctionMatrix(((TransferFunction(-a + p, 9*s - 9, s), TransferFunction(-a - s, s**2 + s + 1, s), TransferFunction(3, s + 2, s)), (TransferFunction(-p**4 + 3*p - 2, p + s, s), TransferFunction(a - p, 9*s - 9, s), TransferFunction(-3, s + 2, s))))) + >>> pprint(_, use_unicode=False) + [ 3 ] + [ ----- ] + [ -a + p -a - s 3 ] [ s + 2 ] + [ ------- ---------- -----] [ ] + [ 9*s - 9 2 s + 2] [ 4 ] + [ s + s + 1 ] [p - 3*p + 2] [ -3 ] + [ ] *[------------] *[-----] + [ 4 ] [ p + s ] [s + 2]{t} + [- p + 3*p - 2 a - p -3 ] [ ] + [-------------- ------- -----] [ -a - s ] + [ p + s 9*s - 9 s + 2]{t} [ ---------- ] + [ 2 ] + [ s + s + 1 ]{t} + >>> tfm_10 + tfm_8*tfm_9 + MIMOParallel(TransferFunctionMatrix(((TransferFunction(a + s, s**2 + s + 1, s),), (TransferFunction(p**4 - 3*p + 2, p + s, s),), (TransferFunction(-a + p, 9*s - 9, s),))), MIMOSeries(TransferFunctionMatrix(((TransferFunction(-3, s + 2, s),),)), TransferFunctionMatrix(((TransferFunction(3, s + 2, s),), (TransferFunction(p**4 - 3*p + 2, p + s, s),), (TransferFunction(-a - s, s**2 + s + 1, s),))))) + >>> pprint(_, use_unicode=False) + [ a + s ] [ 3 ] + [ ---------- ] [ ----- ] + [ 2 ] [ s + 2 ] + [ s + s + 1 ] [ ] + [ ] [ 4 ] + [ 4 ] [p - 3*p + 2] [ -3 ] + [p - 3*p + 2] + [------------] *[-----] + [------------] [ p + s ] [s + 2]{t} + [ p + s ] [ ] + [ ] [ -a - s ] + [ -a + p ] [ ---------- ] + [ ------- ] [ 2 ] + [ 9*s - 9 ]{t} [ s + s + 1 ]{t} + + These unevaluated ``Series`` or ``Parallel`` objects can convert into the + resultant transfer function matrix using ``.doit()`` method or by + ``.rewrite(TransferFunctionMatrix)``. + + >>> (-tfm_8 + tfm_10 + tfm_8*tfm_9).doit() + TransferFunctionMatrix(((TransferFunction((a + s)*(s + 2)**3 - 3*(s + 2)**2*(s**2 + s + 1) - 9*(s + 2)*(s**2 + s + 1), (s + 2)**3*(s**2 + s + 1), s),), (TransferFunction((p + s)*(-3*p**4 + 9*p - 6), (p + s)**2*(s + 2), s),), (TransferFunction((-a + p)*(s + 2)*(s**2 + s + 1)**2 + (a + s)*(s + 2)*(9*s - 9)*(s**2 + s + 1) + (3*a + 3*s)*(9*s - 9)*(s**2 + s + 1), (s + 2)*(9*s - 9)*(s**2 + s + 1)**2, s),))) + >>> (-tfm_12 * -tfm_8 * -tfm_9).rewrite(TransferFunctionMatrix) + TransferFunctionMatrix(((TransferFunction(3*(-3*a + 3*p)*(p + s)*(s + 2)*(s**2 + s + 1)**2 + 3*(-3*a - 3*s)*(p + s)*(s + 2)*(9*s - 9)*(s**2 + s + 1) + 3*(a + s)*(s + 2)**2*(9*s - 9)*(-p**4 + 3*p - 2)*(s**2 + s + 1), (p + s)*(s + 2)**3*(9*s - 9)*(s**2 + s + 1)**2, s),), (TransferFunction(3*(-a + p)*(p + s)*(s + 2)**2*(-p**4 + 3*p - 2)*(s**2 + s + 1) + 3*(3*a + 3*s)*(p + s)**2*(s + 2)*(9*s - 9) + 3*(p + s)*(s + 2)*(9*s - 9)*(-3*p**4 + 9*p - 6)*(s**2 + s + 1), (p + s)**2*(s + 2)**3*(9*s - 9)*(s**2 + s + 1), s),))) + + See Also + ======== + + TransferFunction, MIMOSeries, MIMOParallel, Feedback + + """ + def __new__(cls, arg): + + expr_mat_arg = [] + try: + var = arg[0][0].var + except TypeError: + raise ValueError(filldedent(""" + `arg` param in TransferFunctionMatrix should + strictly be a nested list containing TransferFunction + objects.""")) + for row in arg: + temp = [] + for element in row: + if not isinstance(element, SISOLinearTimeInvariant): + raise TypeError(filldedent(""" + Each element is expected to be of + type `SISOLinearTimeInvariant`.""")) + + if var != element.var: + raise ValueError(filldedent(""" + Conflicting value(s) found for `var`. All TransferFunction + instances in TransferFunctionMatrix should use the same + complex variable in Laplace domain.""")) + + temp.append(element.to_expr()) + expr_mat_arg.append(temp) + + if isinstance(arg, (tuple, list, Tuple)): + # Making nested Tuple (sympy.core.containers.Tuple) from nested list or nested Python tuple + arg = Tuple(*(Tuple(*r, sympify=False) for r in arg), sympify=False) + + obj = super(TransferFunctionMatrix, cls).__new__(cls, arg) + obj._expr_mat = ImmutableMatrix(expr_mat_arg) + obj.is_StateSpace_object = False + return obj + + @classmethod + def from_Matrix(cls, matrix, var): + """ + Creates a new ``TransferFunctionMatrix`` efficiently from a SymPy Matrix of ``Expr`` objects. + + Parameters + ========== + + matrix : ``ImmutableMatrix`` having ``Expr``/``Number`` elements. + var : Symbol + Complex variable of the Laplace transform which will be used by the + all the ``TransferFunction`` objects in the ``TransferFunctionMatrix``. + + Examples + ======== + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunctionMatrix + >>> from sympy import Matrix, pprint + >>> M = Matrix([[s, 1/s], [1/(s+1), s]]) + >>> M_tf = TransferFunctionMatrix.from_Matrix(M, s) + >>> pprint(M_tf, use_unicode=False) + [ s 1] + [ - -] + [ 1 s] + [ ] + [ 1 s] + [----- -] + [s + 1 1]{t} + >>> M_tf.elem_poles() + [[[], [0]], [[-1], []]] + >>> M_tf.elem_zeros() + [[[0], []], [[], [0]]] + + """ + return _to_TFM(matrix, var) + + @property + def var(self): + """ + Returns the complex variable used by all the transfer functions or + ``Series``/``Parallel`` objects in a transfer function matrix. + + Examples + ======== + + >>> from sympy.abc import p, s + >>> from sympy.physics.control.lti import TransferFunction, TransferFunctionMatrix, Series, Parallel + >>> G1 = TransferFunction(p**2 + 2*p + 4, p - 6, p) + >>> G2 = TransferFunction(p, 4 - p, p) + >>> G3 = TransferFunction(0, p**4 - 1, p) + >>> G4 = TransferFunction(s + 1, s**2 + s + 1, s) + >>> S1 = Series(G1, G2) + >>> S2 = Series(-G3, Parallel(G2, -G1)) + >>> tfm1 = TransferFunctionMatrix([[G1], [G2], [G3]]) + >>> tfm1.var + p + >>> tfm2 = TransferFunctionMatrix([[-S1, -S2], [S1, S2]]) + >>> tfm2.var + p + >>> tfm3 = TransferFunctionMatrix([[G4]]) + >>> tfm3.var + s + + """ + return self.args[0][0][0].var + + @property + def num_inputs(self): + """ + Returns the number of inputs of the system. + + Examples + ======== + + >>> from sympy.abc import s, p + >>> from sympy.physics.control.lti import TransferFunction, TransferFunctionMatrix + >>> G1 = TransferFunction(s + 3, s**2 - 3, s) + >>> G2 = TransferFunction(4, s**2, s) + >>> G3 = TransferFunction(p**2 + s**2, p - 3, s) + >>> tfm_1 = TransferFunctionMatrix([[G2, -G1, G3], [-G2, -G1, -G3]]) + >>> tfm_1.num_inputs + 3 + + See Also + ======== + + num_outputs + + """ + return self._expr_mat.shape[1] + + @property + def num_outputs(self): + """ + Returns the number of outputs of the system. + + Examples + ======== + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunctionMatrix + >>> from sympy import Matrix + >>> M_1 = Matrix([[s], [1/s]]) + >>> TFM = TransferFunctionMatrix.from_Matrix(M_1, s) + >>> print(TFM) + TransferFunctionMatrix(((TransferFunction(s, 1, s),), (TransferFunction(1, s, s),))) + >>> TFM.num_outputs + 2 + + See Also + ======== + + num_inputs + + """ + return self._expr_mat.shape[0] + + @property + def shape(self): + """ + Returns the shape of the transfer function matrix, that is, ``(# of outputs, # of inputs)``. + + Examples + ======== + + >>> from sympy.abc import s, p + >>> from sympy.physics.control.lti import TransferFunction, TransferFunctionMatrix + >>> tf1 = TransferFunction(p**2 - 1, s**4 + s**3 - p, p) + >>> tf2 = TransferFunction(1 - p, p**2 - 3*p + 7, p) + >>> tf3 = TransferFunction(3, 4, p) + >>> tfm1 = TransferFunctionMatrix([[tf1, -tf2]]) + >>> tfm1.shape + (1, 2) + >>> tfm2 = TransferFunctionMatrix([[-tf2, tf3], [tf1, -tf1]]) + >>> tfm2.shape + (2, 2) + + """ + return self._expr_mat.shape + + def __neg__(self): + neg = -self._expr_mat + return _to_TFM(neg, self.var) + + @_check_other_MIMO + def __add__(self, other): + + if not isinstance(other, MIMOParallel): + return MIMOParallel(self, other) + other_arg_list = list(other.args) + return MIMOParallel(self, *other_arg_list) + + @_check_other_MIMO + def __sub__(self, other): + return self + (-other) + + @_check_other_MIMO + def __mul__(self, other): + + if not isinstance(other, MIMOSeries): + return MIMOSeries(other, self) + other_arg_list = list(other.args) + return MIMOSeries(*other_arg_list, self) + + def __getitem__(self, key): + trunc = self._expr_mat.__getitem__(key) + if isinstance(trunc, ImmutableMatrix): + return _to_TFM(trunc, self.var) + return TransferFunction.from_rational_expression(trunc, self.var) + + def transpose(self): + """Returns the transpose of the ``TransferFunctionMatrix`` (switched input and output layers).""" + transposed_mat = self._expr_mat.transpose() + return _to_TFM(transposed_mat, self.var) + + def elem_poles(self): + """ + Returns the poles of each element of the ``TransferFunctionMatrix``. + + .. note:: + Actual poles of a MIMO system are NOT the poles of individual elements. + + Examples + ======== + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction, TransferFunctionMatrix + >>> tf_1 = TransferFunction(3, (s + 1), s) + >>> tf_2 = TransferFunction(s + 6, (s + 1)*(s + 2), s) + >>> tf_3 = TransferFunction(s + 3, s**2 + 3*s + 2, s) + >>> tf_4 = TransferFunction(s + 2, s**2 + 5*s - 10, s) + >>> tfm_1 = TransferFunctionMatrix([[tf_1, tf_2], [tf_3, tf_4]]) + >>> tfm_1 + TransferFunctionMatrix(((TransferFunction(3, s + 1, s), TransferFunction(s + 6, (s + 1)*(s + 2), s)), (TransferFunction(s + 3, s**2 + 3*s + 2, s), TransferFunction(s + 2, s**2 + 5*s - 10, s)))) + >>> tfm_1.elem_poles() + [[[-1], [-2, -1]], [[-2, -1], [-5/2 + sqrt(65)/2, -sqrt(65)/2 - 5/2]]] + + See Also + ======== + + elem_zeros + + """ + return [[element.poles() for element in row] for row in self.doit().args[0]] + + def elem_zeros(self): + """ + Returns the zeros of each element of the ``TransferFunctionMatrix``. + + .. note:: + Actual zeros of a MIMO system are NOT the zeros of individual elements. + + Examples + ======== + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction, TransferFunctionMatrix + >>> tf_1 = TransferFunction(3, (s + 1), s) + >>> tf_2 = TransferFunction(s + 6, (s + 1)*(s + 2), s) + >>> tf_3 = TransferFunction(s + 3, s**2 + 3*s + 2, s) + >>> tf_4 = TransferFunction(s**2 - 9*s + 20, s**2 + 5*s - 10, s) + >>> tfm_1 = TransferFunctionMatrix([[tf_1, tf_2], [tf_3, tf_4]]) + >>> tfm_1 + TransferFunctionMatrix(((TransferFunction(3, s + 1, s), TransferFunction(s + 6, (s + 1)*(s + 2), s)), (TransferFunction(s + 3, s**2 + 3*s + 2, s), TransferFunction(s**2 - 9*s + 20, s**2 + 5*s - 10, s)))) + >>> tfm_1.elem_zeros() + [[[], [-6]], [[-3], [4, 5]]] + + See Also + ======== + + elem_poles + + """ + return [[element.zeros() for element in row] for row in self.doit().args[0]] + + def eval_frequency(self, other): + """ + Evaluates system response of each transfer function in the ``TransferFunctionMatrix`` at any point in the real or complex plane. + + Examples + ======== + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction, TransferFunctionMatrix + >>> from sympy import I + >>> tf_1 = TransferFunction(3, (s + 1), s) + >>> tf_2 = TransferFunction(s + 6, (s + 1)*(s + 2), s) + >>> tf_3 = TransferFunction(s + 3, s**2 + 3*s + 2, s) + >>> tf_4 = TransferFunction(s**2 - 9*s + 20, s**2 + 5*s - 10, s) + >>> tfm_1 = TransferFunctionMatrix([[tf_1, tf_2], [tf_3, tf_4]]) + >>> tfm_1 + TransferFunctionMatrix(((TransferFunction(3, s + 1, s), TransferFunction(s + 6, (s + 1)*(s + 2), s)), (TransferFunction(s + 3, s**2 + 3*s + 2, s), TransferFunction(s**2 - 9*s + 20, s**2 + 5*s - 10, s)))) + >>> tfm_1.eval_frequency(2) + Matrix([ + [ 1, 2/3], + [5/12, 3/2]]) + >>> tfm_1.eval_frequency(I*2) + Matrix([ + [ 3/5 - 6*I/5, -I], + [3/20 - 11*I/20, -101/74 + 23*I/74]]) + """ + mat = self._expr_mat.subs(self.var, other) + return mat.expand() + + def _flat(self): + """Returns flattened list of args in TransferFunctionMatrix""" + return [elem for tup in self.args[0] for elem in tup] + + def _eval_evalf(self, prec): + """Calls evalf() on each transfer function in the transfer function matrix""" + dps = prec_to_dps(prec) + mat = self._expr_mat.applyfunc(lambda a: a.evalf(n=dps)) + return _to_TFM(mat, self.var) + + def _eval_simplify(self, **kwargs): + """Simplifies the transfer function matrix""" + simp_mat = self._expr_mat.applyfunc(lambda a: cancel(a, expand=False)) + return _to_TFM(simp_mat, self.var) + + def expand(self, **hints): + """Expands the transfer function matrix""" + expand_mat = self._expr_mat.expand(**hints) + return _to_TFM(expand_mat, self.var) + +class StateSpace(LinearTimeInvariant): + r""" + State space model (ssm) of a linear, time invariant control system. + + Represents the standard state-space model with A, B, C, D as state-space matrices. + This makes the linear control system: + + (1) x'(t) = A * x(t) + B * u(t); x in R^n , u in R^k + (2) y(t) = C * x(t) + D * u(t); y in R^m + + where u(t) is any input signal, y(t) the corresponding output, and x(t) the system's state. + + Parameters + ========== + + A : Matrix + The State matrix of the state space model. + B : Matrix + The Input-to-State matrix of the state space model. + C : Matrix + The State-to-Output matrix of the state space model. + D : Matrix + The Feedthrough matrix of the state space model. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + + The easiest way to create a StateSpaceModel is via four matrices: + + >>> A = Matrix([[1, 2], [1, 0]]) + >>> B = Matrix([1, 1]) + >>> C = Matrix([[0, 1]]) + >>> D = Matrix([0]) + >>> StateSpace(A, B, C, D) + StateSpace(Matrix([ + [1, 2], + [1, 0]]), Matrix([ + [1], + [1]]), Matrix([[0, 1]]), Matrix([[0]])) + + One can use less matrices. The rest will be filled with a minimum of zeros: + + >>> StateSpace(A, B) + StateSpace(Matrix([ + [1, 2], + [1, 0]]), Matrix([ + [1], + [1]]), Matrix([[0, 0]]), Matrix([[0]])) + + See Also + ======== + + TransferFunction, TransferFunctionMatrix + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/State-space_representation + .. [2] https://in.mathworks.com/help/control/ref/ss.html + + """ + def __new__(cls, A=None, B=None, C=None, D=None): + if A is None: + A = zeros(1) + if B is None: + B = zeros(A.rows, 1) + if C is None: + C = zeros(1, A.cols) + if D is None: + D = zeros(C.rows, B.cols) + + A = _sympify(A) + B = _sympify(B) + C = _sympify(C) + D = _sympify(D) + + if (isinstance(A, ImmutableDenseMatrix) and isinstance(B, ImmutableDenseMatrix) and + isinstance(C, ImmutableDenseMatrix) and isinstance(D, ImmutableDenseMatrix)): + # Check State Matrix is square + if A.rows != A.cols: + raise ShapeError("Matrix A must be a square matrix.") + + # Check State and Input matrices have same rows + if A.rows != B.rows: + raise ShapeError("Matrices A and B must have the same number of rows.") + + # Check Output and Feedthrough matrices have same rows + if C.rows != D.rows: + raise ShapeError("Matrices C and D must have the same number of rows.") + + # Check State and Output matrices have same columns + if A.cols != C.cols: + raise ShapeError("Matrices A and C must have the same number of columns.") + + # Check Input and Feedthrough matrices have same columns + if B.cols != D.cols: + raise ShapeError("Matrices B and D must have the same number of columns.") + + obj = super(StateSpace, cls).__new__(cls, A, B, C, D) + obj._A = A + obj._B = B + obj._C = C + obj._D = D + + # Determine if the system is SISO or MIMO + num_outputs = D.rows + num_inputs = D.cols + if num_inputs == 1 and num_outputs == 1: + obj._is_SISO = True + obj._clstype = SISOLinearTimeInvariant + else: + obj._is_SISO = False + obj._clstype = MIMOLinearTimeInvariant + obj.is_StateSpace_object = True + return obj + + else: + raise TypeError("A, B, C and D inputs must all be sympy Matrices.") + + @property + def state_matrix(self): + """ + Returns the state matrix of the model. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A = Matrix([[1, 2], [1, 0]]) + >>> B = Matrix([1, 1]) + >>> C = Matrix([[0, 1]]) + >>> D = Matrix([0]) + >>> ss = StateSpace(A, B, C, D) + >>> ss.state_matrix + Matrix([ + [1, 2], + [1, 0]]) + + """ + return self._A + + @property + def input_matrix(self): + """ + Returns the input matrix of the model. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A = Matrix([[1, 2], [1, 0]]) + >>> B = Matrix([1, 1]) + >>> C = Matrix([[0, 1]]) + >>> D = Matrix([0]) + >>> ss = StateSpace(A, B, C, D) + >>> ss.input_matrix + Matrix([ + [1], + [1]]) + + """ + return self._B + + @property + def output_matrix(self): + """ + Returns the output matrix of the model. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A = Matrix([[1, 2], [1, 0]]) + >>> B = Matrix([1, 1]) + >>> C = Matrix([[0, 1]]) + >>> D = Matrix([0]) + >>> ss = StateSpace(A, B, C, D) + >>> ss.output_matrix + Matrix([[0, 1]]) + + """ + return self._C + + @property + def feedforward_matrix(self): + """ + Returns the feedforward matrix of the model. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A = Matrix([[1, 2], [1, 0]]) + >>> B = Matrix([1, 1]) + >>> C = Matrix([[0, 1]]) + >>> D = Matrix([0]) + >>> ss = StateSpace(A, B, C, D) + >>> ss.feedforward_matrix + Matrix([[0]]) + + """ + return self._D + + A = state_matrix + B = input_matrix + C = output_matrix + D = feedforward_matrix + + @property + def num_states(self): + """ + Returns the number of states of the model. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A = Matrix([[1, 2], [1, 0]]) + >>> B = Matrix([1, 1]) + >>> C = Matrix([[0, 1]]) + >>> D = Matrix([0]) + >>> ss = StateSpace(A, B, C, D) + >>> ss.num_states + 2 + + """ + return self._A.rows + + @property + def num_inputs(self): + """ + Returns the number of inputs of the model. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A = Matrix([[1, 2], [1, 0]]) + >>> B = Matrix([1, 1]) + >>> C = Matrix([[0, 1]]) + >>> D = Matrix([0]) + >>> ss = StateSpace(A, B, C, D) + >>> ss.num_inputs + 1 + + """ + return self._D.cols + + @property + def num_outputs(self): + """ + Returns the number of outputs of the model. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A = Matrix([[1, 2], [1, 0]]) + >>> B = Matrix([1, 1]) + >>> C = Matrix([[0, 1]]) + >>> D = Matrix([0]) + >>> ss = StateSpace(A, B, C, D) + >>> ss.num_outputs + 1 + + """ + return self._D.rows + + + @property + def shape(self): + """Returns the shape of the equivalent StateSpace system.""" + return self.num_outputs, self.num_inputs + + def dsolve(self, initial_conditions=None, input_vector=None, var=Symbol('t')): + r""" + Returns `y(t)` or output of StateSpace given by the solution of equations: + x'(t) = A * x(t) + B * u(t) + y(t) = C * x(t) + D * u(t) + + Parameters + ============ + + initial_conditions : Matrix + The initial conditions of `x` state vector. If not provided, it defaults to a zero vector. + input_vector : Matrix + The input vector for state space. If not provided, it defaults to a zero vector. + var : Symbol + The symbol representing time. If not provided, it defaults to `t`. + + Examples + ========== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A = Matrix([[-2, 0], [1, -1]]) + >>> B = Matrix([[1], [0]]) + >>> C = Matrix([[2, 1]]) + >>> ip = Matrix([5]) + >>> i = Matrix([0, 0]) + >>> ss = StateSpace(A, B, C) + >>> ss.dsolve(input_vector=ip, initial_conditions=i).simplify() + Matrix([[15/2 - 5*exp(-t) - 5*exp(-2*t)/2]]) + + If no input is provided it defaults to solving the system with zero initial conditions and zero input. + + >>> ss.dsolve() + Matrix([[0]]) + + References + ========== + .. [1] https://web.mit.edu/2.14/www/Handouts/StateSpaceResponse.pdf + .. [2] https://docs.sympy.org/latest/modules/solvers/ode.html#sympy.solvers.ode.systems.linodesolve + + """ + + if not isinstance(var, Symbol): + raise ValueError("Variable for representing time must be a Symbol.") + if not initial_conditions: + initial_conditions = zeros(self._A.shape[0], 1) + elif initial_conditions.shape != (self._A.shape[0], 1): + raise ShapeError("Initial condition vector should have the same number of " + "rows as the state matrix.") + if not input_vector: + input_vector = zeros(self._B.shape[1], 1) + elif input_vector.shape != (self._B.shape[1], 1): + raise ShapeError("Input vector should have the same number of " + "columns as the input matrix.") + sol = linodesolve(A=self._A, t=var, b=self._B*input_vector, type='type2', doit=True) + mat1 = Matrix(sol) + mat2 = mat1.replace(var, 0) + free1 = self._A.free_symbols | self._B.free_symbols | input_vector.free_symbols + free2 = mat2.free_symbols + # Get all the free symbols form the matrix + dummy_symbols = list(free2-free1) + # Convert the matrix to a Coefficient matrix + r1, r2 = linear_eq_to_matrix(mat2, dummy_symbols) + s = linsolve((r1, initial_conditions+r2)) + res_tuple = next(iter(s)) + for ind, v in enumerate(res_tuple): + mat1 = mat1.replace(dummy_symbols[ind], v) + res = self._C*mat1 + self._D*input_vector + return res + + def _eval_evalf(self, prec): + """ + Returns state space model where numerical expressions are evaluated into floating point numbers. + """ + dps = prec_to_dps(prec) + return StateSpace( + self._A.evalf(n = dps), + self._B.evalf(n = dps), + self._C.evalf(n = dps), + self._D.evalf(n = dps)) + + def _eval_rewrite_as_TransferFunction(self, *args): + """ + Returns the equivalent Transfer Function of the state space model. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import TransferFunction, StateSpace + >>> A = Matrix([[-5, -1], [3, -1]]) + >>> B = Matrix([2, 5]) + >>> C = Matrix([[1, 2]]) + >>> D = Matrix([0]) + >>> ss = StateSpace(A, B, C, D) + >>> ss.rewrite(TransferFunction) + [[TransferFunction(12*s + 59, s**2 + 6*s + 8, s)]] + + """ + s = Symbol('s') + n = self._A.shape[0] + I = eye(n) + G = self._C*(s*I - self._A).solve(self._B) + self._D + G = G.simplify() + to_tf = lambda expr: TransferFunction.from_rational_expression(expr, s) + tf_mat = [[to_tf(expr) for expr in sublist] for sublist in G.tolist()] + return tf_mat + + def __add__(self, other): + """ + Add two State Space systems (parallel connection). + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A1 = Matrix([[1]]) + >>> B1 = Matrix([[2]]) + >>> C1 = Matrix([[-1]]) + >>> D1 = Matrix([[-2]]) + >>> A2 = Matrix([[-1]]) + >>> B2 = Matrix([[-2]]) + >>> C2 = Matrix([[1]]) + >>> D2 = Matrix([[2]]) + >>> ss1 = StateSpace(A1, B1, C1, D1) + >>> ss2 = StateSpace(A2, B2, C2, D2) + >>> ss1 + ss2 + StateSpace(Matrix([ + [1, 0], + [0, -1]]), Matrix([ + [ 2], + [-2]]), Matrix([[-1, 1]]), Matrix([[0]])) + + """ + # Check for scalars + if isinstance(other, (int, float, complex, Symbol)): + A = self._A + B = self._B + C = self._C + D = self._D.applyfunc(lambda element: element + other) + + else: + # Check nature of system + if not isinstance(other, StateSpace): + raise ValueError("Addition is only supported for 2 State Space models.") + # Check dimensions of system + elif ((self.num_inputs != other.num_inputs) or (self.num_outputs != other.num_outputs)): + raise ShapeError("Systems with incompatible inputs and outputs cannot be added.") + + m1 = (self._A).row_join(zeros(self._A.shape[0], other._A.shape[-1])) + m2 = zeros(other._A.shape[0], self._A.shape[-1]).row_join(other._A) + + A = m1.col_join(m2) + B = self._B.col_join(other._B) + C = self._C.row_join(other._C) + D = self._D + other._D + + return StateSpace(A, B, C, D) + + def __radd__(self, other): + """ + Right add two State Space systems. + + Examples + ======== + + >>> from sympy.physics.control import StateSpace + >>> s = StateSpace() + >>> 5 + s + StateSpace(Matrix([[0]]), Matrix([[0]]), Matrix([[0]]), Matrix([[5]])) + + """ + return self + other + + def __sub__(self, other): + """ + Subtract two State Space systems. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A1 = Matrix([[1]]) + >>> B1 = Matrix([[2]]) + >>> C1 = Matrix([[-1]]) + >>> D1 = Matrix([[-2]]) + >>> A2 = Matrix([[-1]]) + >>> B2 = Matrix([[-2]]) + >>> C2 = Matrix([[1]]) + >>> D2 = Matrix([[2]]) + >>> ss1 = StateSpace(A1, B1, C1, D1) + >>> ss2 = StateSpace(A2, B2, C2, D2) + >>> ss1 - ss2 + StateSpace(Matrix([ + [1, 0], + [0, -1]]), Matrix([ + [ 2], + [-2]]), Matrix([[-1, -1]]), Matrix([[-4]])) + + """ + return self + (-other) + + def __rsub__(self, other): + """ + Right subtract two tate Space systems. + + Examples + ======== + + >>> from sympy.physics.control import StateSpace + >>> s = StateSpace() + >>> 5 - s + StateSpace(Matrix([[0]]), Matrix([[0]]), Matrix([[0]]), Matrix([[5]])) + + """ + return other + (-self) + + def __neg__(self): + """ + Returns the negation of the state space model. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A = Matrix([[-5, -1], [3, -1]]) + >>> B = Matrix([2, 5]) + >>> C = Matrix([[1, 2]]) + >>> D = Matrix([0]) + >>> ss = StateSpace(A, B, C, D) + >>> -ss + StateSpace(Matrix([ + [-5, -1], + [ 3, -1]]), Matrix([ + [2], + [5]]), Matrix([[-1, -2]]), Matrix([[0]])) + + """ + return StateSpace(self._A, self._B, -self._C, -self._D) + + def __mul__(self, other): + """ + Multiplication of two State Space systems (serial connection). + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A = Matrix([[-5, -1], [3, -1]]) + >>> B = Matrix([2, 5]) + >>> C = Matrix([[1, 2]]) + >>> D = Matrix([0]) + >>> ss = StateSpace(A, B, C, D) + >>> ss*5 + StateSpace(Matrix([ + [-5, -1], + [ 3, -1]]), Matrix([ + [2], + [5]]), Matrix([[5, 10]]), Matrix([[0]])) + + """ + # Check for scalars + if isinstance(other, (int, float, complex, Symbol)): + A = self._A + B = self._B + C = self._C.applyfunc(lambda element: element*other) + D = self._D.applyfunc(lambda element: element*other) + + else: + # Check nature of system + if not isinstance(other, StateSpace): + raise ValueError("Multiplication is only supported for 2 State Space models.") + # Check dimensions of system + elif self.num_inputs != other.num_outputs: + raise ShapeError("Systems with incompatible inputs and outputs cannot be multiplied.") + + m1 = (other._A).row_join(zeros(other._A.shape[0], self._A.shape[1])) + m2 = (self._B * other._C).row_join(self._A) + + A = m1.col_join(m2) + B = (other._B).col_join(self._B * other._D) + C = (self._D * other._C).row_join(self._C) + D = self._D * other._D + + return StateSpace(A, B, C, D) + + def __rmul__(self, other): + """ + Right multiply two tate Space systems. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A = Matrix([[-5, -1], [3, -1]]) + >>> B = Matrix([2, 5]) + >>> C = Matrix([[1, 2]]) + >>> D = Matrix([0]) + >>> ss = StateSpace(A, B, C, D) + >>> 5*ss + StateSpace(Matrix([ + [-5, -1], + [ 3, -1]]), Matrix([ + [10], + [25]]), Matrix([[1, 2]]), Matrix([[0]])) + + """ + if isinstance(other, (int, float, complex, Symbol)): + A = self._A + C = self._C + B = self._B.applyfunc(lambda element: element*other) + D = self._D.applyfunc(lambda element: element*other) + return StateSpace(A, B, C, D) + else: + return self*other + + def __repr__(self): + A_str = self._A.__repr__() + B_str = self._B.__repr__() + C_str = self._C.__repr__() + D_str = self._D.__repr__() + + return f"StateSpace(\n{A_str},\n\n{B_str},\n\n{C_str},\n\n{D_str})" + + + def append(self, other): + """ + Returns the first model appended with the second model. The order is preserved. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A1 = Matrix([[1]]) + >>> B1 = Matrix([[2]]) + >>> C1 = Matrix([[-1]]) + >>> D1 = Matrix([[-2]]) + >>> A2 = Matrix([[-1]]) + >>> B2 = Matrix([[-2]]) + >>> C2 = Matrix([[1]]) + >>> D2 = Matrix([[2]]) + >>> ss1 = StateSpace(A1, B1, C1, D1) + >>> ss2 = StateSpace(A2, B2, C2, D2) + >>> ss1.append(ss2) + StateSpace(Matrix([ + [1, 0], + [0, -1]]), Matrix([ + [2, 0], + [0, -2]]), Matrix([ + [-1, 0], + [ 0, 1]]), Matrix([ + [-2, 0], + [ 0, 2]])) + + """ + n = self.num_states + other.num_states + m = self.num_inputs + other.num_inputs + p = self.num_outputs + other.num_outputs + + A = zeros(n, n) + B = zeros(n, m) + C = zeros(p, n) + D = zeros(p, m) + + A[:self.num_states, :self.num_states] = self._A + A[self.num_states:, self.num_states:] = other._A + B[:self.num_states, :self.num_inputs] = self._B + B[self.num_states:, self.num_inputs:] = other._B + C[:self.num_outputs, :self.num_states] = self._C + C[self.num_outputs:, self.num_states:] = other._C + D[:self.num_outputs, :self.num_inputs] = self._D + D[self.num_outputs:, self.num_inputs:] = other._D + return StateSpace(A, B, C, D) + + def observability_matrix(self): + """ + Returns the observability matrix of the state space model: + [C, C * A^1, C * A^2, .. , C * A^(n-1)]; A in R^(n x n), C in R^(m x k) + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A = Matrix([[-1.5, -2], [1, 0]]) + >>> B = Matrix([0.5, 0]) + >>> C = Matrix([[0, 1]]) + >>> D = Matrix([1]) + >>> ss = StateSpace(A, B, C, D) + >>> ob = ss.observability_matrix() + >>> ob + Matrix([ + [0, 1], + [1, 0]]) + + References + ========== + .. [1] https://in.mathworks.com/help/control/ref/statespacemodel.obsv.html + + """ + n = self.num_states + ob = self._C + for i in range(1,n): + ob = ob.col_join(self._C * self._A**i) + + return ob + + def observable_subspace(self): + """ + Returns the observable subspace of the state space model. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A = Matrix([[-1.5, -2], [1, 0]]) + >>> B = Matrix([0.5, 0]) + >>> C = Matrix([[0, 1]]) + >>> D = Matrix([1]) + >>> ss = StateSpace(A, B, C, D) + >>> ob_subspace = ss.observable_subspace() + >>> ob_subspace + [Matrix([ + [0], + [1]]), Matrix([ + [1], + [0]])] + + """ + return self.observability_matrix().columnspace() + + def is_observable(self): + """ + Returns if the state space model is observable. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A = Matrix([[-1.5, -2], [1, 0]]) + >>> B = Matrix([0.5, 0]) + >>> C = Matrix([[0, 1]]) + >>> D = Matrix([1]) + >>> ss = StateSpace(A, B, C, D) + >>> ss.is_observable() + True + + """ + return self.observability_matrix().rank() == self.num_states + + def controllability_matrix(self): + """ + Returns the controllability matrix of the system: + [B, A * B, A^2 * B, .. , A^(n-1) * B]; A in R^(n x n), B in R^(n x m) + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A = Matrix([[-1.5, -2], [1, 0]]) + >>> B = Matrix([0.5, 0]) + >>> C = Matrix([[0, 1]]) + >>> D = Matrix([1]) + >>> ss = StateSpace(A, B, C, D) + >>> ss.controllability_matrix() + Matrix([ + [0.5, -0.75], + [ 0, 0.5]]) + + References + ========== + .. [1] https://in.mathworks.com/help/control/ref/statespacemodel.ctrb.html + + """ + co = self._B + n = self._A.shape[0] + for i in range(1, n): + co = co.row_join(((self._A)**i) * self._B) + + return co + + def controllable_subspace(self): + """ + Returns the controllable subspace of the state space model. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A = Matrix([[-1.5, -2], [1, 0]]) + >>> B = Matrix([0.5, 0]) + >>> C = Matrix([[0, 1]]) + >>> D = Matrix([1]) + >>> ss = StateSpace(A, B, C, D) + >>> co_subspace = ss.controllable_subspace() + >>> co_subspace + [Matrix([ + [0.5], + [ 0]]), Matrix([ + [-0.75], + [ 0.5]])] + + """ + return self.controllability_matrix().columnspace() + + def is_controllable(self): + """ + Returns if the state space model is controllable. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A = Matrix([[-1.5, -2], [1, 0]]) + >>> B = Matrix([0.5, 0]) + >>> C = Matrix([[0, 1]]) + >>> D = Matrix([1]) + >>> ss = StateSpace(A, B, C, D) + >>> ss.is_controllable() + True + + """ + return self.controllability_matrix().rank() == self.num_states diff --git a/phivenv/Lib/site-packages/sympy/physics/control/tests/__init__.py b/phivenv/Lib/site-packages/sympy/physics/control/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/physics/control/tests/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/control/tests/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83efdaa6801c552f158e604caeeb0cd481a37dec Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/control/tests/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/control/tests/__pycache__/test_control_plots.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/control/tests/__pycache__/test_control_plots.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2499351d4d1cf0ccccea3496f644512973c9c972 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/control/tests/__pycache__/test_control_plots.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/control/tests/__pycache__/test_lti.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/control/tests/__pycache__/test_lti.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53df16fae926368387dc26e9ea856e3ede644d73 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/control/tests/__pycache__/test_lti.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/control/tests/test_control_plots.py b/phivenv/Lib/site-packages/sympy/physics/control/tests/test_control_plots.py new file mode 100644 index 0000000000000000000000000000000000000000..05836c806f93c4a8ff375efe2b8bd5f993db7502 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/control/tests/test_control_plots.py @@ -0,0 +1,332 @@ +from math import isclose +from sympy.core.numbers import I, all_close +from sympy.core.symbol import Dummy +from sympy.functions.elementary.complexes import (Abs, arg) +from sympy.functions.elementary.exponential import log +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.abc import s, p, a +from sympy import pi +from sympy.external import import_module +from sympy.physics.control.control_plots import \ + (pole_zero_numerical_data, pole_zero_plot, step_response_numerical_data, + step_response_plot, impulse_response_numerical_data, + impulse_response_plot, ramp_response_numerical_data, + ramp_response_plot, bode_magnitude_numerical_data, + bode_phase_numerical_data, bode_plot, nyquist_plot_expr, + nichols_plot_expr) + +from sympy.physics.control.lti import (TransferFunction, + Series, Parallel, TransferFunctionMatrix) +from sympy.testing.pytest import raises, skip + +matplotlib = import_module( + 'matplotlib', import_kwargs={'fromlist': ['pyplot']}, + catch=(RuntimeError,)) + +numpy = import_module('numpy') + +tf1 = TransferFunction(1, p**2 + 0.5*p + 2, p) +tf2 = TransferFunction(p, 6*p**2 + 3*p + 1, p) +tf3 = TransferFunction(p, p**3 - 1, p) +tf4 = TransferFunction(10, p**3, p) +tf5 = TransferFunction(5, s**2 + 2*s + 10, s) +tf6 = TransferFunction(1, 1, s) +tf7 = TransferFunction(4*s*3 + 9*s**2 + 0.1*s + 11, 8*s**6 + 9*s**4 + 11, s) +tf8 = TransferFunction(5, s**2 + (2+I)*s + 10, s) + +ser1 = Series(tf4, TransferFunction(1, p - 5, p)) +ser2 = Series(tf3, TransferFunction(p, p + 2, p)) + +par1 = Parallel(tf1, tf2) + + +def _to_tuple(a, b): + return tuple(a), tuple(b) + +def _trim_tuple(a, b): + a, b = _to_tuple(a, b) + return tuple(a[0: 2] + a[len(a)//2 : len(a)//2 + 1] + a[-2:]), \ + tuple(b[0: 2] + b[len(b)//2 : len(b)//2 + 1] + b[-2:]) + +def y_coordinate_equality(plot_data_func, evalf_func, system): + """Checks whether the y-coordinate value of the plotted + data point is equal to the value of the function at a + particular x.""" + x, y = plot_data_func(system) + x, y = _trim_tuple(x, y) + y_exp = tuple(evalf_func(system, x_i) for x_i in x) + return all(Abs(y_exp_i - y_i) < 1e-8 for y_exp_i, y_i in zip(y_exp, y)) + + +def test_errors(): + if not matplotlib: + skip("Matplotlib not the default backend") + + # Invalid `system` check + tfm = TransferFunctionMatrix([[tf6, tf5], [tf5, tf6]]) + expr = 1/(s**2 - 1) + raises(NotImplementedError, lambda: pole_zero_plot(tfm)) + raises(NotImplementedError, lambda: pole_zero_numerical_data(expr)) + raises(NotImplementedError, lambda: impulse_response_plot(expr)) + raises(NotImplementedError, lambda: impulse_response_numerical_data(tfm)) + raises(NotImplementedError, lambda: step_response_plot(tfm)) + raises(NotImplementedError, lambda: step_response_numerical_data(expr)) + raises(NotImplementedError, lambda: ramp_response_plot(expr)) + raises(NotImplementedError, lambda: ramp_response_numerical_data(tfm)) + raises(NotImplementedError, lambda: bode_plot(tfm)) + + # More than 1 variables + tf_a = TransferFunction(a, s + 1, s) + raises(ValueError, lambda: pole_zero_plot(tf_a)) + raises(ValueError, lambda: pole_zero_numerical_data(tf_a)) + raises(ValueError, lambda: impulse_response_plot(tf_a)) + raises(ValueError, lambda: impulse_response_numerical_data(tf_a)) + raises(ValueError, lambda: step_response_plot(tf_a)) + raises(ValueError, lambda: step_response_numerical_data(tf_a)) + raises(ValueError, lambda: ramp_response_plot(tf_a)) + raises(ValueError, lambda: ramp_response_numerical_data(tf_a)) + raises(ValueError, lambda: bode_plot(tf_a)) + + # lower_limit > 0 for response plots + raises(ValueError, lambda: impulse_response_plot(tf1, lower_limit=-1)) + raises(ValueError, lambda: step_response_plot(tf1, lower_limit=-0.1)) + raises(ValueError, lambda: ramp_response_plot(tf1, lower_limit=-4/3)) + + # slope in ramp_response_plot() is negative + raises(ValueError, lambda: ramp_response_plot(tf1, slope=-0.1)) + + # incorrect frequency or phase unit + raises(ValueError, lambda: bode_plot(tf1,freq_unit = 'hz')) + raises(ValueError, lambda: bode_plot(tf1,phase_unit = 'degree')) + + +def test_pole_zero(): + + def pz_tester(sys, expected_value): + _z, _p = pole_zero_numerical_data(sys) + z_check = all_close(_z, expected_value[0]) + p_check = all_close(_p, expected_value[1]) + return p_check and z_check + + exp1 = [[], [-0.24999999999999994-1.3919410907075054j, -0.24999999999999994+1.3919410907075054j]] + exp2 = [[0.0], [-0.25-0.3227486121839514j, -0.25+0.3227486121839514j]] + exp3 = [[0.0], [0.9999999999999998+0j, -0.5000000000000004-0.8660254037844395j, + -0.5000000000000004+0.8660254037844395j]] + exp4 = [[], [0.0, 0.0, 0.0, 5.0]] + exp5 = [[-5.645751311064592, -0.5000000000000008, -0.3542486889354093], + [-0.24999999999999986-0.322748612183951348j, + -0.2499999999999998+0.32274861218395134j, + -0.24999999999999986-1.3919410907075052j, + -0.2499999999999998+1.3919410907075052j]] + exp6 = [[], [-1.1641600331447917-3.545808351896439j, + -0.8358399668552097+2.5458083518964383j]] + + assert pz_tester(tf1, exp1) + assert pz_tester(tf2, exp2) + assert pz_tester(tf3, exp3) + assert pz_tester(ser1, exp4) + assert pz_tester(par1, exp5) + assert pz_tester(tf8, exp6) + + +def test_bode(): + if not numpy: + skip("NumPy is required for this test") + + def bode_phase_evalf(system, point): + expr = system.to_expr() + _w = Dummy("w", real=True) + w_expr = expr.subs({system.var: I*_w}) + return arg(w_expr).subs({_w: point}).evalf() + + def bode_mag_evalf(system, point): + expr = system.to_expr() + _w = Dummy("w", real=True) + w_expr = expr.subs({system.var: I*_w}) + return 20*log(Abs(w_expr), 10).subs({_w: point}).evalf() + + def test_bode_data(sys): + return y_coordinate_equality(bode_magnitude_numerical_data, bode_mag_evalf, sys) \ + and y_coordinate_equality(bode_phase_numerical_data, bode_phase_evalf, sys) + + assert test_bode_data(tf1) + assert test_bode_data(tf2) + assert test_bode_data(tf3) + assert test_bode_data(tf4) + assert test_bode_data(tf5) + + +def check_point_accuracy(a, b): + return all(isclose(*_, rel_tol=1e-1, abs_tol=1e-6 + ) for _ in zip(a, b)) + + +def test_impulse_response(): + if not numpy: + skip("NumPy is required for this test") + + def impulse_res_tester(sys, expected_value): + x, y = _to_tuple(*impulse_response_numerical_data(sys, + adaptive=False, n=10)) + x_check = check_point_accuracy(x, expected_value[0]) + y_check = check_point_accuracy(y, expected_value[1]) + return x_check and y_check + + exp1 = ((0.0, 1.1111111111111112, 2.2222222222222223, 3.3333333333333335, 4.444444444444445, + 5.555555555555555, 6.666666666666667, 7.777777777777779, 8.88888888888889, 10.0), + (0.0, 0.544019738507865, 0.01993849743234938, -0.31140243360893216, -0.022852779906491996, 0.1778306498155759, + 0.01962941084328499, -0.1013115194573652, -0.014975541213105696, 0.0575789724730714)) + exp2 = ((0.0, 1.1111111111111112, 2.2222222222222223, 3.3333333333333335, 4.444444444444445, 5.555555555555555, + 6.666666666666667, 7.777777777777779, 8.88888888888889, 10.0), (0.1666666675, 0.08389223412935855, + 0.02338051973475047, -0.014966807776379383, -0.034645954223054234, -0.040560075735512804, + -0.037658628907103885, -0.030149507719590022, -0.021162090730736834, -0.012721292737437523)) + exp3 = ((0.0, 1.1111111111111112, 2.2222222222222223, 3.3333333333333335, 4.444444444444445, 5.555555555555555, + 6.666666666666667, 7.777777777777779, 8.88888888888889, 10.0), (4.369893391586999e-09, 1.1750333000630964, + 3.2922404058312473, 9.432290008148343, 28.37098083007151, 86.18577464367974, 261.90356653762115, + 795.6538758627842, 2416.9920942096983, 7342.159505206647)) + exp4 = ((0.0, 1.1111111111111112, 2.2222222222222223, 3.3333333333333335, 4.444444444444445, 5.555555555555555, + 6.666666666666667, 7.777777777777779, 8.88888888888889, 10.0), (0.0, 6.17283950617284, 24.69135802469136, + 55.555555555555564, 98.76543209876544, 154.320987654321, 222.22222222222226, 302.46913580246917, + 395.0617283950618, 500.0)) + exp5 = ((0.0, 1.1111111111111112, 2.2222222222222223, 3.3333333333333335, 4.444444444444445, 5.555555555555555, + 6.666666666666667, 7.777777777777779, 8.88888888888889, 10.0), (0.0, -0.10455606138085417, + 0.06757671513476461, -0.03234567568833768, 0.013582514927757873, -0.005273419510705473, + 0.0019364083003354075, -0.000680070134067832, 0.00022969845960406913, -7.476094359583917e-05)) + exp6 = ((0.0, 1.1111111111111112, 2.2222222222222223, 3.3333333333333335, 4.444444444444445, + 5.555555555555555, 6.666666666666667, 7.777777777777779, 8.88888888888889, 10.0), + (-6.016699583000218e-09, 0.35039802056107394, 3.3728423827689884, 12.119846079276684, + 25.86101014293389, 29.352480635282088, -30.49475907497664, -273.8717189554019, -863.2381702029659, + -1747.0262164682233)) + exp7 = ((0.0, 1.1111111111111112, 2.2222222222222223, 3.3333333333333335, + 4.444444444444445, 5.555555555555555, 6.666666666666667, 7.777777777777779, + 8.88888888888889, 10.0), (0.0, 18.934638095560974, 5346.93244680907, 1384609.8718249386, + 358161126.65801865, 92645770015.70108, 23964739753087.42, 6198974342083139.0, 1.603492601616059e+18, + 4.147764422869658e+20)) + + assert impulse_res_tester(tf1, exp1) + assert impulse_res_tester(tf2, exp2) + assert impulse_res_tester(tf3, exp3) + assert impulse_res_tester(tf4, exp4) + assert impulse_res_tester(tf5, exp5) + assert impulse_res_tester(tf7, exp6) + assert impulse_res_tester(ser1, exp7) + + +def test_step_response(): + if not numpy: + skip("NumPy is required for this test") + + def step_res_tester(sys, expected_value): + x, y = _to_tuple(*step_response_numerical_data(sys, + adaptive=False, n=10)) + x_check = check_point_accuracy(x, expected_value[0]) + y_check = check_point_accuracy(y, expected_value[1]) + return x_check and y_check + + exp1 = ((0.0, 1.1111111111111112, 2.2222222222222223, 3.3333333333333335, 4.444444444444445, + 5.555555555555555, 6.666666666666667, 7.777777777777779, 8.88888888888889, 10.0), + (-1.9193285738516863e-08, 0.42283495488246126, 0.7840485977945262, 0.5546841805655717, + 0.33903033806932087, 0.4627251747410237, 0.5909907598988051, 0.5247213989553071, + 0.4486997874319281, 0.4839358435839171)) + exp2 = ((0.0, 1.1111111111111112, 2.2222222222222223, 3.3333333333333335, 4.444444444444445, + 5.555555555555555, 6.666666666666667, 7.777777777777779, 8.88888888888889, 10.0), + (0.0, 0.13728409095645816, 0.19474559355325086, 0.1974909129243011, 0.16841657696573073, + 0.12559777736159378, 0.08153828016664713, 0.04360471317348958, 0.015072994568868221, + -0.003636420058445484)) + exp3 = ((0.0, 1.1111111111111112, 2.2222222222222223, 3.3333333333333335, 4.444444444444445, + 5.555555555555555, 6.666666666666667, 7.777777777777779, 8.88888888888889, 10.0), + (0.0, 0.6314542141914303, 2.9356520038101035, 9.37731009663807, 28.452300356688376, + 86.25721933273988, 261.9236645044672, 795.6435410577224, 2416.9786984578764, 7342.154119725917)) + exp4 = ((0.0, 1.1111111111111112, 2.2222222222222223, 3.3333333333333335, 4.444444444444445, + 5.555555555555555, 6.666666666666667, 7.777777777777779, 8.88888888888889, 10.0), + (0.0, 2.286236899862826, 18.28989519890261, 61.72839629629631, 146.31916159122088, 285.7796124828532, + 493.8271703703705, 784.1792566529494, 1170.553292729767, 1666.6667)) + exp5 = ((0.0, 1.1111111111111112, 2.2222222222222223, 3.3333333333333335, 4.444444444444445, + 5.555555555555555, 6.666666666666667, 7.777777777777779, 8.88888888888889, 10.0), + (-3.999999997894577e-09, 0.6720357068882895, 0.4429938256137113, 0.5182010838004518, + 0.4944139147159695, 0.5016379853883338, 0.4995466896527733, 0.5001154784851325, + 0.49997448824584123, 0.5000039745919259)) + exp6 = ((0.0, 1.1111111111111112, 2.2222222222222223, 3.3333333333333335, 4.444444444444445, + 5.555555555555555, 6.666666666666667, 7.777777777777779, 8.88888888888889, 10.0), + (-1.5433688493882158e-09, 0.3428705539937336, 1.1253619102202777, 3.1849962651016517, + 9.47532757182671, 28.727231099148135, 87.29426924860557, 265.2138681048606, 805.6636260007757, + 2447.387582370878)) + + assert step_res_tester(tf1, exp1) + assert step_res_tester(tf2, exp2) + assert step_res_tester(tf3, exp3) + assert step_res_tester(tf4, exp4) + assert step_res_tester(tf5, exp5) + assert step_res_tester(ser2, exp6) + + +def test_ramp_response(): + if not numpy: + skip("NumPy is required for this test") + + def ramp_res_tester(sys, num_points, expected_value, slope=1): + x, y = _to_tuple(*ramp_response_numerical_data(sys, + slope=slope, adaptive=False, n=num_points)) + x_check = check_point_accuracy(x, expected_value[0]) + y_check = check_point_accuracy(y, expected_value[1]) + return x_check and y_check + + exp1 = ((0.0, 2.0, 4.0, 6.0, 8.0, 10.0), (0.0, 0.7324667795033895, 1.9909720978650398, + 2.7956587704217783, 3.9224897567931514, 4.85022655284895)) + exp2 = ((0.0, 1.1111111111111112, 2.2222222222222223, 3.3333333333333335, 4.444444444444445, + 5.555555555555555, 6.666666666666667, 7.777777777777779, 8.88888888888889, 10.0), + (2.4360213402019326e-08, 0.10175320182493253, 0.33057612497658406, 0.5967937263298935, + 0.8431511866718248, 1.0398805391471613, 1.1776043125035738, 1.2600994825747305, 1.2981042689274653, + 1.304684417610106)) + exp3 = ((0.0, 1.1111111111111112, 2.2222222222222223, 3.3333333333333335, 4.444444444444445, 5.555555555555555, + 6.666666666666667, 7.777777777777779, 8.88888888888889, 10.0), (-3.9329040468771836e-08, + 0.34686634635794555, 2.9998828170537903, 12.33303690737476, 40.993913948137795, 127.84145222317912, + 391.41713691996, 1192.0006858708389, 3623.9808672503405, 11011.728034546572)) + exp4 = ((0.0, 1.1111111111111112, 2.2222222222222223, 3.3333333333333335, 4.444444444444445, 5.555555555555555, + 6.666666666666667, 7.777777777777779, 8.88888888888889, 10.0), (0.0, 1.9051973784484078, 30.483158055174524, + 154.32098765432104, 487.7305288827924, 1190.7483615302544, 2469.1358024691367, 4574.3789056546275, + 7803.688462124678, 12500.0)) + exp5 = ((0.0, 1.1111111111111112, 2.2222222222222223, 3.3333333333333335, 4.444444444444445, 5.555555555555555, + 6.666666666666667, 7.777777777777779, 8.88888888888889, 10.0), (0.0, 3.8844361856975635, 9.141792069209865, + 14.096349157657231, 19.09783068994694, 24.10179770390321, 29.09907319114121, 34.10040420185154, + 39.09983919254265, 44.10006013058409)) + exp6 = ((0.0, 1.1111111111111112, 2.2222222222222223, 3.3333333333333335, 4.444444444444445, 5.555555555555555, + 6.666666666666667, 7.777777777777779, 8.88888888888889, 10.0), (0.0, 1.1111111111111112, 2.2222222222222223, + 3.3333333333333335, 4.444444444444445, 5.555555555555555, 6.666666666666667, 7.777777777777779, 8.88888888888889, 10.0)) + + assert ramp_res_tester(tf1, 6, exp1) + assert ramp_res_tester(tf2, 10, exp2, 1.2) + assert ramp_res_tester(tf3, 10, exp3, 1.5) + assert ramp_res_tester(tf4, 10, exp4, 3) + assert ramp_res_tester(tf5, 10, exp5, 9) + assert ramp_res_tester(tf6, 10, exp6) + + +def test_nyquist_plot_expr(): + r1, i1, w1 = nyquist_plot_expr(tf1) + r2, i2, w2 = nyquist_plot_expr(tf2) + r3, i3, w3 = nyquist_plot_expr(tf3) + r4, i4, w4 = nyquist_plot_expr(tf4) + assert r1 == (2 - w1**2)/(0.25*w1**2 + (2 - w1**2)**2) + assert i1 == -0.5*w1/(0.25*w1**2 + (2 - w1**2)**2) + assert r2 == 3*w2**2/(9*w2**2 + (1 - 6*w2**2)**2) + assert i2 == w2*(1 - 6*w2**2)/(9*w2**2 + (1 - 6*w2**2)**2) + assert r3 == -w3**4/(w3**6 + 1) + assert i3 == -w3/(w3**6 + 1) + assert r4 == 0 + assert i4 == 10/w4**3 + + +def test_nichols_expr(): + m1, p1, w1 = nichols_plot_expr(tf1) + m2, p2, w2 = nichols_plot_expr(tf2) + m3, p3, w3 = nichols_plot_expr(tf3) + m4, p4, w4 = nichols_plot_expr(tf4) + assert m1 == 20*log(1/sqrt(w1**4 - 3.75*w1**2 + 4))/log(10) + assert p1 == 180*arg(1/(-w1**2 + 0.5*w1*I + 2))/pi + assert m2 == 20*log(Abs(w2)/sqrt(36*w2**4 - 3*w2**2 + 1))/log(10) + assert p2 == 180*arg(w2*I/(-6*w2**2 + 3*w2*I + 1))/pi + assert m3 == 20*log(Abs(w3)/sqrt(w3**6 + 1))/log(10) + assert p3 == 180*arg(-w3*I/(w3**3*I + 1))/pi + assert m4 == 20*log(10/(w4**2*Abs(w4)))/log(10) + assert p4 == 180*arg(I/w4**3)/pi diff --git a/phivenv/Lib/site-packages/sympy/physics/control/tests/test_lti.py b/phivenv/Lib/site-packages/sympy/physics/control/tests/test_lti.py new file mode 100644 index 0000000000000000000000000000000000000000..a78a4c9b893d11f5e9e94705637080e2a722796a --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/control/tests/test_lti.py @@ -0,0 +1,2273 @@ +from sympy.core.add import Add +from sympy.core.function import Function +from sympy.core.mul import Mul +from sympy.core.numbers import (I, pi, Rational, oo) +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.special.delta_functions import Heaviside +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import atan +from sympy.matrices.dense import eye +from sympy.physics.control.lti import SISOLinearTimeInvariant +from sympy.polys.polytools import factor +from sympy.polys.rootoftools import CRootOf +from sympy.simplify.simplify import simplify +from sympy.core.containers import Tuple +from sympy.matrices import ImmutableMatrix, Matrix, ShapeError +from sympy.functions.elementary.trigonometric import sin, cos +from sympy.physics.control import (TransferFunction, PIDController, Series, Parallel, + Feedback, TransferFunctionMatrix, MIMOSeries, MIMOParallel, MIMOFeedback, + StateSpace, gbt, bilinear, forward_diff, backward_diff, phase_margin, gain_margin) +from sympy.testing.pytest import raises + +a, x, b, c, s, g, d, p, k, tau, zeta, wn, T = symbols('a, x, b, c, s, g, d, p, k,\ + tau, zeta, wn, T') +a0, a1, a2, a3, b0, b1, b2, b3, c0, c1, c2, c3, d0, d1, d2, d3 = symbols('a0:4,\ + b0:4, c0:4, d0:4') +TF1 = TransferFunction(1, s**2 + 2*zeta*wn*s + wn**2, s) +TF2 = TransferFunction(k, 1, s) +TF3 = TransferFunction(a2*p - s, a2*s + p, s) + + +def test_TransferFunction_construction(): + tf = TransferFunction(s + 1, s**2 + s + 1, s) + assert tf.num == (s + 1) + assert tf.den == (s**2 + s + 1) + assert tf.args == (s + 1, s**2 + s + 1, s) + + tf1 = TransferFunction(s + 4, s - 5, s) + assert tf1.num == (s + 4) + assert tf1.den == (s - 5) + assert tf1.args == (s + 4, s - 5, s) + + # using different polynomial variables. + tf2 = TransferFunction(p + 3, p**2 - 9, p) + assert tf2.num == (p + 3) + assert tf2.den == (p**2 - 9) + assert tf2.args == (p + 3, p**2 - 9, p) + + tf3 = TransferFunction(p**3 + 5*p**2 + 4, p**4 + 3*p + 1, p) + assert tf3.args == (p**3 + 5*p**2 + 4, p**4 + 3*p + 1, p) + + # no pole-zero cancellation on its own. + tf4 = TransferFunction((s + 3)*(s - 1), (s - 1)*(s + 5), s) + assert tf4.den == (s - 1)*(s + 5) + assert tf4.args == ((s + 3)*(s - 1), (s - 1)*(s + 5), s) + + tf4_ = TransferFunction(p + 2, p + 2, p) + assert tf4_.args == (p + 2, p + 2, p) + + tf5 = TransferFunction(s - 1, 4 - p, s) + assert tf5.args == (s - 1, 4 - p, s) + + tf5_ = TransferFunction(s - 1, s - 1, s) + assert tf5_.args == (s - 1, s - 1, s) + + tf6 = TransferFunction(5, 6, s) + assert tf6.num == 5 + assert tf6.den == 6 + assert tf6.args == (5, 6, s) + + tf6_ = TransferFunction(1/2, 4, s) + assert tf6_.num == 0.5 + assert tf6_.den == 4 + assert tf6_.args == (0.500000000000000, 4, s) + + tf7 = TransferFunction(3*s**2 + 2*p + 4*s, 8*p**2 + 7*s, s) + tf8 = TransferFunction(3*s**2 + 2*p + 4*s, 8*p**2 + 7*s, p) + assert not tf7 == tf8 + + tf7_ = TransferFunction(a0*s + a1*s**2 + a2*s**3, b0*p - b1*s, s) + tf8_ = TransferFunction(a0*s + a1*s**2 + a2*s**3, b0*p - b1*s, s) + assert tf7_ == tf8_ + assert -(-tf7_) == tf7_ == -(-(-(-tf7_))) + + tf9 = TransferFunction(a*s**3 + b*s**2 + g*s + d, d*p + g*p**2 + g*s, s) + assert tf9.args == (a*s**3 + b*s**2 + d + g*s, d*p + g*p**2 + g*s, s) + + tf10 = TransferFunction(p**3 + d, g*s**2 + d*s + a, p) + tf10_ = TransferFunction(p**3 + d, g*s**2 + d*s + a, p) + assert tf10.args == (d + p**3, a + d*s + g*s**2, p) + assert tf10_ == tf10 + + tf11 = TransferFunction(a1*s + a0, b2*s**2 + b1*s + b0, s) + assert tf11.num == (a0 + a1*s) + assert tf11.den == (b0 + b1*s + b2*s**2) + assert tf11.args == (a0 + a1*s, b0 + b1*s + b2*s**2, s) + + # when just the numerator is 0, leave the denominator alone. + tf12 = TransferFunction(0, p**2 - p + 1, p) + assert tf12.args == (0, p**2 - p + 1, p) + + tf13 = TransferFunction(0, 1, s) + assert tf13.args == (0, 1, s) + + # float exponents + tf14 = TransferFunction(a0*s**0.5 + a2*s**0.6 - a1, a1*p**(-8.7), s) + assert tf14.args == (a0*s**0.5 - a1 + a2*s**0.6, a1*p**(-8.7), s) + + tf15 = TransferFunction(a2**2*p**(1/4) + a1*s**(-4/5), a0*s - p, p) + assert tf15.args == (a1*s**(-0.8) + a2**2*p**0.25, a0*s - p, p) + + omega_o, k_p, k_o, k_i = symbols('omega_o, k_p, k_o, k_i') + tf18 = TransferFunction((k_p + k_o*s + k_i/s), s**2 + 2*omega_o*s + omega_o**2, s) + assert tf18.num == k_i/s + k_o*s + k_p + assert tf18.args == (k_i/s + k_o*s + k_p, omega_o**2 + 2*omega_o*s + s**2, s) + + # ValueError when denominator is zero. + raises(ValueError, lambda: TransferFunction(4, 0, s)) + raises(ValueError, lambda: TransferFunction(s, 0, s)) + raises(ValueError, lambda: TransferFunction(0, 0, s)) + + raises(TypeError, lambda: TransferFunction(Matrix([1, 2, 3]), s, s)) + + raises(TypeError, lambda: TransferFunction(s**2 + 2*s - 1, s + 3, 3)) + raises(TypeError, lambda: TransferFunction(p + 1, 5 - p, 4)) + raises(TypeError, lambda: TransferFunction(3, 4, 8)) + + +def test_TransferFunction_functions(): + # classmethod from_rational_expression + expr_1 = Mul(0, Pow(s, -1, evaluate=False), evaluate=False) + expr_2 = s/0 + expr_3 = (p*s**2 + 5*s)/(s + 1)**3 + expr_4 = 6 + expr_5 = ((2 + 3*s)*(5 + 2*s))/((9 + 3*s)*(5 + 2*s**2)) + expr_6 = (9*s**4 + 4*s**2 + 8)/((s + 1)*(s + 9)) + tf = TransferFunction(s + 1, s**2 + 2, s) + delay = exp(-s/tau) + expr_7 = delay*tf.to_expr() + H1 = TransferFunction.from_rational_expression(expr_7, s) + H2 = TransferFunction(s + 1, (s**2 + 2)*exp(s/tau), s) + expr_8 = Add(2, 3*s/(s**2 + 1), evaluate=False) + + assert TransferFunction.from_rational_expression(expr_1) == TransferFunction(0, s, s) + raises(ZeroDivisionError, lambda: TransferFunction.from_rational_expression(expr_2)) + raises(ValueError, lambda: TransferFunction.from_rational_expression(expr_3)) + assert TransferFunction.from_rational_expression(expr_3, s) == TransferFunction((p*s**2 + 5*s), (s + 1)**3, s) + assert TransferFunction.from_rational_expression(expr_3, p) == TransferFunction((p*s**2 + 5*s), (s + 1)**3, p) + raises(ValueError, lambda: TransferFunction.from_rational_expression(expr_4)) + assert TransferFunction.from_rational_expression(expr_4, s) == TransferFunction(6, 1, s) + assert TransferFunction.from_rational_expression(expr_5, s) == \ + TransferFunction((2 + 3*s)*(5 + 2*s), (9 + 3*s)*(5 + 2*s**2), s) + assert TransferFunction.from_rational_expression(expr_6, s) == \ + TransferFunction((9*s**4 + 4*s**2 + 8), (s + 1)*(s + 9), s) + assert H1 == H2 + assert TransferFunction.from_rational_expression(expr_8, s) == \ + TransferFunction(2*s**2 + 3*s + 2, s**2 + 1, s) + + # classmethod from_coeff_lists + tf1 = TransferFunction.from_coeff_lists([1, 2], [3, 4, 5], s) + num2 = [p**2, 2*p] + den2 = [p**3, p + 1, 4] + tf2 = TransferFunction.from_coeff_lists(num2, den2, s) + num3 = [1, 2, 3] + den3 = [0, 0] + + assert tf1 == TransferFunction(s + 2, 3*s**2 + 4*s + 5, s) + assert tf2 == TransferFunction(p**2*s + 2*p, p**3*s**2 + s*(p + 1) + 4, s) + raises(ZeroDivisionError, lambda: TransferFunction.from_coeff_lists(num3, den3, s)) + + # classmethod from_zpk + zeros = [4] + poles = [-1+2j, -1-2j] + gain = 3 + tf1 = TransferFunction.from_zpk(zeros, poles, gain, s) + + assert tf1 == TransferFunction(3*s - 12, (s + 1.0 - 2.0*I)*(s + 1.0 + 2.0*I), s) + + # explicitly cancel poles and zeros. + tf0 = TransferFunction(s**5 + s**3 + s, s - s**2, s) + a = TransferFunction(-(s**4 + s**2 + 1), s - 1, s) + assert tf0.simplify() == simplify(tf0) == a + + tf1 = TransferFunction((p + 3)*(p - 1), (p - 1)*(p + 5), p) + b = TransferFunction(p + 3, p + 5, p) + assert tf1.simplify() == simplify(tf1) == b + + # expand the numerator and the denominator. + G1 = TransferFunction((1 - s)**2, (s**2 + 1)**2, s) + G2 = TransferFunction(1, -3, p) + c = (a2*s**p + a1*s**s + a0*p**p)*(p**s + s**p) + d = (b0*s**s + b1*p**s)*(b2*s*p + p**p) + e = a0*p**p*p**s + a0*p**p*s**p + a1*p**s*s**s + a1*s**p*s**s + a2*p**s*s**p + a2*s**(2*p) + f = b0*b2*p*s*s**s + b0*p**p*s**s + b1*b2*p*p**s*s + b1*p**p*p**s + g = a1*a2*s*s**p + a1*p*s + a2*b1*p*s*s**p + b1*p**2*s + G3 = TransferFunction(c, d, s) + G4 = TransferFunction(a0*s**s - b0*p**p, (a1*s + b1*s*p)*(a2*s**p + p), p) + + assert G1.expand() == TransferFunction(s**2 - 2*s + 1, s**4 + 2*s**2 + 1, s) + assert tf1.expand() == TransferFunction(p**2 + 2*p - 3, p**2 + 4*p - 5, p) + assert G2.expand() == G2 + assert G3.expand() == TransferFunction(e, f, s) + assert G4.expand() == TransferFunction(a0*s**s - b0*p**p, g, p) + + # purely symbolic polynomials. + p1 = a1*s + a0 + p2 = b2*s**2 + b1*s + b0 + SP1 = TransferFunction(p1, p2, s) + expect1 = TransferFunction(2.0*s + 1.0, 5.0*s**2 + 4.0*s + 3.0, s) + expect1_ = TransferFunction(2*s + 1, 5*s**2 + 4*s + 3, s) + assert SP1.subs({a0: 1, a1: 2, b0: 3, b1: 4, b2: 5}) == expect1_ + assert SP1.subs({a0: 1, a1: 2, b0: 3, b1: 4, b2: 5}).evalf() == expect1 + assert expect1_.evalf() == expect1 + + c1, d0, d1, d2 = symbols('c1, d0:3') + p3, p4 = c1*p, d2*p**3 + d1*p**2 - d0 + SP2 = TransferFunction(p3, p4, p) + expect2 = TransferFunction(2.0*p, 5.0*p**3 + 2.0*p**2 - 3.0, p) + expect2_ = TransferFunction(2*p, 5*p**3 + 2*p**2 - 3, p) + assert SP2.subs({c1: 2, d0: 3, d1: 2, d2: 5}) == expect2_ + assert SP2.subs({c1: 2, d0: 3, d1: 2, d2: 5}).evalf() == expect2 + assert expect2_.evalf() == expect2 + + SP3 = TransferFunction(a0*p**3 + a1*s**2 - b0*s + b1, a1*s + p, s) + expect3 = TransferFunction(2.0*p**3 + 4.0*s**2 - s + 5.0, p + 4.0*s, s) + expect3_ = TransferFunction(2*p**3 + 4*s**2 - s + 5, p + 4*s, s) + assert SP3.subs({a0: 2, a1: 4, b0: 1, b1: 5}) == expect3_ + assert SP3.subs({a0: 2, a1: 4, b0: 1, b1: 5}).evalf() == expect3 + assert expect3_.evalf() == expect3 + + SP4 = TransferFunction(s - a1*p**3, a0*s + p, p) + expect4 = TransferFunction(7.0*p**3 + s, p - s, p) + expect4_ = TransferFunction(7*p**3 + s, p - s, p) + assert SP4.subs({a0: -1, a1: -7}) == expect4_ + assert SP4.subs({a0: -1, a1: -7}).evalf() == expect4 + assert expect4_.evalf() == expect4 + + # evaluate the transfer function at particular frequencies. + assert tf1.eval_frequency(wn) == wn**2/(wn**2 + 4*wn - 5) + 2*wn/(wn**2 + 4*wn - 5) - 3/(wn**2 + 4*wn - 5) + assert G1.eval_frequency(1 + I) == S(3)/25 + S(4)*I/25 + assert G4.eval_frequency(S(5)/3) == \ + a0*s**s/(a1*a2*s**(S(8)/3) + S(5)*a1*s/3 + 5*a2*b1*s**(S(8)/3)/3 + S(25)*b1*s/9) - 5*3**(S(1)/3)*5**(S(2)/3)*b0/(9*a1*a2*s**(S(8)/3) + 15*a1*s + 15*a2*b1*s**(S(8)/3) + 25*b1*s) + + # Low-frequency (or DC) gain. + assert tf0.dc_gain() == 1 + assert tf1.dc_gain() == Rational(3, 5) + assert SP2.dc_gain() == 0 + assert expect4.dc_gain() == -1 + assert expect2_.dc_gain() == 0 + assert TransferFunction(1, s, s).dc_gain() == oo + + # Poles of a transfer function. + tf_ = TransferFunction(x**3 - k, k, x) + _tf = TransferFunction(k, x**4 - k, x) + TF_ = TransferFunction(x**2, x**10 + x + x**2, x) + _TF = TransferFunction(x**10 + x + x**2, x**2, x) + assert G1.poles() == [I, I, -I, -I] + assert G2.poles() == [] + assert tf1.poles() == [-5, 1] + assert expect4_.poles() == [s] + assert SP4.poles() == [-a0*s] + assert expect3.poles() == [-0.25*p] + assert str(expect2.poles()) == str([0.729001428685125, -0.564500714342563 - 0.710198984796332*I, -0.564500714342563 + 0.710198984796332*I]) + assert str(expect1.poles()) == str([-0.4 - 0.66332495807108*I, -0.4 + 0.66332495807108*I]) + assert _tf.poles() == [k**(Rational(1, 4)), -k**(Rational(1, 4)), I*k**(Rational(1, 4)), -I*k**(Rational(1, 4))] + assert TF_.poles() == [CRootOf(x**9 + x + 1, 0), 0, CRootOf(x**9 + x + 1, 1), CRootOf(x**9 + x + 1, 2), + CRootOf(x**9 + x + 1, 3), CRootOf(x**9 + x + 1, 4), CRootOf(x**9 + x + 1, 5), CRootOf(x**9 + x + 1, 6), + CRootOf(x**9 + x + 1, 7), CRootOf(x**9 + x + 1, 8)] + raises(NotImplementedError, lambda: TransferFunction(x**2, a0*x**10 + x + x**2, x).poles()) + + # Stability of a transfer function. + q, r = symbols('q, r', negative=True) + t = symbols('t', positive=True) + TF_ = TransferFunction(s**2 + a0 - a1*p, q*s - r, s) + stable_tf = TransferFunction(s**2 + a0 - a1*p, q*s - 1, s) + stable_tf_ = TransferFunction(s**2 + a0 - a1*p, q*s - t, s) + + assert G1.is_stable() is False + assert G2.is_stable() is True + assert tf1.is_stable() is False # as one pole is +ve, and the other is -ve. + assert expect2.is_stable() is False + assert expect1.is_stable() is True + assert stable_tf.is_stable() is True + assert stable_tf_.is_stable() is True + assert TF_.is_stable() is False + assert expect4_.is_stable() is None # no assumption provided for the only pole 's'. + assert SP4.is_stable() is None + + # Zeros of a transfer function. + assert G1.zeros() == [1, 1] + assert G2.zeros() == [] + assert tf1.zeros() == [-3, 1] + assert expect4_.zeros() == [7**(Rational(2, 3))*(-s)**(Rational(1, 3))/7, -7**(Rational(2, 3))*(-s)**(Rational(1, 3))/14 - + sqrt(3)*7**(Rational(2, 3))*I*(-s)**(Rational(1, 3))/14, -7**(Rational(2, 3))*(-s)**(Rational(1, 3))/14 + sqrt(3)*7**(Rational(2, 3))*I*(-s)**(Rational(1, 3))/14] + assert SP4.zeros() == [(s/a1)**(Rational(1, 3)), -(s/a1)**(Rational(1, 3))/2 - sqrt(3)*I*(s/a1)**(Rational(1, 3))/2, + -(s/a1)**(Rational(1, 3))/2 + sqrt(3)*I*(s/a1)**(Rational(1, 3))/2] + assert str(expect3.zeros()) == str([0.125 - 1.11102430216445*sqrt(-0.405063291139241*p**3 - 1.0), + 1.11102430216445*sqrt(-0.405063291139241*p**3 - 1.0) + 0.125]) + assert tf_.zeros() == [k**(Rational(1, 3)), -k**(Rational(1, 3))/2 - sqrt(3)*I*k**(Rational(1, 3))/2, + -k**(Rational(1, 3))/2 + sqrt(3)*I*k**(Rational(1, 3))/2] + assert _TF.zeros() == [CRootOf(x**9 + x + 1, 0), 0, CRootOf(x**9 + x + 1, 1), CRootOf(x**9 + x + 1, 2), + CRootOf(x**9 + x + 1, 3), CRootOf(x**9 + x + 1, 4), CRootOf(x**9 + x + 1, 5), CRootOf(x**9 + x + 1, 6), + CRootOf(x**9 + x + 1, 7), CRootOf(x**9 + x + 1, 8)] + raises(NotImplementedError, lambda: TransferFunction(a0*x**10 + x + x**2, x**2, x).zeros()) + + # negation of TF. + tf2 = TransferFunction(s + 3, s**2 - s**3 + 9, s) + tf3 = TransferFunction(-3*p + 3, 1 - p, p) + assert -tf2 == TransferFunction(-s - 3, s**2 - s**3 + 9, s) + assert -tf3 == TransferFunction(3*p - 3, 1 - p, p) + + # taking power of a TF. + tf4 = TransferFunction(p + 4, p - 3, p) + tf5 = TransferFunction(s**2 + 1, 1 - s, s) + expect2 = TransferFunction((s**2 + 1)**3, (1 - s)**3, s) + expect1 = TransferFunction((p + 4)**2, (p - 3)**2, p) + assert (tf4*tf4).doit() == tf4**2 == pow(tf4, 2) == expect1 + assert (tf5*tf5*tf5).doit() == tf5**3 == pow(tf5, 3) == expect2 + assert tf5**0 == pow(tf5, 0) == TransferFunction(1, 1, s) + assert Series(tf4).doit()**-1 == tf4**-1 == pow(tf4, -1) == TransferFunction(p - 3, p + 4, p) + assert (tf5*tf5).doit()**-1 == tf5**-2 == pow(tf5, -2) == TransferFunction((1 - s)**2, (s**2 + 1)**2, s) + + raises(ValueError, lambda: tf4**(s**2 + s - 1)) + raises(ValueError, lambda: tf5**s) + raises(ValueError, lambda: tf4**tf5) + + # SymPy's own functions. + tf = TransferFunction(s - 1, s**2 - 2*s + 1, s) + tf6 = TransferFunction(s + p, p**2 - 5, s) + assert factor(tf) == TransferFunction(s - 1, (s - 1)**2, s) + assert tf.num.subs(s, 2) == tf.den.subs(s, 2) == 1 + # subs & xreplace + assert tf.subs(s, 2) == TransferFunction(s - 1, s**2 - 2*s + 1, s) + assert tf6.subs(p, 3) == TransferFunction(s + 3, 4, s) + assert tf3.xreplace({p: s}) == TransferFunction(-3*s + 3, 1 - s, s) + raises(TypeError, lambda: tf3.xreplace({p: exp(2)})) + assert tf3.subs(p, exp(2)) == tf3 + + tf7 = TransferFunction(a0*s**p + a1*p**s, a2*p - s, s) + assert tf7.xreplace({s: k}) == TransferFunction(a0*k**p + a1*p**k, a2*p - k, k) + assert tf7.subs(s, k) == TransferFunction(a0*s**p + a1*p**s, a2*p - s, s) + + # Conversion to Expr with to_expr() + tf8 = TransferFunction(a0*s**5 + 5*s**2 + 3, s**6 - 3, s) + tf9 = TransferFunction((5 + s), (5 + s)*(6 + s), s) + tf10 = TransferFunction(0, 1, s) + tf11 = TransferFunction(1, 1, s) + assert tf8.to_expr() == Mul((a0*s**5 + 5*s**2 + 3), Pow((s**6 - 3), -1, evaluate=False), evaluate=False) + assert tf9.to_expr() == Mul((s + 5), Pow((5 + s)*(6 + s), -1, evaluate=False), evaluate=False) + assert tf10.to_expr() == Mul(S(0), Pow(1, -1, evaluate=False), evaluate=False) + assert tf11.to_expr() == Pow(1, -1, evaluate=False) + + +def test_TransferFunction_addition_and_subtraction(): + tf1 = TransferFunction(s + 6, s - 5, s) + tf2 = TransferFunction(s + 3, s + 1, s) + tf3 = TransferFunction(s + 1, s**2 + s + 1, s) + tf4 = TransferFunction(p, 2 - p, p) + + # addition + assert tf1 + tf2 == Parallel(tf1, tf2) + assert tf3 + tf1 == Parallel(tf3, tf1) + assert -tf1 + tf2 + tf3 == Parallel(-tf1, tf2, tf3) + assert tf1 + (tf2 + tf3) == Parallel(tf1, tf2, tf3) + + c = symbols("c", commutative=False) + raises(ValueError, lambda: tf1 + Matrix([1, 2, 3])) + raises(ValueError, lambda: tf2 + c) + raises(ValueError, lambda: tf3 + tf4) + raises(ValueError, lambda: tf1 + (s - 1)) + raises(ValueError, lambda: tf1 + 8) + raises(ValueError, lambda: (1 - p**3) + tf1) + + # subtraction + assert tf1 - tf2 == Parallel(tf1, -tf2) + assert tf3 - tf2 == Parallel(tf3, -tf2) + assert -tf1 - tf3 == Parallel(-tf1, -tf3) + assert tf1 - tf2 + tf3 == Parallel(tf1, -tf2, tf3) + + raises(ValueError, lambda: tf1 - Matrix([1, 2, 3])) + raises(ValueError, lambda: tf3 - tf4) + raises(ValueError, lambda: tf1 - (s - 1)) + raises(ValueError, lambda: tf1 - 8) + raises(ValueError, lambda: (s + 5) - tf2) + raises(ValueError, lambda: (1 + p**4) - tf1) + + +def test_TransferFunction_multiplication_and_division(): + G1 = TransferFunction(s + 3, -s**3 + 9, s) + G2 = TransferFunction(s + 1, s - 5, s) + G3 = TransferFunction(p, p**4 - 6, p) + G4 = TransferFunction(p + 4, p - 5, p) + G5 = TransferFunction(s + 6, s - 5, s) + G6 = TransferFunction(s + 3, s + 1, s) + G7 = TransferFunction(1, 1, s) + + # multiplication + assert G1*G2 == Series(G1, G2) + assert -G1*G5 == Series(-G1, G5) + assert -G2*G5*-G6 == Series(-G2, G5, -G6) + assert -G1*-G2*-G5*-G6 == Series(-G1, -G2, -G5, -G6) + assert G3*G4 == Series(G3, G4) + assert (G1*G2)*-(G5*G6) == \ + Series(G1, G2, TransferFunction(-1, 1, s), Series(G5, G6)) + assert G1*G2*(G5 + G6) == Series(G1, G2, Parallel(G5, G6)) + + # division - See ``test_Feedback_functions()`` for division by Parallel objects. + assert G5/G6 == Series(G5, pow(G6, -1)) + assert -G3/G4 == Series(-G3, pow(G4, -1)) + assert (G5*G6)/G7 == Series(G5, G6, pow(G7, -1)) + + c = symbols("c", commutative=False) + raises(ValueError, lambda: G3 * Matrix([1, 2, 3])) + raises(ValueError, lambda: G1 * c) + raises(ValueError, lambda: G3 * G5) + raises(ValueError, lambda: G5 * (s - 1)) + raises(ValueError, lambda: 9 * G5) + + raises(ValueError, lambda: G3 / Matrix([1, 2, 3])) + raises(ValueError, lambda: G6 / 0) + raises(ValueError, lambda: G3 / G5) + raises(ValueError, lambda: G5 / 2) + raises(ValueError, lambda: G5 / s**2) + raises(ValueError, lambda: (s - 4*s**2) / G2) + raises(ValueError, lambda: 0 / G4) + raises(ValueError, lambda: G7 / (1 + G6)) + raises(ValueError, lambda: G7 / (G5 * G6)) + raises(ValueError, lambda: G7 / (G7 + (G5 + G6))) + + +def test_TransferFunction_is_proper(): + omega_o, zeta, tau = symbols('omega_o, zeta, tau') + G1 = TransferFunction(omega_o**2, s**2 + p*omega_o*zeta*s + omega_o**2, omega_o) + G2 = TransferFunction(tau - s**3, tau + p**4, tau) + G3 = TransferFunction(a*b*s**3 + s**2 - a*p + s, b - s*p**2, p) + G4 = TransferFunction(b*s**2 + p**2 - a*p + s, b - p**2, s) + assert G1.is_proper + assert G2.is_proper + assert G3.is_proper + assert not G4.is_proper + + +def test_TransferFunction_is_strictly_proper(): + omega_o, zeta, tau = symbols('omega_o, zeta, tau') + tf1 = TransferFunction(omega_o**2, s**2 + p*omega_o*zeta*s + omega_o**2, omega_o) + tf2 = TransferFunction(tau - s**3, tau + p**4, tau) + tf3 = TransferFunction(a*b*s**3 + s**2 - a*p + s, b - s*p**2, p) + tf4 = TransferFunction(b*s**2 + p**2 - a*p + s, b - p**2, s) + assert not tf1.is_strictly_proper + assert not tf2.is_strictly_proper + assert tf3.is_strictly_proper + assert not tf4.is_strictly_proper + + +def test_TransferFunction_is_biproper(): + tau, omega_o, zeta = symbols('tau, omega_o, zeta') + tf1 = TransferFunction(omega_o**2, s**2 + p*omega_o*zeta*s + omega_o**2, omega_o) + tf2 = TransferFunction(tau - s**3, tau + p**4, tau) + tf3 = TransferFunction(a*b*s**3 + s**2 - a*p + s, b - s*p**2, p) + tf4 = TransferFunction(b*s**2 + p**2 - a*p + s, b - p**2, s) + assert tf1.is_biproper + assert tf2.is_biproper + assert not tf3.is_biproper + assert not tf4.is_biproper + + +def test_PIDController(): + kp, ki, kd, tf = symbols("kp ki kd tf") + p1 = PIDController(kp, ki, kd, tf) + p2 = PIDController() + + # Type Checking + assert isinstance(p1, PIDController) + assert isinstance(p1, TransferFunction) + + # Properties checking + assert p1 == PIDController(kp, ki, kd, tf, s) + assert p2 == PIDController(kp, ki, kd, 0, s) + assert p1.num == kd*s**2 + ki*s*tf + ki + kp*s**2*tf + kp*s + assert p1.den == s**2*tf + s + assert p1.var == s + assert p1.kp == kp + assert p1.ki == ki + assert p1.kd == kd + assert p1.tf == tf + + # Functionality checking + assert p1.doit() == TransferFunction(kd*s**2 + ki*s*tf + ki + kp*s**2*tf + kp*s, s**2*tf + s, s) + assert p1.is_proper == True + assert p1.is_biproper == True + assert p1.is_strictly_proper == False + assert p2.doit() == TransferFunction(kd*s**2 + ki + kp*s, s, s) + + # Using PIDController with TransferFunction + tf1 = TransferFunction(s, s + 1, s) + par1 = Parallel(p1, tf1) + ser1 = Series(p1, tf1) + fed1 = Feedback(p1, tf1) + assert par1 == Parallel(PIDController(kp, ki, kd, tf, s), TransferFunction(s, s + 1, s)) + assert ser1 == Series(PIDController(kp, ki, kd, tf, s), TransferFunction(s, s + 1, s)) + assert fed1 == Feedback(PIDController(kp, ki, kd, tf, s), TransferFunction(s, s + 1, s)) + assert par1.doit() == TransferFunction(s*(s**2*tf + s) + (s + 1)*(kd*s**2 + ki*s*tf + ki + kp*s**2*tf + kp*s), + (s + 1)*(s**2*tf + s), s) + assert ser1.doit() == TransferFunction(s*(kd*s**2 + ki*s*tf + ki + kp*s**2*tf + kp*s), + (s + 1)*(s**2*tf + s), s) + assert fed1.doit() == TransferFunction((s + 1)*(s**2*tf + s)*(kd*s**2 + ki*s*tf + ki + kp*s**2*tf + kp*s), + (s*(kd*s**2 + ki*s*tf + ki + kp*s**2*tf + kp*s) + (s + 1)*(s**2*tf + s))*(s**2*tf + s), s) + + +def test_Series_construction(): + tf = TransferFunction(a0*s**3 + a1*s**2 - a2*s, b0*p**4 + b1*p**3 - b2*s*p, s) + tf2 = TransferFunction(a2*p - s, a2*s + p, s) + tf3 = TransferFunction(a0*p + p**a1 - s, p, p) + tf4 = TransferFunction(1, s**2 + 2*zeta*wn*s + wn**2, s) + inp = Function('X_d')(s) + out = Function('X')(s) + + s0 = Series(tf, tf2) + assert s0.args == (tf, tf2) + assert s0.var == s + + s1 = Series(Parallel(tf, -tf2), tf2) + assert s1.args == (Parallel(tf, -tf2), tf2) + assert s1.var == s + + tf3_ = TransferFunction(inp, 1, s) + tf4_ = TransferFunction(-out, 1, s) + s2 = Series(tf, Parallel(tf3_, tf4_), tf2) + assert s2.args == (tf, Parallel(tf3_, tf4_), tf2) + + s3 = Series(tf, tf2, tf4) + assert s3.args == (tf, tf2, tf4) + + s4 = Series(tf3_, tf4_) + assert s4.args == (tf3_, tf4_) + assert s4.var == s + + s6 = Series(tf2, tf4, Parallel(tf2, -tf), tf4) + assert s6.args == (tf2, tf4, Parallel(tf2, -tf), tf4) + + s7 = Series(tf, tf2) + assert s0 == s7 + assert not s0 == s2 + + raises(ValueError, lambda: Series(tf, tf3)) + raises(ValueError, lambda: Series(tf, tf2, tf3, tf4)) + raises(ValueError, lambda: Series(-tf3, tf2)) + raises(TypeError, lambda: Series(2, tf, tf4)) + raises(TypeError, lambda: Series(s**2 + p*s, tf3, tf2)) + raises(TypeError, lambda: Series(tf3, Matrix([1, 2, 3, 4]))) + + +def test_MIMOSeries_construction(): + tf_1 = TransferFunction(a0*s**3 + a1*s**2 - a2*s, b0*p**4 + b1*p**3 - b2*s*p, s) + tf_2 = TransferFunction(a2*p - s, a2*s + p, s) + tf_3 = TransferFunction(1, s**2 + 2*zeta*wn*s + wn**2, s) + + tfm_1 = TransferFunctionMatrix([[tf_1, tf_2, tf_3], [-tf_3, -tf_2, tf_1]]) + tfm_2 = TransferFunctionMatrix([[-tf_2], [-tf_2], [-tf_3]]) + tfm_3 = TransferFunctionMatrix([[-tf_3]]) + tfm_4 = TransferFunctionMatrix([[TF3], [TF2], [-TF1]]) + tfm_5 = TransferFunctionMatrix.from_Matrix(Matrix([1/p]), p) + + s8 = MIMOSeries(tfm_2, tfm_1) + assert s8.args == (tfm_2, tfm_1) + assert s8.var == s + assert s8.shape == (s8.num_outputs, s8.num_inputs) == (2, 1) + + s9 = MIMOSeries(tfm_3, tfm_2, tfm_1) + assert s9.args == (tfm_3, tfm_2, tfm_1) + assert s9.var == s + assert s9.shape == (s9.num_outputs, s9.num_inputs) == (2, 1) + + s11 = MIMOSeries(tfm_3, MIMOParallel(-tfm_2, -tfm_4), tfm_1) + assert s11.args == (tfm_3, MIMOParallel(-tfm_2, -tfm_4), tfm_1) + assert s11.shape == (s11.num_outputs, s11.num_inputs) == (2, 1) + + # arg cannot be empty tuple. + raises(ValueError, lambda: MIMOSeries()) + + # arg cannot contain SISO as well as MIMO systems. + raises(TypeError, lambda: MIMOSeries(tfm_1, tf_1)) + + # for all the adjacent transfer function matrices: + # no. of inputs of first TFM must be equal to the no. of outputs of the second TFM. + raises(ValueError, lambda: MIMOSeries(tfm_1, tfm_2, -tfm_1)) + + # all the TFMs must use the same complex variable. + raises(ValueError, lambda: MIMOSeries(tfm_3, tfm_5)) + + # Number or expression not allowed in the arguments. + raises(TypeError, lambda: MIMOSeries(2, tfm_2, tfm_3)) + raises(TypeError, lambda: MIMOSeries(s**2 + p*s, -tfm_2, tfm_3)) + raises(TypeError, lambda: MIMOSeries(Matrix([1/p]), tfm_3)) + + +def test_Series_functions(): + tf1 = TransferFunction(1, s**2 + 2*zeta*wn*s + wn**2, s) + tf2 = TransferFunction(k, 1, s) + tf3 = TransferFunction(a2*p - s, a2*s + p, s) + tf4 = TransferFunction(a0*p + p**a1 - s, p, p) + tf5 = TransferFunction(a1*s**2 + a2*s - a0, s + a0, s) + + assert tf1*tf2*tf3 == Series(tf1, tf2, tf3) == Series(Series(tf1, tf2), tf3) \ + == Series(tf1, Series(tf2, tf3)) + assert tf1*(tf2 + tf3) == Series(tf1, Parallel(tf2, tf3)) + assert tf1*tf2 + tf5 == Parallel(Series(tf1, tf2), tf5) + assert tf1*tf2 - tf5 == Parallel(Series(tf1, tf2), -tf5) + assert tf1*tf2 + tf3 + tf5 == Parallel(Series(tf1, tf2), tf3, tf5) + assert tf1*tf2 - tf3 - tf5 == Parallel(Series(tf1, tf2), -tf3, -tf5) + assert tf1*tf2 - tf3 + tf5 == Parallel(Series(tf1, tf2), -tf3, tf5) + assert tf1*tf2 + tf3*tf5 == Parallel(Series(tf1, tf2), Series(tf3, tf5)) + assert tf1*tf2 - tf3*tf5 == Parallel(Series(tf1, tf2), Series(TransferFunction(-1, 1, s), Series(tf3, tf5))) + assert tf2*tf3*(tf2 - tf1)*tf3 == Series(tf2, tf3, Parallel(tf2, -tf1), tf3) + assert -tf1*tf2 == Series(-tf1, tf2) + assert -(tf1*tf2) == Series(TransferFunction(-1, 1, s), Series(tf1, tf2)) + raises(ValueError, lambda: tf1*tf2*tf4) + raises(ValueError, lambda: tf1*(tf2 - tf4)) + raises(ValueError, lambda: tf3*Matrix([1, 2, 3])) + + # evaluate=True -> doit() + assert Series(tf1, tf2, evaluate=True) == Series(tf1, tf2).doit() == \ + TransferFunction(k, s**2 + 2*s*wn*zeta + wn**2, s) + assert Series(tf1, tf2, Parallel(tf1, -tf3), evaluate=True) == Series(tf1, tf2, Parallel(tf1, -tf3)).doit() == \ + TransferFunction(k*(a2*s + p + (-a2*p + s)*(s**2 + 2*s*wn*zeta + wn**2)), (a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2)**2, s) + assert Series(tf2, tf1, -tf3, evaluate=True) == Series(tf2, tf1, -tf3).doit() == \ + TransferFunction(k*(-a2*p + s), (a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2), s) + assert not Series(tf1, -tf2, evaluate=False) == Series(tf1, -tf2).doit() + + assert Series(Parallel(tf1, tf2), Parallel(tf2, -tf3)).doit() == \ + TransferFunction((k*(s**2 + 2*s*wn*zeta + wn**2) + 1)*(-a2*p + k*(a2*s + p) + s), (a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2), s) + assert Series(-tf1, -tf2, -tf3).doit() == \ + TransferFunction(k*(-a2*p + s), (a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2), s) + assert -Series(tf1, tf2, tf3).doit() == \ + TransferFunction(-k*(a2*p - s), (a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2), s) + assert Series(tf2, tf3, Parallel(tf2, -tf1), tf3).doit() == \ + TransferFunction(k*(a2*p - s)**2*(k*(s**2 + 2*s*wn*zeta + wn**2) - 1), (a2*s + p)**2*(s**2 + 2*s*wn*zeta + wn**2), s) + + assert Series(tf1, tf2).rewrite(TransferFunction) == TransferFunction(k, s**2 + 2*s*wn*zeta + wn**2, s) + assert Series(tf2, tf1, -tf3).rewrite(TransferFunction) == \ + TransferFunction(k*(-a2*p + s), (a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2), s) + + S1 = Series(Parallel(tf1, tf2), Parallel(tf2, -tf3)) + assert S1.is_proper + assert not S1.is_strictly_proper + assert S1.is_biproper + + S2 = Series(tf1, tf2, tf3) + assert S2.is_proper + assert S2.is_strictly_proper + assert not S2.is_biproper + + S3 = Series(tf1, -tf2, Parallel(tf1, -tf3)) + assert S3.is_proper + assert S3.is_strictly_proper + assert not S3.is_biproper + + +def test_MIMOSeries_functions(): + tfm1 = TransferFunctionMatrix([[TF1, TF2, TF3], [-TF3, -TF2, TF1]]) + tfm2 = TransferFunctionMatrix([[-TF1], [-TF2], [-TF3]]) + tfm3 = TransferFunctionMatrix([[-TF1]]) + tfm4 = TransferFunctionMatrix([[-TF2, -TF3], [-TF1, TF2]]) + tfm5 = TransferFunctionMatrix([[TF2, -TF2], [-TF3, -TF2]]) + tfm6 = TransferFunctionMatrix([[-TF3], [TF1]]) + tfm7 = TransferFunctionMatrix([[TF1], [-TF2]]) + + assert tfm1*tfm2 + tfm6 == MIMOParallel(MIMOSeries(tfm2, tfm1), tfm6) + assert tfm1*tfm2 + tfm7 + tfm6 == MIMOParallel(MIMOSeries(tfm2, tfm1), tfm7, tfm6) + assert tfm1*tfm2 - tfm6 - tfm7 == MIMOParallel(MIMOSeries(tfm2, tfm1), -tfm6, -tfm7) + assert tfm4*tfm5 + (tfm4 - tfm5) == MIMOParallel(MIMOSeries(tfm5, tfm4), tfm4, -tfm5) + assert tfm4*-tfm6 + (-tfm4*tfm6) == MIMOParallel(MIMOSeries(-tfm6, tfm4), MIMOSeries(tfm6, -tfm4)) + + raises(ValueError, lambda: tfm1*tfm2 + TF1) + raises(TypeError, lambda: tfm1*tfm2 + a0) + raises(TypeError, lambda: tfm4*tfm6 - (s - 1)) + raises(TypeError, lambda: tfm4*-tfm6 - 8) + raises(TypeError, lambda: (-1 + p**5) + tfm1*tfm2) + + # Shape criteria. + + raises(TypeError, lambda: -tfm1*tfm2 + tfm4) + raises(TypeError, lambda: tfm1*tfm2 - tfm4 + tfm5) + raises(TypeError, lambda: tfm1*tfm2 - tfm4*tfm5) + + assert tfm1*tfm2*-tfm3 == MIMOSeries(-tfm3, tfm2, tfm1) + assert (tfm1*-tfm2)*tfm3 == MIMOSeries(tfm3, -tfm2, tfm1) + + # Multiplication of a Series object with a SISO TF not allowed. + + raises(ValueError, lambda: tfm4*tfm5*TF1) + raises(TypeError, lambda: tfm4*tfm5*a1) + raises(TypeError, lambda: tfm4*-tfm5*(s - 2)) + raises(TypeError, lambda: tfm5*tfm4*9) + raises(TypeError, lambda: (-p**3 + 1)*tfm5*tfm4) + + # Transfer function matrix in the arguments. + assert (MIMOSeries(tfm2, tfm1, evaluate=True) == MIMOSeries(tfm2, tfm1).doit() + == TransferFunctionMatrix(((TransferFunction(-k**2*(a2*s + p)**2*(s**2 + 2*s*wn*zeta + wn**2)**2 + (-a2*p + s)*(a2*p - s)*(s**2 + 2*s*wn*zeta + wn**2)**2 - (a2*s + p)**2, + (a2*s + p)**2*(s**2 + 2*s*wn*zeta + wn**2)**2, s),), + (TransferFunction(k**2*(a2*s + p)**2*(s**2 + 2*s*wn*zeta + wn**2)**2 + (-a2*p + s)*(a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2) + (a2*p - s)*(a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2), + (a2*s + p)**2*(s**2 + 2*s*wn*zeta + wn**2)**2, s),)))) + + # doit() should not cancel poles and zeros. + mat_1 = Matrix([[1/(1+s), (1+s)/(1+s**2+2*s)**3]]) + mat_2 = Matrix([[(1+s)], [(1+s**2+2*s)**3/(1+s)]]) + tm_1, tm_2 = TransferFunctionMatrix.from_Matrix(mat_1, s), TransferFunctionMatrix.from_Matrix(mat_2, s) + assert (MIMOSeries(tm_2, tm_1).doit() + == TransferFunctionMatrix(((TransferFunction(2*(s + 1)**2*(s**2 + 2*s + 1)**3, (s + 1)**2*(s**2 + 2*s + 1)**3, s),),))) + assert MIMOSeries(tm_2, tm_1).doit().simplify() == TransferFunctionMatrix(((TransferFunction(2, 1, s),),)) + + # calling doit() will expand the internal Series and Parallel objects. + assert (MIMOSeries(-tfm3, -tfm2, tfm1, evaluate=True) + == MIMOSeries(-tfm3, -tfm2, tfm1).doit() + == TransferFunctionMatrix(((TransferFunction(k**2*(a2*s + p)**2*(s**2 + 2*s*wn*zeta + wn**2)**2 + (a2*p - s)**2*(s**2 + 2*s*wn*zeta + wn**2)**2 + (a2*s + p)**2, + (a2*s + p)**2*(s**2 + 2*s*wn*zeta + wn**2)**3, s),), + (TransferFunction(-k**2*(a2*s + p)**2*(s**2 + 2*s*wn*zeta + wn**2)**2 + (-a2*p + s)*(a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2) + (a2*p - s)*(a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2), + (a2*s + p)**2*(s**2 + 2*s*wn*zeta + wn**2)**3, s),)))) + assert (MIMOSeries(MIMOParallel(tfm4, tfm5), tfm5, evaluate=True) + == MIMOSeries(MIMOParallel(tfm4, tfm5), tfm5).doit() + == TransferFunctionMatrix(((TransferFunction(-k*(-a2*s - p + (-a2*p + s)*(s**2 + 2*s*wn*zeta + wn**2)), (a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2), s), TransferFunction(k*(-a2*p - \ + k*(a2*s + p) + s), a2*s + p, s)), (TransferFunction(-k*(-a2*s - p + (-a2*p + s)*(s**2 + 2*s*wn*zeta + wn**2)), (a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2), s), \ + TransferFunction((-a2*p + s)*(-a2*p - k*(a2*s + p) + s), (a2*s + p)**2, s)))) == MIMOSeries(MIMOParallel(tfm4, tfm5), tfm5).rewrite(TransferFunctionMatrix)) + + +def test_Parallel_construction(): + tf = TransferFunction(a0*s**3 + a1*s**2 - a2*s, b0*p**4 + b1*p**3 - b2*s*p, s) + tf2 = TransferFunction(a2*p - s, a2*s + p, s) + tf3 = TransferFunction(a0*p + p**a1 - s, p, p) + tf4 = TransferFunction(1, s**2 + 2*zeta*wn*s + wn**2, s) + inp = Function('X_d')(s) + out = Function('X')(s) + + p0 = Parallel(tf, tf2) + assert p0.args == (tf, tf2) + assert p0.var == s + + p1 = Parallel(Series(tf, -tf2), tf2) + assert p1.args == (Series(tf, -tf2), tf2) + assert p1.var == s + + tf3_ = TransferFunction(inp, 1, s) + tf4_ = TransferFunction(-out, 1, s) + p2 = Parallel(tf, Series(tf3_, -tf4_), tf2) + assert p2.args == (tf, Series(tf3_, -tf4_), tf2) + + p3 = Parallel(tf, tf2, tf4) + assert p3.args == (tf, tf2, tf4) + + p4 = Parallel(tf3_, tf4_) + assert p4.args == (tf3_, tf4_) + assert p4.var == s + + p5 = Parallel(tf, tf2) + assert p0 == p5 + assert not p0 == p1 + + p6 = Parallel(tf2, tf4, Series(tf2, -tf4)) + assert p6.args == (tf2, tf4, Series(tf2, -tf4)) + + p7 = Parallel(tf2, tf4, Series(tf2, -tf), tf4) + assert p7.args == (tf2, tf4, Series(tf2, -tf), tf4) + + raises(ValueError, lambda: Parallel(tf, tf3)) + raises(ValueError, lambda: Parallel(tf, tf2, tf3, tf4)) + raises(ValueError, lambda: Parallel(-tf3, tf4)) + raises(TypeError, lambda: Parallel(2, tf, tf4)) + raises(TypeError, lambda: Parallel(s**2 + p*s, tf3, tf2)) + raises(TypeError, lambda: Parallel(tf3, Matrix([1, 2, 3, 4]))) + + +def test_MIMOParallel_construction(): + tfm1 = TransferFunctionMatrix([[TF1], [TF2], [TF3]]) + tfm2 = TransferFunctionMatrix([[-TF3], [TF2], [TF1]]) + tfm3 = TransferFunctionMatrix([[TF1]]) + tfm4 = TransferFunctionMatrix([[TF2], [TF1], [TF3]]) + tfm5 = TransferFunctionMatrix([[TF1, TF2], [TF2, TF1]]) + tfm6 = TransferFunctionMatrix([[TF2, TF1], [TF1, TF2]]) + tfm7 = TransferFunctionMatrix.from_Matrix(Matrix([[1/p]]), p) + + p8 = MIMOParallel(tfm1, tfm2) + assert p8.args == (tfm1, tfm2) + assert p8.var == s + assert p8.shape == (p8.num_outputs, p8.num_inputs) == (3, 1) + + p9 = MIMOParallel(MIMOSeries(tfm3, tfm1), tfm2) + assert p9.args == (MIMOSeries(tfm3, tfm1), tfm2) + assert p9.var == s + assert p9.shape == (p9.num_outputs, p9.num_inputs) == (3, 1) + + p10 = MIMOParallel(tfm1, MIMOSeries(tfm3, tfm4), tfm2) + assert p10.args == (tfm1, MIMOSeries(tfm3, tfm4), tfm2) + assert p10.var == s + assert p10.shape == (p10.num_outputs, p10.num_inputs) == (3, 1) + + p11 = MIMOParallel(tfm2, tfm1, tfm4) + assert p11.args == (tfm2, tfm1, tfm4) + assert p11.shape == (p11.num_outputs, p11.num_inputs) == (3, 1) + + p12 = MIMOParallel(tfm6, tfm5) + assert p12.args == (tfm6, tfm5) + assert p12.shape == (p12.num_outputs, p12.num_inputs) == (2, 2) + + p13 = MIMOParallel(tfm2, tfm4, MIMOSeries(-tfm3, tfm4), -tfm4) + assert p13.args == (tfm2, tfm4, MIMOSeries(-tfm3, tfm4), -tfm4) + assert p13.shape == (p13.num_outputs, p13.num_inputs) == (3, 1) + + # arg cannot be empty tuple. + raises(TypeError, lambda: MIMOParallel(())) + + # arg cannot contain SISO as well as MIMO systems. + raises(TypeError, lambda: MIMOParallel(tfm1, tfm2, TF1)) + + # all TFMs must have same shapes. + raises(TypeError, lambda: MIMOParallel(tfm1, tfm3, tfm4)) + + # all TFMs must be using the same complex variable. + raises(ValueError, lambda: MIMOParallel(tfm3, tfm7)) + + # Number or expression not allowed in the arguments. + raises(TypeError, lambda: MIMOParallel(2, tfm1, tfm4)) + raises(TypeError, lambda: MIMOParallel(s**2 + p*s, -tfm4, tfm2)) + + +def test_Parallel_functions(): + tf1 = TransferFunction(1, s**2 + 2*zeta*wn*s + wn**2, s) + tf2 = TransferFunction(k, 1, s) + tf3 = TransferFunction(a2*p - s, a2*s + p, s) + tf4 = TransferFunction(a0*p + p**a1 - s, p, p) + tf5 = TransferFunction(a1*s**2 + a2*s - a0, s + a0, s) + + assert tf1 + tf2 + tf3 == Parallel(tf1, tf2, tf3) + assert tf1 + tf2 + tf3 + tf5 == Parallel(tf1, tf2, tf3, tf5) + assert tf1 + tf2 - tf3 - tf5 == Parallel(tf1, tf2, -tf3, -tf5) + assert tf1 + tf2*tf3 == Parallel(tf1, Series(tf2, tf3)) + assert tf1 - tf2*tf3 == Parallel(tf1, -Series(tf2,tf3)) + assert -tf1 - tf2 == Parallel(-tf1, -tf2) + assert -(tf1 + tf2) == Series(TransferFunction(-1, 1, s), Parallel(tf1, tf2)) + assert (tf2 + tf3)*tf1 == Series(Parallel(tf2, tf3), tf1) + assert (tf1 + tf2)*(tf3*tf5) == Series(Parallel(tf1, tf2), tf3, tf5) + assert -(tf2 + tf3)*-tf5 == Series(TransferFunction(-1, 1, s), Parallel(tf2, tf3), -tf5) + assert tf2 + tf3 + tf2*tf1 + tf5 == Parallel(tf2, tf3, Series(tf2, tf1), tf5) + assert tf2 + tf3 + tf2*tf1 - tf3 == Parallel(tf2, tf3, Series(tf2, tf1), -tf3) + assert (tf1 + tf2 + tf5)*(tf3 + tf5) == Series(Parallel(tf1, tf2, tf5), Parallel(tf3, tf5)) + raises(ValueError, lambda: tf1 + tf2 + tf4) + raises(ValueError, lambda: tf1 - tf2*tf4) + raises(ValueError, lambda: tf3 + Matrix([1, 2, 3])) + + # evaluate=True -> doit() + assert Parallel(tf1, tf2, evaluate=True) == Parallel(tf1, tf2).doit() == \ + TransferFunction(k*(s**2 + 2*s*wn*zeta + wn**2) + 1, s**2 + 2*s*wn*zeta + wn**2, s) + assert Parallel(tf1, tf2, Series(-tf1, tf3), evaluate=True) == \ + Parallel(tf1, tf2, Series(-tf1, tf3)).doit() == TransferFunction(k*(a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2)**2 + \ + (-a2*p + s)*(s**2 + 2*s*wn*zeta + wn**2) + (a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2), (a2*s + p)*(s**2 + \ + 2*s*wn*zeta + wn**2)**2, s) + assert Parallel(tf2, tf1, -tf3, evaluate=True) == Parallel(tf2, tf1, -tf3).doit() == \ + TransferFunction(a2*s + k*(a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2) + p + (-a2*p + s)*(s**2 + 2*s*wn*zeta + wn**2) \ + , (a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2), s) + assert not Parallel(tf1, -tf2, evaluate=False) == Parallel(tf1, -tf2).doit() + + assert Parallel(Series(tf1, tf2), Series(tf2, tf3)).doit() == \ + TransferFunction(k*(a2*p - s)*(s**2 + 2*s*wn*zeta + wn**2) + k*(a2*s + p), (a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2), s) + assert Parallel(-tf1, -tf2, -tf3).doit() == \ + TransferFunction(-a2*s - k*(a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2) - p + (-a2*p + s)*(s**2 + 2*s*wn*zeta + wn**2), \ + (a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2), s) + assert -Parallel(tf1, tf2, tf3).doit() == \ + TransferFunction(-a2*s - k*(a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2) - p - (a2*p - s)*(s**2 + 2*s*wn*zeta + wn**2), \ + (a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2), s) + assert Parallel(tf2, tf3, Series(tf2, -tf1), tf3).doit() == \ + TransferFunction(k*(a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2) - k*(a2*s + p) + (2*a2*p - 2*s)*(s**2 + 2*s*wn*zeta \ + + wn**2), (a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2), s) + + assert Parallel(tf1, tf2).rewrite(TransferFunction) == \ + TransferFunction(k*(s**2 + 2*s*wn*zeta + wn**2) + 1, s**2 + 2*s*wn*zeta + wn**2, s) + assert Parallel(tf2, tf1, -tf3).rewrite(TransferFunction) == \ + TransferFunction(a2*s + k*(a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2) + p + (-a2*p + s)*(s**2 + 2*s*wn*zeta + \ + wn**2), (a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2), s) + + assert Parallel(tf1, Parallel(tf2, tf3)) == Parallel(tf1, tf2, tf3) == Parallel(Parallel(tf1, tf2), tf3) + + P1 = Parallel(Series(tf1, tf2), Series(tf2, tf3)) + assert P1.is_proper + assert not P1.is_strictly_proper + assert P1.is_biproper + + P2 = Parallel(tf1, -tf2, -tf3) + assert P2.is_proper + assert not P2.is_strictly_proper + assert P2.is_biproper + + P3 = Parallel(tf1, -tf2, Series(tf1, tf3)) + assert P3.is_proper + assert not P3.is_strictly_proper + assert P3.is_biproper + + +def test_MIMOParallel_functions(): + tf4 = TransferFunction(a0*p + p**a1 - s, p, p) + tf5 = TransferFunction(a1*s**2 + a2*s - a0, s + a0, s) + + tfm1 = TransferFunctionMatrix([[TF1], [TF2], [TF3]]) + tfm2 = TransferFunctionMatrix([[-TF2], [tf5], [-TF1]]) + tfm3 = TransferFunctionMatrix([[tf5], [-tf5], [TF2]]) + tfm4 = TransferFunctionMatrix([[TF2, -tf5], [TF1, tf5]]) + tfm5 = TransferFunctionMatrix([[TF1, TF2], [TF3, -tf5]]) + tfm6 = TransferFunctionMatrix([[-TF2]]) + tfm7 = TransferFunctionMatrix([[tf4], [-tf4], [tf4]]) + + assert tfm1 + tfm2 + tfm3 == MIMOParallel(tfm1, tfm2, tfm3) == MIMOParallel(MIMOParallel(tfm1, tfm2), tfm3) + assert tfm2 - tfm1 - tfm3 == MIMOParallel(tfm2, -tfm1, -tfm3) + assert tfm2 - tfm3 + (-tfm1*tfm6*-tfm6) == MIMOParallel(tfm2, -tfm3, MIMOSeries(-tfm6, tfm6, -tfm1)) + assert tfm1 + tfm1 - (-tfm1*tfm6) == MIMOParallel(tfm1, tfm1, -MIMOSeries(tfm6, -tfm1)) + assert tfm2 - tfm3 - tfm1 + tfm2 == MIMOParallel(tfm2, -tfm3, -tfm1, tfm2) + assert tfm1 + tfm2 - tfm3 - tfm1 == MIMOParallel(tfm1, tfm2, -tfm3, -tfm1) + raises(ValueError, lambda: tfm1 + tfm2 + TF2) + raises(TypeError, lambda: tfm1 - tfm2 - a1) + raises(TypeError, lambda: tfm2 - tfm3 - (s - 1)) + raises(TypeError, lambda: -tfm3 - tfm2 - 9) + raises(TypeError, lambda: (1 - p**3) - tfm3 - tfm2) + # All TFMs must use the same complex var. tfm7 uses 'p'. + raises(ValueError, lambda: tfm3 - tfm2 - tfm7) + raises(ValueError, lambda: tfm2 - tfm1 + tfm7) + # (tfm1 +/- tfm2) has (3, 1) shape while tfm4 has (2, 2) shape. + raises(TypeError, lambda: tfm1 + tfm2 + tfm4) + raises(TypeError, lambda: (tfm1 - tfm2) - tfm4) + + assert (tfm1 + tfm2)*tfm6 == MIMOSeries(tfm6, MIMOParallel(tfm1, tfm2)) + assert (tfm2 - tfm3)*tfm6*-tfm6 == MIMOSeries(-tfm6, tfm6, MIMOParallel(tfm2, -tfm3)) + assert (tfm2 - tfm1 - tfm3)*(tfm6 + tfm6) == MIMOSeries(MIMOParallel(tfm6, tfm6), MIMOParallel(tfm2, -tfm1, -tfm3)) + raises(ValueError, lambda: (tfm4 + tfm5)*TF1) + raises(TypeError, lambda: (tfm2 - tfm3)*a2) + raises(TypeError, lambda: (tfm3 + tfm2)*(s - 6)) + raises(TypeError, lambda: (tfm1 + tfm2 + tfm3)*0) + raises(TypeError, lambda: (1 - p**3)*(tfm1 + tfm3)) + + # (tfm3 - tfm2) has (3, 1) shape while tfm4*tfm5 has (2, 2) shape. + raises(ValueError, lambda: (tfm3 - tfm2)*tfm4*tfm5) + # (tfm1 - tfm2) has (3, 1) shape while tfm5 has (2, 2) shape. + raises(ValueError, lambda: (tfm1 - tfm2)*tfm5) + + # TFM in the arguments. + assert (MIMOParallel(tfm1, tfm2, evaluate=True) == MIMOParallel(tfm1, tfm2).doit() + == MIMOParallel(tfm1, tfm2).rewrite(TransferFunctionMatrix) + == TransferFunctionMatrix(((TransferFunction(-k*(s**2 + 2*s*wn*zeta + wn**2) + 1, s**2 + 2*s*wn*zeta + wn**2, s),), \ + (TransferFunction(-a0 + a1*s**2 + a2*s + k*(a0 + s), a0 + s, s),), (TransferFunction(-a2*s - p + (a2*p - s)* \ + (s**2 + 2*s*wn*zeta + wn**2), (a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2), s),)))) + + +def test_Feedback_construction(): + tf1 = TransferFunction(1, s**2 + 2*zeta*wn*s + wn**2, s) + tf2 = TransferFunction(k, 1, s) + tf3 = TransferFunction(a2*p - s, a2*s + p, s) + tf4 = TransferFunction(a0*p + p**a1 - s, p, p) + tf5 = TransferFunction(a1*s**2 + a2*s - a0, s + a0, s) + tf6 = TransferFunction(s - p, p + s, p) + + f1 = Feedback(TransferFunction(1, 1, s), tf1*tf2*tf3) + assert f1.args == (TransferFunction(1, 1, s), Series(tf1, tf2, tf3), -1) + assert f1.sys1 == TransferFunction(1, 1, s) + assert f1.sys2 == Series(tf1, tf2, tf3) + assert f1.var == s + + f2 = Feedback(tf1, tf2*tf3) + assert f2.args == (tf1, Series(tf2, tf3), -1) + assert f2.sys1 == tf1 + assert f2.sys2 == Series(tf2, tf3) + assert f2.var == s + + f3 = Feedback(tf1*tf2, tf5) + assert f3.args == (Series(tf1, tf2), tf5, -1) + assert f3.sys1 == Series(tf1, tf2) + + f4 = Feedback(tf4, tf6) + assert f4.args == (tf4, tf6, -1) + assert f4.sys1 == tf4 + assert f4.var == p + + f5 = Feedback(tf5, TransferFunction(1, 1, s)) + assert f5.args == (tf5, TransferFunction(1, 1, s), -1) + assert f5.var == s + assert f5 == Feedback(tf5) # When sys2 is not passed explicitly, it is assumed to be unit tf. + + f6 = Feedback(TransferFunction(1, 1, p), tf4) + assert f6.args == (TransferFunction(1, 1, p), tf4, -1) + assert f6.var == p + + f7 = -Feedback(tf4*tf6, TransferFunction(1, 1, p)) + assert f7.args == (Series(TransferFunction(-1, 1, p), Series(tf4, tf6)), -TransferFunction(1, 1, p), -1) + assert f7.sys1 == Series(TransferFunction(-1, 1, p), Series(tf4, tf6)) + + # denominator can't be a Parallel instance + raises(TypeError, lambda: Feedback(tf1, tf2 + tf3)) + raises(TypeError, lambda: Feedback(tf1, Matrix([1, 2, 3]))) + raises(TypeError, lambda: Feedback(TransferFunction(1, 1, s), s - 1)) + raises(TypeError, lambda: Feedback(1, 1)) + # raises(ValueError, lambda: Feedback(TransferFunction(1, 1, s), TransferFunction(1, 1, s))) + raises(ValueError, lambda: Feedback(tf2, tf4*tf5)) + raises(ValueError, lambda: Feedback(tf2, tf1, 1.5)) # `sign` can only be -1 or 1 + raises(ValueError, lambda: Feedback(tf1, -tf1**-1)) # denominator can't be zero + raises(ValueError, lambda: Feedback(tf4, tf5)) # Both systems should use the same `var` + + +def test_Feedback_functions(): + tf = TransferFunction(1, 1, s) + tf1 = TransferFunction(1, s**2 + 2*zeta*wn*s + wn**2, s) + tf2 = TransferFunction(k, 1, s) + tf3 = TransferFunction(a2*p - s, a2*s + p, s) + tf4 = TransferFunction(a0*p + p**a1 - s, p, p) + tf5 = TransferFunction(a1*s**2 + a2*s - a0, s + a0, s) + tf6 = TransferFunction(s - p, p + s, p) + + assert (tf1*tf2*tf3 / tf3*tf5) == Series(tf1, tf2, tf3, pow(tf3, -1), tf5) + assert (tf1*tf2*tf3) / (tf3*tf5) == Series((tf1*tf2*tf3).doit(), pow((tf3*tf5).doit(),-1)) + assert tf / (tf + tf1) == Feedback(tf, tf1) + assert tf / (tf + tf1*tf2*tf3) == Feedback(tf, tf1*tf2*tf3) + assert tf1 / (tf + tf1*tf2*tf3) == Feedback(tf1, tf2*tf3) + assert (tf1*tf2) / (tf + tf1*tf2) == Feedback(tf1*tf2, tf) + assert (tf1*tf2) / (tf + tf1*tf2*tf5) == Feedback(tf1*tf2, tf5) + assert (tf1*tf2) / (tf + tf1*tf2*tf5*tf3) in (Feedback(tf1*tf2, tf5*tf3), Feedback(tf1*tf2, tf3*tf5)) + assert tf4 / (TransferFunction(1, 1, p) + tf4*tf6) == Feedback(tf4, tf6) + assert tf5 / (tf + tf5) == Feedback(tf5, tf) + + raises(TypeError, lambda: tf1*tf2*tf3 / (1 + tf1*tf2*tf3)) + raises(ValueError, lambda: tf2*tf3 / (tf + tf2*tf3*tf4)) + + assert Feedback(tf, tf1*tf2*tf3).doit() == \ + TransferFunction((a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2), k*(a2*p - s) + \ + (a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2), s) + assert Feedback(tf, tf1*tf2*tf3).sensitivity == \ + 1/(k*(a2*p - s)/((a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2)) + 1) + assert Feedback(tf1, tf2*tf3).doit() == \ + TransferFunction((a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2), (k*(a2*p - s) + \ + (a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2))*(s**2 + 2*s*wn*zeta + wn**2), s) + assert Feedback(tf1, tf2*tf3).sensitivity == \ + 1/(k*(a2*p - s)/((a2*s + p)*(s**2 + 2*s*wn*zeta + wn**2)) + 1) + assert Feedback(tf1*tf2, tf5).doit() == \ + TransferFunction(k*(a0 + s)*(s**2 + 2*s*wn*zeta + wn**2), (k*(-a0 + a1*s**2 + a2*s) + \ + (a0 + s)*(s**2 + 2*s*wn*zeta + wn**2))*(s**2 + 2*s*wn*zeta + wn**2), s) + assert Feedback(tf1*tf2, tf5, 1).sensitivity == \ + 1/(-k*(-a0 + a1*s**2 + a2*s)/((a0 + s)*(s**2 + 2*s*wn*zeta + wn**2)) + 1) + assert Feedback(tf4, tf6).doit() == \ + TransferFunction(p*(p + s)*(a0*p + p**a1 - s), p*(p*(p + s) + (-p + s)*(a0*p + p**a1 - s)), p) + assert -Feedback(tf4*tf6, TransferFunction(1, 1, p)).doit() == \ + TransferFunction(-p*(-p + s)*(p + s)*(a0*p + p**a1 - s), p*(p + s)*(p*(p + s) + (-p + s)*(a0*p + p**a1 - s)), p) + assert Feedback(tf, tf).doit() == TransferFunction(1, 2, s) + + assert Feedback(tf1, tf2*tf5).rewrite(TransferFunction) == \ + TransferFunction((a0 + s)*(s**2 + 2*s*wn*zeta + wn**2), (k*(-a0 + a1*s**2 + a2*s) + \ + (a0 + s)*(s**2 + 2*s*wn*zeta + wn**2))*(s**2 + 2*s*wn*zeta + wn**2), s) + assert Feedback(TransferFunction(1, 1, p), tf4).rewrite(TransferFunction) == \ + TransferFunction(p, a0*p + p + p**a1 - s, p) + + +def test_Feedback_with_Series(): + # Solves issue https://github.com/sympy/sympy/issues/26161 + tf1 = TransferFunction(s+1, 1, s) + tf2 = TransferFunction(s+2, 1, s) + fd1 = Feedback(tf1, tf2, -1) # Negative Feedback system + fd2 = Feedback(tf1, tf2, 1) # Positive Feedback system + unit = TransferFunction(1, 1, s) + + # Checking the type + assert isinstance(fd1, SISOLinearTimeInvariant) + assert isinstance(fd1, Feedback) + + # Testing the numerator and denominator + assert fd1.num == tf1 + assert fd2.num == tf1 + assert fd1.den == Parallel(unit, Series(tf2, tf1)) + assert fd2.den == Parallel(unit, -Series(tf2, tf1)) + + # Testing the Series and Parallel Combination with Feedback and TransferFunction + s1 = Series(tf1, fd1) + p1 = Parallel(tf1, fd1) + assert tf1 * fd1 == s1 + assert tf1 + fd1 == p1 + assert s1.doit() == TransferFunction((s + 1)**2, (s + 1)*(s + 2) + 1, s) + assert p1.doit() == TransferFunction(s + (s + 1)*((s + 1)*(s + 2) + 1) + 1, (s + 1)*(s + 2) + 1, s) + + # Testing the use of Feedback and TransferFunction with Feedback + fd3 = Feedback(tf1*fd1, tf2, -1) + assert fd3 == Feedback(Series(tf1, fd1), tf2) + assert fd3.num == tf1 * fd1 + assert fd3.den == Parallel(unit, Series(tf2, Series(tf1, fd1))) + + # Testing the use of Feedback and TransferFunction with TransferFunction + tf3 = TransferFunction(tf1*fd1, tf2, s) + assert tf3 == TransferFunction(Series(tf1, fd1), tf2, s) + assert tf3.num == tf1*fd1 + + +def test_issue_26161(): + # Issue https://github.com/sympy/sympy/issues/26161 + Ib, Is, m, h, l2, l1 = symbols('I_b, I_s, m, h, l2, l1', + real=True, nonnegative=True) + KD, KP, v = symbols('K_D, K_P, v', real=True) + + tau1_sq = (Ib + m * h ** 2) / m / g / h + tau2 = l2 / v + tau3 = v / (l1 + l2) + K = v ** 2 / g / (l1 + l2) + + Gtheta = TransferFunction(-K * (tau2 * s + 1), tau1_sq * s ** 2 - 1, s) + Gdelta = TransferFunction(1, Is * s ** 2 + c * s, s) + Gpsi = TransferFunction(1, tau3 * s, s) + Dcont = TransferFunction(KD * s, 1, s) + PIcont = TransferFunction(KP, s, s) + Gunity = TransferFunction(1, 1, s) + + Ginner = Feedback(Dcont * Gdelta, Gtheta) + Gouter = Feedback(PIcont * Ginner * Gpsi, Gunity) + assert Gouter == Feedback(Series(PIcont, Series(Ginner, Gpsi)), Gunity) + assert Gouter.num == Series(PIcont, Series(Ginner, Gpsi)) + assert Gouter.den == Parallel(Gunity, Series(Gunity, Series(PIcont, Series(Ginner, Gpsi)))) + expr = (KD*KP*g*s**3*v**2*(l1 + l2)*(Is*s**2 + c*s)**2*(-g*h*m + s**2*(Ib + h**2*m))*(-KD*g*h*m*s*v**2*(l2*s + v) + \ + g*v*(l1 + l2)*(Is*s**2 + c*s)*(-g*h*m + s**2*(Ib + h**2*m))))/((s**2*v*(Is*s**2 + c*s)*(-KD*g*h*m*s*v**2* \ + (l2*s + v) + g*v*(l1 + l2)*(Is*s**2 + c*s)*(-g*h*m + s**2*(Ib + h**2*m)))*(KD*KP*g*s*v*(l1 + l2)**2* \ + (Is*s**2 + c*s)*(-g*h*m + s**2*(Ib + h**2*m)) + s**2*v*(Is*s**2 + c*s)*(-KD*g*h*m*s*v**2*(l2*s + v) + \ + g*v*(l1 + l2)*(Is*s**2 + c*s)*(-g*h*m + s**2*(Ib + h**2*m))))/(l1 + l2))) + + assert (Gouter.to_expr() - expr).simplify() == 0 + + +def test_MIMOFeedback_construction(): + tf1 = TransferFunction(1, s, s) + tf2 = TransferFunction(s, s**3 - 1, s) + tf3 = TransferFunction(s, s + 1, s) + tf4 = TransferFunction(s, s**2 + 1, s) + + tfm_1 = TransferFunctionMatrix([[tf1, tf2], [tf3, tf4]]) + tfm_2 = TransferFunctionMatrix([[tf2, tf3], [tf4, tf1]]) + tfm_3 = TransferFunctionMatrix([[tf3, tf4], [tf1, tf2]]) + + f1 = MIMOFeedback(tfm_1, tfm_2) + assert f1.args == (tfm_1, tfm_2, -1) + assert f1.sys1 == tfm_1 + assert f1.sys2 == tfm_2 + assert f1.var == s + assert f1.sign == -1 + assert -(-f1) == f1 + + f2 = MIMOFeedback(tfm_2, tfm_1, 1) + assert f2.args == (tfm_2, tfm_1, 1) + assert f2.sys1 == tfm_2 + assert f2.sys2 == tfm_1 + assert f2.var == s + assert f2.sign == 1 + + f3 = MIMOFeedback(tfm_1, MIMOSeries(tfm_3, tfm_2)) + assert f3.args == (tfm_1, MIMOSeries(tfm_3, tfm_2), -1) + assert f3.sys1 == tfm_1 + assert f3.sys2 == MIMOSeries(tfm_3, tfm_2) + assert f3.var == s + assert f3.sign == -1 + + mat = Matrix([[1, 1/s], [0, 1]]) + sys1 = controller = TransferFunctionMatrix.from_Matrix(mat, s) + f4 = MIMOFeedback(sys1, controller) + assert f4.args == (sys1, controller, -1) + assert f4.sys1 == f4.sys2 == sys1 + + +def test_MIMOFeedback_errors(): + tf1 = TransferFunction(1, s, s) + tf2 = TransferFunction(s, s**3 - 1, s) + tf3 = TransferFunction(s, s - 1, s) + tf4 = TransferFunction(s, s**2 + 1, s) + tf5 = TransferFunction(1, 1, s) + tf6 = TransferFunction(-1, s - 1, s) + + tfm_1 = TransferFunctionMatrix([[tf1, tf2], [tf3, tf4]]) + tfm_2 = TransferFunctionMatrix([[tf2, tf3], [tf4, tf1]]) + tfm_3 = TransferFunctionMatrix.from_Matrix(eye(2), var=s) + tfm_4 = TransferFunctionMatrix([[tf1, tf5], [tf5, tf5]]) + tfm_5 = TransferFunctionMatrix([[-tf3, tf3], [tf3, tf6]]) + # tfm_4 is inverse of tfm_5. Therefore tfm_5*tfm_4 = I + tfm_6 = TransferFunctionMatrix([[-tf3]]) + tfm_7 = TransferFunctionMatrix([[tf3, tf4]]) + + # Unsupported Types + raises(TypeError, lambda: MIMOFeedback(tf1, tf2)) + raises(TypeError, lambda: MIMOFeedback(MIMOParallel(tfm_1, tfm_2), tfm_3)) + # Shape Errors + raises(ValueError, lambda: MIMOFeedback(tfm_1, tfm_6, 1)) + raises(ValueError, lambda: MIMOFeedback(tfm_7, tfm_7)) + # sign not 1/-1 + raises(ValueError, lambda: MIMOFeedback(tfm_1, tfm_2, -2)) + # Non-Invertible Systems + raises(ValueError, lambda: MIMOFeedback(tfm_5, tfm_4, 1)) + raises(ValueError, lambda: MIMOFeedback(tfm_4, -tfm_5)) + raises(ValueError, lambda: MIMOFeedback(tfm_3, tfm_3, 1)) + # Variable not same in both the systems + tfm_8 = TransferFunctionMatrix.from_Matrix(eye(2), var=p) + raises(ValueError, lambda: MIMOFeedback(tfm_1, tfm_8, 1)) + + +def test_MIMOFeedback_functions(): + tf1 = TransferFunction(1, s, s) + tf2 = TransferFunction(s, s - 1, s) + tf3 = TransferFunction(1, 1, s) + tf4 = TransferFunction(-1, s - 1, s) + + tfm_1 = TransferFunctionMatrix.from_Matrix(eye(2), var=s) + tfm_2 = TransferFunctionMatrix([[tf1, tf3], [tf3, tf3]]) + tfm_3 = TransferFunctionMatrix([[-tf2, tf2], [tf2, tf4]]) + tfm_4 = TransferFunctionMatrix([[tf1, tf2], [-tf2, tf1]]) + + # sensitivity, doit(), rewrite() + F_1 = MIMOFeedback(tfm_2, tfm_3) + F_2 = MIMOFeedback(tfm_2, MIMOSeries(tfm_4, -tfm_1), 1) + + assert F_1.sensitivity == Matrix([[S.Half, 0], [0, S.Half]]) + assert F_2.sensitivity == Matrix([[(-2*s**4 + s**2)/(s**2 - s + 1), + (2*s**3 - s**2)/(s**2 - s + 1)], [-s**2, s]]) + + assert F_1.doit() == \ + TransferFunctionMatrix(((TransferFunction(1, 2*s, s), + TransferFunction(1, 2, s)), (TransferFunction(1, 2, s), + TransferFunction(1, 2, s)))) == F_1.rewrite(TransferFunctionMatrix) + assert F_2.doit(cancel=False, expand=True) == \ + TransferFunctionMatrix(((TransferFunction(-s**5 + 2*s**4 - 2*s**3 + s**2, s**5 - 2*s**4 + 3*s**3 - 2*s**2 + s, s), + TransferFunction(-2*s**4 + 2*s**3, s**2 - s + 1, s)), (TransferFunction(0, 1, s), TransferFunction(-s**2 + s, 1, s)))) + assert F_2.doit(cancel=False) == \ + TransferFunctionMatrix(((TransferFunction(s*(2*s**3 - s**2)*(s**2 - s + 1) + \ + (-2*s**4 + s**2)*(s**2 - s + 1), s*(s**2 - s + 1)**2, s), TransferFunction(-2*s**4 + 2*s**3, s**2 - s + 1, s)), + (TransferFunction(0, 1, s), TransferFunction(-s**2 + s, 1, s)))) + assert F_2.doit() == \ + TransferFunctionMatrix(((TransferFunction(s*(-2*s**2 + s*(2*s - 1) + 1), s**2 - s + 1, s), + TransferFunction(-2*s**3*(s - 1), s**2 - s + 1, s)), (TransferFunction(0, 1, s), TransferFunction(s*(1 - s), 1, s)))) + assert F_2.doit(expand=True) == \ + TransferFunctionMatrix(((TransferFunction(-s**2 + s, s**2 - s + 1, s), TransferFunction(-2*s**4 + 2*s**3, s**2 - s + 1, s)), + (TransferFunction(0, 1, s), TransferFunction(-s**2 + s, 1, s)))) + + assert -(F_1.doit()) == (-F_1).doit() # First negating then calculating vs calculating then negating. + + +def test_TransferFunctionMatrix_construction(): + tf5 = TransferFunction(a1*s**2 + a2*s - a0, s + a0, s) + tf4 = TransferFunction(a0*p + p**a1 - s, p, p) + + tfm3_ = TransferFunctionMatrix([[-TF3]]) + assert tfm3_.shape == (tfm3_.num_outputs, tfm3_.num_inputs) == (1, 1) + assert tfm3_.args == Tuple(Tuple(Tuple(-TF3))) + assert tfm3_.var == s + + tfm5 = TransferFunctionMatrix([[TF1, -TF2], [TF3, tf5]]) + assert tfm5.shape == (tfm5.num_outputs, tfm5.num_inputs) == (2, 2) + assert tfm5.args == Tuple(Tuple(Tuple(TF1, -TF2), Tuple(TF3, tf5))) + assert tfm5.var == s + + tfm7 = TransferFunctionMatrix([[TF1, TF2], [TF3, -tf5], [-tf5, TF2]]) + assert tfm7.shape == (tfm7.num_outputs, tfm7.num_inputs) == (3, 2) + assert tfm7.args == Tuple(Tuple(Tuple(TF1, TF2), Tuple(TF3, -tf5), Tuple(-tf5, TF2))) + assert tfm7.var == s + + # all transfer functions will use the same complex variable. tf4 uses 'p'. + raises(ValueError, lambda: TransferFunctionMatrix([[TF1], [TF2], [tf4]])) + raises(ValueError, lambda: TransferFunctionMatrix([[TF1, tf4], [TF3, tf5]])) + + # length of all the lists in the TFM should be equal. + raises(ValueError, lambda: TransferFunctionMatrix([[TF1], [TF3, tf5]])) + raises(ValueError, lambda: TransferFunctionMatrix([[TF1, TF3], [tf5]])) + + # lists should only support transfer functions in them. + raises(TypeError, lambda: TransferFunctionMatrix([[TF1, TF2], [TF3, Matrix([1, 2])]])) + raises(TypeError, lambda: TransferFunctionMatrix([[TF1, Matrix([1, 2])], [TF3, TF2]])) + + # `arg` should strictly be nested list of TransferFunction + raises(ValueError, lambda: TransferFunctionMatrix([TF1, TF2, tf5])) + raises(ValueError, lambda: TransferFunctionMatrix([TF1])) + +def test_TransferFunctionMatrix_functions(): + tf5 = TransferFunction(a1*s**2 + a2*s - a0, s + a0, s) + + # Classmethod (from_matrix) + + mat_1 = ImmutableMatrix([ + [s*(s + 1)*(s - 3)/(s**4 + 1), 2], + [p, p*(s + 1)/(s*(s**1 + 1))] + ]) + mat_2 = ImmutableMatrix([[(2*s + 1)/(s**2 - 9)]]) + mat_3 = ImmutableMatrix([[1, 2], [3, 4]]) + assert TransferFunctionMatrix.from_Matrix(mat_1, s) == \ + TransferFunctionMatrix([[TransferFunction(s*(s - 3)*(s + 1), s**4 + 1, s), TransferFunction(2, 1, s)], + [TransferFunction(p, 1, s), TransferFunction(p, s, s)]]) + assert TransferFunctionMatrix.from_Matrix(mat_2, s) == \ + TransferFunctionMatrix([[TransferFunction(2*s + 1, s**2 - 9, s)]]) + assert TransferFunctionMatrix.from_Matrix(mat_3, p) == \ + TransferFunctionMatrix([[TransferFunction(1, 1, p), TransferFunction(2, 1, p)], + [TransferFunction(3, 1, p), TransferFunction(4, 1, p)]]) + + # Negating a TFM + + tfm1 = TransferFunctionMatrix([[TF1], [TF2]]) + assert -tfm1 == TransferFunctionMatrix([[-TF1], [-TF2]]) + + tfm2 = TransferFunctionMatrix([[TF1, TF2, TF3], [tf5, -TF1, -TF3]]) + assert -tfm2 == TransferFunctionMatrix([[-TF1, -TF2, -TF3], [-tf5, TF1, TF3]]) + + # subs() + + H_1 = TransferFunctionMatrix.from_Matrix(mat_1, s) + H_2 = TransferFunctionMatrix([[TransferFunction(a*p*s, k*s**2, s), TransferFunction(p*s, k*(s**2 - a), s)]]) + assert H_1.subs(p, 1) == TransferFunctionMatrix([[TransferFunction(s*(s - 3)*(s + 1), s**4 + 1, s), TransferFunction(2, 1, s)], [TransferFunction(1, 1, s), TransferFunction(1, s, s)]]) + assert H_1.subs({p: 1}) == TransferFunctionMatrix([[TransferFunction(s*(s - 3)*(s + 1), s**4 + 1, s), TransferFunction(2, 1, s)], [TransferFunction(1, 1, s), TransferFunction(1, s, s)]]) + assert H_1.subs({p: 1, s: 1}) == TransferFunctionMatrix([[TransferFunction(s*(s - 3)*(s + 1), s**4 + 1, s), TransferFunction(2, 1, s)], [TransferFunction(1, 1, s), TransferFunction(1, s, s)]]) # This should ignore `s` as it is `var` + assert H_2.subs(p, 2) == TransferFunctionMatrix([[TransferFunction(2*a*s, k*s**2, s), TransferFunction(2*s, k*(-a + s**2), s)]]) + assert H_2.subs(k, 1) == TransferFunctionMatrix([[TransferFunction(a*p*s, s**2, s), TransferFunction(p*s, -a + s**2, s)]]) + assert H_2.subs(a, 0) == TransferFunctionMatrix([[TransferFunction(0, k*s**2, s), TransferFunction(p*s, k*s**2, s)]]) + assert H_2.subs({p: 1, k: 1, a: a0}) == TransferFunctionMatrix([[TransferFunction(a0*s, s**2, s), TransferFunction(s, -a0 + s**2, s)]]) + + # eval_frequency() + assert H_2.eval_frequency(S(1)/2 + I) == Matrix([[2*a*p/(5*k) - 4*I*a*p/(5*k), I*p/(-a*k - 3*k/4 + I*k) + p/(-2*a*k - 3*k/2 + 2*I*k)]]) + + # transpose() + + assert H_1.transpose() == TransferFunctionMatrix([[TransferFunction(s*(s - 3)*(s + 1), s**4 + 1, s), TransferFunction(p, 1, s)], [TransferFunction(2, 1, s), TransferFunction(p, s, s)]]) + assert H_2.transpose() == TransferFunctionMatrix([[TransferFunction(a*p*s, k*s**2, s)], [TransferFunction(p*s, k*(-a + s**2), s)]]) + assert H_1.transpose().transpose() == H_1 + assert H_2.transpose().transpose() == H_2 + + # elem_poles() + + assert H_1.elem_poles() == [[[-sqrt(2)/2 - sqrt(2)*I/2, -sqrt(2)/2 + sqrt(2)*I/2, sqrt(2)/2 - sqrt(2)*I/2, sqrt(2)/2 + sqrt(2)*I/2], []], + [[], [0]]] + assert H_2.elem_poles() == [[[0, 0], [sqrt(a), -sqrt(a)]]] + assert tfm2.elem_poles() == [[[wn*(-zeta + sqrt((zeta - 1)*(zeta + 1))), wn*(-zeta - sqrt((zeta - 1)*(zeta + 1)))], [], [-p/a2]], + [[-a0], [wn*(-zeta + sqrt((zeta - 1)*(zeta + 1))), wn*(-zeta - sqrt((zeta - 1)*(zeta + 1)))], [-p/a2]]] + + # elem_zeros() + + assert H_1.elem_zeros() == [[[-1, 0, 3], []], [[], []]] + assert H_2.elem_zeros() == [[[0], [0]]] + assert tfm2.elem_zeros() == [[[], [], [a2*p]], + [[-a2/(2*a1) - sqrt(4*a0*a1 + a2**2)/(2*a1), -a2/(2*a1) + sqrt(4*a0*a1 + a2**2)/(2*a1)], [], [a2*p]]] + + # doit() + + H_3 = TransferFunctionMatrix([[Series(TransferFunction(1, s**3 - 3, s), TransferFunction(s**2 - 2*s + 5, 1, s), TransferFunction(1, s, s))]]) + H_4 = TransferFunctionMatrix([[Parallel(TransferFunction(s**3 - 3, 4*s**4 - s**2 - 2*s + 5, s), TransferFunction(4 - s**3, 4*s**4 - s**2 - 2*s + 5, s))]]) + + assert H_3.doit() == TransferFunctionMatrix([[TransferFunction(s**2 - 2*s + 5, s*(s**3 - 3), s)]]) + assert H_4.doit() == TransferFunctionMatrix([[TransferFunction(1, 4*s**4 - s**2 - 2*s + 5, s)]]) + + # _flat() + + assert H_1._flat() == [TransferFunction(s*(s - 3)*(s + 1), s**4 + 1, s), TransferFunction(2, 1, s), TransferFunction(p, 1, s), TransferFunction(p, s, s)] + assert H_2._flat() == [TransferFunction(a*p*s, k*s**2, s), TransferFunction(p*s, k*(-a + s**2), s)] + assert H_3._flat() == [Series(TransferFunction(1, s**3 - 3, s), TransferFunction(s**2 - 2*s + 5, 1, s), TransferFunction(1, s, s))] + assert H_4._flat() == [Parallel(TransferFunction(s**3 - 3, 4*s**4 - s**2 - 2*s + 5, s), TransferFunction(4 - s**3, 4*s**4 - s**2 - 2*s + 5, s))] + + # evalf() + + assert H_1.evalf() == \ + TransferFunctionMatrix(((TransferFunction(s*(s - 3.0)*(s + 1.0), s**4 + 1.0, s), TransferFunction(2.0, 1, s)), (TransferFunction(1.0*p, 1, s), TransferFunction(p, s, s)))) + assert H_2.subs({a:3.141, p:2.88, k:2}).evalf() == \ + TransferFunctionMatrix(((TransferFunction(4.5230399999999999494093572138808667659759521484375, s, s), + TransferFunction(2.87999999999999989341858963598497211933135986328125*s, 2.0*s**2 - 6.282000000000000028421709430404007434844970703125, s)),)) + + # simplify() + + H_5 = TransferFunctionMatrix([[TransferFunction(s**5 + s**3 + s, s - s**2, s), + TransferFunction((s + 3)*(s - 1), (s - 1)*(s + 5), s)]]) + + assert H_5.simplify() == simplify(H_5) == \ + TransferFunctionMatrix(((TransferFunction(-s**4 - s**2 - 1, s - 1, s), TransferFunction(s + 3, s + 5, s)),)) + + # expand() + + assert (H_1.expand() + == TransferFunctionMatrix(((TransferFunction(s**3 - 2*s**2 - 3*s, s**4 + 1, s), TransferFunction(2, 1, s)), + (TransferFunction(p, 1, s), TransferFunction(p, s, s))))) + assert H_5.expand() == \ + TransferFunctionMatrix(((TransferFunction(s**5 + s**3 + s, -s**2 + s, s), TransferFunction(s**2 + 2*s - 3, s**2 + 4*s - 5, s)),)) + +def test_TransferFunction_gbt(): + # simple transfer function, e.g. ohms law + tf = TransferFunction(1, a*s+b, s) + numZ, denZ = gbt(tf, T, 0.5) + # discretized transfer function with coefs from tf.gbt() + tf_test_bilinear = TransferFunction(s * numZ[0] + numZ[1], s * denZ[0] + denZ[1], s) + # corresponding tf with manually calculated coefs + tf_test_manual = TransferFunction(s * T/(2*(a + b*T/2)) + T/(2*(a + b*T/2)), s + (-a + b*T/2)/(a + b*T/2), s) + + assert S.Zero == (tf_test_bilinear.simplify()-tf_test_manual.simplify()).simplify().num + + tf = TransferFunction(1, a*s+b, s) + numZ, denZ = gbt(tf, T, 0) + # discretized transfer function with coefs from tf.gbt() + tf_test_forward = TransferFunction(numZ[0], s*denZ[0]+denZ[1], s) + # corresponding tf with manually calculated coefs + tf_test_manual = TransferFunction(T/a, s + (-a + b*T)/a, s) + + assert S.Zero == (tf_test_forward.simplify()-tf_test_manual.simplify()).simplify().num + + tf = TransferFunction(1, a*s+b, s) + numZ, denZ = gbt(tf, T, 1) + # discretized transfer function with coefs from tf.gbt() + tf_test_backward = TransferFunction(s*numZ[0], s*denZ[0]+denZ[1], s) + # corresponding tf with manually calculated coefs + tf_test_manual = TransferFunction(s * T/(a + b*T), s - a/(a + b*T), s) + + assert S.Zero == (tf_test_backward.simplify()-tf_test_manual.simplify()).simplify().num + + tf = TransferFunction(1, a*s+b, s) + numZ, denZ = gbt(tf, T, 0.3) + # discretized transfer function with coefs from tf.gbt() + tf_test_gbt = TransferFunction(s*numZ[0]+numZ[1], s*denZ[0]+denZ[1], s) + # corresponding tf with manually calculated coefs + tf_test_manual = TransferFunction(s*3*T/(10*(a + 3*b*T/10)) + 7*T/(10*(a + 3*b*T/10)), s + (-a + 7*b*T/10)/(a + 3*b*T/10), s) + + assert S.Zero == (tf_test_gbt.simplify()-tf_test_manual.simplify()).simplify().num + +def test_TransferFunction_bilinear(): + # simple transfer function, e.g. ohms law + tf = TransferFunction(1, a*s+b, s) + numZ, denZ = bilinear(tf, T) + # discretized transfer function with coefs from tf.bilinear() + tf_test_bilinear = TransferFunction(s*numZ[0]+numZ[1], s*denZ[0]+denZ[1], s) + # corresponding tf with manually calculated coefs + tf_test_manual = TransferFunction(s * T/(2*(a + b*T/2)) + T/(2*(a + b*T/2)), s + (-a + b*T/2)/(a + b*T/2), s) + + assert S.Zero == (tf_test_bilinear.simplify()-tf_test_manual.simplify()).simplify().num + +def test_TransferFunction_forward_diff(): + # simple transfer function, e.g. ohms law + tf = TransferFunction(1, a*s+b, s) + numZ, denZ = forward_diff(tf, T) + # discretized transfer function with coefs from tf.forward_diff() + tf_test_forward = TransferFunction(numZ[0], s*denZ[0]+denZ[1], s) + # corresponding tf with manually calculated coefs + tf_test_manual = TransferFunction(T/a, s + (-a + b*T)/a, s) + + assert S.Zero == (tf_test_forward.simplify()-tf_test_manual.simplify()).simplify().num + +def test_TransferFunction_backward_diff(): + # simple transfer function, e.g. ohms law + tf = TransferFunction(1, a*s+b, s) + numZ, denZ = backward_diff(tf, T) + # discretized transfer function with coefs from tf.backward_diff() + tf_test_backward = TransferFunction(s*numZ[0]+numZ[1], s*denZ[0]+denZ[1], s) + # corresponding tf with manually calculated coefs + tf_test_manual = TransferFunction(s * T/(a + b*T), s - a/(a + b*T), s) + + assert S.Zero == (tf_test_backward.simplify()-tf_test_manual.simplify()).simplify().num + +def test_TransferFunction_phase_margin(): + # Test for phase margin + tf1 = TransferFunction(10, p**3 + 1, p) + tf2 = TransferFunction(s**2, 10, s) + tf3 = TransferFunction(1, a*s+b, s) + tf4 = TransferFunction((s + 1)*exp(s/tau), s**2 + 2, s) + tf_m = TransferFunctionMatrix([[tf2],[tf3]]) + + assert phase_margin(tf1) == -180 + 180*atan(3*sqrt(11))/pi + assert phase_margin(tf2) == 0 + + raises(NotImplementedError, lambda: phase_margin(tf4)) + raises(ValueError, lambda: phase_margin(tf3)) + raises(ValueError, lambda: phase_margin(MIMOSeries(tf_m))) + +def test_TransferFunction_gain_margin(): + # Test for gain margin + tf1 = TransferFunction(s**2, 5*(s+1)*(s-5)*(s-10), s) + tf2 = TransferFunction(s**2 + 2*s + 1, 1, s) + tf3 = TransferFunction(1, a*s+b, s) + tf4 = TransferFunction((s + 1)*exp(s/tau), s**2 + 2, s) + tf_m = TransferFunctionMatrix([[tf2],[tf3]]) + + assert gain_margin(tf1) == -20*log(S(7)/540)/log(10) + assert gain_margin(tf2) == oo + + raises(NotImplementedError, lambda: gain_margin(tf4)) + raises(ValueError, lambda: gain_margin(tf3)) + raises(ValueError, lambda: gain_margin(MIMOSeries(tf_m))) + + +def test_StateSpace_construction(): + # using different numbers for a SISO system. + A1 = Matrix([[0, 1], [1, 0]]) + B1 = Matrix([1, 0]) + C1 = Matrix([[0, 1]]) + D1 = Matrix([0]) + ss1 = StateSpace(A1, B1, C1, D1) + + assert ss1.state_matrix == Matrix([[0, 1], [1, 0]]) + assert ss1.input_matrix == Matrix([1, 0]) + assert ss1.output_matrix == Matrix([[0, 1]]) + assert ss1.feedforward_matrix == Matrix([0]) + assert ss1.args == (Matrix([[0, 1], [1, 0]]), Matrix([[1], [0]]), Matrix([[0, 1]]), Matrix([[0]])) + + # using different symbols for a SISO system. + ss2 = StateSpace(Matrix([a0]), Matrix([a1]), + Matrix([a2]), Matrix([a3])) + + assert ss2.state_matrix == Matrix([[a0]]) + assert ss2.input_matrix == Matrix([[a1]]) + assert ss2.output_matrix == Matrix([[a2]]) + assert ss2.feedforward_matrix == Matrix([[a3]]) + assert ss2.args == (Matrix([[a0]]), Matrix([[a1]]), Matrix([[a2]]), Matrix([[a3]])) + + # using different numbers for a MIMO system. + ss3 = StateSpace(Matrix([[-1.5, -2], [1, 0]]), + Matrix([[0.5, 0], [0, 1]]), + Matrix([[0, 1], [0, 2]]), + Matrix([[2, 2], [1, 1]])) + + assert ss3.state_matrix == Matrix([[-1.5, -2], [1, 0]]) + assert ss3.input_matrix == Matrix([[0.5, 0], [0, 1]]) + assert ss3.output_matrix == Matrix([[0, 1], [0, 2]]) + assert ss3.feedforward_matrix == Matrix([[2, 2], [1, 1]]) + assert ss3.args == (Matrix([[-1.5, -2], + [1, 0]]), + Matrix([[0.5, 0], + [0, 1]]), + Matrix([[0, 1], + [0, 2]]), + Matrix([[2, 2], + [1, 1]])) + + # using different symbols for a MIMO system. + A4 = Matrix([[a0, a1], [a2, a3]]) + B4 = Matrix([[b0, b1], [b2, b3]]) + C4 = Matrix([[c0, c1], [c2, c3]]) + D4 = Matrix([[d0, d1], [d2, d3]]) + ss4 = StateSpace(A4, B4, C4, D4) + + assert ss4.state_matrix == Matrix([[a0, a1], [a2, a3]]) + assert ss4.input_matrix == Matrix([[b0, b1], [b2, b3]]) + assert ss4.output_matrix == Matrix([[c0, c1], [c2, c3]]) + assert ss4.feedforward_matrix == Matrix([[d0, d1], [d2, d3]]) + assert ss4.args == (Matrix([[a0, a1], + [a2, a3]]), + Matrix([[b0, b1], + [b2, b3]]), + Matrix([[c0, c1], + [c2, c3]]), + Matrix([[d0, d1], + [d2, d3]])) + + # using less matrices. Rest will be filled with a minimum of zeros. + ss5 = StateSpace() + assert ss5.args == (Matrix([[0]]), Matrix([[0]]), Matrix([[0]]), Matrix([[0]])) + + A6 = Matrix([[0, 1], [1, 0]]) + B6 = Matrix([1, 1]) + ss6 = StateSpace(A6, B6) + + assert ss6.state_matrix == Matrix([[0, 1], [1, 0]]) + assert ss6.input_matrix == Matrix([1, 1]) + assert ss6.output_matrix == Matrix([[0, 0]]) + assert ss6.feedforward_matrix == Matrix([[0]]) + assert ss6.args == (Matrix([[0, 1], + [1, 0]]), + Matrix([[1], + [1]]), + Matrix([[0, 0]]), + Matrix([[0]])) + + # Check if the system is SISO or MIMO. + # If system is not SISO, then it is definitely MIMO. + + assert ss1.is_SISO == True + assert ss2.is_SISO == True + assert ss3.is_SISO == False + assert ss4.is_SISO == False + assert ss5.is_SISO == True + assert ss6.is_SISO == True + + # ShapeError if matrices do not fit. + raises(ShapeError, lambda: StateSpace(Matrix([s, (s+1)**2]), Matrix([s+1]), + Matrix([s**2 - 1]), Matrix([2*s]))) + raises(ShapeError, lambda: StateSpace(Matrix([s]), Matrix([s+1, s**3 + 1]), + Matrix([s**2 - 1]), Matrix([2*s]))) + raises(ShapeError, lambda: StateSpace(Matrix([s]), Matrix([s+1]), + Matrix([[s**2 - 1], [s**2 + 2*s + 1]]), Matrix([2*s]))) + raises(ShapeError, lambda: StateSpace(Matrix([[-s, -s], [s, 0]]), + Matrix([[s/2, 0], [0, s]]), + Matrix([[0, s]]), + Matrix([[2*s, 2*s], [s, s]]))) + + # TypeError if arguments are not sympy matrices. + raises(TypeError, lambda: StateSpace(s**2, s+1, 2*s, 1)) + raises(TypeError, lambda: StateSpace(Matrix([2, 0.5]), Matrix([-1]), + Matrix([1]), 0)) +def test_StateSpace_add(): + A1 = Matrix([[4, 1],[2, -3]]) + B1 = Matrix([[5, 2],[-3, -3]]) + C1 = Matrix([[2, -4],[0, 1]]) + D1 = Matrix([[3, 2],[1, -1]]) + ss1 = StateSpace(A1, B1, C1, D1) + + A2 = Matrix([[-3, 4, 2],[-1, -3, 0],[2, 5, 3]]) + B2 = Matrix([[1, 4],[-3, -3],[-2, 1]]) + C2 = Matrix([[4, 2, -3],[1, 4, 3]]) + D2 = Matrix([[-2, 4],[0, 1]]) + ss2 = StateSpace(A2, B2, C2, D2) + ss3 = StateSpace() + ss4 = StateSpace(Matrix([1]), Matrix([2]), Matrix([3]), Matrix([4])) + + expected_add = \ + StateSpace( + Matrix([ + [4, 1, 0, 0, 0], + [2, -3, 0, 0, 0], + [0, 0, -3, 4, 2], + [0, 0, -1, -3, 0], + [0, 0, 2, 5, 3]]), + Matrix([ + [ 5, 2], + [-3, -3], + [ 1, 4], + [-3, -3], + [-2, 1]]), + Matrix([ + [2, -4, 4, 2, -3], + [0, 1, 1, 4, 3]]), + Matrix([ + [1, 6], + [1, 0]])) + + expected_mul = \ + StateSpace( + Matrix([ + [ -3, 4, 2, 0, 0], + [ -1, -3, 0, 0, 0], + [ 2, 5, 3, 0, 0], + [ 22, 18, -9, 4, 1], + [-15, -18, 0, 2, -3]]), + Matrix([ + [ 1, 4], + [ -3, -3], + [ -2, 1], + [-10, 22], + [ 6, -15]]), + Matrix([ + [14, 14, -3, 2, -4], + [ 3, -2, -6, 0, 1]]), + Matrix([ + [-6, 14], + [-2, 3]])) + + assert ss1 + ss2 == expected_add + assert ss1*ss2 == expected_mul + assert ss3 + 1/2 == StateSpace(Matrix([[0]]), Matrix([[0]]), Matrix([[0]]), Matrix([[0.5]])) + assert ss4*1.5 == StateSpace(Matrix([[1]]), Matrix([[2]]), Matrix([[4.5]]), Matrix([[6.0]])) + assert 1.5*ss4 == StateSpace(Matrix([[1]]), Matrix([[3.0]]), Matrix([[3]]), Matrix([[6.0]])) + raises(ShapeError, lambda: ss1 + ss3) + raises(ShapeError, lambda: ss2*ss4) + +def test_StateSpace_negation(): + A = Matrix([[a0, a1], [a2, a3]]) + B = Matrix([[b0, b1], [b2, b3]]) + C = Matrix([[c0, c1], [c1, c2], [c2, c3]]) + D = Matrix([[d0, d1], [d1, d2], [d2, d3]]) + SS = StateSpace(A, B, C, D) + SS_neg = -SS + + state_mat = Matrix([[-1, 1], [1, -1]]) + input_mat = Matrix([1, -1]) + output_mat = Matrix([[-1, 1]]) + feedforward_mat = Matrix([1]) + system = StateSpace(state_mat, input_mat, output_mat, feedforward_mat) + + assert SS_neg == \ + StateSpace(Matrix([[a0, a1], + [a2, a3]]), + Matrix([[b0, b1], + [b2, b3]]), + Matrix([[-c0, -c1], + [-c1, -c2], + [-c2, -c3]]), + Matrix([[-d0, -d1], + [-d1, -d2], + [-d2, -d3]])) + assert -system == \ + StateSpace(Matrix([[-1, 1], + [ 1, -1]]), + Matrix([[ 1],[-1]]), + Matrix([[1, -1]]), + Matrix([[-1]])) + assert -SS_neg == SS + assert -(-(-(-system))) == system + +def test_SymPy_substitution_functions(): + # subs + ss1 = StateSpace(Matrix([s]), Matrix([(s + 1)**2]), Matrix([s**2 - 1]), Matrix([2*s])) + ss2 = StateSpace(Matrix([s + p]), Matrix([(s + 1)*(p - 1)]), Matrix([p**3 - s**3]), Matrix([s - p])) + + assert ss1.subs({s:5}) == StateSpace(Matrix([[5]]), Matrix([[36]]), Matrix([[24]]), Matrix([[10]])) + assert ss2.subs({p:1}) == StateSpace(Matrix([[s + 1]]), Matrix([[0]]), Matrix([[1 - s**3]]), Matrix([[s - 1]])) + + # xreplace + assert ss1.xreplace({s:p}) == \ + StateSpace(Matrix([[p]]), Matrix([[(p + 1)**2]]), Matrix([[p**2 - 1]]), Matrix([[2*p]])) + assert ss2.xreplace({s:a, p:b}) == \ + StateSpace(Matrix([[a + b]]), Matrix([[(a + 1)*(b - 1)]]), Matrix([[-a**3 + b**3]]), Matrix([[a - b]])) + + # evalf + p1 = a1*s + a0 + p2 = b2*s**2 + b1*s + b0 + G = StateSpace(Matrix([p1]), Matrix([p2])) + expect = StateSpace(Matrix([[2*s + 1]]), Matrix([[5*s**2 + 4*s + 3]]), Matrix([[0]]), Matrix([[0]])) + expect_ = StateSpace(Matrix([[2.0*s + 1.0]]), Matrix([[5.0*s**2 + 4.0*s + 3.0]]), Matrix([[0]]), Matrix([[0]])) + assert G.subs({a0: 1, a1: 2, b0: 3, b1: 4, b2: 5}) == expect + assert G.subs({a0: 1, a1: 2, b0: 3, b1: 4, b2: 5}).evalf() == expect_ + assert expect.evalf() == expect_ + +def test_conversion(): + # StateSpace to TransferFunction for SISO + A1 = Matrix([[-5, -1], [3, -1]]) + B1 = Matrix([2, 5]) + C1 = Matrix([[1, 2]]) + D1 = Matrix([0]) + H1 = StateSpace(A1, B1, C1, D1) + H3 = StateSpace(Matrix([[a0, a1], [a2, a3]]), B = Matrix([[b1], [b2]]), C = Matrix([[c1, c2]])) + tm1 = H1.rewrite(TransferFunction) + tm2 = (-H1).rewrite(TransferFunction) + + tf1 = tm1[0][0] + tf2 = tm2[0][0] + + assert tf1 == TransferFunction(12*s + 59, s**2 + 6*s + 8, s) + assert tf2.num == -tf1.num + assert tf2.den == tf1.den + + # StateSpace to TransferFunction for MIMO + A2 = Matrix([[-1.5, -2, 3], [1, 0, 1], [2, 1, 1]]) + B2 = Matrix([[0.5, 0, 1], [0, 1, 2], [2, 2, 3]]) + C2 = Matrix([[0, 1, 0], [0, 2, 1], [1, 0, 2]]) + D2 = Matrix([[2, 2, 0], [1, 1, 1], [3, 2, 1]]) + H2 = StateSpace(A2, B2, C2, D2) + tm3 = H2.rewrite(TransferFunction) + + # outputs for input i obtained at Index i-1. Consider input 1 + assert tm3[0][0] == TransferFunction(2.0*s**3 + 1.0*s**2 - 10.5*s + 4.5, 1.0*s**3 + 0.5*s**2 - 6.5*s - 2.5, s) + assert tm3[0][1] == TransferFunction(2.0*s**3 + 2.0*s**2 - 10.5*s - 3.5, 1.0*s**3 + 0.5*s**2 - 6.5*s - 2.5, s) + assert tm3[0][2] == TransferFunction(2.0*s**2 + 5.0*s - 0.5, 1.0*s**3 + 0.5*s**2 - 6.5*s - 2.5, s) + assert H3.rewrite(TransferFunction) == [[TransferFunction(-c1*(a1*b2 - a3*b1 + b1*s) - c2*(-a0*b2 + a2*b1 + b2*s), + -a0*a3 + a0*s + a1*a2 + a3*s - s**2, s)]] + # TransferFunction to StateSpace + SS = TF1.rewrite(StateSpace) + assert SS == \ + StateSpace(Matrix([[ 0, 1], + [-wn**2, -2*wn*zeta]]), + Matrix([[0], + [1]]), + Matrix([[1, 0]]), + Matrix([[0]])) + assert SS.rewrite(TransferFunction)[0][0] == TF1 + + # Transfer function has to be proper + raises(ValueError, lambda: TransferFunction(b*s**2 + p**2 - a*p + s, b - p**2, s).rewrite(StateSpace)) + + +def test_StateSpace_dsolve(): + # https://web.mit.edu/2.14/www/Handouts/StateSpaceResponse.pdf + # https://lpsa.swarthmore.edu/Transient/TransMethSS.html + A1 = Matrix([[0, 1], [-2, -3]]) + B1 = Matrix([[0], [1]]) + C1 = Matrix([[1, -1]]) + D1 = Matrix([0]) + I1 = Matrix([[1], [2]]) + t = symbols('t') + ss1 = StateSpace(A1, B1, C1, D1) + + # Zero input and Zero initial conditions + assert ss1.dsolve() == Matrix([[0]]) + assert ss1.dsolve(initial_conditions=I1) == Matrix([[8*exp(-t) - 9*exp(-2*t)]]) + + A2 = Matrix([[-2, 0], [1, -1]]) + C2 = eye(2,2) + I2 = Matrix([2, 3]) + ss2 = StateSpace(A=A2, C=C2) + assert ss2.dsolve(initial_conditions=I2) == Matrix([[2*exp(-2*t)], [5*exp(-t) - 2*exp(-2*t)]]) + + A3 = Matrix([[-1, 1], [-4, -4]]) + B3 = Matrix([[0], [4]]) + C3 = Matrix([[0, 1]]) + D3 = Matrix([0]) + U3 = Matrix([10]) + ss3 = StateSpace(A3, B3, C3, D3) + op = ss3.dsolve(input_vector=U3, var=t) + assert str(op.simplify().expand().evalf()[0]) == str(5.0 + 20.7880460155075*exp(-5*t/2)*sin(sqrt(7)*t/2) + - 5.0*exp(-5*t/2)*cos(sqrt(7)*t/2)) + + # Test with Heaviside as input + A4 = Matrix([[-1, 1], [-4, -4]]) + B4 = Matrix([[0], [4]]) + C4 = Matrix([[0, 1]]) + U4 = Matrix([[10*Heaviside(t)]]) + ss4 = StateSpace(A4, B4, C4) + op4 = str(ss4.dsolve(var=t, input_vector=U4)[0].simplify().expand().evalf()) + assert op4 == str(5.0*Heaviside(t) + 20.7880460155075*exp(-5*t/2)*sin(sqrt(7)*t/2)*Heaviside(t) + - 5.0*exp(-5*t/2)*cos(sqrt(7)*t/2)*Heaviside(t)) + + # Test with Symbolic Matrices + m, a, x0 = symbols('m a x_0') + A5 = Matrix([[0, 1], [0, 0]]) + B5 = Matrix([[0], [1 / m]]) + C5 = Matrix([[1, 0]]) + I5 = Matrix([[x0], [0]]) + U5 = Matrix([[exp(-a * t)]]) + ss5 = StateSpace(A5, B5, C5) + op5 = ss5.dsolve(initial_conditions=I5, input_vector=U5, var=t).simplify() + assert op5[0].args[0][0] == x0 + t/(a*m) - 1/(a**2*m) + exp(-a*t)/(a**2*m) + a11, a12, a21, a22, b1, b2, c1, c2, i1, i2 = symbols('a_11 a_12 a_21 a_22 b_1 b_2 c_1 c_2 i_1 i_2') + A6 = Matrix([[a11, a12], [a21, a22]]) + B6 = Matrix([b1, b2]) + C6 = Matrix([[c1, c2]]) + I6 = Matrix([i1, i2]) + ss6 = StateSpace(A6, B6, C6) + expr6 = ss6.dsolve(initial_conditions=I6)[0] + expr6 = expr6.subs([(a11, 0), (a12, 1), (a21, -2), (a22, -3), (b1, 0), (b2, 1), (c1, 1), (c2, -1), (i1, 1), (i2, 2)]) + assert expr6 == 8*exp(-t) - 9*exp(-2*t) + + +def test_StateSpace_functions(): + # https://in.mathworks.com/help/control/ref/statespacemodel.obsv.html + + A_mat = Matrix([[-1.5, -2], [1, 0]]) + B_mat = Matrix([0.5, 0]) + C_mat = Matrix([[0, 1]]) + D_mat = Matrix([1]) + SS1 = StateSpace(A_mat, B_mat, C_mat, D_mat) + SS2 = StateSpace(Matrix([[1, 1], [4, -2]]),Matrix([[0, 1], [0, 2]]),Matrix([[-1, 1], [1, -1]])) + SS3 = StateSpace(Matrix([[1, 1], [4, -2]]),Matrix([[1, -1], [1, -1]])) + SS4 = StateSpace(Matrix([[a0, a1], [a2, a3]]), Matrix([[b1], [b2]]), Matrix([[c1, c2]])) + + # Observability + assert SS1.is_observable() == True + assert SS2.is_observable() == False + assert SS1.observability_matrix() == Matrix([[0, 1], [1, 0]]) + assert SS2.observability_matrix() == Matrix([[-1, 1], [ 1, -1], [ 3, -3], [-3, 3]]) + assert SS1.observable_subspace() == [Matrix([[0], [1]]), Matrix([[1], [0]])] + assert SS2.observable_subspace() == [Matrix([[-1], [ 1], [ 3], [-3]])] + Qo = SS4.observability_matrix().subs([(a0, 0), (a1, -6), (a2, 1), (a3, -5), (c1, 0), (c2, 1)]) + assert Qo == Matrix([[0, 1], [1, -5]]) + + # Controllability + assert SS1.is_controllable() == True + assert SS3.is_controllable() == False + assert SS1.controllability_matrix() == Matrix([[0.5, -0.75], [ 0, 0.5]]) + assert SS3.controllability_matrix() == Matrix([[1, -1, 2, -2], [1, -1, 2, -2]]) + assert SS1.controllable_subspace() == [Matrix([[0.5], [ 0]]), Matrix([[-0.75], [ 0.5]])] + assert SS3.controllable_subspace() == [Matrix([[1], [1]])] + assert SS4.controllable_subspace() == [Matrix([ + [b1], + [b2]]), Matrix([ + [a0*b1 + a1*b2], + [a2*b1 + a3*b2]])] + Qc = SS4.controllability_matrix().subs([(a0, 0), (a1, 1), (a2, -6), (a3, -5), (b1, 0), (b2, 1)]) + assert Qc == Matrix([[0, 1], [1, -5]]) + + # Append + A1 = Matrix([[0, 1], [1, 0]]) + B1 = Matrix([[0], [1]]) + C1 = Matrix([[0, 1]]) + D1 = Matrix([[0]]) + ss1 = StateSpace(A1, B1, C1, D1) + ss2 = StateSpace(Matrix([[1, 0], [0, 1]]), Matrix([[1], [0]]), Matrix([[1, 0]]), Matrix([[1]])) + ss3 = ss1.append(ss2) + ss4 = SS4.append(ss1) + + assert ss3.num_states == ss1.num_states + ss2.num_states + assert ss3.num_inputs == ss1.num_inputs + ss2.num_inputs + assert ss3.num_outputs == ss1.num_outputs + ss2.num_outputs + assert ss3.state_matrix == Matrix([[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]) + assert ss3.input_matrix == Matrix([[0, 0], [1, 0], [0, 1], [0, 0]]) + assert ss3.output_matrix == Matrix([[0, 1, 0, 0], [0, 0, 1, 0]]) + assert ss3.feedforward_matrix == Matrix([[0, 0], [0, 1]]) + + # Using symbolic matrices + assert ss4.num_states == SS4.num_states + ss1.num_states + assert ss4.num_inputs == SS4.num_inputs + ss1.num_inputs + assert ss4.num_outputs == SS4.num_outputs + ss1.num_outputs + assert ss4.state_matrix == Matrix([[a0, a1, 0, 0], [a2, a3, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0]]) + assert ss4.input_matrix == Matrix([[b1, 0], [b2, 0], [0, 0], [0, 1]]) + assert ss4.output_matrix == Matrix([[c1, c2, 0, 0], [0, 0, 0, 1]]) + assert ss4.feedforward_matrix == Matrix([[0, 0], [0, 0]]) + + +def test_StateSpace_series(): + # For SISO Systems + a1 = Matrix([[0, 1], [1, 0]]) + b1 = Matrix([[0], [1]]) + c1 = Matrix([[0, 1]]) + d1 = Matrix([[0]]) + a2 = Matrix([[1, 0], [0, 1]]) + b2 = Matrix([[1], [0]]) + c2 = Matrix([[1, 0]]) + d2 = Matrix([[1]]) + + ss1 = StateSpace(a1, b1, c1, d1) + ss2 = StateSpace(a2, b2, c2, d2) + tf1 = TransferFunction(s, s+1, s) + ser1 = Series(ss1, ss2) + assert ser1 == Series(StateSpace(Matrix([ + [0, 1], + [1, 0]]), Matrix([ + [0], + [1]]), Matrix([[0, 1]]), Matrix([[0]])), StateSpace(Matrix([ + [1, 0], + [0, 1]]), Matrix([ + [1], + [0]]), Matrix([[1, 0]]), Matrix([[1]]))) + assert ser1.doit() == StateSpace( + Matrix([ + [0, 1, 0, 0], + [1, 0, 0, 0], + [0, 1, 1, 0], + [0, 0, 0, 1]]), + Matrix([ + [0], + [1], + [0], + [0]]), + Matrix([[0, 1, 1, 0]]), + Matrix([[0]])) + + assert ser1.num_inputs == 1 + assert ser1.num_outputs == 1 + assert ser1.rewrite(TransferFunction) == TransferFunction(s**2, s**3 - s**2 - s + 1, s) + ser2 = Series(ss1) + ser3 = Series(ser2, ss2) + assert ser3.doit() == ser1.doit() + + # TransferFunction interconnection with StateSpace + ser_tf = Series(tf1, ss1) + assert ser_tf == Series(TransferFunction(s, s + 1, s), StateSpace(Matrix([ + [0, 1], + [1, 0]]), Matrix([ + [0], + [1]]), Matrix([[0, 1]]), Matrix([[0]]))) + assert ser_tf.doit() == StateSpace( + Matrix([ + [-1, 0, 0], + [0, 0, 1], + [-1, 1, 0]]), + Matrix([ + [1], + [0], + [1]]), + Matrix([[0, 0, 1]]), + Matrix([[0]])) + assert ser_tf.rewrite(TransferFunction) == TransferFunction(s**2, s**3 + s**2 - s - 1, s) + + # For MIMO Systems + a3 = Matrix([[4, 1], [2, -3]]) + b3 = Matrix([[5, 2], [-3, -3]]) + c3 = Matrix([[2, -4], [0, 1]]) + d3 = Matrix([[3, 2], [1, -1]]) + a4 = Matrix([[-3, 4, 2], [-1, -3, 0], [2, 5, 3]]) + b4 = Matrix([[1, 4], [-3, -3], [-2, 1]]) + c4 = Matrix([[4, 2, -3], [1, 4, 3]]) + d4 = Matrix([[-2, 4], [0, 1]]) + ss3 = StateSpace(a3, b3, c3, d3) + ss4 = StateSpace(a4, b4, c4, d4) + ser4 = MIMOSeries(ss3, ss4) + assert ser4 == MIMOSeries(StateSpace(Matrix([ + [4, 1], + [2, -3]]), Matrix([ + [ 5, 2], + [-3, -3]]), Matrix([ + [2, -4], + [0, 1]]), Matrix([ + [3, 2], + [1, -1]])), StateSpace(Matrix([ + [-3, 4, 2], + [-1, -3, 0], + [ 2, 5, 3]]), Matrix([ + [ 1, 4], + [-3, -3], + [-2, 1]]), Matrix([ + [4, 2, -3], + [1, 4, 3]]), Matrix([ + [-2, 4], + [ 0, 1]]))) + assert ser4.doit() == StateSpace( + Matrix([ + [4, 1, 0, 0, 0], + [2, -3, 0, 0, 0], + [2, 0, -3, 4, 2], + [-6, 9, -1, -3, 0], + [-4, 9, 2, 5, 3]]), + Matrix([ + [5, 2], + [-3, -3], + [7, -2], + [-12, -3], + [-5, -5]]), + Matrix([ + [-4, 12, 4, 2, -3], + [0, 1, 1, 4, 3]]), + Matrix([ + [-2, -8], + [1, -1]])) + assert ser4.num_inputs == ss3.num_inputs + assert ser4.num_outputs == ss4.num_outputs + ser5 = MIMOSeries(ss3) + ser6 = MIMOSeries(ser5, ss4) + assert ser6.doit() == ser4.doit() + assert ser6.rewrite(TransferFunctionMatrix) == ser4.rewrite(TransferFunctionMatrix) + tf2 = TransferFunction(1, s, s) + tf3 = TransferFunction(1, s+1, s) + tf4 = TransferFunction(s, s+2, s) + tfm = TransferFunctionMatrix([[tf1, tf2], [tf3, tf4]]) + ser6 = MIMOSeries(ss3, tfm) + assert ser6 == MIMOSeries(StateSpace(Matrix([ + [4, 1], + [2, -3]]), Matrix([ + [ 5, 2], + [-3, -3]]), Matrix([ + [2, -4], + [0, 1]]), Matrix([ + [3, 2], + [1, -1]])), TransferFunctionMatrix(( + (TransferFunction(s, s + 1, s), TransferFunction(1, s, s)), + (TransferFunction(1, s + 1, s), TransferFunction(s, s + 2, s))))) + + +def test_StateSpace_parallel(): + # For SISO system + a1 = Matrix([[0, 1], [1, 0]]) + b1 = Matrix([[0], [1]]) + c1 = Matrix([[0, 1]]) + d1 = Matrix([[0]]) + a2 = Matrix([[1, 0], [0, 1]]) + b2 = Matrix([[1], [0]]) + c2 = Matrix([[1, 0]]) + d2 = Matrix([[1]]) + ss1 = StateSpace(a1, b1, c1, d1) + ss2 = StateSpace(a2, b2, c2, d2) + p1 = Parallel(ss1, ss2) + assert p1 == Parallel(StateSpace(Matrix([[0, 1], [1, 0]]), Matrix([[0], [1]]), Matrix([[0, 1]]), Matrix([[0]])), + StateSpace(Matrix([[1, 0],[0, 1]]), Matrix([[1],[0]]), Matrix([[1, 0]]), Matrix([[1]]))) + assert p1.doit() == StateSpace(Matrix([ + [0, 1, 0, 0], + [1, 0, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 1]]), + Matrix([ + [0], + [1], + [1], + [0]]), + Matrix([[0, 1, 1, 0]]), + Matrix([[1]])) + assert p1.rewrite(TransferFunction) == TransferFunction(s*(s + 2), s**2 - 1, s) + + # Connecting StateSpace with TransferFunction + tf1 = TransferFunction(s, s+1, s) + p2 = Parallel(ss1, tf1) + assert p2 == Parallel(StateSpace(Matrix([ + [0, 1], + [1, 0]]), Matrix([ + [0], + [1]]), Matrix([[0, 1]]), Matrix([[0]])), TransferFunction(s, s + 1, s)) + assert p2.doit() == StateSpace( + Matrix([ + [0, 1, 0], + [1, 0, 0], + [0, 0, -1]]), + Matrix([ + [0], + [1], + [1]]), + Matrix([[0, 1, -1]]), + Matrix([[1]])) + assert p2.rewrite(TransferFunction) == TransferFunction(s**2, s**2 - 1, s) + + # For MIMO + a3 = Matrix([[4, 1], [2, -3]]) + b3 = Matrix([[5, 2], [-3, -3]]) + c3 = Matrix([[2, -4], [0, 1]]) + d3 = Matrix([[3, 2], [1, -1]]) + a4 = Matrix([[-3, 4, 2], [-1, -3, 0], [2, 5, 3]]) + b4 = Matrix([[1, 4], [-3, -3], [-2, 1]]) + c4 = Matrix([[4, 2, -3], [1, 4, 3]]) + d4 = Matrix([[-2, 4], [0, 1]]) + ss3 = StateSpace(a3, b3, c3, d3) + ss4 = StateSpace(a4, b4, c4, d4) + p3 = MIMOParallel(ss3, ss4) + assert p3 == MIMOParallel(StateSpace(Matrix([ + [4, 1], + [2, -3]]), Matrix([ + [ 5, 2], + [-3, -3]]), Matrix([ + [2, -4], + [0, 1]]), Matrix([ + [3, 2], + [1, -1]])), StateSpace(Matrix([ + [-3, 4, 2], + [-1, -3, 0], + [ 2, 5, 3]]), Matrix([ + [ 1, 4], + [-3, -3], + [-2, 1]]), Matrix([ + [4, 2, -3], + [1, 4, 3]]), Matrix([ + [-2, 4], + [ 0, 1]]))) + assert p3.doit() == StateSpace(Matrix([ + [4, 1, 0, 0, 0], + [2, -3, 0, 0, 0], + [0, 0, -3, 4, 2], + [0, 0, -1, -3, 0], + [0, 0, 2, 5, 3]]), + Matrix([ + [5, 2], + [-3, -3], + [1, 4], + [-3, -3], + [-2, 1]]), + Matrix([ + [2, -4, 4, 2, -3], + [0, 1, 1, 4, 3]]), + Matrix([ + [1, 6], + [1, 0]])) + + # Using StateSpace with MIMOParallel. + tf2 = TransferFunction(1, s, s) + tf3 = TransferFunction(1, s + 1, s) + tf4 = TransferFunction(s, s + 2, s) + tfm = TransferFunctionMatrix([[tf1, tf2], [tf3, tf4]]) + p4 = MIMOParallel(tfm, ss3) + assert p4 == MIMOParallel(TransferFunctionMatrix(( + (TransferFunction(s, s + 1, s), TransferFunction(1, s, s)), + (TransferFunction(1, s + 1, s), TransferFunction(s, s + 2, s)))), + StateSpace(Matrix([ + [4, 1], + [2, -3]]), Matrix([ + [5, 2], + [-3, -3]]), Matrix([ + [2, -4], + [0, 1]]), Matrix([ + [3, 2], + [1, -1]]))) + + +def test_StateSpace_feedback(): + # For SISO + a1 = Matrix([[0, 1], [1, 0]]) + b1 = Matrix([[0], [1]]) + c1 = Matrix([[0, 1]]) + d1 = Matrix([[0]]) + a2 = Matrix([[1, 0], [0, 1]]) + b2 = Matrix([[1], [0]]) + c2 = Matrix([[1, 0]]) + d2 = Matrix([[1]]) + ss1 = StateSpace(a1, b1, c1, d1) + ss2 = StateSpace(a2, b2, c2, d2) + fd1 = Feedback(ss1, ss2) + + # Negative feedback + assert fd1 == Feedback(StateSpace(Matrix([[0, 1], [1, 0]]), Matrix([[0], [1]]), Matrix([[0, 1]]), Matrix([[0]])), + StateSpace(Matrix([[1, 0],[0, 1]]), Matrix([[1],[0]]), Matrix([[1, 0]]), Matrix([[1]])), -1) + assert fd1.doit() == StateSpace(Matrix([ + [0, 1, 0, 0], + [1, -1, -1, 0], + [0, 1, 1, 0], + [0, 0, 0, 1]]), Matrix([ + [0], + [1], + [0], + [0]]), Matrix( + [[0, 1, 0, 0]]), Matrix( + [[0]])) + assert fd1.rewrite(TransferFunction) == TransferFunction(s*(s - 1), s**3 - s + 1, s) + + # Positive Feedback + fd2 = Feedback(ss1, ss2, 1) + assert fd2.doit() == StateSpace(Matrix([ + [0, 1, 0, 0], + [1, 1, 1, 0], + [0, 1, 1, 0], + [0, 0, 0, 1]]), Matrix([ + [0], + [1], + [0], + [0]]), Matrix( + [[0, 1, 0, 0]]), Matrix( + [[0]])) + assert fd2.rewrite(TransferFunction) == TransferFunction(s*(s - 1), s**3 - 2*s**2 - s + 1, s) + + # Connection with TransferFunction + tf1 = TransferFunction(s, s+1, s) + fd3 = Feedback(ss1, tf1) + assert fd3 == Feedback(StateSpace(Matrix([ + [0, 1], + [1, 0]]), Matrix([ + [0], + [1]]), Matrix([[0, 1]]), Matrix([[0]])), + TransferFunction(s, s + 1, s), -1) + assert fd3.doit() == StateSpace (Matrix([ + [0, 1, 0], + [1, -1, 1], + [0, 1, -1]]), Matrix([ + [0], + [1], + [0]]), Matrix( + [[0, 1, 0]]), Matrix( + [[0]])) + + # For MIMO + a3 = Matrix([[4, 1], [2, -3]]) + b3 = Matrix([[5, 2], [-3, -3]]) + c3 = Matrix([[2, -4], [0, 1]]) + d3 = Matrix([[3, 2], [1, -1]]) + a4 = Matrix([[-3, 4, 2], [-1, -3, 0], [2, 5, 3]]) + b4 = Matrix([[1, 4], [-3, -3], [-2, 1]]) + c4 = Matrix([[4, 2, -3], [1, 4, 3]]) + d4 = Matrix([[-2, 4], [0, 1]]) + ss3 = StateSpace(a3, b3, c3, d3) + ss4 = StateSpace(a4, b4, c4, d4) + + # Negative Feedback + fd4 = MIMOFeedback(ss3, ss4) + assert fd4 == MIMOFeedback(StateSpace(Matrix([ + [4, 1], + [2, -3]]), Matrix([ + [ 5, 2], + [-3, -3]]), Matrix([ + [2, -4], + [0, 1]]), Matrix([ + [3, 2], + [1, -1]])), StateSpace(Matrix([ + [-3, 4, 2], + [-1, -3, 0], + [ 2, 5, 3]]), Matrix([ + [ 1, 4], + [-3, -3], + [-2, 1]]), Matrix([ + [4, 2, -3], + [1, 4, 3]]), Matrix([ + [-2, 4], + [ 0, 1]])), -1) + assert fd4.doit() == StateSpace(Matrix([ + [Rational(3), Rational(-3, 4), Rational(-15, 4), Rational(-37, 2), Rational(-15)], + [Rational(7, 2), Rational(-39, 8), Rational(9, 8), Rational(39, 4), Rational(9)], + [Rational(3), Rational(-41, 4), Rational(-45, 4), Rational(-51, 2), Rational(-19)], + [Rational(-9, 2), Rational(129, 8), Rational(73, 8), Rational(171, 4), Rational(36)], + [Rational(-3, 2), Rational(47, 8), Rational(31, 8), Rational(85, 4), Rational(18)]]), Matrix([ + [Rational(-1, 4), Rational(19, 4)], + [Rational(3, 8), Rational(-21, 8)], + [Rational(1, 4), Rational(29, 4)], + [Rational(3, 8), Rational(-93, 8)], + [Rational(5, 8), Rational(-35, 8)]]), Matrix([ + [Rational(1), Rational(-15, 4), Rational(-7, 4), Rational(-21, 2), Rational(-9)], + [Rational(1, 2), Rational(-13, 8), Rational(-13, 8), Rational(-19, 4), Rational(-3)]]), Matrix([ + [Rational(-1, 4), Rational(11, 4)], + [Rational(1, 8), Rational(9, 8)]])) + + # Positive Feedback + fd5 = MIMOFeedback(ss3, ss4, 1) + assert fd5.doit() == StateSpace(Matrix([ + [Rational(4, 7), Rational(62, 7), Rational(1), Rational(-8), Rational(-69, 7)], + [Rational(32, 7), Rational(-135, 14), Rational(-3, 2), Rational(3), Rational(36, 7)], + [Rational(-10, 7), Rational(41, 7), Rational(-4), Rational(-12), Rational(-97, 7)], + [Rational(12, 7), Rational(-111, 14), Rational(-5, 2), Rational(18), Rational(171, 7)], + [Rational(2, 7), Rational(-29, 14), Rational(-1, 2), Rational(10), Rational(81, 7)]]), Matrix([ + [Rational(6, 7), Rational(-17, 7)], + [Rational(-9, 14), Rational(15, 14)], + [Rational(6, 7), Rational(-31, 7)], + [Rational(-27, 14), Rational(87, 14)], + [Rational(-15, 14), Rational(25, 14)]]), Matrix([ + [Rational(-2, 7), Rational(11, 7), Rational(1), Rational(-4), Rational(-39, 7)], + [Rational(-2, 7), Rational(15, 14), Rational(-1, 2), Rational(-3), Rational(-18, 7)]]), Matrix([ + [Rational(4, 7), Rational(-9, 7)], + [Rational(1, 14), Rational(-11, 14)]])) diff --git a/phivenv/Lib/site-packages/sympy/physics/hep/__init__.py b/phivenv/Lib/site-packages/sympy/physics/hep/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/physics/hep/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/hep/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..80448a9312cdf9aea31724709bef05004dd20639 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/hep/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/hep/__pycache__/gamma_matrices.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/hep/__pycache__/gamma_matrices.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..404881e0ed5af8081b2fbbab1f5b8a0adf42b006 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/hep/__pycache__/gamma_matrices.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/hep/gamma_matrices.py b/phivenv/Lib/site-packages/sympy/physics/hep/gamma_matrices.py new file mode 100644 index 0000000000000000000000000000000000000000..40c3d0754438902f304d01c2df354dd09f9ea257 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/hep/gamma_matrices.py @@ -0,0 +1,716 @@ +""" + Module to handle gamma matrices expressed as tensor objects. + + Examples + ======== + + >>> from sympy.physics.hep.gamma_matrices import GammaMatrix as G, LorentzIndex + >>> from sympy.tensor.tensor import tensor_indices + >>> i = tensor_indices('i', LorentzIndex) + >>> G(i) + GammaMatrix(i) + + Note that there is already an instance of GammaMatrixHead in four dimensions: + GammaMatrix, which is simply declare as + + >>> from sympy.physics.hep.gamma_matrices import GammaMatrix + >>> from sympy.tensor.tensor import tensor_indices + >>> i = tensor_indices('i', LorentzIndex) + >>> GammaMatrix(i) + GammaMatrix(i) + + To access the metric tensor + + >>> LorentzIndex.metric + metric(LorentzIndex,LorentzIndex) + +""" +from sympy.core.mul import Mul +from sympy.core.singleton import S +from sympy.matrices.dense import eye +from sympy.matrices.expressions.trace import trace +from sympy.tensor.tensor import TensorIndexType, TensorIndex,\ + TensMul, TensAdd, tensor_mul, Tensor, TensorHead, TensorSymmetry + + +# DiracSpinorIndex = TensorIndexType('DiracSpinorIndex', dim=4, dummy_name="S") + + +LorentzIndex = TensorIndexType('LorentzIndex', dim=4, dummy_name="L") + + +GammaMatrix = TensorHead("GammaMatrix", [LorentzIndex], + TensorSymmetry.no_symmetry(1), comm=None) + + +def extract_type_tens(expression, component): + """ + Extract from a ``TensExpr`` all tensors with `component`. + + Returns two tensor expressions: + + * the first contains all ``Tensor`` of having `component`. + * the second contains all remaining. + + + """ + if isinstance(expression, Tensor): + sp = [expression] + elif isinstance(expression, TensMul): + sp = expression.args + else: + raise ValueError('wrong type') + + # Collect all gamma matrices of the same dimension + new_expr = S.One + residual_expr = S.One + for i in sp: + if isinstance(i, Tensor) and i.component == component: + new_expr *= i + else: + residual_expr *= i + return new_expr, residual_expr + + +def simplify_gamma_expression(expression): + extracted_expr, residual_expr = extract_type_tens(expression, GammaMatrix) + res_expr = _simplify_single_line(extracted_expr) + return res_expr * residual_expr + + +def simplify_gpgp(ex, sort=True): + """ + simplify products ``G(i)*p(-i)*G(j)*p(-j) -> p(i)*p(-i)`` + + Examples + ======== + + >>> from sympy.physics.hep.gamma_matrices import GammaMatrix as G, \ + LorentzIndex, simplify_gpgp + >>> from sympy.tensor.tensor import tensor_indices, tensor_heads + >>> p, q = tensor_heads('p, q', [LorentzIndex]) + >>> i0,i1,i2,i3,i4,i5 = tensor_indices('i0:6', LorentzIndex) + >>> ps = p(i0)*G(-i0) + >>> qs = q(i0)*G(-i0) + >>> simplify_gpgp(ps*qs*qs) + GammaMatrix(-L_0)*p(L_0)*q(L_1)*q(-L_1) + """ + def _simplify_gpgp(ex): + components = ex.components + a = [] + comp_map = [] + for i, comp in enumerate(components): + comp_map.extend([i]*comp.rank) + dum = [(i[0], i[1], comp_map[i[0]], comp_map[i[1]]) for i in ex.dum] + for i in range(len(components)): + if components[i] != GammaMatrix: + continue + for dx in dum: + if dx[2] == i: + p_pos1 = dx[3] + elif dx[3] == i: + p_pos1 = dx[2] + else: + continue + comp1 = components[p_pos1] + if comp1.comm == 0 and comp1.rank == 1: + a.append((i, p_pos1)) + if not a: + return ex + elim = set() + tv = [] + hit = True + coeff = S.One + ta = None + while hit: + hit = False + for i, ai in enumerate(a[:-1]): + if ai[0] in elim: + continue + if ai[0] != a[i + 1][0] - 1: + continue + if components[ai[1]] != components[a[i + 1][1]]: + continue + elim.add(ai[0]) + elim.add(ai[1]) + elim.add(a[i + 1][0]) + elim.add(a[i + 1][1]) + if not ta: + ta = ex.split() + mu = TensorIndex('mu', LorentzIndex) + hit = True + if i == 0: + coeff = ex.coeff + tx = components[ai[1]](mu)*components[ai[1]](-mu) + if len(a) == 2: + tx *= 4 # eye(4) + tv.append(tx) + break + + if tv: + a = [x for j, x in enumerate(ta) if j not in elim] + a.extend(tv) + t = tensor_mul(*a)*coeff + # t = t.replace(lambda x: x.is_Matrix, lambda x: 1) + return t + else: + return ex + + if sort: + ex = ex.sorted_components() + # this would be better off with pattern matching + while 1: + t = _simplify_gpgp(ex) + if t != ex: + ex = t + else: + return t + + +def gamma_trace(t): + """ + trace of a single line of gamma matrices + + Examples + ======== + + >>> from sympy.physics.hep.gamma_matrices import GammaMatrix as G, \ + gamma_trace, LorentzIndex + >>> from sympy.tensor.tensor import tensor_indices, tensor_heads + >>> p, q = tensor_heads('p, q', [LorentzIndex]) + >>> i0,i1,i2,i3,i4,i5 = tensor_indices('i0:6', LorentzIndex) + >>> ps = p(i0)*G(-i0) + >>> qs = q(i0)*G(-i0) + >>> gamma_trace(G(i0)*G(i1)) + 4*metric(i0, i1) + >>> gamma_trace(ps*ps) - 4*p(i0)*p(-i0) + 0 + >>> gamma_trace(ps*qs + ps*ps) - 4*p(i0)*p(-i0) - 4*p(i0)*q(-i0) + 0 + + """ + if isinstance(t, TensAdd): + res = TensAdd(*[gamma_trace(x) for x in t.args]) + return res + t = _simplify_single_line(t) + res = _trace_single_line(t) + return res + + +def _simplify_single_line(expression): + """ + Simplify single-line product of gamma matrices. + + Examples + ======== + + >>> from sympy.physics.hep.gamma_matrices import GammaMatrix as G, \ + LorentzIndex, _simplify_single_line + >>> from sympy.tensor.tensor import tensor_indices, TensorHead + >>> p = TensorHead('p', [LorentzIndex]) + >>> i0,i1 = tensor_indices('i0:2', LorentzIndex) + >>> _simplify_single_line(G(i0)*G(i1)*p(-i1)*G(-i0)) + 2*G(i0)*p(-i0) + 0 + + """ + t1, t2 = extract_type_tens(expression, GammaMatrix) + if t1 != 1: + t1 = kahane_simplify(t1) + res = t1*t2 + return res + + +def _trace_single_line(t): + """ + Evaluate the trace of a single gamma matrix line inside a ``TensExpr``. + + Notes + ===== + + If there are ``DiracSpinorIndex.auto_left`` and ``DiracSpinorIndex.auto_right`` + indices trace over them; otherwise traces are not implied (explain) + + + Examples + ======== + + >>> from sympy.physics.hep.gamma_matrices import GammaMatrix as G, \ + LorentzIndex, _trace_single_line + >>> from sympy.tensor.tensor import tensor_indices, TensorHead + >>> p = TensorHead('p', [LorentzIndex]) + >>> i0,i1,i2,i3,i4,i5 = tensor_indices('i0:6', LorentzIndex) + >>> _trace_single_line(G(i0)*G(i1)) + 4*metric(i0, i1) + >>> _trace_single_line(G(i0)*p(-i0)*G(i1)*p(-i1)) - 4*p(i0)*p(-i0) + 0 + + """ + def _trace_single_line1(t): + t = t.sorted_components() + components = t.components + ncomps = len(components) + g = LorentzIndex.metric + # gamma matirices are in a[i:j] + hit = 0 + for i in range(ncomps): + if components[i] == GammaMatrix: + hit = 1 + break + + for j in range(i + hit, ncomps): + if components[j] != GammaMatrix: + break + else: + j = ncomps + numG = j - i + if numG == 0: + tcoeff = t.coeff + return t.nocoeff if tcoeff else t + if numG % 2 == 1: + return TensMul.from_data(S.Zero, [], [], []) + elif numG > 4: + # find the open matrix indices and connect them: + a = t.split() + ind1 = a[i].get_indices()[0] + ind2 = a[i + 1].get_indices()[0] + aa = a[:i] + a[i + 2:] + t1 = tensor_mul(*aa)*g(ind1, ind2) + t1 = t1.contract_metric(g) + args = [t1] + sign = 1 + for k in range(i + 2, j): + sign = -sign + ind2 = a[k].get_indices()[0] + aa = a[:i] + a[i + 1:k] + a[k + 1:] + t2 = sign*tensor_mul(*aa)*g(ind1, ind2) + t2 = t2.contract_metric(g) + t2 = simplify_gpgp(t2, False) + args.append(t2) + t3 = TensAdd(*args) + t3 = _trace_single_line(t3) + return t3 + else: + a = t.split() + t1 = _gamma_trace1(*a[i:j]) + a2 = a[:i] + a[j:] + t2 = tensor_mul(*a2) + t3 = t1*t2 + if not t3: + return t3 + t3 = t3.contract_metric(g) + return t3 + + t = t.expand() + if isinstance(t, TensAdd): + a = [_trace_single_line1(x)*x.coeff for x in t.args] + return TensAdd(*a) + elif isinstance(t, (Tensor, TensMul)): + r = t.coeff*_trace_single_line1(t) + return r + else: + return trace(t) + + +def _gamma_trace1(*a): + gctr = 4 # FIXME specific for d=4 + g = LorentzIndex.metric + if not a: + return gctr + n = len(a) + if n%2 == 1: + #return TensMul.from_data(S.Zero, [], [], []) + return S.Zero + if n == 2: + ind0 = a[0].get_indices()[0] + ind1 = a[1].get_indices()[0] + return gctr*g(ind0, ind1) + if n == 4: + ind0 = a[0].get_indices()[0] + ind1 = a[1].get_indices()[0] + ind2 = a[2].get_indices()[0] + ind3 = a[3].get_indices()[0] + + return gctr*(g(ind0, ind1)*g(ind2, ind3) - \ + g(ind0, ind2)*g(ind1, ind3) + g(ind0, ind3)*g(ind1, ind2)) + + +def kahane_simplify(expression): + r""" + This function cancels contracted elements in a product of four + dimensional gamma matrices, resulting in an expression equal to the given + one, without the contracted gamma matrices. + + Parameters + ========== + + `expression` the tensor expression containing the gamma matrices to simplify. + + Notes + ===== + + If spinor indices are given, the matrices must be given in + the order given in the product. + + Algorithm + ========= + + The idea behind the algorithm is to use some well-known identities, + i.e., for contractions enclosing an even number of `\gamma` matrices + + `\gamma^\mu \gamma_{a_1} \cdots \gamma_{a_{2N}} \gamma_\mu = 2 (\gamma_{a_{2N}} \gamma_{a_1} \cdots \gamma_{a_{2N-1}} + \gamma_{a_{2N-1}} \cdots \gamma_{a_1} \gamma_{a_{2N}} )` + + for an odd number of `\gamma` matrices + + `\gamma^\mu \gamma_{a_1} \cdots \gamma_{a_{2N+1}} \gamma_\mu = -2 \gamma_{a_{2N+1}} \gamma_{a_{2N}} \cdots \gamma_{a_{1}}` + + Instead of repeatedly applying these identities to cancel out all contracted indices, + it is possible to recognize the links that would result from such an operation, + the problem is thus reduced to a simple rearrangement of free gamma matrices. + + Examples + ======== + + When using, always remember that the original expression coefficient + has to be handled separately + + >>> from sympy.physics.hep.gamma_matrices import GammaMatrix as G, LorentzIndex + >>> from sympy.physics.hep.gamma_matrices import kahane_simplify + >>> from sympy.tensor.tensor import tensor_indices + >>> i0, i1, i2 = tensor_indices('i0:3', LorentzIndex) + >>> ta = G(i0)*G(-i0) + >>> kahane_simplify(ta) + Matrix([ + [4, 0, 0, 0], + [0, 4, 0, 0], + [0, 0, 4, 0], + [0, 0, 0, 4]]) + >>> tb = G(i0)*G(i1)*G(-i0) + >>> kahane_simplify(tb) + -2*GammaMatrix(i1) + >>> t = G(i0)*G(-i0) + >>> kahane_simplify(t) + Matrix([ + [4, 0, 0, 0], + [0, 4, 0, 0], + [0, 0, 4, 0], + [0, 0, 0, 4]]) + >>> t = G(i0)*G(-i0) + >>> kahane_simplify(t) + Matrix([ + [4, 0, 0, 0], + [0, 4, 0, 0], + [0, 0, 4, 0], + [0, 0, 0, 4]]) + + If there are no contractions, the same expression is returned + + >>> tc = G(i0)*G(i1) + >>> kahane_simplify(tc) + GammaMatrix(i0)*GammaMatrix(i1) + + References + ========== + + [1] Algorithm for Reducing Contracted Products of gamma Matrices, + Joseph Kahane, Journal of Mathematical Physics, Vol. 9, No. 10, October 1968. + """ + + if isinstance(expression, Mul): + return expression + if isinstance(expression, TensAdd): + return TensAdd(*[kahane_simplify(arg) for arg in expression.args]) + + if isinstance(expression, Tensor): + return expression + + assert isinstance(expression, TensMul) + + gammas = expression.args + + for gamma in gammas: + assert gamma.component == GammaMatrix + + free = expression.free + # spinor_free = [_ for _ in expression.free_in_args if _[1] != 0] + + # if len(spinor_free) == 2: + # spinor_free.sort(key=lambda x: x[2]) + # assert spinor_free[0][1] == 1 and spinor_free[-1][1] == 2 + # assert spinor_free[0][2] == 0 + # elif spinor_free: + # raise ValueError('spinor indices do not match') + + dum = [] + for dum_pair in expression.dum: + if expression.index_types[dum_pair[0]] == LorentzIndex: + dum.append((dum_pair[0], dum_pair[1])) + + dum = sorted(dum) + + if len(dum) == 0: # or GammaMatrixHead: + # no contractions in `expression`, just return it. + return expression + + # find the `first_dum_pos`, i.e. the position of the first contracted + # gamma matrix, Kahane's algorithm as described in his paper requires the + # gamma matrix expression to start with a contracted gamma matrix, this is + # a workaround which ignores possible initial free indices, and re-adds + # them later. + + first_dum_pos = min(map(min, dum)) + + # for p1, p2, a1, a2 in expression.dum_in_args: + # if p1 != 0 or p2 != 0: + # # only Lorentz indices, skip Dirac indices: + # continue + # first_dum_pos = min(p1, p2) + # break + + total_number = len(free) + len(dum)*2 + number_of_contractions = len(dum) + + free_pos = [None]*total_number + for i in free: + free_pos[i[1]] = i[0] + + # `index_is_free` is a list of booleans, to identify index position + # and whether that index is free or dummy. + index_is_free = [False]*total_number + + for i, indx in enumerate(free): + index_is_free[indx[1]] = True + + # `links` is a dictionary containing the graph described in Kahane's paper, + # to every key correspond one or two values, representing the linked indices. + # All values in `links` are integers, negative numbers are used in the case + # where it is necessary to insert gamma matrices between free indices, in + # order to make Kahane's algorithm work (see paper). + links = {i: [] for i in range(first_dum_pos, total_number)} + + # `cum_sign` is a step variable to mark the sign of every index, see paper. + cum_sign = -1 + # `cum_sign_list` keeps storage for all `cum_sign` (every index). + cum_sign_list = [None]*total_number + block_free_count = 0 + + # multiply `resulting_coeff` by the coefficient parameter, the rest + # of the algorithm ignores a scalar coefficient. + resulting_coeff = S.One + + # initialize a list of lists of indices. The outer list will contain all + # additive tensor expressions, while the inner list will contain the + # free indices (rearranged according to the algorithm). + resulting_indices = [[]] + + # start to count the `connected_components`, which together with the number + # of contractions, determines a -1 or +1 factor to be multiplied. + connected_components = 1 + + # First loop: here we fill `cum_sign_list`, and draw the links + # among consecutive indices (they are stored in `links`). Links among + # non-consecutive indices will be drawn later. + for i, is_free in enumerate(index_is_free): + # if `expression` starts with free indices, they are ignored here; + # they are later added as they are to the beginning of all + # `resulting_indices` list of lists of indices. + if i < first_dum_pos: + continue + + if is_free: + block_free_count += 1 + # if previous index was free as well, draw an arch in `links`. + if block_free_count > 1: + links[i - 1].append(i) + links[i].append(i - 1) + else: + # Change the sign of the index (`cum_sign`) if the number of free + # indices preceding it is even. + cum_sign *= 1 if (block_free_count % 2) else -1 + if block_free_count == 0 and i != first_dum_pos: + # check if there are two consecutive dummy indices: + # in this case create virtual indices with negative position, + # these "virtual" indices represent the insertion of two + # gamma^0 matrices to separate consecutive dummy indices, as + # Kahane's algorithm requires dummy indices to be separated by + # free indices. The product of two gamma^0 matrices is unity, + # so the new expression being examined is the same as the + # original one. + if cum_sign == -1: + links[-1-i] = [-1-i+1] + links[-1-i+1] = [-1-i] + if (i - cum_sign) in links: + if i != first_dum_pos: + links[i].append(i - cum_sign) + if block_free_count != 0: + if i - cum_sign < len(index_is_free): + if index_is_free[i - cum_sign]: + links[i - cum_sign].append(i) + block_free_count = 0 + + cum_sign_list[i] = cum_sign + + # The previous loop has only created links between consecutive free indices, + # it is necessary to properly create links among dummy (contracted) indices, + # according to the rules described in Kahane's paper. There is only one exception + # to Kahane's rules: the negative indices, which handle the case of some + # consecutive free indices (Kahane's paper just describes dummy indices + # separated by free indices, hinting that free indices can be added without + # altering the expression result). + for i in dum: + # get the positions of the two contracted indices: + pos1 = i[0] + pos2 = i[1] + + # create Kahane's upper links, i.e. the upper arcs between dummy + # (i.e. contracted) indices: + links[pos1].append(pos2) + links[pos2].append(pos1) + + # create Kahane's lower links, this corresponds to the arcs below + # the line described in the paper: + + # first we move `pos1` and `pos2` according to the sign of the indices: + linkpos1 = pos1 + cum_sign_list[pos1] + linkpos2 = pos2 + cum_sign_list[pos2] + + # otherwise, perform some checks before creating the lower arcs: + + # make sure we are not exceeding the total number of indices: + if linkpos1 >= total_number: + continue + if linkpos2 >= total_number: + continue + + # make sure we are not below the first dummy index in `expression`: + if linkpos1 < first_dum_pos: + continue + if linkpos2 < first_dum_pos: + continue + + # check if the previous loop created "virtual" indices between dummy + # indices, in such a case relink `linkpos1` and `linkpos2`: + if (-1-linkpos1) in links: + linkpos1 = -1-linkpos1 + if (-1-linkpos2) in links: + linkpos2 = -1-linkpos2 + + # move only if not next to free index: + if linkpos1 >= 0 and not index_is_free[linkpos1]: + linkpos1 = pos1 + + if linkpos2 >=0 and not index_is_free[linkpos2]: + linkpos2 = pos2 + + # create the lower arcs: + if linkpos2 not in links[linkpos1]: + links[linkpos1].append(linkpos2) + if linkpos1 not in links[linkpos2]: + links[linkpos2].append(linkpos1) + + # This loop starts from the `first_dum_pos` index (first dummy index) + # walks through the graph deleting the visited indices from `links`, + # it adds a gamma matrix for every free index in encounters, while it + # completely ignores dummy indices and virtual indices. + pointer = first_dum_pos + previous_pointer = 0 + while True: + if pointer in links: + next_ones = links.pop(pointer) + else: + break + + if previous_pointer in next_ones: + next_ones.remove(previous_pointer) + + previous_pointer = pointer + + if next_ones: + pointer = next_ones[0] + else: + break + + if pointer == previous_pointer: + break + if pointer >=0 and free_pos[pointer] is not None: + for ri in resulting_indices: + ri.append(free_pos[pointer]) + + # The following loop removes the remaining connected components in `links`. + # If there are free indices inside a connected component, it gives a + # contribution to the resulting expression given by the factor + # `gamma_a gamma_b ... gamma_z + gamma_z ... gamma_b gamma_a`, in Kahanes's + # paper represented as {gamma_a, gamma_b, ... , gamma_z}, + # virtual indices are ignored. The variable `connected_components` is + # increased by one for every connected component this loop encounters. + + # If the connected component has virtual and dummy indices only + # (no free indices), it contributes to `resulting_indices` by a factor of two. + # The multiplication by two is a result of the + # factor {gamma^0, gamma^0} = 2 I, as it appears in Kahane's paper. + # Note: curly brackets are meant as in the paper, as a generalized + # multi-element anticommutator! + + while links: + connected_components += 1 + pointer = min(links.keys()) + previous_pointer = pointer + # the inner loop erases the visited indices from `links`, and it adds + # all free indices to `prepend_indices` list, virtual indices are + # ignored. + prepend_indices = [] + while True: + if pointer in links: + next_ones = links.pop(pointer) + else: + break + + if previous_pointer in next_ones: + if len(next_ones) > 1: + next_ones.remove(previous_pointer) + + previous_pointer = pointer + + if next_ones: + pointer = next_ones[0] + + if pointer >= first_dum_pos and free_pos[pointer] is not None: + prepend_indices.insert(0, free_pos[pointer]) + # if `prepend_indices` is void, it means there are no free indices + # in the loop (and it can be shown that there must be a virtual index), + # loops of virtual indices only contribute by a factor of two: + if len(prepend_indices) == 0: + resulting_coeff *= 2 + # otherwise, add the free indices in `prepend_indices` to + # the `resulting_indices`: + else: + expr1 = prepend_indices + expr2 = list(reversed(prepend_indices)) + resulting_indices = [expri + ri for ri in resulting_indices for expri in (expr1, expr2)] + + # sign correction, as described in Kahane's paper: + resulting_coeff *= -1 if (number_of_contractions - connected_components + 1) % 2 else 1 + # power of two factor, as described in Kahane's paper: + resulting_coeff *= 2**(number_of_contractions) + + # If `first_dum_pos` is not zero, it means that there are trailing free gamma + # matrices in front of `expression`, so multiply by them: + resulting_indices = [ free_pos[0:first_dum_pos] + ri for ri in resulting_indices ] + + resulting_expr = S.Zero + for i in resulting_indices: + temp_expr = S.One + for j in i: + temp_expr *= GammaMatrix(j) + resulting_expr += temp_expr + + t = resulting_coeff * resulting_expr + t1 = None + if isinstance(t, TensAdd): + t1 = t.args[0] + elif isinstance(t, TensMul): + t1 = t + if t1: + pass + else: + t = eye(4)*t + return t diff --git a/phivenv/Lib/site-packages/sympy/physics/hep/tests/__init__.py b/phivenv/Lib/site-packages/sympy/physics/hep/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/physics/hep/tests/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/hep/tests/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e17a4dc67e2fc732b296e1395fc1a9e216153a0 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/hep/tests/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/hep/tests/__pycache__/test_gamma_matrices.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/hep/tests/__pycache__/test_gamma_matrices.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f12e59d9b5c75ebd60733ec48e78c470babb8f35 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/hep/tests/__pycache__/test_gamma_matrices.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/hep/tests/test_gamma_matrices.py b/phivenv/Lib/site-packages/sympy/physics/hep/tests/test_gamma_matrices.py new file mode 100644 index 0000000000000000000000000000000000000000..1552cf0d19be222ba249a7e32c65c8c3abc54ac2 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/hep/tests/test_gamma_matrices.py @@ -0,0 +1,427 @@ +from sympy.matrices.dense import eye, Matrix +from sympy.tensor.tensor import tensor_indices, TensorHead, tensor_heads, \ + TensExpr, canon_bp +from sympy.physics.hep.gamma_matrices import GammaMatrix as G, LorentzIndex, \ + kahane_simplify, gamma_trace, _simplify_single_line, simplify_gamma_expression +from sympy import Symbol + + +def _is_tensor_eq(arg1, arg2): + arg1 = canon_bp(arg1) + arg2 = canon_bp(arg2) + if isinstance(arg1, TensExpr): + return arg1.equals(arg2) + elif isinstance(arg2, TensExpr): + return arg2.equals(arg1) + return arg1 == arg2 + +def execute_gamma_simplify_tests_for_function(tfunc, D): + """ + Perform tests to check if sfunc is able to simplify gamma matrix expressions. + + Parameters + ========== + + `sfunc` a function to simplify a `TIDS`, shall return the simplified `TIDS`. + `D` the number of dimension (in most cases `D=4`). + + """ + + mu, nu, rho, sigma = tensor_indices("mu, nu, rho, sigma", LorentzIndex) + a1, a2, a3, a4, a5, a6 = tensor_indices("a1:7", LorentzIndex) + mu11, mu12, mu21, mu31, mu32, mu41, mu51, mu52 = tensor_indices("mu11, mu12, mu21, mu31, mu32, mu41, mu51, mu52", LorentzIndex) + mu61, mu71, mu72 = tensor_indices("mu61, mu71, mu72", LorentzIndex) + m0, m1, m2, m3, m4, m5, m6 = tensor_indices("m0:7", LorentzIndex) + + def g(xx, yy): + return (G(xx)*G(yy) + G(yy)*G(xx))/2 + + # Some examples taken from Kahane's paper, 4 dim only: + if D == 4: + t = (G(a1)*G(mu11)*G(a2)*G(mu21)*G(-a1)*G(mu31)*G(-a2)) + assert _is_tensor_eq(tfunc(t), -4*G(mu11)*G(mu31)*G(mu21) - 4*G(mu31)*G(mu11)*G(mu21)) + + t = (G(a1)*G(mu11)*G(mu12)*\ + G(a2)*G(mu21)*\ + G(a3)*G(mu31)*G(mu32)*\ + G(a4)*G(mu41)*\ + G(-a2)*G(mu51)*G(mu52)*\ + G(-a1)*G(mu61)*\ + G(-a3)*G(mu71)*G(mu72)*\ + G(-a4)) + assert _is_tensor_eq(tfunc(t), \ + 16*G(mu31)*G(mu32)*G(mu72)*G(mu71)*G(mu11)*G(mu52)*G(mu51)*G(mu12)*G(mu61)*G(mu21)*G(mu41) + 16*G(mu31)*G(mu32)*G(mu72)*G(mu71)*G(mu12)*G(mu51)*G(mu52)*G(mu11)*G(mu61)*G(mu21)*G(mu41) + 16*G(mu71)*G(mu72)*G(mu32)*G(mu31)*G(mu11)*G(mu52)*G(mu51)*G(mu12)*G(mu61)*G(mu21)*G(mu41) + 16*G(mu71)*G(mu72)*G(mu32)*G(mu31)*G(mu12)*G(mu51)*G(mu52)*G(mu11)*G(mu61)*G(mu21)*G(mu41)) + + # Fully Lorentz-contracted expressions, these return scalars: + + def add_delta(ne): + return ne * eye(4) # DiracSpinorIndex.delta(DiracSpinorIndex.auto_left, -DiracSpinorIndex.auto_right) + + t = (G(mu)*G(-mu)) + ts = add_delta(D) + assert _is_tensor_eq(tfunc(t), ts) + + t = (G(mu)*G(nu)*G(-mu)*G(-nu)) + ts = add_delta(2*D - D**2) # -8 + assert _is_tensor_eq(tfunc(t), ts) + + t = (G(mu)*G(nu)*G(-nu)*G(-mu)) + ts = add_delta(D**2) # 16 + assert _is_tensor_eq(tfunc(t), ts) + + t = (G(mu)*G(nu)*G(-rho)*G(-nu)*G(-mu)*G(rho)) + ts = add_delta(4*D - 4*D**2 + D**3) # 16 + assert _is_tensor_eq(tfunc(t), ts) + + t = (G(mu)*G(nu)*G(rho)*G(-rho)*G(-nu)*G(-mu)) + ts = add_delta(D**3) # 64 + assert _is_tensor_eq(tfunc(t), ts) + + t = (G(a1)*G(a2)*G(a3)*G(a4)*G(-a3)*G(-a1)*G(-a2)*G(-a4)) + ts = add_delta(-8*D + 16*D**2 - 8*D**3 + D**4) # -32 + assert _is_tensor_eq(tfunc(t), ts) + + t = (G(-mu)*G(-nu)*G(-rho)*G(-sigma)*G(nu)*G(mu)*G(sigma)*G(rho)) + ts = add_delta(-16*D + 24*D**2 - 8*D**3 + D**4) # 64 + assert _is_tensor_eq(tfunc(t), ts) + + t = (G(-mu)*G(nu)*G(-rho)*G(sigma)*G(rho)*G(-nu)*G(mu)*G(-sigma)) + ts = add_delta(8*D - 12*D**2 + 6*D**3 - D**4) # -32 + assert _is_tensor_eq(tfunc(t), ts) + + t = (G(a1)*G(a2)*G(a3)*G(a4)*G(a5)*G(-a3)*G(-a2)*G(-a1)*G(-a5)*G(-a4)) + ts = add_delta(64*D - 112*D**2 + 60*D**3 - 12*D**4 + D**5) # 256 + assert _is_tensor_eq(tfunc(t), ts) + + t = (G(a1)*G(a2)*G(a3)*G(a4)*G(a5)*G(-a3)*G(-a1)*G(-a2)*G(-a4)*G(-a5)) + ts = add_delta(64*D - 120*D**2 + 72*D**3 - 16*D**4 + D**5) # -128 + assert _is_tensor_eq(tfunc(t), ts) + + t = (G(a1)*G(a2)*G(a3)*G(a4)*G(a5)*G(a6)*G(-a3)*G(-a2)*G(-a1)*G(-a6)*G(-a5)*G(-a4)) + ts = add_delta(416*D - 816*D**2 + 528*D**3 - 144*D**4 + 18*D**5 - D**6) # -128 + assert _is_tensor_eq(tfunc(t), ts) + + t = (G(a1)*G(a2)*G(a3)*G(a4)*G(a5)*G(a6)*G(-a2)*G(-a3)*G(-a1)*G(-a6)*G(-a4)*G(-a5)) + ts = add_delta(416*D - 848*D**2 + 584*D**3 - 172*D**4 + 22*D**5 - D**6) # -128 + assert _is_tensor_eq(tfunc(t), ts) + + # Expressions with free indices: + + t = (G(mu)*G(nu)*G(rho)*G(sigma)*G(-mu)) + assert _is_tensor_eq(tfunc(t), (-2*G(sigma)*G(rho)*G(nu) + (4-D)*G(nu)*G(rho)*G(sigma))) + + t = (G(mu)*G(nu)*G(-mu)) + assert _is_tensor_eq(tfunc(t), (2-D)*G(nu)) + + t = (G(mu)*G(nu)*G(rho)*G(-mu)) + assert _is_tensor_eq(tfunc(t), 2*G(nu)*G(rho) + 2*G(rho)*G(nu) - (4-D)*G(nu)*G(rho)) + + t = 2*G(m2)*G(m0)*G(m1)*G(-m0)*G(-m1) + st = tfunc(t) + assert _is_tensor_eq(st, (D*(-2*D + 4))*G(m2)) + + t = G(m2)*G(m0)*G(m1)*G(-m0)*G(-m2) + st = tfunc(t) + assert _is_tensor_eq(st, ((-D + 2)**2)*G(m1)) + + t = G(m0)*G(m1)*G(m2)*G(m3)*G(-m1) + st = tfunc(t) + assert _is_tensor_eq(st, (D - 4)*G(m0)*G(m2)*G(m3) + 4*G(m0)*g(m2, m3)) + + t = G(m0)*G(m1)*G(m2)*G(m3)*G(-m1)*G(-m0) + st = tfunc(t) + assert _is_tensor_eq(st, ((D - 4)**2)*G(m2)*G(m3) + (8*D - 16)*g(m2, m3)) + + t = G(m2)*G(m0)*G(m1)*G(-m2)*G(-m0) + st = tfunc(t) + assert _is_tensor_eq(st, ((-D + 2)*(D - 4) + 4)*G(m1)) + + t = G(m3)*G(m1)*G(m0)*G(m2)*G(-m3)*G(-m0)*G(-m2) + st = tfunc(t) + assert _is_tensor_eq(st, (-4*D + (-D + 2)**2*(D - 4) + 8)*G(m1)) + + t = 2*G(m0)*G(m1)*G(m2)*G(m3)*G(-m0) + st = tfunc(t) + assert _is_tensor_eq(st, ((-2*D + 8)*G(m1)*G(m2)*G(m3) - 4*G(m3)*G(m2)*G(m1))) + + t = G(m5)*G(m0)*G(m1)*G(m4)*G(m2)*G(-m4)*G(m3)*G(-m0) + st = tfunc(t) + assert _is_tensor_eq(st, (((-D + 2)*(-D + 4))*G(m5)*G(m1)*G(m2)*G(m3) + (2*D - 4)*G(m5)*G(m3)*G(m2)*G(m1))) + + t = -G(m0)*G(m1)*G(m2)*G(m3)*G(-m0)*G(m4) + st = tfunc(t) + assert _is_tensor_eq(st, ((D - 4)*G(m1)*G(m2)*G(m3)*G(m4) + 2*G(m3)*G(m2)*G(m1)*G(m4))) + + t = G(-m5)*G(m0)*G(m1)*G(m2)*G(m3)*G(m4)*G(-m0)*G(m5) + st = tfunc(t) + + result1 = ((-D + 4)**2 + 4)*G(m1)*G(m2)*G(m3)*G(m4) +\ + (4*D - 16)*G(m3)*G(m2)*G(m1)*G(m4) + (4*D - 16)*G(m4)*G(m1)*G(m2)*G(m3)\ + + 4*G(m2)*G(m1)*G(m4)*G(m3) + 4*G(m3)*G(m4)*G(m1)*G(m2) +\ + 4*G(m4)*G(m3)*G(m2)*G(m1) + + # Kahane's algorithm yields this result, which is equivalent to `result1` + # in four dimensions, but is not automatically recognized as equal: + result2 = 8*G(m1)*G(m2)*G(m3)*G(m4) + 8*G(m4)*G(m3)*G(m2)*G(m1) + + if D == 4: + assert _is_tensor_eq(st, (result1)) or _is_tensor_eq(st, (result2)) + else: + assert _is_tensor_eq(st, (result1)) + + # and a few very simple cases, with no contracted indices: + + t = G(m0) + st = tfunc(t) + assert _is_tensor_eq(st, t) + + t = -7*G(m0) + st = tfunc(t) + assert _is_tensor_eq(st, t) + + t = 224*G(m0)*G(m1)*G(-m2)*G(m3) + st = tfunc(t) + assert _is_tensor_eq(st, t) + + +def test_kahane_algorithm(): + # Wrap this function to convert to and from TIDS: + + def tfunc(e): + return _simplify_single_line(e) + + execute_gamma_simplify_tests_for_function(tfunc, D=4) + + +def test_kahane_simplify1(): + i0,i1,i2,i3,i4,i5,i6,i7,i8,i9,i10,i11,i12,i13,i14,i15 = tensor_indices('i0:16', LorentzIndex) + mu, nu, rho, sigma = tensor_indices("mu, nu, rho, sigma", LorentzIndex) + D = 4 + t = G(i0)*G(i1) + r = kahane_simplify(t) + assert r.equals(t) + + t = G(i0)*G(i1)*G(-i0) + r = kahane_simplify(t) + assert r.equals(-2*G(i1)) + t = G(i0)*G(i1)*G(-i0) + r = kahane_simplify(t) + assert r.equals(-2*G(i1)) + + t = G(i0)*G(i1) + r = kahane_simplify(t) + assert r.equals(t) + t = G(i0)*G(i1) + r = kahane_simplify(t) + assert r.equals(t) + t = G(i0)*G(-i0) + r = kahane_simplify(t) + assert r.equals(4*eye(4)) + t = G(i0)*G(-i0) + r = kahane_simplify(t) + assert r.equals(4*eye(4)) + t = G(i0)*G(-i0) + r = kahane_simplify(t) + assert r.equals(4*eye(4)) + t = G(i0)*G(i1)*G(-i0) + r = kahane_simplify(t) + assert r.equals(-2*G(i1)) + t = G(i0)*G(i1)*G(-i0)*G(-i1) + r = kahane_simplify(t) + assert r.equals((2*D - D**2)*eye(4)) + t = G(i0)*G(i1)*G(-i0)*G(-i1) + r = kahane_simplify(t) + assert r.equals((2*D - D**2)*eye(4)) + t = G(i0)*G(-i0)*G(i1)*G(-i1) + r = kahane_simplify(t) + assert r.equals(16*eye(4)) + t = (G(mu)*G(nu)*G(-nu)*G(-mu)) + r = kahane_simplify(t) + assert r.equals(D**2*eye(4)) + t = (G(mu)*G(nu)*G(-nu)*G(-mu)) + r = kahane_simplify(t) + assert r.equals(D**2*eye(4)) + t = (G(mu)*G(nu)*G(-nu)*G(-mu)) + r = kahane_simplify(t) + assert r.equals(D**2*eye(4)) + t = (G(mu)*G(nu)*G(-rho)*G(-nu)*G(-mu)*G(rho)) + r = kahane_simplify(t) + assert r.equals((4*D - 4*D**2 + D**3)*eye(4)) + t = (G(-mu)*G(-nu)*G(-rho)*G(-sigma)*G(nu)*G(mu)*G(sigma)*G(rho)) + r = kahane_simplify(t) + assert r.equals((-16*D + 24*D**2 - 8*D**3 + D**4)*eye(4)) + t = (G(-mu)*G(nu)*G(-rho)*G(sigma)*G(rho)*G(-nu)*G(mu)*G(-sigma)) + r = kahane_simplify(t) + assert r.equals((8*D - 12*D**2 + 6*D**3 - D**4)*eye(4)) + + # Expressions with free indices: + t = (G(mu)*G(nu)*G(rho)*G(sigma)*G(-mu)) + r = kahane_simplify(t) + assert r.equals(-2*G(sigma)*G(rho)*G(nu)) + t = (G(mu)*G(-mu)*G(rho)*G(sigma)) + r = kahane_simplify(t) + assert r.equals(4*G(rho)*G(sigma)) + t = (G(rho)*G(sigma)*G(mu)*G(-mu)) + r = kahane_simplify(t) + assert r.equals(4*G(rho)*G(sigma)) + +def test_gamma_matrix_class(): + i, j, k = tensor_indices('i,j,k', LorentzIndex) + + # define another type of TensorHead to see if exprs are correctly handled: + A = TensorHead('A', [LorentzIndex]) + + t = A(k)*G(i)*G(-i) + ts = simplify_gamma_expression(t) + assert _is_tensor_eq(ts, Matrix([ + [4, 0, 0, 0], + [0, 4, 0, 0], + [0, 0, 4, 0], + [0, 0, 0, 4]])*A(k)) + + t = G(i)*A(k)*G(j) + ts = simplify_gamma_expression(t) + assert _is_tensor_eq(ts, A(k)*G(i)*G(j)) + + execute_gamma_simplify_tests_for_function(simplify_gamma_expression, D=4) + + +def test_gamma_matrix_trace(): + g = LorentzIndex.metric + + m0, m1, m2, m3, m4, m5, m6 = tensor_indices('m0:7', LorentzIndex) + n0, n1, n2, n3, n4, n5 = tensor_indices('n0:6', LorentzIndex) + + # working in D=4 dimensions + D = 4 + + # traces of odd number of gamma matrices are zero: + t = G(m0) + t1 = gamma_trace(t) + assert t1.equals(0) + + t = G(m0)*G(m1)*G(m2) + t1 = gamma_trace(t) + assert t1.equals(0) + + t = G(m0)*G(m1)*G(-m0) + t1 = gamma_trace(t) + assert t1.equals(0) + + t = G(m0)*G(m1)*G(m2)*G(m3)*G(m4) + t1 = gamma_trace(t) + assert t1.equals(0) + + # traces without internal contractions: + t = G(m0)*G(m1) + t1 = gamma_trace(t) + assert _is_tensor_eq(t1, 4*g(m0, m1)) + + t = G(m0)*G(m1)*G(m2)*G(m3) + t1 = gamma_trace(t) + t2 = -4*g(m0, m2)*g(m1, m3) + 4*g(m0, m1)*g(m2, m3) + 4*g(m0, m3)*g(m1, m2) + assert _is_tensor_eq(t1, t2) + + t = G(m0)*G(m1)*G(m2)*G(m3)*G(m4)*G(m5) + t1 = gamma_trace(t) + t2 = t1*g(-m0, -m5) + t2 = t2.contract_metric(g) + assert _is_tensor_eq(t2, D*gamma_trace(G(m1)*G(m2)*G(m3)*G(m4))) + + # traces of expressions with internal contractions: + t = G(m0)*G(-m0) + t1 = gamma_trace(t) + assert t1.equals(4*D) + + t = G(m0)*G(m1)*G(-m0)*G(-m1) + t1 = gamma_trace(t) + assert t1.equals(8*D - 4*D**2) + + t = G(m0)*G(m1)*G(m2)*G(m3)*G(m4)*G(-m0) + t1 = gamma_trace(t) + t2 = (-4*D)*g(m1, m3)*g(m2, m4) + (4*D)*g(m1, m2)*g(m3, m4) + \ + (4*D)*g(m1, m4)*g(m2, m3) + assert _is_tensor_eq(t1, t2) + + t = G(-m5)*G(m0)*G(m1)*G(m2)*G(m3)*G(m4)*G(-m0)*G(m5) + t1 = gamma_trace(t) + t2 = (32*D + 4*(-D + 4)**2 - 64)*(g(m1, m2)*g(m3, m4) - \ + g(m1, m3)*g(m2, m4) + g(m1, m4)*g(m2, m3)) + assert _is_tensor_eq(t1, t2) + + t = G(m0)*G(m1)*G(-m0)*G(m3) + t1 = gamma_trace(t) + assert t1.equals((-4*D + 8)*g(m1, m3)) + +# p, q = S1('p,q') +# ps = p(m0)*G(-m0) +# qs = q(m0)*G(-m0) +# t = ps*qs*ps*qs +# t1 = gamma_trace(t) +# assert t1 == 8*p(m0)*q(-m0)*p(m1)*q(-m1) - 4*p(m0)*p(-m0)*q(m1)*q(-m1) + + t = G(m0)*G(m1)*G(m2)*G(m3)*G(m4)*G(m5)*G(-m0)*G(-m1)*G(-m2)*G(-m3)*G(-m4)*G(-m5) + t1 = gamma_trace(t) + assert t1.equals(-4*D**6 + 120*D**5 - 1040*D**4 + 3360*D**3 - 4480*D**2 + 2048*D) + + t = G(m0)*G(m1)*G(n1)*G(m2)*G(n2)*G(m3)*G(m4)*G(-n2)*G(-n1)*G(-m0)*G(-m1)*G(-m2)*G(-m3)*G(-m4) + t1 = gamma_trace(t) + tresu = -7168*D + 16768*D**2 - 14400*D**3 + 5920*D**4 - 1232*D**5 + 120*D**6 - 4*D**7 + assert t1.equals(tresu) + + # checked with Mathematica + # In[1]:= <>> from sympy.physics.hydrogen import R_nl + >>> from sympy.abc import r, Z + >>> R_nl(1, 0, r, Z) + 2*sqrt(Z**3)*exp(-Z*r) + >>> R_nl(2, 0, r, Z) + sqrt(2)*(-Z*r + 2)*sqrt(Z**3)*exp(-Z*r/2)/4 + >>> R_nl(2, 1, r, Z) + sqrt(6)*Z*r*sqrt(Z**3)*exp(-Z*r/2)/12 + + For Hydrogen atom, you can just use the default value of Z=1: + + >>> R_nl(1, 0, r) + 2*exp(-r) + >>> R_nl(2, 0, r) + sqrt(2)*(2 - r)*exp(-r/2)/4 + >>> R_nl(3, 0, r) + 2*sqrt(3)*(2*r**2/9 - 2*r + 3)*exp(-r/3)/27 + + For Silver atom, you would use Z=47: + + >>> R_nl(1, 0, r, Z=47) + 94*sqrt(47)*exp(-47*r) + >>> R_nl(2, 0, r, Z=47) + 47*sqrt(94)*(2 - 47*r)*exp(-47*r/2)/4 + >>> R_nl(3, 0, r, Z=47) + 94*sqrt(141)*(4418*r**2/9 - 94*r + 3)*exp(-47*r/3)/27 + + The normalization of the radial wavefunction is: + + >>> from sympy import integrate, oo + >>> integrate(R_nl(1, 0, r)**2 * r**2, (r, 0, oo)) + 1 + >>> integrate(R_nl(2, 0, r)**2 * r**2, (r, 0, oo)) + 1 + >>> integrate(R_nl(2, 1, r)**2 * r**2, (r, 0, oo)) + 1 + + It holds for any atomic number: + + >>> integrate(R_nl(1, 0, r, Z=2)**2 * r**2, (r, 0, oo)) + 1 + >>> integrate(R_nl(2, 0, r, Z=3)**2 * r**2, (r, 0, oo)) + 1 + >>> integrate(R_nl(2, 1, r, Z=4)**2 * r**2, (r, 0, oo)) + 1 + + """ + # sympify arguments + n, l, r, Z = map(S, [n, l, r, Z]) + # radial quantum number + n_r = n - l - 1 + # rescaled "r" + a = 1/Z # Bohr radius + r0 = 2 * r / (n * a) + # normalization coefficient + C = sqrt((S(2)/(n*a))**3 * factorial(n_r) / (2*n*factorial(n + l))) + # This is an equivalent normalization coefficient, that can be found in + # some books. Both coefficients seem to be the same fast: + # C = S(2)/n**2 * sqrt(1/a**3 * factorial(n_r) / (factorial(n+l))) + return C * r0**l * assoc_laguerre(n_r, 2*l + 1, r0).expand() * exp(-r0/2) + + +def Psi_nlm(n, l, m, r, phi, theta, Z=1): + """ + Returns the Hydrogen wave function psi_{nlm}. It's the product of + the radial wavefunction R_{nl} and the spherical harmonic Y_{l}^{m}. + + Parameters + ========== + + n : integer + Principal Quantum Number which is + an integer with possible values as 1, 2, 3, 4,... + l : integer + ``l`` is the Angular Momentum Quantum Number with + values ranging from 0 to ``n-1``. + m : integer + ``m`` is the Magnetic Quantum Number with values + ranging from ``-l`` to ``l``. + r : + radial coordinate + phi : + azimuthal angle + theta : + polar angle + Z : + atomic number (1 for Hydrogen, 2 for Helium, ...) + + Everything is in Hartree atomic units. + + Examples + ======== + + >>> from sympy.physics.hydrogen import Psi_nlm + >>> from sympy import Symbol + >>> r=Symbol("r", positive=True) + >>> phi=Symbol("phi", real=True) + >>> theta=Symbol("theta", real=True) + >>> Z=Symbol("Z", positive=True, integer=True, nonzero=True) + >>> Psi_nlm(1,0,0,r,phi,theta,Z) + Z**(3/2)*exp(-Z*r)/sqrt(pi) + >>> Psi_nlm(2,1,1,r,phi,theta,Z) + -Z**(5/2)*r*exp(I*phi)*exp(-Z*r/2)*sin(theta)/(8*sqrt(pi)) + + Integrating the absolute square of a hydrogen wavefunction psi_{nlm} + over the whole space leads 1. + + The normalization of the hydrogen wavefunctions Psi_nlm is: + + >>> from sympy import integrate, conjugate, pi, oo, sin + >>> wf=Psi_nlm(2,1,1,r,phi,theta,Z) + >>> abs_sqrd=wf*conjugate(wf) + >>> jacobi=r**2*sin(theta) + >>> integrate(abs_sqrd*jacobi, (r,0,oo), (phi,0,2*pi), (theta,0,pi)) + 1 + """ + + # sympify arguments + n, l, m, r, phi, theta, Z = map(S, [n, l, m, r, phi, theta, Z]) + # check if values for n,l,m make physically sense + if n.is_integer and n < 1: + raise ValueError("'n' must be positive integer") + if l.is_integer and not (n > l): + raise ValueError("'n' must be greater than 'l'") + if m.is_integer and not (abs(m) <= l): + raise ValueError("|'m'| must be less or equal 'l'") + # return the hydrogen wave function + return R_nl(n, l, r, Z)*Ynm(l, m, theta, phi).expand(func=True) + + +def E_nl(n, Z=1): + """ + Returns the energy of the state (n, l) in Hartree atomic units. + + The energy does not depend on "l". + + Parameters + ========== + + n : integer + Principal Quantum Number which is + an integer with possible values as 1, 2, 3, 4,... + Z : + Atomic number (1 for Hydrogen, 2 for Helium, ...) + + Examples + ======== + + >>> from sympy.physics.hydrogen import E_nl + >>> from sympy.abc import n, Z + >>> E_nl(n, Z) + -Z**2/(2*n**2) + >>> E_nl(1) + -1/2 + >>> E_nl(2) + -1/8 + >>> E_nl(3) + -1/18 + >>> E_nl(3, 47) + -2209/18 + + """ + n, Z = S(n), S(Z) + if n.is_integer and (n < 1): + raise ValueError("'n' must be positive integer") + return -Z**2/(2*n**2) + + +def E_nl_dirac(n, l, spin_up=True, Z=1, c=Float("137.035999037")): + """ + Returns the relativistic energy of the state (n, l, spin) in Hartree atomic + units. + + The energy is calculated from the Dirac equation. The rest mass energy is + *not* included. + + Parameters + ========== + + n : integer + Principal Quantum Number which is + an integer with possible values as 1, 2, 3, 4,... + l : integer + ``l`` is the Angular Momentum Quantum Number with + values ranging from 0 to ``n-1``. + spin_up : + True if the electron spin is up (default), otherwise down + Z : + Atomic number (1 for Hydrogen, 2 for Helium, ...) + c : + Speed of light in atomic units. Default value is 137.035999037, + taken from https://arxiv.org/abs/1012.3627 + + Examples + ======== + + >>> from sympy.physics.hydrogen import E_nl_dirac + >>> E_nl_dirac(1, 0) + -0.500006656595360 + + >>> E_nl_dirac(2, 0) + -0.125002080189006 + >>> E_nl_dirac(2, 1) + -0.125000416028342 + >>> E_nl_dirac(2, 1, False) + -0.125002080189006 + + >>> E_nl_dirac(3, 0) + -0.0555562951740285 + >>> E_nl_dirac(3, 1) + -0.0555558020932949 + >>> E_nl_dirac(3, 1, False) + -0.0555562951740285 + >>> E_nl_dirac(3, 2) + -0.0555556377366884 + >>> E_nl_dirac(3, 2, False) + -0.0555558020932949 + + """ + n, l, Z, c = map(S, [n, l, Z, c]) + if not (l >= 0): + raise ValueError("'l' must be positive or zero") + if not (n > l): + raise ValueError("'n' must be greater than 'l'") + if (l == 0 and spin_up is False): + raise ValueError("Spin must be up for l==0.") + # skappa is sign*kappa, where sign contains the correct sign + if spin_up: + skappa = -l - 1 + else: + skappa = -l + beta = sqrt(skappa**2 - Z**2/c**2) + return c**2/sqrt(1 + Z**2/(n + skappa + beta)**2/c**2) - c**2 diff --git a/phivenv/Lib/site-packages/sympy/physics/matrices.py b/phivenv/Lib/site-packages/sympy/physics/matrices.py new file mode 100644 index 0000000000000000000000000000000000000000..d91466220d63956053b91bd76b948ee677e7c191 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/matrices.py @@ -0,0 +1,176 @@ +"""Known matrices related to physics""" + +from sympy.core.numbers import I +from sympy.matrices.dense import MutableDenseMatrix as Matrix +from sympy.utilities.decorator import deprecated + + +def msigma(i): + r"""Returns a Pauli matrix `\sigma_i` with `i=1,2,3`. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Pauli_matrices + + Examples + ======== + + >>> from sympy.physics.matrices import msigma + >>> msigma(1) + Matrix([ + [0, 1], + [1, 0]]) + """ + if i == 1: + mat = ( + (0, 1), + (1, 0) + ) + elif i == 2: + mat = ( + (0, -I), + (I, 0) + ) + elif i == 3: + mat = ( + (1, 0), + (0, -1) + ) + else: + raise IndexError("Invalid Pauli index") + return Matrix(mat) + + +def pat_matrix(m, dx, dy, dz): + """Returns the Parallel Axis Theorem matrix to translate the inertia + matrix a distance of `(dx, dy, dz)` for a body of mass m. + + Examples + ======== + + To translate a body having a mass of 2 units a distance of 1 unit along + the `x`-axis we get: + + >>> from sympy.physics.matrices import pat_matrix + >>> pat_matrix(2, 1, 0, 0) + Matrix([ + [0, 0, 0], + [0, 2, 0], + [0, 0, 2]]) + + """ + dxdy = -dx*dy + dydz = -dy*dz + dzdx = -dz*dx + dxdx = dx**2 + dydy = dy**2 + dzdz = dz**2 + mat = ((dydy + dzdz, dxdy, dzdx), + (dxdy, dxdx + dzdz, dydz), + (dzdx, dydz, dydy + dxdx)) + return m*Matrix(mat) + + +def mgamma(mu, lower=False): + r"""Returns a Dirac gamma matrix `\gamma^\mu` in the standard + (Dirac) representation. + + Explanation + =========== + + If you want `\gamma_\mu`, use ``gamma(mu, True)``. + + We use a convention: + + `\gamma^5 = i \cdot \gamma^0 \cdot \gamma^1 \cdot \gamma^2 \cdot \gamma^3` + + `\gamma_5 = i \cdot \gamma_0 \cdot \gamma_1 \cdot \gamma_2 \cdot \gamma_3 = - \gamma^5` + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Gamma_matrices + + Examples + ======== + + >>> from sympy.physics.matrices import mgamma + >>> mgamma(1) + Matrix([ + [ 0, 0, 0, 1], + [ 0, 0, 1, 0], + [ 0, -1, 0, 0], + [-1, 0, 0, 0]]) + """ + if mu not in (0, 1, 2, 3, 5): + raise IndexError("Invalid Dirac index") + if mu == 0: + mat = ( + (1, 0, 0, 0), + (0, 1, 0, 0), + (0, 0, -1, 0), + (0, 0, 0, -1) + ) + elif mu == 1: + mat = ( + (0, 0, 0, 1), + (0, 0, 1, 0), + (0, -1, 0, 0), + (-1, 0, 0, 0) + ) + elif mu == 2: + mat = ( + (0, 0, 0, -I), + (0, 0, I, 0), + (0, I, 0, 0), + (-I, 0, 0, 0) + ) + elif mu == 3: + mat = ( + (0, 0, 1, 0), + (0, 0, 0, -1), + (-1, 0, 0, 0), + (0, 1, 0, 0) + ) + elif mu == 5: + mat = ( + (0, 0, 1, 0), + (0, 0, 0, 1), + (1, 0, 0, 0), + (0, 1, 0, 0) + ) + m = Matrix(mat) + if lower: + if mu in (1, 2, 3, 5): + m = -m + return m + +#Minkowski tensor using the convention (+,-,-,-) used in the Quantum Field +#Theory +minkowski_tensor = Matrix( ( + (1, 0, 0, 0), + (0, -1, 0, 0), + (0, 0, -1, 0), + (0, 0, 0, -1) +)) + + +@deprecated( + """ + The sympy.physics.matrices.mdft method is deprecated. Use + sympy.DFT(n).as_explicit() instead. + """, + deprecated_since_version="1.9", + active_deprecations_target="deprecated-physics-mdft", +) +def mdft(n): + r""" + .. deprecated:: 1.9 + + Use DFT from sympy.matrices.expressions.fourier instead. + + To get identical behavior to ``mdft(n)``, use ``DFT(n).as_explicit()``. + """ + from sympy.matrices.expressions.fourier import DFT + return DFT(n).as_mutable() diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/__init__.py b/phivenv/Lib/site-packages/sympy/physics/mechanics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..afd8c071a2af4fd201d5b2371594b19e4a68edda --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/mechanics/__init__.py @@ -0,0 +1,90 @@ +__all__ = [ + 'vector', + + 'CoordinateSym', 'ReferenceFrame', 'Dyadic', 'Vector', 'Point', 'cross', + 'dot', 'express', 'time_derivative', 'outer', 'kinematic_equations', + 'get_motion_params', 'partial_velocity', 'dynamicsymbols', 'vprint', + 'vsstrrepr', 'vsprint', 'vpprint', 'vlatex', 'init_vprinting', 'curl', + 'divergence', 'gradient', 'is_conservative', 'is_solenoidal', + 'scalar_potential', 'scalar_potential_difference', + + 'KanesMethod', + + 'RigidBody', + + 'linear_momentum', 'angular_momentum', 'kinetic_energy', 'potential_energy', + 'Lagrangian', 'mechanics_printing', 'mprint', 'msprint', 'mpprint', + 'mlatex', 'msubs', 'find_dynamicsymbols', + + 'inertia', 'inertia_of_point_mass', 'Inertia', + + 'Force', 'Torque', + + 'Particle', + + 'LagrangesMethod', + + 'Linearizer', + + 'Body', + + 'SymbolicSystem', 'System', + + 'PinJoint', 'PrismaticJoint', 'CylindricalJoint', 'PlanarJoint', + 'SphericalJoint', 'WeldJoint', + + 'JointsMethod', + + 'WrappingCylinder', 'WrappingGeometryBase', 'WrappingSphere', + + 'PathwayBase', 'LinearPathway', 'ObstacleSetPathway', 'WrappingPathway', + + 'ActuatorBase', 'ForceActuator', 'LinearDamper', 'LinearSpring', + 'TorqueActuator', 'DuffingSpring', 'CoulombKineticFriction', +] + +from sympy.physics import vector + +from sympy.physics.vector import (CoordinateSym, ReferenceFrame, Dyadic, Vector, Point, + cross, dot, express, time_derivative, outer, kinematic_equations, + get_motion_params, partial_velocity, dynamicsymbols, vprint, + vsstrrepr, vsprint, vpprint, vlatex, init_vprinting, curl, divergence, + gradient, is_conservative, is_solenoidal, scalar_potential, + scalar_potential_difference) + +from .kane import KanesMethod + +from .rigidbody import RigidBody + +from .functions import (linear_momentum, angular_momentum, kinetic_energy, + potential_energy, Lagrangian, mechanics_printing, + mprint, msprint, mpprint, mlatex, msubs, + find_dynamicsymbols) + +from .inertia import inertia, inertia_of_point_mass, Inertia + +from .loads import Force, Torque + +from .particle import Particle + +from .lagrange import LagrangesMethod + +from .linearize import Linearizer + +from .body import Body + +from .system import SymbolicSystem, System + +from .jointsmethod import JointsMethod + +from .joint import (PinJoint, PrismaticJoint, CylindricalJoint, PlanarJoint, + SphericalJoint, WeldJoint) + +from .wrapping_geometry import (WrappingCylinder, WrappingGeometryBase, + WrappingSphere) + +from .pathway import (PathwayBase, LinearPathway, ObstacleSetPathway, + WrappingPathway) + +from .actuator import (ActuatorBase, ForceActuator, LinearDamper, LinearSpring, + TorqueActuator, DuffingSpring, CoulombKineticFriction) diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45f2bd0556566370763162166f87fbd35d0f7758 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/actuator.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/actuator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4978071cbc83af3044ccb333b5f817d6841bfa34 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/actuator.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/body.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/body.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88c214da62a6300a23d16b7e2ae052b44cf807ed Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/body.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/body_base.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/body_base.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c0a347ab09b99d99bd2f36ea798607b6902d195 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/body_base.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/functions.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/functions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24131bfa7c2a246d62dbd165d51035b9f8d4b3f9 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/functions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/inertia.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/inertia.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4208b90ad94cfa1d1280c03f46875665611cfd94 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/inertia.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/joint.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/joint.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a54bfc8fb560de2388805e6c32c9cd79f773ec7a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/joint.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/jointsmethod.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/jointsmethod.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30441eda9adaa4c3fdf386110cb43060676a7852 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/jointsmethod.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/kane.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/kane.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a93f0e0124526e056824582f1ae3a47609647a8 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/kane.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/lagrange.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/lagrange.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9be4a02064b6df368c680ee815ba330261a0bedd Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/lagrange.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/linearize.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/linearize.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..beeb455db9066411e9729a4e8ea17dcf9066e636 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/linearize.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/loads.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/loads.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe66e89d81c16586af79d0659f240220c0bf2388 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/loads.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/method.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/method.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6dd68d46f2395179058f32c84360c2d0b16d64f1 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/method.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/models.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/models.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..940a5a3c190a3b00c3f59a7bd134a1ef5f8155cf Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/models.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/particle.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/particle.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95c769ce2a580e66aea1b4e1ea87487a3c07db6b Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/particle.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/pathway.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/pathway.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e94ce2473ad4e85051634d9b6784501de5a2ff6b Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/pathway.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/rigidbody.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/rigidbody.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb29f375afea801f41354e4f6f2ce33888163b8c Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/rigidbody.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/system.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/system.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d883a159b7698be0e1cb9cbab6fd4267513709c6 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/system.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/wrapping_geometry.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/wrapping_geometry.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45c78979d24b241a94ec2fa488aa4ed3ce1171ed Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/mechanics/__pycache__/wrapping_geometry.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/actuator.py b/phivenv/Lib/site-packages/sympy/physics/mechanics/actuator.py new file mode 100644 index 0000000000000000000000000000000000000000..625b3e55019e7545c6dfed073d388acba91a324c --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/mechanics/actuator.py @@ -0,0 +1,1147 @@ +"""Implementations of actuators for linked force and torque application.""" + +from abc import ABC, abstractmethod + +from sympy import S, sympify, exp, sign +from sympy.physics.mechanics.joint import PinJoint +from sympy.physics.mechanics.loads import Torque +from sympy.physics.mechanics.pathway import PathwayBase +from sympy.physics.mechanics.rigidbody import RigidBody +from sympy.physics.vector import ReferenceFrame, Vector + + +__all__ = [ + 'ActuatorBase', + 'ForceActuator', + 'LinearDamper', + 'LinearSpring', + 'TorqueActuator', + 'DuffingSpring', + 'CoulombKineticFriction', +] + + +class ActuatorBase(ABC): + """Abstract base class for all actuator classes to inherit from. + + Notes + ===== + + Instances of this class cannot be directly instantiated by users. However, + it can be used to created custom actuator types through subclassing. + + """ + + def __init__(self): + """Initializer for ``ActuatorBase``.""" + pass + + @abstractmethod + def to_loads(self): + """Loads required by the equations of motion method classes. + + Explanation + =========== + + ``KanesMethod`` requires a list of ``Point``-``Vector`` tuples to be + passed to the ``loads`` parameters of its ``kanes_equations`` method + when constructing the equations of motion. This method acts as a + utility to produce the correctly-structred pairs of points and vectors + required so that these can be easily concatenated with other items in + the list of loads and passed to ``KanesMethod.kanes_equations``. These + loads are also in the correct form to also be passed to the other + equations of motion method classes, e.g. ``LagrangesMethod``. + + """ + pass + + def __repr__(self): + """Default representation of an actuator.""" + return f'{self.__class__.__name__}()' + + +class ForceActuator(ActuatorBase): + """Force-producing actuator. + + Explanation + =========== + + A ``ForceActuator`` is an actuator that produces a (expansile) force along + its length. + + A force actuator uses a pathway instance to determine the direction and + number of forces that it applies to a system. Consider the simplest case + where a ``LinearPathway`` instance is used. This pathway is made up of two + points that can move relative to each other, and results in a pair of equal + and opposite forces acting on the endpoints. If the positive time-varying + Euclidean distance between the two points is defined, then the "extension + velocity" is the time derivative of this distance. The extension velocity + is positive when the two points are moving away from each other and + negative when moving closer to each other. The direction for the force + acting on either point is determined by constructing a unit vector directed + from the other point to this point. This establishes a sign convention such + that a positive force magnitude tends to push the points apart, this is the + meaning of "expansile" in this context. The following diagram shows the + positive force sense and the distance between the points:: + + P Q + o<--- F --->o + | | + |<--l(t)--->| + + Examples + ======== + + To construct an actuator, an expression (or symbol) must be supplied to + represent the force it can produce, alongside a pathway specifying its line + of action. Let's also create a global reference frame and spatially fix one + of the points in it while setting the other to be positioned such that it + can freely move in the frame's x direction specified by the coordinate + ``q``. + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import (ForceActuator, LinearPathway, + ... Point, ReferenceFrame) + >>> from sympy.physics.vector import dynamicsymbols + >>> N = ReferenceFrame('N') + >>> q = dynamicsymbols('q') + >>> force = symbols('F') + >>> pA, pB = Point('pA'), Point('pB') + >>> pA.set_vel(N, 0) + >>> pB.set_pos(pA, q*N.x) + >>> pB.pos_from(pA) + q(t)*N.x + >>> linear_pathway = LinearPathway(pA, pB) + >>> actuator = ForceActuator(force, linear_pathway) + >>> actuator + ForceActuator(F, LinearPathway(pA, pB)) + + Parameters + ========== + + force : Expr + The scalar expression defining the (expansile) force that the actuator + produces. + pathway : PathwayBase + The pathway that the actuator follows. This must be an instance of a + concrete subclass of ``PathwayBase``, e.g. ``LinearPathway``. + + """ + + def __init__(self, force, pathway): + """Initializer for ``ForceActuator``. + + Parameters + ========== + + force : Expr + The scalar expression defining the (expansile) force that the + actuator produces. + pathway : PathwayBase + The pathway that the actuator follows. This must be an instance of + a concrete subclass of ``PathwayBase``, e.g. ``LinearPathway``. + + """ + self.force = force + self.pathway = pathway + + @property + def force(self): + """The magnitude of the force produced by the actuator.""" + return self._force + + @force.setter + def force(self, force): + if hasattr(self, '_force'): + msg = ( + f'Can\'t set attribute `force` to {repr(force)} as it is ' + f'immutable.' + ) + raise AttributeError(msg) + self._force = sympify(force, strict=True) + + @property + def pathway(self): + """The ``Pathway`` defining the actuator's line of action.""" + return self._pathway + + @pathway.setter + def pathway(self, pathway): + if hasattr(self, '_pathway'): + msg = ( + f'Can\'t set attribute `pathway` to {repr(pathway)} as it is ' + f'immutable.' + ) + raise AttributeError(msg) + if not isinstance(pathway, PathwayBase): + msg = ( + f'Value {repr(pathway)} passed to `pathway` was of type ' + f'{type(pathway)}, must be {PathwayBase}.' + ) + raise TypeError(msg) + self._pathway = pathway + + def to_loads(self): + """Loads required by the equations of motion method classes. + + Explanation + =========== + + ``KanesMethod`` requires a list of ``Point``-``Vector`` tuples to be + passed to the ``loads`` parameters of its ``kanes_equations`` method + when constructing the equations of motion. This method acts as a + utility to produce the correctly-structred pairs of points and vectors + required so that these can be easily concatenated with other items in + the list of loads and passed to ``KanesMethod.kanes_equations``. These + loads are also in the correct form to also be passed to the other + equations of motion method classes, e.g. ``LagrangesMethod``. + + Examples + ======== + + The below example shows how to generate the loads produced by a force + actuator that follows a linear pathway. In this example we'll assume + that the force actuator is being used to model a simple linear spring. + First, create a linear pathway between two points separated by the + coordinate ``q`` in the ``x`` direction of the global frame ``N``. + + >>> from sympy.physics.mechanics import (LinearPathway, Point, + ... ReferenceFrame) + >>> from sympy.physics.vector import dynamicsymbols + >>> q = dynamicsymbols('q') + >>> N = ReferenceFrame('N') + >>> pA, pB = Point('pA'), Point('pB') + >>> pB.set_pos(pA, q*N.x) + >>> pathway = LinearPathway(pA, pB) + + Now create a symbol ``k`` to describe the spring's stiffness and + instantiate a force actuator that produces a (contractile) force + proportional to both the spring's stiffness and the pathway's length. + Note that actuator classes use the sign convention that expansile + forces are positive, so for a spring to produce a contractile force the + spring force needs to be calculated as the negative for the stiffness + multiplied by the length. + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import ForceActuator + >>> stiffness = symbols('k') + >>> spring_force = -stiffness*pathway.length + >>> spring = ForceActuator(spring_force, pathway) + + The forces produced by the spring can be generated in the list of loads + form that ``KanesMethod`` (and other equations of motion methods) + requires by calling the ``to_loads`` method. + + >>> spring.to_loads() + [(pA, k*q(t)*N.x), (pB, - k*q(t)*N.x)] + + A simple linear damper can be modeled in a similar way. Create another + symbol ``c`` to describe the dampers damping coefficient. This time + instantiate a force actuator that produces a force proportional to both + the damper's damping coefficient and the pathway's extension velocity. + Note that the damping force is negative as it acts in the opposite + direction to which the damper is changing in length. + + >>> damping_coefficient = symbols('c') + >>> damping_force = -damping_coefficient*pathway.extension_velocity + >>> damper = ForceActuator(damping_force, pathway) + + Again, the forces produces by the damper can be generated by calling + the ``to_loads`` method. + + >>> damper.to_loads() + [(pA, c*Derivative(q(t), t)*N.x), (pB, - c*Derivative(q(t), t)*N.x)] + + """ + return self.pathway.to_loads(self.force) + + def __repr__(self): + """Representation of a ``ForceActuator``.""" + return f'{self.__class__.__name__}({self.force}, {self.pathway})' + + +class LinearSpring(ForceActuator): + """A spring with its spring force as a linear function of its length. + + Explanation + =========== + + Note that the "linear" in the name ``LinearSpring`` refers to the fact that + the spring force is a linear function of the springs length. I.e. for a + linear spring with stiffness ``k``, distance between its ends of ``x``, and + an equilibrium length of ``0``, the spring force will be ``-k*x``, which is + a linear function in ``x``. To create a spring that follows a linear, or + straight, pathway between its two ends, a ``LinearPathway`` instance needs + to be passed to the ``pathway`` parameter. + + A ``LinearSpring`` is a subclass of ``ForceActuator`` and so follows the + same sign conventions for length, extension velocity, and the direction of + the forces it applies to its points of attachment on bodies. The sign + convention for the direction of forces is such that, for the case where a + linear spring is instantiated with a ``LinearPathway`` instance as its + pathway, they act to push the two ends of the spring away from one another. + Because springs produces a contractile force and acts to pull the two ends + together towards the equilibrium length when stretched, the scalar portion + of the forces on the endpoint are negative in order to flip the sign of the + forces on the endpoints when converted into vector quantities. The + following diagram shows the positive force sense and the distance between + the points:: + + P Q + o<--- F --->o + | | + |<--l(t)--->| + + Examples + ======== + + To construct a linear spring, an expression (or symbol) must be supplied to + represent the stiffness (spring constant) of the spring, alongside a + pathway specifying its line of action. Let's also create a global reference + frame and spatially fix one of the points in it while setting the other to + be positioned such that it can freely move in the frame's x direction + specified by the coordinate ``q``. + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import (LinearPathway, LinearSpring, + ... Point, ReferenceFrame) + >>> from sympy.physics.vector import dynamicsymbols + >>> N = ReferenceFrame('N') + >>> q = dynamicsymbols('q') + >>> stiffness = symbols('k') + >>> pA, pB = Point('pA'), Point('pB') + >>> pA.set_vel(N, 0) + >>> pB.set_pos(pA, q*N.x) + >>> pB.pos_from(pA) + q(t)*N.x + >>> linear_pathway = LinearPathway(pA, pB) + >>> spring = LinearSpring(stiffness, linear_pathway) + >>> spring + LinearSpring(k, LinearPathway(pA, pB)) + + This spring will produce a force that is proportional to both its stiffness + and the pathway's length. Note that this force is negative as SymPy's sign + convention for actuators is that negative forces are contractile. + + >>> spring.force + -k*sqrt(q(t)**2) + + To create a linear spring with a non-zero equilibrium length, an expression + (or symbol) can be passed to the ``equilibrium_length`` parameter on + construction on a ``LinearSpring`` instance. Let's create a symbol ``l`` + to denote a non-zero equilibrium length and create another linear spring. + + >>> l = symbols('l') + >>> spring = LinearSpring(stiffness, linear_pathway, equilibrium_length=l) + >>> spring + LinearSpring(k, LinearPathway(pA, pB), equilibrium_length=l) + + The spring force of this new spring is again proportional to both its + stiffness and the pathway's length. However, the spring will not produce + any force when ``q(t)`` equals ``l``. Note that the force will become + expansile when ``q(t)`` is less than ``l``, as expected. + + >>> spring.force + -k*(-l + sqrt(q(t)**2)) + + Parameters + ========== + + stiffness : Expr + The spring constant. + pathway : PathwayBase + The pathway that the actuator follows. This must be an instance of a + concrete subclass of ``PathwayBase``, e.g. ``LinearPathway``. + equilibrium_length : Expr, optional + The length at which the spring is in equilibrium, i.e. it produces no + force. The default value is 0, i.e. the spring force is a linear + function of the pathway's length with no constant offset. + + See Also + ======== + + ForceActuator: force-producing actuator (superclass of ``LinearSpring``). + LinearPathway: straight-line pathway between a pair of points. + + """ + + def __init__(self, stiffness, pathway, equilibrium_length=S.Zero): + """Initializer for ``LinearSpring``. + + Parameters + ========== + + stiffness : Expr + The spring constant. + pathway : PathwayBase + The pathway that the actuator follows. This must be an instance of + a concrete subclass of ``PathwayBase``, e.g. ``LinearPathway``. + equilibrium_length : Expr, optional + The length at which the spring is in equilibrium, i.e. it produces + no force. The default value is 0, i.e. the spring force is a linear + function of the pathway's length with no constant offset. + + """ + self.stiffness = stiffness + self.pathway = pathway + self.equilibrium_length = equilibrium_length + + @property + def force(self): + """The spring force produced by the linear spring.""" + return -self.stiffness*(self.pathway.length - self.equilibrium_length) + + @force.setter + def force(self, force): + raise AttributeError('Can\'t set computed attribute `force`.') + + @property + def stiffness(self): + """The spring constant for the linear spring.""" + return self._stiffness + + @stiffness.setter + def stiffness(self, stiffness): + if hasattr(self, '_stiffness'): + msg = ( + f'Can\'t set attribute `stiffness` to {repr(stiffness)} as it ' + f'is immutable.' + ) + raise AttributeError(msg) + self._stiffness = sympify(stiffness, strict=True) + + @property + def equilibrium_length(self): + """The length of the spring at which it produces no force.""" + return self._equilibrium_length + + @equilibrium_length.setter + def equilibrium_length(self, equilibrium_length): + if hasattr(self, '_equilibrium_length'): + msg = ( + f'Can\'t set attribute `equilibrium_length` to ' + f'{repr(equilibrium_length)} as it is immutable.' + ) + raise AttributeError(msg) + self._equilibrium_length = sympify(equilibrium_length, strict=True) + + def __repr__(self): + """Representation of a ``LinearSpring``.""" + string = f'{self.__class__.__name__}({self.stiffness}, {self.pathway}' + if self.equilibrium_length == S.Zero: + string += ')' + else: + string += f', equilibrium_length={self.equilibrium_length})' + return string + + +class LinearDamper(ForceActuator): + """A damper whose force is a linear function of its extension velocity. + + Explanation + =========== + + Note that the "linear" in the name ``LinearDamper`` refers to the fact that + the damping force is a linear function of the damper's rate of change in + its length. I.e. for a linear damper with damping ``c`` and extension + velocity ``v``, the damping force will be ``-c*v``, which is a linear + function in ``v``. To create a damper that follows a linear, or straight, + pathway between its two ends, a ``LinearPathway`` instance needs to be + passed to the ``pathway`` parameter. + + A ``LinearDamper`` is a subclass of ``ForceActuator`` and so follows the + same sign conventions for length, extension velocity, and the direction of + the forces it applies to its points of attachment on bodies. The sign + convention for the direction of forces is such that, for the case where a + linear damper is instantiated with a ``LinearPathway`` instance as its + pathway, they act to push the two ends of the damper away from one another. + Because dampers produce a force that opposes the direction of change in + length, when extension velocity is positive the scalar portions of the + forces applied at the two endpoints are negative in order to flip the sign + of the forces on the endpoints wen converted into vector quantities. When + extension velocity is negative (i.e. when the damper is shortening), the + scalar portions of the fofces applied are also negative so that the signs + cancel producing forces on the endpoints that are in the same direction as + the positive sign convention for the forces at the endpoints of the pathway + (i.e. they act to push the endpoints away from one another). The following + diagram shows the positive force sense and the distance between the + points:: + + P Q + o<--- F --->o + | | + |<--l(t)--->| + + Examples + ======== + + To construct a linear damper, an expression (or symbol) must be supplied to + represent the damping coefficient of the damper (we'll use the symbol + ``c``), alongside a pathway specifying its line of action. Let's also + create a global reference frame and spatially fix one of the points in it + while setting the other to be positioned such that it can freely move in + the frame's x direction specified by the coordinate ``q``. The velocity + that the two points move away from one another can be specified by the + coordinate ``u`` where ``u`` is the first time derivative of ``q`` + (i.e., ``u = Derivative(q(t), t)``). + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import (LinearDamper, LinearPathway, + ... Point, ReferenceFrame) + >>> from sympy.physics.vector import dynamicsymbols + >>> N = ReferenceFrame('N') + >>> q = dynamicsymbols('q') + >>> damping = symbols('c') + >>> pA, pB = Point('pA'), Point('pB') + >>> pA.set_vel(N, 0) + >>> pB.set_pos(pA, q*N.x) + >>> pB.pos_from(pA) + q(t)*N.x + >>> pB.vel(N) + Derivative(q(t), t)*N.x + >>> linear_pathway = LinearPathway(pA, pB) + >>> damper = LinearDamper(damping, linear_pathway) + >>> damper + LinearDamper(c, LinearPathway(pA, pB)) + + This damper will produce a force that is proportional to both its damping + coefficient and the pathway's extension length. Note that this force is + negative as SymPy's sign convention for actuators is that negative forces + are contractile and the damping force of the damper will oppose the + direction of length change. + + >>> damper.force + -c*sqrt(q(t)**2)*Derivative(q(t), t)/q(t) + + Parameters + ========== + + damping : Expr + The damping constant. + pathway : PathwayBase + The pathway that the actuator follows. This must be an instance of a + concrete subclass of ``PathwayBase``, e.g. ``LinearPathway``. + + See Also + ======== + + ForceActuator: force-producing actuator (superclass of ``LinearDamper``). + LinearPathway: straight-line pathway between a pair of points. + + """ + + def __init__(self, damping, pathway): + """Initializer for ``LinearDamper``. + + Parameters + ========== + + damping : Expr + The damping constant. + pathway : PathwayBase + The pathway that the actuator follows. This must be an instance of + a concrete subclass of ``PathwayBase``, e.g. ``LinearPathway``. + + """ + self.damping = damping + self.pathway = pathway + + @property + def force(self): + """The damping force produced by the linear damper.""" + return -self.damping*self.pathway.extension_velocity + + @force.setter + def force(self, force): + raise AttributeError('Can\'t set computed attribute `force`.') + + @property + def damping(self): + """The damping constant for the linear damper.""" + return self._damping + + @damping.setter + def damping(self, damping): + if hasattr(self, '_damping'): + msg = ( + f'Can\'t set attribute `damping` to {repr(damping)} as it is ' + f'immutable.' + ) + raise AttributeError(msg) + self._damping = sympify(damping, strict=True) + + def __repr__(self): + """Representation of a ``LinearDamper``.""" + return f'{self.__class__.__name__}({self.damping}, {self.pathway})' + + +class TorqueActuator(ActuatorBase): + """Torque-producing actuator. + + Explanation + =========== + + A ``TorqueActuator`` is an actuator that produces a pair of equal and + opposite torques on a pair of bodies. + + Examples + ======== + + To construct a torque actuator, an expression (or symbol) must be supplied + to represent the torque it can produce, alongside a vector specifying the + axis about which the torque will act, and a pair of frames on which the + torque will act. + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import (ReferenceFrame, RigidBody, + ... TorqueActuator) + >>> N = ReferenceFrame('N') + >>> A = ReferenceFrame('A') + >>> torque = symbols('T') + >>> axis = N.z + >>> parent = RigidBody('parent', frame=N) + >>> child = RigidBody('child', frame=A) + >>> bodies = (child, parent) + >>> actuator = TorqueActuator(torque, axis, *bodies) + >>> actuator + TorqueActuator(T, axis=N.z, target_frame=A, reaction_frame=N) + + Note that because torques actually act on frames, not bodies, + ``TorqueActuator`` will extract the frame associated with a ``RigidBody`` + when one is passed instead of a ``ReferenceFrame``. + + Parameters + ========== + + torque : Expr + The scalar expression defining the torque that the actuator produces. + axis : Vector + The axis about which the actuator applies torques. + target_frame : ReferenceFrame | RigidBody + The primary frame on which the actuator will apply the torque. + reaction_frame : ReferenceFrame | RigidBody | None + The secondary frame on which the actuator will apply the torque. Note + that the (equal and opposite) reaction torque is applied to this frame. + + """ + + def __init__(self, torque, axis, target_frame, reaction_frame=None): + """Initializer for ``TorqueActuator``. + + Parameters + ========== + + torque : Expr + The scalar expression defining the torque that the actuator + produces. + axis : Vector + The axis about which the actuator applies torques. + target_frame : ReferenceFrame | RigidBody + The primary frame on which the actuator will apply the torque. + reaction_frame : ReferenceFrame | RigidBody | None + The secondary frame on which the actuator will apply the torque. + Note that the (equal and opposite) reaction torque is applied to + this frame. + + """ + self.torque = torque + self.axis = axis + self.target_frame = target_frame + self.reaction_frame = reaction_frame + + @classmethod + def at_pin_joint(cls, torque, pin_joint): + """Alternate constructor to instantiate from a ``PinJoint`` instance. + + Examples + ======== + + To create a pin joint the ``PinJoint`` class requires a name, parent + body, and child body to be passed to its constructor. It is also + possible to control the joint axis using the ``joint_axis`` keyword + argument. In this example let's use the parent body's reference frame's + z-axis as the joint axis. + + >>> from sympy.physics.mechanics import (PinJoint, ReferenceFrame, + ... RigidBody, TorqueActuator) + >>> N = ReferenceFrame('N') + >>> A = ReferenceFrame('A') + >>> parent = RigidBody('parent', frame=N) + >>> child = RigidBody('child', frame=A) + >>> pin_joint = PinJoint( + ... 'pin', + ... parent, + ... child, + ... joint_axis=N.z, + ... ) + + Let's also create a symbol ``T`` that will represent the torque applied + by the torque actuator. + + >>> from sympy import symbols + >>> torque = symbols('T') + + To create the torque actuator from the ``torque`` and ``pin_joint`` + variables previously instantiated, these can be passed to the alternate + constructor class method ``at_pin_joint`` of the ``TorqueActuator`` + class. It should be noted that a positive torque will cause a positive + displacement of the joint coordinate or that the torque is applied on + the child body with a reaction torque on the parent. + + >>> actuator = TorqueActuator.at_pin_joint(torque, pin_joint) + >>> actuator + TorqueActuator(T, axis=N.z, target_frame=A, reaction_frame=N) + + Parameters + ========== + + torque : Expr + The scalar expression defining the torque that the actuator + produces. + pin_joint : PinJoint + The pin joint, and by association the parent and child bodies, on + which the torque actuator will act. The pair of bodies acted upon + by the torque actuator are the parent and child bodies of the pin + joint, with the child acting as the reaction body. The pin joint's + axis is used as the axis about which the torque actuator will apply + its torque. + + """ + if not isinstance(pin_joint, PinJoint): + msg = ( + f'Value {repr(pin_joint)} passed to `pin_joint` was of type ' + f'{type(pin_joint)}, must be {PinJoint}.' + ) + raise TypeError(msg) + return cls( + torque, + pin_joint.joint_axis, + pin_joint.child_interframe, + pin_joint.parent_interframe, + ) + + @property + def torque(self): + """The magnitude of the torque produced by the actuator.""" + return self._torque + + @torque.setter + def torque(self, torque): + if hasattr(self, '_torque'): + msg = ( + f'Can\'t set attribute `torque` to {repr(torque)} as it is ' + f'immutable.' + ) + raise AttributeError(msg) + self._torque = sympify(torque, strict=True) + + @property + def axis(self): + """The axis about which the torque acts.""" + return self._axis + + @axis.setter + def axis(self, axis): + if hasattr(self, '_axis'): + msg = ( + f'Can\'t set attribute `axis` to {repr(axis)} as it is ' + f'immutable.' + ) + raise AttributeError(msg) + if not isinstance(axis, Vector): + msg = ( + f'Value {repr(axis)} passed to `axis` was of type ' + f'{type(axis)}, must be {Vector}.' + ) + raise TypeError(msg) + self._axis = axis + + @property + def target_frame(self): + """The primary reference frames on which the torque will act.""" + return self._target_frame + + @target_frame.setter + def target_frame(self, target_frame): + if hasattr(self, '_target_frame'): + msg = ( + f'Can\'t set attribute `target_frame` to {repr(target_frame)} ' + f'as it is immutable.' + ) + raise AttributeError(msg) + if isinstance(target_frame, RigidBody): + target_frame = target_frame.frame + elif not isinstance(target_frame, ReferenceFrame): + msg = ( + f'Value {repr(target_frame)} passed to `target_frame` was of ' + f'type {type(target_frame)}, must be {ReferenceFrame}.' + ) + raise TypeError(msg) + self._target_frame = target_frame + + @property + def reaction_frame(self): + """The primary reference frames on which the torque will act.""" + return self._reaction_frame + + @reaction_frame.setter + def reaction_frame(self, reaction_frame): + if hasattr(self, '_reaction_frame'): + msg = ( + f'Can\'t set attribute `reaction_frame` to ' + f'{repr(reaction_frame)} as it is immutable.' + ) + raise AttributeError(msg) + if isinstance(reaction_frame, RigidBody): + reaction_frame = reaction_frame.frame + elif ( + not isinstance(reaction_frame, ReferenceFrame) + and reaction_frame is not None + ): + msg = ( + f'Value {repr(reaction_frame)} passed to `reaction_frame` was ' + f'of type {type(reaction_frame)}, must be {ReferenceFrame}.' + ) + raise TypeError(msg) + self._reaction_frame = reaction_frame + + def to_loads(self): + """Loads required by the equations of motion method classes. + + Explanation + =========== + + ``KanesMethod`` requires a list of ``Point``-``Vector`` tuples to be + passed to the ``loads`` parameters of its ``kanes_equations`` method + when constructing the equations of motion. This method acts as a + utility to produce the correctly-structred pairs of points and vectors + required so that these can be easily concatenated with other items in + the list of loads and passed to ``KanesMethod.kanes_equations``. These + loads are also in the correct form to also be passed to the other + equations of motion method classes, e.g. ``LagrangesMethod``. + + Examples + ======== + + The below example shows how to generate the loads produced by a torque + actuator that acts on a pair of bodies attached by a pin joint. + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import (PinJoint, ReferenceFrame, + ... RigidBody, TorqueActuator) + >>> torque = symbols('T') + >>> N = ReferenceFrame('N') + >>> A = ReferenceFrame('A') + >>> parent = RigidBody('parent', frame=N) + >>> child = RigidBody('child', frame=A) + >>> pin_joint = PinJoint( + ... 'pin', + ... parent, + ... child, + ... joint_axis=N.z, + ... ) + >>> actuator = TorqueActuator.at_pin_joint(torque, pin_joint) + + The forces produces by the damper can be generated by calling the + ``to_loads`` method. + + >>> actuator.to_loads() + [(A, T*N.z), (N, - T*N.z)] + + Alternatively, if a torque actuator is created without a reaction frame + then the loads returned by the ``to_loads`` method will contain just + the single load acting on the target frame. + + >>> actuator = TorqueActuator(torque, N.z, N) + >>> actuator.to_loads() + [(N, T*N.z)] + + """ + loads = [ + Torque(self.target_frame, self.torque*self.axis), + ] + if self.reaction_frame is not None: + loads.append(Torque(self.reaction_frame, -self.torque*self.axis)) + return loads + + def __repr__(self): + """Representation of a ``TorqueActuator``.""" + string = ( + f'{self.__class__.__name__}({self.torque}, axis={self.axis}, ' + f'target_frame={self.target_frame}' + ) + if self.reaction_frame is not None: + string += f', reaction_frame={self.reaction_frame})' + else: + string += ')' + return string + + +class DuffingSpring(ForceActuator): + """A nonlinear spring based on the Duffing equation. + + Explanation + =========== + + Here, ``DuffingSpring`` represents the force exerted by a nonlinear spring based on the Duffing equation: + F = -beta*x-alpha*x**3, where x is the displacement from the equilibrium position, beta is the linear spring constant, + and alpha is the coefficient for the nonlinear cubic term. + + Parameters + ========== + + linear_stiffness : Expr + The linear stiffness coefficient (beta). + nonlinear_stiffness : Expr + The nonlinear stiffness coefficient (alpha). + pathway : PathwayBase + The pathway that the actuator follows. + equilibrium_length : Expr, optional + The length at which the spring is in equilibrium (x). + """ + + def __init__(self, linear_stiffness, nonlinear_stiffness, pathway, equilibrium_length=S.Zero): + self.linear_stiffness = sympify(linear_stiffness, strict=True) + self.nonlinear_stiffness = sympify(nonlinear_stiffness, strict=True) + self.equilibrium_length = sympify(equilibrium_length, strict=True) + + if not isinstance(pathway, PathwayBase): + raise TypeError("pathway must be an instance of PathwayBase.") + self._pathway = pathway + + @property + def linear_stiffness(self): + return self._linear_stiffness + + @linear_stiffness.setter + def linear_stiffness(self, linear_stiffness): + if hasattr(self, '_linear_stiffness'): + msg = ( + f'Can\'t set attribute `linear_stiffness` to ' + f'{repr(linear_stiffness)} as it is immutable.' + ) + raise AttributeError(msg) + self._linear_stiffness = sympify(linear_stiffness, strict=True) + + @property + def nonlinear_stiffness(self): + return self._nonlinear_stiffness + + @nonlinear_stiffness.setter + def nonlinear_stiffness(self, nonlinear_stiffness): + if hasattr(self, '_nonlinear_stiffness'): + msg = ( + f'Can\'t set attribute `nonlinear_stiffness` to ' + f'{repr(nonlinear_stiffness)} as it is immutable.' + ) + raise AttributeError(msg) + self._nonlinear_stiffness = sympify(nonlinear_stiffness, strict=True) + + @property + def pathway(self): + return self._pathway + + @pathway.setter + def pathway(self, pathway): + if hasattr(self, '_pathway'): + msg = ( + f'Can\'t set attribute `pathway` to {repr(pathway)} as it is ' + f'immutable.' + ) + raise AttributeError(msg) + if not isinstance(pathway, PathwayBase): + msg = ( + f'Value {repr(pathway)} passed to `pathway` was of type ' + f'{type(pathway)}, must be {PathwayBase}.' + ) + raise TypeError(msg) + self._pathway = pathway + + @property + def equilibrium_length(self): + return self._equilibrium_length + + @equilibrium_length.setter + def equilibrium_length(self, equilibrium_length): + if hasattr(self, '_equilibrium_length'): + msg = ( + f'Can\'t set attribute `equilibrium_length` to ' + f'{repr(equilibrium_length)} as it is immutable.' + ) + raise AttributeError(msg) + self._equilibrium_length = sympify(equilibrium_length, strict=True) + + @property + def force(self): + """The force produced by the Duffing spring.""" + displacement = self.pathway.length - self.equilibrium_length + return -self.linear_stiffness * displacement - self.nonlinear_stiffness * displacement**3 + + @force.setter + def force(self, force): + if hasattr(self, '_force'): + msg = ( + f'Can\'t set attribute `force` to {repr(force)} as it is ' + f'immutable.' + ) + raise AttributeError(msg) + self._force = sympify(force, strict=True) + + def __repr__(self): + return (f"{self.__class__.__name__}(" + f"{self.linear_stiffness}, {self.nonlinear_stiffness}, {self.pathway}, " + f"equilibrium_length={self.equilibrium_length})") + +class CoulombKineticFriction(ForceActuator): + r"""Coulomb kinetic friction with Stribeck and viscous effects. + + Explanation + =========== + + This represents a Coulomb kinetic friction with the Stribeck and viscous effect, + described by the function: + + .. math:: + F = (\mu_k f_n + (\mu_s - \mu_k) f_n e^{-(\frac{v}{v_s})^2}) \text{sign}(v) + \sigma v + + where :math:`\mu_k` is the coefficient of kinetic friction, :math:`\mu_s` is the + coefficient of static friction, :math:`f_n` is the normal force, :math:`v` is the + relative velocity, :math:`v_s` is the Stribeck friction coefficient, and + :math:`\sigma` is the viscous friction constant. + + The default friction force is :math:`F = \mu_k f_n`. + When specified, the actuator includes: + + - Stribeck effect: :math:`(\mu_s - \mu_k) f_n e^{-(\frac{v}{v_s})^2}` + - Viscous effect: :math:`\sigma v` + + Notes + ===== + + The actuator makes the following assumptions: + + - The actuator assumes relative motion is non-zero. + - The normal force is assumed to be a non-negative scalar. + - The resultant friction force is opposite to the velocity direction. + - Each point in the pathway is fixed within separate objects that are sliding relative to each other. In other words, these two points are fixed in the mutually sliding objects. + + This actuator has been tested for straightforward motions, like a block sliding + on a surface. + + The friction force is defined to always oppose the direction of relative velocity :math:`v`. + Specifically: + + - The default Coulomb friction force :math:`\mu_k f_n \text{sign}(v)` is opposite to :math:`v`. + - The Stribeck effect :math:`(\mu_s - \mu_k) f_n e^{-(\frac{v}{v_s})^2} \text{sign}(v)` is also opposite to :math:`v`. + - The viscous friction term :math:`\sigma v` is opposite to :math:`v`. + + Examples + ======== + + The below example shows how to generate the loads produced by a Coulomb kinetic + friction actuator in a mass-spring system with friction. + + >>> import sympy as sm + >>> from sympy.physics.mechanics import (dynamicsymbols, ReferenceFrame, Point, + ... LinearPathway, CoulombKineticFriction, LinearSpring, KanesMethod, Particle) + + >>> x, v = dynamicsymbols('x, v', real=True) + >>> m, g, k, mu_k, mu_s, v_s, sigma = sm.symbols('m, g, k, mu_k, mu_s, v_s, sigma') + + >>> N = ReferenceFrame('N') + >>> O, P = Point('O'), Point('P') + >>> O.set_vel(N, 0) + >>> P.set_pos(O, x*N.x) + + >>> pathway = LinearPathway(O, P) + >>> friction = CoulombKineticFriction(mu_k, m*g, pathway, v_s=v_s, sigma=sigma, mu_s=mu_k) + >>> spring = LinearSpring(k, pathway) + >>> block = Particle('block', point=P, mass=m) + + >>> kane = KanesMethod(N, (x,), (v,), kd_eqs=(x.diff() - v,)) + >>> friction.to_loads() + [(O, (g*m*mu_k*sign(sign(x(t))*Derivative(x(t), t)) + sigma*sign(x(t))*Derivative(x(t), t))*x(t)/Abs(x(t))*N.x), (P, (-g*m*mu_k*sign(sign(x(t))*Derivative(x(t), t)) - sigma*sign(x(t))*Derivative(x(t), t))*x(t)/Abs(x(t))*N.x)] + >>> loads = friction.to_loads() + spring.to_loads() + >>> fr, frstar = kane.kanes_equations([block], loads) + >>> eom = fr + frstar + >>> eom + Matrix([[-k*x(t) - m*Derivative(v(t), t) + (-g*m*mu_k*sign(v(t)*sign(x(t))) - sigma*v(t)*sign(x(t)))*x(t)/Abs(x(t))]]) + + Parameters + ========== + + f_n : sympifiable + The normal force between the surfaces. It should always be a non-negative scalar. + mu_k : sympifiable + The coefficient of kinetic friction. + pathway : PathwayBase + The pathway that the actuator follows. + v_s : sympifiable, optional + The Stribeck friction coefficient. + sigma : sympifiable, optional + The viscous friction coefficient. + mu_s : sympifiable, optional + The coefficient of static friction. Defaults to mu_k, meaning the Stribeck effect evaluates to 0 by default. + + References + ========== + + .. [Moore2022] https://moorepants.github.io/learn-multibody-dynamics/loads.html#friction. + .. [Flores2023] Paulo Flores, Jorge Ambrosio, Hamid M. Lankarani, + "Contact-impact events with friction in multibody dynamics: Back to basics", + Mechanism and Machine Theory, vol. 184, 2023. https://doi.org/10.1016/j.mechmachtheory.2023.105305. + .. [Rogner2017] I. Rogner, "Friction modelling for robotic applications with planar motion", + Chalmers University of Technology, Department of Electrical Engineering, 2017. + + """ + + def __init__(self, mu_k, f_n, pathway, *, v_s=None, sigma=None, mu_s=None): + self._mu_k = sympify(mu_k, strict=True) if mu_k is not None else 1 + self._mu_s = sympify(mu_s, strict=True) if mu_s is not None else self._mu_k + self._f_n = sympify(f_n, strict=True) + self._sigma = sympify(sigma, strict=True) if sigma is not None else 0 + self._v_s = sympify(v_s, strict=True) if v_s is not None or v_s == 0 else 0.01 + self.pathway = pathway + + @property + def mu_k(self): + """The coefficient of kinetic friction.""" + return self._mu_k + + @property + def mu_s(self): + """The coefficient of static friction.""" + return self._mu_s + + @property + def f_n(self): + """The normal force between the surfaces.""" + return self._f_n + + @property + def sigma(self): + """The viscous friction coefficient.""" + return self._sigma + + @property + def v_s(self): + """The Stribeck friction coefficient.""" + return self._v_s + + @property + def force(self): + v = self.pathway.extension_velocity + f_c = self.mu_k * self.f_n + f_max = self.mu_s * self.f_n + stribeck_term = (f_max - f_c) * exp(-(v / self.v_s)**2) if self.v_s is not None else 0 + viscous_term = self.sigma * v if self.sigma is not None else 0 + return (f_c + stribeck_term) * -sign(v) - viscous_term + + @force.setter + def force(self, force): + raise AttributeError('Can\'t set computed attribute `force`.') + + def __repr__(self): + return (f'{self.__class__.__name__}({self._mu_k}, {self._mu_s} ' + f'{self._f_n}, {self.pathway}, {self._v_s}, ' + f'{self._sigma})') diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/body.py b/phivenv/Lib/site-packages/sympy/physics/mechanics/body.py new file mode 100644 index 0000000000000000000000000000000000000000..efc367158bbf51e7d9929318ac9286ba5c3fb3ac --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/mechanics/body.py @@ -0,0 +1,710 @@ +from sympy import Symbol +from sympy.physics.vector import Point, Vector, ReferenceFrame, Dyadic +from sympy.physics.mechanics import RigidBody, Particle, Inertia +from sympy.physics.mechanics.body_base import BodyBase +from sympy.utilities.exceptions import sympy_deprecation_warning + +__all__ = ['Body'] + + +# XXX: We use type:ignore because the classes RigidBody and Particle have +# inconsistent parallel axis methods that take different numbers of arguments. +class Body(RigidBody, Particle): # type: ignore + """ + Body is a common representation of either a RigidBody or a Particle SymPy + object depending on what is passed in during initialization. If a mass is + passed in and central_inertia is left as None, the Particle object is + created. Otherwise a RigidBody object will be created. + + .. deprecated:: 1.13 + The Body class is deprecated. Its functionality is captured by + :class:`~.RigidBody` and :class:`~.Particle`. + + Explanation + =========== + + The attributes that Body possesses will be the same as a Particle instance + or a Rigid Body instance depending on which was created. Additional + attributes are listed below. + + Attributes + ========== + + name : string + The body's name + masscenter : Point + The point which represents the center of mass of the rigid body + frame : ReferenceFrame + The reference frame which the body is fixed in + mass : Sympifyable + The body's mass + inertia : (Dyadic, Point) + The body's inertia around its center of mass. This attribute is specific + to the rigid body form of Body and is left undefined for the Particle + form + loads : iterable + This list contains information on the different loads acting on the + Body. Forces are listed as a (point, vector) tuple and torques are + listed as (reference frame, vector) tuples. + + Parameters + ========== + + name : String + Defines the name of the body. It is used as the base for defining + body specific properties. + masscenter : Point, optional + A point that represents the center of mass of the body or particle. + If no point is given, a point is generated. + mass : Sympifyable, optional + A Sympifyable object which represents the mass of the body. If no + mass is passed, one is generated. + frame : ReferenceFrame, optional + The ReferenceFrame that represents the reference frame of the body. + If no frame is given, a frame is generated. + central_inertia : Dyadic, optional + Central inertia dyadic of the body. If none is passed while creating + RigidBody, a default inertia is generated. + + Examples + ======== + + As Body has been deprecated, the following examples are for illustrative + purposes only. The functionality of Body is fully captured by + :class:`~.RigidBody` and :class:`~.Particle`. To ignore the deprecation + warning we can use the ignore_warnings context manager. + + >>> from sympy.utilities.exceptions import ignore_warnings + + Default behaviour. This results in the creation of a RigidBody object for + which the mass, mass center, frame and inertia attributes are given default + values. :: + + >>> from sympy.physics.mechanics import Body + >>> with ignore_warnings(DeprecationWarning): + ... body = Body('name_of_body') + + This next example demonstrates the code required to specify all of the + values of the Body object. Note this will also create a RigidBody version of + the Body object. :: + + >>> from sympy import Symbol + >>> from sympy.physics.mechanics import ReferenceFrame, Point, inertia + >>> from sympy.physics.mechanics import Body + >>> mass = Symbol('mass') + >>> masscenter = Point('masscenter') + >>> frame = ReferenceFrame('frame') + >>> ixx = Symbol('ixx') + >>> body_inertia = inertia(frame, ixx, 0, 0) + >>> with ignore_warnings(DeprecationWarning): + ... body = Body('name_of_body', masscenter, mass, frame, body_inertia) + + The minimal code required to create a Particle version of the Body object + involves simply passing in a name and a mass. :: + + >>> from sympy import Symbol + >>> from sympy.physics.mechanics import Body + >>> mass = Symbol('mass') + >>> with ignore_warnings(DeprecationWarning): + ... body = Body('name_of_body', mass=mass) + + The Particle version of the Body object can also receive a masscenter point + and a reference frame, just not an inertia. + """ + + def __init__(self, name, masscenter=None, mass=None, frame=None, + central_inertia=None): + sympy_deprecation_warning( + """ + Support for the Body class has been removed, as its functionality is + fully captured by RigidBody and Particle. + """, + deprecated_since_version="1.13", + active_deprecations_target="deprecated-mechanics-body-class" + ) + + self._loads = [] + + if frame is None: + frame = ReferenceFrame(name + '_frame') + + if masscenter is None: + masscenter = Point(name + '_masscenter') + + if central_inertia is None and mass is None: + ixx = Symbol(name + '_ixx') + iyy = Symbol(name + '_iyy') + izz = Symbol(name + '_izz') + izx = Symbol(name + '_izx') + ixy = Symbol(name + '_ixy') + iyz = Symbol(name + '_iyz') + _inertia = Inertia.from_inertia_scalars(masscenter, frame, ixx, iyy, + izz, ixy, iyz, izx) + else: + _inertia = (central_inertia, masscenter) + + if mass is None: + _mass = Symbol(name + '_mass') + else: + _mass = mass + + masscenter.set_vel(frame, 0) + + # If user passes masscenter and mass then a particle is created + # otherwise a rigidbody. As a result a body may or may not have inertia. + # Note: BodyBase.__init__ is used to prevent problems with super() calls in + # Particle and RigidBody arising due to multiple inheritance. + if central_inertia is None and mass is not None: + BodyBase.__init__(self, name, masscenter, _mass) + self.frame = frame + self._central_inertia = Dyadic(0) + else: + BodyBase.__init__(self, name, masscenter, _mass) + self.frame = frame + self.inertia = _inertia + + def __repr__(self): + if self.is_rigidbody: + return RigidBody.__repr__(self) + return Particle.__repr__(self) + + @property + def loads(self): + return self._loads + + @property + def x(self): + """The basis Vector for the Body, in the x direction.""" + return self.frame.x + + @property + def y(self): + """The basis Vector for the Body, in the y direction.""" + return self.frame.y + + @property + def z(self): + """The basis Vector for the Body, in the z direction.""" + return self.frame.z + + @property + def inertia(self): + """The body's inertia about a point; stored as (Dyadic, Point).""" + if self.is_rigidbody: + return RigidBody.inertia.fget(self) + return (self.central_inertia, self.masscenter) + + @inertia.setter + def inertia(self, I): + RigidBody.inertia.fset(self, I) + + @property + def is_rigidbody(self): + if hasattr(self, '_inertia'): + return True + return False + + def kinetic_energy(self, frame): + """Kinetic energy of the body. + + Parameters + ========== + + frame : ReferenceFrame or Body + The Body's angular velocity and the velocity of it's mass + center are typically defined with respect to an inertial frame but + any relevant frame in which the velocities are known can be supplied. + + Examples + ======== + + As Body has been deprecated, the following examples are for illustrative + purposes only. The functionality of Body is fully captured by + :class:`~.RigidBody` and :class:`~.Particle`. To ignore the deprecation + warning we can use the ignore_warnings context manager. + + >>> from sympy.utilities.exceptions import ignore_warnings + >>> from sympy.physics.mechanics import Body, ReferenceFrame, Point + >>> from sympy import symbols + >>> m, v, r, omega = symbols('m v r omega') + >>> N = ReferenceFrame('N') + >>> O = Point('O') + >>> with ignore_warnings(DeprecationWarning): + ... P = Body('P', masscenter=O, mass=m) + >>> P.masscenter.set_vel(N, v * N.y) + >>> P.kinetic_energy(N) + m*v**2/2 + + >>> N = ReferenceFrame('N') + >>> b = ReferenceFrame('b') + >>> b.set_ang_vel(N, omega * b.x) + >>> P = Point('P') + >>> P.set_vel(N, v * N.x) + >>> with ignore_warnings(DeprecationWarning): + ... B = Body('B', masscenter=P, frame=b) + >>> B.kinetic_energy(N) + B_ixx*omega**2/2 + B_mass*v**2/2 + + See Also + ======== + + sympy.physics.mechanics : Particle, RigidBody + + """ + if isinstance(frame, Body): + frame = Body.frame + if self.is_rigidbody: + return RigidBody(self.name, self.masscenter, self.frame, self.mass, + (self.central_inertia, self.masscenter)).kinetic_energy(frame) + return Particle(self.name, self.masscenter, self.mass).kinetic_energy(frame) + + def apply_force(self, force, point=None, reaction_body=None, reaction_point=None): + """Add force to the body(s). + + Explanation + =========== + + Applies the force on self or equal and opposite forces on + self and other body if both are given on the desired point on the bodies. + The force applied on other body is taken opposite of self, i.e, -force. + + Parameters + ========== + + force: Vector + The force to be applied. + point: Point, optional + The point on self on which force is applied. + By default self's masscenter. + reaction_body: Body, optional + Second body on which equal and opposite force + is to be applied. + reaction_point : Point, optional + The point on other body on which equal and opposite + force is applied. By default masscenter of other body. + + Example + ======= + + As Body has been deprecated, the following examples are for illustrative + purposes only. The functionality of Body is fully captured by + :class:`~.RigidBody` and :class:`~.Particle`. To ignore the deprecation + warning we can use the ignore_warnings context manager. + + >>> from sympy.utilities.exceptions import ignore_warnings + >>> from sympy import symbols + >>> from sympy.physics.mechanics import Body, Point, dynamicsymbols + >>> m, g = symbols('m g') + >>> with ignore_warnings(DeprecationWarning): + ... B = Body('B') + >>> force1 = m*g*B.z + >>> B.apply_force(force1) #Applying force on B's masscenter + >>> B.loads + [(B_masscenter, g*m*B_frame.z)] + + We can also remove some part of force from any point on the body by + adding the opposite force to the body on that point. + + >>> f1, f2 = dynamicsymbols('f1 f2') + >>> P = Point('P') #Considering point P on body B + >>> B.apply_force(f1*B.x + f2*B.y, P) + >>> B.loads + [(B_masscenter, g*m*B_frame.z), (P, f1(t)*B_frame.x + f2(t)*B_frame.y)] + + Let's remove f1 from point P on body B. + + >>> B.apply_force(-f1*B.x, P) + >>> B.loads + [(B_masscenter, g*m*B_frame.z), (P, f2(t)*B_frame.y)] + + To further demonstrate the use of ``apply_force`` attribute, + consider two bodies connected through a spring. + + >>> from sympy.physics.mechanics import Body, dynamicsymbols + >>> with ignore_warnings(DeprecationWarning): + ... N = Body('N') #Newtonion Frame + >>> x = dynamicsymbols('x') + >>> with ignore_warnings(DeprecationWarning): + ... B1 = Body('B1') + ... B2 = Body('B2') + >>> spring_force = x*N.x + + Now let's apply equal and opposite spring force to the bodies. + + >>> P1 = Point('P1') + >>> P2 = Point('P2') + >>> B1.apply_force(spring_force, point=P1, reaction_body=B2, reaction_point=P2) + + We can check the loads(forces) applied to bodies now. + + >>> B1.loads + [(P1, x(t)*N_frame.x)] + >>> B2.loads + [(P2, - x(t)*N_frame.x)] + + Notes + ===== + + If a new force is applied to a body on a point which already has some + force applied on it, then the new force is added to the already applied + force on that point. + + """ + + if not isinstance(point, Point): + if point is None: + point = self.masscenter # masscenter + else: + raise TypeError("Force must be applied to a point on the body.") + if not isinstance(force, Vector): + raise TypeError("Force must be a vector.") + + if reaction_body is not None: + reaction_body.apply_force(-force, point=reaction_point) + + for load in self._loads: + if point in load: + force += load[1] + self._loads.remove(load) + break + + self._loads.append((point, force)) + + def apply_torque(self, torque, reaction_body=None): + """Add torque to the body(s). + + Explanation + =========== + + Applies the torque on self or equal and opposite torques on + self and other body if both are given. + The torque applied on other body is taken opposite of self, + i.e, -torque. + + Parameters + ========== + + torque: Vector + The torque to be applied. + reaction_body: Body, optional + Second body on which equal and opposite torque + is to be applied. + + Example + ======= + + As Body has been deprecated, the following examples are for illustrative + purposes only. The functionality of Body is fully captured by + :class:`~.RigidBody` and :class:`~.Particle`. To ignore the deprecation + warning we can use the ignore_warnings context manager. + + >>> from sympy.utilities.exceptions import ignore_warnings + >>> from sympy import symbols + >>> from sympy.physics.mechanics import Body, dynamicsymbols + >>> t = symbols('t') + >>> with ignore_warnings(DeprecationWarning): + ... B = Body('B') + >>> torque1 = t*B.z + >>> B.apply_torque(torque1) + >>> B.loads + [(B_frame, t*B_frame.z)] + + We can also remove some part of torque from the body by + adding the opposite torque to the body. + + >>> t1, t2 = dynamicsymbols('t1 t2') + >>> B.apply_torque(t1*B.x + t2*B.y) + >>> B.loads + [(B_frame, t1(t)*B_frame.x + t2(t)*B_frame.y + t*B_frame.z)] + + Let's remove t1 from Body B. + + >>> B.apply_torque(-t1*B.x) + >>> B.loads + [(B_frame, t2(t)*B_frame.y + t*B_frame.z)] + + To further demonstrate the use, let us consider two bodies such that + a torque `T` is acting on one body, and `-T` on the other. + + >>> from sympy.physics.mechanics import Body, dynamicsymbols + >>> with ignore_warnings(DeprecationWarning): + ... N = Body('N') #Newtonion frame + ... B1 = Body('B1') + ... B2 = Body('B2') + >>> v = dynamicsymbols('v') + >>> T = v*N.y #Torque + + Now let's apply equal and opposite torque to the bodies. + + >>> B1.apply_torque(T, B2) + + We can check the loads (torques) applied to bodies now. + + >>> B1.loads + [(B1_frame, v(t)*N_frame.y)] + >>> B2.loads + [(B2_frame, - v(t)*N_frame.y)] + + Notes + ===== + + If a new torque is applied on body which already has some torque applied on it, + then the new torque is added to the previous torque about the body's frame. + + """ + + if not isinstance(torque, Vector): + raise TypeError("A Vector must be supplied to add torque.") + + if reaction_body is not None: + reaction_body.apply_torque(-torque) + + for load in self._loads: + if self.frame in load: + torque += load[1] + self._loads.remove(load) + break + self._loads.append((self.frame, torque)) + + def clear_loads(self): + """ + Clears the Body's loads list. + + Example + ======= + + As Body has been deprecated, the following examples are for illustrative + purposes only. The functionality of Body is fully captured by + :class:`~.RigidBody` and :class:`~.Particle`. To ignore the deprecation + warning we can use the ignore_warnings context manager. + + >>> from sympy.utilities.exceptions import ignore_warnings + >>> from sympy.physics.mechanics import Body + >>> with ignore_warnings(DeprecationWarning): + ... B = Body('B') + >>> force = B.x + B.y + >>> B.apply_force(force) + >>> B.loads + [(B_masscenter, B_frame.x + B_frame.y)] + >>> B.clear_loads() + >>> B.loads + [] + + """ + + self._loads = [] + + def remove_load(self, about=None): + """ + Remove load about a point or frame. + + Parameters + ========== + + about : Point or ReferenceFrame, optional + The point about which force is applied, + and is to be removed. + If about is None, then the torque about + self's frame is removed. + + Example + ======= + + As Body has been deprecated, the following examples are for illustrative + purposes only. The functionality of Body is fully captured by + :class:`~.RigidBody` and :class:`~.Particle`. To ignore the deprecation + warning we can use the ignore_warnings context manager. + + >>> from sympy.utilities.exceptions import ignore_warnings + >>> from sympy.physics.mechanics import Body, Point + >>> with ignore_warnings(DeprecationWarning): + ... B = Body('B') + >>> P = Point('P') + >>> f1 = B.x + >>> f2 = B.y + >>> B.apply_force(f1) + >>> B.apply_force(f2, P) + >>> B.loads + [(B_masscenter, B_frame.x), (P, B_frame.y)] + + >>> B.remove_load(P) + >>> B.loads + [(B_masscenter, B_frame.x)] + + """ + + if about is not None: + if not isinstance(about, Point): + raise TypeError('Load is applied about Point or ReferenceFrame.') + else: + about = self.frame + + for load in self._loads: + if about in load: + self._loads.remove(load) + break + + def masscenter_vel(self, body): + """ + Returns the velocity of the mass center with respect to the provided + rigid body or reference frame. + + Parameters + ========== + + body: Body or ReferenceFrame + The rigid body or reference frame to calculate the velocity in. + + Example + ======= + + As Body has been deprecated, the following examples are for illustrative + purposes only. The functionality of Body is fully captured by + :class:`~.RigidBody` and :class:`~.Particle`. To ignore the deprecation + warning we can use the ignore_warnings context manager. + + >>> from sympy.utilities.exceptions import ignore_warnings + >>> from sympy.physics.mechanics import Body + >>> with ignore_warnings(DeprecationWarning): + ... A = Body('A') + ... B = Body('B') + >>> A.masscenter.set_vel(B.frame, 5*B.frame.x) + >>> A.masscenter_vel(B) + 5*B_frame.x + >>> A.masscenter_vel(B.frame) + 5*B_frame.x + + """ + + if isinstance(body, ReferenceFrame): + frame=body + elif isinstance(body, Body): + frame = body.frame + return self.masscenter.vel(frame) + + def ang_vel_in(self, body): + """ + Returns this body's angular velocity with respect to the provided + rigid body or reference frame. + + Parameters + ========== + + body: Body or ReferenceFrame + The rigid body or reference frame to calculate the angular velocity in. + + Example + ======= + + As Body has been deprecated, the following examples are for illustrative + purposes only. The functionality of Body is fully captured by + :class:`~.RigidBody` and :class:`~.Particle`. To ignore the deprecation + warning we can use the ignore_warnings context manager. + + >>> from sympy.utilities.exceptions import ignore_warnings + >>> from sympy.physics.mechanics import Body, ReferenceFrame + >>> with ignore_warnings(DeprecationWarning): + ... A = Body('A') + >>> N = ReferenceFrame('N') + >>> with ignore_warnings(DeprecationWarning): + ... B = Body('B', frame=N) + >>> A.frame.set_ang_vel(N, 5*N.x) + >>> A.ang_vel_in(B) + 5*N.x + >>> A.ang_vel_in(N) + 5*N.x + + """ + + if isinstance(body, ReferenceFrame): + frame=body + elif isinstance(body, Body): + frame = body.frame + return self.frame.ang_vel_in(frame) + + def dcm(self, body): + """ + Returns the direction cosine matrix of this body relative to the + provided rigid body or reference frame. + + Parameters + ========== + + body: Body or ReferenceFrame + The rigid body or reference frame to calculate the dcm. + + Example + ======= + + As Body has been deprecated, the following examples are for illustrative + purposes only. The functionality of Body is fully captured by + :class:`~.RigidBody` and :class:`~.Particle`. To ignore the deprecation + warning we can use the ignore_warnings context manager. + + >>> from sympy.utilities.exceptions import ignore_warnings + >>> from sympy.physics.mechanics import Body + >>> with ignore_warnings(DeprecationWarning): + ... A = Body('A') + ... B = Body('B') + >>> A.frame.orient_axis(B.frame, B.frame.x, 5) + >>> A.dcm(B) + Matrix([ + [1, 0, 0], + [0, cos(5), sin(5)], + [0, -sin(5), cos(5)]]) + >>> A.dcm(B.frame) + Matrix([ + [1, 0, 0], + [0, cos(5), sin(5)], + [0, -sin(5), cos(5)]]) + + """ + + if isinstance(body, ReferenceFrame): + frame=body + elif isinstance(body, Body): + frame = body.frame + return self.frame.dcm(frame) + + def parallel_axis(self, point, frame=None): + """Returns the inertia dyadic of the body with respect to another + point. + + Parameters + ========== + + point : sympy.physics.vector.Point + The point to express the inertia dyadic about. + frame : sympy.physics.vector.ReferenceFrame + The reference frame used to construct the dyadic. + + Returns + ======= + + inertia : sympy.physics.vector.Dyadic + The inertia dyadic of the rigid body expressed about the provided + point. + + Example + ======= + + As Body has been deprecated, the following examples are for illustrative + purposes only. The functionality of Body is fully captured by + :class:`~.RigidBody` and :class:`~.Particle`. To ignore the deprecation + warning we can use the ignore_warnings context manager. + + >>> from sympy.utilities.exceptions import ignore_warnings + >>> from sympy.physics.mechanics import Body + >>> with ignore_warnings(DeprecationWarning): + ... A = Body('A') + >>> P = A.masscenter.locatenew('point', 3 * A.x + 5 * A.y) + >>> A.parallel_axis(P).to_matrix(A.frame) + Matrix([ + [A_ixx + 25*A_mass, A_ixy - 15*A_mass, A_izx], + [A_ixy - 15*A_mass, A_iyy + 9*A_mass, A_iyz], + [ A_izx, A_iyz, A_izz + 34*A_mass]]) + + """ + if self.is_rigidbody: + return RigidBody.parallel_axis(self, point, frame) + return Particle.parallel_axis(self, point, frame) diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/body_base.py b/phivenv/Lib/site-packages/sympy/physics/mechanics/body_base.py new file mode 100644 index 0000000000000000000000000000000000000000..d2546faf685f579d2aea10ed7f139a4beced7dd0 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/mechanics/body_base.py @@ -0,0 +1,94 @@ +from abc import ABC, abstractmethod +from sympy import Symbol, sympify +from sympy.physics.vector import Point + +__all__ = ['BodyBase'] + + +class BodyBase(ABC): + """Abstract class for body type objects.""" + def __init__(self, name, masscenter=None, mass=None): + # Note: If frame=None, no auto-generated frame is created, because a + # Particle does not need to have a frame by default. + if not isinstance(name, str): + raise TypeError('Supply a valid name.') + self._name = name + if mass is None: + mass = Symbol(f'{name}_mass') + if masscenter is None: + masscenter = Point(f'{name}_masscenter') + self.mass = mass + self.masscenter = masscenter + self.potential_energy = 0 + self.points = [] + + def __str__(self): + return self.name + + def __repr__(self): + return (f'{self.__class__.__name__}({repr(self.name)}, masscenter=' + f'{repr(self.masscenter)}, mass={repr(self.mass)})') + + @property + def name(self): + """The name of the body.""" + return self._name + + @property + def masscenter(self): + """The body's center of mass.""" + return self._masscenter + + @masscenter.setter + def masscenter(self, point): + if not isinstance(point, Point): + raise TypeError("The body's center of mass must be a Point object.") + self._masscenter = point + + @property + def mass(self): + """The body's mass.""" + return self._mass + + @mass.setter + def mass(self, mass): + self._mass = sympify(mass) + + @property + def potential_energy(self): + """The potential energy of the body. + + Examples + ======== + + >>> from sympy.physics.mechanics import Particle, Point + >>> from sympy import symbols + >>> m, g, h = symbols('m g h') + >>> O = Point('O') + >>> P = Particle('P', O, m) + >>> P.potential_energy = m * g * h + >>> P.potential_energy + g*h*m + + """ + return self._potential_energy + + @potential_energy.setter + def potential_energy(self, scalar): + self._potential_energy = sympify(scalar) + + @abstractmethod + def kinetic_energy(self, frame): + pass + + @abstractmethod + def linear_momentum(self, frame): + pass + + @abstractmethod + def angular_momentum(self, point, frame): + pass + + @abstractmethod + def parallel_axis(self, point, frame): + pass diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/functions.py b/phivenv/Lib/site-packages/sympy/physics/mechanics/functions.py new file mode 100644 index 0000000000000000000000000000000000000000..42abe2b7fe608b4602cdab518f209b446b2dbe03 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/mechanics/functions.py @@ -0,0 +1,735 @@ +from sympy.utilities import dict_merge +from sympy.utilities.iterables import iterable +from sympy.physics.vector import (Dyadic, Vector, ReferenceFrame, + Point, dynamicsymbols) +from sympy.physics.vector.printing import (vprint, vsprint, vpprint, vlatex, + init_vprinting) +from sympy.physics.mechanics.particle import Particle +from sympy.physics.mechanics.rigidbody import RigidBody +from sympy.simplify.simplify import simplify +from sympy import Matrix, Mul, Derivative, sin, cos, tan, S +from sympy.core.function import AppliedUndef +from sympy.physics.mechanics.inertia import (inertia as _inertia, + inertia_of_point_mass as _inertia_of_point_mass) +from sympy.utilities.exceptions import sympy_deprecation_warning + +__all__ = ['linear_momentum', + 'angular_momentum', + 'kinetic_energy', + 'potential_energy', + 'Lagrangian', + 'mechanics_printing', + 'mprint', + 'msprint', + 'mpprint', + 'mlatex', + 'msubs', + 'find_dynamicsymbols'] + +# These are functions that we've moved and renamed during extracting the +# basic vector calculus code from the mechanics packages. + +mprint = vprint +msprint = vsprint +mpprint = vpprint +mlatex = vlatex + + +def mechanics_printing(**kwargs): + """ + Initializes time derivative printing for all SymPy objects in + mechanics module. + """ + + init_vprinting(**kwargs) + +mechanics_printing.__doc__ = init_vprinting.__doc__ + + +def inertia(frame, ixx, iyy, izz, ixy=0, iyz=0, izx=0): + sympy_deprecation_warning( + """ + The inertia function has been moved. + Import it from "sympy.physics.mechanics". + """, + deprecated_since_version="1.13", + active_deprecations_target="moved-mechanics-functions" + ) + return _inertia(frame, ixx, iyy, izz, ixy, iyz, izx) + + +def inertia_of_point_mass(mass, pos_vec, frame): + sympy_deprecation_warning( + """ + The inertia_of_point_mass function has been moved. + Import it from "sympy.physics.mechanics". + """, + deprecated_since_version="1.13", + active_deprecations_target="moved-mechanics-functions" + ) + return _inertia_of_point_mass(mass, pos_vec, frame) + + +def linear_momentum(frame, *body): + """Linear momentum of the system. + + Explanation + =========== + + This function returns the linear momentum of a system of Particle's and/or + RigidBody's. The linear momentum of a system is equal to the vector sum of + the linear momentum of its constituents. Consider a system, S, comprised of + a rigid body, A, and a particle, P. The linear momentum of the system, L, + is equal to the vector sum of the linear momentum of the particle, L1, and + the linear momentum of the rigid body, L2, i.e. + + L = L1 + L2 + + Parameters + ========== + + frame : ReferenceFrame + The frame in which linear momentum is desired. + body1, body2, body3... : Particle and/or RigidBody + The body (or bodies) whose linear momentum is required. + + Examples + ======== + + >>> from sympy.physics.mechanics import Point, Particle, ReferenceFrame + >>> from sympy.physics.mechanics import RigidBody, outer, linear_momentum + >>> N = ReferenceFrame('N') + >>> P = Point('P') + >>> P.set_vel(N, 10 * N.x) + >>> Pa = Particle('Pa', P, 1) + >>> Ac = Point('Ac') + >>> Ac.set_vel(N, 25 * N.y) + >>> I = outer(N.x, N.x) + >>> A = RigidBody('A', Ac, N, 20, (I, Ac)) + >>> linear_momentum(N, A, Pa) + 10*N.x + 500*N.y + + """ + + if not isinstance(frame, ReferenceFrame): + raise TypeError('Please specify a valid ReferenceFrame') + else: + linear_momentum_sys = Vector(0) + for e in body: + if isinstance(e, (RigidBody, Particle)): + linear_momentum_sys += e.linear_momentum(frame) + else: + raise TypeError('*body must have only Particle or RigidBody') + return linear_momentum_sys + + +def angular_momentum(point, frame, *body): + """Angular momentum of a system. + + Explanation + =========== + + This function returns the angular momentum of a system of Particle's and/or + RigidBody's. The angular momentum of such a system is equal to the vector + sum of the angular momentum of its constituents. Consider a system, S, + comprised of a rigid body, A, and a particle, P. The angular momentum of + the system, H, is equal to the vector sum of the angular momentum of the + particle, H1, and the angular momentum of the rigid body, H2, i.e. + + H = H1 + H2 + + Parameters + ========== + + point : Point + The point about which angular momentum of the system is desired. + frame : ReferenceFrame + The frame in which angular momentum is desired. + body1, body2, body3... : Particle and/or RigidBody + The body (or bodies) whose angular momentum is required. + + Examples + ======== + + >>> from sympy.physics.mechanics import Point, Particle, ReferenceFrame + >>> from sympy.physics.mechanics import RigidBody, outer, angular_momentum + >>> N = ReferenceFrame('N') + >>> O = Point('O') + >>> O.set_vel(N, 0 * N.x) + >>> P = O.locatenew('P', 1 * N.x) + >>> P.set_vel(N, 10 * N.x) + >>> Pa = Particle('Pa', P, 1) + >>> Ac = O.locatenew('Ac', 2 * N.y) + >>> Ac.set_vel(N, 5 * N.y) + >>> a = ReferenceFrame('a') + >>> a.set_ang_vel(N, 10 * N.z) + >>> I = outer(N.z, N.z) + >>> A = RigidBody('A', Ac, a, 20, (I, Ac)) + >>> angular_momentum(O, N, Pa, A) + 10*N.z + + """ + + if not isinstance(frame, ReferenceFrame): + raise TypeError('Please enter a valid ReferenceFrame') + if not isinstance(point, Point): + raise TypeError('Please specify a valid Point') + else: + angular_momentum_sys = Vector(0) + for e in body: + if isinstance(e, (RigidBody, Particle)): + angular_momentum_sys += e.angular_momentum(point, frame) + else: + raise TypeError('*body must have only Particle or RigidBody') + return angular_momentum_sys + + +def kinetic_energy(frame, *body): + """Kinetic energy of a multibody system. + + Explanation + =========== + + This function returns the kinetic energy of a system of Particle's and/or + RigidBody's. The kinetic energy of such a system is equal to the sum of + the kinetic energies of its constituents. Consider a system, S, comprising + a rigid body, A, and a particle, P. The kinetic energy of the system, T, + is equal to the vector sum of the kinetic energy of the particle, T1, and + the kinetic energy of the rigid body, T2, i.e. + + T = T1 + T2 + + Kinetic energy is a scalar. + + Parameters + ========== + + frame : ReferenceFrame + The frame in which the velocity or angular velocity of the body is + defined. + body1, body2, body3... : Particle and/or RigidBody + The body (or bodies) whose kinetic energy is required. + + Examples + ======== + + >>> from sympy.physics.mechanics import Point, Particle, ReferenceFrame + >>> from sympy.physics.mechanics import RigidBody, outer, kinetic_energy + >>> N = ReferenceFrame('N') + >>> O = Point('O') + >>> O.set_vel(N, 0 * N.x) + >>> P = O.locatenew('P', 1 * N.x) + >>> P.set_vel(N, 10 * N.x) + >>> Pa = Particle('Pa', P, 1) + >>> Ac = O.locatenew('Ac', 2 * N.y) + >>> Ac.set_vel(N, 5 * N.y) + >>> a = ReferenceFrame('a') + >>> a.set_ang_vel(N, 10 * N.z) + >>> I = outer(N.z, N.z) + >>> A = RigidBody('A', Ac, a, 20, (I, Ac)) + >>> kinetic_energy(N, Pa, A) + 350 + + """ + + if not isinstance(frame, ReferenceFrame): + raise TypeError('Please enter a valid ReferenceFrame') + ke_sys = S.Zero + for e in body: + if isinstance(e, (RigidBody, Particle)): + ke_sys += e.kinetic_energy(frame) + else: + raise TypeError('*body must have only Particle or RigidBody') + return ke_sys + + +def potential_energy(*body): + """Potential energy of a multibody system. + + Explanation + =========== + + This function returns the potential energy of a system of Particle's and/or + RigidBody's. The potential energy of such a system is equal to the sum of + the potential energy of its constituents. Consider a system, S, comprising + a rigid body, A, and a particle, P. The potential energy of the system, V, + is equal to the vector sum of the potential energy of the particle, V1, and + the potential energy of the rigid body, V2, i.e. + + V = V1 + V2 + + Potential energy is a scalar. + + Parameters + ========== + + body1, body2, body3... : Particle and/or RigidBody + The body (or bodies) whose potential energy is required. + + Examples + ======== + + >>> from sympy.physics.mechanics import Point, Particle, ReferenceFrame + >>> from sympy.physics.mechanics import RigidBody, outer, potential_energy + >>> from sympy import symbols + >>> M, m, g, h = symbols('M m g h') + >>> N = ReferenceFrame('N') + >>> O = Point('O') + >>> O.set_vel(N, 0 * N.x) + >>> P = O.locatenew('P', 1 * N.x) + >>> Pa = Particle('Pa', P, m) + >>> Ac = O.locatenew('Ac', 2 * N.y) + >>> a = ReferenceFrame('a') + >>> I = outer(N.z, N.z) + >>> A = RigidBody('A', Ac, a, M, (I, Ac)) + >>> Pa.potential_energy = m * g * h + >>> A.potential_energy = M * g * h + >>> potential_energy(Pa, A) + M*g*h + g*h*m + + """ + + pe_sys = S.Zero + for e in body: + if isinstance(e, (RigidBody, Particle)): + pe_sys += e.potential_energy + else: + raise TypeError('*body must have only Particle or RigidBody') + return pe_sys + + +def gravity(acceleration, *bodies): + from sympy.physics.mechanics.loads import gravity as _gravity + sympy_deprecation_warning( + """ + The gravity function has been moved. + Import it from "sympy.physics.mechanics.loads". + """, + deprecated_since_version="1.13", + active_deprecations_target="moved-mechanics-functions" + ) + return _gravity(acceleration, *bodies) + + +def center_of_mass(point, *bodies): + """ + Returns the position vector from the given point to the center of mass + of the given bodies(particles or rigidbodies). + + Example + ======= + + >>> from sympy import symbols, S + >>> from sympy.physics.vector import Point + >>> from sympy.physics.mechanics import Particle, ReferenceFrame, RigidBody, outer + >>> from sympy.physics.mechanics.functions import center_of_mass + >>> a = ReferenceFrame('a') + >>> m = symbols('m', real=True) + >>> p1 = Particle('p1', Point('p1_pt'), S(1)) + >>> p2 = Particle('p2', Point('p2_pt'), S(2)) + >>> p3 = Particle('p3', Point('p3_pt'), S(3)) + >>> p4 = Particle('p4', Point('p4_pt'), m) + >>> b_f = ReferenceFrame('b_f') + >>> b_cm = Point('b_cm') + >>> mb = symbols('mb') + >>> b = RigidBody('b', b_cm, b_f, mb, (outer(b_f.x, b_f.x), b_cm)) + >>> p2.point.set_pos(p1.point, a.x) + >>> p3.point.set_pos(p1.point, a.x + a.y) + >>> p4.point.set_pos(p1.point, a.y) + >>> b.masscenter.set_pos(p1.point, a.y + a.z) + >>> point_o=Point('o') + >>> point_o.set_pos(p1.point, center_of_mass(p1.point, p1, p2, p3, p4, b)) + >>> expr = 5/(m + mb + 6)*a.x + (m + mb + 3)/(m + mb + 6)*a.y + mb/(m + mb + 6)*a.z + >>> point_o.pos_from(p1.point) + 5/(m + mb + 6)*a.x + (m + mb + 3)/(m + mb + 6)*a.y + mb/(m + mb + 6)*a.z + + """ + if not bodies: + raise TypeError("No bodies(instances of Particle or Rigidbody) were passed.") + + total_mass = 0 + vec = Vector(0) + for i in bodies: + total_mass += i.mass + + masscenter = getattr(i, 'masscenter', None) + if masscenter is None: + masscenter = i.point + vec += i.mass*masscenter.pos_from(point) + + return vec/total_mass + + +def Lagrangian(frame, *body): + """Lagrangian of a multibody system. + + Explanation + =========== + + This function returns the Lagrangian of a system of Particle's and/or + RigidBody's. The Lagrangian of such a system is equal to the difference + between the kinetic energies and potential energies of its constituents. If + T and V are the kinetic and potential energies of a system then it's + Lagrangian, L, is defined as + + L = T - V + + The Lagrangian is a scalar. + + Parameters + ========== + + frame : ReferenceFrame + The frame in which the velocity or angular velocity of the body is + defined to determine the kinetic energy. + + body1, body2, body3... : Particle and/or RigidBody + The body (or bodies) whose Lagrangian is required. + + Examples + ======== + + >>> from sympy.physics.mechanics import Point, Particle, ReferenceFrame + >>> from sympy.physics.mechanics import RigidBody, outer, Lagrangian + >>> from sympy import symbols + >>> M, m, g, h = symbols('M m g h') + >>> N = ReferenceFrame('N') + >>> O = Point('O') + >>> O.set_vel(N, 0 * N.x) + >>> P = O.locatenew('P', 1 * N.x) + >>> P.set_vel(N, 10 * N.x) + >>> Pa = Particle('Pa', P, 1) + >>> Ac = O.locatenew('Ac', 2 * N.y) + >>> Ac.set_vel(N, 5 * N.y) + >>> a = ReferenceFrame('a') + >>> a.set_ang_vel(N, 10 * N.z) + >>> I = outer(N.z, N.z) + >>> A = RigidBody('A', Ac, a, 20, (I, Ac)) + >>> Pa.potential_energy = m * g * h + >>> A.potential_energy = M * g * h + >>> Lagrangian(N, Pa, A) + -M*g*h - g*h*m + 350 + + """ + + if not isinstance(frame, ReferenceFrame): + raise TypeError('Please supply a valid ReferenceFrame') + for e in body: + if not isinstance(e, (RigidBody, Particle)): + raise TypeError('*body must have only Particle or RigidBody') + return kinetic_energy(frame, *body) - potential_energy(*body) + + +def find_dynamicsymbols(expression, exclude=None, reference_frame=None): + """Find all dynamicsymbols in expression. + + Explanation + =========== + + If the optional ``exclude`` kwarg is used, only dynamicsymbols + not in the iterable ``exclude`` are returned. + If we intend to apply this function on a vector, the optional + ``reference_frame`` is also used to inform about the corresponding frame + with respect to which the dynamic symbols of the given vector is to be + determined. + + Parameters + ========== + + expression : SymPy expression + + exclude : iterable of dynamicsymbols, optional + + reference_frame : ReferenceFrame, optional + The frame with respect to which the dynamic symbols of the + given vector is to be determined. + + Examples + ======== + + >>> from sympy.physics.mechanics import dynamicsymbols, find_dynamicsymbols + >>> from sympy.physics.mechanics import ReferenceFrame + >>> x, y = dynamicsymbols('x, y') + >>> expr = x + x.diff()*y + >>> find_dynamicsymbols(expr) + {x(t), y(t), Derivative(x(t), t)} + >>> find_dynamicsymbols(expr, exclude=[x, y]) + {Derivative(x(t), t)} + >>> a, b, c = dynamicsymbols('a, b, c') + >>> A = ReferenceFrame('A') + >>> v = a * A.x + b * A.y + c * A.z + >>> find_dynamicsymbols(v, reference_frame=A) + {a(t), b(t), c(t)} + + """ + t_set = {dynamicsymbols._t} + if exclude: + if iterable(exclude): + exclude_set = set(exclude) + else: + raise TypeError("exclude kwarg must be iterable") + else: + exclude_set = set() + if isinstance(expression, Vector): + if reference_frame is None: + raise ValueError("You must provide reference_frame when passing a " + "vector expression, got %s." % reference_frame) + else: + expression = expression.to_matrix(reference_frame) + return {i for i in expression.atoms(AppliedUndef, Derivative) if + i.free_symbols == t_set} - exclude_set + + +def msubs(expr, *sub_dicts, smart=False, **kwargs): + """A custom subs for use on expressions derived in physics.mechanics. + + Traverses the expression tree once, performing the subs found in sub_dicts. + Terms inside ``Derivative`` expressions are ignored: + + Examples + ======== + + >>> from sympy.physics.mechanics import dynamicsymbols, msubs + >>> x = dynamicsymbols('x') + >>> msubs(x.diff() + x, {x: 1}) + Derivative(x(t), t) + 1 + + Note that sub_dicts can be a single dictionary, or several dictionaries: + + >>> x, y, z = dynamicsymbols('x, y, z') + >>> sub1 = {x: 1, y: 2} + >>> sub2 = {z: 3, x.diff(): 4} + >>> msubs(x.diff() + x + y + z, sub1, sub2) + 10 + + If smart=True (default False), also checks for conditions that may result + in ``nan``, but if simplified would yield a valid expression. For example: + + >>> from sympy import sin, tan + >>> (sin(x)/tan(x)).subs(x, 0) + nan + >>> msubs(sin(x)/tan(x), {x: 0}, smart=True) + 1 + + It does this by first replacing all ``tan`` with ``sin/cos``. Then each + node is traversed. If the node is a fraction, subs is first evaluated on + the denominator. If this results in 0, simplification of the entire + fraction is attempted. Using this selective simplification, only + subexpressions that result in 1/0 are targeted, resulting in faster + performance. + + """ + + sub_dict = dict_merge(*sub_dicts) + if smart: + func = _smart_subs + elif hasattr(expr, 'msubs'): + return expr.msubs(sub_dict) + else: + func = lambda expr, sub_dict: _crawl(expr, _sub_func, sub_dict) + if isinstance(expr, (Matrix, Vector, Dyadic)): + return expr.applyfunc(lambda x: func(x, sub_dict)) + else: + return func(expr, sub_dict) + + +def _crawl(expr, func, *args, **kwargs): + """Crawl the expression tree, and apply func to every node.""" + val = func(expr, *args, **kwargs) + if val is not None: + return val + new_args = (_crawl(arg, func, *args, **kwargs) for arg in expr.args) + return expr.func(*new_args) + + +def _sub_func(expr, sub_dict): + """Perform direct matching substitution, ignoring derivatives.""" + if expr in sub_dict: + return sub_dict[expr] + elif not expr.args or expr.is_Derivative: + return expr + + +def _tan_repl_func(expr): + """Replace tan with sin/cos.""" + if isinstance(expr, tan): + return sin(*expr.args) / cos(*expr.args) + elif not expr.args or expr.is_Derivative: + return expr + + +def _smart_subs(expr, sub_dict): + """Performs subs, checking for conditions that may result in `nan` or + `oo`, and attempts to simplify them out. + + The expression tree is traversed twice, and the following steps are + performed on each expression node: + - First traverse: + Replace all `tan` with `sin/cos`. + - Second traverse: + If node is a fraction, check if the denominator evaluates to 0. + If so, attempt to simplify it out. Then if node is in sub_dict, + sub in the corresponding value. + + """ + expr = _crawl(expr, _tan_repl_func) + + def _recurser(expr, sub_dict): + # Decompose the expression into num, den + num, den = _fraction_decomp(expr) + if den != 1: + # If there is a non trivial denominator, we need to handle it + denom_subbed = _recurser(den, sub_dict) + if denom_subbed.evalf() == 0: + # If denom is 0 after this, attempt to simplify the bad expr + expr = simplify(expr) + else: + # Expression won't result in nan, find numerator + num_subbed = _recurser(num, sub_dict) + return num_subbed / denom_subbed + # We have to crawl the tree manually, because `expr` may have been + # modified in the simplify step. First, perform subs as normal: + val = _sub_func(expr, sub_dict) + if val is not None: + return val + new_args = (_recurser(arg, sub_dict) for arg in expr.args) + return expr.func(*new_args) + return _recurser(expr, sub_dict) + + +def _fraction_decomp(expr): + """Return num, den such that expr = num/den.""" + if not isinstance(expr, Mul): + return expr, 1 + num = [] + den = [] + for a in expr.args: + if a.is_Pow and a.args[1] < 0: + den.append(1 / a) + else: + num.append(a) + if not den: + return expr, 1 + num = Mul(*num) + den = Mul(*den) + return num, den + + +def _f_list_parser(fl, ref_frame): + """Parses the provided forcelist composed of items + of the form (obj, force). + Returns a tuple containing: + vel_list: The velocity (ang_vel for Frames, vel for Points) in + the provided reference frame. + f_list: The forces. + + Used internally in the KanesMethod and LagrangesMethod classes. + + """ + def flist_iter(): + for pair in fl: + obj, force = pair + if isinstance(obj, ReferenceFrame): + yield obj.ang_vel_in(ref_frame), force + elif isinstance(obj, Point): + yield obj.vel(ref_frame), force + else: + raise TypeError('First entry in each forcelist pair must ' + 'be a point or frame.') + + if not fl: + vel_list, f_list = (), () + else: + unzip = lambda l: list(zip(*l)) if l[0] else [(), ()] + vel_list, f_list = unzip(list(flist_iter())) + return vel_list, f_list + + +def _validate_coordinates(coordinates=None, speeds=None, check_duplicates=True, + is_dynamicsymbols=True, u_auxiliary=None): + """Validate the generalized coordinates and generalized speeds. + + Parameters + ========== + coordinates : iterable, optional + Generalized coordinates to be validated. + speeds : iterable, optional + Generalized speeds to be validated. + check_duplicates : bool, optional + Checks if there are duplicates in the generalized coordinates and + generalized speeds. If so it will raise a ValueError. The default is + True. + is_dynamicsymbols : iterable, optional + Checks if all the generalized coordinates and generalized speeds are + dynamicsymbols. If any is not a dynamicsymbol, a ValueError will be + raised. The default is True. + u_auxiliary : iterable, optional + Auxiliary generalized speeds to be validated. + + """ + t_set = {dynamicsymbols._t} + # Convert input to iterables + if coordinates is None: + coordinates = [] + elif not iterable(coordinates): + coordinates = [coordinates] + if speeds is None: + speeds = [] + elif not iterable(speeds): + speeds = [speeds] + if u_auxiliary is None: + u_auxiliary = [] + elif not iterable(u_auxiliary): + u_auxiliary = [u_auxiliary] + + msgs = [] + if check_duplicates: # Check for duplicates + seen = set() + coord_duplicates = {x for x in coordinates if x in seen or seen.add(x)} + seen = set() + speed_duplicates = {x for x in speeds if x in seen or seen.add(x)} + seen = set() + aux_duplicates = {x for x in u_auxiliary if x in seen or seen.add(x)} + overlap_coords = set(coordinates).intersection(speeds) + overlap_aux = set(coordinates).union(speeds).intersection(u_auxiliary) + if coord_duplicates: + msgs.append(f'The generalized coordinates {coord_duplicates} are ' + f'duplicated, all generalized coordinates should be ' + f'unique.') + if speed_duplicates: + msgs.append(f'The generalized speeds {speed_duplicates} are ' + f'duplicated, all generalized speeds should be unique.') + if aux_duplicates: + msgs.append(f'The auxiliary speeds {aux_duplicates} are duplicated,' + f' all auxiliary speeds should be unique.') + if overlap_coords: + msgs.append(f'{overlap_coords} are defined as both generalized ' + f'coordinates and generalized speeds.') + if overlap_aux: + msgs.append(f'The auxiliary speeds {overlap_aux} are also defined ' + f'as generalized coordinates or generalized speeds.') + if is_dynamicsymbols: # Check whether all coordinates are dynamicsymbols + for coordinate in coordinates: + if not (isinstance(coordinate, (AppliedUndef, Derivative)) and + coordinate.free_symbols == t_set): + msgs.append(f'Generalized coordinate "{coordinate}" is not a ' + f'dynamicsymbol.') + for speed in speeds: + if not (isinstance(speed, (AppliedUndef, Derivative)) and + speed.free_symbols == t_set): + msgs.append( + f'Generalized speed "{speed}" is not a dynamicsymbol.') + for aux in u_auxiliary: + if not (isinstance(aux, (AppliedUndef, Derivative)) and + aux.free_symbols == t_set): + msgs.append( + f'Auxiliary speed "{aux}" is not a dynamicsymbol.') + if msgs: + raise ValueError('\n'.join(msgs)) + + +def _parse_linear_solver(linear_solver): + """Helper function to retrieve a specified linear solver.""" + if callable(linear_solver): + return linear_solver + return lambda A, b: Matrix.solve(A, b, method=linear_solver) diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/inertia.py b/phivenv/Lib/site-packages/sympy/physics/mechanics/inertia.py new file mode 100644 index 0000000000000000000000000000000000000000..683c1f630f3cedb82d02a9c5ba2309ae438b7fff --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/mechanics/inertia.py @@ -0,0 +1,199 @@ +from sympy import sympify +from sympy.physics.vector import Point, Dyadic, ReferenceFrame, outer +from collections import namedtuple + +__all__ = ['inertia', 'inertia_of_point_mass', 'Inertia'] + + +def inertia(frame, ixx, iyy, izz, ixy=0, iyz=0, izx=0): + """Simple way to create inertia Dyadic object. + + Explanation + =========== + + Creates an inertia Dyadic based on the given tensor values and a body-fixed + reference frame. + + Parameters + ========== + + frame : ReferenceFrame + The frame the inertia is defined in. + ixx : Sympifyable + The xx element in the inertia dyadic. + iyy : Sympifyable + The yy element in the inertia dyadic. + izz : Sympifyable + The zz element in the inertia dyadic. + ixy : Sympifyable + The xy element in the inertia dyadic. + iyz : Sympifyable + The yz element in the inertia dyadic. + izx : Sympifyable + The zx element in the inertia dyadic. + + Examples + ======== + + >>> from sympy.physics.mechanics import ReferenceFrame, inertia + >>> N = ReferenceFrame('N') + >>> inertia(N, 1, 2, 3) + (N.x|N.x) + 2*(N.y|N.y) + 3*(N.z|N.z) + + """ + + if not isinstance(frame, ReferenceFrame): + raise TypeError('Need to define the inertia in a frame') + ixx, iyy, izz = sympify(ixx), sympify(iyy), sympify(izz) + ixy, iyz, izx = sympify(ixy), sympify(iyz), sympify(izx) + return (ixx*outer(frame.x, frame.x) + ixy*outer(frame.x, frame.y) + + izx*outer(frame.x, frame.z) + ixy*outer(frame.y, frame.x) + + iyy*outer(frame.y, frame.y) + iyz*outer(frame.y, frame.z) + + izx*outer(frame.z, frame.x) + iyz*outer(frame.z, frame.y) + + izz*outer(frame.z, frame.z)) + + +def inertia_of_point_mass(mass, pos_vec, frame): + """Inertia dyadic of a point mass relative to point O. + + Parameters + ========== + + mass : Sympifyable + Mass of the point mass + pos_vec : Vector + Position from point O to point mass + frame : ReferenceFrame + Reference frame to express the dyadic in + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import ReferenceFrame, inertia_of_point_mass + >>> N = ReferenceFrame('N') + >>> r, m = symbols('r m') + >>> px = r * N.x + >>> inertia_of_point_mass(m, px, N) + m*r**2*(N.y|N.y) + m*r**2*(N.z|N.z) + + """ + + return mass*( + (outer(frame.x, frame.x) + + outer(frame.y, frame.y) + + outer(frame.z, frame.z)) * + (pos_vec.dot(pos_vec)) - outer(pos_vec, pos_vec)) + + +class Inertia(namedtuple('Inertia', ['dyadic', 'point'])): + """Inertia object consisting of a Dyadic and a Point of reference. + + Explanation + =========== + + This is a simple class to store the Point and Dyadic, belonging to an + inertia. + + Attributes + ========== + + dyadic : Dyadic + The dyadic of the inertia. + point : Point + The reference point of the inertia. + + Examples + ======== + + >>> from sympy.physics.mechanics import ReferenceFrame, Point, Inertia + >>> N = ReferenceFrame('N') + >>> Po = Point('Po') + >>> Inertia(N.x.outer(N.x) + N.y.outer(N.y) + N.z.outer(N.z), Po) + ((N.x|N.x) + (N.y|N.y) + (N.z|N.z), Po) + + In the example above the Dyadic was created manually, one can however also + use the ``inertia`` function for this or the class method ``from_tensor`` as + shown below. + + >>> Inertia.from_inertia_scalars(Po, N, 1, 1, 1) + ((N.x|N.x) + (N.y|N.y) + (N.z|N.z), Po) + + """ + __slots__ = () + + def __new__(cls, dyadic, point): + # Switch order if given in the wrong order + if isinstance(dyadic, Point) and isinstance(point, Dyadic): + point, dyadic = dyadic, point + if not isinstance(point, Point): + raise TypeError('Reference point should be of type Point') + if not isinstance(dyadic, Dyadic): + raise TypeError('Inertia value should be expressed as a Dyadic') + return super().__new__(cls, dyadic, point) + + @classmethod + def from_inertia_scalars(cls, point, frame, ixx, iyy, izz, ixy=0, iyz=0, + izx=0): + """Simple way to create an Inertia object based on the tensor values. + + Explanation + =========== + + This class method uses the :func`~.inertia` to create the Dyadic based + on the tensor values. + + Parameters + ========== + + point : Point + The reference point of the inertia. + frame : ReferenceFrame + The frame the inertia is defined in. + ixx : Sympifyable + The xx element in the inertia dyadic. + iyy : Sympifyable + The yy element in the inertia dyadic. + izz : Sympifyable + The zz element in the inertia dyadic. + ixy : Sympifyable + The xy element in the inertia dyadic. + iyz : Sympifyable + The yz element in the inertia dyadic. + izx : Sympifyable + The zx element in the inertia dyadic. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import ReferenceFrame, Point, Inertia + >>> ixx, iyy, izz, ixy, iyz, izx = symbols('ixx iyy izz ixy iyz izx') + >>> N = ReferenceFrame('N') + >>> P = Point('P') + >>> I = Inertia.from_inertia_scalars(P, N, ixx, iyy, izz, ixy, iyz, izx) + + The tensor values can easily be seen when converting the dyadic to a + matrix. + + >>> I.dyadic.to_matrix(N) + Matrix([ + [ixx, ixy, izx], + [ixy, iyy, iyz], + [izx, iyz, izz]]) + + """ + return cls(inertia(frame, ixx, iyy, izz, ixy, iyz, izx), point) + + def __add__(self, other): + raise TypeError(f"unsupported operand type(s) for +: " + f"'{self.__class__.__name__}' and " + f"'{other.__class__.__name__}'") + + def __mul__(self, other): + raise TypeError(f"unsupported operand type(s) for *: " + f"'{self.__class__.__name__}' and " + f"'{other.__class__.__name__}'") + + __radd__ = __add__ + __rmul__ = __mul__ diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/joint.py b/phivenv/Lib/site-packages/sympy/physics/mechanics/joint.py new file mode 100644 index 0000000000000000000000000000000000000000..6f3fe661532cff6bf8dda4ab4383fc09f75e9e44 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/mechanics/joint.py @@ -0,0 +1,2188 @@ +# coding=utf-8 + +from abc import ABC, abstractmethod + +from sympy import pi, Derivative, Matrix +from sympy.core.function import AppliedUndef +from sympy.physics.mechanics.body_base import BodyBase +from sympy.physics.mechanics.functions import _validate_coordinates +from sympy.physics.vector import (Vector, dynamicsymbols, cross, Point, + ReferenceFrame) +from sympy.utilities.iterables import iterable +from sympy.utilities.exceptions import sympy_deprecation_warning + +__all__ = ['Joint', 'PinJoint', 'PrismaticJoint', 'CylindricalJoint', + 'PlanarJoint', 'SphericalJoint', 'WeldJoint'] + + +class Joint(ABC): + """Abstract base class for all specific joints. + + Explanation + =========== + + A joint subtracts degrees of freedom from a body. This is the base class + for all specific joints and holds all common methods acting as an interface + for all joints. Custom joint can be created by inheriting Joint class and + defining all abstract functions. + + The abstract methods are: + + - ``_generate_coordinates`` + - ``_generate_speeds`` + - ``_orient_frames`` + - ``_set_angular_velocity`` + - ``_set_linear_velocity`` + + Parameters + ========== + + name : string + A unique name for the joint. + parent : Particle or RigidBody + The parent body of joint. + child : Particle or RigidBody + The child body of joint. + coordinates : iterable of dynamicsymbols, optional + Generalized coordinates of the joint. + speeds : iterable of dynamicsymbols, optional + Generalized speeds of joint. + parent_point : Point or Vector, optional + Attachment point where the joint is fixed to the parent body. If a + vector is provided, then the attachment point is computed by adding the + vector to the body's mass center. The default value is the parent's mass + center. + child_point : Point or Vector, optional + Attachment point where the joint is fixed to the child body. If a + vector is provided, then the attachment point is computed by adding the + vector to the body's mass center. The default value is the child's mass + center. + parent_axis : Vector, optional + .. deprecated:: 1.12 + Axis fixed in the parent body which aligns with an axis fixed in the + child body. The default is the x axis of parent's reference frame. + For more information on this deprecation, see + :ref:`deprecated-mechanics-joint-axis`. + child_axis : Vector, optional + .. deprecated:: 1.12 + Axis fixed in the child body which aligns with an axis fixed in the + parent body. The default is the x axis of child's reference frame. + For more information on this deprecation, see + :ref:`deprecated-mechanics-joint-axis`. + parent_interframe : ReferenceFrame, optional + Intermediate frame of the parent body with respect to which the joint + transformation is formulated. If a Vector is provided then an interframe + is created which aligns its X axis with the given vector. The default + value is the parent's own frame. + child_interframe : ReferenceFrame, optional + Intermediate frame of the child body with respect to which the joint + transformation is formulated. If a Vector is provided then an interframe + is created which aligns its X axis with the given vector. The default + value is the child's own frame. + parent_joint_pos : Point or Vector, optional + .. deprecated:: 1.12 + This argument is replaced by parent_point and will be removed in a + future version. + See :ref:`deprecated-mechanics-joint-pos` for more information. + child_joint_pos : Point or Vector, optional + .. deprecated:: 1.12 + This argument is replaced by child_point and will be removed in a + future version. + See :ref:`deprecated-mechanics-joint-pos` for more information. + + Attributes + ========== + + name : string + The joint's name. + parent : Particle or RigidBody + The joint's parent body. + child : Particle or RigidBody + The joint's child body. + coordinates : Matrix + Matrix of the joint's generalized coordinates. + speeds : Matrix + Matrix of the joint's generalized speeds. + parent_point : Point + Attachment point where the joint is fixed to the parent body. + child_point : Point + Attachment point where the joint is fixed to the child body. + parent_axis : Vector + The axis fixed in the parent frame that represents the joint. + child_axis : Vector + The axis fixed in the child frame that represents the joint. + parent_interframe : ReferenceFrame + Intermediate frame of the parent body with respect to which the joint + transformation is formulated. + child_interframe : ReferenceFrame + Intermediate frame of the child body with respect to which the joint + transformation is formulated. + kdes : Matrix + Kinematical differential equations of the joint. + + Notes + ===== + + When providing a vector as the intermediate frame, a new intermediate frame + is created which aligns its X axis with the provided vector. This is done + with a single fixed rotation about a rotation axis. This rotation axis is + determined by taking the cross product of the ``body.x`` axis with the + provided vector. In the case where the provided vector is in the ``-body.x`` + direction, the rotation is done about the ``body.y`` axis. + + """ + + def __init__(self, name, parent, child, coordinates=None, speeds=None, + parent_point=None, child_point=None, parent_interframe=None, + child_interframe=None, parent_axis=None, child_axis=None, + parent_joint_pos=None, child_joint_pos=None): + + if not isinstance(name, str): + raise TypeError('Supply a valid name.') + self._name = name + + if not isinstance(parent, BodyBase): + raise TypeError('Parent must be a body.') + self._parent = parent + + if not isinstance(child, BodyBase): + raise TypeError('Child must be a body.') + self._child = child + + if parent_axis is not None or child_axis is not None: + sympy_deprecation_warning( + """ + The parent_axis and child_axis arguments for the Joint classes + are deprecated. Instead use parent_interframe, child_interframe. + """, + deprecated_since_version="1.12", + active_deprecations_target="deprecated-mechanics-joint-axis", + stacklevel=4 + ) + if parent_interframe is None: + parent_interframe = parent_axis + if child_interframe is None: + child_interframe = child_axis + + # Set parent and child frame attributes + if hasattr(self._parent, 'frame'): + self._parent_frame = self._parent.frame + else: + if isinstance(parent_interframe, ReferenceFrame): + self._parent_frame = parent_interframe + else: + self._parent_frame = ReferenceFrame( + f'{self.name}_{self._parent.name}_frame') + if hasattr(self._child, 'frame'): + self._child_frame = self._child.frame + else: + if isinstance(child_interframe, ReferenceFrame): + self._child_frame = child_interframe + else: + self._child_frame = ReferenceFrame( + f'{self.name}_{self._child.name}_frame') + + self._parent_interframe = self._locate_joint_frame( + self._parent, parent_interframe, self._parent_frame) + self._child_interframe = self._locate_joint_frame( + self._child, child_interframe, self._child_frame) + self._parent_axis = self._axis(parent_axis, self._parent_frame) + self._child_axis = self._axis(child_axis, self._child_frame) + + if parent_joint_pos is not None or child_joint_pos is not None: + sympy_deprecation_warning( + """ + The parent_joint_pos and child_joint_pos arguments for the Joint + classes are deprecated. Instead use parent_point and child_point. + """, + deprecated_since_version="1.12", + active_deprecations_target="deprecated-mechanics-joint-pos", + stacklevel=4 + ) + if parent_point is None: + parent_point = parent_joint_pos + if child_point is None: + child_point = child_joint_pos + self._parent_point = self._locate_joint_pos( + self._parent, parent_point, self._parent_frame) + self._child_point = self._locate_joint_pos( + self._child, child_point, self._child_frame) + + self._coordinates = self._generate_coordinates(coordinates) + self._speeds = self._generate_speeds(speeds) + _validate_coordinates(self.coordinates, self.speeds) + self._kdes = self._generate_kdes() + + self._orient_frames() + self._set_angular_velocity() + self._set_linear_velocity() + + def __str__(self): + return self.name + + def __repr__(self): + return self.__str__() + + @property + def name(self): + """Name of the joint.""" + return self._name + + @property + def parent(self): + """Parent body of Joint.""" + return self._parent + + @property + def child(self): + """Child body of Joint.""" + return self._child + + @property + def coordinates(self): + """Matrix of the joint's generalized coordinates.""" + return self._coordinates + + @property + def speeds(self): + """Matrix of the joint's generalized speeds.""" + return self._speeds + + @property + def kdes(self): + """Kinematical differential equations of the joint.""" + return self._kdes + + @property + def parent_axis(self): + """The axis of parent frame.""" + # Will be removed with `deprecated-mechanics-joint-axis` + return self._parent_axis + + @property + def child_axis(self): + """The axis of child frame.""" + # Will be removed with `deprecated-mechanics-joint-axis` + return self._child_axis + + @property + def parent_point(self): + """Attachment point where the joint is fixed to the parent body.""" + return self._parent_point + + @property + def child_point(self): + """Attachment point where the joint is fixed to the child body.""" + return self._child_point + + @property + def parent_interframe(self): + return self._parent_interframe + + @property + def child_interframe(self): + return self._child_interframe + + @abstractmethod + def _generate_coordinates(self, coordinates): + """Generate Matrix of the joint's generalized coordinates.""" + pass + + @abstractmethod + def _generate_speeds(self, speeds): + """Generate Matrix of the joint's generalized speeds.""" + pass + + @abstractmethod + def _orient_frames(self): + """Orient frames as per the joint.""" + pass + + @abstractmethod + def _set_angular_velocity(self): + """Set angular velocity of the joint related frames.""" + pass + + @abstractmethod + def _set_linear_velocity(self): + """Set velocity of related points to the joint.""" + pass + + @staticmethod + def _to_vector(matrix, frame): + """Converts a matrix to a vector in the given frame.""" + return Vector([(matrix, frame)]) + + @staticmethod + def _axis(ax, *frames): + """Check whether an axis is fixed in one of the frames.""" + if ax is None: + ax = frames[0].x + return ax + if not isinstance(ax, Vector): + raise TypeError("Axis must be a Vector.") + ref_frame = None # Find a body in which the axis can be expressed + for frame in frames: + try: + ax.to_matrix(frame) + except ValueError: + pass + else: + ref_frame = frame + break + if ref_frame is None: + raise ValueError("Axis cannot be expressed in one of the body's " + "frames.") + if not ax.dt(ref_frame) == 0: + raise ValueError('Axis cannot be time-varying when viewed from the ' + 'associated body.') + return ax + + @staticmethod + def _choose_rotation_axis(frame, axis): + components = axis.to_matrix(frame) + x, y, z = components[0], components[1], components[2] + + if x != 0: + if y != 0: + if z != 0: + return cross(axis, frame.x) + if z != 0: + return frame.y + return frame.z + else: + if y != 0: + return frame.x + return frame.y + + @staticmethod + def _create_aligned_interframe(frame, align_axis, frame_axis=None, + frame_name=None): + """ + Returns an intermediate frame, where the ``frame_axis`` defined in + ``frame`` is aligned with ``axis``. By default this means that the X + axis will be aligned with ``axis``. + + Parameters + ========== + + frame : BodyBase or ReferenceFrame + The body or reference frame with respect to which the intermediate + frame is oriented. + align_axis : Vector + The vector with respect to which the intermediate frame will be + aligned. + frame_axis : Vector + The vector of the frame which should get aligned with ``axis``. The + default is the X axis of the frame. + frame_name : string + Name of the to be created intermediate frame. The default adds + "_int_frame" to the name of ``frame``. + + Example + ======= + + An intermediate frame, where the X axis of the parent becomes aligned + with ``parent.y + parent.z`` can be created as follows: + + >>> from sympy.physics.mechanics.joint import Joint + >>> from sympy.physics.mechanics import RigidBody + >>> parent = RigidBody('parent') + >>> parent_interframe = Joint._create_aligned_interframe( + ... parent, parent.y + parent.z) + >>> parent_interframe + parent_int_frame + >>> parent.frame.dcm(parent_interframe) + Matrix([ + [ 0, -sqrt(2)/2, -sqrt(2)/2], + [sqrt(2)/2, 1/2, -1/2], + [sqrt(2)/2, -1/2, 1/2]]) + >>> (parent.y + parent.z).express(parent_interframe) + sqrt(2)*parent_int_frame.x + + Notes + ===== + + The direction cosine matrix between the given frame and intermediate + frame is formed using a simple rotation about an axis that is normal to + both ``align_axis`` and ``frame_axis``. In general, the normal axis is + formed by crossing the ``frame_axis`` with the ``align_axis``. The + exception is if the axes are parallel with opposite directions, in which + case the rotation vector is chosen using the rules in the following + table with the vectors expressed in the given frame: + + .. list-table:: + :header-rows: 1 + + * - ``align_axis`` + - ``frame_axis`` + - ``rotation_axis`` + * - ``-x`` + - ``x`` + - ``z`` + * - ``-y`` + - ``y`` + - ``x`` + * - ``-z`` + - ``z`` + - ``y`` + * - ``-x-y`` + - ``x+y`` + - ``z`` + * - ``-y-z`` + - ``y+z`` + - ``x`` + * - ``-x-z`` + - ``x+z`` + - ``y`` + * - ``-x-y-z`` + - ``x+y+z`` + - ``(x+y+z) × x`` + + """ + if isinstance(frame, BodyBase): + frame = frame.frame + if frame_axis is None: + frame_axis = frame.x + if frame_name is None: + if frame.name[-6:] == '_frame': + frame_name = f'{frame.name[:-6]}_int_frame' + else: + frame_name = f'{frame.name}_int_frame' + angle = frame_axis.angle_between(align_axis) + rotation_axis = cross(frame_axis, align_axis) + if rotation_axis == Vector(0) and angle == 0: + return frame + if angle == pi: + rotation_axis = Joint._choose_rotation_axis(frame, align_axis) + + int_frame = ReferenceFrame(frame_name) + int_frame.orient_axis(frame, rotation_axis, angle) + int_frame.set_ang_vel(frame, 0 * rotation_axis) + return int_frame + + def _generate_kdes(self): + """Generate kinematical differential equations.""" + kdes = [] + t = dynamicsymbols._t + for i in range(len(self.coordinates)): + kdes.append(-self.coordinates[i].diff(t) + self.speeds[i]) + return Matrix(kdes) + + def _locate_joint_pos(self, body, joint_pos, body_frame=None): + """Returns the attachment point of a body.""" + if body_frame is None: + body_frame = body.frame + if joint_pos is None: + return body.masscenter + if not isinstance(joint_pos, (Point, Vector)): + raise TypeError('Attachment point must be a Point or Vector.') + if isinstance(joint_pos, Vector): + point_name = f'{self.name}_{body.name}_joint' + joint_pos = body.masscenter.locatenew(point_name, joint_pos) + if not joint_pos.pos_from(body.masscenter).dt(body_frame) == 0: + raise ValueError('Attachment point must be fixed to the associated ' + 'body.') + return joint_pos + + def _locate_joint_frame(self, body, interframe, body_frame=None): + """Returns the attachment frame of a body.""" + if body_frame is None: + body_frame = body.frame + if interframe is None: + return body_frame + if isinstance(interframe, Vector): + interframe = Joint._create_aligned_interframe( + body_frame, interframe, + frame_name=f'{self.name}_{body.name}_int_frame') + elif not isinstance(interframe, ReferenceFrame): + raise TypeError('Interframe must be a ReferenceFrame.') + if not interframe.ang_vel_in(body_frame) == 0: + raise ValueError(f'Interframe {interframe} is not fixed to body ' + f'{body}.') + body.masscenter.set_vel(interframe, 0) # Fixate interframe to body + return interframe + + def _fill_coordinate_list(self, coordinates, n_coords, label='q', offset=0, + number_single=False): + """Helper method for _generate_coordinates and _generate_speeds. + + Parameters + ========== + + coordinates : iterable + Iterable of coordinates or speeds that have been provided. + n_coords : Integer + Number of coordinates that should be returned. + label : String, optional + Coordinate type either 'q' (coordinates) or 'u' (speeds). The + Default is 'q'. + offset : Integer + Count offset when creating new dynamicsymbols. The default is 0. + number_single : Boolean + Boolean whether if n_coords == 1, number should still be used. The + default is False. + + """ + + def create_symbol(number): + if n_coords == 1 and not number_single: + return dynamicsymbols(f'{label}_{self.name}') + return dynamicsymbols(f'{label}{number}_{self.name}') + + name = 'generalized coordinate' if label == 'q' else 'generalized speed' + generated_coordinates = [] + if coordinates is None: + coordinates = [] + elif not iterable(coordinates): + coordinates = [coordinates] + if not (len(coordinates) == 0 or len(coordinates) == n_coords): + raise ValueError(f'Expected {n_coords} {name}s, instead got ' + f'{len(coordinates)} {name}s.') + # Supports more iterables, also Matrix + for i, coord in enumerate(coordinates): + if coord is None: + generated_coordinates.append(create_symbol(i + offset)) + elif isinstance(coord, (AppliedUndef, Derivative)): + generated_coordinates.append(coord) + else: + raise TypeError(f'The {name} {coord} should have been a ' + f'dynamicsymbol.') + for i in range(len(coordinates) + offset, n_coords + offset): + generated_coordinates.append(create_symbol(i)) + return Matrix(generated_coordinates) + + +class PinJoint(Joint): + """Pin (Revolute) Joint. + + .. raw:: html + :file: ../../../doc/src/explanation/modules/physics/mechanics/PinJoint.svg + + Explanation + =========== + + A pin joint is defined such that the joint rotation axis is fixed in both + the child and parent and the location of the joint is relative to the mass + center of each body. The child rotates an angle, θ, from the parent about + the rotation axis and has a simple angular speed, ω, relative to the + parent. The direction cosine matrix between the child interframe and + parent interframe is formed using a simple rotation about the joint axis. + The page on the joints framework gives a more detailed explanation of the + intermediate frames. + + Parameters + ========== + + name : string + A unique name for the joint. + parent : Particle or RigidBody + The parent body of joint. + child : Particle or RigidBody + The child body of joint. + coordinates : dynamicsymbol, optional + Generalized coordinates of the joint. + speeds : dynamicsymbol, optional + Generalized speeds of joint. + parent_point : Point or Vector, optional + Attachment point where the joint is fixed to the parent body. If a + vector is provided, then the attachment point is computed by adding the + vector to the body's mass center. The default value is the parent's mass + center. + child_point : Point or Vector, optional + Attachment point where the joint is fixed to the child body. If a + vector is provided, then the attachment point is computed by adding the + vector to the body's mass center. The default value is the child's mass + center. + parent_axis : Vector, optional + .. deprecated:: 1.12 + Axis fixed in the parent body which aligns with an axis fixed in the + child body. The default is the x axis of parent's reference frame. + For more information on this deprecation, see + :ref:`deprecated-mechanics-joint-axis`. + child_axis : Vector, optional + .. deprecated:: 1.12 + Axis fixed in the child body which aligns with an axis fixed in the + parent body. The default is the x axis of child's reference frame. + For more information on this deprecation, see + :ref:`deprecated-mechanics-joint-axis`. + parent_interframe : ReferenceFrame, optional + Intermediate frame of the parent body with respect to which the joint + transformation is formulated. If a Vector is provided then an interframe + is created which aligns its X axis with the given vector. The default + value is the parent's own frame. + child_interframe : ReferenceFrame, optional + Intermediate frame of the child body with respect to which the joint + transformation is formulated. If a Vector is provided then an interframe + is created which aligns its X axis with the given vector. The default + value is the child's own frame. + joint_axis : Vector + The axis about which the rotation occurs. Note that the components + of this axis are the same in the parent_interframe and child_interframe. + parent_joint_pos : Point or Vector, optional + .. deprecated:: 1.12 + This argument is replaced by parent_point and will be removed in a + future version. + See :ref:`deprecated-mechanics-joint-pos` for more information. + child_joint_pos : Point or Vector, optional + .. deprecated:: 1.12 + This argument is replaced by child_point and will be removed in a + future version. + See :ref:`deprecated-mechanics-joint-pos` for more information. + + Attributes + ========== + + name : string + The joint's name. + parent : Particle or RigidBody + The joint's parent body. + child : Particle or RigidBody + The joint's child body. + coordinates : Matrix + Matrix of the joint's generalized coordinates. The default value is + ``dynamicsymbols(f'q_{joint.name}')``. + speeds : Matrix + Matrix of the joint's generalized speeds. The default value is + ``dynamicsymbols(f'u_{joint.name}')``. + parent_point : Point + Attachment point where the joint is fixed to the parent body. + child_point : Point + Attachment point where the joint is fixed to the child body. + parent_axis : Vector + The axis fixed in the parent frame that represents the joint. + child_axis : Vector + The axis fixed in the child frame that represents the joint. + parent_interframe : ReferenceFrame + Intermediate frame of the parent body with respect to which the joint + transformation is formulated. + child_interframe : ReferenceFrame + Intermediate frame of the child body with respect to which the joint + transformation is formulated. + joint_axis : Vector + The axis about which the rotation occurs. Note that the components of + this axis are the same in the parent_interframe and child_interframe. + kdes : Matrix + Kinematical differential equations of the joint. + + Examples + ========= + + A single pin joint is created from two bodies and has the following basic + attributes: + + >>> from sympy.physics.mechanics import RigidBody, PinJoint + >>> parent = RigidBody('P') + >>> parent + P + >>> child = RigidBody('C') + >>> child + C + >>> joint = PinJoint('PC', parent, child) + >>> joint + PinJoint: PC parent: P child: C + >>> joint.name + 'PC' + >>> joint.parent + P + >>> joint.child + C + >>> joint.parent_point + P_masscenter + >>> joint.child_point + C_masscenter + >>> joint.parent_axis + P_frame.x + >>> joint.child_axis + C_frame.x + >>> joint.coordinates + Matrix([[q_PC(t)]]) + >>> joint.speeds + Matrix([[u_PC(t)]]) + >>> child.frame.ang_vel_in(parent.frame) + u_PC(t)*P_frame.x + >>> child.frame.dcm(parent.frame) + Matrix([ + [1, 0, 0], + [0, cos(q_PC(t)), sin(q_PC(t))], + [0, -sin(q_PC(t)), cos(q_PC(t))]]) + >>> joint.child_point.pos_from(joint.parent_point) + 0 + + To further demonstrate the use of the pin joint, the kinematics of simple + double pendulum that rotates about the Z axis of each connected body can be + created as follows. + + >>> from sympy import symbols, trigsimp + >>> from sympy.physics.mechanics import RigidBody, PinJoint + >>> l1, l2 = symbols('l1 l2') + + First create bodies to represent the fixed ceiling and one to represent + each pendulum bob. + + >>> ceiling = RigidBody('C') + >>> upper_bob = RigidBody('U') + >>> lower_bob = RigidBody('L') + + The first joint will connect the upper bob to the ceiling by a distance of + ``l1`` and the joint axis will be about the Z axis for each body. + + >>> ceiling_joint = PinJoint('P1', ceiling, upper_bob, + ... child_point=-l1*upper_bob.frame.x, + ... joint_axis=ceiling.frame.z) + + The second joint will connect the lower bob to the upper bob by a distance + of ``l2`` and the joint axis will also be about the Z axis for each body. + + >>> pendulum_joint = PinJoint('P2', upper_bob, lower_bob, + ... child_point=-l2*lower_bob.frame.x, + ... joint_axis=upper_bob.frame.z) + + Once the joints are established the kinematics of the connected bodies can + be accessed. First the direction cosine matrices of pendulum link relative + to the ceiling are found: + + >>> upper_bob.frame.dcm(ceiling.frame) + Matrix([ + [ cos(q_P1(t)), sin(q_P1(t)), 0], + [-sin(q_P1(t)), cos(q_P1(t)), 0], + [ 0, 0, 1]]) + >>> trigsimp(lower_bob.frame.dcm(ceiling.frame)) + Matrix([ + [ cos(q_P1(t) + q_P2(t)), sin(q_P1(t) + q_P2(t)), 0], + [-sin(q_P1(t) + q_P2(t)), cos(q_P1(t) + q_P2(t)), 0], + [ 0, 0, 1]]) + + The position of the lower bob's masscenter is found with: + + >>> lower_bob.masscenter.pos_from(ceiling.masscenter) + l1*U_frame.x + l2*L_frame.x + + The angular velocities of the two pendulum links can be computed with + respect to the ceiling. + + >>> upper_bob.frame.ang_vel_in(ceiling.frame) + u_P1(t)*C_frame.z + >>> lower_bob.frame.ang_vel_in(ceiling.frame) + u_P1(t)*C_frame.z + u_P2(t)*U_frame.z + + And finally, the linear velocities of the two pendulum bobs can be computed + with respect to the ceiling. + + >>> upper_bob.masscenter.vel(ceiling.frame) + l1*u_P1(t)*U_frame.y + >>> lower_bob.masscenter.vel(ceiling.frame) + l1*u_P1(t)*U_frame.y + l2*(u_P1(t) + u_P2(t))*L_frame.y + + """ + + def __init__(self, name, parent, child, coordinates=None, speeds=None, + parent_point=None, child_point=None, parent_interframe=None, + child_interframe=None, parent_axis=None, child_axis=None, + joint_axis=None, parent_joint_pos=None, child_joint_pos=None): + + self._joint_axis = joint_axis + super().__init__(name, parent, child, coordinates, speeds, parent_point, + child_point, parent_interframe, child_interframe, + parent_axis, child_axis, parent_joint_pos, + child_joint_pos) + + def __str__(self): + return (f'PinJoint: {self.name} parent: {self.parent} ' + f'child: {self.child}') + + @property + def joint_axis(self): + """Axis about which the child rotates with respect to the parent.""" + return self._joint_axis + + def _generate_coordinates(self, coordinate): + return self._fill_coordinate_list(coordinate, 1, 'q') + + def _generate_speeds(self, speed): + return self._fill_coordinate_list(speed, 1, 'u') + + def _orient_frames(self): + self._joint_axis = self._axis(self.joint_axis, self.parent_interframe) + self.child_interframe.orient_axis( + self.parent_interframe, self.joint_axis, self.coordinates[0]) + + def _set_angular_velocity(self): + self.child_interframe.set_ang_vel(self.parent_interframe, self.speeds[ + 0] * self.joint_axis.normalize()) + + def _set_linear_velocity(self): + self.child_point.set_pos(self.parent_point, 0) + self.parent_point.set_vel(self._parent_frame, 0) + self.child_point.set_vel(self._child_frame, 0) + self.child.masscenter.v2pt_theory(self.parent_point, + self._parent_frame, self._child_frame) + + +class PrismaticJoint(Joint): + """Prismatic (Sliding) Joint. + + .. image:: PrismaticJoint.svg + + Explanation + =========== + + It is defined such that the child body translates with respect to the parent + body along the body-fixed joint axis. The location of the joint is defined + by two points, one in each body, which coincide when the generalized + coordinate is zero. The direction cosine matrix between the + parent_interframe and child_interframe is the identity matrix. Therefore, + the direction cosine matrix between the parent and child frames is fully + defined by the definition of the intermediate frames. The page on the joints + framework gives a more detailed explanation of the intermediate frames. + + Parameters + ========== + + name : string + A unique name for the joint. + parent : Particle or RigidBody + The parent body of joint. + child : Particle or RigidBody + The child body of joint. + coordinates : dynamicsymbol, optional + Generalized coordinates of the joint. The default value is + ``dynamicsymbols(f'q_{joint.name}')``. + speeds : dynamicsymbol, optional + Generalized speeds of joint. The default value is + ``dynamicsymbols(f'u_{joint.name}')``. + parent_point : Point or Vector, optional + Attachment point where the joint is fixed to the parent body. If a + vector is provided, then the attachment point is computed by adding the + vector to the body's mass center. The default value is the parent's mass + center. + child_point : Point or Vector, optional + Attachment point where the joint is fixed to the child body. If a + vector is provided, then the attachment point is computed by adding the + vector to the body's mass center. The default value is the child's mass + center. + parent_axis : Vector, optional + .. deprecated:: 1.12 + Axis fixed in the parent body which aligns with an axis fixed in the + child body. The default is the x axis of parent's reference frame. + For more information on this deprecation, see + :ref:`deprecated-mechanics-joint-axis`. + child_axis : Vector, optional + .. deprecated:: 1.12 + Axis fixed in the child body which aligns with an axis fixed in the + parent body. The default is the x axis of child's reference frame. + For more information on this deprecation, see + :ref:`deprecated-mechanics-joint-axis`. + parent_interframe : ReferenceFrame, optional + Intermediate frame of the parent body with respect to which the joint + transformation is formulated. If a Vector is provided then an interframe + is created which aligns its X axis with the given vector. The default + value is the parent's own frame. + child_interframe : ReferenceFrame, optional + Intermediate frame of the child body with respect to which the joint + transformation is formulated. If a Vector is provided then an interframe + is created which aligns its X axis with the given vector. The default + value is the child's own frame. + joint_axis : Vector + The axis along which the translation occurs. Note that the components + of this axis are the same in the parent_interframe and child_interframe. + parent_joint_pos : Point or Vector, optional + .. deprecated:: 1.12 + This argument is replaced by parent_point and will be removed in a + future version. + See :ref:`deprecated-mechanics-joint-pos` for more information. + child_joint_pos : Point or Vector, optional + .. deprecated:: 1.12 + This argument is replaced by child_point and will be removed in a + future version. + See :ref:`deprecated-mechanics-joint-pos` for more information. + + Attributes + ========== + + name : string + The joint's name. + parent : Particle or RigidBody + The joint's parent body. + child : Particle or RigidBody + The joint's child body. + coordinates : Matrix + Matrix of the joint's generalized coordinates. + speeds : Matrix + Matrix of the joint's generalized speeds. + parent_point : Point + Attachment point where the joint is fixed to the parent body. + child_point : Point + Attachment point where the joint is fixed to the child body. + parent_axis : Vector + The axis fixed in the parent frame that represents the joint. + child_axis : Vector + The axis fixed in the child frame that represents the joint. + parent_interframe : ReferenceFrame + Intermediate frame of the parent body with respect to which the joint + transformation is formulated. + child_interframe : ReferenceFrame + Intermediate frame of the child body with respect to which the joint + transformation is formulated. + kdes : Matrix + Kinematical differential equations of the joint. + + Examples + ========= + + A single prismatic joint is created from two bodies and has the following + basic attributes: + + >>> from sympy.physics.mechanics import RigidBody, PrismaticJoint + >>> parent = RigidBody('P') + >>> parent + P + >>> child = RigidBody('C') + >>> child + C + >>> joint = PrismaticJoint('PC', parent, child) + >>> joint + PrismaticJoint: PC parent: P child: C + >>> joint.name + 'PC' + >>> joint.parent + P + >>> joint.child + C + >>> joint.parent_point + P_masscenter + >>> joint.child_point + C_masscenter + >>> joint.parent_axis + P_frame.x + >>> joint.child_axis + C_frame.x + >>> joint.coordinates + Matrix([[q_PC(t)]]) + >>> joint.speeds + Matrix([[u_PC(t)]]) + >>> child.frame.ang_vel_in(parent.frame) + 0 + >>> child.frame.dcm(parent.frame) + Matrix([ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1]]) + >>> joint.child_point.pos_from(joint.parent_point) + q_PC(t)*P_frame.x + + To further demonstrate the use of the prismatic joint, the kinematics of two + masses sliding, one moving relative to a fixed body and the other relative + to the moving body. about the X axis of each connected body can be created + as follows. + + >>> from sympy.physics.mechanics import PrismaticJoint, RigidBody + + First create bodies to represent the fixed ceiling and one to represent + a particle. + + >>> wall = RigidBody('W') + >>> Part1 = RigidBody('P1') + >>> Part2 = RigidBody('P2') + + The first joint will connect the particle to the ceiling and the + joint axis will be about the X axis for each body. + + >>> J1 = PrismaticJoint('J1', wall, Part1) + + The second joint will connect the second particle to the first particle + and the joint axis will also be about the X axis for each body. + + >>> J2 = PrismaticJoint('J2', Part1, Part2) + + Once the joint is established the kinematics of the connected bodies can + be accessed. First the direction cosine matrices of Part relative + to the ceiling are found: + + >>> Part1.frame.dcm(wall.frame) + Matrix([ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1]]) + + >>> Part2.frame.dcm(wall.frame) + Matrix([ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1]]) + + The position of the particles' masscenter is found with: + + >>> Part1.masscenter.pos_from(wall.masscenter) + q_J1(t)*W_frame.x + + >>> Part2.masscenter.pos_from(wall.masscenter) + q_J1(t)*W_frame.x + q_J2(t)*P1_frame.x + + The angular velocities of the two particle links can be computed with + respect to the ceiling. + + >>> Part1.frame.ang_vel_in(wall.frame) + 0 + + >>> Part2.frame.ang_vel_in(wall.frame) + 0 + + And finally, the linear velocities of the two particles can be computed + with respect to the ceiling. + + >>> Part1.masscenter.vel(wall.frame) + u_J1(t)*W_frame.x + + >>> Part2.masscenter.vel(wall.frame) + u_J1(t)*W_frame.x + Derivative(q_J2(t), t)*P1_frame.x + + """ + + def __init__(self, name, parent, child, coordinates=None, speeds=None, + parent_point=None, child_point=None, parent_interframe=None, + child_interframe=None, parent_axis=None, child_axis=None, + joint_axis=None, parent_joint_pos=None, child_joint_pos=None): + + self._joint_axis = joint_axis + super().__init__(name, parent, child, coordinates, speeds, parent_point, + child_point, parent_interframe, child_interframe, + parent_axis, child_axis, parent_joint_pos, + child_joint_pos) + + def __str__(self): + return (f'PrismaticJoint: {self.name} parent: {self.parent} ' + f'child: {self.child}') + + @property + def joint_axis(self): + """Axis along which the child translates with respect to the parent.""" + return self._joint_axis + + def _generate_coordinates(self, coordinate): + return self._fill_coordinate_list(coordinate, 1, 'q') + + def _generate_speeds(self, speed): + return self._fill_coordinate_list(speed, 1, 'u') + + def _orient_frames(self): + self._joint_axis = self._axis(self.joint_axis, self.parent_interframe) + self.child_interframe.orient_axis( + self.parent_interframe, self.joint_axis, 0) + + def _set_angular_velocity(self): + self.child_interframe.set_ang_vel(self.parent_interframe, 0) + + def _set_linear_velocity(self): + axis = self.joint_axis.normalize() + self.child_point.set_pos(self.parent_point, self.coordinates[0] * axis) + self.parent_point.set_vel(self._parent_frame, 0) + self.child_point.set_vel(self._child_frame, 0) + self.child_point.set_vel(self._parent_frame, self.speeds[0] * axis) + self.child.masscenter.set_vel(self._parent_frame, self.speeds[0] * axis) + + +class CylindricalJoint(Joint): + """Cylindrical Joint. + + .. image:: CylindricalJoint.svg + :align: center + :width: 600 + + Explanation + =========== + + A cylindrical joint is defined such that the child body both rotates about + and translates along the body-fixed joint axis with respect to the parent + body. The joint axis is both the rotation axis and translation axis. The + location of the joint is defined by two points, one in each body, which + coincide when the generalized coordinate corresponding to the translation is + zero. The direction cosine matrix between the child interframe and parent + interframe is formed using a simple rotation about the joint axis. The page + on the joints framework gives a more detailed explanation of the + intermediate frames. + + Parameters + ========== + + name : string + A unique name for the joint. + parent : Particle or RigidBody + The parent body of joint. + child : Particle or RigidBody + The child body of joint. + rotation_coordinate : dynamicsymbol, optional + Generalized coordinate corresponding to the rotation angle. The default + value is ``dynamicsymbols(f'q0_{joint.name}')``. + translation_coordinate : dynamicsymbol, optional + Generalized coordinate corresponding to the translation distance. The + default value is ``dynamicsymbols(f'q1_{joint.name}')``. + rotation_speed : dynamicsymbol, optional + Generalized speed corresponding to the angular velocity. The default + value is ``dynamicsymbols(f'u0_{joint.name}')``. + translation_speed : dynamicsymbol, optional + Generalized speed corresponding to the translation velocity. The default + value is ``dynamicsymbols(f'u1_{joint.name}')``. + parent_point : Point or Vector, optional + Attachment point where the joint is fixed to the parent body. If a + vector is provided, then the attachment point is computed by adding the + vector to the body's mass center. The default value is the parent's mass + center. + child_point : Point or Vector, optional + Attachment point where the joint is fixed to the child body. If a + vector is provided, then the attachment point is computed by adding the + vector to the body's mass center. The default value is the child's mass + center. + parent_interframe : ReferenceFrame, optional + Intermediate frame of the parent body with respect to which the joint + transformation is formulated. If a Vector is provided then an interframe + is created which aligns its X axis with the given vector. The default + value is the parent's own frame. + child_interframe : ReferenceFrame, optional + Intermediate frame of the child body with respect to which the joint + transformation is formulated. If a Vector is provided then an interframe + is created which aligns its X axis with the given vector. The default + value is the child's own frame. + joint_axis : Vector, optional + The rotation as well as translation axis. Note that the components of + this axis are the same in the parent_interframe and child_interframe. + + Attributes + ========== + + name : string + The joint's name. + parent : Particle or RigidBody + The joint's parent body. + child : Particle or RigidBody + The joint's child body. + rotation_coordinate : dynamicsymbol + Generalized coordinate corresponding to the rotation angle. + translation_coordinate : dynamicsymbol + Generalized coordinate corresponding to the translation distance. + rotation_speed : dynamicsymbol + Generalized speed corresponding to the angular velocity. + translation_speed : dynamicsymbol + Generalized speed corresponding to the translation velocity. + coordinates : Matrix + Matrix of the joint's generalized coordinates. + speeds : Matrix + Matrix of the joint's generalized speeds. + parent_point : Point + Attachment point where the joint is fixed to the parent body. + child_point : Point + Attachment point where the joint is fixed to the child body. + parent_interframe : ReferenceFrame + Intermediate frame of the parent body with respect to which the joint + transformation is formulated. + child_interframe : ReferenceFrame + Intermediate frame of the child body with respect to which the joint + transformation is formulated. + kdes : Matrix + Kinematical differential equations of the joint. + joint_axis : Vector + The axis of rotation and translation. + + Examples + ========= + + A single cylindrical joint is created between two bodies and has the + following basic attributes: + + >>> from sympy.physics.mechanics import RigidBody, CylindricalJoint + >>> parent = RigidBody('P') + >>> parent + P + >>> child = RigidBody('C') + >>> child + C + >>> joint = CylindricalJoint('PC', parent, child) + >>> joint + CylindricalJoint: PC parent: P child: C + >>> joint.name + 'PC' + >>> joint.parent + P + >>> joint.child + C + >>> joint.parent_point + P_masscenter + >>> joint.child_point + C_masscenter + >>> joint.parent_axis + P_frame.x + >>> joint.child_axis + C_frame.x + >>> joint.coordinates + Matrix([ + [q0_PC(t)], + [q1_PC(t)]]) + >>> joint.speeds + Matrix([ + [u0_PC(t)], + [u1_PC(t)]]) + >>> child.frame.ang_vel_in(parent.frame) + u0_PC(t)*P_frame.x + >>> child.frame.dcm(parent.frame) + Matrix([ + [1, 0, 0], + [0, cos(q0_PC(t)), sin(q0_PC(t))], + [0, -sin(q0_PC(t)), cos(q0_PC(t))]]) + >>> joint.child_point.pos_from(joint.parent_point) + q1_PC(t)*P_frame.x + >>> child.masscenter.vel(parent.frame) + u1_PC(t)*P_frame.x + + To further demonstrate the use of the cylindrical joint, the kinematics of + two cylindrical joints perpendicular to each other can be created as follows. + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import RigidBody, CylindricalJoint + >>> r, l, w = symbols('r l w') + + First create bodies to represent the fixed floor with a fixed pole on it. + The second body represents a freely moving tube around that pole. The third + body represents a solid flag freely translating along and rotating around + the Y axis of the tube. + + >>> floor = RigidBody('floor') + >>> tube = RigidBody('tube') + >>> flag = RigidBody('flag') + + The first joint will connect the first tube to the floor with it translating + along and rotating around the Z axis of both bodies. + + >>> floor_joint = CylindricalJoint('C1', floor, tube, joint_axis=floor.z) + + The second joint will connect the tube perpendicular to the flag along the Y + axis of both the tube and the flag, with the joint located at a distance + ``r`` from the tube's center of mass and a combination of the distances + ``l`` and ``w`` from the flag's center of mass. + + >>> flag_joint = CylindricalJoint('C2', tube, flag, + ... parent_point=r * tube.y, + ... child_point=-w * flag.y + l * flag.z, + ... joint_axis=tube.y) + + Once the joints are established the kinematics of the connected bodies can + be accessed. First the direction cosine matrices of both the body and the + flag relative to the floor are found: + + >>> tube.frame.dcm(floor.frame) + Matrix([ + [ cos(q0_C1(t)), sin(q0_C1(t)), 0], + [-sin(q0_C1(t)), cos(q0_C1(t)), 0], + [ 0, 0, 1]]) + >>> flag.frame.dcm(floor.frame) + Matrix([ + [cos(q0_C1(t))*cos(q0_C2(t)), sin(q0_C1(t))*cos(q0_C2(t)), -sin(q0_C2(t))], + [ -sin(q0_C1(t)), cos(q0_C1(t)), 0], + [sin(q0_C2(t))*cos(q0_C1(t)), sin(q0_C1(t))*sin(q0_C2(t)), cos(q0_C2(t))]]) + + The position of the flag's center of mass is found with: + + >>> flag.masscenter.pos_from(floor.masscenter) + q1_C1(t)*floor_frame.z + (r + q1_C2(t))*tube_frame.y + w*flag_frame.y - l*flag_frame.z + + The angular velocities of the two tubes can be computed with respect to the + floor. + + >>> tube.frame.ang_vel_in(floor.frame) + u0_C1(t)*floor_frame.z + >>> flag.frame.ang_vel_in(floor.frame) + u0_C1(t)*floor_frame.z + u0_C2(t)*tube_frame.y + + Finally, the linear velocities of the two tube centers of mass can be + computed with respect to the floor, while expressed in the tube's frame. + + >>> tube.masscenter.vel(floor.frame).to_matrix(tube.frame) + Matrix([ + [ 0], + [ 0], + [u1_C1(t)]]) + >>> flag.masscenter.vel(floor.frame).to_matrix(tube.frame).simplify() + Matrix([ + [-l*u0_C2(t)*cos(q0_C2(t)) - r*u0_C1(t) - w*u0_C1(t) - q1_C2(t)*u0_C1(t)], + [ -l*u0_C1(t)*sin(q0_C2(t)) + Derivative(q1_C2(t), t)], + [ l*u0_C2(t)*sin(q0_C2(t)) + u1_C1(t)]]) + + """ + + def __init__(self, name, parent, child, rotation_coordinate=None, + translation_coordinate=None, rotation_speed=None, + translation_speed=None, parent_point=None, child_point=None, + parent_interframe=None, child_interframe=None, + joint_axis=None): + self._joint_axis = joint_axis + coordinates = (rotation_coordinate, translation_coordinate) + speeds = (rotation_speed, translation_speed) + super().__init__(name, parent, child, coordinates, speeds, + parent_point, child_point, + parent_interframe=parent_interframe, + child_interframe=child_interframe) + + def __str__(self): + return (f'CylindricalJoint: {self.name} parent: {self.parent} ' + f'child: {self.child}') + + @property + def joint_axis(self): + """Axis about and along which the rotation and translation occurs.""" + return self._joint_axis + + @property + def rotation_coordinate(self): + """Generalized coordinate corresponding to the rotation angle.""" + return self.coordinates[0] + + @property + def translation_coordinate(self): + """Generalized coordinate corresponding to the translation distance.""" + return self.coordinates[1] + + @property + def rotation_speed(self): + """Generalized speed corresponding to the angular velocity.""" + return self.speeds[0] + + @property + def translation_speed(self): + """Generalized speed corresponding to the translation velocity.""" + return self.speeds[1] + + def _generate_coordinates(self, coordinates): + return self._fill_coordinate_list(coordinates, 2, 'q') + + def _generate_speeds(self, speeds): + return self._fill_coordinate_list(speeds, 2, 'u') + + def _orient_frames(self): + self._joint_axis = self._axis(self.joint_axis, self.parent_interframe) + self.child_interframe.orient_axis( + self.parent_interframe, self.joint_axis, self.rotation_coordinate) + + def _set_angular_velocity(self): + self.child_interframe.set_ang_vel( + self.parent_interframe, + self.rotation_speed * self.joint_axis.normalize()) + + def _set_linear_velocity(self): + self.child_point.set_pos( + self.parent_point, + self.translation_coordinate * self.joint_axis.normalize()) + self.parent_point.set_vel(self._parent_frame, 0) + self.child_point.set_vel(self._child_frame, 0) + self.child_point.set_vel( + self._parent_frame, + self.translation_speed * self.joint_axis.normalize()) + self.child.masscenter.v2pt_theory(self.child_point, self._parent_frame, + self.child_interframe) + + +class PlanarJoint(Joint): + """Planar Joint. + + .. raw:: html + :file: ../../../doc/src/modules/physics/mechanics/api/PlanarJoint.svg + + Explanation + =========== + + A planar joint is defined such that the child body translates over a fixed + plane of the parent body as well as rotate about the rotation axis, which + is perpendicular to that plane. The origin of this plane is the + ``parent_point`` and the plane is spanned by two nonparallel planar vectors. + The location of the ``child_point`` is based on the planar vectors + ($\\vec{v}_1$, $\\vec{v}_2$) and generalized coordinates ($q_1$, $q_2$), + i.e. $\\vec{r} = q_1 \\hat{v}_1 + q_2 \\hat{v}_2$. The direction cosine + matrix between the ``child_interframe`` and ``parent_interframe`` is formed + using a simple rotation ($q_0$) about the rotation axis. + + In order to simplify the definition of the ``PlanarJoint``, the + ``rotation_axis`` and ``planar_vectors`` are set to be the unit vectors of + the ``parent_interframe`` according to the table below. This ensures that + you can only define these vectors by creating a separate frame and supplying + that as the interframe. If you however would only like to supply the normals + of the plane with respect to the parent and child bodies, then you can also + supply those to the ``parent_interframe`` and ``child_interframe`` + arguments. An example of both of these cases is in the examples section + below and the page on the joints framework provides a more detailed + explanation of the intermediate frames. + + .. list-table:: + + * - ``rotation_axis`` + - ``parent_interframe.x`` + * - ``planar_vectors[0]`` + - ``parent_interframe.y`` + * - ``planar_vectors[1]`` + - ``parent_interframe.z`` + + Parameters + ========== + + name : string + A unique name for the joint. + parent : Particle or RigidBody + The parent body of joint. + child : Particle or RigidBody + The child body of joint. + rotation_coordinate : dynamicsymbol, optional + Generalized coordinate corresponding to the rotation angle. The default + value is ``dynamicsymbols(f'q0_{joint.name}')``. + planar_coordinates : iterable of dynamicsymbols, optional + Two generalized coordinates used for the planar translation. The default + value is ``dynamicsymbols(f'q1_{joint.name} q2_{joint.name}')``. + rotation_speed : dynamicsymbol, optional + Generalized speed corresponding to the angular velocity. The default + value is ``dynamicsymbols(f'u0_{joint.name}')``. + planar_speeds : dynamicsymbols, optional + Two generalized speeds used for the planar translation velocity. The + default value is ``dynamicsymbols(f'u1_{joint.name} u2_{joint.name}')``. + parent_point : Point or Vector, optional + Attachment point where the joint is fixed to the parent body. If a + vector is provided, then the attachment point is computed by adding the + vector to the body's mass center. The default value is the parent's mass + center. + child_point : Point or Vector, optional + Attachment point where the joint is fixed to the child body. If a + vector is provided, then the attachment point is computed by adding the + vector to the body's mass center. The default value is the child's mass + center. + parent_interframe : ReferenceFrame, optional + Intermediate frame of the parent body with respect to which the joint + transformation is formulated. If a Vector is provided then an interframe + is created which aligns its X axis with the given vector. The default + value is the parent's own frame. + child_interframe : ReferenceFrame, optional + Intermediate frame of the child body with respect to which the joint + transformation is formulated. If a Vector is provided then an interframe + is created which aligns its X axis with the given vector. The default + value is the child's own frame. + + Attributes + ========== + + name : string + The joint's name. + parent : Particle or RigidBody + The joint's parent body. + child : Particle or RigidBody + The joint's child body. + rotation_coordinate : dynamicsymbol + Generalized coordinate corresponding to the rotation angle. + planar_coordinates : Matrix + Two generalized coordinates used for the planar translation. + rotation_speed : dynamicsymbol + Generalized speed corresponding to the angular velocity. + planar_speeds : Matrix + Two generalized speeds used for the planar translation velocity. + coordinates : Matrix + Matrix of the joint's generalized coordinates. + speeds : Matrix + Matrix of the joint's generalized speeds. + parent_point : Point + Attachment point where the joint is fixed to the parent body. + child_point : Point + Attachment point where the joint is fixed to the child body. + parent_interframe : ReferenceFrame + Intermediate frame of the parent body with respect to which the joint + transformation is formulated. + child_interframe : ReferenceFrame + Intermediate frame of the child body with respect to which the joint + transformation is formulated. + kdes : Matrix + Kinematical differential equations of the joint. + rotation_axis : Vector + The axis about which the rotation occurs. + planar_vectors : list + The vectors that describe the planar translation directions. + + Examples + ========= + + A single planar joint is created between two bodies and has the following + basic attributes: + + >>> from sympy.physics.mechanics import RigidBody, PlanarJoint + >>> parent = RigidBody('P') + >>> parent + P + >>> child = RigidBody('C') + >>> child + C + >>> joint = PlanarJoint('PC', parent, child) + >>> joint + PlanarJoint: PC parent: P child: C + >>> joint.name + 'PC' + >>> joint.parent + P + >>> joint.child + C + >>> joint.parent_point + P_masscenter + >>> joint.child_point + C_masscenter + >>> joint.rotation_axis + P_frame.x + >>> joint.planar_vectors + [P_frame.y, P_frame.z] + >>> joint.rotation_coordinate + q0_PC(t) + >>> joint.planar_coordinates + Matrix([ + [q1_PC(t)], + [q2_PC(t)]]) + >>> joint.coordinates + Matrix([ + [q0_PC(t)], + [q1_PC(t)], + [q2_PC(t)]]) + >>> joint.rotation_speed + u0_PC(t) + >>> joint.planar_speeds + Matrix([ + [u1_PC(t)], + [u2_PC(t)]]) + >>> joint.speeds + Matrix([ + [u0_PC(t)], + [u1_PC(t)], + [u2_PC(t)]]) + >>> child.frame.ang_vel_in(parent.frame) + u0_PC(t)*P_frame.x + >>> child.frame.dcm(parent.frame) + Matrix([ + [1, 0, 0], + [0, cos(q0_PC(t)), sin(q0_PC(t))], + [0, -sin(q0_PC(t)), cos(q0_PC(t))]]) + >>> joint.child_point.pos_from(joint.parent_point) + q1_PC(t)*P_frame.y + q2_PC(t)*P_frame.z + >>> child.masscenter.vel(parent.frame) + u1_PC(t)*P_frame.y + u2_PC(t)*P_frame.z + + To further demonstrate the use of the planar joint, the kinematics of a + block sliding on a slope, can be created as follows. + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import PlanarJoint, RigidBody, ReferenceFrame + >>> a, d, h = symbols('a d h') + + First create bodies to represent the slope and the block. + + >>> ground = RigidBody('G') + >>> block = RigidBody('B') + + To define the slope you can either define the plane by specifying the + ``planar_vectors`` or/and the ``rotation_axis``. However it is advisable to + create a rotated intermediate frame, so that the ``parent_vectors`` and + ``rotation_axis`` will be the unit vectors of this intermediate frame. + + >>> slope = ReferenceFrame('A') + >>> slope.orient_axis(ground.frame, ground.y, a) + + The planar joint can be created using these bodies and intermediate frame. + We can specify the origin of the slope to be ``d`` above the slope's center + of mass and the block's center of mass to be a distance ``h`` above the + slope's surface. Note that we can specify the normal of the plane using the + rotation axis argument. + + >>> joint = PlanarJoint('PC', ground, block, parent_point=d * ground.x, + ... child_point=-h * block.x, parent_interframe=slope) + + Once the joint is established the kinematics of the bodies can be accessed. + First the ``rotation_axis``, which is normal to the plane and the + ``plane_vectors``, can be found. + + >>> joint.rotation_axis + A.x + >>> joint.planar_vectors + [A.y, A.z] + + The direction cosine matrix of the block with respect to the ground can be + found with: + + >>> block.frame.dcm(ground.frame) + Matrix([ + [ cos(a), 0, -sin(a)], + [sin(a)*sin(q0_PC(t)), cos(q0_PC(t)), sin(q0_PC(t))*cos(a)], + [sin(a)*cos(q0_PC(t)), -sin(q0_PC(t)), cos(a)*cos(q0_PC(t))]]) + + The angular velocity of the block can be computed with respect to the + ground. + + >>> block.frame.ang_vel_in(ground.frame) + u0_PC(t)*A.x + + The position of the block's center of mass can be found with: + + >>> block.masscenter.pos_from(ground.masscenter) + d*G_frame.x + h*B_frame.x + q1_PC(t)*A.y + q2_PC(t)*A.z + + Finally, the linear velocity of the block's center of mass can be + computed with respect to the ground. + + >>> block.masscenter.vel(ground.frame) + u1_PC(t)*A.y + u2_PC(t)*A.z + + In some cases it could be your preference to only define the normals of the + plane with respect to both bodies. This can most easily be done by supplying + vectors to the ``interframe`` arguments. What will happen in this case is + that an interframe will be created with its ``x`` axis aligned with the + provided vector. For a further explanation of how this is done see the notes + of the ``Joint`` class. In the code below, the above example (with the block + on the slope) is recreated by supplying vectors to the interframe arguments. + Note that the previously described option is however more computationally + efficient, because the algorithm now has to compute the rotation angle + between the provided vector and the 'x' axis. + + >>> from sympy import symbols, cos, sin + >>> from sympy.physics.mechanics import PlanarJoint, RigidBody + >>> a, d, h = symbols('a d h') + >>> ground = RigidBody('G') + >>> block = RigidBody('B') + >>> joint = PlanarJoint( + ... 'PC', ground, block, parent_point=d * ground.x, + ... child_point=-h * block.x, child_interframe=block.x, + ... parent_interframe=cos(a) * ground.x + sin(a) * ground.z) + >>> block.frame.dcm(ground.frame).simplify() + Matrix([ + [ cos(a), 0, sin(a)], + [-sin(a)*sin(q0_PC(t)), cos(q0_PC(t)), sin(q0_PC(t))*cos(a)], + [-sin(a)*cos(q0_PC(t)), -sin(q0_PC(t)), cos(a)*cos(q0_PC(t))]]) + + """ + + def __init__(self, name, parent, child, rotation_coordinate=None, + planar_coordinates=None, rotation_speed=None, + planar_speeds=None, parent_point=None, child_point=None, + parent_interframe=None, child_interframe=None): + # A ready to merge implementation of setting the planar_vectors and + # rotation_axis was added and removed in PR #24046 + coordinates = (rotation_coordinate, planar_coordinates) + speeds = (rotation_speed, planar_speeds) + super().__init__(name, parent, child, coordinates, speeds, + parent_point, child_point, + parent_interframe=parent_interframe, + child_interframe=child_interframe) + + def __str__(self): + return (f'PlanarJoint: {self.name} parent: {self.parent} ' + f'child: {self.child}') + + @property + def rotation_coordinate(self): + """Generalized coordinate corresponding to the rotation angle.""" + return self.coordinates[0] + + @property + def planar_coordinates(self): + """Two generalized coordinates used for the planar translation.""" + return self.coordinates[1:, 0] + + @property + def rotation_speed(self): + """Generalized speed corresponding to the angular velocity.""" + return self.speeds[0] + + @property + def planar_speeds(self): + """Two generalized speeds used for the planar translation velocity.""" + return self.speeds[1:, 0] + + @property + def rotation_axis(self): + """The axis about which the rotation occurs.""" + return self.parent_interframe.x + + @property + def planar_vectors(self): + """The vectors that describe the planar translation directions.""" + return [self.parent_interframe.y, self.parent_interframe.z] + + def _generate_coordinates(self, coordinates): + rotation_speed = self._fill_coordinate_list(coordinates[0], 1, 'q', + number_single=True) + planar_speeds = self._fill_coordinate_list(coordinates[1], 2, 'q', 1) + return rotation_speed.col_join(planar_speeds) + + def _generate_speeds(self, speeds): + rotation_speed = self._fill_coordinate_list(speeds[0], 1, 'u', + number_single=True) + planar_speeds = self._fill_coordinate_list(speeds[1], 2, 'u', 1) + return rotation_speed.col_join(planar_speeds) + + def _orient_frames(self): + self.child_interframe.orient_axis( + self.parent_interframe, self.rotation_axis, + self.rotation_coordinate) + + def _set_angular_velocity(self): + self.child_interframe.set_ang_vel( + self.parent_interframe, + self.rotation_speed * self.rotation_axis) + + def _set_linear_velocity(self): + self.child_point.set_pos( + self.parent_point, + self.planar_coordinates[0] * self.planar_vectors[0] + + self.planar_coordinates[1] * self.planar_vectors[1]) + self.parent_point.set_vel(self.parent_interframe, 0) + self.child_point.set_vel(self.child_interframe, 0) + self.child_point.set_vel( + self._parent_frame, self.planar_speeds[0] * self.planar_vectors[0] + + self.planar_speeds[1] * self.planar_vectors[1]) + self.child.masscenter.v2pt_theory(self.child_point, self._parent_frame, + self._child_frame) + + +class SphericalJoint(Joint): + """Spherical (Ball-and-Socket) Joint. + + .. image:: SphericalJoint.svg + :align: center + :width: 600 + + Explanation + =========== + + A spherical joint is defined such that the child body is free to rotate in + any direction, without allowing a translation of the ``child_point``. As can + also be seen in the image, the ``parent_point`` and ``child_point`` are + fixed on top of each other, i.e. the ``joint_point``. This rotation is + defined using the :func:`parent_interframe.orient(child_interframe, + rot_type, amounts, rot_order) + ` method. The default + rotation consists of three relative rotations, i.e. body-fixed rotations. + Based on the direction cosine matrix following from these rotations, the + angular velocity is computed based on the generalized coordinates and + generalized speeds. + + Parameters + ========== + + name : string + A unique name for the joint. + parent : Particle or RigidBody + The parent body of joint. + child : Particle or RigidBody + The child body of joint. + coordinates: iterable of dynamicsymbols, optional + Generalized coordinates of the joint. + speeds : iterable of dynamicsymbols, optional + Generalized speeds of joint. + parent_point : Point or Vector, optional + Attachment point where the joint is fixed to the parent body. If a + vector is provided, then the attachment point is computed by adding the + vector to the body's mass center. The default value is the parent's mass + center. + child_point : Point or Vector, optional + Attachment point where the joint is fixed to the child body. If a + vector is provided, then the attachment point is computed by adding the + vector to the body's mass center. The default value is the child's mass + center. + parent_interframe : ReferenceFrame, optional + Intermediate frame of the parent body with respect to which the joint + transformation is formulated. If a Vector is provided then an interframe + is created which aligns its X axis with the given vector. The default + value is the parent's own frame. + child_interframe : ReferenceFrame, optional + Intermediate frame of the child body with respect to which the joint + transformation is formulated. If a Vector is provided then an interframe + is created which aligns its X axis with the given vector. The default + value is the child's own frame. + rot_type : str, optional + The method used to generate the direction cosine matrix. Supported + methods are: + + - ``'Body'``: three successive rotations about new intermediate axes, + also called "Euler and Tait-Bryan angles" + - ``'Space'``: three successive rotations about the parent frames' unit + vectors + + The default method is ``'Body'``. + amounts : + Expressions defining the rotation angles or direction cosine matrix. + These must match the ``rot_type``. See examples below for details. The + input types are: + + - ``'Body'``: 3-tuple of expressions, symbols, or functions + - ``'Space'``: 3-tuple of expressions, symbols, or functions + + The default amounts are the given ``coordinates``. + rot_order : str or int, optional + If applicable, the order of the successive of rotations. The string + ``'123'`` and integer ``123`` are equivalent, for example. Required for + ``'Body'`` and ``'Space'``. The default value is ``123``. + + Attributes + ========== + + name : string + The joint's name. + parent : Particle or RigidBody + The joint's parent body. + child : Particle or RigidBody + The joint's child body. + coordinates : Matrix + Matrix of the joint's generalized coordinates. + speeds : Matrix + Matrix of the joint's generalized speeds. + parent_point : Point + Attachment point where the joint is fixed to the parent body. + child_point : Point + Attachment point where the joint is fixed to the child body. + parent_interframe : ReferenceFrame + Intermediate frame of the parent body with respect to which the joint + transformation is formulated. + child_interframe : ReferenceFrame + Intermediate frame of the child body with respect to which the joint + transformation is formulated. + kdes : Matrix + Kinematical differential equations of the joint. + + Examples + ========= + + A single spherical joint is created from two bodies and has the following + basic attributes: + + >>> from sympy.physics.mechanics import RigidBody, SphericalJoint + >>> parent = RigidBody('P') + >>> parent + P + >>> child = RigidBody('C') + >>> child + C + >>> joint = SphericalJoint('PC', parent, child) + >>> joint + SphericalJoint: PC parent: P child: C + >>> joint.name + 'PC' + >>> joint.parent + P + >>> joint.child + C + >>> joint.parent_point + P_masscenter + >>> joint.child_point + C_masscenter + >>> joint.parent_interframe + P_frame + >>> joint.child_interframe + C_frame + >>> joint.coordinates + Matrix([ + [q0_PC(t)], + [q1_PC(t)], + [q2_PC(t)]]) + >>> joint.speeds + Matrix([ + [u0_PC(t)], + [u1_PC(t)], + [u2_PC(t)]]) + >>> child.frame.ang_vel_in(parent.frame).to_matrix(child.frame) + Matrix([ + [ u0_PC(t)*cos(q1_PC(t))*cos(q2_PC(t)) + u1_PC(t)*sin(q2_PC(t))], + [-u0_PC(t)*sin(q2_PC(t))*cos(q1_PC(t)) + u1_PC(t)*cos(q2_PC(t))], + [ u0_PC(t)*sin(q1_PC(t)) + u2_PC(t)]]) + >>> child.frame.x.to_matrix(parent.frame) + Matrix([ + [ cos(q1_PC(t))*cos(q2_PC(t))], + [sin(q0_PC(t))*sin(q1_PC(t))*cos(q2_PC(t)) + sin(q2_PC(t))*cos(q0_PC(t))], + [sin(q0_PC(t))*sin(q2_PC(t)) - sin(q1_PC(t))*cos(q0_PC(t))*cos(q2_PC(t))]]) + >>> joint.child_point.pos_from(joint.parent_point) + 0 + + To further demonstrate the use of the spherical joint, the kinematics of a + spherical joint with a ZXZ rotation can be created as follows. + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import RigidBody, SphericalJoint + >>> l1 = symbols('l1') + + First create bodies to represent the fixed floor and a pendulum bob. + + >>> floor = RigidBody('F') + >>> bob = RigidBody('B') + + The joint will connect the bob to the floor, with the joint located at a + distance of ``l1`` from the child's center of mass and the rotation set to a + body-fixed ZXZ rotation. + + >>> joint = SphericalJoint('S', floor, bob, child_point=l1 * bob.y, + ... rot_type='body', rot_order='ZXZ') + + Now that the joint is established, the kinematics of the connected body can + be accessed. + + The position of the bob's masscenter is found with: + + >>> bob.masscenter.pos_from(floor.masscenter) + - l1*B_frame.y + + The angular velocities of the pendulum link can be computed with respect to + the floor. + + >>> bob.frame.ang_vel_in(floor.frame).to_matrix( + ... floor.frame).simplify() + Matrix([ + [u1_S(t)*cos(q0_S(t)) + u2_S(t)*sin(q0_S(t))*sin(q1_S(t))], + [u1_S(t)*sin(q0_S(t)) - u2_S(t)*sin(q1_S(t))*cos(q0_S(t))], + [ u0_S(t) + u2_S(t)*cos(q1_S(t))]]) + + Finally, the linear velocity of the bob's center of mass can be computed. + + >>> bob.masscenter.vel(floor.frame).to_matrix(bob.frame) + Matrix([ + [ l1*(u0_S(t)*cos(q1_S(t)) + u2_S(t))], + [ 0], + [-l1*(u0_S(t)*sin(q1_S(t))*sin(q2_S(t)) + u1_S(t)*cos(q2_S(t)))]]) + + """ + def __init__(self, name, parent, child, coordinates=None, speeds=None, + parent_point=None, child_point=None, parent_interframe=None, + child_interframe=None, rot_type='BODY', amounts=None, + rot_order=123): + self._rot_type = rot_type + self._amounts = amounts + self._rot_order = rot_order + super().__init__(name, parent, child, coordinates, speeds, + parent_point, child_point, + parent_interframe=parent_interframe, + child_interframe=child_interframe) + + def __str__(self): + return (f'SphericalJoint: {self.name} parent: {self.parent} ' + f'child: {self.child}') + + def _generate_coordinates(self, coordinates): + return self._fill_coordinate_list(coordinates, 3, 'q') + + def _generate_speeds(self, speeds): + return self._fill_coordinate_list(speeds, len(self.coordinates), 'u') + + def _orient_frames(self): + supported_rot_types = ('BODY', 'SPACE') + if self._rot_type.upper() not in supported_rot_types: + raise NotImplementedError( + f'Rotation type "{self._rot_type}" is not implemented. ' + f'Implemented rotation types are: {supported_rot_types}') + amounts = self.coordinates if self._amounts is None else self._amounts + self.child_interframe.orient(self.parent_interframe, self._rot_type, + amounts, self._rot_order) + + def _set_angular_velocity(self): + t = dynamicsymbols._t + vel = self.child_interframe.ang_vel_in(self.parent_interframe).xreplace( + {q.diff(t): u for q, u in zip(self.coordinates, self.speeds)} + ) + self.child_interframe.set_ang_vel(self.parent_interframe, vel) + + def _set_linear_velocity(self): + self.child_point.set_pos(self.parent_point, 0) + self.parent_point.set_vel(self._parent_frame, 0) + self.child_point.set_vel(self._child_frame, 0) + self.child.masscenter.v2pt_theory(self.parent_point, self._parent_frame, + self._child_frame) + + +class WeldJoint(Joint): + """Weld Joint. + + .. raw:: html + :file: ../../../doc/src/modules/physics/mechanics/api/WeldJoint.svg + + Explanation + =========== + + A weld joint is defined such that there is no relative motion between the + child and parent bodies. The direction cosine matrix between the attachment + frame (``parent_interframe`` and ``child_interframe``) is the identity + matrix and the attachment points (``parent_point`` and ``child_point``) are + coincident. The page on the joints framework gives a more detailed + explanation of the intermediate frames. + + Parameters + ========== + + name : string + A unique name for the joint. + parent : Particle or RigidBody + The parent body of joint. + child : Particle or RigidBody + The child body of joint. + parent_point : Point or Vector, optional + Attachment point where the joint is fixed to the parent body. If a + vector is provided, then the attachment point is computed by adding the + vector to the body's mass center. The default value is the parent's mass + center. + child_point : Point or Vector, optional + Attachment point where the joint is fixed to the child body. If a + vector is provided, then the attachment point is computed by adding the + vector to the body's mass center. The default value is the child's mass + center. + parent_interframe : ReferenceFrame, optional + Intermediate frame of the parent body with respect to which the joint + transformation is formulated. If a Vector is provided then an interframe + is created which aligns its X axis with the given vector. The default + value is the parent's own frame. + child_interframe : ReferenceFrame, optional + Intermediate frame of the child body with respect to which the joint + transformation is formulated. If a Vector is provided then an interframe + is created which aligns its X axis with the given vector. The default + value is the child's own frame. + + Attributes + ========== + + name : string + The joint's name. + parent : Particle or RigidBody + The joint's parent body. + child : Particle or RigidBody + The joint's child body. + coordinates : Matrix + Matrix of the joint's generalized coordinates. The default value is + ``dynamicsymbols(f'q_{joint.name}')``. + speeds : Matrix + Matrix of the joint's generalized speeds. The default value is + ``dynamicsymbols(f'u_{joint.name}')``. + parent_point : Point + Attachment point where the joint is fixed to the parent body. + child_point : Point + Attachment point where the joint is fixed to the child body. + parent_interframe : ReferenceFrame + Intermediate frame of the parent body with respect to which the joint + transformation is formulated. + child_interframe : ReferenceFrame + Intermediate frame of the child body with respect to which the joint + transformation is formulated. + kdes : Matrix + Kinematical differential equations of the joint. + + Examples + ========= + + A single weld joint is created from two bodies and has the following basic + attributes: + + >>> from sympy.physics.mechanics import RigidBody, WeldJoint + >>> parent = RigidBody('P') + >>> parent + P + >>> child = RigidBody('C') + >>> child + C + >>> joint = WeldJoint('PC', parent, child) + >>> joint + WeldJoint: PC parent: P child: C + >>> joint.name + 'PC' + >>> joint.parent + P + >>> joint.child + C + >>> joint.parent_point + P_masscenter + >>> joint.child_point + C_masscenter + >>> joint.coordinates + Matrix(0, 0, []) + >>> joint.speeds + Matrix(0, 0, []) + >>> child.frame.ang_vel_in(parent.frame) + 0 + >>> child.frame.dcm(parent.frame) + Matrix([ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1]]) + >>> joint.child_point.pos_from(joint.parent_point) + 0 + + To further demonstrate the use of the weld joint, two relatively-fixed + bodies rotated by a quarter turn about the Y axis can be created as follows: + + >>> from sympy import symbols, pi + >>> from sympy.physics.mechanics import ReferenceFrame, RigidBody, WeldJoint + >>> l1, l2 = symbols('l1 l2') + + First create the bodies to represent the parent and rotated child body. + + >>> parent = RigidBody('P') + >>> child = RigidBody('C') + + Next the intermediate frame specifying the fixed rotation with respect to + the parent can be created. + + >>> rotated_frame = ReferenceFrame('Pr') + >>> rotated_frame.orient_axis(parent.frame, parent.y, pi / 2) + + The weld between the parent body and child body is located at a distance + ``l1`` from the parent's center of mass in the X direction and ``l2`` from + the child's center of mass in the child's negative X direction. + + >>> weld = WeldJoint('weld', parent, child, parent_point=l1 * parent.x, + ... child_point=-l2 * child.x, + ... parent_interframe=rotated_frame) + + Now that the joint has been established, the kinematics of the bodies can be + accessed. The direction cosine matrix of the child body with respect to the + parent can be found: + + >>> child.frame.dcm(parent.frame) + Matrix([ + [0, 0, -1], + [0, 1, 0], + [1, 0, 0]]) + + As can also been seen from the direction cosine matrix, the parent X axis is + aligned with the child's Z axis: + >>> parent.x == child.z + True + + The position of the child's center of mass with respect to the parent's + center of mass can be found with: + + >>> child.masscenter.pos_from(parent.masscenter) + l1*P_frame.x + l2*C_frame.x + + The angular velocity of the child with respect to the parent is 0 as one + would expect. + + >>> child.frame.ang_vel_in(parent.frame) + 0 + + """ + + def __init__(self, name, parent, child, parent_point=None, child_point=None, + parent_interframe=None, child_interframe=None): + super().__init__(name, parent, child, [], [], parent_point, + child_point, parent_interframe=parent_interframe, + child_interframe=child_interframe) + self._kdes = Matrix(1, 0, []).T # Removes stackability problems #10770 + + def __str__(self): + return (f'WeldJoint: {self.name} parent: {self.parent} ' + f'child: {self.child}') + + def _generate_coordinates(self, coordinate): + return Matrix() + + def _generate_speeds(self, speed): + return Matrix() + + def _orient_frames(self): + self.child_interframe.orient_axis(self.parent_interframe, + self.parent_interframe.x, 0) + + def _set_angular_velocity(self): + self.child_interframe.set_ang_vel(self.parent_interframe, 0) + + def _set_linear_velocity(self): + self.child_point.set_pos(self.parent_point, 0) + self.parent_point.set_vel(self._parent_frame, 0) + self.child_point.set_vel(self._child_frame, 0) + self.child.masscenter.set_vel(self._parent_frame, 0) diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/jointsmethod.py b/phivenv/Lib/site-packages/sympy/physics/mechanics/jointsmethod.py new file mode 100644 index 0000000000000000000000000000000000000000..df7bd56360072feb57a65e5f78c2d116f0d4842d --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/mechanics/jointsmethod.py @@ -0,0 +1,318 @@ +from sympy.physics.mechanics import (Body, Lagrangian, KanesMethod, LagrangesMethod, + RigidBody, Particle) +from sympy.physics.mechanics.body_base import BodyBase +from sympy.physics.mechanics.method import _Methods +from sympy import Matrix +from sympy.utilities.exceptions import sympy_deprecation_warning + +__all__ = ['JointsMethod'] + + +class JointsMethod(_Methods): + """Method for formulating the equations of motion using a set of interconnected bodies with joints. + + .. deprecated:: 1.13 + The JointsMethod class is deprecated. Its functionality has been + replaced by the new :class:`~.System` class. + + Parameters + ========== + + newtonion : Body or ReferenceFrame + The newtonion(inertial) frame. + *joints : Joint + The joints in the system + + Attributes + ========== + + q, u : iterable + Iterable of the generalized coordinates and speeds + bodies : iterable + Iterable of Body objects in the system. + loads : iterable + Iterable of (Point, vector) or (ReferenceFrame, vector) tuples + describing the forces on the system. + mass_matrix : Matrix, shape(n, n) + The system's mass matrix + forcing : Matrix, shape(n, 1) + The system's forcing vector + mass_matrix_full : Matrix, shape(2*n, 2*n) + The "mass matrix" for the u's and q's + forcing_full : Matrix, shape(2*n, 1) + The "forcing vector" for the u's and q's + method : KanesMethod or Lagrange's method + Method's object. + kdes : iterable + Iterable of kde in they system. + + Examples + ======== + + As Body and JointsMethod have been deprecated, the following examples are + for illustrative purposes only. The functionality of Body is fully captured + by :class:`~.RigidBody` and :class:`~.Particle` and the functionality of + JointsMethod is fully captured by :class:`~.System`. To ignore the + deprecation warning we can use the ignore_warnings context manager. + + >>> from sympy.utilities.exceptions import ignore_warnings + + This is a simple example for a one degree of freedom translational + spring-mass-damper. + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import Body, JointsMethod, PrismaticJoint + >>> from sympy.physics.vector import dynamicsymbols + >>> c, k = symbols('c k') + >>> x, v = dynamicsymbols('x v') + >>> with ignore_warnings(DeprecationWarning): + ... wall = Body('W') + ... body = Body('B') + >>> J = PrismaticJoint('J', wall, body, coordinates=x, speeds=v) + >>> wall.apply_force(c*v*wall.x, reaction_body=body) + >>> wall.apply_force(k*x*wall.x, reaction_body=body) + >>> with ignore_warnings(DeprecationWarning): + ... method = JointsMethod(wall, J) + >>> method.form_eoms() + Matrix([[-B_mass*Derivative(v(t), t) - c*v(t) - k*x(t)]]) + >>> M = method.mass_matrix_full + >>> F = method.forcing_full + >>> rhs = M.LUsolve(F) + >>> rhs + Matrix([ + [ v(t)], + [(-c*v(t) - k*x(t))/B_mass]]) + + Notes + ===== + + ``JointsMethod`` currently only works with systems that do not have any + configuration or motion constraints. + + """ + + def __init__(self, newtonion, *joints): + sympy_deprecation_warning( + """ + The JointsMethod class is deprecated. + Its functionality has been replaced by the new System class. + """, + deprecated_since_version="1.13", + active_deprecations_target="deprecated-mechanics-jointsmethod" + ) + if isinstance(newtonion, BodyBase): + self.frame = newtonion.frame + else: + self.frame = newtonion + + self._joints = joints + self._bodies = self._generate_bodylist() + self._loads = self._generate_loadlist() + self._q = self._generate_q() + self._u = self._generate_u() + self._kdes = self._generate_kdes() + + self._method = None + + @property + def bodies(self): + """List of bodies in they system.""" + return self._bodies + + @property + def loads(self): + """List of loads on the system.""" + return self._loads + + @property + def q(self): + """List of the generalized coordinates.""" + return self._q + + @property + def u(self): + """List of the generalized speeds.""" + return self._u + + @property + def kdes(self): + """List of the generalized coordinates.""" + return self._kdes + + @property + def forcing_full(self): + """The "forcing vector" for the u's and q's.""" + return self.method.forcing_full + + @property + def mass_matrix_full(self): + """The "mass matrix" for the u's and q's.""" + return self.method.mass_matrix_full + + @property + def mass_matrix(self): + """The system's mass matrix.""" + return self.method.mass_matrix + + @property + def forcing(self): + """The system's forcing vector.""" + return self.method.forcing + + @property + def method(self): + """Object of method used to form equations of systems.""" + return self._method + + def _generate_bodylist(self): + bodies = [] + for joint in self._joints: + if joint.child not in bodies: + bodies.append(joint.child) + if joint.parent not in bodies: + bodies.append(joint.parent) + return bodies + + def _generate_loadlist(self): + load_list = [] + for body in self.bodies: + if isinstance(body, Body): + load_list.extend(body.loads) + return load_list + + def _generate_q(self): + q_ind = [] + for joint in self._joints: + for coordinate in joint.coordinates: + if coordinate in q_ind: + raise ValueError('Coordinates of joints should be unique.') + q_ind.append(coordinate) + return Matrix(q_ind) + + def _generate_u(self): + u_ind = [] + for joint in self._joints: + for speed in joint.speeds: + if speed in u_ind: + raise ValueError('Speeds of joints should be unique.') + u_ind.append(speed) + return Matrix(u_ind) + + def _generate_kdes(self): + kd_ind = Matrix(1, 0, []).T + for joint in self._joints: + kd_ind = kd_ind.col_join(joint.kdes) + return kd_ind + + def _convert_bodies(self): + # Convert `Body` to `Particle` and `RigidBody` + bodylist = [] + for body in self.bodies: + if not isinstance(body, Body): + bodylist.append(body) + continue + if body.is_rigidbody: + rb = RigidBody(body.name, body.masscenter, body.frame, body.mass, + (body.central_inertia, body.masscenter)) + rb.potential_energy = body.potential_energy + bodylist.append(rb) + else: + part = Particle(body.name, body.masscenter, body.mass) + part.potential_energy = body.potential_energy + bodylist.append(part) + return bodylist + + def form_eoms(self, method=KanesMethod): + """Method to form system's equation of motions. + + Parameters + ========== + + method : Class + Class name of method. + + Returns + ======== + + Matrix + Vector of equations of motions. + + Examples + ======== + + As Body and JointsMethod have been deprecated, the following examples + are for illustrative purposes only. The functionality of Body is fully + captured by :class:`~.RigidBody` and :class:`~.Particle` and the + functionality of JointsMethod is fully captured by :class:`~.System`. To + ignore the deprecation warning we can use the ignore_warnings context + manager. + + >>> from sympy.utilities.exceptions import ignore_warnings + + This is a simple example for a one degree of freedom translational + spring-mass-damper. + + >>> from sympy import S, symbols + >>> from sympy.physics.mechanics import LagrangesMethod, dynamicsymbols, Body + >>> from sympy.physics.mechanics import PrismaticJoint, JointsMethod + >>> q = dynamicsymbols('q') + >>> qd = dynamicsymbols('q', 1) + >>> m, k, b = symbols('m k b') + >>> with ignore_warnings(DeprecationWarning): + ... wall = Body('W') + ... part = Body('P', mass=m) + >>> part.potential_energy = k * q**2 / S(2) + >>> J = PrismaticJoint('J', wall, part, coordinates=q, speeds=qd) + >>> wall.apply_force(b * qd * wall.x, reaction_body=part) + >>> with ignore_warnings(DeprecationWarning): + ... method = JointsMethod(wall, J) + >>> method.form_eoms(LagrangesMethod) + Matrix([[b*Derivative(q(t), t) + k*q(t) + m*Derivative(q(t), (t, 2))]]) + + We can also solve for the states using the 'rhs' method. + + >>> method.rhs() + Matrix([ + [ Derivative(q(t), t)], + [(-b*Derivative(q(t), t) - k*q(t))/m]]) + + """ + + bodylist = self._convert_bodies() + if issubclass(method, LagrangesMethod): #LagrangesMethod or similar + L = Lagrangian(self.frame, *bodylist) + self._method = method(L, self.q, self.loads, bodylist, self.frame) + else: #KanesMethod or similar + self._method = method(self.frame, q_ind=self.q, u_ind=self.u, kd_eqs=self.kdes, + forcelist=self.loads, bodies=bodylist) + soln = self.method._form_eoms() + return soln + + def rhs(self, inv_method=None): + """Returns equations that can be solved numerically. + + Parameters + ========== + + inv_method : str + The specific sympy inverse matrix calculation method to use. For a + list of valid methods, see + :meth:`~sympy.matrices.matrixbase.MatrixBase.inv` + + Returns + ======== + + Matrix + Numerically solvable equations. + + See Also + ======== + + sympy.physics.mechanics.kane.KanesMethod.rhs: + KanesMethod's rhs function. + sympy.physics.mechanics.lagrange.LagrangesMethod.rhs: + LagrangesMethod's rhs function. + + """ + + return self.method.rhs(inv_method=inv_method) diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/kane.py b/phivenv/Lib/site-packages/sympy/physics/mechanics/kane.py new file mode 100644 index 0000000000000000000000000000000000000000..805587a4fe9d7696f45c5815ee5406b103150698 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/mechanics/kane.py @@ -0,0 +1,859 @@ +from sympy import zeros, Matrix, diff, eye, linear_eq_to_matrix +from sympy.core.sorting import default_sort_key +from sympy.physics.vector import (ReferenceFrame, dynamicsymbols, + partial_velocity) +from sympy.physics.mechanics.method import _Methods +from sympy.physics.mechanics.particle import Particle +from sympy.physics.mechanics.rigidbody import RigidBody +from sympy.physics.mechanics.functions import (msubs, find_dynamicsymbols, + _f_list_parser, + _validate_coordinates, + _parse_linear_solver) +from sympy.physics.mechanics.linearize import Linearizer +from sympy.utilities.iterables import iterable + + +__all__ = ['KanesMethod'] + + +class KanesMethod(_Methods): + r"""Kane's method object. + + Explanation + =========== + + This object is used to do the "book-keeping" as you go through and form + equations of motion in the way Kane presents in: + Kane, T., Levinson, D. Dynamics Theory and Applications. 1985 McGraw-Hill + + The attributes are for equations in the form [M] udot = forcing. + + Attributes + ========== + + q, u : Matrix + Matrices of the generalized coordinates and speeds + bodies : iterable + Iterable of Particle and RigidBody objects in the system. + loads : iterable + Iterable of (Point, vector) or (ReferenceFrame, vector) tuples + describing the forces on the system. + auxiliary_eqs : Matrix + If applicable, the set of auxiliary Kane's + equations used to solve for non-contributing + forces. + mass_matrix : Matrix + The system's dynamics mass matrix: [k_d; k_dnh] + forcing : Matrix + The system's dynamics forcing vector: -[f_d; f_dnh] + mass_matrix_kin : Matrix + The "mass matrix" for kinematic differential equations: k_kqdot + forcing_kin : Matrix + The forcing vector for kinematic differential equations: -(k_ku*u + f_k) + mass_matrix_full : Matrix + The "mass matrix" for the u's and q's with dynamics and kinematics + forcing_full : Matrix + The "forcing vector" for the u's and q's with dynamics and kinematics + + Parameters + ========== + + frame : ReferenceFrame + The inertial reference frame for the system. + q_ind : iterable of dynamicsymbols + Independent generalized coordinates. + u_ind : iterable of dynamicsymbols + Independent generalized speeds. + kd_eqs : iterable of Expr, optional + Kinematic differential equations, which linearly relate the generalized + speeds to the time-derivatives of the generalized coordinates. + q_dependent : iterable of dynamicsymbols, optional + Dependent generalized coordinates. + configuration_constraints : iterable of Expr, optional + Constraints on the system's configuration, i.e. holonomic constraints. + u_dependent : iterable of dynamicsymbols, optional + Dependent generalized speeds. + velocity_constraints : iterable of Expr, optional + Constraints on the system's velocity, i.e. the combination of the + nonholonomic constraints and the time-derivative of the holonomic + constraints. + acceleration_constraints : iterable of Expr, optional + Constraints on the system's acceleration, by default these are the + time-derivative of the velocity constraints. + u_auxiliary : iterable of dynamicsymbols, optional + Auxiliary generalized speeds. + bodies : iterable of Particle and/or RigidBody, optional + The particles and rigid bodies in the system. + forcelist : iterable of tuple[Point | ReferenceFrame, Vector], optional + Forces and torques applied on the system. + explicit_kinematics : bool + Boolean whether the mass matrices and forcing vectors should use the + explicit form (default) or implicit form for kinematics. + See the notes for more details. + kd_eqs_solver : str, callable + Method used to solve the kinematic differential equations. If a string + is supplied, it should be a valid method that can be used with the + :meth:`sympy.matrices.matrixbase.MatrixBase.solve`. If a callable is + supplied, it should have the format ``f(A, rhs)``, where it solves the + equations and returns the solution. The default utilizes LU solve. See + the notes for more information. + constraint_solver : str, callable + Method used to solve the velocity constraints. If a string is + supplied, it should be a valid method that can be used with the + :meth:`sympy.matrices.matrixbase.MatrixBase.solve`. If a callable is + supplied, it should have the format ``f(A, rhs)``, where it solves the + equations and returns the solution. The default utilizes LU solve. See + the notes for more information. + + Notes + ===== + + The mass matrices and forcing vectors related to kinematic equations + are given in the explicit form by default. In other words, the kinematic + mass matrix is $\mathbf{k_{k\dot{q}}} = \mathbf{I}$. + In order to get the implicit form of those matrices/vectors, you can set the + ``explicit_kinematics`` attribute to ``False``. So $\mathbf{k_{k\dot{q}}}$ + is not necessarily an identity matrix. This can provide more compact + equations for non-simple kinematics. + + Two linear solvers can be supplied to ``KanesMethod``: one for solving the + kinematic differential equations and one to solve the velocity constraints. + Both of these sets of equations can be expressed as a linear system ``Ax = rhs``, + which have to be solved in order to obtain the equations of motion. + + The default solver ``'LU'``, which stands for LU solve, results relatively low + number of operations. The weakness of this method is that it can result in zero + division errors. + + If zero divisions are encountered, a possible solver which may solve the problem + is ``"CRAMER"``. This method uses Cramer's rule to solve the system. This method + is slower and results in more operations than the default solver. However it only + uses a single division by default per entry of the solution. + + While a valid list of solvers can be found at + :meth:`sympy.matrices.matrixbase.MatrixBase.solve`, it is also possible to supply a + `callable`. This way it is possible to use a different solver routine. If the + kinematic differential equations are not too complex it can be worth it to simplify + the solution by using ``lambda A, b: simplify(Matrix.LUsolve(A, b))``. Another + option solver one may use is :func:`sympy.solvers.solveset.linsolve`. This can be + done using `lambda A, b: tuple(linsolve((A, b)))[0]`, where we select the first + solution as our system should have only one unique solution. + + Examples + ======== + + This is a simple example for a one degree of freedom translational + spring-mass-damper. + + In this example, we first need to do the kinematics. + This involves creating generalized speeds and coordinates and their + derivatives. + Then we create a point and set its velocity in a frame. + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import dynamicsymbols, ReferenceFrame + >>> from sympy.physics.mechanics import Point, Particle, KanesMethod + >>> q, u = dynamicsymbols('q u') + >>> qd, ud = dynamicsymbols('q u', 1) + >>> m, c, k = symbols('m c k') + >>> N = ReferenceFrame('N') + >>> P = Point('P') + >>> P.set_vel(N, u * N.x) + + Next we need to arrange/store information in the way that KanesMethod + requires. The kinematic differential equations should be an iterable of + expressions. A list of forces/torques must be constructed, where each entry + in the list is a (Point, Vector) or (ReferenceFrame, Vector) tuple, where + the Vectors represent the Force or Torque. + Next a particle needs to be created, and it needs to have a point and mass + assigned to it. + Finally, a list of all bodies and particles needs to be created. + + >>> kd = [qd - u] + >>> FL = [(P, (-k * q - c * u) * N.x)] + >>> pa = Particle('pa', P, m) + >>> BL = [pa] + + Finally we can generate the equations of motion. + First we create the KanesMethod object and supply an inertial frame, + coordinates, generalized speeds, and the kinematic differential equations. + Additional quantities such as configuration and motion constraints, + dependent coordinates and speeds, and auxiliary speeds are also supplied + here (see the online documentation). + Next we form FR* and FR to complete: Fr + Fr* = 0. + We have the equations of motion at this point. + It makes sense to rearrange them though, so we calculate the mass matrix and + the forcing terms, for E.o.M. in the form: [MM] udot = forcing, where MM is + the mass matrix, udot is a vector of the time derivatives of the + generalized speeds, and forcing is a vector representing "forcing" terms. + + >>> KM = KanesMethod(N, q_ind=[q], u_ind=[u], kd_eqs=kd) + >>> (fr, frstar) = KM.kanes_equations(BL, FL) + >>> MM = KM.mass_matrix + >>> forcing = KM.forcing + >>> rhs = MM.inv() * forcing + >>> rhs + Matrix([[(-c*u(t) - k*q(t))/m]]) + >>> KM.linearize(A_and_B=True)[0] + Matrix([ + [ 0, 1], + [-k/m, -c/m]]) + + Please look at the documentation pages for more information on how to + perform linearization and how to deal with dependent coordinates & speeds, + and how do deal with bringing non-contributing forces into evidence. + + """ + + def __init__(self, frame, q_ind, u_ind, kd_eqs=None, q_dependent=None, + configuration_constraints=None, u_dependent=None, + velocity_constraints=None, acceleration_constraints=None, + u_auxiliary=None, bodies=None, forcelist=None, + explicit_kinematics=True, kd_eqs_solver='LU', + constraint_solver='LU'): + + """Please read the online documentation. """ + if not q_ind: + q_ind = [dynamicsymbols('dummy_q')] + kd_eqs = [dynamicsymbols('dummy_kd')] + + if not isinstance(frame, ReferenceFrame): + raise TypeError('An inertial ReferenceFrame must be supplied') + self._inertial = frame + + self._fr = None + self._frstar = None + + self._forcelist = forcelist + self._bodylist = bodies + + self.explicit_kinematics = explicit_kinematics + self._constraint_solver = constraint_solver + self._initialize_vectors(q_ind, q_dependent, u_ind, u_dependent, + u_auxiliary) + _validate_coordinates(self.q, self.u) + self._initialize_kindiffeq_matrices(kd_eqs, kd_eqs_solver) + self._initialize_constraint_matrices( + configuration_constraints, velocity_constraints, + acceleration_constraints, constraint_solver) + + def _initialize_vectors(self, q_ind, q_dep, u_ind, u_dep, u_aux): + """Initialize the coordinate and speed vectors.""" + + none_handler = lambda x: Matrix(x) if x else Matrix() + + # Initialize generalized coordinates + q_dep = none_handler(q_dep) + if not iterable(q_ind): + raise TypeError('Generalized coordinates must be an iterable.') + if not iterable(q_dep): + raise TypeError('Dependent coordinates must be an iterable.') + q_ind = Matrix(q_ind) + self._qdep = q_dep + self._q = Matrix([q_ind, q_dep]) + self._qdot = self.q.diff(dynamicsymbols._t) + + # Initialize generalized speeds + u_dep = none_handler(u_dep) + if not iterable(u_ind): + raise TypeError('Generalized speeds must be an iterable.') + if not iterable(u_dep): + raise TypeError('Dependent speeds must be an iterable.') + u_ind = Matrix(u_ind) + self._udep = u_dep + self._u = Matrix([u_ind, u_dep]) + self._udot = self.u.diff(dynamicsymbols._t) + self._uaux = none_handler(u_aux) + + def _initialize_constraint_matrices(self, config, vel, acc, linear_solver='LU'): + """Initializes constraint matrices.""" + linear_solver = _parse_linear_solver(linear_solver) + # Define vector dimensions + o = len(self.u) + m = len(self._udep) + p = o - m + none_handler = lambda x: Matrix(x) if x else Matrix() + + # Initialize configuration constraints + config = none_handler(config) + if len(self._qdep) != len(config): + raise ValueError('There must be an equal number of dependent ' + 'coordinates and configuration constraints.') + self._f_h = none_handler(config) + + # Initialize velocity and acceleration constraints + vel = none_handler(vel) + acc = none_handler(acc) + if len(vel) != m: + raise ValueError('There must be an equal number of dependent ' + 'speeds and velocity constraints.') + if acc and (len(acc) != m): + raise ValueError('There must be an equal number of dependent ' + 'speeds and acceleration constraints.') + if vel: + + # When calling kanes_equations, another class instance will be + # created if auxiliary u's are present. In this case, the + # computation of kinetic differential equation matrices will be + # skipped as this was computed during the original KanesMethod + # object, and the qd_u_map will not be available. + if self._qdot_u_map is not None: + vel = msubs(vel, self._qdot_u_map) + self._k_nh, f_nh_neg = linear_eq_to_matrix(vel, self.u[:]) + self._f_nh = -f_nh_neg + + # If no acceleration constraints given, calculate them. + if not acc: + _f_dnh = (self._k_nh.diff(dynamicsymbols._t) * self.u + + self._f_nh.diff(dynamicsymbols._t)) + if self._qdot_u_map is not None: + _f_dnh = msubs(_f_dnh, self._qdot_u_map) + self._f_dnh = _f_dnh + self._k_dnh = self._k_nh + else: + if self._qdot_u_map is not None: + acc = msubs(acc, self._qdot_u_map) + + self._k_dnh, f_dnh_neg = linear_eq_to_matrix(acc, self._udot[:]) + self._f_dnh = -f_dnh_neg + # Form of non-holonomic constraints is B*u + C = 0. + # We partition B into independent and dependent columns: + # Ars is then -B_dep.inv() * B_ind, and it relates dependent speeds + # to independent speeds as: udep = Ars*uind, neglecting the C term. + B_ind = self._k_nh[:, :p] + B_dep = self._k_nh[:, p:o] + self._Ars = -linear_solver(B_dep, B_ind) + else: + self._f_nh = Matrix() + self._k_nh = Matrix() + self._f_dnh = Matrix() + self._k_dnh = Matrix() + self._Ars = Matrix() + + def _initialize_kindiffeq_matrices(self, kdeqs, linear_solver='LU'): + """Initialize the kinematic differential equation matrices. + + Parameters + ========== + kdeqs : sequence of sympy expressions + Kinematic differential equations in the form of f(u,q',q,t) where + f() = 0. The equations have to be linear in the time-derivatives of + the generalized coordinates and in the generalized speeds. + + """ + linear_solver = _parse_linear_solver(linear_solver) + if kdeqs: + if len(self.q) != len(kdeqs): + raise ValueError('There must be an equal number of kinematic ' + 'differential equations and coordinates.') + + u = self.u + qdot = self._qdot + + kdeqs = Matrix(kdeqs) + + u_zero = dict.fromkeys(u, 0) + uaux_zero = dict.fromkeys(self._uaux, 0) + qdot_zero = dict.fromkeys(qdot, 0) + + # Extract the linear coefficient matrices as per the following + # equation: + # + # k_ku(q,t)*u(t) + k_kqdot(q,t)*q'(t) + f_k(q,t) = 0 + # + k_ku = kdeqs.jacobian(u) + k_kqdot = kdeqs.jacobian(qdot) + f_k = kdeqs.xreplace(u_zero).xreplace(qdot_zero) + + # The kinematic differential equations should be linear in both q' + # and u so check for u and q' in the components. + dy_syms = find_dynamicsymbols(k_ku.row_join(k_kqdot).row_join(f_k)) + nonlin_vars = [vari for vari in u[:] + qdot[:] if vari in dy_syms] + if nonlin_vars: + msg = ('The provided kinematic differential equations are ' + 'nonlinear in {}. They must be linear in the ' + 'generalized speeds and derivatives of the generalized ' + 'coordinates.') + raise ValueError(msg.format(nonlin_vars)) + + self._f_k_implicit = f_k.xreplace(uaux_zero) + self._k_ku_implicit = k_ku.xreplace(uaux_zero) + self._k_kqdot_implicit = k_kqdot + + # Solve for q'(t) such that the coefficient matrices are now in + # this form: + # + # k_kqdot^-1*k_ku*u(t) + I*q'(t) + k_kqdot^-1*f_k = 0 + # + # NOTE : Solving the kinematic differential equations here is not + # necessary and prevents the equations from being provided in fully + # implicit form. + f_k_explicit = linear_solver(k_kqdot, f_k) + k_ku_explicit = linear_solver(k_kqdot, k_ku) + self._qdot_u_map = dict(zip(qdot, -(k_ku_explicit*u + f_k_explicit))) + + self._f_k = f_k_explicit.xreplace(uaux_zero) + self._k_ku = k_ku_explicit.xreplace(uaux_zero) + self._k_kqdot = eye(len(qdot)) + + else: + self._qdot_u_map = None + self._f_k_implicit = self._f_k = Matrix() + self._k_ku_implicit = self._k_ku = Matrix() + self._k_kqdot_implicit = self._k_kqdot = Matrix() + + def _form_fr(self, fl): + """Form the generalized active force.""" + if fl is not None and (len(fl) == 0 or not iterable(fl)): + raise ValueError('Force pairs must be supplied in an ' + 'non-empty iterable or None.') + + N = self._inertial + # pull out relevant velocities for constructing partial velocities + vel_list, f_list = _f_list_parser(fl, N) + vel_list = [msubs(i, self._qdot_u_map) for i in vel_list] + f_list = [msubs(i, self._qdot_u_map) for i in f_list] + + # Fill Fr with dot product of partial velocities and forces + o = len(self.u) + b = len(f_list) + FR = zeros(o, 1) + partials = partial_velocity(vel_list, self.u, N) + for i in range(o): + FR[i] = sum(partials[j][i].dot(f_list[j]) for j in range(b)) + + # In case there are dependent speeds + if self._udep: + p = o - len(self._udep) + FRtilde = FR[:p, 0] + FRold = FR[p:o, 0] + FRtilde += self._Ars.T * FRold + FR = FRtilde + + self._forcelist = fl + self._fr = FR + return FR + + def _form_frstar(self, bl): + """Form the generalized inertia force.""" + + if not iterable(bl): + raise TypeError('Bodies must be supplied in an iterable.') + + t = dynamicsymbols._t + N = self._inertial + # Dicts setting things to zero + udot_zero = dict.fromkeys(self._udot, 0) + uaux_zero = dict.fromkeys(self._uaux, 0) + uauxdot = [diff(i, t) for i in self._uaux] + uauxdot_zero = dict.fromkeys(uauxdot, 0) + # Dictionary of q' and q'' to u and u' + q_ddot_u_map = {k.diff(t): v.diff(t).xreplace( + self._qdot_u_map) for (k, v) in self._qdot_u_map.items()} + q_ddot_u_map.update(self._qdot_u_map) + + # Fill up the list of partials: format is a list with num elements + # equal to number of entries in body list. Each of these elements is a + # list - either of length 1 for the translational components of + # particles or of length 2 for the translational and rotational + # components of rigid bodies. The inner most list is the list of + # partial velocities. + def get_partial_velocity(body): + if isinstance(body, RigidBody): + vlist = [body.masscenter.vel(N), body.frame.ang_vel_in(N)] + elif isinstance(body, Particle): + vlist = [body.point.vel(N),] + else: + raise TypeError('The body list may only contain either ' + 'RigidBody or Particle as list elements.') + v = [msubs(vel, self._qdot_u_map) for vel in vlist] + return partial_velocity(v, self.u, N) + partials = [get_partial_velocity(body) for body in bl] + + # Compute fr_star in two components: + # fr_star = -(MM*u' + nonMM) + o = len(self.u) + MM = zeros(o, o) + nonMM = zeros(o, 1) + zero_uaux = lambda expr: msubs(expr, uaux_zero) + zero_udot_uaux = lambda expr: msubs(msubs(expr, udot_zero), uaux_zero) + for i, body in enumerate(bl): + if isinstance(body, RigidBody): + M = zero_uaux(body.mass) + I = zero_uaux(body.central_inertia) + vel = zero_uaux(body.masscenter.vel(N)) + omega = zero_uaux(body.frame.ang_vel_in(N)) + acc = zero_udot_uaux(body.masscenter.acc(N)) + inertial_force = (M.diff(t) * vel + M * acc) + inertial_torque = zero_uaux((I.dt(body.frame).dot(omega)) + + msubs(I.dot(body.frame.ang_acc_in(N)), udot_zero) + + (omega.cross(I.dot(omega)))) + for j in range(o): + tmp_vel = zero_uaux(partials[i][0][j]) + tmp_ang = zero_uaux(I.dot(partials[i][1][j])) + for k in range(o): + # translational + MM[j, k] += M*tmp_vel.dot(partials[i][0][k]) + # rotational + MM[j, k] += tmp_ang.dot(partials[i][1][k]) + nonMM[j] += inertial_force.dot(partials[i][0][j]) + nonMM[j] += inertial_torque.dot(partials[i][1][j]) + else: + M = zero_uaux(body.mass) + vel = zero_uaux(body.point.vel(N)) + acc = zero_udot_uaux(body.point.acc(N)) + inertial_force = (M.diff(t) * vel + M * acc) + for j in range(o): + temp = zero_uaux(partials[i][0][j]) + for k in range(o): + MM[j, k] += M*temp.dot(partials[i][0][k]) + nonMM[j] += inertial_force.dot(partials[i][0][j]) + # Compose fr_star out of MM and nonMM + MM = zero_uaux(msubs(MM, q_ddot_u_map)) + nonMM = msubs(msubs(nonMM, q_ddot_u_map), + udot_zero, uauxdot_zero, uaux_zero) + fr_star = -(MM * msubs(Matrix(self._udot), uauxdot_zero) + nonMM) + + # If there are dependent speeds, we need to find fr_star_tilde + if self._udep: + p = o - len(self._udep) + fr_star_ind = fr_star[:p, 0] + fr_star_dep = fr_star[p:o, 0] + fr_star = fr_star_ind + (self._Ars.T * fr_star_dep) + # Apply the same to MM + MMi = MM[:p, :] + MMd = MM[p:o, :] + MM = MMi + (self._Ars.T * MMd) + # Apply the same to nonMM + nonMM = nonMM[:p, :] + (self._Ars.T * nonMM[p:o, :]) + + self._bodylist = bl + self._frstar = fr_star + self._k_d = MM + self._f_d = -(self._fr - nonMM) + return fr_star + + def to_linearizer(self, linear_solver='LU'): + """Returns an instance of the Linearizer class, initiated from the + data in the KanesMethod class. This may be more desirable than using + the linearize class method, as the Linearizer object will allow more + efficient recalculation (i.e. about varying operating points). + + Parameters + ========== + linear_solver : str, callable + Method used to solve the several symbolic linear systems of the + form ``A*x=b`` in the linearization process. If a string is + supplied, it should be a valid method that can be used with the + :meth:`sympy.matrices.matrixbase.MatrixBase.solve`. If a callable is + supplied, it should have the format ``x = f(A, b)``, where it + solves the equations and returns the solution. The default is + ``'LU'`` which corresponds to SymPy's ``A.LUsolve(b)``. + ``LUsolve()`` is fast to compute but will often result in + divide-by-zero and thus ``nan`` results. + + Returns + ======= + Linearizer + An instantiated + :class:`sympy.physics.mechanics.linearize.Linearizer`. + + """ + + if (self._fr is None) or (self._frstar is None): + raise ValueError('Need to compute Fr, Fr* first.') + + # Get required equation components. The Kane's method class breaks + # these into pieces. Need to reassemble + f_c = self._f_h + if self._f_nh and self._k_nh: + f_v = self._f_nh + self._k_nh*Matrix(self.u) + else: + f_v = Matrix() + if self._f_dnh and self._k_dnh: + f_a = self._f_dnh + self._k_dnh*Matrix(self._udot) + else: + f_a = Matrix() + # Dicts to sub to zero, for splitting up expressions + u_zero = dict.fromkeys(self.u, 0) + ud_zero = dict.fromkeys(self._udot, 0) + qd_zero = dict.fromkeys(self._qdot, 0) + qd_u_zero = dict.fromkeys(Matrix([self._qdot, self.u]), 0) + # Break the kinematic differential eqs apart into f_0 and f_1 + f_0 = msubs(self._f_k, u_zero) + self._k_kqdot*Matrix(self._qdot) + f_1 = msubs(self._f_k, qd_zero) + self._k_ku*Matrix(self.u) + # Break the dynamic differential eqs into f_2 and f_3 + f_2 = msubs(self._frstar, qd_u_zero) + f_3 = msubs(self._frstar, ud_zero) + self._fr + f_4 = zeros(len(f_2), 1) + + # Get the required vector components + q = self.q + u = self.u + if self._qdep: + q_i = q[:-len(self._qdep)] + else: + q_i = q + q_d = self._qdep + if self._udep: + u_i = u[:-len(self._udep)] + else: + u_i = u + u_d = self._udep + + # Form dictionary to set auxiliary speeds & their derivatives to 0. + uaux = self._uaux + uauxdot = uaux.diff(dynamicsymbols._t) + uaux_zero = dict.fromkeys(Matrix([uaux, uauxdot]), 0) + + # Checking for dynamic symbols outside the dynamic differential + # equations; throws error if there is. + sym_list = set(Matrix([q, self._qdot, u, self._udot, uaux, uauxdot])) + if any(find_dynamicsymbols(i, sym_list) for i in [self._k_kqdot, + self._k_ku, self._f_k, self._k_dnh, self._f_dnh, self._k_d]): + raise ValueError('Cannot have dynamicsymbols outside dynamic \ + forcing vector.') + + # Find all other dynamic symbols, forming the forcing vector r. + # Sort r to make it canonical. + r = list(find_dynamicsymbols(msubs(self._f_d, uaux_zero), sym_list)) + r.sort(key=default_sort_key) + + # Check for any derivatives of variables in r that are also found in r. + for i in r: + if diff(i, dynamicsymbols._t) in r: + raise ValueError('Cannot have derivatives of specified \ + quantities when linearizing forcing terms.') + return Linearizer(f_0, f_1, f_2, f_3, f_4, f_c, f_v, f_a, q, u, q_i, + q_d, u_i, u_d, r, linear_solver=linear_solver) + + # TODO : Remove `new_method` after 1.1 has been released. + def linearize(self, *, new_method=None, linear_solver='LU', **kwargs): + """ Linearize the equations of motion about a symbolic operating point. + + Parameters + ========== + new_method + Deprecated, does nothing and will be removed. + linear_solver : str, callable + Method used to solve the several symbolic linear systems of the + form ``A*x=b`` in the linearization process. If a string is + supplied, it should be a valid method that can be used with the + :meth:`sympy.matrices.matrixbase.MatrixBase.solve`. If a callable is + supplied, it should have the format ``x = f(A, b)``, where it + solves the equations and returns the solution. The default is + ``'LU'`` which corresponds to SymPy's ``A.LUsolve(b)``. + ``LUsolve()`` is fast to compute but will often result in + divide-by-zero and thus ``nan`` results. + **kwargs + Extra keyword arguments are passed to + :meth:`sympy.physics.mechanics.linearize.Linearizer.linearize`. + + Explanation + =========== + + If kwarg A_and_B is False (default), returns M, A, B, r for the + linearized form, M*[q', u']^T = A*[q_ind, u_ind]^T + B*r. + + If kwarg A_and_B is True, returns A, B, r for the linearized form + dx = A*x + B*r, where x = [q_ind, u_ind]^T. Note that this is + computationally intensive if there are many symbolic parameters. For + this reason, it may be more desirable to use the default A_and_B=False, + returning M, A, and B. Values may then be substituted in to these + matrices, and the state space form found as + A = P.T*M.inv()*A, B = P.T*M.inv()*B, where P = Linearizer.perm_mat. + + In both cases, r is found as all dynamicsymbols in the equations of + motion that are not part of q, u, q', or u'. They are sorted in + canonical form. + + The operating points may be also entered using the ``op_point`` kwarg. + This takes a dictionary of {symbol: value}, or a an iterable of such + dictionaries. The values may be numeric or symbolic. The more values + you can specify beforehand, the faster this computation will run. + + For more documentation, please see the ``Linearizer`` class. + + """ + + linearizer = self.to_linearizer(linear_solver=linear_solver) + result = linearizer.linearize(**kwargs) + return result + (linearizer.r,) + + def kanes_equations(self, bodies=None, loads=None): + """ Method to form Kane's equations, Fr + Fr* = 0. + + Explanation + =========== + + Returns (Fr, Fr*). In the case where auxiliary generalized speeds are + present (say, s auxiliary speeds, o generalized speeds, and m motion + constraints) the length of the returned vectors will be o - m + s in + length. The first o - m equations will be the constrained Kane's + equations, then the s auxiliary Kane's equations. These auxiliary + equations can be accessed with the auxiliary_eqs property. + + Parameters + ========== + + bodies : iterable + An iterable of all RigidBody's and Particle's in the system. + A system must have at least one body. + loads : iterable + Takes in an iterable of (Particle, Vector) or (ReferenceFrame, Vector) + tuples which represent the force at a point or torque on a frame. + Must be either a non-empty iterable of tuples or None which corresponds + to a system with no constraints. + """ + if bodies is None: + bodies = self.bodies + if loads is None and self._forcelist is not None: + loads = self._forcelist + if loads == []: + loads = None + if not self._k_kqdot: + raise AttributeError('Create an instance of KanesMethod with ' + 'kinematic differential equations to use this method.') + fr = self._form_fr(loads) + frstar = self._form_frstar(bodies) + if self._uaux: + if not self._udep: + km = KanesMethod(self._inertial, self.q, self._uaux, + u_auxiliary=self._uaux, constraint_solver=self._constraint_solver) + else: + km = KanesMethod(self._inertial, self.q, self._uaux, + u_auxiliary=self._uaux, u_dependent=self._udep, + velocity_constraints=(self._k_nh * self.u + + self._f_nh), + acceleration_constraints=(self._k_dnh * self._udot + + self._f_dnh), + constraint_solver=self._constraint_solver + ) + km._qdot_u_map = self._qdot_u_map + self._km = km + fraux = km._form_fr(loads) + frstaraux = km._form_frstar(bodies) + self._aux_eq = fraux + frstaraux + self._fr = fr.col_join(fraux) + self._frstar = frstar.col_join(frstaraux) + return (self._fr, self._frstar) + + def _form_eoms(self): + fr, frstar = self.kanes_equations(self.bodylist, self.forcelist) + return fr + frstar + + def rhs(self, inv_method=None): + """Returns the system's equations of motion in first order form. The + output is the right hand side of:: + + x' = |q'| =: f(q, u, r, p, t) + |u'| + + The right hand side is what is needed by most numerical ODE + integrators. + + Parameters + ========== + + inv_method : str + The specific sympy inverse matrix calculation method to use. For a + list of valid methods, see + :meth:`~sympy.matrices.matrixbase.MatrixBase.inv` + + """ + rhs = zeros(len(self.q) + len(self.u), 1) + kdes = self.kindiffdict() + for i, q_i in enumerate(self.q): + rhs[i] = kdes[q_i.diff()] + + if inv_method is None: + rhs[len(self.q):, 0] = self.mass_matrix.LUsolve(self.forcing) + else: + rhs[len(self.q):, 0] = (self.mass_matrix.inv(inv_method, + try_block_diag=True) * + self.forcing) + + return rhs + + def kindiffdict(self): + """Returns a dictionary mapping q' to u.""" + if not self._qdot_u_map: + raise AttributeError('Create an instance of KanesMethod with ' + 'kinematic differential equations to use this method.') + return self._qdot_u_map + + @property + def auxiliary_eqs(self): + """A matrix containing the auxiliary equations.""" + if not self._fr or not self._frstar: + raise ValueError('Need to compute Fr, Fr* first.') + if not self._uaux: + raise ValueError('No auxiliary speeds have been declared.') + return self._aux_eq + + @property + def mass_matrix_kin(self): + r"""The kinematic "mass matrix" $\mathbf{k_{k\dot{q}}}$ of the system.""" + return self._k_kqdot if self.explicit_kinematics else self._k_kqdot_implicit + + @property + def forcing_kin(self): + """The kinematic "forcing vector" of the system.""" + if self.explicit_kinematics: + return -(self._k_ku * Matrix(self.u) + self._f_k) + else: + return -(self._k_ku_implicit * Matrix(self.u) + self._f_k_implicit) + + @property + def mass_matrix(self): + """The mass matrix of the system.""" + if not self._fr or not self._frstar: + raise ValueError('Need to compute Fr, Fr* first.') + return Matrix([self._k_d, self._k_dnh]) + + @property + def forcing(self): + """The forcing vector of the system.""" + if not self._fr or not self._frstar: + raise ValueError('Need to compute Fr, Fr* first.') + return -Matrix([self._f_d, self._f_dnh]) + + @property + def mass_matrix_full(self): + """The mass matrix of the system, augmented by the kinematic + differential equations in explicit or implicit form.""" + if not self._fr or not self._frstar: + raise ValueError('Need to compute Fr, Fr* first.') + o, n = len(self.u), len(self.q) + return (self.mass_matrix_kin.row_join(zeros(n, o))).col_join( + zeros(o, n).row_join(self.mass_matrix)) + + @property + def forcing_full(self): + """The forcing vector of the system, augmented by the kinematic + differential equations in explicit or implicit form.""" + return Matrix([self.forcing_kin, self.forcing]) + + @property + def q(self): + return self._q + + @property + def u(self): + return self._u + + @property + def bodylist(self): + return self._bodylist + + @property + def forcelist(self): + return self._forcelist + + @property + def bodies(self): + return self._bodylist + + @property + def loads(self): + return self._forcelist diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/lagrange.py b/phivenv/Lib/site-packages/sympy/physics/mechanics/lagrange.py new file mode 100644 index 0000000000000000000000000000000000000000..282176a404f77762abc3ee8c6a575519b2de1f02 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/mechanics/lagrange.py @@ -0,0 +1,512 @@ +from sympy import diff, zeros, Matrix, eye, sympify +from sympy.core.sorting import default_sort_key +from sympy.physics.vector import dynamicsymbols, ReferenceFrame +from sympy.physics.mechanics.method import _Methods +from sympy.physics.mechanics.functions import ( + find_dynamicsymbols, msubs, _f_list_parser, _validate_coordinates) +from sympy.physics.mechanics.linearize import Linearizer +from sympy.utilities.iterables import iterable + +__all__ = ['LagrangesMethod'] + + +class LagrangesMethod(_Methods): + """Lagrange's method object. + + Explanation + =========== + + This object generates the equations of motion in a two step procedure. The + first step involves the initialization of LagrangesMethod by supplying the + Lagrangian and the generalized coordinates, at the bare minimum. If there + are any constraint equations, they can be supplied as keyword arguments. + The Lagrange multipliers are automatically generated and are equal in + number to the constraint equations. Similarly any non-conservative forces + can be supplied in an iterable (as described below and also shown in the + example) along with a ReferenceFrame. This is also discussed further in the + __init__ method. + + Attributes + ========== + + q, u : Matrix + Matrices of the generalized coordinates and speeds + loads : iterable + Iterable of (Point, vector) or (ReferenceFrame, vector) tuples + describing the forces on the system. + bodies : iterable + Iterable containing the rigid bodies and particles of the system. + mass_matrix : Matrix + The system's mass matrix + forcing : Matrix + The system's forcing vector + mass_matrix_full : Matrix + The "mass matrix" for the qdot's, qdoubledot's, and the + lagrange multipliers (lam) + forcing_full : Matrix + The forcing vector for the qdot's, qdoubledot's and + lagrange multipliers (lam) + + Examples + ======== + + This is a simple example for a one degree of freedom translational + spring-mass-damper. + + In this example, we first need to do the kinematics. + This involves creating generalized coordinates and their derivatives. + Then we create a point and set its velocity in a frame. + + >>> from sympy.physics.mechanics import LagrangesMethod, Lagrangian + >>> from sympy.physics.mechanics import ReferenceFrame, Particle, Point + >>> from sympy.physics.mechanics import dynamicsymbols + >>> from sympy import symbols + >>> q = dynamicsymbols('q') + >>> qd = dynamicsymbols('q', 1) + >>> m, k, b = symbols('m k b') + >>> N = ReferenceFrame('N') + >>> P = Point('P') + >>> P.set_vel(N, qd * N.x) + + We need to then prepare the information as required by LagrangesMethod to + generate equations of motion. + First we create the Particle, which has a point attached to it. + Following this the lagrangian is created from the kinetic and potential + energies. + Then, an iterable of nonconservative forces/torques must be constructed, + where each item is a (Point, Vector) or (ReferenceFrame, Vector) tuple, + with the Vectors representing the nonconservative forces or torques. + + >>> Pa = Particle('Pa', P, m) + >>> Pa.potential_energy = k * q**2 / 2.0 + >>> L = Lagrangian(N, Pa) + >>> fl = [(P, -b * qd * N.x)] + + Finally we can generate the equations of motion. + First we create the LagrangesMethod object. To do this one must supply + the Lagrangian, and the generalized coordinates. The constraint equations, + the forcelist, and the inertial frame may also be provided, if relevant. + Next we generate Lagrange's equations of motion, such that: + Lagrange's equations of motion = 0. + We have the equations of motion at this point. + + >>> l = LagrangesMethod(L, [q], forcelist = fl, frame = N) + >>> print(l.form_lagranges_equations()) + Matrix([[b*Derivative(q(t), t) + 1.0*k*q(t) + m*Derivative(q(t), (t, 2))]]) + + We can also solve for the states using the 'rhs' method. + + >>> print(l.rhs()) + Matrix([[Derivative(q(t), t)], [(-b*Derivative(q(t), t) - 1.0*k*q(t))/m]]) + + Please refer to the docstrings on each method for more details. + """ + + def __init__(self, Lagrangian, qs, forcelist=None, bodies=None, frame=None, + hol_coneqs=None, nonhol_coneqs=None): + """Supply the following for the initialization of LagrangesMethod. + + Lagrangian : Sympifyable + + qs : array_like + The generalized coordinates + + hol_coneqs : array_like, optional + The holonomic constraint equations + + nonhol_coneqs : array_like, optional + The nonholonomic constraint equations + + forcelist : iterable, optional + Takes an iterable of (Point, Vector) or (ReferenceFrame, Vector) + tuples which represent the force at a point or torque on a frame. + This feature is primarily to account for the nonconservative forces + and/or moments. + + bodies : iterable, optional + Takes an iterable containing the rigid bodies and particles of the + system. + + frame : ReferenceFrame, optional + Supply the inertial frame. This is used to determine the + generalized forces due to non-conservative forces. + """ + + self._L = Matrix([sympify(Lagrangian)]) + self.eom = None + self._m_cd = Matrix() # Mass Matrix of differentiated coneqs + self._m_d = Matrix() # Mass Matrix of dynamic equations + self._f_cd = Matrix() # Forcing part of the diff coneqs + self._f_d = Matrix() # Forcing part of the dynamic equations + self.lam_coeffs = Matrix() # The coeffecients of the multipliers + + forcelist = forcelist if forcelist else [] + if not iterable(forcelist): + raise TypeError('Force pairs must be supplied in an iterable.') + self._forcelist = forcelist + if frame and not isinstance(frame, ReferenceFrame): + raise TypeError('frame must be a valid ReferenceFrame') + self._bodies = bodies + self.inertial = frame + + self.lam_vec = Matrix() + + self._term1 = Matrix() + self._term2 = Matrix() + self._term3 = Matrix() + self._term4 = Matrix() + + # Creating the qs, qdots and qdoubledots + if not iterable(qs): + raise TypeError('Generalized coordinates must be an iterable') + self._q = Matrix(qs) + self._qdots = self.q.diff(dynamicsymbols._t) + self._qdoubledots = self._qdots.diff(dynamicsymbols._t) + _validate_coordinates(self.q) + + mat_build = lambda x: Matrix(x) if x else Matrix() + hol_coneqs = mat_build(hol_coneqs) + nonhol_coneqs = mat_build(nonhol_coneqs) + self.coneqs = Matrix([hol_coneqs.diff(dynamicsymbols._t), + nonhol_coneqs]) + self._hol_coneqs = hol_coneqs + + def form_lagranges_equations(self): + """Method to form Lagrange's equations of motion. + + Returns a vector of equations of motion using Lagrange's equations of + the second kind. + """ + + qds = self._qdots + qdd_zero = dict.fromkeys(self._qdoubledots, 0) + n = len(self.q) + + # Internally we represent the EOM as four terms: + # EOM = term1 - term2 - term3 - term4 = 0 + + # First term + self._term1 = self._L.jacobian(qds) + self._term1 = self._term1.diff(dynamicsymbols._t).T + + # Second term + self._term2 = self._L.jacobian(self.q).T + + # Third term + if self.coneqs: + coneqs = self.coneqs + m = len(coneqs) + # Creating the multipliers + self.lam_vec = Matrix(dynamicsymbols('lam1:' + str(m + 1))) + self.lam_coeffs = -coneqs.jacobian(qds) + self._term3 = self.lam_coeffs.T * self.lam_vec + # Extracting the coeffecients of the qdds from the diff coneqs + diffconeqs = coneqs.diff(dynamicsymbols._t) + self._m_cd = diffconeqs.jacobian(self._qdoubledots) + # The remaining terms i.e. the 'forcing' terms in diff coneqs + self._f_cd = -diffconeqs.subs(qdd_zero) + else: + self._term3 = zeros(n, 1) + + # Fourth term + if self.forcelist: + N = self.inertial + self._term4 = zeros(n, 1) + for i, qd in enumerate(qds): + flist = zip(*_f_list_parser(self.forcelist, N)) + self._term4[i] = sum(v.diff(qd, N).dot(f) for (v, f) in flist) + else: + self._term4 = zeros(n, 1) + + # Form the dynamic mass and forcing matrices + without_lam = self._term1 - self._term2 - self._term4 + self._m_d = without_lam.jacobian(self._qdoubledots) + self._f_d = -without_lam.subs(qdd_zero) + + # Form the EOM + self.eom = without_lam - self._term3 + return self.eom + + def _form_eoms(self): + return self.form_lagranges_equations() + + @property + def mass_matrix(self): + """Returns the mass matrix, which is augmented by the Lagrange + multipliers, if necessary. + + Explanation + =========== + + If the system is described by 'n' generalized coordinates and there are + no constraint equations then an n X n matrix is returned. + + If there are 'n' generalized coordinates and 'm' constraint equations + have been supplied during initialization then an n X (n+m) matrix is + returned. The (n + m - 1)th and (n + m)th columns contain the + coefficients of the Lagrange multipliers. + """ + + if self.eom is None: + raise ValueError('Need to compute the equations of motion first') + if self.coneqs: + return (self._m_d).row_join(self.lam_coeffs.T) + else: + return self._m_d + + @property + def mass_matrix_full(self): + """Augments the coefficients of qdots to the mass_matrix.""" + + if self.eom is None: + raise ValueError('Need to compute the equations of motion first') + n = len(self.q) + m = len(self.coneqs) + row1 = eye(n).row_join(zeros(n, n + m)) + row2 = zeros(n, n).row_join(self.mass_matrix) + if self.coneqs: + row3 = zeros(m, n).row_join(self._m_cd).row_join(zeros(m, m)) + return row1.col_join(row2).col_join(row3) + else: + return row1.col_join(row2) + + @property + def forcing(self): + """Returns the forcing vector from 'lagranges_equations' method.""" + + if self.eom is None: + raise ValueError('Need to compute the equations of motion first') + return self._f_d + + @property + def forcing_full(self): + """Augments qdots to the forcing vector above.""" + + if self.eom is None: + raise ValueError('Need to compute the equations of motion first') + if self.coneqs: + return self._qdots.col_join(self.forcing).col_join(self._f_cd) + else: + return self._qdots.col_join(self.forcing) + + def to_linearizer(self, q_ind=None, qd_ind=None, q_dep=None, qd_dep=None, + linear_solver='LU'): + """Returns an instance of the Linearizer class, initiated from the data + in the LagrangesMethod class. This may be more desirable than using the + linearize class method, as the Linearizer object will allow more + efficient recalculation (i.e. about varying operating points). + + Parameters + ========== + + q_ind, qd_ind : array_like, optional + The independent generalized coordinates and speeds. + q_dep, qd_dep : array_like, optional + The dependent generalized coordinates and speeds. + linear_solver : str, callable + Method used to solve the several symbolic linear systems of the + form ``A*x=b`` in the linearization process. If a string is + supplied, it should be a valid method that can be used with the + :meth:`sympy.matrices.matrixbase.MatrixBase.solve`. If a callable is + supplied, it should have the format ``x = f(A, b)``, where it + solves the equations and returns the solution. The default is + ``'LU'`` which corresponds to SymPy's ``A.LUsolve(b)``. + ``LUsolve()`` is fast to compute but will often result in + divide-by-zero and thus ``nan`` results. + + Returns + ======= + Linearizer + An instantiated + :class:`sympy.physics.mechanics.linearize.Linearizer`. + + """ + + # Compose vectors + t = dynamicsymbols._t + q = self.q + u = self._qdots + ud = u.diff(t) + # Get vector of lagrange multipliers + lams = self.lam_vec + + mat_build = lambda x: Matrix(x) if x else Matrix() + q_i = mat_build(q_ind) + q_d = mat_build(q_dep) + u_i = mat_build(qd_ind) + u_d = mat_build(qd_dep) + + # Compose general form equations + f_c = self._hol_coneqs + f_v = self.coneqs + f_a = f_v.diff(t) + f_0 = u + f_1 = -u + f_2 = self._term1 + f_3 = -(self._term2 + self._term4) + f_4 = -self._term3 + + # Check that there are an appropriate number of independent and + # dependent coordinates + if len(q_d) != len(f_c) or len(u_d) != len(f_v): + raise ValueError(("Must supply {:} dependent coordinates, and " + + "{:} dependent speeds").format(len(f_c), len(f_v))) + if set(Matrix([q_i, q_d])) != set(q): + raise ValueError("Must partition q into q_ind and q_dep, with " + + "no extra or missing symbols.") + if set(Matrix([u_i, u_d])) != set(u): + raise ValueError("Must partition qd into qd_ind and qd_dep, " + + "with no extra or missing symbols.") + + # Find all other dynamic symbols, forming the forcing vector r. + # Sort r to make it canonical. + insyms = set(Matrix([q, u, ud, lams])) + r = list(find_dynamicsymbols(f_3, insyms)) + r.sort(key=default_sort_key) + # Check for any derivatives of variables in r that are also found in r. + for i in r: + if diff(i, dynamicsymbols._t) in r: + raise ValueError('Cannot have derivatives of specified \ + quantities when linearizing forcing terms.') + + return Linearizer(f_0, f_1, f_2, f_3, f_4, f_c, f_v, f_a, q, u, q_i, + q_d, u_i, u_d, r, lams, linear_solver=linear_solver) + + def linearize(self, q_ind=None, qd_ind=None, q_dep=None, qd_dep=None, + linear_solver='LU', **kwargs): + """Linearize the equations of motion about a symbolic operating point. + + Parameters + ========== + linear_solver : str, callable + Method used to solve the several symbolic linear systems of the + form ``A*x=b`` in the linearization process. If a string is + supplied, it should be a valid method that can be used with the + :meth:`sympy.matrices.matrixbase.MatrixBase.solve`. If a callable is + supplied, it should have the format ``x = f(A, b)``, where it + solves the equations and returns the solution. The default is + ``'LU'`` which corresponds to SymPy's ``A.LUsolve(b)``. + ``LUsolve()`` is fast to compute but will often result in + divide-by-zero and thus ``nan`` results. + **kwargs + Extra keyword arguments are passed to + :meth:`sympy.physics.mechanics.linearize.Linearizer.linearize`. + + Explanation + =========== + + If kwarg A_and_B is False (default), returns M, A, B, r for the + linearized form, M*[q', u']^T = A*[q_ind, u_ind]^T + B*r. + + If kwarg A_and_B is True, returns A, B, r for the linearized form + dx = A*x + B*r, where x = [q_ind, u_ind]^T. Note that this is + computationally intensive if there are many symbolic parameters. For + this reason, it may be more desirable to use the default A_and_B=False, + returning M, A, and B. Values may then be substituted in to these + matrices, and the state space form found as + A = P.T*M.inv()*A, B = P.T*M.inv()*B, where P = Linearizer.perm_mat. + + In both cases, r is found as all dynamicsymbols in the equations of + motion that are not part of q, u, q', or u'. They are sorted in + canonical form. + + The operating points may be also entered using the ``op_point`` kwarg. + This takes a dictionary of {symbol: value}, or a an iterable of such + dictionaries. The values may be numeric or symbolic. The more values + you can specify beforehand, the faster this computation will run. + + For more documentation, please see the ``Linearizer`` class.""" + + linearizer = self.to_linearizer(q_ind, qd_ind, q_dep, qd_dep, + linear_solver=linear_solver) + result = linearizer.linearize(**kwargs) + return result + (linearizer.r,) + + def solve_multipliers(self, op_point=None, sol_type='dict'): + """Solves for the values of the lagrange multipliers symbolically at + the specified operating point. + + Parameters + ========== + + op_point : dict or iterable of dicts, optional + Point at which to solve at. The operating point is specified as + a dictionary or iterable of dictionaries of {symbol: value}. The + value may be numeric or symbolic itself. + + sol_type : str, optional + Solution return type. Valid options are: + - 'dict': A dict of {symbol : value} (default) + - 'Matrix': An ordered column matrix of the solution + """ + + # Determine number of multipliers + k = len(self.lam_vec) + if k == 0: + raise ValueError("System has no lagrange multipliers to solve for.") + # Compose dict of operating conditions + if isinstance(op_point, dict): + op_point_dict = op_point + elif iterable(op_point): + op_point_dict = {} + for op in op_point: + op_point_dict.update(op) + elif op_point is None: + op_point_dict = {} + else: + raise TypeError("op_point must be either a dictionary or an " + "iterable of dictionaries.") + # Compose the system to be solved + mass_matrix = self.mass_matrix.col_join(-self.lam_coeffs.row_join( + zeros(k, k))) + force_matrix = self.forcing.col_join(self._f_cd) + # Sub in the operating point + mass_matrix = msubs(mass_matrix, op_point_dict) + force_matrix = msubs(force_matrix, op_point_dict) + # Solve for the multipliers + sol_list = mass_matrix.LUsolve(-force_matrix)[-k:] + if sol_type == 'dict': + return dict(zip(self.lam_vec, sol_list)) + elif sol_type == 'Matrix': + return Matrix(sol_list) + else: + raise ValueError("Unknown sol_type {:}.".format(sol_type)) + + def rhs(self, inv_method=None, **kwargs): + """Returns equations that can be solved numerically. + + Parameters + ========== + + inv_method : str + The specific sympy inverse matrix calculation method to use. For a + list of valid methods, see + :meth:`~sympy.matrices.matrixbase.MatrixBase.inv` + """ + + if inv_method is None: + self._rhs = self.mass_matrix_full.LUsolve(self.forcing_full) + else: + self._rhs = (self.mass_matrix_full.inv(inv_method, + try_block_diag=True) * self.forcing_full) + return self._rhs + + @property + def q(self): + return self._q + + @property + def u(self): + return self._qdots + + @property + def bodies(self): + return self._bodies + + @property + def forcelist(self): + return self._forcelist + + @property + def loads(self): + return self._forcelist diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/linearize.py b/phivenv/Lib/site-packages/sympy/physics/mechanics/linearize.py new file mode 100644 index 0000000000000000000000000000000000000000..b94ddb865a7236a5ac6f1a41ba96679eb8b2cd8f --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/mechanics/linearize.py @@ -0,0 +1,474 @@ +__all__ = ['Linearizer'] + +from sympy import Matrix, eye, zeros +from sympy.core.symbol import Dummy +from sympy.utilities.iterables import flatten +from sympy.physics.vector import dynamicsymbols +from sympy.physics.mechanics.functions import msubs, _parse_linear_solver + +from collections import namedtuple +from collections.abc import Iterable + + +class Linearizer: + """This object holds the general model form for a dynamic system. This + model is used for computing the linearized form of the system, while + properly dealing with constraints leading to dependent coordinates and + speeds. The notation and method is described in [1]_. + + Attributes + ========== + + f_0, f_1, f_2, f_3, f_4, f_c, f_v, f_a : Matrix + Matrices holding the general system form. + q, u, r : Matrix + Matrices holding the generalized coordinates, speeds, and + input vectors. + q_i, u_i : Matrix + Matrices of the independent generalized coordinates and speeds. + q_d, u_d : Matrix + Matrices of the dependent generalized coordinates and speeds. + perm_mat : Matrix + Permutation matrix such that [q_ind, u_ind]^T = perm_mat*[q, u]^T + + References + ========== + + .. [1] D. L. Peterson, G. Gede, and M. Hubbard, "Symbolic linearization of + equations of motion of constrained multibody systems," Multibody + Syst Dyn, vol. 33, no. 2, pp. 143-161, Feb. 2015, doi: + 10.1007/s11044-014-9436-5. + + """ + + def __init__(self, f_0, f_1, f_2, f_3, f_4, f_c, f_v, f_a, q, u, q_i=None, + q_d=None, u_i=None, u_d=None, r=None, lams=None, + linear_solver='LU'): + """ + Parameters + ========== + + f_0, f_1, f_2, f_3, f_4, f_c, f_v, f_a : array_like + System of equations holding the general system form. + Supply empty array or Matrix if the parameter + does not exist. + q : array_like + The generalized coordinates. + u : array_like + The generalized speeds + q_i, u_i : array_like, optional + The independent generalized coordinates and speeds. + q_d, u_d : array_like, optional + The dependent generalized coordinates and speeds. + r : array_like, optional + The input variables. + lams : array_like, optional + The lagrange multipliers + linear_solver : str, callable + Method used to solve the several symbolic linear systems of the + form ``A*x=b`` in the linearization process. If a string is + supplied, it should be a valid method that can be used with the + :meth:`sympy.matrices.matrixbase.MatrixBase.solve`. If a callable is + supplied, it should have the format ``x = f(A, b)``, where it + solves the equations and returns the solution. The default is + ``'LU'`` which corresponds to SymPy's ``A.LUsolve(b)``. + ``LUsolve()`` is fast to compute but will often result in + divide-by-zero and thus ``nan`` results. + + """ + self.linear_solver = _parse_linear_solver(linear_solver) + + # Generalized equation form + self.f_0 = Matrix(f_0) + self.f_1 = Matrix(f_1) + self.f_2 = Matrix(f_2) + self.f_3 = Matrix(f_3) + self.f_4 = Matrix(f_4) + self.f_c = Matrix(f_c) + self.f_v = Matrix(f_v) + self.f_a = Matrix(f_a) + + # Generalized equation variables + self.q = Matrix(q) + self.u = Matrix(u) + none_handler = lambda x: Matrix(x) if x else Matrix() + self.q_i = none_handler(q_i) + self.q_d = none_handler(q_d) + self.u_i = none_handler(u_i) + self.u_d = none_handler(u_d) + self.r = none_handler(r) + self.lams = none_handler(lams) + + # Derivatives of generalized equation variables + self._qd = self.q.diff(dynamicsymbols._t) + self._ud = self.u.diff(dynamicsymbols._t) + # If the user doesn't actually use generalized variables, and the + # qd and u vectors have any intersecting variables, this can cause + # problems. We'll fix this with some hackery, and Dummy variables + dup_vars = set(self._qd).intersection(self.u) + self._qd_dup = Matrix([var if var not in dup_vars else Dummy() for var + in self._qd]) + + # Derive dimension terms + l = len(self.f_c) + m = len(self.f_v) + n = len(self.q) + o = len(self.u) + s = len(self.r) + k = len(self.lams) + dims = namedtuple('dims', ['l', 'm', 'n', 'o', 's', 'k']) + self._dims = dims(l, m, n, o, s, k) + + self._Pq = None + self._Pqi = None + self._Pqd = None + self._Pu = None + self._Pui = None + self._Pud = None + self._C_0 = None + self._C_1 = None + self._C_2 = None + self.perm_mat = None + + self._setup_done = False + + def _setup(self): + # Calculations here only need to be run once. They are moved out of + # the __init__ method to increase the speed of Linearizer creation. + self._form_permutation_matrices() + self._form_block_matrices() + self._form_coefficient_matrices() + self._setup_done = True + + def _form_permutation_matrices(self): + """Form the permutation matrices Pq and Pu.""" + + # Extract dimension variables + l, m, n, o, s, k = self._dims + # Compute permutation matrices + if n != 0: + self._Pq = permutation_matrix(self.q, Matrix([self.q_i, self.q_d])) + if l > 0: + self._Pqi = self._Pq[:, :-l] + self._Pqd = self._Pq[:, -l:] + else: + self._Pqi = self._Pq + self._Pqd = Matrix() + if o != 0: + self._Pu = permutation_matrix(self.u, Matrix([self.u_i, self.u_d])) + if m > 0: + self._Pui = self._Pu[:, :-m] + self._Pud = self._Pu[:, -m:] + else: + self._Pui = self._Pu + self._Pud = Matrix() + # Compute combination permutation matrix for computing A and B + P_col1 = Matrix([self._Pqi, zeros(o + k, n - l)]) + P_col2 = Matrix([zeros(n, o - m), self._Pui, zeros(k, o - m)]) + if P_col1: + if P_col2: + self.perm_mat = P_col1.row_join(P_col2) + else: + self.perm_mat = P_col1 + else: + self.perm_mat = P_col2 + + def _form_coefficient_matrices(self): + """Form the coefficient matrices C_0, C_1, and C_2.""" + + # Extract dimension variables + l, m, n, o, s, k = self._dims + # Build up the coefficient matrices C_0, C_1, and C_2 + # If there are configuration constraints (l > 0), form C_0 as normal. + # If not, C_0 is I_(nxn). Note that this works even if n=0 + if l > 0: + f_c_jac_q = self.f_c.jacobian(self.q) + self._C_0 = (eye(n) - self._Pqd * + self.linear_solver(f_c_jac_q*self._Pqd, + f_c_jac_q))*self._Pqi + else: + self._C_0 = eye(n) + # If there are motion constraints (m > 0), form C_1 and C_2 as normal. + # If not, C_1 is 0, and C_2 is I_(oxo). Note that this works even if + # o = 0. + if m > 0: + f_v_jac_u = self.f_v.jacobian(self.u) + temp = f_v_jac_u * self._Pud + if n != 0: + f_v_jac_q = self.f_v.jacobian(self.q) + self._C_1 = -self._Pud * self.linear_solver(temp, f_v_jac_q) + else: + self._C_1 = zeros(o, n) + self._C_2 = (eye(o) - self._Pud * + self.linear_solver(temp, f_v_jac_u))*self._Pui + else: + self._C_1 = zeros(o, n) + self._C_2 = eye(o) + + def _form_block_matrices(self): + """Form the block matrices for composing M, A, and B.""" + + # Extract dimension variables + l, m, n, o, s, k = self._dims + # Block Matrix Definitions. These are only defined if under certain + # conditions. If undefined, an empty matrix is used instead + if n != 0: + self._M_qq = self.f_0.jacobian(self._qd) + self._A_qq = -(self.f_0 + self.f_1).jacobian(self.q) + else: + self._M_qq = Matrix() + self._A_qq = Matrix() + if n != 0 and m != 0: + self._M_uqc = self.f_a.jacobian(self._qd_dup) + self._A_uqc = -self.f_a.jacobian(self.q) + else: + self._M_uqc = Matrix() + self._A_uqc = Matrix() + if n != 0 and o - m + k != 0: + self._M_uqd = self.f_3.jacobian(self._qd_dup) + self._A_uqd = -(self.f_2 + self.f_3 + self.f_4).jacobian(self.q) + else: + self._M_uqd = Matrix() + self._A_uqd = Matrix() + if o != 0 and m != 0: + self._M_uuc = self.f_a.jacobian(self._ud) + self._A_uuc = -self.f_a.jacobian(self.u) + else: + self._M_uuc = Matrix() + self._A_uuc = Matrix() + if o != 0 and o - m + k != 0: + self._M_uud = self.f_2.jacobian(self._ud) + self._A_uud = -(self.f_2 + self.f_3).jacobian(self.u) + else: + self._M_uud = Matrix() + self._A_uud = Matrix() + if o != 0 and n != 0: + self._A_qu = -self.f_1.jacobian(self.u) + else: + self._A_qu = Matrix() + if k != 0 and o - m + k != 0: + self._M_uld = self.f_4.jacobian(self.lams) + else: + self._M_uld = Matrix() + if s != 0 and o - m + k != 0: + self._B_u = -self.f_3.jacobian(self.r) + else: + self._B_u = Matrix() + + def linearize(self, op_point=None, A_and_B=False, simplify=False): + """Linearize the system about the operating point. Note that + q_op, u_op, qd_op, ud_op must satisfy the equations of motion. + These may be either symbolic or numeric. + + Parameters + ========== + op_point : dict or iterable of dicts, optional + Dictionary or iterable of dictionaries containing the operating + point conditions for all or a subset of the generalized + coordinates, generalized speeds, and time derivatives of the + generalized speeds. These will be substituted into the linearized + system before the linearization is complete. Leave set to ``None`` + if you want the operating point to be an arbitrary set of symbols. + Note that any reduction in symbols (whether substituted for numbers + or expressions with a common parameter) will result in faster + runtime. + A_and_B : bool, optional + If A_and_B=False (default), (M, A, B) is returned and of + A_and_B=True, (A, B) is returned. See below. + simplify : bool, optional + Determines if returned values are simplified before return. + For large expressions this may be time consuming. Default is False. + + Returns + ======= + M, A, B : Matrices, ``A_and_B=False`` + Matrices from the implicit form: + ``[M]*[q', u']^T = [A]*[q_ind, u_ind]^T + [B]*r`` + A, B : Matrices, ``A_and_B=True`` + Matrices from the explicit form: + ``[q_ind', u_ind']^T = [A]*[q_ind, u_ind]^T + [B]*r`` + + Notes + ===== + + Note that the process of solving with A_and_B=True is computationally + intensive if there are many symbolic parameters. For this reason, it + may be more desirable to use the default A_and_B=False, returning M, A, + and B. More values may then be substituted in to these matrices later + on. The state space form can then be found as A = P.T*M.LUsolve(A), B = + P.T*M.LUsolve(B), where P = Linearizer.perm_mat. + + """ + + # Run the setup if needed: + if not self._setup_done: + self._setup() + + # Compose dict of operating conditions + if isinstance(op_point, dict): + op_point_dict = op_point + elif isinstance(op_point, Iterable): + op_point_dict = {} + for op in op_point: + op_point_dict.update(op) + else: + op_point_dict = {} + + # Extract dimension variables + l, m, n, o, s, k = self._dims + + # Rename terms to shorten expressions + M_qq = self._M_qq + M_uqc = self._M_uqc + M_uqd = self._M_uqd + M_uuc = self._M_uuc + M_uud = self._M_uud + M_uld = self._M_uld + A_qq = self._A_qq + A_uqc = self._A_uqc + A_uqd = self._A_uqd + A_qu = self._A_qu + A_uuc = self._A_uuc + A_uud = self._A_uud + B_u = self._B_u + C_0 = self._C_0 + C_1 = self._C_1 + C_2 = self._C_2 + + # Build up Mass Matrix + # |M_qq 0_nxo 0_nxk| + # M = |M_uqc M_uuc 0_mxk| + # |M_uqd M_uud M_uld| + if o != 0: + col2 = Matrix([zeros(n, o), M_uuc, M_uud]) + if k != 0: + col3 = Matrix([zeros(n + m, k), M_uld]) + if n != 0: + col1 = Matrix([M_qq, M_uqc, M_uqd]) + if o != 0 and k != 0: + M = col1.row_join(col2).row_join(col3) + elif o != 0: + M = col1.row_join(col2) + else: + M = col1 + elif k != 0: + M = col2.row_join(col3) + else: + M = col2 + M_eq = msubs(M, op_point_dict) + + # Build up state coefficient matrix A + # |(A_qq + A_qu*C_1)*C_0 A_qu*C_2| + # A = |(A_uqc + A_uuc*C_1)*C_0 A_uuc*C_2| + # |(A_uqd + A_uud*C_1)*C_0 A_uud*C_2| + # Col 1 is only defined if n != 0 + if n != 0: + r1c1 = A_qq + if o != 0: + r1c1 += (A_qu * C_1) + r1c1 = r1c1 * C_0 + if m != 0: + r2c1 = A_uqc + if o != 0: + r2c1 += (A_uuc * C_1) + r2c1 = r2c1 * C_0 + else: + r2c1 = Matrix() + if o - m + k != 0: + r3c1 = A_uqd + if o != 0: + r3c1 += (A_uud * C_1) + r3c1 = r3c1 * C_0 + else: + r3c1 = Matrix() + col1 = Matrix([r1c1, r2c1, r3c1]) + else: + col1 = Matrix() + # Col 2 is only defined if o != 0 + if o != 0: + if n != 0: + r1c2 = A_qu * C_2 + else: + r1c2 = Matrix() + if m != 0: + r2c2 = A_uuc * C_2 + else: + r2c2 = Matrix() + if o - m + k != 0: + r3c2 = A_uud * C_2 + else: + r3c2 = Matrix() + col2 = Matrix([r1c2, r2c2, r3c2]) + else: + col2 = Matrix() + if col1: + if col2: + Amat = col1.row_join(col2) + else: + Amat = col1 + else: + Amat = col2 + Amat_eq = msubs(Amat, op_point_dict) + + # Build up the B matrix if there are forcing variables + # |0_(n + m)xs| + # B = |B_u | + if s != 0 and o - m + k != 0: + Bmat = zeros(n + m, s).col_join(B_u) + Bmat_eq = msubs(Bmat, op_point_dict) + else: + Bmat_eq = Matrix() + + # kwarg A_and_B indicates to return A, B for forming the equation + # dx = [A]x + [B]r, where x = [q_indnd, u_indnd]^T, + if A_and_B: + A_cont = self.perm_mat.T * self.linear_solver(M_eq, Amat_eq) + if Bmat_eq: + B_cont = self.perm_mat.T * self.linear_solver(M_eq, Bmat_eq) + else: + # Bmat = Matrix([]), so no need to sub + B_cont = Bmat_eq + if simplify: + A_cont.simplify() + B_cont.simplify() + return A_cont, B_cont + # Otherwise return M, A, B for forming the equation + # [M]dx = [A]x + [B]r, where x = [q, u]^T + else: + if simplify: + M_eq.simplify() + Amat_eq.simplify() + Bmat_eq.simplify() + return M_eq, Amat_eq, Bmat_eq + + +def permutation_matrix(orig_vec, per_vec): + """Compute the permutation matrix to change order of + orig_vec into order of per_vec. + + Parameters + ========== + + orig_vec : array_like + Symbols in original ordering. + per_vec : array_like + Symbols in new ordering. + + Returns + ======= + + p_matrix : Matrix + Permutation matrix such that orig_vec == (p_matrix * per_vec). + """ + if not isinstance(orig_vec, (list, tuple)): + orig_vec = flatten(orig_vec) + if not isinstance(per_vec, (list, tuple)): + per_vec = flatten(per_vec) + if set(orig_vec) != set(per_vec): + raise ValueError("orig_vec and per_vec must be the same length, " + "and contain the same symbols.") + ind_list = [orig_vec.index(i) for i in per_vec] + p_matrix = zeros(len(orig_vec)) + for i, j in enumerate(ind_list): + p_matrix[i, j] = 1 + return p_matrix diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/loads.py b/phivenv/Lib/site-packages/sympy/physics/mechanics/loads.py new file mode 100644 index 0000000000000000000000000000000000000000..3b9db763ffd6f99905e9d17fdc07f4171de4801b --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/mechanics/loads.py @@ -0,0 +1,177 @@ +from abc import ABC +from collections import namedtuple +from sympy.physics.mechanics.body_base import BodyBase +from sympy.physics.vector import Vector, ReferenceFrame, Point + +__all__ = ['LoadBase', 'Force', 'Torque'] + + +class LoadBase(ABC, namedtuple('LoadBase', ['location', 'vector'])): + """Abstract base class for the various loading types.""" + + def __add__(self, other): + raise TypeError(f"unsupported operand type(s) for +: " + f"'{self.__class__.__name__}' and " + f"'{other.__class__.__name__}'") + + def __mul__(self, other): + raise TypeError(f"unsupported operand type(s) for *: " + f"'{self.__class__.__name__}' and " + f"'{other.__class__.__name__}'") + + __radd__ = __add__ + __rmul__ = __mul__ + + +class Force(LoadBase): + """Force acting upon a point. + + Explanation + =========== + + A force is a vector that is bound to a line of action. This class stores + both a point, which lies on the line of action, and the vector. A tuple can + also be used, with the location as the first entry and the vector as second + entry. + + Examples + ======== + + A force of magnitude 2 along N.x acting on a point Po can be created as + follows: + + >>> from sympy.physics.mechanics import Point, ReferenceFrame, Force + >>> N = ReferenceFrame('N') + >>> Po = Point('Po') + >>> Force(Po, 2 * N.x) + (Po, 2*N.x) + + If a body is supplied, then the center of mass of that body is used. + + >>> from sympy.physics.mechanics import Particle + >>> P = Particle('P', point=Po) + >>> Force(P, 2 * N.x) + (Po, 2*N.x) + + """ + + def __new__(cls, point, force): + if isinstance(point, BodyBase): + point = point.masscenter + if not isinstance(point, Point): + raise TypeError('Force location should be a Point.') + if not isinstance(force, Vector): + raise TypeError('Force vector should be a Vector.') + return super().__new__(cls, point, force) + + def __repr__(self): + return (f'{self.__class__.__name__}(point={self.point}, ' + f'force={self.force})') + + @property + def point(self): + return self.location + + @property + def force(self): + return self.vector + + +class Torque(LoadBase): + """Torque acting upon a frame. + + Explanation + =========== + + A torque is a free vector that is acting on a reference frame, which is + associated with a rigid body. This class stores both the frame and the + vector. A tuple can also be used, with the location as the first item and + the vector as second item. + + Examples + ======== + + A torque of magnitude 2 about N.x acting on a frame N can be created as + follows: + + >>> from sympy.physics.mechanics import ReferenceFrame, Torque + >>> N = ReferenceFrame('N') + >>> Torque(N, 2 * N.x) + (N, 2*N.x) + + If a body is supplied, then the frame fixed to that body is used. + + >>> from sympy.physics.mechanics import RigidBody + >>> rb = RigidBody('rb', frame=N) + >>> Torque(rb, 2 * N.x) + (N, 2*N.x) + + """ + + def __new__(cls, frame, torque): + if isinstance(frame, BodyBase): + frame = frame.frame + if not isinstance(frame, ReferenceFrame): + raise TypeError('Torque location should be a ReferenceFrame.') + if not isinstance(torque, Vector): + raise TypeError('Torque vector should be a Vector.') + return super().__new__(cls, frame, torque) + + def __repr__(self): + return (f'{self.__class__.__name__}(frame={self.frame}, ' + f'torque={self.torque})') + + @property + def frame(self): + return self.location + + @property + def torque(self): + return self.vector + + +def gravity(acceleration, *bodies): + """ + Returns a list of gravity forces given the acceleration + due to gravity and any number of particles or rigidbodies. + + Example + ======= + + >>> from sympy.physics.mechanics import ReferenceFrame, Particle, RigidBody + >>> from sympy.physics.mechanics.loads import gravity + >>> from sympy import symbols + >>> N = ReferenceFrame('N') + >>> g = symbols('g') + >>> P = Particle('P') + >>> B = RigidBody('B') + >>> gravity(g*N.y, P, B) + [(P_masscenter, P_mass*g*N.y), + (B_masscenter, B_mass*g*N.y)] + + """ + + gravity_force = [] + for body in bodies: + if not isinstance(body, BodyBase): + raise TypeError(f'{type(body)} is not a body type') + gravity_force.append(Force(body.masscenter, body.mass * acceleration)) + return gravity_force + + +def _parse_load(load): + """Helper function to parse loads and convert tuples to load objects.""" + if isinstance(load, LoadBase): + return load + elif isinstance(load, tuple): + if len(load) != 2: + raise ValueError(f'Load {load} should have a length of 2.') + if isinstance(load[0], Point): + return Force(load[0], load[1]) + elif isinstance(load[0], ReferenceFrame): + return Torque(load[0], load[1]) + else: + raise ValueError(f'Load not recognized. The load location {load[0]}' + f' should either be a Point or a ReferenceFrame.') + raise TypeError(f'Load type {type(load)} not recognized as a load. It ' + f'should be a Force, Torque or tuple.') diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/method.py b/phivenv/Lib/site-packages/sympy/physics/mechanics/method.py new file mode 100644 index 0000000000000000000000000000000000000000..5c2c4a5f388e56e37bd9ecdf6daffc08ffa51070 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/mechanics/method.py @@ -0,0 +1,39 @@ +from abc import ABC, abstractmethod + +class _Methods(ABC): + """Abstract Base Class for all methods.""" + + @abstractmethod + def q(self): + pass + + @abstractmethod + def u(self): + pass + + @abstractmethod + def bodies(self): + pass + + @abstractmethod + def loads(self): + pass + + @abstractmethod + def mass_matrix(self): + pass + + @abstractmethod + def forcing(self): + pass + + @abstractmethod + def mass_matrix_full(self): + pass + + @abstractmethod + def forcing_full(self): + pass + + def _form_eoms(self): + raise NotImplementedError("Subclasses must implement this.") diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/models.py b/phivenv/Lib/site-packages/sympy/physics/mechanics/models.py new file mode 100644 index 0000000000000000000000000000000000000000..a89b929ffd540a07787f6f94714850b348c90781 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/mechanics/models.py @@ -0,0 +1,230 @@ +#!/usr/bin/env python +"""This module contains some sample symbolic models used for testing and +examples.""" + +# Internal imports +from sympy.core import backend as sm +import sympy.physics.mechanics as me + + +def multi_mass_spring_damper(n=1, apply_gravity=False, + apply_external_forces=False): + r"""Returns a system containing the symbolic equations of motion and + associated variables for a simple multi-degree of freedom point mass, + spring, damper system with optional gravitational and external + specified forces. For example, a two mass system under the influence of + gravity and external forces looks like: + + :: + + ---------------- + | | | | g + \ | | | V + k0 / --- c0 | + | | | x0, v0 + --------- V + | m0 | ----- + --------- | + | | | | + \ v | | | + k1 / f0 --- c1 | + | | | x1, v1 + --------- V + | m1 | ----- + --------- + | f1 + V + + Parameters + ========== + + n : integer + The number of masses in the serial chain. + apply_gravity : boolean + If true, gravity will be applied to each mass. + apply_external_forces : boolean + If true, a time varying external force will be applied to each mass. + + Returns + ======= + + kane : sympy.physics.mechanics.kane.KanesMethod + A KanesMethod object. + + """ + + mass = sm.symbols('m:{}'.format(n)) + stiffness = sm.symbols('k:{}'.format(n)) + damping = sm.symbols('c:{}'.format(n)) + + acceleration_due_to_gravity = sm.symbols('g') + + coordinates = me.dynamicsymbols('x:{}'.format(n)) + speeds = me.dynamicsymbols('v:{}'.format(n)) + specifieds = me.dynamicsymbols('f:{}'.format(n)) + + ceiling = me.ReferenceFrame('N') + origin = me.Point('origin') + origin.set_vel(ceiling, 0) + + points = [origin] + kinematic_equations = [] + particles = [] + forces = [] + + for i in range(n): + + center = points[-1].locatenew('center{}'.format(i), + coordinates[i] * ceiling.x) + center.set_vel(ceiling, points[-1].vel(ceiling) + + speeds[i] * ceiling.x) + points.append(center) + + block = me.Particle('block{}'.format(i), center, mass[i]) + + kinematic_equations.append(speeds[i] - coordinates[i].diff()) + + total_force = (-stiffness[i] * coordinates[i] - + damping[i] * speeds[i]) + try: + total_force += (stiffness[i + 1] * coordinates[i + 1] + + damping[i + 1] * speeds[i + 1]) + except IndexError: # no force from below on last mass + pass + + if apply_gravity: + total_force += mass[i] * acceleration_due_to_gravity + + if apply_external_forces: + total_force += specifieds[i] + + forces.append((center, total_force * ceiling.x)) + + particles.append(block) + + kane = me.KanesMethod(ceiling, q_ind=coordinates, u_ind=speeds, + kd_eqs=kinematic_equations) + kane.kanes_equations(particles, forces) + + return kane + + +def n_link_pendulum_on_cart(n=1, cart_force=True, joint_torques=False): + r"""Returns the system containing the symbolic first order equations of + motion for a 2D n-link pendulum on a sliding cart under the influence of + gravity. + + :: + + | + o y v + \ 0 ^ g + \ | + --\-|---- + | \| | + F-> | o --|---> x + | | + --------- + o o + + Parameters + ========== + + n : integer + The number of links in the pendulum. + cart_force : boolean, default=True + If true an external specified lateral force is applied to the cart. + joint_torques : boolean, default=False + If true joint torques will be added as specified inputs at each + joint. + + Returns + ======= + + kane : sympy.physics.mechanics.kane.KanesMethod + A KanesMethod object. + + Notes + ===== + + The degrees of freedom of the system are n + 1, i.e. one for each + pendulum link and one for the lateral motion of the cart. + + M x' = F, where x = [u0, ..., un+1, q0, ..., qn+1] + + The joint angles are all defined relative to the ground where the x axis + defines the ground line and the y axis points up. The joint torques are + applied between each adjacent link and the between the cart and the + lower link where a positive torque corresponds to positive angle. + + """ + if n <= 0: + raise ValueError('The number of links must be a positive integer.') + + q = me.dynamicsymbols('q:{}'.format(n + 1)) + u = me.dynamicsymbols('u:{}'.format(n + 1)) + + if joint_torques is True: + T = me.dynamicsymbols('T1:{}'.format(n + 1)) + + m = sm.symbols('m:{}'.format(n + 1)) + l = sm.symbols('l:{}'.format(n)) + g, t = sm.symbols('g t') + + I = me.ReferenceFrame('I') + O = me.Point('O') + O.set_vel(I, 0) + + P0 = me.Point('P0') + P0.set_pos(O, q[0] * I.x) + P0.set_vel(I, u[0] * I.x) + Pa0 = me.Particle('Pa0', P0, m[0]) + + frames = [I] + points = [P0] + particles = [Pa0] + forces = [(P0, -m[0] * g * I.y)] + kindiffs = [q[0].diff(t) - u[0]] + + if cart_force is True or joint_torques is True: + specified = [] + else: + specified = None + + for i in range(n): + Bi = I.orientnew('B{}'.format(i), 'Axis', [q[i + 1], I.z]) + Bi.set_ang_vel(I, u[i + 1] * I.z) + frames.append(Bi) + + Pi = points[-1].locatenew('P{}'.format(i + 1), l[i] * Bi.y) + Pi.v2pt_theory(points[-1], I, Bi) + points.append(Pi) + + Pai = me.Particle('Pa' + str(i + 1), Pi, m[i + 1]) + particles.append(Pai) + + forces.append((Pi, -m[i + 1] * g * I.y)) + + if joint_torques is True: + + specified.append(T[i]) + + if i == 0: + forces.append((I, -T[i] * I.z)) + + if i == n - 1: + forces.append((Bi, T[i] * I.z)) + else: + forces.append((Bi, T[i] * I.z - T[i + 1] * I.z)) + + kindiffs.append(q[i + 1].diff(t) - u[i + 1]) + + if cart_force is True: + F = me.dynamicsymbols('F') + forces.append((P0, F * I.x)) + specified.append(F) + + kane = me.KanesMethod(I, q_ind=q, u_ind=u, kd_eqs=kindiffs) + kane.kanes_equations(particles, forces) + + return kane diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/particle.py b/phivenv/Lib/site-packages/sympy/physics/mechanics/particle.py new file mode 100644 index 0000000000000000000000000000000000000000..5d49d4f811b8d1c7fff16c71991f5e01da6ded02 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/mechanics/particle.py @@ -0,0 +1,209 @@ +from sympy import S +from sympy.physics.vector import cross, dot +from sympy.physics.mechanics.body_base import BodyBase +from sympy.physics.mechanics.inertia import inertia_of_point_mass +from sympy.utilities.exceptions import sympy_deprecation_warning + +__all__ = ['Particle'] + + +class Particle(BodyBase): + """A particle. + + Explanation + =========== + + Particles have a non-zero mass and lack spatial extension; they take up no + space. + + Values need to be supplied on initialization, but can be changed later. + + Parameters + ========== + + name : str + Name of particle + point : Point + A physics/mechanics Point which represents the position, velocity, and + acceleration of this Particle + mass : Sympifyable + A SymPy expression representing the Particle's mass + potential_energy : Sympifyable + The potential energy of the Particle. + + Examples + ======== + + >>> from sympy.physics.mechanics import Particle, Point + >>> from sympy import Symbol + >>> po = Point('po') + >>> m = Symbol('m') + >>> pa = Particle('pa', po, m) + >>> # Or you could change these later + >>> pa.mass = m + >>> pa.point = po + + """ + point = BodyBase.masscenter + + def __init__(self, name, point=None, mass=None): + super().__init__(name, point, mass) + + def linear_momentum(self, frame): + """Linear momentum of the particle. + + Explanation + =========== + + The linear momentum L, of a particle P, with respect to frame N is + given by: + + L = m * v + + where m is the mass of the particle, and v is the velocity of the + particle in the frame N. + + Parameters + ========== + + frame : ReferenceFrame + The frame in which linear momentum is desired. + + Examples + ======== + + >>> from sympy.physics.mechanics import Particle, Point, ReferenceFrame + >>> from sympy.physics.mechanics import dynamicsymbols + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> m, v = dynamicsymbols('m v') + >>> N = ReferenceFrame('N') + >>> P = Point('P') + >>> A = Particle('A', P, m) + >>> P.set_vel(N, v * N.x) + >>> A.linear_momentum(N) + m*v*N.x + + """ + + return self.mass * self.point.vel(frame) + + def angular_momentum(self, point, frame): + """Angular momentum of the particle about the point. + + Explanation + =========== + + The angular momentum H, about some point O of a particle, P, is given + by: + + ``H = cross(r, m * v)`` + + where r is the position vector from point O to the particle P, m is + the mass of the particle, and v is the velocity of the particle in + the inertial frame, N. + + Parameters + ========== + + point : Point + The point about which angular momentum of the particle is desired. + + frame : ReferenceFrame + The frame in which angular momentum is desired. + + Examples + ======== + + >>> from sympy.physics.mechanics import Particle, Point, ReferenceFrame + >>> from sympy.physics.mechanics import dynamicsymbols + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> m, v, r = dynamicsymbols('m v r') + >>> N = ReferenceFrame('N') + >>> O = Point('O') + >>> A = O.locatenew('A', r * N.x) + >>> P = Particle('P', A, m) + >>> P.point.set_vel(N, v * N.y) + >>> P.angular_momentum(O, N) + m*r*v*N.z + + """ + + return cross(self.point.pos_from(point), + self.mass * self.point.vel(frame)) + + def kinetic_energy(self, frame): + """Kinetic energy of the particle. + + Explanation + =========== + + The kinetic energy, T, of a particle, P, is given by: + + ``T = 1/2 (dot(m * v, v))`` + + where m is the mass of particle P, and v is the velocity of the + particle in the supplied ReferenceFrame. + + Parameters + ========== + + frame : ReferenceFrame + The Particle's velocity is typically defined with respect to + an inertial frame but any relevant frame in which the velocity is + known can be supplied. + + Examples + ======== + + >>> from sympy.physics.mechanics import Particle, Point, ReferenceFrame + >>> from sympy import symbols + >>> m, v, r = symbols('m v r') + >>> N = ReferenceFrame('N') + >>> O = Point('O') + >>> P = Particle('P', O, m) + >>> P.point.set_vel(N, v * N.y) + >>> P.kinetic_energy(N) + m*v**2/2 + + """ + + return S.Half * self.mass * dot(self.point.vel(frame), + self.point.vel(frame)) + + def set_potential_energy(self, scalar): + sympy_deprecation_warning( + """ +The sympy.physics.mechanics.Particle.set_potential_energy() +method is deprecated. Instead use + + P.potential_energy = scalar + """, + deprecated_since_version="1.5", + active_deprecations_target="deprecated-set-potential-energy", + ) + self.potential_energy = scalar + + def parallel_axis(self, point, frame): + """Returns an inertia dyadic of the particle with respect to another + point and frame. + + Parameters + ========== + + point : sympy.physics.vector.Point + The point to express the inertia dyadic about. + frame : sympy.physics.vector.ReferenceFrame + The reference frame used to construct the dyadic. + + Returns + ======= + + inertia : sympy.physics.vector.Dyadic + The inertia dyadic of the particle expressed about the provided + point and frame. + + """ + return inertia_of_point_mass(self.mass, self.point.pos_from(point), + frame) diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/pathway.py b/phivenv/Lib/site-packages/sympy/physics/mechanics/pathway.py new file mode 100644 index 0000000000000000000000000000000000000000..b86ba85b1d9d1434c51de3fd7cc429442fdbedb0 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/mechanics/pathway.py @@ -0,0 +1,688 @@ +"""Implementations of pathways for use by actuators.""" + +from abc import ABC, abstractmethod + +from sympy.core.singleton import S +from sympy.physics.mechanics.loads import Force +from sympy.physics.mechanics.wrapping_geometry import WrappingGeometryBase +from sympy.physics.vector import Point, dynamicsymbols + + +__all__ = ['PathwayBase', 'LinearPathway', 'ObstacleSetPathway', + 'WrappingPathway'] + + +class PathwayBase(ABC): + """Abstract base class for all pathway classes to inherit from. + + Notes + ===== + + Instances of this class cannot be directly instantiated by users. However, + it can be used to created custom pathway types through subclassing. + + """ + + def __init__(self, *attachments): + """Initializer for ``PathwayBase``.""" + self.attachments = attachments + + @property + def attachments(self): + """The pair of points defining a pathway's ends.""" + return self._attachments + + @attachments.setter + def attachments(self, attachments): + if hasattr(self, '_attachments'): + msg = ( + f'Can\'t set attribute `attachments` to {repr(attachments)} ' + f'as it is immutable.' + ) + raise AttributeError(msg) + if len(attachments) != 2: + msg = ( + f'Value {repr(attachments)} passed to `attachments` was an ' + f'iterable of length {len(attachments)}, must be an iterable ' + f'of length 2.' + ) + raise ValueError(msg) + for i, point in enumerate(attachments): + if not isinstance(point, Point): + msg = ( + f'Value {repr(point)} passed to `attachments` at index ' + f'{i} was of type {type(point)}, must be {Point}.' + ) + raise TypeError(msg) + self._attachments = tuple(attachments) + + @property + @abstractmethod + def length(self): + """An expression representing the pathway's length.""" + pass + + @property + @abstractmethod + def extension_velocity(self): + """An expression representing the pathway's extension velocity.""" + pass + + @abstractmethod + def to_loads(self, force): + """Loads required by the equations of motion method classes. + + Explanation + =========== + + ``KanesMethod`` requires a list of ``Point``-``Vector`` tuples to be + passed to the ``loads`` parameters of its ``kanes_equations`` method + when constructing the equations of motion. This method acts as a + utility to produce the correctly-structred pairs of points and vectors + required so that these can be easily concatenated with other items in + the list of loads and passed to ``KanesMethod.kanes_equations``. These + loads are also in the correct form to also be passed to the other + equations of motion method classes, e.g. ``LagrangesMethod``. + + """ + pass + + def __repr__(self): + """Default representation of a pathway.""" + attachments = ', '.join(str(a) for a in self.attachments) + return f'{self.__class__.__name__}({attachments})' + + +class LinearPathway(PathwayBase): + """Linear pathway between a pair of attachment points. + + Explanation + =========== + + A linear pathway forms a straight-line segment between two points and is + the simplest pathway that can be formed. It will not interact with any + other objects in the system, i.e. a ``LinearPathway`` will intersect other + objects to ensure that the path between its two ends (its attachments) is + the shortest possible. + + A linear pathway is made up of two points that can move relative to each + other, and a pair of equal and opposite forces acting on the points. If the + positive time-varying Euclidean distance between the two points is defined, + then the "extension velocity" is the time derivative of this distance. The + extension velocity is positive when the two points are moving away from + each other and negative when moving closer to each other. The direction for + the force acting on either point is determined by constructing a unit + vector directed from the other point to this point. This establishes a sign + convention such that a positive force magnitude tends to push the points + apart. The following diagram shows the positive force sense and the + distance between the points:: + + P Q + o<--- F --->o + | | + |<--l(t)--->| + + Examples + ======== + + >>> from sympy.physics.mechanics import LinearPathway + + To construct a pathway, two points are required to be passed to the + ``attachments`` parameter as a ``tuple``. + + >>> from sympy.physics.mechanics import Point + >>> pA, pB = Point('pA'), Point('pB') + >>> linear_pathway = LinearPathway(pA, pB) + >>> linear_pathway + LinearPathway(pA, pB) + + The pathway created above isn't very interesting without the positions and + velocities of its attachment points being described. Without this its not + possible to describe how the pathway moves, i.e. its length or its + extension velocity. + + >>> from sympy.physics.mechanics import ReferenceFrame + >>> from sympy.physics.vector import dynamicsymbols + >>> N = ReferenceFrame('N') + >>> q = dynamicsymbols('q') + >>> pB.set_pos(pA, q*N.x) + >>> pB.pos_from(pA) + q(t)*N.x + + A pathway's length can be accessed via its ``length`` attribute. + + >>> linear_pathway.length + sqrt(q(t)**2) + + Note how what appears to be an overly-complex expression is returned. This + is actually required as it ensures that a pathway's length is always + positive. + + A pathway's extension velocity can be accessed similarly via its + ``extension_velocity`` attribute. + + >>> linear_pathway.extension_velocity + sqrt(q(t)**2)*Derivative(q(t), t)/q(t) + + Parameters + ========== + + attachments : tuple[Point, Point] + Pair of ``Point`` objects between which the linear pathway spans. + Constructor expects two points to be passed, e.g. + ``LinearPathway(Point('pA'), Point('pB'))``. More or fewer points will + cause an error to be thrown. + + """ + + def __init__(self, *attachments): + """Initializer for ``LinearPathway``. + + Parameters + ========== + + attachments : Point + Pair of ``Point`` objects between which the linear pathway spans. + Constructor expects two points to be passed, e.g. + ``LinearPathway(Point('pA'), Point('pB'))``. More or fewer points + will cause an error to be thrown. + + """ + super().__init__(*attachments) + + @property + def length(self): + """Exact analytical expression for the pathway's length.""" + return _point_pair_length(*self.attachments) + + @property + def extension_velocity(self): + """Exact analytical expression for the pathway's extension velocity.""" + return _point_pair_extension_velocity(*self.attachments) + + def to_loads(self, force): + """Loads required by the equations of motion method classes. + + Explanation + =========== + + ``KanesMethod`` requires a list of ``Point``-``Vector`` tuples to be + passed to the ``loads`` parameters of its ``kanes_equations`` method + when constructing the equations of motion. This method acts as a + utility to produce the correctly-structred pairs of points and vectors + required so that these can be easily concatenated with other items in + the list of loads and passed to ``KanesMethod.kanes_equations``. These + loads are also in the correct form to also be passed to the other + equations of motion method classes, e.g. ``LagrangesMethod``. + + Examples + ======== + + The below example shows how to generate the loads produced in a linear + actuator that produces an expansile force ``F``. First, create a linear + actuator between two points separated by the coordinate ``q`` in the + ``x`` direction of the global frame ``N``. + + >>> from sympy.physics.mechanics import (LinearPathway, Point, + ... ReferenceFrame) + >>> from sympy.physics.vector import dynamicsymbols + >>> q = dynamicsymbols('q') + >>> N = ReferenceFrame('N') + >>> pA, pB = Point('pA'), Point('pB') + >>> pB.set_pos(pA, q*N.x) + >>> linear_pathway = LinearPathway(pA, pB) + + Now create a symbol ``F`` to describe the magnitude of the (expansile) + force that will be produced along the pathway. The list of loads that + ``KanesMethod`` requires can be produced by calling the pathway's + ``to_loads`` method with ``F`` passed as the only argument. + + >>> from sympy import symbols + >>> F = symbols('F') + >>> linear_pathway.to_loads(F) + [(pA, - F*q(t)/sqrt(q(t)**2)*N.x), (pB, F*q(t)/sqrt(q(t)**2)*N.x)] + + Parameters + ========== + + force : Expr + Magnitude of the force acting along the length of the pathway. As + per the sign conventions for the pathway length, pathway extension + velocity, and pair of point forces, if this ``Expr`` is positive + then the force will act to push the pair of points away from one + another (it is expansile). + + """ + relative_position = _point_pair_relative_position(*self.attachments) + loads = [ + Force(self.attachments[0], -force*relative_position/self.length), + Force(self.attachments[-1], force*relative_position/self.length), + ] + return loads + + +class ObstacleSetPathway(PathwayBase): + """Obstacle-set pathway between a set of attachment points. + + Explanation + =========== + + An obstacle-set pathway forms a series of straight-line segment between + pairs of consecutive points in a set of points. It is similar to multiple + linear pathways joined end-to-end. It will not interact with any other + objects in the system, i.e. an ``ObstacleSetPathway`` will intersect other + objects to ensure that the path between its pairs of points (its + attachments) is the shortest possible. + + Examples + ======== + + To construct an obstacle-set pathway, three or more points are required to + be passed to the ``attachments`` parameter as a ``tuple``. + + >>> from sympy.physics.mechanics import ObstacleSetPathway, Point + >>> pA, pB, pC, pD = Point('pA'), Point('pB'), Point('pC'), Point('pD') + >>> obstacle_set_pathway = ObstacleSetPathway(pA, pB, pC, pD) + >>> obstacle_set_pathway + ObstacleSetPathway(pA, pB, pC, pD) + + The pathway created above isn't very interesting without the positions and + velocities of its attachment points being described. Without this its not + possible to describe how the pathway moves, i.e. its length or its + extension velocity. + + >>> from sympy import cos, sin + >>> from sympy.physics.mechanics import ReferenceFrame + >>> from sympy.physics.vector import dynamicsymbols + >>> N = ReferenceFrame('N') + >>> q = dynamicsymbols('q') + >>> pO = Point('pO') + >>> pA.set_pos(pO, N.y) + >>> pB.set_pos(pO, -N.x) + >>> pC.set_pos(pA, cos(q) * N.x - (sin(q) + 1) * N.y) + >>> pD.set_pos(pA, sin(q) * N.x + (cos(q) - 1) * N.y) + >>> pB.pos_from(pA) + - N.x - N.y + >>> pC.pos_from(pA) + cos(q(t))*N.x + (-sin(q(t)) - 1)*N.y + >>> pD.pos_from(pA) + sin(q(t))*N.x + (cos(q(t)) - 1)*N.y + + A pathway's length can be accessed via its ``length`` attribute. + + >>> obstacle_set_pathway.length.simplify() + sqrt(2)*(sqrt(cos(q(t)) + 1) + 2) + + A pathway's extension velocity can be accessed similarly via its + ``extension_velocity`` attribute. + + >>> obstacle_set_pathway.extension_velocity.simplify() + -sqrt(2)*sin(q(t))*Derivative(q(t), t)/(2*sqrt(cos(q(t)) + 1)) + + Parameters + ========== + + attachments : tuple[Point, ...] + The set of ``Point`` objects that define the segmented obstacle-set + pathway. + + """ + + def __init__(self, *attachments): + """Initializer for ``ObstacleSetPathway``. + + Parameters + ========== + + attachments : tuple[Point, ...] + The set of ``Point`` objects that define the segmented obstacle-set + pathway. + + """ + super().__init__(*attachments) + + @property + def attachments(self): + """The set of points defining a pathway's segmented path.""" + return self._attachments + + @attachments.setter + def attachments(self, attachments): + if hasattr(self, '_attachments'): + msg = ( + f'Can\'t set attribute `attachments` to {repr(attachments)} ' + f'as it is immutable.' + ) + raise AttributeError(msg) + if len(attachments) <= 2: + msg = ( + f'Value {repr(attachments)} passed to `attachments` was an ' + f'iterable of length {len(attachments)}, must be an iterable ' + f'of length 3 or greater.' + ) + raise ValueError(msg) + for i, point in enumerate(attachments): + if not isinstance(point, Point): + msg = ( + f'Value {repr(point)} passed to `attachments` at index ' + f'{i} was of type {type(point)}, must be {Point}.' + ) + raise TypeError(msg) + self._attachments = tuple(attachments) + + @property + def length(self): + """Exact analytical expression for the pathway's length.""" + length = S.Zero + attachment_pairs = zip(self.attachments[:-1], self.attachments[1:]) + for attachment_pair in attachment_pairs: + length += _point_pair_length(*attachment_pair) + return length + + @property + def extension_velocity(self): + """Exact analytical expression for the pathway's extension velocity.""" + extension_velocity = S.Zero + attachment_pairs = zip(self.attachments[:-1], self.attachments[1:]) + for attachment_pair in attachment_pairs: + extension_velocity += _point_pair_extension_velocity(*attachment_pair) + return extension_velocity + + def to_loads(self, force): + """Loads required by the equations of motion method classes. + + Explanation + =========== + + ``KanesMethod`` requires a list of ``Point``-``Vector`` tuples to be + passed to the ``loads`` parameters of its ``kanes_equations`` method + when constructing the equations of motion. This method acts as a + utility to produce the correctly-structred pairs of points and vectors + required so that these can be easily concatenated with other items in + the list of loads and passed to ``KanesMethod.kanes_equations``. These + loads are also in the correct form to also be passed to the other + equations of motion method classes, e.g. ``LagrangesMethod``. + + Examples + ======== + + The below example shows how to generate the loads produced in an + actuator that follows an obstacle-set pathway between four points and + produces an expansile force ``F``. First, create a pair of reference + frames, ``A`` and ``B``, in which the four points ``pA``, ``pB``, + ``pC``, and ``pD`` will be located. The first two points in frame ``A`` + and the second two in frame ``B``. Frame ``B`` will also be oriented + such that it relates to ``A`` via a rotation of ``q`` about an axis + ``N.z`` in a global frame (``N.z``, ``A.z``, and ``B.z`` are parallel). + + >>> from sympy.physics.mechanics import (ObstacleSetPathway, Point, + ... ReferenceFrame) + >>> from sympy.physics.vector import dynamicsymbols + >>> q = dynamicsymbols('q') + >>> N = ReferenceFrame('N') + >>> N = ReferenceFrame('N') + >>> A = N.orientnew('A', 'axis', (0, N.x)) + >>> B = A.orientnew('B', 'axis', (q, N.z)) + >>> pO = Point('pO') + >>> pA, pB, pC, pD = Point('pA'), Point('pB'), Point('pC'), Point('pD') + >>> pA.set_pos(pO, A.x) + >>> pB.set_pos(pO, -A.y) + >>> pC.set_pos(pO, B.y) + >>> pD.set_pos(pO, B.x) + >>> obstacle_set_pathway = ObstacleSetPathway(pA, pB, pC, pD) + + Now create a symbol ``F`` to describe the magnitude of the (expansile) + force that will be produced along the pathway. The list of loads that + ``KanesMethod`` requires can be produced by calling the pathway's + ``to_loads`` method with ``F`` passed as the only argument. + + >>> from sympy import Symbol + >>> F = Symbol('F') + >>> obstacle_set_pathway.to_loads(F) + [(pA, sqrt(2)*F/2*A.x + sqrt(2)*F/2*A.y), + (pB, - sqrt(2)*F/2*A.x - sqrt(2)*F/2*A.y), + (pB, - F/sqrt(2*cos(q(t)) + 2)*A.y - F/sqrt(2*cos(q(t)) + 2)*B.y), + (pC, F/sqrt(2*cos(q(t)) + 2)*A.y + F/sqrt(2*cos(q(t)) + 2)*B.y), + (pC, - sqrt(2)*F/2*B.x + sqrt(2)*F/2*B.y), + (pD, sqrt(2)*F/2*B.x - sqrt(2)*F/2*B.y)] + + Parameters + ========== + + force : Expr + The force acting along the length of the pathway. It is assumed + that this ``Expr`` represents an expansile force. + + """ + loads = [] + attachment_pairs = zip(self.attachments[:-1], self.attachments[1:]) + for attachment_pair in attachment_pairs: + relative_position = _point_pair_relative_position(*attachment_pair) + length = _point_pair_length(*attachment_pair) + loads.extend([ + Force(attachment_pair[0], -force*relative_position/length), + Force(attachment_pair[1], force*relative_position/length), + ]) + return loads + + +class WrappingPathway(PathwayBase): + """Pathway that wraps a geometry object. + + Explanation + =========== + + A wrapping pathway interacts with a geometry object and forms a path that + wraps smoothly along its surface. The wrapping pathway along the geometry + object will be the geodesic that the geometry object defines based on the + two points. It will not interact with any other objects in the system, i.e. + a ``WrappingPathway`` will intersect other objects to ensure that the path + between its two ends (its attachments) is the shortest possible. + + To explain the sign conventions used for pathway length, extension + velocity, and direction of applied forces, we can ignore the geometry with + which the wrapping pathway interacts. A wrapping pathway is made up of two + points that can move relative to each other, and a pair of equal and + opposite forces acting on the points. If the positive time-varying + Euclidean distance between the two points is defined, then the "extension + velocity" is the time derivative of this distance. The extension velocity + is positive when the two points are moving away from each other and + negative when moving closer to each other. The direction for the force + acting on either point is determined by constructing a unit vector directed + from the other point to this point. This establishes a sign convention such + that a positive force magnitude tends to push the points apart. The + following diagram shows the positive force sense and the distance between + the points:: + + P Q + o<--- F --->o + | | + |<--l(t)--->| + + Examples + ======== + + >>> from sympy.physics.mechanics import WrappingPathway + + To construct a wrapping pathway, like other pathways, a pair of points must + be passed, followed by an instance of a wrapping geometry class as a + keyword argument. We'll use a cylinder with radius ``r`` and its axis + parallel to ``N.x`` passing through a point ``pO``. + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import Point, ReferenceFrame, WrappingCylinder + >>> r = symbols('r') + >>> N = ReferenceFrame('N') + >>> pA, pB, pO = Point('pA'), Point('pB'), Point('pO') + >>> cylinder = WrappingCylinder(r, pO, N.x) + >>> wrapping_pathway = WrappingPathway(pA, pB, cylinder) + >>> wrapping_pathway + WrappingPathway(pA, pB, geometry=WrappingCylinder(radius=r, point=pO, + axis=N.x)) + + Parameters + ========== + + attachment_1 : Point + First of the pair of ``Point`` objects between which the wrapping + pathway spans. + attachment_2 : Point + Second of the pair of ``Point`` objects between which the wrapping + pathway spans. + geometry : WrappingGeometryBase + Geometry about which the pathway wraps. + + """ + + def __init__(self, attachment_1, attachment_2, geometry): + """Initializer for ``WrappingPathway``. + + Parameters + ========== + + attachment_1 : Point + First of the pair of ``Point`` objects between which the wrapping + pathway spans. + attachment_2 : Point + Second of the pair of ``Point`` objects between which the wrapping + pathway spans. + geometry : WrappingGeometryBase + Geometry about which the pathway wraps. + The geometry about which the pathway wraps. + + """ + super().__init__(attachment_1, attachment_2) + self.geometry = geometry + + @property + def geometry(self): + """Geometry around which the pathway wraps.""" + return self._geometry + + @geometry.setter + def geometry(self, geometry): + if hasattr(self, '_geometry'): + msg = ( + f'Can\'t set attribute `geometry` to {repr(geometry)} as it ' + f'is immutable.' + ) + raise AttributeError(msg) + if not isinstance(geometry, WrappingGeometryBase): + msg = ( + f'Value {repr(geometry)} passed to `geometry` was of type ' + f'{type(geometry)}, must be {WrappingGeometryBase}.' + ) + raise TypeError(msg) + self._geometry = geometry + + @property + def length(self): + """Exact analytical expression for the pathway's length.""" + return self.geometry.geodesic_length(*self.attachments) + + @property + def extension_velocity(self): + """Exact analytical expression for the pathway's extension velocity.""" + return self.length.diff(dynamicsymbols._t) + + def to_loads(self, force): + """Loads required by the equations of motion method classes. + + Explanation + =========== + + ``KanesMethod`` requires a list of ``Point``-``Vector`` tuples to be + passed to the ``loads`` parameters of its ``kanes_equations`` method + when constructing the equations of motion. This method acts as a + utility to produce the correctly-structred pairs of points and vectors + required so that these can be easily concatenated with other items in + the list of loads and passed to ``KanesMethod.kanes_equations``. These + loads are also in the correct form to also be passed to the other + equations of motion method classes, e.g. ``LagrangesMethod``. + + Examples + ======== + + The below example shows how to generate the loads produced in an + actuator that produces an expansile force ``F`` while wrapping around a + cylinder. First, create a cylinder with radius ``r`` and an axis + parallel to the ``N.z`` direction of the global frame ``N`` that also + passes through a point ``pO``. + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import (Point, ReferenceFrame, + ... WrappingCylinder) + >>> N = ReferenceFrame('N') + >>> r = symbols('r', positive=True) + >>> pO = Point('pO') + >>> cylinder = WrappingCylinder(r, pO, N.z) + + Create the pathway of the actuator using the ``WrappingPathway`` class, + defined to span between two points ``pA`` and ``pB``. Both points lie + on the surface of the cylinder and the location of ``pB`` is defined + relative to ``pA`` by the dynamics symbol ``q``. + + >>> from sympy import cos, sin + >>> from sympy.physics.mechanics import WrappingPathway, dynamicsymbols + >>> q = dynamicsymbols('q') + >>> pA = Point('pA') + >>> pB = Point('pB') + >>> pA.set_pos(pO, r*N.x) + >>> pB.set_pos(pO, r*(cos(q)*N.x + sin(q)*N.y)) + >>> pB.pos_from(pA) + (r*cos(q(t)) - r)*N.x + r*sin(q(t))*N.y + >>> pathway = WrappingPathway(pA, pB, cylinder) + + Now create a symbol ``F`` to describe the magnitude of the (expansile) + force that will be produced along the pathway. The list of loads that + ``KanesMethod`` requires can be produced by calling the pathway's + ``to_loads`` method with ``F`` passed as the only argument. + + >>> F = symbols('F') + >>> loads = pathway.to_loads(F) + >>> [load.__class__(load.location, load.vector.simplify()) for load in loads] + [(pA, F*N.y), (pB, F*sin(q(t))*N.x - F*cos(q(t))*N.y), + (pO, - F*sin(q(t))*N.x + F*(cos(q(t)) - 1)*N.y)] + + Parameters + ========== + + force : Expr + Magnitude of the force acting along the length of the pathway. It + is assumed that this ``Expr`` represents an expansile force. + + """ + pA, pB = self.attachments + pO = self.geometry.point + pA_force, pB_force = self.geometry.geodesic_end_vectors(pA, pB) + pO_force = -(pA_force + pB_force) + + loads = [ + Force(pA, force * pA_force), + Force(pB, force * pB_force), + Force(pO, force * pO_force), + ] + return loads + + def __repr__(self): + """Representation of a ``WrappingPathway``.""" + attachments = ', '.join(str(a) for a in self.attachments) + return ( + f'{self.__class__.__name__}({attachments}, ' + f'geometry={self.geometry})' + ) + + +def _point_pair_relative_position(point_1, point_2): + """The relative position between a pair of points.""" + return point_2.pos_from(point_1) + + +def _point_pair_length(point_1, point_2): + """The length of the direct linear path between two points.""" + return _point_pair_relative_position(point_1, point_2).magnitude() + + +def _point_pair_extension_velocity(point_1, point_2): + """The extension velocity of the direct linear path between two points.""" + return _point_pair_length(point_1, point_2).diff(dynamicsymbols._t) diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/rigidbody.py b/phivenv/Lib/site-packages/sympy/physics/mechanics/rigidbody.py new file mode 100644 index 0000000000000000000000000000000000000000..7cc61ff468f7f26d98209a48ca59ffa12a570490 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/mechanics/rigidbody.py @@ -0,0 +1,314 @@ +from sympy import Symbol, S +from sympy.physics.vector import ReferenceFrame, Dyadic, Point, dot +from sympy.physics.mechanics.body_base import BodyBase +from sympy.physics.mechanics.inertia import inertia_of_point_mass, Inertia +from sympy.utilities.exceptions import sympy_deprecation_warning + +__all__ = ['RigidBody'] + + +class RigidBody(BodyBase): + """An idealized rigid body. + + Explanation + =========== + + This is essentially a container which holds the various components which + describe a rigid body: a name, mass, center of mass, reference frame, and + inertia. + + All of these need to be supplied on creation, but can be changed + afterwards. + + Attributes + ========== + + name : string + The body's name. + masscenter : Point + The point which represents the center of mass of the rigid body. + frame : ReferenceFrame + The ReferenceFrame which the rigid body is fixed in. + mass : Sympifyable + The body's mass. + inertia : (Dyadic, Point) + The body's inertia about a point; stored in a tuple as shown above. + potential_energy : Sympifyable + The potential energy of the RigidBody. + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy.physics.mechanics import ReferenceFrame, Point, RigidBody + >>> from sympy.physics.mechanics import outer + >>> m = Symbol('m') + >>> A = ReferenceFrame('A') + >>> P = Point('P') + >>> I = outer (A.x, A.x) + >>> inertia_tuple = (I, P) + >>> B = RigidBody('B', P, A, m, inertia_tuple) + >>> # Or you could change them afterwards + >>> m2 = Symbol('m2') + >>> B.mass = m2 + + """ + + def __init__(self, name, masscenter=None, frame=None, mass=None, + inertia=None): + super().__init__(name, masscenter, mass) + if frame is None: + frame = ReferenceFrame(f'{name}_frame') + self.frame = frame + if inertia is None: + ixx = Symbol(f'{name}_ixx') + iyy = Symbol(f'{name}_iyy') + izz = Symbol(f'{name}_izz') + izx = Symbol(f'{name}_izx') + ixy = Symbol(f'{name}_ixy') + iyz = Symbol(f'{name}_iyz') + inertia = Inertia.from_inertia_scalars(self.masscenter, self.frame, + ixx, iyy, izz, ixy, iyz, izx) + self.inertia = inertia + + def __repr__(self): + return (f'{self.__class__.__name__}({repr(self.name)}, masscenter=' + f'{repr(self.masscenter)}, frame={repr(self.frame)}, mass=' + f'{repr(self.mass)}, inertia={repr(self.inertia)})') + + @property + def frame(self): + """The ReferenceFrame fixed to the body.""" + return self._frame + + @frame.setter + def frame(self, F): + if not isinstance(F, ReferenceFrame): + raise TypeError("RigidBody frame must be a ReferenceFrame object.") + self._frame = F + + @property + def x(self): + """The basis Vector for the body, in the x direction. """ + return self.frame.x + + @property + def y(self): + """The basis Vector for the body, in the y direction. """ + return self.frame.y + + @property + def z(self): + """The basis Vector for the body, in the z direction. """ + return self.frame.z + + @property + def inertia(self): + """The body's inertia about a point; stored as (Dyadic, Point).""" + return self._inertia + + @inertia.setter + def inertia(self, I): + # check if I is of the form (Dyadic, Point) + if len(I) != 2 or not isinstance(I[0], Dyadic) or not isinstance(I[1], Point): + raise TypeError("RigidBody inertia must be a tuple of the form (Dyadic, Point).") + + self._inertia = Inertia(I[0], I[1]) + # have I S/O, want I S/S* + # I S/O = I S/S* + I S*/O; I S/S* = I S/O - I S*/O + # I_S/S* = I_S/O - I_S*/O + I_Ss_O = inertia_of_point_mass(self.mass, + self.masscenter.pos_from(I[1]), + self.frame) + self._central_inertia = I[0] - I_Ss_O + + @property + def central_inertia(self): + """The body's central inertia dyadic.""" + return self._central_inertia + + @central_inertia.setter + def central_inertia(self, I): + if not isinstance(I, Dyadic): + raise TypeError("RigidBody inertia must be a Dyadic object.") + self.inertia = Inertia(I, self.masscenter) + + def linear_momentum(self, frame): + """ Linear momentum of the rigid body. + + Explanation + =========== + + The linear momentum L, of a rigid body B, with respect to frame N is + given by: + + ``L = m * v`` + + where m is the mass of the rigid body, and v is the velocity of the mass + center of B in the frame N. + + Parameters + ========== + + frame : ReferenceFrame + The frame in which linear momentum is desired. + + Examples + ======== + + >>> from sympy.physics.mechanics import Point, ReferenceFrame, outer + >>> from sympy.physics.mechanics import RigidBody, dynamicsymbols + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> m, v = dynamicsymbols('m v') + >>> N = ReferenceFrame('N') + >>> P = Point('P') + >>> P.set_vel(N, v * N.x) + >>> I = outer (N.x, N.x) + >>> Inertia_tuple = (I, P) + >>> B = RigidBody('B', P, N, m, Inertia_tuple) + >>> B.linear_momentum(N) + m*v*N.x + + """ + + return self.mass * self.masscenter.vel(frame) + + def angular_momentum(self, point, frame): + """Returns the angular momentum of the rigid body about a point in the + given frame. + + Explanation + =========== + + The angular momentum H of a rigid body B about some point O in a frame N + is given by: + + ``H = dot(I, w) + cross(r, m * v)`` + + where I and m are the central inertia dyadic and mass of rigid body B, w + is the angular velocity of body B in the frame N, r is the position + vector from point O to the mass center of B, and v is the velocity of + the mass center in the frame N. + + Parameters + ========== + + point : Point + The point about which angular momentum is desired. + frame : ReferenceFrame + The frame in which angular momentum is desired. + + Examples + ======== + + >>> from sympy.physics.mechanics import Point, ReferenceFrame, outer + >>> from sympy.physics.mechanics import RigidBody, dynamicsymbols + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> m, v, r, omega = dynamicsymbols('m v r omega') + >>> N = ReferenceFrame('N') + >>> b = ReferenceFrame('b') + >>> b.set_ang_vel(N, omega * b.x) + >>> P = Point('P') + >>> P.set_vel(N, 1 * N.x) + >>> I = outer(b.x, b.x) + >>> B = RigidBody('B', P, b, m, (I, P)) + >>> B.angular_momentum(P, N) + omega*b.x + + """ + I = self.central_inertia + w = self.frame.ang_vel_in(frame) + m = self.mass + r = self.masscenter.pos_from(point) + v = self.masscenter.vel(frame) + + return I.dot(w) + r.cross(m * v) + + def kinetic_energy(self, frame): + """Kinetic energy of the rigid body. + + Explanation + =========== + + The kinetic energy, T, of a rigid body, B, is given by: + + ``T = 1/2 * (dot(dot(I, w), w) + dot(m * v, v))`` + + where I and m are the central inertia dyadic and mass of rigid body B + respectively, w is the body's angular velocity, and v is the velocity of + the body's mass center in the supplied ReferenceFrame. + + Parameters + ========== + + frame : ReferenceFrame + The RigidBody's angular velocity and the velocity of it's mass + center are typically defined with respect to an inertial frame but + any relevant frame in which the velocities are known can be + supplied. + + Examples + ======== + + >>> from sympy.physics.mechanics import Point, ReferenceFrame, outer + >>> from sympy.physics.mechanics import RigidBody + >>> from sympy import symbols + >>> m, v, r, omega = symbols('m v r omega') + >>> N = ReferenceFrame('N') + >>> b = ReferenceFrame('b') + >>> b.set_ang_vel(N, omega * b.x) + >>> P = Point('P') + >>> P.set_vel(N, v * N.x) + >>> I = outer (b.x, b.x) + >>> inertia_tuple = (I, P) + >>> B = RigidBody('B', P, b, m, inertia_tuple) + >>> B.kinetic_energy(N) + m*v**2/2 + omega**2/2 + + """ + + rotational_KE = S.Half * dot( + self.frame.ang_vel_in(frame), + dot(self.central_inertia, self.frame.ang_vel_in(frame))) + translational_KE = S.Half * self.mass * dot(self.masscenter.vel(frame), + self.masscenter.vel(frame)) + return rotational_KE + translational_KE + + def set_potential_energy(self, scalar): + sympy_deprecation_warning( + """ +The sympy.physics.mechanics.RigidBody.set_potential_energy() +method is deprecated. Instead use + + B.potential_energy = scalar + """, + deprecated_since_version="1.5", + active_deprecations_target="deprecated-set-potential-energy", + ) + self.potential_energy = scalar + + def parallel_axis(self, point, frame=None): + """Returns the inertia dyadic of the body with respect to another point. + + Parameters + ========== + + point : sympy.physics.vector.Point + The point to express the inertia dyadic about. + frame : sympy.physics.vector.ReferenceFrame + The reference frame used to construct the dyadic. + + Returns + ======= + + inertia : sympy.physics.vector.Dyadic + The inertia dyadic of the rigid body expressed about the provided + point. + + """ + if frame is None: + frame = self.frame + return self.central_inertia + inertia_of_point_mass( + self.mass, self.masscenter.pos_from(point), frame) diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/system.py b/phivenv/Lib/site-packages/sympy/physics/mechanics/system.py new file mode 100644 index 0000000000000000000000000000000000000000..c8e0657d7da54ca5aaad9b37b816235641968470 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/mechanics/system.py @@ -0,0 +1,1553 @@ +from functools import wraps + +from sympy.core.basic import Basic +from sympy.matrices.immutable import ImmutableMatrix +from sympy.matrices.dense import Matrix, eye, zeros +from sympy.core.containers import OrderedSet +from sympy.physics.mechanics.actuator import ActuatorBase +from sympy.physics.mechanics.body_base import BodyBase +from sympy.physics.mechanics.functions import ( + Lagrangian, _validate_coordinates, find_dynamicsymbols) +from sympy.physics.mechanics.joint import Joint +from sympy.physics.mechanics.kane import KanesMethod +from sympy.physics.mechanics.lagrange import LagrangesMethod +from sympy.physics.mechanics.loads import _parse_load, gravity +from sympy.physics.mechanics.method import _Methods +from sympy.physics.mechanics.particle import Particle +from sympy.physics.vector import Point, ReferenceFrame, dynamicsymbols +from sympy.utilities.iterables import iterable +from sympy.utilities.misc import filldedent + +__all__ = ['SymbolicSystem', 'System'] + + +def _reset_eom_method(method): + """Decorator to reset the eom_method if a property is changed.""" + + @wraps(method) + def wrapper(self, *args, **kwargs): + self._eom_method = None + return method(self, *args, **kwargs) + + return wrapper + + +class System(_Methods): + """Class to define a multibody system and form its equations of motion. + + Explanation + =========== + + A ``System`` instance stores the different objects associated with a model, + including bodies, joints, constraints, and other relevant information. With + all the relationships between components defined, the ``System`` can be used + to form the equations of motion using a backend, such as ``KanesMethod``. + The ``System`` has been designed to be compatible with third-party + libraries for greater flexibility and integration with other tools. + + Attributes + ========== + + frame : ReferenceFrame + Inertial reference frame of the system. + fixed_point : Point + A fixed point in the inertial reference frame. + x : Vector + Unit vector fixed in the inertial reference frame. + y : Vector + Unit vector fixed in the inertial reference frame. + z : Vector + Unit vector fixed in the inertial reference frame. + q : ImmutableMatrix + Matrix of all the generalized coordinates, i.e. the independent + generalized coordinates stacked upon the dependent. + u : ImmutableMatrix + Matrix of all the generalized speeds, i.e. the independent generealized + speeds stacked upon the dependent. + q_ind : ImmutableMatrix + Matrix of the independent generalized coordinates. + q_dep : ImmutableMatrix + Matrix of the dependent generalized coordinates. + u_ind : ImmutableMatrix + Matrix of the independent generalized speeds. + u_dep : ImmutableMatrix + Matrix of the dependent generalized speeds. + u_aux : ImmutableMatrix + Matrix of auxiliary generalized speeds. + kdes : ImmutableMatrix + Matrix of the kinematical differential equations as expressions equated + to the zero matrix. + bodies : tuple of BodyBase subclasses + Tuple of all bodies that make up the system. + joints : tuple of Joint + Tuple of all joints that connect bodies in the system. + loads : tuple of LoadBase subclasses + Tuple of all loads that have been applied to the system. + actuators : tuple of ActuatorBase subclasses + Tuple of all actuators present in the system. + holonomic_constraints : ImmutableMatrix + Matrix with the holonomic constraints as expressions equated to the zero + matrix. + nonholonomic_constraints : ImmutableMatrix + Matrix with the nonholonomic constraints as expressions equated to the + zero matrix. + velocity_constraints : ImmutableMatrix + Matrix with the velocity constraints as expressions equated to the zero + matrix. These are by default derived as the time derivatives of the + holonomic constraints extended with the nonholonomic constraints. + eom_method : subclass of KanesMethod or LagrangesMethod + Backend for forming the equations of motion. + + Examples + ======== + + In the example below a cart with a pendulum is created. The cart moves along + the x axis of the rail and the pendulum rotates about the z axis. The length + of the pendulum is ``l`` with the pendulum represented as a particle. To + move the cart a time dependent force ``F`` is applied to the cart. + + We first need to import some functions and create some of our variables. + + >>> from sympy import symbols, simplify + >>> from sympy.physics.mechanics import ( + ... mechanics_printing, dynamicsymbols, RigidBody, Particle, + ... ReferenceFrame, PrismaticJoint, PinJoint, System) + >>> mechanics_printing(pretty_print=False) + >>> g, l = symbols('g l') + >>> F = dynamicsymbols('F') + + The next step is to create bodies. It is also useful to create a frame for + locating the particle with respect to the pin joint later on, as a particle + does not have a body-fixed frame. + + >>> rail = RigidBody('rail') + >>> cart = RigidBody('cart') + >>> bob = Particle('bob') + >>> bob_frame = ReferenceFrame('bob_frame') + + Initialize the system, with the rail as the Newtonian reference. The body is + also automatically added to the system. + + >>> system = System.from_newtonian(rail) + >>> print(system.bodies[0]) + rail + + Create the joints, while immediately also adding them to the system. + + >>> system.add_joints( + ... PrismaticJoint('slider', rail, cart, joint_axis=rail.x), + ... PinJoint('pin', cart, bob, joint_axis=cart.z, + ... child_interframe=bob_frame, + ... child_point=l * bob_frame.y) + ... ) + >>> system.joints + (PrismaticJoint: slider parent: rail child: cart, + PinJoint: pin parent: cart child: bob) + + While adding the joints, the associated generalized coordinates, generalized + speeds, kinematic differential equations and bodies are also added to the + system. + + >>> system.q + Matrix([ + [q_slider], + [ q_pin]]) + >>> system.u + Matrix([ + [u_slider], + [ u_pin]]) + >>> system.kdes + Matrix([ + [u_slider - q_slider'], + [ u_pin - q_pin']]) + >>> [body.name for body in system.bodies] + ['rail', 'cart', 'bob'] + + With the kinematics established, we can now apply gravity and the cart force + ``F``. + + >>> system.apply_uniform_gravity(-g * system.y) + >>> system.add_loads((cart.masscenter, F * rail.x)) + >>> system.loads + ((rail_masscenter, - g*rail_mass*rail_frame.y), + (cart_masscenter, - cart_mass*g*rail_frame.y), + (bob_masscenter, - bob_mass*g*rail_frame.y), + (cart_masscenter, F*rail_frame.x)) + + With the entire system defined, we can now form the equations of motion. + Before forming the equations of motion, one can also run some checks that + will try to identify some common errors. + + >>> system.validate_system() + >>> system.form_eoms() + Matrix([ + [bob_mass*l*u_pin**2*sin(q_pin) - bob_mass*l*cos(q_pin)*u_pin' + - (bob_mass + cart_mass)*u_slider' + F], + [ -bob_mass*g*l*sin(q_pin) - bob_mass*l**2*u_pin' + - bob_mass*l*cos(q_pin)*u_slider']]) + >>> simplify(system.mass_matrix) + Matrix([ + [ bob_mass + cart_mass, bob_mass*l*cos(q_pin)], + [bob_mass*l*cos(q_pin), bob_mass*l**2]]) + >>> system.forcing + Matrix([ + [bob_mass*l*u_pin**2*sin(q_pin) + F], + [ -bob_mass*g*l*sin(q_pin)]]) + + The complexity of the above example can be increased if we add a constraint + to prevent the particle from moving in the horizontal (x) direction. This + can be done by adding a holonomic constraint. After which we should also + redefine what our (in)dependent generalized coordinates and speeds are. + + >>> system.add_holonomic_constraints( + ... bob.masscenter.pos_from(rail.masscenter).dot(system.x) + ... ) + >>> system.q_ind = system.get_joint('pin').coordinates + >>> system.q_dep = system.get_joint('slider').coordinates + >>> system.u_ind = system.get_joint('pin').speeds + >>> system.u_dep = system.get_joint('slider').speeds + + With the updated system the equations of motion can be formed again. + + >>> system.validate_system() + >>> system.form_eoms() + Matrix([[-bob_mass*g*l*sin(q_pin) + - bob_mass*l**2*u_pin' + - bob_mass*l*cos(q_pin)*u_slider' + - l*(bob_mass*l*u_pin**2*sin(q_pin) + - bob_mass*l*cos(q_pin)*u_pin' + - (bob_mass + cart_mass)*u_slider')*cos(q_pin) + - l*F*cos(q_pin)]]) + >>> simplify(system.mass_matrix) + Matrix([ + [bob_mass*l**2*sin(q_pin)**2, -cart_mass*l*cos(q_pin)], + [ l*cos(q_pin), 1]]) + >>> simplify(system.forcing) + Matrix([ + [-l*(bob_mass*g*sin(q_pin) + bob_mass*l*u_pin**2*sin(2*q_pin)/2 + + F*cos(q_pin))], + [ + l*u_pin**2*sin(q_pin)]]) + + """ + + def __init__(self, frame=None, fixed_point=None): + """Initialize the system. + + Parameters + ========== + + frame : ReferenceFrame, optional + The inertial frame of the system. If none is supplied, a new frame + will be created. + fixed_point : Point, optional + A fixed point in the inertial reference frame. If none is supplied, + a new fixed_point will be created. + + """ + if frame is None: + frame = ReferenceFrame('inertial_frame') + elif not isinstance(frame, ReferenceFrame): + raise TypeError('Frame must be an instance of ReferenceFrame.') + self._frame = frame + if fixed_point is None: + fixed_point = Point('inertial_point') + elif not isinstance(fixed_point, Point): + raise TypeError('Fixed point must be an instance of Point.') + self._fixed_point = fixed_point + self._fixed_point.set_vel(self._frame, 0) + self._q_ind = ImmutableMatrix(1, 0, []).T + self._q_dep = ImmutableMatrix(1, 0, []).T + self._u_ind = ImmutableMatrix(1, 0, []).T + self._u_dep = ImmutableMatrix(1, 0, []).T + self._u_aux = ImmutableMatrix(1, 0, []).T + self._kdes = ImmutableMatrix(1, 0, []).T + self._hol_coneqs = ImmutableMatrix(1, 0, []).T + self._nonhol_coneqs = ImmutableMatrix(1, 0, []).T + self._vel_constrs = None + self._bodies = [] + self._joints = [] + self._loads = [] + self._actuators = [] + self._eom_method = None + + @classmethod + def from_newtonian(cls, newtonian): + """Constructs the system with respect to a Newtonian body.""" + if isinstance(newtonian, Particle): + raise TypeError('A Particle has no frame so cannot act as ' + 'the Newtonian.') + system = cls(frame=newtonian.frame, fixed_point=newtonian.masscenter) + system.add_bodies(newtonian) + return system + + @property + def fixed_point(self): + """Fixed point in the inertial reference frame.""" + return self._fixed_point + + @property + def frame(self): + """Inertial reference frame of the system.""" + return self._frame + + @property + def x(self): + """Unit vector fixed in the inertial reference frame.""" + return self._frame.x + + @property + def y(self): + """Unit vector fixed in the inertial reference frame.""" + return self._frame.y + + @property + def z(self): + """Unit vector fixed in the inertial reference frame.""" + return self._frame.z + + @property + def bodies(self): + """Tuple of all bodies that have been added to the system.""" + return tuple(self._bodies) + + @bodies.setter + @_reset_eom_method + def bodies(self, bodies): + bodies = self._objects_to_list(bodies) + self._check_objects(bodies, [], BodyBase, 'Bodies', 'bodies') + self._bodies = bodies + + @property + def joints(self): + """Tuple of all joints that have been added to the system.""" + return tuple(self._joints) + + @joints.setter + @_reset_eom_method + def joints(self, joints): + joints = self._objects_to_list(joints) + self._check_objects(joints, [], Joint, 'Joints', 'joints') + self._joints = [] + self.add_joints(*joints) + + @property + def loads(self): + """Tuple of loads that have been applied on the system.""" + return tuple(self._loads) + + @loads.setter + @_reset_eom_method + def loads(self, loads): + loads = self._objects_to_list(loads) + self._loads = [_parse_load(load) for load in loads] + + @property + def actuators(self): + """Tuple of actuators present in the system.""" + return tuple(self._actuators) + + @actuators.setter + @_reset_eom_method + def actuators(self, actuators): + actuators = self._objects_to_list(actuators) + self._check_objects(actuators, [], ActuatorBase, 'Actuators', + 'actuators') + self._actuators = actuators + + @property + def q(self): + """Matrix of all the generalized coordinates with the independent + stacked upon the dependent.""" + return self._q_ind.col_join(self._q_dep) + + @property + def u(self): + """Matrix of all the generalized speeds with the independent stacked + upon the dependent.""" + return self._u_ind.col_join(self._u_dep) + + @property + def q_ind(self): + """Matrix of the independent generalized coordinates.""" + return self._q_ind + + @q_ind.setter + @_reset_eom_method + def q_ind(self, q_ind): + self._q_ind, self._q_dep = self._parse_coordinates( + self._objects_to_list(q_ind), True, [], self.q_dep, 'coordinates') + + @property + def q_dep(self): + """Matrix of the dependent generalized coordinates.""" + return self._q_dep + + @q_dep.setter + @_reset_eom_method + def q_dep(self, q_dep): + self._q_ind, self._q_dep = self._parse_coordinates( + self._objects_to_list(q_dep), False, self.q_ind, [], 'coordinates') + + @property + def u_ind(self): + """Matrix of the independent generalized speeds.""" + return self._u_ind + + @u_ind.setter + @_reset_eom_method + def u_ind(self, u_ind): + self._u_ind, self._u_dep = self._parse_coordinates( + self._objects_to_list(u_ind), True, [], self.u_dep, 'speeds') + + @property + def u_dep(self): + """Matrix of the dependent generalized speeds.""" + return self._u_dep + + @u_dep.setter + @_reset_eom_method + def u_dep(self, u_dep): + self._u_ind, self._u_dep = self._parse_coordinates( + self._objects_to_list(u_dep), False, self.u_ind, [], 'speeds') + + @property + def u_aux(self): + """Matrix of auxiliary generalized speeds.""" + return self._u_aux + + @u_aux.setter + @_reset_eom_method + def u_aux(self, u_aux): + self._u_aux = self._parse_coordinates( + self._objects_to_list(u_aux), True, [], [], 'u_auxiliary')[0] + + @property + def kdes(self): + """Kinematical differential equations as expressions equated to the zero + matrix. These equations describe the coupling between the generalized + coordinates and the generalized speeds.""" + return self._kdes + + @kdes.setter + @_reset_eom_method + def kdes(self, kdes): + kdes = self._objects_to_list(kdes) + self._kdes = self._parse_expressions( + kdes, [], 'kinematic differential equations') + + @property + def holonomic_constraints(self): + """Matrix with the holonomic constraints as expressions equated to the + zero matrix.""" + return self._hol_coneqs + + @holonomic_constraints.setter + @_reset_eom_method + def holonomic_constraints(self, constraints): + constraints = self._objects_to_list(constraints) + self._hol_coneqs = self._parse_expressions( + constraints, [], 'holonomic constraints') + + @property + def nonholonomic_constraints(self): + """Matrix with the nonholonomic constraints as expressions equated to + the zero matrix.""" + return self._nonhol_coneqs + + @nonholonomic_constraints.setter + @_reset_eom_method + def nonholonomic_constraints(self, constraints): + constraints = self._objects_to_list(constraints) + self._nonhol_coneqs = self._parse_expressions( + constraints, [], 'nonholonomic constraints') + + @property + def velocity_constraints(self): + """Matrix with the velocity constraints as expressions equated to the + zero matrix. The velocity constraints are by default derived from the + holonomic and nonholonomic constraints unless they are explicitly set. + """ + if self._vel_constrs is None: + return self.holonomic_constraints.diff(dynamicsymbols._t).col_join( + self.nonholonomic_constraints) + return self._vel_constrs + + @velocity_constraints.setter + @_reset_eom_method + def velocity_constraints(self, constraints): + if constraints is None: + self._vel_constrs = None + return + constraints = self._objects_to_list(constraints) + self._vel_constrs = self._parse_expressions( + constraints, [], 'velocity constraints') + + @property + def eom_method(self): + """Backend for forming the equations of motion.""" + return self._eom_method + + @staticmethod + def _objects_to_list(lst): + """Helper to convert passed objects to a list.""" + if not iterable(lst): # Only one object + return [lst] + return list(lst[:]) # converts Matrix and tuple to flattened list + + @staticmethod + def _check_objects(objects, obj_lst, expected_type, obj_name, type_name): + """Helper to check the objects that are being added to the system. + + Explanation + =========== + This method checks that the objects that are being added to the system + are of the correct type and have not already been added. If any of the + objects are not of the correct type or have already been added, then + an error is raised. + + Parameters + ========== + objects : iterable + The objects that would be added to the system. + obj_lst : list + The list of objects that are already in the system. + expected_type : type + The type that the objects should be. + obj_name : str + The name of the category of objects. This string is used to + formulate the error message for the user. + type_name : str + The name of the type that the objects should be. This string is used + to formulate the error message for the user. + + """ + seen = set(obj_lst) + duplicates = set() + wrong_types = set() + for obj in objects: + if not isinstance(obj, expected_type): + wrong_types.add(obj) + if obj in seen: + duplicates.add(obj) + else: + seen.add(obj) + if wrong_types: + raise TypeError(f'{obj_name} {wrong_types} are not {type_name}.') + if duplicates: + raise ValueError(f'{obj_name} {duplicates} have already been added ' + f'to the system.') + + def _parse_coordinates(self, new_coords, independent, old_coords_ind, + old_coords_dep, coord_type='coordinates'): + """Helper to parse coordinates and speeds.""" + # Construct lists of the independent and dependent coordinates + coords_ind, coords_dep = old_coords_ind[:], old_coords_dep[:] + if not iterable(independent): + independent = [independent] * len(new_coords) + for coord, indep in zip(new_coords, independent): + if indep: + coords_ind.append(coord) + else: + coords_dep.append(coord) + # Check types and duplicates + current = {'coordinates': self.q_ind[:] + self.q_dep[:], + 'speeds': self.u_ind[:] + self.u_dep[:], + 'u_auxiliary': self._u_aux[:], + coord_type: coords_ind + coords_dep} + _validate_coordinates(**current) + return (ImmutableMatrix(1, len(coords_ind), coords_ind).T, + ImmutableMatrix(1, len(coords_dep), coords_dep).T) + + @staticmethod + def _parse_expressions(new_expressions, old_expressions, name, + check_negatives=False): + """Helper to parse expressions like constraints.""" + old_expressions = old_expressions[:] + new_expressions = list(new_expressions) # Converts a possible tuple + if check_negatives: + check_exprs = old_expressions + [-expr for expr in old_expressions] + else: + check_exprs = old_expressions + System._check_objects(new_expressions, check_exprs, Basic, name, + 'expressions') + for expr in new_expressions: + if expr == 0: + raise ValueError(f'Parsed {name} are zero.') + return ImmutableMatrix(1, len(old_expressions) + len(new_expressions), + old_expressions + new_expressions).T + + @_reset_eom_method + def add_coordinates(self, *coordinates, independent=True): + """Add generalized coordinate(s) to the system. + + Parameters + ========== + + *coordinates : dynamicsymbols + One or more generalized coordinates to be added to the system. + independent : bool or list of bool, optional + Boolean whether a coordinate is dependent or independent. The + default is True, so the coordinates are added as independent by + default. + + """ + self._q_ind, self._q_dep = self._parse_coordinates( + coordinates, independent, self.q_ind, self.q_dep, 'coordinates') + + @_reset_eom_method + def add_speeds(self, *speeds, independent=True): + """Add generalized speed(s) to the system. + + Parameters + ========== + + *speeds : dynamicsymbols + One or more generalized speeds to be added to the system. + independent : bool or list of bool, optional + Boolean whether a speed is dependent or independent. The default is + True, so the speeds are added as independent by default. + + """ + self._u_ind, self._u_dep = self._parse_coordinates( + speeds, independent, self.u_ind, self.u_dep, 'speeds') + + @_reset_eom_method + def add_auxiliary_speeds(self, *speeds): + """Add auxiliary speed(s) to the system. + + Parameters + ========== + + *speeds : dynamicsymbols + One or more auxiliary speeds to be added to the system. + + """ + self._u_aux = self._parse_coordinates( + speeds, True, self._u_aux, [], 'u_auxiliary')[0] + + @_reset_eom_method + def add_kdes(self, *kdes): + """Add kinematic differential equation(s) to the system. + + Parameters + ========== + + *kdes : Expr + One or more kinematic differential equations. + + """ + self._kdes = self._parse_expressions( + kdes, self.kdes, 'kinematic differential equations', + check_negatives=True) + + @_reset_eom_method + def add_holonomic_constraints(self, *constraints): + """Add holonomic constraint(s) to the system. + + Parameters + ========== + + *constraints : Expr + One or more holonomic constraints, which are expressions that should + be zero. + + """ + self._hol_coneqs = self._parse_expressions( + constraints, self._hol_coneqs, 'holonomic constraints', + check_negatives=True) + + @_reset_eom_method + def add_nonholonomic_constraints(self, *constraints): + """Add nonholonomic constraint(s) to the system. + + Parameters + ========== + + *constraints : Expr + One or more nonholonomic constraints, which are expressions that + should be zero. + + """ + self._nonhol_coneqs = self._parse_expressions( + constraints, self._nonhol_coneqs, 'nonholonomic constraints', + check_negatives=True) + + @_reset_eom_method + def add_bodies(self, *bodies): + """Add body(ies) to the system. + + Parameters + ========== + + bodies : Particle or RigidBody + One or more bodies. + + """ + self._check_objects(bodies, self.bodies, BodyBase, 'Bodies', 'bodies') + self._bodies.extend(bodies) + + @_reset_eom_method + def add_loads(self, *loads): + """Add load(s) to the system. + + Parameters + ========== + + *loads : Force or Torque + One or more loads. + + """ + loads = [_parse_load(load) for load in loads] # Checks the loads + self._loads.extend(loads) + + @_reset_eom_method + def apply_uniform_gravity(self, acceleration): + """Apply uniform gravity to all bodies in the system by adding loads. + + Parameters + ========== + + acceleration : Vector + The acceleration due to gravity. + + """ + self.add_loads(*gravity(acceleration, *self.bodies)) + + @_reset_eom_method + def add_actuators(self, *actuators): + """Add actuator(s) to the system. + + Parameters + ========== + + *actuators : subclass of ActuatorBase + One or more actuators. + + """ + self._check_objects(actuators, self.actuators, ActuatorBase, + 'Actuators', 'actuators') + self._actuators.extend(actuators) + + @_reset_eom_method + def add_joints(self, *joints): + """Add joint(s) to the system. + + Explanation + =========== + + This methods adds one or more joints to the system including its + associated objects, i.e. generalized coordinates, generalized speeds, + kinematic differential equations and the bodies. + + Parameters + ========== + + *joints : subclass of Joint + One or more joints. + + Notes + ===== + + For the generalized coordinates, generalized speeds and bodies it is + checked whether they are already known by the system instance. If they + are, then they are not added. The kinematic differential equations are + however always added to the system, so you should not also manually add + those on beforehand. + + """ + self._check_objects(joints, self.joints, Joint, 'Joints', 'joints') + self._joints.extend(joints) + coordinates, speeds, kdes, bodies = (OrderedSet() for _ in range(4)) + for joint in joints: + coordinates.update(joint.coordinates) + speeds.update(joint.speeds) + kdes.update(joint.kdes) + bodies.update((joint.parent, joint.child)) + coordinates = coordinates.difference(self.q) + speeds = speeds.difference(self.u) + kdes = kdes.difference(self.kdes[:] + (-self.kdes)[:]) + bodies = bodies.difference(self.bodies) + self.add_coordinates(*tuple(coordinates)) + self.add_speeds(*tuple(speeds)) + self.add_kdes(*(kde for kde in tuple(kdes) if not kde == 0)) + self.add_bodies(*tuple(bodies)) + + def get_body(self, name): + """Retrieve a body from the system by name. + + Parameters + ========== + + name : str + The name of the body to retrieve. + + Returns + ======= + + RigidBody or Particle + The body with the given name, or None if no such body exists. + + """ + for body in self._bodies: + if body.name == name: + return body + + def get_joint(self, name): + """Retrieve a joint from the system by name. + + Parameters + ========== + + name : str + The name of the joint to retrieve. + + Returns + ======= + + subclass of Joint + The joint with the given name, or None if no such joint exists. + + """ + for joint in self._joints: + if joint.name == name: + return joint + + def _form_eoms(self): + return self.form_eoms() + + def form_eoms(self, eom_method=KanesMethod, **kwargs): + """Form the equations of motion of the system. + + Parameters + ========== + + eom_method : subclass of KanesMethod or LagrangesMethod + Backend class to be used for forming the equations of motion. The + default is ``KanesMethod``. + + Returns + ======== + + ImmutableMatrix + Vector of equations of motions. + + Examples + ======== + + This is a simple example for a one degree of freedom translational + spring-mass-damper. + + >>> from sympy import S, symbols + >>> from sympy.physics.mechanics import ( + ... LagrangesMethod, dynamicsymbols, PrismaticJoint, Particle, + ... RigidBody, System) + >>> q = dynamicsymbols('q') + >>> qd = dynamicsymbols('q', 1) + >>> m, k, b = symbols('m k b') + >>> wall = RigidBody('W') + >>> system = System.from_newtonian(wall) + >>> bob = Particle('P', mass=m) + >>> bob.potential_energy = S.Half * k * q**2 + >>> system.add_joints(PrismaticJoint('J', wall, bob, q, qd)) + >>> system.add_loads((bob.masscenter, b * qd * system.x)) + >>> system.form_eoms(LagrangesMethod) + Matrix([[-b*Derivative(q(t), t) + k*q(t) + m*Derivative(q(t), (t, 2))]]) + + We can also solve for the states using the 'rhs' method. + + >>> system.rhs() + Matrix([ + [ Derivative(q(t), t)], + [(b*Derivative(q(t), t) - k*q(t))/m]]) + + """ + # KanesMethod does not accept empty iterables + loads = self.loads + tuple( + load for act in self.actuators for load in act.to_loads()) + loads = loads if loads else None + if issubclass(eom_method, KanesMethod): + disallowed_kwargs = { + "frame", "q_ind", "u_ind", "kd_eqs", "q_dependent", + "u_dependent", "u_auxiliary", "configuration_constraints", + "velocity_constraints", "forcelist", "bodies"} + wrong_kwargs = disallowed_kwargs.intersection(kwargs) + if wrong_kwargs: + raise ValueError( + f"The following keyword arguments are not allowed to be " + f"overwritten in {eom_method.__name__}: {wrong_kwargs}.") + kwargs = {"frame": self.frame, "q_ind": self.q_ind, + "u_ind": self.u_ind, "kd_eqs": self.kdes, + "q_dependent": self.q_dep, "u_dependent": self.u_dep, + "configuration_constraints": self.holonomic_constraints, + "velocity_constraints": self.velocity_constraints, + "u_auxiliary": self.u_aux, + "forcelist": loads, "bodies": self.bodies, + "explicit_kinematics": False, **kwargs} + self._eom_method = eom_method(**kwargs) + elif issubclass(eom_method, LagrangesMethod): + disallowed_kwargs = { + "frame", "qs", "forcelist", "bodies", "hol_coneqs", + "nonhol_coneqs", "Lagrangian"} + wrong_kwargs = disallowed_kwargs.intersection(kwargs) + if wrong_kwargs: + raise ValueError( + f"The following keyword arguments are not allowed to be " + f"overwritten in {eom_method.__name__}: {wrong_kwargs}.") + kwargs = {"frame": self.frame, "qs": self.q, "forcelist": loads, + "bodies": self.bodies, + "hol_coneqs": self.holonomic_constraints, + "nonhol_coneqs": self.nonholonomic_constraints, **kwargs} + if "Lagrangian" not in kwargs: + kwargs["Lagrangian"] = Lagrangian(kwargs["frame"], + *kwargs["bodies"]) + self._eom_method = eom_method(**kwargs) + else: + raise NotImplementedError(f'{eom_method} has not been implemented.') + return self.eom_method._form_eoms() + + def rhs(self, inv_method=None): + """Compute the equations of motion in the explicit form. + + Parameters + ========== + + inv_method : str + The specific sympy inverse matrix calculation method to use. For a + list of valid methods, see + :meth:`~sympy.matrices.matrixbase.MatrixBase.inv` + + Returns + ======== + + ImmutableMatrix + Equations of motion in the explicit form. + + See Also + ======== + + sympy.physics.mechanics.kane.KanesMethod.rhs: + KanesMethod's ``rhs`` function. + sympy.physics.mechanics.lagrange.LagrangesMethod.rhs: + LagrangesMethod's ``rhs`` function. + + """ + return self.eom_method.rhs(inv_method=inv_method) + + @property + def mass_matrix(self): + r"""The mass matrix of the system. + + Explanation + =========== + + The mass matrix $M_d$ and the forcing vector $f_d$ of a system describe + the system's dynamics according to the following equations: + + .. math:: + M_d \dot{u} = f_d + + where $\dot{u}$ is the time derivative of the generalized speeds. + + """ + return self.eom_method.mass_matrix + + @property + def mass_matrix_full(self): + r"""The mass matrix of the system, augmented by the kinematic + differential equations in explicit or implicit form. + + Explanation + =========== + + The full mass matrix $M_m$ and the full forcing vector $f_m$ of a system + describe the dynamics and kinematics according to the following + equation: + + .. math:: + M_m \dot{x} = f_m + + where $x$ is the state vector stacking $q$ and $u$. + + """ + return self.eom_method.mass_matrix_full + + @property + def forcing(self): + """The forcing vector of the system.""" + return self.eom_method.forcing + + @property + def forcing_full(self): + """The forcing vector of the system, augmented by the kinematic + differential equations in explicit or implicit form.""" + return self.eom_method.forcing_full + + def validate_system(self, eom_method=KanesMethod, check_duplicates=False): + """Validates the system using some basic checks. + + Explanation + =========== + + This method validates the system based on the following checks: + + - The number of dependent generalized coordinates should equal the + number of holonomic constraints. + - All generalized coordinates defined by the joints should also be known + to the system. + - If ``KanesMethod`` is used as a ``eom_method``: + - All generalized speeds and kinematic differential equations + defined by the joints should also be known to the system. + - The number of dependent generalized speeds should equal the number + of velocity constraints. + - The number of generalized coordinates should be less than or equal + to the number of generalized speeds. + - The number of generalized coordinates should equal the number of + kinematic differential equations. + - If ``LagrangesMethod`` is used as ``eom_method``: + - There should not be any generalized speeds that are not + derivatives of the generalized coordinates (this includes the + generalized speeds defined by the joints). + + Parameters + ========== + + eom_method : subclass of KanesMethod or LagrangesMethod + Backend class that will be used for forming the equations of motion. + There are different checks for the different backends. The default + is ``KanesMethod``. + check_duplicates : bool + Boolean whether the system should be checked for duplicate + definitions. The default is False, because duplicates are already + checked when adding objects to the system. + + Notes + ===== + + This method is not guaranteed to be backwards compatible as it may + improve over time. The method can become both more and less strict in + certain areas. However a well-defined system should always pass all + these tests. + + """ + msgs = [] + # Save some data in variables + n_hc = self.holonomic_constraints.shape[0] + n_vc = self.velocity_constraints.shape[0] + n_q_dep, n_u_dep = self.q_dep.shape[0], self.u_dep.shape[0] + q_set, u_set = set(self.q), set(self.u) + n_q, n_u = len(q_set), len(u_set) + # Check number of holonomic constraints + if n_q_dep != n_hc: + msgs.append(filldedent(f""" + The number of dependent generalized coordinates {n_q_dep} should be + equal to the number of holonomic constraints {n_hc}.""")) + # Check if all joint coordinates and speeds are present + missing_q = set() + for joint in self.joints: + missing_q.update(set(joint.coordinates).difference(q_set)) + if missing_q: + msgs.append(filldedent(f""" + The generalized coordinates {missing_q} used in joints are not added + to the system.""")) + # Method dependent checks + if issubclass(eom_method, KanesMethod): + n_kdes = len(self.kdes) + missing_kdes, missing_u = set(), set() + for joint in self.joints: + missing_u.update(set(joint.speeds).difference(u_set)) + missing_kdes.update(set(joint.kdes).difference( + self.kdes[:] + (-self.kdes)[:])) + if missing_u: + msgs.append(filldedent(f""" + The generalized speeds {missing_u} used in joints are not added + to the system.""")) + if missing_kdes: + msgs.append(filldedent(f""" + The kinematic differential equations {missing_kdes} used in + joints are not added to the system.""")) + if n_u_dep != n_vc: + msgs.append(filldedent(f""" + The number of dependent generalized speeds {n_u_dep} should be + equal to the number of velocity constraints {n_vc}.""")) + if n_q > n_u: + msgs.append(filldedent(f""" + The number of generalized coordinates {n_q} should be less than + or equal to the number of generalized speeds {n_u}.""")) + if n_u != n_kdes: + msgs.append(filldedent(f""" + The number of generalized speeds {n_u} should be equal to the + number of kinematic differential equations {n_kdes}.""")) + elif issubclass(eom_method, LagrangesMethod): + not_qdots = set(self.u).difference(self.q.diff(dynamicsymbols._t)) + for joint in self.joints: + not_qdots.update(set( + joint.speeds).difference(self.q.diff(dynamicsymbols._t))) + if not_qdots: + msgs.append(filldedent(f""" + The generalized speeds {not_qdots} are not supported by this + method. Only derivatives of the generalized coordinates are + supported. If these symbols are used in your expressions, then + this will result in wrong equations of motion.""")) + if self.u_aux: + msgs.append(filldedent(f""" + This method does not support auxiliary speeds. If these symbols + are used in your expressions, then this will result in wrong + equations of motion. The auxiliary speeds are {self.u_aux}.""")) + else: + raise NotImplementedError(f'{eom_method} has not been implemented.') + if check_duplicates: # Should be redundant + duplicates_to_check = [('generalized coordinates', self.q), + ('generalized speeds', self.u), + ('auxiliary speeds', self.u_aux), + ('bodies', self.bodies), + ('joints', self.joints)] + for name, lst in duplicates_to_check: + seen = set() + duplicates = {x for x in lst if x in seen or seen.add(x)} + if duplicates: + msgs.append(filldedent(f""" + The {name} {duplicates} exist multiple times within the + system.""")) + if msgs: + raise ValueError('\n'.join(msgs)) + + +class SymbolicSystem: + """SymbolicSystem is a class that contains all the information about a + system in a symbolic format such as the equations of motions and the bodies + and loads in the system. + + There are three ways that the equations of motion can be described for + Symbolic System: + + + [1] Explicit form where the kinematics and dynamics are combined + x' = F_1(x, t, r, p) + + [2] Implicit form where the kinematics and dynamics are combined + M_2(x, p) x' = F_2(x, t, r, p) + + [3] Implicit form where the kinematics and dynamics are separate + M_3(q, p) u' = F_3(q, u, t, r, p) + q' = G(q, u, t, r, p) + + where + + x : states, e.g. [q, u] + t : time + r : specified (exogenous) inputs + p : constants + q : generalized coordinates + u : generalized speeds + F_1 : right hand side of the combined equations in explicit form + F_2 : right hand side of the combined equations in implicit form + F_3 : right hand side of the dynamical equations in implicit form + M_2 : mass matrix of the combined equations in implicit form + M_3 : mass matrix of the dynamical equations in implicit form + G : right hand side of the kinematical differential equations + + Parameters + ========== + + coord_states : ordered iterable of functions of time + This input will either be a collection of the coordinates or states + of the system depending on whether or not the speeds are also + given. If speeds are specified this input will be assumed to + be the coordinates otherwise this input will be assumed to + be the states. + + right_hand_side : Matrix + This variable is the right hand side of the equations of motion in + any of the forms. The specific form will be assumed depending on + whether a mass matrix or coordinate derivatives are given. + + speeds : ordered iterable of functions of time, optional + This is a collection of the generalized speeds of the system. If + given it will be assumed that the first argument (coord_states) + will represent the generalized coordinates of the system. + + mass_matrix : Matrix, optional + The matrix of the implicit forms of the equations of motion (forms + [2] and [3]). The distinction between the forms is determined by + whether or not the coordinate derivatives are passed in. If + they are given form [3] will be assumed otherwise form [2] is + assumed. + + coordinate_derivatives : Matrix, optional + The right hand side of the kinematical equations in explicit form. + If given it will be assumed that the equations of motion are being + entered in form [3]. + + alg_con : Iterable, optional + The indexes of the rows in the equations of motion that contain + algebraic constraints instead of differential equations. If the + equations are input in form [3], it will be assumed the indexes are + referencing the mass_matrix/right_hand_side combination and not the + coordinate_derivatives. + + output_eqns : Dictionary, optional + Any output equations that are desired to be tracked are stored in a + dictionary where the key corresponds to the name given for the + specific equation and the value is the equation itself in symbolic + form + + coord_idxs : Iterable, optional + If coord_states corresponds to the states rather than the + coordinates this variable will tell SymbolicSystem which indexes of + the states correspond to generalized coordinates. + + speed_idxs : Iterable, optional + If coord_states corresponds to the states rather than the + coordinates this variable will tell SymbolicSystem which indexes of + the states correspond to generalized speeds. + + bodies : iterable of Body/Rigidbody objects, optional + Iterable containing the bodies of the system + + loads : iterable of load instances (described below), optional + Iterable containing the loads of the system where forces are given + by (point of application, force vector) and torques are given by + (reference frame acting upon, torque vector). Ex [(point, force), + (ref_frame, torque)] + + Attributes + ========== + + coordinates : Matrix, shape(n, 1) + This is a matrix containing the generalized coordinates of the system + + speeds : Matrix, shape(m, 1) + This is a matrix containing the generalized speeds of the system + + states : Matrix, shape(o, 1) + This is a matrix containing the state variables of the system + + alg_con : List + This list contains the indices of the algebraic constraints in the + combined equations of motion. The presence of these constraints + requires that a DAE solver be used instead of an ODE solver. + If the system is given in form [3] the alg_con variable will be + adjusted such that it is a representation of the combined kinematics + and dynamics thus make sure it always matches the mass matrix + entered. + + dyn_implicit_mat : Matrix, shape(m, m) + This is the M matrix in form [3] of the equations of motion (the mass + matrix or generalized inertia matrix of the dynamical equations of + motion in implicit form). + + dyn_implicit_rhs : Matrix, shape(m, 1) + This is the F vector in form [3] of the equations of motion (the right + hand side of the dynamical equations of motion in implicit form). + + comb_implicit_mat : Matrix, shape(o, o) + This is the M matrix in form [2] of the equations of motion. + This matrix contains a block diagonal structure where the top + left block (the first rows) represent the matrix in the + implicit form of the kinematical equations and the bottom right + block (the last rows) represent the matrix in the implicit form + of the dynamical equations. + + comb_implicit_rhs : Matrix, shape(o, 1) + This is the F vector in form [2] of the equations of motion. The top + part of the vector represents the right hand side of the implicit form + of the kinemaical equations and the bottom of the vector represents the + right hand side of the implicit form of the dynamical equations of + motion. + + comb_explicit_rhs : Matrix, shape(o, 1) + This vector represents the right hand side of the combined equations of + motion in explicit form (form [1] from above). + + kin_explicit_rhs : Matrix, shape(m, 1) + This is the right hand side of the explicit form of the kinematical + equations of motion as can be seen in form [3] (the G matrix). + + output_eqns : Dictionary + If output equations were given they are stored in a dictionary where + the key corresponds to the name given for the specific equation and + the value is the equation itself in symbolic form + + bodies : Tuple + If the bodies in the system were given they are stored in a tuple for + future access + + loads : Tuple + If the loads in the system were given they are stored in a tuple for + future access. This includes forces and torques where forces are given + by (point of application, force vector) and torques are given by + (reference frame acted upon, torque vector). + + Example + ======= + + As a simple example, the dynamics of a simple pendulum will be input into a + SymbolicSystem object manually. First some imports will be needed and then + symbols will be set up for the length of the pendulum (l), mass at the end + of the pendulum (m), and a constant for gravity (g). :: + + >>> from sympy import Matrix, sin, symbols + >>> from sympy.physics.mechanics import dynamicsymbols, SymbolicSystem + >>> l, m, g = symbols('l m g') + + The system will be defined by an angle of theta from the vertical and a + generalized speed of omega will be used where omega = theta_dot. :: + + >>> theta, omega = dynamicsymbols('theta omega') + + Now the equations of motion are ready to be formed and passed to the + SymbolicSystem object. :: + + >>> kin_explicit_rhs = Matrix([omega]) + >>> dyn_implicit_mat = Matrix([l**2 * m]) + >>> dyn_implicit_rhs = Matrix([-g * l * m * sin(theta)]) + >>> symsystem = SymbolicSystem([theta], dyn_implicit_rhs, [omega], + ... dyn_implicit_mat) + + Notes + ===== + + m : number of generalized speeds + n : number of generalized coordinates + o : number of states + + """ + + def __init__(self, coord_states, right_hand_side, speeds=None, + mass_matrix=None, coordinate_derivatives=None, alg_con=None, + output_eqns={}, coord_idxs=None, speed_idxs=None, bodies=None, + loads=None): + """Initializes a SymbolicSystem object""" + + # Extract information on speeds, coordinates and states + if speeds is None: + self._states = Matrix(coord_states) + + if coord_idxs is None: + self._coordinates = None + else: + coords = [coord_states[i] for i in coord_idxs] + self._coordinates = Matrix(coords) + + if speed_idxs is None: + self._speeds = None + else: + speeds_inter = [coord_states[i] for i in speed_idxs] + self._speeds = Matrix(speeds_inter) + else: + self._coordinates = Matrix(coord_states) + self._speeds = Matrix(speeds) + self._states = self._coordinates.col_join(self._speeds) + + # Extract equations of motion form + if coordinate_derivatives is not None: + self._kin_explicit_rhs = coordinate_derivatives + self._dyn_implicit_rhs = right_hand_side + self._dyn_implicit_mat = mass_matrix + self._comb_implicit_rhs = None + self._comb_implicit_mat = None + self._comb_explicit_rhs = None + elif mass_matrix is not None: + self._kin_explicit_rhs = None + self._dyn_implicit_rhs = None + self._dyn_implicit_mat = None + self._comb_implicit_rhs = right_hand_side + self._comb_implicit_mat = mass_matrix + self._comb_explicit_rhs = None + else: + self._kin_explicit_rhs = None + self._dyn_implicit_rhs = None + self._dyn_implicit_mat = None + self._comb_implicit_rhs = None + self._comb_implicit_mat = None + self._comb_explicit_rhs = right_hand_side + + # Set the remainder of the inputs as instance attributes + if alg_con is not None and coordinate_derivatives is not None: + alg_con = [i + len(coordinate_derivatives) for i in alg_con] + self._alg_con = alg_con + self.output_eqns = output_eqns + + # Change the body and loads iterables to tuples if they are not tuples + # already + if not isinstance(bodies, tuple) and bodies is not None: + bodies = tuple(bodies) + if not isinstance(loads, tuple) and loads is not None: + loads = tuple(loads) + self._bodies = bodies + self._loads = loads + + @property + def coordinates(self): + """Returns the column matrix of the generalized coordinates""" + if self._coordinates is None: + raise AttributeError("The coordinates were not specified.") + else: + return self._coordinates + + @property + def speeds(self): + """Returns the column matrix of generalized speeds""" + if self._speeds is None: + raise AttributeError("The speeds were not specified.") + else: + return self._speeds + + @property + def states(self): + """Returns the column matrix of the state variables""" + return self._states + + @property + def alg_con(self): + """Returns a list with the indices of the rows containing algebraic + constraints in the combined form of the equations of motion""" + return self._alg_con + + @property + def dyn_implicit_mat(self): + """Returns the matrix, M, corresponding to the dynamic equations in + implicit form, M x' = F, where the kinematical equations are not + included""" + if self._dyn_implicit_mat is None: + raise AttributeError("dyn_implicit_mat is not specified for " + "equations of motion form [1] or [2].") + else: + return self._dyn_implicit_mat + + @property + def dyn_implicit_rhs(self): + """Returns the column matrix, F, corresponding to the dynamic equations + in implicit form, M x' = F, where the kinematical equations are not + included""" + if self._dyn_implicit_rhs is None: + raise AttributeError("dyn_implicit_rhs is not specified for " + "equations of motion form [1] or [2].") + else: + return self._dyn_implicit_rhs + + @property + def comb_implicit_mat(self): + """Returns the matrix, M, corresponding to the equations of motion in + implicit form (form [2]), M x' = F, where the kinematical equations are + included""" + if self._comb_implicit_mat is None: + if self._dyn_implicit_mat is not None: + num_kin_eqns = len(self._kin_explicit_rhs) + num_dyn_eqns = len(self._dyn_implicit_rhs) + zeros1 = zeros(num_kin_eqns, num_dyn_eqns) + zeros2 = zeros(num_dyn_eqns, num_kin_eqns) + inter1 = eye(num_kin_eqns).row_join(zeros1) + inter2 = zeros2.row_join(self._dyn_implicit_mat) + self._comb_implicit_mat = inter1.col_join(inter2) + return self._comb_implicit_mat + else: + raise AttributeError("comb_implicit_mat is not specified for " + "equations of motion form [1].") + else: + return self._comb_implicit_mat + + @property + def comb_implicit_rhs(self): + """Returns the column matrix, F, corresponding to the equations of + motion in implicit form (form [2]), M x' = F, where the kinematical + equations are included""" + if self._comb_implicit_rhs is None: + if self._dyn_implicit_rhs is not None: + kin_inter = self._kin_explicit_rhs + dyn_inter = self._dyn_implicit_rhs + self._comb_implicit_rhs = kin_inter.col_join(dyn_inter) + return self._comb_implicit_rhs + else: + raise AttributeError("comb_implicit_mat is not specified for " + "equations of motion in form [1].") + else: + return self._comb_implicit_rhs + + def compute_explicit_form(self): + """If the explicit right hand side of the combined equations of motion + is to provided upon initialization, this method will calculate it. This + calculation can potentially take awhile to compute.""" + if self._comb_explicit_rhs is not None: + raise AttributeError("comb_explicit_rhs is already formed.") + + inter1 = getattr(self, 'kin_explicit_rhs', None) + if inter1 is not None: + inter2 = self._dyn_implicit_mat.LUsolve(self._dyn_implicit_rhs) + out = inter1.col_join(inter2) + else: + out = self._comb_implicit_mat.LUsolve(self._comb_implicit_rhs) + + self._comb_explicit_rhs = out + + @property + def comb_explicit_rhs(self): + """Returns the right hand side of the equations of motion in explicit + form, x' = F, where the kinematical equations are included""" + if self._comb_explicit_rhs is None: + raise AttributeError("Please run .combute_explicit_form before " + "attempting to access comb_explicit_rhs.") + else: + return self._comb_explicit_rhs + + @property + def kin_explicit_rhs(self): + """Returns the right hand side of the kinematical equations in explicit + form, q' = G""" + if self._kin_explicit_rhs is None: + raise AttributeError("kin_explicit_rhs is not specified for " + "equations of motion form [1] or [2].") + else: + return self._kin_explicit_rhs + + def dynamic_symbols(self): + """Returns a column matrix containing all of the symbols in the system + that depend on time""" + # Create a list of all of the expressions in the equations of motion + if self._comb_explicit_rhs is None: + eom_expressions = (self.comb_implicit_mat[:] + + self.comb_implicit_rhs[:]) + else: + eom_expressions = (self._comb_explicit_rhs[:]) + + functions_of_time = set() + for expr in eom_expressions: + functions_of_time = functions_of_time.union( + find_dynamicsymbols(expr)) + functions_of_time = functions_of_time.union(self._states) + + return tuple(functions_of_time) + + def constant_symbols(self): + """Returns a column matrix containing all of the symbols in the system + that do not depend on time""" + # Create a list of all of the expressions in the equations of motion + if self._comb_explicit_rhs is None: + eom_expressions = (self.comb_implicit_mat[:] + + self.comb_implicit_rhs[:]) + else: + eom_expressions = (self._comb_explicit_rhs[:]) + + constants = set() + for expr in eom_expressions: + constants = constants.union(expr.free_symbols) + constants.remove(dynamicsymbols._t) + + return tuple(constants) + + @property + def bodies(self): + """Returns the bodies in the system""" + if self._bodies is None: + raise AttributeError("bodies were not specified for the system.") + else: + return self._bodies + + @property + def loads(self): + """Returns the loads in the system""" + if self._loads is None: + raise AttributeError("loads were not specified for the system.") + else: + return self._loads diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__init__.py b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37c3a8882af97610a6f27fafbfb96da8024e7328 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_actuator.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_actuator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91306bec4b48da9b03c503ec9334fe83fc028ccf Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_actuator.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_body.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_body.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b29d19bd47edd22aff87b4408697cd8860d80d21 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_body.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_functions.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_functions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d32949d1baf14a432c148773e5574bdcdaf34ef7 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_functions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_inertia.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_inertia.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8295132546a47832f6a8e4af847d746cb87303b4 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_inertia.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_joint.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_joint.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88685a21bfb9ecf9b5dcc5bccd8c9e89d3c4230d Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_joint.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_jointsmethod.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_jointsmethod.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c5518c9e8568205085eb02fa4f19691a05b8d38 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_jointsmethod.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_kane.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_kane.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27e66dc8ee3321a222ceb2c4f69a4c59880d642b Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_kane.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_kane2.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_kane2.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81c16b47d7b66b36d53b9458c768d525de531ba0 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_kane2.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_kane3.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_kane3.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2aff1132ed1d90f99338e4659fbb4724a1cee8b0 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_kane3.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_kane4.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_kane4.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6442b2ff5803be72fc3109b5995211611917909e Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_kane4.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_kane5.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_kane5.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..224e6406624cfe82cc1e970f8d244e42661e07b0 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_kane5.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_lagrange.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_lagrange.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6cf5c95d13c6f1871268c9f2bab1bdca1ab93f6a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_lagrange.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_lagrange2.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_lagrange2.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d51a6bc08b4d7f53b27a79c3a01881b3a59263fc Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_lagrange2.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_linearity_of_velocity_constraints.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_linearity_of_velocity_constraints.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64c99f2a5e127558f5fca37bed3ba5d524ae0be2 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_linearity_of_velocity_constraints.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_linearize.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_linearize.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba20081d1102e13f2a97c468f9ee0cc57ba20796 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_linearize.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_loads.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_loads.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd63262e8d8bbe327ba5eb0eebe1abcd8fc22f7e Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_loads.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_method.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_method.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95c33d95fecbd79ace661152fde35d94054f1305 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_method.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_models.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_models.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43a6a8704078f1608d0a460828cfac19a6c4b53f Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_models.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_particle.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_particle.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..007d000b7fc6980ef1ef1e63b2c368726c309634 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_particle.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_pathway.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_pathway.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c918b3097581854441798b8d5fa6abe0c47ebd8 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_pathway.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_rigidbody.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_rigidbody.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a47a6fc5ec6e184bc041af6f31eea9e5b9c5b1d9 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_rigidbody.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_system.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_system.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..54c221a0c8e078d16d8ac6a02328b53b244c244b Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_system.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_system_class.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_system_class.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ce6bb8e33b4730e99576b0aee0816a1cf8126f1 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_system_class.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_wrapping_geometry.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_wrapping_geometry.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..292940a39a5a928437d25bcea465dd7930e0e81a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/__pycache__/test_wrapping_geometry.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_actuator.py b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_actuator.py new file mode 100644 index 0000000000000000000000000000000000000000..5d69bccfe7a86d7555242a64923d77cb52cade88 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_actuator.py @@ -0,0 +1,1084 @@ +"""Tests for the ``sympy.physics.mechanics.actuator.py`` module.""" + +import pytest + +from sympy import ( + S, + Matrix, + Symbol, + SympifyError, + sqrt, + Abs, + symbols, + exp, + sign, +) +from sympy.physics.mechanics import ( + ActuatorBase, + Force, + ForceActuator, + KanesMethod, + LinearDamper, + LinearPathway, + LinearSpring, + Particle, + PinJoint, + Point, + ReferenceFrame, + RigidBody, + TorqueActuator, + Vector, + dynamicsymbols, + DuffingSpring, + CoulombKineticFriction, +) + +from sympy.core.expr import Expr as ExprType + +target = RigidBody('target') +reaction = RigidBody('reaction') + + +class TestForceActuator: + + @pytest.fixture(autouse=True) + def _linear_pathway_fixture(self): + self.force = Symbol('F') + self.pA = Point('pA') + self.pB = Point('pB') + self.pathway = LinearPathway(self.pA, self.pB) + self.q1 = dynamicsymbols('q1') + self.q2 = dynamicsymbols('q2') + self.q3 = dynamicsymbols('q3') + self.q1d = dynamicsymbols('q1', 1) + self.q2d = dynamicsymbols('q2', 1) + self.q3d = dynamicsymbols('q3', 1) + self.N = ReferenceFrame('N') + + def test_is_actuator_base_subclass(self): + assert issubclass(ForceActuator, ActuatorBase) + + @pytest.mark.parametrize( + 'force, expected_force', + [ + (1, S.One), + (S.One, S.One), + (Symbol('F'), Symbol('F')), + (dynamicsymbols('F'), dynamicsymbols('F')), + (Symbol('F')**2 + Symbol('F'), Symbol('F')**2 + Symbol('F')), + ] + ) + def test_valid_constructor_force(self, force, expected_force): + instance = ForceActuator(force, self.pathway) + assert isinstance(instance, ForceActuator) + assert hasattr(instance, 'force') + assert isinstance(instance.force, ExprType) + assert instance.force == expected_force + + @pytest.mark.parametrize('force', [None, 'F']) + def test_invalid_constructor_force_not_sympifyable(self, force): + with pytest.raises(SympifyError): + _ = ForceActuator(force, self.pathway) + + @pytest.mark.parametrize( + 'pathway', + [ + LinearPathway(Point('pA'), Point('pB')), + ] + ) + def test_valid_constructor_pathway(self, pathway): + instance = ForceActuator(self.force, pathway) + assert isinstance(instance, ForceActuator) + assert hasattr(instance, 'pathway') + assert isinstance(instance.pathway, LinearPathway) + assert instance.pathway == pathway + + def test_invalid_constructor_pathway_not_pathway_base(self): + with pytest.raises(TypeError): + _ = ForceActuator(self.force, None) + + @pytest.mark.parametrize( + 'property_name, fixture_attr_name', + [ + ('force', 'force'), + ('pathway', 'pathway'), + ] + ) + def test_properties_are_immutable(self, property_name, fixture_attr_name): + instance = ForceActuator(self.force, self.pathway) + value = getattr(self, fixture_attr_name) + with pytest.raises(AttributeError): + setattr(instance, property_name, value) + + def test_repr(self): + actuator = ForceActuator(self.force, self.pathway) + expected = "ForceActuator(F, LinearPathway(pA, pB))" + assert repr(actuator) == expected + + def test_to_loads_static_pathway(self): + self.pB.set_pos(self.pA, 2*self.N.x) + actuator = ForceActuator(self.force, self.pathway) + expected = [ + (self.pA, - self.force*self.N.x), + (self.pB, self.force*self.N.x), + ] + assert actuator.to_loads() == expected + + def test_to_loads_2D_pathway(self): + self.pB.set_pos(self.pA, 2*self.q1*self.N.x) + actuator = ForceActuator(self.force, self.pathway) + expected = [ + (self.pA, - self.force*(self.q1/sqrt(self.q1**2))*self.N.x), + (self.pB, self.force*(self.q1/sqrt(self.q1**2))*self.N.x), + ] + assert actuator.to_loads() == expected + + def test_to_loads_3D_pathway(self): + self.pB.set_pos( + self.pA, + self.q1*self.N.x - self.q2*self.N.y + 2*self.q3*self.N.z, + ) + actuator = ForceActuator(self.force, self.pathway) + length = sqrt(self.q1**2 + self.q2**2 + 4*self.q3**2) + pO_force = ( + - self.force*self.q1*self.N.x/length + + self.force*self.q2*self.N.y/length + - 2*self.force*self.q3*self.N.z/length + ) + pI_force = ( + self.force*self.q1*self.N.x/length + - self.force*self.q2*self.N.y/length + + 2*self.force*self.q3*self.N.z/length + ) + expected = [ + (self.pA, pO_force), + (self.pB, pI_force), + ] + assert actuator.to_loads() == expected + + +class TestLinearSpring: + + @pytest.fixture(autouse=True) + def _linear_spring_fixture(self): + self.stiffness = Symbol('k') + self.l = Symbol('l') + self.pA = Point('pA') + self.pB = Point('pB') + self.pathway = LinearPathway(self.pA, self.pB) + self.q = dynamicsymbols('q') + self.N = ReferenceFrame('N') + + def test_is_force_actuator_subclass(self): + assert issubclass(LinearSpring, ForceActuator) + + def test_is_actuator_base_subclass(self): + assert issubclass(LinearSpring, ActuatorBase) + + @pytest.mark.parametrize( + ( + 'stiffness, ' + 'expected_stiffness, ' + 'equilibrium_length, ' + 'expected_equilibrium_length, ' + 'force' + ), + [ + ( + 1, + S.One, + 0, + S.Zero, + -sqrt(dynamicsymbols('q')**2), + ), + ( + Symbol('k'), + Symbol('k'), + 0, + S.Zero, + -Symbol('k')*sqrt(dynamicsymbols('q')**2), + ), + ( + Symbol('k'), + Symbol('k'), + S.Zero, + S.Zero, + -Symbol('k')*sqrt(dynamicsymbols('q')**2), + ), + ( + Symbol('k'), + Symbol('k'), + Symbol('l'), + Symbol('l'), + -Symbol('k')*(sqrt(dynamicsymbols('q')**2) - Symbol('l')), + ), + ] + ) + def test_valid_constructor( + self, + stiffness, + expected_stiffness, + equilibrium_length, + expected_equilibrium_length, + force, + ): + self.pB.set_pos(self.pA, self.q*self.N.x) + spring = LinearSpring(stiffness, self.pathway, equilibrium_length) + + assert isinstance(spring, LinearSpring) + + assert hasattr(spring, 'stiffness') + assert isinstance(spring.stiffness, ExprType) + assert spring.stiffness == expected_stiffness + + assert hasattr(spring, 'pathway') + assert isinstance(spring.pathway, LinearPathway) + assert spring.pathway == self.pathway + + assert hasattr(spring, 'equilibrium_length') + assert isinstance(spring.equilibrium_length, ExprType) + assert spring.equilibrium_length == expected_equilibrium_length + + assert hasattr(spring, 'force') + assert isinstance(spring.force, ExprType) + assert spring.force == force + + @pytest.mark.parametrize('stiffness', [None, 'k']) + def test_invalid_constructor_stiffness_not_sympifyable(self, stiffness): + with pytest.raises(SympifyError): + _ = LinearSpring(stiffness, self.pathway, self.l) + + def test_invalid_constructor_pathway_not_pathway_base(self): + with pytest.raises(TypeError): + _ = LinearSpring(self.stiffness, None, self.l) + + @pytest.mark.parametrize('equilibrium_length', [None, 'l']) + def test_invalid_constructor_equilibrium_length_not_sympifyable( + self, + equilibrium_length, + ): + with pytest.raises(SympifyError): + _ = LinearSpring(self.stiffness, self.pathway, equilibrium_length) + + @pytest.mark.parametrize( + 'property_name, fixture_attr_name', + [ + ('stiffness', 'stiffness'), + ('pathway', 'pathway'), + ('equilibrium_length', 'l'), + ] + ) + def test_properties_are_immutable(self, property_name, fixture_attr_name): + spring = LinearSpring(self.stiffness, self.pathway, self.l) + value = getattr(self, fixture_attr_name) + with pytest.raises(AttributeError): + setattr(spring, property_name, value) + + @pytest.mark.parametrize( + 'equilibrium_length, expected', + [ + (S.Zero, 'LinearSpring(k, LinearPathway(pA, pB))'), + ( + Symbol('l'), + 'LinearSpring(k, LinearPathway(pA, pB), equilibrium_length=l)', + ), + ] + ) + def test_repr(self, equilibrium_length, expected): + self.pB.set_pos(self.pA, self.q*self.N.x) + spring = LinearSpring(self.stiffness, self.pathway, equilibrium_length) + assert repr(spring) == expected + + def test_to_loads(self): + self.pB.set_pos(self.pA, self.q*self.N.x) + spring = LinearSpring(self.stiffness, self.pathway, self.l) + normal = self.q/sqrt(self.q**2)*self.N.x + pA_force = self.stiffness*(sqrt(self.q**2) - self.l)*normal + pB_force = -self.stiffness*(sqrt(self.q**2) - self.l)*normal + expected = [Force(self.pA, pA_force), Force(self.pB, pB_force)] + loads = spring.to_loads() + + for load, (point, vector) in zip(loads, expected): + assert isinstance(load, Force) + assert load.point == point + assert (load.vector - vector).simplify() == 0 + + +class TestLinearDamper: + + @pytest.fixture(autouse=True) + def _linear_damper_fixture(self): + self.damping = Symbol('c') + self.l = Symbol('l') + self.pA = Point('pA') + self.pB = Point('pB') + self.pathway = LinearPathway(self.pA, self.pB) + self.q = dynamicsymbols('q') + self.dq = dynamicsymbols('q', 1) + self.u = dynamicsymbols('u') + self.N = ReferenceFrame('N') + + def test_is_force_actuator_subclass(self): + assert issubclass(LinearDamper, ForceActuator) + + def test_is_actuator_base_subclass(self): + assert issubclass(LinearDamper, ActuatorBase) + + def test_valid_constructor(self): + self.pB.set_pos(self.pA, self.q*self.N.x) + damper = LinearDamper(self.damping, self.pathway) + + assert isinstance(damper, LinearDamper) + + assert hasattr(damper, 'damping') + assert isinstance(damper.damping, ExprType) + assert damper.damping == self.damping + + assert hasattr(damper, 'pathway') + assert isinstance(damper.pathway, LinearPathway) + assert damper.pathway == self.pathway + + def test_valid_constructor_force(self): + self.pB.set_pos(self.pA, self.q*self.N.x) + damper = LinearDamper(self.damping, self.pathway) + + expected_force = -self.damping*sqrt(self.q**2)*self.dq/self.q + assert hasattr(damper, 'force') + assert isinstance(damper.force, ExprType) + assert damper.force == expected_force + + @pytest.mark.parametrize('damping', [None, 'c']) + def test_invalid_constructor_damping_not_sympifyable(self, damping): + with pytest.raises(SympifyError): + _ = LinearDamper(damping, self.pathway) + + def test_invalid_constructor_pathway_not_pathway_base(self): + with pytest.raises(TypeError): + _ = LinearDamper(self.damping, None) + + @pytest.mark.parametrize( + 'property_name, fixture_attr_name', + [ + ('damping', 'damping'), + ('pathway', 'pathway'), + ] + ) + def test_properties_are_immutable(self, property_name, fixture_attr_name): + damper = LinearDamper(self.damping, self.pathway) + value = getattr(self, fixture_attr_name) + with pytest.raises(AttributeError): + setattr(damper, property_name, value) + + def test_repr(self): + self.pB.set_pos(self.pA, self.q*self.N.x) + damper = LinearDamper(self.damping, self.pathway) + expected = 'LinearDamper(c, LinearPathway(pA, pB))' + assert repr(damper) == expected + + def test_to_loads(self): + self.pB.set_pos(self.pA, self.q*self.N.x) + damper = LinearDamper(self.damping, self.pathway) + direction = self.q**2/self.q**2*self.N.x + pA_force = self.damping*self.dq*direction + pB_force = -self.damping*self.dq*direction + expected = [Force(self.pA, pA_force), Force(self.pB, pB_force)] + assert damper.to_loads() == expected + + +class TestForcedMassSpringDamperModel(): + r"""A single degree of freedom translational forced mass-spring-damper. + + Notes + ===== + + This system is well known to have the governing equation: + + .. math:: + m \ddot{x} = F - k x - c \dot{x} + + where $F$ is an externally applied force, $m$ is the mass of the particle + to which the spring and damper are attached, $k$ is the spring's stiffness, + $c$ is the dampers damping coefficient, and $x$ is the generalized + coordinate representing the system's single (translational) degree of + freedom. + + """ + + @pytest.fixture(autouse=True) + def _force_mass_spring_damper_model_fixture(self): + self.m = Symbol('m') + self.k = Symbol('k') + self.c = Symbol('c') + self.F = Symbol('F') + + self.q = dynamicsymbols('q') + self.dq = dynamicsymbols('q', 1) + self.u = dynamicsymbols('u') + + self.frame = ReferenceFrame('N') + self.origin = Point('pO') + self.origin.set_vel(self.frame, 0) + + self.attachment = Point('pA') + self.attachment.set_pos(self.origin, self.q*self.frame.x) + + self.mass = Particle('mass', self.attachment, self.m) + self.pathway = LinearPathway(self.origin, self.attachment) + + self.kanes_method = KanesMethod( + self.frame, + q_ind=[self.q], + u_ind=[self.u], + kd_eqs=[self.dq - self.u], + ) + self.bodies = [self.mass] + + self.mass_matrix = Matrix([[self.m]]) + self.forcing = Matrix([[self.F - self.c*self.u - self.k*self.q]]) + + def test_force_acuator(self): + stiffness = -self.k*self.pathway.length + spring = ForceActuator(stiffness, self.pathway) + damping = -self.c*self.pathway.extension_velocity + damper = ForceActuator(damping, self.pathway) + + loads = [ + (self.attachment, self.F*self.frame.x), + *spring.to_loads(), + *damper.to_loads(), + ] + self.kanes_method.kanes_equations(self.bodies, loads) + + assert self.kanes_method.mass_matrix == self.mass_matrix + assert self.kanes_method.forcing == self.forcing + + def test_linear_spring_linear_damper(self): + spring = LinearSpring(self.k, self.pathway) + damper = LinearDamper(self.c, self.pathway) + + loads = [ + (self.attachment, self.F*self.frame.x), + *spring.to_loads(), + *damper.to_loads(), + ] + self.kanes_method.kanes_equations(self.bodies, loads) + + assert self.kanes_method.mass_matrix == self.mass_matrix + assert self.kanes_method.forcing == self.forcing + + +class TestTorqueActuator: + + @pytest.fixture(autouse=True) + def _torque_actuator_fixture(self): + self.torque = Symbol('T') + self.N = ReferenceFrame('N') + self.A = ReferenceFrame('A') + self.axis = self.N.z + self.target = RigidBody('target', frame=self.N) + self.reaction = RigidBody('reaction', frame=self.A) + + def test_is_actuator_base_subclass(self): + assert issubclass(TorqueActuator, ActuatorBase) + + @pytest.mark.parametrize( + 'torque', + [ + Symbol('T'), + dynamicsymbols('T'), + Symbol('T')**2 + Symbol('T'), + ] + ) + @pytest.mark.parametrize( + 'target_frame, reaction_frame', + [ + (target.frame, reaction.frame), + (target, reaction.frame), + (target.frame, reaction), + (target, reaction), + ] + ) + def test_valid_constructor_with_reaction( + self, + torque, + target_frame, + reaction_frame, + ): + instance = TorqueActuator( + torque, + self.axis, + target_frame, + reaction_frame, + ) + assert isinstance(instance, TorqueActuator) + + assert hasattr(instance, 'torque') + assert isinstance(instance.torque, ExprType) + assert instance.torque == torque + + assert hasattr(instance, 'axis') + assert isinstance(instance.axis, Vector) + assert instance.axis == self.axis + + assert hasattr(instance, 'target_frame') + assert isinstance(instance.target_frame, ReferenceFrame) + assert instance.target_frame == target.frame + + assert hasattr(instance, 'reaction_frame') + assert isinstance(instance.reaction_frame, ReferenceFrame) + assert instance.reaction_frame == reaction.frame + + @pytest.mark.parametrize( + 'torque', + [ + Symbol('T'), + dynamicsymbols('T'), + Symbol('T')**2 + Symbol('T'), + ] + ) + @pytest.mark.parametrize('target_frame', [target.frame, target]) + def test_valid_constructor_without_reaction(self, torque, target_frame): + instance = TorqueActuator(torque, self.axis, target_frame) + assert isinstance(instance, TorqueActuator) + + assert hasattr(instance, 'torque') + assert isinstance(instance.torque, ExprType) + assert instance.torque == torque + + assert hasattr(instance, 'axis') + assert isinstance(instance.axis, Vector) + assert instance.axis == self.axis + + assert hasattr(instance, 'target_frame') + assert isinstance(instance.target_frame, ReferenceFrame) + assert instance.target_frame == target.frame + + assert hasattr(instance, 'reaction_frame') + assert instance.reaction_frame is None + + @pytest.mark.parametrize('torque', [None, 'T']) + def test_invalid_constructor_torque_not_sympifyable(self, torque): + with pytest.raises(SympifyError): + _ = TorqueActuator(torque, self.axis, self.target) + + @pytest.mark.parametrize('axis', [Symbol('a'), dynamicsymbols('a')]) + def test_invalid_constructor_axis_not_vector(self, axis): + with pytest.raises(TypeError): + _ = TorqueActuator(self.torque, axis, self.target, self.reaction) + + @pytest.mark.parametrize( + 'frames', + [ + (None, ReferenceFrame('child')), + (ReferenceFrame('parent'), True), + (None, RigidBody('child')), + (RigidBody('parent'), True), + ] + ) + def test_invalid_constructor_frames_not_frame(self, frames): + with pytest.raises(TypeError): + _ = TorqueActuator(self.torque, self.axis, *frames) + + @pytest.mark.parametrize( + 'property_name, fixture_attr_name', + [ + ('torque', 'torque'), + ('axis', 'axis'), + ('target_frame', 'target'), + ('reaction_frame', 'reaction'), + ] + ) + def test_properties_are_immutable(self, property_name, fixture_attr_name): + actuator = TorqueActuator( + self.torque, + self.axis, + self.target, + self.reaction, + ) + value = getattr(self, fixture_attr_name) + with pytest.raises(AttributeError): + setattr(actuator, property_name, value) + + def test_repr_without_reaction(self): + actuator = TorqueActuator(self.torque, self.axis, self.target) + expected = 'TorqueActuator(T, axis=N.z, target_frame=N)' + assert repr(actuator) == expected + + def test_repr_with_reaction(self): + actuator = TorqueActuator( + self.torque, + self.axis, + self.target, + self.reaction, + ) + expected = 'TorqueActuator(T, axis=N.z, target_frame=N, reaction_frame=A)' + assert repr(actuator) == expected + + def test_at_pin_joint_constructor(self): + pin_joint = PinJoint( + 'pin', + self.target, + self.reaction, + coordinates=dynamicsymbols('q'), + speeds=dynamicsymbols('u'), + parent_interframe=self.N, + joint_axis=self.axis, + ) + instance = TorqueActuator.at_pin_joint(self.torque, pin_joint) + assert isinstance(instance, TorqueActuator) + + assert hasattr(instance, 'torque') + assert isinstance(instance.torque, ExprType) + assert instance.torque == self.torque + + assert hasattr(instance, 'axis') + assert isinstance(instance.axis, Vector) + assert instance.axis == self.axis + + assert hasattr(instance, 'target_frame') + assert isinstance(instance.target_frame, ReferenceFrame) + assert instance.target_frame == self.A + + assert hasattr(instance, 'reaction_frame') + assert isinstance(instance.reaction_frame, ReferenceFrame) + assert instance.reaction_frame == self.N + + def test_at_pin_joint_pin_joint_not_pin_joint_invalid(self): + with pytest.raises(TypeError): + _ = TorqueActuator.at_pin_joint(self.torque, Symbol('pin')) + + def test_to_loads_without_reaction(self): + actuator = TorqueActuator(self.torque, self.axis, self.target) + expected = [ + (self.N, self.torque*self.axis), + ] + assert actuator.to_loads() == expected + + def test_to_loads_with_reaction(self): + actuator = TorqueActuator( + self.torque, + self.axis, + self.target, + self.reaction, + ) + expected = [ + (self.N, self.torque*self.axis), + (self.A, - self.torque*self.axis), + ] + assert actuator.to_loads() == expected + + +class NonSympifyable: + pass + + +class TestDuffingSpring: + @pytest.fixture(autouse=True) + # Set up common variables that will be used in multiple tests + def _duffing_spring_fixture(self): + self.linear_stiffness = Symbol('beta') + self.nonlinear_stiffness = Symbol('alpha') + self.equilibrium_length = Symbol('l') + self.pA = Point('pA') + self.pB = Point('pB') + self.pathway = LinearPathway(self.pA, self.pB) + self.q = dynamicsymbols('q') + self.N = ReferenceFrame('N') + + # Simples tests to check that DuffingSpring is a subclass of ForceActuator and ActuatorBase + def test_is_force_actuator_subclass(self): + assert issubclass(DuffingSpring, ForceActuator) + + def test_is_actuator_base_subclass(self): + assert issubclass(DuffingSpring, ActuatorBase) + + @pytest.mark.parametrize( + # Create parametrized tests that allows running the same test function multiple times with different sets of arguments + ( + 'linear_stiffness, ' + 'expected_linear_stiffness, ' + 'nonlinear_stiffness, ' + 'expected_nonlinear_stiffness, ' + 'equilibrium_length, ' + 'expected_equilibrium_length, ' + 'force' + ), + [ + ( + 1, + S.One, + 1, + S.One, + 0, + S.Zero, + -sqrt(dynamicsymbols('q')**2)-(sqrt(dynamicsymbols('q')**2))**3, + ), + ( + Symbol('beta'), + Symbol('beta'), + Symbol('alpha'), + Symbol('alpha'), + 0, + S.Zero, + -Symbol('beta')*sqrt(dynamicsymbols('q')**2)-Symbol('alpha')*(sqrt(dynamicsymbols('q')**2))**3, + ), + ( + Symbol('beta'), + Symbol('beta'), + Symbol('alpha'), + Symbol('alpha'), + S.Zero, + S.Zero, + -Symbol('beta')*sqrt(dynamicsymbols('q')**2)-Symbol('alpha')*(sqrt(dynamicsymbols('q')**2))**3, + ), + ( + Symbol('beta'), + Symbol('beta'), + Symbol('alpha'), + Symbol('alpha'), + Symbol('l'), + Symbol('l'), + -Symbol('beta') * (sqrt(dynamicsymbols('q')**2) - Symbol('l')) - Symbol('alpha') * (sqrt(dynamicsymbols('q')**2) - Symbol('l'))**3, + ), + ] + ) + + # Check if DuffingSpring correctly initializes its attributes + # It tests various combinations of linear & nonlinear stiffness, equilibriun length, and the resulting force expression + def test_valid_constructor( + self, + linear_stiffness, + expected_linear_stiffness, + nonlinear_stiffness, + expected_nonlinear_stiffness, + equilibrium_length, + expected_equilibrium_length, + force, + ): + self.pB.set_pos(self.pA, self.q*self.N.x) + spring = DuffingSpring(linear_stiffness, nonlinear_stiffness, self.pathway, equilibrium_length) + + assert isinstance(spring, DuffingSpring) + + assert hasattr(spring, 'linear_stiffness') + assert isinstance(spring.linear_stiffness, ExprType) + assert spring.linear_stiffness == expected_linear_stiffness + + assert hasattr(spring, 'nonlinear_stiffness') + assert isinstance(spring.nonlinear_stiffness, ExprType) + assert spring.nonlinear_stiffness == expected_nonlinear_stiffness + + assert hasattr(spring, 'pathway') + assert isinstance(spring.pathway, LinearPathway) + assert spring.pathway == self.pathway + + assert hasattr(spring, 'equilibrium_length') + assert isinstance(spring.equilibrium_length, ExprType) + assert spring.equilibrium_length == expected_equilibrium_length + + assert hasattr(spring, 'force') + assert isinstance(spring.force, ExprType) + assert spring.force == force + + @pytest.mark.parametrize('linear_stiffness', [None, NonSympifyable()]) + def test_invalid_constructor_linear_stiffness_not_sympifyable(self, linear_stiffness): + with pytest.raises(SympifyError): + _ = DuffingSpring(linear_stiffness, self.nonlinear_stiffness, self.pathway, self.equilibrium_length) + + @pytest.mark.parametrize('nonlinear_stiffness', [None, NonSympifyable()]) + def test_invalid_constructor_nonlinear_stiffness_not_sympifyable(self, nonlinear_stiffness): + with pytest.raises(SympifyError): + _ = DuffingSpring(self.linear_stiffness, nonlinear_stiffness, self.pathway, self.equilibrium_length) + + def test_invalid_constructor_pathway_not_pathway_base(self): + with pytest.raises(TypeError): + _ = DuffingSpring(self.linear_stiffness, self.nonlinear_stiffness, NonSympifyable(), self.equilibrium_length) + + @pytest.mark.parametrize('equilibrium_length', [None, NonSympifyable()]) + def test_invalid_constructor_equilibrium_length_not_sympifyable(self, equilibrium_length): + with pytest.raises(SympifyError): + _ = DuffingSpring(self.linear_stiffness, self.nonlinear_stiffness, self.pathway, equilibrium_length) + + @pytest.mark.parametrize( + 'property_name, fixture_attr_name', + [ + ('linear_stiffness', 'linear_stiffness'), + ('nonlinear_stiffness', 'nonlinear_stiffness'), + ('pathway', 'pathway'), + ('equilibrium_length', 'equilibrium_length') + ] + ) + # Check if certain properties of DuffingSpring object are immutable after initialization + # Ensure that once DuffingSpring is created, its key properties cannot be changed + def test_properties_are_immutable(self, property_name, fixture_attr_name): + spring = DuffingSpring(self.linear_stiffness, self.nonlinear_stiffness, self.pathway, self.equilibrium_length) + with pytest.raises(AttributeError): + setattr(spring, property_name, getattr(self, fixture_attr_name)) + + @pytest.mark.parametrize( + 'equilibrium_length, expected', + [ + (0, 'DuffingSpring(beta, alpha, LinearPathway(pA, pB), equilibrium_length=0)'), + (Symbol('l'), 'DuffingSpring(beta, alpha, LinearPathway(pA, pB), equilibrium_length=l)'), + ] + ) + # Check the __repr__ method of DuffingSpring class + # Check if the actual string representation of DuffingSpring instance matches the expected string for each provided parameter values + def test_repr(self, equilibrium_length, expected): + spring = DuffingSpring(self.linear_stiffness, self.nonlinear_stiffness, self.pathway, equilibrium_length) + assert repr(spring) == expected + + def test_to_loads(self): + self.pB.set_pos(self.pA, self.q*self.N.x) + spring = DuffingSpring(self.linear_stiffness, self.nonlinear_stiffness, self.pathway, self.equilibrium_length) + + # Calculate the displacement from the equilibrium length + displacement = self.q - self.equilibrium_length + + # Make sure this matches the computation in DuffingSpring class + force = -self.linear_stiffness * displacement - self.nonlinear_stiffness * displacement**3 + + # The expected loads on pA and pB due to the spring + expected_loads = [Force(self.pA, force * self.N.x), Force(self.pB, -force * self.N.x)] + + # Compare expected loads to what is returned from DuffingSpring.to_loads() + calculated_loads = spring.to_loads() + for calculated, expected in zip(calculated_loads, expected_loads): + assert calculated.point == expected.point + for dim in self.N: # Assuming self.N is the reference frame + calculated_component = calculated.vector.dot(dim) + expected_component = expected.vector.dot(dim) + # Substitute all symbols with numeric values + substitutions = {self.q: 1, Symbol('l'): 1, Symbol('alpha'): 1, Symbol('beta'): 1} # Add other necessary symbols as needed + diff = (calculated_component - expected_component).subs(substitutions).evalf() + # Check if the absolute value of the difference is below a threshold + assert Abs(diff) < 1e-9, f"The forces do not match. Difference: {diff}" + +class TestCoulombKineticFriction: + @pytest.fixture(autouse=True) + def _block_on_surface(self): + """A block sliding on a surface. + + Notes + ===== + This test validates the correctness of the CoulombKineticFriction by simulating + a block sliding on a surface with the Coulomb kinetic friction force. + The test covers scenarios with both positive and negative velocities. + + """ + + # Mass, gravity constant, friction coefficient, coefficient of Stribeck friction, viscous_coefficient + self.m, self.g, self.mu_k, self.mu_s, self.v_s, self.sigma, self.F = symbols('m g mu_k mu_s v_s sigma F', real=True) + + def test_block_on_surface_default(self): + # General Case + q = dynamicsymbols('q') + + N = ReferenceFrame('N') + O = Point('O') + P = O.locatenew('P', q * N.x) + O.set_vel(N, 0) + P.set_vel(N, q.diff() * N.x) + + pathway = LinearPathway(O, P) + friction = CoulombKineticFriction(self.mu_k, self.m * self.g, pathway) + expected_general = [Force(point=O, force=self.g * self.m * self.mu_k * q * sign(sqrt(q**2) * q.diff()/q)/sqrt(q**2) * N.x), + Force(point=P, force=-self.g * self.m * self.mu_k * q * sign(sqrt(q**2) * q.diff()/q)/sqrt(q**2) * N.x)] + + assert friction.to_loads() == expected_general + + # Positive + q = dynamicsymbols('q', positive=True) + + N = ReferenceFrame('N') + O = Point('O') + P = O.locatenew('P', q * N.x) + O.set_vel(N, 0) + P.set_vel(N, q.diff() * N.x) + + pathway = LinearPathway(O, P) + friction = CoulombKineticFriction(self.mu_k, self.m * self.g, pathway) + expected_positive = [Force(point=O, force=self.g * self.m * self.mu_k * sign(q.diff()) * N.x), + Force(point=P, force=-self.g * self.m * self.mu_k * sign(q.diff()) * N.x)] + + assert friction.to_loads() == expected_positive + + # Negative + q = dynamicsymbols('q', positive=False) + + N = ReferenceFrame('N') + O = Point('O') + P = O.locatenew('P', q * N.x) + O.set_vel(N, 0) + P.set_vel(N, q.diff() * N.x) + + pathway = LinearPathway(O, P) + friction = CoulombKineticFriction(self.mu_k, self.m * self.g, pathway) + expected_negative = [Force(point=O, force=self.g * self.m * self.mu_k * q * sign(sqrt(q**2) * q.diff()/q)/sqrt(q**2)*N.x), + Force(point=P, force=-self.g * self.m * self.mu_k * q * sign(sqrt(q**2) * q.diff()/q)/sqrt(q**2)*N.x)] + + assert friction.to_loads() == expected_negative + + def test_block_on_surface_viscous(self): + # General Case + q = dynamicsymbols('q') + + N = ReferenceFrame('N') + O = Point('O') + P = O.locatenew('P', q * N.x) + O.set_vel(N, 0) + P.set_vel(N, q.diff() * N.x) + + pathway = LinearPathway(O, P) + friction = CoulombKineticFriction(self.mu_k, self.m * self.g, pathway, sigma=self.sigma) + expected_general = [Force(point=O, force=(self.g * self.m * self.mu_k * sign(sqrt(q**2) * q.diff()/q) + self.sigma * sqrt(q**2) * q.diff()/q) * q/sqrt(q**2) * N.x), + Force(point=P, force=(-self.g * self.m * self.mu_k * sign(sqrt(q**2) * q.diff()/q) - self.sigma * sqrt(q**2) * q.diff()/q) * q/sqrt(q**2) * N.x)] + + assert friction.to_loads() == expected_general + + # Positive + q = dynamicsymbols('q', positive=True) + + N = ReferenceFrame('N') + O = Point('O') + P = O.locatenew('P', q * N.x) + O.set_vel(N, 0) + P.set_vel(N, q.diff() * N.x) + + pathway = LinearPathway(O, P) + friction = CoulombKineticFriction(self.mu_k, self.m * self.g, pathway, sigma=self.sigma) + expected_positive = [Force(point=O, force=(self.g * self.m * self.mu_k * sign(q.diff()) + self.sigma * q.diff()) * N.x), + Force(point=P, force=(-self.g * self.m * self.mu_k * sign(q.diff()) - self.sigma * q.diff()) * N.x)] + + assert friction.to_loads() == expected_positive + + # Negative + q = dynamicsymbols('q', positive=False) + + N = ReferenceFrame('N') + O = Point('O') + P = O.locatenew('P', q * N.x) + O.set_vel(N, 0) + P.set_vel(N, q.diff() * N.x) + + pathway = LinearPathway(O, P) + friction = CoulombKineticFriction(self.mu_k, self.m * self.g, pathway, sigma=self.sigma) + expected_negative = [Force(point=O, force=(self.g * self.m * self.mu_k * sign(sqrt(q**2) * q.diff()/q) + self.sigma * sqrt(q**2) * q.diff()/q) * q/sqrt(q**2) * N.x), + Force(point=P, force=(-self.g * self.m * self.mu_k * sign(sqrt(q**2) * q.diff()/q) - self.sigma * sqrt(q**2) * q.diff()/q) * q/sqrt(q**2) * N.x)] + + assert friction.to_loads() == expected_negative + + def test_block_on_surface_stribeck(self): + # General Case + q = dynamicsymbols('q') + + N = ReferenceFrame('N') + O = Point('O') + P = O.locatenew('P', q * N.x) + O.set_vel(N, 0) + P.set_vel(N, q.diff() * N.x) + + pathway = LinearPathway(O, P) + friction = CoulombKineticFriction(self.mu_k, self.m * self.g, pathway, v_s=self.v_s, mu_s=self.mu_s) + expected_general = [Force(point=O, force=(self.g * self.m * self.mu_k + (-self.g * self.m * self.mu_k + self.g * self.m * self.mu_s) * exp(-q.diff()**2/self.v_s**2)) * q * sign(sqrt(q**2) * q.diff()/q)/sqrt(q**2) * N.x), + Force(point=P, force=- (self.g * self.m * self.mu_k + (-self.g * self.m * self.mu_k + self.g * self.m * self.mu_s) * exp(-q.diff()**2/self.v_s**2)) * q * sign(sqrt(q**2) * q.diff()/q)/sqrt(q**2) * N.x)] + + assert friction.to_loads() == expected_general + + # Positive + q = dynamicsymbols('q', positive=True) + + N = ReferenceFrame('N') + O = Point('O') + P = O.locatenew('P', q * N.x) + O.set_vel(N, 0) + P.set_vel(N, q.diff() * N.x) + + pathway = LinearPathway(O, P) + friction = CoulombKineticFriction(self.mu_k, self.m * self.g, pathway, v_s=self.v_s, mu_s=self.mu_s) + expected_positive = [Force(point=O, force=(self.g * self.m * self.mu_k + (-self.g * self.m * self.mu_k + self.g * self.m * self.mu_s) * exp(-q.diff()**2/self.v_s**2)) * sign(q.diff()) * N.x), + Force(point=P, force=- (self.g * self.m * self.mu_k + (-self.g * self.m * self.mu_k + self.g * self.m * self.mu_s) * exp(-q.diff()**2/self.v_s**2)) * sign(q.diff()) * N.x)] + + assert friction.to_loads() == expected_positive + + # Negative + q = dynamicsymbols('q', positive=False) + + N = ReferenceFrame('N') + O = Point('O') + P = O.locatenew('P', q * N.x) + O.set_vel(N, 0) + P.set_vel(N, q.diff() * N.x) + + pathway = LinearPathway(O, P) + friction = CoulombKineticFriction(self.mu_k, self.m * self.g, pathway, v_s=self.v_s, mu_s=self.mu_s) + expected_negative = [Force(point=O, force=(self.g * self.m * self.mu_k + (-self.g * self.m * self.mu_k + self.g * self.m * self.mu_s) * exp(-q.diff()**2/self.v_s**2)) * q * sign(sqrt(q**2) * q.diff()/q)/sqrt(q**2) * N.x), + Force(point=P, force=- (self.g * self.m * self.mu_k + (-self.g * self.m * self.mu_k + self.g * self.m * self.mu_s) * exp(-q.diff()**2/self.v_s**2)) * q * sign(sqrt(q**2) * q.diff()/q)/sqrt(q**2) * N.x)] + + assert friction.to_loads() == expected_negative + + def test_block_on_surface_all(self): + # General Case + q = dynamicsymbols('q') + + N = ReferenceFrame('N') + O = Point('O') + P = O.locatenew('P', q * N.x) + O.set_vel(N, 0) + P.set_vel(N, q.diff() * N.x) + + pathway = LinearPathway(O, P) + friction = CoulombKineticFriction(self.mu_k, self.m * self.g, pathway, v_s=self.v_s, sigma=self.sigma, mu_s=self.mu_s) + expected_general = [Force(point=O, force=(self.sigma * sqrt(q**2) * q.diff()/q + (self.g * self.m * self.mu_k + (-self.g * self.m * self.mu_k + self.g * self.m * self.mu_s) * exp(-q.diff()**2/self.v_s**2)) * sign(sqrt(q**2) * q.diff()/q)) * q/sqrt(q**2) * N.x), + Force(point=P, force=(-self.sigma * sqrt(q**2) * q.diff()/q - (self.g * self.m * self.mu_k + (-self.g * self.m * self.mu_k + self.g * self.m * self.mu_s) * exp(-q.diff()**2/self.v_s**2)) * sign(sqrt(q**2) * q.diff()/q)) * q/sqrt(q**2) * N.x)] + + assert friction.to_loads() == expected_general + + # Positive + q = dynamicsymbols('q', positive=True) + + N = ReferenceFrame('N') + O = Point('O') + P = O.locatenew('P', q * N.x) + O.set_vel(N, 0) + P.set_vel(N, q.diff() * N.x) + + pathway = LinearPathway(O, P) + friction = CoulombKineticFriction(self.mu_k, self.m * self.g, pathway, v_s=self.v_s, sigma=self.sigma, mu_s=self.mu_s) + expected_positive = [Force(point=O, force=(self.sigma * q.diff() + (self.g * self.m * self.mu_k + (-self.g * self.m * self.mu_k + self.g * self.m * self.mu_s) * exp(-q.diff()**2/self.v_s**2)) * sign(q.diff())) * N.x), + Force(point=P, force=(-self.sigma * q.diff() - (self.g * self.m * self.mu_k + (-self.g * self.m * self.mu_k + self.g * self.m * self.mu_s) * exp(-q.diff()**2/self.v_s**2)) * sign(q.diff())) * N.x)] + + assert friction.to_loads() == expected_positive + + # Negative + q = dynamicsymbols('q', positive=False) + + N = ReferenceFrame('N') + O = Point('O') + P = O.locatenew('P', q * N.x) + O.set_vel(N, 0) + P.set_vel(N, q.diff() * N.x) + + pathway = LinearPathway(O, P) + friction = CoulombKineticFriction(self.mu_k, self.m * self.g, pathway, v_s=self.v_s, sigma=self.sigma, mu_s=self.mu_s) + expected_negative = [Force(point=O, force=(self.sigma * sqrt(q**2) * q.diff()/q + (self.g * self.m * self.mu_k + (-self.g * self.m * self.mu_k + self.g * self.m * self.mu_s) * exp(-q.diff()**2/self.v_s**2)) * sign(sqrt(q**2) * q.diff()/q)) * q/sqrt(q**2) * N.x), + Force(point=P, force=(-self.sigma * sqrt(q**2) * q.diff()/q - (self.g * self.m * self.mu_k + (-self.g * self.m * self.mu_k + self.g * self.m * self.mu_s) * exp(-q.diff()**2/self.v_s**2)) * sign(sqrt(q**2) * q.diff()/q)) * q/sqrt(q**2) * N.x)] + + assert friction.to_loads() == expected_negative + + def test_normal_force_zero(self): + q = dynamicsymbols('q') + + N = ReferenceFrame('N') + O = Point('O') + P = O.locatenew('P', q * N.x) + O.set_vel(N, 0) + P.set_vel(N, q.diff() * N.x) + + pathway = LinearPathway(O, P) + friction = CoulombKineticFriction( + self.mu_k, + 0, + pathway + ) + assert friction.force == 0 diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_body.py b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_body.py new file mode 100644 index 0000000000000000000000000000000000000000..2d59d747400652a0cbb081f4afc5ae4ebaa4db85 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_body.py @@ -0,0 +1,340 @@ +from sympy import (Symbol, symbols, sin, cos, Matrix, zeros, + simplify) +from sympy.physics.vector import Point, ReferenceFrame, dynamicsymbols, Dyadic +from sympy.physics.mechanics import inertia, Body +from sympy.testing.pytest import raises, warns_deprecated_sympy + + +def test_default(): + with warns_deprecated_sympy(): + body = Body('body') + assert body.name == 'body' + assert body.loads == [] + point = Point('body_masscenter') + point.set_vel(body.frame, 0) + com = body.masscenter + frame = body.frame + assert com.vel(frame) == point.vel(frame) + assert body.mass == Symbol('body_mass') + ixx, iyy, izz = symbols('body_ixx body_iyy body_izz') + ixy, iyz, izx = symbols('body_ixy body_iyz body_izx') + assert body.inertia == (inertia(body.frame, ixx, iyy, izz, ixy, iyz, izx), + body.masscenter) + + +def test_custom_rigid_body(): + # Body with RigidBody. + rigidbody_masscenter = Point('rigidbody_masscenter') + rigidbody_mass = Symbol('rigidbody_mass') + rigidbody_frame = ReferenceFrame('rigidbody_frame') + body_inertia = inertia(rigidbody_frame, 1, 0, 0) + with warns_deprecated_sympy(): + rigid_body = Body('rigidbody_body', rigidbody_masscenter, + rigidbody_mass, rigidbody_frame, body_inertia) + com = rigid_body.masscenter + frame = rigid_body.frame + rigidbody_masscenter.set_vel(rigidbody_frame, 0) + assert com.vel(frame) == rigidbody_masscenter.vel(frame) + assert com.pos_from(com) == rigidbody_masscenter.pos_from(com) + + assert rigid_body.mass == rigidbody_mass + assert rigid_body.inertia == (body_inertia, rigidbody_masscenter) + + assert rigid_body.is_rigidbody + + assert hasattr(rigid_body, 'masscenter') + assert hasattr(rigid_body, 'mass') + assert hasattr(rigid_body, 'frame') + assert hasattr(rigid_body, 'inertia') + + +def test_particle_body(): + # Body with Particle + particle_masscenter = Point('particle_masscenter') + particle_mass = Symbol('particle_mass') + particle_frame = ReferenceFrame('particle_frame') + with warns_deprecated_sympy(): + particle_body = Body('particle_body', particle_masscenter, + particle_mass, particle_frame) + com = particle_body.masscenter + frame = particle_body.frame + particle_masscenter.set_vel(particle_frame, 0) + assert com.vel(frame) == particle_masscenter.vel(frame) + assert com.pos_from(com) == particle_masscenter.pos_from(com) + + assert particle_body.mass == particle_mass + assert not hasattr(particle_body, "_inertia") + assert hasattr(particle_body, 'frame') + assert hasattr(particle_body, 'masscenter') + assert hasattr(particle_body, 'mass') + assert particle_body.inertia == (Dyadic(0), particle_body.masscenter) + assert particle_body.central_inertia == Dyadic(0) + assert not particle_body.is_rigidbody + + particle_body.central_inertia = inertia(particle_frame, 1, 1, 1) + assert particle_body.central_inertia == inertia(particle_frame, 1, 1, 1) + assert particle_body.is_rigidbody + + with warns_deprecated_sympy(): + particle_body = Body('particle_body', mass=particle_mass) + assert not particle_body.is_rigidbody + point = particle_body.masscenter.locatenew('point', particle_body.x) + point_inertia = particle_mass * inertia(particle_body.frame, 0, 1, 1) + particle_body.inertia = (point_inertia, point) + assert particle_body.inertia == (point_inertia, point) + assert particle_body.central_inertia == Dyadic(0) + assert particle_body.is_rigidbody + + +def test_particle_body_add_force(): + # Body with Particle + particle_masscenter = Point('particle_masscenter') + particle_mass = Symbol('particle_mass') + particle_frame = ReferenceFrame('particle_frame') + with warns_deprecated_sympy(): + particle_body = Body('particle_body', particle_masscenter, + particle_mass, particle_frame) + + a = Symbol('a') + force_vector = a * particle_body.frame.x + particle_body.apply_force(force_vector, particle_body.masscenter) + assert len(particle_body.loads) == 1 + point = particle_body.masscenter.locatenew( + particle_body._name + '_point0', 0) + point.set_vel(particle_body.frame, 0) + force_point = particle_body.loads[0][0] + + frame = particle_body.frame + assert force_point.vel(frame) == point.vel(frame) + assert force_point.pos_from(force_point) == point.pos_from(force_point) + + assert particle_body.loads[0][1] == force_vector + + +def test_body_add_force(): + # Body with RigidBody. + rigidbody_masscenter = Point('rigidbody_masscenter') + rigidbody_mass = Symbol('rigidbody_mass') + rigidbody_frame = ReferenceFrame('rigidbody_frame') + body_inertia = inertia(rigidbody_frame, 1, 0, 0) + with warns_deprecated_sympy(): + rigid_body = Body('rigidbody_body', rigidbody_masscenter, + rigidbody_mass, rigidbody_frame, body_inertia) + + l = Symbol('l') + Fa = Symbol('Fa') + point = rigid_body.masscenter.locatenew( + 'rigidbody_body_point0', + l * rigid_body.frame.x) + point.set_vel(rigid_body.frame, 0) + force_vector = Fa * rigid_body.frame.z + # apply_force with point + rigid_body.apply_force(force_vector, point) + assert len(rigid_body.loads) == 1 + force_point = rigid_body.loads[0][0] + frame = rigid_body.frame + assert force_point.vel(frame) == point.vel(frame) + assert force_point.pos_from(force_point) == point.pos_from(force_point) + assert rigid_body.loads[0][1] == force_vector + # apply_force without point + rigid_body.apply_force(force_vector) + assert len(rigid_body.loads) == 2 + assert rigid_body.loads[1][1] == force_vector + # passing something else than point + raises(TypeError, lambda: rigid_body.apply_force(force_vector, 0)) + raises(TypeError, lambda: rigid_body.apply_force(0)) + +def test_body_add_torque(): + with warns_deprecated_sympy(): + body = Body('body') + torque_vector = body.frame.x + body.apply_torque(torque_vector) + + assert len(body.loads) == 1 + assert body.loads[0] == (body.frame, torque_vector) + raises(TypeError, lambda: body.apply_torque(0)) + +def test_body_masscenter_vel(): + with warns_deprecated_sympy(): + A = Body('A') + N = ReferenceFrame('N') + with warns_deprecated_sympy(): + B = Body('B', frame=N) + A.masscenter.set_vel(N, N.z) + assert A.masscenter_vel(B) == N.z + assert A.masscenter_vel(N) == N.z + +def test_body_ang_vel(): + with warns_deprecated_sympy(): + A = Body('A') + N = ReferenceFrame('N') + with warns_deprecated_sympy(): + B = Body('B', frame=N) + A.frame.set_ang_vel(N, N.y) + assert A.ang_vel_in(B) == N.y + assert B.ang_vel_in(A) == -N.y + assert A.ang_vel_in(N) == N.y + +def test_body_dcm(): + with warns_deprecated_sympy(): + A = Body('A') + B = Body('B') + A.frame.orient_axis(B.frame, B.frame.z, 10) + assert A.dcm(B) == Matrix([[cos(10), sin(10), 0], [-sin(10), cos(10), 0], [0, 0, 1]]) + assert A.dcm(B.frame) == Matrix([[cos(10), sin(10), 0], [-sin(10), cos(10), 0], [0, 0, 1]]) + +def test_body_axis(): + N = ReferenceFrame('N') + with warns_deprecated_sympy(): + B = Body('B', frame=N) + assert B.x == N.x + assert B.y == N.y + assert B.z == N.z + +def test_apply_force_multiple_one_point(): + a, b = symbols('a b') + P = Point('P') + with warns_deprecated_sympy(): + B = Body('B') + f1 = a*B.x + f2 = b*B.y + B.apply_force(f1, P) + assert B.loads == [(P, f1)] + B.apply_force(f2, P) + assert B.loads == [(P, f1+f2)] + +def test_apply_force(): + f, g = symbols('f g') + q, x, v1, v2 = dynamicsymbols('q x v1 v2') + P1 = Point('P1') + P2 = Point('P2') + with warns_deprecated_sympy(): + B1 = Body('B1') + B2 = Body('B2') + N = ReferenceFrame('N') + + P1.set_vel(B1.frame, v1*B1.x) + P2.set_vel(B2.frame, v2*B2.x) + force = f*q*N.z # time varying force + + B1.apply_force(force, P1, B2, P2) #applying equal and opposite force on moving points + assert B1.loads == [(P1, force)] + assert B2.loads == [(P2, -force)] + + g1 = B1.mass*g*N.y + g2 = B2.mass*g*N.y + + B1.apply_force(g1) #applying gravity on B1 masscenter + B2.apply_force(g2) #applying gravity on B2 masscenter + + assert B1.loads == [(P1,force), (B1.masscenter, g1)] + assert B2.loads == [(P2, -force), (B2.masscenter, g2)] + + force2 = x*N.x + + B1.apply_force(force2, reaction_body=B2) #Applying time varying force on masscenter + + assert B1.loads == [(P1, force), (B1.masscenter, force2+g1)] + assert B2.loads == [(P2, -force), (B2.masscenter, -force2+g2)] + +def test_apply_torque(): + t = symbols('t') + q = dynamicsymbols('q') + with warns_deprecated_sympy(): + B1 = Body('B1') + B2 = Body('B2') + N = ReferenceFrame('N') + torque = t*q*N.x + + B1.apply_torque(torque, B2) #Applying equal and opposite torque + assert B1.loads == [(B1.frame, torque)] + assert B2.loads == [(B2.frame, -torque)] + + torque2 = t*N.y + B1.apply_torque(torque2) + assert B1.loads == [(B1.frame, torque+torque2)] + +def test_clear_load(): + a = symbols('a') + P = Point('P') + with warns_deprecated_sympy(): + B = Body('B') + force = a*B.z + B.apply_force(force, P) + assert B.loads == [(P, force)] + B.clear_loads() + assert B.loads == [] + +def test_remove_load(): + P1 = Point('P1') + P2 = Point('P2') + with warns_deprecated_sympy(): + B = Body('B') + f1 = B.x + f2 = B.y + B.apply_force(f1, P1) + B.apply_force(f2, P2) + assert B.loads == [(P1, f1), (P2, f2)] + B.remove_load(P2) + assert B.loads == [(P1, f1)] + B.apply_torque(f1.cross(f2)) + assert B.loads == [(P1, f1), (B.frame, f1.cross(f2))] + B.remove_load() + assert B.loads == [(P1, f1)] + +def test_apply_loads_on_multi_degree_freedom_holonomic_system(): + """Example based on: https://pydy.readthedocs.io/en/latest/examples/multidof-holonomic.html""" + with warns_deprecated_sympy(): + W = Body('W') #Wall + B = Body('B') #Block + P = Body('P') #Pendulum + b = Body('b') #bob + q1, q2 = dynamicsymbols('q1 q2') #generalized coordinates + k, c, g, kT = symbols('k c g kT') #constants + F, T = dynamicsymbols('F T') #Specified forces + + #Applying forces + B.apply_force(F*W.x) + W.apply_force(k*q1*W.x, reaction_body=B) #Spring force + W.apply_force(c*q1.diff()*W.x, reaction_body=B) #dampner + P.apply_force(P.mass*g*W.y) + b.apply_force(b.mass*g*W.y) + + #Applying torques + P.apply_torque(kT*q2*W.z, reaction_body=b) + P.apply_torque(T*W.z) + + assert B.loads == [(B.masscenter, (F - k*q1 - c*q1.diff())*W.x)] + assert P.loads == [(P.masscenter, P.mass*g*W.y), (P.frame, (T + kT*q2)*W.z)] + assert b.loads == [(b.masscenter, b.mass*g*W.y), (b.frame, -kT*q2*W.z)] + assert W.loads == [(W.masscenter, (c*q1.diff() + k*q1)*W.x)] + + +def test_parallel_axis(): + N = ReferenceFrame('N') + m, Ix, Iy, Iz, a, b = symbols('m, I_x, I_y, I_z, a, b') + Io = inertia(N, Ix, Iy, Iz) + # Test RigidBody + o = Point('o') + p = o.locatenew('p', a * N.x + b * N.y) + with warns_deprecated_sympy(): + R = Body('R', masscenter=o, frame=N, mass=m, central_inertia=Io) + Ip = R.parallel_axis(p) + Ip_expected = inertia(N, Ix + m * b**2, Iy + m * a**2, + Iz + m * (a**2 + b**2), ixy=-m * a * b) + assert Ip == Ip_expected + # Reference frame from which the parallel axis is viewed should not matter + A = ReferenceFrame('A') + A.orient_axis(N, N.z, 1) + assert simplify( + (R.parallel_axis(p, A) - Ip_expected).to_matrix(A)) == zeros(3, 3) + # Test Particle + o = Point('o') + p = o.locatenew('p', a * N.x + b * N.y) + with warns_deprecated_sympy(): + P = Body('P', masscenter=o, mass=m, frame=N) + Ip = P.parallel_axis(p, N) + Ip_expected = inertia(N, m * b ** 2, m * a ** 2, m * (a ** 2 + b ** 2), + ixy=-m * a * b) + assert not P.is_rigidbody + assert Ip == Ip_expected diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_functions.py b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..bae6b19b2807dca1632942bd3717e29d214eb269 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_functions.py @@ -0,0 +1,262 @@ +from sympy import sin, cos, tan, pi, symbols, Matrix, S, Function +from sympy.physics.mechanics import (Particle, Point, ReferenceFrame, + RigidBody) +from sympy.physics.mechanics import (angular_momentum, dynamicsymbols, + kinetic_energy, linear_momentum, + outer, potential_energy, msubs, + find_dynamicsymbols, Lagrangian) + +from sympy.physics.mechanics.functions import ( + center_of_mass, _validate_coordinates, _parse_linear_solver) +from sympy.testing.pytest import raises, warns_deprecated_sympy + + +q1, q2, q3, q4, q5 = symbols('q1 q2 q3 q4 q5') +N = ReferenceFrame('N') +A = N.orientnew('A', 'Axis', [q1, N.z]) +B = A.orientnew('B', 'Axis', [q2, A.x]) +C = B.orientnew('C', 'Axis', [q3, B.y]) + + +def test_linear_momentum(): + N = ReferenceFrame('N') + Ac = Point('Ac') + Ac.set_vel(N, 25 * N.y) + I = outer(N.x, N.x) + A = RigidBody('A', Ac, N, 20, (I, Ac)) + P = Point('P') + Pa = Particle('Pa', P, 1) + Pa.point.set_vel(N, 10 * N.x) + raises(TypeError, lambda: linear_momentum(A, A, Pa)) + raises(TypeError, lambda: linear_momentum(N, N, Pa)) + assert linear_momentum(N, A, Pa) == 10 * N.x + 500 * N.y + + +def test_angular_momentum_and_linear_momentum(): + """A rod with length 2l, centroidal inertia I, and mass M along with a + particle of mass m fixed to the end of the rod rotate with an angular rate + of omega about point O which is fixed to the non-particle end of the rod. + The rod's reference frame is A and the inertial frame is N.""" + m, M, l, I = symbols('m, M, l, I') + omega = dynamicsymbols('omega') + N = ReferenceFrame('N') + a = ReferenceFrame('a') + O = Point('O') + Ac = O.locatenew('Ac', l * N.x) + P = Ac.locatenew('P', l * N.x) + O.set_vel(N, 0 * N.x) + a.set_ang_vel(N, omega * N.z) + Ac.v2pt_theory(O, N, a) + P.v2pt_theory(O, N, a) + Pa = Particle('Pa', P, m) + A = RigidBody('A', Ac, a, M, (I * outer(N.z, N.z), Ac)) + expected = 2 * m * omega * l * N.y + M * l * omega * N.y + assert linear_momentum(N, A, Pa) == expected + raises(TypeError, lambda: angular_momentum(N, N, A, Pa)) + raises(TypeError, lambda: angular_momentum(O, O, A, Pa)) + raises(TypeError, lambda: angular_momentum(O, N, O, Pa)) + expected = (I + M * l**2 + 4 * m * l**2) * omega * N.z + assert angular_momentum(O, N, A, Pa) == expected + + +def test_kinetic_energy(): + m, M, l1 = symbols('m M l1') + omega = dynamicsymbols('omega') + N = ReferenceFrame('N') + O = Point('O') + O.set_vel(N, 0 * N.x) + Ac = O.locatenew('Ac', l1 * N.x) + P = Ac.locatenew('P', l1 * N.x) + a = ReferenceFrame('a') + a.set_ang_vel(N, omega * N.z) + Ac.v2pt_theory(O, N, a) + P.v2pt_theory(O, N, a) + Pa = Particle('Pa', P, m) + I = outer(N.z, N.z) + A = RigidBody('A', Ac, a, M, (I, Ac)) + raises(TypeError, lambda: kinetic_energy(Pa, Pa, A)) + raises(TypeError, lambda: kinetic_energy(N, N, A)) + assert 0 == (kinetic_energy(N, Pa, A) - (M*l1**2*omega**2/2 + + 2*l1**2*m*omega**2 + omega**2/2)).expand() + + +def test_potential_energy(): + m, M, l1, g, h, H = symbols('m M l1 g h H') + omega = dynamicsymbols('omega') + N = ReferenceFrame('N') + O = Point('O') + O.set_vel(N, 0 * N.x) + Ac = O.locatenew('Ac', l1 * N.x) + P = Ac.locatenew('P', l1 * N.x) + a = ReferenceFrame('a') + a.set_ang_vel(N, omega * N.z) + Ac.v2pt_theory(O, N, a) + P.v2pt_theory(O, N, a) + Pa = Particle('Pa', P, m) + I = outer(N.z, N.z) + A = RigidBody('A', Ac, a, M, (I, Ac)) + Pa.potential_energy = m * g * h + A.potential_energy = M * g * H + assert potential_energy(A, Pa) == m * g * h + M * g * H + + +def test_Lagrangian(): + M, m, g, h = symbols('M m g h') + N = ReferenceFrame('N') + O = Point('O') + O.set_vel(N, 0 * N.x) + P = O.locatenew('P', 1 * N.x) + P.set_vel(N, 10 * N.x) + Pa = Particle('Pa', P, 1) + Ac = O.locatenew('Ac', 2 * N.y) + Ac.set_vel(N, 5 * N.y) + a = ReferenceFrame('a') + a.set_ang_vel(N, 10 * N.z) + I = outer(N.z, N.z) + A = RigidBody('A', Ac, a, 20, (I, Ac)) + Pa.potential_energy = m * g * h + A.potential_energy = M * g * h + raises(TypeError, lambda: Lagrangian(A, A, Pa)) + raises(TypeError, lambda: Lagrangian(N, N, Pa)) + + +def test_msubs(): + a, b = symbols('a, b') + x, y, z = dynamicsymbols('x, y, z') + # Test simple substitution + expr = Matrix([[a*x + b, x*y.diff() + y], + [x.diff().diff(), z + sin(z.diff())]]) + sol = Matrix([[a + b, y], + [x.diff().diff(), 1]]) + sd = {x: 1, z: 1, z.diff(): 0, y.diff(): 0} + assert msubs(expr, sd) == sol + # Test smart substitution + expr = cos(x + y)*tan(x + y) + b*x.diff() + sd = {x: 0, y: pi/2, x.diff(): 1} + assert msubs(expr, sd, smart=True) == b + 1 + N = ReferenceFrame('N') + v = x*N.x + y*N.y + d = x*(N.x|N.x) + y*(N.y|N.y) + v_sol = 1*N.y + d_sol = 1*(N.y|N.y) + sd = {x: 0, y: 1} + assert msubs(v, sd) == v_sol + assert msubs(d, sd) == d_sol + + +def test_find_dynamicsymbols(): + a, b = symbols('a, b') + x, y, z = dynamicsymbols('x, y, z') + expr = Matrix([[a*x + b, x*y.diff() + y], + [x.diff().diff(), z + sin(z.diff())]]) + # Test finding all dynamicsymbols + sol = {x, y.diff(), y, x.diff().diff(), z, z.diff()} + assert find_dynamicsymbols(expr) == sol + # Test finding all but those in sym_list + exclude_list = [x, y, z] + sol = {y.diff(), x.diff().diff(), z.diff()} + assert find_dynamicsymbols(expr, exclude=exclude_list) == sol + # Test finding all dynamicsymbols in a vector with a given reference frame + d, e, f = dynamicsymbols('d, e, f') + A = ReferenceFrame('A') + v = d * A.x + e * A.y + f * A.z + sol = {d, e, f} + assert find_dynamicsymbols(v, reference_frame=A) == sol + # Test if a ValueError is raised on supplying only a vector as input + raises(ValueError, lambda: find_dynamicsymbols(v)) + + +# This function tests the center_of_mass() function +# that was added in PR #14758 to compute the center of +# mass of a system of bodies. +def test_center_of_mass(): + a = ReferenceFrame('a') + m = symbols('m', real=True) + p1 = Particle('p1', Point('p1_pt'), S.One) + p2 = Particle('p2', Point('p2_pt'), S(2)) + p3 = Particle('p3', Point('p3_pt'), S(3)) + p4 = Particle('p4', Point('p4_pt'), m) + b_f = ReferenceFrame('b_f') + b_cm = Point('b_cm') + mb = symbols('mb') + b = RigidBody('b', b_cm, b_f, mb, (outer(b_f.x, b_f.x), b_cm)) + p2.point.set_pos(p1.point, a.x) + p3.point.set_pos(p1.point, a.x + a.y) + p4.point.set_pos(p1.point, a.y) + b.masscenter.set_pos(p1.point, a.y + a.z) + point_o=Point('o') + point_o.set_pos(p1.point, center_of_mass(p1.point, p1, p2, p3, p4, b)) + expr = 5/(m + mb + 6)*a.x + (m + mb + 3)/(m + mb + 6)*a.y + mb/(m + mb + 6)*a.z + assert point_o.pos_from(p1.point)-expr == 0 + + +def test_validate_coordinates(): + q1, q2, q3, u1, u2, u3, ua1, ua2, ua3 = dynamicsymbols('q1:4 u1:4 ua1:4') + s1, s2, s3 = symbols('s1:4') + # Test normal + _validate_coordinates([q1, q2, q3], [u1, u2, u3], + u_auxiliary=[ua1, ua2, ua3]) + # Test not equal number of coordinates and speeds + _validate_coordinates([q1, q2]) + _validate_coordinates([q1, q2], [u1]) + _validate_coordinates(speeds=[u1, u2]) + # Test duplicate + _validate_coordinates([q1, q2, q2], [u1, u2, u3], check_duplicates=False) + raises(ValueError, lambda: _validate_coordinates( + [q1, q2, q2], [u1, u2, u3])) + _validate_coordinates([q1, q2, q3], [u1, u2, u2], check_duplicates=False) + raises(ValueError, lambda: _validate_coordinates( + [q1, q2, q3], [u1, u2, u2], check_duplicates=True)) + raises(ValueError, lambda: _validate_coordinates( + [q1, q2, q3], [q1, u2, u3], check_duplicates=True)) + _validate_coordinates([q1, q2, q3], [u1, u2, u3], check_duplicates=False, + u_auxiliary=[u1, ua2, ua2]) + raises(ValueError, lambda: _validate_coordinates( + [q1, q2, q3], [u1, u2, u3], u_auxiliary=[u1, ua2, ua3])) + raises(ValueError, lambda: _validate_coordinates( + [q1, q2, q3], [u1, u2, u3], u_auxiliary=[q1, ua2, ua3])) + raises(ValueError, lambda: _validate_coordinates( + [q1, q2, q3], [u1, u2, u3], u_auxiliary=[ua1, ua2, ua2])) + # Test is_dynamicsymbols + _validate_coordinates([q1 + q2, q3], is_dynamicsymbols=False) + raises(ValueError, lambda: _validate_coordinates([q1 + q2, q3])) + _validate_coordinates([s1, q1, q2], [0, u1, u2], is_dynamicsymbols=False) + raises(ValueError, lambda: _validate_coordinates( + [s1, q1, q2], [0, u1, u2], is_dynamicsymbols=True)) + _validate_coordinates([s1 + s2 + s3, q1], [0, u1], is_dynamicsymbols=False) + raises(ValueError, lambda: _validate_coordinates( + [s1 + s2 + s3, q1], [0, u1], is_dynamicsymbols=True)) + _validate_coordinates(u_auxiliary=[s1, ua1], is_dynamicsymbols=False) + raises(ValueError, lambda: _validate_coordinates(u_auxiliary=[s1, ua1])) + # Test normal function + t = dynamicsymbols._t + a = symbols('a') + f1, f2 = symbols('f1:3', cls=Function) + _validate_coordinates([f1(a), f2(a)], is_dynamicsymbols=False) + raises(ValueError, lambda: _validate_coordinates([f1(a), f2(a)])) + raises(ValueError, lambda: _validate_coordinates(speeds=[f1(a), f2(a)])) + dynamicsymbols._t = a + _validate_coordinates([f1(a), f2(a)]) + raises(ValueError, lambda: _validate_coordinates([f1(t), f2(t)])) + dynamicsymbols._t = t + + +def test_parse_linear_solver(): + A, b = Matrix(3, 3, symbols('a:9')), Matrix(3, 2, symbols('b:6')) + assert _parse_linear_solver(Matrix.LUsolve) == Matrix.LUsolve # Test callable + assert _parse_linear_solver('LU')(A, b) == Matrix.LUsolve(A, b) + + +def test_deprecated_moved_functions(): + from sympy.physics.mechanics.functions import ( + inertia, inertia_of_point_mass, gravity) + N = ReferenceFrame('N') + with warns_deprecated_sympy(): + assert inertia(N, 0, 1, 0, 1) == (N.x | N.y) + (N.y | N.x) + (N.y | N.y) + with warns_deprecated_sympy(): + assert inertia_of_point_mass(1, N.x + N.y, N) == ( + (N.x | N.x) + (N.y | N.y) + 2 * (N.z | N.z) - + (N.x | N.y) - (N.y | N.x)) + p = Particle('P') + with warns_deprecated_sympy(): + assert gravity(-2 * N.z, p) == [(p.masscenter, -2 * p.mass * N.z)] diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_inertia.py b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_inertia.py new file mode 100644 index 0000000000000000000000000000000000000000..8d29e5f31868e539c4b50575af5180e5eb96f2cd --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_inertia.py @@ -0,0 +1,71 @@ +from sympy import symbols +from sympy.testing.pytest import raises +from sympy.physics.mechanics import (inertia, inertia_of_point_mass, + Inertia, ReferenceFrame, Point) + + +def test_inertia_dyadic(): + N = ReferenceFrame('N') + ixx, iyy, izz = symbols('ixx iyy izz') + ixy, iyz, izx = symbols('ixy iyz izx') + assert inertia(N, ixx, iyy, izz) == (ixx * (N.x | N.x) + iyy * + (N.y | N.y) + izz * (N.z | N.z)) + assert inertia(N, 0, 0, 0) == 0 * (N.x | N.x) + raises(TypeError, lambda: inertia(0, 0, 0, 0)) + assert inertia(N, ixx, iyy, izz, ixy, iyz, izx) == (ixx * (N.x | N.x) + + ixy * (N.x | N.y) + izx * (N.x | N.z) + ixy * (N.y | N.x) + iyy * + (N.y | N.y) + iyz * (N.y | N.z) + izx * (N.z | N.x) + iyz * (N.z | + N.y) + izz * (N.z | N.z)) + + +def test_inertia_of_point_mass(): + r, s, t, m = symbols('r s t m') + N = ReferenceFrame('N') + + px = r * N.x + I = inertia_of_point_mass(m, px, N) + assert I == m * r**2 * (N.y | N.y) + m * r**2 * (N.z | N.z) + + py = s * N.y + I = inertia_of_point_mass(m, py, N) + assert I == m * s**2 * (N.x | N.x) + m * s**2 * (N.z | N.z) + + pz = t * N.z + I = inertia_of_point_mass(m, pz, N) + assert I == m * t**2 * (N.x | N.x) + m * t**2 * (N.y | N.y) + + p = px + py + pz + I = inertia_of_point_mass(m, p, N) + assert I == (m * (s**2 + t**2) * (N.x | N.x) - + m * r * s * (N.x | N.y) - + m * r * t * (N.x | N.z) - + m * r * s * (N.y | N.x) + + m * (r**2 + t**2) * (N.y | N.y) - + m * s * t * (N.y | N.z) - + m * r * t * (N.z | N.x) - + m * s * t * (N.z | N.y) + + m * (r**2 + s**2) * (N.z | N.z)) + + +def test_inertia_object(): + N = ReferenceFrame('N') + O = Point('O') + ixx, iyy, izz = symbols('ixx iyy izz') + I_dyadic = ixx * (N.x | N.x) + iyy * (N.y | N.y) + izz * (N.z | N.z) + I = Inertia(inertia(N, ixx, iyy, izz), O) + assert isinstance(I, tuple) + assert I.__repr__() == ('Inertia(dyadic=ixx*(N.x|N.x) + iyy*(N.y|N.y) + ' + 'izz*(N.z|N.z), point=O)') + assert I.dyadic == I_dyadic + assert I.point == O + assert I[0] == I_dyadic + assert I[1] == O + assert I == (I_dyadic, O) # Test tuple equal + raises(TypeError, lambda: I != (O, I_dyadic)) # Incorrect tuple order + assert I == Inertia(O, I_dyadic) # Parse changed argument order + assert I == Inertia.from_inertia_scalars(O, N, ixx, iyy, izz) + # Test invalid tuple operations + raises(TypeError, lambda: I + (1, 2)) + raises(TypeError, lambda: (1, 2) + I) + raises(TypeError, lambda: I * 2) + raises(TypeError, lambda: 2 * I) diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_joint.py b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_joint.py new file mode 100644 index 0000000000000000000000000000000000000000..271801b5b7290a4479ee61e1414741e4c4d6f966 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_joint.py @@ -0,0 +1,1240 @@ +from sympy.core.function import expand_mul +from sympy.core.numbers import pi +from sympy.core.singleton import S +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy import Matrix, simplify, eye, zeros +from sympy.core.symbol import symbols +from sympy.physics.mechanics import ( + dynamicsymbols, RigidBody, Particle, JointsMethod, PinJoint, PrismaticJoint, + CylindricalJoint, PlanarJoint, SphericalJoint, WeldJoint, Body) +from sympy.physics.mechanics.joint import Joint +from sympy.physics.vector import Vector, ReferenceFrame, Point +from sympy.testing.pytest import raises, warns_deprecated_sympy + + +t = dynamicsymbols._t # type: ignore + + +def _generate_body(interframe=False): + N = ReferenceFrame('N') + A = ReferenceFrame('A') + P = RigidBody('P', frame=N) + C = RigidBody('C', frame=A) + if interframe: + Pint, Cint = ReferenceFrame('P_int'), ReferenceFrame('C_int') + Pint.orient_axis(N, N.x, pi) + Cint.orient_axis(A, A.y, -pi / 2) + return N, A, P, C, Pint, Cint + return N, A, P, C + + +def test_Joint(): + parent = RigidBody('parent') + child = RigidBody('child') + raises(TypeError, lambda: Joint('J', parent, child)) + + +def test_coordinate_generation(): + q, u, qj, uj = dynamicsymbols('q u q_J u_J') + q0j, q1j, q2j, q3j, u0j, u1j, u2j, u3j = dynamicsymbols('q0:4_J u0:4_J') + q0, q1, q2, q3, u0, u1, u2, u3 = dynamicsymbols('q0:4 u0:4') + _, _, P, C = _generate_body() + # Using PinJoint to access Joint's coordinate generation method + J = PinJoint('J', P, C) + # Test single given + assert J._fill_coordinate_list(q, 1) == Matrix([q]) + assert J._fill_coordinate_list([u], 1) == Matrix([u]) + assert J._fill_coordinate_list([u], 1, offset=2) == Matrix([u]) + # Test None + assert J._fill_coordinate_list(None, 1) == Matrix([qj]) + assert J._fill_coordinate_list([None], 1) == Matrix([qj]) + assert J._fill_coordinate_list([q0, None, None], 3) == Matrix( + [q0, q1j, q2j]) + # Test autofill + assert J._fill_coordinate_list(None, 3) == Matrix([q0j, q1j, q2j]) + assert J._fill_coordinate_list([], 3) == Matrix([q0j, q1j, q2j]) + # Test offset + assert J._fill_coordinate_list([], 3, offset=1) == Matrix([q1j, q2j, q3j]) + assert J._fill_coordinate_list([q1, None, q3], 3, offset=1) == Matrix( + [q1, q2j, q3]) + assert J._fill_coordinate_list(None, 2, offset=2) == Matrix([q2j, q3j]) + # Test label + assert J._fill_coordinate_list(None, 1, 'u') == Matrix([uj]) + assert J._fill_coordinate_list([], 3, 'u') == Matrix([u0j, u1j, u2j]) + # Test single numbering + assert J._fill_coordinate_list(None, 1, number_single=True) == Matrix([q0j]) + assert J._fill_coordinate_list([], 1, 'u', 2, True) == Matrix([u2j]) + assert J._fill_coordinate_list([], 3, 'q') == Matrix([q0j, q1j, q2j]) + # Test invalid number of coordinates supplied + raises(ValueError, lambda: J._fill_coordinate_list([q0, q1], 1)) + raises(ValueError, lambda: J._fill_coordinate_list([u0, u1, None], 2, 'u')) + raises(ValueError, lambda: J._fill_coordinate_list([q0, q1], 3)) + # Test incorrect coordinate type + raises(TypeError, lambda: J._fill_coordinate_list([q0, symbols('q1')], 2)) + raises(TypeError, lambda: J._fill_coordinate_list([q0 + q1, q1], 2)) + # Test if derivative as generalized speed is allowed + _, _, P, C = _generate_body() + PinJoint('J', P, C, q1, q1.diff(t)) + # Test duplicate coordinates + _, _, P, C = _generate_body() + raises(ValueError, lambda: SphericalJoint('J', P, C, [q1j, None, None])) + raises(ValueError, lambda: SphericalJoint('J', P, C, speeds=[u0, u0, u1])) + + +def test_pin_joint(): + P = RigidBody('P') + C = RigidBody('C') + l, m = symbols('l m') + q, u = dynamicsymbols('q_J, u_J') + Pj = PinJoint('J', P, C) + assert Pj.name == 'J' + assert Pj.parent == P + assert Pj.child == C + assert Pj.coordinates == Matrix([q]) + assert Pj.speeds == Matrix([u]) + assert Pj.kdes == Matrix([u - q.diff(t)]) + assert Pj.joint_axis == P.frame.x + assert Pj.child_point.pos_from(C.masscenter) == Vector(0) + assert Pj.parent_point.pos_from(P.masscenter) == Vector(0) + assert Pj.parent_point.pos_from(Pj._child_point) == Vector(0) + assert C.masscenter.pos_from(P.masscenter) == Vector(0) + assert Pj.parent_interframe == P.frame + assert Pj.child_interframe == C.frame + assert Pj.__str__() == 'PinJoint: J parent: P child: C' + + P1 = RigidBody('P1') + C1 = RigidBody('C1') + Pint = ReferenceFrame('P_int') + Pint.orient_axis(P1.frame, P1.y, pi / 2) + J1 = PinJoint('J1', P1, C1, parent_point=l*P1.frame.x, + child_point=m*C1.frame.y, joint_axis=P1.frame.z, + parent_interframe=Pint) + assert J1._joint_axis == P1.frame.z + assert J1._child_point.pos_from(C1.masscenter) == m * C1.frame.y + assert J1._parent_point.pos_from(P1.masscenter) == l * P1.frame.x + assert J1._parent_point.pos_from(J1._child_point) == Vector(0) + assert (P1.masscenter.pos_from(C1.masscenter) == + -l*P1.frame.x + m*C1.frame.y) + assert J1.parent_interframe == Pint + assert J1.child_interframe == C1.frame + + q, u = dynamicsymbols('q, u') + N, A, P, C, Pint, Cint = _generate_body(True) + parent_point = P.masscenter.locatenew('parent_point', N.x + N.y) + child_point = C.masscenter.locatenew('child_point', C.y + C.z) + J = PinJoint('J', P, C, q, u, parent_point=parent_point, + child_point=child_point, parent_interframe=Pint, + child_interframe=Cint, joint_axis=N.z) + assert J.joint_axis == N.z + assert J.parent_point.vel(N) == 0 + assert J.parent_point == parent_point + assert J.child_point == child_point + assert J.child_point.pos_from(P.masscenter) == N.x + N.y + assert J.parent_point.pos_from(C.masscenter) == C.y + C.z + assert C.masscenter.pos_from(P.masscenter) == N.x + N.y - C.y - C.z + assert C.masscenter.vel(N).express(N) == (u * sin(q) - u * cos(q)) * N.x + ( + -u * sin(q) - u * cos(q)) * N.y + assert J.parent_interframe == Pint + assert J.child_interframe == Cint + + +def test_particle_compatibility(): + m, l = symbols('m l') + C_frame = ReferenceFrame('C') + P = Particle('P') + C = Particle('C', mass=m) + q, u = dynamicsymbols('q, u') + J = PinJoint('J', P, C, q, u, child_interframe=C_frame, + child_point=l * C_frame.y) + assert J.child_interframe == C_frame + assert J.parent_interframe.name == 'J_P_frame' + assert C.masscenter.pos_from(P.masscenter) == -l * C_frame.y + assert C_frame.dcm(J.parent_interframe) == Matrix([[1, 0, 0], + [0, cos(q), sin(q)], + [0, -sin(q), cos(q)]]) + assert C.masscenter.vel(J.parent_interframe) == -l * u * C_frame.z + # Test with specified joint axis + P_frame = ReferenceFrame('P') + C_frame = ReferenceFrame('C') + P = Particle('P') + C = Particle('C', mass=m) + q, u = dynamicsymbols('q, u') + J = PinJoint('J', P, C, q, u, parent_interframe=P_frame, + child_interframe=C_frame, child_point=l * C_frame.y, + joint_axis=P_frame.z) + assert J.joint_axis == J.parent_interframe.z + assert C_frame.dcm(J.parent_interframe) == Matrix([[cos(q), sin(q), 0], + [-sin(q), cos(q), 0], + [0, 0, 1]]) + assert P.masscenter.vel(J.parent_interframe) == 0 + assert C.masscenter.vel(J.parent_interframe) == l * u * C_frame.x + q1, q2, q3, u1, u2, u3 = dynamicsymbols('q1:4 u1:4') + qdot_to_u = {qi.diff(t): ui for qi, ui in ((q1, u1), (q2, u2), (q3, u3))} + # Test compatibility for prismatic joint + P, C = Particle('P'), Particle('C') + J = PrismaticJoint('J', P, C, q, u) + assert J.parent_interframe.dcm(J.child_interframe) == eye(3) + assert C.masscenter.pos_from(P.masscenter) == q * J.parent_interframe.x + assert P.masscenter.vel(J.parent_interframe) == 0 + assert C.masscenter.vel(J.parent_interframe) == u * J.parent_interframe.x + # Test compatibility for cylindrical joint + P, C = Particle('P'), Particle('C') + P_frame = ReferenceFrame('P_frame') + J = CylindricalJoint('J', P, C, q1, q2, u1, u2, parent_interframe=P_frame, + parent_point=l * P_frame.x, joint_axis=P_frame.y) + assert J.parent_interframe.dcm(J.child_interframe) == Matrix([ + [cos(q1), 0, sin(q1)], [0, 1, 0], [-sin(q1), 0, cos(q1)]]) + assert C.masscenter.pos_from(P.masscenter) == l * P_frame.x + q2 * P_frame.y + assert C.masscenter.vel(J.parent_interframe) == u2 * P_frame.y + assert P.masscenter.vel(J.child_interframe).xreplace(qdot_to_u) == ( + -u2 * P_frame.y - l * u1 * P_frame.z) + # Test compatibility for planar joint + P, C = Particle('P'), Particle('C') + C_frame = ReferenceFrame('C_frame') + J = PlanarJoint('J', P, C, q1, [q2, q3], u1, [u2, u3], + child_interframe=C_frame, child_point=l * C_frame.z) + P_frame = J.parent_interframe + assert J.parent_interframe.dcm(J.child_interframe) == Matrix([ + [1, 0, 0], [0, cos(q1), -sin(q1)], [0, sin(q1), cos(q1)]]) + assert C.masscenter.pos_from(P.masscenter) == ( + -l * C_frame.z + q2 * P_frame.y + q3 * P_frame.z) + assert C.masscenter.vel(J.parent_interframe) == ( + l * u1 * C_frame.y + u2 * P_frame.y + u3 * P_frame.z) + # Test compatibility for weld joint + P, C = Particle('P'), Particle('C') + C_frame, P_frame = ReferenceFrame('C_frame'), ReferenceFrame('P_frame') + J = WeldJoint('J', P, C, parent_interframe=P_frame, + child_interframe=C_frame, parent_point=l * P_frame.x, + child_point=l * C_frame.y) + assert P_frame.dcm(C_frame) == eye(3) + assert C.masscenter.pos_from(P.masscenter) == l * P_frame.x - l * C_frame.y + assert C.masscenter.vel(J.parent_interframe) == 0 + + +def test_body_compatibility(): + m, l = symbols('m l') + C_frame = ReferenceFrame('C') + with warns_deprecated_sympy(): + P = Body('P') + C = Body('C', mass=m, frame=C_frame) + q, u = dynamicsymbols('q, u') + PinJoint('J', P, C, q, u, child_point=l * C_frame.y) + assert C.frame == C_frame + assert P.frame.name == 'P_frame' + assert C.masscenter.pos_from(P.masscenter) == -l * C.y + assert C.frame.dcm(P.frame) == Matrix([[1, 0, 0], + [0, cos(q), sin(q)], + [0, -sin(q), cos(q)]]) + assert C.masscenter.vel(P.frame) == -l * u * C.z + + +def test_pin_joint_double_pendulum(): + q1, q2 = dynamicsymbols('q1 q2') + u1, u2 = dynamicsymbols('u1 u2') + m, l = symbols('m l') + N = ReferenceFrame('N') + A = ReferenceFrame('A') + B = ReferenceFrame('B') + C = RigidBody('C', frame=N) # ceiling + PartP = RigidBody('P', frame=A, mass=m) + PartR = RigidBody('R', frame=B, mass=m) + + J1 = PinJoint('J1', C, PartP, speeds=u1, coordinates=q1, + child_point=-l*A.x, joint_axis=C.frame.z) + J2 = PinJoint('J2', PartP, PartR, speeds=u2, coordinates=q2, + child_point=-l*B.x, joint_axis=PartP.frame.z) + + # Check orientation + assert N.dcm(A) == Matrix([[cos(q1), -sin(q1), 0], + [sin(q1), cos(q1), 0], [0, 0, 1]]) + assert A.dcm(B) == Matrix([[cos(q2), -sin(q2), 0], + [sin(q2), cos(q2), 0], [0, 0, 1]]) + assert simplify(N.dcm(B)) == Matrix([[cos(q1 + q2), -sin(q1 + q2), 0], + [sin(q1 + q2), cos(q1 + q2), 0], + [0, 0, 1]]) + + # Check Angular Velocity + assert A.ang_vel_in(N) == u1 * N.z + assert B.ang_vel_in(A) == u2 * A.z + assert B.ang_vel_in(N) == u1 * N.z + u2 * A.z + + # Check kde + assert J1.kdes == Matrix([u1 - q1.diff(t)]) + assert J2.kdes == Matrix([u2 - q2.diff(t)]) + + # Check Linear Velocity + assert PartP.masscenter.vel(N) == l*u1*A.y + assert PartR.masscenter.vel(A) == l*u2*B.y + assert PartR.masscenter.vel(N) == l*u1*A.y + l*(u1 + u2)*B.y + + +def test_pin_joint_chaos_pendulum(): + mA, mB, lA, lB, h = symbols('mA, mB, lA, lB, h') + theta, phi, omega, alpha = dynamicsymbols('theta phi omega alpha') + N = ReferenceFrame('N') + A = ReferenceFrame('A') + B = ReferenceFrame('B') + lA = (lB - h / 2) / 2 + lC = (lB/2 + h/4) + rod = RigidBody('rod', frame=A, mass=mA) + plate = RigidBody('plate', mass=mB, frame=B) + C = RigidBody('C', frame=N) + J1 = PinJoint('J1', C, rod, coordinates=theta, speeds=omega, + child_point=lA*A.z, joint_axis=N.y) + J2 = PinJoint('J2', rod, plate, coordinates=phi, speeds=alpha, + parent_point=lC*A.z, joint_axis=A.z) + + # Check orientation + assert A.dcm(N) == Matrix([[cos(theta), 0, -sin(theta)], + [0, 1, 0], + [sin(theta), 0, cos(theta)]]) + assert A.dcm(B) == Matrix([[cos(phi), -sin(phi), 0], + [sin(phi), cos(phi), 0], + [0, 0, 1]]) + assert B.dcm(N) == Matrix([ + [cos(phi)*cos(theta), sin(phi), -sin(theta)*cos(phi)], + [-sin(phi)*cos(theta), cos(phi), sin(phi)*sin(theta)], + [sin(theta), 0, cos(theta)]]) + + # Check Angular Velocity + assert A.ang_vel_in(N) == omega*N.y + assert A.ang_vel_in(B) == -alpha*A.z + assert N.ang_vel_in(B) == -omega*N.y - alpha*A.z + + # Check kde + assert J1.kdes == Matrix([omega - theta.diff(t)]) + assert J2.kdes == Matrix([alpha - phi.diff(t)]) + + # Check pos of masscenters + assert C.masscenter.pos_from(rod.masscenter) == lA*A.z + assert rod.masscenter.pos_from(plate.masscenter) == - lC * A.z + + # Check Linear Velocities + assert rod.masscenter.vel(N) == (h/4 - lB/2)*omega*A.x + assert plate.masscenter.vel(N) == ((h/4 - lB/2)*omega + + (h/4 + lB/2)*omega)*A.x + + +def test_pin_joint_interframe(): + q, u = dynamicsymbols('q, u') + # Check not connected + N, A, P, C = _generate_body() + Pint, Cint = ReferenceFrame('Pint'), ReferenceFrame('Cint') + raises(ValueError, lambda: PinJoint('J', P, C, parent_interframe=Pint)) + raises(ValueError, lambda: PinJoint('J', P, C, child_interframe=Cint)) + # Check not fixed interframe + Pint.orient_axis(N, N.z, q) + Cint.orient_axis(A, A.z, q) + raises(ValueError, lambda: PinJoint('J', P, C, parent_interframe=Pint)) + raises(ValueError, lambda: PinJoint('J', P, C, child_interframe=Cint)) + # Check only parent_interframe + N, A, P, C = _generate_body() + Pint = ReferenceFrame('Pint') + Pint.orient_body_fixed(N, (pi / 4, pi, pi / 3), 'xyz') + PinJoint('J', P, C, q, u, parent_point=N.x, child_point=-C.y, + parent_interframe=Pint, joint_axis=Pint.x) + assert simplify(N.dcm(A)) - Matrix([ + [-1 / 2, sqrt(3) * cos(q) / 2, -sqrt(3) * sin(q) / 2], + [sqrt(6) / 4, sqrt(2) * (2 * sin(q) + cos(q)) / 4, + sqrt(2) * (-sin(q) + 2 * cos(q)) / 4], + [sqrt(6) / 4, sqrt(2) * (-2 * sin(q) + cos(q)) / 4, + -sqrt(2) * (sin(q) + 2 * cos(q)) / 4]]) == zeros(3) + assert A.ang_vel_in(N) == u * Pint.x + assert C.masscenter.pos_from(P.masscenter) == N.x + A.y + assert C.masscenter.vel(N) == u * A.z + assert P.masscenter.vel(Pint) == Vector(0) + assert C.masscenter.vel(Pint) == u * A.z + # Check only child_interframe + N, A, P, C = _generate_body() + Cint = ReferenceFrame('Cint') + Cint.orient_body_fixed(A, (2 * pi / 3, -pi, pi / 2), 'xyz') + PinJoint('J', P, C, q, u, parent_point=-N.z, child_point=C.x, + child_interframe=Cint, joint_axis=P.x + P.z) + assert simplify(N.dcm(A)) == Matrix([ + [-sqrt(2) * sin(q) / 2, + -sqrt(3) * (cos(q) - 1) / 4 - cos(q) / 4 - S(1) / 4, + sqrt(3) * (cos(q) + 1) / 4 - cos(q) / 4 + S(1) / 4], + [cos(q), (sqrt(2) + sqrt(6)) * -sin(q) / 4, + (-sqrt(2) + sqrt(6)) * sin(q) / 4], + [sqrt(2) * sin(q) / 2, + sqrt(3) * (cos(q) + 1) / 4 + cos(q) / 4 - S(1) / 4, + sqrt(3) * (1 - cos(q)) / 4 + cos(q) / 4 + S(1) / 4]]) + assert A.ang_vel_in(N) == sqrt(2) * u / 2 * N.x + sqrt(2) * u / 2 * N.z + assert C.masscenter.pos_from(P.masscenter) == - N.z - A.x + assert C.masscenter.vel(N).simplify() == ( + -sqrt(6) - sqrt(2)) * u / 4 * A.y + ( + -sqrt(2) + sqrt(6)) * u / 4 * A.z + assert C.masscenter.vel(Cint) == Vector(0) + # Check combination + N, A, P, C = _generate_body() + Pint, Cint = ReferenceFrame('Pint'), ReferenceFrame('Cint') + Pint.orient_body_fixed(N, (-pi / 2, pi, pi / 2), 'xyz') + Cint.orient_body_fixed(A, (2 * pi / 3, -pi, pi / 2), 'xyz') + PinJoint('J', P, C, q, u, parent_point=N.x - N.y, child_point=-C.z, + parent_interframe=Pint, child_interframe=Cint, + joint_axis=Pint.x + Pint.z) + assert simplify(N.dcm(A)) == Matrix([ + [cos(q), (sqrt(2) + sqrt(6)) * -sin(q) / 4, + (-sqrt(2) + sqrt(6)) * sin(q) / 4], + [-sqrt(2) * sin(q) / 2, + -sqrt(3) * (cos(q) + 1) / 4 - cos(q) / 4 + S(1) / 4, + sqrt(3) * (cos(q) - 1) / 4 - cos(q) / 4 - S(1) / 4], + [sqrt(2) * sin(q) / 2, + sqrt(3) * (cos(q) - 1) / 4 + cos(q) / 4 + S(1) / 4, + -sqrt(3) * (cos(q) + 1) / 4 + cos(q) / 4 - S(1) / 4]]) + assert A.ang_vel_in(N) == sqrt(2) * u / 2 * Pint.x + sqrt( + 2) * u / 2 * Pint.z + assert C.masscenter.pos_from(P.masscenter) == N.x - N.y + A.z + N_v_C = (-sqrt(2) + sqrt(6)) * u / 4 * A.x + assert C.masscenter.vel(N).simplify() == N_v_C + assert C.masscenter.vel(Pint).simplify() == N_v_C + assert C.masscenter.vel(Cint) == Vector(0) + + +def test_pin_joint_joint_axis(): + q, u = dynamicsymbols('q, u') + # Check parent as reference + N, A, P, C, Pint, Cint = _generate_body(True) + pin = PinJoint('J', P, C, q, u, parent_interframe=Pint, + child_interframe=Cint, joint_axis=P.y) + assert pin.joint_axis == P.y + assert N.dcm(A) == Matrix([[sin(q), 0, cos(q)], [0, -1, 0], + [cos(q), 0, -sin(q)]]) + # Check parent_interframe as reference + N, A, P, C, Pint, Cint = _generate_body(True) + pin = PinJoint('J', P, C, q, u, parent_interframe=Pint, + child_interframe=Cint, joint_axis=Pint.y) + assert pin.joint_axis == Pint.y + assert N.dcm(A) == Matrix([[-sin(q), 0, cos(q)], [0, -1, 0], + [cos(q), 0, sin(q)]]) + # Check combination of joint_axis with interframes supplied as vectors (2x) + N, A, P, C = _generate_body() + pin = PinJoint('J', P, C, q, u, parent_interframe=N.z, + child_interframe=-C.z, joint_axis=N.z) + assert pin.joint_axis == N.z + assert N.dcm(A) == Matrix([[-cos(q), -sin(q), 0], [-sin(q), cos(q), 0], + [0, 0, -1]]) + N, A, P, C = _generate_body() + pin = PinJoint('J', P, C, q, u, parent_interframe=N.z, + child_interframe=-C.z, joint_axis=N.x) + assert pin.joint_axis == N.x + assert N.dcm(A) == Matrix([[-1, 0, 0], [0, cos(q), sin(q)], + [0, sin(q), -cos(q)]]) + # Check time varying axis + N, A, P, C, Pint, Cint = _generate_body(True) + raises(ValueError, lambda: PinJoint('J', P, C, + joint_axis=cos(q) * N.x + sin(q) * N.y)) + # Check joint_axis provided in child frame + raises(ValueError, lambda: PinJoint('J', P, C, joint_axis=C.x)) + # Check some invalid combinations + raises(ValueError, lambda: PinJoint('J', P, C, joint_axis=P.x + C.y)) + raises(ValueError, lambda: PinJoint( + 'J', P, C, parent_interframe=Pint, child_interframe=Cint, + joint_axis=Pint.x + C.y)) + raises(ValueError, lambda: PinJoint( + 'J', P, C, parent_interframe=Pint, child_interframe=Cint, + joint_axis=P.x + Cint.y)) + # Check valid special combination + N, A, P, C, Pint, Cint = _generate_body(True) + PinJoint('J', P, C, parent_interframe=Pint, child_interframe=Cint, + joint_axis=Pint.x + P.y) + # Check invalid zero vector + raises(Exception, lambda: PinJoint( + 'J', P, C, parent_interframe=Pint, child_interframe=Cint, + joint_axis=Vector(0))) + raises(Exception, lambda: PinJoint( + 'J', P, C, parent_interframe=Pint, child_interframe=Cint, + joint_axis=P.y + Pint.y)) + + +def test_pin_joint_arbitrary_axis(): + q, u = dynamicsymbols('q_J, u_J') + + # When the bodies are attached though masscenters but axes are opposite. + N, A, P, C = _generate_body() + PinJoint('J', P, C, child_interframe=-A.x) + + assert (-A.x).angle_between(N.x) == 0 + assert -A.x.express(N) == N.x + assert A.dcm(N) == Matrix([[-1, 0, 0], + [0, -cos(q), -sin(q)], + [0, -sin(q), cos(q)]]) + assert A.ang_vel_in(N) == u*N.x + assert A.ang_vel_in(N).magnitude() == sqrt(u**2) + assert C.masscenter.pos_from(P.masscenter) == 0 + assert C.masscenter.pos_from(P.masscenter).express(N).simplify() == 0 + assert C.masscenter.vel(N) == 0 + + # When axes are different and parent joint is at masscenter but child joint + # is at a unit vector from child masscenter. + N, A, P, C = _generate_body() + PinJoint('J', P, C, child_interframe=A.y, child_point=A.x) + + assert A.y.angle_between(N.x) == 0 # Axis are aligned + assert A.y.express(N) == N.x + assert A.dcm(N) == Matrix([[0, -cos(q), -sin(q)], + [1, 0, 0], + [0, -sin(q), cos(q)]]) + assert A.ang_vel_in(N) == u*N.x + assert A.ang_vel_in(N).express(A) == u * A.y + assert A.ang_vel_in(N).magnitude() == sqrt(u**2) + assert A.ang_vel_in(N).cross(A.y) == 0 + assert C.masscenter.vel(N) == u*A.z + assert C.masscenter.pos_from(P.masscenter) == -A.x + assert (C.masscenter.pos_from(P.masscenter).express(N).simplify() == + cos(q)*N.y + sin(q)*N.z) + assert C.masscenter.vel(N).angle_between(A.x) == pi/2 + + # Similar to previous case but wrt parent body + N, A, P, C = _generate_body() + PinJoint('J', P, C, parent_interframe=N.y, parent_point=N.x) + + assert N.y.angle_between(A.x) == 0 # Axis are aligned + assert N.y.express(A) == A.x + assert A.dcm(N) == Matrix([[0, 1, 0], + [-cos(q), 0, sin(q)], + [sin(q), 0, cos(q)]]) + assert A.ang_vel_in(N) == u*N.y + assert A.ang_vel_in(N).express(A) == u*A.x + assert A.ang_vel_in(N).magnitude() == sqrt(u**2) + angle = A.ang_vel_in(N).angle_between(A.x) + assert angle.xreplace({u: 1}) == 0 + assert C.masscenter.vel(N) == 0 + assert C.masscenter.pos_from(P.masscenter) == N.x + + # Both joint pos id defined but different axes + N, A, P, C = _generate_body() + PinJoint('J', P, C, parent_point=N.x, child_point=A.x, + child_interframe=A.x + A.y) + assert expand_mul(N.x.angle_between(A.x + A.y)) == 0 # Axis are aligned + assert (A.x + A.y).express(N).simplify() == sqrt(2)*N.x + assert simplify(A.dcm(N)) == Matrix([ + [sqrt(2)/2, -sqrt(2)*cos(q)/2, -sqrt(2)*sin(q)/2], + [sqrt(2)/2, sqrt(2)*cos(q)/2, sqrt(2)*sin(q)/2], + [0, -sin(q), cos(q)]]) + assert A.ang_vel_in(N) == u*N.x + assert (A.ang_vel_in(N).express(A).simplify() == + (u*A.x + u*A.y)/sqrt(2)) + assert A.ang_vel_in(N).magnitude() == sqrt(u**2) + angle = A.ang_vel_in(N).angle_between(A.x + A.y) + assert angle.xreplace({u: 1}) == 0 + assert C.masscenter.vel(N).simplify() == (u * A.z)/sqrt(2) + assert C.masscenter.pos_from(P.masscenter) == N.x - A.x + assert (C.masscenter.pos_from(P.masscenter).express(N).simplify() == + (1 - sqrt(2)/2)*N.x + sqrt(2)*cos(q)/2*N.y + + sqrt(2)*sin(q)/2*N.z) + assert (C.masscenter.vel(N).express(N).simplify() == + -sqrt(2)*u*sin(q)/2*N.y + sqrt(2)*u*cos(q)/2*N.z) + assert C.masscenter.vel(N).angle_between(A.x) == pi/2 + + N, A, P, C = _generate_body() + PinJoint('J', P, C, parent_point=N.x, child_point=A.x, + child_interframe=A.x + A.y - A.z) + assert expand_mul(N.x.angle_between(A.x + A.y - A.z)) == 0 # Axis aligned + assert (A.x + A.y - A.z).express(N).simplify() == sqrt(3)*N.x + assert simplify(A.dcm(N)) == Matrix([ + [sqrt(3)/3, -sqrt(6)*sin(q + pi/4)/3, + sqrt(6)*cos(q + pi/4)/3], + [sqrt(3)/3, sqrt(6)*cos(q + pi/12)/3, + sqrt(6)*sin(q + pi/12)/3], + [-sqrt(3)/3, sqrt(6)*cos(q + 5*pi/12)/3, + sqrt(6)*sin(q + 5*pi/12)/3]]) + assert A.ang_vel_in(N) == u*N.x + assert A.ang_vel_in(N).express(A).simplify() == (u*A.x + u*A.y - + u*A.z)/sqrt(3) + assert A.ang_vel_in(N).magnitude() == sqrt(u**2) + angle = A.ang_vel_in(N).angle_between(A.x + A.y-A.z) + assert angle.xreplace({u: 1}).simplify() == 0 + assert C.masscenter.vel(N).simplify() == (u*A.y + u*A.z)/sqrt(3) + assert C.masscenter.pos_from(P.masscenter) == N.x - A.x + assert (C.masscenter.pos_from(P.masscenter).express(N).simplify() == + (1 - sqrt(3)/3)*N.x + sqrt(6)*sin(q + pi/4)/3*N.y - + sqrt(6)*cos(q + pi/4)/3*N.z) + assert (C.masscenter.vel(N).express(N).simplify() == + sqrt(6)*u*cos(q + pi/4)/3*N.y + + sqrt(6)*u*sin(q + pi/4)/3*N.z) + assert C.masscenter.vel(N).angle_between(A.x) == pi/2 + + N, A, P, C = _generate_body() + m, n = symbols('m n') + PinJoint('J', P, C, parent_point=m * N.x, child_point=n * A.x, + child_interframe=A.x + A.y - A.z, + parent_interframe=N.x - N.y + N.z) + angle = (N.x - N.y + N.z).angle_between(A.x + A.y - A.z) + assert expand_mul(angle) == 0 # Axis are aligned + assert ((A.x-A.y+A.z).express(N).simplify() == + (-4*cos(q)/3 - S(1)/3)*N.x + (S(1)/3 - 4*sin(q + pi/6)/3)*N.y + + (4*cos(q + pi/3)/3 - S(1)/3)*N.z) + assert simplify(A.dcm(N)) == Matrix([ + [S(1)/3 - 2*cos(q)/3, -2*sin(q + pi/6)/3 - S(1)/3, + 2*cos(q + pi/3)/3 + S(1)/3], + [2*cos(q + pi/3)/3 + S(1)/3, 2*cos(q)/3 - S(1)/3, + 2*sin(q + pi/6)/3 + S(1)/3], + [-2*sin(q + pi/6)/3 - S(1)/3, 2*cos(q + pi/3)/3 + S(1)/3, + 2*cos(q)/3 - S(1)/3]]) + assert (A.ang_vel_in(N) - (u*N.x - u*N.y + u*N.z)/sqrt(3)).simplify() + assert A.ang_vel_in(N).express(A).simplify() == (u*A.x + u*A.y - + u*A.z)/sqrt(3) + assert A.ang_vel_in(N).magnitude() == sqrt(u**2) + angle = A.ang_vel_in(N).angle_between(A.x+A.y-A.z) + assert angle.xreplace({u: 1}).simplify() == 0 + assert (C.masscenter.vel(N).simplify() == + sqrt(3)*n*u/3*A.y + sqrt(3)*n*u/3*A.z) + assert C.masscenter.pos_from(P.masscenter) == m*N.x - n*A.x + assert (C.masscenter.pos_from(P.masscenter).express(N).simplify() == + (m + n*(2*cos(q) - 1)/3)*N.x + n*(2*sin(q + pi/6) + + 1)/3*N.y - n*(2*cos(q + pi/3) + 1)/3*N.z) + assert (C.masscenter.vel(N).express(N).simplify() == + - 2*n*u*sin(q)/3*N.x + 2*n*u*cos(q + pi/6)/3*N.y + + 2*n*u*sin(q + pi/3)/3*N.z) + assert C.masscenter.vel(N).dot(N.x - N.y + N.z).simplify() == 0 + + +def test_create_aligned_frame_pi(): + N, A, P, C = _generate_body() + f = Joint._create_aligned_interframe(P, -P.x, P.x) + assert f.z == P.z + f = Joint._create_aligned_interframe(P, -P.y, P.y) + assert f.x == P.x + f = Joint._create_aligned_interframe(P, -P.z, P.z) + assert f.y == P.y + f = Joint._create_aligned_interframe(P, -P.x - P.y, P.x + P.y) + assert f.z == P.z + f = Joint._create_aligned_interframe(P, -P.y - P.z, P.y + P.z) + assert f.x == P.x + f = Joint._create_aligned_interframe(P, -P.x - P.z, P.x + P.z) + assert f.y == P.y + f = Joint._create_aligned_interframe(P, -P.x - P.y - P.z, P.x + P.y + P.z) + assert f.y - f.z == P.y - P.z + + +def test_pin_joint_axis(): + q, u = dynamicsymbols('q u') + # Test default joint axis + N, A, P, C, Pint, Cint = _generate_body(True) + J = PinJoint('J', P, C, q, u, parent_interframe=Pint, child_interframe=Cint) + assert J.joint_axis == Pint.x + # Test for the same joint axis expressed in different frames + N_R_A = Matrix([[0, sin(q), cos(q)], + [0, -cos(q), sin(q)], + [1, 0, 0]]) + N, A, P, C, Pint, Cint = _generate_body(True) + PinJoint('J', P, C, q, u, parent_interframe=Pint, child_interframe=Cint, + joint_axis=N.z) + assert N.dcm(A) == N_R_A + N, A, P, C, Pint, Cint = _generate_body(True) + PinJoint('J', P, C, q, u, parent_interframe=Pint, child_interframe=Cint, + joint_axis=-Pint.z) + assert N.dcm(A) == N_R_A + # Test time varying joint axis + N, A, P, C, Pint, Cint = _generate_body(True) + raises(ValueError, lambda: PinJoint('J', P, C, joint_axis=q * N.z)) + + +def test_locate_joint_pos(): + # Test Vector and default + N, A, P, C = _generate_body() + joint = PinJoint('J', P, C, parent_point=N.y + N.z) + assert joint.parent_point.name == 'J_P_joint' + assert joint.parent_point.pos_from(P.masscenter) == N.y + N.z + assert joint.child_point == C.masscenter + # Test Point objects + N, A, P, C = _generate_body() + parent_point = P.masscenter.locatenew('p', N.y + N.z) + joint = PinJoint('J', P, C, parent_point=parent_point, + child_point=C.masscenter) + assert joint.parent_point == parent_point + assert joint.child_point == C.masscenter + # Check invalid type + N, A, P, C = _generate_body() + raises(TypeError, + lambda: PinJoint('J', P, C, parent_point=N.x.to_matrix(N))) + # Test time varying positions + q = dynamicsymbols('q') + N, A, P, C = _generate_body() + raises(ValueError, lambda: PinJoint('J', P, C, parent_point=q * N.x)) + N, A, P, C = _generate_body() + child_point = C.masscenter.locatenew('p', q * A.y) + raises(ValueError, lambda: PinJoint('J', P, C, child_point=child_point)) + # Test undefined position + child_point = Point('p') + raises(ValueError, lambda: PinJoint('J', P, C, child_point=child_point)) + + +def test_locate_joint_frame(): + # Test rotated frame and default + N, A, P, C = _generate_body() + parent_interframe = ReferenceFrame('int_frame') + parent_interframe.orient_axis(N, N.z, 1) + joint = PinJoint('J', P, C, parent_interframe=parent_interframe) + assert joint.parent_interframe == parent_interframe + assert joint.parent_interframe.ang_vel_in(N) == 0 + assert joint.child_interframe == A + # Test time varying orientations + q = dynamicsymbols('q') + N, A, P, C = _generate_body() + parent_interframe = ReferenceFrame('int_frame') + parent_interframe.orient_axis(N, N.z, q) + raises(ValueError, + lambda: PinJoint('J', P, C, parent_interframe=parent_interframe)) + # Test undefined frame + N, A, P, C = _generate_body() + child_interframe = ReferenceFrame('int_frame') + child_interframe.orient_axis(N, N.z, 1) # Defined with respect to parent + raises(ValueError, + lambda: PinJoint('J', P, C, child_interframe=child_interframe)) + + +def test_prismatic_joint(): + _, _, P, C = _generate_body() + q, u = dynamicsymbols('q_S, u_S') + S = PrismaticJoint('S', P, C) + assert S.name == 'S' + assert S.parent == P + assert S.child == C + assert S.coordinates == Matrix([q]) + assert S.speeds == Matrix([u]) + assert S.kdes == Matrix([u - q.diff(t)]) + assert S.joint_axis == P.frame.x + assert S.child_point.pos_from(C.masscenter) == Vector(0) + assert S.parent_point.pos_from(P.masscenter) == Vector(0) + assert S.parent_point.pos_from(S.child_point) == - q * P.frame.x + assert P.masscenter.pos_from(C.masscenter) == - q * P.frame.x + assert C.masscenter.vel(P.frame) == u * P.frame.x + assert P.frame.ang_vel_in(C.frame) == 0 + assert C.frame.ang_vel_in(P.frame) == 0 + assert S.__str__() == 'PrismaticJoint: S parent: P child: C' + + N, A, P, C = _generate_body() + l, m = symbols('l m') + Pint = ReferenceFrame('P_int') + Pint.orient_axis(P.frame, P.y, pi / 2) + S = PrismaticJoint('S', P, C, parent_point=l * P.frame.x, + child_point=m * C.frame.y, joint_axis=P.frame.z, + parent_interframe=Pint) + + assert S.joint_axis == P.frame.z + assert S.child_point.pos_from(C.masscenter) == m * C.frame.y + assert S.parent_point.pos_from(P.masscenter) == l * P.frame.x + assert S.parent_point.pos_from(S.child_point) == - q * P.frame.z + assert P.masscenter.pos_from(C.masscenter) == - l * N.x - q * N.z + m * A.y + assert C.masscenter.vel(P.frame) == u * P.frame.z + assert P.masscenter.vel(Pint) == Vector(0) + assert C.frame.ang_vel_in(P.frame) == 0 + assert P.frame.ang_vel_in(C.frame) == 0 + + _, _, P, C = _generate_body() + Pint = ReferenceFrame('P_int') + Pint.orient_axis(P.frame, P.y, pi / 2) + S = PrismaticJoint('S', P, C, parent_point=l * P.frame.z, + child_point=m * C.frame.x, joint_axis=P.frame.z, + parent_interframe=Pint) + assert S.joint_axis == P.frame.z + assert S.child_point.pos_from(C.masscenter) == m * C.frame.x + assert S.parent_point.pos_from(P.masscenter) == l * P.frame.z + assert S.parent_point.pos_from(S.child_point) == - q * P.frame.z + assert P.masscenter.pos_from(C.masscenter) == (-l - q)*P.frame.z + m*C.frame.x + assert C.masscenter.vel(P.frame) == u * P.frame.z + assert C.frame.ang_vel_in(P.frame) == 0 + assert P.frame.ang_vel_in(C.frame) == 0 + + +def test_prismatic_joint_arbitrary_axis(): + q, u = dynamicsymbols('q_S, u_S') + + N, A, P, C = _generate_body() + PrismaticJoint('S', P, C, child_interframe=-A.x) + + assert (-A.x).angle_between(N.x) == 0 + assert -A.x.express(N) == N.x + assert A.dcm(N) == Matrix([[-1, 0, 0], [0, -1, 0], [0, 0, 1]]) + assert C.masscenter.pos_from(P.masscenter) == q * N.x + assert C.masscenter.pos_from(P.masscenter).express(A).simplify() == -q * A.x + assert C.masscenter.vel(N) == u * N.x + assert C.masscenter.vel(N).express(A) == -u * A.x + assert A.ang_vel_in(N) == 0 + assert N.ang_vel_in(A) == 0 + + #When axes are different and parent joint is at masscenter but child joint is at a unit vector from + #child masscenter. + N, A, P, C = _generate_body() + PrismaticJoint('S', P, C, child_interframe=A.y, child_point=A.x) + + assert A.y.angle_between(N.x) == 0 #Axis are aligned + assert A.y.express(N) == N.x + assert A.dcm(N) == Matrix([[0, -1, 0], [1, 0, 0], [0, 0, 1]]) + assert C.masscenter.vel(N) == u * N.x + assert C.masscenter.vel(N).express(A) == u * A.y + assert C.masscenter.pos_from(P.masscenter) == q*N.x - A.x + assert C.masscenter.pos_from(P.masscenter).express(N).simplify() == q*N.x + N.y + assert A.ang_vel_in(N) == 0 + assert N.ang_vel_in(A) == 0 + + #Similar to previous case but wrt parent body + N, A, P, C = _generate_body() + PrismaticJoint('S', P, C, parent_interframe=N.y, parent_point=N.x) + + assert N.y.angle_between(A.x) == 0 #Axis are aligned + assert N.y.express(A) == A.x + assert A.dcm(N) == Matrix([[0, 1, 0], [-1, 0, 0], [0, 0, 1]]) + assert C.masscenter.vel(N) == u * N.y + assert C.masscenter.vel(N).express(A) == u * A.x + assert C.masscenter.pos_from(P.masscenter) == N.x + q*N.y + assert A.ang_vel_in(N) == 0 + assert N.ang_vel_in(A) == 0 + + #Both joint pos is defined but different axes + N, A, P, C = _generate_body() + PrismaticJoint('S', P, C, parent_point=N.x, child_point=A.x, + child_interframe=A.x + A.y) + assert N.x.angle_between(A.x + A.y) == 0 #Axis are aligned + assert (A.x + A.y).express(N) == sqrt(2)*N.x + assert A.dcm(N) == Matrix([[sqrt(2)/2, -sqrt(2)/2, 0], [sqrt(2)/2, sqrt(2)/2, 0], [0, 0, 1]]) + assert C.masscenter.pos_from(P.masscenter) == (q + 1)*N.x - A.x + assert C.masscenter.pos_from(P.masscenter).express(N) == (q - sqrt(2)/2 + 1)*N.x + sqrt(2)/2*N.y + assert C.masscenter.vel(N).express(A) == u * (A.x + A.y)/sqrt(2) + assert C.masscenter.vel(N) == u*N.x + assert A.ang_vel_in(N) == 0 + assert N.ang_vel_in(A) == 0 + + N, A, P, C = _generate_body() + PrismaticJoint('S', P, C, parent_point=N.x, child_point=A.x, + child_interframe=A.x + A.y - A.z) + assert N.x.angle_between(A.x + A.y - A.z).simplify() == 0 #Axis are aligned + assert ((A.x + A.y - A.z).express(N) - sqrt(3)*N.x).simplify() == 0 + assert simplify(A.dcm(N)) == Matrix([[sqrt(3)/3, -sqrt(3)/3, sqrt(3)/3], + [sqrt(3)/3, sqrt(3)/6 + S(1)/2, S(1)/2 - sqrt(3)/6], + [-sqrt(3)/3, S(1)/2 - sqrt(3)/6, sqrt(3)/6 + S(1)/2]]) + assert C.masscenter.pos_from(P.masscenter) == (q + 1)*N.x - A.x + assert (C.masscenter.pos_from(P.masscenter).express(N) - + ((q - sqrt(3)/3 + 1)*N.x + sqrt(3)/3*N.y - sqrt(3)/3*N.z)).simplify() == 0 + assert C.masscenter.vel(N) == u*N.x + assert (C.masscenter.vel(N).express(A) - ( + sqrt(3)*u/3*A.x + sqrt(3)*u/3*A.y - sqrt(3)*u/3*A.z)).simplify() + assert A.ang_vel_in(N) == 0 + assert N.ang_vel_in(A) == 0 + + N, A, P, C = _generate_body() + m, n = symbols('m n') + PrismaticJoint('S', P, C, parent_point=m*N.x, child_point=n*A.x, + child_interframe=A.x + A.y - A.z, + parent_interframe=N.x - N.y + N.z) + # 0 angle means that the axis are aligned + assert (N.x-N.y+N.z).angle_between(A.x+A.y-A.z).simplify() == 0 + assert ((A.x+A.y-A.z).express(N) - (N.x - N.y + N.z)).simplify() == 0 + assert simplify(A.dcm(N)) == Matrix([[-S(1)/3, -S(2)/3, S(2)/3], + [S(2)/3, S(1)/3, S(2)/3], + [-S(2)/3, S(2)/3, S(1)/3]]) + assert (C.masscenter.pos_from(P.masscenter) - ( + (m + sqrt(3)*q/3)*N.x - sqrt(3)*q/3*N.y + sqrt(3)*q/3*N.z - n*A.x) + ).express(N).simplify() == 0 + assert (C.masscenter.pos_from(P.masscenter).express(N) - ( + (m + n/3 + sqrt(3)*q/3)*N.x + (2*n/3 - sqrt(3)*q/3)*N.y + + (-2*n/3 + sqrt(3)*q/3)*N.z)).simplify() == 0 + assert (C.masscenter.vel(N).express(N) - ( + sqrt(3)*u/3*N.x - sqrt(3)*u/3*N.y + sqrt(3)*u/3*N.z)).simplify() == 0 + assert (C.masscenter.vel(N).express(A) - + (sqrt(3)*u/3*A.x + sqrt(3)*u/3*A.y - sqrt(3)*u/3*A.z)).simplify() == 0 + assert A.ang_vel_in(N) == 0 + assert N.ang_vel_in(A) == 0 + + +def test_cylindrical_joint(): + N, A, P, C = _generate_body() + q0_def, q1_def, u0_def, u1_def = dynamicsymbols('q0:2_J, u0:2_J') + Cj = CylindricalJoint('J', P, C) + assert Cj.name == 'J' + assert Cj.parent == P + assert Cj.child == C + assert Cj.coordinates == Matrix([q0_def, q1_def]) + assert Cj.speeds == Matrix([u0_def, u1_def]) + assert Cj.rotation_coordinate == q0_def + assert Cj.translation_coordinate == q1_def + assert Cj.rotation_speed == u0_def + assert Cj.translation_speed == u1_def + assert Cj.kdes == Matrix([u0_def - q0_def.diff(t), u1_def - q1_def.diff(t)]) + assert Cj.joint_axis == N.x + assert Cj.child_point.pos_from(C.masscenter) == Vector(0) + assert Cj.parent_point.pos_from(P.masscenter) == Vector(0) + assert Cj.parent_point.pos_from(Cj._child_point) == -q1_def * N.x + assert C.masscenter.pos_from(P.masscenter) == q1_def * N.x + assert Cj.child_point.vel(N) == u1_def * N.x + assert A.ang_vel_in(N) == u0_def * N.x + assert Cj.parent_interframe == N + assert Cj.child_interframe == A + assert Cj.__str__() == 'CylindricalJoint: J parent: P child: C' + + q0, q1, u0, u1 = dynamicsymbols('q0:2, u0:2') + l, m = symbols('l, m') + N, A, P, C, Pint, Cint = _generate_body(True) + Cj = CylindricalJoint('J', P, C, rotation_coordinate=q0, rotation_speed=u0, + translation_speed=u1, parent_point=m * N.x, + child_point=l * A.y, parent_interframe=Pint, + child_interframe=Cint, joint_axis=2 * N.z) + assert Cj.coordinates == Matrix([q0, q1_def]) + assert Cj.speeds == Matrix([u0, u1]) + assert Cj.rotation_coordinate == q0 + assert Cj.translation_coordinate == q1_def + assert Cj.rotation_speed == u0 + assert Cj.translation_speed == u1 + assert Cj.kdes == Matrix([u0 - q0.diff(t), u1 - q1_def.diff(t)]) + assert Cj.joint_axis == 2 * N.z + assert Cj.child_point.pos_from(C.masscenter) == l * A.y + assert Cj.parent_point.pos_from(P.masscenter) == m * N.x + assert Cj.parent_point.pos_from(Cj._child_point) == -q1_def * N.z + assert C.masscenter.pos_from( + P.masscenter) == m * N.x + q1_def * N.z - l * A.y + assert C.masscenter.vel(N) == u1 * N.z - u0 * l * A.z + assert A.ang_vel_in(N) == u0 * N.z + + +def test_planar_joint(): + N, A, P, C = _generate_body() + q0_def, q1_def, q2_def = dynamicsymbols('q0:3_J') + u0_def, u1_def, u2_def = dynamicsymbols('u0:3_J') + Cj = PlanarJoint('J', P, C) + assert Cj.name == 'J' + assert Cj.parent == P + assert Cj.child == C + assert Cj.coordinates == Matrix([q0_def, q1_def, q2_def]) + assert Cj.speeds == Matrix([u0_def, u1_def, u2_def]) + assert Cj.rotation_coordinate == q0_def + assert Cj.planar_coordinates == Matrix([q1_def, q2_def]) + assert Cj.rotation_speed == u0_def + assert Cj.planar_speeds == Matrix([u1_def, u2_def]) + assert Cj.kdes == Matrix([u0_def - q0_def.diff(t), u1_def - q1_def.diff(t), + u2_def - q2_def.diff(t)]) + assert Cj.rotation_axis == N.x + assert Cj.planar_vectors == [N.y, N.z] + assert Cj.child_point.pos_from(C.masscenter) == Vector(0) + assert Cj.parent_point.pos_from(P.masscenter) == Vector(0) + r_P_C = q1_def * N.y + q2_def * N.z + assert Cj.parent_point.pos_from(Cj.child_point) == -r_P_C + assert C.masscenter.pos_from(P.masscenter) == r_P_C + assert Cj.child_point.vel(N) == u1_def * N.y + u2_def * N.z + assert A.ang_vel_in(N) == u0_def * N.x + assert Cj.parent_interframe == N + assert Cj.child_interframe == A + assert Cj.__str__() == 'PlanarJoint: J parent: P child: C' + + q0, q1, q2, u0, u1, u2 = dynamicsymbols('q0:3, u0:3') + l, m = symbols('l, m') + N, A, P, C, Pint, Cint = _generate_body(True) + Cj = PlanarJoint('J', P, C, rotation_coordinate=q0, + planar_coordinates=[q1, q2], planar_speeds=[u1, u2], + parent_point=m * N.x, child_point=l * A.y, + parent_interframe=Pint, child_interframe=Cint) + assert Cj.coordinates == Matrix([q0, q1, q2]) + assert Cj.speeds == Matrix([u0_def, u1, u2]) + assert Cj.rotation_coordinate == q0 + assert Cj.planar_coordinates == Matrix([q1, q2]) + assert Cj.rotation_speed == u0_def + assert Cj.planar_speeds == Matrix([u1, u2]) + assert Cj.kdes == Matrix([u0_def - q0.diff(t), u1 - q1.diff(t), + u2 - q2.diff(t)]) + assert Cj.rotation_axis == Pint.x + assert Cj.planar_vectors == [Pint.y, Pint.z] + assert Cj.child_point.pos_from(C.masscenter) == l * A.y + assert Cj.parent_point.pos_from(P.masscenter) == m * N.x + assert Cj.parent_point.pos_from(Cj.child_point) == q1 * N.y + q2 * N.z + assert C.masscenter.pos_from( + P.masscenter) == m * N.x - q1 * N.y - q2 * N.z - l * A.y + assert C.masscenter.vel(N) == -u1 * N.y - u2 * N.z + u0_def * l * A.x + assert A.ang_vel_in(N) == u0_def * N.x + + +def test_planar_joint_advanced(): + # Tests whether someone is able to just specify two normals, which will form + # the rotation axis seen from the parent and child body. + # This specific example is a block on a slope, which has that same slope of + # 30 degrees, so in the zero configuration the frames of the parent and + # child are actually aligned. + q0, q1, q2, u0, u1, u2 = dynamicsymbols('q0:3, u0:3') + l1, l2 = symbols('l1:3') + N, A, P, C = _generate_body() + J = PlanarJoint('J', P, C, q0, [q1, q2], u0, [u1, u2], + parent_point=l1 * N.z, + child_point=-l2 * C.z, + parent_interframe=N.z + N.y / sqrt(3), + child_interframe=A.z + A.y / sqrt(3)) + assert J.rotation_axis.express(N) == (N.z + N.y / sqrt(3)).normalize() + assert J.rotation_axis.express(A) == (A.z + A.y / sqrt(3)).normalize() + assert J.rotation_axis.angle_between(N.z) == pi / 6 + assert N.dcm(A).xreplace({q0: 0, q1: 0, q2: 0}) == eye(3) + N_R_A = Matrix([ + [cos(q0), -sqrt(3) * sin(q0) / 2, sin(q0) / 2], + [sqrt(3) * sin(q0) / 2, 3 * cos(q0) / 4 + 1 / 4, + sqrt(3) * (1 - cos(q0)) / 4], + [-sin(q0) / 2, sqrt(3) * (1 - cos(q0)) / 4, cos(q0) / 4 + 3 / 4]]) + # N.dcm(A) == N_R_A did not work + assert simplify(N.dcm(A) - N_R_A) == zeros(3) + + +def test_spherical_joint(): + N, A, P, C = _generate_body() + q0, q1, q2, u0, u1, u2 = dynamicsymbols('q0:3_S, u0:3_S') + S = SphericalJoint('S', P, C) + assert S.name == 'S' + assert S.parent == P + assert S.child == C + assert S.coordinates == Matrix([q0, q1, q2]) + assert S.speeds == Matrix([u0, u1, u2]) + assert S.kdes == Matrix([u0 - q0.diff(t), u1 - q1.diff(t), u2 - q2.diff(t)]) + assert S.child_point.pos_from(C.masscenter) == Vector(0) + assert S.parent_point.pos_from(P.masscenter) == Vector(0) + assert S.parent_point.pos_from(S.child_point) == Vector(0) + assert P.masscenter.pos_from(C.masscenter) == Vector(0) + assert C.masscenter.vel(N) == Vector(0) + assert N.ang_vel_in(A) == (-u0 * cos(q1) * cos(q2) - u1 * sin(q2)) * A.x + ( + u0 * sin(q2) * cos(q1) - u1 * cos(q2)) * A.y + ( + -u0 * sin(q1) - u2) * A.z + assert A.ang_vel_in(N) == (u0 * cos(q1) * cos(q2) + u1 * sin(q2)) * A.x + ( + -u0 * sin(q2) * cos(q1) + u1 * cos(q2)) * A.y + ( + u0 * sin(q1) + u2) * A.z + assert S.__str__() == 'SphericalJoint: S parent: P child: C' + assert S._rot_type == 'BODY' + assert S._rot_order == 123 + assert S._amounts is None + + +def test_spherical_joint_speeds_as_derivative_terms(): + # This tests checks whether the system remains valid if the user chooses to + # pass the derivative of the generalized coordinates as generalized speeds + q0, q1, q2 = dynamicsymbols('q0:3') + u0, u1, u2 = dynamicsymbols('q0:3', 1) + N, A, P, C = _generate_body() + S = SphericalJoint('S', P, C, coordinates=[q0, q1, q2], speeds=[u0, u1, u2]) + assert S.coordinates == Matrix([q0, q1, q2]) + assert S.speeds == Matrix([u0, u1, u2]) + assert S.kdes == Matrix([0, 0, 0]) + assert N.ang_vel_in(A) == (-u0 * cos(q1) * cos(q2) - u1 * sin(q2)) * A.x + ( + u0 * sin(q2) * cos(q1) - u1 * cos(q2)) * A.y + ( + -u0 * sin(q1) - u2) * A.z + + +def test_spherical_joint_coords(): + q0s, q1s, q2s, u0s, u1s, u2s = dynamicsymbols('q0:3_S, u0:3_S') + q0, q1, q2, q3, u0, u1, u2, u4 = dynamicsymbols('q0:4, u0:4') + # Test coordinates as list + N, A, P, C = _generate_body() + S = SphericalJoint('S', P, C, [q0, q1, q2], [u0, u1, u2]) + assert S.coordinates == Matrix([q0, q1, q2]) + assert S.speeds == Matrix([u0, u1, u2]) + # Test coordinates as Matrix + N, A, P, C = _generate_body() + S = SphericalJoint('S', P, C, Matrix([q0, q1, q2]), + Matrix([u0, u1, u2])) + assert S.coordinates == Matrix([q0, q1, q2]) + assert S.speeds == Matrix([u0, u1, u2]) + # Test too few generalized coordinates + N, A, P, C = _generate_body() + raises(ValueError, + lambda: SphericalJoint('S', P, C, Matrix([q0, q1]), Matrix([u0]))) + # Test too many generalized coordinates + raises(ValueError, lambda: SphericalJoint( + 'S', P, C, Matrix([q0, q1, q2, q3]), Matrix([u0, u1, u2]))) + raises(ValueError, lambda: SphericalJoint( + 'S', P, C, Matrix([q0, q1, q2]), Matrix([u0, u1, u2, u4]))) + + +def test_spherical_joint_orient_body(): + q0, q1, q2, u0, u1, u2 = dynamicsymbols('q0:3, u0:3') + N_R_A = Matrix([ + [-sin(q1), -sin(q2) * cos(q1), cos(q1) * cos(q2)], + [-sin(q0) * cos(q1), sin(q0) * sin(q1) * sin(q2) - cos(q0) * cos(q2), + -sin(q0) * sin(q1) * cos(q2) - sin(q2) * cos(q0)], + [cos(q0) * cos(q1), -sin(q0) * cos(q2) - sin(q1) * sin(q2) * cos(q0), + -sin(q0) * sin(q2) + sin(q1) * cos(q0) * cos(q2)]]) + N_w_A = Matrix([[-u0 * sin(q1) - u2], + [-u0 * sin(q2) * cos(q1) + u1 * cos(q2)], + [u0 * cos(q1) * cos(q2) + u1 * sin(q2)]]) + N_v_Co = Matrix([ + [-sqrt(2) * (u0 * cos(q2 + pi / 4) * cos(q1) + u1 * sin(q2 + pi / 4))], + [-u0 * sin(q1) - u2], [-u0 * sin(q1) - u2]]) + # Test default rot_type='BODY', rot_order=123 + N, A, P, C, Pint, Cint = _generate_body(True) + S = SphericalJoint('S', P, C, coordinates=[q0, q1, q2], speeds=[u0, u1, u2], + parent_point=N.x + N.y, child_point=-A.y + A.z, + parent_interframe=Pint, child_interframe=Cint, + rot_type='body', rot_order=123) + assert S._rot_type.upper() == 'BODY' + assert S._rot_order == 123 + assert simplify(N.dcm(A) - N_R_A) == zeros(3) + assert simplify(A.ang_vel_in(N).to_matrix(A) - N_w_A) == zeros(3, 1) + assert simplify(C.masscenter.vel(N).to_matrix(A)) == N_v_Co + # Test change of amounts + N, A, P, C, Pint, Cint = _generate_body(True) + S = SphericalJoint('S', P, C, coordinates=[q0, q1, q2], speeds=[u0, u1, u2], + parent_point=N.x + N.y, child_point=-A.y + A.z, + parent_interframe=Pint, child_interframe=Cint, + rot_type='BODY', amounts=(q1, q0, q2), rot_order=123) + switch_order = lambda expr: expr.xreplace( + {q0: q1, q1: q0, q2: q2, u0: u1, u1: u0, u2: u2}) + assert S._rot_type.upper() == 'BODY' + assert S._rot_order == 123 + assert simplify(N.dcm(A) - switch_order(N_R_A)) == zeros(3) + assert simplify(A.ang_vel_in(N).to_matrix(A) - switch_order(N_w_A) + ) == zeros(3, 1) + assert simplify(C.masscenter.vel(N).to_matrix(A)) == switch_order(N_v_Co) + # Test different rot_order + N, A, P, C, Pint, Cint = _generate_body(True) + S = SphericalJoint('S', P, C, coordinates=[q0, q1, q2], speeds=[u0, u1, u2], + parent_point=N.x + N.y, child_point=-A.y + A.z, + parent_interframe=Pint, child_interframe=Cint, + rot_type='BodY', rot_order='yxz') + assert S._rot_type.upper() == 'BODY' + assert S._rot_order == 'yxz' + assert simplify(N.dcm(A) - Matrix([ + [-sin(q0) * cos(q1), sin(q0) * sin(q1) * cos(q2) - sin(q2) * cos(q0), + sin(q0) * sin(q1) * sin(q2) + cos(q0) * cos(q2)], + [-sin(q1), -cos(q1) * cos(q2), -sin(q2) * cos(q1)], + [cos(q0) * cos(q1), -sin(q0) * sin(q2) - sin(q1) * cos(q0) * cos(q2), + sin(q0) * cos(q2) - sin(q1) * sin(q2) * cos(q0)]])) == zeros(3) + assert simplify(A.ang_vel_in(N).to_matrix(A) - Matrix([ + [u0 * sin(q1) - u2], [u0 * cos(q1) * cos(q2) - u1 * sin(q2)], + [u0 * sin(q2) * cos(q1) + u1 * cos(q2)]])) == zeros(3, 1) + assert simplify(C.masscenter.vel(N).to_matrix(A)) == Matrix([ + [-sqrt(2) * (u0 * sin(q2 + pi / 4) * cos(q1) + u1 * cos(q2 + pi / 4))], + [u0 * sin(q1) - u2], [u0 * sin(q1) - u2]]) + + +def test_spherical_joint_orient_space(): + q0, q1, q2, u0, u1, u2 = dynamicsymbols('q0:3, u0:3') + N_R_A = Matrix([ + [-sin(q0) * sin(q2) - sin(q1) * cos(q0) * cos(q2), + sin(q0) * sin(q1) * cos(q2) - sin(q2) * cos(q0), cos(q1) * cos(q2)], + [-sin(q0) * cos(q2) + sin(q1) * sin(q2) * cos(q0), + -sin(q0) * sin(q1) * sin(q2) - cos(q0) * cos(q2), -sin(q2) * cos(q1)], + [cos(q0) * cos(q1), -sin(q0) * cos(q1), sin(q1)]]) + N_w_A = Matrix([ + [u1 * sin(q0) - u2 * cos(q0) * cos(q1)], + [u1 * cos(q0) + u2 * sin(q0) * cos(q1)], [u0 - u2 * sin(q1)]]) + N_v_Co = Matrix([ + [u0 - u2 * sin(q1)], [u0 - u2 * sin(q1)], + [sqrt(2) * (-u1 * sin(q0 + pi / 4) + u2 * cos(q0 + pi / 4) * cos(q1))]]) + # Test default rot_type='BODY', rot_order=123 + N, A, P, C, Pint, Cint = _generate_body(True) + S = SphericalJoint('S', P, C, coordinates=[q0, q1, q2], speeds=[u0, u1, u2], + parent_point=N.x + N.z, child_point=-A.x + A.y, + parent_interframe=Pint, child_interframe=Cint, + rot_type='space', rot_order=123) + assert S._rot_type.upper() == 'SPACE' + assert S._rot_order == 123 + assert simplify(N.dcm(A) - N_R_A) == zeros(3) + assert simplify(A.ang_vel_in(N).to_matrix(A)) == N_w_A + assert simplify(C.masscenter.vel(N).to_matrix(A)) == N_v_Co + # Test change of amounts + switch_order = lambda expr: expr.xreplace( + {q0: q1, q1: q0, q2: q2, u0: u1, u1: u0, u2: u2}) + N, A, P, C, Pint, Cint = _generate_body(True) + S = SphericalJoint('S', P, C, coordinates=[q0, q1, q2], speeds=[u0, u1, u2], + parent_point=N.x + N.z, child_point=-A.x + A.y, + parent_interframe=Pint, child_interframe=Cint, + rot_type='SPACE', amounts=(q1, q0, q2), rot_order=123) + assert S._rot_type.upper() == 'SPACE' + assert S._rot_order == 123 + assert simplify(N.dcm(A) - switch_order(N_R_A)) == zeros(3) + assert simplify(A.ang_vel_in(N).to_matrix(A)) == switch_order(N_w_A) + assert simplify(C.masscenter.vel(N).to_matrix(A)) == switch_order(N_v_Co) + # Test different rot_order + N, A, P, C, Pint, Cint = _generate_body(True) + S = SphericalJoint('S', P, C, coordinates=[q0, q1, q2], speeds=[u0, u1, u2], + parent_point=N.x + N.z, child_point=-A.x + A.y, + parent_interframe=Pint, child_interframe=Cint, + rot_type='SPaCe', rot_order='zxy') + assert S._rot_type.upper() == 'SPACE' + assert S._rot_order == 'zxy' + assert simplify(N.dcm(A) - Matrix([ + [-sin(q2) * cos(q1), -sin(q0) * cos(q2) + sin(q1) * sin(q2) * cos(q0), + sin(q0) * sin(q1) * sin(q2) + cos(q0) * cos(q2)], + [-sin(q1), -cos(q0) * cos(q1), -sin(q0) * cos(q1)], + [cos(q1) * cos(q2), -sin(q0) * sin(q2) - sin(q1) * cos(q0) * cos(q2), + -sin(q0) * sin(q1) * cos(q2) + sin(q2) * cos(q0)]])) + assert simplify(A.ang_vel_in(N).to_matrix(A) - Matrix([ + [-u0 + u2 * sin(q1)], [-u1 * sin(q0) + u2 * cos(q0) * cos(q1)], + [u1 * cos(q0) + u2 * sin(q0) * cos(q1)]])) == zeros(3, 1) + assert simplify(C.masscenter.vel(N).to_matrix(A) - Matrix([ + [u1 * cos(q0) + u2 * sin(q0) * cos(q1)], + [u1 * cos(q0) + u2 * sin(q0) * cos(q1)], + [u0 + u1 * sin(q0) - u2 * sin(q1) - + u2 * cos(q0) * cos(q1)]])) == zeros(3, 1) + + +def test_weld_joint(): + _, _, P, C = _generate_body() + W = WeldJoint('W', P, C) + assert W.name == 'W' + assert W.parent == P + assert W.child == C + assert W.coordinates == Matrix() + assert W.speeds == Matrix() + assert W.kdes == Matrix(1, 0, []).T + assert P.frame.dcm(C.frame) == eye(3) + assert W.child_point.pos_from(C.masscenter) == Vector(0) + assert W.parent_point.pos_from(P.masscenter) == Vector(0) + assert W.parent_point.pos_from(W.child_point) == Vector(0) + assert P.masscenter.pos_from(C.masscenter) == Vector(0) + assert C.masscenter.vel(P.frame) == Vector(0) + assert P.frame.ang_vel_in(C.frame) == 0 + assert C.frame.ang_vel_in(P.frame) == 0 + assert W.__str__() == 'WeldJoint: W parent: P child: C' + + N, A, P, C = _generate_body() + l, m = symbols('l m') + Pint = ReferenceFrame('P_int') + Pint.orient_axis(P.frame, P.y, pi / 2) + W = WeldJoint('W', P, C, parent_point=l * P.frame.x, + child_point=m * C.frame.y, parent_interframe=Pint) + + assert W.child_point.pos_from(C.masscenter) == m * C.frame.y + assert W.parent_point.pos_from(P.masscenter) == l * P.frame.x + assert W.parent_point.pos_from(W.child_point) == Vector(0) + assert P.masscenter.pos_from(C.masscenter) == - l * N.x + m * A.y + assert C.masscenter.vel(P.frame) == Vector(0) + assert P.masscenter.vel(Pint) == Vector(0) + assert C.frame.ang_vel_in(P.frame) == 0 + assert P.frame.ang_vel_in(C.frame) == 0 + assert P.x == A.z + + with warns_deprecated_sympy(): + JointsMethod(P, W) # Tests #10770 + + +def test_deprecated_parent_child_axis(): + q, u = dynamicsymbols('q_J, u_J') + N, A, P, C = _generate_body() + with warns_deprecated_sympy(): + PinJoint('J', P, C, child_axis=-A.x) + assert (-A.x).angle_between(N.x) == 0 + assert -A.x.express(N) == N.x + assert A.dcm(N) == Matrix([[-1, 0, 0], + [0, -cos(q), -sin(q)], + [0, -sin(q), cos(q)]]) + assert A.ang_vel_in(N) == u * N.x + assert A.ang_vel_in(N).magnitude() == sqrt(u ** 2) + + N, A, P, C = _generate_body() + with warns_deprecated_sympy(): + PrismaticJoint('J', P, C, parent_axis=P.x + P.y) + assert (A.x).angle_between(N.x + N.y) == 0 + assert A.x.express(N) == (N.x + N.y) / sqrt(2) + assert A.dcm(N) == Matrix([[sqrt(2) / 2, sqrt(2) / 2, 0], + [-sqrt(2) / 2, sqrt(2) / 2, 0], [0, 0, 1]]) + assert A.ang_vel_in(N) == Vector(0) + + +def test_deprecated_joint_pos(): + N, A, P, C = _generate_body() + with warns_deprecated_sympy(): + pin = PinJoint('J', P, C, parent_joint_pos=N.x + N.y, + child_joint_pos=C.y - C.z) + assert pin.parent_point.pos_from(P.masscenter) == N.x + N.y + assert pin.child_point.pos_from(C.masscenter) == C.y - C.z + + N, A, P, C = _generate_body() + with warns_deprecated_sympy(): + slider = PrismaticJoint('J', P, C, parent_joint_pos=N.z + N.y, + child_joint_pos=C.y - C.x) + assert slider.parent_point.pos_from(P.masscenter) == N.z + N.y + assert slider.child_point.pos_from(C.masscenter) == C.y - C.x diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_jointsmethod.py b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_jointsmethod.py new file mode 100644 index 0000000000000000000000000000000000000000..1b48eae06dadc627442fd4e42445450be0393e33 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_jointsmethod.py @@ -0,0 +1,249 @@ +from sympy.core.function import expand +from sympy.core.symbol import symbols +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.matrices.dense import Matrix +from sympy.simplify.trigsimp import trigsimp +from sympy.physics.mechanics import ( + PinJoint, JointsMethod, RigidBody, Particle, Body, KanesMethod, + PrismaticJoint, LagrangesMethod, inertia) +from sympy.physics.vector import dynamicsymbols, ReferenceFrame +from sympy.testing.pytest import raises, warns_deprecated_sympy +from sympy import zeros +from sympy.utilities.lambdify import lambdify +from sympy.solvers.solvers import solve + + +t = dynamicsymbols._t # type: ignore + + +def test_jointsmethod(): + with warns_deprecated_sympy(): + P = Body('P') + C = Body('C') + Pin = PinJoint('P1', P, C) + C_ixx, g = symbols('C_ixx g') + q, u = dynamicsymbols('q_P1, u_P1') + P.apply_force(g*P.y) + with warns_deprecated_sympy(): + method = JointsMethod(P, Pin) + assert method.frame == P.frame + assert method.bodies == [C, P] + assert method.loads == [(P.masscenter, g*P.frame.y)] + assert method.q == Matrix([q]) + assert method.u == Matrix([u]) + assert method.kdes == Matrix([u - q.diff()]) + soln = method.form_eoms() + assert soln == Matrix([[-C_ixx*u.diff()]]) + assert method.forcing_full == Matrix([[u], [0]]) + assert method.mass_matrix_full == Matrix([[1, 0], [0, C_ixx]]) + assert isinstance(method.method, KanesMethod) + + +def test_rigid_body_particle_compatibility(): + l, m, g = symbols('l m g') + C = RigidBody('C') + b = Particle('b', mass=m) + b_frame = ReferenceFrame('b_frame') + q, u = dynamicsymbols('q u') + P = PinJoint('P', C, b, coordinates=q, speeds=u, child_interframe=b_frame, + child_point=-l * b_frame.x, joint_axis=C.z) + with warns_deprecated_sympy(): + method = JointsMethod(C, P) + method.loads.append((b.masscenter, m * g * C.x)) + method.form_eoms() + rhs = method.rhs() + assert rhs[1] == -g*sin(q)/l + + +def test_jointmethod_duplicate_coordinates_speeds(): + with warns_deprecated_sympy(): + P = Body('P') + C = Body('C') + T = Body('T') + q, u = dynamicsymbols('q u') + P1 = PinJoint('P1', P, C, q) + P2 = PrismaticJoint('P2', C, T, q) + with warns_deprecated_sympy(): + raises(ValueError, lambda: JointsMethod(P, P1, P2)) + + P1 = PinJoint('P1', P, C, speeds=u) + P2 = PrismaticJoint('P2', C, T, speeds=u) + with warns_deprecated_sympy(): + raises(ValueError, lambda: JointsMethod(P, P1, P2)) + + P1 = PinJoint('P1', P, C, q, u) + P2 = PrismaticJoint('P2', C, T, q, u) + with warns_deprecated_sympy(): + raises(ValueError, lambda: JointsMethod(P, P1, P2)) + +def test_complete_simple_double_pendulum(): + q1, q2 = dynamicsymbols('q1 q2') + u1, u2 = dynamicsymbols('u1 u2') + m, l, g = symbols('m l g') + with warns_deprecated_sympy(): + C = Body('C') # ceiling + PartP = Body('P', mass=m) + PartR = Body('R', mass=m) + J1 = PinJoint('J1', C, PartP, speeds=u1, coordinates=q1, + child_point=-l*PartP.x, joint_axis=C.z) + J2 = PinJoint('J2', PartP, PartR, speeds=u2, coordinates=q2, + child_point=-l*PartR.x, joint_axis=PartP.z) + + PartP.apply_force(m*g*C.x) + PartR.apply_force(m*g*C.x) + + with warns_deprecated_sympy(): + method = JointsMethod(C, J1, J2) + method.form_eoms() + + assert expand(method.mass_matrix_full) == Matrix([[1, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 2*l**2*m*cos(q2) + 3*l**2*m, l**2*m*cos(q2) + l**2*m], + [0, 0, l**2*m*cos(q2) + l**2*m, l**2*m]]) + assert trigsimp(method.forcing_full) == trigsimp(Matrix([[u1], [u2], [-g*l*m*(sin(q1 + q2) + sin(q1)) - + g*l*m*sin(q1) + l**2*m*(2*u1 + u2)*u2*sin(q2)], + [-g*l*m*sin(q1 + q2) - l**2*m*u1**2*sin(q2)]])) + +def test_two_dof_joints(): + q1, q2, u1, u2 = dynamicsymbols('q1 q2 u1 u2') + m, c1, c2, k1, k2 = symbols('m c1 c2 k1 k2') + with warns_deprecated_sympy(): + W = Body('W') + B1 = Body('B1', mass=m) + B2 = Body('B2', mass=m) + J1 = PrismaticJoint('J1', W, B1, coordinates=q1, speeds=u1) + J2 = PrismaticJoint('J2', B1, B2, coordinates=q2, speeds=u2) + W.apply_force(k1*q1*W.x, reaction_body=B1) + W.apply_force(c1*u1*W.x, reaction_body=B1) + B1.apply_force(k2*q2*W.x, reaction_body=B2) + B1.apply_force(c2*u2*W.x, reaction_body=B2) + with warns_deprecated_sympy(): + method = JointsMethod(W, J1, J2) + method.form_eoms() + MM = method.mass_matrix + forcing = method.forcing + rhs = MM.LUsolve(forcing) + assert expand(rhs[0]) == expand((-k1 * q1 - c1 * u1 + k2 * q2 + c2 * u2)/m) + assert expand(rhs[1]) == expand((k1 * q1 + c1 * u1 - 2 * k2 * q2 - 2 * + c2 * u2) / m) + +def test_simple_pedulum(): + l, m, g = symbols('l m g') + with warns_deprecated_sympy(): + C = Body('C') + b = Body('b', mass=m) + q = dynamicsymbols('q') + P = PinJoint('P', C, b, speeds=q.diff(t), coordinates=q, + child_point=-l * b.x, joint_axis=C.z) + b.potential_energy = - m * g * l * cos(q) + with warns_deprecated_sympy(): + method = JointsMethod(C, P) + method.form_eoms(LagrangesMethod) + rhs = method.rhs() + assert rhs[1] == -g*sin(q)/l + +def test_chaos_pendulum(): + #https://www.pydy.org/examples/chaos_pendulum.html + mA, mB, lA, lB, IAxx, IBxx, IByy, IBzz, g = symbols('mA, mB, lA, lB, IAxx, IBxx, IByy, IBzz, g') + theta, phi, omega, alpha = dynamicsymbols('theta phi omega alpha') + + A = ReferenceFrame('A') + B = ReferenceFrame('B') + + with warns_deprecated_sympy(): + rod = Body('rod', mass=mA, frame=A, + central_inertia=inertia(A, IAxx, IAxx, 0)) + plate = Body('plate', mass=mB, frame=B, + central_inertia=inertia(B, IBxx, IByy, IBzz)) + C = Body('C') + J1 = PinJoint('J1', C, rod, coordinates=theta, speeds=omega, + child_point=-lA * rod.z, joint_axis=C.y) + J2 = PinJoint('J2', rod, plate, coordinates=phi, speeds=alpha, + parent_point=(lB - lA) * rod.z, joint_axis=rod.z) + + rod.apply_force(mA*g*C.z) + plate.apply_force(mB*g*C.z) + + with warns_deprecated_sympy(): + method = JointsMethod(C, J1, J2) + method.form_eoms() + + MM = method.mass_matrix + forcing = method.forcing + rhs = MM.LUsolve(forcing) + xd = (-2 * IBxx * alpha * omega * sin(phi) * cos(phi) + 2 * IByy * alpha * omega * sin(phi) * + cos(phi) - g * lA * mA * sin(theta) - g * lB * mB * sin(theta)) / (IAxx + IBxx * + sin(phi)**2 + IByy * cos(phi)**2 + lA**2 * mA + lB**2 * mB) + assert (rhs[0] - xd).simplify() == 0 + xd = (IBxx - IByy) * omega**2 * sin(phi) * cos(phi) / IBzz + assert (rhs[1] - xd).simplify() == 0 + +def test_four_bar_linkage_with_manual_constraints(): + q1, q2, q3, u1, u2, u3 = dynamicsymbols('q1:4, u1:4') + l1, l2, l3, l4, rho = symbols('l1:5, rho') + + N = ReferenceFrame('N') + inertias = [inertia(N, 0, 0, rho * l ** 3 / 12) for l in (l1, l2, l3, l4)] + with warns_deprecated_sympy(): + link1 = Body('Link1', frame=N, mass=rho * l1, + central_inertia=inertias[0]) + link2 = Body('Link2', mass=rho * l2, central_inertia=inertias[1]) + link3 = Body('Link3', mass=rho * l3, central_inertia=inertias[2]) + link4 = Body('Link4', mass=rho * l4, central_inertia=inertias[3]) + + joint1 = PinJoint( + 'J1', link1, link2, coordinates=q1, speeds=u1, joint_axis=link1.z, + parent_point=l1 / 2 * link1.x, child_point=-l2 / 2 * link2.x) + joint2 = PinJoint( + 'J2', link2, link3, coordinates=q2, speeds=u2, joint_axis=link2.z, + parent_point=l2 / 2 * link2.x, child_point=-l3 / 2 * link3.x) + joint3 = PinJoint( + 'J3', link3, link4, coordinates=q3, speeds=u3, joint_axis=link3.z, + parent_point=l3 / 2 * link3.x, child_point=-l4 / 2 * link4.x) + + loop = link4.masscenter.pos_from(link1.masscenter) \ + + l1 / 2 * link1.x + l4 / 2 * link4.x + + fh = Matrix([loop.dot(link1.x), loop.dot(link1.y)]) + + with warns_deprecated_sympy(): + method = JointsMethod(link1, joint1, joint2, joint3) + + t = dynamicsymbols._t + qdots = solve(method.kdes, [q1.diff(t), q2.diff(t), q3.diff(t)]) + fhd = fh.diff(t).subs(qdots) + + kane = KanesMethod(method.frame, q_ind=[q1], u_ind=[u1], + q_dependent=[q2, q3], u_dependent=[u2, u3], + kd_eqs=method.kdes, configuration_constraints=fh, + velocity_constraints=fhd, forcelist=method.loads, + bodies=method.bodies) + fr, frs = kane.kanes_equations() + assert fr == zeros(1) + + # Numerically check the mass- and forcing-matrix + p = Matrix([l1, l2, l3, l4, rho]) + q = Matrix([q1, q2, q3]) + u = Matrix([u1, u2, u3]) + eval_m = lambdify((q, p), kane.mass_matrix) + eval_f = lambdify((q, u, p), kane.forcing) + eval_fhd = lambdify((q, u, p), fhd) + + p_vals = [0.13, 0.24, 0.21, 0.34, 997] + q_vals = [2.1, 0.6655470375077588, 2.527408138024188] # Satisfies fh + u_vals = [0.2, -0.17963733938852067, 0.1309060540601612] # Satisfies fhd + mass_check = Matrix([[3.452709815256506e+01, 7.003948798374735e+00, + -4.939690970641498e+00], + [-2.203792703880936e-14, 2.071702479957077e-01, + 2.842917573033711e-01], + [-1.300000000000123e-01, -8.836934896046506e-03, + 1.864891330060847e-01]]) + forcing_check = Matrix([[-0.031211821321648], + [-0.00066022608181], + [0.001813559741243]]) + eps = 1e-10 + assert all(abs(x) < eps for x in eval_fhd(q_vals, u_vals, p_vals)) + assert all(abs(x) < eps for x in + (Matrix(eval_m(q_vals, p_vals)) - mass_check)) + assert all(abs(x) < eps for x in + (Matrix(eval_f(q_vals, u_vals, p_vals)) - forcing_check)) diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_kane.py b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_kane.py new file mode 100644 index 0000000000000000000000000000000000000000..5f9310aae6d720c32615a86df5e488f46a513c76 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_kane.py @@ -0,0 +1,553 @@ +from sympy import solve +from sympy import (cos, expand, Matrix, sin, symbols, tan, sqrt, S, + zeros, eye) +from sympy.simplify.simplify import simplify +from sympy.physics.mechanics import (dynamicsymbols, ReferenceFrame, Point, + RigidBody, KanesMethod, inertia, Particle, + dot, find_dynamicsymbols) +from sympy.testing.pytest import raises + + +def test_invalid_coordinates(): + # Simple pendulum, but use symbols instead of dynamicsymbols + l, m, g = symbols('l m g') + q, u = symbols('q u') # Generalized coordinate + kd = [q.diff(dynamicsymbols._t) - u] + N, O = ReferenceFrame('N'), Point('O') + O.set_vel(N, 0) + P = Particle('P', Point('P'), m) + P.point.set_pos(O, l * (sin(q) * N.x - cos(q) * N.y)) + F = (P.point, -m * g * N.y) + raises(ValueError, lambda: KanesMethod(N, [q], [u], kd, bodies=[P], + forcelist=[F])) + + +def test_one_dof(): + # This is for a 1 dof spring-mass-damper case. + # It is described in more detail in the KanesMethod docstring. + q, u = dynamicsymbols('q u') + qd, ud = dynamicsymbols('q u', 1) + m, c, k = symbols('m c k') + N = ReferenceFrame('N') + P = Point('P') + P.set_vel(N, u * N.x) + + kd = [qd - u] + FL = [(P, (-k * q - c * u) * N.x)] + pa = Particle('pa', P, m) + BL = [pa] + + KM = KanesMethod(N, [q], [u], kd) + KM.kanes_equations(BL, FL) + + assert KM.bodies == BL + assert KM.loads == FL + + MM = KM.mass_matrix + forcing = KM.forcing + rhs = MM.inv() * forcing + assert expand(rhs[0]) == expand(-(q * k + u * c) / m) + + assert simplify(KM.rhs() - + KM.mass_matrix_full.LUsolve(KM.forcing_full)) == zeros(2, 1) + + assert (KM.linearize(A_and_B=True, )[0] == Matrix([[0, 1], [-k/m, -c/m]])) + + +def test_two_dof(): + # This is for a 2 d.o.f., 2 particle spring-mass-damper. + # The first coordinate is the displacement of the first particle, and the + # second is the relative displacement between the first and second + # particles. Speeds are defined as the time derivatives of the particles. + q1, q2, u1, u2 = dynamicsymbols('q1 q2 u1 u2') + q1d, q2d, u1d, u2d = dynamicsymbols('q1 q2 u1 u2', 1) + m, c1, c2, k1, k2 = symbols('m c1 c2 k1 k2') + N = ReferenceFrame('N') + P1 = Point('P1') + P2 = Point('P2') + P1.set_vel(N, u1 * N.x) + P2.set_vel(N, (u1 + u2) * N.x) + # Note we multiply the kinematic equation by an arbitrary factor + # to test the implicit vs explicit kinematics attribute + kd = [q1d/2 - u1/2, 2*q2d - 2*u2] + + # Now we create the list of forces, then assign properties to each + # particle, then create a list of all particles. + FL = [(P1, (-k1 * q1 - c1 * u1 + k2 * q2 + c2 * u2) * N.x), (P2, (-k2 * + q2 - c2 * u2) * N.x)] + pa1 = Particle('pa1', P1, m) + pa2 = Particle('pa2', P2, m) + BL = [pa1, pa2] + + # Finally we create the KanesMethod object, specify the inertial frame, + # pass relevant information, and form Fr & Fr*. Then we calculate the mass + # matrix and forcing terms, and finally solve for the udots. + KM = KanesMethod(N, q_ind=[q1, q2], u_ind=[u1, u2], kd_eqs=kd) + KM.kanes_equations(BL, FL) + MM = KM.mass_matrix + forcing = KM.forcing + rhs = MM.inv() * forcing + assert expand(rhs[0]) == expand((-k1 * q1 - c1 * u1 + k2 * q2 + c2 * u2)/m) + assert expand(rhs[1]) == expand((k1 * q1 + c1 * u1 - 2 * k2 * q2 - 2 * + c2 * u2) / m) + + # Check that the explicit form is the default and kinematic mass matrix is identity + assert KM.explicit_kinematics + assert KM.mass_matrix_kin == eye(2) + + # Check that for the implicit form the mass matrix is not identity + KM.explicit_kinematics = False + assert KM.mass_matrix_kin == Matrix([[S(1)/2, 0], [0, 2]]) + + # Check that whether using implicit or explicit kinematics the RHS + # equations are consistent with the matrix form + for explicit_kinematics in [False, True]: + KM.explicit_kinematics = explicit_kinematics + assert simplify(KM.rhs() - + KM.mass_matrix_full.LUsolve(KM.forcing_full)) == zeros(4, 1) + + # Make sure an error is raised if nonlinear kinematic differential + # equations are supplied. + kd = [q1d - u1**2, sin(q2d) - cos(u2)] + raises(ValueError, lambda: KanesMethod(N, q_ind=[q1, q2], + u_ind=[u1, u2], kd_eqs=kd)) + +def test_pend(): + q, u = dynamicsymbols('q u') + qd, ud = dynamicsymbols('q u', 1) + m, l, g = symbols('m l g') + N = ReferenceFrame('N') + P = Point('P') + P.set_vel(N, -l * u * sin(q) * N.x + l * u * cos(q) * N.y) + kd = [qd - u] + + FL = [(P, m * g * N.x)] + pa = Particle('pa', P, m) + BL = [pa] + + KM = KanesMethod(N, [q], [u], kd) + KM.kanes_equations(BL, FL) + MM = KM.mass_matrix + forcing = KM.forcing + rhs = MM.inv() * forcing + rhs.simplify() + assert expand(rhs[0]) == expand(-g / l * sin(q)) + assert simplify(KM.rhs() - + KM.mass_matrix_full.LUsolve(KM.forcing_full)) == zeros(2, 1) + + +def test_rolling_disc(): + # Rolling Disc Example + # Here the rolling disc is formed from the contact point up, removing the + # need to introduce generalized speeds. Only 3 configuration and three + # speed variables are need to describe this system, along with the disc's + # mass and radius, and the local gravity (note that mass will drop out). + q1, q2, q3, u1, u2, u3 = dynamicsymbols('q1 q2 q3 u1 u2 u3') + q1d, q2d, q3d, u1d, u2d, u3d = dynamicsymbols('q1 q2 q3 u1 u2 u3', 1) + r, m, g = symbols('r m g') + + # The kinematics are formed by a series of simple rotations. Each simple + # rotation creates a new frame, and the next rotation is defined by the new + # frame's basis vectors. This example uses a 3-1-2 series of rotations, or + # Z, X, Y series of rotations. Angular velocity for this is defined using + # the second frame's basis (the lean frame). + N = ReferenceFrame('N') + Y = N.orientnew('Y', 'Axis', [q1, N.z]) + L = Y.orientnew('L', 'Axis', [q2, Y.x]) + R = L.orientnew('R', 'Axis', [q3, L.y]) + w_R_N_qd = R.ang_vel_in(N) + R.set_ang_vel(N, u1 * L.x + u2 * L.y + u3 * L.z) + + # This is the translational kinematics. We create a point with no velocity + # in N; this is the contact point between the disc and ground. Next we form + # the position vector from the contact point to the disc's center of mass. + # Finally we form the velocity and acceleration of the disc. + C = Point('C') + C.set_vel(N, 0) + Dmc = C.locatenew('Dmc', r * L.z) + Dmc.v2pt_theory(C, N, R) + + # This is a simple way to form the inertia dyadic. + I = inertia(L, m / 4 * r**2, m / 2 * r**2, m / 4 * r**2) + + # Kinematic differential equations; how the generalized coordinate time + # derivatives relate to generalized speeds. + kd = [dot(R.ang_vel_in(N) - w_R_N_qd, uv) for uv in L] + + # Creation of the force list; it is the gravitational force at the mass + # center of the disc. Then we create the disc by assigning a Point to the + # center of mass attribute, a ReferenceFrame to the frame attribute, and mass + # and inertia. Then we form the body list. + ForceList = [(Dmc, - m * g * Y.z)] + BodyD = RigidBody('BodyD', Dmc, R, m, (I, Dmc)) + BodyList = [BodyD] + + # Finally we form the equations of motion, using the same steps we did + # before. Specify inertial frame, supply generalized speeds, supply + # kinematic differential equation dictionary, compute Fr from the force + # list and Fr* from the body list, compute the mass matrix and forcing + # terms, then solve for the u dots (time derivatives of the generalized + # speeds). + KM = KanesMethod(N, q_ind=[q1, q2, q3], u_ind=[u1, u2, u3], kd_eqs=kd) + KM.kanes_equations(BodyList, ForceList) + MM = KM.mass_matrix + forcing = KM.forcing + rhs = MM.inv() * forcing + kdd = KM.kindiffdict() + rhs = rhs.subs(kdd) + rhs.simplify() + assert rhs.expand() == Matrix([(6*u2*u3*r - u3**2*r*tan(q2) + + 4*g*sin(q2))/(5*r), -2*u1*u3/3, u1*(-2*u2 + u3*tan(q2))]).expand() + assert simplify(KM.rhs() - + KM.mass_matrix_full.LUsolve(KM.forcing_full)) == zeros(6, 1) + + # This code tests our output vs. benchmark values. When r=g=m=1, the + # critical speed (where all eigenvalues of the linearized equations are 0) + # is 1 / sqrt(3) for the upright case. + A = KM.linearize(A_and_B=True)[0] + A_upright = A.subs({r: 1, g: 1, m: 1}).subs({q1: 0, q2: 0, q3: 0, u1: 0, u3: 0}) + import sympy + assert sympy.sympify(A_upright.subs({u2: 1 / sqrt(3)})).eigenvals() == {S.Zero: 6} + + +def test_aux(): + # Same as above, except we have 2 auxiliary speeds for the ground contact + # point, which is known to be zero. In one case, we go through then + # substitute the aux. speeds in at the end (they are zero, as well as their + # derivative), in the other case, we use the built-in auxiliary speed part + # of KanesMethod. The equations from each should be the same. + q1, q2, q3, u1, u2, u3 = dynamicsymbols('q1 q2 q3 u1 u2 u3') + q1d, q2d, q3d, u1d, u2d, u3d = dynamicsymbols('q1 q2 q3 u1 u2 u3', 1) + u4, u5, f1, f2 = dynamicsymbols('u4, u5, f1, f2') + u4d, u5d = dynamicsymbols('u4, u5', 1) + r, m, g = symbols('r m g') + + N = ReferenceFrame('N') + Y = N.orientnew('Y', 'Axis', [q1, N.z]) + L = Y.orientnew('L', 'Axis', [q2, Y.x]) + R = L.orientnew('R', 'Axis', [q3, L.y]) + w_R_N_qd = R.ang_vel_in(N) + R.set_ang_vel(N, u1 * L.x + u2 * L.y + u3 * L.z) + + C = Point('C') + C.set_vel(N, u4 * L.x + u5 * (Y.z ^ L.x)) + Dmc = C.locatenew('Dmc', r * L.z) + Dmc.v2pt_theory(C, N, R) + Dmc.a2pt_theory(C, N, R) + + I = inertia(L, m / 4 * r**2, m / 2 * r**2, m / 4 * r**2) + + kd = [dot(R.ang_vel_in(N) - w_R_N_qd, uv) for uv in L] + + ForceList = [(Dmc, - m * g * Y.z), (C, f1 * L.x + f2 * (Y.z ^ L.x))] + BodyD = RigidBody('BodyD', Dmc, R, m, (I, Dmc)) + BodyList = [BodyD] + + KM = KanesMethod(N, q_ind=[q1, q2, q3], u_ind=[u1, u2, u3, u4, u5], + kd_eqs=kd) + (fr, frstar) = KM.kanes_equations(BodyList, ForceList) + fr = fr.subs({u4d: 0, u5d: 0}).subs({u4: 0, u5: 0}) + frstar = frstar.subs({u4d: 0, u5d: 0}).subs({u4: 0, u5: 0}) + + KM2 = KanesMethod(N, q_ind=[q1, q2, q3], u_ind=[u1, u2, u3], kd_eqs=kd, + u_auxiliary=[u4, u5]) + (fr2, frstar2) = KM2.kanes_equations(BodyList, ForceList) + fr2 = fr2.subs({u4d: 0, u5d: 0}).subs({u4: 0, u5: 0}) + frstar2 = frstar2.subs({u4d: 0, u5d: 0}).subs({u4: 0, u5: 0}) + + frstar.simplify() + frstar2.simplify() + + assert (fr - fr2).expand() == Matrix([0, 0, 0, 0, 0]) + assert (frstar - frstar2).expand() == Matrix([0, 0, 0, 0, 0]) + + +def test_parallel_axis(): + # This is for a 2 dof inverted pendulum on a cart. + # This tests the parallel axis code in KanesMethod. The inertia of the + # pendulum is defined about the hinge, not about the center of mass. + + # Defining the constants and knowns of the system + gravity = symbols('g') + k, ls = symbols('k ls') + a, mA, mC = symbols('a mA mC') + F = dynamicsymbols('F') + Ix, Iy, Iz = symbols('Ix Iy Iz') + + # Declaring the Generalized coordinates and speeds + q1, q2 = dynamicsymbols('q1 q2') + q1d, q2d = dynamicsymbols('q1 q2', 1) + u1, u2 = dynamicsymbols('u1 u2') + u1d, u2d = dynamicsymbols('u1 u2', 1) + + # Creating reference frames + N = ReferenceFrame('N') + A = ReferenceFrame('A') + + A.orient(N, 'Axis', [-q2, N.z]) + A.set_ang_vel(N, -u2 * N.z) + + # Origin of Newtonian reference frame + O = Point('O') + + # Creating and Locating the positions of the cart, C, and the + # center of mass of the pendulum, A + C = O.locatenew('C', q1 * N.x) + Ao = C.locatenew('Ao', a * A.y) + + # Defining velocities of the points + O.set_vel(N, 0) + C.set_vel(N, u1 * N.x) + Ao.v2pt_theory(C, N, A) + Cart = Particle('Cart', C, mC) + Pendulum = RigidBody('Pendulum', Ao, A, mA, (inertia(A, Ix, Iy, Iz), C)) + + # kinematical differential equations + + kindiffs = [q1d - u1, q2d - u2] + + bodyList = [Cart, Pendulum] + + forceList = [(Ao, -N.y * gravity * mA), + (C, -N.y * gravity * mC), + (C, -N.x * k * (q1 - ls)), + (C, N.x * F)] + + km = KanesMethod(N, [q1, q2], [u1, u2], kindiffs) + (fr, frstar) = km.kanes_equations(bodyList, forceList) + mm = km.mass_matrix_full + assert mm[3, 3] == Iz + +def test_input_format(): + # 1 dof problem from test_one_dof + q, u = dynamicsymbols('q u') + qd, ud = dynamicsymbols('q u', 1) + m, c, k = symbols('m c k') + N = ReferenceFrame('N') + P = Point('P') + P.set_vel(N, u * N.x) + + kd = [qd - u] + FL = [(P, (-k * q - c * u) * N.x)] + pa = Particle('pa', P, m) + BL = [pa] + + KM = KanesMethod(N, [q], [u], kd) + # test for input format kane.kanes_equations((body1, body2, particle1)) + assert KM.kanes_equations(BL)[0] == Matrix([0]) + # test for input format kane.kanes_equations(bodies=(body1, body 2), loads=(load1,load2)) + assert KM.kanes_equations(bodies=BL, loads=None)[0] == Matrix([0]) + # test for input format kane.kanes_equations(bodies=(body1, body 2), loads=None) + assert KM.kanes_equations(BL, loads=None)[0] == Matrix([0]) + # test for input format kane.kanes_equations(bodies=(body1, body 2)) + assert KM.kanes_equations(BL)[0] == Matrix([0]) + # test for input format kane.kanes_equations(bodies=(body1, body2), loads=[]) + assert KM.kanes_equations(BL, [])[0] == Matrix([0]) + # test for error raised when a wrong force list (in this case a string) is provided + raises(ValueError, lambda: KM._form_fr('bad input')) + + # 1 dof problem from test_one_dof with FL & BL in instance + KM = KanesMethod(N, [q], [u], kd, bodies=BL, forcelist=FL) + assert KM.kanes_equations()[0] == Matrix([-c*u - k*q]) + + # 2 dof problem from test_two_dof + q1, q2, u1, u2 = dynamicsymbols('q1 q2 u1 u2') + q1d, q2d, u1d, u2d = dynamicsymbols('q1 q2 u1 u2', 1) + m, c1, c2, k1, k2 = symbols('m c1 c2 k1 k2') + N = ReferenceFrame('N') + P1 = Point('P1') + P2 = Point('P2') + P1.set_vel(N, u1 * N.x) + P2.set_vel(N, (u1 + u2) * N.x) + kd = [q1d - u1, q2d - u2] + + FL = ((P1, (-k1 * q1 - c1 * u1 + k2 * q2 + c2 * u2) * N.x), (P2, (-k2 * + q2 - c2 * u2) * N.x)) + pa1 = Particle('pa1', P1, m) + pa2 = Particle('pa2', P2, m) + BL = (pa1, pa2) + + KM = KanesMethod(N, q_ind=[q1, q2], u_ind=[u1, u2], kd_eqs=kd) + # test for input format + # kane.kanes_equations((body1, body2), (load1, load2)) + KM.kanes_equations(BL, FL) + MM = KM.mass_matrix + forcing = KM.forcing + rhs = MM.inv() * forcing + assert expand(rhs[0]) == expand((-k1 * q1 - c1 * u1 + k2 * q2 + c2 * u2)/m) + assert expand(rhs[1]) == expand((k1 * q1 + c1 * u1 - 2 * k2 * q2 - 2 * + c2 * u2) / m) + + +def test_implicit_kinematics(): + # Test that implicit kinematics can handle complicated + # equations that explicit form struggles with + # See https://github.com/sympy/sympy/issues/22626 + + # Inertial frame + NED = ReferenceFrame('NED') + NED_o = Point('NED_o') + NED_o.set_vel(NED, 0) + + # body frame + q_att = dynamicsymbols('lambda_0:4', real=True) + B = NED.orientnew('B', 'Quaternion', q_att) + + # Generalized coordinates + q_pos = dynamicsymbols('B_x:z') + B_cm = NED_o.locatenew('B_cm', q_pos[0]*B.x + q_pos[1]*B.y + q_pos[2]*B.z) + + q_ind = q_att[1:] + q_pos + q_dep = [q_att[0]] + + kinematic_eqs = [] + + # Generalized velocities + B_ang_vel = B.ang_vel_in(NED) + P, Q, R = dynamicsymbols('P Q R') + B.set_ang_vel(NED, P*B.x + Q*B.y + R*B.z) + + B_ang_vel_kd = (B.ang_vel_in(NED) - B_ang_vel).simplify() + + # Equating the two gives us the kinematic equation + kinematic_eqs += [ + B_ang_vel_kd & B.x, + B_ang_vel_kd & B.y, + B_ang_vel_kd & B.z + ] + + B_cm_vel = B_cm.vel(NED) + U, V, W = dynamicsymbols('U V W') + B_cm.set_vel(NED, U*B.x + V*B.y + W*B.z) + + # Compute the velocity of the point using the two methods + B_ref_vel_kd = (B_cm.vel(NED) - B_cm_vel) + + # taking dot product with unit vectors to get kinematic equations + # relating body coordinates and velocities + + # Note, there is a choice to dot with NED.xyz here. That makes + # the implicit form have some bigger terms but is still fine, the + # explicit form still struggles though + kinematic_eqs += [ + B_ref_vel_kd & B.x, + B_ref_vel_kd & B.y, + B_ref_vel_kd & B.z, + ] + + u_ind = [U, V, W, P, Q, R] + + # constraints + q_att_vec = Matrix(q_att) + config_cons = [(q_att_vec.T*q_att_vec)[0] - 1] #unit norm + kinematic_eqs = kinematic_eqs + [(q_att_vec.T * q_att_vec.diff())[0]] + + try: + KM = KanesMethod(NED, q_ind, u_ind, + q_dependent= q_dep, + kd_eqs = kinematic_eqs, + configuration_constraints = config_cons, + velocity_constraints= [], + u_dependent= [], #no dependent speeds + u_auxiliary = [], # No auxiliary speeds + explicit_kinematics = False # implicit kinematics + ) + except Exception as e: + raise e + + # mass and inertia dyadic relative to CM + M_B = symbols('M_B') + J_B = inertia(B, *[S(f'J_B_{ax}')*(1 if ax[0] == ax[1] else -1) + for ax in ['xx', 'yy', 'zz', 'xy', 'yz', 'xz']]) + J_B = J_B.subs({S('J_B_xy'): 0, S('J_B_yz'): 0}) + RB = RigidBody('RB', B_cm, B, M_B, (J_B, B_cm)) + + rigid_bodies = [RB] + # Forces + force_list = [ + #gravity pointing down + (RB.masscenter, RB.mass*S('g')*NED.z), + #generic forces and torques in body frame(inputs) + (RB.frame, dynamicsymbols('T_z')*B.z), + (RB.masscenter, dynamicsymbols('F_z')*B.z) + ] + + KM.kanes_equations(rigid_bodies, force_list) + + # Expecting implicit form to be less than 5% of the flops + n_ops_implicit = sum( + [x.count_ops() for x in KM.forcing_full] + + [x.count_ops() for x in KM.mass_matrix_full] + ) + # Save implicit kinematic matrices to use later + mass_matrix_kin_implicit = KM.mass_matrix_kin + forcing_kin_implicit = KM.forcing_kin + + KM.explicit_kinematics = True + n_ops_explicit = sum( + [x.count_ops() for x in KM.forcing_full] + + [x.count_ops() for x in KM.mass_matrix_full] + ) + forcing_kin_explicit = KM.forcing_kin + + assert n_ops_implicit / n_ops_explicit < .05 + + # Ideally we would check that implicit and explicit equations give the same result as done in test_one_dof + # But the whole raison-d'etre of the implicit equations is to deal with problems such + # as this one where the explicit form is too complicated to handle, especially the angular part + # (i.e. tests would be too slow) + # Instead, we check that the kinematic equations are correct using more fundamental tests: + # + # (1) that we recover the kinematic equations we have provided + assert (mass_matrix_kin_implicit * KM.q.diff() - forcing_kin_implicit) == Matrix(kinematic_eqs) + + # (2) that rate of quaternions matches what 'textbook' solutions give + # Note that we just use the explicit kinematics for the linear velocities + # as they are not as complicated as the angular ones + qdot_candidate = forcing_kin_explicit + + quat_dot_textbook = Matrix([ + [0, -P, -Q, -R], + [P, 0, R, -Q], + [Q, -R, 0, P], + [R, Q, -P, 0], + ]) * q_att_vec / 2 + + # Again, if we don't use this "textbook" solution + # sympy will struggle to deal with the terms related to quaternion rates + # due to the number of operations involved + qdot_candidate[-1] = quat_dot_textbook[0] # lambda_0, note the [-1] as sympy's Kane puts the dependent coordinate last + qdot_candidate[0] = quat_dot_textbook[1] # lambda_1 + qdot_candidate[1] = quat_dot_textbook[2] # lambda_2 + qdot_candidate[2] = quat_dot_textbook[3] # lambda_3 + + # sub the config constraint in the candidate solution and compare to the implicit rhs + lambda_0_sol = solve(config_cons[0], q_att_vec[0])[1] + lhs_candidate = simplify(mass_matrix_kin_implicit * qdot_candidate).subs({q_att_vec[0]: lambda_0_sol}) + assert lhs_candidate == forcing_kin_implicit + +def test_issue_24887(): + # Spherical pendulum + g, l, m, c = symbols('g l m c') + q1, q2, q3, u1, u2, u3 = dynamicsymbols('q1:4 u1:4') + N = ReferenceFrame('N') + A = ReferenceFrame('A') + A.orient_body_fixed(N, (q1, q2, q3), 'zxy') + N_w_A = A.ang_vel_in(N) + # A.set_ang_vel(N, u1 * A.x + u2 * A.y + u3 * A.z) + kdes = [N_w_A.dot(A.x) - u1, N_w_A.dot(A.y) - u2, N_w_A.dot(A.z) - u3] + O = Point('O') + O.set_vel(N, 0) + Po = O.locatenew('Po', -l * A.y) + Po.set_vel(A, 0) + P = Particle('P', Po, m) + kane = KanesMethod(N, [q1, q2, q3], [u1, u2, u3], kdes, bodies=[P], + forcelist=[(Po, -m * g * N.y)]) + kane.kanes_equations() + expected_md = m * l ** 2 * Matrix([[1, 0, 0], [0, 0, 0], [0, 0, 1]]) + expected_fd = Matrix([ + [l*m*(g*(sin(q1)*sin(q3) - sin(q2)*cos(q1)*cos(q3)) - l*u2*u3)], + [0], [l*m*(-g*(sin(q1)*cos(q3) + sin(q2)*sin(q3)*cos(q1)) + l*u1*u2)]]) + assert find_dynamicsymbols(kane.forcing).issubset({q1, q2, q3, u1, u2, u3}) + assert simplify(kane.mass_matrix - expected_md) == zeros(3, 3) + assert simplify(kane.forcing - expected_fd) == zeros(3, 1) diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_kane2.py b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_kane2.py new file mode 100644 index 0000000000000000000000000000000000000000..e55866672aec0adcd951e772964a4ed205b56405 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_kane2.py @@ -0,0 +1,464 @@ +from sympy import cos, Matrix, sin, zeros, tan, pi, symbols +from sympy.simplify.simplify import simplify +from sympy.simplify.trigsimp import trigsimp +from sympy.solvers.solvers import solve +from sympy.physics.mechanics import (cross, dot, dynamicsymbols, + find_dynamicsymbols, KanesMethod, inertia, + inertia_of_point_mass, Point, + ReferenceFrame, RigidBody) + + +def test_aux_dep(): + # This test is about rolling disc dynamics, comparing the results found + # with KanesMethod to those found when deriving the equations "manually" + # with SymPy. + # The terms Fr, Fr*, and Fr*_steady are all compared between the two + # methods. Here, Fr*_steady refers to the generalized inertia forces for an + # equilibrium configuration. + # Note: comparing to the test of test_rolling_disc() in test_kane.py, this + # test also tests auxiliary speeds and configuration and motion constraints + #, seen in the generalized dependent coordinates q[3], and depend speeds + # u[3], u[4] and u[5]. + + + # First, manual derivation of Fr, Fr_star, Fr_star_steady. + + # Symbols for time and constant parameters. + # Symbols for contact forces: Fx, Fy, Fz. + t, r, m, g, I, J = symbols('t r m g I J') + Fx, Fy, Fz = symbols('Fx Fy Fz') + + # Configuration variables and their time derivatives: + # q[0] -- yaw + # q[1] -- lean + # q[2] -- spin + # q[3] -- dot(-r*B.z, A.z) -- distance from ground plane to disc center in + # A.z direction + # Generalized speeds and their time derivatives: + # u[0] -- disc angular velocity component, disc fixed x direction + # u[1] -- disc angular velocity component, disc fixed y direction + # u[2] -- disc angular velocity component, disc fixed z direction + # u[3] -- disc velocity component, A.x direction + # u[4] -- disc velocity component, A.y direction + # u[5] -- disc velocity component, A.z direction + # Auxiliary generalized speeds: + # ua[0] -- contact point auxiliary generalized speed, A.x direction + # ua[1] -- contact point auxiliary generalized speed, A.y direction + # ua[2] -- contact point auxiliary generalized speed, A.z direction + q = dynamicsymbols('q:4') + qd = [qi.diff(t) for qi in q] + u = dynamicsymbols('u:6') + ud = [ui.diff(t) for ui in u] + ud_zero = dict(zip(ud, [0.]*len(ud))) + ua = dynamicsymbols('ua:3') + ua_zero = dict(zip(ua, [0.]*len(ua))) # noqa:F841 + + # Reference frames: + # Yaw intermediate frame: A. + # Lean intermediate frame: B. + # Disc fixed frame: C. + N = ReferenceFrame('N') + A = N.orientnew('A', 'Axis', [q[0], N.z]) + B = A.orientnew('B', 'Axis', [q[1], A.x]) + C = B.orientnew('C', 'Axis', [q[2], B.y]) + + # Angular velocity and angular acceleration of disc fixed frame + # u[0], u[1] and u[2] are generalized independent speeds. + C.set_ang_vel(N, u[0]*B.x + u[1]*B.y + u[2]*B.z) + C.set_ang_acc(N, C.ang_vel_in(N).diff(t, B) + + cross(B.ang_vel_in(N), C.ang_vel_in(N))) + + # Velocity and acceleration of points: + # Disc-ground contact point: P. + # Center of disc: O, defined from point P with depend coordinate: q[3] + # u[3], u[4] and u[5] are generalized dependent speeds. + P = Point('P') + P.set_vel(N, ua[0]*A.x + ua[1]*A.y + ua[2]*A.z) + O = P.locatenew('O', q[3]*A.z + r*sin(q[1])*A.y) + O.set_vel(N, u[3]*A.x + u[4]*A.y + u[5]*A.z) + O.set_acc(N, O.vel(N).diff(t, A) + cross(A.ang_vel_in(N), O.vel(N))) + + # Kinematic differential equations: + # Two equalities: one is w_c_n_qd = C.ang_vel_in(N) in three coordinates + # directions of B, for qd0, qd1 and qd2. + # the other is v_o_n_qd = O.vel(N) in A.z direction for qd3. + # Then, solve for dq/dt's in terms of u's: qd_kd. + w_c_n_qd = qd[0]*A.z + qd[1]*B.x + qd[2]*B.y + v_o_n_qd = O.pos_from(P).diff(t, A) + cross(A.ang_vel_in(N), O.pos_from(P)) + kindiffs = Matrix([dot(w_c_n_qd - C.ang_vel_in(N), uv) for uv in B] + + [dot(v_o_n_qd - O.vel(N), A.z)]) + qd_kd = solve(kindiffs, qd) # noqa:F841 + + # Values of generalized speeds during a steady turn for later substitution + # into the Fr_star_steady. + steady_conditions = solve(kindiffs.subs({qd[1] : 0, qd[3] : 0}), u) + steady_conditions.update({qd[1] : 0, qd[3] : 0}) + + # Partial angular velocities and velocities. + partial_w_C = [C.ang_vel_in(N).diff(ui, N) for ui in u + ua] + partial_v_O = [O.vel(N).diff(ui, N) for ui in u + ua] + partial_v_P = [P.vel(N).diff(ui, N) for ui in u + ua] + + # Configuration constraint: f_c, the projection of radius r in A.z direction + # is q[3]. + # Velocity constraints: f_v, for u3, u4 and u5. + # Acceleration constraints: f_a. + f_c = Matrix([dot(-r*B.z, A.z) - q[3]]) + f_v = Matrix([dot(O.vel(N) - (P.vel(N) + cross(C.ang_vel_in(N), + O.pos_from(P))), ai).expand() for ai in A]) + v_o_n = cross(C.ang_vel_in(N), O.pos_from(P)) + a_o_n = v_o_n.diff(t, A) + cross(A.ang_vel_in(N), v_o_n) + f_a = Matrix([dot(O.acc(N) - a_o_n, ai) for ai in A]) # noqa:F841 + + # Solve for constraint equations in the form of + # u_dependent = A_rs * [u_i; u_aux]. + # First, obtain constraint coefficient matrix: M_v * [u; ua] = 0; + # Second, taking u[0], u[1], u[2] as independent, + # taking u[3], u[4], u[5] as dependent, + # rearranging the matrix of M_v to be A_rs for u_dependent. + # Third, u_aux ==0 for u_dep, and resulting dictionary of u_dep_dict. + M_v = zeros(3, 9) + for i in range(3): + for j, ui in enumerate(u + ua): + M_v[i, j] = f_v[i].diff(ui) + + M_v_i = M_v[:, :3] + M_v_d = M_v[:, 3:6] + M_v_aux = M_v[:, 6:] + M_v_i_aux = M_v_i.row_join(M_v_aux) + A_rs = - M_v_d.inv() * M_v_i_aux + + u_dep = A_rs[:, :3] * Matrix(u[:3]) + u_dep_dict = dict(zip(u[3:], u_dep)) + + # Active forces: F_O acting on point O; F_P acting on point P. + # Generalized active forces (unconstrained): Fr_u = F_point * pv_point. + F_O = m*g*A.z + F_P = Fx * A.x + Fy * A.y + Fz * A.z + Fr_u = Matrix([dot(F_O, pv_o) + dot(F_P, pv_p) for pv_o, pv_p in + zip(partial_v_O, partial_v_P)]) + + # Inertia force: R_star_O. + # Inertia of disc: I_C_O, where J is a inertia component about principal axis. + # Inertia torque: T_star_C. + # Generalized inertia forces (unconstrained): Fr_star_u. + R_star_O = -m*O.acc(N) + I_C_O = inertia(B, I, J, I) + T_star_C = -(dot(I_C_O, C.ang_acc_in(N)) \ + + cross(C.ang_vel_in(N), dot(I_C_O, C.ang_vel_in(N)))) + Fr_star_u = Matrix([dot(R_star_O, pv) + dot(T_star_C, pav) for pv, pav in + zip(partial_v_O, partial_w_C)]) + + # Form nonholonomic Fr: Fr_c, and nonholonomic Fr_star: Fr_star_c. + # Also, nonholonomic Fr_star in steady turning condition: Fr_star_steady. + Fr_c = Fr_u[:3, :].col_join(Fr_u[6:, :]) + A_rs.T * Fr_u[3:6, :] + Fr_star_c = Fr_star_u[:3, :].col_join(Fr_star_u[6:, :])\ + + A_rs.T * Fr_star_u[3:6, :] + Fr_star_steady = Fr_star_c.subs(ud_zero).subs(u_dep_dict)\ + .subs(steady_conditions).subs({q[3]: -r*cos(q[1])}).expand() + + + # Second, using KaneMethod in mechanics for fr, frstar and frstar_steady. + + # Rigid Bodies: disc, with inertia I_C_O. + iner_tuple = (I_C_O, O) + disc = RigidBody('disc', O, C, m, iner_tuple) + bodyList = [disc] + + # Generalized forces: Gravity: F_o; Auxiliary forces: F_p. + F_o = (O, F_O) + F_p = (P, F_P) + forceList = [F_o, F_p] + + # KanesMethod. + kane = KanesMethod( + N, q_ind= q[:3], u_ind= u[:3], kd_eqs=kindiffs, + q_dependent=q[3:], configuration_constraints = f_c, + u_dependent=u[3:], velocity_constraints= f_v, + u_auxiliary=ua + ) + + # fr, frstar, frstar_steady and kdd(kinematic differential equations). + (fr, frstar)= kane.kanes_equations(bodyList, forceList) + frstar_steady = frstar.subs(ud_zero).subs(u_dep_dict).subs(steady_conditions)\ + .subs({q[3]: -r*cos(q[1])}).expand() + kdd = kane.kindiffdict() + + assert Matrix(Fr_c).expand() == fr.expand() + assert Matrix(Fr_star_c.subs(kdd)).expand() == frstar.expand() + # These Matrices have some Integer(0) and some Float(0). Running under + # SymEngine gives different types of zero. + assert (simplify(Matrix(Fr_star_steady).expand()).xreplace({0:0.0}) == + simplify(frstar_steady.expand()).xreplace({0:0.0})) + + syms_in_forcing = find_dynamicsymbols(kane.forcing) + for qdi in qd: + assert qdi not in syms_in_forcing + + +def test_non_central_inertia(): + # This tests that the calculation of Fr* does not depend the point + # about which the inertia of a rigid body is defined. This test solves + # exercises 8.12, 8.17 from Kane 1985. + + # Declare symbols + q1, q2, q3 = dynamicsymbols('q1:4') + q1d, q2d, q3d = dynamicsymbols('q1:4', level=1) + u1, u2, u3, u4, u5 = dynamicsymbols('u1:6') + u_prime, R, M, g, e, f, theta = symbols('u\' R, M, g, e, f, theta') + a, b, mA, mB, IA, J, K, t = symbols('a b mA mB IA J K t') + Q1, Q2, Q3 = symbols('Q1, Q2 Q3') + IA22, IA23, IA33 = symbols('IA22 IA23 IA33') + + # Reference Frames + F = ReferenceFrame('F') + P = F.orientnew('P', 'axis', [-theta, F.y]) + A = P.orientnew('A', 'axis', [q1, P.x]) + A.set_ang_vel(F, u1*A.x + u3*A.z) + # define frames for wheels + B = A.orientnew('B', 'axis', [q2, A.z]) + C = A.orientnew('C', 'axis', [q3, A.z]) + B.set_ang_vel(A, u4 * A.z) + C.set_ang_vel(A, u5 * A.z) + + # define points D, S*, Q on frame A and their velocities + pD = Point('D') + pD.set_vel(A, 0) + # u3 will not change v_D_F since wheels are still assumed to roll without slip. + pD.set_vel(F, u2 * A.y) + + pS_star = pD.locatenew('S*', e*A.y) + pQ = pD.locatenew('Q', f*A.y - R*A.x) + for p in [pS_star, pQ]: + p.v2pt_theory(pD, F, A) + + # masscenters of bodies A, B, C + pA_star = pD.locatenew('A*', a*A.y) + pB_star = pD.locatenew('B*', b*A.z) + pC_star = pD.locatenew('C*', -b*A.z) + for p in [pA_star, pB_star, pC_star]: + p.v2pt_theory(pD, F, A) + + # points of B, C touching the plane P + pB_hat = pB_star.locatenew('B^', -R*A.x) + pC_hat = pC_star.locatenew('C^', -R*A.x) + pB_hat.v2pt_theory(pB_star, F, B) + pC_hat.v2pt_theory(pC_star, F, C) + + # the velocities of B^, C^ are zero since B, C are assumed to roll without slip + kde = [q1d - u1, q2d - u4, q3d - u5] + vc = [dot(p.vel(F), A.y) for p in [pB_hat, pC_hat]] + + # inertias of bodies A, B, C + # IA22, IA23, IA33 are not specified in the problem statement, but are + # necessary to define an inertia object. Although the values of + # IA22, IA23, IA33 are not known in terms of the variables given in the + # problem statement, they do not appear in the general inertia terms. + inertia_A = inertia(A, IA, IA22, IA33, 0, IA23, 0) + inertia_B = inertia(B, K, K, J) + inertia_C = inertia(C, K, K, J) + + # define the rigid bodies A, B, C + rbA = RigidBody('rbA', pA_star, A, mA, (inertia_A, pA_star)) + rbB = RigidBody('rbB', pB_star, B, mB, (inertia_B, pB_star)) + rbC = RigidBody('rbC', pC_star, C, mB, (inertia_C, pC_star)) + + km = KanesMethod(F, q_ind=[q1, q2, q3], u_ind=[u1, u2], kd_eqs=kde, + u_dependent=[u4, u5], velocity_constraints=vc, + u_auxiliary=[u3]) + + forces = [(pS_star, -M*g*F.x), (pQ, Q1*A.x + Q2*A.y + Q3*A.z)] + bodies = [rbA, rbB, rbC] + fr, fr_star = km.kanes_equations(bodies, forces) + vc_map = solve(vc, [u4, u5]) + + # KanesMethod returns the negative of Fr, Fr* as defined in Kane1985. + fr_star_expected = Matrix([ + -(IA + 2*J*b**2/R**2 + 2*K + + mA*a**2 + 2*mB*b**2) * u1.diff(t) - mA*a*u1*u2, + -(mA + 2*mB +2*J/R**2) * u2.diff(t) + mA*a*u1**2, + 0]) + t = trigsimp(fr_star.subs(vc_map).subs({u3: 0})).doit().expand() + assert ((fr_star_expected - t).expand() == zeros(3, 1)) + + # define inertias of rigid bodies A, B, C about point D + # I_S/O = I_S/S* + I_S*/O + bodies2 = [] + for rb, I_star in zip([rbA, rbB, rbC], [inertia_A, inertia_B, inertia_C]): + I = I_star + inertia_of_point_mass(rb.mass, + rb.masscenter.pos_from(pD), + rb.frame) + bodies2.append(RigidBody('', rb.masscenter, rb.frame, rb.mass, + (I, pD))) + fr2, fr_star2 = km.kanes_equations(bodies2, forces) + + t = trigsimp(fr_star2.subs(vc_map).subs({u3: 0})).doit() + assert (fr_star_expected - t).expand() == zeros(3, 1) + +def test_sub_qdot(): + # This test solves exercises 8.12, 8.17 from Kane 1985 and defines + # some velocities in terms of q, qdot. + + ## --- Declare symbols --- + q1, q2, q3 = dynamicsymbols('q1:4') + q1d, q2d, q3d = dynamicsymbols('q1:4', level=1) + u1, u2, u3 = dynamicsymbols('u1:4') + u_prime, R, M, g, e, f, theta = symbols('u\' R, M, g, e, f, theta') + a, b, mA, mB, IA, J, K, t = symbols('a b mA mB IA J K t') + IA22, IA23, IA33 = symbols('IA22 IA23 IA33') + Q1, Q2, Q3 = symbols('Q1 Q2 Q3') + + # --- Reference Frames --- + F = ReferenceFrame('F') + P = F.orientnew('P', 'axis', [-theta, F.y]) + A = P.orientnew('A', 'axis', [q1, P.x]) + A.set_ang_vel(F, u1*A.x + u3*A.z) + # define frames for wheels + B = A.orientnew('B', 'axis', [q2, A.z]) + C = A.orientnew('C', 'axis', [q3, A.z]) + + ## --- define points D, S*, Q on frame A and their velocities --- + pD = Point('D') + pD.set_vel(A, 0) + # u3 will not change v_D_F since wheels are still assumed to roll w/o slip + pD.set_vel(F, u2 * A.y) + + pS_star = pD.locatenew('S*', e*A.y) + pQ = pD.locatenew('Q', f*A.y - R*A.x) + # masscenters of bodies A, B, C + pA_star = pD.locatenew('A*', a*A.y) + pB_star = pD.locatenew('B*', b*A.z) + pC_star = pD.locatenew('C*', -b*A.z) + for p in [pS_star, pQ, pA_star, pB_star, pC_star]: + p.v2pt_theory(pD, F, A) + + # points of B, C touching the plane P + pB_hat = pB_star.locatenew('B^', -R*A.x) + pC_hat = pC_star.locatenew('C^', -R*A.x) + pB_hat.v2pt_theory(pB_star, F, B) + pC_hat.v2pt_theory(pC_star, F, C) + + # --- relate qdot, u --- + # the velocities of B^, C^ are zero since B, C are assumed to roll w/o slip + kde = [dot(p.vel(F), A.y) for p in [pB_hat, pC_hat]] + kde += [u1 - q1d] + kde_map = solve(kde, [q1d, q2d, q3d]) + for k, v in list(kde_map.items()): + kde_map[k.diff(t)] = v.diff(t) + + # inertias of bodies A, B, C + # IA22, IA23, IA33 are not specified in the problem statement, but are + # necessary to define an inertia object. Although the values of + # IA22, IA23, IA33 are not known in terms of the variables given in the + # problem statement, they do not appear in the general inertia terms. + inertia_A = inertia(A, IA, IA22, IA33, 0, IA23, 0) + inertia_B = inertia(B, K, K, J) + inertia_C = inertia(C, K, K, J) + + # define the rigid bodies A, B, C + rbA = RigidBody('rbA', pA_star, A, mA, (inertia_A, pA_star)) + rbB = RigidBody('rbB', pB_star, B, mB, (inertia_B, pB_star)) + rbC = RigidBody('rbC', pC_star, C, mB, (inertia_C, pC_star)) + + ## --- use kanes method --- + km = KanesMethod(F, [q1, q2, q3], [u1, u2], kd_eqs=kde, u_auxiliary=[u3]) + + forces = [(pS_star, -M*g*F.x), (pQ, Q1*A.x + Q2*A.y + Q3*A.z)] + bodies = [rbA, rbB, rbC] + + # Q2 = -u_prime * u2 * Q1 / sqrt(u2**2 + f**2 * u1**2) + # -u_prime * R * u2 / sqrt(u2**2 + f**2 * u1**2) = R / Q1 * Q2 + fr_expected = Matrix([ + f*Q3 + M*g*e*sin(theta)*cos(q1), + Q2 + M*g*sin(theta)*sin(q1), + e*M*g*cos(theta) - Q1*f - Q2*R]) + #Q1 * (f - u_prime * R * u2 / sqrt(u2**2 + f**2 * u1**2)))]) + fr_star_expected = Matrix([ + -(IA + 2*J*b**2/R**2 + 2*K + + mA*a**2 + 2*mB*b**2) * u1.diff(t) - mA*a*u1*u2, + -(mA + 2*mB +2*J/R**2) * u2.diff(t) + mA*a*u1**2, + 0]) + + fr, fr_star = km.kanes_equations(bodies, forces) + assert (fr.expand() == fr_expected.expand()) + assert ((fr_star_expected - trigsimp(fr_star)).expand() == zeros(3, 1)) + +def test_sub_qdot2(): + # This test solves exercises 8.3 from Kane 1985 and defines + # all velocities in terms of q, qdot. We check that the generalized active + # forces are correctly computed if u terms are only defined in the + # kinematic differential equations. + # + # This functionality was added in PR 8948. Without qdot/u substitution, the + # KanesMethod constructor will fail during the constraint initialization as + # the B matrix will be poorly formed and inversion of the dependent part + # will fail. + + g, m, Px, Py, Pz, R, t = symbols('g m Px Py Pz R t') + q = dynamicsymbols('q:5') + qd = dynamicsymbols('q:5', level=1) + u = dynamicsymbols('u:5') + + ## Define inertial, intermediate, and rigid body reference frames + A = ReferenceFrame('A') + B_prime = A.orientnew('B_prime', 'Axis', [q[0], A.z]) + B = B_prime.orientnew('B', 'Axis', [pi/2 - q[1], B_prime.x]) + C = B.orientnew('C', 'Axis', [q[2], B.z]) + + ## Define points of interest and their velocities + pO = Point('O') + pO.set_vel(A, 0) + + # R is the point in plane H that comes into contact with disk C. + pR = pO.locatenew('R', q[3]*A.x + q[4]*A.y) + pR.set_vel(A, pR.pos_from(pO).diff(t, A)) + pR.set_vel(B, 0) + + # C^ is the point in disk C that comes into contact with plane H. + pC_hat = pR.locatenew('C^', 0) + pC_hat.set_vel(C, 0) + + # C* is the point at the center of disk C. + pCs = pC_hat.locatenew('C*', R*B.y) + pCs.set_vel(C, 0) + pCs.set_vel(B, 0) + + # calculate velocites of points C* and C^ in frame A + pCs.v2pt_theory(pR, A, B) # points C* and R are fixed in frame B + pC_hat.v2pt_theory(pCs, A, C) # points C* and C^ are fixed in frame C + + ## Define forces on each point of the system + R_C_hat = Px*A.x + Py*A.y + Pz*A.z + R_Cs = -m*g*A.z + forces = [(pC_hat, R_C_hat), (pCs, R_Cs)] + + ## Define kinematic differential equations + # let ui = omega_C_A & bi (i = 1, 2, 3) + # u4 = qd4, u5 = qd5 + u_expr = [C.ang_vel_in(A) & uv for uv in B] + u_expr += qd[3:] + kde = [ui - e for ui, e in zip(u, u_expr)] + km1 = KanesMethod(A, q, u, kde) + fr1, _ = km1.kanes_equations([], forces) + + ## Calculate generalized active forces if we impose the condition that the + # disk C is rolling without slipping + u_indep = u[:3] + u_dep = list(set(u) - set(u_indep)) + vc = [pC_hat.vel(A) & uv for uv in [A.x, A.y]] + km2 = KanesMethod(A, q, u_indep, kde, + u_dependent=u_dep, velocity_constraints=vc) + fr2, _ = km2.kanes_equations([], forces) + + fr1_expected = Matrix([ + -R*g*m*sin(q[1]), + -R*(Px*cos(q[0]) + Py*sin(q[0]))*tan(q[1]), + R*(Px*cos(q[0]) + Py*sin(q[0])), + Px, + Py]) + fr2_expected = Matrix([ + -R*g*m*sin(q[1]), + 0, + 0]) + assert (trigsimp(fr1.expand()) == trigsimp(fr1_expected.expand())) + assert (trigsimp(fr2.expand()) == trigsimp(fr2_expected.expand())) diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_kane3.py b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_kane3.py new file mode 100644 index 0000000000000000000000000000000000000000..438759451cfb142c488b9b5c67ac269b668cac68 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_kane3.py @@ -0,0 +1,315 @@ +from sympy.core.numbers import pi +from sympy.core.symbol import symbols +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import acos, sin, cos +from sympy.matrices.dense import Matrix +from sympy.physics.mechanics import (ReferenceFrame, dynamicsymbols, + KanesMethod, inertia, Point, RigidBody, + dot) +from sympy.testing.pytest import slow + + +@slow +def test_bicycle(): + # Code to get equations of motion for a bicycle modeled as in: + # J.P Meijaard, Jim M Papadopoulos, Andy Ruina and A.L Schwab. Linearized + # dynamics equations for the balance and steer of a bicycle: a benchmark + # and review. Proceedings of The Royal Society (2007) 463, 1955-1982 + # doi: 10.1098/rspa.2007.1857 + + # Note that this code has been crudely ported from Autolev, which is the + # reason for some of the unusual naming conventions. It was purposefully as + # similar as possible in order to aide debugging. + + # Declare Coordinates & Speeds + # Simple definitions for qdots - qd = u + # Speeds are: + # - u1: yaw frame ang. rate + # - u2: roll frame ang. rate + # - u3: rear wheel frame ang. rate (spinning motion) + # - u4: frame ang. rate (pitching motion) + # - u5: steering frame ang. rate + # - u6: front wheel ang. rate (spinning motion) + # Wheel positions are ignorable coordinates, so they are not introduced. + q1, q2, q4, q5 = dynamicsymbols('q1 q2 q4 q5') + q1d, q2d, q4d, q5d = dynamicsymbols('q1 q2 q4 q5', 1) + u1, u2, u3, u4, u5, u6 = dynamicsymbols('u1 u2 u3 u4 u5 u6') + u1d, u2d, u3d, u4d, u5d, u6d = dynamicsymbols('u1 u2 u3 u4 u5 u6', 1) + + # Declare System's Parameters + WFrad, WRrad, htangle, forkoffset = symbols('WFrad WRrad htangle forkoffset') + forklength, framelength, forkcg1 = symbols('forklength framelength forkcg1') + forkcg3, framecg1, framecg3, Iwr11 = symbols('forkcg3 framecg1 framecg3 Iwr11') + Iwr22, Iwf11, Iwf22, Iframe11 = symbols('Iwr22 Iwf11 Iwf22 Iframe11') + Iframe22, Iframe33, Iframe31, Ifork11 = symbols('Iframe22 Iframe33 Iframe31 Ifork11') + Ifork22, Ifork33, Ifork31, g = symbols('Ifork22 Ifork33 Ifork31 g') + mframe, mfork, mwf, mwr = symbols('mframe mfork mwf mwr') + + # Set up reference frames for the system + # N - inertial + # Y - yaw + # R - roll + # WR - rear wheel, rotation angle is ignorable coordinate so not oriented + # Frame - bicycle frame + # TempFrame - statically rotated frame for easier reference inertia definition + # Fork - bicycle fork + # TempFork - statically rotated frame for easier reference inertia definition + # WF - front wheel, again posses a ignorable coordinate + N = ReferenceFrame('N') + Y = N.orientnew('Y', 'Axis', [q1, N.z]) + R = Y.orientnew('R', 'Axis', [q2, Y.x]) + Frame = R.orientnew('Frame', 'Axis', [q4 + htangle, R.y]) + WR = ReferenceFrame('WR') + TempFrame = Frame.orientnew('TempFrame', 'Axis', [-htangle, Frame.y]) + Fork = Frame.orientnew('Fork', 'Axis', [q5, Frame.x]) + TempFork = Fork.orientnew('TempFork', 'Axis', [-htangle, Fork.y]) + WF = ReferenceFrame('WF') + + # Kinematics of the Bicycle First block of code is forming the positions of + # the relevant points + # rear wheel contact -> rear wheel mass center -> frame mass center + + # frame/fork connection -> fork mass center + front wheel mass center -> + # front wheel contact point + WR_cont = Point('WR_cont') + WR_mc = WR_cont.locatenew('WR_mc', WRrad * R.z) + Steer = WR_mc.locatenew('Steer', framelength * Frame.z) + Frame_mc = WR_mc.locatenew('Frame_mc', - framecg1 * Frame.x + + framecg3 * Frame.z) + Fork_mc = Steer.locatenew('Fork_mc', - forkcg1 * Fork.x + + forkcg3 * Fork.z) + WF_mc = Steer.locatenew('WF_mc', forklength * Fork.x + forkoffset * Fork.z) + WF_cont = WF_mc.locatenew('WF_cont', WFrad * (dot(Fork.y, Y.z) * Fork.y - + Y.z).normalize()) + + # Set the angular velocity of each frame. + # Angular accelerations end up being calculated automatically by + # differentiating the angular velocities when first needed. + # u1 is yaw rate + # u2 is roll rate + # u3 is rear wheel rate + # u4 is frame pitch rate + # u5 is fork steer rate + # u6 is front wheel rate + Y.set_ang_vel(N, u1 * Y.z) + R.set_ang_vel(Y, u2 * R.x) + WR.set_ang_vel(Frame, u3 * Frame.y) + Frame.set_ang_vel(R, u4 * Frame.y) + Fork.set_ang_vel(Frame, u5 * Fork.x) + WF.set_ang_vel(Fork, u6 * Fork.y) + + # Form the velocities of the previously defined points, using the 2 - point + # theorem (written out by hand here). Accelerations again are calculated + # automatically when first needed. + WR_cont.set_vel(N, 0) + WR_mc.v2pt_theory(WR_cont, N, WR) + Steer.v2pt_theory(WR_mc, N, Frame) + Frame_mc.v2pt_theory(WR_mc, N, Frame) + Fork_mc.v2pt_theory(Steer, N, Fork) + WF_mc.v2pt_theory(Steer, N, Fork) + WF_cont.v2pt_theory(WF_mc, N, WF) + + # Sets the inertias of each body. Uses the inertia frame to construct the + # inertia dyadics. Wheel inertias are only defined by principle moments of + # inertia, and are in fact constant in the frame and fork reference frames; + # it is for this reason that the orientations of the wheels does not need + # to be defined. The frame and fork inertias are defined in the 'Temp' + # frames which are fixed to the appropriate body frames; this is to allow + # easier input of the reference values of the benchmark paper. Note that + # due to slightly different orientations, the products of inertia need to + # have their signs flipped; this is done later when entering the numerical + # value. + + Frame_I = (inertia(TempFrame, Iframe11, Iframe22, Iframe33, 0, 0, Iframe31), Frame_mc) + Fork_I = (inertia(TempFork, Ifork11, Ifork22, Ifork33, 0, 0, Ifork31), Fork_mc) + WR_I = (inertia(Frame, Iwr11, Iwr22, Iwr11), WR_mc) + WF_I = (inertia(Fork, Iwf11, Iwf22, Iwf11), WF_mc) + + # Declaration of the RigidBody containers. :: + + BodyFrame = RigidBody('BodyFrame', Frame_mc, Frame, mframe, Frame_I) + BodyFork = RigidBody('BodyFork', Fork_mc, Fork, mfork, Fork_I) + BodyWR = RigidBody('BodyWR', WR_mc, WR, mwr, WR_I) + BodyWF = RigidBody('BodyWF', WF_mc, WF, mwf, WF_I) + + # The kinematic differential equations; they are defined quite simply. Each + # entry in this list is equal to zero. + kd = [q1d - u1, q2d - u2, q4d - u4, q5d - u5] + + # The nonholonomic constraints are the velocity of the front wheel contact + # point dotted into the X, Y, and Z directions; the yaw frame is used as it + # is "closer" to the front wheel (1 less DCM connecting them). These + # constraints force the velocity of the front wheel contact point to be 0 + # in the inertial frame; the X and Y direction constraints enforce a + # "no-slip" condition, and the Z direction constraint forces the front + # wheel contact point to not move away from the ground frame, essentially + # replicating the holonomic constraint which does not allow the frame pitch + # to change in an invalid fashion. + + conlist_speed = [WF_cont.vel(N) & Y.x, WF_cont.vel(N) & Y.y, WF_cont.vel(N) & Y.z] + + # The holonomic constraint is that the position from the rear wheel contact + # point to the front wheel contact point when dotted into the + # normal-to-ground plane direction must be zero; effectively that the front + # and rear wheel contact points are always touching the ground plane. This + # is actually not part of the dynamic equations, but instead is necessary + # for the lineraization process. + + conlist_coord = [WF_cont.pos_from(WR_cont) & Y.z] + + # The force list; each body has the appropriate gravitational force applied + # at its mass center. + FL = [(Frame_mc, -mframe * g * Y.z), + (Fork_mc, -mfork * g * Y.z), + (WF_mc, -mwf * g * Y.z), + (WR_mc, -mwr * g * Y.z)] + BL = [BodyFrame, BodyFork, BodyWR, BodyWF] + + + # The N frame is the inertial frame, coordinates are supplied in the order + # of independent, dependent coordinates, as are the speeds. The kinematic + # differential equation are also entered here. Here the dependent speeds + # are specified, in the same order they were provided in earlier, along + # with the non-holonomic constraints. The dependent coordinate is also + # provided, with the holonomic constraint. Again, this is only provided + # for the linearization process. + + KM = KanesMethod(N, q_ind=[q1, q2, q5], + q_dependent=[q4], configuration_constraints=conlist_coord, + u_ind=[u2, u3, u5], + u_dependent=[u1, u4, u6], velocity_constraints=conlist_speed, + kd_eqs=kd, + constraint_solver="CRAMER") + (fr, frstar) = KM.kanes_equations(BL, FL) + + # This is the start of entering in the numerical values from the benchmark + # paper to validate the eigen values of the linearized equations from this + # model to the reference eigen values. Look at the aforementioned paper for + # more information. Some of these are intermediate values, used to + # transform values from the paper into the coordinate systems used in this + # model. + PaperRadRear = 0.3 + PaperRadFront = 0.35 + HTA = (pi / 2 - pi / 10).evalf() + TrailPaper = 0.08 + rake = (-(TrailPaper*sin(HTA)-(PaperRadFront*cos(HTA)))).evalf() + PaperWb = 1.02 + PaperFrameCgX = 0.3 + PaperFrameCgZ = 0.9 + PaperForkCgX = 0.9 + PaperForkCgZ = 0.7 + FrameLength = (PaperWb*sin(HTA)-(rake-(PaperRadFront-PaperRadRear)*cos(HTA))).evalf() + FrameCGNorm = ((PaperFrameCgZ - PaperRadRear-(PaperFrameCgX/sin(HTA))*cos(HTA))*sin(HTA)).evalf() + FrameCGPar = (PaperFrameCgX / sin(HTA) + (PaperFrameCgZ - PaperRadRear - PaperFrameCgX / sin(HTA) * cos(HTA)) * cos(HTA)).evalf() + tempa = (PaperForkCgZ - PaperRadFront) + tempb = (PaperWb-PaperForkCgX) + tempc = (sqrt(tempa**2+tempb**2)).evalf() + PaperForkL = (PaperWb*cos(HTA)-(PaperRadFront-PaperRadRear)*sin(HTA)).evalf() + ForkCGNorm = (rake+(tempc * sin(pi/2-HTA-acos(tempa/tempc)))).evalf() + ForkCGPar = (tempc * cos((pi/2-HTA)-acos(tempa/tempc))-PaperForkL).evalf() + + # Here is the final assembly of the numerical values. The symbol 'v' is the + # forward speed of the bicycle (a concept which only makes sense in the + # upright, static equilibrium case?). These are in a dictionary which will + # later be substituted in. Again the sign on the *product* of inertia + # values is flipped here, due to different orientations of coordinate + # systems. + v = symbols('v') + val_dict = {WFrad: PaperRadFront, + WRrad: PaperRadRear, + htangle: HTA, + forkoffset: rake, + forklength: PaperForkL, + framelength: FrameLength, + forkcg1: ForkCGPar, + forkcg3: ForkCGNorm, + framecg1: FrameCGNorm, + framecg3: FrameCGPar, + Iwr11: 0.0603, + Iwr22: 0.12, + Iwf11: 0.1405, + Iwf22: 0.28, + Ifork11: 0.05892, + Ifork22: 0.06, + Ifork33: 0.00708, + Ifork31: 0.00756, + Iframe11: 9.2, + Iframe22: 11, + Iframe33: 2.8, + Iframe31: -2.4, + mfork: 4, + mframe: 85, + mwf: 3, + mwr: 2, + g: 9.81, + q1: 0, + q2: 0, + q4: 0, + q5: 0, + u1: 0, + u2: 0, + u3: v / PaperRadRear, + u4: 0, + u5: 0, + u6: v / PaperRadFront} + + # Linearizes the forcing vector; the equations are set up as MM udot = + # forcing, where MM is the mass matrix, udot is the vector representing the + # time derivatives of the generalized speeds, and forcing is a vector which + # contains both external forcing terms and internal forcing terms, such as + # centripital or coriolis forces. This actually returns a matrix with as + # many rows as *total* coordinates and speeds, but only as many columns as + # independent coordinates and speeds. + + A, B, _ = KM.linearize( + A_and_B=True, + op_point={ + # Operating points for the accelerations are required for the + # linearizer to eliminate u' terms showing up in the coefficient + # matrices. + u1.diff(): 0, + u2.diff(): 0, + u3.diff(): 0, + u4.diff(): 0, + u5.diff(): 0, + u6.diff(): 0, + u1: 0, + u2: 0, + u3: v / PaperRadRear, + u4: 0, + u5: 0, + u6: v / PaperRadFront, + q1: 0, + q2: 0, + q4: 0, + q5: 0, + }, + linear_solver="CRAMER", + ) + # As mentioned above, the size of the linearized forcing terms is expanded + # to include both q's and u's, so the mass matrix must have this done as + # well. This will likely be changed to be part of the linearized process, + # for future reference. + A_s = A.xreplace(val_dict) + B_s = B.xreplace(val_dict) + + A_s = A_s.evalf() + B_s = B_s.evalf() + + # Finally, we construct an "A" matrix for the form xdot = A x (x being the + # state vector, although in this case, the sizes are a little off). The + # following line extracts only the minimum entries required for eigenvalue + # analysis, which correspond to rows and columns for lean, steer, lean + # rate, and steer rate. + A = A_s.extract([1, 2, 3, 5], [1, 2, 3, 5]) + + # Precomputed for comparison + Res = Matrix([[ 0, 0, 1.0, 0], + [ 0, 0, 0, 1.0], + [9.48977444677355, -0.891197738059089*v**2 - 0.571523173729245, -0.105522449805691*v, -0.330515398992311*v], + [11.7194768719633, -1.97171508499972*v**2 + 30.9087533932407, 3.67680523332152*v, -3.08486552743311*v]]) + + # Actual eigenvalue comparison + eps = 1.e-12 + for i in range(6): + error = Res.subs(v, i) - A.subs(v, i) + assert all(abs(x) < eps for x in error) diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_kane4.py b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_kane4.py new file mode 100644 index 0000000000000000000000000000000000000000..a44dd2d407056ea36669268d478780fc581def51 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_kane4.py @@ -0,0 +1,115 @@ +from sympy import (cos, sin, Matrix, symbols) +from sympy.physics.mechanics import (dynamicsymbols, ReferenceFrame, Point, + KanesMethod, Particle) + +def test_replace_qdots_in_force(): + # Test PR 16700 "Replaces qdots with us in force-list in kanes.py" + # The new functionality allows one to specify forces in qdots which will + # automatically be replaced with u:s which are defined by the kde supplied + # to KanesMethod. The test case is the double pendulum with interacting + # forces in the example of chapter 4.7 "CONTRIBUTING INTERACTION FORCES" + # in Ref. [1]. Reference list at end test function. + + q1, q2 = dynamicsymbols('q1, q2') + qd1, qd2 = dynamicsymbols('q1, q2', level=1) + u1, u2 = dynamicsymbols('u1, u2') + + l, m = symbols('l, m') + + N = ReferenceFrame('N') # Inertial frame + A = N.orientnew('A', 'Axis', (q1, N.z)) # Rod A frame + B = A.orientnew('B', 'Axis', (q2, N.z)) # Rod B frame + + O = Point('O') # Origo + O.set_vel(N, 0) + + P = O.locatenew('P', ( l * A.x )) # Point @ end of rod A + P.v2pt_theory(O, N, A) + + Q = P.locatenew('Q', ( l * B.x )) # Point @ end of rod B + Q.v2pt_theory(P, N, B) + + Ap = Particle('Ap', P, m) + Bp = Particle('Bp', Q, m) + + # The forces are specified below. sigma is the torsional spring stiffness + # and delta is the viscous damping coefficient acting between the two + # bodies. Here, we specify the viscous damper as function of qdots prior + # forming the kde. In more complex systems it not might be obvious which + # kde is most efficient, why it is convenient to specify viscous forces in + # qdots independently of the kde. + sig, delta = symbols('sigma, delta') + Ta = (sig * q2 + delta * qd2) * N.z + forces = [(A, Ta), (B, -Ta)] + + # Try different kdes. + kde1 = [u1 - qd1, u2 - qd2] + kde2 = [u1 - qd1, u2 - (qd1 + qd2)] + + KM1 = KanesMethod(N, [q1, q2], [u1, u2], kd_eqs=kde1) + fr1, fstar1 = KM1.kanes_equations([Ap, Bp], forces) + + KM2 = KanesMethod(N, [q1, q2], [u1, u2], kd_eqs=kde2) + fr2, fstar2 = KM2.kanes_equations([Ap, Bp], forces) + + # Check EOM for KM2: + # Mass and force matrix from p.6 in Ref. [2] with added forces from + # example of chapter 4.7 in [1] and without gravity. + forcing_matrix_expected = Matrix( [ [ m * l**2 * sin(q2) * u2**2 + sig * q2 + + delta * (u2 - u1)], + [ m * l**2 * sin(q2) * -u1**2 - sig * q2 + - delta * (u2 - u1)] ] ) + mass_matrix_expected = Matrix( [ [ 2 * m * l**2, m * l**2 * cos(q2) ], + [ m * l**2 * cos(q2), m * l**2 ] ] ) + + assert (KM2.mass_matrix.expand() == mass_matrix_expected.expand()) + assert (KM2.forcing.expand() == forcing_matrix_expected.expand()) + + # Check fr1 with reference fr_expected from [1] with u:s instead of qdots. + fr1_expected = Matrix([ 0, -(sig*q2 + delta * u2) ]) + assert fr1.expand() == fr1_expected.expand() + + # Check fr2 + fr2_expected = Matrix([sig * q2 + delta * (u2 - u1), + - sig * q2 - delta * (u2 - u1)]) + assert fr2.expand() == fr2_expected.expand() + + # Specifying forces in u:s should stay the same: + Ta = (sig * q2 + delta * u2) * N.z + forces = [(A, Ta), (B, -Ta)] + KM1 = KanesMethod(N, [q1, q2], [u1, u2], kd_eqs=kde1) + fr1, fstar1 = KM1.kanes_equations([Ap, Bp], forces) + + assert fr1.expand() == fr1_expected.expand() + + Ta = (sig * q2 + delta * (u2-u1)) * N.z + forces = [(A, Ta), (B, -Ta)] + KM2 = KanesMethod(N, [q1, q2], [u1, u2], kd_eqs=kde2) + fr2, fstar2 = KM2.kanes_equations([Ap, Bp], forces) + + assert fr2.expand() == fr2_expected.expand() + + # Test if we have a qubic qdot force: + Ta = (sig * q2 + delta * qd2**3) * N.z + forces = [(A, Ta), (B, -Ta)] + + KM1 = KanesMethod(N, [q1, q2], [u1, u2], kd_eqs=kde1) + fr1, fstar1 = KM1.kanes_equations([Ap, Bp], forces) + + fr1_cubic_expected = Matrix([ 0, -(sig*q2 + delta * u2**3) ]) + + assert fr1.expand() == fr1_cubic_expected.expand() + + KM2 = KanesMethod(N, [q1, q2], [u1, u2], kd_eqs=kde2) + fr2, fstar2 = KM2.kanes_equations([Ap, Bp], forces) + + fr2_cubic_expected = Matrix([sig * q2 + delta * (u2 - u1)**3, + - sig * q2 - delta * (u2 - u1)**3]) + + assert fr2.expand() == fr2_cubic_expected.expand() + + # References: + # [1] T.R. Kane, D. a Levinson, Dynamics Theory and Applications, 2005. + # [2] Arun K Banerjee, Flexible Multibody Dynamics:Efficient Formulations + # and Applications, John Wiley and Sons, Ltd, 2016. + # doi:http://dx.doi.org/10.1002/9781119015635. diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_kane5.py b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_kane5.py new file mode 100644 index 0000000000000000000000000000000000000000..1d0f863e8fa0f46bcd8ae729a1a8852b702bdafa --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_kane5.py @@ -0,0 +1,128 @@ +from sympy import (zeros, Matrix, symbols, lambdify, sqrt, pi, + simplify) +from sympy.physics.mechanics import (dynamicsymbols, cross, inertia, RigidBody, + ReferenceFrame, KanesMethod) + + +def _create_rolling_disc(): + # Define symbols and coordinates + t = dynamicsymbols._t + q1, q2, q3, q4, q5, u1, u2, u3, u4, u5 = dynamicsymbols('q1:6 u1:6') + g, r, m = symbols('g r m') + # Define bodies and frames + ground = RigidBody('ground') + disc = RigidBody('disk', mass=m) + disc.inertia = (m * r ** 2 / 4 * inertia(disc.frame, 1, 2, 1), + disc.masscenter) + ground.masscenter.set_vel(ground.frame, 0) + disc.masscenter.set_vel(disc.frame, 0) + int_frame = ReferenceFrame('int_frame') + # Orient frames + int_frame.orient_body_fixed(ground.frame, (q1, q2, 0), 'zxy') + disc.frame.orient_axis(int_frame, int_frame.y, q3) + g_w_d = disc.frame.ang_vel_in(ground.frame) + disc.frame.set_ang_vel(ground.frame, + u1 * disc.x + u2 * disc.y + u3 * disc.z) + # Define points + cp = ground.masscenter.locatenew('contact_point', + q4 * ground.x + q5 * ground.y) + cp.set_vel(ground.frame, u4 * ground.x + u5 * ground.y) + disc.masscenter.set_pos(cp, r * int_frame.z) + disc.masscenter.set_vel(ground.frame, cross( + disc.frame.ang_vel_in(ground.frame), disc.masscenter.pos_from(cp))) + # Define kinematic differential equations + kdes = [g_w_d.dot(disc.x) - u1, g_w_d.dot(disc.y) - u2, + g_w_d.dot(disc.z) - u3, q4.diff(t) - u4, q5.diff(t) - u5] + # Define nonholonomic constraints + v0 = cp.vel(ground.frame) + cross( + disc.frame.ang_vel_in(int_frame), cp.pos_from(disc.masscenter)) + fnh = [v0.dot(ground.x), v0.dot(ground.y)] + # Define loads + loads = [(disc.masscenter, -disc.mass * g * ground.z)] + bodies = [disc] + return { + 'frame': ground.frame, + 'q_ind': [q1, q2, q3, q4, q5], + 'u_ind': [u1, u2, u3], + 'u_dep': [u4, u5], + 'kdes': kdes, + 'fnh': fnh, + 'bodies': bodies, + 'loads': loads + } + + +def _verify_rolling_disc_numerically(kane, all_zero=False): + q, u, p = dynamicsymbols('q1:6'), dynamicsymbols('u1:6'), symbols('g r m') + eval_sys = lambdify((q, u, p), (kane.mass_matrix_full, kane.forcing_full), + cse=True) + solve_sys = lambda q, u, p: Matrix.LUsolve( + *(Matrix(mat) for mat in eval_sys(q, u, p))) + solve_u_dep = lambdify((q, u[:3], p), kane._Ars * Matrix(u[:3]), cse=True) + eps = 1e-10 + p_vals = (9.81, 0.26, 3.43) + # First numeric test + q_vals = (0.3, 0.1, 1.97, -0.35, 2.27) + u_vals = [-0.2, 1.3, 0.15] + u_vals.extend(solve_u_dep(q_vals, u_vals, p_vals)[:2, 0]) + expected = Matrix([ + 0.126603940595934, 0.215942571601660, 1.28736069604936, + 0.319764288376543, 0.0989146857254898, -0.925848952664489, + -0.0181350656532944, 2.91695398184589, -0.00992793421754526, + 0.0412861634829171]) + assert all(abs(x) < eps for x in + (solve_sys(q_vals, u_vals, p_vals) - expected)) + # Second numeric test + q_vals = (3.97, -0.28, 8.2, -0.35, 2.27) + u_vals = [-0.25, -2.2, 0.62] + u_vals.extend(solve_u_dep(q_vals, u_vals, p_vals)[:2, 0]) + expected = Matrix([ + 0.0259159090798597, 0.668041660387416, -2.19283799213811, + 0.385441810852219, 0.420109283790573, 1.45030568179066, + -0.0110924422400793, -8.35617840186040, -0.154098542632173, + -0.146102664410010]) + assert all(abs(x) < eps for x in + (solve_sys(q_vals, u_vals, p_vals) - expected)) + if all_zero: + q_vals = (0, 0, 0, 0, 0) + u_vals = (0, 0, 0, 0, 0) + assert solve_sys(q_vals, u_vals, p_vals) == zeros(10, 1) + + +def test_kane_rolling_disc_lu(): + props = _create_rolling_disc() + kane = KanesMethod(props['frame'], props['q_ind'], props['u_ind'], + props['kdes'], u_dependent=props['u_dep'], + velocity_constraints=props['fnh'], + bodies=props['bodies'], forcelist=props['loads'], + explicit_kinematics=False, constraint_solver='LU') + kane.kanes_equations() + _verify_rolling_disc_numerically(kane) + + +def test_kane_rolling_disc_kdes_callable(): + props = _create_rolling_disc() + kane = KanesMethod( + props['frame'], props['q_ind'], props['u_ind'], props['kdes'], + u_dependent=props['u_dep'], velocity_constraints=props['fnh'], + bodies=props['bodies'], forcelist=props['loads'], + explicit_kinematics=False, + kd_eqs_solver=lambda A, b: simplify(A.LUsolve(b))) + q, u, p = dynamicsymbols('q1:6'), dynamicsymbols('u1:6'), symbols('g r m') + qd = dynamicsymbols('q1:6', 1) + eval_kdes = lambdify((q, qd, u, p), tuple(kane.kindiffdict().items())) + eps = 1e-10 + # Test with only zeros. If 'LU' would be used this would result in nan. + p_vals = (9.81, 0.25, 3.5) + zero_vals = (0, 0, 0, 0, 0) + assert all(abs(qdi - fui) < eps for qdi, fui in + eval_kdes(zero_vals, zero_vals, zero_vals, p_vals)) + # Test with some arbitrary values + q_vals = tuple(map(float, (pi / 6, pi / 3, pi / 2, 0.42, 0.62))) + qd_vals = tuple(map(float, (4, 1 / 3, 4 - 2 * sqrt(3), + 0.25 * (2 * sqrt(3) - 3), + 0.25 * (2 - sqrt(3))))) + u_vals = tuple(map(float, (-2, 4, 1 / 3, 0.25 * (-3 + 2 * sqrt(3)), + 0.25 * (-sqrt(3) + 2)))) + assert all(abs(qdi - fui) < eps for qdi, fui in + eval_kdes(q_vals, qd_vals, u_vals, p_vals)) diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_lagrange.py b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_lagrange.py new file mode 100644 index 0000000000000000000000000000000000000000..81552bc7a4d0f6766dc46dcd47b7c7b1b0151b3f --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_lagrange.py @@ -0,0 +1,247 @@ +from sympy.physics.mechanics import (dynamicsymbols, ReferenceFrame, Point, + RigidBody, LagrangesMethod, Particle, + inertia, Lagrangian) +from sympy.core.function import (Derivative, Function) +from sympy.core.numbers import pi +from sympy.core.symbol import symbols +from sympy.functions.elementary.trigonometric import (cos, sin, tan) +from sympy.matrices.dense import Matrix +from sympy.simplify.simplify import simplify +from sympy.testing.pytest import raises + + +def test_invalid_coordinates(): + # Simple pendulum, but use symbol instead of dynamicsymbol + l, m, g = symbols('l m g') + q = symbols('q') # Generalized coordinate + N, O = ReferenceFrame('N'), Point('O') + O.set_vel(N, 0) + P = Particle('P', Point('P'), m) + P.point.set_pos(O, l * (sin(q) * N.x - cos(q) * N.y)) + P.potential_energy = m * g * P.point.pos_from(O).dot(N.y) + L = Lagrangian(N, P) + raises(ValueError, lambda: LagrangesMethod(L, [q], bodies=P)) + + +def test_disc_on_an_incline_plane(): + # Disc rolling on an inclined plane + # First the generalized coordinates are created. The mass center of the + # disc is located from top vertex of the inclined plane by the generalized + # coordinate 'y'. The orientation of the disc is defined by the angle + # 'theta'. The mass of the disc is 'm' and its radius is 'R'. The length of + # the inclined path is 'l', the angle of inclination is 'alpha'. 'g' is the + # gravitational constant. + y, theta = dynamicsymbols('y theta') + yd, thetad = dynamicsymbols('y theta', 1) + m, g, R, l, alpha = symbols('m g R l alpha') + + # Next, we create the inertial reference frame 'N'. A reference frame 'A' + # is attached to the inclined plane. Finally a frame is created which is attached to the disk. + N = ReferenceFrame('N') + A = N.orientnew('A', 'Axis', [pi/2 - alpha, N.z]) + B = A.orientnew('B', 'Axis', [-theta, A.z]) + + # Creating the disc 'D'; we create the point that represents the mass + # center of the disc and set its velocity. The inertia dyadic of the disc + # is created. Finally, we create the disc. + Do = Point('Do') + Do.set_vel(N, yd * A.x) + I = m * R**2/2 * B.z | B.z + D = RigidBody('D', Do, B, m, (I, Do)) + + # To construct the Lagrangian, 'L', of the disc, we determine its kinetic + # and potential energies, T and U, respectively. L is defined as the + # difference between T and U. + D.potential_energy = m * g * (l - y) * sin(alpha) + L = Lagrangian(N, D) + + # We then create the list of generalized coordinates and constraint + # equations. The constraint arises due to the disc rolling without slip on + # on the inclined path. We then invoke the 'LagrangesMethod' class and + # supply it the necessary arguments and generate the equations of motion. + # The'rhs' method solves for the q_double_dots (i.e. the second derivative + # with respect to time of the generalized coordinates and the lagrange + # multipliers. + q = [y, theta] + hol_coneqs = [y - R * theta] + m = LagrangesMethod(L, q, hol_coneqs=hol_coneqs) + m.form_lagranges_equations() + rhs = m.rhs() + rhs.simplify() + assert rhs[2] == 2*g*sin(alpha)/3 + + +def test_simp_pen(): + # This tests that the equations generated by LagrangesMethod are identical + # to those obtained by hand calculations. The system under consideration is + # the simple pendulum. + # We begin by creating the generalized coordinates as per the requirements + # of LagrangesMethod. Also we created the associate symbols + # that characterize the system: 'm' is the mass of the bob, l is the length + # of the massless rigid rod connecting the bob to a point O fixed in the + # inertial frame. + q, u = dynamicsymbols('q u') + qd, ud = dynamicsymbols('q u ', 1) + l, m, g = symbols('l m g') + + # We then create the inertial frame and a frame attached to the massless + # string following which we define the inertial angular velocity of the + # string. + N = ReferenceFrame('N') + A = N.orientnew('A', 'Axis', [q, N.z]) + A.set_ang_vel(N, qd * N.z) + + # Next, we create the point O and fix it in the inertial frame. We then + # locate the point P to which the bob is attached. Its corresponding + # velocity is then determined by the 'two point formula'. + O = Point('O') + O.set_vel(N, 0) + P = O.locatenew('P', l * A.x) + P.v2pt_theory(O, N, A) + + # The 'Particle' which represents the bob is then created and its + # Lagrangian generated. + Pa = Particle('Pa', P, m) + Pa.potential_energy = - m * g * l * cos(q) + L = Lagrangian(N, Pa) + + # The 'LagrangesMethod' class is invoked to obtain equations of motion. + lm = LagrangesMethod(L, [q]) + lm.form_lagranges_equations() + RHS = lm.rhs() + assert RHS[1] == -g*sin(q)/l + + +def test_nonminimal_pendulum(): + q1, q2 = dynamicsymbols('q1:3') + q1d, q2d = dynamicsymbols('q1:3', level=1) + L, m, t = symbols('L, m, t') + g = 9.8 + # Compose World Frame + N = ReferenceFrame('N') + pN = Point('N*') + pN.set_vel(N, 0) + # Create point P, the pendulum mass + P = pN.locatenew('P1', q1*N.x + q2*N.y) + P.set_vel(N, P.pos_from(pN).dt(N)) + pP = Particle('pP', P, m) + # Constraint Equations + f_c = Matrix([q1**2 + q2**2 - L**2]) + # Calculate the lagrangian, and form the equations of motion + Lag = Lagrangian(N, pP) + LM = LagrangesMethod(Lag, [q1, q2], hol_coneqs=f_c, + forcelist=[(P, m*g*N.x)], frame=N) + LM.form_lagranges_equations() + # Check solution + lam1 = LM.lam_vec[0, 0] + eom_sol = Matrix([[m*Derivative(q1, t, t) - 9.8*m + 2*lam1*q1], + [m*Derivative(q2, t, t) + 2*lam1*q2]]) + assert LM.eom == eom_sol + # Check multiplier solution + lam_sol = Matrix([(19.6*q1 + 2*q1d**2 + 2*q2d**2)/(4*q1**2/m + 4*q2**2/m)]) + assert simplify(LM.solve_multipliers(sol_type='Matrix')) == simplify(lam_sol) + + +def test_dub_pen(): + + # The system considered is the double pendulum. Like in the + # test of the simple pendulum above, we begin by creating the generalized + # coordinates and the simple generalized speeds and accelerations which + # will be used later. Following this we create frames and points necessary + # for the kinematics. The procedure isn't explicitly explained as this is + # similar to the simple pendulum. Also this is documented on the pydy.org + # website. + q1, q2 = dynamicsymbols('q1 q2') + q1d, q2d = dynamicsymbols('q1 q2', 1) + q1dd, q2dd = dynamicsymbols('q1 q2', 2) + u1, u2 = dynamicsymbols('u1 u2') + u1d, u2d = dynamicsymbols('u1 u2', 1) + l, m, g = symbols('l m g') + + N = ReferenceFrame('N') + A = N.orientnew('A', 'Axis', [q1, N.z]) + B = N.orientnew('B', 'Axis', [q2, N.z]) + + A.set_ang_vel(N, q1d * A.z) + B.set_ang_vel(N, q2d * A.z) + + O = Point('O') + P = O.locatenew('P', l * A.x) + R = P.locatenew('R', l * B.x) + + O.set_vel(N, 0) + P.v2pt_theory(O, N, A) + R.v2pt_theory(P, N, B) + + ParP = Particle('ParP', P, m) + ParR = Particle('ParR', R, m) + + ParP.potential_energy = - m * g * l * cos(q1) + ParR.potential_energy = - m * g * l * cos(q1) - m * g * l * cos(q2) + L = Lagrangian(N, ParP, ParR) + lm = LagrangesMethod(L, [q1, q2], bodies=[ParP, ParR]) + lm.form_lagranges_equations() + + assert simplify(l*m*(2*g*sin(q1) + l*sin(q1)*sin(q2)*q2dd + + l*sin(q1)*cos(q2)*q2d**2 - l*sin(q2)*cos(q1)*q2d**2 + + l*cos(q1)*cos(q2)*q2dd + 2*l*q1dd) - lm.eom[0]) == 0 + assert simplify(l*m*(g*sin(q2) + l*sin(q1)*sin(q2)*q1dd + - l*sin(q1)*cos(q2)*q1d**2 + l*sin(q2)*cos(q1)*q1d**2 + + l*cos(q1)*cos(q2)*q1dd + l*q2dd) - lm.eom[1]) == 0 + assert lm.bodies == [ParP, ParR] + + +def test_rolling_disc(): + # Rolling Disc Example + # Here the rolling disc is formed from the contact point up, removing the + # need to introduce generalized speeds. Only 3 configuration and 3 + # speed variables are need to describe this system, along with the + # disc's mass and radius, and the local gravity. + q1, q2, q3 = dynamicsymbols('q1 q2 q3') + q1d, q2d, q3d = dynamicsymbols('q1 q2 q3', 1) + r, m, g = symbols('r m g') + + # The kinematics are formed by a series of simple rotations. Each simple + # rotation creates a new frame, and the next rotation is defined by the new + # frame's basis vectors. This example uses a 3-1-2 series of rotations, or + # Z, X, Y series of rotations. Angular velocity for this is defined using + # the second frame's basis (the lean frame). + N = ReferenceFrame('N') + Y = N.orientnew('Y', 'Axis', [q1, N.z]) + L = Y.orientnew('L', 'Axis', [q2, Y.x]) + R = L.orientnew('R', 'Axis', [q3, L.y]) + + # This is the translational kinematics. We create a point with no velocity + # in N; this is the contact point between the disc and ground. Next we form + # the position vector from the contact point to the disc's center of mass. + # Finally we form the velocity and acceleration of the disc. + C = Point('C') + C.set_vel(N, 0) + Dmc = C.locatenew('Dmc', r * L.z) + Dmc.v2pt_theory(C, N, R) + + # Forming the inertia dyadic. + I = inertia(L, m/4 * r**2, m/2 * r**2, m/4 * r**2) + BodyD = RigidBody('BodyD', Dmc, R, m, (I, Dmc)) + + # Finally we form the equations of motion, using the same steps we did + # before. Supply the Lagrangian, the generalized speeds. + BodyD.potential_energy = - m * g * r * cos(q2) + Lag = Lagrangian(N, BodyD) + q = [q1, q2, q3] + q1 = Function('q1') + q2 = Function('q2') + q3 = Function('q3') + l = LagrangesMethod(Lag, q) + l.form_lagranges_equations() + RHS = l.rhs() + RHS.simplify() + t = symbols('t') + + assert (l.mass_matrix[3:6] == [0, 5*m*r**2/4, 0]) + assert RHS[4].simplify() == ( + (-8*g*sin(q2(t)) + r*(5*sin(2*q2(t))*Derivative(q1(t), t) + + 12*cos(q2(t))*Derivative(q3(t), t))*Derivative(q1(t), t))/(10*r)) + assert RHS[5] == (-5*cos(q2(t))*Derivative(q1(t), t) + 6*tan(q2(t) + )*Derivative(q3(t), t) + 4*Derivative(q1(t), t)/cos(q2(t)) + )*Derivative(q2(t), t) diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_lagrange2.py b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_lagrange2.py new file mode 100644 index 0000000000000000000000000000000000000000..7602df157e9beb13db1dbb68a2980765cdc49bf2 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_lagrange2.py @@ -0,0 +1,46 @@ +from sympy import symbols +from sympy.physics.mechanics import dynamicsymbols +from sympy.physics.mechanics import ReferenceFrame, Point, Particle +from sympy.physics.mechanics import LagrangesMethod, Lagrangian + +### This test asserts that a system with more than one external forces +### is accurately formed with Lagrange method (see issue #8626) + +def test_lagrange_2forces(): + ### Equations for two damped springs in series with two forces + + ### generalized coordinates + q1, q2 = dynamicsymbols('q1, q2') + ### generalized speeds + q1d, q2d = dynamicsymbols('q1, q2', 1) + + ### Mass, spring strength, friction coefficient + m, k, nu = symbols('m, k, nu') + + N = ReferenceFrame('N') + O = Point('O') + + ### Two points + P1 = O.locatenew('P1', q1 * N.x) + P1.set_vel(N, q1d * N.x) + P2 = O.locatenew('P1', q2 * N.x) + P2.set_vel(N, q2d * N.x) + + pP1 = Particle('pP1', P1, m) + pP1.potential_energy = k * q1**2 / 2 + + pP2 = Particle('pP2', P2, m) + pP2.potential_energy = k * (q1 - q2)**2 / 2 + + #### Friction forces + forcelist = [(P1, - nu * q1d * N.x), + (P2, - nu * q2d * N.x)] + lag = Lagrangian(N, pP1, pP2) + + l_method = LagrangesMethod(lag, (q1, q2), forcelist=forcelist, frame=N) + l_method.form_lagranges_equations() + + eq1 = l_method.eom[0] + assert eq1.diff(q1d) == nu + eq2 = l_method.eom[1] + assert eq2.diff(q2d) == nu diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_linearity_of_velocity_constraints.py b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_linearity_of_velocity_constraints.py new file mode 100644 index 0000000000000000000000000000000000000000..33c9e7ec070a3e6db2a6e26697d670964b0a32b9 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_linearity_of_velocity_constraints.py @@ -0,0 +1,41 @@ +from sympy import symbols, sin, cos +from sympy.physics.mechanics import (dynamicsymbols, ReferenceFrame, Point, + KanesMethod) +from sympy.testing import pytest +from sympy.solvers.solveset import NonlinearError + +def test_linearity_of_motion_constraints(): + # Test that an error is raised by KanesMethod if nonlinear velocity + # constraints are supplied. + # It is a simple pendulum. + t = dynamicsymbols._t + N, A = ReferenceFrame('N'), ReferenceFrame('A') + O, P = Point('O'), Point('P') + O.set_vel(N, 0) + + l = symbols('l') + q, x, y, u, ux, uy = dynamicsymbols('q x y u ux uy') + + A.orient_axis(N, q, N.z) + A.set_ang_vel(N, u * N.z) + P.set_pos(O, -l * A.y) + P.v2pt_theory(O, N, A) + + kd = [u - q.diff(t), ux - x.diff(t), uy - y.diff(t)] + config_constr = [x - l * sin(q), y - l * cos(q)] + + q_ind = [q] + q_dep = [x, y] + u_ind = [u] + u_dep = [ux, uy] + + # Make sure an error is raised if nonlinear velocity constraints are + # supplied. + speed_constr = [ux - l * q.diff(t) * cos(q), sin(uy) + + l * q.diff(t) * sin(q)] + + with pytest.raises(NonlinearError): + KanesMethod(N, q_ind=q_ind, q_dependent=q_dep, u_ind=u_ind, + u_dependent=u_dep, kd_eqs=kd, + configuration_constraints=config_constr, + velocity_constraints=speed_constr) diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_linearize.py b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_linearize.py new file mode 100644 index 0000000000000000000000000000000000000000..ec62b960b71d7fce5a5504478431ca23eb371fe0 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_linearize.py @@ -0,0 +1,372 @@ +from sympy import symbols, Matrix, cos, sin, atan, sqrt, Rational +from sympy.core.sympify import sympify +from sympy.simplify.simplify import simplify +from sympy.solvers.solvers import solve +from sympy.physics.mechanics import dynamicsymbols, ReferenceFrame, Point,\ + dot, cross, inertia, KanesMethod, Particle, RigidBody, Lagrangian,\ + LagrangesMethod +from sympy.testing.pytest import slow + + +@slow +def test_linearize_rolling_disc_kane(): + # Symbols for time and constant parameters + t, r, m, g, v = symbols('t r m g v') + + # Configuration variables and their time derivatives + q1, q2, q3, q4, q5, q6 = q = dynamicsymbols('q1:7') + q1d, q2d, q3d, q4d, q5d, q6d = qd = [qi.diff(t) for qi in q] + + # Generalized speeds and their time derivatives + u = dynamicsymbols('u:6') + u1, u2, u3, u4, u5, u6 = u = dynamicsymbols('u1:7') + u1d, u2d, u3d, u4d, u5d, u6d = [ui.diff(t) for ui in u] + + # Reference frames + N = ReferenceFrame('N') # Inertial frame + NO = Point('NO') # Inertial origin + A = N.orientnew('A', 'Axis', [q1, N.z]) # Yaw intermediate frame + B = A.orientnew('B', 'Axis', [q2, A.x]) # Lean intermediate frame + C = B.orientnew('C', 'Axis', [q3, B.y]) # Disc fixed frame + CO = NO.locatenew('CO', q4*N.x + q5*N.y + q6*N.z) # Disc center + + # Disc angular velocity in N expressed using time derivatives of coordinates + w_c_n_qd = C.ang_vel_in(N) + w_b_n_qd = B.ang_vel_in(N) + + # Inertial angular velocity and angular acceleration of disc fixed frame + C.set_ang_vel(N, u1*B.x + u2*B.y + u3*B.z) + + # Disc center velocity in N expressed using time derivatives of coordinates + v_co_n_qd = CO.pos_from(NO).dt(N) + + # Disc center velocity in N expressed using generalized speeds + CO.set_vel(N, u4*C.x + u5*C.y + u6*C.z) + + # Disc Ground Contact Point + P = CO.locatenew('P', r*B.z) + P.v2pt_theory(CO, N, C) + + # Configuration constraint + f_c = Matrix([q6 - dot(CO.pos_from(P), N.z)]) + + # Velocity level constraints + f_v = Matrix([dot(P.vel(N), uv) for uv in C]) + + # Kinematic differential equations + kindiffs = Matrix([dot(w_c_n_qd - C.ang_vel_in(N), uv) for uv in B] + + [dot(v_co_n_qd - CO.vel(N), uv) for uv in N]) + qdots = solve(kindiffs, qd) + + # Set angular velocity of remaining frames + B.set_ang_vel(N, w_b_n_qd.subs(qdots)) + C.set_ang_acc(N, C.ang_vel_in(N).dt(B) + cross(B.ang_vel_in(N), C.ang_vel_in(N))) + + # Active forces + F_CO = m*g*A.z + + # Create inertia dyadic of disc C about point CO + I = (m * r**2) / 4 + J = (m * r**2) / 2 + I_C_CO = inertia(C, I, J, I) + + Disc = RigidBody('Disc', CO, C, m, (I_C_CO, CO)) + BL = [Disc] + FL = [(CO, F_CO)] + KM = KanesMethod(N, [q1, q2, q3, q4, q5], [u1, u2, u3], kd_eqs=kindiffs, + q_dependent=[q6], configuration_constraints=f_c, + u_dependent=[u4, u5, u6], velocity_constraints=f_v) + (fr, fr_star) = KM.kanes_equations(BL, FL) + + # Test generalized form equations + linearizer = KM.to_linearizer() + assert linearizer.f_c == f_c + assert linearizer.f_v == f_v + assert linearizer.f_a == f_v.diff(t).subs(KM.kindiffdict()) + sol = solve(linearizer.f_0 + linearizer.f_1, qd) + for qi in qdots.keys(): + assert sol[qi] == qdots[qi] + assert simplify(linearizer.f_2 + linearizer.f_3 - fr - fr_star) == Matrix([0, 0, 0]) + + # Perform the linearization + # Precomputed operating point + q_op = {q6: -r*cos(q2)} + u_op = {u1: 0, + u2: sin(q2)*q1d + q3d, + u3: cos(q2)*q1d, + u4: -r*(sin(q2)*q1d + q3d)*cos(q3), + u5: 0, + u6: -r*(sin(q2)*q1d + q3d)*sin(q3)} + qd_op = {q2d: 0, + q4d: -r*(sin(q2)*q1d + q3d)*cos(q1), + q5d: -r*(sin(q2)*q1d + q3d)*sin(q1), + q6d: 0} + ud_op = {u1d: 4*g*sin(q2)/(5*r) + sin(2*q2)*q1d**2/2 + 6*cos(q2)*q1d*q3d/5, + u2d: 0, + u3d: 0, + u4d: r*(sin(q2)*sin(q3)*q1d*q3d + sin(q3)*q3d**2), + u5d: r*(4*g*sin(q2)/(5*r) + sin(2*q2)*q1d**2/2 + 6*cos(q2)*q1d*q3d/5), + u6d: -r*(sin(q2)*cos(q3)*q1d*q3d + cos(q3)*q3d**2)} + + A, B = linearizer.linearize(op_point=[q_op, u_op, qd_op, ud_op], A_and_B=True, simplify=True) + + upright_nominal = {q1d: 0, q2: 0, m: 1, r: 1, g: 1} + + # Precomputed solution + A_sol = Matrix([[0, 0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0], + [sin(q1)*q3d, 0, 0, 0, 0, -sin(q1), -cos(q1), 0], + [-cos(q1)*q3d, 0, 0, 0, 0, cos(q1), -sin(q1), 0], + [0, Rational(4, 5), 0, 0, 0, 0, 0, 6*q3d/5], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, -2*q3d, 0, 0]]) + B_sol = Matrix([]) + + # Check that linearization is correct + assert A.subs(upright_nominal) == A_sol + assert B.subs(upright_nominal) == B_sol + + # Check eigenvalues at critical speed are all zero: + assert sympify(A.subs(upright_nominal).subs(q3d, 1/sqrt(3))).eigenvals() == {0: 8} + + # Check whether alternative solvers work + # symengine doesn't support method='GJ' + linearizer = KM.to_linearizer(linear_solver='GJ') + A, B = linearizer.linearize(op_point=[q_op, u_op, qd_op, ud_op], + A_and_B=True, simplify=True) + assert A.subs(upright_nominal) == A_sol + assert B.subs(upright_nominal) == B_sol + +def test_linearize_pendulum_kane_minimal(): + q1 = dynamicsymbols('q1') # angle of pendulum + u1 = dynamicsymbols('u1') # Angular velocity + q1d = dynamicsymbols('q1', 1) # Angular velocity + L, m, t = symbols('L, m, t') + g = 9.8 + + # Compose world frame + N = ReferenceFrame('N') + pN = Point('N*') + pN.set_vel(N, 0) + + # A.x is along the pendulum + A = N.orientnew('A', 'axis', [q1, N.z]) + A.set_ang_vel(N, u1*N.z) + + # Locate point P relative to the origin N* + P = pN.locatenew('P', L*A.x) + P.v2pt_theory(pN, N, A) + pP = Particle('pP', P, m) + + # Create Kinematic Differential Equations + kde = Matrix([q1d - u1]) + + # Input the force resultant at P + R = m*g*N.x + + # Solve for eom with kanes method + KM = KanesMethod(N, q_ind=[q1], u_ind=[u1], kd_eqs=kde) + (fr, frstar) = KM.kanes_equations([pP], [(P, R)]) + + # Linearize + A, B, inp_vec = KM.linearize(A_and_B=True, simplify=True) + + assert A == Matrix([[0, 1], [-9.8*cos(q1)/L, 0]]) + assert B == Matrix([]) + +def test_linearize_pendulum_kane_nonminimal(): + # Create generalized coordinates and speeds for this non-minimal realization + # q1, q2 = N.x and N.y coordinates of pendulum + # u1, u2 = N.x and N.y velocities of pendulum + q1, q2 = dynamicsymbols('q1:3') + q1d, q2d = dynamicsymbols('q1:3', level=1) + u1, u2 = dynamicsymbols('u1:3') + u1d, u2d = dynamicsymbols('u1:3', level=1) + L, m, t = symbols('L, m, t') + g = 9.8 + + # Compose world frame + N = ReferenceFrame('N') + pN = Point('N*') + pN.set_vel(N, 0) + + # A.x is along the pendulum + theta1 = atan(q2/q1) + A = N.orientnew('A', 'axis', [theta1, N.z]) + + # Locate the pendulum mass + P = pN.locatenew('P1', q1*N.x + q2*N.y) + pP = Particle('pP', P, m) + + # Calculate the kinematic differential equations + kde = Matrix([q1d - u1, + q2d - u2]) + dq_dict = solve(kde, [q1d, q2d]) + + # Set velocity of point P + P.set_vel(N, P.pos_from(pN).dt(N).subs(dq_dict)) + + # Configuration constraint is length of pendulum + f_c = Matrix([P.pos_from(pN).magnitude() - L]) + + # Velocity constraint is that the velocity in the A.x direction is + # always zero (the pendulum is never getting longer). + f_v = Matrix([P.vel(N).express(A).dot(A.x)]) + f_v.simplify() + + # Acceleration constraints is the time derivative of the velocity constraint + f_a = f_v.diff(t) + f_a.simplify() + + # Input the force resultant at P + R = m*g*N.x + + # Derive the equations of motion using the KanesMethod class. + KM = KanesMethod(N, q_ind=[q2], u_ind=[u2], q_dependent=[q1], + u_dependent=[u1], configuration_constraints=f_c, + velocity_constraints=f_v, acceleration_constraints=f_a, kd_eqs=kde) + (fr, frstar) = KM.kanes_equations([pP], [(P, R)]) + + # Set the operating point to be straight down, and non-moving + q_op = {q1: L, q2: 0} + u_op = {u1: 0, u2: 0} + ud_op = {u1d: 0, u2d: 0} + + A, B, inp_vec = KM.linearize(op_point=[q_op, u_op, ud_op], A_and_B=True, + simplify=True) + + assert A.expand() == Matrix([[0, 1], [-9.8/L, 0]]) + assert B == Matrix([]) + + + # symengine doesn't support method='GJ' + A, B, inp_vec = KM.linearize(op_point=[q_op, u_op, ud_op], A_and_B=True, + simplify=True, linear_solver='GJ') + + assert A.expand() == Matrix([[0, 1], [-9.8/L, 0]]) + assert B == Matrix([]) + + A, B, inp_vec = KM.linearize(op_point=[q_op, u_op, ud_op], + A_and_B=True, + simplify=True, + linear_solver=lambda A, b: A.LUsolve(b)) + + assert A.expand() == Matrix([[0, 1], [-9.8/L, 0]]) + assert B == Matrix([]) + + +def test_linearize_pendulum_lagrange_minimal(): + q1 = dynamicsymbols('q1') # angle of pendulum + q1d = dynamicsymbols('q1', 1) # Angular velocity + L, m, t = symbols('L, m, t') + g = 9.8 + + # Compose world frame + N = ReferenceFrame('N') + pN = Point('N*') + pN.set_vel(N, 0) + + # A.x is along the pendulum + A = N.orientnew('A', 'axis', [q1, N.z]) + A.set_ang_vel(N, q1d*N.z) + + # Locate point P relative to the origin N* + P = pN.locatenew('P', L*A.x) + P.v2pt_theory(pN, N, A) + pP = Particle('pP', P, m) + + # Solve for eom with Lagranges method + Lag = Lagrangian(N, pP) + LM = LagrangesMethod(Lag, [q1], forcelist=[(P, m*g*N.x)], frame=N) + LM.form_lagranges_equations() + + # Linearize + A, B, inp_vec = LM.linearize([q1], [q1d], A_and_B=True) + + assert simplify(A) == Matrix([[0, 1], [-9.8*cos(q1)/L, 0]]) + assert B == Matrix([]) + + # Check an alternative solver + A, B, inp_vec = LM.linearize([q1], [q1d], A_and_B=True, linear_solver='GJ') + + assert simplify(A) == Matrix([[0, 1], [-9.8*cos(q1)/L, 0]]) + assert B == Matrix([]) + + +def test_linearize_pendulum_lagrange_nonminimal(): + q1, q2 = dynamicsymbols('q1:3') + q1d, q2d = dynamicsymbols('q1:3', level=1) + L, m, t = symbols('L, m, t') + g = 9.8 + # Compose World Frame + N = ReferenceFrame('N') + pN = Point('N*') + pN.set_vel(N, 0) + # A.x is along the pendulum + theta1 = atan(q2/q1) + A = N.orientnew('A', 'axis', [theta1, N.z]) + # Create point P, the pendulum mass + P = pN.locatenew('P1', q1*N.x + q2*N.y) + P.set_vel(N, P.pos_from(pN).dt(N)) + pP = Particle('pP', P, m) + # Constraint Equations + f_c = Matrix([q1**2 + q2**2 - L**2]) + # Calculate the lagrangian, and form the equations of motion + Lag = Lagrangian(N, pP) + LM = LagrangesMethod(Lag, [q1, q2], hol_coneqs=f_c, forcelist=[(P, m*g*N.x)], frame=N) + LM.form_lagranges_equations() + # Compose operating point + op_point = {q1: L, q2: 0, q1d: 0, q2d: 0, q1d.diff(t): 0, q2d.diff(t): 0} + # Solve for multiplier operating point + lam_op = LM.solve_multipliers(op_point=op_point) + op_point.update(lam_op) + # Perform the Linearization + A, B, inp_vec = LM.linearize([q2], [q2d], [q1], [q1d], + op_point=op_point, A_and_B=True) + assert simplify(A) == Matrix([[0, 1], [-9.8/L, 0]]) + assert B == Matrix([]) + + # Check if passing a function to linear_solver works + A, B, inp_vec = LM.linearize([q2], [q2d], [q1], [q1d], op_point=op_point, + A_and_B=True, linear_solver=lambda A, b: + A.LUsolve(b)) + assert simplify(A) == Matrix([[0, 1], [-9.8/L, 0]]) + assert B == Matrix([]) + +def test_linearize_rolling_disc_lagrange(): + q1, q2, q3 = q = dynamicsymbols('q1 q2 q3') + q1d, q2d, q3d = qd = dynamicsymbols('q1 q2 q3', 1) + r, m, g = symbols('r m g') + + N = ReferenceFrame('N') + Y = N.orientnew('Y', 'Axis', [q1, N.z]) + L = Y.orientnew('L', 'Axis', [q2, Y.x]) + R = L.orientnew('R', 'Axis', [q3, L.y]) + + C = Point('C') + C.set_vel(N, 0) + Dmc = C.locatenew('Dmc', r * L.z) + Dmc.v2pt_theory(C, N, R) + + I = inertia(L, m / 4 * r**2, m / 2 * r**2, m / 4 * r**2) + BodyD = RigidBody('BodyD', Dmc, R, m, (I, Dmc)) + BodyD.potential_energy = - m * g * r * cos(q2) + + Lag = Lagrangian(N, BodyD) + l = LagrangesMethod(Lag, q) + l.form_lagranges_equations() + + # Linearize about steady-state upright rolling + op_point = {q1: 0, q2: 0, q3: 0, + q1d: 0, q2d: 0, + q1d.diff(): 0, q2d.diff(): 0, q3d.diff(): 0} + A = l.linearize(q_ind=q, qd_ind=qd, op_point=op_point, A_and_B=True)[0] + sol = Matrix([[0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 1, 0], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, -6*q3d, 0], + [0, -4*g/(5*r), 0, 6*q3d/5, 0, 0], + [0, 0, 0, 0, 0, 0]]) + + assert A == sol diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_loads.py b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_loads.py new file mode 100644 index 0000000000000000000000000000000000000000..8aa0cec14887f0778fc1e60e7ff33830ceef72d3 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_loads.py @@ -0,0 +1,86 @@ +from pytest import raises + +from sympy import symbols +from sympy.physics.mechanics import (RigidBody, Particle, ReferenceFrame, Point, + outer, dynamicsymbols, Force, Torque) +from sympy.physics.mechanics.loads import gravity, _parse_load + + +def test_force_default(): + N = ReferenceFrame('N') + Po = Point('Po') + f1 = Force(Po, N.x) + assert f1.point == Po + assert f1.force == N.x + assert f1.__repr__() == 'Force(point=Po, force=N.x)' + # Test tuple behaviour + assert isinstance(f1, tuple) + assert f1[0] == Po + assert f1[1] == N.x + assert f1 == (Po, N.x) + assert f1 != (N.x, Po) + assert f1 != (Po, N.x + N.y) + assert f1 != (Point('Co'), N.x) + # Test body as input + P = Particle('P', Po) + f2 = Force(P, N.x) + assert f1 == f2 + + +def test_torque_default(): + N = ReferenceFrame('N') + f1 = Torque(N, N.x) + assert f1.frame == N + assert f1.torque == N.x + assert f1.__repr__() == 'Torque(frame=N, torque=N.x)' + # Test tuple behaviour + assert isinstance(f1, tuple) + assert f1[0] == N + assert f1[1] == N.x + assert f1 == (N, N.x) + assert f1 != (N.x, N) + assert f1 != (N, N.x + N.y) + assert f1 != (ReferenceFrame('A'), N.x) + # Test body as input + rb = RigidBody('P', frame=N) + f2 = Torque(rb, N.x) + assert f1 == f2 + + +def test_gravity(): + N = ReferenceFrame('N') + m, M, g = symbols('m M g') + F1, F2 = dynamicsymbols('F1 F2') + po = Point('po') + pa = Particle('pa', po, m) + A = ReferenceFrame('A') + P = Point('P') + I = outer(A.x, A.x) + B = RigidBody('B', P, A, M, (I, P)) + forceList = [(po, F1), (P, F2)] + forceList.extend(gravity(g * N.y, pa, B)) + l = [(po, F1), (P, F2), (po, g * m * N.y), (P, g * M * N.y)] + + for i in range(len(l)): + for j in range(len(l[i])): + assert forceList[i][j] == l[i][j] + + +def test_parse_loads(): + N = ReferenceFrame('N') + po = Point('po') + assert _parse_load(Force(po, N.z)) == (po, N.z) + assert _parse_load(Torque(N, N.x)) == (N, N.x) + f1 = _parse_load((po, N.x)) # Test whether a force is recognized + assert isinstance(f1, Force) + assert f1 == Force(po, N.x) + t1 = _parse_load((N, N.y)) # Test whether a torque is recognized + assert isinstance(t1, Torque) + assert t1 == Torque(N, N.y) + # Bodies should be undetermined (even in case of a Particle) + raises(ValueError, lambda: _parse_load((Particle('pa', po), N.x))) + raises(ValueError, lambda: _parse_load((RigidBody('pa', po, N), N.x))) + # Invalid tuple length + raises(ValueError, lambda: _parse_load((po, N.x, po, N.x))) + # Invalid type + raises(TypeError, lambda: _parse_load([po, N.x])) diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_method.py b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_method.py new file mode 100644 index 0000000000000000000000000000000000000000..4a8fd5fb50c3178f5a5cdab1e80423df8b52f525 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_method.py @@ -0,0 +1,5 @@ +from sympy.physics.mechanics.method import _Methods +from sympy.testing.pytest import raises + +def test_method(): + raises(TypeError, lambda: _Methods()) diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_models.py b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_models.py new file mode 100644 index 0000000000000000000000000000000000000000..2b3d3ae89b44d774ead1a3ea641a8274ba951638 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_models.py @@ -0,0 +1,117 @@ +import sympy.physics.mechanics.models as models +from sympy import (cos, sin, Matrix, symbols, zeros) +from sympy.simplify.simplify import simplify +from sympy.physics.mechanics import (dynamicsymbols) + + +def test_multi_mass_spring_damper_inputs(): + + c0, k0, m0 = symbols("c0 k0 m0") + g = symbols("g") + v0, x0, f0 = dynamicsymbols("v0 x0 f0") + + kane1 = models.multi_mass_spring_damper(1) + massmatrix1 = Matrix([[m0]]) + forcing1 = Matrix([[-c0*v0 - k0*x0]]) + assert simplify(massmatrix1 - kane1.mass_matrix) == Matrix([0]) + assert simplify(forcing1 - kane1.forcing) == Matrix([0]) + + kane2 = models.multi_mass_spring_damper(1, True) + massmatrix2 = Matrix([[m0]]) + forcing2 = Matrix([[-c0*v0 + g*m0 - k0*x0]]) + assert simplify(massmatrix2 - kane2.mass_matrix) == Matrix([0]) + assert simplify(forcing2 - kane2.forcing) == Matrix([0]) + + kane3 = models.multi_mass_spring_damper(1, True, True) + massmatrix3 = Matrix([[m0]]) + forcing3 = Matrix([[-c0*v0 + g*m0 - k0*x0 + f0]]) + assert simplify(massmatrix3 - kane3.mass_matrix) == Matrix([0]) + assert simplify(forcing3 - kane3.forcing) == Matrix([0]) + + kane4 = models.multi_mass_spring_damper(1, False, True) + massmatrix4 = Matrix([[m0]]) + forcing4 = Matrix([[-c0*v0 - k0*x0 + f0]]) + assert simplify(massmatrix4 - kane4.mass_matrix) == Matrix([0]) + assert simplify(forcing4 - kane4.forcing) == Matrix([0]) + + +def test_multi_mass_spring_damper_higher_order(): + c0, k0, m0 = symbols("c0 k0 m0") + c1, k1, m1 = symbols("c1 k1 m1") + c2, k2, m2 = symbols("c2 k2 m2") + v0, x0 = dynamicsymbols("v0 x0") + v1, x1 = dynamicsymbols("v1 x1") + v2, x2 = dynamicsymbols("v2 x2") + + kane1 = models.multi_mass_spring_damper(3) + massmatrix1 = Matrix([[m0 + m1 + m2, m1 + m2, m2], + [m1 + m2, m1 + m2, m2], + [m2, m2, m2]]) + forcing1 = Matrix([[-c0*v0 - k0*x0], + [-c1*v1 - k1*x1], + [-c2*v2 - k2*x2]]) + assert simplify(massmatrix1 - kane1.mass_matrix) == zeros(3) + assert simplify(forcing1 - kane1.forcing) == Matrix([0, 0, 0]) + + +def test_n_link_pendulum_on_cart_inputs(): + l0, m0 = symbols("l0 m0") + m1 = symbols("m1") + g = symbols("g") + q0, q1, F, T1 = dynamicsymbols("q0 q1 F T1") + u0, u1 = dynamicsymbols("u0 u1") + + kane1 = models.n_link_pendulum_on_cart(1) + massmatrix1 = Matrix([[m0 + m1, -l0*m1*cos(q1)], + [-l0*m1*cos(q1), l0**2*m1]]) + forcing1 = Matrix([[-l0*m1*u1**2*sin(q1) + F], [g*l0*m1*sin(q1)]]) + assert simplify(massmatrix1 - kane1.mass_matrix) == zeros(2) + assert simplify(forcing1 - kane1.forcing) == Matrix([0, 0]) + + kane2 = models.n_link_pendulum_on_cart(1, False) + massmatrix2 = Matrix([[m0 + m1, -l0*m1*cos(q1)], + [-l0*m1*cos(q1), l0**2*m1]]) + forcing2 = Matrix([[-l0*m1*u1**2*sin(q1)], [g*l0*m1*sin(q1)]]) + assert simplify(massmatrix2 - kane2.mass_matrix) == zeros(2) + assert simplify(forcing2 - kane2.forcing) == Matrix([0, 0]) + + kane3 = models.n_link_pendulum_on_cart(1, False, True) + massmatrix3 = Matrix([[m0 + m1, -l0*m1*cos(q1)], + [-l0*m1*cos(q1), l0**2*m1]]) + forcing3 = Matrix([[-l0*m1*u1**2*sin(q1)], [g*l0*m1*sin(q1) + T1]]) + assert simplify(massmatrix3 - kane3.mass_matrix) == zeros(2) + assert simplify(forcing3 - kane3.forcing) == Matrix([0, 0]) + + kane4 = models.n_link_pendulum_on_cart(1, True, False) + massmatrix4 = Matrix([[m0 + m1, -l0*m1*cos(q1)], + [-l0*m1*cos(q1), l0**2*m1]]) + forcing4 = Matrix([[-l0*m1*u1**2*sin(q1) + F], [g*l0*m1*sin(q1)]]) + assert simplify(massmatrix4 - kane4.mass_matrix) == zeros(2) + assert simplify(forcing4 - kane4.forcing) == Matrix([0, 0]) + + +def test_n_link_pendulum_on_cart_higher_order(): + l0, m0 = symbols("l0 m0") + l1, m1 = symbols("l1 m1") + m2 = symbols("m2") + g = symbols("g") + q0, q1, q2 = dynamicsymbols("q0 q1 q2") + u0, u1, u2 = dynamicsymbols("u0 u1 u2") + F, T1 = dynamicsymbols("F T1") + + kane1 = models.n_link_pendulum_on_cart(2) + massmatrix1 = Matrix([[m0 + m1 + m2, -l0*m1*cos(q1) - l0*m2*cos(q1), + -l1*m2*cos(q2)], + [-l0*m1*cos(q1) - l0*m2*cos(q1), l0**2*m1 + l0**2*m2, + l0*l1*m2*(sin(q1)*sin(q2) + cos(q1)*cos(q2))], + [-l1*m2*cos(q2), + l0*l1*m2*(sin(q1)*sin(q2) + cos(q1)*cos(q2)), + l1**2*m2]]) + forcing1 = Matrix([[-l0*m1*u1**2*sin(q1) - l0*m2*u1**2*sin(q1) - + l1*m2*u2**2*sin(q2) + F], + [g*l0*m1*sin(q1) + g*l0*m2*sin(q1) - + l0*l1*m2*(sin(q1)*cos(q2) - sin(q2)*cos(q1))*u2**2], + [g*l1*m2*sin(q2) - l0*l1*m2*(-sin(q1)*cos(q2) + + sin(q2)*cos(q1))*u1**2]]) + assert simplify(massmatrix1 - kane1.mass_matrix) == zeros(3) + assert simplify(forcing1 - kane1.forcing) == Matrix([0, 0, 0]) diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_particle.py b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_particle.py new file mode 100644 index 0000000000000000000000000000000000000000..8eec80275b532055eacaf2339a276c0fd19b330a --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_particle.py @@ -0,0 +1,78 @@ +from sympy import symbols +from sympy.physics.mechanics import Point, Particle, ReferenceFrame, inertia +from sympy.physics.mechanics.body_base import BodyBase +from sympy.testing.pytest import raises, warns_deprecated_sympy + + +def test_particle_default(): + # Test default + p = Particle('P') + assert p.name == 'P' + assert p.mass == symbols('P_mass') + assert p.masscenter.name == 'P_masscenter' + assert p.potential_energy == 0 + assert p.__str__() == 'P' + assert p.__repr__() == ("Particle('P', masscenter=P_masscenter, " + "mass=P_mass)") + raises(AttributeError, lambda: p.frame) + + +def test_particle(): + # Test initializing with parameters + m, m2, v1, v2, v3, r, g, h = symbols('m m2 v1 v2 v3 r g h') + P = Point('P') + P2 = Point('P2') + p = Particle('pa', P, m) + assert isinstance(p, BodyBase) + assert p.mass == m + assert p.point == P + # Test the mass setter + p.mass = m2 + assert p.mass == m2 + # Test the point setter + p.point = P2 + assert p.point == P2 + # Test the linear momentum function + N = ReferenceFrame('N') + O = Point('O') + P2.set_pos(O, r * N.y) + P2.set_vel(N, v1 * N.x) + raises(TypeError, lambda: Particle(P, P, m)) + raises(TypeError, lambda: Particle('pa', m, m)) + assert p.linear_momentum(N) == m2 * v1 * N.x + assert p.angular_momentum(O, N) == -m2 * r * v1 * N.z + P2.set_vel(N, v2 * N.y) + assert p.linear_momentum(N) == m2 * v2 * N.y + assert p.angular_momentum(O, N) == 0 + P2.set_vel(N, v3 * N.z) + assert p.linear_momentum(N) == m2 * v3 * N.z + assert p.angular_momentum(O, N) == m2 * r * v3 * N.x + P2.set_vel(N, v1 * N.x + v2 * N.y + v3 * N.z) + assert p.linear_momentum(N) == m2 * (v1 * N.x + v2 * N.y + v3 * N.z) + assert p.angular_momentum(O, N) == m2 * r * (v3 * N.x - v1 * N.z) + p.potential_energy = m * g * h + assert p.potential_energy == m * g * h + # TODO make the result not be system-dependent + assert p.kinetic_energy( + N) in [m2 * (v1 ** 2 + v2 ** 2 + v3 ** 2) / 2, + m2 * v1 ** 2 / 2 + m2 * v2 ** 2 / 2 + m2 * v3 ** 2 / 2] + + +def test_parallel_axis(): + N = ReferenceFrame('N') + m, a, b = symbols('m, a, b') + o = Point('o') + p = o.locatenew('p', a * N.x + b * N.y) + P = Particle('P', o, m) + Ip = P.parallel_axis(p, N) + Ip_expected = inertia(N, m * b ** 2, m * a ** 2, m * (a ** 2 + b ** 2), + ixy=-m * a * b) + assert Ip == Ip_expected + + +def test_deprecated_set_potential_energy(): + m, g, h = symbols('m g h') + P = Point('P') + p = Particle('pa', P, m) + with warns_deprecated_sympy(): + p.set_potential_energy(m * g * h) diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_pathway.py b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_pathway.py new file mode 100644 index 0000000000000000000000000000000000000000..49dc4bd4d61300745833f9d32f3a91d9054c4839 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_pathway.py @@ -0,0 +1,691 @@ +"""Tests for the ``sympy.physics.mechanics.pathway.py`` module.""" + +import pytest + +from sympy import ( + Rational, + Symbol, + cos, + pi, + sin, + sqrt, +) +from sympy.physics.mechanics import ( + Force, + LinearPathway, + ObstacleSetPathway, + PathwayBase, + Point, + ReferenceFrame, + WrappingCylinder, + WrappingGeometryBase, + WrappingPathway, + WrappingSphere, + dynamicsymbols, +) +from sympy.simplify.simplify import simplify + + +def _simplify_loads(loads): + return [ + load.__class__(load.location, load.vector.simplify()) + for load in loads + ] + + +class TestLinearPathway: + + def test_is_pathway_base_subclass(self): + assert issubclass(LinearPathway, PathwayBase) + + @staticmethod + @pytest.mark.parametrize( + 'args, kwargs', + [ + ((Point('pA'), Point('pB')), {}), + ] + ) + def test_valid_constructor(args, kwargs): + pointA, pointB = args + instance = LinearPathway(*args, **kwargs) + assert isinstance(instance, LinearPathway) + assert hasattr(instance, 'attachments') + assert len(instance.attachments) == 2 + assert instance.attachments[0] is pointA + assert instance.attachments[1] is pointB + assert isinstance(instance.attachments[0], Point) + assert instance.attachments[0].name == 'pA' + assert isinstance(instance.attachments[1], Point) + assert instance.attachments[1].name == 'pB' + + @staticmethod + @pytest.mark.parametrize( + 'attachments', + [ + (Point('pA'), ), + (Point('pA'), Point('pB'), Point('pZ')), + ] + ) + def test_invalid_attachments_incorrect_number(attachments): + with pytest.raises(ValueError): + _ = LinearPathway(*attachments) + + @staticmethod + @pytest.mark.parametrize( + 'attachments', + [ + (None, Point('pB')), + (Point('pA'), None), + ] + ) + def test_invalid_attachments_not_point(attachments): + with pytest.raises(TypeError): + _ = LinearPathway(*attachments) + + @pytest.fixture(autouse=True) + def _linear_pathway_fixture(self): + self.N = ReferenceFrame('N') + self.pA = Point('pA') + self.pB = Point('pB') + self.pathway = LinearPathway(self.pA, self.pB) + self.q1 = dynamicsymbols('q1') + self.q2 = dynamicsymbols('q2') + self.q3 = dynamicsymbols('q3') + self.q1d = dynamicsymbols('q1', 1) + self.q2d = dynamicsymbols('q2', 1) + self.q3d = dynamicsymbols('q3', 1) + self.F = Symbol('F') + + def test_properties_are_immutable(self): + instance = LinearPathway(self.pA, self.pB) + with pytest.raises(AttributeError): + instance.attachments = None + with pytest.raises(TypeError): + instance.attachments[0] = None + with pytest.raises(TypeError): + instance.attachments[1] = None + + def test_repr(self): + pathway = LinearPathway(self.pA, self.pB) + expected = 'LinearPathway(pA, pB)' + assert repr(pathway) == expected + + def test_static_pathway_length(self): + self.pB.set_pos(self.pA, 2*self.N.x) + assert self.pathway.length == 2 + + def test_static_pathway_extension_velocity(self): + self.pB.set_pos(self.pA, 2*self.N.x) + assert self.pathway.extension_velocity == 0 + + def test_static_pathway_to_loads(self): + self.pB.set_pos(self.pA, 2*self.N.x) + expected = [ + (self.pA, - self.F*self.N.x), + (self.pB, self.F*self.N.x), + ] + assert self.pathway.to_loads(self.F) == expected + + def test_2D_pathway_length(self): + self.pB.set_pos(self.pA, 2*self.q1*self.N.x) + expected = 2*sqrt(self.q1**2) + assert self.pathway.length == expected + + def test_2D_pathway_extension_velocity(self): + self.pB.set_pos(self.pA, 2*self.q1*self.N.x) + expected = 2*sqrt(self.q1**2)*self.q1d/self.q1 + assert self.pathway.extension_velocity == expected + + def test_2D_pathway_to_loads(self): + self.pB.set_pos(self.pA, 2*self.q1*self.N.x) + expected = [ + (self.pA, - self.F*(self.q1 / sqrt(self.q1**2))*self.N.x), + (self.pB, self.F*(self.q1 / sqrt(self.q1**2))*self.N.x), + ] + assert self.pathway.to_loads(self.F) == expected + + def test_3D_pathway_length(self): + self.pB.set_pos( + self.pA, + self.q1*self.N.x - self.q2*self.N.y + 2*self.q3*self.N.z, + ) + expected = sqrt(self.q1**2 + self.q2**2 + 4*self.q3**2) + assert simplify(self.pathway.length - expected) == 0 + + def test_3D_pathway_extension_velocity(self): + self.pB.set_pos( + self.pA, + self.q1*self.N.x - self.q2*self.N.y + 2*self.q3*self.N.z, + ) + length = sqrt(self.q1**2 + self.q2**2 + 4*self.q3**2) + expected = ( + self.q1*self.q1d/length + + self.q2*self.q2d/length + + 4*self.q3*self.q3d/length + ) + assert simplify(self.pathway.extension_velocity - expected) == 0 + + def test_3D_pathway_to_loads(self): + self.pB.set_pos( + self.pA, + self.q1*self.N.x - self.q2*self.N.y + 2*self.q3*self.N.z, + ) + length = sqrt(self.q1**2 + self.q2**2 + 4*self.q3**2) + pO_force = ( + - self.F*self.q1*self.N.x/length + + self.F*self.q2*self.N.y/length + - 2*self.F*self.q3*self.N.z/length + ) + pI_force = ( + self.F*self.q1*self.N.x/length + - self.F*self.q2*self.N.y/length + + 2*self.F*self.q3*self.N.z/length + ) + expected = [ + (self.pA, pO_force), + (self.pB, pI_force), + ] + assert self.pathway.to_loads(self.F) == expected + + +class TestObstacleSetPathway: + + def test_is_pathway_base_subclass(self): + assert issubclass(ObstacleSetPathway, PathwayBase) + + @staticmethod + @pytest.mark.parametrize( + 'num_attachments, attachments', + [ + (3, [Point(name) for name in ('pO', 'pA', 'pI')]), + (4, [Point(name) for name in ('pO', 'pA', 'pB', 'pI')]), + (5, [Point(name) for name in ('pO', 'pA', 'pB', 'pC', 'pI')]), + (6, [Point(name) for name in ('pO', 'pA', 'pB', 'pC', 'pD', 'pI')]), + ] + ) + def test_valid_constructor(num_attachments, attachments): + instance = ObstacleSetPathway(*attachments) + assert isinstance(instance, ObstacleSetPathway) + assert hasattr(instance, 'attachments') + assert len(instance.attachments) == num_attachments + for attachment in instance.attachments: + assert isinstance(attachment, Point) + + @staticmethod + @pytest.mark.parametrize( + 'attachments', + [[Point('pO')], [Point('pO'), Point('pI')]], + ) + def test_invalid_constructor_attachments_incorrect_number(attachments): + with pytest.raises(ValueError): + _ = ObstacleSetPathway(*attachments) + + @staticmethod + @pytest.mark.parametrize( + 'attachments', + [ + (None, Point('pA'), Point('pI')), + (Point('pO'), None, Point('pI')), + (Point('pO'), Point('pA'), None), + ] + ) + def test_invalid_constructor_attachments_not_point(attachments): + with pytest.raises(TypeError): + _ = WrappingPathway(*attachments) # type: ignore + + def test_properties_are_immutable(self): + pathway = ObstacleSetPathway(Point('pO'), Point('pA'), Point('pI')) + with pytest.raises(AttributeError): + pathway.attachments = None # type: ignore + with pytest.raises(TypeError): + pathway.attachments[0] = None # type: ignore + with pytest.raises(TypeError): + pathway.attachments[1] = None # type: ignore + with pytest.raises(TypeError): + pathway.attachments[-1] = None # type: ignore + + @staticmethod + @pytest.mark.parametrize( + 'attachments, expected', + [ + ( + [Point(name) for name in ('pO', 'pA', 'pI')], + 'ObstacleSetPathway(pO, pA, pI)' + ), + ( + [Point(name) for name in ('pO', 'pA', 'pB', 'pI')], + 'ObstacleSetPathway(pO, pA, pB, pI)' + ), + ( + [Point(name) for name in ('pO', 'pA', 'pB', 'pC', 'pI')], + 'ObstacleSetPathway(pO, pA, pB, pC, pI)' + ), + ] + ) + def test_repr(attachments, expected): + pathway = ObstacleSetPathway(*attachments) + assert repr(pathway) == expected + + @pytest.fixture(autouse=True) + def _obstacle_set_pathway_fixture(self): + self.N = ReferenceFrame('N') + self.pO = Point('pO') + self.pI = Point('pI') + self.pA = Point('pA') + self.pB = Point('pB') + self.q = dynamicsymbols('q') + self.qd = dynamicsymbols('q', 1) + self.F = Symbol('F') + + def test_static_pathway_length(self): + self.pA.set_pos(self.pO, self.N.x) + self.pB.set_pos(self.pO, self.N.y) + self.pI.set_pos(self.pO, self.N.z) + pathway = ObstacleSetPathway(self.pO, self.pA, self.pB, self.pI) + assert pathway.length == 1 + 2 * sqrt(2) + + def test_static_pathway_extension_velocity(self): + self.pA.set_pos(self.pO, self.N.x) + self.pB.set_pos(self.pO, self.N.y) + self.pI.set_pos(self.pO, self.N.z) + pathway = ObstacleSetPathway(self.pO, self.pA, self.pB, self.pI) + assert pathway.extension_velocity == 0 + + def test_static_pathway_to_loads(self): + self.pA.set_pos(self.pO, self.N.x) + self.pB.set_pos(self.pO, self.N.y) + self.pI.set_pos(self.pO, self.N.z) + pathway = ObstacleSetPathway(self.pO, self.pA, self.pB, self.pI) + expected = [ + Force(self.pO, -self.F * self.N.x), + Force(self.pA, self.F * self.N.x), + Force(self.pA, self.F * sqrt(2) / 2 * (self.N.x - self.N.y)), + Force(self.pB, self.F * sqrt(2) / 2 * (self.N.y - self.N.x)), + Force(self.pB, self.F * sqrt(2) / 2 * (self.N.y - self.N.z)), + Force(self.pI, self.F * sqrt(2) / 2 * (self.N.z - self.N.y)), + ] + assert pathway.to_loads(self.F) == expected + + def test_2D_pathway_length(self): + self.pA.set_pos(self.pO, -(self.N.x + self.N.y)) + self.pB.set_pos( + self.pO, cos(self.q) * self.N.x - (sin(self.q) + 1) * self.N.y + ) + self.pI.set_pos( + self.pO, sin(self.q) * self.N.x + (cos(self.q) - 1) * self.N.y + ) + pathway = ObstacleSetPathway(self.pO, self.pA, self.pB, self.pI) + expected = 2 * sqrt(2) + sqrt(2 + 2*cos(self.q)) + assert (pathway.length - expected).simplify() == 0 + + def test_2D_pathway_extension_velocity(self): + self.pA.set_pos(self.pO, -(self.N.x + self.N.y)) + self.pB.set_pos( + self.pO, cos(self.q) * self.N.x - (sin(self.q) + 1) * self.N.y + ) + self.pI.set_pos( + self.pO, sin(self.q) * self.N.x + (cos(self.q) - 1) * self.N.y + ) + pathway = ObstacleSetPathway(self.pO, self.pA, self.pB, self.pI) + expected = - (sqrt(2) * sin(self.q) * self.qd) / (2 * sqrt(cos(self.q) + 1)) + assert (pathway.extension_velocity - expected).simplify() == 0 + + def test_2D_pathway_to_loads(self): + self.pA.set_pos(self.pO, -(self.N.x + self.N.y)) + self.pB.set_pos( + self.pO, cos(self.q) * self.N.x - (sin(self.q) + 1) * self.N.y + ) + self.pI.set_pos( + self.pO, sin(self.q) * self.N.x + (cos(self.q) - 1) * self.N.y + ) + pathway = ObstacleSetPathway(self.pO, self.pA, self.pB, self.pI) + pO_pA_force_vec = sqrt(2) / 2 * (self.N.x + self.N.y) + pA_pB_force_vec = ( + - sqrt(2 * cos(self.q) + 2) / 2 * self.N.x + + sqrt(2) * sin(self.q) / (2 * sqrt(cos(self.q) + 1)) * self.N.y + ) + pB_pI_force_vec = cos(self.q + pi/4) * self.N.x - sin(self.q + pi/4) * self.N.y + expected = [ + Force(self.pO, self.F * pO_pA_force_vec), + Force(self.pA, -self.F * pO_pA_force_vec), + Force(self.pA, self.F * pA_pB_force_vec), + Force(self.pB, -self.F * pA_pB_force_vec), + Force(self.pB, self.F * pB_pI_force_vec), + Force(self.pI, -self.F * pB_pI_force_vec), + ] + assert _simplify_loads(pathway.to_loads(self.F)) == expected + + +class TestWrappingPathway: + + def test_is_pathway_base_subclass(self): + assert issubclass(WrappingPathway, PathwayBase) + + @pytest.fixture(autouse=True) + def _wrapping_pathway_fixture(self): + self.pA = Point('pA') + self.pB = Point('pB') + self.r = Symbol('r', positive=True) + self.pO = Point('pO') + self.N = ReferenceFrame('N') + self.ax = self.N.z + self.sphere = WrappingSphere(self.r, self.pO) + self.cylinder = WrappingCylinder(self.r, self.pO, self.ax) + self.pathway = WrappingPathway(self.pA, self.pB, self.cylinder) + self.F = Symbol('F') + + def test_valid_constructor(self): + instance = WrappingPathway(self.pA, self.pB, self.cylinder) + assert isinstance(instance, WrappingPathway) + assert hasattr(instance, 'attachments') + assert len(instance.attachments) == 2 + assert isinstance(instance.attachments[0], Point) + assert instance.attachments[0] == self.pA + assert isinstance(instance.attachments[1], Point) + assert instance.attachments[1] == self.pB + assert hasattr(instance, 'geometry') + assert isinstance(instance.geometry, WrappingGeometryBase) + assert instance.geometry == self.cylinder + + @pytest.mark.parametrize( + 'attachments', + [ + (Point('pA'), ), + (Point('pA'), Point('pB'), Point('pZ')), + ] + ) + def test_invalid_constructor_attachments_incorrect_number(self, attachments): + with pytest.raises(TypeError): + _ = WrappingPathway(*attachments, self.cylinder) + + @staticmethod + @pytest.mark.parametrize( + 'attachments', + [ + (None, Point('pB')), + (Point('pA'), None), + ] + ) + def test_invalid_constructor_attachments_not_point(attachments): + with pytest.raises(TypeError): + _ = WrappingPathway(*attachments) + + def test_invalid_constructor_geometry_is_not_supplied(self): + with pytest.raises(TypeError): + _ = WrappingPathway(self.pA, self.pB) + + @pytest.mark.parametrize( + 'geometry', + [ + Symbol('r'), + dynamicsymbols('q'), + ReferenceFrame('N'), + ReferenceFrame('N').x, + ] + ) + def test_invalid_geometry_not_geometry(self, geometry): + with pytest.raises(TypeError): + _ = WrappingPathway(self.pA, self.pB, geometry) + + def test_attachments_property_is_immutable(self): + with pytest.raises(TypeError): + self.pathway.attachments[0] = self.pB + with pytest.raises(TypeError): + self.pathway.attachments[1] = self.pA + + def test_geometry_property_is_immutable(self): + with pytest.raises(AttributeError): + self.pathway.geometry = None + + def test_repr(self): + expected = ( + f'WrappingPathway(pA, pB, ' + f'geometry={self.cylinder!r})' + ) + assert repr(self.pathway) == expected + + @staticmethod + def _expand_pos_to_vec(pos, frame): + return sum(mag*unit for (mag, unit) in zip(pos, frame)) + + @pytest.mark.parametrize( + 'pA_vec, pB_vec, factor', + [ + ((1, 0, 0), (0, 1, 0), pi/2), + ((0, 1, 0), (sqrt(2)/2, -sqrt(2)/2, 0), 3*pi/4), + ((1, 0, 0), (Rational(1, 2), sqrt(3)/2, 0), pi/3), + ] + ) + def test_static_pathway_on_sphere_length(self, pA_vec, pB_vec, factor): + pA_vec = self._expand_pos_to_vec(pA_vec, self.N) + pB_vec = self._expand_pos_to_vec(pB_vec, self.N) + self.pA.set_pos(self.pO, self.r*pA_vec) + self.pB.set_pos(self.pO, self.r*pB_vec) + pathway = WrappingPathway(self.pA, self.pB, self.sphere) + expected = factor*self.r + assert simplify(pathway.length - expected) == 0 + + @pytest.mark.parametrize( + 'pA_vec, pB_vec, factor', + [ + ((1, 0, 0), (0, 1, 0), Rational(1, 2)*pi), + ((1, 0, 0), (-1, 0, 0), pi), + ((-1, 0, 0), (1, 0, 0), pi), + ((0, 1, 0), (sqrt(2)/2, -sqrt(2)/2, 0), 5*pi/4), + ((1, 0, 0), (Rational(1, 2), sqrt(3)/2, 0), pi/3), + ( + (0, 1, 0), + (sqrt(2)*Rational(1, 2), -sqrt(2)*Rational(1, 2), 1), + sqrt(1 + (Rational(5, 4)*pi)**2), + ), + ( + (1, 0, 0), + (Rational(1, 2), sqrt(3)*Rational(1, 2), 1), + sqrt(1 + (Rational(1, 3)*pi)**2), + ), + ] + ) + def test_static_pathway_on_cylinder_length(self, pA_vec, pB_vec, factor): + pA_vec = self._expand_pos_to_vec(pA_vec, self.N) + pB_vec = self._expand_pos_to_vec(pB_vec, self.N) + self.pA.set_pos(self.pO, self.r*pA_vec) + self.pB.set_pos(self.pO, self.r*pB_vec) + pathway = WrappingPathway(self.pA, self.pB, self.cylinder) + expected = factor*sqrt(self.r**2) + assert simplify(pathway.length - expected) == 0 + + @pytest.mark.parametrize( + 'pA_vec, pB_vec', + [ + ((1, 0, 0), (0, 1, 0)), + ((0, 1, 0), (sqrt(2)*Rational(1, 2), -sqrt(2)*Rational(1, 2), 0)), + ((1, 0, 0), (Rational(1, 2), sqrt(3)*Rational(1, 2), 0)), + ] + ) + def test_static_pathway_on_sphere_extension_velocity(self, pA_vec, pB_vec): + pA_vec = self._expand_pos_to_vec(pA_vec, self.N) + pB_vec = self._expand_pos_to_vec(pB_vec, self.N) + self.pA.set_pos(self.pO, self.r*pA_vec) + self.pB.set_pos(self.pO, self.r*pB_vec) + pathway = WrappingPathway(self.pA, self.pB, self.sphere) + assert pathway.extension_velocity == 0 + + @pytest.mark.parametrize( + 'pA_vec, pB_vec', + [ + ((1, 0, 0), (0, 1, 0)), + ((1, 0, 0), (-1, 0, 0)), + ((-1, 0, 0), (1, 0, 0)), + ((0, 1, 0), (sqrt(2)/2, -sqrt(2)/2, 0)), + ((1, 0, 0), (Rational(1, 2), sqrt(3)/2, 0)), + ((0, 1, 0), (sqrt(2)*Rational(1, 2), -sqrt(2)/2, 1)), + ((1, 0, 0), (Rational(1, 2), sqrt(3)/2, 1)), + ] + ) + def test_static_pathway_on_cylinder_extension_velocity(self, pA_vec, pB_vec): + pA_vec = self._expand_pos_to_vec(pA_vec, self.N) + pB_vec = self._expand_pos_to_vec(pB_vec, self.N) + self.pA.set_pos(self.pO, self.r*pA_vec) + self.pB.set_pos(self.pO, self.r*pB_vec) + pathway = WrappingPathway(self.pA, self.pB, self.cylinder) + assert pathway.extension_velocity == 0 + + @pytest.mark.parametrize( + 'pA_vec, pB_vec, pA_vec_expected, pB_vec_expected, pO_vec_expected', + ( + ((1, 0, 0), (0, 1, 0), (0, 1, 0), (1, 0, 0), (-1, -1, 0)), + ( + (0, 1, 0), + (sqrt(2)/2, -sqrt(2)/2, 0), + (1, 0, 0), + (sqrt(2)/2, sqrt(2)/2, 0), + (-1 - sqrt(2)/2, -sqrt(2)/2, 0) + ), + ( + (1, 0, 0), + (Rational(1, 2), sqrt(3)/2, 0), + (0, 1, 0), + (sqrt(3)/2, -Rational(1, 2), 0), + (-sqrt(3)/2, Rational(1, 2) - 1, 0), + ), + ) + ) + def test_static_pathway_on_sphere_to_loads( + self, + pA_vec, + pB_vec, + pA_vec_expected, + pB_vec_expected, + pO_vec_expected, + ): + pA_vec = self._expand_pos_to_vec(pA_vec, self.N) + pB_vec = self._expand_pos_to_vec(pB_vec, self.N) + self.pA.set_pos(self.pO, self.r*pA_vec) + self.pB.set_pos(self.pO, self.r*pB_vec) + pathway = WrappingPathway(self.pA, self.pB, self.sphere) + + pA_vec_expected = sum( + mag*unit for (mag, unit) in zip(pA_vec_expected, self.N) + ) + pB_vec_expected = sum( + mag*unit for (mag, unit) in zip(pB_vec_expected, self.N) + ) + pO_vec_expected = sum( + mag*unit for (mag, unit) in zip(pO_vec_expected, self.N) + ) + expected = [ + Force(self.pA, self.F*(self.r**3/sqrt(self.r**6))*pA_vec_expected), + Force(self.pB, self.F*(self.r**3/sqrt(self.r**6))*pB_vec_expected), + Force(self.pO, self.F*(self.r**3/sqrt(self.r**6))*pO_vec_expected), + ] + assert pathway.to_loads(self.F) == expected + + @pytest.mark.parametrize( + 'pA_vec, pB_vec, pA_vec_expected, pB_vec_expected, pO_vec_expected', + ( + ((1, 0, 0), (0, 1, 0), (0, 1, 0), (1, 0, 0), (-1, -1, 0)), + ((1, 0, 0), (-1, 0, 0), (0, 1, 0), (0, 1, 0), (0, -2, 0)), + ((-1, 0, 0), (1, 0, 0), (0, -1, 0), (0, -1, 0), (0, 2, 0)), + ( + (0, 1, 0), + (sqrt(2)/2, -sqrt(2)/2, 0), + (-1, 0, 0), + (-sqrt(2)/2, -sqrt(2)/2, 0), + (1 + sqrt(2)/2, sqrt(2)/2, 0) + ), + ( + (1, 0, 0), + (Rational(1, 2), sqrt(3)/2, 0), + (0, 1, 0), + (sqrt(3)/2, -Rational(1, 2), 0), + (-sqrt(3)/2, Rational(1, 2) - 1, 0), + ), + ( + (1, 0, 0), + (sqrt(2)/2, sqrt(2)/2, 0), + (0, 1, 0), + (sqrt(2)/2, -sqrt(2)/2, 0), + (-sqrt(2)/2, sqrt(2)/2 - 1, 0), + ), + ((0, 1, 0), (0, 1, 1), (0, 0, 1), (0, 0, -1), (0, 0, 0)), + ( + (0, 1, 0), + (sqrt(2)/2, -sqrt(2)/2, 1), + (-5*pi/sqrt(16 + 25*pi**2), 0, 4/sqrt(16 + 25*pi**2)), + ( + -5*sqrt(2)*pi/(2*sqrt(16 + 25*pi**2)), + -5*sqrt(2)*pi/(2*sqrt(16 + 25*pi**2)), + -4/sqrt(16 + 25*pi**2), + ), + ( + 5*(sqrt(2) + 2)*pi/(2*sqrt(16 + 25*pi**2)), + 5*sqrt(2)*pi/(2*sqrt(16 + 25*pi**2)), + 0, + ), + ), + ) + ) + def test_static_pathway_on_cylinder_to_loads( + self, + pA_vec, + pB_vec, + pA_vec_expected, + pB_vec_expected, + pO_vec_expected, + ): + pA_vec = self._expand_pos_to_vec(pA_vec, self.N) + pB_vec = self._expand_pos_to_vec(pB_vec, self.N) + self.pA.set_pos(self.pO, self.r*pA_vec) + self.pB.set_pos(self.pO, self.r*pB_vec) + pathway = WrappingPathway(self.pA, self.pB, self.cylinder) + + pA_force_expected = self.F*self._expand_pos_to_vec(pA_vec_expected, + self.N) + pB_force_expected = self.F*self._expand_pos_to_vec(pB_vec_expected, + self.N) + pO_force_expected = self.F*self._expand_pos_to_vec(pO_vec_expected, + self.N) + expected = [ + Force(self.pA, pA_force_expected), + Force(self.pB, pB_force_expected), + Force(self.pO, pO_force_expected), + ] + assert _simplify_loads(pathway.to_loads(self.F)) == expected + + def test_2D_pathway_on_cylinder_length(self): + q = dynamicsymbols('q') + pA_pos = self.r*self.N.x + pB_pos = self.r*(cos(q)*self.N.x + sin(q)*self.N.y) + self.pA.set_pos(self.pO, pA_pos) + self.pB.set_pos(self.pO, pB_pos) + expected = self.r*sqrt(q**2) + assert simplify(self.pathway.length - expected) == 0 + + def test_2D_pathway_on_cylinder_extension_velocity(self): + q = dynamicsymbols('q') + qd = dynamicsymbols('q', 1) + pA_pos = self.r*self.N.x + pB_pos = self.r*(cos(q)*self.N.x + sin(q)*self.N.y) + self.pA.set_pos(self.pO, pA_pos) + self.pB.set_pos(self.pO, pB_pos) + expected = self.r*(sqrt(q**2)/q)*qd + assert simplify(self.pathway.extension_velocity - expected) == 0 + + def test_2D_pathway_on_cylinder_to_loads(self): + q = dynamicsymbols('q') + pA_pos = self.r*self.N.x + pB_pos = self.r*(cos(q)*self.N.x + sin(q)*self.N.y) + self.pA.set_pos(self.pO, pA_pos) + self.pB.set_pos(self.pO, pB_pos) + + pA_force = self.F*self.N.y + pB_force = self.F*(sin(q)*self.N.x - cos(q)*self.N.y) + pO_force = self.F*(-sin(q)*self.N.x + (cos(q) - 1)*self.N.y) + expected = [ + Force(self.pA, pA_force), + Force(self.pB, pB_force), + Force(self.pO, pO_force), + ] + + loads = _simplify_loads(self.pathway.to_loads(self.F)) + assert loads == expected diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_rigidbody.py b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_rigidbody.py new file mode 100644 index 0000000000000000000000000000000000000000..78161e0c9fc33be6e3d274034b67278c8ceee8fd --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_rigidbody.py @@ -0,0 +1,184 @@ +from sympy.physics.mechanics import Point, ReferenceFrame, Dyadic, RigidBody +from sympy.physics.mechanics import dynamicsymbols, outer, inertia, Inertia +from sympy.physics.mechanics import inertia_of_point_mass +from sympy import expand, zeros, simplify, symbols +from sympy.testing.pytest import raises, warns_deprecated_sympy + + +def test_rigidbody_default(): + # Test default + b = RigidBody('B') + I = inertia(b.frame, *symbols('B_ixx B_iyy B_izz B_ixy B_iyz B_izx')) + assert b.name == 'B' + assert b.mass == symbols('B_mass') + assert b.masscenter.name == 'B_masscenter' + assert b.inertia == (I, b.masscenter) + assert b.central_inertia == I + assert b.frame.name == 'B_frame' + assert b.__str__() == 'B' + assert b.__repr__() == ( + "RigidBody('B', masscenter=B_masscenter, frame=B_frame, mass=B_mass, " + "inertia=Inertia(dyadic=B_ixx*(B_frame.x|B_frame.x) + " + "B_ixy*(B_frame.x|B_frame.y) + B_izx*(B_frame.x|B_frame.z) + " + "B_ixy*(B_frame.y|B_frame.x) + B_iyy*(B_frame.y|B_frame.y) + " + "B_iyz*(B_frame.y|B_frame.z) + B_izx*(B_frame.z|B_frame.x) + " + "B_iyz*(B_frame.z|B_frame.y) + B_izz*(B_frame.z|B_frame.z), " + "point=B_masscenter))") + + +def test_rigidbody(): + m, m2, v1, v2, v3, omega = symbols('m m2 v1 v2 v3 omega') + A = ReferenceFrame('A') + A2 = ReferenceFrame('A2') + P = Point('P') + P2 = Point('P2') + I = Dyadic(0) + I2 = Dyadic(0) + B = RigidBody('B', P, A, m, (I, P)) + assert B.mass == m + assert B.frame == A + assert B.masscenter == P + assert B.inertia == (I, B.masscenter) + + B.mass = m2 + B.frame = A2 + B.masscenter = P2 + B.inertia = (I2, B.masscenter) + raises(TypeError, lambda: RigidBody(P, P, A, m, (I, P))) + raises(TypeError, lambda: RigidBody('B', P, P, m, (I, P))) + raises(TypeError, lambda: RigidBody('B', P, A, m, (P, P))) + raises(TypeError, lambda: RigidBody('B', P, A, m, (I, I))) + assert B.__str__() == 'B' + assert B.mass == m2 + assert B.frame == A2 + assert B.masscenter == P2 + assert B.inertia == (I2, B.masscenter) + assert isinstance(B.inertia, Inertia) + + # Testing linear momentum function assuming A2 is the inertial frame + N = ReferenceFrame('N') + P2.set_vel(N, v1 * N.x + v2 * N.y + v3 * N.z) + assert B.linear_momentum(N) == m2 * (v1 * N.x + v2 * N.y + v3 * N.z) + + +def test_rigidbody2(): + M, v, r, omega, g, h = dynamicsymbols('M v r omega g h') + N = ReferenceFrame('N') + b = ReferenceFrame('b') + b.set_ang_vel(N, omega * b.x) + P = Point('P') + I = outer(b.x, b.x) + Inertia_tuple = (I, P) + B = RigidBody('B', P, b, M, Inertia_tuple) + P.set_vel(N, v * b.x) + assert B.angular_momentum(P, N) == omega * b.x + O = Point('O') + O.set_vel(N, v * b.x) + P.set_pos(O, r * b.y) + assert B.angular_momentum(O, N) == omega * b.x - M*v*r*b.z + B.potential_energy = M * g * h + assert B.potential_energy == M * g * h + assert expand(2 * B.kinetic_energy(N)) == omega**2 + M * v**2 + + +def test_rigidbody3(): + q1, q2, q3, q4 = dynamicsymbols('q1:5') + p1, p2, p3 = symbols('p1:4') + m = symbols('m') + + A = ReferenceFrame('A') + B = A.orientnew('B', 'axis', [q1, A.x]) + O = Point('O') + O.set_vel(A, q2*A.x + q3*A.y + q4*A.z) + P = O.locatenew('P', p1*B.x + p2*B.y + p3*B.z) + P.v2pt_theory(O, A, B) + I = outer(B.x, B.x) + + rb1 = RigidBody('rb1', P, B, m, (I, P)) + # I_S/O = I_S/S* + I_S*/O + rb2 = RigidBody('rb2', P, B, m, + (I + inertia_of_point_mass(m, P.pos_from(O), B), O)) + + assert rb1.central_inertia == rb2.central_inertia + assert rb1.angular_momentum(O, A) == rb2.angular_momentum(O, A) + + +def test_pendulum_angular_momentum(): + """Consider a pendulum of length OA = 2a, of mass m as a rigid body of + center of mass G (OG = a) which turn around (O,z). The angle between the + reference frame R and the rod is q. The inertia of the body is I = + (G,0,ma^2/3,ma^2/3). """ + + m, a = symbols('m, a') + q = dynamicsymbols('q') + + R = ReferenceFrame('R') + R1 = R.orientnew('R1', 'Axis', [q, R.z]) + R1.set_ang_vel(R, q.diff() * R.z) + + I = inertia(R1, 0, m * a**2 / 3, m * a**2 / 3) + + O = Point('O') + + A = O.locatenew('A', 2*a * R1.x) + G = O.locatenew('G', a * R1.x) + + S = RigidBody('S', G, R1, m, (I, G)) + + O.set_vel(R, 0) + A.v2pt_theory(O, R, R1) + G.v2pt_theory(O, R, R1) + + assert (4 * m * a**2 / 3 * q.diff() * R.z - + S.angular_momentum(O, R).express(R)) == 0 + + +def test_rigidbody_inertia(): + N = ReferenceFrame('N') + m, Ix, Iy, Iz, a, b = symbols('m, I_x, I_y, I_z, a, b') + Io = inertia(N, Ix, Iy, Iz) + o = Point('o') + p = o.locatenew('p', a * N.x + b * N.y) + R = RigidBody('R', o, N, m, (Io, p)) + I_check = inertia(N, Ix - b ** 2 * m, Iy - a ** 2 * m, + Iz - m * (a ** 2 + b ** 2), m * a * b) + assert isinstance(R.inertia, Inertia) + assert R.inertia == (Io, p) + assert R.central_inertia == I_check + R.central_inertia = Io + assert R.inertia == (Io, o) + assert R.central_inertia == Io + R.inertia = (Io, p) + assert R.inertia == (Io, p) + assert R.central_inertia == I_check + # parse Inertia object + R.inertia = Inertia(Io, o) + assert R.inertia == (Io, o) + + +def test_parallel_axis(): + N = ReferenceFrame('N') + m, Ix, Iy, Iz, a, b = symbols('m, I_x, I_y, I_z, a, b') + Io = inertia(N, Ix, Iy, Iz) + o = Point('o') + p = o.locatenew('p', a * N.x + b * N.y) + R = RigidBody('R', o, N, m, (Io, o)) + Ip = R.parallel_axis(p) + Ip_expected = inertia(N, Ix + m * b**2, Iy + m * a**2, + Iz + m * (a**2 + b**2), ixy=-m * a * b) + assert Ip == Ip_expected + # Reference frame from which the parallel axis is viewed should not matter + A = ReferenceFrame('A') + A.orient_axis(N, N.z, 1) + assert simplify( + (R.parallel_axis(p, A) - Ip_expected).to_matrix(A)) == zeros(3, 3) + + +def test_deprecated_set_potential_energy(): + m, g, h = symbols('m g h') + A = ReferenceFrame('A') + P = Point('P') + I = Dyadic(0) + B = RigidBody('B', P, A, m, (I, P)) + with warns_deprecated_sympy(): + B.set_potential_energy(m*g*h) diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_system.py b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_system.py new file mode 100644 index 0000000000000000000000000000000000000000..6fdac1ea10e9f71f8cf999cc5069da7567f67adf --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_system.py @@ -0,0 +1,245 @@ +from sympy import symbols, Matrix, atan, zeros +from sympy.simplify.simplify import simplify +from sympy.physics.mechanics import (dynamicsymbols, Particle, Point, + ReferenceFrame, SymbolicSystem) +from sympy.testing.pytest import raises + +# This class is going to be tested using a simple pendulum set up in x and y +# coordinates +x, y, u, v, lam = dynamicsymbols('x y u v lambda') +m, l, g = symbols('m l g') + +# Set up the different forms the equations can take +# [1] Explicit form where the kinematics and dynamics are combined +# x' = F(x, t, r, p) +# +# [2] Implicit form where the kinematics and dynamics are combined +# M(x, p) x' = F(x, t, r, p) +# +# [3] Implicit form where the kinematics and dynamics are separate +# M(q, p) u' = F(q, u, t, r, p) +# q' = G(q, u, t, r, p) +dyn_implicit_mat = Matrix([[1, 0, -x/m], + [0, 1, -y/m], + [0, 0, l**2/m]]) + +dyn_implicit_rhs = Matrix([0, 0, u**2 + v**2 - g*y]) + +comb_implicit_mat = Matrix([[1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 1, 0, -x/m], + [0, 0, 0, 1, -y/m], + [0, 0, 0, 0, l**2/m]]) + +comb_implicit_rhs = Matrix([u, v, 0, 0, u**2 + v**2 - g*y]) + +kin_explicit_rhs = Matrix([u, v]) + +comb_explicit_rhs = comb_implicit_mat.LUsolve(comb_implicit_rhs) + +# Set up a body and load to pass into the system +theta = atan(x/y) +N = ReferenceFrame('N') +A = N.orientnew('A', 'Axis', [theta, N.z]) +O = Point('O') +P = O.locatenew('P', l * A.x) + +Pa = Particle('Pa', P, m) + +bodies = [Pa] +loads = [(P, g * m * N.x)] + +# Set up some output equations to be given to SymbolicSystem +# Change to make these fit the pendulum +PE = symbols("PE") +out_eqns = {PE: m*g*(l+y)} + +# Set up remaining arguments that can be passed to SymbolicSystem +alg_con = [2] +alg_con_full = [4] +coordinates = (x, y, lam) +speeds = (u, v) +states = (x, y, u, v, lam) +coord_idxs = (0, 1) +speed_idxs = (2, 3) + + +def test_form_1(): + symsystem1 = SymbolicSystem(states, comb_explicit_rhs, + alg_con=alg_con_full, output_eqns=out_eqns, + coord_idxs=coord_idxs, speed_idxs=speed_idxs, + bodies=bodies, loads=loads) + + assert symsystem1.coordinates == Matrix([x, y]) + assert symsystem1.speeds == Matrix([u, v]) + assert symsystem1.states == Matrix([x, y, u, v, lam]) + + assert symsystem1.alg_con == [4] + + inter = comb_explicit_rhs + assert simplify(symsystem1.comb_explicit_rhs - inter) == zeros(5, 1) + + assert set(symsystem1.dynamic_symbols()) == {y, v, lam, u, x} + assert type(symsystem1.dynamic_symbols()) == tuple + assert set(symsystem1.constant_symbols()) == {l, g, m} + assert type(symsystem1.constant_symbols()) == tuple + + assert symsystem1.output_eqns == out_eqns + + assert symsystem1.bodies == (Pa,) + assert symsystem1.loads == ((P, g * m * N.x),) + + +def test_form_2(): + symsystem2 = SymbolicSystem(coordinates, comb_implicit_rhs, speeds=speeds, + mass_matrix=comb_implicit_mat, + alg_con=alg_con_full, output_eqns=out_eqns, + bodies=bodies, loads=loads) + + assert symsystem2.coordinates == Matrix([x, y, lam]) + assert symsystem2.speeds == Matrix([u, v]) + assert symsystem2.states == Matrix([x, y, lam, u, v]) + + assert symsystem2.alg_con == [4] + + inter = comb_implicit_rhs + assert simplify(symsystem2.comb_implicit_rhs - inter) == zeros(5, 1) + assert simplify(symsystem2.comb_implicit_mat-comb_implicit_mat) == zeros(5) + + assert set(symsystem2.dynamic_symbols()) == {y, v, lam, u, x} + assert type(symsystem2.dynamic_symbols()) == tuple + assert set(symsystem2.constant_symbols()) == {l, g, m} + assert type(symsystem2.constant_symbols()) == tuple + + inter = comb_explicit_rhs + symsystem2.compute_explicit_form() + assert simplify(symsystem2.comb_explicit_rhs - inter) == zeros(5, 1) + + + assert symsystem2.output_eqns == out_eqns + + assert symsystem2.bodies == (Pa,) + assert symsystem2.loads == ((P, g * m * N.x),) + + +def test_form_3(): + symsystem3 = SymbolicSystem(states, dyn_implicit_rhs, + mass_matrix=dyn_implicit_mat, + coordinate_derivatives=kin_explicit_rhs, + alg_con=alg_con, coord_idxs=coord_idxs, + speed_idxs=speed_idxs, bodies=bodies, + loads=loads) + + assert symsystem3.coordinates == Matrix([x, y]) + assert symsystem3.speeds == Matrix([u, v]) + assert symsystem3.states == Matrix([x, y, u, v, lam]) + + assert symsystem3.alg_con == [4] + + inter1 = kin_explicit_rhs + inter2 = dyn_implicit_rhs + assert simplify(symsystem3.kin_explicit_rhs - inter1) == zeros(2, 1) + assert simplify(symsystem3.dyn_implicit_mat - dyn_implicit_mat) == zeros(3) + assert simplify(symsystem3.dyn_implicit_rhs - inter2) == zeros(3, 1) + + inter = comb_implicit_rhs + assert simplify(symsystem3.comb_implicit_rhs - inter) == zeros(5, 1) + assert simplify(symsystem3.comb_implicit_mat-comb_implicit_mat) == zeros(5) + + inter = comb_explicit_rhs + symsystem3.compute_explicit_form() + assert simplify(symsystem3.comb_explicit_rhs - inter) == zeros(5, 1) + + assert set(symsystem3.dynamic_symbols()) == {y, v, lam, u, x} + assert type(symsystem3.dynamic_symbols()) == tuple + assert set(symsystem3.constant_symbols()) == {l, g, m} + assert type(symsystem3.constant_symbols()) == tuple + + assert symsystem3.output_eqns == {} + + assert symsystem3.bodies == (Pa,) + assert symsystem3.loads == ((P, g * m * N.x),) + + +def test_property_attributes(): + symsystem = SymbolicSystem(states, comb_explicit_rhs, + alg_con=alg_con_full, output_eqns=out_eqns, + coord_idxs=coord_idxs, speed_idxs=speed_idxs, + bodies=bodies, loads=loads) + + with raises(AttributeError): + symsystem.bodies = 42 + with raises(AttributeError): + symsystem.coordinates = 42 + with raises(AttributeError): + symsystem.dyn_implicit_rhs = 42 + with raises(AttributeError): + symsystem.comb_implicit_rhs = 42 + with raises(AttributeError): + symsystem.loads = 42 + with raises(AttributeError): + symsystem.dyn_implicit_mat = 42 + with raises(AttributeError): + symsystem.comb_implicit_mat = 42 + with raises(AttributeError): + symsystem.kin_explicit_rhs = 42 + with raises(AttributeError): + symsystem.comb_explicit_rhs = 42 + with raises(AttributeError): + symsystem.speeds = 42 + with raises(AttributeError): + symsystem.states = 42 + with raises(AttributeError): + symsystem.alg_con = 42 + + +def test_not_specified_errors(): + """This test will cover errors that arise from trying to access attributes + that were not specified upon object creation or were specified on creation + and the user tries to recalculate them.""" + # Trying to access form 2 when form 1 given + # Trying to access form 3 when form 2 given + + symsystem1 = SymbolicSystem(states, comb_explicit_rhs) + + with raises(AttributeError): + symsystem1.comb_implicit_mat + with raises(AttributeError): + symsystem1.comb_implicit_rhs + with raises(AttributeError): + symsystem1.dyn_implicit_mat + with raises(AttributeError): + symsystem1.dyn_implicit_rhs + with raises(AttributeError): + symsystem1.kin_explicit_rhs + with raises(AttributeError): + symsystem1.compute_explicit_form() + + symsystem2 = SymbolicSystem(coordinates, comb_implicit_rhs, speeds=speeds, + mass_matrix=comb_implicit_mat) + + with raises(AttributeError): + symsystem2.dyn_implicit_mat + with raises(AttributeError): + symsystem2.dyn_implicit_rhs + with raises(AttributeError): + symsystem2.kin_explicit_rhs + + # Attribute error when trying to access coordinates and speeds when only the + # states were given. + with raises(AttributeError): + symsystem1.coordinates + with raises(AttributeError): + symsystem1.speeds + + # Attribute error when trying to access bodies and loads when they are not + # given + with raises(AttributeError): + symsystem1.bodies + with raises(AttributeError): + symsystem1.loads + + # Attribute error when trying to access comb_explicit_rhs before it was + # calculated + with raises(AttributeError): + symsystem2.comb_explicit_rhs diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_system_class.py b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_system_class.py new file mode 100644 index 0000000000000000000000000000000000000000..924cb8272c27c4f978aa4c3b1999f6ac56e47335 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_system_class.py @@ -0,0 +1,831 @@ +import pytest + +from sympy.core.symbol import symbols +from sympy.core.sympify import sympify +from sympy.functions.elementary.trigonometric import cos, sin +from sympy.matrices.dense import eye, zeros +from sympy.matrices.immutable import ImmutableMatrix +from sympy.physics.mechanics import ( + Force, KanesMethod, LagrangesMethod, Particle, PinJoint, Point, + PrismaticJoint, ReferenceFrame, RigidBody, Torque, TorqueActuator, System, + dynamicsymbols) +from sympy.simplify.simplify import simplify +from sympy.solvers.solvers import solve + +t = dynamicsymbols._t # type: ignore +q = dynamicsymbols('q:6') # type: ignore +qd = dynamicsymbols('q:6', 1) # type: ignore +u = dynamicsymbols('u:6') # type: ignore +ua = dynamicsymbols('ua:3') # type: ignore + + +class TestSystemBase: + @pytest.fixture() + def _empty_system_setup(self): + self.system = System(ReferenceFrame('frame'), Point('fixed_point')) + + def _empty_system_check(self, exclude=()): + matrices = ('q_ind', 'q_dep', 'q', 'u_ind', 'u_dep', 'u', 'u_aux', + 'kdes', 'holonomic_constraints', 'nonholonomic_constraints') + tuples = ('loads', 'bodies', 'joints', 'actuators') + for attr in matrices: + if attr not in exclude: + assert getattr(self.system, attr)[:] == [] + for attr in tuples: + if attr not in exclude: + assert getattr(self.system, attr) == () + if 'eom_method' not in exclude: + assert self.system.eom_method is None + + def _create_filled_system(self, with_speeds=True): + self.system = System(ReferenceFrame('frame'), Point('fixed_point')) + u = dynamicsymbols('u:6') if with_speeds else qd + self.bodies = symbols('rb1:5', cls=RigidBody) + self.joints = ( + PinJoint('J1', self.bodies[0], self.bodies[1], q[0], u[0]), + PrismaticJoint('J2', self.bodies[1], self.bodies[2], q[1], u[1]), + PinJoint('J3', self.bodies[2], self.bodies[3], q[2], u[2]) + ) + self.system.add_joints(*self.joints) + self.system.add_coordinates(q[3], independent=[False]) + self.system.add_speeds(u[3], independent=False) + if with_speeds: + self.system.add_kdes(u[3] - qd[3]) + self.system.add_auxiliary_speeds(ua[0], ua[1]) + self.system.add_holonomic_constraints(q[2] - q[0] + q[1]) + self.system.add_nonholonomic_constraints(u[3] - qd[1] + u[2]) + self.system.u_ind = u[:2] + self.system.u_dep = u[2:4] + self.q_ind, self.q_dep = self.system.q_ind[:], self.system.q_dep[:] + self.u_ind, self.u_dep = self.system.u_ind[:], self.system.u_dep[:] + self.kdes = self.system.kdes[:] + self.hc = self.system.holonomic_constraints[:] + self.vc = self.system.velocity_constraints[:] + self.nhc = self.system.nonholonomic_constraints[:] + + @pytest.fixture() + def _filled_system_setup(self): + self._create_filled_system(with_speeds=True) + + @pytest.fixture() + def _filled_system_setup_no_speeds(self): + self._create_filled_system(with_speeds=False) + + def _filled_system_check(self, exclude=()): + assert 'q_ind' in exclude or self.system.q_ind[:] == q[:3] + assert 'q_dep' in exclude or self.system.q_dep[:] == [q[3]] + assert 'q' in exclude or self.system.q[:] == q[:4] + assert 'u_ind' in exclude or self.system.u_ind[:] == u[:2] + assert 'u_dep' in exclude or self.system.u_dep[:] == u[2:4] + assert 'u' in exclude or self.system.u[:] == u[:4] + assert 'u_aux' in exclude or self.system.u_aux[:] == ua[:2] + assert 'kdes' in exclude or self.system.kdes[:] == [ + ui - qdi for ui, qdi in zip(u[:4], qd[:4])] + assert ('holonomic_constraints' in exclude or + self.system.holonomic_constraints[:] == [q[2] - q[0] + q[1]]) + assert ('nonholonomic_constraints' in exclude or + self.system.nonholonomic_constraints[:] == [u[3] - qd[1] + u[2]] + ) + assert ('velocity_constraints' in exclude or + self.system.velocity_constraints[:] == [ + qd[2] - qd[0] + qd[1], u[3] - qd[1] + u[2]]) + assert ('bodies' in exclude or + self.system.bodies == tuple(self.bodies)) + assert ('joints' in exclude or + self.system.joints == tuple(self.joints)) + + @pytest.fixture() + def _moving_point_mass(self, _empty_system_setup): + self.system.q_ind = q[0] + self.system.u_ind = u[0] + self.system.kdes = u[0] - q[0].diff(t) + p = Particle('p', mass=symbols('m')) + self.system.add_bodies(p) + p.masscenter.set_pos(self.system.fixed_point, q[0] * self.system.x) + + +class TestSystem(TestSystemBase): + def test_empty_system(self, _empty_system_setup): + self._empty_system_check() + self.system.validate_system() + + def test_filled_system(self, _filled_system_setup): + self._filled_system_check() + self.system.validate_system() + + @pytest.mark.parametrize('frame', [None, ReferenceFrame('frame')]) + @pytest.mark.parametrize('fixed_point', [None, Point('fixed_point')]) + def test_init(self, frame, fixed_point): + if fixed_point is None and frame is None: + self.system = System() + else: + self.system = System(frame, fixed_point) + if fixed_point is None: + assert self.system.fixed_point.name == 'inertial_point' + else: + assert self.system.fixed_point == fixed_point + if frame is None: + assert self.system.frame.name == 'inertial_frame' + else: + assert self.system.frame == frame + self._empty_system_check() + assert isinstance(self.system.q_ind, ImmutableMatrix) + assert isinstance(self.system.q_dep, ImmutableMatrix) + assert isinstance(self.system.q, ImmutableMatrix) + assert isinstance(self.system.u_ind, ImmutableMatrix) + assert isinstance(self.system.u_dep, ImmutableMatrix) + assert isinstance(self.system.u, ImmutableMatrix) + assert isinstance(self.system.kdes, ImmutableMatrix) + assert isinstance(self.system.holonomic_constraints, ImmutableMatrix) + assert isinstance(self.system.nonholonomic_constraints, ImmutableMatrix) + + def test_from_newtonian_rigid_body(self): + rb = RigidBody('body') + self.system = System.from_newtonian(rb) + assert self.system.fixed_point == rb.masscenter + assert self.system.frame == rb.frame + self._empty_system_check(exclude=('bodies',)) + self.system.bodies = (rb,) + + def test_from_newtonian_particle(self): + pt = Particle('particle') + with pytest.raises(TypeError): + System.from_newtonian(pt) + + @pytest.mark.parametrize('args, kwargs, exp_q_ind, exp_q_dep, exp_q', [ + (q[:3], {}, q[:3], [], q[:3]), + (q[:3], {'independent': True}, q[:3], [], q[:3]), + (q[:3], {'independent': False}, [], q[:3], q[:3]), + (q[:3], {'independent': [True, False, True]}, [q[0], q[2]], [q[1]], + [q[0], q[2], q[1]]), + ]) + def test_coordinates(self, _empty_system_setup, args, kwargs, + exp_q_ind, exp_q_dep, exp_q): + # Test add_coordinates + self.system.add_coordinates(*args, **kwargs) + assert self.system.q_ind[:] == exp_q_ind + assert self.system.q_dep[:] == exp_q_dep + assert self.system.q[:] == exp_q + self._empty_system_check(exclude=('q_ind', 'q_dep', 'q')) + # Test setter for q_ind and q_dep + self.system.q_ind = exp_q_ind + self.system.q_dep = exp_q_dep + assert self.system.q_ind[:] == exp_q_ind + assert self.system.q_dep[:] == exp_q_dep + assert self.system.q[:] == exp_q + self._empty_system_check(exclude=('q_ind', 'q_dep', 'q')) + + @pytest.mark.parametrize('func', ['add_coordinates', 'add_speeds']) + @pytest.mark.parametrize('args, kwargs', [ + ((q[0], q[5]), {}), + ((u[0], u[5]), {}), + ((q[0],), {'independent': False}), + ((u[0],), {'independent': False}), + ((u[0], q[5]), {}), + ((symbols('a'), q[5]), {}), + ]) + def test_coordinates_speeds_invalid(self, _filled_system_setup, func, args, + kwargs): + with pytest.raises(ValueError): + getattr(self.system, func)(*args, **kwargs) + self._filled_system_check() + + @pytest.mark.parametrize('args, kwargs, exp_u_ind, exp_u_dep, exp_u', [ + (u[:3], {}, u[:3], [], u[:3]), + (u[:3], {'independent': True}, u[:3], [], u[:3]), + (u[:3], {'independent': False}, [], u[:3], u[:3]), + (u[:3], {'independent': [True, False, True]}, [u[0], u[2]], [u[1]], + [u[0], u[2], u[1]]), + ]) + def test_speeds(self, _empty_system_setup, args, kwargs, exp_u_ind, + exp_u_dep, exp_u): + # Test add_speeds + self.system.add_speeds(*args, **kwargs) + assert self.system.u_ind[:] == exp_u_ind + assert self.system.u_dep[:] == exp_u_dep + assert self.system.u[:] == exp_u + self._empty_system_check(exclude=('u_ind', 'u_dep', 'u')) + # Test setter for u_ind and u_dep + self.system.u_ind = exp_u_ind + self.system.u_dep = exp_u_dep + assert self.system.u_ind[:] == exp_u_ind + assert self.system.u_dep[:] == exp_u_dep + assert self.system.u[:] == exp_u + self._empty_system_check(exclude=('u_ind', 'u_dep', 'u')) + + @pytest.mark.parametrize('args, kwargs, exp_u_aux', [ + (ua[:3], {}, ua[:3]), + ]) + def test_auxiliary_speeds(self, _empty_system_setup, args, kwargs, + exp_u_aux): + # Test add_speeds + self.system.add_auxiliary_speeds(*args, **kwargs) + assert self.system.u_aux[:] == exp_u_aux + self._empty_system_check(exclude=('u_aux',)) + # Test setter for u_ind and u_dep + self.system.u_aux = exp_u_aux + assert self.system.u_aux[:] == exp_u_aux + self._empty_system_check(exclude=('u_aux',)) + + @pytest.mark.parametrize('args, kwargs', [ + ((ua[2], q[0]), {}), + ((ua[2], u[1]), {}), + ((ua[0], ua[2]), {}), + ((symbols('a'), ua[2]), {}), + ]) + def test_auxiliary_invalid(self, _filled_system_setup, args, kwargs): + with pytest.raises(ValueError): + self.system.add_auxiliary_speeds(*args, **kwargs) + self._filled_system_check() + + @pytest.mark.parametrize('prop, add_func, args, kwargs', [ + ('q_ind', 'add_coordinates', (q[0],), {}), + ('q_dep', 'add_coordinates', (q[3],), {'independent': False}), + ('u_ind', 'add_speeds', (u[0],), {}), + ('u_dep', 'add_speeds', (u[3],), {'independent': False}), + ('u_aux', 'add_auxiliary_speeds', (ua[2],), {}), + ('kdes', 'add_kdes', (qd[0] - u[0],), {}), + ('holonomic_constraints', 'add_holonomic_constraints', + (q[0] - q[1],), {}), + ('nonholonomic_constraints', 'add_nonholonomic_constraints', + (u[0] - u[1],), {}), + ('bodies', 'add_bodies', (RigidBody('body'),), {}), + ('loads', 'add_loads', (Force(Point('P'), ReferenceFrame('N').x),), {}), + ('actuators', 'add_actuators', (TorqueActuator( + symbols('T'), ReferenceFrame('N').x, ReferenceFrame('A')),), {}), + ]) + def test_add_after_reset(self, _filled_system_setup, prop, add_func, args, + kwargs): + setattr(self.system, prop, ()) + exclude = (prop, 'q', 'u') + if prop in ('holonomic_constraints', 'nonholonomic_constraints'): + exclude += ('velocity_constraints',) + self._filled_system_check(exclude=exclude) + assert list(getattr(self.system, prop)[:]) == [] + getattr(self.system, add_func)(*args, **kwargs) + assert list(getattr(self.system, prop)[:]) == list(args) + + @pytest.mark.parametrize('prop, add_func, value, error', [ + ('q_ind', 'add_coordinates', symbols('a'), ValueError), + ('q_dep', 'add_coordinates', symbols('a'), ValueError), + ('u_ind', 'add_speeds', symbols('a'), ValueError), + ('u_dep', 'add_speeds', symbols('a'), ValueError), + ('u_aux', 'add_auxiliary_speeds', symbols('a'), ValueError), + ('kdes', 'add_kdes', 7, TypeError), + ('holonomic_constraints', 'add_holonomic_constraints', 7, TypeError), + ('nonholonomic_constraints', 'add_nonholonomic_constraints', 7, + TypeError), + ('bodies', 'add_bodies', symbols('a'), TypeError), + ('loads', 'add_loads', symbols('a'), TypeError), + ('actuators', 'add_actuators', symbols('a'), TypeError), + ]) + def test_type_error(self, _filled_system_setup, prop, add_func, value, + error): + with pytest.raises(error): + getattr(self.system, add_func)(value) + with pytest.raises(error): + setattr(self.system, prop, value) + self._filled_system_check() + + @pytest.mark.parametrize('args, kwargs, exp_kdes', [ + ((), {}, [ui - qdi for ui, qdi in zip(u[:4], qd[:4])]), + ((u[4] - qd[4], u[5] - qd[5]), {}, + [ui - qdi for ui, qdi in zip(u[:6], qd[:6])]), + ]) + def test_kdes(self, _filled_system_setup, args, kwargs, exp_kdes): + # Test add_speeds + self.system.add_kdes(*args, **kwargs) + self._filled_system_check(exclude=('kdes',)) + assert self.system.kdes[:] == exp_kdes + # Test setter for kdes + self.system.kdes = exp_kdes + self._filled_system_check(exclude=('kdes',)) + assert self.system.kdes[:] == exp_kdes + + @pytest.mark.parametrize('args, kwargs', [ + ((u[0] - qd[0], u[4] - qd[4]), {}), + ((-(u[0] - qd[0]), u[4] - qd[4]), {}), + (([u[0] - u[0], u[4] - qd[4]]), {}), + ]) + def test_kdes_invalid(self, _filled_system_setup, args, kwargs): + with pytest.raises(ValueError): + self.system.add_kdes(*args, **kwargs) + self._filled_system_check() + + @pytest.mark.parametrize('args, kwargs, exp_con', [ + ((), {}, [q[2] - q[0] + q[1]]), + ((q[4] - q[5], q[5] + q[3]), {}, + [q[2] - q[0] + q[1], q[4] - q[5], q[5] + q[3]]), + ]) + def test_holonomic_constraints(self, _filled_system_setup, args, kwargs, + exp_con): + exclude = ('holonomic_constraints', 'velocity_constraints') + exp_vel_con = [c.diff(t) for c in exp_con] + self.nhc + # Test add_holonomic_constraints + self.system.add_holonomic_constraints(*args, **kwargs) + self._filled_system_check(exclude=exclude) + assert self.system.holonomic_constraints[:] == exp_con + assert self.system.velocity_constraints[:] == exp_vel_con + # Test setter for holonomic_constraints + self.system.holonomic_constraints = exp_con + self._filled_system_check(exclude=exclude) + assert self.system.holonomic_constraints[:] == exp_con + assert self.system.velocity_constraints[:] == exp_vel_con + + @pytest.mark.parametrize('args, kwargs', [ + ((q[2] - q[0] + q[1], q[4] - q[3]), {}), + ((-(q[2] - q[0] + q[1]), q[4] - q[3]), {}), + ((q[0] - q[0], q[4] - q[3]), {}), + ]) + def test_holonomic_constraints_invalid(self, _filled_system_setup, args, + kwargs): + with pytest.raises(ValueError): + self.system.add_holonomic_constraints(*args, **kwargs) + self._filled_system_check() + + @pytest.mark.parametrize('args, kwargs, exp_con', [ + ((), {}, [u[3] - qd[1] + u[2]]), + ((u[4] - u[5], u[5] + u[3]), {}, + [u[3] - qd[1] + u[2], u[4] - u[5], u[5] + u[3]]), + ]) + def test_nonholonomic_constraints(self, _filled_system_setup, args, kwargs, + exp_con): + exclude = ('nonholonomic_constraints', 'velocity_constraints') + exp_vel_con = self.vc[:len(self.hc)] + exp_con + # Test add_nonholonomic_constraints + self.system.add_nonholonomic_constraints(*args, **kwargs) + self._filled_system_check(exclude=exclude) + assert self.system.nonholonomic_constraints[:] == exp_con + assert self.system.velocity_constraints[:] == exp_vel_con + # Test setter for nonholonomic_constraints + self.system.nonholonomic_constraints = exp_con + self._filled_system_check(exclude=exclude) + assert self.system.nonholonomic_constraints[:] == exp_con + assert self.system.velocity_constraints[:] == exp_vel_con + + @pytest.mark.parametrize('args, kwargs', [ + ((u[3] - qd[1] + u[2], u[4] - u[3]), {}), + ((-(u[3] - qd[1] + u[2]), u[4] - u[3]), {}), + ((u[0] - u[0], u[4] - u[3]), {}), + (([u[0] - u[0], u[4] - u[3]]), {}), + ]) + def test_nonholonomic_constraints_invalid(self, _filled_system_setup, args, + kwargs): + with pytest.raises(ValueError): + self.system.add_nonholonomic_constraints(*args, **kwargs) + self._filled_system_check() + + @pytest.mark.parametrize('constraints, expected', [ + ([], []), + (qd[2] - qd[0] + qd[1], [qd[2] - qd[0] + qd[1]]), + ([qd[2] + qd[1], u[2] - u[1]], [qd[2] + qd[1], u[2] - u[1]]), + ]) + def test_velocity_constraints_overwrite(self, _filled_system_setup, + constraints, expected): + self.system.velocity_constraints = constraints + self._filled_system_check(exclude=('velocity_constraints',)) + assert self.system.velocity_constraints[:] == expected + + def test_velocity_constraints_back_to_auto(self, _filled_system_setup): + self.system.velocity_constraints = qd[3] - qd[2] + self._filled_system_check(exclude=('velocity_constraints',)) + assert self.system.velocity_constraints[:] == [qd[3] - qd[2]] + self.system.velocity_constraints = None + self._filled_system_check() + + def test_bodies(self, _filled_system_setup): + rb1, rb2 = RigidBody('rb1'), RigidBody('rb2') + p1, p2 = Particle('p1'), Particle('p2') + self.system.add_bodies(rb1, p1) + assert self.system.bodies == (*self.bodies, rb1, p1) + self.system.add_bodies(p2) + assert self.system.bodies == (*self.bodies, rb1, p1, p2) + self.system.bodies = [] + assert self.system.bodies == () + self.system.bodies = p2 + assert self.system.bodies == (p2,) + symb = symbols('symb') + pytest.raises(TypeError, lambda: self.system.add_bodies(symb)) + pytest.raises(ValueError, lambda: self.system.add_bodies(p2)) + with pytest.raises(TypeError): + self.system.bodies = (rb1, rb2, p1, p2, symb) + assert self.system.bodies == (p2,) + + def test_add_loads(self): + system = System() + N, A = ReferenceFrame('N'), ReferenceFrame('A') + rb1 = RigidBody('rb1', frame=N) + mc1 = Point('mc1') + p1 = Particle('p1', mc1) + system.add_loads(Torque(rb1, N.x), (mc1, A.x), Force(p1, A.x)) + assert system.loads == ((N, N.x), (mc1, A.x), (mc1, A.x)) + system.loads = [(A, A.x)] + assert system.loads == ((A, A.x),) + pytest.raises(ValueError, lambda: system.add_loads((N, N.x, N.y))) + with pytest.raises(TypeError): + system.loads = (N, N.x) + assert system.loads == ((A, A.x),) + + def test_add_actuators(self): + system = System() + N, A = ReferenceFrame('N'), ReferenceFrame('A') + act1 = TorqueActuator(symbols('T1'), N.x, N) + act2 = TorqueActuator(symbols('T2'), N.y, N, A) + system.add_actuators(act1) + assert system.actuators == (act1,) + assert system.loads == () + system.actuators = (act2,) + assert system.actuators == (act2,) + + def test_add_joints(self): + q1, q2, q3, q4, u1, u2, u3 = dynamicsymbols('q1:5 u1:4') + rb1, rb2, rb3, rb4, rb5 = symbols('rb1:6', cls=RigidBody) + J1 = PinJoint('J1', rb1, rb2, q1, u1) + J2 = PrismaticJoint('J2', rb2, rb3, q2, u2) + J3 = PinJoint('J3', rb3, rb4, q3, u3) + J_lag = PinJoint('J_lag', rb4, rb5, q4, q4.diff(t)) + system = System() + system.add_joints(J1) + assert system.joints == (J1,) + assert system.bodies == (rb1, rb2) + assert system.q_ind == ImmutableMatrix([q1]) + assert system.u_ind == ImmutableMatrix([u1]) + assert system.kdes == ImmutableMatrix([u1 - q1.diff(t)]) + system.add_bodies(rb4) + system.add_coordinates(q3) + system.add_kdes(u3 - q3.diff(t)) + system.add_joints(J3) + assert system.joints == (J1, J3) + assert system.bodies == (rb1, rb2, rb4, rb3) + assert system.q_ind == ImmutableMatrix([q1, q3]) + assert system.u_ind == ImmutableMatrix([u1, u3]) + assert system.kdes == ImmutableMatrix( + [u1 - q1.diff(t), u3 - q3.diff(t)]) + system.add_kdes(-(u2 - q2.diff(t))) + system.add_joints(J2) + assert system.joints == (J1, J3, J2) + assert system.bodies == (rb1, rb2, rb4, rb3) + assert system.q_ind == ImmutableMatrix([q1, q3, q2]) + assert system.u_ind == ImmutableMatrix([u1, u3, u2]) + assert system.kdes == ImmutableMatrix([u1 - q1.diff(t), u3 - q3.diff(t), + -(u2 - q2.diff(t))]) + system.add_joints(J_lag) + assert system.joints == (J1, J3, J2, J_lag) + assert system.bodies == (rb1, rb2, rb4, rb3, rb5) + assert system.q_ind == ImmutableMatrix([q1, q3, q2, q4]) + assert system.u_ind == ImmutableMatrix([u1, u3, u2, q4.diff(t)]) + assert system.kdes == ImmutableMatrix([u1 - q1.diff(t), u3 - q3.diff(t), + -(u2 - q2.diff(t))]) + assert system.q_dep[:] == [] + assert system.u_dep[:] == [] + pytest.raises(ValueError, lambda: system.add_joints(J2)) + pytest.raises(TypeError, lambda: system.add_joints(rb1)) + + def test_joints_setter(self, _filled_system_setup): + self.system.joints = self.joints[1:] + assert self.system.joints == self.joints[1:] + self._filled_system_check(exclude=('joints',)) + self.system.q_ind = () + self.system.u_ind = () + self.system.joints = self.joints + self._filled_system_check() + + @pytest.mark.parametrize('name, joint_index', [ + ('J1', 0), + ('J2', 1), + ('not_existing', None), + ]) + def test_get_joint(self, _filled_system_setup, name, joint_index): + joint = self.system.get_joint(name) + if joint_index is None: + assert joint is None + else: + assert joint == self.joints[joint_index] + + @pytest.mark.parametrize('name, body_index', [ + ('rb1', 0), + ('rb3', 2), + ('not_existing', None), + ]) + def test_get_body(self, _filled_system_setup, name, body_index): + body = self.system.get_body(name) + if body_index is None: + assert body is None + else: + assert body == self.bodies[body_index] + + @pytest.mark.parametrize('eom_method', [KanesMethod, LagrangesMethod]) + def test_form_eoms_calls_subclass(self, _moving_point_mass, eom_method): + class MyMethod(eom_method): + pass + + self.system.form_eoms(eom_method=MyMethod) + assert isinstance(self.system.eom_method, MyMethod) + + @pytest.mark.parametrize('kwargs, expected', [ + ({}, ImmutableMatrix([[-1, 0], [0, symbols('m')]])), + ({'explicit_kinematics': True}, ImmutableMatrix([[1, 0], + [0, symbols('m')]])), + ]) + def test_system_kane_form_eoms_kwargs(self, _moving_point_mass, kwargs, + expected): + self.system.form_eoms(**kwargs) + assert self.system.mass_matrix_full == expected + + @pytest.mark.parametrize('kwargs, mm, gm', [ + ({}, ImmutableMatrix([[1, 0], [0, symbols('m')]]), + ImmutableMatrix([q[0].diff(t), 0])), + ]) + def test_system_lagrange_form_eoms_kwargs(self, _moving_point_mass, kwargs, + mm, gm): + self.system.form_eoms(eom_method=LagrangesMethod, **kwargs) + assert self.system.mass_matrix_full == mm + assert self.system.forcing_full == gm + + @pytest.mark.parametrize('eom_method, kwargs, error', [ + (KanesMethod, {'non_existing_kwarg': 1}, TypeError), + (LagrangesMethod, {'non_existing_kwarg': 1}, TypeError), + (KanesMethod, {'bodies': []}, ValueError), + (KanesMethod, {'kd_eqs': []}, ValueError), + (LagrangesMethod, {'bodies': []}, ValueError), + (LagrangesMethod, {'Lagrangian': 1}, ValueError), + ]) + def test_form_eoms_kwargs_errors(self, _empty_system_setup, eom_method, + kwargs, error): + self.system.q_ind = q[0] + p = Particle('p', mass=symbols('m')) + self.system.add_bodies(p) + p.masscenter.set_pos(self.system.fixed_point, q[0] * self.system.x) + with pytest.raises(error): + self.system.form_eoms(eom_method=eom_method, **kwargs) + + +class TestValidateSystem(TestSystemBase): + @pytest.mark.parametrize('valid_method, invalid_method, with_speeds', [ + (KanesMethod, LagrangesMethod, True), + (LagrangesMethod, KanesMethod, False) + ]) + def test_only_valid(self, valid_method, invalid_method, with_speeds): + self._create_filled_system(with_speeds=with_speeds) + self.system.validate_system(valid_method) + # Test Lagrange should fail due to the usage of generalized speeds + with pytest.raises(ValueError): + self.system.validate_system(invalid_method) + + @pytest.mark.parametrize('method, with_speeds', [ + (KanesMethod, True), (LagrangesMethod, False)]) + def test_missing_joint_coordinate(self, method, with_speeds): + self._create_filled_system(with_speeds=with_speeds) + self.system.q_ind = self.q_ind[1:] + self.system.u_ind = self.u_ind[:-1] + self.system.kdes = self.kdes[:-1] + pytest.raises(ValueError, lambda: self.system.validate_system(method)) + + def test_missing_joint_speed(self, _filled_system_setup): + self.system.q_ind = self.q_ind[:-1] + self.system.u_ind = self.u_ind[1:] + self.system.kdes = self.kdes[:-1] + pytest.raises(ValueError, lambda: self.system.validate_system()) + + def test_missing_joint_kdes(self, _filled_system_setup): + self.system.kdes = self.kdes[1:] + pytest.raises(ValueError, lambda: self.system.validate_system()) + + def test_negative_joint_kdes(self, _filled_system_setup): + self.system.kdes = [-self.kdes[0]] + self.kdes[1:] + self.system.validate_system() + + @pytest.mark.parametrize('method, with_speeds', [ + (KanesMethod, True), (LagrangesMethod, False)]) + def test_missing_holonomic_constraint(self, method, with_speeds): + self._create_filled_system(with_speeds=with_speeds) + self.system.holonomic_constraints = [] + self.system.nonholonomic_constraints = self.nhc + [ + self.u_ind[1] - self.u_dep[0] + self.u_ind[0]] + pytest.raises(ValueError, lambda: self.system.validate_system(method)) + self.system.q_dep = [] + self.system.q_ind = self.q_ind + self.q_dep + self.system.validate_system(method) + + def test_missing_nonholonomic_constraint(self, _filled_system_setup): + self.system.nonholonomic_constraints = [] + pytest.raises(ValueError, lambda: self.system.validate_system()) + self.system.u_dep = self.u_dep[1] + self.system.u_ind = self.u_ind + [self.u_dep[0]] + self.system.validate_system() + + def test_number_of_coordinates_speeds(self, _filled_system_setup): + # Test more speeds than coordinates + self.system.u_ind = self.u_ind + [u[5]] + self.system.kdes = self.kdes + [u[5] - qd[5]] + self.system.validate_system() + # Test more coordinates than speeds + self.system.q_ind = self.q_ind + self.system.u_ind = self.u_ind[:-1] + self.system.kdes = self.kdes[:-1] + pytest.raises(ValueError, lambda: self.system.validate_system()) + + def test_number_of_kdes(self, _filled_system_setup): + # Test wrong number of kdes + self.system.kdes = self.kdes[:-1] + pytest.raises(ValueError, lambda: self.system.validate_system()) + self.system.kdes = self.kdes + [u[2] + u[1] - qd[2]] + pytest.raises(ValueError, lambda: self.system.validate_system()) + + def test_duplicates(self, _filled_system_setup): + # This is basically a redundant feature, which should never fail + self.system.validate_system(check_duplicates=True) + + def test_speeds_in_lagrange(self, _filled_system_setup_no_speeds): + self.system.u_ind = u[:len(self.u_ind)] + with pytest.raises(ValueError): + self.system.validate_system(LagrangesMethod) + self.system.u_ind = [] + self.system.validate_system(LagrangesMethod) + self.system.u_aux = ua + with pytest.raises(ValueError): + self.system.validate_system(LagrangesMethod) + self.system.u_aux = [] + self.system.validate_system(LagrangesMethod) + self.system.add_joints( + PinJoint('Ju', RigidBody('rbu1'), RigidBody('rbu2'))) + self.system.u_ind = [] + with pytest.raises(ValueError): + self.system.validate_system(LagrangesMethod) + + +class TestSystemExamples: + def test_cart_pendulum_kanes(self): + # This example is the same as in the top documentation of System + # Added a spring to the cart + g, l, mc, mp, k = symbols('g l mc mp k') + F, qp, qc, up, uc = dynamicsymbols('F qp qc up uc') + rail = RigidBody('rail') + cart = RigidBody('cart', mass=mc) + bob = Particle('bob', mass=mp) + bob_frame = ReferenceFrame('bob_frame') + system = System.from_newtonian(rail) + assert system.bodies == (rail,) + assert system.frame == rail.frame + assert system.fixed_point == rail.masscenter + slider = PrismaticJoint('slider', rail, cart, qc, uc, joint_axis=rail.x) + pin = PinJoint('pin', cart, bob, qp, up, joint_axis=cart.z, + child_interframe=bob_frame, child_point=l * bob_frame.y) + system.add_joints(slider, pin) + assert system.joints == (slider, pin) + assert system.get_joint('slider') == slider + assert system.get_body('bob') == bob + system.apply_uniform_gravity(-g * system.y) + system.add_loads((cart.masscenter, F * rail.x)) + system.add_actuators(TorqueActuator(k * qp, cart.z, bob_frame, cart)) + system.validate_system() + system.form_eoms() + assert isinstance(system.eom_method, KanesMethod) + assert (simplify(system.mass_matrix - ImmutableMatrix( + [[mp + mc, mp * l * cos(qp)], [mp * l * cos(qp), mp * l ** 2]])) + == zeros(2, 2)) + assert (simplify(system.forcing - ImmutableMatrix([ + [mp * l * up ** 2 * sin(qp) + F], + [-mp * g * l * sin(qp) + k * qp]])) == zeros(2, 1)) + + system.add_holonomic_constraints( + sympify(bob.masscenter.pos_from(rail.masscenter).dot(system.x))) + assert system.eom_method is None + system.q_ind, system.q_dep = qp, qc + system.u_ind, system.u_dep = up, uc + system.validate_system() + + # Computed solution based on manually solving the constraints + subs = {qc: -l * sin(qp), + uc: -l * cos(qp) * up, + uc.diff(t): l * (up ** 2 * sin(qp) - up.diff(t) * cos(qp))} + upd_expected = ( + (-g * mp * sin(qp) + k * qp / l + l * mc * sin(2 * qp) * up ** 2 / 2 + - l * mp * sin(2 * qp) * up ** 2 / 2 - F * cos(qp)) / + (l * (mc * cos(qp) ** 2 + mp * sin(qp) ** 2))) + upd_sol = tuple(solve(system.form_eoms().xreplace(subs), + up.diff(t)).values())[0] + assert simplify(upd_sol - upd_expected) == 0 + assert isinstance(system.eom_method, KanesMethod) + + # Test other output + Mk = -ImmutableMatrix([[0, 1], [1, 0]]) + gk = -ImmutableMatrix([uc, up]) + Md = ImmutableMatrix([[-l ** 2 * mp * cos(qp) ** 2 + l ** 2 * mp, + l * mp * cos(qp) - l * (mc + mp) * cos(qp)], + [l * cos(qp), 1]]) + gd = ImmutableMatrix( + [[-g * l * mp * sin(qp) + k * qp - l ** 2 * mp * up ** 2 * sin(qp) * + cos(qp) - l * F * cos(qp)], [l * up ** 2 * sin(qp)]]) + Mm = (Mk.row_join(zeros(2, 2))).col_join(zeros(2, 2).row_join(Md)) + gm = gk.col_join(gd) + assert simplify(system.mass_matrix - Md) == zeros(2, 2) + assert simplify(system.forcing - gd) == zeros(2, 1) + assert simplify(system.mass_matrix_full - Mm) == zeros(4, 4) + assert simplify(system.forcing_full - gm) == zeros(4, 1) + + def test_cart_pendulum_lagrange(self): + # Lagrange version of test_cart_pendulus_kanes + # Added a spring to the cart + g, l, mc, mp, k = symbols('g l mc mp k') + F, qp, qc = dynamicsymbols('F qp qc') + qpd, qcd = dynamicsymbols('qp qc', 1) + rail = RigidBody('rail') + cart = RigidBody('cart', mass=mc) + bob = Particle('bob', mass=mp) + bob_frame = ReferenceFrame('bob_frame') + system = System.from_newtonian(rail) + assert system.bodies == (rail,) + assert system.frame == rail.frame + assert system.fixed_point == rail.masscenter + slider = PrismaticJoint('slider', rail, cart, qc, qcd, + joint_axis=rail.x) + pin = PinJoint('pin', cart, bob, qp, qpd, joint_axis=cart.z, + child_interframe=bob_frame, child_point=l * bob_frame.y) + system.add_joints(slider, pin) + assert system.joints == (slider, pin) + assert system.get_joint('slider') == slider + assert system.get_body('bob') == bob + for body in system.bodies: + body.potential_energy = body.mass * g * body.masscenter.pos_from( + system.fixed_point).dot(system.y) + system.add_loads((cart.masscenter, F * rail.x)) + system.add_actuators(TorqueActuator(k * qp, cart.z, bob_frame, cart)) + system.validate_system(LagrangesMethod) + system.form_eoms(LagrangesMethod) + assert (simplify(system.mass_matrix - ImmutableMatrix( + [[mp + mc, mp * l * cos(qp)], [mp * l * cos(qp), mp * l ** 2]])) + == zeros(2, 2)) + assert (simplify(system.forcing - ImmutableMatrix([ + [mp * l * qpd ** 2 * sin(qp) + F], [-mp * g * l * sin(qp) + k * qp]] + )) == zeros(2, 1)) + + system.add_holonomic_constraints( + sympify(bob.masscenter.pos_from(rail.masscenter).dot(system.x))) + assert system.eom_method is None + system.q_ind, system.q_dep = qp, qc + + # Computed solution based on manually solving the constraints + subs = {qc: -l * sin(qp), + qcd: -l * cos(qp) * qpd, + qcd.diff(t): l * (qpd ** 2 * sin(qp) - qpd.diff(t) * cos(qp))} + qpdd_expected = ( + (-g * mp * sin(qp) + k * qp / l + l * mc * sin(2 * qp) * qpd ** 2 / + 2 - l * mp * sin(2 * qp) * qpd ** 2 / 2 - F * cos(qp)) / + (l * (mc * cos(qp) ** 2 + mp * sin(qp) ** 2))) + eoms = system.form_eoms(LagrangesMethod) + lam1 = system.eom_method.lam_vec[0] + lam1_sol = system.eom_method.solve_multipliers()[lam1] + qpdd_sol = solve(eoms[0].xreplace({lam1: lam1_sol}).xreplace(subs), + qpd.diff(t))[0] + assert simplify(qpdd_sol - qpdd_expected) == 0 + assert isinstance(system.eom_method, LagrangesMethod) + + # Test other output + Md = ImmutableMatrix([[l ** 2 * mp, l * mp * cos(qp), -l * cos(qp)], + [l * mp * cos(qp), mc + mp, -1]]) + gd = ImmutableMatrix( + [[-g * l * mp * sin(qp) + k * qp], + [l * mp * sin(qp) * qpd ** 2 + F]]) + Mm = (eye(2).row_join(zeros(2, 3))).col_join(zeros(3, 2).row_join( + Md.col_join(ImmutableMatrix([l * cos(qp), 1, 0]).T))) + gm = ImmutableMatrix([qpd, qcd] + gd[:] + [l * sin(qp) * qpd ** 2]) + assert simplify(system.mass_matrix - Md) == zeros(2, 3) + assert simplify(system.forcing - gd) == zeros(2, 1) + assert simplify(system.mass_matrix_full - Mm) == zeros(5, 5) + assert simplify(system.forcing_full - gm) == zeros(5, 1) + + def test_box_on_ground(self): + # Particle sliding on ground with friction. The applied force is assumed + # to be positive and to be higher than the friction force. + g, m, mu = symbols('g m mu') + q, u, ua = dynamicsymbols('q u ua') + N, F = dynamicsymbols('N F', positive=True) + P = Particle("P", mass=m) + system = System() + system.add_bodies(P) + P.masscenter.set_pos(system.fixed_point, q * system.x) + P.masscenter.set_vel(system.frame, u * system.x + ua * system.y) + system.q_ind, system.u_ind, system.u_aux = [q], [u], [ua] + system.kdes = [q.diff(t) - u] + system.apply_uniform_gravity(-g * system.y) + system.add_loads( + Force(P, N * system.y), + Force(P, F * system.x - mu * N * system.x)) + system.validate_system() + system.form_eoms() + + # Test other output + Mk = ImmutableMatrix([1]) + gk = ImmutableMatrix([u]) + Md = ImmutableMatrix([m]) + gd = ImmutableMatrix([F - mu * N]) + Mm = (Mk.row_join(zeros(1, 1))).col_join(zeros(1, 1).row_join(Md)) + gm = gk.col_join(gd) + aux_eqs = ImmutableMatrix([N - m * g]) + assert simplify(system.mass_matrix - Md) == zeros(1, 1) + assert simplify(system.forcing - gd) == zeros(1, 1) + assert simplify(system.mass_matrix_full - Mm) == zeros(2, 2) + assert simplify(system.forcing_full - gm) == zeros(2, 1) + assert simplify(system.eom_method.auxiliary_eqs - aux_eqs + ) == zeros(1, 1) diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_wrapping_geometry.py b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_wrapping_geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..30c3ae71db5da75238ebb3d4cc53e11a29a72e5d --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/mechanics/tests/test_wrapping_geometry.py @@ -0,0 +1,363 @@ +"""Tests for the ``sympy.physics.mechanics.wrapping_geometry.py`` module.""" + +import pytest + +from sympy import ( + Integer, + Rational, + S, + Symbol, + acos, + cos, + pi, + sin, + sqrt, +) +from sympy.core.relational import Eq +from sympy.physics.mechanics import ( + Point, + ReferenceFrame, + WrappingCylinder, + WrappingSphere, + dynamicsymbols, +) +from sympy.simplify.simplify import simplify + + +r = Symbol('r', positive=True) +x = Symbol('x') +q = dynamicsymbols('q') +N = ReferenceFrame('N') + + +class TestWrappingSphere: + + @staticmethod + def test_valid_constructor(): + r = Symbol('r', positive=True) + pO = Point('pO') + sphere = WrappingSphere(r, pO) + assert isinstance(sphere, WrappingSphere) + assert hasattr(sphere, 'radius') + assert sphere.radius == r + assert hasattr(sphere, 'point') + assert sphere.point == pO + + @staticmethod + @pytest.mark.parametrize('position', [S.Zero, Integer(2)*r*N.x]) + def test_geodesic_length_point_not_on_surface_invalid(position): + r = Symbol('r', positive=True) + pO = Point('pO') + sphere = WrappingSphere(r, pO) + + p1 = Point('p1') + p1.set_pos(pO, position) + p2 = Point('p2') + p2.set_pos(pO, position) + + error_msg = r'point .* does not lie on the surface of' + with pytest.raises(ValueError, match=error_msg): + sphere.geodesic_length(p1, p2) + + @staticmethod + @pytest.mark.parametrize( + 'position_1, position_2, expected', + [ + (r*N.x, r*N.x, S.Zero), + (r*N.x, r*N.y, S.Half*pi*r), + (r*N.x, r*-N.x, pi*r), + (r*-N.x, r*N.x, pi*r), + (r*N.x, r*sqrt(2)*S.Half*(N.x + N.y), Rational(1, 4)*pi*r), + ( + r*sqrt(2)*S.Half*(N.x + N.y), + r*sqrt(3)*Rational(1, 3)*(N.x + N.y + N.z), + r*acos(sqrt(6)*Rational(1, 3)), + ), + ] + ) + def test_geodesic_length(position_1, position_2, expected): + r = Symbol('r', positive=True) + pO = Point('pO') + sphere = WrappingSphere(r, pO) + + p1 = Point('p1') + p1.set_pos(pO, position_1) + p2 = Point('p2') + p2.set_pos(pO, position_2) + + assert simplify(Eq(sphere.geodesic_length(p1, p2), expected)) + + @staticmethod + @pytest.mark.parametrize( + 'position_1, position_2, vector_1, vector_2', + [ + (r * N.x, r * N.y, N.y, N.x), + (r * N.x, -r * N.y, -N.y, N.x), + ( + r * N.y, + sqrt(2)/2 * r * N.x - sqrt(2)/2 * r * N.y, + N.x, + sqrt(2)/2 * N.x + sqrt(2)/2 * N.y, + ), + ( + r * N.x, + r / 2 * N.x + sqrt(3)/2 * r * N.y, + N.y, + sqrt(3)/2 * N.x - 1/2 * N.y, + ), + ( + r * N.x, + sqrt(2)/2 * r * N.x + sqrt(2)/2 * r * N.y, + N.y, + sqrt(2)/2 * N.x - sqrt(2)/2 * N.y, + ), + ] + ) + def test_geodesic_end_vectors(position_1, position_2, vector_1, vector_2): + r = Symbol('r', positive=True) + pO = Point('pO') + sphere = WrappingSphere(r, pO) + + p1 = Point('p1') + p1.set_pos(pO, position_1) + p2 = Point('p2') + p2.set_pos(pO, position_2) + + expected = (vector_1, vector_2) + + assert sphere.geodesic_end_vectors(p1, p2) == expected + + @staticmethod + @pytest.mark.parametrize( + 'position', + [r * N.x, r * cos(q) * N.x + r * sin(q) * N.y] + ) + def test_geodesic_end_vectors_invalid_coincident(position): + r = Symbol('r', positive=True) + pO = Point('pO') + sphere = WrappingSphere(r, pO) + + p1 = Point('p1') + p1.set_pos(pO, position) + p2 = Point('p2') + p2.set_pos(pO, position) + + with pytest.raises(ValueError): + _ = sphere.geodesic_end_vectors(p1, p2) + + @staticmethod + @pytest.mark.parametrize( + 'position_1, position_2', + [ + (r * N.x, -r * N.x), + (-r * N.y, r * N.y), + ( + r * cos(q) * N.x + r * sin(q) * N.y, + -r * cos(q) * N.x - r * sin(q) * N.y, + ) + ] + ) + def test_geodesic_end_vectors_invalid_diametrically_opposite( + position_1, + position_2, + ): + r = Symbol('r', positive=True) + pO = Point('pO') + sphere = WrappingSphere(r, pO) + + p1 = Point('p1') + p1.set_pos(pO, position_1) + p2 = Point('p2') + p2.set_pos(pO, position_2) + + with pytest.raises(ValueError): + _ = sphere.geodesic_end_vectors(p1, p2) + + +class TestWrappingCylinder: + + @staticmethod + def test_valid_constructor(): + N = ReferenceFrame('N') + r = Symbol('r', positive=True) + pO = Point('pO') + cylinder = WrappingCylinder(r, pO, N.x) + assert isinstance(cylinder, WrappingCylinder) + assert hasattr(cylinder, 'radius') + assert cylinder.radius == r + assert hasattr(cylinder, 'point') + assert cylinder.point == pO + assert hasattr(cylinder, 'axis') + assert cylinder.axis == N.x + + @staticmethod + @pytest.mark.parametrize( + 'position, expected', + [ + (S.Zero, False), + (r*N.y, True), + (r*N.z, True), + (r*(N.y + N.z).normalize(), True), + (Integer(2)*r*N.y, False), + (r*(N.x + N.y), True), + (r*(Integer(2)*N.x + N.y), True), + (Integer(2)*N.x + r*(Integer(2)*N.y + N.z).normalize(), True), + (r*(cos(q)*N.y + sin(q)*N.z), True) + ] + ) + def test_point_is_on_surface(position, expected): + r = Symbol('r', positive=True) + pO = Point('pO') + cylinder = WrappingCylinder(r, pO, N.x) + + p1 = Point('p1') + p1.set_pos(pO, position) + + assert cylinder.point_on_surface(p1) is expected + + @staticmethod + @pytest.mark.parametrize('position', [S.Zero, Integer(2)*r*N.y]) + def test_geodesic_length_point_not_on_surface_invalid(position): + r = Symbol('r', positive=True) + pO = Point('pO') + cylinder = WrappingCylinder(r, pO, N.x) + + p1 = Point('p1') + p1.set_pos(pO, position) + p2 = Point('p2') + p2.set_pos(pO, position) + + error_msg = r'point .* does not lie on the surface of' + with pytest.raises(ValueError, match=error_msg): + cylinder.geodesic_length(p1, p2) + + @staticmethod + @pytest.mark.parametrize( + 'axis, position_1, position_2, expected', + [ + (N.x, r*N.y, r*N.y, S.Zero), + (N.x, r*N.y, N.x + r*N.y, S.One), + (N.x, r*N.y, -x*N.x + r*N.y, sqrt(x**2)), + (-N.x, r*N.y, x*N.x + r*N.y, sqrt(x**2)), + (N.x, r*N.y, r*N.z, S.Half*pi*sqrt(r**2)), + (-N.x, r*N.y, r*N.z, Integer(3)*S.Half*pi*sqrt(r**2)), + (N.x, r*N.z, r*N.y, Integer(3)*S.Half*pi*sqrt(r**2)), + (-N.x, r*N.z, r*N.y, S.Half*pi*sqrt(r**2)), + (N.x, r*N.y, r*(cos(q)*N.y + sin(q)*N.z), sqrt(r**2*q**2)), + ( + -N.x, r*N.y, + r*(cos(q)*N.y + sin(q)*N.z), + sqrt(r**2*(Integer(2)*pi - q)**2), + ), + ] + ) + def test_geodesic_length(axis, position_1, position_2, expected): + r = Symbol('r', positive=True) + pO = Point('pO') + cylinder = WrappingCylinder(r, pO, axis) + + p1 = Point('p1') + p1.set_pos(pO, position_1) + p2 = Point('p2') + p2.set_pos(pO, position_2) + + assert simplify(Eq(cylinder.geodesic_length(p1, p2), expected)) + + @staticmethod + @pytest.mark.parametrize( + 'axis, position_1, position_2, vector_1, vector_2', + [ + (N.z, r * N.x, r * N.y, N.y, N.x), + (N.z, r * N.x, -r * N.x, N.y, N.y), + (N.z, -r * N.x, r * N.x, -N.y, -N.y), + (-N.z, r * N.x, -r * N.x, -N.y, -N.y), + (-N.z, -r * N.x, r * N.x, N.y, N.y), + (N.z, r * N.x, -r * N.y, N.y, -N.x), + ( + N.z, + r * N.y, + sqrt(2)/2 * r * N.x - sqrt(2)/2 * r * N.y, + - N.x, + - sqrt(2)/2 * N.x - sqrt(2)/2 * N.y, + ), + ( + N.z, + r * N.x, + r / 2 * N.x + sqrt(3)/2 * r * N.y, + N.y, + sqrt(3)/2 * N.x - 1/2 * N.y, + ), + ( + N.z, + r * N.x, + sqrt(2)/2 * r * N.x + sqrt(2)/2 * r * N.y, + N.y, + sqrt(2)/2 * N.x - sqrt(2)/2 * N.y, + ), + ( + N.z, + r * N.x, + r * N.x + N.z, + N.z, + -N.z, + ), + ( + N.z, + r * N.x, + r * N.y + pi/2 * r * N.z, + sqrt(2)/2 * N.y + sqrt(2)/2 * N.z, + sqrt(2)/2 * N.x - sqrt(2)/2 * N.z, + ), + ( + N.z, + r * N.x, + r * cos(q) * N.x + r * sin(q) * N.y, + N.y, + sin(q) * N.x - cos(q) * N.y, + ), + ] + ) + def test_geodesic_end_vectors( + axis, + position_1, + position_2, + vector_1, + vector_2, + ): + r = Symbol('r', positive=True) + pO = Point('pO') + cylinder = WrappingCylinder(r, pO, axis) + + p1 = Point('p1') + p1.set_pos(pO, position_1) + p2 = Point('p2') + p2.set_pos(pO, position_2) + + expected = (vector_1, vector_2) + end_vectors = tuple( + end_vector.simplify() + for end_vector in cylinder.geodesic_end_vectors(p1, p2) + ) + + assert end_vectors == expected + + @staticmethod + @pytest.mark.parametrize( + 'axis, position', + [ + (N.z, r * N.x), + (N.z, r * cos(q) * N.x + r * sin(q) * N.y + N.z), + ] + ) + def test_geodesic_end_vectors_invalid_coincident(axis, position): + r = Symbol('r', positive=True) + pO = Point('pO') + cylinder = WrappingCylinder(r, pO, axis) + + p1 = Point('p1') + p1.set_pos(pO, position) + p2 = Point('p2') + p2.set_pos(pO, position) + + with pytest.raises(ValueError): + _ = cylinder.geodesic_end_vectors(p1, p2) diff --git a/phivenv/Lib/site-packages/sympy/physics/mechanics/wrapping_geometry.py b/phivenv/Lib/site-packages/sympy/physics/mechanics/wrapping_geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..47ed3c1c463499b024afb9e31cfa2ecd77534132 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/mechanics/wrapping_geometry.py @@ -0,0 +1,641 @@ +"""Geometry objects for use by wrapping pathways.""" + +from abc import ABC, abstractmethod + +from sympy import Integer, acos, pi, sqrt, sympify, tan +from sympy.core.relational import Eq +from sympy.functions.elementary.trigonometric import atan2 +from sympy.polys.polytools import cancel +from sympy.physics.vector import Vector, dot +from sympy.simplify.simplify import trigsimp + + +__all__ = [ + 'WrappingGeometryBase', + 'WrappingCylinder', + 'WrappingSphere', +] + + +class WrappingGeometryBase(ABC): + """Abstract base class for all geometry classes to inherit from. + + Notes + ===== + + Instances of this class cannot be directly instantiated by users. However, + it can be used to created custom geometry types through subclassing. + + """ + + @property + @abstractmethod + def point(cls): + """The point with which the geometry is associated.""" + pass + + @abstractmethod + def point_on_surface(self, point): + """Returns ``True`` if a point is on the geometry's surface. + + Parameters + ========== + point : Point + The point for which it's to be ascertained if it's on the + geometry's surface or not. + + """ + pass + + @abstractmethod + def geodesic_length(self, point_1, point_2): + """Returns the shortest distance between two points on a geometry's + surface. + + Parameters + ========== + + point_1 : Point + The point from which the geodesic length should be calculated. + point_2 : Point + The point to which the geodesic length should be calculated. + + """ + pass + + @abstractmethod + def geodesic_end_vectors(self, point_1, point_2): + """The vectors parallel to the geodesic at the two end points. + + Parameters + ========== + + point_1 : Point + The point from which the geodesic originates. + point_2 : Point + The point at which the geodesic terminates. + + """ + pass + + def __repr__(self): + """Default representation of a geometry object.""" + return f'{self.__class__.__name__}()' + + +class WrappingSphere(WrappingGeometryBase): + """A solid spherical object. + + Explanation + =========== + + A wrapping geometry that allows for circular arcs to be defined between + pairs of points. These paths are always geodetic (the shortest possible). + + Examples + ======== + + To create a ``WrappingSphere`` instance, a ``Symbol`` denoting its radius + and ``Point`` at which its center will be located are needed: + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import Point, WrappingSphere + >>> r = symbols('r') + >>> pO = Point('pO') + + A sphere with radius ``r`` centered on ``pO`` can be instantiated with: + + >>> WrappingSphere(r, pO) + WrappingSphere(radius=r, point=pO) + + Parameters + ========== + + radius : Symbol + Radius of the sphere. This symbol must represent a value that is + positive and constant, i.e. it cannot be a dynamic symbol, nor can it + be an expression. + point : Point + A point at which the sphere is centered. + + See Also + ======== + + WrappingCylinder: Cylindrical geometry where the wrapping direction can be + defined. + + """ + + def __init__(self, radius, point): + """Initializer for ``WrappingSphere``. + + Parameters + ========== + + radius : Symbol + The radius of the sphere. + point : Point + A point on which the sphere is centered. + + """ + self.radius = radius + self.point = point + + @property + def radius(self): + """Radius of the sphere.""" + return self._radius + + @radius.setter + def radius(self, radius): + self._radius = radius + + @property + def point(self): + """A point on which the sphere is centered.""" + return self._point + + @point.setter + def point(self, point): + self._point = point + + def point_on_surface(self, point): + """Returns ``True`` if a point is on the sphere's surface. + + Parameters + ========== + + point : Point + The point for which it's to be ascertained if it's on the sphere's + surface or not. This point's position relative to the sphere's + center must be a simple expression involving the radius of the + sphere, otherwise this check will likely not work. + + """ + point_vector = point.pos_from(self.point) + if isinstance(point_vector, Vector): + point_radius_squared = dot(point_vector, point_vector) + else: + point_radius_squared = point_vector**2 + return Eq(point_radius_squared, self.radius**2) == True + + def geodesic_length(self, point_1, point_2): + r"""Returns the shortest distance between two points on the sphere's + surface. + + Explanation + =========== + + The geodesic length, i.e. the shortest arc along the surface of a + sphere, connecting two points can be calculated using the formula: + + .. math:: + + l = \arccos\left(\mathbf{v}_1 \cdot \mathbf{v}_2\right) + + where $\mathbf{v}_1$ and $\mathbf{v}_2$ are the unit vectors from the + sphere's center to the first and second points on the sphere's surface + respectively. Note that the actual path that the geodesic will take is + undefined when the two points are directly opposite one another. + + Examples + ======== + + A geodesic length can only be calculated between two points on the + sphere's surface. Firstly, a ``WrappingSphere`` instance must be + created along with two points that will lie on its surface: + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import (Point, ReferenceFrame, + ... WrappingSphere) + >>> N = ReferenceFrame('N') + >>> r = symbols('r') + >>> pO = Point('pO') + >>> pO.set_vel(N, 0) + >>> sphere = WrappingSphere(r, pO) + >>> p1 = Point('p1') + >>> p2 = Point('p2') + + Let's assume that ``p1`` lies at a distance of ``r`` in the ``N.x`` + direction from ``pO`` and that ``p2`` is located on the sphere's + surface in the ``N.y + N.z`` direction from ``pO``. These positions can + be set with: + + >>> p1.set_pos(pO, r*N.x) + >>> p1.pos_from(pO) + r*N.x + >>> p2.set_pos(pO, r*(N.y + N.z).normalize()) + >>> p2.pos_from(pO) + sqrt(2)*r/2*N.y + sqrt(2)*r/2*N.z + + The geodesic length, which is in this case is a quarter of the sphere's + circumference, can be calculated using the ``geodesic_length`` method: + + >>> sphere.geodesic_length(p1, p2) + pi*r/2 + + If the ``geodesic_length`` method is passed an argument, the ``Point`` + that doesn't lie on the sphere's surface then a ``ValueError`` is + raised because it's not possible to calculate a value in this case. + + Parameters + ========== + + point_1 : Point + Point from which the geodesic length should be calculated. + point_2 : Point + Point to which the geodesic length should be calculated. + + """ + for point in (point_1, point_2): + if not self.point_on_surface(point): + msg = ( + f'Geodesic length cannot be calculated as point {point} ' + f'with radius {point.pos_from(self.point).magnitude()} ' + f'from the sphere\'s center {self.point} does not lie on ' + f'the surface of {self} with radius {self.radius}.' + ) + raise ValueError(msg) + point_1_vector = point_1.pos_from(self.point).normalize() + point_2_vector = point_2.pos_from(self.point).normalize() + central_angle = acos(point_2_vector.dot(point_1_vector)) + geodesic_length = self.radius*central_angle + return geodesic_length + + def geodesic_end_vectors(self, point_1, point_2): + """The vectors parallel to the geodesic at the two end points. + + Parameters + ========== + + point_1 : Point + The point from which the geodesic originates. + point_2 : Point + The point at which the geodesic terminates. + + """ + pA, pB = point_1, point_2 + pO = self.point + pA_vec = pA.pos_from(pO) + pB_vec = pB.pos_from(pO) + + if pA_vec.cross(pB_vec) == 0: + msg = ( + f'Can\'t compute geodesic end vectors for the pair of points ' + f'{pA} and {pB} on a sphere {self} as they are diametrically ' + f'opposed, thus the geodesic is not defined.' + ) + raise ValueError(msg) + + return ( + pA_vec.cross(pB.pos_from(pA)).cross(pA_vec).normalize(), + pB_vec.cross(pA.pos_from(pB)).cross(pB_vec).normalize(), + ) + + def __repr__(self): + """Representation of a ``WrappingSphere``.""" + return ( + f'{self.__class__.__name__}(radius={self.radius}, ' + f'point={self.point})' + ) + + +class WrappingCylinder(WrappingGeometryBase): + """A solid (infinite) cylindrical object. + + Explanation + =========== + + A wrapping geometry that allows for circular arcs to be defined between + pairs of points. These paths are always geodetic (the shortest possible) in + the sense that they will be a straight line on the unwrapped cylinder's + surface. However, it is also possible for a direction to be specified, i.e. + paths can be influenced such that they either wrap along the shortest side + or the longest side of the cylinder. To define these directions, rotations + are in the positive direction following the right-hand rule. + + Examples + ======== + + To create a ``WrappingCylinder`` instance, a ``Symbol`` denoting its + radius, a ``Vector`` defining its axis, and a ``Point`` through which its + axis passes are needed: + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import (Point, ReferenceFrame, + ... WrappingCylinder) + >>> N = ReferenceFrame('N') + >>> r = symbols('r') + >>> pO = Point('pO') + >>> ax = N.x + + A cylinder with radius ``r``, and axis parallel to ``N.x`` passing through + ``pO`` can be instantiated with: + + >>> WrappingCylinder(r, pO, ax) + WrappingCylinder(radius=r, point=pO, axis=N.x) + + Parameters + ========== + + radius : Symbol + The radius of the cylinder. + point : Point + A point through which the cylinder's axis passes. + axis : Vector + The axis along which the cylinder is aligned. + + See Also + ======== + + WrappingSphere: Spherical geometry where the wrapping direction is always + geodetic. + + """ + + def __init__(self, radius, point, axis): + """Initializer for ``WrappingCylinder``. + + Parameters + ========== + + radius : Symbol + The radius of the cylinder. This symbol must represent a value that + is positive and constant, i.e. it cannot be a dynamic symbol. + point : Point + A point through which the cylinder's axis passes. + axis : Vector + The axis along which the cylinder is aligned. + + """ + self.radius = radius + self.point = point + self.axis = axis + + @property + def radius(self): + """Radius of the cylinder.""" + return self._radius + + @radius.setter + def radius(self, radius): + self._radius = radius + + @property + def point(self): + """A point through which the cylinder's axis passes.""" + return self._point + + @point.setter + def point(self, point): + self._point = point + + @property + def axis(self): + """Axis along which the cylinder is aligned.""" + return self._axis + + @axis.setter + def axis(self, axis): + self._axis = axis.normalize() + + def point_on_surface(self, point): + """Returns ``True`` if a point is on the cylinder's surface. + + Parameters + ========== + + point : Point + The point for which it's to be ascertained if it's on the + cylinder's surface or not. This point's position relative to the + cylinder's axis must be a simple expression involving the radius of + the sphere, otherwise this check will likely not work. + + """ + relative_position = point.pos_from(self.point) + parallel = relative_position.dot(self.axis) * self.axis + point_vector = relative_position - parallel + if isinstance(point_vector, Vector): + point_radius_squared = dot(point_vector, point_vector) + else: + point_radius_squared = point_vector**2 + return Eq(trigsimp(point_radius_squared), self.radius**2) == True + + def geodesic_length(self, point_1, point_2): + """The shortest distance between two points on a geometry's surface. + + Explanation + =========== + + The geodesic length, i.e. the shortest arc along the surface of a + cylinder, connecting two points. It can be calculated using Pythagoras' + theorem. The first short side is the distance between the two points on + the cylinder's surface parallel to the cylinder's axis. The second + short side is the arc of a circle between the two points of the + cylinder's surface perpendicular to the cylinder's axis. The resulting + hypotenuse is the geodesic length. + + Examples + ======== + + A geodesic length can only be calculated between two points on the + cylinder's surface. Firstly, a ``WrappingCylinder`` instance must be + created along with two points that will lie on its surface: + + >>> from sympy import symbols, cos, sin + >>> from sympy.physics.mechanics import (Point, ReferenceFrame, + ... WrappingCylinder, dynamicsymbols) + >>> N = ReferenceFrame('N') + >>> r = symbols('r') + >>> pO = Point('pO') + >>> pO.set_vel(N, 0) + >>> cylinder = WrappingCylinder(r, pO, N.x) + >>> p1 = Point('p1') + >>> p2 = Point('p2') + + Let's assume that ``p1`` is located at ``N.x + r*N.y`` relative to + ``pO`` and that ``p2`` is located at ``r*(cos(q)*N.y + sin(q)*N.z)`` + relative to ``pO``, where ``q(t)`` is a generalized coordinate + specifying the angle rotated around the ``N.x`` axis according to the + right-hand rule where ``N.y`` is zero. These positions can be set with: + + >>> q = dynamicsymbols('q') + >>> p1.set_pos(pO, N.x + r*N.y) + >>> p1.pos_from(pO) + N.x + r*N.y + >>> p2.set_pos(pO, r*(cos(q)*N.y + sin(q)*N.z).normalize()) + >>> p2.pos_from(pO).simplify() + r*cos(q(t))*N.y + r*sin(q(t))*N.z + + The geodesic length, which is in this case a is the hypotenuse of a + right triangle where the other two side lengths are ``1`` (parallel to + the cylinder's axis) and ``r*q(t)`` (parallel to the cylinder's cross + section), can be calculated using the ``geodesic_length`` method: + + >>> cylinder.geodesic_length(p1, p2).simplify() + sqrt(r**2*q(t)**2 + 1) + + If the ``geodesic_length`` method is passed an argument ``Point`` that + doesn't lie on the sphere's surface then a ``ValueError`` is raised + because it's not possible to calculate a value in this case. + + Parameters + ========== + + point_1 : Point + Point from which the geodesic length should be calculated. + point_2 : Point + Point to which the geodesic length should be calculated. + + """ + for point in (point_1, point_2): + if not self.point_on_surface(point): + msg = ( + f'Geodesic length cannot be calculated as point {point} ' + f'with radius {point.pos_from(self.point).magnitude()} ' + f'from the cylinder\'s center {self.point} does not lie on ' + f'the surface of {self} with radius {self.radius} and axis ' + f'{self.axis}.' + ) + raise ValueError(msg) + + relative_position = point_2.pos_from(point_1) + parallel_length = relative_position.dot(self.axis) + + point_1_relative_position = point_1.pos_from(self.point) + point_1_perpendicular_vector = ( + point_1_relative_position + - point_1_relative_position.dot(self.axis)*self.axis + ).normalize() + + point_2_relative_position = point_2.pos_from(self.point) + point_2_perpendicular_vector = ( + point_2_relative_position + - point_2_relative_position.dot(self.axis)*self.axis + ).normalize() + + central_angle = _directional_atan( + cancel(point_1_perpendicular_vector + .cross(point_2_perpendicular_vector) + .dot(self.axis)), + cancel(point_1_perpendicular_vector.dot(point_2_perpendicular_vector)), + ) + + planar_arc_length = self.radius*central_angle + geodesic_length = sqrt(parallel_length**2 + planar_arc_length**2) + return geodesic_length + + def geodesic_end_vectors(self, point_1, point_2): + """The vectors parallel to the geodesic at the two end points. + + Parameters + ========== + + point_1 : Point + The point from which the geodesic originates. + point_2 : Point + The point at which the geodesic terminates. + + """ + point_1_from_origin_point = point_1.pos_from(self.point) + point_2_from_origin_point = point_2.pos_from(self.point) + + if point_1_from_origin_point == point_2_from_origin_point: + msg = ( + f'Cannot compute geodesic end vectors for coincident points ' + f'{point_1} and {point_2} as no geodesic exists.' + ) + raise ValueError(msg) + + point_1_parallel = point_1_from_origin_point.dot(self.axis) * self.axis + point_2_parallel = point_2_from_origin_point.dot(self.axis) * self.axis + point_1_normal = (point_1_from_origin_point - point_1_parallel) + point_2_normal = (point_2_from_origin_point - point_2_parallel) + + if point_1_normal == point_2_normal: + point_1_perpendicular = Vector(0) + point_2_perpendicular = Vector(0) + else: + point_1_perpendicular = self.axis.cross(point_1_normal).normalize() + point_2_perpendicular = -self.axis.cross(point_2_normal).normalize() + + geodesic_length = self.geodesic_length(point_1, point_2) + relative_position = point_2.pos_from(point_1) + parallel_length = relative_position.dot(self.axis) + planar_arc_length = sqrt(geodesic_length**2 - parallel_length**2) + + point_1_vector = ( + planar_arc_length * point_1_perpendicular + + parallel_length * self.axis + ).normalize() + point_2_vector = ( + planar_arc_length * point_2_perpendicular + - parallel_length * self.axis + ).normalize() + + return (point_1_vector, point_2_vector) + + def __repr__(self): + """Representation of a ``WrappingCylinder``.""" + return ( + f'{self.__class__.__name__}(radius={self.radius}, ' + f'point={self.point}, axis={self.axis})' + ) + + +def _directional_atan(numerator, denominator): + """Compute atan in a directional sense as required for geodesics. + + Explanation + =========== + + To be able to control the direction of the geodesic length along the + surface of a cylinder a dedicated arctangent function is needed that + properly handles the directionality of different case. This function + ensures that the central angle is always positive but shifting the case + where ``atan2`` would return a negative angle to be centered around + ``2*pi``. + + Notes + ===== + + This function only handles very specific cases, i.e. the ones that are + expected to be encountered when calculating symbolic geodesics on uniformly + curved surfaces. As such, ``NotImplemented`` errors can be raised in many + cases. This function is named with a leader underscore to indicate that it + only aims to provide very specific functionality within the private scope + of this module. + + """ + + if numerator.is_number and denominator.is_number: + angle = atan2(numerator, denominator) + if angle < 0: + angle += 2 * pi + elif numerator.is_number: + msg = ( + f'Cannot compute a directional atan when the numerator {numerator} ' + f'is numeric and the denominator {denominator} is symbolic.' + ) + raise NotImplementedError(msg) + elif denominator.is_number: + msg = ( + f'Cannot compute a directional atan when the numerator {numerator} ' + f'is symbolic and the denominator {denominator} is numeric.' + ) + raise NotImplementedError(msg) + else: + ratio = sympify(trigsimp(numerator / denominator)) + if isinstance(ratio, tan): + angle = ratio.args[0] + elif ( + ratio.is_Mul + and ratio.args[0] == Integer(-1) + and isinstance(ratio.args[1], tan) + ): + angle = 2 * pi - ratio.args[1].args[0] + else: + msg = f'Cannot compute a directional atan for the value {ratio}.' + raise NotImplementedError(msg) + + return angle diff --git a/phivenv/Lib/site-packages/sympy/physics/optics/__init__.py b/phivenv/Lib/site-packages/sympy/physics/optics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d2d83d452fd30e718546c0eac26fe03bbef59c06 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/optics/__init__.py @@ -0,0 +1,38 @@ +__all__ = [ + 'TWave', + + 'RayTransferMatrix', 'FreeSpace', 'FlatRefraction', 'CurvedRefraction', + 'FlatMirror', 'CurvedMirror', 'ThinLens', 'GeometricRay', 'BeamParameter', + 'waist2rayleigh', 'rayleigh2waist', 'geometric_conj_ab', + 'geometric_conj_af', 'geometric_conj_bf', 'gaussian_conj', + 'conjugate_gauss_beams', + + 'Medium', + + 'refraction_angle', 'deviation', 'fresnel_coefficients', 'brewster_angle', + 'critical_angle', 'lens_makers_formula', 'mirror_formula', 'lens_formula', + 'hyperfocal_distance', 'transverse_magnification', + + 'jones_vector', 'stokes_vector', 'jones_2_stokes', 'linear_polarizer', + 'phase_retarder', 'half_wave_retarder', 'quarter_wave_retarder', + 'transmissive_filter', 'reflective_filter', 'mueller_matrix', + 'polarizing_beam_splitter', +] +from .waves import TWave + +from .gaussopt import (RayTransferMatrix, FreeSpace, FlatRefraction, + CurvedRefraction, FlatMirror, CurvedMirror, ThinLens, GeometricRay, + BeamParameter, waist2rayleigh, rayleigh2waist, geometric_conj_ab, + geometric_conj_af, geometric_conj_bf, gaussian_conj, + conjugate_gauss_beams) + +from .medium import Medium + +from .utils import (refraction_angle, deviation, fresnel_coefficients, + brewster_angle, critical_angle, lens_makers_formula, mirror_formula, + lens_formula, hyperfocal_distance, transverse_magnification) + +from .polarization import (jones_vector, stokes_vector, jones_2_stokes, + linear_polarizer, phase_retarder, half_wave_retarder, + quarter_wave_retarder, transmissive_filter, reflective_filter, + mueller_matrix, polarizing_beam_splitter) diff --git a/phivenv/Lib/site-packages/sympy/physics/optics/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/optics/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7370e8ff40983788cd4e755fcc569611fafb92e4 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/optics/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/optics/__pycache__/gaussopt.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/optics/__pycache__/gaussopt.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93464085d08fde5529c02c8b20e04273acbd27db Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/optics/__pycache__/gaussopt.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/optics/__pycache__/medium.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/optics/__pycache__/medium.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb4333a68653e5b5c997c63c3b55ae4300966359 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/optics/__pycache__/medium.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/optics/__pycache__/polarization.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/optics/__pycache__/polarization.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8330fae98fe0f7f6e9242383ae93a33f5d354f01 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/optics/__pycache__/polarization.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/optics/__pycache__/utils.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/optics/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c33aade448185264b699fddabd2713c5118f06a5 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/optics/__pycache__/utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/optics/__pycache__/waves.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/optics/__pycache__/waves.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..031bbbb9985e0ce51934076323d9329acd51fd3f Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/optics/__pycache__/waves.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/optics/gaussopt.py b/phivenv/Lib/site-packages/sympy/physics/optics/gaussopt.py new file mode 100644 index 0000000000000000000000000000000000000000..d9e8ef555d60e3204341cdc65cdd05fb02b2f196 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/optics/gaussopt.py @@ -0,0 +1,923 @@ +""" +Gaussian optics. + +The module implements: + +- Ray transfer matrices for geometrical and gaussian optics. + + See RayTransferMatrix, GeometricRay and BeamParameter + +- Conjugation relations for geometrical and gaussian optics. + + See geometric_conj*, gauss_conj and conjugate_gauss_beams + +The conventions for the distances are as follows: + +focal distance + positive for convergent lenses +object distance + positive for real objects +image distance + positive for real images +""" + +__all__ = [ + 'RayTransferMatrix', + 'FreeSpace', + 'FlatRefraction', + 'CurvedRefraction', + 'FlatMirror', + 'CurvedMirror', + 'ThinLens', + 'GeometricRay', + 'BeamParameter', + 'waist2rayleigh', + 'rayleigh2waist', + 'geometric_conj_ab', + 'geometric_conj_af', + 'geometric_conj_bf', + 'gaussian_conj', + 'conjugate_gauss_beams', +] + + +from sympy.core.expr import Expr +from sympy.core.numbers import (I, pi) +from sympy.core.sympify import sympify +from sympy.functions.elementary.complexes import (im, re) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import atan2 +from sympy.matrices.dense import Matrix, MutableDenseMatrix +from sympy.polys.rationaltools import together +from sympy.utilities.misc import filldedent + +### +# A, B, C, D matrices +### + + +class RayTransferMatrix(MutableDenseMatrix): + """ + Base class for a Ray Transfer Matrix. + + It should be used if there is not already a more specific subclass mentioned + in See Also. + + Parameters + ========== + + parameters : + A, B, C and D or 2x2 matrix (Matrix(2, 2, [A, B, C, D])) + + Examples + ======== + + >>> from sympy.physics.optics import RayTransferMatrix, ThinLens + >>> from sympy import Symbol, Matrix + + >>> mat = RayTransferMatrix(1, 2, 3, 4) + >>> mat + Matrix([ + [1, 2], + [3, 4]]) + + >>> RayTransferMatrix(Matrix([[1, 2], [3, 4]])) + Matrix([ + [1, 2], + [3, 4]]) + + >>> mat.A + 1 + + >>> f = Symbol('f') + >>> lens = ThinLens(f) + >>> lens + Matrix([ + [ 1, 0], + [-1/f, 1]]) + + >>> lens.C + -1/f + + See Also + ======== + + GeometricRay, BeamParameter, + FreeSpace, FlatRefraction, CurvedRefraction, + FlatMirror, CurvedMirror, ThinLens + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Ray_transfer_matrix_analysis + """ + + def __new__(cls, *args): + + if len(args) == 4: + temp = ((args[0], args[1]), (args[2], args[3])) + elif len(args) == 1 \ + and isinstance(args[0], Matrix) \ + and args[0].shape == (2, 2): + temp = args[0] + else: + raise ValueError(filldedent(''' + Expecting 2x2 Matrix or the 4 elements of + the Matrix but got %s''' % str(args))) + return Matrix.__new__(cls, temp) + + def __mul__(self, other): + if isinstance(other, RayTransferMatrix): + return RayTransferMatrix(Matrix(self)*Matrix(other)) + elif isinstance(other, GeometricRay): + return GeometricRay(Matrix(self)*Matrix(other)) + elif isinstance(other, BeamParameter): + temp = Matrix(self)*Matrix(((other.q,), (1,))) + q = (temp[0]/temp[1]).expand(complex=True) + return BeamParameter(other.wavelen, + together(re(q)), + z_r=together(im(q))) + else: + return Matrix.__mul__(self, other) + + @property + def A(self): + """ + The A parameter of the Matrix. + + Examples + ======== + + >>> from sympy.physics.optics import RayTransferMatrix + >>> mat = RayTransferMatrix(1, 2, 3, 4) + >>> mat.A + 1 + """ + return self[0, 0] + + @property + def B(self): + """ + The B parameter of the Matrix. + + Examples + ======== + + >>> from sympy.physics.optics import RayTransferMatrix + >>> mat = RayTransferMatrix(1, 2, 3, 4) + >>> mat.B + 2 + """ + return self[0, 1] + + @property + def C(self): + """ + The C parameter of the Matrix. + + Examples + ======== + + >>> from sympy.physics.optics import RayTransferMatrix + >>> mat = RayTransferMatrix(1, 2, 3, 4) + >>> mat.C + 3 + """ + return self[1, 0] + + @property + def D(self): + """ + The D parameter of the Matrix. + + Examples + ======== + + >>> from sympy.physics.optics import RayTransferMatrix + >>> mat = RayTransferMatrix(1, 2, 3, 4) + >>> mat.D + 4 + """ + return self[1, 1] + + +class FreeSpace(RayTransferMatrix): + """ + Ray Transfer Matrix for free space. + + Parameters + ========== + + distance + + See Also + ======== + + RayTransferMatrix + + Examples + ======== + + >>> from sympy.physics.optics import FreeSpace + >>> from sympy import symbols + >>> d = symbols('d') + >>> FreeSpace(d) + Matrix([ + [1, d], + [0, 1]]) + """ + def __new__(cls, d): + return RayTransferMatrix.__new__(cls, 1, d, 0, 1) + + +class FlatRefraction(RayTransferMatrix): + """ + Ray Transfer Matrix for refraction. + + Parameters + ========== + + n1 : + Refractive index of one medium. + n2 : + Refractive index of other medium. + + See Also + ======== + + RayTransferMatrix + + Examples + ======== + + >>> from sympy.physics.optics import FlatRefraction + >>> from sympy import symbols + >>> n1, n2 = symbols('n1 n2') + >>> FlatRefraction(n1, n2) + Matrix([ + [1, 0], + [0, n1/n2]]) + """ + def __new__(cls, n1, n2): + n1, n2 = map(sympify, (n1, n2)) + return RayTransferMatrix.__new__(cls, 1, 0, 0, n1/n2) + + +class CurvedRefraction(RayTransferMatrix): + """ + Ray Transfer Matrix for refraction on curved interface. + + Parameters + ========== + + R : + Radius of curvature (positive for concave). + n1 : + Refractive index of one medium. + n2 : + Refractive index of other medium. + + See Also + ======== + + RayTransferMatrix + + Examples + ======== + + >>> from sympy.physics.optics import CurvedRefraction + >>> from sympy import symbols + >>> R, n1, n2 = symbols('R n1 n2') + >>> CurvedRefraction(R, n1, n2) + Matrix([ + [ 1, 0], + [(n1 - n2)/(R*n2), n1/n2]]) + """ + def __new__(cls, R, n1, n2): + R, n1, n2 = map(sympify, (R, n1, n2)) + return RayTransferMatrix.__new__(cls, 1, 0, (n1 - n2)/R/n2, n1/n2) + + +class FlatMirror(RayTransferMatrix): + """ + Ray Transfer Matrix for reflection. + + See Also + ======== + + RayTransferMatrix + + Examples + ======== + + >>> from sympy.physics.optics import FlatMirror + >>> FlatMirror() + Matrix([ + [1, 0], + [0, 1]]) + """ + def __new__(cls): + return RayTransferMatrix.__new__(cls, 1, 0, 0, 1) + + +class CurvedMirror(RayTransferMatrix): + """ + Ray Transfer Matrix for reflection from curved surface. + + Parameters + ========== + + R : radius of curvature (positive for concave) + + See Also + ======== + + RayTransferMatrix + + Examples + ======== + + >>> from sympy.physics.optics import CurvedMirror + >>> from sympy import symbols + >>> R = symbols('R') + >>> CurvedMirror(R) + Matrix([ + [ 1, 0], + [-2/R, 1]]) + """ + def __new__(cls, R): + R = sympify(R) + return RayTransferMatrix.__new__(cls, 1, 0, -2/R, 1) + + +class ThinLens(RayTransferMatrix): + """ + Ray Transfer Matrix for a thin lens. + + Parameters + ========== + + f : + The focal distance. + + See Also + ======== + + RayTransferMatrix + + Examples + ======== + + >>> from sympy.physics.optics import ThinLens + >>> from sympy import symbols + >>> f = symbols('f') + >>> ThinLens(f) + Matrix([ + [ 1, 0], + [-1/f, 1]]) + """ + def __new__(cls, f): + f = sympify(f) + return RayTransferMatrix.__new__(cls, 1, 0, -1/f, 1) + + +### +# Representation for geometric ray +### + +class GeometricRay(MutableDenseMatrix): + """ + Representation for a geometric ray in the Ray Transfer Matrix formalism. + + Parameters + ========== + + h : height, and + angle : angle, or + matrix : a 2x1 matrix (Matrix(2, 1, [height, angle])) + + Examples + ======== + + >>> from sympy.physics.optics import GeometricRay, FreeSpace + >>> from sympy import symbols, Matrix + >>> d, h, angle = symbols('d, h, angle') + + >>> GeometricRay(h, angle) + Matrix([ + [ h], + [angle]]) + + >>> FreeSpace(d)*GeometricRay(h, angle) + Matrix([ + [angle*d + h], + [ angle]]) + + >>> GeometricRay( Matrix( ((h,), (angle,)) ) ) + Matrix([ + [ h], + [angle]]) + + See Also + ======== + + RayTransferMatrix + + """ + + def __new__(cls, *args): + if len(args) == 1 and isinstance(args[0], Matrix) \ + and args[0].shape == (2, 1): + temp = args[0] + elif len(args) == 2: + temp = ((args[0],), (args[1],)) + else: + raise ValueError(filldedent(''' + Expecting 2x1 Matrix or the 2 elements of + the Matrix but got %s''' % str(args))) + return Matrix.__new__(cls, temp) + + @property + def height(self): + """ + The distance from the optical axis. + + Examples + ======== + + >>> from sympy.physics.optics import GeometricRay + >>> from sympy import symbols + >>> h, angle = symbols('h, angle') + >>> gRay = GeometricRay(h, angle) + >>> gRay.height + h + """ + return self[0] + + @property + def angle(self): + """ + The angle with the optical axis. + + Examples + ======== + + >>> from sympy.physics.optics import GeometricRay + >>> from sympy import symbols + >>> h, angle = symbols('h, angle') + >>> gRay = GeometricRay(h, angle) + >>> gRay.angle + angle + """ + return self[1] + + +### +# Representation for gauss beam +### + +class BeamParameter(Expr): + """ + Representation for a gaussian ray in the Ray Transfer Matrix formalism. + + Parameters + ========== + + wavelen : the wavelength, + z : the distance to waist, and + w : the waist, or + z_r : the rayleigh range. + n : the refractive index of medium. + + Examples + ======== + + >>> from sympy.physics.optics import BeamParameter + >>> p = BeamParameter(530e-9, 1, w=1e-3) + >>> p.q + 1 + 1.88679245283019*I*pi + + >>> p.q.n() + 1.0 + 5.92753330865999*I + >>> p.w_0.n() + 0.00100000000000000 + >>> p.z_r.n() + 5.92753330865999 + + >>> from sympy.physics.optics import FreeSpace + >>> fs = FreeSpace(10) + >>> p1 = fs*p + >>> p.w.n() + 0.00101413072159615 + >>> p1.w.n() + 0.00210803120913829 + + See Also + ======== + + RayTransferMatrix + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Complex_beam_parameter + .. [2] https://en.wikipedia.org/wiki/Gaussian_beam + """ + #TODO A class Complex may be implemented. The BeamParameter may + # subclass it. See: + # https://groups.google.com/d/topic/sympy/7XkU07NRBEs/discussion + + def __new__(cls, wavelen, z, z_r=None, w=None, n=1): + wavelen = sympify(wavelen) + z = sympify(z) + n = sympify(n) + + if z_r is not None and w is None: + z_r = sympify(z_r) + elif w is not None and z_r is None: + z_r = waist2rayleigh(sympify(w), wavelen, n) + elif z_r is None and w is None: + raise ValueError('Must specify one of w and z_r.') + + return Expr.__new__(cls, wavelen, z, z_r, n) + + @property + def wavelen(self): + return self.args[0] + + @property + def z(self): + return self.args[1] + + @property + def z_r(self): + return self.args[2] + + @property + def n(self): + return self.args[3] + + @property + def q(self): + """ + The complex parameter representing the beam. + + Examples + ======== + + >>> from sympy.physics.optics import BeamParameter + >>> p = BeamParameter(530e-9, 1, w=1e-3) + >>> p.q + 1 + 1.88679245283019*I*pi + """ + return self.z + I*self.z_r + + @property + def radius(self): + """ + The radius of curvature of the phase front. + + Examples + ======== + + >>> from sympy.physics.optics import BeamParameter + >>> p = BeamParameter(530e-9, 1, w=1e-3) + >>> p.radius + 1 + 3.55998576005696*pi**2 + """ + return self.z*(1 + (self.z_r/self.z)**2) + + @property + def w(self): + """ + The radius of the beam w(z), at any position z along the beam. + The beam radius at `1/e^2` intensity (axial value). + + See Also + ======== + + w_0 : + The minimal radius of beam. + + Examples + ======== + + >>> from sympy.physics.optics import BeamParameter + >>> p = BeamParameter(530e-9, 1, w=1e-3) + >>> p.w + 0.001*sqrt(0.2809/pi**2 + 1) + """ + return self.w_0*sqrt(1 + (self.z/self.z_r)**2) + + @property + def w_0(self): + """ + The minimal radius of beam at `1/e^2` intensity (peak value). + + See Also + ======== + + w : the beam radius at `1/e^2` intensity (axial value). + + Examples + ======== + + >>> from sympy.physics.optics import BeamParameter + >>> p = BeamParameter(530e-9, 1, w=1e-3) + >>> p.w_0 + 0.00100000000000000 + """ + return sqrt(self.z_r/(pi*self.n)*self.wavelen) + + @property + def divergence(self): + """ + Half of the total angular spread. + + Examples + ======== + + >>> from sympy.physics.optics import BeamParameter + >>> p = BeamParameter(530e-9, 1, w=1e-3) + >>> p.divergence + 0.00053/pi + """ + return self.wavelen/pi/self.w_0 + + @property + def gouy(self): + """ + The Gouy phase. + + Examples + ======== + + >>> from sympy.physics.optics import BeamParameter + >>> p = BeamParameter(530e-9, 1, w=1e-3) + >>> p.gouy + atan(0.53/pi) + """ + return atan2(self.z, self.z_r) + + @property + def waist_approximation_limit(self): + """ + The minimal waist for which the gauss beam approximation is valid. + + Explanation + =========== + + The gauss beam is a solution to the paraxial equation. For curvatures + that are too great it is not a valid approximation. + + Examples + ======== + + >>> from sympy.physics.optics import BeamParameter + >>> p = BeamParameter(530e-9, 1, w=1e-3) + >>> p.waist_approximation_limit + 1.06e-6/pi + """ + return 2*self.wavelen/pi + + +### +# Utilities +### + +def waist2rayleigh(w, wavelen, n=1): + """ + Calculate the rayleigh range from the waist of a gaussian beam. + + See Also + ======== + + rayleigh2waist, BeamParameter + + Examples + ======== + + >>> from sympy.physics.optics import waist2rayleigh + >>> from sympy import symbols + >>> w, wavelen = symbols('w wavelen') + >>> waist2rayleigh(w, wavelen) + pi*w**2/wavelen + """ + w, wavelen = map(sympify, (w, wavelen)) + return w**2*n*pi/wavelen + + +def rayleigh2waist(z_r, wavelen): + """Calculate the waist from the rayleigh range of a gaussian beam. + + See Also + ======== + + waist2rayleigh, BeamParameter + + Examples + ======== + + >>> from sympy.physics.optics import rayleigh2waist + >>> from sympy import symbols + >>> z_r, wavelen = symbols('z_r wavelen') + >>> rayleigh2waist(z_r, wavelen) + sqrt(wavelen*z_r)/sqrt(pi) + """ + z_r, wavelen = map(sympify, (z_r, wavelen)) + return sqrt(z_r/pi*wavelen) + + +def geometric_conj_ab(a, b): + """ + Conjugation relation for geometrical beams under paraxial conditions. + + Explanation + =========== + + Takes the distances to the optical element and returns the needed + focal distance. + + See Also + ======== + + geometric_conj_af, geometric_conj_bf + + Examples + ======== + + >>> from sympy.physics.optics import geometric_conj_ab + >>> from sympy import symbols + >>> a, b = symbols('a b') + >>> geometric_conj_ab(a, b) + a*b/(a + b) + """ + a, b = map(sympify, (a, b)) + if a.is_infinite or b.is_infinite: + return a if b.is_infinite else b + else: + return a*b/(a + b) + + +def geometric_conj_af(a, f): + """ + Conjugation relation for geometrical beams under paraxial conditions. + + Explanation + =========== + + Takes the object distance (for geometric_conj_af) or the image distance + (for geometric_conj_bf) to the optical element and the focal distance. + Then it returns the other distance needed for conjugation. + + See Also + ======== + + geometric_conj_ab + + Examples + ======== + + >>> from sympy.physics.optics.gaussopt import geometric_conj_af, geometric_conj_bf + >>> from sympy import symbols + >>> a, b, f = symbols('a b f') + >>> geometric_conj_af(a, f) + a*f/(a - f) + >>> geometric_conj_bf(b, f) + b*f/(b - f) + """ + a, f = map(sympify, (a, f)) + return -geometric_conj_ab(a, -f) + +geometric_conj_bf = geometric_conj_af + + +def gaussian_conj(s_in, z_r_in, f): + """ + Conjugation relation for gaussian beams. + + Parameters + ========== + + s_in : + The distance to optical element from the waist. + z_r_in : + The rayleigh range of the incident beam. + f : + The focal length of the optical element. + + Returns + ======= + + a tuple containing (s_out, z_r_out, m) + s_out : + The distance between the new waist and the optical element. + z_r_out : + The rayleigh range of the emergent beam. + m : + The ration between the new and the old waists. + + Examples + ======== + + >>> from sympy.physics.optics import gaussian_conj + >>> from sympy import symbols + >>> s_in, z_r_in, f = symbols('s_in z_r_in f') + + >>> gaussian_conj(s_in, z_r_in, f)[0] + 1/(-1/(s_in + z_r_in**2/(-f + s_in)) + 1/f) + + >>> gaussian_conj(s_in, z_r_in, f)[1] + z_r_in/(1 - s_in**2/f**2 + z_r_in**2/f**2) + + >>> gaussian_conj(s_in, z_r_in, f)[2] + 1/sqrt(1 - s_in**2/f**2 + z_r_in**2/f**2) + """ + s_in, z_r_in, f = map(sympify, (s_in, z_r_in, f)) + s_out = 1 / ( -1/(s_in + z_r_in**2/(s_in - f)) + 1/f ) + m = 1/sqrt((1 - (s_in/f)**2) + (z_r_in/f)**2) + z_r_out = z_r_in / ((1 - (s_in/f)**2) + (z_r_in/f)**2) + return (s_out, z_r_out, m) + + +def conjugate_gauss_beams(wavelen, waist_in, waist_out, **kwargs): + """ + Find the optical setup conjugating the object/image waists. + + Parameters + ========== + + wavelen : + The wavelength of the beam. + waist_in and waist_out : + The waists to be conjugated. + f : + The focal distance of the element used in the conjugation. + + Returns + ======= + + a tuple containing (s_in, s_out, f) + s_in : + The distance before the optical element. + s_out : + The distance after the optical element. + f : + The focal distance of the optical element. + + Examples + ======== + + >>> from sympy.physics.optics import conjugate_gauss_beams + >>> from sympy import symbols, factor + >>> l, w_i, w_o, f = symbols('l w_i w_o f') + + >>> conjugate_gauss_beams(l, w_i, w_o, f=f)[0] + f*(1 - sqrt(w_i**2/w_o**2 - pi**2*w_i**4/(f**2*l**2))) + + >>> factor(conjugate_gauss_beams(l, w_i, w_o, f=f)[1]) + f*w_o**2*(w_i**2/w_o**2 - sqrt(w_i**2/w_o**2 - + pi**2*w_i**4/(f**2*l**2)))/w_i**2 + + >>> conjugate_gauss_beams(l, w_i, w_o, f=f)[2] + f + """ + #TODO add the other possible arguments + wavelen, waist_in, waist_out = map(sympify, (wavelen, waist_in, waist_out)) + m = waist_out / waist_in + z = waist2rayleigh(waist_in, wavelen) + if len(kwargs) != 1: + raise ValueError("The function expects only one named argument") + elif 'dist' in kwargs: + raise NotImplementedError(filldedent(''' + Currently only focal length is supported as a parameter''')) + elif 'f' in kwargs: + f = sympify(kwargs['f']) + s_in = f * (1 - sqrt(1/m**2 - z**2/f**2)) + s_out = gaussian_conj(s_in, z, f)[0] + elif 's_in' in kwargs: + raise NotImplementedError(filldedent(''' + Currently only focal length is supported as a parameter''')) + else: + raise ValueError(filldedent(''' + The functions expects the focal length as a named argument''')) + return (s_in, s_out, f) + +#TODO +#def plot_beam(): +# """Plot the beam radius as it propagates in space.""" +# pass + +#TODO +#def plot_beam_conjugation(): +# """ +# Plot the intersection of two beams. +# +# Represents the conjugation relation. +# +# See Also +# ======== +# +# conjugate_gauss_beams +# """ +# pass diff --git a/phivenv/Lib/site-packages/sympy/physics/optics/medium.py b/phivenv/Lib/site-packages/sympy/physics/optics/medium.py new file mode 100644 index 0000000000000000000000000000000000000000..764b68caad5865b8f3cee028a14cfa304796b4c0 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/optics/medium.py @@ -0,0 +1,253 @@ +""" +**Contains** + +* Medium +""" +from sympy.physics.units import second, meter, kilogram, ampere + +__all__ = ['Medium'] + +from sympy.core.basic import Basic +from sympy.core.symbol import Str +from sympy.core.sympify import _sympify +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.physics.units import speed_of_light, u0, e0 + + +c = speed_of_light.convert_to(meter/second) +_e0mksa = e0.convert_to(ampere**2*second**4/(kilogram*meter**3)) +_u0mksa = u0.convert_to(meter*kilogram/(ampere**2*second**2)) + + +class Medium(Basic): + + """ + This class represents an optical medium. The prime reason to implement this is + to facilitate refraction, Fermat's principle, etc. + + Explanation + =========== + + An optical medium is a material through which electromagnetic waves propagate. + The permittivity and permeability of the medium define how electromagnetic + waves propagate in it. + + + Parameters + ========== + + name: string + The display name of the Medium. + + permittivity: Sympifyable + Electric permittivity of the space. + + permeability: Sympifyable + Magnetic permeability of the space. + + n: Sympifyable + Index of refraction of the medium. + + + Examples + ======== + + >>> from sympy.abc import epsilon, mu + >>> from sympy.physics.optics import Medium + >>> m1 = Medium('m1') + >>> m2 = Medium('m2', epsilon, mu) + >>> m1.intrinsic_impedance + 149896229*pi*kilogram*meter**2/(1250000*ampere**2*second**3) + >>> m2.refractive_index + 299792458*meter*sqrt(epsilon*mu)/second + + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Optical_medium + + """ + + def __new__(cls, name, permittivity=None, permeability=None, n=None): + if not isinstance(name, Str): + name = Str(name) + + permittivity = _sympify(permittivity) if permittivity is not None else permittivity + permeability = _sympify(permeability) if permeability is not None else permeability + n = _sympify(n) if n is not None else n + + if n is not None: + if permittivity is not None and permeability is None: + permeability = n**2/(c**2*permittivity) + return MediumPP(name, permittivity, permeability) + elif permeability is not None and permittivity is None: + permittivity = n**2/(c**2*permeability) + return MediumPP(name, permittivity, permeability) + elif permittivity is not None and permittivity is not None: + raise ValueError("Specifying all of permittivity, permeability, and n is not allowed") + else: + return MediumN(name, n) + elif permittivity is not None and permeability is not None: + return MediumPP(name, permittivity, permeability) + elif permittivity is None and permeability is None: + return MediumPP(name, _e0mksa, _u0mksa) + else: + raise ValueError("Arguments are underspecified. Either specify n or any two of permittivity, " + "permeability, and n") + + @property + def name(self): + return self.args[0] + + @property + def speed(self): + """ + Returns speed of the electromagnetic wave travelling in the medium. + + Examples + ======== + + >>> from sympy.physics.optics import Medium + >>> m = Medium('m') + >>> m.speed + 299792458*meter/second + >>> m2 = Medium('m2', n=1) + >>> m.speed == m2.speed + True + + """ + return c / self.n + + @property + def refractive_index(self): + """ + Returns refractive index of the medium. + + Examples + ======== + + >>> from sympy.physics.optics import Medium + >>> m = Medium('m') + >>> m.refractive_index + 1 + + """ + return (c/self.speed) + + +class MediumN(Medium): + + """ + Represents an optical medium for which only the refractive index is known. + Useful for simple ray optics. + + This class should never be instantiated directly. + Instead it should be instantiated indirectly by instantiating Medium with + only n specified. + + Examples + ======== + >>> from sympy.physics.optics import Medium + >>> m = Medium('m', n=2) + >>> m + MediumN(Str('m'), 2) + """ + + def __new__(cls, name, n): + obj = super(Medium, cls).__new__(cls, name, n) + return obj + + @property + def n(self): + return self.args[1] + + +class MediumPP(Medium): + """ + Represents an optical medium for which the permittivity and permeability are known. + + This class should never be instantiated directly. Instead it should be + instantiated indirectly by instantiating Medium with any two of + permittivity, permeability, and n specified, or by not specifying any + of permittivity, permeability, or n, in which case default values for + permittivity and permeability will be used. + + Examples + ======== + >>> from sympy.physics.optics import Medium + >>> from sympy.abc import epsilon, mu + >>> m1 = Medium('m1', permittivity=epsilon, permeability=mu) + >>> m1 + MediumPP(Str('m1'), epsilon, mu) + >>> m2 = Medium('m2') + >>> m2 + MediumPP(Str('m2'), 625000*ampere**2*second**4/(22468879468420441*pi*kilogram*meter**3), pi*kilogram*meter/(2500000*ampere**2*second**2)) + """ + + + def __new__(cls, name, permittivity, permeability): + obj = super(Medium, cls).__new__(cls, name, permittivity, permeability) + return obj + + @property + def intrinsic_impedance(self): + """ + Returns intrinsic impedance of the medium. + + Explanation + =========== + + The intrinsic impedance of a medium is the ratio of the + transverse components of the electric and magnetic fields + of the electromagnetic wave travelling in the medium. + In a region with no electrical conductivity it simplifies + to the square root of ratio of magnetic permeability to + electric permittivity. + + Examples + ======== + + >>> from sympy.physics.optics import Medium + >>> m = Medium('m') + >>> m.intrinsic_impedance + 149896229*pi*kilogram*meter**2/(1250000*ampere**2*second**3) + + """ + return sqrt(self.permeability / self.permittivity) + + @property + def permittivity(self): + """ + Returns electric permittivity of the medium. + + Examples + ======== + + >>> from sympy.physics.optics import Medium + >>> m = Medium('m') + >>> m.permittivity + 625000*ampere**2*second**4/(22468879468420441*pi*kilogram*meter**3) + + """ + return self.args[1] + + @property + def permeability(self): + """ + Returns magnetic permeability of the medium. + + Examples + ======== + + >>> from sympy.physics.optics import Medium + >>> m = Medium('m') + >>> m.permeability + pi*kilogram*meter/(2500000*ampere**2*second**2) + + """ + return self.args[2] + + @property + def n(self): + return c*sqrt(self.permittivity*self.permeability) diff --git a/phivenv/Lib/site-packages/sympy/physics/optics/polarization.py b/phivenv/Lib/site-packages/sympy/physics/optics/polarization.py new file mode 100644 index 0000000000000000000000000000000000000000..0bdb546548ad082ef38f5f0c159d7eadd38f6d30 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/optics/polarization.py @@ -0,0 +1,732 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +The module implements routines to model the polarization of optical fields +and can be used to calculate the effects of polarization optical elements on +the fields. + +- Jones vectors. + +- Stokes vectors. + +- Jones matrices. + +- Mueller matrices. + +Examples +======== + +We calculate a generic Jones vector: + +>>> from sympy import symbols, pprint, zeros, simplify +>>> from sympy.physics.optics.polarization import (jones_vector, stokes_vector, +... half_wave_retarder, polarizing_beam_splitter, jones_2_stokes) + +>>> psi, chi, p, I0 = symbols("psi, chi, p, I0", real=True) +>>> x0 = jones_vector(psi, chi) +>>> pprint(x0, use_unicode=True) +⎡-ⅈ⋅sin(χ)⋅sin(ψ) + cos(χ)⋅cos(ψ)⎤ +⎢ ⎥ +⎣ⅈ⋅sin(χ)⋅cos(ψ) + sin(ψ)⋅cos(χ) ⎦ + +And the more general Stokes vector: + +>>> s0 = stokes_vector(psi, chi, p, I0) +>>> pprint(s0, use_unicode=True) +⎡ I₀ ⎤ +⎢ ⎥ +⎢I₀⋅p⋅cos(2⋅χ)⋅cos(2⋅ψ)⎥ +⎢ ⎥ +⎢I₀⋅p⋅sin(2⋅ψ)⋅cos(2⋅χ)⎥ +⎢ ⎥ +⎣ I₀⋅p⋅sin(2⋅χ) ⎦ + +We calculate how the Jones vector is modified by a half-wave plate: + +>>> alpha = symbols("alpha", real=True) +>>> HWP = half_wave_retarder(alpha) +>>> x1 = simplify(HWP*x0) + +We calculate the very common operation of passing a beam through a half-wave +plate and then through a polarizing beam-splitter. We do this by putting this +Jones vector as the first entry of a two-Jones-vector state that is transformed +by a 4x4 Jones matrix modelling the polarizing beam-splitter to get the +transmitted and reflected Jones vectors: + +>>> PBS = polarizing_beam_splitter() +>>> X1 = zeros(4, 1) +>>> X1[:2, :] = x1 +>>> X2 = PBS*X1 +>>> transmitted_port = X2[:2, :] +>>> reflected_port = X2[2:, :] + +This allows us to calculate how the power in both ports depends on the initial +polarization: + +>>> transmitted_power = jones_2_stokes(transmitted_port)[0] +>>> reflected_power = jones_2_stokes(reflected_port)[0] +>>> print(transmitted_power) +cos(-2*alpha + chi + psi)**2/2 + cos(2*alpha + chi - psi)**2/2 + + +>>> print(reflected_power) +sin(-2*alpha + chi + psi)**2/2 + sin(2*alpha + chi - psi)**2/2 + +Please see the description of the individual functions for further +details and examples. + +References +========== + +.. [1] https://en.wikipedia.org/wiki/Jones_calculus +.. [2] https://en.wikipedia.org/wiki/Mueller_calculus +.. [3] https://en.wikipedia.org/wiki/Stokes_parameters + +""" + +from sympy.core.numbers import (I, pi) +from sympy.functions.elementary.complexes import (Abs, im, re) +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.matrices.dense import Matrix +from sympy.simplify.simplify import simplify +from sympy.physics.quantum import TensorProduct + + +def jones_vector(psi, chi): + """A Jones vector corresponding to a polarization ellipse with `psi` tilt, + and `chi` circularity. + + Parameters + ========== + + psi : numeric type or SymPy Symbol + The tilt of the polarization relative to the `x` axis. + + chi : numeric type or SymPy Symbol + The angle adjacent to the mayor axis of the polarization ellipse. + + + Returns + ======= + + Matrix : + A Jones vector. + + Examples + ======== + + The axes on the Poincaré sphere. + + >>> from sympy import pprint, symbols, pi + >>> from sympy.physics.optics.polarization import jones_vector + >>> psi, chi = symbols("psi, chi", real=True) + + A general Jones vector. + + >>> pprint(jones_vector(psi, chi), use_unicode=True) + ⎡-ⅈ⋅sin(χ)⋅sin(ψ) + cos(χ)⋅cos(ψ)⎤ + ⎢ ⎥ + ⎣ⅈ⋅sin(χ)⋅cos(ψ) + sin(ψ)⋅cos(χ) ⎦ + + Horizontal polarization. + + >>> pprint(jones_vector(0, 0), use_unicode=True) + ⎡1⎤ + ⎢ ⎥ + ⎣0⎦ + + Vertical polarization. + + >>> pprint(jones_vector(pi/2, 0), use_unicode=True) + ⎡0⎤ + ⎢ ⎥ + ⎣1⎦ + + Diagonal polarization. + + >>> pprint(jones_vector(pi/4, 0), use_unicode=True) + ⎡√2⎤ + ⎢──⎥ + ⎢2 ⎥ + ⎢ ⎥ + ⎢√2⎥ + ⎢──⎥ + ⎣2 ⎦ + + Anti-diagonal polarization. + + >>> pprint(jones_vector(-pi/4, 0), use_unicode=True) + ⎡ √2 ⎤ + ⎢ ── ⎥ + ⎢ 2 ⎥ + ⎢ ⎥ + ⎢-√2 ⎥ + ⎢────⎥ + ⎣ 2 ⎦ + + Right-hand circular polarization. + + >>> pprint(jones_vector(0, pi/4), use_unicode=True) + ⎡ √2 ⎤ + ⎢ ── ⎥ + ⎢ 2 ⎥ + ⎢ ⎥ + ⎢√2⋅ⅈ⎥ + ⎢────⎥ + ⎣ 2 ⎦ + + Left-hand circular polarization. + + >>> pprint(jones_vector(0, -pi/4), use_unicode=True) + ⎡ √2 ⎤ + ⎢ ── ⎥ + ⎢ 2 ⎥ + ⎢ ⎥ + ⎢-√2⋅ⅈ ⎥ + ⎢──────⎥ + ⎣ 2 ⎦ + + """ + return Matrix([-I*sin(chi)*sin(psi) + cos(chi)*cos(psi), + I*sin(chi)*cos(psi) + sin(psi)*cos(chi)]) + + +def stokes_vector(psi, chi, p=1, I=1): + """A Stokes vector corresponding to a polarization ellipse with ``psi`` + tilt, and ``chi`` circularity. + + Parameters + ========== + + psi : numeric type or SymPy Symbol + The tilt of the polarization relative to the ``x`` axis. + chi : numeric type or SymPy Symbol + The angle adjacent to the mayor axis of the polarization ellipse. + p : numeric type or SymPy Symbol + The degree of polarization. + I : numeric type or SymPy Symbol + The intensity of the field. + + + Returns + ======= + + Matrix : + A Stokes vector. + + Examples + ======== + + The axes on the Poincaré sphere. + + >>> from sympy import pprint, symbols, pi + >>> from sympy.physics.optics.polarization import stokes_vector + >>> psi, chi, p, I = symbols("psi, chi, p, I", real=True) + >>> pprint(stokes_vector(psi, chi, p, I), use_unicode=True) + ⎡ I ⎤ + ⎢ ⎥ + ⎢I⋅p⋅cos(2⋅χ)⋅cos(2⋅ψ)⎥ + ⎢ ⎥ + ⎢I⋅p⋅sin(2⋅ψ)⋅cos(2⋅χ)⎥ + ⎢ ⎥ + ⎣ I⋅p⋅sin(2⋅χ) ⎦ + + + Horizontal polarization + + >>> pprint(stokes_vector(0, 0), use_unicode=True) + ⎡1⎤ + ⎢ ⎥ + ⎢1⎥ + ⎢ ⎥ + ⎢0⎥ + ⎢ ⎥ + ⎣0⎦ + + Vertical polarization + + >>> pprint(stokes_vector(pi/2, 0), use_unicode=True) + ⎡1 ⎤ + ⎢ ⎥ + ⎢-1⎥ + ⎢ ⎥ + ⎢0 ⎥ + ⎢ ⎥ + ⎣0 ⎦ + + Diagonal polarization + + >>> pprint(stokes_vector(pi/4, 0), use_unicode=True) + ⎡1⎤ + ⎢ ⎥ + ⎢0⎥ + ⎢ ⎥ + ⎢1⎥ + ⎢ ⎥ + ⎣0⎦ + + Anti-diagonal polarization + + >>> pprint(stokes_vector(-pi/4, 0), use_unicode=True) + ⎡1 ⎤ + ⎢ ⎥ + ⎢0 ⎥ + ⎢ ⎥ + ⎢-1⎥ + ⎢ ⎥ + ⎣0 ⎦ + + Right-hand circular polarization + + >>> pprint(stokes_vector(0, pi/4), use_unicode=True) + ⎡1⎤ + ⎢ ⎥ + ⎢0⎥ + ⎢ ⎥ + ⎢0⎥ + ⎢ ⎥ + ⎣1⎦ + + Left-hand circular polarization + + >>> pprint(stokes_vector(0, -pi/4), use_unicode=True) + ⎡1 ⎤ + ⎢ ⎥ + ⎢0 ⎥ + ⎢ ⎥ + ⎢0 ⎥ + ⎢ ⎥ + ⎣-1⎦ + + Unpolarized light + + >>> pprint(stokes_vector(0, 0, 0), use_unicode=True) + ⎡1⎤ + ⎢ ⎥ + ⎢0⎥ + ⎢ ⎥ + ⎢0⎥ + ⎢ ⎥ + ⎣0⎦ + + """ + S0 = I + S1 = I*p*cos(2*psi)*cos(2*chi) + S2 = I*p*sin(2*psi)*cos(2*chi) + S3 = I*p*sin(2*chi) + return Matrix([S0, S1, S2, S3]) + + +def jones_2_stokes(e): + """Return the Stokes vector for a Jones vector ``e``. + + Parameters + ========== + + e : SymPy Matrix + A Jones vector. + + Returns + ======= + + SymPy Matrix + A Jones vector. + + Examples + ======== + + The axes on the Poincaré sphere. + + >>> from sympy import pprint, pi + >>> from sympy.physics.optics.polarization import jones_vector + >>> from sympy.physics.optics.polarization import jones_2_stokes + >>> H = jones_vector(0, 0) + >>> V = jones_vector(pi/2, 0) + >>> D = jones_vector(pi/4, 0) + >>> A = jones_vector(-pi/4, 0) + >>> R = jones_vector(0, pi/4) + >>> L = jones_vector(0, -pi/4) + >>> pprint([jones_2_stokes(e) for e in [H, V, D, A, R, L]], + ... use_unicode=True) + ⎡⎡1⎤ ⎡1 ⎤ ⎡1⎤ ⎡1 ⎤ ⎡1⎤ ⎡1 ⎤⎤ + ⎢⎢ ⎥ ⎢ ⎥ ⎢ ⎥ ⎢ ⎥ ⎢ ⎥ ⎢ ⎥⎥ + ⎢⎢1⎥ ⎢-1⎥ ⎢0⎥ ⎢0 ⎥ ⎢0⎥ ⎢0 ⎥⎥ + ⎢⎢ ⎥, ⎢ ⎥, ⎢ ⎥, ⎢ ⎥, ⎢ ⎥, ⎢ ⎥⎥ + ⎢⎢0⎥ ⎢0 ⎥ ⎢1⎥ ⎢-1⎥ ⎢0⎥ ⎢0 ⎥⎥ + ⎢⎢ ⎥ ⎢ ⎥ ⎢ ⎥ ⎢ ⎥ ⎢ ⎥ ⎢ ⎥⎥ + ⎣⎣0⎦ ⎣0 ⎦ ⎣0⎦ ⎣0 ⎦ ⎣1⎦ ⎣-1⎦⎦ + + """ + ex, ey = e + return Matrix([Abs(ex)**2 + Abs(ey)**2, + Abs(ex)**2 - Abs(ey)**2, + 2*re(ex*ey.conjugate()), + -2*im(ex*ey.conjugate())]) + + +def linear_polarizer(theta=0): + """A linear polarizer Jones matrix with transmission axis at + an angle ``theta``. + + Parameters + ========== + + theta : numeric type or SymPy Symbol + The angle of the transmission axis relative to the horizontal plane. + + Returns + ======= + + SymPy Matrix + A Jones matrix representing the polarizer. + + Examples + ======== + + A generic polarizer. + + >>> from sympy import pprint, symbols + >>> from sympy.physics.optics.polarization import linear_polarizer + >>> theta = symbols("theta", real=True) + >>> J = linear_polarizer(theta) + >>> pprint(J, use_unicode=True) + ⎡ 2 ⎤ + ⎢ cos (θ) sin(θ)⋅cos(θ)⎥ + ⎢ ⎥ + ⎢ 2 ⎥ + ⎣sin(θ)⋅cos(θ) sin (θ) ⎦ + + + """ + M = Matrix([[cos(theta)**2, sin(theta)*cos(theta)], + [sin(theta)*cos(theta), sin(theta)**2]]) + return M + + +def phase_retarder(theta=0, delta=0): + """A phase retarder Jones matrix with retardance ``delta`` at angle ``theta``. + + Parameters + ========== + + theta : numeric type or SymPy Symbol + The angle of the fast axis relative to the horizontal plane. + delta : numeric type or SymPy Symbol + The phase difference between the fast and slow axes of the + transmitted light. + + Returns + ======= + + SymPy Matrix : + A Jones matrix representing the retarder. + + Examples + ======== + + A generic retarder. + + >>> from sympy import pprint, symbols + >>> from sympy.physics.optics.polarization import phase_retarder + >>> theta, delta = symbols("theta, delta", real=True) + >>> R = phase_retarder(theta, delta) + >>> pprint(R, use_unicode=True) + ⎡ -ⅈ⋅δ -ⅈ⋅δ ⎤ + ⎢ ───── ───── ⎥ + ⎢⎛ ⅈ⋅δ 2 2 ⎞ 2 ⎛ ⅈ⋅δ⎞ 2 ⎥ + ⎢⎝ℯ ⋅sin (θ) + cos (θ)⎠⋅ℯ ⎝1 - ℯ ⎠⋅ℯ ⋅sin(θ)⋅cos(θ)⎥ + ⎢ ⎥ + ⎢ -ⅈ⋅δ -ⅈ⋅δ ⎥ + ⎢ ───── ─────⎥ + ⎢⎛ ⅈ⋅δ⎞ 2 ⎛ ⅈ⋅δ 2 2 ⎞ 2 ⎥ + ⎣⎝1 - ℯ ⎠⋅ℯ ⋅sin(θ)⋅cos(θ) ⎝ℯ ⋅cos (θ) + sin (θ)⎠⋅ℯ ⎦ + + """ + R = Matrix([[cos(theta)**2 + exp(I*delta)*sin(theta)**2, + (1-exp(I*delta))*cos(theta)*sin(theta)], + [(1-exp(I*delta))*cos(theta)*sin(theta), + sin(theta)**2 + exp(I*delta)*cos(theta)**2]]) + return R*exp(-I*delta/2) + + +def half_wave_retarder(theta): + """A half-wave retarder Jones matrix at angle ``theta``. + + Parameters + ========== + + theta : numeric type or SymPy Symbol + The angle of the fast axis relative to the horizontal plane. + + Returns + ======= + + SymPy Matrix + A Jones matrix representing the retarder. + + Examples + ======== + + A generic half-wave plate. + + >>> from sympy import pprint, symbols + >>> from sympy.physics.optics.polarization import half_wave_retarder + >>> theta= symbols("theta", real=True) + >>> HWP = half_wave_retarder(theta) + >>> pprint(HWP, use_unicode=True) + ⎡ ⎛ 2 2 ⎞ ⎤ + ⎢-ⅈ⋅⎝- sin (θ) + cos (θ)⎠ -2⋅ⅈ⋅sin(θ)⋅cos(θ) ⎥ + ⎢ ⎥ + ⎢ ⎛ 2 2 ⎞⎥ + ⎣ -2⋅ⅈ⋅sin(θ)⋅cos(θ) -ⅈ⋅⎝sin (θ) - cos (θ)⎠⎦ + + """ + return phase_retarder(theta, pi) + + +def quarter_wave_retarder(theta): + """A quarter-wave retarder Jones matrix at angle ``theta``. + + Parameters + ========== + + theta : numeric type or SymPy Symbol + The angle of the fast axis relative to the horizontal plane. + + Returns + ======= + + SymPy Matrix + A Jones matrix representing the retarder. + + Examples + ======== + + A generic quarter-wave plate. + + >>> from sympy import pprint, symbols + >>> from sympy.physics.optics.polarization import quarter_wave_retarder + >>> theta= symbols("theta", real=True) + >>> QWP = quarter_wave_retarder(theta) + >>> pprint(QWP, use_unicode=True) + ⎡ -ⅈ⋅π -ⅈ⋅π ⎤ + ⎢ ───── ───── ⎥ + ⎢⎛ 2 2 ⎞ 4 4 ⎥ + ⎢⎝ⅈ⋅sin (θ) + cos (θ)⎠⋅ℯ (1 - ⅈ)⋅ℯ ⋅sin(θ)⋅cos(θ)⎥ + ⎢ ⎥ + ⎢ -ⅈ⋅π -ⅈ⋅π ⎥ + ⎢ ───── ─────⎥ + ⎢ 4 ⎛ 2 2 ⎞ 4 ⎥ + ⎣(1 - ⅈ)⋅ℯ ⋅sin(θ)⋅cos(θ) ⎝sin (θ) + ⅈ⋅cos (θ)⎠⋅ℯ ⎦ + + """ + return phase_retarder(theta, pi/2) + + +def transmissive_filter(T): + """An attenuator Jones matrix with transmittance ``T``. + + Parameters + ========== + + T : numeric type or SymPy Symbol + The transmittance of the attenuator. + + Returns + ======= + + SymPy Matrix + A Jones matrix representing the filter. + + Examples + ======== + + A generic filter. + + >>> from sympy import pprint, symbols + >>> from sympy.physics.optics.polarization import transmissive_filter + >>> T = symbols("T", real=True) + >>> NDF = transmissive_filter(T) + >>> pprint(NDF, use_unicode=True) + ⎡√T 0 ⎤ + ⎢ ⎥ + ⎣0 √T⎦ + + """ + return Matrix([[sqrt(T), 0], [0, sqrt(T)]]) + + +def reflective_filter(R): + """A reflective filter Jones matrix with reflectance ``R``. + + Parameters + ========== + + R : numeric type or SymPy Symbol + The reflectance of the filter. + + Returns + ======= + + SymPy Matrix + A Jones matrix representing the filter. + + Examples + ======== + + A generic filter. + + >>> from sympy import pprint, symbols + >>> from sympy.physics.optics.polarization import reflective_filter + >>> R = symbols("R", real=True) + >>> pprint(reflective_filter(R), use_unicode=True) + ⎡√R 0 ⎤ + ⎢ ⎥ + ⎣0 -√R⎦ + + """ + return Matrix([[sqrt(R), 0], [0, -sqrt(R)]]) + + +def mueller_matrix(J): + """The Mueller matrix corresponding to Jones matrix `J`. + + Parameters + ========== + + J : SymPy Matrix + A Jones matrix. + + Returns + ======= + + SymPy Matrix + The corresponding Mueller matrix. + + Examples + ======== + + Generic optical components. + + >>> from sympy import pprint, symbols + >>> from sympy.physics.optics.polarization import (mueller_matrix, + ... linear_polarizer, half_wave_retarder, quarter_wave_retarder) + >>> theta = symbols("theta", real=True) + + A linear_polarizer + + >>> pprint(mueller_matrix(linear_polarizer(theta)), use_unicode=True) + ⎡ cos(2⋅θ) sin(2⋅θ) ⎤ + ⎢ 1/2 ──────── ──────── 0⎥ + ⎢ 2 2 ⎥ + ⎢ ⎥ + ⎢cos(2⋅θ) cos(4⋅θ) 1 sin(4⋅θ) ⎥ + ⎢──────── ──────── + ─ ──────── 0⎥ + ⎢ 2 4 4 4 ⎥ + ⎢ ⎥ + ⎢sin(2⋅θ) sin(4⋅θ) 1 cos(4⋅θ) ⎥ + ⎢──────── ──────── ─ - ──────── 0⎥ + ⎢ 2 4 4 4 ⎥ + ⎢ ⎥ + ⎣ 0 0 0 0⎦ + + A half-wave plate + + >>> pprint(mueller_matrix(half_wave_retarder(theta)), use_unicode=True) + ⎡1 0 0 0 ⎤ + ⎢ ⎥ + ⎢ 4 2 ⎥ + ⎢0 8⋅sin (θ) - 8⋅sin (θ) + 1 sin(4⋅θ) 0 ⎥ + ⎢ ⎥ + ⎢ 4 2 ⎥ + ⎢0 sin(4⋅θ) - 8⋅sin (θ) + 8⋅sin (θ) - 1 0 ⎥ + ⎢ ⎥ + ⎣0 0 0 -1⎦ + + A quarter-wave plate + + >>> pprint(mueller_matrix(quarter_wave_retarder(theta)), use_unicode=True) + ⎡1 0 0 0 ⎤ + ⎢ ⎥ + ⎢ cos(4⋅θ) 1 sin(4⋅θ) ⎥ + ⎢0 ──────── + ─ ──────── -sin(2⋅θ)⎥ + ⎢ 2 2 2 ⎥ + ⎢ ⎥ + ⎢ sin(4⋅θ) 1 cos(4⋅θ) ⎥ + ⎢0 ──────── ─ - ──────── cos(2⋅θ) ⎥ + ⎢ 2 2 2 ⎥ + ⎢ ⎥ + ⎣0 sin(2⋅θ) -cos(2⋅θ) 0 ⎦ + + """ + A = Matrix([[1, 0, 0, 1], + [1, 0, 0, -1], + [0, 1, 1, 0], + [0, -I, I, 0]]) + + return simplify(A*TensorProduct(J, J.conjugate())*A.inv()) + + +def polarizing_beam_splitter(Tp=1, Rs=1, Ts=0, Rp=0, phia=0, phib=0): + r"""A polarizing beam splitter Jones matrix at angle `theta`. + + Parameters + ========== + + J : SymPy Matrix + A Jones matrix. + Tp : numeric type or SymPy Symbol + The transmissivity of the P-polarized component. + Rs : numeric type or SymPy Symbol + The reflectivity of the S-polarized component. + Ts : numeric type or SymPy Symbol + The transmissivity of the S-polarized component. + Rp : numeric type or SymPy Symbol + The reflectivity of the P-polarized component. + phia : numeric type or SymPy Symbol + The phase difference between transmitted and reflected component for + output mode a. + phib : numeric type or SymPy Symbol + The phase difference between transmitted and reflected component for + output mode b. + + + Returns + ======= + + SymPy Matrix + A 4x4 matrix representing the PBS. This matrix acts on a 4x1 vector + whose first two entries are the Jones vector on one of the PBS ports, + and the last two entries the Jones vector on the other port. + + Examples + ======== + + Generic polarizing beam-splitter. + + >>> from sympy import pprint, symbols + >>> from sympy.physics.optics.polarization import polarizing_beam_splitter + >>> Ts, Rs, Tp, Rp = symbols(r"Ts, Rs, Tp, Rp", positive=True) + >>> phia, phib = symbols("phi_a, phi_b", real=True) + >>> PBS = polarizing_beam_splitter(Tp, Rs, Ts, Rp, phia, phib) + >>> pprint(PBS, use_unicode=False) + [ ____ ____ ] + [ \/ Tp 0 I*\/ Rp 0 ] + [ ] + [ ____ ____ I*phi_a] + [ 0 \/ Ts 0 -I*\/ Rs *e ] + [ ] + [ ____ ____ ] + [I*\/ Rp 0 \/ Tp 0 ] + [ ] + [ ____ I*phi_b ____ ] + [ 0 -I*\/ Rs *e 0 \/ Ts ] + + """ + PBS = Matrix([[sqrt(Tp), 0, I*sqrt(Rp), 0], + [0, sqrt(Ts), 0, -I*sqrt(Rs)*exp(I*phia)], + [I*sqrt(Rp), 0, sqrt(Tp), 0], + [0, -I*sqrt(Rs)*exp(I*phib), 0, sqrt(Ts)]]) + return PBS diff --git a/phivenv/Lib/site-packages/sympy/physics/optics/tests/__init__.py b/phivenv/Lib/site-packages/sympy/physics/optics/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/physics/optics/tests/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/optics/tests/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2d5777b994d82d9f265d9662f5265b5574c2007 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/optics/tests/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/optics/tests/__pycache__/test_gaussopt.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/optics/tests/__pycache__/test_gaussopt.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5b930850425594056de9018eacb6e993304335c Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/optics/tests/__pycache__/test_gaussopt.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/optics/tests/__pycache__/test_medium.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/optics/tests/__pycache__/test_medium.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01a29e35a778e6aca9529618084f7d5cf94e3fa3 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/optics/tests/__pycache__/test_medium.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/optics/tests/__pycache__/test_polarization.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/optics/tests/__pycache__/test_polarization.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be7bff2df5ce944d4c4710772cd99825d2e62488 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/optics/tests/__pycache__/test_polarization.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/optics/tests/__pycache__/test_utils.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/optics/tests/__pycache__/test_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e31fcfe18a20372ef8f7a6ef971f9f71a1d6de0c Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/optics/tests/__pycache__/test_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/optics/tests/__pycache__/test_waves.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/optics/tests/__pycache__/test_waves.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce2046fa0bd4cd6b858dcc5983b64c200ad22ff1 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/optics/tests/__pycache__/test_waves.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/optics/tests/test_gaussopt.py b/phivenv/Lib/site-packages/sympy/physics/optics/tests/test_gaussopt.py new file mode 100644 index 0000000000000000000000000000000000000000..5271f3cbb69cf5de861ff332d36418b79daeb1b5 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/optics/tests/test_gaussopt.py @@ -0,0 +1,102 @@ +from sympy.core.evalf import N +from sympy.core.numbers import (Float, I, oo, pi) +from sympy.core.symbol import symbols +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import atan2 +from sympy.matrices.dense import Matrix +from sympy.polys.polytools import factor + +from sympy.physics.optics import (BeamParameter, CurvedMirror, + CurvedRefraction, FlatMirror, FlatRefraction, FreeSpace, GeometricRay, + RayTransferMatrix, ThinLens, conjugate_gauss_beams, + gaussian_conj, geometric_conj_ab, geometric_conj_af, geometric_conj_bf, + rayleigh2waist, waist2rayleigh) + + +def streq(a, b): + return str(a) == str(b) + + +def test_gauss_opt(): + mat = RayTransferMatrix(1, 2, 3, 4) + assert mat == Matrix([[1, 2], [3, 4]]) + assert mat == RayTransferMatrix( Matrix([[1, 2], [3, 4]]) ) + assert [mat.A, mat.B, mat.C, mat.D] == [1, 2, 3, 4] + + d, f, h, n1, n2, R = symbols('d f h n1 n2 R') + lens = ThinLens(f) + assert lens == Matrix([[ 1, 0], [-1/f, 1]]) + assert lens.C == -1/f + assert FreeSpace(d) == Matrix([[ 1, d], [0, 1]]) + assert FlatRefraction(n1, n2) == Matrix([[1, 0], [0, n1/n2]]) + assert CurvedRefraction( + R, n1, n2) == Matrix([[1, 0], [(n1 - n2)/(R*n2), n1/n2]]) + assert FlatMirror() == Matrix([[1, 0], [0, 1]]) + assert CurvedMirror(R) == Matrix([[ 1, 0], [-2/R, 1]]) + assert ThinLens(f) == Matrix([[ 1, 0], [-1/f, 1]]) + + mul = CurvedMirror(R)*FreeSpace(d) + mul_mat = Matrix([[ 1, 0], [-2/R, 1]])*Matrix([[ 1, d], [0, 1]]) + assert mul.A == mul_mat[0, 0] + assert mul.B == mul_mat[0, 1] + assert mul.C == mul_mat[1, 0] + assert mul.D == mul_mat[1, 1] + + angle = symbols('angle') + assert GeometricRay(h, angle) == Matrix([[ h], [angle]]) + assert FreeSpace( + d)*GeometricRay(h, angle) == Matrix([[angle*d + h], [angle]]) + assert GeometricRay( Matrix( ((h,), (angle,)) ) ) == Matrix([[h], [angle]]) + assert (FreeSpace(d)*GeometricRay(h, angle)).height == angle*d + h + assert (FreeSpace(d)*GeometricRay(h, angle)).angle == angle + + p = BeamParameter(530e-9, 1, w=1e-3) + assert streq(p.q, 1 + 1.88679245283019*I*pi) + assert streq(N(p.q), 1.0 + 5.92753330865999*I) + assert streq(N(p.w_0), Float(0.00100000000000000)) + assert streq(N(p.z_r), Float(5.92753330865999)) + fs = FreeSpace(10) + p1 = fs*p + assert streq(N(p.w), Float(0.00101413072159615)) + assert streq(N(p1.w), Float(0.00210803120913829)) + + w, wavelen = symbols('w wavelen') + assert waist2rayleigh(w, wavelen) == pi*w**2/wavelen + z_r, wavelen = symbols('z_r wavelen') + assert rayleigh2waist(z_r, wavelen) == sqrt(wavelen*z_r)/sqrt(pi) + + a, b, f = symbols('a b f') + assert geometric_conj_ab(a, b) == a*b/(a + b) + assert geometric_conj_af(a, f) == a*f/(a - f) + assert geometric_conj_bf(b, f) == b*f/(b - f) + assert geometric_conj_ab(oo, b) == b + assert geometric_conj_ab(a, oo) == a + + s_in, z_r_in, f = symbols('s_in z_r_in f') + assert gaussian_conj( + s_in, z_r_in, f)[0] == 1/(-1/(s_in + z_r_in**2/(-f + s_in)) + 1/f) + assert gaussian_conj( + s_in, z_r_in, f)[1] == z_r_in/(1 - s_in**2/f**2 + z_r_in**2/f**2) + assert gaussian_conj( + s_in, z_r_in, f)[2] == 1/sqrt(1 - s_in**2/f**2 + z_r_in**2/f**2) + + l, w_i, w_o, f = symbols('l w_i w_o f') + assert conjugate_gauss_beams(l, w_i, w_o, f=f)[0] == f*( + -sqrt(w_i**2/w_o**2 - pi**2*w_i**4/(f**2*l**2)) + 1) + assert factor(conjugate_gauss_beams(l, w_i, w_o, f=f)[1]) == f*w_o**2*( + w_i**2/w_o**2 - sqrt(w_i**2/w_o**2 - pi**2*w_i**4/(f**2*l**2)))/w_i**2 + assert conjugate_gauss_beams(l, w_i, w_o, f=f)[2] == f + + z, l, w_0 = symbols('z l w_0', positive=True) + p = BeamParameter(l, z, w=w_0) + assert p.radius == z*(pi**2*w_0**4/(l**2*z**2) + 1) + assert p.w == w_0*sqrt(l**2*z**2/(pi**2*w_0**4) + 1) + assert p.w_0 == w_0 + assert p.divergence == l/(pi*w_0) + assert p.gouy == atan2(z, pi*w_0**2/l) + assert p.waist_approximation_limit == 2*l/pi + + p = BeamParameter(530e-9, 1, w=1e-3, n=2) + assert streq(p.q, 1 + 3.77358490566038*I*pi) + assert streq(N(p.z_r), Float(11.8550666173200)) + assert streq(N(p.w_0), Float(0.00100000000000000)) diff --git a/phivenv/Lib/site-packages/sympy/physics/optics/tests/test_medium.py b/phivenv/Lib/site-packages/sympy/physics/optics/tests/test_medium.py new file mode 100644 index 0000000000000000000000000000000000000000..dfbb485f5b8e401f38c7f1cfa573f960a2479d7b --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/optics/tests/test_medium.py @@ -0,0 +1,48 @@ +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.physics.optics import Medium +from sympy.abc import epsilon, mu, n +from sympy.physics.units import speed_of_light, u0, e0, m, kg, s, A + +from sympy.testing.pytest import raises + +c = speed_of_light.convert_to(m/s) +e0 = e0.convert_to(A**2*s**4/(kg*m**3)) +u0 = u0.convert_to(m*kg/(A**2*s**2)) + + +def test_medium(): + m1 = Medium('m1') + assert m1.intrinsic_impedance == sqrt(u0/e0) + assert m1.speed == 1/sqrt(e0*u0) + assert m1.refractive_index == c*sqrt(e0*u0) + assert m1.permittivity == e0 + assert m1.permeability == u0 + m2 = Medium('m2', epsilon, mu) + assert m2.intrinsic_impedance == sqrt(mu/epsilon) + assert m2.speed == 1/sqrt(epsilon*mu) + assert m2.refractive_index == c*sqrt(epsilon*mu) + assert m2.permittivity == epsilon + assert m2.permeability == mu + # Increasing electric permittivity and magnetic permeability + # by small amount from its value in vacuum. + m3 = Medium('m3', 9.0*10**(-12)*s**4*A**2/(m**3*kg), 1.45*10**(-6)*kg*m/(A**2*s**2)) + assert m3.refractive_index > m1.refractive_index + assert m3 != m1 + # Decreasing electric permittivity and magnetic permeability + # by small amount from its value in vacuum. + m4 = Medium('m4', 7.0*10**(-12)*s**4*A**2/(m**3*kg), 1.15*10**(-6)*kg*m/(A**2*s**2)) + assert m4.refractive_index < m1.refractive_index + m5 = Medium('m5', permittivity=710*10**(-12)*s**4*A**2/(m**3*kg), n=1.33) + assert abs(m5.intrinsic_impedance - 6.24845417765552*kg*m**2/(A**2*s**3)) \ + < 1e-12*kg*m**2/(A**2*s**3) + assert abs(m5.speed - 225407863.157895*m/s) < 1e-6*m/s + assert abs(m5.refractive_index - 1.33000000000000) < 1e-12 + assert abs(m5.permittivity - 7.1e-10*A**2*s**4/(kg*m**3)) \ + < 1e-20*A**2*s**4/(kg*m**3) + assert abs(m5.permeability - 2.77206575232851e-8*kg*m/(A**2*s**2)) \ + < 1e-20*kg*m/(A**2*s**2) + m6 = Medium('m6', None, mu, n) + assert m6.permittivity == n**2/(c**2*mu) + # test for equality of refractive indices + assert Medium('m7').refractive_index == Medium('m8', e0, u0).refractive_index + raises(ValueError, lambda:Medium('m9', e0, u0, 2)) diff --git a/phivenv/Lib/site-packages/sympy/physics/optics/tests/test_polarization.py b/phivenv/Lib/site-packages/sympy/physics/optics/tests/test_polarization.py new file mode 100644 index 0000000000000000000000000000000000000000..99c595d82a4a296066d5075f6182895a8de54d91 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/optics/tests/test_polarization.py @@ -0,0 +1,57 @@ +from sympy.physics.optics.polarization import (jones_vector, stokes_vector, + jones_2_stokes, linear_polarizer, phase_retarder, half_wave_retarder, + quarter_wave_retarder, transmissive_filter, reflective_filter, + mueller_matrix, polarizing_beam_splitter) +from sympy.core.numbers import (I, pi) +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.exponential import exp +from sympy.matrices.dense import Matrix + + +def test_polarization(): + assert jones_vector(0, 0) == Matrix([1, 0]) + assert jones_vector(pi/2, 0) == Matrix([0, 1]) + ################################################################# + assert stokes_vector(0, 0) == Matrix([1, 1, 0, 0]) + assert stokes_vector(pi/2, 0) == Matrix([1, -1, 0, 0]) + ################################################################# + H = jones_vector(0, 0) + V = jones_vector(pi/2, 0) + D = jones_vector(pi/4, 0) + A = jones_vector(-pi/4, 0) + R = jones_vector(0, pi/4) + L = jones_vector(0, -pi/4) + + res = [Matrix([1, 1, 0, 0]), + Matrix([1, -1, 0, 0]), + Matrix([1, 0, 1, 0]), + Matrix([1, 0, -1, 0]), + Matrix([1, 0, 0, 1]), + Matrix([1, 0, 0, -1])] + + assert [jones_2_stokes(e) for e in [H, V, D, A, R, L]] == res + ################################################################# + assert linear_polarizer(0) == Matrix([[1, 0], [0, 0]]) + ################################################################# + delta = symbols("delta", real=True) + res = Matrix([[exp(-I*delta/2), 0], [0, exp(I*delta/2)]]) + assert phase_retarder(0, delta) == res + ################################################################# + assert half_wave_retarder(0) == Matrix([[-I, 0], [0, I]]) + ################################################################# + res = Matrix([[exp(-I*pi/4), 0], [0, I*exp(-I*pi/4)]]) + assert quarter_wave_retarder(0) == res + ################################################################# + assert transmissive_filter(1) == Matrix([[1, 0], [0, 1]]) + ################################################################# + assert reflective_filter(1) == Matrix([[1, 0], [0, -1]]) + + res = Matrix([[S(1)/2, S(1)/2, 0, 0], + [S(1)/2, S(1)/2, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0]]) + assert mueller_matrix(linear_polarizer(0)) == res + ################################################################# + res = Matrix([[1, 0, 0, 0], [0, 0, 0, -I], [0, 0, 1, 0], [0, -I, 0, 0]]) + assert polarizing_beam_splitter() == res diff --git a/phivenv/Lib/site-packages/sympy/physics/optics/tests/test_utils.py b/phivenv/Lib/site-packages/sympy/physics/optics/tests/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6c93883a081d3614a604aeadc8a4b617181de669 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/optics/tests/test_utils.py @@ -0,0 +1,202 @@ +from sympy.core.numbers import comp, Rational +from sympy.physics.optics.utils import (refraction_angle, fresnel_coefficients, + deviation, brewster_angle, critical_angle, lens_makers_formula, + mirror_formula, lens_formula, hyperfocal_distance, + transverse_magnification) +from sympy.physics.optics.medium import Medium +from sympy.physics.units import e0 + +from sympy.core.numbers import oo +from sympy.core.symbol import symbols +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.matrices.dense import Matrix +from sympy.geometry.point import Point3D +from sympy.geometry.line import Ray3D +from sympy.geometry.plane import Plane + +from sympy.testing.pytest import raises + + +ae = lambda a, b, n: comp(a, b, 10**-n) + + +def test_refraction_angle(): + n1, n2 = symbols('n1, n2') + m1 = Medium('m1') + m2 = Medium('m2') + r1 = Ray3D(Point3D(-1, -1, 1), Point3D(0, 0, 0)) + i = Matrix([1, 1, 1]) + n = Matrix([0, 0, 1]) + normal_ray = Ray3D(Point3D(0, 0, 0), Point3D(0, 0, 1)) + P = Plane(Point3D(0, 0, 0), normal_vector=[0, 0, 1]) + assert refraction_angle(r1, 1, 1, n) == Matrix([ + [ 1], + [ 1], + [-1]]) + assert refraction_angle([1, 1, 1], 1, 1, n) == Matrix([ + [ 1], + [ 1], + [-1]]) + assert refraction_angle((1, 1, 1), 1, 1, n) == Matrix([ + [ 1], + [ 1], + [-1]]) + assert refraction_angle(i, 1, 1, [0, 0, 1]) == Matrix([ + [ 1], + [ 1], + [-1]]) + assert refraction_angle(i, 1, 1, (0, 0, 1)) == Matrix([ + [ 1], + [ 1], + [-1]]) + assert refraction_angle(i, 1, 1, normal_ray) == Matrix([ + [ 1], + [ 1], + [-1]]) + assert refraction_angle(i, 1, 1, plane=P) == Matrix([ + [ 1], + [ 1], + [-1]]) + assert refraction_angle(r1, 1, 1, plane=P) == \ + Ray3D(Point3D(0, 0, 0), Point3D(1, 1, -1)) + assert refraction_angle(r1, m1, 1.33, plane=P) == \ + Ray3D(Point3D(0, 0, 0), Point3D(Rational(100, 133), Rational(100, 133), -789378201649271*sqrt(3)/1000000000000000)) + assert refraction_angle(r1, 1, m2, plane=P) == \ + Ray3D(Point3D(0, 0, 0), Point3D(1, 1, -1)) + assert refraction_angle(r1, n1, n2, plane=P) == \ + Ray3D(Point3D(0, 0, 0), Point3D(n1/n2, n1/n2, -sqrt(3)*sqrt(-2*n1**2/(3*n2**2) + 1))) + assert refraction_angle(r1, 1.33, 1, plane=P) == 0 # TIR + assert refraction_angle(r1, 1, 1, normal_ray) == \ + Ray3D(Point3D(0, 0, 0), direction_ratio=[1, 1, -1]) + assert ae(refraction_angle(0.5, 1, 2), 0.24207, 5) + assert ae(refraction_angle(0.5, 2, 1), 1.28293, 5) + raises(ValueError, lambda: refraction_angle(r1, m1, m2, normal_ray, P)) + raises(TypeError, lambda: refraction_angle(m1, m1, m2)) # can add other values for arg[0] + raises(TypeError, lambda: refraction_angle(r1, m1, m2, None, i)) + raises(TypeError, lambda: refraction_angle(r1, m1, m2, m2)) + + +def test_fresnel_coefficients(): + assert all(ae(i, j, 5) for i, j in zip( + fresnel_coefficients(0.5, 1, 1.33), + [0.11163, -0.17138, 0.83581, 0.82862])) + assert all(ae(i, j, 5) for i, j in zip( + fresnel_coefficients(0.5, 1.33, 1), + [-0.07726, 0.20482, 1.22724, 1.20482])) + m1 = Medium('m1') + m2 = Medium('m2', n=2) + assert all(ae(i, j, 5) for i, j in zip( + fresnel_coefficients(0.3, m1, m2), + [0.31784, -0.34865, 0.65892, 0.65135])) + ans = [[-0.23563, -0.97184], [0.81648, -0.57738]] + got = fresnel_coefficients(0.6, m2, m1) + for i, j in zip(got, ans): + for a, b in zip(i.as_real_imag(), j): + assert ae(a, b, 5) + + +def test_deviation(): + n1, n2 = symbols('n1, n2') + r1 = Ray3D(Point3D(-1, -1, 1), Point3D(0, 0, 0)) + n = Matrix([0, 0, 1]) + i = Matrix([-1, -1, -1]) + normal_ray = Ray3D(Point3D(0, 0, 0), Point3D(0, 0, 1)) + P = Plane(Point3D(0, 0, 0), normal_vector=[0, 0, 1]) + assert deviation(r1, 1, 1, normal=n) == 0 + assert deviation(r1, 1, 1, plane=P) == 0 + assert deviation(r1, 1, 1.1, plane=P).evalf(3) + 0.119 < 1e-3 + assert deviation(i, 1, 1.1, normal=normal_ray).evalf(3) + 0.119 < 1e-3 + assert deviation(r1, 1.33, 1, plane=P) is None # TIR + assert deviation(r1, 1, 1, normal=[0, 0, 1]) == 0 + assert deviation([-1, -1, -1], 1, 1, normal=[0, 0, 1]) == 0 + assert ae(deviation(0.5, 1, 2), -0.25793, 5) + assert ae(deviation(0.5, 2, 1), 0.78293, 5) + + +def test_brewster_angle(): + m1 = Medium('m1', n=1) + m2 = Medium('m2', n=1.33) + assert ae(brewster_angle(m1, m2), 0.93, 2) + m1 = Medium('m1', permittivity=e0, n=1) + m2 = Medium('m2', permittivity=e0, n=1.33) + assert ae(brewster_angle(m1, m2), 0.93, 2) + assert ae(brewster_angle(1, 1.33), 0.93, 2) + + +def test_critical_angle(): + m1 = Medium('m1', n=1) + m2 = Medium('m2', n=1.33) + assert ae(critical_angle(m2, m1), 0.85, 2) + + +def test_lens_makers_formula(): + n1, n2 = symbols('n1, n2') + m1 = Medium('m1', permittivity=e0, n=1) + m2 = Medium('m2', permittivity=e0, n=1.33) + assert lens_makers_formula(n1, n2, 10, -10) == 5.0*n2/(n1 - n2) + assert ae(lens_makers_formula(m1, m2, 10, -10), -20.15, 2) + assert ae(lens_makers_formula(1.33, 1, 10, -10), 15.15, 2) + + +def test_mirror_formula(): + u, v, f = symbols('u, v, f') + assert mirror_formula(focal_length=f, u=u) == f*u/(-f + u) + assert mirror_formula(focal_length=f, v=v) == f*v/(-f + v) + assert mirror_formula(u=u, v=v) == u*v/(u + v) + assert mirror_formula(u=oo, v=v) == v + assert mirror_formula(u=oo, v=oo) is oo + assert mirror_formula(focal_length=oo, u=u) == -u + assert mirror_formula(u=u, v=oo) == u + assert mirror_formula(focal_length=oo, v=oo) is oo + assert mirror_formula(focal_length=f, v=oo) == f + assert mirror_formula(focal_length=oo, v=v) == -v + assert mirror_formula(focal_length=oo, u=oo) is oo + assert mirror_formula(focal_length=f, u=oo) == f + assert mirror_formula(focal_length=oo, u=u) == -u + raises(ValueError, lambda: mirror_formula(focal_length=f, u=u, v=v)) + + +def test_lens_formula(): + u, v, f = symbols('u, v, f') + assert lens_formula(focal_length=f, u=u) == f*u/(f + u) + assert lens_formula(focal_length=f, v=v) == f*v/(f - v) + assert lens_formula(u=u, v=v) == u*v/(u - v) + assert lens_formula(u=oo, v=v) == v + assert lens_formula(u=oo, v=oo) is oo + assert lens_formula(focal_length=oo, u=u) == u + assert lens_formula(u=u, v=oo) == -u + assert lens_formula(focal_length=oo, v=oo) is -oo + assert lens_formula(focal_length=oo, v=v) == v + assert lens_formula(focal_length=f, v=oo) == -f + assert lens_formula(focal_length=oo, u=oo) is oo + assert lens_formula(focal_length=oo, u=u) == u + assert lens_formula(focal_length=f, u=oo) == f + raises(ValueError, lambda: lens_formula(focal_length=f, u=u, v=v)) + + +def test_hyperfocal_distance(): + f, N, c = symbols('f, N, c') + assert hyperfocal_distance(f=f, N=N, c=c) == f**2/(N*c) + assert ae(hyperfocal_distance(f=0.5, N=8, c=0.0033), 9.47, 2) + + +def test_transverse_magnification(): + si, so = symbols('si, so') + assert transverse_magnification(si, so) == -si/so + assert transverse_magnification(30, 15) == -2 + + +def test_lens_makers_formula_thick_lens(): + n1, n2 = symbols('n1, n2') + m1 = Medium('m1', permittivity=e0, n=1) + m2 = Medium('m2', permittivity=e0, n=1.33) + assert ae(lens_makers_formula(m1, m2, 10, -10, d=1), -19.82, 2) + assert lens_makers_formula(n1, n2, 1, -1, d=0.1) == n2/((2.0 - (0.1*n1 - 0.1*n2)/n1)*(n1 - n2)) + + +def test_lens_makers_formula_plano_lens(): + n1, n2 = symbols('n1, n2') + m1 = Medium('m1', permittivity=e0, n=1) + m2 = Medium('m2', permittivity=e0, n=1.33) + assert ae(lens_makers_formula(m1, m2, 10, oo), -40.30, 2) + assert lens_makers_formula(n1, n2, 10, oo) == 10.0*n2/(n1 - n2) diff --git a/phivenv/Lib/site-packages/sympy/physics/optics/tests/test_waves.py b/phivenv/Lib/site-packages/sympy/physics/optics/tests/test_waves.py new file mode 100644 index 0000000000000000000000000000000000000000..3cb8f804fb5be86d6174cb7c7b15fd8979c85ff8 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/optics/tests/test_waves.py @@ -0,0 +1,82 @@ +from sympy.core.function import (Derivative, Function) +from sympy.core.numbers import (I, pi) +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (atan2, cos, sin) +from sympy.simplify.simplify import simplify +from sympy.abc import epsilon, mu +from sympy.functions.elementary.exponential import exp +from sympy.physics.units import speed_of_light, m, s +from sympy.physics.optics import TWave + +from sympy.testing.pytest import raises + +c = speed_of_light.convert_to(m/s) + +def test_twave(): + A1, phi1, A2, phi2, f = symbols('A1, phi1, A2, phi2, f') + n = Symbol('n') # Refractive index + t = Symbol('t') # Time + x = Symbol('x') # Spatial variable + E = Function('E') + w1 = TWave(A1, f, phi1) + w2 = TWave(A2, f, phi2) + assert w1.amplitude == A1 + assert w1.frequency == f + assert w1.phase == phi1 + assert w1.wavelength == c/(f*n) + assert w1.time_period == 1/f + assert w1.angular_velocity == 2*pi*f + assert w1.wavenumber == 2*pi*f*n/c + assert w1.speed == c/n + + w3 = w1 + w2 + assert w3.amplitude == sqrt(A1**2 + 2*A1*A2*cos(phi1 - phi2) + A2**2) + assert w3.frequency == f + assert w3.phase == atan2(A1*sin(phi1) + A2*sin(phi2), A1*cos(phi1) + A2*cos(phi2)) + assert w3.wavelength == c/(f*n) + assert w3.time_period == 1/f + assert w3.angular_velocity == 2*pi*f + assert w3.wavenumber == 2*pi*f*n/c + assert w3.speed == c/n + assert simplify(w3.rewrite(sin) - w2.rewrite(sin) - w1.rewrite(sin)) == 0 + assert w3.rewrite('pde') == epsilon*mu*Derivative(E(x, t), t, t) + Derivative(E(x, t), x, x) + assert w3.rewrite(cos) == sqrt(A1**2 + 2*A1*A2*cos(phi1 - phi2) + + A2**2)*cos(pi*f*n*x*s/(149896229*m) - 2*pi*f*t + atan2(A1*sin(phi1) + + A2*sin(phi2), A1*cos(phi1) + A2*cos(phi2))) + assert w3.rewrite(exp) == sqrt(A1**2 + 2*A1*A2*cos(phi1 - phi2) + + A2**2)*exp(I*(-2*pi*f*t + atan2(A1*sin(phi1) + A2*sin(phi2), A1*cos(phi1) + + A2*cos(phi2)) + pi*s*f*n*x/(149896229*m))) + + w4 = TWave(A1, None, 0, 1/f) + assert w4.frequency == f + + w5 = w1 - w2 + assert w5.amplitude == sqrt(A1**2 - 2*A1*A2*cos(phi1 - phi2) + A2**2) + assert w5.frequency == f + assert w5.phase == atan2(A1*sin(phi1) - A2*sin(phi2), A1*cos(phi1) - A2*cos(phi2)) + assert w5.wavelength == c/(f*n) + assert w5.time_period == 1/f + assert w5.angular_velocity == 2*pi*f + assert w5.wavenumber == 2*pi*f*n/c + assert w5.speed == c/n + assert simplify(w5.rewrite(sin) - w1.rewrite(sin) + w2.rewrite(sin)) == 0 + assert w5.rewrite('pde') == epsilon*mu*Derivative(E(x, t), t, t) + Derivative(E(x, t), x, x) + assert w5.rewrite(cos) == sqrt(A1**2 - 2*A1*A2*cos(phi1 - phi2) + + A2**2)*cos(-2*pi*f*t + atan2(A1*sin(phi1) - A2*sin(phi2), A1*cos(phi1) + - A2*cos(phi2)) + pi*s*f*n*x/(149896229*m)) + assert w5.rewrite(exp) == sqrt(A1**2 - 2*A1*A2*cos(phi1 - phi2) + + A2**2)*exp(I*(-2*pi*f*t + atan2(A1*sin(phi1) - A2*sin(phi2), A1*cos(phi1) + - A2*cos(phi2)) + pi*s*f*n*x/(149896229*m))) + + w6 = 2*w1 + assert w6.amplitude == 2*A1 + assert w6.frequency == f + assert w6.phase == phi1 + w7 = -w6 + assert w7.amplitude == -2*A1 + assert w7.frequency == f + assert w7.phase == phi1 + + raises(ValueError, lambda:TWave(A1)) + raises(ValueError, lambda:TWave(A1, f, phi1, t)) diff --git a/phivenv/Lib/site-packages/sympy/physics/optics/utils.py b/phivenv/Lib/site-packages/sympy/physics/optics/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..72c3b78bd4b09eb069757fb3f8d3632f09ec4b80 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/optics/utils.py @@ -0,0 +1,698 @@ +""" +**Contains** + +* refraction_angle +* fresnel_coefficients +* deviation +* brewster_angle +* critical_angle +* lens_makers_formula +* mirror_formula +* lens_formula +* hyperfocal_distance +* transverse_magnification +""" + +__all__ = ['refraction_angle', + 'deviation', + 'fresnel_coefficients', + 'brewster_angle', + 'critical_angle', + 'lens_makers_formula', + 'mirror_formula', + 'lens_formula', + 'hyperfocal_distance', + 'transverse_magnification' + ] + +from sympy.core.numbers import (Float, I, oo, pi, zoo) +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.core.sympify import sympify +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (acos, asin, atan2, cos, sin, tan) +from sympy.matrices.dense import Matrix +from sympy.polys.polytools import cancel +from sympy.series.limits import Limit +from sympy.geometry.line import Ray3D +from sympy.geometry.util import intersection +from sympy.geometry.plane import Plane +from sympy.utilities.iterables import is_sequence +from .medium import Medium + + +def refractive_index_of_medium(medium): + """ + Helper function that returns refractive index, given a medium + """ + if isinstance(medium, Medium): + n = medium.refractive_index + else: + n = sympify(medium) + return n + + +def refraction_angle(incident, medium1, medium2, normal=None, plane=None): + """ + This function calculates transmitted vector after refraction at planar + surface. ``medium1`` and ``medium2`` can be ``Medium`` or any sympifiable object. + If ``incident`` is a number then treated as angle of incidence (in radians) + in which case refraction angle is returned. + + If ``incident`` is an object of `Ray3D`, `normal` also has to be an instance + of `Ray3D` in order to get the output as a `Ray3D`. Please note that if + plane of separation is not provided and normal is an instance of `Ray3D`, + ``normal`` will be assumed to be intersecting incident ray at the plane of + separation. This will not be the case when `normal` is a `Matrix` or + any other sequence. + If ``incident`` is an instance of `Ray3D` and `plane` has not been provided + and ``normal`` is not `Ray3D`, output will be a `Matrix`. + + Parameters + ========== + + incident : Matrix, Ray3D, sequence or a number + Incident vector or angle of incidence + medium1 : sympy.physics.optics.medium.Medium or sympifiable + Medium 1 or its refractive index + medium2 : sympy.physics.optics.medium.Medium or sympifiable + Medium 2 or its refractive index + normal : Matrix, Ray3D, or sequence + Normal vector + plane : Plane + Plane of separation of the two media. + + Returns + ======= + + Returns an angle of refraction or a refracted ray depending on inputs. + + Examples + ======== + + >>> from sympy.physics.optics import refraction_angle + >>> from sympy.geometry import Point3D, Ray3D, Plane + >>> from sympy.matrices import Matrix + >>> from sympy import symbols, pi + >>> n = Matrix([0, 0, 1]) + >>> P = Plane(Point3D(0, 0, 0), normal_vector=[0, 0, 1]) + >>> r1 = Ray3D(Point3D(-1, -1, 1), Point3D(0, 0, 0)) + >>> refraction_angle(r1, 1, 1, n) + Matrix([ + [ 1], + [ 1], + [-1]]) + >>> refraction_angle(r1, 1, 1, plane=P) + Ray3D(Point3D(0, 0, 0), Point3D(1, 1, -1)) + + With different index of refraction of the two media + + >>> n1, n2 = symbols('n1, n2') + >>> refraction_angle(r1, n1, n2, n) + Matrix([ + [ n1/n2], + [ n1/n2], + [-sqrt(3)*sqrt(-2*n1**2/(3*n2**2) + 1)]]) + >>> refraction_angle(r1, n1, n2, plane=P) + Ray3D(Point3D(0, 0, 0), Point3D(n1/n2, n1/n2, -sqrt(3)*sqrt(-2*n1**2/(3*n2**2) + 1))) + >>> round(refraction_angle(pi/6, 1.2, 1.5), 5) + 0.41152 + """ + + n1 = refractive_index_of_medium(medium1) + n2 = refractive_index_of_medium(medium2) + + # check if an incidence angle was supplied instead of a ray + try: + angle_of_incidence = float(incident) + except TypeError: + angle_of_incidence = None + + try: + critical_angle_ = critical_angle(medium1, medium2) + except (ValueError, TypeError): + critical_angle_ = None + + if angle_of_incidence is not None: + if normal is not None or plane is not None: + raise ValueError('Normal/plane not allowed if incident is an angle') + + if not 0.0 <= angle_of_incidence < pi*0.5: + raise ValueError('Angle of incidence not in range [0:pi/2)') + + if critical_angle_ and angle_of_incidence > critical_angle_: + raise ValueError('Ray undergoes total internal reflection') + return asin(n1*sin(angle_of_incidence)/n2) + + # Treat the incident as ray below + # A flag to check whether to return Ray3D or not + return_ray = False + + if plane is not None and normal is not None: + raise ValueError("Either plane or normal is acceptable.") + + if not isinstance(incident, Matrix): + if is_sequence(incident): + _incident = Matrix(incident) + elif isinstance(incident, Ray3D): + _incident = Matrix(incident.direction_ratio) + else: + raise TypeError( + "incident should be a Matrix, Ray3D, or sequence") + else: + _incident = incident + + # If plane is provided, get direction ratios of the normal + # to the plane from the plane else go with `normal` param. + if plane is not None: + if not isinstance(plane, Plane): + raise TypeError("plane should be an instance of geometry.plane.Plane") + # If we have the plane, we can get the intersection + # point of incident ray and the plane and thus return + # an instance of Ray3D. + if isinstance(incident, Ray3D): + return_ray = True + intersection_pt = plane.intersection(incident)[0] + _normal = Matrix(plane.normal_vector) + else: + if not isinstance(normal, Matrix): + if is_sequence(normal): + _normal = Matrix(normal) + elif isinstance(normal, Ray3D): + _normal = Matrix(normal.direction_ratio) + if isinstance(incident, Ray3D): + intersection_pt = intersection(incident, normal) + if len(intersection_pt) == 0: + raise ValueError( + "Normal isn't concurrent with the incident ray.") + else: + return_ray = True + intersection_pt = intersection_pt[0] + else: + raise TypeError( + "Normal should be a Matrix, Ray3D, or sequence") + else: + _normal = normal + + eta = n1/n2 # Relative index of refraction + # Calculating magnitude of the vectors + mag_incident = sqrt(sum(i**2 for i in _incident)) + mag_normal = sqrt(sum(i**2 for i in _normal)) + # Converting vectors to unit vectors by dividing + # them with their magnitudes + _incident /= mag_incident + _normal /= mag_normal + c1 = -_incident.dot(_normal) # cos(angle_of_incidence) + cs2 = 1 - eta**2*(1 - c1**2) # cos(angle_of_refraction)**2 + if cs2.is_negative: # This is the case of total internal reflection(TIR). + return S.Zero + drs = eta*_incident + (eta*c1 - sqrt(cs2))*_normal + # Multiplying unit vector by its magnitude + drs = drs*mag_incident + if not return_ray: + return drs + else: + return Ray3D(intersection_pt, direction_ratio=drs) + + +def fresnel_coefficients(angle_of_incidence, medium1, medium2): + """ + This function uses Fresnel equations to calculate reflection and + transmission coefficients. Those are obtained for both polarisations + when the electric field vector is in the plane of incidence (labelled 'p') + and when the electric field vector is perpendicular to the plane of + incidence (labelled 's'). There are four real coefficients unless the + incident ray reflects in total internal in which case there are two complex + ones. Angle of incidence is the angle between the incident ray and the + surface normal. ``medium1`` and ``medium2`` can be ``Medium`` or any + sympifiable object. + + Parameters + ========== + + angle_of_incidence : sympifiable + + medium1 : Medium or sympifiable + Medium 1 or its refractive index + + medium2 : Medium or sympifiable + Medium 2 or its refractive index + + Returns + ======= + + Returns a list with four real Fresnel coefficients: + [reflection p (TM), reflection s (TE), + transmission p (TM), transmission s (TE)] + If the ray is undergoes total internal reflection then returns a + list of two complex Fresnel coefficients: + [reflection p (TM), reflection s (TE)] + + Examples + ======== + + >>> from sympy.physics.optics import fresnel_coefficients + >>> fresnel_coefficients(0.3, 1, 2) + [0.317843553417859, -0.348645229818821, + 0.658921776708929, 0.651354770181179] + >>> fresnel_coefficients(0.6, 2, 1) + [-0.235625382192159 - 0.971843958291041*I, + 0.816477005968898 - 0.577377951366403*I] + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Fresnel_equations + """ + if not 0 <= 2*angle_of_incidence < pi: + raise ValueError('Angle of incidence not in range [0:pi/2)') + + n1 = refractive_index_of_medium(medium1) + n2 = refractive_index_of_medium(medium2) + + angle_of_refraction = asin(n1*sin(angle_of_incidence)/n2) + try: + angle_of_total_internal_reflection_onset = critical_angle(n1, n2) + except ValueError: + angle_of_total_internal_reflection_onset = None + + if angle_of_total_internal_reflection_onset is None or\ + angle_of_total_internal_reflection_onset > angle_of_incidence: + R_s = -sin(angle_of_incidence - angle_of_refraction)\ + /sin(angle_of_incidence + angle_of_refraction) + R_p = tan(angle_of_incidence - angle_of_refraction)\ + /tan(angle_of_incidence + angle_of_refraction) + T_s = 2*sin(angle_of_refraction)*cos(angle_of_incidence)\ + /sin(angle_of_incidence + angle_of_refraction) + T_p = 2*sin(angle_of_refraction)*cos(angle_of_incidence)\ + /(sin(angle_of_incidence + angle_of_refraction)\ + *cos(angle_of_incidence - angle_of_refraction)) + return [R_p, R_s, T_p, T_s] + else: + n = n2/n1 + R_s = cancel((cos(angle_of_incidence)-\ + I*sqrt(sin(angle_of_incidence)**2 - n**2))\ + /(cos(angle_of_incidence)+\ + I*sqrt(sin(angle_of_incidence)**2 - n**2))) + R_p = cancel((n**2*cos(angle_of_incidence)-\ + I*sqrt(sin(angle_of_incidence)**2 - n**2))\ + /(n**2*cos(angle_of_incidence)+\ + I*sqrt(sin(angle_of_incidence)**2 - n**2))) + return [R_p, R_s] + + +def deviation(incident, medium1, medium2, normal=None, plane=None): + """ + This function calculates the angle of deviation of a ray + due to refraction at planar surface. + + Parameters + ========== + + incident : Matrix, Ray3D, sequence or float + Incident vector or angle of incidence + medium1 : sympy.physics.optics.medium.Medium or sympifiable + Medium 1 or its refractive index + medium2 : sympy.physics.optics.medium.Medium or sympifiable + Medium 2 or its refractive index + normal : Matrix, Ray3D, or sequence + Normal vector + plane : Plane + Plane of separation of the two media. + + Returns angular deviation between incident and refracted rays + + Examples + ======== + + >>> from sympy.physics.optics import deviation + >>> from sympy.geometry import Point3D, Ray3D, Plane + >>> from sympy.matrices import Matrix + >>> from sympy import symbols + >>> n1, n2 = symbols('n1, n2') + >>> n = Matrix([0, 0, 1]) + >>> P = Plane(Point3D(0, 0, 0), normal_vector=[0, 0, 1]) + >>> r1 = Ray3D(Point3D(-1, -1, 1), Point3D(0, 0, 0)) + >>> deviation(r1, 1, 1, n) + 0 + >>> deviation(r1, n1, n2, plane=P) + -acos(-sqrt(-2*n1**2/(3*n2**2) + 1)) + acos(-sqrt(3)/3) + >>> round(deviation(0.1, 1.2, 1.5), 5) + -0.02005 + """ + refracted = refraction_angle(incident, + medium1, + medium2, + normal=normal, + plane=plane) + try: + angle_of_incidence = Float(incident) + except TypeError: + angle_of_incidence = None + + if angle_of_incidence is not None: + return float(refracted) - angle_of_incidence + + if refracted != 0: + if isinstance(refracted, Ray3D): + refracted = Matrix(refracted.direction_ratio) + + if not isinstance(incident, Matrix): + if is_sequence(incident): + _incident = Matrix(incident) + elif isinstance(incident, Ray3D): + _incident = Matrix(incident.direction_ratio) + else: + raise TypeError( + "incident should be a Matrix, Ray3D, or sequence") + else: + _incident = incident + + if plane is None: + if not isinstance(normal, Matrix): + if is_sequence(normal): + _normal = Matrix(normal) + elif isinstance(normal, Ray3D): + _normal = Matrix(normal.direction_ratio) + else: + raise TypeError( + "normal should be a Matrix, Ray3D, or sequence") + else: + _normal = normal + else: + _normal = Matrix(plane.normal_vector) + + mag_incident = sqrt(sum(i**2 for i in _incident)) + mag_normal = sqrt(sum(i**2 for i in _normal)) + mag_refracted = sqrt(sum(i**2 for i in refracted)) + _incident /= mag_incident + _normal /= mag_normal + refracted /= mag_refracted + i = acos(_incident.dot(_normal)) + r = acos(refracted.dot(_normal)) + return i - r + + +def brewster_angle(medium1, medium2): + """ + This function calculates the Brewster's angle of incidence to Medium 2 from + Medium 1 in radians. + + Parameters + ========== + + medium 1 : Medium or sympifiable + Refractive index of Medium 1 + medium 2 : Medium or sympifiable + Refractive index of Medium 1 + + Examples + ======== + + >>> from sympy.physics.optics import brewster_angle + >>> brewster_angle(1, 1.33) + 0.926093295503462 + + """ + + n1 = refractive_index_of_medium(medium1) + n2 = refractive_index_of_medium(medium2) + + return atan2(n2, n1) + +def critical_angle(medium1, medium2): + """ + This function calculates the critical angle of incidence (marking the onset + of total internal) to Medium 2 from Medium 1 in radians. + + Parameters + ========== + + medium 1 : Medium or sympifiable + Refractive index of Medium 1. + medium 2 : Medium or sympifiable + Refractive index of Medium 1. + + Examples + ======== + + >>> from sympy.physics.optics import critical_angle + >>> critical_angle(1.33, 1) + 0.850908514477849 + + """ + + n1 = refractive_index_of_medium(medium1) + n2 = refractive_index_of_medium(medium2) + + if n2 > n1: + raise ValueError('Total internal reflection impossible for n1 < n2') + else: + return asin(n2/n1) + + + +def lens_makers_formula(n_lens, n_surr, r1, r2, d=0): + """ + This function calculates focal length of a lens. + It follows cartesian sign convention. + + Parameters + ========== + + n_lens : Medium or sympifiable + Index of refraction of lens. + n_surr : Medium or sympifiable + Index of reflection of surrounding. + r1 : sympifiable + Radius of curvature of first surface. + r2 : sympifiable + Radius of curvature of second surface. + d : sympifiable, optional + Thickness of lens, default value is 0. + + Examples + ======== + + >>> from sympy.physics.optics import lens_makers_formula + >>> from sympy import S + >>> lens_makers_formula(1.33, 1, 10, -10) + 15.1515151515151 + >>> lens_makers_formula(1.2, 1, 10, S.Infinity) + 50.0000000000000 + >>> lens_makers_formula(1.33, 1, 10, -10, d=1) + 15.3418463277618 + + """ + + if isinstance(n_lens, Medium): + n_lens = n_lens.refractive_index + else: + n_lens = sympify(n_lens) + if isinstance(n_surr, Medium): + n_surr = n_surr.refractive_index + else: + n_surr = sympify(n_surr) + d = sympify(d) + + focal_length = 1/((n_lens - n_surr) / n_surr*(1/r1 - 1/r2 + (((n_lens - n_surr) * d) / (n_lens * r1 * r2)))) + + if focal_length == zoo: + return S.Infinity + return focal_length + + +def mirror_formula(focal_length=None, u=None, v=None): + """ + This function provides one of the three parameters + when two of them are supplied. + This is valid only for paraxial rays. + + Parameters + ========== + + focal_length : sympifiable + Focal length of the mirror. + u : sympifiable + Distance of object from the pole on + the principal axis. + v : sympifiable + Distance of the image from the pole + on the principal axis. + + Examples + ======== + + >>> from sympy.physics.optics import mirror_formula + >>> from sympy.abc import f, u, v + >>> mirror_formula(focal_length=f, u=u) + f*u/(-f + u) + >>> mirror_formula(focal_length=f, v=v) + f*v/(-f + v) + >>> mirror_formula(u=u, v=v) + u*v/(u + v) + + """ + if focal_length and u and v: + raise ValueError("Please provide only two parameters") + + focal_length = sympify(focal_length) + u = sympify(u) + v = sympify(v) + if u is oo: + _u = Symbol('u') + if v is oo: + _v = Symbol('v') + if focal_length is oo: + _f = Symbol('f') + if focal_length is None: + if u is oo and v is oo: + return Limit(Limit(_v*_u/(_v + _u), _u, oo), _v, oo).doit() + if u is oo: + return Limit(v*_u/(v + _u), _u, oo).doit() + if v is oo: + return Limit(_v*u/(_v + u), _v, oo).doit() + return v*u/(v + u) + if u is None: + if v is oo and focal_length is oo: + return Limit(Limit(_v*_f/(_v - _f), _v, oo), _f, oo).doit() + if v is oo: + return Limit(_v*focal_length/(_v - focal_length), _v, oo).doit() + if focal_length is oo: + return Limit(v*_f/(v - _f), _f, oo).doit() + return v*focal_length/(v - focal_length) + if v is None: + if u is oo and focal_length is oo: + return Limit(Limit(_u*_f/(_u - _f), _u, oo), _f, oo).doit() + if u is oo: + return Limit(_u*focal_length/(_u - focal_length), _u, oo).doit() + if focal_length is oo: + return Limit(u*_f/(u - _f), _f, oo).doit() + return u*focal_length/(u - focal_length) + + +def lens_formula(focal_length=None, u=None, v=None): + """ + This function provides one of the three parameters + when two of them are supplied. + This is valid only for paraxial rays. + + Parameters + ========== + + focal_length : sympifiable + Focal length of the mirror. + u : sympifiable + Distance of object from the optical center on + the principal axis. + v : sympifiable + Distance of the image from the optical center + on the principal axis. + + Examples + ======== + + >>> from sympy.physics.optics import lens_formula + >>> from sympy.abc import f, u, v + >>> lens_formula(focal_length=f, u=u) + f*u/(f + u) + >>> lens_formula(focal_length=f, v=v) + f*v/(f - v) + >>> lens_formula(u=u, v=v) + u*v/(u - v) + + """ + if focal_length and u and v: + raise ValueError("Please provide only two parameters") + + focal_length = sympify(focal_length) + u = sympify(u) + v = sympify(v) + if u is oo: + _u = Symbol('u') + if v is oo: + _v = Symbol('v') + if focal_length is oo: + _f = Symbol('f') + if focal_length is None: + if u is oo and v is oo: + return Limit(Limit(_v*_u/(_u - _v), _u, oo), _v, oo).doit() + if u is oo: + return Limit(v*_u/(_u - v), _u, oo).doit() + if v is oo: + return Limit(_v*u/(u - _v), _v, oo).doit() + return v*u/(u - v) + if u is None: + if v is oo and focal_length is oo: + return Limit(Limit(_v*_f/(_f - _v), _v, oo), _f, oo).doit() + if v is oo: + return Limit(_v*focal_length/(focal_length - _v), _v, oo).doit() + if focal_length is oo: + return Limit(v*_f/(_f - v), _f, oo).doit() + return v*focal_length/(focal_length - v) + if v is None: + if u is oo and focal_length is oo: + return Limit(Limit(_u*_f/(_u + _f), _u, oo), _f, oo).doit() + if u is oo: + return Limit(_u*focal_length/(_u + focal_length), _u, oo).doit() + if focal_length is oo: + return Limit(u*_f/(u + _f), _f, oo).doit() + return u*focal_length/(u + focal_length) + +def hyperfocal_distance(f, N, c): + """ + + Parameters + ========== + + f: sympifiable + Focal length of a given lens. + + N: sympifiable + F-number of a given lens. + + c: sympifiable + Circle of Confusion (CoC) of a given image format. + + Example + ======= + + >>> from sympy.physics.optics import hyperfocal_distance + >>> round(hyperfocal_distance(f = 0.5, N = 8, c = 0.0033), 2) + 9.47 + """ + + f = sympify(f) + N = sympify(N) + c = sympify(c) + + return (1/(N * c))*(f**2) + +def transverse_magnification(si, so): + """ + + Calculates the transverse magnification upon reflection in a mirror, + which is the ratio of the image size to the object size. + + Parameters + ========== + + so: sympifiable + Lens-object distance. + + si: sympifiable + Lens-image distance. + + Example + ======= + + >>> from sympy.physics.optics import transverse_magnification + >>> transverse_magnification(30, 15) + -2 + + """ + + si = sympify(si) + so = sympify(so) + + return (-(si/so)) diff --git a/phivenv/Lib/site-packages/sympy/physics/optics/waves.py b/phivenv/Lib/site-packages/sympy/physics/optics/waves.py new file mode 100644 index 0000000000000000000000000000000000000000..61e2ff4db578543f9f2694f239f03439bfab2c41 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/optics/waves.py @@ -0,0 +1,340 @@ +""" +This module has all the classes and functions related to waves in optics. + +**Contains** + +* TWave +""" + +__all__ = ['TWave'] + +from sympy.core.basic import Basic +from sympy.core.expr import Expr +from sympy.core.function import Derivative, Function +from sympy.core.numbers import (Number, pi, I) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.core.sympify import _sympify, sympify +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (atan2, cos, sin) +from sympy.physics.units import speed_of_light, meter, second + + +c = speed_of_light.convert_to(meter/second) + + +class TWave(Expr): + + r""" + This is a simple transverse sine wave travelling in a one-dimensional space. + Basic properties are required at the time of creation of the object, + but they can be changed later with respective methods provided. + + Explanation + =========== + + It is represented as :math:`A \times cos(k*x - \omega \times t + \phi )`, + where :math:`A` is the amplitude, :math:`\omega` is the angular frequency, + :math:`k` is the wavenumber (spatial frequency), :math:`x` is a spatial variable + to represent the position on the dimension on which the wave propagates, + and :math:`\phi` is the phase angle of the wave. + + + Arguments + ========= + + amplitude : Sympifyable + Amplitude of the wave. + frequency : Sympifyable + Frequency of the wave. + phase : Sympifyable + Phase angle of the wave. + time_period : Sympifyable + Time period of the wave. + n : Sympifyable + Refractive index of the medium. + + Raises + ======= + + ValueError : When neither frequency nor time period is provided + or they are not consistent. + TypeError : When anything other than TWave objects is added. + + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.optics import TWave + >>> A1, phi1, A2, phi2, f = symbols('A1, phi1, A2, phi2, f') + >>> w1 = TWave(A1, f, phi1) + >>> w2 = TWave(A2, f, phi2) + >>> w3 = w1 + w2 # Superposition of two waves + >>> w3 + TWave(sqrt(A1**2 + 2*A1*A2*cos(phi1 - phi2) + A2**2), f, + atan2(A1*sin(phi1) + A2*sin(phi2), A1*cos(phi1) + A2*cos(phi2)), 1/f, n) + >>> w3.amplitude + sqrt(A1**2 + 2*A1*A2*cos(phi1 - phi2) + A2**2) + >>> w3.phase + atan2(A1*sin(phi1) + A2*sin(phi2), A1*cos(phi1) + A2*cos(phi2)) + >>> w3.speed + 299792458*meter/(second*n) + >>> w3.angular_velocity + 2*pi*f + + """ + + def __new__( + cls, + amplitude, + frequency=None, + phase=S.Zero, + time_period=None, + n=Symbol('n')): + if time_period is not None: + time_period = _sympify(time_period) + _frequency = S.One/time_period + if frequency is not None: + frequency = _sympify(frequency) + _time_period = S.One/frequency + if time_period is not None: + if frequency != S.One/time_period: + raise ValueError("frequency and time_period should be consistent.") + if frequency is None and time_period is None: + raise ValueError("Either frequency or time period is needed.") + if frequency is None: + frequency = _frequency + if time_period is None: + time_period = _time_period + + amplitude = _sympify(amplitude) + phase = _sympify(phase) + n = sympify(n) + obj = Basic.__new__(cls, amplitude, frequency, phase, time_period, n) + return obj + + @property + def amplitude(self): + """ + Returns the amplitude of the wave. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.optics import TWave + >>> A, phi, f = symbols('A, phi, f') + >>> w = TWave(A, f, phi) + >>> w.amplitude + A + """ + return self.args[0] + + @property + def frequency(self): + """ + Returns the frequency of the wave, + in cycles per second. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.optics import TWave + >>> A, phi, f = symbols('A, phi, f') + >>> w = TWave(A, f, phi) + >>> w.frequency + f + """ + return self.args[1] + + @property + def phase(self): + """ + Returns the phase angle of the wave, + in radians. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.optics import TWave + >>> A, phi, f = symbols('A, phi, f') + >>> w = TWave(A, f, phi) + >>> w.phase + phi + """ + return self.args[2] + + @property + def time_period(self): + """ + Returns the temporal period of the wave, + in seconds per cycle. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.optics import TWave + >>> A, phi, f = symbols('A, phi, f') + >>> w = TWave(A, f, phi) + >>> w.time_period + 1/f + """ + return self.args[3] + + @property + def n(self): + """ + Returns the refractive index of the medium + """ + return self.args[4] + + @property + def wavelength(self): + """ + Returns the wavelength (spatial period) of the wave, + in meters per cycle. + It depends on the medium of the wave. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.optics import TWave + >>> A, phi, f = symbols('A, phi, f') + >>> w = TWave(A, f, phi) + >>> w.wavelength + 299792458*meter/(second*f*n) + """ + return c/(self.frequency*self.n) + + + @property + def speed(self): + """ + Returns the propagation speed of the wave, + in meters per second. + It is dependent on the propagation medium. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.optics import TWave + >>> A, phi, f = symbols('A, phi, f') + >>> w = TWave(A, f, phi) + >>> w.speed + 299792458*meter/(second*n) + """ + return self.wavelength*self.frequency + + @property + def angular_velocity(self): + """ + Returns the angular velocity of the wave, + in radians per second. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.optics import TWave + >>> A, phi, f = symbols('A, phi, f') + >>> w = TWave(A, f, phi) + >>> w.angular_velocity + 2*pi*f + """ + return 2*pi*self.frequency + + @property + def wavenumber(self): + """ + Returns the wavenumber of the wave, + in radians per meter. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.optics import TWave + >>> A, phi, f = symbols('A, phi, f') + >>> w = TWave(A, f, phi) + >>> w.wavenumber + pi*second*f*n/(149896229*meter) + """ + return 2*pi/self.wavelength + + def __str__(self): + """String representation of a TWave.""" + from sympy.printing import sstr + return type(self).__name__ + sstr(self.args) + + __repr__ = __str__ + + def __add__(self, other): + """ + Addition of two waves will result in their superposition. + The type of interference will depend on their phase angles. + """ + if isinstance(other, TWave): + if self.frequency == other.frequency and self.wavelength == other.wavelength: + return TWave(sqrt(self.amplitude**2 + other.amplitude**2 + 2 * + self.amplitude*other.amplitude*cos( + self.phase - other.phase)), + self.frequency, + atan2(self.amplitude*sin(self.phase) + + other.amplitude*sin(other.phase), + self.amplitude*cos(self.phase) + + other.amplitude*cos(other.phase)) + ) + else: + raise NotImplementedError("Interference of waves with different frequencies" + " has not been implemented.") + else: + raise TypeError(type(other).__name__ + " and TWave objects cannot be added.") + + def __mul__(self, other): + """ + Multiplying a wave by a scalar rescales the amplitude of the wave. + """ + other = sympify(other) + if isinstance(other, Number): + return TWave(self.amplitude*other, *self.args[1:]) + else: + raise TypeError(type(other).__name__ + " and TWave objects cannot be multiplied.") + + def __sub__(self, other): + return self.__add__(-1*other) + + def __neg__(self): + return self.__mul__(-1) + + def __radd__(self, other): + return self.__add__(other) + + def __rmul__(self, other): + return self.__mul__(other) + + def __rsub__(self, other): + return (-self).__radd__(other) + + def _eval_rewrite_as_sin(self, *args, **kwargs): + return self.amplitude*sin(self.wavenumber*Symbol('x') + - self.angular_velocity*Symbol('t') + self.phase + pi/2, evaluate=False) + + def _eval_rewrite_as_cos(self, *args, **kwargs): + return self.amplitude*cos(self.wavenumber*Symbol('x') + - self.angular_velocity*Symbol('t') + self.phase) + + def _eval_rewrite_as_pde(self, *args, **kwargs): + mu, epsilon, x, t = symbols('mu, epsilon, x, t') + E = Function('E') + return Derivative(E(x, t), x, 2) + mu*epsilon*Derivative(E(x, t), t, 2) + + def _eval_rewrite_as_exp(self, *args, **kwargs): + return self.amplitude*exp(I*(self.wavenumber*Symbol('x') + - self.angular_velocity*Symbol('t') + self.phase)) diff --git a/phivenv/Lib/site-packages/sympy/physics/paulialgebra.py b/phivenv/Lib/site-packages/sympy/physics/paulialgebra.py new file mode 100644 index 0000000000000000000000000000000000000000..300957354ff34907035aa1d1a48b00276230a1e5 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/paulialgebra.py @@ -0,0 +1,231 @@ +""" +This module implements Pauli algebra by subclassing Symbol. Only algebraic +properties of Pauli matrices are used (we do not use the Matrix class). + +See the documentation to the class Pauli for examples. + +References +========== + +.. [1] https://en.wikipedia.org/wiki/Pauli_matrices +""" + +from sympy.core.add import Add +from sympy.core.mul import Mul +from sympy.core.numbers import I +from sympy.core.power import Pow +from sympy.core.symbol import Symbol +from sympy.physics.quantum import TensorProduct + +__all__ = ['evaluate_pauli_product'] + + +def delta(i, j): + """ + Returns 1 if ``i == j``, else 0. + + This is used in the multiplication of Pauli matrices. + + Examples + ======== + + >>> from sympy.physics.paulialgebra import delta + >>> delta(1, 1) + 1 + >>> delta(2, 3) + 0 + """ + if i == j: + return 1 + else: + return 0 + + +def epsilon(i, j, k): + """ + Return 1 if i,j,k is equal to (1,2,3), (2,3,1), or (3,1,2); + -1 if ``i``,``j``,``k`` is equal to (1,3,2), (3,2,1), or (2,1,3); + else return 0. + + This is used in the multiplication of Pauli matrices. + + Examples + ======== + + >>> from sympy.physics.paulialgebra import epsilon + >>> epsilon(1, 2, 3) + 1 + >>> epsilon(1, 3, 2) + -1 + """ + if (i, j, k) in ((1, 2, 3), (2, 3, 1), (3, 1, 2)): + return 1 + elif (i, j, k) in ((1, 3, 2), (3, 2, 1), (2, 1, 3)): + return -1 + else: + return 0 + + +class Pauli(Symbol): + """ + The class representing algebraic properties of Pauli matrices. + + Explanation + =========== + + The symbol used to display the Pauli matrices can be changed with an + optional parameter ``label="sigma"``. Pauli matrices with different + ``label`` attributes cannot multiply together. + + If the left multiplication of symbol or number with Pauli matrix is needed, + please use parentheses to separate Pauli and symbolic multiplication + (for example: 2*I*(Pauli(3)*Pauli(2))). + + Another variant is to use evaluate_pauli_product function to evaluate + the product of Pauli matrices and other symbols (with commutative + multiply rules). + + See Also + ======== + + evaluate_pauli_product + + Examples + ======== + + >>> from sympy.physics.paulialgebra import Pauli + >>> Pauli(1) + sigma1 + >>> Pauli(1)*Pauli(2) + I*sigma3 + >>> Pauli(1)*Pauli(1) + 1 + >>> Pauli(3)**4 + 1 + >>> Pauli(1)*Pauli(2)*Pauli(3) + I + + >>> from sympy.physics.paulialgebra import Pauli + >>> Pauli(1, label="tau") + tau1 + >>> Pauli(1)*Pauli(2, label="tau") + sigma1*tau2 + >>> Pauli(1, label="tau")*Pauli(2, label="tau") + I*tau3 + + >>> from sympy import I + >>> I*(Pauli(2)*Pauli(3)) + -sigma1 + + >>> from sympy.physics.paulialgebra import evaluate_pauli_product + >>> f = I*Pauli(2)*Pauli(3) + >>> f + I*sigma2*sigma3 + >>> evaluate_pauli_product(f) + -sigma1 + """ + + __slots__ = ("i", "label") + + def __new__(cls, i, label="sigma"): + if i not in [1, 2, 3]: + raise IndexError("Invalid Pauli index") + obj = Symbol.__new__(cls, "%s%d" %(label,i), commutative=False, hermitian=True) + obj.i = i + obj.label = label + return obj + + def __getnewargs_ex__(self): + return (self.i, self.label), {} + + def _hashable_content(self): + return (self.i, self.label) + + # FIXME don't work for -I*Pauli(2)*Pauli(3) + def __mul__(self, other): + if isinstance(other, Pauli): + j = self.i + k = other.i + jlab = self.label + klab = other.label + + if jlab == klab: + return delta(j, k) \ + + I*epsilon(j, k, 1)*Pauli(1,jlab) \ + + I*epsilon(j, k, 2)*Pauli(2,jlab) \ + + I*epsilon(j, k, 3)*Pauli(3,jlab) + return super().__mul__(other) + + def _eval_power(b, e): + if e.is_Integer and e.is_positive: + return super().__pow__(int(e) % 2) + + +def evaluate_pauli_product(arg): + '''Help function to evaluate Pauli matrices product + with symbolic objects. + + Parameters + ========== + + arg: symbolic expression that contains Paulimatrices + + Examples + ======== + + >>> from sympy.physics.paulialgebra import Pauli, evaluate_pauli_product + >>> from sympy import I + >>> evaluate_pauli_product(I*Pauli(1)*Pauli(2)) + -sigma3 + + >>> from sympy.abc import x + >>> evaluate_pauli_product(x**2*Pauli(2)*Pauli(1)) + -I*x**2*sigma3 + ''' + start = arg + end = arg + + if isinstance(arg, Pow) and isinstance(arg.args[0], Pauli): + if arg.args[1].is_odd: + return arg.args[0] + else: + return 1 + + if isinstance(arg, Add): + return Add(*[evaluate_pauli_product(part) for part in arg.args]) + + if isinstance(arg, TensorProduct): + return TensorProduct(*[evaluate_pauli_product(part) for part in arg.args]) + + elif not(isinstance(arg, Mul)): + return arg + + while not start == end or start == arg and end == arg: + start = end + + tmp = start.as_coeff_mul() + sigma_product = 1 + com_product = 1 + keeper = 1 + + for el in tmp[1]: + if isinstance(el, Pauli): + sigma_product *= el + elif not el.is_commutative: + if isinstance(el, Pow) and isinstance(el.args[0], Pauli): + if el.args[1].is_odd: + sigma_product *= el.args[0] + elif isinstance(el, TensorProduct): + keeper = keeper*sigma_product*\ + TensorProduct( + *[evaluate_pauli_product(part) for part in el.args] + ) + sigma_product = 1 + else: + keeper = keeper*sigma_product*el + sigma_product = 1 + else: + com_product *= el + end = tmp[0]*keeper*sigma_product*com_product + if end == arg: break + return end diff --git a/phivenv/Lib/site-packages/sympy/physics/pring.py b/phivenv/Lib/site-packages/sympy/physics/pring.py new file mode 100644 index 0000000000000000000000000000000000000000..325f4ff98a8c9fc428b4e332153af533f4d199ca --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/pring.py @@ -0,0 +1,94 @@ +from sympy.core.numbers import (I, pi) +from sympy.core.singleton import S +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.physics.quantum.constants import hbar + + +def wavefunction(n, x): + """ + Returns the wavefunction for particle on ring. + + Parameters + ========== + + n : The quantum number. + Here ``n`` can be positive as well as negative + which can be used to describe the direction of motion of particle. + x : + The angle. + + Examples + ======== + + >>> from sympy.physics.pring import wavefunction + >>> from sympy import Symbol, integrate, pi + >>> x=Symbol("x") + >>> wavefunction(1, x) + sqrt(2)*exp(I*x)/(2*sqrt(pi)) + >>> wavefunction(2, x) + sqrt(2)*exp(2*I*x)/(2*sqrt(pi)) + >>> wavefunction(3, x) + sqrt(2)*exp(3*I*x)/(2*sqrt(pi)) + + The normalization of the wavefunction is: + + >>> integrate(wavefunction(2, x)*wavefunction(-2, x), (x, 0, 2*pi)) + 1 + >>> integrate(wavefunction(4, x)*wavefunction(-4, x), (x, 0, 2*pi)) + 1 + + References + ========== + + .. [1] Atkins, Peter W.; Friedman, Ronald (2005). Molecular Quantum + Mechanics (4th ed.). Pages 71-73. + + """ + # sympify arguments + n, x = S(n), S(x) + return exp(n * I * x) / sqrt(2 * pi) + + +def energy(n, m, r): + """ + Returns the energy of the state corresponding to quantum number ``n``. + + E=(n**2 * (hcross)**2) / (2 * m * r**2) + + Parameters + ========== + + n : + The quantum number. + m : + Mass of the particle. + r : + Radius of circle. + + Examples + ======== + + >>> from sympy.physics.pring import energy + >>> from sympy import Symbol + >>> m=Symbol("m") + >>> r=Symbol("r") + >>> energy(1, m, r) + hbar**2/(2*m*r**2) + >>> energy(2, m, r) + 2*hbar**2/(m*r**2) + >>> energy(-2, 2.0, 3.0) + 0.111111111111111*hbar**2 + + References + ========== + + .. [1] Atkins, Peter W.; Friedman, Ronald (2005). Molecular Quantum + Mechanics (4th ed.). Pages 71-73. + + """ + n, m, r = S(n), S(m), S(r) + if n.is_integer: + return (n**2 * hbar**2) / (2 * m * r**2) + else: + raise ValueError("'n' must be integer") diff --git a/phivenv/Lib/site-packages/sympy/physics/qho_1d.py b/phivenv/Lib/site-packages/sympy/physics/qho_1d.py new file mode 100644 index 0000000000000000000000000000000000000000..f418e0e954656923fbfa64cea2145581ddf65aea --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/qho_1d.py @@ -0,0 +1,88 @@ +from sympy.core import S, pi, Rational +from sympy.functions import hermite, sqrt, exp, factorial, Abs +from sympy.physics.quantum.constants import hbar + + +def psi_n(n, x, m, omega): + """ + Returns the wavefunction psi_{n} for the One-dimensional harmonic oscillator. + + Parameters + ========== + + n : + the "nodal" quantum number. Corresponds to the number of nodes in the + wavefunction. ``n >= 0`` + x : + x coordinate. + m : + Mass of the particle. + omega : + Angular frequency of the oscillator. + + Examples + ======== + + >>> from sympy.physics.qho_1d import psi_n + >>> from sympy.abc import m, x, omega + >>> psi_n(0, x, m, omega) + (m*omega)**(1/4)*exp(-m*omega*x**2/(2*hbar))/(hbar**(1/4)*pi**(1/4)) + + """ + + # sympify arguments + n, x, m, omega = map(S, [n, x, m, omega]) + nu = m * omega / hbar + # normalization coefficient + C = (nu/pi)**Rational(1, 4) * sqrt(1/(2**n*factorial(n))) + + return C * exp(-nu* x**2 /2) * hermite(n, sqrt(nu)*x) + + +def E_n(n, omega): + """ + Returns the Energy of the One-dimensional harmonic oscillator. + + Parameters + ========== + + n : + The "nodal" quantum number. + omega : + The harmonic oscillator angular frequency. + + Notes + ===== + + The unit of the returned value matches the unit of hw, since the energy is + calculated as: + + E_n = hbar * omega*(n + 1/2) + + Examples + ======== + + >>> from sympy.physics.qho_1d import E_n + >>> from sympy.abc import x, omega + >>> E_n(x, omega) + hbar*omega*(x + 1/2) + """ + + return hbar * omega * (n + S.Half) + + +def coherent_state(n, alpha): + """ + Returns for the coherent states of 1D harmonic oscillator. + See https://en.wikipedia.org/wiki/Coherent_states + + Parameters + ========== + + n : + The "nodal" quantum number. + alpha : + The eigen value of annihilation operator. + """ + + return exp(- Abs(alpha)**2/2)*(alpha**n)/sqrt(factorial(n)) diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/__init__.py b/phivenv/Lib/site-packages/sympy/physics/quantum/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..36203f1a48c4c53832ce44942878ddc7b89f8091 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/__init__.py @@ -0,0 +1,65 @@ +# Names exposed by 'from sympy.physics.quantum import *' + +__all__ = [ + 'AntiCommutator', + + 'qapply', + + 'Commutator', + + 'Dagger', + + 'HilbertSpaceError', 'HilbertSpace', 'TensorProductHilbertSpace', + 'TensorPowerHilbertSpace', 'DirectSumHilbertSpace', 'ComplexSpace', 'L2', + 'FockSpace', + + 'InnerProduct', + + 'Operator', 'HermitianOperator', 'UnitaryOperator', 'IdentityOperator', + 'OuterProduct', 'DifferentialOperator', + + 'represent', 'rep_innerproduct', 'rep_expectation', 'integrate_result', + 'get_basis', 'enumerate_states', + + 'KetBase', 'BraBase', 'StateBase', 'State', 'Ket', 'Bra', 'TimeDepState', + 'TimeDepBra', 'TimeDepKet', 'OrthogonalKet', 'OrthogonalBra', + 'OrthogonalState', 'Wavefunction', + + 'TensorProduct', 'tensor_product_simp', + + 'hbar', 'HBar', + + '_postprocess_state_mul', '_postprocess_state_pow' +] + +from .anticommutator import AntiCommutator + +from .qapply import qapply + +from .commutator import Commutator + +from .dagger import Dagger + +from .hilbert import (HilbertSpaceError, HilbertSpace, + TensorProductHilbertSpace, TensorPowerHilbertSpace, + DirectSumHilbertSpace, ComplexSpace, L2, FockSpace) + +from .innerproduct import InnerProduct + +from .operator import (Operator, HermitianOperator, UnitaryOperator, + IdentityOperator, OuterProduct, DifferentialOperator) + +from .represent import (represent, rep_innerproduct, rep_expectation, + integrate_result, get_basis, enumerate_states) + +from .state import (KetBase, BraBase, StateBase, State, Ket, Bra, + TimeDepState, TimeDepBra, TimeDepKet, OrthogonalKet, + OrthogonalBra, OrthogonalState, Wavefunction) + +from .tensorproduct import TensorProduct, tensor_product_simp + +from .constants import hbar, HBar + +# These are private, but need to be imported so they are registered +# as postprocessing transformers with Mul and Pow. +from .transforms import _postprocess_state_mul, _postprocess_state_pow diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4abae2fdbe4dfdd7c6d7f22b982de87e415e0099 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/anticommutator.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/anticommutator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea3a6456eb9e573db50940929b165e784258b52c Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/anticommutator.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/boson.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/boson.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2561f562dc58bfb7dff605a166e4983b32d5cce3 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/boson.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/cartesian.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/cartesian.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d522b74c8bdf3b3d2b9b94d999225c0288cb1d52 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/cartesian.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/cg.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/cg.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e98ff12076c687539b7c070b9d036288997e9376 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/cg.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/circuitplot.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/circuitplot.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1acfe0ef364f5f1ee9a34936397a5a422563f817 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/circuitplot.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/circuitutils.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/circuitutils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2b6af9367bd5e5007a7ed4814451cc1bd67c58e Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/circuitutils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/commutator.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/commutator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..830421d2171811ee1e6c381ea956fd80bb1fd1aa Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/commutator.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/constants.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/constants.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74f053aab824399a993bd0d373033b9802bccce5 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/constants.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/dagger.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/dagger.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd95b6e5ff2e4e31fd04dd98caa60033936634b2 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/dagger.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/density.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/density.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a24e94192941498ff968a24e194b45550a50a8ee Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/density.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/fermion.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/fermion.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b17c46c7e9b7864d190c54b44997ad12a9deee4 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/fermion.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/gate.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/gate.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b3757311fc11aa061bb04bc91e2f391dcf72e616 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/gate.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/grover.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/grover.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b29269ac9539ea4f59a125d3baadb3c71d9ad2b Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/grover.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/hilbert.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/hilbert.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a262b7a22fc9cb9da74233f38260cb7d70645a3 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/hilbert.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/identitysearch.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/identitysearch.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a98f4a4050061a8d921e70bcf74c9ca9c4030e8 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/identitysearch.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/innerproduct.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/innerproduct.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3089c3904c3cb06a2fc0319e83270adbe68047a9 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/innerproduct.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/kind.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/kind.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..468d669b20d24ea9b34838d6cb3585ad8fff53a0 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/kind.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/matrixcache.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/matrixcache.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff09be6af18640f25ea4a1e9d6a3bf3ff216e718 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/matrixcache.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/matrixutils.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/matrixutils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5dc39d2866f999fcf8c88a152949e09332c6c832 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/matrixutils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/operator.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/operator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1249605f8695bd6301a8cdae9baefabf2dee964 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/operator.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/operatorordering.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/operatorordering.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eee34aef501c43e86ea2a91fdcbcc83b9cd19311 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/operatorordering.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/operatorset.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/operatorset.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b5c9eafe8e581330b5321fa1e3fd3422aca3114 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/operatorset.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/pauli.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/pauli.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e694de8ab56cff2720b3fb8a9d783319dd8abd7 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/pauli.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/piab.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/piab.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45e407c1da460d53afbf4d3f491b2790d0697568 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/piab.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/qapply.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/qapply.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8fb762925f8afaadb1ef3e8a77e91d8bf5c8bf00 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/qapply.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/qasm.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/qasm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68179bf626b2b7d686f7cc87f00e268413a4146d Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/qasm.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/qexpr.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/qexpr.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..52e7569a27654ed747867f7e0830037020b44480 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/qexpr.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/qft.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/qft.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc7a28cbd1e774de47c8bd2d2461c6c5b6f71bb2 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/qft.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/qubit.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/qubit.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..807234fbe8be4ce771ebef5445258434d4fe355b Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/qubit.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/represent.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/represent.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0383beba235b47bc37e547b0bd62ca2983089ddb Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/represent.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/sho1d.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/sho1d.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f0bdf4d6ba39811e2a951ee6fd2a48c45e0ba72 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/sho1d.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/shor.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/shor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf8788c57e2d3e11a971ed079b1a0d41b3ade305 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/shor.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/spin.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/spin.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e72d9782fad74fcb7c8fcbfd959079607a319f2 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/spin.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/state.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/state.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da4e86ecac4d5c196c2889a61df402237e9890d3 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/state.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/tensorproduct.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/tensorproduct.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b14e1d9b27f35c3f4caee91edbdd30eefe299c7 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/tensorproduct.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/trace.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/trace.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1de43a5a3b9aafbf08c1ffc4fcdf520b383d611 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/trace.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/transforms.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/transforms.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..685754b93fc730c9e7978c0c23291485c3a7800e Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/__pycache__/transforms.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/anticommutator.py b/phivenv/Lib/site-packages/sympy/physics/quantum/anticommutator.py new file mode 100644 index 0000000000000000000000000000000000000000..cbd26eade640b60a48eaac8c8b0abaf236478ca9 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/anticommutator.py @@ -0,0 +1,166 @@ +"""The anti-commutator: ``{A,B} = A*B + B*A``.""" + +from sympy.core.expr import Expr +from sympy.core.kind import KindDispatcher +from sympy.core.mul import Mul +from sympy.core.numbers import Integer +from sympy.core.singleton import S +from sympy.printing.pretty.stringpict import prettyForm + +from sympy.physics.quantum.dagger import Dagger +from sympy.physics.quantum.kind import _OperatorKind, OperatorKind + +__all__ = [ + 'AntiCommutator' +] + +#----------------------------------------------------------------------------- +# Anti-commutator +#----------------------------------------------------------------------------- + + +class AntiCommutator(Expr): + """The standard anticommutator, in an unevaluated state. + + Explanation + =========== + + Evaluating an anticommutator is defined [1]_ as: ``{A, B} = A*B + B*A``. + This class returns the anticommutator in an unevaluated form. To evaluate + the anticommutator, use the ``.doit()`` method. + + Canonical ordering of an anticommutator is ``{A, B}`` for ``A < B``. The + arguments of the anticommutator are put into canonical order using + ``__cmp__``. If ``B < A``, then ``{A, B}`` is returned as ``{B, A}``. + + Parameters + ========== + + A : Expr + The first argument of the anticommutator {A,B}. + B : Expr + The second argument of the anticommutator {A,B}. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.quantum import AntiCommutator + >>> from sympy.physics.quantum import Operator, Dagger + >>> x, y = symbols('x,y') + >>> A = Operator('A') + >>> B = Operator('B') + + Create an anticommutator and use ``doit()`` to multiply them out. + + >>> ac = AntiCommutator(A,B); ac + {A,B} + >>> ac.doit() + A*B + B*A + + The commutator orders it arguments in canonical order: + + >>> ac = AntiCommutator(B,A); ac + {A,B} + + Commutative constants are factored out: + + >>> AntiCommutator(3*x*A,x*y*B) + 3*x**2*y*{A,B} + + Adjoint operations applied to the anticommutator are properly applied to + the arguments: + + >>> Dagger(AntiCommutator(A,B)) + {Dagger(A),Dagger(B)} + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Commutator + """ + is_commutative = False + + _kind_dispatcher = KindDispatcher("AntiCommutator_kind_dispatcher", commutative=True) + + @property + def kind(self): + arg_kinds = (a.kind for a in self.args) + return self._kind_dispatcher(*arg_kinds) + + def __new__(cls, A, B): + r = cls.eval(A, B) + if r is not None: + return r + obj = Expr.__new__(cls, A, B) + return obj + + @classmethod + def eval(cls, a, b): + if not (a and b): + return S.Zero + if a == b: + return Integer(2)*a**2 + if a.is_commutative or b.is_commutative: + return Integer(2)*a*b + + # [xA,yB] -> xy*[A,B] + ca, nca = a.args_cnc() + cb, ncb = b.args_cnc() + c_part = ca + cb + if c_part: + return Mul(Mul(*c_part), cls(Mul._from_args(nca), Mul._from_args(ncb))) + + # Canonical ordering of arguments + #The Commutator [A,B] is on canonical form if A < B. + if a.compare(b) == 1: + return cls(b, a) + + def doit(self, **hints): + """ Evaluate anticommutator """ + # Keep the import of Operator here to avoid problems with + # circular imports. + from sympy.physics.quantum.operator import Operator + A = self.args[0] + B = self.args[1] + if isinstance(A, Operator) and isinstance(B, Operator): + try: + comm = A._eval_anticommutator(B, **hints) + except NotImplementedError: + try: + comm = B._eval_anticommutator(A, **hints) + except NotImplementedError: + comm = None + if comm is not None: + return comm.doit(**hints) + return (A*B + B*A).doit(**hints) + + def _eval_adjoint(self): + return AntiCommutator(Dagger(self.args[0]), Dagger(self.args[1])) + + def _sympyrepr(self, printer, *args): + return "%s(%s,%s)" % ( + self.__class__.__name__, printer._print( + self.args[0]), printer._print(self.args[1]) + ) + + def _sympystr(self, printer, *args): + return "{%s,%s}" % ( + printer._print(self.args[0]), printer._print(self.args[1])) + + def _pretty(self, printer, *args): + pform = printer._print(self.args[0], *args) + pform = prettyForm(*pform.right(prettyForm(','))) + pform = prettyForm(*pform.right(printer._print(self.args[1], *args))) + pform = prettyForm(*pform.parens(left='{', right='}')) + return pform + + def _latex(self, printer, *args): + return "\\left\\{%s,%s\\right\\}" % tuple([ + printer._print(arg, *args) for arg in self.args]) + + +@AntiCommutator._kind_dispatcher.register(_OperatorKind, _OperatorKind) +def find_op_kind(e1, e2): + """Find the kind of an anticommutator of two OperatorKinds.""" + return OperatorKind diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/boson.py b/phivenv/Lib/site-packages/sympy/physics/quantum/boson.py new file mode 100644 index 0000000000000000000000000000000000000000..0f24cae2a7ad2f438234fcf00dadb2a4a9d76fe8 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/boson.py @@ -0,0 +1,243 @@ +"""Bosonic quantum operators.""" + +from sympy.core.numbers import Integer +from sympy.core.singleton import S +from sympy.functions.elementary.complexes import conjugate +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.physics.quantum import Operator +from sympy.physics.quantum import HilbertSpace, FockSpace, Ket, Bra +from sympy.functions.special.tensor_functions import KroneckerDelta + + +__all__ = [ + 'BosonOp', + 'BosonFockKet', + 'BosonFockBra', + 'BosonCoherentKet', + 'BosonCoherentBra' +] + + +class BosonOp(Operator): + """A bosonic operator that satisfies [a, Dagger(a)] == 1. + + Parameters + ========== + + name : str + A string that labels the bosonic mode. + + annihilation : bool + A bool that indicates if the bosonic operator is an annihilation (True, + default value) or creation operator (False) + + Examples + ======== + + >>> from sympy.physics.quantum import Dagger, Commutator + >>> from sympy.physics.quantum.boson import BosonOp + >>> a = BosonOp("a") + >>> Commutator(a, Dagger(a)).doit() + 1 + """ + + @property + def name(self): + return self.args[0] + + @property + def is_annihilation(self): + return bool(self.args[1]) + + @classmethod + def default_args(self): + return ("a", True) + + def __new__(cls, *args, **hints): + if not len(args) in [1, 2]: + raise ValueError('1 or 2 parameters expected, got %s' % args) + + if len(args) == 1: + args = (args[0], S.One) + + if len(args) == 2: + args = (args[0], Integer(args[1])) + + return Operator.__new__(cls, *args) + + def _eval_commutator_BosonOp(self, other, **hints): + if self.name == other.name: + # [a^\dagger, a] = -1 + if not self.is_annihilation and other.is_annihilation: + return S.NegativeOne + + elif 'independent' in hints and hints['independent']: + # [a, b] = 0 + return S.Zero + + return None + + def _eval_commutator_FermionOp(self, other, **hints): + return S.Zero + + def _eval_anticommutator_BosonOp(self, other, **hints): + if 'independent' in hints and hints['independent']: + # {a, b} = 2 * a * b, because [a, b] = 0 + return 2 * self * other + + return None + + def _eval_adjoint(self): + return BosonOp(str(self.name), not self.is_annihilation) + + def _print_contents_latex(self, printer, *args): + if self.is_annihilation: + return r'{%s}' % str(self.name) + else: + return r'{{%s}^\dagger}' % str(self.name) + + def _print_contents(self, printer, *args): + if self.is_annihilation: + return r'%s' % str(self.name) + else: + return r'Dagger(%s)' % str(self.name) + + def _print_contents_pretty(self, printer, *args): + from sympy.printing.pretty.stringpict import prettyForm + pform = printer._print(self.args[0], *args) + if self.is_annihilation: + return pform + else: + return pform**prettyForm('\N{DAGGER}') + + +class BosonFockKet(Ket): + """Fock state ket for a bosonic mode. + + Parameters + ========== + + n : Number + The Fock state number. + + """ + + def __new__(cls, n): + return Ket.__new__(cls, n) + + @property + def n(self): + return self.label[0] + + @classmethod + def dual_class(self): + return BosonFockBra + + @classmethod + def _eval_hilbert_space(cls, label): + return FockSpace() + + def _eval_innerproduct_BosonFockBra(self, bra, **hints): + return KroneckerDelta(self.n, bra.n) + + def _apply_from_right_to_BosonOp(self, op, **options): + if op.is_annihilation: + return sqrt(self.n) * BosonFockKet(self.n - 1) + else: + return sqrt(self.n + 1) * BosonFockKet(self.n + 1) + + +class BosonFockBra(Bra): + """Fock state bra for a bosonic mode. + + Parameters + ========== + + n : Number + The Fock state number. + + """ + + def __new__(cls, n): + return Bra.__new__(cls, n) + + @property + def n(self): + return self.label[0] + + @classmethod + def dual_class(self): + return BosonFockKet + + @classmethod + def _eval_hilbert_space(cls, label): + return FockSpace() + + +class BosonCoherentKet(Ket): + """Coherent state ket for a bosonic mode. + + Parameters + ========== + + alpha : Number, Symbol + The complex amplitude of the coherent state. + + """ + + def __new__(cls, alpha): + return Ket.__new__(cls, alpha) + + @property + def alpha(self): + return self.label[0] + + @classmethod + def dual_class(self): + return BosonCoherentBra + + @classmethod + def _eval_hilbert_space(cls, label): + return HilbertSpace() + + def _eval_innerproduct_BosonCoherentBra(self, bra, **hints): + if self.alpha == bra.alpha: + return S.One + else: + return exp(-(abs(self.alpha)**2 + abs(bra.alpha)**2 - 2 * conjugate(bra.alpha) * self.alpha)/2) + + def _apply_from_right_to_BosonOp(self, op, **options): + if op.is_annihilation: + return self.alpha * self + else: + return None + + +class BosonCoherentBra(Bra): + """Coherent state bra for a bosonic mode. + + Parameters + ========== + + alpha : Number, Symbol + The complex amplitude of the coherent state. + + """ + + def __new__(cls, alpha): + return Bra.__new__(cls, alpha) + + @property + def alpha(self): + return self.label[0] + + @classmethod + def dual_class(self): + return BosonCoherentKet + + def _apply_operator_BosonOp(self, op, **options): + if not op.is_annihilation: + return self.alpha * self + else: + return None diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/cartesian.py b/phivenv/Lib/site-packages/sympy/physics/quantum/cartesian.py new file mode 100644 index 0000000000000000000000000000000000000000..f3af1856f22c8fe4535b24be30bf99d0b3541a50 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/cartesian.py @@ -0,0 +1,341 @@ +"""Operators and states for 1D cartesian position and momentum. + +TODO: + +* Add 3D classes to mappings in operatorset.py + +""" + +from sympy.core.numbers import (I, pi) +from sympy.core.singleton import S +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.special.delta_functions import DiracDelta +from sympy.sets.sets import Interval + +from sympy.physics.quantum.constants import hbar +from sympy.physics.quantum.hilbert import L2 +from sympy.physics.quantum.operator import DifferentialOperator, HermitianOperator +from sympy.physics.quantum.state import Ket, Bra, State + +__all__ = [ + 'XOp', + 'YOp', + 'ZOp', + 'PxOp', + 'X', + 'Y', + 'Z', + 'Px', + 'XKet', + 'XBra', + 'PxKet', + 'PxBra', + 'PositionState3D', + 'PositionKet3D', + 'PositionBra3D' +] + +#------------------------------------------------------------------------- +# Position operators +#------------------------------------------------------------------------- + + +class XOp(HermitianOperator): + """1D cartesian position operator.""" + + @classmethod + def default_args(self): + return ("X",) + + @classmethod + def _eval_hilbert_space(self, args): + return L2(Interval(S.NegativeInfinity, S.Infinity)) + + def _eval_commutator_PxOp(self, other): + return I*hbar + + def _apply_operator_XKet(self, ket, **options): + return ket.position*ket + + def _apply_operator_PositionKet3D(self, ket, **options): + return ket.position_x*ket + + def _represent_PxKet(self, basis, *, index=1, **options): + states = basis._enumerate_state(2, start_index=index) + coord1 = states[0].momentum + coord2 = states[1].momentum + d = DifferentialOperator(coord1) + delta = DiracDelta(coord1 - coord2) + + return I*hbar*(d*delta) + + +class YOp(HermitianOperator): + """ Y cartesian coordinate operator (for 2D or 3D systems) """ + + @classmethod + def default_args(self): + return ("Y",) + + @classmethod + def _eval_hilbert_space(self, args): + return L2(Interval(S.NegativeInfinity, S.Infinity)) + + def _apply_operator_PositionKet3D(self, ket, **options): + return ket.position_y*ket + + +class ZOp(HermitianOperator): + """ Z cartesian coordinate operator (for 3D systems) """ + + @classmethod + def default_args(self): + return ("Z",) + + @classmethod + def _eval_hilbert_space(self, args): + return L2(Interval(S.NegativeInfinity, S.Infinity)) + + def _apply_operator_PositionKet3D(self, ket, **options): + return ket.position_z*ket + +#------------------------------------------------------------------------- +# Momentum operators +#------------------------------------------------------------------------- + + +class PxOp(HermitianOperator): + """1D cartesian momentum operator.""" + + @classmethod + def default_args(self): + return ("Px",) + + @classmethod + def _eval_hilbert_space(self, args): + return L2(Interval(S.NegativeInfinity, S.Infinity)) + + def _apply_operator_PxKet(self, ket, **options): + return ket.momentum*ket + + def _represent_XKet(self, basis, *, index=1, **options): + states = basis._enumerate_state(2, start_index=index) + coord1 = states[0].position + coord2 = states[1].position + d = DifferentialOperator(coord1) + delta = DiracDelta(coord1 - coord2) + + return -I*hbar*(d*delta) + +X = XOp('X') +Y = YOp('Y') +Z = ZOp('Z') +Px = PxOp('Px') + +#------------------------------------------------------------------------- +# Position eigenstates +#------------------------------------------------------------------------- + + +class XKet(Ket): + """1D cartesian position eigenket.""" + + @classmethod + def _operators_to_state(self, op, **options): + return self.__new__(self, *_lowercase_labels(op), **options) + + def _state_to_operators(self, op_class, **options): + return op_class.__new__(op_class, + *_uppercase_labels(self), **options) + + @classmethod + def default_args(self): + return ("x",) + + @classmethod + def dual_class(self): + return XBra + + @property + def position(self): + """The position of the state.""" + return self.label[0] + + def _enumerate_state(self, num_states, **options): + return _enumerate_continuous_1D(self, num_states, **options) + + def _eval_innerproduct_XBra(self, bra, **hints): + return DiracDelta(self.position - bra.position) + + def _eval_innerproduct_PxBra(self, bra, **hints): + return exp(-I*self.position*bra.momentum/hbar)/sqrt(2*pi*hbar) + + +class XBra(Bra): + """1D cartesian position eigenbra.""" + + @classmethod + def default_args(self): + return ("x",) + + @classmethod + def dual_class(self): + return XKet + + @property + def position(self): + """The position of the state.""" + return self.label[0] + + +class PositionState3D(State): + """ Base class for 3D cartesian position eigenstates """ + + @classmethod + def _operators_to_state(self, op, **options): + return self.__new__(self, *_lowercase_labels(op), **options) + + def _state_to_operators(self, op_class, **options): + return op_class.__new__(op_class, + *_uppercase_labels(self), **options) + + @classmethod + def default_args(self): + return ("x", "y", "z") + + @property + def position_x(self): + """ The x coordinate of the state """ + return self.label[0] + + @property + def position_y(self): + """ The y coordinate of the state """ + return self.label[1] + + @property + def position_z(self): + """ The z coordinate of the state """ + return self.label[2] + + +class PositionKet3D(Ket, PositionState3D): + """ 3D cartesian position eigenket """ + + def _eval_innerproduct_PositionBra3D(self, bra, **options): + x_diff = self.position_x - bra.position_x + y_diff = self.position_y - bra.position_y + z_diff = self.position_z - bra.position_z + + return DiracDelta(x_diff)*DiracDelta(y_diff)*DiracDelta(z_diff) + + @classmethod + def dual_class(self): + return PositionBra3D + + +# XXX: The type:ignore here is because mypy gives Definition of +# "_state_to_operators" in base class "PositionState3D" is incompatible with +# definition in base class "BraBase" +class PositionBra3D(Bra, PositionState3D): # type: ignore + """ 3D cartesian position eigenbra """ + + @classmethod + def dual_class(self): + return PositionKet3D + +#------------------------------------------------------------------------- +# Momentum eigenstates +#------------------------------------------------------------------------- + + +class PxKet(Ket): + """1D cartesian momentum eigenket.""" + + @classmethod + def _operators_to_state(self, op, **options): + return self.__new__(self, *_lowercase_labels(op), **options) + + def _state_to_operators(self, op_class, **options): + return op_class.__new__(op_class, + *_uppercase_labels(self), **options) + + @classmethod + def default_args(self): + return ("px",) + + @classmethod + def dual_class(self): + return PxBra + + @property + def momentum(self): + """The momentum of the state.""" + return self.label[0] + + def _enumerate_state(self, *args, **options): + return _enumerate_continuous_1D(self, *args, **options) + + def _eval_innerproduct_XBra(self, bra, **hints): + return exp(I*self.momentum*bra.position/hbar)/sqrt(2*pi*hbar) + + def _eval_innerproduct_PxBra(self, bra, **hints): + return DiracDelta(self.momentum - bra.momentum) + + +class PxBra(Bra): + """1D cartesian momentum eigenbra.""" + + @classmethod + def default_args(self): + return ("px",) + + @classmethod + def dual_class(self): + return PxKet + + @property + def momentum(self): + """The momentum of the state.""" + return self.label[0] + +#------------------------------------------------------------------------- +# Global helper functions +#------------------------------------------------------------------------- + + +def _enumerate_continuous_1D(*args, **options): + state = args[0] + num_states = args[1] + state_class = state.__class__ + index_list = options.pop('index_list', []) + + if len(index_list) == 0: + start_index = options.pop('start_index', 1) + index_list = list(range(start_index, start_index + num_states)) + + enum_states = [0 for i in range(len(index_list))] + + for i, ind in enumerate(index_list): + label = state.args[0] + enum_states[i] = state_class(str(label) + "_" + str(ind), **options) + + return enum_states + + +def _lowercase_labels(ops): + if not isinstance(ops, set): + ops = [ops] + + return [str(arg.label[0]).lower() for arg in ops] + + +def _uppercase_labels(ops): + if not isinstance(ops, set): + ops = [ops] + + new_args = [str(arg.label[0])[0].upper() + + str(arg.label[0])[1:] for arg in ops] + + return new_args diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/cg.py b/phivenv/Lib/site-packages/sympy/physics/quantum/cg.py new file mode 100644 index 0000000000000000000000000000000000000000..0f285cd39413a953246777c42fb6763c22a5716b --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/cg.py @@ -0,0 +1,754 @@ +#TODO: +# -Implement Clebsch-Gordan symmetries +# -Improve simplification method +# -Implement new simplifications +"""Clebsch-Gordon Coefficients.""" + +from sympy.concrete.summations import Sum +from sympy.core.add import Add +from sympy.core.expr import Expr +from sympy.core.function import expand +from sympy.core.mul import Mul +from sympy.core.power import Pow +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import (Wild, symbols) +from sympy.core.sympify import sympify +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.printing.pretty.stringpict import prettyForm, stringPict + +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.physics.wigner import clebsch_gordan, wigner_3j, wigner_6j, wigner_9j +from sympy.printing.precedence import PRECEDENCE + +__all__ = [ + 'CG', + 'Wigner3j', + 'Wigner6j', + 'Wigner9j', + 'cg_simp' +] + +#----------------------------------------------------------------------------- +# CG Coefficients +#----------------------------------------------------------------------------- + + +class Wigner3j(Expr): + """Class for the Wigner-3j symbols. + + Explanation + =========== + + Wigner 3j-symbols are coefficients determined by the coupling of + two angular momenta. When created, they are expressed as symbolic + quantities that, for numerical parameters, can be evaluated using the + ``.doit()`` method [1]_. + + Parameters + ========== + + j1, m1, j2, m2, j3, m3 : Number, Symbol + Terms determining the angular momentum of coupled angular momentum + systems. + + Examples + ======== + + Declare a Wigner-3j coefficient and calculate its value + + >>> from sympy.physics.quantum.cg import Wigner3j + >>> w3j = Wigner3j(6,0,4,0,2,0) + >>> w3j + Wigner3j(6, 0, 4, 0, 2, 0) + >>> w3j.doit() + sqrt(715)/143 + + See Also + ======== + + CG: Clebsch-Gordan coefficients + + References + ========== + + .. [1] Varshalovich, D A, Quantum Theory of Angular Momentum. 1988. + """ + + is_commutative = True + + def __new__(cls, j1, m1, j2, m2, j3, m3): + args = map(sympify, (j1, m1, j2, m2, j3, m3)) + return Expr.__new__(cls, *args) + + @property + def j1(self): + return self.args[0] + + @property + def m1(self): + return self.args[1] + + @property + def j2(self): + return self.args[2] + + @property + def m2(self): + return self.args[3] + + @property + def j3(self): + return self.args[4] + + @property + def m3(self): + return self.args[5] + + @property + def is_symbolic(self): + return not all(arg.is_number for arg in self.args) + + # This is modified from the _print_Matrix method + def _pretty(self, printer, *args): + m = ((printer._print(self.j1), printer._print(self.m1)), + (printer._print(self.j2), printer._print(self.m2)), + (printer._print(self.j3), printer._print(self.m3))) + hsep = 2 + vsep = 1 + maxw = [-1]*3 + for j in range(3): + maxw[j] = max(m[j][i].width() for i in range(2)) + D = None + for i in range(2): + D_row = None + for j in range(3): + s = m[j][i] + wdelta = maxw[j] - s.width() + wleft = wdelta //2 + wright = wdelta - wleft + + s = prettyForm(*s.right(' '*wright)) + s = prettyForm(*s.left(' '*wleft)) + + if D_row is None: + D_row = s + continue + D_row = prettyForm(*D_row.right(' '*hsep)) + D_row = prettyForm(*D_row.right(s)) + if D is None: + D = D_row + continue + for _ in range(vsep): + D = prettyForm(*D.below(' ')) + D = prettyForm(*D.below(D_row)) + D = prettyForm(*D.parens()) + return D + + def _latex(self, printer, *args): + label = map(printer._print, (self.j1, self.j2, self.j3, + self.m1, self.m2, self.m3)) + return r'\left(\begin{array}{ccc} %s & %s & %s \\ %s & %s & %s \end{array}\right)' % \ + tuple(label) + + def doit(self, **hints): + if self.is_symbolic: + raise ValueError("Coefficients must be numerical") + return wigner_3j(self.j1, self.j2, self.j3, self.m1, self.m2, self.m3) + + +class CG(Wigner3j): + r"""Class for Clebsch-Gordan coefficient. + + Explanation + =========== + + Clebsch-Gordan coefficients describe the angular momentum coupling between + two systems. The coefficients give the expansion of a coupled total angular + momentum state and an uncoupled tensor product state. The Clebsch-Gordan + coefficients are defined as [1]_: + + .. math :: + C^{j_3,m_3}_{j_1,m_1,j_2,m_2} = \left\langle j_1,m_1;j_2,m_2 | j_3,m_3\right\rangle + + Parameters + ========== + + j1, m1, j2, m2 : Number, Symbol + Angular momenta of states 1 and 2. + + j3, m3: Number, Symbol + Total angular momentum of the coupled system. + + Examples + ======== + + Define a Clebsch-Gordan coefficient and evaluate its value + + >>> from sympy.physics.quantum.cg import CG + >>> from sympy import S + >>> cg = CG(S(3)/2, S(3)/2, S(1)/2, -S(1)/2, 1, 1) + >>> cg + CG(3/2, 3/2, 1/2, -1/2, 1, 1) + >>> cg.doit() + sqrt(3)/2 + >>> CG(j1=S(1)/2, m1=-S(1)/2, j2=S(1)/2, m2=+S(1)/2, j3=1, m3=0).doit() + sqrt(2)/2 + + + Compare [2]_. + + See Also + ======== + + Wigner3j: Wigner-3j symbols + + References + ========== + + .. [1] Varshalovich, D A, Quantum Theory of Angular Momentum. 1988. + .. [2] `Clebsch-Gordan Coefficients, Spherical Harmonics, and d Functions + `_ + in P.A. Zyla *et al.* (Particle Data Group), Prog. Theor. Exp. Phys. + 2020, 083C01 (2020). + """ + precedence = PRECEDENCE["Pow"] - 1 + + def doit(self, **hints): + if self.is_symbolic: + raise ValueError("Coefficients must be numerical") + return clebsch_gordan(self.j1, self.j2, self.j3, self.m1, self.m2, self.m3) + + def _pretty(self, printer, *args): + bot = printer._print_seq( + (self.j1, self.m1, self.j2, self.m2), delimiter=',') + top = printer._print_seq((self.j3, self.m3), delimiter=',') + + pad = max(top.width(), bot.width()) + bot = prettyForm(*bot.left(' ')) + top = prettyForm(*top.left(' ')) + + if not pad == bot.width(): + bot = prettyForm(*bot.right(' '*(pad - bot.width()))) + if not pad == top.width(): + top = prettyForm(*top.right(' '*(pad - top.width()))) + s = stringPict('C' + ' '*pad) + s = prettyForm(*s.below(bot)) + s = prettyForm(*s.above(top)) + return s + + def _latex(self, printer, *args): + label = map(printer._print, (self.j3, self.m3, self.j1, + self.m1, self.j2, self.m2)) + return r'C^{%s,%s}_{%s,%s,%s,%s}' % tuple(label) + + +class Wigner6j(Expr): + """Class for the Wigner-6j symbols + + See Also + ======== + + Wigner3j: Wigner-3j symbols + + """ + def __new__(cls, j1, j2, j12, j3, j, j23): + args = map(sympify, (j1, j2, j12, j3, j, j23)) + return Expr.__new__(cls, *args) + + @property + def j1(self): + return self.args[0] + + @property + def j2(self): + return self.args[1] + + @property + def j12(self): + return self.args[2] + + @property + def j3(self): + return self.args[3] + + @property + def j(self): + return self.args[4] + + @property + def j23(self): + return self.args[5] + + @property + def is_symbolic(self): + return not all(arg.is_number for arg in self.args) + + # This is modified from the _print_Matrix method + def _pretty(self, printer, *args): + m = ((printer._print(self.j1), printer._print(self.j3)), + (printer._print(self.j2), printer._print(self.j)), + (printer._print(self.j12), printer._print(self.j23))) + hsep = 2 + vsep = 1 + maxw = [-1]*3 + for j in range(3): + maxw[j] = max(m[j][i].width() for i in range(2)) + D = None + for i in range(2): + D_row = None + for j in range(3): + s = m[j][i] + wdelta = maxw[j] - s.width() + wleft = wdelta //2 + wright = wdelta - wleft + + s = prettyForm(*s.right(' '*wright)) + s = prettyForm(*s.left(' '*wleft)) + + if D_row is None: + D_row = s + continue + D_row = prettyForm(*D_row.right(' '*hsep)) + D_row = prettyForm(*D_row.right(s)) + if D is None: + D = D_row + continue + for _ in range(vsep): + D = prettyForm(*D.below(' ')) + D = prettyForm(*D.below(D_row)) + D = prettyForm(*D.parens(left='{', right='}')) + return D + + def _latex(self, printer, *args): + label = map(printer._print, (self.j1, self.j2, self.j12, + self.j3, self.j, self.j23)) + return r'\left\{\begin{array}{ccc} %s & %s & %s \\ %s & %s & %s \end{array}\right\}' % \ + tuple(label) + + def doit(self, **hints): + if self.is_symbolic: + raise ValueError("Coefficients must be numerical") + return wigner_6j(self.j1, self.j2, self.j12, self.j3, self.j, self.j23) + + +class Wigner9j(Expr): + """Class for the Wigner-9j symbols + + See Also + ======== + + Wigner3j: Wigner-3j symbols + + """ + def __new__(cls, j1, j2, j12, j3, j4, j34, j13, j24, j): + args = map(sympify, (j1, j2, j12, j3, j4, j34, j13, j24, j)) + return Expr.__new__(cls, *args) + + @property + def j1(self): + return self.args[0] + + @property + def j2(self): + return self.args[1] + + @property + def j12(self): + return self.args[2] + + @property + def j3(self): + return self.args[3] + + @property + def j4(self): + return self.args[4] + + @property + def j34(self): + return self.args[5] + + @property + def j13(self): + return self.args[6] + + @property + def j24(self): + return self.args[7] + + @property + def j(self): + return self.args[8] + + @property + def is_symbolic(self): + return not all(arg.is_number for arg in self.args) + + # This is modified from the _print_Matrix method + def _pretty(self, printer, *args): + m = ( + (printer._print( + self.j1), printer._print(self.j3), printer._print(self.j13)), + (printer._print( + self.j2), printer._print(self.j4), printer._print(self.j24)), + (printer._print(self.j12), printer._print(self.j34), printer._print(self.j))) + hsep = 2 + vsep = 1 + maxw = [-1]*3 + for j in range(3): + maxw[j] = max(m[j][i].width() for i in range(3)) + D = None + for i in range(3): + D_row = None + for j in range(3): + s = m[j][i] + wdelta = maxw[j] - s.width() + wleft = wdelta //2 + wright = wdelta - wleft + + s = prettyForm(*s.right(' '*wright)) + s = prettyForm(*s.left(' '*wleft)) + + if D_row is None: + D_row = s + continue + D_row = prettyForm(*D_row.right(' '*hsep)) + D_row = prettyForm(*D_row.right(s)) + if D is None: + D = D_row + continue + for _ in range(vsep): + D = prettyForm(*D.below(' ')) + D = prettyForm(*D.below(D_row)) + D = prettyForm(*D.parens(left='{', right='}')) + return D + + def _latex(self, printer, *args): + label = map(printer._print, (self.j1, self.j2, self.j12, self.j3, + self.j4, self.j34, self.j13, self.j24, self.j)) + return r'\left\{\begin{array}{ccc} %s & %s & %s \\ %s & %s & %s \\ %s & %s & %s \end{array}\right\}' % \ + tuple(label) + + def doit(self, **hints): + if self.is_symbolic: + raise ValueError("Coefficients must be numerical") + return wigner_9j(self.j1, self.j2, self.j12, self.j3, self.j4, self.j34, self.j13, self.j24, self.j) + + +def cg_simp(e): + """Simplify and combine CG coefficients. + + Explanation + =========== + + This function uses various symmetry and properties of sums and + products of Clebsch-Gordan coefficients to simplify statements + involving these terms [1]_. + + Examples + ======== + + Simplify the sum over CG(a,alpha,0,0,a,alpha) for all alpha to + 2*a+1 + + >>> from sympy.physics.quantum.cg import CG, cg_simp + >>> a = CG(1,1,0,0,1,1) + >>> b = CG(1,0,0,0,1,0) + >>> c = CG(1,-1,0,0,1,-1) + >>> cg_simp(a+b+c) + 3 + + See Also + ======== + + CG: Clebsh-Gordan coefficients + + References + ========== + + .. [1] Varshalovich, D A, Quantum Theory of Angular Momentum. 1988. + """ + if isinstance(e, Add): + return _cg_simp_add(e) + elif isinstance(e, Sum): + return _cg_simp_sum(e) + elif isinstance(e, Mul): + return Mul(*[cg_simp(arg) for arg in e.args]) + elif isinstance(e, Pow): + return Pow(cg_simp(e.base), e.exp) + else: + return e + + +def _cg_simp_add(e): + #TODO: Improve simplification method + """Takes a sum of terms involving Clebsch-Gordan coefficients and + simplifies the terms. + + Explanation + =========== + + First, we create two lists, cg_part, which is all the terms involving CG + coefficients, and other_part, which is all other terms. The cg_part list + is then passed to the simplification methods, which return the new cg_part + and any additional terms that are added to other_part + """ + cg_part = [] + other_part = [] + + e = expand(e) + for arg in e.args: + if arg.has(CG): + if isinstance(arg, Sum): + other_part.append(_cg_simp_sum(arg)) + elif isinstance(arg, Mul): + terms = 1 + for term in arg.args: + if isinstance(term, Sum): + terms *= _cg_simp_sum(term) + else: + terms *= term + if terms.has(CG): + cg_part.append(terms) + else: + other_part.append(terms) + else: + cg_part.append(arg) + else: + other_part.append(arg) + + cg_part, other = _check_varsh_871_1(cg_part) + other_part.append(other) + cg_part, other = _check_varsh_871_2(cg_part) + other_part.append(other) + cg_part, other = _check_varsh_872_9(cg_part) + other_part.append(other) + return Add(*cg_part) + Add(*other_part) + + +def _check_varsh_871_1(term_list): + # Sum( CG(a,alpha,b,0,a,alpha), (alpha, -a, a)) == KroneckerDelta(b,0) + a, alpha, b, lt = map(Wild, ('a', 'alpha', 'b', 'lt')) + expr = lt*CG(a, alpha, b, 0, a, alpha) + simp = (2*a + 1)*KroneckerDelta(b, 0) + sign = lt/abs(lt) + build_expr = 2*a + 1 + index_expr = a + alpha + return _check_cg_simp(expr, simp, sign, lt, term_list, (a, alpha, b, lt), (a, b), build_expr, index_expr) + + +def _check_varsh_871_2(term_list): + # Sum((-1)**(a-alpha)*CG(a,alpha,a,-alpha,c,0),(alpha,-a,a)) + a, alpha, c, lt = map(Wild, ('a', 'alpha', 'c', 'lt')) + expr = lt*CG(a, alpha, a, -alpha, c, 0) + simp = sqrt(2*a + 1)*KroneckerDelta(c, 0) + sign = (-1)**(a - alpha)*lt/abs(lt) + build_expr = 2*a + 1 + index_expr = a + alpha + return _check_cg_simp(expr, simp, sign, lt, term_list, (a, alpha, c, lt), (a, c), build_expr, index_expr) + + +def _check_varsh_872_9(term_list): + # Sum( CG(a,alpha,b,beta,c,gamma)*CG(a,alpha',b,beta',c,gamma), (gamma, -c, c), (c, abs(a-b), a+b)) + a, alpha, alphap, b, beta, betap, c, gamma, lt = map(Wild, ( + 'a', 'alpha', 'alphap', 'b', 'beta', 'betap', 'c', 'gamma', 'lt')) + # Case alpha==alphap, beta==betap + + # For numerical alpha,beta + expr = lt*CG(a, alpha, b, beta, c, gamma)**2 + simp = S.One + sign = lt/abs(lt) + x = abs(a - b) + y = abs(alpha + beta) + build_expr = a + b + 1 - Piecewise((x, x > y), (0, Eq(x, y)), (y, y > x)) + index_expr = a + b - c + term_list, other1 = _check_cg_simp(expr, simp, sign, lt, term_list, (a, alpha, b, beta, c, gamma, lt), (a, alpha, b, beta), build_expr, index_expr) + + # For symbolic alpha,beta + x = abs(a - b) + y = a + b + build_expr = (y + 1 - x)*(x + y + 1) + index_expr = (c - x)*(x + c) + c + gamma + term_list, other2 = _check_cg_simp(expr, simp, sign, lt, term_list, (a, alpha, b, beta, c, gamma, lt), (a, alpha, b, beta), build_expr, index_expr) + + # Case alpha!=alphap or beta!=betap + # Note: this only works with leading term of 1, pattern matching is unable to match when there is a Wild leading term + # For numerical alpha,alphap,beta,betap + expr = CG(a, alpha, b, beta, c, gamma)*CG(a, alphap, b, betap, c, gamma) + simp = KroneckerDelta(alpha, alphap)*KroneckerDelta(beta, betap) + sign = S.One + x = abs(a - b) + y = abs(alpha + beta) + build_expr = a + b + 1 - Piecewise((x, x > y), (0, Eq(x, y)), (y, y > x)) + index_expr = a + b - c + term_list, other3 = _check_cg_simp(expr, simp, sign, S.One, term_list, (a, alpha, alphap, b, beta, betap, c, gamma), (a, alpha, alphap, b, beta, betap), build_expr, index_expr) + + # For symbolic alpha,alphap,beta,betap + x = abs(a - b) + y = a + b + build_expr = (y + 1 - x)*(x + y + 1) + index_expr = (c - x)*(x + c) + c + gamma + term_list, other4 = _check_cg_simp(expr, simp, sign, S.One, term_list, (a, alpha, alphap, b, beta, betap, c, gamma), (a, alpha, alphap, b, beta, betap), build_expr, index_expr) + + return term_list, other1 + other2 + other4 + + +def _check_cg_simp(expr, simp, sign, lt, term_list, variables, dep_variables, build_index_expr, index_expr): + """ Checks for simplifications that can be made, returning a tuple of the + simplified list of terms and any terms generated by simplification. + + Parameters + ========== + + expr: expression + The expression with Wild terms that will be matched to the terms in + the sum + + simp: expression + The expression with Wild terms that is substituted in place of the CG + terms in the case of simplification + + sign: expression + The expression with Wild terms denoting the sign that is on expr that + must match + + lt: expression + The expression with Wild terms that gives the leading term of the + matched expr + + term_list: list + A list of all of the terms is the sum to be simplified + + variables: list + A list of all the variables that appears in expr + + dep_variables: list + A list of the variables that must match for all the terms in the sum, + i.e. the dependent variables + + build_index_expr: expression + Expression with Wild terms giving the number of elements in cg_index + + index_expr: expression + Expression with Wild terms giving the index terms have when storing + them to cg_index + + """ + other_part = 0 + i = 0 + while i < len(term_list): + sub_1 = _check_cg(term_list[i], expr, len(variables)) + if sub_1 is None: + i += 1 + continue + if not build_index_expr.subs(sub_1).is_number: + i += 1 + continue + sub_dep = [(x, sub_1[x]) for x in dep_variables] + cg_index = [None]*build_index_expr.subs(sub_1) + for j in range(i, len(term_list)): + sub_2 = _check_cg(term_list[j], expr.subs(sub_dep), len(variables) - len(dep_variables), sign=(sign.subs(sub_1), sign.subs(sub_dep))) + if sub_2 is None: + continue + if not index_expr.subs(sub_dep).subs(sub_2).is_number: + continue + cg_index[index_expr.subs(sub_dep).subs(sub_2)] = j, expr.subs(lt, 1).subs(sub_dep).subs(sub_2), lt.subs(sub_2), sign.subs(sub_dep).subs(sub_2) + if not any(i is None for i in cg_index): + min_lt = min(*[ abs(term[2]) for term in cg_index ]) + indices = [ term[0] for term in cg_index] + indices.sort() + indices.reverse() + [ term_list.pop(j) for j in indices ] + for term in cg_index: + if abs(term[2]) > min_lt: + term_list.append( (term[2] - min_lt*term[3])*term[1] ) + other_part += min_lt*(sign*simp).subs(sub_1) + else: + i += 1 + return term_list, other_part + + +def _check_cg(cg_term, expr, length, sign=None): + """Checks whether a term matches the given expression""" + # TODO: Check for symmetries + matches = cg_term.match(expr) + if matches is None: + return + if sign is not None: + if not isinstance(sign, tuple): + raise TypeError('sign must be a tuple') + if not sign[0] == (sign[1]).subs(matches): + return + if len(matches) == length: + return matches + + +def _cg_simp_sum(e): + e = _check_varsh_sum_871_1(e) + e = _check_varsh_sum_871_2(e) + e = _check_varsh_sum_872_4(e) + return e + + +def _check_varsh_sum_871_1(e): + a = Wild('a') + alpha = symbols('alpha') + b = Wild('b') + match = e.match(Sum(CG(a, alpha, b, 0, a, alpha), (alpha, -a, a))) + if match is not None and len(match) == 2: + return ((2*a + 1)*KroneckerDelta(b, 0)).subs(match) + return e + + +def _check_varsh_sum_871_2(e): + a = Wild('a') + alpha = symbols('alpha') + c = Wild('c') + match = e.match( + Sum((-1)**(a - alpha)*CG(a, alpha, a, -alpha, c, 0), (alpha, -a, a))) + if match is not None and len(match) == 2: + return (sqrt(2*a + 1)*KroneckerDelta(c, 0)).subs(match) + return e + + +def _check_varsh_sum_872_4(e): + alpha = symbols('alpha') + beta = symbols('beta') + a = Wild('a') + b = Wild('b') + c = Wild('c') + cp = Wild('cp') + gamma = Wild('gamma') + gammap = Wild('gammap') + cg1 = CG(a, alpha, b, beta, c, gamma) + cg2 = CG(a, alpha, b, beta, cp, gammap) + match1 = e.match(Sum(cg1*cg2, (alpha, -a, a), (beta, -b, b))) + if match1 is not None and len(match1) == 6: + return (KroneckerDelta(c, cp)*KroneckerDelta(gamma, gammap)).subs(match1) + match2 = e.match(Sum(cg1**2, (alpha, -a, a), (beta, -b, b))) + if match2 is not None and len(match2) == 4: + return S.One + return e + + +def _cg_list(term): + if isinstance(term, CG): + return (term,), 1, 1 + cg = [] + coeff = 1 + if not isinstance(term, (Mul, Pow)): + raise NotImplementedError('term must be CG, Add, Mul or Pow') + if isinstance(term, Pow) and term.exp.is_number: + if term.exp.is_number: + [ cg.append(term.base) for _ in range(term.exp) ] + else: + return (term,), 1, 1 + if isinstance(term, Mul): + for arg in term.args: + if isinstance(arg, CG): + cg.append(arg) + else: + coeff *= arg + return cg, coeff, coeff/abs(coeff) diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/circuitplot.py b/phivenv/Lib/site-packages/sympy/physics/quantum/circuitplot.py new file mode 100644 index 0000000000000000000000000000000000000000..316a4be613b2e275565999130c06ea678acd8b96 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/circuitplot.py @@ -0,0 +1,370 @@ +"""Matplotlib based plotting of quantum circuits. + +Todo: + +* Optimize printing of large circuits. +* Get this to work with single gates. +* Do a better job checking the form of circuits to make sure it is a Mul of + Gates. +* Get multi-target gates plotting. +* Get initial and final states to plot. +* Get measurements to plot. Might need to rethink measurement as a gate + issue. +* Get scale and figsize to be handled in a better way. +* Write some tests/examples! +""" + +from __future__ import annotations + +from sympy.core.mul import Mul +from sympy.external import import_module +from sympy.physics.quantum.gate import Gate, OneQubitGate, CGate, CGateS + + +__all__ = [ + 'CircuitPlot', + 'circuit_plot', + 'labeller', + 'Mz', + 'Mx', + 'CreateOneQubitGate', + 'CreateCGate', +] + +np = import_module('numpy') +matplotlib = import_module( + 'matplotlib', import_kwargs={'fromlist': ['pyplot']}, + catch=(RuntimeError,)) # This is raised in environments that have no display. + +if np and matplotlib: + pyplot = matplotlib.pyplot + Line2D = matplotlib.lines.Line2D + Circle = matplotlib.patches.Circle + +#from matplotlib import rc +#rc('text',usetex=True) + +class CircuitPlot: + """A class for managing a circuit plot.""" + + scale = 1.0 + fontsize = 20.0 + linewidth = 1.0 + control_radius = 0.05 + not_radius = 0.15 + swap_delta = 0.05 + labels: list[str] = [] + inits: dict[str, str] = {} + label_buffer = 0.5 + + def __init__(self, c, nqubits, **kwargs): + if not np or not matplotlib: + raise ImportError('numpy or matplotlib not available.') + self.circuit = c + self.ngates = len(self.circuit.args) + self.nqubits = nqubits + self.update(kwargs) + self._create_grid() + self._create_figure() + self._plot_wires() + self._plot_gates() + self._finish() + + def update(self, kwargs): + """Load the kwargs into the instance dict.""" + self.__dict__.update(kwargs) + + def _create_grid(self): + """Create the grid of wires.""" + scale = self.scale + wire_grid = np.arange(0.0, self.nqubits*scale, scale, dtype=float) + gate_grid = np.arange(0.0, self.ngates*scale, scale, dtype=float) + self._wire_grid = wire_grid + self._gate_grid = gate_grid + + def _create_figure(self): + """Create the main matplotlib figure.""" + self._figure = pyplot.figure( + figsize=(self.ngates*self.scale, self.nqubits*self.scale), + facecolor='w', + edgecolor='w' + ) + ax = self._figure.add_subplot( + 1, 1, 1, + frameon=True + ) + ax.set_axis_off() + offset = 0.5*self.scale + ax.set_xlim(self._gate_grid[0] - offset, self._gate_grid[-1] + offset) + ax.set_ylim(self._wire_grid[0] - offset, self._wire_grid[-1] + offset) + ax.set_aspect('equal') + self._axes = ax + + def _plot_wires(self): + """Plot the wires of the circuit diagram.""" + xstart = self._gate_grid[0] + xstop = self._gate_grid[-1] + xdata = (xstart - self.scale, xstop + self.scale) + for i in range(self.nqubits): + ydata = (self._wire_grid[i], self._wire_grid[i]) + line = Line2D( + xdata, ydata, + color='k', + lw=self.linewidth + ) + self._axes.add_line(line) + if self.labels: + init_label_buffer = 0 + if self.inits.get(self.labels[i]): init_label_buffer = 0.25 + self._axes.text( + xdata[0]-self.label_buffer-init_label_buffer,ydata[0], + render_label(self.labels[i],self.inits), + size=self.fontsize, + color='k',ha='center',va='center') + self._plot_measured_wires() + + def _plot_measured_wires(self): + ismeasured = self._measurements() + xstop = self._gate_grid[-1] + dy = 0.04 # amount to shift wires when doubled + # Plot doubled wires after they are measured + for im in ismeasured: + xdata = (self._gate_grid[ismeasured[im]],xstop+self.scale) + ydata = (self._wire_grid[im]+dy,self._wire_grid[im]+dy) + line = Line2D( + xdata, ydata, + color='k', + lw=self.linewidth + ) + self._axes.add_line(line) + # Also double any controlled lines off these wires + for i,g in enumerate(self._gates()): + if isinstance(g, (CGate, CGateS)): + wires = g.controls + g.targets + for wire in wires: + if wire in ismeasured and \ + self._gate_grid[i] > self._gate_grid[ismeasured[wire]]: + ydata = min(wires), max(wires) + xdata = self._gate_grid[i]-dy, self._gate_grid[i]-dy + line = Line2D( + xdata, ydata, + color='k', + lw=self.linewidth + ) + self._axes.add_line(line) + def _gates(self): + """Create a list of all gates in the circuit plot.""" + gates = [] + if isinstance(self.circuit, Mul): + for g in reversed(self.circuit.args): + if isinstance(g, Gate): + gates.append(g) + elif isinstance(self.circuit, Gate): + gates.append(self.circuit) + return gates + + def _plot_gates(self): + """Iterate through the gates and plot each of them.""" + for i, gate in enumerate(self._gates()): + gate.plot_gate(self, i) + + def _measurements(self): + """Return a dict ``{i:j}`` where i is the index of the wire that has + been measured, and j is the gate where the wire is measured. + """ + ismeasured = {} + for i,g in enumerate(self._gates()): + if getattr(g,'measurement',False): + for target in g.targets: + if target in ismeasured: + if ismeasured[target] > i: + ismeasured[target] = i + else: + ismeasured[target] = i + return ismeasured + + def _finish(self): + # Disable clipping to make panning work well for large circuits. + for o in self._figure.findobj(): + o.set_clip_on(False) + + def one_qubit_box(self, t, gate_idx, wire_idx): + """Draw a box for a single qubit gate.""" + x = self._gate_grid[gate_idx] + y = self._wire_grid[wire_idx] + self._axes.text( + x, y, t, + color='k', + ha='center', + va='center', + bbox={"ec": 'k', "fc": 'w', "fill": True, "lw": self.linewidth}, + size=self.fontsize + ) + + def two_qubit_box(self, t, gate_idx, wire_idx): + """Draw a box for a two qubit gate. Does not work yet. + """ + # x = self._gate_grid[gate_idx] + # y = self._wire_grid[wire_idx]+0.5 + print(self._gate_grid) + print(self._wire_grid) + # unused: + # obj = self._axes.text( + # x, y, t, + # color='k', + # ha='center', + # va='center', + # bbox=dict(ec='k', fc='w', fill=True, lw=self.linewidth), + # size=self.fontsize + # ) + + def control_line(self, gate_idx, min_wire, max_wire): + """Draw a vertical control line.""" + xdata = (self._gate_grid[gate_idx], self._gate_grid[gate_idx]) + ydata = (self._wire_grid[min_wire], self._wire_grid[max_wire]) + line = Line2D( + xdata, ydata, + color='k', + lw=self.linewidth + ) + self._axes.add_line(line) + + def control_point(self, gate_idx, wire_idx): + """Draw a control point.""" + x = self._gate_grid[gate_idx] + y = self._wire_grid[wire_idx] + radius = self.control_radius + c = Circle( + (x, y), + radius*self.scale, + ec='k', + fc='k', + fill=True, + lw=self.linewidth + ) + self._axes.add_patch(c) + + def not_point(self, gate_idx, wire_idx): + """Draw a NOT gates as the circle with plus in the middle.""" + x = self._gate_grid[gate_idx] + y = self._wire_grid[wire_idx] + radius = self.not_radius + c = Circle( + (x, y), + radius, + ec='k', + fc='w', + fill=False, + lw=self.linewidth + ) + self._axes.add_patch(c) + l = Line2D( + (x, x), (y - radius, y + radius), + color='k', + lw=self.linewidth + ) + self._axes.add_line(l) + + def swap_point(self, gate_idx, wire_idx): + """Draw a swap point as a cross.""" + x = self._gate_grid[gate_idx] + y = self._wire_grid[wire_idx] + d = self.swap_delta + l1 = Line2D( + (x - d, x + d), + (y - d, y + d), + color='k', + lw=self.linewidth + ) + l2 = Line2D( + (x - d, x + d), + (y + d, y - d), + color='k', + lw=self.linewidth + ) + self._axes.add_line(l1) + self._axes.add_line(l2) + +def circuit_plot(c, nqubits, **kwargs): + """Draw the circuit diagram for the circuit with nqubits. + + Parameters + ========== + + c : circuit + The circuit to plot. Should be a product of Gate instances. + nqubits : int + The number of qubits to include in the circuit. Must be at least + as big as the largest ``min_qubits`` of the gates. + """ + return CircuitPlot(c, nqubits, **kwargs) + +def render_label(label, inits={}): + """Slightly more flexible way to render labels. + + >>> from sympy.physics.quantum.circuitplot import render_label + >>> render_label('q0') + '$\\\\left|q0\\\\right\\\\rangle$' + >>> render_label('q0', {'q0':'0'}) + '$\\\\left|q0\\\\right\\\\rangle=\\\\left|0\\\\right\\\\rangle$' + """ + init = inits.get(label) + if init: + return r'$\left|%s\right\rangle=\left|%s\right\rangle$' % (label, init) + return r'$\left|%s\right\rangle$' % label + +def labeller(n, symbol='q'): + """Autogenerate labels for wires of quantum circuits. + + Parameters + ========== + + n : int + number of qubits in the circuit. + symbol : string + A character string to precede all gate labels. E.g. 'q_0', 'q_1', etc. + + >>> from sympy.physics.quantum.circuitplot import labeller + >>> labeller(2) + ['q_1', 'q_0'] + >>> labeller(3,'j') + ['j_2', 'j_1', 'j_0'] + """ + return ['%s_%d' % (symbol,n-i-1) for i in range(n)] + +class Mz(OneQubitGate): + """Mock-up of a z measurement gate. + + This is in circuitplot rather than gate.py because it's not a real + gate, it just draws one. + """ + measurement = True + gate_name='Mz' + gate_name_latex='M_z' + +class Mx(OneQubitGate): + """Mock-up of an x measurement gate. + + This is in circuitplot rather than gate.py because it's not a real + gate, it just draws one. + """ + measurement = True + gate_name='Mx' + gate_name_latex='M_x' + +class CreateOneQubitGate(type): + def __new__(mcl, name, latexname=None): + if not latexname: + latexname = name + return type(name + "Gate", (OneQubitGate,), + {'gate_name': name, 'gate_name_latex': latexname}) + +def CreateCGate(name, latexname=None): + """Use a lexical closure to make a controlled gate. + """ + if not latexname: + latexname = name + onequbitgate = CreateOneQubitGate(name, latexname) + def ControlledGate(ctrls,target): + return CGate(tuple(ctrls),onequbitgate(target)) + return ControlledGate diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/circuitutils.py b/phivenv/Lib/site-packages/sympy/physics/quantum/circuitutils.py new file mode 100644 index 0000000000000000000000000000000000000000..84955d3d724a2658f2dc3b26738133bd46f1aa57 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/circuitutils.py @@ -0,0 +1,488 @@ +"""Primitive circuit operations on quantum circuits.""" + +from functools import reduce + +from sympy.core.sorting import default_sort_key +from sympy.core.containers import Tuple +from sympy.core.mul import Mul +from sympy.core.symbol import Symbol +from sympy.core.sympify import sympify +from sympy.utilities import numbered_symbols +from sympy.physics.quantum.gate import Gate + +__all__ = [ + 'kmp_table', + 'find_subcircuit', + 'replace_subcircuit', + 'convert_to_symbolic_indices', + 'convert_to_real_indices', + 'random_reduce', + 'random_insert' +] + + +def kmp_table(word): + """Build the 'partial match' table of the Knuth-Morris-Pratt algorithm. + + Note: This is applicable to strings or + quantum circuits represented as tuples. + """ + + # Current position in subcircuit + pos = 2 + # Beginning position of candidate substring that + # may reappear later in word + cnd = 0 + # The 'partial match' table that helps one determine + # the next location to start substring search + table = [] + table.append(-1) + table.append(0) + + while pos < len(word): + if word[pos - 1] == word[cnd]: + cnd = cnd + 1 + table.append(cnd) + pos = pos + 1 + elif cnd > 0: + cnd = table[cnd] + else: + table.append(0) + pos = pos + 1 + + return table + + +def find_subcircuit(circuit, subcircuit, start=0, end=0): + """Finds the subcircuit in circuit, if it exists. + + Explanation + =========== + + If the subcircuit exists, the index of the start of + the subcircuit in circuit is returned; otherwise, + -1 is returned. The algorithm that is implemented + is the Knuth-Morris-Pratt algorithm. + + Parameters + ========== + + circuit : tuple, Gate or Mul + A tuple of Gates or Mul representing a quantum circuit + subcircuit : tuple, Gate or Mul + A tuple of Gates or Mul to find in circuit + start : int + The location to start looking for subcircuit. + If start is the same or past end, -1 is returned. + end : int + The last place to look for a subcircuit. If end + is less than 1 (one), then the length of circuit + is taken to be end. + + Examples + ======== + + Find the first instance of a subcircuit: + + >>> from sympy.physics.quantum.circuitutils import find_subcircuit + >>> from sympy.physics.quantum.gate import X, Y, Z, H + >>> circuit = X(0)*Z(0)*Y(0)*H(0) + >>> subcircuit = Z(0)*Y(0) + >>> find_subcircuit(circuit, subcircuit) + 1 + + Find the first instance starting at a specific position: + + >>> find_subcircuit(circuit, subcircuit, start=1) + 1 + + >>> find_subcircuit(circuit, subcircuit, start=2) + -1 + + >>> circuit = circuit*subcircuit + >>> find_subcircuit(circuit, subcircuit, start=2) + 4 + + Find the subcircuit within some interval: + + >>> find_subcircuit(circuit, subcircuit, start=2, end=2) + -1 + """ + + if isinstance(circuit, Mul): + circuit = circuit.args + + if isinstance(subcircuit, Mul): + subcircuit = subcircuit.args + + if len(subcircuit) == 0 or len(subcircuit) > len(circuit): + return -1 + + if end < 1: + end = len(circuit) + + # Location in circuit + pos = start + # Location in the subcircuit + index = 0 + # 'Partial match' table + table = kmp_table(subcircuit) + + while (pos + index) < end: + if subcircuit[index] == circuit[pos + index]: + index = index + 1 + else: + pos = pos + index - table[index] + index = table[index] if table[index] > -1 else 0 + + if index == len(subcircuit): + return pos + + return -1 + + +def replace_subcircuit(circuit, subcircuit, replace=None, pos=0): + """Replaces a subcircuit with another subcircuit in circuit, + if it exists. + + Explanation + =========== + + If multiple instances of subcircuit exists, the first instance is + replaced. The position to being searching from (if different from + 0) may be optionally given. If subcircuit cannot be found, circuit + is returned. + + Parameters + ========== + + circuit : tuple, Gate or Mul + A quantum circuit. + subcircuit : tuple, Gate or Mul + The circuit to be replaced. + replace : tuple, Gate or Mul + The replacement circuit. + pos : int + The location to start search and replace + subcircuit, if it exists. This may be used + if it is known beforehand that multiple + instances exist, and it is desirable to + replace a specific instance. If a negative number + is given, pos will be defaulted to 0. + + Examples + ======== + + Find and remove the subcircuit: + + >>> from sympy.physics.quantum.circuitutils import replace_subcircuit + >>> from sympy.physics.quantum.gate import X, Y, Z, H + >>> circuit = X(0)*Z(0)*Y(0)*H(0)*X(0)*H(0)*Y(0) + >>> subcircuit = Z(0)*Y(0) + >>> replace_subcircuit(circuit, subcircuit) + (X(0), H(0), X(0), H(0), Y(0)) + + Remove the subcircuit given a starting search point: + + >>> replace_subcircuit(circuit, subcircuit, pos=1) + (X(0), H(0), X(0), H(0), Y(0)) + + >>> replace_subcircuit(circuit, subcircuit, pos=2) + (X(0), Z(0), Y(0), H(0), X(0), H(0), Y(0)) + + Replace the subcircuit: + + >>> replacement = H(0)*Z(0) + >>> replace_subcircuit(circuit, subcircuit, replace=replacement) + (X(0), H(0), Z(0), H(0), X(0), H(0), Y(0)) + """ + + if pos < 0: + pos = 0 + + if isinstance(circuit, Mul): + circuit = circuit.args + + if isinstance(subcircuit, Mul): + subcircuit = subcircuit.args + + if isinstance(replace, Mul): + replace = replace.args + elif replace is None: + replace = () + + # Look for the subcircuit starting at pos + loc = find_subcircuit(circuit, subcircuit, start=pos) + + # If subcircuit was found + if loc > -1: + # Get the gates to the left of subcircuit + left = circuit[0:loc] + # Get the gates to the right of subcircuit + right = circuit[loc + len(subcircuit):len(circuit)] + # Recombine the left and right side gates into a circuit + circuit = left + replace + right + + return circuit + + +def _sympify_qubit_map(mapping): + new_map = {} + for key in mapping: + new_map[key] = sympify(mapping[key]) + return new_map + + +def convert_to_symbolic_indices(seq, start=None, gen=None, qubit_map=None): + """Returns the circuit with symbolic indices and the + dictionary mapping symbolic indices to real indices. + + The mapping is 1 to 1 and onto (bijective). + + Parameters + ========== + + seq : tuple, Gate/Integer/tuple or Mul + A tuple of Gate, Integer, or tuple objects, or a Mul + start : Symbol + An optional starting symbolic index + gen : object + An optional numbered symbol generator + qubit_map : dict + An existing mapping of symbolic indices to real indices + + All symbolic indices have the format 'i#', where # is + some number >= 0. + """ + + if isinstance(seq, Mul): + seq = seq.args + + # A numbered symbol generator + index_gen = numbered_symbols(prefix='i', start=-1) + cur_ndx = next(index_gen) + + # keys are symbolic indices; values are real indices + ndx_map = {} + + def create_inverse_map(symb_to_real_map): + rev_items = lambda item: (item[1], item[0]) + return dict(map(rev_items, symb_to_real_map.items())) + + if start is not None: + if not isinstance(start, Symbol): + msg = 'Expected Symbol for starting index, got %r.' % start + raise TypeError(msg) + cur_ndx = start + + if gen is not None: + if not isinstance(gen, numbered_symbols().__class__): + msg = 'Expected a generator, got %r.' % gen + raise TypeError(msg) + index_gen = gen + + if qubit_map is not None: + if not isinstance(qubit_map, dict): + msg = ('Expected dict for existing map, got ' + + '%r.' % qubit_map) + raise TypeError(msg) + ndx_map = qubit_map + + ndx_map = _sympify_qubit_map(ndx_map) + # keys are real indices; keys are symbolic indices + inv_map = create_inverse_map(ndx_map) + + sym_seq = () + for item in seq: + # Nested items, so recurse + if isinstance(item, Gate): + result = convert_to_symbolic_indices(item.args, + qubit_map=ndx_map, + start=cur_ndx, + gen=index_gen) + sym_item, new_map, cur_ndx, index_gen = result + ndx_map.update(new_map) + inv_map = create_inverse_map(ndx_map) + + elif isinstance(item, (tuple, Tuple)): + result = convert_to_symbolic_indices(item, + qubit_map=ndx_map, + start=cur_ndx, + gen=index_gen) + sym_item, new_map, cur_ndx, index_gen = result + ndx_map.update(new_map) + inv_map = create_inverse_map(ndx_map) + + elif item in inv_map: + sym_item = inv_map[item] + + else: + cur_ndx = next(gen) + ndx_map[cur_ndx] = item + inv_map[item] = cur_ndx + sym_item = cur_ndx + + if isinstance(item, Gate): + sym_item = item.__class__(*sym_item) + + sym_seq = sym_seq + (sym_item,) + + return sym_seq, ndx_map, cur_ndx, index_gen + + +def convert_to_real_indices(seq, qubit_map): + """Returns the circuit with real indices. + + Parameters + ========== + + seq : tuple, Gate/Integer/tuple or Mul + A tuple of Gate, Integer, or tuple objects or a Mul + qubit_map : dict + A dictionary mapping symbolic indices to real indices. + + Examples + ======== + + Change the symbolic indices to real integers: + + >>> from sympy import symbols + >>> from sympy.physics.quantum.circuitutils import convert_to_real_indices + >>> from sympy.physics.quantum.gate import X, Y, H + >>> i0, i1 = symbols('i:2') + >>> index_map = {i0 : 0, i1 : 1} + >>> convert_to_real_indices(X(i0)*Y(i1)*H(i0)*X(i1), index_map) + (X(0), Y(1), H(0), X(1)) + """ + + if isinstance(seq, Mul): + seq = seq.args + + if not isinstance(qubit_map, dict): + msg = 'Expected dict for qubit_map, got %r.' % qubit_map + raise TypeError(msg) + + qubit_map = _sympify_qubit_map(qubit_map) + real_seq = () + for item in seq: + # Nested items, so recurse + if isinstance(item, Gate): + real_item = convert_to_real_indices(item.args, qubit_map) + + elif isinstance(item, (tuple, Tuple)): + real_item = convert_to_real_indices(item, qubit_map) + + else: + real_item = qubit_map[item] + + if isinstance(item, Gate): + real_item = item.__class__(*real_item) + + real_seq = real_seq + (real_item,) + + return real_seq + + +def random_reduce(circuit, gate_ids, seed=None): + """Shorten the length of a quantum circuit. + + Explanation + =========== + + random_reduce looks for circuit identities in circuit, randomly chooses + one to remove, and returns a shorter yet equivalent circuit. If no + identities are found, the same circuit is returned. + + Parameters + ========== + + circuit : Gate tuple of Mul + A tuple of Gates representing a quantum circuit + gate_ids : list, GateIdentity + List of gate identities to find in circuit + seed : int or list + seed used for _randrange; to override the random selection, provide a + list of integers: the elements of gate_ids will be tested in the order + given by the list + + """ + from sympy.core.random import _randrange + + if not gate_ids: + return circuit + + if isinstance(circuit, Mul): + circuit = circuit.args + + ids = flatten_ids(gate_ids) + + # Create the random integer generator with the seed + randrange = _randrange(seed) + + # Look for an identity in the circuit + while ids: + i = randrange(len(ids)) + id = ids.pop(i) + if find_subcircuit(circuit, id) != -1: + break + else: + # no identity was found + return circuit + + # return circuit with the identity removed + return replace_subcircuit(circuit, id) + + +def random_insert(circuit, choices, seed=None): + """Insert a circuit into another quantum circuit. + + Explanation + =========== + + random_insert randomly chooses a location in the circuit to insert + a randomly selected circuit from amongst the given choices. + + Parameters + ========== + + circuit : Gate tuple or Mul + A tuple or Mul of Gates representing a quantum circuit + choices : list + Set of circuit choices + seed : int or list + seed used for _randrange; to override the random selections, give + a list two integers, [i, j] where i is the circuit location where + choice[j] will be inserted. + + Notes + ===== + + Indices for insertion should be [0, n] if n is the length of the + circuit. + """ + from sympy.core.random import _randrange + + if not choices: + return circuit + + if isinstance(circuit, Mul): + circuit = circuit.args + + # get the location in the circuit and the element to insert from choices + randrange = _randrange(seed) + loc = randrange(len(circuit) + 1) + choice = choices[randrange(len(choices))] + + circuit = list(circuit) + circuit[loc: loc] = choice + return tuple(circuit) + +# Flatten the GateIdentity objects (with gate rules) into one single list + + +def flatten_ids(ids): + collapse = lambda acc, an_id: acc + sorted(an_id.equivalent_ids, + key=default_sort_key) + ids = reduce(collapse, ids, []) + ids.sort(key=default_sort_key) + return ids diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/commutator.py b/phivenv/Lib/site-packages/sympy/physics/quantum/commutator.py new file mode 100644 index 0000000000000000000000000000000000000000..a2d97a679e27387077429a9973de21ad868e84ac --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/commutator.py @@ -0,0 +1,256 @@ +"""The commutator: [A,B] = A*B - B*A.""" + +from sympy.core.add import Add +from sympy.core.expr import Expr +from sympy.core.kind import KindDispatcher +from sympy.core.mul import Mul +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.printing.pretty.stringpict import prettyForm + +from sympy.physics.quantum.dagger import Dagger +from sympy.physics.quantum.kind import _OperatorKind, OperatorKind + + +__all__ = [ + 'Commutator' +] + +#----------------------------------------------------------------------------- +# Commutator +#----------------------------------------------------------------------------- + + +class Commutator(Expr): + """The standard commutator, in an unevaluated state. + + Explanation + =========== + + Evaluating a commutator is defined [1]_ as: ``[A, B] = A*B - B*A``. This + class returns the commutator in an unevaluated form. To evaluate the + commutator, use the ``.doit()`` method. + + Canonical ordering of a commutator is ``[A, B]`` for ``A < B``. The + arguments of the commutator are put into canonical order using ``__cmp__``. + If ``B < A``, then ``[B, A]`` is returned as ``-[A, B]``. + + Parameters + ========== + + A : Expr + The first argument of the commutator [A,B]. + B : Expr + The second argument of the commutator [A,B]. + + Examples + ======== + + >>> from sympy.physics.quantum import Commutator, Dagger, Operator + >>> from sympy.abc import x, y + >>> A = Operator('A') + >>> B = Operator('B') + >>> C = Operator('C') + + Create a commutator and use ``.doit()`` to evaluate it: + + >>> comm = Commutator(A, B) + >>> comm + [A,B] + >>> comm.doit() + A*B - B*A + + The commutator orders it arguments in canonical order: + + >>> comm = Commutator(B, A); comm + -[A,B] + + Commutative constants are factored out: + + >>> Commutator(3*x*A, x*y*B) + 3*x**2*y*[A,B] + + Using ``.expand(commutator=True)``, the standard commutator expansion rules + can be applied: + + >>> Commutator(A+B, C).expand(commutator=True) + [A,C] + [B,C] + >>> Commutator(A, B+C).expand(commutator=True) + [A,B] + [A,C] + >>> Commutator(A*B, C).expand(commutator=True) + [A,C]*B + A*[B,C] + >>> Commutator(A, B*C).expand(commutator=True) + [A,B]*C + B*[A,C] + + Adjoint operations applied to the commutator are properly applied to the + arguments: + + >>> Dagger(Commutator(A, B)) + -[Dagger(A),Dagger(B)] + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Commutator + """ + is_commutative = False + + _kind_dispatcher = KindDispatcher("Commutator_kind_dispatcher", commutative=True) + + @property + def kind(self): + arg_kinds = (a.kind for a in self.args) + return self._kind_dispatcher(*arg_kinds) + + def __new__(cls, A, B): + r = cls.eval(A, B) + if r is not None: + return r + obj = Expr.__new__(cls, A, B) + return obj + + @classmethod + def eval(cls, a, b): + if not (a and b): + return S.Zero + if a == b: + return S.Zero + if a.is_commutative or b.is_commutative: + return S.Zero + + # [xA,yB] -> xy*[A,B] + ca, nca = a.args_cnc() + cb, ncb = b.args_cnc() + c_part = ca + cb + if c_part: + return Mul(Mul(*c_part), cls(Mul._from_args(nca), Mul._from_args(ncb))) + + # Canonical ordering of arguments + # The Commutator [A, B] is in canonical form if A < B. + if a.compare(b) == 1: + return S.NegativeOne*cls(b, a) + + def _expand_pow(self, A, B, sign): + exp = A.exp + if not exp.is_integer or not exp.is_constant() or abs(exp) <= 1: + # nothing to do + return self + base = A.base + if exp.is_negative: + base = A.base**-1 + exp = -exp + comm = Commutator(base, B).expand(commutator=True) + + result = base**(exp - 1) * comm + for i in range(1, exp): + result += base**(exp - 1 - i) * comm * base**i + return sign*result.expand() + + def _eval_expand_commutator(self, **hints): + A = self.args[0] + B = self.args[1] + + if isinstance(A, Add): + # [A + B, C] -> [A, C] + [B, C] + sargs = [] + for term in A.args: + comm = Commutator(term, B) + if isinstance(comm, Commutator): + comm = comm._eval_expand_commutator() + sargs.append(comm) + return Add(*sargs) + elif isinstance(B, Add): + # [A, B + C] -> [A, B] + [A, C] + sargs = [] + for term in B.args: + comm = Commutator(A, term) + if isinstance(comm, Commutator): + comm = comm._eval_expand_commutator() + sargs.append(comm) + return Add(*sargs) + elif isinstance(A, Mul): + # [A*B, C] -> A*[B, C] + [A, C]*B + a = A.args[0] + b = Mul(*A.args[1:]) + c = B + comm1 = Commutator(b, c) + comm2 = Commutator(a, c) + if isinstance(comm1, Commutator): + comm1 = comm1._eval_expand_commutator() + if isinstance(comm2, Commutator): + comm2 = comm2._eval_expand_commutator() + first = Mul(a, comm1) + second = Mul(comm2, b) + return Add(first, second) + elif isinstance(B, Mul): + # [A, B*C] -> [A, B]*C + B*[A, C] + a = A + b = B.args[0] + c = Mul(*B.args[1:]) + comm1 = Commutator(a, b) + comm2 = Commutator(a, c) + if isinstance(comm1, Commutator): + comm1 = comm1._eval_expand_commutator() + if isinstance(comm2, Commutator): + comm2 = comm2._eval_expand_commutator() + first = Mul(comm1, c) + second = Mul(b, comm2) + return Add(first, second) + elif isinstance(A, Pow): + # [A**n, C] -> A**(n - 1)*[A, C] + A**(n - 2)*[A, C]*A + ... + [A, C]*A**(n-1) + return self._expand_pow(A, B, 1) + elif isinstance(B, Pow): + # [A, C**n] -> C**(n - 1)*[C, A] + C**(n - 2)*[C, A]*C + ... + [C, A]*C**(n-1) + return self._expand_pow(B, A, -1) + + # No changes, so return self + return self + + def doit(self, **hints): + """ Evaluate commutator """ + # Keep the import of Operator here to avoid problems with + # circular imports. + from sympy.physics.quantum.operator import Operator + A = self.args[0] + B = self.args[1] + if isinstance(A, Operator) and isinstance(B, Operator): + try: + comm = A._eval_commutator(B, **hints) + except NotImplementedError: + try: + comm = -1*B._eval_commutator(A, **hints) + except NotImplementedError: + comm = None + if comm is not None: + return comm.doit(**hints) + return (A*B - B*A).doit(**hints) + + def _eval_adjoint(self): + return Commutator(Dagger(self.args[1]), Dagger(self.args[0])) + + def _sympyrepr(self, printer, *args): + return "%s(%s,%s)" % ( + self.__class__.__name__, printer._print( + self.args[0]), printer._print(self.args[1]) + ) + + def _sympystr(self, printer, *args): + return "[%s,%s]" % ( + printer._print(self.args[0]), printer._print(self.args[1])) + + def _pretty(self, printer, *args): + pform = printer._print(self.args[0], *args) + pform = prettyForm(*pform.right(prettyForm(','))) + pform = prettyForm(*pform.right(printer._print(self.args[1], *args))) + pform = prettyForm(*pform.parens(left='[', right=']')) + return pform + + def _latex(self, printer, *args): + return "\\left[%s,%s\\right]" % tuple([ + printer._print(arg, *args) for arg in self.args]) + + +@Commutator._kind_dispatcher.register(_OperatorKind, _OperatorKind) +def find_op_kind(e1, e2): + """Find the kind of an anticommutator of two OperatorKinds.""" + return OperatorKind diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/constants.py b/phivenv/Lib/site-packages/sympy/physics/quantum/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..3e848bf24e95e3bd612169128a1845202066c6e9 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/constants.py @@ -0,0 +1,59 @@ +"""Constants (like hbar) related to quantum mechanics.""" + +from sympy.core.numbers import NumberSymbol +from sympy.core.singleton import Singleton +from sympy.printing.pretty.stringpict import prettyForm +import mpmath.libmp as mlib + +#----------------------------------------------------------------------------- +# Constants +#----------------------------------------------------------------------------- + +__all__ = [ + 'hbar', + 'HBar', +] + + +class HBar(NumberSymbol, metaclass=Singleton): + """Reduced Plank's constant in numerical and symbolic form [1]_. + + Examples + ======== + + >>> from sympy.physics.quantum.constants import hbar + >>> hbar.evalf() + 1.05457162000000e-34 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Planck_constant + """ + + is_real = True + is_positive = True + is_negative = False + is_irrational = True + + __slots__ = () + + def _as_mpf_val(self, prec): + return mlib.from_float(1.05457162e-34, prec) + + def _sympyrepr(self, printer, *args): + return 'HBar()' + + def _sympystr(self, printer, *args): + return 'hbar' + + def _pretty(self, printer, *args): + if printer._use_unicode: + return prettyForm('\N{PLANCK CONSTANT OVER TWO PI}') + return prettyForm('hbar') + + def _latex(self, printer, *args): + return r'\hbar' + +# Create an instance for everyone to use. +hbar = HBar() diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/dagger.py b/phivenv/Lib/site-packages/sympy/physics/quantum/dagger.py new file mode 100644 index 0000000000000000000000000000000000000000..f96f01e3b9ac86ae30b03e3b97293bbafceaed8a --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/dagger.py @@ -0,0 +1,95 @@ +"""Hermitian conjugation.""" + +from sympy.core import Expr, sympify +from sympy.functions.elementary.complexes import adjoint + +__all__ = [ + 'Dagger' +] + + +class Dagger(adjoint): + """General Hermitian conjugate operation. + + Explanation + =========== + + Take the Hermetian conjugate of an argument [1]_. For matrices this + operation is equivalent to transpose and complex conjugate [2]_. + + Parameters + ========== + + arg : Expr + The SymPy expression that we want to take the dagger of. + evaluate : bool + Whether the resulting expression should be directly evaluated. + + Examples + ======== + + Daggering various quantum objects: + + >>> from sympy.physics.quantum.dagger import Dagger + >>> from sympy.physics.quantum.state import Ket, Bra + >>> from sympy.physics.quantum.operator import Operator + >>> Dagger(Ket('psi')) + >> Dagger(Bra('phi')) + |phi> + >>> Dagger(Operator('A')) + Dagger(A) + + Inner and outer products:: + + >>> from sympy.physics.quantum import InnerProduct, OuterProduct + >>> Dagger(InnerProduct(Bra('a'), Ket('b'))) + + >>> Dagger(OuterProduct(Ket('a'), Bra('b'))) + |b>>> A = Operator('A') + >>> B = Operator('B') + >>> Dagger(A*B) + Dagger(B)*Dagger(A) + >>> Dagger(A+B) + Dagger(A) + Dagger(B) + >>> Dagger(A**2) + Dagger(A)**2 + + Dagger also seamlessly handles complex numbers and matrices:: + + >>> from sympy import Matrix, I + >>> m = Matrix([[1,I],[2,I]]) + >>> m + Matrix([ + [1, I], + [2, I]]) + >>> Dagger(m) + Matrix([ + [ 1, 2], + [-I, -I]]) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Hermitian_adjoint + .. [2] https://en.wikipedia.org/wiki/Hermitian_transpose + """ + + @property + def kind(self): + """Find the kind of a dagger of something (just the kind of the something).""" + return self.args[0].kind + + def __new__(cls, arg, evaluate=True): + if hasattr(arg, 'adjoint') and evaluate: + return arg.adjoint() + elif hasattr(arg, 'conjugate') and hasattr(arg, 'transpose') and evaluate: + return arg.conjugate().transpose() + return Expr.__new__(cls, sympify(arg)) + +adjoint.__name__ = "Dagger" +adjoint._sympyrepr = lambda a, b: "Dagger(%s)" % b._print(a.args[0]) diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/density.py b/phivenv/Lib/site-packages/sympy/physics/quantum/density.py new file mode 100644 index 0000000000000000000000000000000000000000..941373e8105dd0c725626396dfd9cd794b19d3f5 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/density.py @@ -0,0 +1,315 @@ +from itertools import product + +from sympy.core.add import Add +from sympy.core.containers import Tuple +from sympy.core.function import expand +from sympy.core.mul import Mul +from sympy.core.singleton import S +from sympy.functions.elementary.exponential import log +from sympy.matrices.dense import MutableDenseMatrix as Matrix +from sympy.printing.pretty.stringpict import prettyForm +from sympy.physics.quantum.dagger import Dagger +from sympy.physics.quantum.operator import HermitianOperator +from sympy.physics.quantum.represent import represent +from sympy.physics.quantum.matrixutils import numpy_ndarray, scipy_sparse_matrix, to_numpy +from sympy.physics.quantum.trace import Tr + + +class Density(HermitianOperator): + """Density operator for representing mixed states. + + TODO: Density operator support for Qubits + + Parameters + ========== + + values : tuples/lists + Each tuple/list should be of form (state, prob) or [state,prob] + + Examples + ======== + + Create a density operator with 2 states represented by Kets. + + >>> from sympy.physics.quantum.state import Ket + >>> from sympy.physics.quantum.density import Density + >>> d = Density([Ket(0), 0.5], [Ket(1),0.5]) + >>> d + Density((|0>, 0.5),(|1>, 0.5)) + + """ + @classmethod + def _eval_args(cls, args): + # call this to qsympify the args + args = super()._eval_args(args) + + for arg in args: + # Check if arg is a tuple + if not (isinstance(arg, Tuple) and len(arg) == 2): + raise ValueError("Each argument should be of form [state,prob]" + " or ( state, prob )") + + return args + + def states(self): + """Return list of all states. + + Examples + ======== + + >>> from sympy.physics.quantum.state import Ket + >>> from sympy.physics.quantum.density import Density + >>> d = Density([Ket(0), 0.5], [Ket(1),0.5]) + >>> d.states() + (|0>, |1>) + + """ + return Tuple(*[arg[0] for arg in self.args]) + + def probs(self): + """Return list of all probabilities. + + Examples + ======== + + >>> from sympy.physics.quantum.state import Ket + >>> from sympy.physics.quantum.density import Density + >>> d = Density([Ket(0), 0.5], [Ket(1),0.5]) + >>> d.probs() + (0.5, 0.5) + + """ + return Tuple(*[arg[1] for arg in self.args]) + + def get_state(self, index): + """Return specific state by index. + + Parameters + ========== + + index : index of state to be returned + + Examples + ======== + + >>> from sympy.physics.quantum.state import Ket + >>> from sympy.physics.quantum.density import Density + >>> d = Density([Ket(0), 0.5], [Ket(1),0.5]) + >>> d.states()[1] + |1> + + """ + state = self.args[index][0] + return state + + def get_prob(self, index): + """Return probability of specific state by index. + + Parameters + =========== + + index : index of states whose probability is returned. + + Examples + ======== + + >>> from sympy.physics.quantum.state import Ket + >>> from sympy.physics.quantum.density import Density + >>> d = Density([Ket(0), 0.5], [Ket(1),0.5]) + >>> d.probs()[1] + 0.500000000000000 + + """ + prob = self.args[index][1] + return prob + + def apply_op(self, op): + """op will operate on each individual state. + + Parameters + ========== + + op : Operator + + Examples + ======== + + >>> from sympy.physics.quantum.state import Ket + >>> from sympy.physics.quantum.density import Density + >>> from sympy.physics.quantum.operator import Operator + >>> A = Operator('A') + >>> d = Density([Ket(0), 0.5], [Ket(1),0.5]) + >>> d.apply_op(A) + Density((A*|0>, 0.5),(A*|1>, 0.5)) + + """ + new_args = [(op*state, prob) for (state, prob) in self.args] + return Density(*new_args) + + def doit(self, **hints): + """Expand the density operator into an outer product format. + + Examples + ======== + + >>> from sympy.physics.quantum.state import Ket + >>> from sympy.physics.quantum.density import Density + >>> from sympy.physics.quantum.operator import Operator + >>> A = Operator('A') + >>> d = Density([Ket(0), 0.5], [Ket(1),0.5]) + >>> d.doit() + 0.5*|0><0| + 0.5*|1><1| + + """ + + terms = [] + for (state, prob) in self.args: + state = state.expand() # needed to break up (a+b)*c + if (isinstance(state, Add)): + for arg in product(state.args, repeat=2): + terms.append(prob*self._generate_outer_prod(arg[0], + arg[1])) + else: + terms.append(prob*self._generate_outer_prod(state, state)) + + return Add(*terms) + + def _generate_outer_prod(self, arg1, arg2): + c_part1, nc_part1 = arg1.args_cnc() + c_part2, nc_part2 = arg2.args_cnc() + + if (len(nc_part1) == 0 or len(nc_part2) == 0): + raise ValueError('Atleast one-pair of' + ' Non-commutative instance required' + ' for outer product.') + + # We were able to remove some tensor product simplifications that + # used to be here as those transformations are not automatically + # applied by transforms.py. + op = Mul(*nc_part1)*Dagger(Mul(*nc_part2)) + + return Mul(*c_part1)*Mul(*c_part2) * op + + def _represent(self, **options): + return represent(self.doit(), **options) + + def _print_operator_name_latex(self, printer, *args): + return r'\rho' + + def _print_operator_name_pretty(self, printer, *args): + return prettyForm('\N{GREEK SMALL LETTER RHO}') + + def _eval_trace(self, **kwargs): + indices = kwargs.get('indices', []) + return Tr(self.doit(), indices).doit() + + def entropy(self): + """ Compute the entropy of a density matrix. + + Refer to density.entropy() method for examples. + """ + return entropy(self) + + +def entropy(density): + """Compute the entropy of a matrix/density object. + + This computes -Tr(density*ln(density)) using the eigenvalue decomposition + of density, which is given as either a Density instance or a matrix + (numpy.ndarray, sympy.Matrix or scipy.sparse). + + Parameters + ========== + + density : density matrix of type Density, SymPy matrix, + scipy.sparse or numpy.ndarray + + Examples + ======== + + >>> from sympy.physics.quantum.density import Density, entropy + >>> from sympy.physics.quantum.spin import JzKet + >>> from sympy import S + >>> up = JzKet(S(1)/2,S(1)/2) + >>> down = JzKet(S(1)/2,-S(1)/2) + >>> d = Density((up,S(1)/2),(down,S(1)/2)) + >>> entropy(d) + log(2)/2 + + """ + if isinstance(density, Density): + density = represent(density) # represent in Matrix + + if isinstance(density, scipy_sparse_matrix): + density = to_numpy(density) + + if isinstance(density, Matrix): + eigvals = density.eigenvals().keys() + return expand(-sum(e*log(e) for e in eigvals)) + elif isinstance(density, numpy_ndarray): + import numpy as np + eigvals = np.linalg.eigvals(density) + return -np.sum(eigvals*np.log(eigvals)) + else: + raise ValueError( + "numpy.ndarray, scipy.sparse or SymPy matrix expected") + + +def fidelity(state1, state2): + """ Computes the fidelity [1]_ between two quantum states + + The arguments provided to this function should be a square matrix or a + Density object. If it is a square matrix, it is assumed to be diagonalizable. + + Parameters + ========== + + state1, state2 : a density matrix or Matrix + + + Examples + ======== + + >>> from sympy import S, sqrt + >>> from sympy.physics.quantum.dagger import Dagger + >>> from sympy.physics.quantum.spin import JzKet + >>> from sympy.physics.quantum.density import fidelity + >>> from sympy.physics.quantum.represent import represent + >>> + >>> up = JzKet(S(1)/2,S(1)/2) + >>> down = JzKet(S(1)/2,-S(1)/2) + >>> amp = 1/sqrt(2) + >>> updown = (amp*up) + (amp*down) + >>> + >>> # represent turns Kets into matrices + >>> up_dm = represent(up*Dagger(up)) + >>> down_dm = represent(down*Dagger(down)) + >>> updown_dm = represent(updown*Dagger(updown)) + >>> + >>> fidelity(up_dm, up_dm) + 1 + >>> fidelity(up_dm, down_dm) #orthogonal states + 0 + >>> fidelity(up_dm, updown_dm).evalf().round(3) + 0.707 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Fidelity_of_quantum_states + + """ + state1 = represent(state1) if isinstance(state1, Density) else state1 + state2 = represent(state2) if isinstance(state2, Density) else state2 + + if not isinstance(state1, Matrix) or not isinstance(state2, Matrix): + raise ValueError("state1 and state2 must be of type Density or Matrix " + "received type=%s for state1 and type=%s for state2" % + (type(state1), type(state2))) + + if state1.shape != state2.shape and state1.is_square: + raise ValueError("The dimensions of both args should be equal and the " + "matrix obtained should be a square matrix") + + sqrt_state1 = state1**S.Half + return Tr((sqrt_state1*state2*sqrt_state1)**S.Half).doit() diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/fermion.py b/phivenv/Lib/site-packages/sympy/physics/quantum/fermion.py new file mode 100644 index 0000000000000000000000000000000000000000..8080bd3b0904b837652fdae7be0bd526da2d508f --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/fermion.py @@ -0,0 +1,191 @@ +"""Fermionic quantum operators.""" + +from sympy.core.numbers import Integer +from sympy.core.singleton import S +from sympy.physics.quantum import Operator +from sympy.physics.quantum import HilbertSpace, Ket, Bra +from sympy.functions.special.tensor_functions import KroneckerDelta + + +__all__ = [ + 'FermionOp', + 'FermionFockKet', + 'FermionFockBra' +] + + +class FermionOp(Operator): + """A fermionic operator that satisfies {c, Dagger(c)} == 1. + + Parameters + ========== + + name : str + A string that labels the fermionic mode. + + annihilation : bool + A bool that indicates if the fermionic operator is an annihilation + (True, default value) or creation operator (False) + + Examples + ======== + + >>> from sympy.physics.quantum import Dagger, AntiCommutator + >>> from sympy.physics.quantum.fermion import FermionOp + >>> c = FermionOp("c") + >>> AntiCommutator(c, Dagger(c)).doit() + 1 + """ + @property + def name(self): + return self.args[0] + + @property + def is_annihilation(self): + return bool(self.args[1]) + + @classmethod + def default_args(self): + return ("c", True) + + def __new__(cls, *args, **hints): + if not len(args) in [1, 2]: + raise ValueError('1 or 2 parameters expected, got %s' % args) + + if len(args) == 1: + args = (args[0], S.One) + + if len(args) == 2: + args = (args[0], Integer(args[1])) + + return Operator.__new__(cls, *args) + + def _eval_commutator_FermionOp(self, other, **hints): + if 'independent' in hints and hints['independent']: + # [c, d] = 0 + return S.Zero + + return None + + def _eval_anticommutator_FermionOp(self, other, **hints): + if self.name == other.name: + # {a^\dagger, a} = 1 + if not self.is_annihilation and other.is_annihilation: + return S.One + + elif 'independent' in hints and hints['independent']: + # {c, d} = 2 * c * d, because [c, d] = 0 for independent operators + return 2 * self * other + + return None + + def _eval_anticommutator_BosonOp(self, other, **hints): + # because fermions and bosons commute + return 2 * self * other + + def _eval_commutator_BosonOp(self, other, **hints): + return S.Zero + + def _eval_adjoint(self): + return FermionOp(str(self.name), not self.is_annihilation) + + def _print_contents_latex(self, printer, *args): + if self.is_annihilation: + return r'{%s}' % str(self.name) + else: + return r'{{%s}^\dagger}' % str(self.name) + + def _print_contents(self, printer, *args): + if self.is_annihilation: + return r'%s' % str(self.name) + else: + return r'Dagger(%s)' % str(self.name) + + def _print_contents_pretty(self, printer, *args): + from sympy.printing.pretty.stringpict import prettyForm + pform = printer._print(self.args[0], *args) + if self.is_annihilation: + return pform + else: + return pform**prettyForm('\N{DAGGER}') + + def _eval_power(self, exp): + from sympy.core.singleton import S + if exp == 0: + return S.One + elif exp == 1: + return self + elif (exp > 1) == True and exp.is_integer == True: + return S.Zero + elif (exp < 0) == True or exp.is_integer == False: + raise ValueError("Fermionic operators can only be raised to a" + " positive integer power") + return Operator._eval_power(self, exp) + +class FermionFockKet(Ket): + """Fock state ket for a fermionic mode. + + Parameters + ========== + + n : Number + The Fock state number. + + """ + + def __new__(cls, n): + if n not in (0, 1): + raise ValueError("n must be 0 or 1") + return Ket.__new__(cls, n) + + @property + def n(self): + return self.label[0] + + @classmethod + def dual_class(self): + return FermionFockBra + + @classmethod + def _eval_hilbert_space(cls, label): + return HilbertSpace() + + def _eval_innerproduct_FermionFockBra(self, bra, **hints): + return KroneckerDelta(self.n, bra.n) + + def _apply_from_right_to_FermionOp(self, op, **options): + if op.is_annihilation: + if self.n == 1: + return FermionFockKet(0) + else: + return S.Zero + else: + if self.n == 0: + return FermionFockKet(1) + else: + return S.Zero + + +class FermionFockBra(Bra): + """Fock state bra for a fermionic mode. + + Parameters + ========== + + n : Number + The Fock state number. + + """ + + def __new__(cls, n): + if n not in (0, 1): + raise ValueError("n must be 0 or 1") + return Bra.__new__(cls, n) + + @property + def n(self): + return self.label[0] + + @classmethod + def dual_class(self): + return FermionFockKet diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/gate.py b/phivenv/Lib/site-packages/sympy/physics/quantum/gate.py new file mode 100644 index 0000000000000000000000000000000000000000..f8bcf5cd3611173cd9ebd6308dbbc896f5257f20 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/gate.py @@ -0,0 +1,1309 @@ +"""An implementation of gates that act on qubits. + +Gates are unitary operators that act on the space of qubits. + +Medium Term Todo: + +* Optimize Gate._apply_operators_Qubit to remove the creation of many + intermediate Qubit objects. +* Add commutation relationships to all operators and use this in gate_sort. +* Fix gate_sort and gate_simp. +* Get multi-target UGates plotting properly. +* Get UGate to work with either sympy/numpy matrices and output either + format. This should also use the matrix slots. +""" + +from itertools import chain +import random + +from sympy.core.add import Add +from sympy.core.containers import Tuple +from sympy.core.mul import Mul +from sympy.core.numbers import (I, Integer) +from sympy.core.power import Pow +from sympy.core.numbers import Number +from sympy.core.singleton import S as _S +from sympy.core.sorting import default_sort_key +from sympy.core.sympify import _sympify +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.printing.pretty.stringpict import prettyForm, stringPict + +from sympy.physics.quantum.anticommutator import AntiCommutator +from sympy.physics.quantum.commutator import Commutator +from sympy.physics.quantum.qexpr import QuantumError +from sympy.physics.quantum.hilbert import ComplexSpace +from sympy.physics.quantum.operator import (UnitaryOperator, Operator, + HermitianOperator) +from sympy.physics.quantum.matrixutils import matrix_tensor_product, matrix_eye +from sympy.physics.quantum.matrixcache import matrix_cache + +from sympy.matrices.matrixbase import MatrixBase + +from sympy.utilities.iterables import is_sequence + +__all__ = [ + 'Gate', + 'CGate', + 'UGate', + 'OneQubitGate', + 'TwoQubitGate', + 'IdentityGate', + 'HadamardGate', + 'XGate', + 'YGate', + 'ZGate', + 'TGate', + 'PhaseGate', + 'SwapGate', + 'CNotGate', + # Aliased gate names + 'CNOT', + 'SWAP', + 'H', + 'X', + 'Y', + 'Z', + 'T', + 'S', + 'Phase', + 'normalized', + 'gate_sort', + 'gate_simp', + 'random_circuit', + 'CPHASE', + 'CGateS', +] + +#----------------------------------------------------------------------------- +# Gate Super-Classes +#----------------------------------------------------------------------------- + +_normalized = True + + +def _max(*args, **kwargs): + if "key" not in kwargs: + kwargs["key"] = default_sort_key + return max(*args, **kwargs) + + +def _min(*args, **kwargs): + if "key" not in kwargs: + kwargs["key"] = default_sort_key + return min(*args, **kwargs) + + +def normalized(normalize): + r"""Set flag controlling normalization of Hadamard gates by `1/\sqrt{2}`. + + This is a global setting that can be used to simplify the look of various + expressions, by leaving off the leading `1/\sqrt{2}` of the Hadamard gate. + + Parameters + ---------- + normalize : bool + Should the Hadamard gate include the `1/\sqrt{2}` normalization factor? + When True, the Hadamard gate will have the `1/\sqrt{2}`. When False, the + Hadamard gate will not have this factor. + """ + global _normalized + _normalized = normalize + + +def _validate_targets_controls(tandc): + tandc = list(tandc) + # Check for integers + for bit in tandc: + if not bit.is_Integer and not bit.is_Symbol: + raise TypeError('Integer expected, got: %r' % tandc[bit]) + # Detect duplicates + if len(set(tandc)) != len(tandc): + raise QuantumError( + 'Target/control qubits in a gate cannot be duplicated' + ) + + +class Gate(UnitaryOperator): + """Non-controlled unitary gate operator that acts on qubits. + + This is a general abstract gate that needs to be subclassed to do anything + useful. + + Parameters + ---------- + label : tuple, int + A list of the target qubits (as ints) that the gate will apply to. + + Examples + ======== + + + """ + + _label_separator = ',' + + gate_name = 'G' + gate_name_latex = 'G' + + #------------------------------------------------------------------------- + # Initialization/creation + #------------------------------------------------------------------------- + + @classmethod + def _eval_args(cls, args): + args = Tuple(*UnitaryOperator._eval_args(args)) + _validate_targets_controls(args) + return args + + @classmethod + def _eval_hilbert_space(cls, args): + """This returns the smallest possible Hilbert space.""" + return ComplexSpace(2)**(_max(args) + 1) + + #------------------------------------------------------------------------- + # Properties + #------------------------------------------------------------------------- + + @property + def nqubits(self): + """The total number of qubits this gate acts on. + + For controlled gate subclasses this includes both target and control + qubits, so that, for examples the CNOT gate acts on 2 qubits. + """ + return len(self.targets) + + @property + def min_qubits(self): + """The minimum number of qubits this gate needs to act on.""" + return _max(self.targets) + 1 + + @property + def targets(self): + """A tuple of target qubits.""" + return self.label + + @property + def gate_name_plot(self): + return r'$%s$' % self.gate_name_latex + + #------------------------------------------------------------------------- + # Gate methods + #------------------------------------------------------------------------- + + def get_target_matrix(self, format='sympy'): + """The matrix representation of the target part of the gate. + + Parameters + ---------- + format : str + The format string ('sympy','numpy', etc.) + """ + raise NotImplementedError( + 'get_target_matrix is not implemented in Gate.') + + #------------------------------------------------------------------------- + # Apply + #------------------------------------------------------------------------- + + def _apply_operator_IntQubit(self, qubits, **options): + """Redirect an apply from IntQubit to Qubit""" + return self._apply_operator_Qubit(qubits, **options) + + def _apply_operator_Qubit(self, qubits, **options): + """Apply this gate to a Qubit.""" + + # Check number of qubits this gate acts on. + if qubits.nqubits < self.min_qubits: + raise QuantumError( + 'Gate needs a minimum of %r qubits to act on, got: %r' % + (self.min_qubits, qubits.nqubits) + ) + + # If the controls are not met, just return + if isinstance(self, CGate): + if not self.eval_controls(qubits): + return qubits + + targets = self.targets + target_matrix = self.get_target_matrix(format='sympy') + + # Find which column of the target matrix this applies to. + column_index = 0 + n = 1 + for target in targets: + column_index += n*qubits[target] + n = n << 1 + column = target_matrix[:, int(column_index)] + + # Now apply each column element to the qubit. + result = 0 + for index in range(column.rows): + # TODO: This can be optimized to reduce the number of Qubit + # creations. We should simply manipulate the raw list of qubit + # values and then build the new Qubit object once. + # Make a copy of the incoming qubits. + new_qubit = qubits.__class__(*qubits.args) + # Flip the bits that need to be flipped. + for bit, target in enumerate(targets): + if new_qubit[target] != (index >> bit) & 1: + new_qubit = new_qubit.flip(target) + # The value in that row and column times the flipped-bit qubit + # is the result for that part. + result += column[index]*new_qubit + return result + + #------------------------------------------------------------------------- + # Represent + #------------------------------------------------------------------------- + + def _represent_default_basis(self, **options): + return self._represent_ZGate(None, **options) + + def _represent_ZGate(self, basis, **options): + format = options.get('format', 'sympy') + nqubits = options.get('nqubits', 0) + if nqubits == 0: + raise QuantumError( + 'The number of qubits must be given as nqubits.') + + # Make sure we have enough qubits for the gate. + if nqubits < self.min_qubits: + raise QuantumError( + 'The number of qubits %r is too small for the gate.' % nqubits + ) + + target_matrix = self.get_target_matrix(format) + targets = self.targets + if isinstance(self, CGate): + controls = self.controls + else: + controls = [] + m = represent_zbasis( + controls, targets, target_matrix, nqubits, format + ) + return m + + #------------------------------------------------------------------------- + # Print methods + #------------------------------------------------------------------------- + + def _sympystr(self, printer, *args): + label = self._print_label(printer, *args) + return '%s(%s)' % (self.gate_name, label) + + def _pretty(self, printer, *args): + a = stringPict(self.gate_name) + b = self._print_label_pretty(printer, *args) + return self._print_subscript_pretty(a, b) + + def _latex(self, printer, *args): + label = self._print_label(printer, *args) + return '%s_{%s}' % (self.gate_name_latex, label) + + def plot_gate(self, axes, gate_idx, gate_grid, wire_grid): + raise NotImplementedError('plot_gate is not implemented.') + + +class CGate(Gate): + """A general unitary gate with control qubits. + + A general control gate applies a target gate to a set of targets if all + of the control qubits have a particular values (set by + ``CGate.control_value``). + + Parameters + ---------- + label : tuple + The label in this case has the form (controls, gate), where controls + is a tuple/list of control qubits (as ints) and gate is a ``Gate`` + instance that is the target operator. + + Examples + ======== + + """ + + gate_name = 'C' + gate_name_latex = 'C' + + # The values this class controls for. + control_value = _S.One + + simplify_cgate = False + + #------------------------------------------------------------------------- + # Initialization + #------------------------------------------------------------------------- + + @classmethod + def _eval_args(cls, args): + # _eval_args has the right logic for the controls argument. + controls = args[0] + gate = args[1] + if not is_sequence(controls): + controls = (controls,) + controls = UnitaryOperator._eval_args(controls) + _validate_targets_controls(chain(controls, gate.targets)) + return (Tuple(*controls), gate) + + @classmethod + def _eval_hilbert_space(cls, args): + """This returns the smallest possible Hilbert space.""" + return ComplexSpace(2)**_max(_max(args[0]) + 1, args[1].min_qubits) + + #------------------------------------------------------------------------- + # Properties + #------------------------------------------------------------------------- + + @property + def nqubits(self): + """The total number of qubits this gate acts on. + + For controlled gate subclasses this includes both target and control + qubits, so that, for examples the CNOT gate acts on 2 qubits. + """ + return len(self.targets) + len(self.controls) + + @property + def min_qubits(self): + """The minimum number of qubits this gate needs to act on.""" + return _max(_max(self.controls), _max(self.targets)) + 1 + + @property + def targets(self): + """A tuple of target qubits.""" + return self.gate.targets + + @property + def controls(self): + """A tuple of control qubits.""" + return tuple(self.label[0]) + + @property + def gate(self): + """The non-controlled gate that will be applied to the targets.""" + return self.label[1] + + #------------------------------------------------------------------------- + # Gate methods + #------------------------------------------------------------------------- + + def get_target_matrix(self, format='sympy'): + return self.gate.get_target_matrix(format) + + def eval_controls(self, qubit): + """Return True/False to indicate if the controls are satisfied.""" + return all(qubit[bit] == self.control_value for bit in self.controls) + + def decompose(self, **options): + """Decompose the controlled gate into CNOT and single qubits gates.""" + if len(self.controls) == 1: + c = self.controls[0] + t = self.gate.targets[0] + if isinstance(self.gate, YGate): + g1 = PhaseGate(t) + g2 = CNotGate(c, t) + g3 = PhaseGate(t) + g4 = ZGate(t) + return g1*g2*g3*g4 + if isinstance(self.gate, ZGate): + g1 = HadamardGate(t) + g2 = CNotGate(c, t) + g3 = HadamardGate(t) + return g1*g2*g3 + else: + return self + + #------------------------------------------------------------------------- + # Print methods + #------------------------------------------------------------------------- + + def _print_label(self, printer, *args): + controls = self._print_sequence(self.controls, ',', printer, *args) + gate = printer._print(self.gate, *args) + return '(%s),%s' % (controls, gate) + + def _pretty(self, printer, *args): + controls = self._print_sequence_pretty( + self.controls, ',', printer, *args) + gate = printer._print(self.gate) + gate_name = stringPict(self.gate_name) + first = self._print_subscript_pretty(gate_name, controls) + gate = self._print_parens_pretty(gate) + final = prettyForm(*first.right(gate)) + return final + + def _latex(self, printer, *args): + controls = self._print_sequence(self.controls, ',', printer, *args) + gate = printer._print(self.gate, *args) + return r'%s_{%s}{\left(%s\right)}' % \ + (self.gate_name_latex, controls, gate) + + def plot_gate(self, circ_plot, gate_idx): + """ + Plot the controlled gate. If *simplify_cgate* is true, simplify + C-X and C-Z gates into their more familiar forms. + """ + min_wire = int(_min(chain(self.controls, self.targets))) + max_wire = int(_max(chain(self.controls, self.targets))) + circ_plot.control_line(gate_idx, min_wire, max_wire) + for c in self.controls: + circ_plot.control_point(gate_idx, int(c)) + if self.simplify_cgate: + if self.gate.gate_name == 'X': + self.gate.plot_gate_plus(circ_plot, gate_idx) + elif self.gate.gate_name == 'Z': + circ_plot.control_point(gate_idx, self.targets[0]) + else: + self.gate.plot_gate(circ_plot, gate_idx) + else: + self.gate.plot_gate(circ_plot, gate_idx) + + #------------------------------------------------------------------------- + # Miscellaneous + #------------------------------------------------------------------------- + + def _eval_dagger(self): + if isinstance(self.gate, HermitianOperator): + return self + else: + return Gate._eval_dagger(self) + + def _eval_inverse(self): + if isinstance(self.gate, HermitianOperator): + return self + else: + return Gate._eval_inverse(self) + + def _eval_power(self, exp): + if isinstance(self.gate, HermitianOperator): + if exp == -1: + return Gate._eval_power(self, exp) + elif abs(exp) % 2 == 0: + return self*(Gate._eval_inverse(self)) + else: + return self + else: + return Gate._eval_power(self, exp) + +class CGateS(CGate): + """Version of CGate that allows gate simplifications. + I.e. cnot looks like an oplus, cphase has dots, etc. + """ + simplify_cgate=True + + +class UGate(Gate): + """General gate specified by a set of targets and a target matrix. + + Parameters + ---------- + label : tuple + A tuple of the form (targets, U), where targets is a tuple of the + target qubits and U is a unitary matrix with dimension of + len(targets). + """ + gate_name = 'U' + gate_name_latex = 'U' + + #------------------------------------------------------------------------- + # Initialization + #------------------------------------------------------------------------- + + @classmethod + def _eval_args(cls, args): + targets = args[0] + if not is_sequence(targets): + targets = (targets,) + targets = Gate._eval_args(targets) + _validate_targets_controls(targets) + mat = args[1] + if not isinstance(mat, MatrixBase): + raise TypeError('Matrix expected, got: %r' % mat) + #make sure this matrix is of a Basic type + mat = _sympify(mat) + dim = 2**len(targets) + if not all(dim == shape for shape in mat.shape): + raise IndexError( + 'Number of targets must match the matrix size: %r %r' % + (targets, mat) + ) + return (targets, mat) + + @classmethod + def _eval_hilbert_space(cls, args): + """This returns the smallest possible Hilbert space.""" + return ComplexSpace(2)**(_max(args[0]) + 1) + + #------------------------------------------------------------------------- + # Properties + #------------------------------------------------------------------------- + + @property + def targets(self): + """A tuple of target qubits.""" + return tuple(self.label[0]) + + #------------------------------------------------------------------------- + # Gate methods + #------------------------------------------------------------------------- + + def get_target_matrix(self, format='sympy'): + """The matrix rep. of the target part of the gate. + + Parameters + ---------- + format : str + The format string ('sympy','numpy', etc.) + """ + return self.label[1] + + #------------------------------------------------------------------------- + # Print methods + #------------------------------------------------------------------------- + def _pretty(self, printer, *args): + targets = self._print_sequence_pretty( + self.targets, ',', printer, *args) + gate_name = stringPict(self.gate_name) + return self._print_subscript_pretty(gate_name, targets) + + def _latex(self, printer, *args): + targets = self._print_sequence(self.targets, ',', printer, *args) + return r'%s_{%s}' % (self.gate_name_latex, targets) + + def plot_gate(self, circ_plot, gate_idx): + circ_plot.one_qubit_box( + self.gate_name_plot, + gate_idx, int(self.targets[0]) + ) + + +class OneQubitGate(Gate): + """A single qubit unitary gate base class.""" + + nqubits = _S.One + + def plot_gate(self, circ_plot, gate_idx): + circ_plot.one_qubit_box( + self.gate_name_plot, + gate_idx, int(self.targets[0]) + ) + + def _eval_commutator(self, other, **hints): + if isinstance(other, OneQubitGate): + if self.targets != other.targets or self.__class__ == other.__class__: + return _S.Zero + return Operator._eval_commutator(self, other, **hints) + + def _eval_anticommutator(self, other, **hints): + if isinstance(other, OneQubitGate): + if self.targets != other.targets or self.__class__ == other.__class__: + return Integer(2)*self*other + return Operator._eval_anticommutator(self, other, **hints) + + +class TwoQubitGate(Gate): + """A two qubit unitary gate base class.""" + + nqubits = Integer(2) + +#----------------------------------------------------------------------------- +# Single Qubit Gates +#----------------------------------------------------------------------------- + + +class IdentityGate(OneQubitGate): + """The single qubit identity gate. + + Parameters + ---------- + target : int + The target qubit this gate will apply to. + + Examples + ======== + + """ + is_hermitian = True + gate_name = '1' + gate_name_latex = '1' + + # Short cut version of gate._apply_operator_Qubit + def _apply_operator_Qubit(self, qubits, **options): + # Check number of qubits this gate acts on (see gate._apply_operator_Qubit) + if qubits.nqubits < self.min_qubits: + raise QuantumError( + 'Gate needs a minimum of %r qubits to act on, got: %r' % + (self.min_qubits, qubits.nqubits) + ) + return qubits # no computation required for IdentityGate + + def get_target_matrix(self, format='sympy'): + return matrix_cache.get_matrix('eye2', format) + + def _eval_commutator(self, other, **hints): + return _S.Zero + + def _eval_anticommutator(self, other, **hints): + return Integer(2)*other + + +class HadamardGate(HermitianOperator, OneQubitGate): + """The single qubit Hadamard gate. + + Parameters + ---------- + target : int + The target qubit this gate will apply to. + + Examples + ======== + + >>> from sympy import sqrt + >>> from sympy.physics.quantum.qubit import Qubit + >>> from sympy.physics.quantum.gate import HadamardGate + >>> from sympy.physics.quantum.qapply import qapply + >>> qapply(HadamardGate(0)*Qubit('1')) + sqrt(2)*|0>/2 - sqrt(2)*|1>/2 + >>> # Hadamard on bell state, applied on 2 qubits. + >>> psi = 1/sqrt(2)*(Qubit('00')+Qubit('11')) + >>> qapply(HadamardGate(0)*HadamardGate(1)*psi) + sqrt(2)*|00>/2 + sqrt(2)*|11>/2 + + """ + gate_name = 'H' + gate_name_latex = 'H' + + def get_target_matrix(self, format='sympy'): + if _normalized: + return matrix_cache.get_matrix('H', format) + else: + return matrix_cache.get_matrix('Hsqrt2', format) + + def _eval_commutator_XGate(self, other, **hints): + return I*sqrt(2)*YGate(self.targets[0]) + + def _eval_commutator_YGate(self, other, **hints): + return I*sqrt(2)*(ZGate(self.targets[0]) - XGate(self.targets[0])) + + def _eval_commutator_ZGate(self, other, **hints): + return -I*sqrt(2)*YGate(self.targets[0]) + + def _eval_anticommutator_XGate(self, other, **hints): + return sqrt(2)*IdentityGate(self.targets[0]) + + def _eval_anticommutator_YGate(self, other, **hints): + return _S.Zero + + def _eval_anticommutator_ZGate(self, other, **hints): + return sqrt(2)*IdentityGate(self.targets[0]) + + +class XGate(HermitianOperator, OneQubitGate): + """The single qubit X, or NOT, gate. + + Parameters + ---------- + target : int + The target qubit this gate will apply to. + + Examples + ======== + + """ + gate_name = 'X' + gate_name_latex = 'X' + + def get_target_matrix(self, format='sympy'): + return matrix_cache.get_matrix('X', format) + + def plot_gate(self, circ_plot, gate_idx): + OneQubitGate.plot_gate(self,circ_plot,gate_idx) + + def plot_gate_plus(self, circ_plot, gate_idx): + circ_plot.not_point( + gate_idx, int(self.label[0]) + ) + + def _eval_commutator_YGate(self, other, **hints): + return Integer(2)*I*ZGate(self.targets[0]) + + def _eval_anticommutator_XGate(self, other, **hints): + return Integer(2)*IdentityGate(self.targets[0]) + + def _eval_anticommutator_YGate(self, other, **hints): + return _S.Zero + + def _eval_anticommutator_ZGate(self, other, **hints): + return _S.Zero + + +class YGate(HermitianOperator, OneQubitGate): + """The single qubit Y gate. + + Parameters + ---------- + target : int + The target qubit this gate will apply to. + + Examples + ======== + + """ + gate_name = 'Y' + gate_name_latex = 'Y' + + def get_target_matrix(self, format='sympy'): + return matrix_cache.get_matrix('Y', format) + + def _eval_commutator_ZGate(self, other, **hints): + return Integer(2)*I*XGate(self.targets[0]) + + def _eval_anticommutator_YGate(self, other, **hints): + return Integer(2)*IdentityGate(self.targets[0]) + + def _eval_anticommutator_ZGate(self, other, **hints): + return _S.Zero + + +class ZGate(HermitianOperator, OneQubitGate): + """The single qubit Z gate. + + Parameters + ---------- + target : int + The target qubit this gate will apply to. + + Examples + ======== + + """ + gate_name = 'Z' + gate_name_latex = 'Z' + + def get_target_matrix(self, format='sympy'): + return matrix_cache.get_matrix('Z', format) + + def _eval_commutator_XGate(self, other, **hints): + return Integer(2)*I*YGate(self.targets[0]) + + def _eval_anticommutator_YGate(self, other, **hints): + return _S.Zero + + +class PhaseGate(OneQubitGate): + """The single qubit phase, or S, gate. + + This gate rotates the phase of the state by pi/2 if the state is ``|1>`` and + does nothing if the state is ``|0>``. + + Parameters + ---------- + target : int + The target qubit this gate will apply to. + + Examples + ======== + + """ + is_hermitian = False + gate_name = 'S' + gate_name_latex = 'S' + + def get_target_matrix(self, format='sympy'): + return matrix_cache.get_matrix('S', format) + + def _eval_commutator_ZGate(self, other, **hints): + return _S.Zero + + def _eval_commutator_TGate(self, other, **hints): + return _S.Zero + + +class TGate(OneQubitGate): + """The single qubit pi/8 gate. + + This gate rotates the phase of the state by pi/4 if the state is ``|1>`` and + does nothing if the state is ``|0>``. + + Parameters + ---------- + target : int + The target qubit this gate will apply to. + + Examples + ======== + + """ + is_hermitian = False + gate_name = 'T' + gate_name_latex = 'T' + + def get_target_matrix(self, format='sympy'): + return matrix_cache.get_matrix('T', format) + + def _eval_commutator_ZGate(self, other, **hints): + return _S.Zero + + def _eval_commutator_PhaseGate(self, other, **hints): + return _S.Zero + + +# Aliases for gate names. +H = HadamardGate +X = XGate +Y = YGate +Z = ZGate +T = TGate +Phase = S = PhaseGate + + +#----------------------------------------------------------------------------- +# 2 Qubit Gates +#----------------------------------------------------------------------------- + + +class CNotGate(HermitianOperator, CGate, TwoQubitGate): + """Two qubit controlled-NOT. + + This gate performs the NOT or X gate on the target qubit if the control + qubits all have the value 1. + + Parameters + ---------- + label : tuple + A tuple of the form (control, target). + + Examples + ======== + + >>> from sympy.physics.quantum.gate import CNOT + >>> from sympy.physics.quantum.qapply import qapply + >>> from sympy.physics.quantum.qubit import Qubit + >>> c = CNOT(1,0) + >>> qapply(c*Qubit('10')) # note that qubits are indexed from right to left + |11> + + """ + gate_name = 'CNOT' + gate_name_latex = r'\text{CNOT}' + simplify_cgate = True + + #------------------------------------------------------------------------- + # Initialization + #------------------------------------------------------------------------- + + @classmethod + def _eval_args(cls, args): + args = Gate._eval_args(args) + return args + + @classmethod + def _eval_hilbert_space(cls, args): + """This returns the smallest possible Hilbert space.""" + return ComplexSpace(2)**(_max(args) + 1) + + #------------------------------------------------------------------------- + # Properties + #------------------------------------------------------------------------- + + @property + def min_qubits(self): + """The minimum number of qubits this gate needs to act on.""" + return _max(self.label) + 1 + + @property + def targets(self): + """A tuple of target qubits.""" + return (self.label[1],) + + @property + def controls(self): + """A tuple of control qubits.""" + return (self.label[0],) + + @property + def gate(self): + """The non-controlled gate that will be applied to the targets.""" + return XGate(self.label[1]) + + #------------------------------------------------------------------------- + # Properties + #------------------------------------------------------------------------- + + # The default printing of Gate works better than those of CGate, so we + # go around the overridden methods in CGate. + + def _print_label(self, printer, *args): + return Gate._print_label(self, printer, *args) + + def _pretty(self, printer, *args): + return Gate._pretty(self, printer, *args) + + def _latex(self, printer, *args): + return Gate._latex(self, printer, *args) + + #------------------------------------------------------------------------- + # Commutator/AntiCommutator + #------------------------------------------------------------------------- + + def _eval_commutator_ZGate(self, other, **hints): + """[CNOT(i, j), Z(i)] == 0.""" + if self.controls[0] == other.targets[0]: + return _S.Zero + else: + raise NotImplementedError('Commutator not implemented: %r' % other) + + def _eval_commutator_TGate(self, other, **hints): + """[CNOT(i, j), T(i)] == 0.""" + return self._eval_commutator_ZGate(other, **hints) + + def _eval_commutator_PhaseGate(self, other, **hints): + """[CNOT(i, j), S(i)] == 0.""" + return self._eval_commutator_ZGate(other, **hints) + + def _eval_commutator_XGate(self, other, **hints): + """[CNOT(i, j), X(j)] == 0.""" + if self.targets[0] == other.targets[0]: + return _S.Zero + else: + raise NotImplementedError('Commutator not implemented: %r' % other) + + def _eval_commutator_CNotGate(self, other, **hints): + """[CNOT(i, j), CNOT(i,k)] == 0.""" + if self.controls[0] == other.controls[0]: + return _S.Zero + else: + raise NotImplementedError('Commutator not implemented: %r' % other) + + +class SwapGate(TwoQubitGate): + """Two qubit SWAP gate. + + This gate swap the values of the two qubits. + + Parameters + ---------- + label : tuple + A tuple of the form (target1, target2). + + Examples + ======== + + """ + is_hermitian = True + gate_name = 'SWAP' + gate_name_latex = r'\text{SWAP}' + + def get_target_matrix(self, format='sympy'): + return matrix_cache.get_matrix('SWAP', format) + + def decompose(self, **options): + """Decompose the SWAP gate into CNOT gates.""" + i, j = self.targets[0], self.targets[1] + g1 = CNotGate(i, j) + g2 = CNotGate(j, i) + return g1*g2*g1 + + def plot_gate(self, circ_plot, gate_idx): + min_wire = int(_min(self.targets)) + max_wire = int(_max(self.targets)) + circ_plot.control_line(gate_idx, min_wire, max_wire) + circ_plot.swap_point(gate_idx, min_wire) + circ_plot.swap_point(gate_idx, max_wire) + + def _represent_ZGate(self, basis, **options): + """Represent the SWAP gate in the computational basis. + + The following representation is used to compute this: + + SWAP = |1><1|x|1><1| + |0><0|x|0><0| + |1><0|x|0><1| + |0><1|x|1><0| + """ + format = options.get('format', 'sympy') + targets = [int(t) for t in self.targets] + min_target = _min(targets) + max_target = _max(targets) + nqubits = options.get('nqubits', self.min_qubits) + + op01 = matrix_cache.get_matrix('op01', format) + op10 = matrix_cache.get_matrix('op10', format) + op11 = matrix_cache.get_matrix('op11', format) + op00 = matrix_cache.get_matrix('op00', format) + eye2 = matrix_cache.get_matrix('eye2', format) + + result = None + for i, j in ((op01, op10), (op10, op01), (op00, op00), (op11, op11)): + product = nqubits*[eye2] + product[nqubits - min_target - 1] = i + product[nqubits - max_target - 1] = j + new_result = matrix_tensor_product(*product) + if result is None: + result = new_result + else: + result = result + new_result + + return result + + +# Aliases for gate names. +CNOT = CNotGate +SWAP = SwapGate +def CPHASE(a,b): return CGateS((a,),Z(b)) + + +#----------------------------------------------------------------------------- +# Represent +#----------------------------------------------------------------------------- + + +def represent_zbasis(controls, targets, target_matrix, nqubits, format='sympy'): + """Represent a gate with controls, targets and target_matrix. + + This function does the low-level work of representing gates as matrices + in the standard computational basis (ZGate). Currently, we support two + main cases: + + 1. One target qubit and no control qubits. + 2. One target qubits and multiple control qubits. + + For the base of multiple controls, we use the following expression [1]: + + 1_{2**n} + (|1><1|)^{(n-1)} x (target-matrix - 1_{2}) + + Parameters + ---------- + controls : list, tuple + A sequence of control qubits. + targets : list, tuple + A sequence of target qubits. + target_matrix : sympy.Matrix, numpy.matrix, scipy.sparse + The matrix form of the transformation to be performed on the target + qubits. The format of this matrix must match that passed into + the `format` argument. + nqubits : int + The total number of qubits used for the representation. + format : str + The format of the final matrix ('sympy', 'numpy', 'scipy.sparse'). + + Examples + ======== + + References + ---------- + [1] http://www.johnlapeyre.com/qinf/qinf_html/node6.html. + """ + controls = [int(x) for x in controls] + targets = [int(x) for x in targets] + nqubits = int(nqubits) + + # This checks for the format as well. + op11 = matrix_cache.get_matrix('op11', format) + eye2 = matrix_cache.get_matrix('eye2', format) + + # Plain single qubit case + if len(controls) == 0 and len(targets) == 1: + product = [] + bit = targets[0] + # Fill product with [I1,Gate,I2] such that the unitaries, + # I, cause the gate to be applied to the correct Qubit + if bit != nqubits - 1: + product.append(matrix_eye(2**(nqubits - bit - 1), format=format)) + product.append(target_matrix) + if bit != 0: + product.append(matrix_eye(2**bit, format=format)) + return matrix_tensor_product(*product) + + # Single target, multiple controls. + elif len(targets) == 1 and len(controls) >= 1: + target = targets[0] + + # Build the non-trivial part. + product2 = [] + for i in range(nqubits): + product2.append(matrix_eye(2, format=format)) + for control in controls: + product2[nqubits - 1 - control] = op11 + product2[nqubits - 1 - target] = target_matrix - eye2 + + return matrix_eye(2**nqubits, format=format) + \ + matrix_tensor_product(*product2) + + # Multi-target, multi-control is not yet implemented. + else: + raise NotImplementedError( + 'The representation of multi-target, multi-control gates ' + 'is not implemented.' + ) + + +#----------------------------------------------------------------------------- +# Gate manipulation functions. +#----------------------------------------------------------------------------- + + +def gate_simp(circuit): + """Simplifies gates symbolically + + It first sorts gates using gate_sort. It then applies basic + simplification rules to the circuit, e.g., XGate**2 = Identity + """ + + # Bubble sort out gates that commute. + circuit = gate_sort(circuit) + + # Do simplifications by subing a simplification into the first element + # which can be simplified. We recursively call gate_simp with new circuit + # as input more simplifications exist. + if isinstance(circuit, Add): + return sum(gate_simp(t) for t in circuit.args) + elif isinstance(circuit, Mul): + circuit_args = circuit.args + elif isinstance(circuit, Pow): + b, e = circuit.as_base_exp() + circuit_args = (gate_simp(b)**e,) + else: + return circuit + + # Iterate through each element in circuit, simplify if possible. + for i in range(len(circuit_args)): + # H,X,Y or Z squared is 1. + # T**2 = S, S**2 = Z + if isinstance(circuit_args[i], Pow): + if isinstance(circuit_args[i].base, + (HadamardGate, XGate, YGate, ZGate)) \ + and isinstance(circuit_args[i].exp, Number): + # Build a new circuit taking replacing the + # H,X,Y,Z squared with one. + newargs = (circuit_args[:i] + + (circuit_args[i].base**(circuit_args[i].exp % 2),) + + circuit_args[i + 1:]) + # Recursively simplify the new circuit. + circuit = gate_simp(Mul(*newargs)) + break + elif isinstance(circuit_args[i].base, PhaseGate): + # Build a new circuit taking old circuit but splicing + # in simplification. + newargs = circuit_args[:i] + # Replace PhaseGate**2 with ZGate. + newargs = newargs + (ZGate(circuit_args[i].base.args[0])** + (Integer(circuit_args[i].exp/2)), circuit_args[i].base** + (circuit_args[i].exp % 2)) + # Append the last elements. + newargs = newargs + circuit_args[i + 1:] + # Recursively simplify the new circuit. + circuit = gate_simp(Mul(*newargs)) + break + elif isinstance(circuit_args[i].base, TGate): + # Build a new circuit taking all the old elements. + newargs = circuit_args[:i] + + # Put an Phasegate in place of any TGate**2. + newargs = newargs + (PhaseGate(circuit_args[i].base.args[0])** + Integer(circuit_args[i].exp/2), circuit_args[i].base** + (circuit_args[i].exp % 2)) + + # Append the last elements. + newargs = newargs + circuit_args[i + 1:] + # Recursively simplify the new circuit. + circuit = gate_simp(Mul(*newargs)) + break + return circuit + + +def gate_sort(circuit): + """Sorts the gates while keeping track of commutation relations + + This function uses a bubble sort to rearrange the order of gate + application. Keeps track of Quantum computations special commutation + relations (e.g. things that apply to the same Qubit do not commute with + each other) + + circuit is the Mul of gates that are to be sorted. + """ + # Make sure we have an Add or Mul. + if isinstance(circuit, Add): + return sum(gate_sort(t) for t in circuit.args) + if isinstance(circuit, Pow): + return gate_sort(circuit.base)**circuit.exp + elif isinstance(circuit, Gate): + return circuit + if not isinstance(circuit, Mul): + return circuit + + changes = True + while changes: + changes = False + circ_array = circuit.args + for i in range(len(circ_array) - 1): + # Go through each element and switch ones that are in wrong order + if isinstance(circ_array[i], (Gate, Pow)) and \ + isinstance(circ_array[i + 1], (Gate, Pow)): + # If we have a Pow object, look at only the base + first_base, first_exp = circ_array[i].as_base_exp() + second_base, second_exp = circ_array[i + 1].as_base_exp() + + # Use SymPy's hash based sorting. This is not mathematical + # sorting, but is rather based on comparing hashes of objects. + # See Basic.compare for details. + if first_base.compare(second_base) > 0: + if Commutator(first_base, second_base).doit() == 0: + new_args = (circuit.args[:i] + (circuit.args[i + 1],) + + (circuit.args[i],) + circuit.args[i + 2:]) + circuit = Mul(*new_args) + changes = True + break + if AntiCommutator(first_base, second_base).doit() == 0: + new_args = (circuit.args[:i] + (circuit.args[i + 1],) + + (circuit.args[i],) + circuit.args[i + 2:]) + sign = _S.NegativeOne**(first_exp*second_exp) + circuit = sign*Mul(*new_args) + changes = True + break + return circuit + + +#----------------------------------------------------------------------------- +# Utility functions +#----------------------------------------------------------------------------- + + +def random_circuit(ngates, nqubits, gate_space=(X, Y, Z, S, T, H, CNOT, SWAP)): + """Return a random circuit of ngates and nqubits. + + This uses an equally weighted sample of (X, Y, Z, S, T, H, CNOT, SWAP) + gates. + + Parameters + ---------- + ngates : int + The number of gates in the circuit. + nqubits : int + The number of qubits in the circuit. + gate_space : tuple + A tuple of the gate classes that will be used in the circuit. + Repeating gate classes multiple times in this tuple will increase + the frequency they appear in the random circuit. + """ + qubit_space = range(nqubits) + result = [] + for i in range(ngates): + g = random.choice(gate_space) + if g == CNotGate or g == SwapGate: + qubits = random.sample(qubit_space, 2) + g = g(*qubits) + else: + qubit = random.choice(qubit_space) + g = g(qubit) + result.append(g) + return Mul(*result) + + +def zx_basis_transform(self, format='sympy'): + """Transformation matrix from Z to X basis.""" + return matrix_cache.get_matrix('ZX', format) + + +def zy_basis_transform(self, format='sympy'): + """Transformation matrix from Z to Y basis.""" + return matrix_cache.get_matrix('ZY', format) diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/grover.py b/phivenv/Lib/site-packages/sympy/physics/quantum/grover.py new file mode 100644 index 0000000000000000000000000000000000000000..a03bd3a61a6e0960ab66d55bcc0fc7f25936199e --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/grover.py @@ -0,0 +1,345 @@ +"""Grover's algorithm and helper functions. + +Todo: + +* W gate construction (or perhaps -W gate based on Mermin's book) +* Generalize the algorithm for an unknown function that returns 1 on multiple + qubit states, not just one. +* Implement _represent_ZGate in OracleGate +""" + +from sympy.core.numbers import pi +from sympy.core.sympify import sympify +from sympy.core.basic import Atom +from sympy.functions.elementary.integers import floor +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.matrices.dense import eye +from sympy.core.numbers import NegativeOne +from sympy.physics.quantum.qapply import qapply +from sympy.physics.quantum.qexpr import QuantumError +from sympy.physics.quantum.hilbert import ComplexSpace +from sympy.physics.quantum.operator import UnitaryOperator +from sympy.physics.quantum.gate import Gate +from sympy.physics.quantum.qubit import IntQubit + +__all__ = [ + 'OracleGate', + 'WGate', + 'superposition_basis', + 'grover_iteration', + 'apply_grover' +] + + +def superposition_basis(nqubits): + """Creates an equal superposition of the computational basis. + + Parameters + ========== + + nqubits : int + The number of qubits. + + Returns + ======= + + state : Qubit + An equal superposition of the computational basis with nqubits. + + Examples + ======== + + Create an equal superposition of 2 qubits:: + + >>> from sympy.physics.quantum.grover import superposition_basis + >>> superposition_basis(2) + |0>/2 + |1>/2 + |2>/2 + |3>/2 + """ + + amp = 1/sqrt(2**nqubits) + return sum(amp*IntQubit(n, nqubits=nqubits) for n in range(2**nqubits)) + +class OracleGateFunction(Atom): + """Wrapper for python functions used in `OracleGate`s""" + + def __new__(cls, function): + if not callable(function): + raise TypeError('Callable expected, got: %r' % function) + obj = Atom.__new__(cls) + obj.function = function + return obj + + def _hashable_content(self): + return type(self), self.function + + def __call__(self, *args): + return self.function(*args) + + +class OracleGate(Gate): + """A black box gate. + + The gate marks the desired qubits of an unknown function by flipping + the sign of the qubits. The unknown function returns true when it + finds its desired qubits and false otherwise. + + Parameters + ========== + + qubits : int + Number of qubits. + + oracle : callable + A callable function that returns a boolean on a computational basis. + + Examples + ======== + + Apply an Oracle gate that flips the sign of ``|2>`` on different qubits:: + + >>> from sympy.physics.quantum.qubit import IntQubit + >>> from sympy.physics.quantum.qapply import qapply + >>> from sympy.physics.quantum.grover import OracleGate + >>> f = lambda qubits: qubits == IntQubit(2) + >>> v = OracleGate(2, f) + >>> qapply(v*IntQubit(2)) + -|2> + >>> qapply(v*IntQubit(3)) + |3> + """ + + gate_name = 'V' + gate_name_latex = 'V' + + #------------------------------------------------------------------------- + # Initialization/creation + #------------------------------------------------------------------------- + + @classmethod + def _eval_args(cls, args): + if len(args) != 2: + raise QuantumError( + 'Insufficient/excessive arguments to Oracle. Please ' + + 'supply the number of qubits and an unknown function.' + ) + sub_args = (args[0],) + sub_args = UnitaryOperator._eval_args(sub_args) + if not sub_args[0].is_Integer: + raise TypeError('Integer expected, got: %r' % sub_args[0]) + + function = args[1] + if not isinstance(function, OracleGateFunction): + function = OracleGateFunction(function) + + return (sub_args[0], function) + + @classmethod + def _eval_hilbert_space(cls, args): + """This returns the smallest possible Hilbert space.""" + return ComplexSpace(2)**args[0] + + #------------------------------------------------------------------------- + # Properties + #------------------------------------------------------------------------- + + @property + def search_function(self): + """The unknown function that helps find the sought after qubits.""" + return self.label[1] + + @property + def targets(self): + """A tuple of target qubits.""" + return sympify(tuple(range(self.args[0]))) + + #------------------------------------------------------------------------- + # Apply + #------------------------------------------------------------------------- + + def _apply_operator_Qubit(self, qubits, **options): + """Apply this operator to a Qubit subclass. + + Parameters + ========== + + qubits : Qubit + The qubit subclass to apply this operator to. + + Returns + ======= + + state : Expr + The resulting quantum state. + """ + if qubits.nqubits != self.nqubits: + raise QuantumError( + 'OracleGate operates on %r qubits, got: %r' + % (self.nqubits, qubits.nqubits) + ) + # If function returns 1 on qubits + # return the negative of the qubits (flip the sign) + if self.search_function(qubits): + return -qubits + else: + return qubits + + #------------------------------------------------------------------------- + # Represent + #------------------------------------------------------------------------- + + def _represent_ZGate(self, basis, **options): + """ + Represent the OracleGate in the computational basis. + """ + nbasis = 2**self.nqubits # compute it only once + matrixOracle = eye(nbasis) + # Flip the sign given the output of the oracle function + for i in range(nbasis): + if self.search_function(IntQubit(i, nqubits=self.nqubits)): + matrixOracle[i, i] = NegativeOne() + return matrixOracle + + +class WGate(Gate): + """General n qubit W Gate in Grover's algorithm. + + The gate performs the operation ``2|phi> = (tensor product of n Hadamards)*(|0> with n qubits)`` + + Parameters + ========== + + nqubits : int + The number of qubits to operate on + + """ + + gate_name = 'W' + gate_name_latex = 'W' + + @classmethod + def _eval_args(cls, args): + if len(args) != 1: + raise QuantumError( + 'Insufficient/excessive arguments to W gate. Please ' + + 'supply the number of qubits to operate on.' + ) + args = UnitaryOperator._eval_args(args) + if not args[0].is_Integer: + raise TypeError('Integer expected, got: %r' % args[0]) + return args + + #------------------------------------------------------------------------- + # Properties + #------------------------------------------------------------------------- + + @property + def targets(self): + return sympify(tuple(reversed(range(self.args[0])))) + + #------------------------------------------------------------------------- + # Apply + #------------------------------------------------------------------------- + + def _apply_operator_Qubit(self, qubits, **options): + """ + qubits: a set of qubits (Qubit) + Returns: quantum object (quantum expression - QExpr) + """ + if qubits.nqubits != self.nqubits: + raise QuantumError( + 'WGate operates on %r qubits, got: %r' + % (self.nqubits, qubits.nqubits) + ) + + # See 'Quantum Computer Science' by David Mermin p.92 -> W|a> result + # Return (2/(sqrt(2^n)))|phi> - |a> where |a> is the current basis + # state and phi is the superposition of basis states (see function + # create_computational_basis above) + basis_states = superposition_basis(self.nqubits) + change_to_basis = (2/sqrt(2**self.nqubits))*basis_states + return change_to_basis - qubits + + +def grover_iteration(qstate, oracle): + """Applies one application of the Oracle and W Gate, WV. + + Parameters + ========== + + qstate : Qubit + A superposition of qubits. + oracle : OracleGate + The black box operator that flips the sign of the desired basis qubits. + + Returns + ======= + + Qubit : The qubits after applying the Oracle and W gate. + + Examples + ======== + + Perform one iteration of grover's algorithm to see a phase change:: + + >>> from sympy.physics.quantum.qapply import qapply + >>> from sympy.physics.quantum.qubit import IntQubit + >>> from sympy.physics.quantum.grover import OracleGate + >>> from sympy.physics.quantum.grover import superposition_basis + >>> from sympy.physics.quantum.grover import grover_iteration + >>> numqubits = 2 + >>> basis_states = superposition_basis(numqubits) + >>> f = lambda qubits: qubits == IntQubit(2) + >>> v = OracleGate(numqubits, f) + >>> qapply(grover_iteration(basis_states, v)) + |2> + + """ + wgate = WGate(oracle.nqubits) + return wgate*oracle*qstate + + +def apply_grover(oracle, nqubits, iterations=None): + """Applies grover's algorithm. + + Parameters + ========== + + oracle : callable + The unknown callable function that returns true when applied to the + desired qubits and false otherwise. + + Returns + ======= + + state : Expr + The resulting state after Grover's algorithm has been iterated. + + Examples + ======== + + Apply grover's algorithm to an even superposition of 2 qubits:: + + >>> from sympy.physics.quantum.qapply import qapply + >>> from sympy.physics.quantum.qubit import IntQubit + >>> from sympy.physics.quantum.grover import apply_grover + >>> f = lambda qubits: qubits == IntQubit(2) + >>> qapply(apply_grover(f, 2)) + |2> + + """ + if nqubits <= 0: + raise QuantumError( + 'Grover\'s algorithm needs nqubits > 0, received %r qubits' + % nqubits + ) + if iterations is None: + iterations = floor(sqrt(2**nqubits)*(pi/4)) + + v = OracleGate(nqubits, oracle) + iterated = superposition_basis(nqubits) + for iter in range(iterations): + iterated = grover_iteration(iterated, v) + iterated = qapply(iterated) + + return iterated diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/hilbert.py b/phivenv/Lib/site-packages/sympy/physics/quantum/hilbert.py new file mode 100644 index 0000000000000000000000000000000000000000..f475a9e83a6ccc93e9e2dbb9873ad111c1d05f93 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/hilbert.py @@ -0,0 +1,653 @@ +"""Hilbert spaces for quantum mechanics. + +Authors: +* Brian Granger +* Matt Curry +""" + +from functools import reduce + +from sympy.core.basic import Basic +from sympy.core.singleton import S +from sympy.core.sympify import sympify +from sympy.sets.sets import Interval +from sympy.printing.pretty.stringpict import prettyForm +from sympy.physics.quantum.qexpr import QuantumError + + +__all__ = [ + 'HilbertSpaceError', + 'HilbertSpace', + 'TensorProductHilbertSpace', + 'TensorPowerHilbertSpace', + 'DirectSumHilbertSpace', + 'ComplexSpace', + 'L2', + 'FockSpace' +] + +#----------------------------------------------------------------------------- +# Main objects +#----------------------------------------------------------------------------- + + +class HilbertSpaceError(QuantumError): + pass + +#----------------------------------------------------------------------------- +# Main objects +#----------------------------------------------------------------------------- + + +class HilbertSpace(Basic): + """An abstract Hilbert space for quantum mechanics. + + In short, a Hilbert space is an abstract vector space that is complete + with inner products defined [1]_. + + Examples + ======== + + >>> from sympy.physics.quantum.hilbert import HilbertSpace + >>> hs = HilbertSpace() + >>> hs + H + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Hilbert_space + """ + + def __new__(cls): + obj = Basic.__new__(cls) + return obj + + @property + def dimension(self): + """Return the Hilbert dimension of the space.""" + raise NotImplementedError('This Hilbert space has no dimension.') + + def __add__(self, other): + return DirectSumHilbertSpace(self, other) + + def __radd__(self, other): + return DirectSumHilbertSpace(other, self) + + def __mul__(self, other): + return TensorProductHilbertSpace(self, other) + + def __rmul__(self, other): + return TensorProductHilbertSpace(other, self) + + def __pow__(self, other, mod=None): + if mod is not None: + raise ValueError('The third argument to __pow__ is not supported \ + for Hilbert spaces.') + return TensorPowerHilbertSpace(self, other) + + def __contains__(self, other): + """Is the operator or state in this Hilbert space. + + This is checked by comparing the classes of the Hilbert spaces, not + the instances. This is to allow Hilbert Spaces with symbolic + dimensions. + """ + if other.hilbert_space.__class__ == self.__class__: + return True + else: + return False + + def _sympystr(self, printer, *args): + return 'H' + + def _pretty(self, printer, *args): + ustr = '\N{LATIN CAPITAL LETTER H}' + return prettyForm(ustr) + + def _latex(self, printer, *args): + return r'\mathcal{H}' + + +class ComplexSpace(HilbertSpace): + """Finite dimensional Hilbert space of complex vectors. + + The elements of this Hilbert space are n-dimensional complex valued + vectors with the usual inner product that takes the complex conjugate + of the vector on the right. + + A classic example of this type of Hilbert space is spin-1/2, which is + ``ComplexSpace(2)``. Generalizing to spin-s, the space is + ``ComplexSpace(2*s+1)``. Quantum computing with N qubits is done with the + direct product space ``ComplexSpace(2)**N``. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.quantum.hilbert import ComplexSpace + >>> c1 = ComplexSpace(2) + >>> c1 + C(2) + >>> c1.dimension + 2 + + >>> n = symbols('n') + >>> c2 = ComplexSpace(n) + >>> c2 + C(n) + >>> c2.dimension + n + + """ + + def __new__(cls, dimension): + dimension = sympify(dimension) + r = cls.eval(dimension) + if isinstance(r, Basic): + return r + obj = Basic.__new__(cls, dimension) + return obj + + @classmethod + def eval(cls, dimension): + if len(dimension.atoms()) == 1: + if not (dimension.is_Integer and dimension > 0 or dimension is S.Infinity + or dimension.is_Symbol): + raise TypeError('The dimension of a ComplexSpace can only' + 'be a positive integer, oo, or a Symbol: %r' + % dimension) + else: + for dim in dimension.atoms(): + if not (dim.is_Integer or dim is S.Infinity or dim.is_Symbol): + raise TypeError('The dimension of a ComplexSpace can only' + ' contain integers, oo, or a Symbol: %r' + % dim) + + @property + def dimension(self): + return self.args[0] + + def _sympyrepr(self, printer, *args): + return "%s(%s)" % (self.__class__.__name__, + printer._print(self.dimension, *args)) + + def _sympystr(self, printer, *args): + return "C(%s)" % printer._print(self.dimension, *args) + + def _pretty(self, printer, *args): + ustr = '\N{LATIN CAPITAL LETTER C}' + pform_exp = printer._print(self.dimension, *args) + pform_base = prettyForm(ustr) + return pform_base**pform_exp + + def _latex(self, printer, *args): + return r'\mathcal{C}^{%s}' % printer._print(self.dimension, *args) + + +class L2(HilbertSpace): + """The Hilbert space of square integrable functions on an interval. + + An L2 object takes in a single SymPy Interval argument which represents + the interval its functions (vectors) are defined on. + + Examples + ======== + + >>> from sympy import Interval, oo + >>> from sympy.physics.quantum.hilbert import L2 + >>> hs = L2(Interval(0,oo)) + >>> hs + L2(Interval(0, oo)) + >>> hs.dimension + oo + >>> hs.interval + Interval(0, oo) + + """ + + def __new__(cls, interval): + if not isinstance(interval, Interval): + raise TypeError('L2 interval must be an Interval instance: %r' + % interval) + obj = Basic.__new__(cls, interval) + return obj + + @property + def dimension(self): + return S.Infinity + + @property + def interval(self): + return self.args[0] + + def _sympyrepr(self, printer, *args): + return "L2(%s)" % printer._print(self.interval, *args) + + def _sympystr(self, printer, *args): + return "L2(%s)" % printer._print(self.interval, *args) + + def _pretty(self, printer, *args): + pform_exp = prettyForm('2') + pform_base = prettyForm('L') + return pform_base**pform_exp + + def _latex(self, printer, *args): + interval = printer._print(self.interval, *args) + return r'{\mathcal{L}^2}\left( %s \right)' % interval + + +class FockSpace(HilbertSpace): + """The Hilbert space for second quantization. + + Technically, this Hilbert space is a infinite direct sum of direct + products of single particle Hilbert spaces [1]_. This is a mess, so we have + a class to represent it directly. + + Examples + ======== + + >>> from sympy.physics.quantum.hilbert import FockSpace + >>> hs = FockSpace() + >>> hs + F + >>> hs.dimension + oo + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Fock_space + """ + + def __new__(cls): + obj = Basic.__new__(cls) + return obj + + @property + def dimension(self): + return S.Infinity + + def _sympyrepr(self, printer, *args): + return "FockSpace()" + + def _sympystr(self, printer, *args): + return "F" + + def _pretty(self, printer, *args): + ustr = '\N{LATIN CAPITAL LETTER F}' + return prettyForm(ustr) + + def _latex(self, printer, *args): + return r'\mathcal{F}' + + +class TensorProductHilbertSpace(HilbertSpace): + """A tensor product of Hilbert spaces [1]_. + + The tensor product between Hilbert spaces is represented by the + operator ``*`` Products of the same Hilbert space will be combined into + tensor powers. + + A ``TensorProductHilbertSpace`` object takes in an arbitrary number of + ``HilbertSpace`` objects as its arguments. In addition, multiplication of + ``HilbertSpace`` objects will automatically return this tensor product + object. + + Examples + ======== + + >>> from sympy.physics.quantum.hilbert import ComplexSpace, FockSpace + >>> from sympy import symbols + + >>> c = ComplexSpace(2) + >>> f = FockSpace() + >>> hs = c*f + >>> hs + C(2)*F + >>> hs.dimension + oo + >>> hs.spaces + (C(2), F) + + >>> c1 = ComplexSpace(2) + >>> n = symbols('n') + >>> c2 = ComplexSpace(n) + >>> hs = c1*c2 + >>> hs + C(2)*C(n) + >>> hs.dimension + 2*n + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Hilbert_space#Tensor_products + """ + + def __new__(cls, *args): + r = cls.eval(args) + if isinstance(r, Basic): + return r + obj = Basic.__new__(cls, *args) + return obj + + @classmethod + def eval(cls, args): + """Evaluates the direct product.""" + new_args = [] + recall = False + #flatten arguments + for arg in args: + if isinstance(arg, TensorProductHilbertSpace): + new_args.extend(arg.args) + recall = True + elif isinstance(arg, (HilbertSpace, TensorPowerHilbertSpace)): + new_args.append(arg) + else: + raise TypeError('Hilbert spaces can only be multiplied by \ + other Hilbert spaces: %r' % arg) + #combine like arguments into direct powers + comb_args = [] + prev_arg = None + for new_arg in new_args: + if prev_arg is not None: + if isinstance(new_arg, TensorPowerHilbertSpace) and \ + isinstance(prev_arg, TensorPowerHilbertSpace) and \ + new_arg.base == prev_arg.base: + prev_arg = new_arg.base**(new_arg.exp + prev_arg.exp) + elif isinstance(new_arg, TensorPowerHilbertSpace) and \ + new_arg.base == prev_arg: + prev_arg = prev_arg**(new_arg.exp + 1) + elif isinstance(prev_arg, TensorPowerHilbertSpace) and \ + new_arg == prev_arg.base: + prev_arg = new_arg**(prev_arg.exp + 1) + elif new_arg == prev_arg: + prev_arg = new_arg**2 + else: + comb_args.append(prev_arg) + prev_arg = new_arg + elif prev_arg is None: + prev_arg = new_arg + comb_args.append(prev_arg) + if recall: + return TensorProductHilbertSpace(*comb_args) + elif len(comb_args) == 1: + return TensorPowerHilbertSpace(comb_args[0].base, comb_args[0].exp) + else: + return None + + @property + def dimension(self): + arg_list = [arg.dimension for arg in self.args] + if S.Infinity in arg_list: + return S.Infinity + else: + return reduce(lambda x, y: x*y, arg_list) + + @property + def spaces(self): + """A tuple of the Hilbert spaces in this tensor product.""" + return self.args + + def _spaces_printer(self, printer, *args): + spaces_strs = [] + for arg in self.args: + s = printer._print(arg, *args) + if isinstance(arg, DirectSumHilbertSpace): + s = '(%s)' % s + spaces_strs.append(s) + return spaces_strs + + def _sympyrepr(self, printer, *args): + spaces_reprs = self._spaces_printer(printer, *args) + return "TensorProductHilbertSpace(%s)" % ','.join(spaces_reprs) + + def _sympystr(self, printer, *args): + spaces_strs = self._spaces_printer(printer, *args) + return '*'.join(spaces_strs) + + def _pretty(self, printer, *args): + length = len(self.args) + pform = printer._print('', *args) + for i in range(length): + next_pform = printer._print(self.args[i], *args) + if isinstance(self.args[i], (DirectSumHilbertSpace, + TensorProductHilbertSpace)): + next_pform = prettyForm( + *next_pform.parens(left='(', right=')') + ) + pform = prettyForm(*pform.right(next_pform)) + if i != length - 1: + if printer._use_unicode: + pform = prettyForm(*pform.right(' ' + '\N{N-ARY CIRCLED TIMES OPERATOR}' + ' ')) + else: + pform = prettyForm(*pform.right(' x ')) + return pform + + def _latex(self, printer, *args): + length = len(self.args) + s = '' + for i in range(length): + arg_s = printer._print(self.args[i], *args) + if isinstance(self.args[i], (DirectSumHilbertSpace, + TensorProductHilbertSpace)): + arg_s = r'\left(%s\right)' % arg_s + s = s + arg_s + if i != length - 1: + s = s + r'\otimes ' + return s + + +class DirectSumHilbertSpace(HilbertSpace): + """A direct sum of Hilbert spaces [1]_. + + This class uses the ``+`` operator to represent direct sums between + different Hilbert spaces. + + A ``DirectSumHilbertSpace`` object takes in an arbitrary number of + ``HilbertSpace`` objects as its arguments. Also, addition of + ``HilbertSpace`` objects will automatically return a direct sum object. + + Examples + ======== + + >>> from sympy.physics.quantum.hilbert import ComplexSpace, FockSpace + + >>> c = ComplexSpace(2) + >>> f = FockSpace() + >>> hs = c+f + >>> hs + C(2)+F + >>> hs.dimension + oo + >>> list(hs.spaces) + [C(2), F] + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Hilbert_space#Direct_sums + """ + def __new__(cls, *args): + r = cls.eval(args) + if isinstance(r, Basic): + return r + obj = Basic.__new__(cls, *args) + return obj + + @classmethod + def eval(cls, args): + """Evaluates the direct product.""" + new_args = [] + recall = False + #flatten arguments + for arg in args: + if isinstance(arg, DirectSumHilbertSpace): + new_args.extend(arg.args) + recall = True + elif isinstance(arg, HilbertSpace): + new_args.append(arg) + else: + raise TypeError('Hilbert spaces can only be summed with other \ + Hilbert spaces: %r' % arg) + if recall: + return DirectSumHilbertSpace(*new_args) + else: + return None + + @property + def dimension(self): + arg_list = [arg.dimension for arg in self.args] + if S.Infinity in arg_list: + return S.Infinity + else: + return reduce(lambda x, y: x + y, arg_list) + + @property + def spaces(self): + """A tuple of the Hilbert spaces in this direct sum.""" + return self.args + + def _sympyrepr(self, printer, *args): + spaces_reprs = [printer._print(arg, *args) for arg in self.args] + return "DirectSumHilbertSpace(%s)" % ','.join(spaces_reprs) + + def _sympystr(self, printer, *args): + spaces_strs = [printer._print(arg, *args) for arg in self.args] + return '+'.join(spaces_strs) + + def _pretty(self, printer, *args): + length = len(self.args) + pform = printer._print('', *args) + for i in range(length): + next_pform = printer._print(self.args[i], *args) + if isinstance(self.args[i], (DirectSumHilbertSpace, + TensorProductHilbertSpace)): + next_pform = prettyForm( + *next_pform.parens(left='(', right=')') + ) + pform = prettyForm(*pform.right(next_pform)) + if i != length - 1: + if printer._use_unicode: + pform = prettyForm(*pform.right(' \N{CIRCLED PLUS} ')) + else: + pform = prettyForm(*pform.right(' + ')) + return pform + + def _latex(self, printer, *args): + length = len(self.args) + s = '' + for i in range(length): + arg_s = printer._print(self.args[i], *args) + if isinstance(self.args[i], (DirectSumHilbertSpace, + TensorProductHilbertSpace)): + arg_s = r'\left(%s\right)' % arg_s + s = s + arg_s + if i != length - 1: + s = s + r'\oplus ' + return s + + +class TensorPowerHilbertSpace(HilbertSpace): + """An exponentiated Hilbert space [1]_. + + Tensor powers (repeated tensor products) are represented by the + operator ``**`` Identical Hilbert spaces that are multiplied together + will be automatically combined into a single tensor power object. + + Any Hilbert space, product, or sum may be raised to a tensor power. The + ``TensorPowerHilbertSpace`` takes two arguments: the Hilbert space; and the + tensor power (number). + + Examples + ======== + + >>> from sympy.physics.quantum.hilbert import ComplexSpace, FockSpace + >>> from sympy import symbols + + >>> n = symbols('n') + >>> c = ComplexSpace(2) + >>> hs = c**n + >>> hs + C(2)**n + >>> hs.dimension + 2**n + + >>> c = ComplexSpace(2) + >>> c*c + C(2)**2 + >>> f = FockSpace() + >>> c*f*f + C(2)*F**2 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Hilbert_space#Tensor_products + """ + + def __new__(cls, *args): + r = cls.eval(args) + if isinstance(r, Basic): + return r + return Basic.__new__(cls, *r) + + @classmethod + def eval(cls, args): + new_args = args[0], sympify(args[1]) + exp = new_args[1] + #simplify hs**1 -> hs + if exp is S.One: + return args[0] + #simplify hs**0 -> 1 + if exp is S.Zero: + return S.One + #check (and allow) for hs**(x+42+y...) case + if len(exp.atoms()) == 1: + if not (exp.is_Integer and exp >= 0 or exp.is_Symbol): + raise ValueError('Hilbert spaces can only be raised to \ + positive integers or Symbols: %r' % exp) + else: + for power in exp.atoms(): + if not (power.is_Integer or power.is_Symbol): + raise ValueError('Tensor powers can only contain integers \ + or Symbols: %r' % power) + return new_args + + @property + def base(self): + return self.args[0] + + @property + def exp(self): + return self.args[1] + + @property + def dimension(self): + if self.base.dimension is S.Infinity: + return S.Infinity + else: + return self.base.dimension**self.exp + + def _sympyrepr(self, printer, *args): + return "TensorPowerHilbertSpace(%s,%s)" % (printer._print(self.base, + *args), printer._print(self.exp, *args)) + + def _sympystr(self, printer, *args): + return "%s**%s" % (printer._print(self.base, *args), + printer._print(self.exp, *args)) + + def _pretty(self, printer, *args): + pform_exp = printer._print(self.exp, *args) + if printer._use_unicode: + pform_exp = prettyForm(*pform_exp.left(prettyForm('\N{N-ARY CIRCLED TIMES OPERATOR}'))) + else: + pform_exp = prettyForm(*pform_exp.left(prettyForm('x'))) + pform_base = printer._print(self.base, *args) + return pform_base**pform_exp + + def _latex(self, printer, *args): + base = printer._print(self.base, *args) + exp = printer._print(self.exp, *args) + return r'{%s}^{\otimes %s}' % (base, exp) diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/identitysearch.py b/phivenv/Lib/site-packages/sympy/physics/quantum/identitysearch.py new file mode 100644 index 0000000000000000000000000000000000000000..9a178e9b808450b7ce91175600d6b393fc9797d6 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/identitysearch.py @@ -0,0 +1,853 @@ +from collections import deque +from sympy.core.random import randint + +from sympy.external import import_module +from sympy.core.basic import Basic +from sympy.core.mul import Mul +from sympy.core.numbers import Number, equal_valued +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.physics.quantum.represent import represent +from sympy.physics.quantum.dagger import Dagger + +__all__ = [ + # Public interfaces + 'generate_gate_rules', + 'generate_equivalent_ids', + 'GateIdentity', + 'bfs_identity_search', + 'random_identity_search', + + # "Private" functions + 'is_scalar_sparse_matrix', + 'is_scalar_nonsparse_matrix', + 'is_degenerate', + 'is_reducible', +] + +np = import_module('numpy') +scipy = import_module('scipy', import_kwargs={'fromlist': ['sparse']}) + + +def is_scalar_sparse_matrix(circuit, nqubits, identity_only, eps=1e-11): + """Checks if a given scipy.sparse matrix is a scalar matrix. + + A scalar matrix is such that B = bI, where B is the scalar + matrix, b is some scalar multiple, and I is the identity + matrix. A scalar matrix would have only the element b along + it's main diagonal and zeroes elsewhere. + + Parameters + ========== + + circuit : Gate tuple + Sequence of quantum gates representing a quantum circuit + nqubits : int + Number of qubits in the circuit + identity_only : bool + Check for only identity matrices + eps : number + The tolerance value for zeroing out elements in the matrix. + Values in the range [-eps, +eps] will be changed to a zero. + """ + + if not np or not scipy: + pass + + matrix = represent(Mul(*circuit), nqubits=nqubits, + format='scipy.sparse') + + # In some cases, represent returns a 1D scalar value in place + # of a multi-dimensional scalar matrix + if (isinstance(matrix, int)): + return matrix == 1 if identity_only else True + + # If represent returns a matrix, check if the matrix is diagonal + # and if every item along the diagonal is the same + else: + # Due to floating pointing operations, must zero out + # elements that are "very" small in the dense matrix + # See parameter for default value. + + # Get the ndarray version of the dense matrix + dense_matrix = matrix.todense().getA() + # Since complex values can't be compared, must split + # the matrix into real and imaginary components + # Find the real values in between -eps and eps + bool_real = np.logical_and(dense_matrix.real > -eps, + dense_matrix.real < eps) + # Find the imaginary values between -eps and eps + bool_imag = np.logical_and(dense_matrix.imag > -eps, + dense_matrix.imag < eps) + # Replaces values between -eps and eps with 0 + corrected_real = np.where(bool_real, 0.0, dense_matrix.real) + corrected_imag = np.where(bool_imag, 0.0, dense_matrix.imag) + # Convert the matrix with real values into imaginary values + corrected_imag = corrected_imag * complex(1j) + # Recombine the real and imaginary components + corrected_dense = corrected_real + corrected_imag + + # Check if it's diagonal + row_indices = corrected_dense.nonzero()[0] + col_indices = corrected_dense.nonzero()[1] + # Check if the rows indices and columns indices are the same + # If they match, then matrix only contains elements along diagonal + bool_indices = row_indices == col_indices + is_diagonal = bool_indices.all() + + first_element = corrected_dense[0][0] + # If the first element is a zero, then can't rescale matrix + # and definitely not diagonal + if (first_element == 0.0 + 0.0j): + return False + + # The dimensions of the dense matrix should still + # be 2^nqubits if there are elements all along the + # the main diagonal + trace_of_corrected = (corrected_dense/first_element).trace() + expected_trace = pow(2, nqubits) + has_correct_trace = trace_of_corrected == expected_trace + + # If only looking for identity matrices + # first element must be a 1 + real_is_one = abs(first_element.real - 1.0) < eps + imag_is_zero = abs(first_element.imag) < eps + is_one = real_is_one and imag_is_zero + is_identity = is_one if identity_only else True + return bool(is_diagonal and has_correct_trace and is_identity) + + +def is_scalar_nonsparse_matrix(circuit, nqubits, identity_only, eps=None): + """Checks if a given circuit, in matrix form, is equivalent to + a scalar value. + + Parameters + ========== + + circuit : Gate tuple + Sequence of quantum gates representing a quantum circuit + nqubits : int + Number of qubits in the circuit + identity_only : bool + Check for only identity matrices + eps : number + This argument is ignored. It is just for signature compatibility with + is_scalar_sparse_matrix. + + Note: Used in situations when is_scalar_sparse_matrix has bugs + """ + + matrix = represent(Mul(*circuit), nqubits=nqubits) + + # In some cases, represent returns a 1D scalar value in place + # of a multi-dimensional scalar matrix + if (isinstance(matrix, Number)): + return matrix == 1 if identity_only else True + + # If represent returns a matrix, check if the matrix is diagonal + # and if every item along the diagonal is the same + else: + # Added up the diagonal elements + matrix_trace = matrix.trace() + # Divide the trace by the first element in the matrix + # if matrix is not required to be the identity matrix + adjusted_matrix_trace = (matrix_trace/matrix[0] + if not identity_only + else matrix_trace) + + is_identity = equal_valued(matrix[0], 1) if identity_only else True + + has_correct_trace = adjusted_matrix_trace == pow(2, nqubits) + + # The matrix is scalar if it's diagonal and the adjusted trace + # value is equal to 2^nqubits + return bool( + matrix.is_diagonal() and has_correct_trace and is_identity) + +if np and scipy: + is_scalar_matrix = is_scalar_sparse_matrix +else: + is_scalar_matrix = is_scalar_nonsparse_matrix + + +def _get_min_qubits(a_gate): + if isinstance(a_gate, Pow): + return a_gate.base.min_qubits + else: + return a_gate.min_qubits + + +def ll_op(left, right): + """Perform a LL operation. + + A LL operation multiplies both left and right circuits + with the dagger of the left circuit's leftmost gate, and + the dagger is multiplied on the left side of both circuits. + + If a LL is possible, it returns the new gate rule as a + 2-tuple (LHS, RHS), where LHS is the left circuit and + and RHS is the right circuit of the new rule. + If a LL is not possible, None is returned. + + Parameters + ========== + + left : Gate tuple + The left circuit of a gate rule expression. + right : Gate tuple + The right circuit of a gate rule expression. + + Examples + ======== + + Generate a new gate rule using a LL operation: + + >>> from sympy.physics.quantum.identitysearch import ll_op + >>> from sympy.physics.quantum.gate import X, Y, Z + >>> x = X(0); y = Y(0); z = Z(0) + >>> ll_op((x, y, z), ()) + ((Y(0), Z(0)), (X(0),)) + + >>> ll_op((y, z), (x,)) + ((Z(0),), (Y(0), X(0))) + """ + + if (len(left) > 0): + ll_gate = left[0] + ll_gate_is_unitary = is_scalar_matrix( + (Dagger(ll_gate), ll_gate), _get_min_qubits(ll_gate), True) + + if (len(left) > 0 and ll_gate_is_unitary): + # Get the new left side w/o the leftmost gate + new_left = left[1:len(left)] + # Add the leftmost gate to the left position on the right side + new_right = (Dagger(ll_gate),) + right + # Return the new gate rule + return (new_left, new_right) + + return None + + +def lr_op(left, right): + """Perform a LR operation. + + A LR operation multiplies both left and right circuits + with the dagger of the left circuit's rightmost gate, and + the dagger is multiplied on the right side of both circuits. + + If a LR is possible, it returns the new gate rule as a + 2-tuple (LHS, RHS), where LHS is the left circuit and + and RHS is the right circuit of the new rule. + If a LR is not possible, None is returned. + + Parameters + ========== + + left : Gate tuple + The left circuit of a gate rule expression. + right : Gate tuple + The right circuit of a gate rule expression. + + Examples + ======== + + Generate a new gate rule using a LR operation: + + >>> from sympy.physics.quantum.identitysearch import lr_op + >>> from sympy.physics.quantum.gate import X, Y, Z + >>> x = X(0); y = Y(0); z = Z(0) + >>> lr_op((x, y, z), ()) + ((X(0), Y(0)), (Z(0),)) + + >>> lr_op((x, y), (z,)) + ((X(0),), (Z(0), Y(0))) + """ + + if (len(left) > 0): + lr_gate = left[len(left) - 1] + lr_gate_is_unitary = is_scalar_matrix( + (Dagger(lr_gate), lr_gate), _get_min_qubits(lr_gate), True) + + if (len(left) > 0 and lr_gate_is_unitary): + # Get the new left side w/o the rightmost gate + new_left = left[0:len(left) - 1] + # Add the rightmost gate to the right position on the right side + new_right = right + (Dagger(lr_gate),) + # Return the new gate rule + return (new_left, new_right) + + return None + + +def rl_op(left, right): + """Perform a RL operation. + + A RL operation multiplies both left and right circuits + with the dagger of the right circuit's leftmost gate, and + the dagger is multiplied on the left side of both circuits. + + If a RL is possible, it returns the new gate rule as a + 2-tuple (LHS, RHS), where LHS is the left circuit and + and RHS is the right circuit of the new rule. + If a RL is not possible, None is returned. + + Parameters + ========== + + left : Gate tuple + The left circuit of a gate rule expression. + right : Gate tuple + The right circuit of a gate rule expression. + + Examples + ======== + + Generate a new gate rule using a RL operation: + + >>> from sympy.physics.quantum.identitysearch import rl_op + >>> from sympy.physics.quantum.gate import X, Y, Z + >>> x = X(0); y = Y(0); z = Z(0) + >>> rl_op((x,), (y, z)) + ((Y(0), X(0)), (Z(0),)) + + >>> rl_op((x, y), (z,)) + ((Z(0), X(0), Y(0)), ()) + """ + + if (len(right) > 0): + rl_gate = right[0] + rl_gate_is_unitary = is_scalar_matrix( + (Dagger(rl_gate), rl_gate), _get_min_qubits(rl_gate), True) + + if (len(right) > 0 and rl_gate_is_unitary): + # Get the new right side w/o the leftmost gate + new_right = right[1:len(right)] + # Add the leftmost gate to the left position on the left side + new_left = (Dagger(rl_gate),) + left + # Return the new gate rule + return (new_left, new_right) + + return None + + +def rr_op(left, right): + """Perform a RR operation. + + A RR operation multiplies both left and right circuits + with the dagger of the right circuit's rightmost gate, and + the dagger is multiplied on the right side of both circuits. + + If a RR is possible, it returns the new gate rule as a + 2-tuple (LHS, RHS), where LHS is the left circuit and + and RHS is the right circuit of the new rule. + If a RR is not possible, None is returned. + + Parameters + ========== + + left : Gate tuple + The left circuit of a gate rule expression. + right : Gate tuple + The right circuit of a gate rule expression. + + Examples + ======== + + Generate a new gate rule using a RR operation: + + >>> from sympy.physics.quantum.identitysearch import rr_op + >>> from sympy.physics.quantum.gate import X, Y, Z + >>> x = X(0); y = Y(0); z = Z(0) + >>> rr_op((x, y), (z,)) + ((X(0), Y(0), Z(0)), ()) + + >>> rr_op((x,), (y, z)) + ((X(0), Z(0)), (Y(0),)) + """ + + if (len(right) > 0): + rr_gate = right[len(right) - 1] + rr_gate_is_unitary = is_scalar_matrix( + (Dagger(rr_gate), rr_gate), _get_min_qubits(rr_gate), True) + + if (len(right) > 0 and rr_gate_is_unitary): + # Get the new right side w/o the rightmost gate + new_right = right[0:len(right) - 1] + # Add the rightmost gate to the right position on the right side + new_left = left + (Dagger(rr_gate),) + # Return the new gate rule + return (new_left, new_right) + + return None + + +def generate_gate_rules(gate_seq, return_as_muls=False): + """Returns a set of gate rules. Each gate rules is represented + as a 2-tuple of tuples or Muls. An empty tuple represents an arbitrary + scalar value. + + This function uses the four operations (LL, LR, RL, RR) + to generate the gate rules. + + A gate rule is an expression such as ABC = D or AB = CD, where + A, B, C, and D are gates. Each value on either side of the + equal sign represents a circuit. The four operations allow + one to find a set of equivalent circuits from a gate identity. + The letters denoting the operation tell the user what + activities to perform on each expression. The first letter + indicates which side of the equal sign to focus on. The + second letter indicates which gate to focus on given the + side. Once this information is determined, the inverse + of the gate is multiplied on both circuits to create a new + gate rule. + + For example, given the identity, ABCD = 1, a LL operation + means look at the left value and multiply both left sides by the + inverse of the leftmost gate A. If A is Hermitian, the inverse + of A is still A. The resulting new rule is BCD = A. + + The following is a summary of the four operations. Assume + that in the examples, all gates are Hermitian. + + LL : left circuit, left multiply + ABCD = E -> AABCD = AE -> BCD = AE + LR : left circuit, right multiply + ABCD = E -> ABCDD = ED -> ABC = ED + RL : right circuit, left multiply + ABC = ED -> EABC = EED -> EABC = D + RR : right circuit, right multiply + AB = CD -> ABD = CDD -> ABD = C + + The number of gate rules generated is n*(n+1), where n + is the number of gates in the sequence (unproven). + + Parameters + ========== + + gate_seq : Gate tuple, Mul, or Number + A variable length tuple or Mul of Gates whose product is equal to + a scalar matrix + return_as_muls : bool + True to return a set of Muls; False to return a set of tuples + + Examples + ======== + + Find the gate rules of the current circuit using tuples: + + >>> from sympy.physics.quantum.identitysearch import generate_gate_rules + >>> from sympy.physics.quantum.gate import X, Y, Z + >>> x = X(0); y = Y(0); z = Z(0) + >>> generate_gate_rules((x, x)) + {((X(0),), (X(0),)), ((X(0), X(0)), ())} + + >>> generate_gate_rules((x, y, z)) + {((), (X(0), Z(0), Y(0))), ((), (Y(0), X(0), Z(0))), + ((), (Z(0), Y(0), X(0))), ((X(0),), (Z(0), Y(0))), + ((Y(0),), (X(0), Z(0))), ((Z(0),), (Y(0), X(0))), + ((X(0), Y(0)), (Z(0),)), ((Y(0), Z(0)), (X(0),)), + ((Z(0), X(0)), (Y(0),)), ((X(0), Y(0), Z(0)), ()), + ((Y(0), Z(0), X(0)), ()), ((Z(0), X(0), Y(0)), ())} + + Find the gate rules of the current circuit using Muls: + + >>> generate_gate_rules(x*x, return_as_muls=True) + {(1, 1)} + + >>> generate_gate_rules(x*y*z, return_as_muls=True) + {(1, X(0)*Z(0)*Y(0)), (1, Y(0)*X(0)*Z(0)), + (1, Z(0)*Y(0)*X(0)), (X(0)*Y(0), Z(0)), + (Y(0)*Z(0), X(0)), (Z(0)*X(0), Y(0)), + (X(0)*Y(0)*Z(0), 1), (Y(0)*Z(0)*X(0), 1), + (Z(0)*X(0)*Y(0), 1), (X(0), Z(0)*Y(0)), + (Y(0), X(0)*Z(0)), (Z(0), Y(0)*X(0))} + """ + + if isinstance(gate_seq, Number): + if return_as_muls: + return {(S.One, S.One)} + else: + return {((), ())} + + elif isinstance(gate_seq, Mul): + gate_seq = gate_seq.args + + # Each item in queue is a 3-tuple: + # i) first item is the left side of an equality + # ii) second item is the right side of an equality + # iii) third item is the number of operations performed + # The argument, gate_seq, will start on the left side, and + # the right side will be empty, implying the presence of an + # identity. + queue = deque() + # A set of gate rules + rules = set() + # Maximum number of operations to perform + max_ops = len(gate_seq) + + def process_new_rule(new_rule, ops): + if new_rule is not None: + new_left, new_right = new_rule + + if new_rule not in rules and (new_right, new_left) not in rules: + rules.add(new_rule) + # If haven't reached the max limit on operations + if ops + 1 < max_ops: + queue.append(new_rule + (ops + 1,)) + + queue.append((gate_seq, (), 0)) + rules.add((gate_seq, ())) + + while len(queue) > 0: + left, right, ops = queue.popleft() + + # Do a LL + new_rule = ll_op(left, right) + process_new_rule(new_rule, ops) + # Do a LR + new_rule = lr_op(left, right) + process_new_rule(new_rule, ops) + # Do a RL + new_rule = rl_op(left, right) + process_new_rule(new_rule, ops) + # Do a RR + new_rule = rr_op(left, right) + process_new_rule(new_rule, ops) + + if return_as_muls: + # Convert each rule as tuples into a rule as muls + mul_rules = set() + for rule in rules: + left, right = rule + mul_rules.add((Mul(*left), Mul(*right))) + + rules = mul_rules + + return rules + + +def generate_equivalent_ids(gate_seq, return_as_muls=False): + """Returns a set of equivalent gate identities. + + A gate identity is a quantum circuit such that the product + of the gates in the circuit is equal to a scalar value. + For example, XYZ = i, where X, Y, Z are the Pauli gates and + i is the imaginary value, is considered a gate identity. + + This function uses the four operations (LL, LR, RL, RR) + to generate the gate rules and, subsequently, to locate equivalent + gate identities. + + Note that all equivalent identities are reachable in n operations + from the starting gate identity, where n is the number of gates + in the sequence. + + The max number of gate identities is 2n, where n is the number + of gates in the sequence (unproven). + + Parameters + ========== + + gate_seq : Gate tuple, Mul, or Number + A variable length tuple or Mul of Gates whose product is equal to + a scalar matrix. + return_as_muls: bool + True to return as Muls; False to return as tuples + + Examples + ======== + + Find equivalent gate identities from the current circuit with tuples: + + >>> from sympy.physics.quantum.identitysearch import generate_equivalent_ids + >>> from sympy.physics.quantum.gate import X, Y, Z + >>> x = X(0); y = Y(0); z = Z(0) + >>> generate_equivalent_ids((x, x)) + {(X(0), X(0))} + + >>> generate_equivalent_ids((x, y, z)) + {(X(0), Y(0), Z(0)), (X(0), Z(0), Y(0)), (Y(0), X(0), Z(0)), + (Y(0), Z(0), X(0)), (Z(0), X(0), Y(0)), (Z(0), Y(0), X(0))} + + Find equivalent gate identities from the current circuit with Muls: + + >>> generate_equivalent_ids(x*x, return_as_muls=True) + {1} + + >>> generate_equivalent_ids(x*y*z, return_as_muls=True) + {X(0)*Y(0)*Z(0), X(0)*Z(0)*Y(0), Y(0)*X(0)*Z(0), + Y(0)*Z(0)*X(0), Z(0)*X(0)*Y(0), Z(0)*Y(0)*X(0)} + """ + + if isinstance(gate_seq, Number): + return {S.One} + elif isinstance(gate_seq, Mul): + gate_seq = gate_seq.args + + # Filter through the gate rules and keep the rules + # with an empty tuple either on the left or right side + + # A set of equivalent gate identities + eq_ids = set() + + gate_rules = generate_gate_rules(gate_seq) + for rule in gate_rules: + l, r = rule + if l == (): + eq_ids.add(r) + elif r == (): + eq_ids.add(l) + + if return_as_muls: + convert_to_mul = lambda id_seq: Mul(*id_seq) + eq_ids = set(map(convert_to_mul, eq_ids)) + + return eq_ids + + +class GateIdentity(Basic): + """Wrapper class for circuits that reduce to a scalar value. + + A gate identity is a quantum circuit such that the product + of the gates in the circuit is equal to a scalar value. + For example, XYZ = i, where X, Y, Z are the Pauli gates and + i is the imaginary value, is considered a gate identity. + + Parameters + ========== + + args : Gate tuple + A variable length tuple of Gates that form an identity. + + Examples + ======== + + Create a GateIdentity and look at its attributes: + + >>> from sympy.physics.quantum.identitysearch import GateIdentity + >>> from sympy.physics.quantum.gate import X, Y, Z + >>> x = X(0); y = Y(0); z = Z(0) + >>> an_identity = GateIdentity(x, y, z) + >>> an_identity.circuit + X(0)*Y(0)*Z(0) + + >>> an_identity.equivalent_ids + {(X(0), Y(0), Z(0)), (X(0), Z(0), Y(0)), (Y(0), X(0), Z(0)), + (Y(0), Z(0), X(0)), (Z(0), X(0), Y(0)), (Z(0), Y(0), X(0))} + """ + + def __new__(cls, *args): + # args should be a tuple - a variable length argument list + obj = Basic.__new__(cls, *args) + obj._circuit = Mul(*args) + obj._rules = generate_gate_rules(args) + obj._eq_ids = generate_equivalent_ids(args) + + return obj + + @property + def circuit(self): + return self._circuit + + @property + def gate_rules(self): + return self._rules + + @property + def equivalent_ids(self): + return self._eq_ids + + @property + def sequence(self): + return self.args + + def __str__(self): + """Returns the string of gates in a tuple.""" + return str(self.circuit) + + +def is_degenerate(identity_set, gate_identity): + """Checks if a gate identity is a permutation of another identity. + + Parameters + ========== + + identity_set : set + A Python set with GateIdentity objects. + gate_identity : GateIdentity + The GateIdentity to check for existence in the set. + + Examples + ======== + + Check if the identity is a permutation of another identity: + + >>> from sympy.physics.quantum.identitysearch import ( + ... GateIdentity, is_degenerate) + >>> from sympy.physics.quantum.gate import X, Y, Z + >>> x = X(0); y = Y(0); z = Z(0) + >>> an_identity = GateIdentity(x, y, z) + >>> id_set = {an_identity} + >>> another_id = (y, z, x) + >>> is_degenerate(id_set, another_id) + True + + >>> another_id = (x, x) + >>> is_degenerate(id_set, another_id) + False + """ + + # For now, just iteratively go through the set and check if the current + # gate_identity is a permutation of an identity in the set + for an_id in identity_set: + if (gate_identity in an_id.equivalent_ids): + return True + return False + + +def is_reducible(circuit, nqubits, begin, end): + """Determines if a circuit is reducible by checking + if its subcircuits are scalar values. + + Parameters + ========== + + circuit : Gate tuple + A tuple of Gates representing a circuit. The circuit to check + if a gate identity is contained in a subcircuit. + nqubits : int + The number of qubits the circuit operates on. + begin : int + The leftmost gate in the circuit to include in a subcircuit. + end : int + The rightmost gate in the circuit to include in a subcircuit. + + Examples + ======== + + Check if the circuit can be reduced: + + >>> from sympy.physics.quantum.identitysearch import is_reducible + >>> from sympy.physics.quantum.gate import X, Y, Z + >>> x = X(0); y = Y(0); z = Z(0) + >>> is_reducible((x, y, z), 1, 0, 3) + True + + Check if an interval in the circuit can be reduced: + + >>> is_reducible((x, y, z), 1, 1, 3) + False + + >>> is_reducible((x, y, y), 1, 1, 3) + True + """ + + current_circuit = () + # Start from the gate at "end" and go down to almost the gate at "begin" + for ndx in reversed(range(begin, end)): + next_gate = circuit[ndx] + current_circuit = (next_gate,) + current_circuit + + # If a circuit as a matrix is equivalent to a scalar value + if (is_scalar_matrix(current_circuit, nqubits, False)): + return True + + return False + + +def bfs_identity_search(gate_list, nqubits, max_depth=None, + identity_only=False): + """Constructs a set of gate identities from the list of possible gates. + + Performs a breadth first search over the space of gate identities. + This allows the finding of the shortest gate identities first. + + Parameters + ========== + + gate_list : list, Gate + A list of Gates from which to search for gate identities. + nqubits : int + The number of qubits the quantum circuit operates on. + max_depth : int + The longest quantum circuit to construct from gate_list. + identity_only : bool + True to search for gate identities that reduce to identity; + False to search for gate identities that reduce to a scalar. + + Examples + ======== + + Find a list of gate identities: + + >>> from sympy.physics.quantum.identitysearch import bfs_identity_search + >>> from sympy.physics.quantum.gate import X, Y, Z + >>> x = X(0); y = Y(0); z = Z(0) + >>> bfs_identity_search([x], 1, max_depth=2) + {GateIdentity(X(0), X(0))} + + >>> bfs_identity_search([x, y, z], 1) + {GateIdentity(X(0), X(0)), GateIdentity(Y(0), Y(0)), + GateIdentity(Z(0), Z(0)), GateIdentity(X(0), Y(0), Z(0))} + + Find a list of identities that only equal to 1: + + >>> bfs_identity_search([x, y, z], 1, identity_only=True) + {GateIdentity(X(0), X(0)), GateIdentity(Y(0), Y(0)), + GateIdentity(Z(0), Z(0))} + """ + + if max_depth is None or max_depth <= 0: + max_depth = len(gate_list) + + id_only = identity_only + + # Start with an empty sequence (implicitly contains an IdentityGate) + queue = deque([()]) + + # Create an empty set of gate identities + ids = set() + + # Begin searching for gate identities in given space. + while (len(queue) > 0): + current_circuit = queue.popleft() + + for next_gate in gate_list: + new_circuit = current_circuit + (next_gate,) + + # Determines if a (strict) subcircuit is a scalar matrix + circuit_reducible = is_reducible(new_circuit, nqubits, + 1, len(new_circuit)) + + # In many cases when the matrix is a scalar value, + # the evaluated matrix will actually be an integer + if (is_scalar_matrix(new_circuit, nqubits, id_only) and + not is_degenerate(ids, new_circuit) and + not circuit_reducible): + ids.add(GateIdentity(*new_circuit)) + + elif (len(new_circuit) < max_depth and + not circuit_reducible): + queue.append(new_circuit) + + return ids + + +def random_identity_search(gate_list, numgates, nqubits): + """Randomly selects numgates from gate_list and checks if it is + a gate identity. + + If the circuit is a gate identity, the circuit is returned; + Otherwise, None is returned. + """ + + gate_size = len(gate_list) + circuit = () + + for i in range(numgates): + next_gate = gate_list[randint(0, gate_size - 1)] + circuit = circuit + (next_gate,) + + is_scalar = is_scalar_matrix(circuit, nqubits, False) + + return circuit if is_scalar else None diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/innerproduct.py b/phivenv/Lib/site-packages/sympy/physics/quantum/innerproduct.py new file mode 100644 index 0000000000000000000000000000000000000000..11fed882b6068a4df5a787ff90eee5392f97447a --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/innerproduct.py @@ -0,0 +1,138 @@ +"""Symbolic inner product.""" + +from sympy.core.expr import Expr +from sympy.core.kind import NumberKind +from sympy.functions.elementary.complexes import conjugate +from sympy.printing.pretty.stringpict import prettyForm +from sympy.physics.quantum.dagger import Dagger + + +__all__ = [ + 'InnerProduct' +] + + +# InnerProduct is not an QExpr because it is really just a regular commutative +# number. We have gone back and forth about this, but we gain a lot by having +# it subclass Expr. The main challenges were getting Dagger to work +# (we use _eval_conjugate) and represent (we can use atoms and subs). Having +# it be an Expr, mean that there are no commutative QExpr subclasses, +# which simplifies the design of everything. + +class InnerProduct(Expr): + """An unevaluated inner product between a Bra and a Ket [1]. + + Parameters + ========== + + bra : BraBase or subclass + The bra on the left side of the inner product. + ket : KetBase or subclass + The ket on the right side of the inner product. + + Examples + ======== + + Create an InnerProduct and check its properties: + + >>> from sympy.physics.quantum import Bra, Ket + >>> b = Bra('b') + >>> k = Ket('k') + >>> ip = b*k + >>> ip + + >>> ip.bra + >> ip.ket + |k> + + In quantum expressions, inner products will be automatically + identified and created:: + + >>> b*k + + + In more complex expressions, where there is ambiguity in whether inner or + outer products should be created, inner products have high priority:: + + >>> k*b*k*b + *|k> moved to the left of the expression + because inner products are commutative complex numbers. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Inner_product + """ + + kind = NumberKind + + is_complex = True + + def __new__(cls, bra, ket): + # Keep the import of BraBase and KetBase here to avoid problems + # with circular imports. + from sympy.physics.quantum.state import KetBase, BraBase + if not isinstance(ket, KetBase): + raise TypeError('KetBase subclass expected, got: %r' % ket) + if not isinstance(bra, BraBase): + raise TypeError('BraBase subclass expected, got: %r' % ket) + obj = Expr.__new__(cls, bra, ket) + return obj + + @property + def bra(self): + return self.args[0] + + @property + def ket(self): + return self.args[1] + + def _eval_conjugate(self): + return InnerProduct(Dagger(self.ket), Dagger(self.bra)) + + def _sympyrepr(self, printer, *args): + return '%s(%s,%s)' % (self.__class__.__name__, + printer._print(self.bra, *args), printer._print(self.ket, *args)) + + def _sympystr(self, printer, *args): + sbra = printer._print(self.bra) + sket = printer._print(self.ket) + return '%s|%s' % (sbra[:-1], sket[1:]) + + def _pretty(self, printer, *args): + # Print state contents + bra = self.bra._print_contents_pretty(printer, *args) + ket = self.ket._print_contents_pretty(printer, *args) + # Print brackets + height = max(bra.height(), ket.height()) + use_unicode = printer._use_unicode + lbracket, _ = self.bra._pretty_brackets(height, use_unicode) + cbracket, rbracket = self.ket._pretty_brackets(height, use_unicode) + # Build innerproduct + pform = prettyForm(*bra.left(lbracket)) + pform = prettyForm(*pform.right(cbracket)) + pform = prettyForm(*pform.right(ket)) + pform = prettyForm(*pform.right(rbracket)) + return pform + + def _latex(self, printer, *args): + bra_label = self.bra._print_contents_latex(printer, *args) + ket = printer._print(self.ket, *args) + return r'\left\langle %s \right. %s' % (bra_label, ket) + + def doit(self, **hints): + try: + r = self.ket._eval_innerproduct(self.bra, **hints) + except NotImplementedError: + try: + r = conjugate( + self.bra.dual._eval_innerproduct(self.ket.dual, **hints) + ) + except NotImplementedError: + r = None + if r is not None: + return r + return self diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/kind.py b/phivenv/Lib/site-packages/sympy/physics/quantum/kind.py new file mode 100644 index 0000000000000000000000000000000000000000..14b5bd2c7b0c87f49dc7e6dc9c1b492fbfad6d56 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/kind.py @@ -0,0 +1,103 @@ +"""Kinds for Operators, Bras, and Kets. + +This module defines kinds for operators, bras, and kets. These are useful +in various places in ``sympy.physics.quantum`` as you often want to know +what the kind is of a compound expression. For example, if you multiply +an operator, bra, or ket by a number, you get back another operator, bra, +or ket - even though if you did an ``isinstance`` check you would find that +you have a ``Mul`` instead. The kind system is meant to give you a quick +way of determining how a compound expression behaves in terms of lower +level kinds. + +The resolution calculation of kinds for compound expressions can be found +either in container classes or in functions that are registered with +kind dispatchers. +""" + +from sympy.core.mul import Mul +from sympy.core.kind import Kind, _NumberKind + + +__all__ = [ + '_KetKind', + 'KetKind', + '_BraKind', + 'BraKind', + '_OperatorKind', + 'OperatorKind', +] + + +class _KetKind(Kind): + """A kind for quantum kets.""" + + def __new__(cls): + obj = super().__new__(cls) + return obj + + def __repr__(self): + return "KetKind" + +# Create an instance as many situations need this. +KetKind = _KetKind() + + +class _BraKind(Kind): + """A kind for quantum bras.""" + + def __new__(cls): + obj = super().__new__(cls) + return obj + + def __repr__(self): + return "BraKind" + +# Create an instance as many situations need this. +BraKind = _BraKind() + + +from sympy.core.kind import Kind + +class _OperatorKind(Kind): + """A kind for quantum operators.""" + + def __new__(cls): + obj = super().__new__(cls) + return obj + + def __repr__(self): + return "OperatorKind" + +# Create an instance as many situations need this. +OperatorKind = _OperatorKind() + + +#----------------------------------------------------------------------------- +# Kind resolution. +#----------------------------------------------------------------------------- + +# Note: We can't currently add kind dispatchers for the following combinations +# as the Mul._kind_dispatcher is set to commutative and will also +# register the opposite order, which isn't correct for these pairs: +# +# 1. (_OperatorKind, _KetKind) +# 2. (_BraKind, _OperatorKind) +# 3. (_BraKind, _KetKind) + + +@Mul._kind_dispatcher.register(_NumberKind, _KetKind) +def _mul_number_ket_kind(lhs, rhs): + """Perform the kind calculation of NumberKind*KetKind -> KetKind.""" + return KetKind + + +@Mul._kind_dispatcher.register(_NumberKind, _BraKind) +def _mul_number_bra_kind(lhs, rhs): + """Perform the kind calculation of NumberKind*BraKind -> BraKind.""" + return BraKind + + +@Mul._kind_dispatcher.register(_NumberKind, _OperatorKind) +def _mul_operator_kind(lhs, rhs): + """Perform the kind calculation of NumberKind*OperatorKind -> OperatorKind.""" + return OperatorKind diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/matrixcache.py b/phivenv/Lib/site-packages/sympy/physics/quantum/matrixcache.py new file mode 100644 index 0000000000000000000000000000000000000000..3cfab3c3490c909966d8a56af395ffa578724ea7 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/matrixcache.py @@ -0,0 +1,103 @@ +"""A cache for storing small matrices in multiple formats.""" + +from sympy.core.numbers import (I, Rational, pi) +from sympy.core.power import Pow +from sympy.functions.elementary.exponential import exp +from sympy.matrices.dense import Matrix + +from sympy.physics.quantum.matrixutils import ( + to_sympy, to_numpy, to_scipy_sparse +) + + +class MatrixCache: + """A cache for small matrices in different formats. + + This class takes small matrices in the standard ``sympy.Matrix`` format, + and then converts these to both ``numpy.matrix`` and + ``scipy.sparse.csr_matrix`` matrices. These matrices are then stored for + future recovery. + """ + + def __init__(self, dtype='complex'): + self._cache = {} + self.dtype = dtype + + def cache_matrix(self, name, m): + """Cache a matrix by its name. + + Parameters + ---------- + name : str + A descriptive name for the matrix, like "identity2". + m : list of lists + The raw matrix data as a SymPy Matrix. + """ + try: + self._sympy_matrix(name, m) + except ImportError: + pass + try: + self._numpy_matrix(name, m) + except ImportError: + pass + try: + self._scipy_sparse_matrix(name, m) + except ImportError: + pass + + def get_matrix(self, name, format): + """Get a cached matrix by name and format. + + Parameters + ---------- + name : str + A descriptive name for the matrix, like "identity2". + format : str + The format desired ('sympy', 'numpy', 'scipy.sparse') + """ + m = self._cache.get((name, format)) + if m is not None: + return m + raise NotImplementedError( + 'Matrix with name %s and format %s is not available.' % + (name, format) + ) + + def _store_matrix(self, name, format, m): + self._cache[(name, format)] = m + + def _sympy_matrix(self, name, m): + self._store_matrix(name, 'sympy', to_sympy(m)) + + def _numpy_matrix(self, name, m): + m = to_numpy(m, dtype=self.dtype) + self._store_matrix(name, 'numpy', m) + + def _scipy_sparse_matrix(self, name, m): + # TODO: explore different sparse formats. But sparse.kron will use + # coo in most cases, so we use that here. + m = to_scipy_sparse(m, dtype=self.dtype) + self._store_matrix(name, 'scipy.sparse', m) + + +sqrt2_inv = Pow(2, Rational(-1, 2), evaluate=False) + +# Save the common matrices that we will need +matrix_cache = MatrixCache() +matrix_cache.cache_matrix('eye2', Matrix([[1, 0], [0, 1]])) +matrix_cache.cache_matrix('op11', Matrix([[0, 0], [0, 1]])) # |1><1| +matrix_cache.cache_matrix('op00', Matrix([[1, 0], [0, 0]])) # |0><0| +matrix_cache.cache_matrix('op10', Matrix([[0, 0], [1, 0]])) # |1><0| +matrix_cache.cache_matrix('op01', Matrix([[0, 1], [0, 0]])) # |0><1| +matrix_cache.cache_matrix('X', Matrix([[0, 1], [1, 0]])) +matrix_cache.cache_matrix('Y', Matrix([[0, -I], [I, 0]])) +matrix_cache.cache_matrix('Z', Matrix([[1, 0], [0, -1]])) +matrix_cache.cache_matrix('S', Matrix([[1, 0], [0, I]])) +matrix_cache.cache_matrix('T', Matrix([[1, 0], [0, exp(I*pi/4)]])) +matrix_cache.cache_matrix('H', sqrt2_inv*Matrix([[1, 1], [1, -1]])) +matrix_cache.cache_matrix('Hsqrt2', Matrix([[1, 1], [1, -1]])) +matrix_cache.cache_matrix( + 'SWAP', Matrix([[1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])) +matrix_cache.cache_matrix('ZX', sqrt2_inv*Matrix([[1, 1], [1, -1]])) +matrix_cache.cache_matrix('ZY', Matrix([[I, 0], [0, -I]])) diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/matrixutils.py b/phivenv/Lib/site-packages/sympy/physics/quantum/matrixutils.py new file mode 100644 index 0000000000000000000000000000000000000000..1082ea326b68256dac96030e36d72efa664495d2 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/matrixutils.py @@ -0,0 +1,272 @@ +"""Utilities to deal with sympy.Matrix, numpy and scipy.sparse.""" + +from sympy.core.expr import Expr +from sympy.core.numbers import I +from sympy.core.singleton import S +from sympy.matrices.matrixbase import MatrixBase +from sympy.matrices import eye, zeros +from sympy.external import import_module + +__all__ = [ + 'numpy_ndarray', + 'scipy_sparse_matrix', + 'sympy_to_numpy', + 'sympy_to_scipy_sparse', + 'numpy_to_sympy', + 'scipy_sparse_to_sympy', + 'flatten_scalar', + 'matrix_dagger', + 'to_sympy', + 'to_numpy', + 'to_scipy_sparse', + 'matrix_tensor_product', + 'matrix_zeros' +] + +# Conditionally define the base classes for numpy and scipy.sparse arrays +# for use in isinstance tests. + +np = import_module('numpy') +if not np: + class numpy_ndarray: + pass +else: + numpy_ndarray = np.ndarray # type: ignore + +scipy = import_module('scipy', import_kwargs={'fromlist': ['sparse']}) +if not scipy: + class scipy_sparse_matrix: + pass + sparse = None +else: + sparse = scipy.sparse + scipy_sparse_matrix = sparse.spmatrix # type: ignore + + +def sympy_to_numpy(m, **options): + """Convert a SymPy Matrix/complex number to a numpy matrix or scalar.""" + if not np: + raise ImportError + dtype = options.get('dtype', 'complex') + if isinstance(m, MatrixBase): + return np.array(m.tolist(), dtype=dtype) + elif isinstance(m, Expr): + if m.is_Number or m.is_NumberSymbol or m == I: + return complex(m) + raise TypeError('Expected MatrixBase or complex scalar, got: %r' % m) + + +def sympy_to_scipy_sparse(m, **options): + """Convert a SymPy Matrix/complex number to a numpy matrix or scalar.""" + if not np or not sparse: + raise ImportError + dtype = options.get('dtype', 'complex') + if isinstance(m, MatrixBase): + return sparse.csr_matrix(np.array(m.tolist(), dtype=dtype)) + elif isinstance(m, Expr): + if m.is_Number or m.is_NumberSymbol or m == I: + return complex(m) + raise TypeError('Expected MatrixBase or complex scalar, got: %r' % m) + + +def scipy_sparse_to_sympy(m, **options): + """Convert a scipy.sparse matrix to a SymPy matrix.""" + return MatrixBase(m.todense()) + + +def numpy_to_sympy(m, **options): + """Convert a numpy matrix to a SymPy matrix.""" + return MatrixBase(m) + + +def to_sympy(m, **options): + """Convert a numpy/scipy.sparse matrix to a SymPy matrix.""" + if isinstance(m, MatrixBase): + return m + elif isinstance(m, numpy_ndarray): + return numpy_to_sympy(m) + elif isinstance(m, scipy_sparse_matrix): + return scipy_sparse_to_sympy(m) + elif isinstance(m, Expr): + return m + raise TypeError('Expected sympy/numpy/scipy.sparse matrix, got: %r' % m) + + +def to_numpy(m, **options): + """Convert a sympy/scipy.sparse matrix to a numpy matrix.""" + dtype = options.get('dtype', 'complex') + if isinstance(m, (MatrixBase, Expr)): + return sympy_to_numpy(m, dtype=dtype) + elif isinstance(m, numpy_ndarray): + return m + elif isinstance(m, scipy_sparse_matrix): + return m.todense() + raise TypeError('Expected sympy/numpy/scipy.sparse matrix, got: %r' % m) + + +def to_scipy_sparse(m, **options): + """Convert a sympy/numpy matrix to a scipy.sparse matrix.""" + dtype = options.get('dtype', 'complex') + if isinstance(m, (MatrixBase, Expr)): + return sympy_to_scipy_sparse(m, dtype=dtype) + elif isinstance(m, numpy_ndarray): + if not sparse: + raise ImportError + return sparse.csr_matrix(m) + elif isinstance(m, scipy_sparse_matrix): + return m + raise TypeError('Expected sympy/numpy/scipy.sparse matrix, got: %r' % m) + + +def flatten_scalar(e): + """Flatten a 1x1 matrix to a scalar, return larger matrices unchanged.""" + if isinstance(e, MatrixBase): + if e.shape == (1, 1): + e = e[0] + if isinstance(e, (numpy_ndarray, scipy_sparse_matrix)): + if e.shape == (1, 1): + e = complex(e[0, 0]) + return e + + +def matrix_dagger(e): + """Return the dagger of a sympy/numpy/scipy.sparse matrix.""" + if isinstance(e, MatrixBase): + return e.H + elif isinstance(e, (numpy_ndarray, scipy_sparse_matrix)): + return e.conjugate().transpose() + raise TypeError('Expected sympy/numpy/scipy.sparse matrix, got: %r' % e) + + +# TODO: Move this into sympy.matrices. +def _sympy_tensor_product(*matrices): + """Compute the kronecker product of a sequence of SymPy Matrices. + """ + from sympy.matrices.expressions.kronecker import matrix_kronecker_product + + return matrix_kronecker_product(*matrices) + + +def _numpy_tensor_product(*product): + """numpy version of tensor product of multiple arguments.""" + if not np: + raise ImportError + answer = product[0] + for item in product[1:]: + answer = np.kron(answer, item) + return answer + + +def _scipy_sparse_tensor_product(*product): + """scipy.sparse version of tensor product of multiple arguments.""" + if not sparse: + raise ImportError + answer = product[0] + for item in product[1:]: + answer = sparse.kron(answer, item) + # The final matrices will just be multiplied, so csr is a good final + # sparse format. + return sparse.csr_matrix(answer) + + +def matrix_tensor_product(*product): + """Compute the matrix tensor product of sympy/numpy/scipy.sparse matrices.""" + if isinstance(product[0], MatrixBase): + return _sympy_tensor_product(*product) + elif isinstance(product[0], numpy_ndarray): + return _numpy_tensor_product(*product) + elif isinstance(product[0], scipy_sparse_matrix): + return _scipy_sparse_tensor_product(*product) + + +def _numpy_eye(n): + """numpy version of complex eye.""" + if not np: + raise ImportError + return np.array(np.eye(n, dtype='complex')) + + +def _scipy_sparse_eye(n): + """scipy.sparse version of complex eye.""" + if not sparse: + raise ImportError + return sparse.eye(n, n, dtype='complex') + + +def matrix_eye(n, **options): + """Get the version of eye and tensor_product for a given format.""" + format = options.get('format', 'sympy') + if format == 'sympy': + return eye(n) + elif format == 'numpy': + return _numpy_eye(n) + elif format == 'scipy.sparse': + return _scipy_sparse_eye(n) + raise NotImplementedError('Invalid format: %r' % format) + + +def _numpy_zeros(m, n, **options): + """numpy version of zeros.""" + dtype = options.get('dtype', 'float64') + if not np: + raise ImportError + return np.zeros((m, n), dtype=dtype) + + +def _scipy_sparse_zeros(m, n, **options): + """scipy.sparse version of zeros.""" + spmatrix = options.get('spmatrix', 'csr') + dtype = options.get('dtype', 'float64') + if not sparse: + raise ImportError + if spmatrix == 'lil': + return sparse.lil_matrix((m, n), dtype=dtype) + elif spmatrix == 'csr': + return sparse.csr_matrix((m, n), dtype=dtype) + + +def matrix_zeros(m, n, **options): + """"Get a zeros matrix for a given format.""" + format = options.get('format', 'sympy') + if format == 'sympy': + return zeros(m, n) + elif format == 'numpy': + return _numpy_zeros(m, n, **options) + elif format == 'scipy.sparse': + return _scipy_sparse_zeros(m, n, **options) + raise NotImplementedError('Invaild format: %r' % format) + + +def _numpy_matrix_to_zero(e): + """Convert a numpy zero matrix to the zero scalar.""" + if not np: + raise ImportError + test = np.zeros_like(e) + if np.allclose(e, test): + return 0.0 + else: + return e + + +def _scipy_sparse_matrix_to_zero(e): + """Convert a scipy.sparse zero matrix to the zero scalar.""" + if not np: + raise ImportError + edense = e.todense() + test = np.zeros_like(edense) + if np.allclose(edense, test): + return 0.0 + else: + return e + + +def matrix_to_zero(e): + """Convert a zero matrix to the scalar zero.""" + if isinstance(e, MatrixBase): + if zeros(*e.shape) == e: + e = S.Zero + elif isinstance(e, numpy_ndarray): + e = _numpy_matrix_to_zero(e) + elif isinstance(e, scipy_sparse_matrix): + e = _scipy_sparse_matrix_to_zero(e) + return e diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/operator.py b/phivenv/Lib/site-packages/sympy/physics/quantum/operator.py new file mode 100644 index 0000000000000000000000000000000000000000..b44617e15c19e8b30b76f011630430787233e724 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/operator.py @@ -0,0 +1,653 @@ +"""Quantum mechanical operators. + +TODO: + +* Fix early 0 in apply_operators. +* Debug and test apply_operators. +* Get cse working with classes in this file. +* Doctests and documentation of special methods for InnerProduct, Commutator, + AntiCommutator, represent, apply_operators. +""" +from typing import Optional + +from sympy.core.add import Add +from sympy.core.expr import Expr +from sympy.core.function import (Derivative, expand) +from sympy.core.mul import Mul +from sympy.core.numbers import oo +from sympy.core.singleton import S +from sympy.printing.pretty.stringpict import prettyForm +from sympy.physics.quantum.dagger import Dagger +from sympy.physics.quantum.kind import OperatorKind +from sympy.physics.quantum.qexpr import QExpr, dispatch_method +from sympy.matrices import eye +from sympy.utilities.exceptions import sympy_deprecation_warning + + + +__all__ = [ + 'Operator', + 'HermitianOperator', + 'UnitaryOperator', + 'IdentityOperator', + 'OuterProduct', + 'DifferentialOperator' +] + +#----------------------------------------------------------------------------- +# Operators and outer products +#----------------------------------------------------------------------------- + + +class Operator(QExpr): + """Base class for non-commuting quantum operators. + + An operator maps between quantum states [1]_. In quantum mechanics, + observables (including, but not limited to, measured physical values) are + represented as Hermitian operators [2]_. + + Parameters + ========== + + args : tuple + The list of numbers or parameters that uniquely specify the + operator. For time-dependent operators, this will include the time. + + Examples + ======== + + Create an operator and examine its attributes:: + + >>> from sympy.physics.quantum import Operator + >>> from sympy import I + >>> A = Operator('A') + >>> A + A + >>> A.hilbert_space + H + >>> A.label + (A,) + >>> A.is_commutative + False + + Create another operator and do some arithmetic operations:: + + >>> B = Operator('B') + >>> C = 2*A*A + I*B + >>> C + 2*A**2 + I*B + + Operators do not commute:: + + >>> A.is_commutative + False + >>> B.is_commutative + False + >>> A*B == B*A + False + + Polymonials of operators respect the commutation properties:: + + >>> e = (A+B)**3 + >>> e.expand() + A*B*A + A*B**2 + A**2*B + A**3 + B*A*B + B*A**2 + B**2*A + B**3 + + Operator inverses are handle symbolically:: + + >>> A.inv() + A**(-1) + >>> A*A.inv() + 1 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Operator_%28physics%29 + .. [2] https://en.wikipedia.org/wiki/Observable + """ + is_hermitian: Optional[bool] = None + is_unitary: Optional[bool] = None + @classmethod + def default_args(self): + return ("O",) + + kind = OperatorKind + + #------------------------------------------------------------------------- + # Printing + #------------------------------------------------------------------------- + + _label_separator = ',' + + def _print_operator_name(self, printer, *args): + return self.__class__.__name__ + + _print_operator_name_latex = _print_operator_name + + def _print_operator_name_pretty(self, printer, *args): + return prettyForm(self.__class__.__name__) + + def _print_contents(self, printer, *args): + if len(self.label) == 1: + return self._print_label(printer, *args) + else: + return '%s(%s)' % ( + self._print_operator_name(printer, *args), + self._print_label(printer, *args) + ) + + def _print_contents_pretty(self, printer, *args): + if len(self.label) == 1: + return self._print_label_pretty(printer, *args) + else: + pform = self._print_operator_name_pretty(printer, *args) + label_pform = self._print_label_pretty(printer, *args) + label_pform = prettyForm( + *label_pform.parens(left='(', right=')') + ) + pform = prettyForm(*pform.right(label_pform)) + return pform + + def _print_contents_latex(self, printer, *args): + if len(self.label) == 1: + return self._print_label_latex(printer, *args) + else: + return r'%s\left(%s\right)' % ( + self._print_operator_name_latex(printer, *args), + self._print_label_latex(printer, *args) + ) + + #------------------------------------------------------------------------- + # _eval_* methods + #------------------------------------------------------------------------- + + def _eval_commutator(self, other, **options): + """Evaluate [self, other] if known, return None if not known.""" + return dispatch_method(self, '_eval_commutator', other, **options) + + def _eval_anticommutator(self, other, **options): + """Evaluate [self, other] if known.""" + return dispatch_method(self, '_eval_anticommutator', other, **options) + + #------------------------------------------------------------------------- + # Operator application + #------------------------------------------------------------------------- + + def _apply_operator(self, ket, **options): + return dispatch_method(self, '_apply_operator', ket, **options) + + def _apply_from_right_to(self, bra, **options): + return None + + def matrix_element(self, *args): + raise NotImplementedError('matrix_elements is not defined') + + def inverse(self): + return self._eval_inverse() + + inv = inverse + + def _eval_inverse(self): + return self**(-1) + + +class HermitianOperator(Operator): + """A Hermitian operator that satisfies H == Dagger(H). + + Parameters + ========== + + args : tuple + The list of numbers or parameters that uniquely specify the + operator. For time-dependent operators, this will include the time. + + Examples + ======== + + >>> from sympy.physics.quantum import Dagger, HermitianOperator + >>> H = HermitianOperator('H') + >>> Dagger(H) + H + """ + + is_hermitian = True + + def _eval_inverse(self): + if isinstance(self, UnitaryOperator): + return self + else: + return Operator._eval_inverse(self) + + def _eval_power(self, exp): + if isinstance(self, UnitaryOperator): + # so all eigenvalues of self are 1 or -1 + if exp.is_even: + from sympy.core.singleton import S + return S.One # is identity, see Issue 24153. + elif exp.is_odd: + return self + # No simplification in all other cases + return Operator._eval_power(self, exp) + + +class UnitaryOperator(Operator): + """A unitary operator that satisfies U*Dagger(U) == 1. + + Parameters + ========== + + args : tuple + The list of numbers or parameters that uniquely specify the + operator. For time-dependent operators, this will include the time. + + Examples + ======== + + >>> from sympy.physics.quantum import Dagger, UnitaryOperator + >>> U = UnitaryOperator('U') + >>> U*Dagger(U) + 1 + """ + is_unitary = True + def _eval_adjoint(self): + return self._eval_inverse() + + +class IdentityOperator(Operator): + """An identity operator I that satisfies op * I == I * op == op for any + operator op. + + .. deprecated:: 1.14. + Use the scalar S.One instead as the multiplicative identity for + operators and states. + + Parameters + ========== + + N : Integer + Optional parameter that specifies the dimension of the Hilbert space + of operator. This is used when generating a matrix representation. + + Examples + ======== + + >>> from sympy.physics.quantum import IdentityOperator + >>> IdentityOperator() # doctest: +SKIP + I + """ + is_hermitian = True + is_unitary = True + @property + def dimension(self): + return self.N + + @classmethod + def default_args(self): + return (oo,) + + def __init__(self, *args, **hints): + sympy_deprecation_warning( + """ + IdentityOperator has been deprecated. In the future, please use + S.One as the identity for quantum operators and states. + """, + deprecated_since_version="1.14", + active_deprecations_target='deprecated-operator-identity', + ) + if not len(args) in (0, 1): + raise ValueError('0 or 1 parameters expected, got %s' % args) + + self.N = args[0] if (len(args) == 1 and args[0]) else oo + + def _eval_commutator(self, other, **hints): + return S.Zero + + def _eval_anticommutator(self, other, **hints): + return 2 * other + + def _eval_inverse(self): + return self + + def _eval_adjoint(self): + return self + + def _apply_operator(self, ket, **options): + return ket + + def _apply_from_right_to(self, bra, **options): + return bra + + def _eval_power(self, exp): + return self + + def _print_contents(self, printer, *args): + return 'I' + + def _print_contents_pretty(self, printer, *args): + return prettyForm('I') + + def _print_contents_latex(self, printer, *args): + return r'{\mathcal{I}}' + + def _represent_default_basis(self, **options): + if not self.N or self.N == oo: + raise NotImplementedError('Cannot represent infinite dimensional' + + ' identity operator as a matrix') + + format = options.get('format', 'sympy') + if format != 'sympy': + raise NotImplementedError('Representation in format ' + + '%s not implemented.' % format) + + return eye(self.N) + + +class OuterProduct(Operator): + """An unevaluated outer product between a ket and bra. + + This constructs an outer product between any subclass of ``KetBase`` and + ``BraBase`` as ``|a>>> from sympy.physics.quantum import Ket, Bra, OuterProduct, Dagger + + >>> k = Ket('k') + >>> b = Bra('b') + >>> op = OuterProduct(k, b) + >>> op + |k>>> op.hilbert_space + H + >>> op.ket + |k> + >>> op.bra + >> Dagger(op) + |b>>> k*b + |k>>> b*k*b + *>> from sympy import Derivative, Function, Symbol + >>> from sympy.physics.quantum.operator import DifferentialOperator + >>> from sympy.physics.quantum.state import Wavefunction + >>> from sympy.physics.quantum.qapply import qapply + >>> f = Function('f') + >>> x = Symbol('x') + >>> d = DifferentialOperator(1/x*Derivative(f(x), x), f(x)) + >>> w = Wavefunction(x**2, x) + >>> d.function + f(x) + >>> d.variables + (x,) + >>> qapply(d*w) + Wavefunction(2, x) + + """ + + @property + def variables(self): + """ + Returns the variables with which the function in the specified + arbitrary expression is evaluated + + Examples + ======== + + >>> from sympy.physics.quantum.operator import DifferentialOperator + >>> from sympy import Symbol, Function, Derivative + >>> x = Symbol('x') + >>> f = Function('f') + >>> d = DifferentialOperator(1/x*Derivative(f(x), x), f(x)) + >>> d.variables + (x,) + >>> y = Symbol('y') + >>> d = DifferentialOperator(Derivative(f(x, y), x) + + ... Derivative(f(x, y), y), f(x, y)) + >>> d.variables + (x, y) + """ + + return self.args[-1].args + + @property + def function(self): + """ + Returns the function which is to be replaced with the Wavefunction + + Examples + ======== + + >>> from sympy.physics.quantum.operator import DifferentialOperator + >>> from sympy import Function, Symbol, Derivative + >>> x = Symbol('x') + >>> f = Function('f') + >>> d = DifferentialOperator(Derivative(f(x), x), f(x)) + >>> d.function + f(x) + >>> y = Symbol('y') + >>> d = DifferentialOperator(Derivative(f(x, y), x) + + ... Derivative(f(x, y), y), f(x, y)) + >>> d.function + f(x, y) + """ + + return self.args[-1] + + @property + def expr(self): + """ + Returns the arbitrary expression which is to have the Wavefunction + substituted into it + + Examples + ======== + + >>> from sympy.physics.quantum.operator import DifferentialOperator + >>> from sympy import Function, Symbol, Derivative + >>> x = Symbol('x') + >>> f = Function('f') + >>> d = DifferentialOperator(Derivative(f(x), x), f(x)) + >>> d.expr + Derivative(f(x), x) + >>> y = Symbol('y') + >>> d = DifferentialOperator(Derivative(f(x, y), x) + + ... Derivative(f(x, y), y), f(x, y)) + >>> d.expr + Derivative(f(x, y), x) + Derivative(f(x, y), y) + """ + + return self.args[0] + + @property + def free_symbols(self): + """ + Return the free symbols of the expression. + """ + + return self.expr.free_symbols + + def _apply_operator_Wavefunction(self, func, **options): + from sympy.physics.quantum.state import Wavefunction + var = self.variables + wf_vars = func.args[1:] + + f = self.function + new_expr = self.expr.subs(f, func(*var)) + new_expr = new_expr.doit() + + return Wavefunction(new_expr, *wf_vars) + + def _eval_derivative(self, symbol): + new_expr = Derivative(self.expr, symbol) + return DifferentialOperator(new_expr, self.args[-1]) + + #------------------------------------------------------------------------- + # Printing + #------------------------------------------------------------------------- + + def _print(self, printer, *args): + return '%s(%s)' % ( + self._print_operator_name(printer, *args), + self._print_label(printer, *args) + ) + + def _print_pretty(self, printer, *args): + pform = self._print_operator_name_pretty(printer, *args) + label_pform = self._print_label_pretty(printer, *args) + label_pform = prettyForm( + *label_pform.parens(left='(', right=')') + ) + pform = prettyForm(*pform.right(label_pform)) + return pform diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/operatorordering.py b/phivenv/Lib/site-packages/sympy/physics/quantum/operatorordering.py new file mode 100644 index 0000000000000000000000000000000000000000..d6ba3dd83b4b79b773793b0094e636cc8a901f44 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/operatorordering.py @@ -0,0 +1,290 @@ +"""Functions for reordering operator expressions.""" + +import warnings + +from sympy.core.add import Add +from sympy.core.mul import Mul +from sympy.core.numbers import Integer +from sympy.core.power import Pow +from sympy.physics.quantum import Commutator, AntiCommutator +from sympy.physics.quantum.boson import BosonOp +from sympy.physics.quantum.fermion import FermionOp + +__all__ = [ + 'normal_order', + 'normal_ordered_form' +] + + +def _expand_powers(factors): + """ + Helper function for normal_ordered_form and normal_order: Expand a + power expression to a multiplication expression so that that the + expression can be handled by the normal ordering functions. + """ + + new_factors = [] + for factor in factors.args: + if (isinstance(factor, Pow) + and isinstance(factor.args[1], Integer) + and factor.args[1] > 0): + for n in range(factor.args[1]): + new_factors.append(factor.args[0]) + else: + new_factors.append(factor) + + return new_factors + +def _normal_ordered_form_factor(product, independent=False, recursive_limit=10, + _recursive_depth=0): + """ + Helper function for normal_ordered_form_factor: Write multiplication + expression with bosonic or fermionic operators on normally ordered form, + using the bosonic and fermionic commutation relations. The resulting + operator expression is equivalent to the argument, but will in general be + a sum of operator products instead of a simple product. + """ + + factors = _expand_powers(product) + + new_factors = [] + n = 0 + while n < len(factors) - 1: + current, next = factors[n], factors[n + 1] + if any(not isinstance(f, (FermionOp, BosonOp)) for f in (current, next)): + new_factors.append(current) + n += 1 + continue + + key_1 = (current.is_annihilation, str(current.name)) + key_2 = (next.is_annihilation, str(next.name)) + + if key_1 <= key_2: + new_factors.append(current) + n += 1 + continue + + n += 2 + if current.is_annihilation and not next.is_annihilation: + if isinstance(current, BosonOp) and isinstance(next, BosonOp): + if current.args[0] != next.args[0]: + if independent: + c = 0 + else: + c = Commutator(current, next) + new_factors.append(next * current + c) + else: + new_factors.append(next * current + 1) + elif isinstance(current, FermionOp) and isinstance(next, FermionOp): + if current.args[0] != next.args[0]: + if independent: + c = 0 + else: + c = AntiCommutator(current, next) + new_factors.append(-next * current + c) + else: + new_factors.append(-next * current + 1) + elif (current.is_annihilation == next.is_annihilation and + isinstance(current, FermionOp) and isinstance(next, FermionOp)): + new_factors.append(-next * current) + else: + new_factors.append(next * current) + + if n == len(factors) - 1: + new_factors.append(factors[-1]) + + if new_factors == factors: + return product + else: + expr = Mul(*new_factors).expand() + return normal_ordered_form(expr, + recursive_limit=recursive_limit, + _recursive_depth=_recursive_depth + 1, + independent=independent) + + +def _normal_ordered_form_terms(expr, independent=False, recursive_limit=10, + _recursive_depth=0): + """ + Helper function for normal_ordered_form: loop through each term in an + addition expression and call _normal_ordered_form_factor to perform the + factor to an normally ordered expression. + """ + + new_terms = [] + for term in expr.args: + if isinstance(term, Mul): + new_term = _normal_ordered_form_factor( + term, recursive_limit=recursive_limit, + _recursive_depth=_recursive_depth, independent=independent) + new_terms.append(new_term) + else: + new_terms.append(term) + + return Add(*new_terms) + + +def normal_ordered_form(expr, independent=False, recursive_limit=10, + _recursive_depth=0): + """Write an expression with bosonic or fermionic operators on normal + ordered form, where each term is normally ordered. Note that this + normal ordered form is equivalent to the original expression. + + Parameters + ========== + + expr : expression + The expression write on normal ordered form. + independent : bool (default False) + Whether to consider operator with different names as operating in + different Hilbert spaces. If False, the (anti-)commutation is left + explicit. + recursive_limit : int (default 10) + The number of allowed recursive applications of the function. + + Examples + ======== + + >>> from sympy.physics.quantum import Dagger + >>> from sympy.physics.quantum.boson import BosonOp + >>> from sympy.physics.quantum.operatorordering import normal_ordered_form + >>> a = BosonOp("a") + >>> normal_ordered_form(a * Dagger(a)) + 1 + Dagger(a)*a + """ + + if _recursive_depth > recursive_limit: + warnings.warn("Too many recursions, aborting") + return expr + + if isinstance(expr, Add): + return _normal_ordered_form_terms(expr, + recursive_limit=recursive_limit, + _recursive_depth=_recursive_depth, + independent=independent) + elif isinstance(expr, Mul): + return _normal_ordered_form_factor(expr, + recursive_limit=recursive_limit, + _recursive_depth=_recursive_depth, + independent=independent) + else: + return expr + + +def _normal_order_factor(product, recursive_limit=10, _recursive_depth=0): + """ + Helper function for normal_order: Normal order a multiplication expression + with bosonic or fermionic operators. In general the resulting operator + expression will not be equivalent to original product. + """ + + factors = _expand_powers(product) + + n = 0 + new_factors = [] + while n < len(factors) - 1: + + if (isinstance(factors[n], BosonOp) and + factors[n].is_annihilation): + # boson + if not isinstance(factors[n + 1], BosonOp): + new_factors.append(factors[n]) + else: + if factors[n + 1].is_annihilation: + new_factors.append(factors[n]) + else: + if factors[n].args[0] != factors[n + 1].args[0]: + new_factors.append(factors[n + 1] * factors[n]) + else: + new_factors.append(factors[n + 1] * factors[n]) + n += 1 + + elif (isinstance(factors[n], FermionOp) and + factors[n].is_annihilation): + # fermion + if not isinstance(factors[n + 1], FermionOp): + new_factors.append(factors[n]) + else: + if factors[n + 1].is_annihilation: + new_factors.append(factors[n]) + else: + if factors[n].args[0] != factors[n + 1].args[0]: + new_factors.append(-factors[n + 1] * factors[n]) + else: + new_factors.append(-factors[n + 1] * factors[n]) + n += 1 + + else: + new_factors.append(factors[n]) + + n += 1 + + if n == len(factors) - 1: + new_factors.append(factors[-1]) + + if new_factors == factors: + return product + else: + expr = Mul(*new_factors).expand() + return normal_order(expr, + recursive_limit=recursive_limit, + _recursive_depth=_recursive_depth + 1) + + +def _normal_order_terms(expr, recursive_limit=10, _recursive_depth=0): + """ + Helper function for normal_order: look through each term in an addition + expression and call _normal_order_factor to perform the normal ordering + on the factors. + """ + + new_terms = [] + for term in expr.args: + if isinstance(term, Mul): + new_term = _normal_order_factor(term, + recursive_limit=recursive_limit, + _recursive_depth=_recursive_depth) + new_terms.append(new_term) + else: + new_terms.append(term) + + return Add(*new_terms) + + +def normal_order(expr, recursive_limit=10, _recursive_depth=0): + """Normal order an expression with bosonic or fermionic operators. Note + that this normal order is not equivalent to the original expression, but + the creation and annihilation operators in each term in expr is reordered + so that the expression becomes normal ordered. + + Parameters + ========== + + expr : expression + The expression to normal order. + + recursive_limit : int (default 10) + The number of allowed recursive applications of the function. + + Examples + ======== + + >>> from sympy.physics.quantum import Dagger + >>> from sympy.physics.quantum.boson import BosonOp + >>> from sympy.physics.quantum.operatorordering import normal_order + >>> a = BosonOp("a") + >>> normal_order(a * Dagger(a)) + Dagger(a)*a + """ + if _recursive_depth > recursive_limit: + warnings.warn("Too many recursions, aborting") + return expr + + if isinstance(expr, Add): + return _normal_order_terms(expr, recursive_limit=recursive_limit, + _recursive_depth=_recursive_depth) + elif isinstance(expr, Mul): + return _normal_order_factor(expr, recursive_limit=recursive_limit, + _recursive_depth=_recursive_depth) + else: + return expr diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/operatorset.py b/phivenv/Lib/site-packages/sympy/physics/quantum/operatorset.py new file mode 100644 index 0000000000000000000000000000000000000000..bf32bcabbe5d33381dff0b94a9b130375032adef --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/operatorset.py @@ -0,0 +1,279 @@ +""" A module for mapping operators to their corresponding eigenstates +and vice versa + +It contains a global dictionary with eigenstate-operator pairings. +If a new state-operator pair is created, this dictionary should be +updated as well. + +It also contains functions operators_to_state and state_to_operators +for mapping between the two. These can handle both classes and +instances of operators and states. See the individual function +descriptions for details. + +TODO List: +- Update the dictionary with a complete list of state-operator pairs +""" + +from sympy.physics.quantum.cartesian import (XOp, YOp, ZOp, XKet, PxOp, PxKet, + PositionKet3D) +from sympy.physics.quantum.operator import Operator +from sympy.physics.quantum.state import StateBase, BraBase, Ket +from sympy.physics.quantum.spin import (JxOp, JyOp, JzOp, J2Op, JxKet, JyKet, + JzKet) + +__all__ = [ + 'operators_to_state', + 'state_to_operators' +] + +#state_mapping stores the mappings between states and their associated +#operators or tuples of operators. This should be updated when new +#classes are written! Entries are of the form PxKet : PxOp or +#something like 3DKet : (ROp, ThetaOp, PhiOp) + +#frozenset is used so that the reverse mapping can be made +#(regular sets are not hashable because they are mutable +state_mapping = { JxKet: frozenset((J2Op, JxOp)), + JyKet: frozenset((J2Op, JyOp)), + JzKet: frozenset((J2Op, JzOp)), + Ket: Operator, + PositionKet3D: frozenset((XOp, YOp, ZOp)), + PxKet: PxOp, + XKet: XOp } + +op_mapping = {v: k for k, v in state_mapping.items()} + + +def operators_to_state(operators, **options): + """ Returns the eigenstate of the given operator or set of operators + + A global function for mapping operator classes to their associated + states. It takes either an Operator or a set of operators and + returns the state associated with these. + + This function can handle both instances of a given operator or + just the class itself (i.e. both XOp() and XOp) + + There are multiple use cases to consider: + + 1) A class or set of classes is passed: First, we try to + instantiate default instances for these operators. If this fails, + then the class is simply returned. If we succeed in instantiating + default instances, then we try to call state._operators_to_state + on the operator instances. If this fails, the class is returned. + Otherwise, the instance returned by _operators_to_state is returned. + + 2) An instance or set of instances is passed: In this case, + state._operators_to_state is called on the instances passed. If + this fails, a state class is returned. If the method returns an + instance, that instance is returned. + + In both cases, if the operator class or set does not exist in the + state_mapping dictionary, None is returned. + + Parameters + ========== + + arg: Operator or set + The class or instance of the operator or set of operators + to be mapped to a state + + Examples + ======== + + >>> from sympy.physics.quantum.cartesian import XOp, PxOp + >>> from sympy.physics.quantum.operatorset import operators_to_state + >>> from sympy.physics.quantum.operator import Operator + >>> operators_to_state(XOp) + |x> + >>> operators_to_state(XOp()) + |x> + >>> operators_to_state(PxOp) + |px> + >>> operators_to_state(PxOp()) + |px> + >>> operators_to_state(Operator) + |psi> + >>> operators_to_state(Operator()) + |psi> + """ + + if not (isinstance(operators, (Operator, set)) or issubclass(operators, Operator)): + raise NotImplementedError("Argument is not an Operator or a set!") + + if isinstance(operators, set): + for s in operators: + if not (isinstance(s, Operator) + or issubclass(s, Operator)): + raise NotImplementedError("Set is not all Operators!") + + ops = frozenset(operators) + + if ops in op_mapping: # ops is a list of classes in this case + #Try to get an object from default instances of the + #operators...if this fails, return the class + try: + op_instances = [op() for op in ops] + ret = _get_state(op_mapping[ops], set(op_instances), **options) + except NotImplementedError: + ret = op_mapping[ops] + + return ret + else: + tmp = [type(o) for o in ops] + classes = frozenset(tmp) + + if classes in op_mapping: + ret = _get_state(op_mapping[classes], ops, **options) + else: + ret = None + + return ret + else: + if operators in op_mapping: + try: + op_instance = operators() + ret = _get_state(op_mapping[operators], op_instance, **options) + except NotImplementedError: + ret = op_mapping[operators] + + return ret + elif type(operators) in op_mapping: + return _get_state(op_mapping[type(operators)], operators, **options) + else: + return None + + +def state_to_operators(state, **options): + """ Returns the operator or set of operators corresponding to the + given eigenstate + + A global function for mapping state classes to their associated + operators or sets of operators. It takes either a state class + or instance. + + This function can handle both instances of a given state or just + the class itself (i.e. both XKet() and XKet) + + There are multiple use cases to consider: + + 1) A state class is passed: In this case, we first try + instantiating a default instance of the class. If this succeeds, + then we try to call state._state_to_operators on that instance. + If the creation of the default instance or if the calling of + _state_to_operators fails, then either an operator class or set of + operator classes is returned. Otherwise, the appropriate + operator instances are returned. + + 2) A state instance is returned: Here, state._state_to_operators + is called for the instance. If this fails, then a class or set of + operator classes is returned. Otherwise, the instances are returned. + + In either case, if the state's class does not exist in + state_mapping, None is returned. + + Parameters + ========== + + arg: StateBase class or instance (or subclasses) + The class or instance of the state to be mapped to an + operator or set of operators + + Examples + ======== + + >>> from sympy.physics.quantum.cartesian import XKet, PxKet, XBra, PxBra + >>> from sympy.physics.quantum.operatorset import state_to_operators + >>> from sympy.physics.quantum.state import Ket, Bra + >>> state_to_operators(XKet) + X + >>> state_to_operators(XKet()) + X + >>> state_to_operators(PxKet) + Px + >>> state_to_operators(PxKet()) + Px + >>> state_to_operators(PxBra) + Px + >>> state_to_operators(XBra) + X + >>> state_to_operators(Ket) + O + >>> state_to_operators(Bra) + O + """ + + if not (isinstance(state, StateBase) or issubclass(state, StateBase)): + raise NotImplementedError("Argument is not a state!") + + if state in state_mapping: # state is a class + state_inst = _make_default(state) + try: + ret = _get_ops(state_inst, + _make_set(state_mapping[state]), **options) + except (NotImplementedError, TypeError): + ret = state_mapping[state] + elif type(state) in state_mapping: + ret = _get_ops(state, + _make_set(state_mapping[type(state)]), **options) + elif isinstance(state, BraBase) and state.dual_class() in state_mapping: + ret = _get_ops(state, + _make_set(state_mapping[state.dual_class()])) + elif issubclass(state, BraBase) and state.dual_class() in state_mapping: + state_inst = _make_default(state) + try: + ret = _get_ops(state_inst, + _make_set(state_mapping[state.dual_class()])) + except (NotImplementedError, TypeError): + ret = state_mapping[state.dual_class()] + else: + ret = None + + return _make_set(ret) + + +def _make_default(expr): + # XXX: Catching TypeError like this is a bad way of distinguishing between + # classes and instances. The logic using this function should be rewritten + # somehow. + try: + ret = expr() + except TypeError: + ret = expr + + return ret + + +def _get_state(state_class, ops, **options): + # Try to get a state instance from the operator INSTANCES. + # If this fails, get the class + try: + ret = state_class._operators_to_state(ops, **options) + except NotImplementedError: + ret = _make_default(state_class) + + return ret + + +def _get_ops(state_inst, op_classes, **options): + # Try to get operator instances from the state INSTANCE. + # If this fails, just return the classes + try: + ret = state_inst._state_to_operators(op_classes, **options) + except NotImplementedError: + if isinstance(op_classes, (set, tuple, frozenset)): + ret = tuple(_make_default(x) for x in op_classes) + else: + ret = _make_default(op_classes) + + if isinstance(ret, set) and len(ret) == 1: + return ret[0] + + return ret + + +def _make_set(ops): + if isinstance(ops, (tuple, list, frozenset)): + return set(ops) + else: + return ops diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/pauli.py b/phivenv/Lib/site-packages/sympy/physics/quantum/pauli.py new file mode 100644 index 0000000000000000000000000000000000000000..89762ed2b38e1c5df3775714ee08d3700df0fa65 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/pauli.py @@ -0,0 +1,675 @@ +"""Pauli operators and states""" + +from sympy.core.add import Add +from sympy.core.mul import Mul +from sympy.core.numbers import I +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.functions.elementary.exponential import exp +from sympy.physics.quantum import Operator, Ket, Bra +from sympy.physics.quantum import ComplexSpace +from sympy.matrices import Matrix +from sympy.functions.special.tensor_functions import KroneckerDelta + +__all__ = [ + 'SigmaX', 'SigmaY', 'SigmaZ', 'SigmaMinus', 'SigmaPlus', 'SigmaZKet', + 'SigmaZBra', 'qsimplify_pauli' +] + + +class SigmaOpBase(Operator): + """Pauli sigma operator, base class""" + + @property + def name(self): + return self.args[0] + + @property + def use_name(self): + return bool(self.args[0]) is not False + + @classmethod + def default_args(self): + return (False,) + + def __new__(cls, *args, **hints): + return Operator.__new__(cls, *args, **hints) + + def _eval_commutator_BosonOp(self, other, **hints): + return S.Zero + + +class SigmaX(SigmaOpBase): + """Pauli sigma x operator + + Parameters + ========== + + name : str + An optional string that labels the operator. Pauli operators with + different names commute. + + Examples + ======== + + >>> from sympy.physics.quantum import represent + >>> from sympy.physics.quantum.pauli import SigmaX + >>> sx = SigmaX() + >>> sx + SigmaX() + >>> represent(sx) + Matrix([ + [0, 1], + [1, 0]]) + """ + + def __new__(cls, *args, **hints): + return SigmaOpBase.__new__(cls, *args, **hints) + + def _eval_commutator_SigmaY(self, other, **hints): + if self.name != other.name: + return S.Zero + else: + return 2 * I * SigmaZ(self.name) + + def _eval_commutator_SigmaZ(self, other, **hints): + if self.name != other.name: + return S.Zero + else: + return - 2 * I * SigmaY(self.name) + + def _eval_commutator_BosonOp(self, other, **hints): + return S.Zero + + def _eval_anticommutator_SigmaY(self, other, **hints): + return S.Zero + + def _eval_anticommutator_SigmaZ(self, other, **hints): + return S.Zero + + def _eval_adjoint(self): + return self + + def _print_contents_latex(self, printer, *args): + if self.use_name: + return r'{\sigma_x^{(%s)}}' % str(self.name) + else: + return r'{\sigma_x}' + + def _print_contents(self, printer, *args): + return 'SigmaX()' + + def _eval_power(self, e): + if e.is_Integer and e.is_positive: + return SigmaX(self.name).__pow__(int(e) % 2) + + def _represent_default_basis(self, **options): + format = options.get('format', 'sympy') + if format == 'sympy': + return Matrix([[0, 1], [1, 0]]) + else: + raise NotImplementedError('Representation in format ' + + format + ' not implemented.') + + +class SigmaY(SigmaOpBase): + """Pauli sigma y operator + + Parameters + ========== + + name : str + An optional string that labels the operator. Pauli operators with + different names commute. + + Examples + ======== + + >>> from sympy.physics.quantum import represent + >>> from sympy.physics.quantum.pauli import SigmaY + >>> sy = SigmaY() + >>> sy + SigmaY() + >>> represent(sy) + Matrix([ + [0, -I], + [I, 0]]) + """ + + def __new__(cls, *args, **hints): + return SigmaOpBase.__new__(cls, *args) + + def _eval_commutator_SigmaZ(self, other, **hints): + if self.name != other.name: + return S.Zero + else: + return 2 * I * SigmaX(self.name) + + def _eval_commutator_SigmaX(self, other, **hints): + if self.name != other.name: + return S.Zero + else: + return - 2 * I * SigmaZ(self.name) + + def _eval_anticommutator_SigmaX(self, other, **hints): + return S.Zero + + def _eval_anticommutator_SigmaZ(self, other, **hints): + return S.Zero + + def _eval_adjoint(self): + return self + + def _print_contents_latex(self, printer, *args): + if self.use_name: + return r'{\sigma_y^{(%s)}}' % str(self.name) + else: + return r'{\sigma_y}' + + def _print_contents(self, printer, *args): + return 'SigmaY()' + + def _eval_power(self, e): + if e.is_Integer and e.is_positive: + return SigmaY(self.name).__pow__(int(e) % 2) + + def _represent_default_basis(self, **options): + format = options.get('format', 'sympy') + if format == 'sympy': + return Matrix([[0, -I], [I, 0]]) + else: + raise NotImplementedError('Representation in format ' + + format + ' not implemented.') + + +class SigmaZ(SigmaOpBase): + """Pauli sigma z operator + + Parameters + ========== + + name : str + An optional string that labels the operator. Pauli operators with + different names commute. + + Examples + ======== + + >>> from sympy.physics.quantum import represent + >>> from sympy.physics.quantum.pauli import SigmaZ + >>> sz = SigmaZ() + >>> sz ** 3 + SigmaZ() + >>> represent(sz) + Matrix([ + [1, 0], + [0, -1]]) + """ + + def __new__(cls, *args, **hints): + return SigmaOpBase.__new__(cls, *args) + + def _eval_commutator_SigmaX(self, other, **hints): + if self.name != other.name: + return S.Zero + else: + return 2 * I * SigmaY(self.name) + + def _eval_commutator_SigmaY(self, other, **hints): + if self.name != other.name: + return S.Zero + else: + return - 2 * I * SigmaX(self.name) + + def _eval_anticommutator_SigmaX(self, other, **hints): + return S.Zero + + def _eval_anticommutator_SigmaY(self, other, **hints): + return S.Zero + + def _eval_adjoint(self): + return self + + def _print_contents_latex(self, printer, *args): + if self.use_name: + return r'{\sigma_z^{(%s)}}' % str(self.name) + else: + return r'{\sigma_z}' + + def _print_contents(self, printer, *args): + return 'SigmaZ()' + + def _eval_power(self, e): + if e.is_Integer and e.is_positive: + return SigmaZ(self.name).__pow__(int(e) % 2) + + def _represent_default_basis(self, **options): + format = options.get('format', 'sympy') + if format == 'sympy': + return Matrix([[1, 0], [0, -1]]) + else: + raise NotImplementedError('Representation in format ' + + format + ' not implemented.') + + +class SigmaMinus(SigmaOpBase): + """Pauli sigma minus operator + + Parameters + ========== + + name : str + An optional string that labels the operator. Pauli operators with + different names commute. + + Examples + ======== + + >>> from sympy.physics.quantum import represent, Dagger + >>> from sympy.physics.quantum.pauli import SigmaMinus + >>> sm = SigmaMinus() + >>> sm + SigmaMinus() + >>> Dagger(sm) + SigmaPlus() + >>> represent(sm) + Matrix([ + [0, 0], + [1, 0]]) + """ + + def __new__(cls, *args, **hints): + return SigmaOpBase.__new__(cls, *args) + + def _eval_commutator_SigmaX(self, other, **hints): + if self.name != other.name: + return S.Zero + else: + return -SigmaZ(self.name) + + def _eval_commutator_SigmaY(self, other, **hints): + if self.name != other.name: + return S.Zero + else: + return I * SigmaZ(self.name) + + def _eval_commutator_SigmaZ(self, other, **hints): + return 2 * self + + def _eval_commutator_SigmaMinus(self, other, **hints): + return SigmaZ(self.name) + + def _eval_anticommutator_SigmaZ(self, other, **hints): + return S.Zero + + def _eval_anticommutator_SigmaX(self, other, **hints): + return S.One + + def _eval_anticommutator_SigmaY(self, other, **hints): + return I * S.NegativeOne + + def _eval_anticommutator_SigmaPlus(self, other, **hints): + return S.One + + def _eval_adjoint(self): + return SigmaPlus(self.name) + + def _eval_power(self, e): + if e.is_Integer and e.is_positive: + return S.Zero + + def _print_contents_latex(self, printer, *args): + if self.use_name: + return r'{\sigma_-^{(%s)}}' % str(self.name) + else: + return r'{\sigma_-}' + + def _print_contents(self, printer, *args): + return 'SigmaMinus()' + + def _represent_default_basis(self, **options): + format = options.get('format', 'sympy') + if format == 'sympy': + return Matrix([[0, 0], [1, 0]]) + else: + raise NotImplementedError('Representation in format ' + + format + ' not implemented.') + + +class SigmaPlus(SigmaOpBase): + """Pauli sigma plus operator + + Parameters + ========== + + name : str + An optional string that labels the operator. Pauli operators with + different names commute. + + Examples + ======== + + >>> from sympy.physics.quantum import represent, Dagger + >>> from sympy.physics.quantum.pauli import SigmaPlus + >>> sp = SigmaPlus() + >>> sp + SigmaPlus() + >>> Dagger(sp) + SigmaMinus() + >>> represent(sp) + Matrix([ + [0, 1], + [0, 0]]) + """ + + def __new__(cls, *args, **hints): + return SigmaOpBase.__new__(cls, *args) + + def _eval_commutator_SigmaX(self, other, **hints): + if self.name != other.name: + return S.Zero + else: + return SigmaZ(self.name) + + def _eval_commutator_SigmaY(self, other, **hints): + if self.name != other.name: + return S.Zero + else: + return I * SigmaZ(self.name) + + def _eval_commutator_SigmaZ(self, other, **hints): + if self.name != other.name: + return S.Zero + else: + return -2 * self + + def _eval_commutator_SigmaMinus(self, other, **hints): + return SigmaZ(self.name) + + def _eval_anticommutator_SigmaZ(self, other, **hints): + return S.Zero + + def _eval_anticommutator_SigmaX(self, other, **hints): + return S.One + + def _eval_anticommutator_SigmaY(self, other, **hints): + return I + + def _eval_anticommutator_SigmaMinus(self, other, **hints): + return S.One + + def _eval_adjoint(self): + return SigmaMinus(self.name) + + def _eval_mul(self, other): + return self * other + + def _eval_power(self, e): + if e.is_Integer and e.is_positive: + return S.Zero + + def _print_contents_latex(self, printer, *args): + if self.use_name: + return r'{\sigma_+^{(%s)}}' % str(self.name) + else: + return r'{\sigma_+}' + + def _print_contents(self, printer, *args): + return 'SigmaPlus()' + + def _represent_default_basis(self, **options): + format = options.get('format', 'sympy') + if format == 'sympy': + return Matrix([[0, 1], [0, 0]]) + else: + raise NotImplementedError('Representation in format ' + + format + ' not implemented.') + + +class SigmaZKet(Ket): + """Ket for a two-level system quantum system. + + Parameters + ========== + + n : Number + The state number (0 or 1). + + """ + + def __new__(cls, n): + if n not in (0, 1): + raise ValueError("n must be 0 or 1") + return Ket.__new__(cls, n) + + @property + def n(self): + return self.label[0] + + @classmethod + def dual_class(self): + return SigmaZBra + + @classmethod + def _eval_hilbert_space(cls, label): + return ComplexSpace(2) + + def _eval_innerproduct_SigmaZBra(self, bra, **hints): + return KroneckerDelta(self.n, bra.n) + + def _apply_from_right_to_SigmaZ(self, op, **options): + if self.n == 0: + return self + else: + return S.NegativeOne * self + + def _apply_from_right_to_SigmaX(self, op, **options): + return SigmaZKet(1) if self.n == 0 else SigmaZKet(0) + + def _apply_from_right_to_SigmaY(self, op, **options): + return I * SigmaZKet(1) if self.n == 0 else (-I) * SigmaZKet(0) + + def _apply_from_right_to_SigmaMinus(self, op, **options): + if self.n == 0: + return SigmaZKet(1) + else: + return S.Zero + + def _apply_from_right_to_SigmaPlus(self, op, **options): + if self.n == 0: + return S.Zero + else: + return SigmaZKet(0) + + def _represent_default_basis(self, **options): + format = options.get('format', 'sympy') + if format == 'sympy': + return Matrix([[1], [0]]) if self.n == 0 else Matrix([[0], [1]]) + else: + raise NotImplementedError('Representation in format ' + + format + ' not implemented.') + + +class SigmaZBra(Bra): + """Bra for a two-level quantum system. + + Parameters + ========== + + n : Number + The state number (0 or 1). + + """ + + def __new__(cls, n): + if n not in (0, 1): + raise ValueError("n must be 0 or 1") + return Bra.__new__(cls, n) + + @property + def n(self): + return self.label[0] + + @classmethod + def dual_class(self): + return SigmaZKet + + +def _qsimplify_pauli_product(a, b): + """ + Internal helper function for simplifying products of Pauli operators. + """ + if not (isinstance(a, SigmaOpBase) and isinstance(b, SigmaOpBase)): + return Mul(a, b) + + if a.name != b.name: + # Pauli matrices with different labels commute; sort by name + if a.name < b.name: + return Mul(a, b) + else: + return Mul(b, a) + + elif isinstance(a, SigmaX): + + if isinstance(b, SigmaX): + return S.One + + if isinstance(b, SigmaY): + return I * SigmaZ(a.name) + + if isinstance(b, SigmaZ): + return - I * SigmaY(a.name) + + if isinstance(b, SigmaMinus): + return (S.Half + SigmaZ(a.name)/2) + + if isinstance(b, SigmaPlus): + return (S.Half - SigmaZ(a.name)/2) + + elif isinstance(a, SigmaY): + + if isinstance(b, SigmaX): + return - I * SigmaZ(a.name) + + if isinstance(b, SigmaY): + return S.One + + if isinstance(b, SigmaZ): + return I * SigmaX(a.name) + + if isinstance(b, SigmaMinus): + return -I * (S.One + SigmaZ(a.name))/2 + + if isinstance(b, SigmaPlus): + return I * (S.One - SigmaZ(a.name))/2 + + elif isinstance(a, SigmaZ): + + if isinstance(b, SigmaX): + return I * SigmaY(a.name) + + if isinstance(b, SigmaY): + return - I * SigmaX(a.name) + + if isinstance(b, SigmaZ): + return S.One + + if isinstance(b, SigmaMinus): + return - SigmaMinus(a.name) + + if isinstance(b, SigmaPlus): + return SigmaPlus(a.name) + + elif isinstance(a, SigmaMinus): + + if isinstance(b, SigmaX): + return (S.One - SigmaZ(a.name))/2 + + if isinstance(b, SigmaY): + return - I * (S.One - SigmaZ(a.name))/2 + + if isinstance(b, SigmaZ): + # (SigmaX(a.name) - I * SigmaY(a.name))/2 + return SigmaMinus(b.name) + + if isinstance(b, SigmaMinus): + return S.Zero + + if isinstance(b, SigmaPlus): + return S.Half - SigmaZ(a.name)/2 + + elif isinstance(a, SigmaPlus): + + if isinstance(b, SigmaX): + return (S.One + SigmaZ(a.name))/2 + + if isinstance(b, SigmaY): + return I * (S.One + SigmaZ(a.name))/2 + + if isinstance(b, SigmaZ): + #-(SigmaX(a.name) + I * SigmaY(a.name))/2 + return -SigmaPlus(a.name) + + if isinstance(b, SigmaMinus): + return (S.One + SigmaZ(a.name))/2 + + if isinstance(b, SigmaPlus): + return S.Zero + + else: + return a * b + + +def qsimplify_pauli(e): + """ + Simplify an expression that includes products of pauli operators. + + Parameters + ========== + + e : expression + An expression that contains products of Pauli operators that is + to be simplified. + + Examples + ======== + + >>> from sympy.physics.quantum.pauli import SigmaX, SigmaY + >>> from sympy.physics.quantum.pauli import qsimplify_pauli + >>> sx, sy = SigmaX(), SigmaY() + >>> sx * sy + SigmaX()*SigmaY() + >>> qsimplify_pauli(sx * sy) + I*SigmaZ() + """ + if isinstance(e, Operator): + return e + + if isinstance(e, (Add, Pow, exp)): + t = type(e) + return t(*(qsimplify_pauli(arg) for arg in e.args)) + + if isinstance(e, Mul): + + c, nc = e.args_cnc() + + nc_s = [] + while nc: + curr = nc.pop(0) + + while (len(nc) and + isinstance(curr, SigmaOpBase) and + isinstance(nc[0], SigmaOpBase) and + curr.name == nc[0].name): + + x = nc.pop(0) + y = _qsimplify_pauli_product(curr, x) + c1, nc1 = y.args_cnc() + curr = Mul(*nc1) + c = c + c1 + + nc_s.append(curr) + + return Mul(*c) * Mul(*nc_s) + + return e diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/piab.py b/phivenv/Lib/site-packages/sympy/physics/quantum/piab.py new file mode 100644 index 0000000000000000000000000000000000000000..f8ac8135ee03e640f745070602c7dd8ca20f2767 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/piab.py @@ -0,0 +1,72 @@ +"""1D quantum particle in a box.""" + +from sympy.core.numbers import pi +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import sin +from sympy.sets.sets import Interval + +from sympy.physics.quantum.operator import HermitianOperator +from sympy.physics.quantum.state import Ket, Bra +from sympy.physics.quantum.constants import hbar +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.physics.quantum.hilbert import L2 + +m = Symbol('m') +L = Symbol('L') + + +__all__ = [ + 'PIABHamiltonian', + 'PIABKet', + 'PIABBra' +] + + +class PIABHamiltonian(HermitianOperator): + """Particle in a box Hamiltonian operator.""" + + @classmethod + def _eval_hilbert_space(cls, label): + return L2(Interval(S.NegativeInfinity, S.Infinity)) + + def _apply_operator_PIABKet(self, ket, **options): + n = ket.label[0] + return (n**2*pi**2*hbar**2)/(2*m*L**2)*ket + + +class PIABKet(Ket): + """Particle in a box eigenket.""" + + @classmethod + def _eval_hilbert_space(cls, args): + return L2(Interval(S.NegativeInfinity, S.Infinity)) + + @classmethod + def dual_class(self): + return PIABBra + + def _represent_default_basis(self, **options): + return self._represent_XOp(None, **options) + + def _represent_XOp(self, basis, **options): + x = Symbol('x') + n = Symbol('n') + subs_info = options.get('subs', {}) + return sqrt(2/L)*sin(n*pi*x/L).subs(subs_info) + + def _eval_innerproduct_PIABBra(self, bra): + return KroneckerDelta(bra.label[0], self.label[0]) + + +class PIABBra(Bra): + """Particle in a box eigenbra.""" + + @classmethod + def _eval_hilbert_space(cls, label): + return L2(Interval(S.NegativeInfinity, S.Infinity)) + + @classmethod + def dual_class(self): + return PIABKet diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/qapply.py b/phivenv/Lib/site-packages/sympy/physics/quantum/qapply.py new file mode 100644 index 0000000000000000000000000000000000000000..a2d8c92e51552c8114d65a1304fcd1925ae752f4 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/qapply.py @@ -0,0 +1,263 @@ +"""Logic for applying operators to states. + +Todo: +* Sometimes the final result needs to be expanded, we should do this by hand. +""" + +from sympy.concrete import Sum +from sympy.core.add import Add +from sympy.core.kind import NumberKind +from sympy.core.mul import Mul +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.core.sympify import sympify, _sympify + +from sympy.physics.quantum.anticommutator import AntiCommutator +from sympy.physics.quantum.commutator import Commutator +from sympy.physics.quantum.dagger import Dagger +from sympy.physics.quantum.innerproduct import InnerProduct +from sympy.physics.quantum.operator import OuterProduct, Operator +from sympy.physics.quantum.state import State, KetBase, BraBase, Wavefunction +from sympy.physics.quantum.tensorproduct import TensorProduct + +__all__ = [ + 'qapply' +] + + +#----------------------------------------------------------------------------- +# Main code +#----------------------------------------------------------------------------- + + +def ip_doit_func(e): + """Transform the inner products in an expression by calling ``.doit()``.""" + return e.replace(InnerProduct, lambda *args: InnerProduct(*args).doit()) + + +def sum_doit_func(e): + """Transform the sums in an expression by calling ``.doit()``.""" + return e.replace(Sum, lambda *args: Sum(*args).doit()) + + +def qapply(e, **options): + """Apply operators to states in a quantum expression. + + Parameters + ========== + + e : Expr + The expression containing operators and states. This expression tree + will be walked to find operators acting on states symbolically. + options : dict + A dict of key/value pairs that determine how the operator actions + are carried out. + + The following options are valid: + + * ``dagger``: try to apply Dagger operators to the left + (default: False). + * ``ip_doit``: call ``.doit()`` in inner products when they are + encountered (default: True). + * ``sum_doit``: call ``.doit()`` on sums when they are encountered + (default: False). This is helpful for collapsing sums over Kronecker + delta's that are created when calling ``qapply``. + + Returns + ======= + + e : Expr + The original expression, but with the operators applied to states. + + Examples + ======== + + >>> from sympy.physics.quantum import qapply, Ket, Bra + >>> b = Bra('b') + >>> k = Ket('k') + >>> A = k * b + >>> A + |k>>> qapply(A * b.dual / (b * b.dual)) + |k> + >>> qapply(k.dual * A / (k.dual * k)) + and A*(|a>+|b>) and all Commutators and + # TensorProducts. The only problem with this is that if we can't apply + # all the Operators, we have just expanded everything. + # TODO: don't expand the scalars in front of each Mul. + e = e.expand(commutator=True, tensorproduct=True) + + # If we just have a raw ket, return it. + if isinstance(e, KetBase): + return e + + # We have an Add(a, b, c, ...) and compute + # Add(qapply(a), qapply(b), ...) + elif isinstance(e, Add): + result = 0 + for arg in e.args: + result += qapply(arg, **options) + return result.expand() + + # For a Density operator call qapply on its state + elif isinstance(e, Density): + new_args = [(qapply(state, **options), prob) for (state, + prob) in e.args] + return Density(*new_args) + + # For a raw TensorProduct, call qapply on its args. + elif isinstance(e, TensorProduct): + return TensorProduct(*[qapply(t, **options) for t in e.args]) + + # For a Sum, call qapply on its function. + elif isinstance(e, Sum): + result = Sum(qapply(e.function, **options), *e.limits) + result = sum_doit_func(result) if sum_doit else result + return result + + # For a Pow, call qapply on its base. + elif isinstance(e, Pow): + return qapply(e.base, **options)**e.exp + + # We have a Mul where there might be actual operators to apply to kets. + elif isinstance(e, Mul): + c_part, nc_part = e.args_cnc() + c_mul = Mul(*c_part) + nc_mul = Mul(*nc_part) + if not nc_part: # If we only have a commuting part, just return it. + result = c_mul + elif isinstance(nc_mul, Mul): + result = c_mul*qapply_Mul(nc_mul, **options) + else: + result = c_mul*qapply(nc_mul, **options) + if result == e and dagger: + result = Dagger(qapply_Mul(Dagger(e), **options)) + result = ip_doit_func(result) if ip_doit else result + result = sum_doit_func(result) if sum_doit else result + return result + + # In all other cases (State, Operator, Pow, Commutator, InnerProduct, + # OuterProduct) we won't ever have operators to apply to kets. + else: + return e + + +def qapply_Mul(e, **options): + + args = list(e.args) + extra = S.One + result = None + + # If we only have 0 or 1 args, we have nothing to do and return. + if len(args) <= 1 or not isinstance(e, Mul): + return e + rhs = args.pop() + lhs = args.pop() + + # Make sure we have two non-commutative objects before proceeding. + if (not isinstance(rhs, Wavefunction) and sympify(rhs).is_commutative) or \ + (not isinstance(lhs, Wavefunction) and sympify(lhs).is_commutative): + return e + + # For a Pow with an integer exponent, apply one of them and reduce the + # exponent by one. + if isinstance(lhs, Pow) and lhs.exp.is_Integer: + args.append(lhs.base**(lhs.exp - 1)) + lhs = lhs.base + + # Pull OuterProduct apart + if isinstance(lhs, OuterProduct): + args.append(lhs.ket) + lhs = lhs.bra + + if isinstance(rhs, OuterProduct): + extra = rhs.bra # Append to the right of the result + rhs = rhs.ket + + # Call .doit() on Commutator/AntiCommutator. + if isinstance(lhs, (Commutator, AntiCommutator)): + comm = lhs.doit() + if isinstance(comm, Add): + return qapply( + e.func(*(args + [comm.args[0], rhs])) + + e.func(*(args + [comm.args[1], rhs])), + **options + )*extra + else: + return qapply(e.func(*args)*comm*rhs, **options)*extra + + # Apply tensor products of operators to states + if isinstance(lhs, TensorProduct) and all(isinstance(arg, (Operator, State, Mul, Pow)) or arg == 1 for arg in lhs.args) and \ + isinstance(rhs, TensorProduct) and all(isinstance(arg, (Operator, State, Mul, Pow)) or arg == 1 for arg in rhs.args) and \ + len(lhs.args) == len(rhs.args): + result = TensorProduct(*[qapply(lhs.args[n]*rhs.args[n], **options) for n in range(len(lhs.args))]).expand(tensorproduct=True) + return qapply_Mul(e.func(*args), **options)*result*extra + + # For Sums, move the Sum to the right. + if isinstance(rhs, Sum): + if isinstance(lhs, Sum): + if set(lhs.variables).intersection(set(rhs.variables)): + raise ValueError('Duplicated dummy indices in separate sums in qapply.') + limits = lhs.limits + rhs.limits + result = Sum(qapply(lhs.function*rhs.function, **options), *limits) + return qapply_Mul(e.func(*args)*result, **options) + else: + result = Sum(qapply(lhs*rhs.function, **options), *rhs.limits) + return qapply_Mul(e.func(*args)*result, **options) + + if isinstance(lhs, Sum): + result = Sum(qapply(lhs.function*rhs, **options), *lhs.limits) + return qapply_Mul(e.func(*args)*result, **options) + + # Now try to actually apply the operator and build an inner product. + _apply = getattr(lhs, '_apply_operator', None) + if _apply is not None: + try: + result = _apply(rhs, **options) + except NotImplementedError: + result = None + else: + result = None + + if result is None: + _apply_right = getattr(rhs, '_apply_from_right_to', None) + if _apply_right is not None: + try: + result = _apply_right(lhs, **options) + except NotImplementedError: + result = None + + if result is None: + if isinstance(lhs, BraBase) and isinstance(rhs, KetBase): + result = InnerProduct(lhs, rhs) + + # TODO: I may need to expand before returning the final result. + if isinstance(result, (int, complex, float)): + return _sympify(result) + elif result is None: + if len(args) == 0: + # We had two args to begin with so args=[]. + return e + else: + return qapply_Mul(e.func(*(args + [lhs])), **options)*rhs*extra + elif isinstance(result, InnerProduct): + return result*qapply_Mul(e.func(*args), **options)*extra + else: # result is a scalar times a Mul, Add or TensorProduct + return qapply(e.func(*args)*result, **options)*extra diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/qasm.py b/phivenv/Lib/site-packages/sympy/physics/quantum/qasm.py new file mode 100644 index 0000000000000000000000000000000000000000..39b49d9a67399114e7d03f12148854b2e41b0b26 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/qasm.py @@ -0,0 +1,224 @@ +""" + +qasm.py - Functions to parse a set of qasm commands into a SymPy Circuit. + +Examples taken from Chuang's page: https://web.archive.org/web/20220120121541/https://www.media.mit.edu/quanta/qasm2circ/ + +The code returns a circuit and an associated list of labels. + +>>> from sympy.physics.quantum.qasm import Qasm +>>> q = Qasm('qubit q0', 'qubit q1', 'h q0', 'cnot q0,q1') +>>> q.get_circuit() +CNOT(1,0)*H(1) + +>>> q = Qasm('qubit q0', 'qubit q1', 'cnot q0,q1', 'cnot q1,q0', 'cnot q0,q1') +>>> q.get_circuit() +CNOT(1,0)*CNOT(0,1)*CNOT(1,0) +""" + +__all__ = [ + 'Qasm', + ] + +from math import prod + +from sympy.physics.quantum.gate import H, CNOT, X, Z, CGate, CGateS, SWAP, S, T,CPHASE +from sympy.physics.quantum.circuitplot import Mz + +def read_qasm(lines): + return Qasm(*lines.splitlines()) + +def read_qasm_file(filename): + return Qasm(*open(filename).readlines()) + +def flip_index(i, n): + """Reorder qubit indices from largest to smallest. + + >>> from sympy.physics.quantum.qasm import flip_index + >>> flip_index(0, 2) + 1 + >>> flip_index(1, 2) + 0 + """ + return n-i-1 + +def trim(line): + """Remove everything following comment # characters in line. + + >>> from sympy.physics.quantum.qasm import trim + >>> trim('nothing happens here') + 'nothing happens here' + >>> trim('something #happens here') + 'something ' + """ + if '#' not in line: + return line + return line.split('#')[0] + +def get_index(target, labels): + """Get qubit labels from the rest of the line,and return indices + + >>> from sympy.physics.quantum.qasm import get_index + >>> get_index('q0', ['q0', 'q1']) + 1 + >>> get_index('q1', ['q0', 'q1']) + 0 + """ + nq = len(labels) + return flip_index(labels.index(target), nq) + +def get_indices(targets, labels): + return [get_index(t, labels) for t in targets] + +def nonblank(args): + for line in args: + line = trim(line) + if line.isspace(): + continue + yield line + return + +def fullsplit(line): + words = line.split() + rest = ' '.join(words[1:]) + return fixcommand(words[0]), [s.strip() for s in rest.split(',')] + +def fixcommand(c): + """Fix Qasm command names. + + Remove all of forbidden characters from command c, and + replace 'def' with 'qdef'. + """ + forbidden_characters = ['-'] + c = c.lower() + for char in forbidden_characters: + c = c.replace(char, '') + if c == 'def': + return 'qdef' + return c + +def stripquotes(s): + """Replace explicit quotes in a string. + + >>> from sympy.physics.quantum.qasm import stripquotes + >>> stripquotes("'S'") == 'S' + True + >>> stripquotes('"S"') == 'S' + True + >>> stripquotes('S') == 'S' + True + """ + s = s.replace('"', '') # Remove second set of quotes? + s = s.replace("'", '') + return s + +class Qasm: + """Class to form objects from Qasm lines + + >>> from sympy.physics.quantum.qasm import Qasm + >>> q = Qasm('qubit q0', 'qubit q1', 'h q0', 'cnot q0,q1') + >>> q.get_circuit() + CNOT(1,0)*H(1) + >>> q = Qasm('qubit q0', 'qubit q1', 'cnot q0,q1', 'cnot q1,q0', 'cnot q0,q1') + >>> q.get_circuit() + CNOT(1,0)*CNOT(0,1)*CNOT(1,0) + """ + def __init__(self, *args, **kwargs): + self.defs = {} + self.circuit = [] + self.labels = [] + self.inits = {} + self.add(*args) + self.kwargs = kwargs + + def add(self, *lines): + for line in nonblank(lines): + command, rest = fullsplit(line) + if self.defs.get(command): #defs come first, since you can override built-in + function = self.defs.get(command) + indices = self.indices(rest) + if len(indices) == 1: + self.circuit.append(function(indices[0])) + else: + self.circuit.append(function(indices[:-1], indices[-1])) + elif hasattr(self, command): + function = getattr(self, command) + function(*rest) + else: + print("Function %s not defined. Skipping" % command) + + def get_circuit(self): + return prod(reversed(self.circuit)) + + def get_labels(self): + return list(reversed(self.labels)) + + def plot(self): + from sympy.physics.quantum.circuitplot import CircuitPlot + circuit, labels = self.get_circuit(), self.get_labels() + CircuitPlot(circuit, len(labels), labels=labels, inits=self.inits) + + def qubit(self, arg, init=None): + self.labels.append(arg) + if init: self.inits[arg] = init + + def indices(self, args): + return get_indices(args, self.labels) + + def index(self, arg): + return get_index(arg, self.labels) + + def nop(self, *args): + pass + + def x(self, arg): + self.circuit.append(X(self.index(arg))) + + def z(self, arg): + self.circuit.append(Z(self.index(arg))) + + def h(self, arg): + self.circuit.append(H(self.index(arg))) + + def s(self, arg): + self.circuit.append(S(self.index(arg))) + + def t(self, arg): + self.circuit.append(T(self.index(arg))) + + def measure(self, arg): + self.circuit.append(Mz(self.index(arg))) + + def cnot(self, a1, a2): + self.circuit.append(CNOT(*self.indices([a1, a2]))) + + def swap(self, a1, a2): + self.circuit.append(SWAP(*self.indices([a1, a2]))) + + def cphase(self, a1, a2): + self.circuit.append(CPHASE(*self.indices([a1, a2]))) + + def toffoli(self, a1, a2, a3): + i1, i2, i3 = self.indices([a1, a2, a3]) + self.circuit.append(CGateS((i1, i2), X(i3))) + + def cx(self, a1, a2): + fi, fj = self.indices([a1, a2]) + self.circuit.append(CGate(fi, X(fj))) + + def cz(self, a1, a2): + fi, fj = self.indices([a1, a2]) + self.circuit.append(CGate(fi, Z(fj))) + + def defbox(self, *args): + print("defbox not supported yet. Skipping: ", args) + + def qdef(self, name, ncontrols, symbol): + from sympy.physics.quantum.circuitplot import CreateOneQubitGate, CreateCGate + ncontrols = int(ncontrols) + command = fixcommand(name) + symbol = stripquotes(symbol) + if ncontrols > 0: + self.defs[command] = CreateCGate(symbol) + else: + self.defs[command] = CreateOneQubitGate(symbol) diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/qexpr.py b/phivenv/Lib/site-packages/sympy/physics/quantum/qexpr.py new file mode 100644 index 0000000000000000000000000000000000000000..64f7e2a200fa7d89b35db1da551bcbd25492f2d9 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/qexpr.py @@ -0,0 +1,409 @@ +from sympy.core.expr import Expr +from sympy.core.symbol import Symbol +from sympy.core.sympify import sympify +from sympy.matrices.dense import Matrix +from sympy.printing.pretty.stringpict import prettyForm +from sympy.core.containers import Tuple +from sympy.utilities.iterables import is_sequence + +from sympy.physics.quantum.dagger import Dagger +from sympy.physics.quantum.matrixutils import ( + numpy_ndarray, scipy_sparse_matrix, + to_sympy, to_numpy, to_scipy_sparse +) + +__all__ = [ + 'QuantumError', + 'QExpr' +] + + +#----------------------------------------------------------------------------- +# Error handling +#----------------------------------------------------------------------------- + +class QuantumError(Exception): + pass + + +def _qsympify_sequence(seq): + """Convert elements of a sequence to standard form. + + This is like sympify, but it performs special logic for arguments passed + to QExpr. The following conversions are done: + + * (list, tuple, Tuple) => _qsympify_sequence each element and convert + sequence to a Tuple. + * basestring => Symbol + * Matrix => Matrix + * other => sympify + + Strings are passed to Symbol, not sympify to make sure that variables like + 'pi' are kept as Symbols, not the SymPy built-in number subclasses. + + Examples + ======== + + >>> from sympy.physics.quantum.qexpr import _qsympify_sequence + >>> _qsympify_sequence((1,2,[3,4,[1,]])) + (1, 2, (3, 4, (1,))) + + """ + + return tuple(__qsympify_sequence_helper(seq)) + + +def __qsympify_sequence_helper(seq): + """ + Helper function for _qsympify_sequence + This function does the actual work. + """ + #base case. If not a list, do Sympification + if not is_sequence(seq): + if isinstance(seq, Matrix): + return seq + elif isinstance(seq, str): + return Symbol(seq) + else: + return sympify(seq) + + # base condition, when seq is QExpr and also + # is iterable. + if isinstance(seq, QExpr): + return seq + + #if list, recurse on each item in the list + result = [__qsympify_sequence_helper(item) for item in seq] + + return Tuple(*result) + + +#----------------------------------------------------------------------------- +# Basic Quantum Expression from which all objects descend +#----------------------------------------------------------------------------- + +class QExpr(Expr): + """A base class for all quantum object like operators and states.""" + + # In sympy, slots are for instance attributes that are computed + # dynamically by the __new__ method. They are not part of args, but they + # derive from args. + + # The Hilbert space a quantum Object belongs to. + __slots__ = ('hilbert_space', ) + + is_commutative = False + + # The separator used in printing the label. + _label_separator = '' + + def __new__(cls, *args, **kwargs): + """Construct a new quantum object. + + Parameters + ========== + + args : tuple + The list of numbers or parameters that uniquely specify the + quantum object. For a state, this will be its symbol or its + set of quantum numbers. + + Examples + ======== + + >>> from sympy.physics.quantum.qexpr import QExpr + >>> q = QExpr(0) + >>> q + 0 + >>> q.label + (0,) + >>> q.hilbert_space + H + >>> q.args + (0,) + >>> q.is_commutative + False + """ + + # First compute args and call Expr.__new__ to create the instance + args = cls._eval_args(args, **kwargs) + if len(args) == 0: + args = cls._eval_args(tuple(cls.default_args()), **kwargs) + inst = Expr.__new__(cls, *args) + # Now set the slots on the instance + inst.hilbert_space = cls._eval_hilbert_space(args) + return inst + + @classmethod + def _new_rawargs(cls, hilbert_space, *args, **old_assumptions): + """Create new instance of this class with hilbert_space and args. + + This is used to bypass the more complex logic in the ``__new__`` + method in cases where you already have the exact ``hilbert_space`` + and ``args``. This should be used when you are positive these + arguments are valid, in their final, proper form and want to optimize + the creation of the object. + """ + + obj = Expr.__new__(cls, *args, **old_assumptions) + obj.hilbert_space = hilbert_space + return obj + + #------------------------------------------------------------------------- + # Properties + #------------------------------------------------------------------------- + + @property + def label(self): + """The label is the unique set of identifiers for the object. + + Usually, this will include all of the information about the state + *except* the time (in the case of time-dependent objects). + + This must be a tuple, rather than a Tuple. + """ + if len(self.args) == 0: # If there is no label specified, return the default + return self._eval_args(list(self.default_args())) + else: + return self.args + + @property + def is_symbolic(self): + return True + + @classmethod + def default_args(self): + """If no arguments are specified, then this will return a default set + of arguments to be run through the constructor. + + NOTE: Any classes that override this MUST return a tuple of arguments. + Should be overridden by subclasses to specify the default arguments for kets and operators + """ + raise NotImplementedError("No default arguments for this class!") + + #------------------------------------------------------------------------- + # _eval_* methods + #------------------------------------------------------------------------- + + def _eval_adjoint(self): + obj = Expr._eval_adjoint(self) + if obj is None: + obj = Expr.__new__(Dagger, self) + if isinstance(obj, QExpr): + obj.hilbert_space = self.hilbert_space + return obj + + @classmethod + def _eval_args(cls, args): + """Process the args passed to the __new__ method. + + This simply runs args through _qsympify_sequence. + """ + return _qsympify_sequence(args) + + @classmethod + def _eval_hilbert_space(cls, args): + """Compute the Hilbert space instance from the args. + """ + from sympy.physics.quantum.hilbert import HilbertSpace + return HilbertSpace() + + #------------------------------------------------------------------------- + # Printing + #------------------------------------------------------------------------- + + # Utilities for printing: these operate on raw SymPy objects + + def _print_sequence(self, seq, sep, printer, *args): + result = [] + for item in seq: + result.append(printer._print(item, *args)) + return sep.join(result) + + def _print_sequence_pretty(self, seq, sep, printer, *args): + pform = printer._print(seq[0], *args) + for item in seq[1:]: + pform = prettyForm(*pform.right(sep)) + pform = prettyForm(*pform.right(printer._print(item, *args))) + return pform + + # Utilities for printing: these operate prettyForm objects + + def _print_subscript_pretty(self, a, b): + top = prettyForm(*b.left(' '*a.width())) + bot = prettyForm(*a.right(' '*b.width())) + return prettyForm(binding=prettyForm.POW, *bot.below(top)) + + def _print_superscript_pretty(self, a, b): + return a**b + + def _print_parens_pretty(self, pform, left='(', right=')'): + return prettyForm(*pform.parens(left=left, right=right)) + + # Printing of labels (i.e. args) + + def _print_label(self, printer, *args): + """Prints the label of the QExpr + + This method prints self.label, using self._label_separator to separate + the elements. This method should not be overridden, instead, override + _print_contents to change printing behavior. + """ + return self._print_sequence( + self.label, self._label_separator, printer, *args + ) + + def _print_label_repr(self, printer, *args): + return self._print_sequence( + self.label, ',', printer, *args + ) + + def _print_label_pretty(self, printer, *args): + return self._print_sequence_pretty( + self.label, self._label_separator, printer, *args + ) + + def _print_label_latex(self, printer, *args): + return self._print_sequence( + self.label, self._label_separator, printer, *args + ) + + # Printing of contents (default to label) + + def _print_contents(self, printer, *args): + """Printer for contents of QExpr + + Handles the printing of any unique identifying contents of a QExpr to + print as its contents, such as any variables or quantum numbers. The + default is to print the label, which is almost always the args. This + should not include printing of any brackets or parentheses. + """ + return self._print_label(printer, *args) + + def _print_contents_pretty(self, printer, *args): + return self._print_label_pretty(printer, *args) + + def _print_contents_latex(self, printer, *args): + return self._print_label_latex(printer, *args) + + # Main printing methods + + def _sympystr(self, printer, *args): + """Default printing behavior of QExpr objects + + Handles the default printing of a QExpr. To add other things to the + printing of the object, such as an operator name to operators or + brackets to states, the class should override the _print/_pretty/_latex + functions directly and make calls to _print_contents where appropriate. + This allows things like InnerProduct to easily control its printing the + printing of contents. + """ + return self._print_contents(printer, *args) + + def _sympyrepr(self, printer, *args): + classname = self.__class__.__name__ + label = self._print_label_repr(printer, *args) + return '%s(%s)' % (classname, label) + + def _pretty(self, printer, *args): + pform = self._print_contents_pretty(printer, *args) + return pform + + def _latex(self, printer, *args): + return self._print_contents_latex(printer, *args) + + #------------------------------------------------------------------------- + # Represent + #------------------------------------------------------------------------- + + def _represent_default_basis(self, **options): + raise NotImplementedError('This object does not have a default basis') + + def _represent(self, *, basis=None, **options): + """Represent this object in a given basis. + + This method dispatches to the actual methods that perform the + representation. Subclases of QExpr should define various methods to + determine how the object will be represented in various bases. The + format of these methods is:: + + def _represent_BasisName(self, basis, **options): + + Thus to define how a quantum object is represented in the basis of + the operator Position, you would define:: + + def _represent_Position(self, basis, **options): + + Usually, basis object will be instances of Operator subclasses, but + there is a chance we will relax this in the future to accommodate other + types of basis sets that are not associated with an operator. + + If the ``format`` option is given it can be ("sympy", "numpy", + "scipy.sparse"). This will ensure that any matrices that result from + representing the object are returned in the appropriate matrix format. + + Parameters + ========== + + basis : Operator + The Operator whose basis functions will be used as the basis for + representation. + options : dict + A dictionary of key/value pairs that give options and hints for + the representation, such as the number of basis functions to + be used. + """ + if basis is None: + result = self._represent_default_basis(**options) + else: + result = dispatch_method(self, '_represent', basis, **options) + + # If we get a matrix representation, convert it to the right format. + format = options.get('format', 'sympy') + result = self._format_represent(result, format) + return result + + def _format_represent(self, result, format): + if format == 'sympy' and not isinstance(result, Matrix): + return to_sympy(result) + elif format == 'numpy' and not isinstance(result, numpy_ndarray): + return to_numpy(result) + elif format == 'scipy.sparse' and \ + not isinstance(result, scipy_sparse_matrix): + return to_scipy_sparse(result) + + return result + + +def split_commutative_parts(e): + """Split into commutative and non-commutative parts.""" + c_part, nc_part = e.args_cnc() + c_part = list(c_part) + return c_part, nc_part + + +def split_qexpr_parts(e): + """Split an expression into Expr and noncommutative QExpr parts.""" + expr_part = [] + qexpr_part = [] + for arg in e.args: + if not isinstance(arg, QExpr): + expr_part.append(arg) + else: + qexpr_part.append(arg) + return expr_part, qexpr_part + + +def dispatch_method(self, basename, arg, **options): + """Dispatch a method to the proper handlers.""" + method_name = '%s_%s' % (basename, arg.__class__.__name__) + if hasattr(self, method_name): + f = getattr(self, method_name) + # This can raise and we will allow it to propagate. + result = f(arg, **options) + if result is not None: + return result + raise NotImplementedError( + "%s.%s cannot handle: %r" % + (self.__class__.__name__, basename, arg) + ) diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/qft.py b/phivenv/Lib/site-packages/sympy/physics/quantum/qft.py new file mode 100644 index 0000000000000000000000000000000000000000..c6a3fa4539267f7bb6cf015521007e292b3d4cfd --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/qft.py @@ -0,0 +1,215 @@ +"""An implementation of qubits and gates acting on them. + +Todo: + +* Update docstrings. +* Update tests. +* Implement apply using decompose. +* Implement represent using decompose or something smarter. For this to + work we first have to implement represent for SWAP. +* Decide if we want upper index to be inclusive in the constructor. +* Fix the printing of Rk gates in plotting. +""" + +from sympy.core.expr import Expr +from sympy.core.numbers import (I, Integer, pi) +from sympy.core.symbol import Symbol +from sympy.functions.elementary.exponential import exp +from sympy.matrices.dense import Matrix +from sympy.functions import sqrt + +from sympy.physics.quantum.qapply import qapply +from sympy.physics.quantum.qexpr import QuantumError, QExpr +from sympy.matrices import eye +from sympy.physics.quantum.tensorproduct import matrix_tensor_product + +from sympy.physics.quantum.gate import ( + Gate, HadamardGate, SwapGate, OneQubitGate, CGate, PhaseGate, TGate, ZGate +) + +from sympy.functions.elementary.complexes import sign + +__all__ = [ + 'QFT', + 'IQFT', + 'RkGate', + 'Rk' +] + +#----------------------------------------------------------------------------- +# Fourier stuff +#----------------------------------------------------------------------------- + + +class RkGate(OneQubitGate): + """This is the R_k gate of the QTF.""" + gate_name = 'Rk' + gate_name_latex = 'R' + + def __new__(cls, *args): + if len(args) != 2: + raise QuantumError( + 'Rk gates only take two arguments, got: %r' % args + ) + # For small k, Rk gates simplify to other gates, using these + # substitutions give us familiar results for the QFT for small numbers + # of qubits. + target = args[0] + k = args[1] + if k == 1: + return ZGate(target) + elif k == 2: + return PhaseGate(target) + elif k == 3: + return TGate(target) + args = cls._eval_args(args) + inst = Expr.__new__(cls, *args) + inst.hilbert_space = cls._eval_hilbert_space(args) + return inst + + @classmethod + def _eval_args(cls, args): + # Fall back to this, because Gate._eval_args assumes that args is + # all targets and can't contain duplicates. + return QExpr._eval_args(args) + + @property + def k(self): + return self.label[1] + + @property + def targets(self): + return self.label[:1] + + @property + def gate_name_plot(self): + return r'$%s_%s$' % (self.gate_name_latex, str(self.k)) + + def get_target_matrix(self, format='sympy'): + if format == 'sympy': + return Matrix([[1, 0], [0, exp(sign(self.k)*Integer(2)*pi*I/(Integer(2)**abs(self.k)))]]) + raise NotImplementedError( + 'Invalid format for the R_k gate: %r' % format) + + +Rk = RkGate + + +class Fourier(Gate): + """Superclass of Quantum Fourier and Inverse Quantum Fourier Gates.""" + + @classmethod + def _eval_args(self, args): + if len(args) != 2: + raise QuantumError( + 'QFT/IQFT only takes two arguments, got: %r' % args + ) + if args[0] >= args[1]: + raise QuantumError("Start must be smaller than finish") + return Gate._eval_args(args) + + def _represent_default_basis(self, **options): + return self._represent_ZGate(None, **options) + + def _represent_ZGate(self, basis, **options): + """ + Represents the (I)QFT In the Z Basis + """ + nqubits = options.get('nqubits', 0) + if nqubits == 0: + raise QuantumError( + 'The number of qubits must be given as nqubits.') + if nqubits < self.min_qubits: + raise QuantumError( + 'The number of qubits %r is too small for the gate.' % nqubits + ) + size = self.size + omega = self.omega + + #Make a matrix that has the basic Fourier Transform Matrix + arrayFT = [[omega**( + i*j % size)/sqrt(size) for i in range(size)] for j in range(size)] + matrixFT = Matrix(arrayFT) + + #Embed the FT Matrix in a higher space, if necessary + if self.label[0] != 0: + matrixFT = matrix_tensor_product(eye(2**self.label[0]), matrixFT) + if self.min_qubits < nqubits: + matrixFT = matrix_tensor_product( + matrixFT, eye(2**(nqubits - self.min_qubits))) + + return matrixFT + + @property + def targets(self): + return range(self.label[0], self.label[1]) + + @property + def min_qubits(self): + return self.label[1] + + @property + def size(self): + """Size is the size of the QFT matrix""" + return 2**(self.label[1] - self.label[0]) + + @property + def omega(self): + return Symbol('omega') + + +class QFT(Fourier): + """The forward quantum Fourier transform.""" + + gate_name = 'QFT' + gate_name_latex = 'QFT' + + def decompose(self): + """Decomposes QFT into elementary gates.""" + start = self.label[0] + finish = self.label[1] + circuit = 1 + for level in reversed(range(start, finish)): + circuit = HadamardGate(level)*circuit + for i in range(level - start): + circuit = CGate(level - i - 1, RkGate(level, i + 2))*circuit + for i in range((finish - start)//2): + circuit = SwapGate(i + start, finish - i - 1)*circuit + return circuit + + def _apply_operator_Qubit(self, qubits, **options): + return qapply(self.decompose()*qubits) + + def _eval_inverse(self): + return IQFT(*self.args) + + @property + def omega(self): + return exp(2*pi*I/self.size) + + +class IQFT(Fourier): + """The inverse quantum Fourier transform.""" + + gate_name = 'IQFT' + gate_name_latex = '{QFT^{-1}}' + + def decompose(self): + """Decomposes IQFT into elementary gates.""" + start = self.args[0] + finish = self.args[1] + circuit = 1 + for i in range((finish - start)//2): + circuit = SwapGate(i + start, finish - i - 1)*circuit + for level in range(start, finish): + for i in reversed(range(level - start)): + circuit = CGate(level - i - 1, RkGate(level, -i - 2))*circuit + circuit = HadamardGate(level)*circuit + return circuit + + def _eval_inverse(self): + return QFT(*self.args) + + @property + def omega(self): + return exp(-2*pi*I/self.size) diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/qubit.py b/phivenv/Lib/site-packages/sympy/physics/quantum/qubit.py new file mode 100644 index 0000000000000000000000000000000000000000..71d1dbc01e3a16e2a4b64eec3c3800b7218b2636 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/qubit.py @@ -0,0 +1,811 @@ +"""Qubits for quantum computing. + +Todo: +* Finish implementing measurement logic. This should include POVM. +* Update docstrings. +* Update tests. +""" + + +import math + +from sympy.core.add import Add +from sympy.core.mul import Mul +from sympy.core.numbers import Integer +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.functions.elementary.complexes import conjugate +from sympy.functions.elementary.exponential import log +from sympy.core.basic import _sympify +from sympy.external.gmpy import SYMPY_INTS +from sympy.matrices import Matrix, zeros +from sympy.printing.pretty.stringpict import prettyForm + +from sympy.physics.quantum.hilbert import ComplexSpace +from sympy.physics.quantum.state import Ket, Bra, State + +from sympy.physics.quantum.qexpr import QuantumError +from sympy.physics.quantum.represent import represent +from sympy.physics.quantum.matrixutils import ( + numpy_ndarray, scipy_sparse_matrix +) +from mpmath.libmp.libintmath import bitcount + +__all__ = [ + 'Qubit', + 'QubitBra', + 'IntQubit', + 'IntQubitBra', + 'qubit_to_matrix', + 'matrix_to_qubit', + 'matrix_to_density', + 'measure_all', + 'measure_partial', + 'measure_partial_oneshot', + 'measure_all_oneshot' +] + +#----------------------------------------------------------------------------- +# Qubit Classes +#----------------------------------------------------------------------------- + + +class QubitState(State): + """Base class for Qubit and QubitBra.""" + + #------------------------------------------------------------------------- + # Initialization/creation + #------------------------------------------------------------------------- + + @classmethod + def _eval_args(cls, args): + # If we are passed a QubitState or subclass, we just take its qubit + # values directly. + if len(args) == 1 and isinstance(args[0], QubitState): + return args[0].qubit_values + + # Turn strings into tuple of strings + if len(args) == 1 and isinstance(args[0], str): + args = tuple( S.Zero if qb == "0" else S.One for qb in args[0]) + else: + args = tuple( S.Zero if qb == "0" else S.One if qb == "1" else qb for qb in args) + args = tuple(_sympify(arg) for arg in args) + + # Validate input (must have 0 or 1 input) + for element in args: + if element not in (S.Zero, S.One): + raise ValueError( + "Qubit values must be 0 or 1, got: %r" % element) + return args + + @classmethod + def _eval_hilbert_space(cls, args): + return ComplexSpace(2)**len(args) + + #------------------------------------------------------------------------- + # Properties + #------------------------------------------------------------------------- + + @property + def dimension(self): + """The number of Qubits in the state.""" + return len(self.qubit_values) + + @property + def nqubits(self): + return self.dimension + + @property + def qubit_values(self): + """Returns the values of the qubits as a tuple.""" + return self.label + + #------------------------------------------------------------------------- + # Special methods + #------------------------------------------------------------------------- + + def __len__(self): + return self.dimension + + def __getitem__(self, bit): + return self.qubit_values[int(self.dimension - bit - 1)] + + #------------------------------------------------------------------------- + # Utility methods + #------------------------------------------------------------------------- + + def flip(self, *bits): + """Flip the bit(s) given.""" + newargs = list(self.qubit_values) + for i in bits: + bit = int(self.dimension - i - 1) + if newargs[bit] == 1: + newargs[bit] = 0 + else: + newargs[bit] = 1 + return self.__class__(*tuple(newargs)) + + +class Qubit(QubitState, Ket): + """A multi-qubit ket in the computational (z) basis. + + We use the normal convention that the least significant qubit is on the + right, so ``|00001>`` has a 1 in the least significant qubit. + + Parameters + ========== + + values : list, str + The qubit values as a list of ints ([0,0,0,1,1,]) or a string ('011'). + + Examples + ======== + + Create a qubit in a couple of different ways and look at their attributes: + + >>> from sympy.physics.quantum.qubit import Qubit + >>> Qubit(0,0,0) + |000> + >>> q = Qubit('0101') + >>> q + |0101> + + >>> q.nqubits + 4 + >>> len(q) + 4 + >>> q.dimension + 4 + >>> q.qubit_values + (0, 1, 0, 1) + + We can flip the value of an individual qubit: + + >>> q.flip(1) + |0111> + + We can take the dagger of a Qubit to get a bra: + + >>> from sympy.physics.quantum.dagger import Dagger + >>> Dagger(q) + <0101| + >>> type(Dagger(q)) + + + Inner products work as expected: + + >>> ip = Dagger(q)*q + >>> ip + <0101|0101> + >>> ip.doit() + 1 + """ + + @classmethod + def dual_class(self): + return QubitBra + + def _eval_innerproduct_QubitBra(self, bra, **hints): + if self.label == bra.label: + return S.One + else: + return S.Zero + + def _represent_default_basis(self, **options): + return self._represent_ZGate(None, **options) + + def _represent_ZGate(self, basis, **options): + """Represent this qubits in the computational basis (ZGate). + """ + _format = options.get('format', 'sympy') + n = 1 + definite_state = 0 + for it in reversed(self.qubit_values): + definite_state += n*it + n = n*2 + result = [0]*(2**self.dimension) + result[int(definite_state)] = 1 + if _format == 'sympy': + return Matrix(result) + elif _format == 'numpy': + import numpy as np + return np.array(result, dtype='complex').transpose() + elif _format == 'scipy.sparse': + from scipy import sparse + return sparse.csr_matrix(result, dtype='complex').transpose() + + def _eval_trace(self, bra, **kwargs): + indices = kwargs.get('indices', []) + + #sort index list to begin trace from most-significant + #qubit + sorted_idx = list(indices) + if len(sorted_idx) == 0: + sorted_idx = list(range(0, self.nqubits)) + sorted_idx.sort() + + #trace out for each of index + new_mat = self*bra + for i in range(len(sorted_idx) - 1, -1, -1): + # start from tracing out from leftmost qubit + new_mat = self._reduced_density(new_mat, int(sorted_idx[i])) + + if (len(sorted_idx) == self.nqubits): + #in case full trace was requested + return new_mat[0] + else: + return matrix_to_density(new_mat) + + def _reduced_density(self, matrix, qubit, **options): + """Compute the reduced density matrix by tracing out one qubit. + The qubit argument should be of type Python int, since it is used + in bit operations + """ + def find_index_that_is_projected(j, k, qubit): + bit_mask = 2**qubit - 1 + return ((j >> qubit) << (1 + qubit)) + (j & bit_mask) + (k << qubit) + + old_matrix = represent(matrix, **options) + old_size = old_matrix.cols + #we expect the old_size to be even + new_size = old_size//2 + new_matrix = Matrix().zeros(new_size) + + for i in range(new_size): + for j in range(new_size): + for k in range(2): + col = find_index_that_is_projected(j, k, qubit) + row = find_index_that_is_projected(i, k, qubit) + new_matrix[i, j] += old_matrix[row, col] + + return new_matrix + + +class QubitBra(QubitState, Bra): + """A multi-qubit bra in the computational (z) basis. + + We use the normal convention that the least significant qubit is on the + right, so ``|00001>`` has a 1 in the least significant qubit. + + Parameters + ========== + + values : list, str + The qubit values as a list of ints ([0,0,0,1,1,]) or a string ('011'). + + See also + ======== + + Qubit: Examples using qubits + + """ + @classmethod + def dual_class(self): + return Qubit + + +class IntQubitState(QubitState): + """A base class for qubits that work with binary representations.""" + + @classmethod + def _eval_args(cls, args, nqubits=None): + # The case of a QubitState instance + if len(args) == 1 and isinstance(args[0], QubitState): + return QubitState._eval_args(args) + # otherwise, args should be integer + elif not all(isinstance(a, (int, Integer)) for a in args): + raise ValueError('values must be integers, got (%s)' % (tuple(type(a) for a in args),)) + # use nqubits if specified + if nqubits is not None: + if not isinstance(nqubits, (int, Integer)): + raise ValueError('nqubits must be an integer, got (%s)' % type(nqubits)) + if len(args) != 1: + raise ValueError( + 'too many positional arguments (%s). should be (number, nqubits=n)' % (args,)) + return cls._eval_args_with_nqubits(args[0], nqubits) + # For a single argument, we construct the binary representation of + # that integer with the minimal number of bits. + if len(args) == 1 and args[0] > 1: + #rvalues is the minimum number of bits needed to express the number + rvalues = reversed(range(bitcount(abs(args[0])))) + qubit_values = [(args[0] >> i) & 1 for i in rvalues] + return QubitState._eval_args(qubit_values) + # For two numbers, the second number is the number of bits + # on which it is expressed, so IntQubit(0,5) == |00000>. + elif len(args) == 2 and args[1] > 1: + return cls._eval_args_with_nqubits(args[0], args[1]) + else: + return QubitState._eval_args(args) + + @classmethod + def _eval_args_with_nqubits(cls, number, nqubits): + need = bitcount(abs(number)) + if nqubits < need: + raise ValueError( + 'cannot represent %s with %s bits' % (number, nqubits)) + qubit_values = [(number >> i) & 1 for i in reversed(range(nqubits))] + return QubitState._eval_args(qubit_values) + + def as_int(self): + """Return the numerical value of the qubit.""" + number = 0 + n = 1 + for i in reversed(self.qubit_values): + number += n*i + n = n << 1 + return number + + def _print_label(self, printer, *args): + return str(self.as_int()) + + def _print_label_pretty(self, printer, *args): + label = self._print_label(printer, *args) + return prettyForm(label) + + _print_label_repr = _print_label + _print_label_latex = _print_label + + +class IntQubit(IntQubitState, Qubit): + """A qubit ket that store integers as binary numbers in qubit values. + + The differences between this class and ``Qubit`` are: + + * The form of the constructor. + * The qubit values are printed as their corresponding integer, rather + than the raw qubit values. The internal storage format of the qubit + values in the same as ``Qubit``. + + Parameters + ========== + + values : int, tuple + If a single argument, the integer we want to represent in the qubit + values. This integer will be represented using the fewest possible + number of qubits. + If a pair of integers and the second value is more than one, the first + integer gives the integer to represent in binary form and the second + integer gives the number of qubits to use. + List of zeros and ones is also accepted to generate qubit by bit pattern. + + nqubits : int + The integer that represents the number of qubits. + This number should be passed with keyword ``nqubits=N``. + You can use this in order to avoid ambiguity of Qubit-style tuple of bits. + Please see the example below for more details. + + Examples + ======== + + Create a qubit for the integer 5: + + >>> from sympy.physics.quantum.qubit import IntQubit + >>> from sympy.physics.quantum.qubit import Qubit + >>> q = IntQubit(5) + >>> q + |5> + + We can also create an ``IntQubit`` by passing a ``Qubit`` instance. + + >>> q = IntQubit(Qubit('101')) + >>> q + |5> + >>> q.as_int() + 5 + >>> q.nqubits + 3 + >>> q.qubit_values + (1, 0, 1) + + We can go back to the regular qubit form. + + >>> Qubit(q) + |101> + + Please note that ``IntQubit`` also accepts a ``Qubit``-style list of bits. + So, the code below yields qubits 3, not a single bit ``1``. + + >>> IntQubit(1, 1) + |3> + + To avoid ambiguity, use ``nqubits`` parameter. + Use of this keyword is recommended especially when you provide the values by variables. + + >>> IntQubit(1, nqubits=1) + |1> + >>> a = 1 + >>> IntQubit(a, nqubits=1) + |1> + """ + @classmethod + def dual_class(self): + return IntQubitBra + + def _eval_innerproduct_IntQubitBra(self, bra, **hints): + return Qubit._eval_innerproduct_QubitBra(self, bra) + +class IntQubitBra(IntQubitState, QubitBra): + """A qubit bra that store integers as binary numbers in qubit values.""" + + @classmethod + def dual_class(self): + return IntQubit + + +#----------------------------------------------------------------------------- +# Qubit <---> Matrix conversion functions +#----------------------------------------------------------------------------- + + +def matrix_to_qubit(matrix): + """Convert from the matrix repr. to a sum of Qubit objects. + + Parameters + ---------- + matrix : Matrix, numpy.matrix, scipy.sparse + The matrix to build the Qubit representation of. This works with + SymPy matrices, numpy matrices and scipy.sparse sparse matrices. + + Examples + ======== + + Represent a state and then go back to its qubit form: + + >>> from sympy.physics.quantum.qubit import matrix_to_qubit, Qubit + >>> from sympy.physics.quantum.represent import represent + >>> q = Qubit('01') + >>> matrix_to_qubit(represent(q)) + |01> + """ + # Determine the format based on the type of the input matrix + format = 'sympy' + if isinstance(matrix, numpy_ndarray): + format = 'numpy' + if isinstance(matrix, scipy_sparse_matrix): + format = 'scipy.sparse' + + # Make sure it is of correct dimensions for a Qubit-matrix representation. + # This logic should work with sympy, numpy or scipy.sparse matrices. + if matrix.shape[0] == 1: + mlistlen = matrix.shape[1] + nqubits = log(mlistlen, 2) + ket = False + cls = QubitBra + elif matrix.shape[1] == 1: + mlistlen = matrix.shape[0] + nqubits = log(mlistlen, 2) + ket = True + cls = Qubit + else: + raise QuantumError( + 'Matrix must be a row/column vector, got %r' % matrix + ) + if not isinstance(nqubits, Integer): + raise QuantumError('Matrix must be a row/column vector of size ' + '2**nqubits, got: %r' % matrix) + # Go through each item in matrix, if element is non-zero, make it into a + # Qubit item times the element. + result = 0 + for i in range(mlistlen): + if ket: + element = matrix[i, 0] + else: + element = matrix[0, i] + if format in ('numpy', 'scipy.sparse'): + element = complex(element) + if element: + # Form Qubit array; 0 in bit-locations where i is 0, 1 in + # bit-locations where i is 1 + qubit_array = [int(i & (1 << x) != 0) for x in range(nqubits)] + qubit_array.reverse() + result = result + element*cls(*qubit_array) + + # If SymPy simplified by pulling out a constant coefficient, undo that. + if isinstance(result, (Mul, Add, Pow)): + result = result.expand() + + return result + + +def matrix_to_density(mat): + """ + Works by finding the eigenvectors and eigenvalues of the matrix. + We know we can decompose rho by doing: + sum(EigenVal*|Eigenvect>>> from sympy.physics.quantum.qubit import Qubit, measure_all + >>> from sympy.physics.quantum.gate import H + >>> from sympy.physics.quantum.qapply import qapply + + >>> c = H(0)*H(1)*Qubit('00') + >>> c + H(0)*H(1)*|00> + >>> q = qapply(c) + >>> measure_all(q) + [(|00>, 1/4), (|01>, 1/4), (|10>, 1/4), (|11>, 1/4)] + """ + m = qubit_to_matrix(qubit, format) + + if format == 'sympy': + results = [] + + if normalize: + m = m.normalized() + + size = max(m.shape) # Max of shape to account for bra or ket + nqubits = int(math.log(size)/math.log(2)) + for i in range(size): + if m[i]: + results.append( + (Qubit(IntQubit(i, nqubits=nqubits)), m[i]*conjugate(m[i])) + ) + return results + else: + raise NotImplementedError( + "This function cannot handle non-SymPy matrix formats yet" + ) + + +def measure_partial(qubit, bits, format='sympy', normalize=True): + """Perform a partial ensemble measure on the specified qubits. + + Parameters + ========== + + qubits : Qubit + The qubit to measure. This can be any Qubit or a linear combination + of them. + bits : tuple + The qubits to measure. + format : str + The format of the intermediate matrices to use. Possible values are + ('sympy','numpy','scipy.sparse'). Currently only 'sympy' is + implemented. + + Returns + ======= + + result : list + A list that consists of primitive states and their probabilities. + + Examples + ======== + + >>> from sympy.physics.quantum.qubit import Qubit, measure_partial + >>> from sympy.physics.quantum.gate import H + >>> from sympy.physics.quantum.qapply import qapply + + >>> c = H(0)*H(1)*Qubit('00') + >>> c + H(0)*H(1)*|00> + >>> q = qapply(c) + >>> measure_partial(q, (0,)) + [(sqrt(2)*|00>/2 + sqrt(2)*|10>/2, 1/2), (sqrt(2)*|01>/2 + sqrt(2)*|11>/2, 1/2)] + """ + m = qubit_to_matrix(qubit, format) + + if isinstance(bits, (SYMPY_INTS, Integer)): + bits = (int(bits),) + + if format == 'sympy': + if normalize: + m = m.normalized() + + possible_outcomes = _get_possible_outcomes(m, bits) + + # Form output from function. + output = [] + for outcome in possible_outcomes: + # Calculate probability of finding the specified bits with + # given values. + prob_of_outcome = 0 + prob_of_outcome += (outcome.H*outcome)[0] + + # If the output has a chance, append it to output with found + # probability. + if prob_of_outcome != 0: + if normalize: + next_matrix = matrix_to_qubit(outcome.normalized()) + else: + next_matrix = matrix_to_qubit(outcome) + + output.append(( + next_matrix, + prob_of_outcome + )) + + return output + else: + raise NotImplementedError( + "This function cannot handle non-SymPy matrix formats yet" + ) + + +def measure_partial_oneshot(qubit, bits, format='sympy'): + """Perform a partial oneshot measurement on the specified qubits. + + A oneshot measurement is equivalent to performing a measurement on a + quantum system. This type of measurement does not return the probabilities + like an ensemble measurement does, but rather returns *one* of the + possible resulting states. The exact state that is returned is determined + by picking a state randomly according to the ensemble probabilities. + + Parameters + ---------- + qubits : Qubit + The qubit to measure. This can be any Qubit or a linear combination + of them. + bits : tuple + The qubits to measure. + format : str + The format of the intermediate matrices to use. Possible values are + ('sympy','numpy','scipy.sparse'). Currently only 'sympy' is + implemented. + + Returns + ------- + result : Qubit + The qubit that the system collapsed to upon measurement. + """ + import random + m = qubit_to_matrix(qubit, format) + + if format == 'sympy': + m = m.normalized() + possible_outcomes = _get_possible_outcomes(m, bits) + + # Form output from function + random_number = random.random() + total_prob = 0 + for outcome in possible_outcomes: + # Calculate probability of finding the specified bits + # with given values + total_prob += (outcome.H*outcome)[0] + if total_prob >= random_number: + return matrix_to_qubit(outcome.normalized()) + else: + raise NotImplementedError( + "This function cannot handle non-SymPy matrix formats yet" + ) + + +def _get_possible_outcomes(m, bits): + """Get the possible states that can be produced in a measurement. + + Parameters + ---------- + m : Matrix + The matrix representing the state of the system. + bits : tuple, list + Which bits will be measured. + + Returns + ------- + result : list + The list of possible states which can occur given this measurement. + These are un-normalized so we can derive the probability of finding + this state by taking the inner product with itself + """ + + # This is filled with loads of dirty binary tricks...You have been warned + + size = max(m.shape) # Max of shape to account for bra or ket + nqubits = int(math.log2(size) + .1) # Number of qubits possible + + # Make the output states and put in output_matrices, nothing in them now. + # Each state will represent a possible outcome of the measurement + # Thus, output_matrices[0] is the matrix which we get when all measured + # bits return 0. and output_matrices[1] is the matrix for only the 0th + # bit being true + output_matrices = [] + for i in range(1 << len(bits)): + output_matrices.append(zeros(2**nqubits, 1)) + + # Bitmasks will help sort how to determine possible outcomes. + # When the bit mask is and-ed with a matrix-index, + # it will determine which state that index belongs to + bit_masks = [] + for bit in bits: + bit_masks.append(1 << bit) + + # Make possible outcome states + for i in range(2**nqubits): + trueness = 0 # This tells us to which output_matrix this value belongs + # Find trueness + for j in range(len(bit_masks)): + if i & bit_masks[j]: + trueness += j + 1 + # Put the value in the correct output matrix + output_matrices[trueness][i] = m[i] + return output_matrices + + +def measure_all_oneshot(qubit, format='sympy'): + """Perform a oneshot ensemble measurement on all qubits. + + A oneshot measurement is equivalent to performing a measurement on a + quantum system. This type of measurement does not return the probabilities + like an ensemble measurement does, but rather returns *one* of the + possible resulting states. The exact state that is returned is determined + by picking a state randomly according to the ensemble probabilities. + + Parameters + ---------- + qubits : Qubit + The qubit to measure. This can be any Qubit or a linear combination + of them. + format : str + The format of the intermediate matrices to use. Possible values are + ('sympy','numpy','scipy.sparse'). Currently only 'sympy' is + implemented. + + Returns + ------- + result : Qubit + The qubit that the system collapsed to upon measurement. + """ + import random + m = qubit_to_matrix(qubit) + + if format == 'sympy': + m = m.normalized() + random_number = random.random() + total = 0 + result = 0 + for i in m: + total += i*i.conjugate() + if total > random_number: + break + result += 1 + return Qubit(IntQubit(result, nqubits=int(math.log2(max(m.shape)) + .1))) + else: + raise NotImplementedError( + "This function cannot handle non-SymPy matrix formats yet" + ) diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/represent.py b/phivenv/Lib/site-packages/sympy/physics/quantum/represent.py new file mode 100644 index 0000000000000000000000000000000000000000..3a1ada80aa6a3dd2caad43ec132fb9a148947106 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/represent.py @@ -0,0 +1,574 @@ +"""Logic for representing operators in state in various bases. + +TODO: + +* Get represent working with continuous hilbert spaces. +* Document default basis functionality. +""" + +from sympy.core.add import Add +from sympy.core.expr import Expr +from sympy.core.mul import Mul +from sympy.core.numbers import I +from sympy.core.power import Pow +from sympy.integrals.integrals import integrate +from sympy.physics.quantum.dagger import Dagger +from sympy.physics.quantum.commutator import Commutator +from sympy.physics.quantum.anticommutator import AntiCommutator +from sympy.physics.quantum.innerproduct import InnerProduct +from sympy.physics.quantum.qexpr import QExpr +from sympy.physics.quantum.tensorproduct import TensorProduct +from sympy.physics.quantum.matrixutils import flatten_scalar +from sympy.physics.quantum.state import KetBase, BraBase, StateBase +from sympy.physics.quantum.operator import Operator, OuterProduct +from sympy.physics.quantum.qapply import qapply +from sympy.physics.quantum.operatorset import operators_to_state, state_to_operators + + +__all__ = [ + 'represent', + 'rep_innerproduct', + 'rep_expectation', + 'integrate_result', + 'get_basis', + 'enumerate_states' +] + +#----------------------------------------------------------------------------- +# Represent +#----------------------------------------------------------------------------- + + +def _sympy_to_scalar(e): + """Convert from a SymPy scalar to a Python scalar.""" + if isinstance(e, Expr): + if e.is_Integer: + return int(e) + elif e.is_Float: + return float(e) + elif e.is_Rational: + return float(e) + elif e.is_Number or e.is_NumberSymbol or e == I: + return complex(e) + raise TypeError('Expected number, got: %r' % e) + + +def represent(expr, **options): + """Represent the quantum expression in the given basis. + + In quantum mechanics abstract states and operators can be represented in + various basis sets. Under this operation the follow transforms happen: + + * Ket -> column vector or function + * Bra -> row vector of function + * Operator -> matrix or differential operator + + This function is the top-level interface for this action. + + This function walks the SymPy expression tree looking for ``QExpr`` + instances that have a ``_represent`` method. This method is then called + and the object is replaced by the representation returned by this method. + By default, the ``_represent`` method will dispatch to other methods + that handle the representation logic for a particular basis set. The + naming convention for these methods is the following:: + + def _represent_FooBasis(self, e, basis, **options) + + This function will have the logic for representing instances of its class + in the basis set having a class named ``FooBasis``. + + Parameters + ========== + + expr : Expr + The expression to represent. + basis : Operator, basis set + An object that contains the information about the basis set. If an + operator is used, the basis is assumed to be the orthonormal + eigenvectors of that operator. In general though, the basis argument + can be any object that contains the basis set information. + options : dict + Key/value pairs of options that are passed to the underlying method + that finds the representation. These options can be used to + control how the representation is done. For example, this is where + the size of the basis set would be set. + + Returns + ======= + + e : Expr + The SymPy expression of the represented quantum expression. + + Examples + ======== + + Here we subclass ``Operator`` and ``Ket`` to create the z-spin operator + and its spin 1/2 up eigenstate. By defining the ``_represent_SzOp`` + method, the ket can be represented in the z-spin basis. + + >>> from sympy.physics.quantum import Operator, represent, Ket + >>> from sympy import Matrix + + >>> class SzUpKet(Ket): + ... def _represent_SzOp(self, basis, **options): + ... return Matrix([1,0]) + ... + >>> class SzOp(Operator): + ... pass + ... + >>> sz = SzOp('Sz') + >>> up = SzUpKet('up') + >>> represent(up, basis=sz) + Matrix([ + [1], + [0]]) + + Here we see an example of representations in a continuous + basis. We see that the result of representing various combinations + of cartesian position operators and kets give us continuous + expressions involving DiracDelta functions. + + >>> from sympy.physics.quantum.cartesian import XOp, XKet, XBra + >>> X = XOp() + >>> x = XKet() + >>> y = XBra('y') + >>> represent(X*x) + x*DiracDelta(x - x_2) + """ + + format = options.get('format', 'sympy') + if format == 'numpy': + import numpy as np + if isinstance(expr, QExpr) and not isinstance(expr, OuterProduct): + options['replace_none'] = False + temp_basis = get_basis(expr, **options) + if temp_basis is not None: + options['basis'] = temp_basis + try: + return expr._represent(**options) + except NotImplementedError as strerr: + #If no _represent_FOO method exists, map to the + #appropriate basis state and try + #the other methods of representation + options['replace_none'] = True + + if isinstance(expr, (KetBase, BraBase)): + try: + return rep_innerproduct(expr, **options) + except NotImplementedError: + raise NotImplementedError(strerr) + elif isinstance(expr, Operator): + try: + return rep_expectation(expr, **options) + except NotImplementedError: + raise NotImplementedError(strerr) + else: + raise NotImplementedError(strerr) + elif isinstance(expr, Add): + result = represent(expr.args[0], **options) + for args in expr.args[1:]: + # scipy.sparse doesn't support += so we use plain = here. + result = result + represent(args, **options) + return result + elif isinstance(expr, Pow): + base, exp = expr.as_base_exp() + if format in ('numpy', 'scipy.sparse'): + exp = _sympy_to_scalar(exp) + base = represent(base, **options) + # scipy.sparse doesn't support negative exponents + # and warns when inverting a matrix in csr format. + if format == 'scipy.sparse' and exp < 0: + from scipy.sparse.linalg import inv + exp = - exp + base = inv(base.tocsc()).tocsr() + if format == 'numpy': + return np.linalg.matrix_power(base, exp) + return base ** exp + elif isinstance(expr, TensorProduct): + new_args = [represent(arg, **options) for arg in expr.args] + return TensorProduct(*new_args) + elif isinstance(expr, Dagger): + return Dagger(represent(expr.args[0], **options)) + elif isinstance(expr, Commutator): + A = expr.args[0] + B = expr.args[1] + return represent(Mul(A, B) - Mul(B, A), **options) + elif isinstance(expr, AntiCommutator): + A = expr.args[0] + B = expr.args[1] + return represent(Mul(A, B) + Mul(B, A), **options) + elif not isinstance(expr, (Mul, OuterProduct, InnerProduct)): + # We have removed special handling of inner products that used to be + # required (before automatic transforms). + # For numpy and scipy.sparse, we can only handle numerical prefactors. + if format in ('numpy', 'scipy.sparse'): + return _sympy_to_scalar(expr) + return expr + + if not isinstance(expr, (Mul, OuterProduct, InnerProduct)): + raise TypeError('Mul expected, got: %r' % expr) + + if "index" in options: + options["index"] += 1 + else: + options["index"] = 1 + + if "unities" not in options: + options["unities"] = [] + + result = represent(expr.args[-1], **options) + last_arg = expr.args[-1] + + for arg in reversed(expr.args[:-1]): + if isinstance(last_arg, Operator): + options["index"] += 1 + options["unities"].append(options["index"]) + elif isinstance(last_arg, BraBase) and isinstance(arg, KetBase): + options["index"] += 1 + elif isinstance(last_arg, KetBase) and isinstance(arg, Operator): + options["unities"].append(options["index"]) + elif isinstance(last_arg, KetBase) and isinstance(arg, BraBase): + options["unities"].append(options["index"]) + + next_arg = represent(arg, **options) + if format == 'numpy' and isinstance(next_arg, np.ndarray): + # Must use np.matmult to "matrix multiply" two np.ndarray + result = np.matmul(next_arg, result) + else: + result = next_arg*result + last_arg = arg + + # All three matrix formats create 1 by 1 matrices when inner products of + # vectors are taken. In these cases, we simply return a scalar. + result = flatten_scalar(result) + + result = integrate_result(expr, result, **options) + + return result + + +def rep_innerproduct(expr, **options): + """ + Returns an innerproduct like representation (e.g. ````) for the + given state. + + Attempts to calculate inner product with a bra from the specified + basis. Should only be passed an instance of KetBase or BraBase + + Parameters + ========== + + expr : KetBase or BraBase + The expression to be represented + + Examples + ======== + + >>> from sympy.physics.quantum.represent import rep_innerproduct + >>> from sympy.physics.quantum.cartesian import XOp, XKet, PxOp, PxKet + >>> rep_innerproduct(XKet()) + DiracDelta(x - x_1) + >>> rep_innerproduct(XKet(), basis=PxOp()) + sqrt(2)*exp(-I*px_1*x/hbar)/(2*sqrt(hbar)*sqrt(pi)) + >>> rep_innerproduct(PxKet(), basis=XOp()) + sqrt(2)*exp(I*px*x_1/hbar)/(2*sqrt(hbar)*sqrt(pi)) + + """ + + if not isinstance(expr, (KetBase, BraBase)): + raise TypeError("expr passed is not a Bra or Ket") + + basis = get_basis(expr, **options) + + if not isinstance(basis, StateBase): + raise NotImplementedError("Can't form this representation!") + + if "index" not in options: + options["index"] = 1 + + basis_kets = enumerate_states(basis, options["index"], 2) + + if isinstance(expr, BraBase): + bra = expr + ket = (basis_kets[1] if basis_kets[0].dual == expr else basis_kets[0]) + else: + bra = (basis_kets[1].dual if basis_kets[0] + == expr else basis_kets[0].dual) + ket = expr + + prod = InnerProduct(bra, ket) + result = prod.doit() + + format = options.get('format', 'sympy') + result = expr._format_represent(result, format) + return result + + +def rep_expectation(expr, **options): + """ + Returns an ```` type representation for the given operator. + + Parameters + ========== + + expr : Operator + Operator to be represented in the specified basis + + Examples + ======== + + >>> from sympy.physics.quantum.cartesian import XOp, PxOp, PxKet + >>> from sympy.physics.quantum.represent import rep_expectation + >>> rep_expectation(XOp()) + x_1*DiracDelta(x_1 - x_2) + >>> rep_expectation(XOp(), basis=PxOp()) + + >>> rep_expectation(XOp(), basis=PxKet()) + + + """ + + if "index" not in options: + options["index"] = 1 + + if not isinstance(expr, Operator): + raise TypeError("The passed expression is not an operator") + + basis_state = get_basis(expr, **options) + + if basis_state is None or not isinstance(basis_state, StateBase): + raise NotImplementedError("Could not get basis kets for this operator") + + basis_kets = enumerate_states(basis_state, options["index"], 2) + + bra = basis_kets[1].dual + ket = basis_kets[0] + + result = qapply(bra*expr*ket) + return result + + +def integrate_result(orig_expr, result, **options): + """ + Returns the result of integrating over any unities ``(|x>>> from sympy import symbols, DiracDelta + >>> from sympy.physics.quantum.represent import integrate_result + >>> from sympy.physics.quantum.cartesian import XOp, XKet + >>> x_ket = XKet() + >>> X_op = XOp() + >>> x, x_1, x_2 = symbols('x, x_1, x_2') + >>> integrate_result(X_op*x_ket, x*DiracDelta(x-x_1)*DiracDelta(x_1-x_2)) + x*DiracDelta(x - x_1)*DiracDelta(x_1 - x_2) + >>> integrate_result(X_op*x_ket, x*DiracDelta(x-x_1)*DiracDelta(x_1-x_2), + ... unities=[1]) + x*DiracDelta(x - x_2) + + """ + if not isinstance(result, Expr): + return result + + options['replace_none'] = True + if "basis" not in options: + arg = orig_expr.args[-1] + options["basis"] = get_basis(arg, **options) + elif not isinstance(options["basis"], StateBase): + options["basis"] = get_basis(orig_expr, **options) + + basis = options.pop("basis", None) + + if basis is None: + return result + + unities = options.pop("unities", []) + + if len(unities) == 0: + return result + + kets = enumerate_states(basis, unities) + coords = [k.label[0] for k in kets] + + for coord in coords: + if coord in result.free_symbols: + #TODO: Add support for sets of operators + basis_op = state_to_operators(basis) + start = basis_op.hilbert_space.interval.start + end = basis_op.hilbert_space.interval.end + result = integrate(result, (coord, start, end)) + + return result + + +def get_basis(expr, *, basis=None, replace_none=True, **options): + """ + Returns a basis state instance corresponding to the basis specified in + options=s. If no basis is specified, the function tries to form a default + basis state of the given expression. + + There are three behaviors: + + 1. The basis specified in options is already an instance of StateBase. If + this is the case, it is simply returned. If the class is specified but + not an instance, a default instance is returned. + + 2. The basis specified is an operator or set of operators. If this + is the case, the operator_to_state mapping method is used. + + 3. No basis is specified. If expr is a state, then a default instance of + its class is returned. If expr is an operator, then it is mapped to the + corresponding state. If it is neither, then we cannot obtain the basis + state. + + If the basis cannot be mapped, then it is not changed. + + This will be called from within represent, and represent will + only pass QExpr's. + + TODO (?): Support for Muls and other types of expressions? + + Parameters + ========== + + expr : Operator or StateBase + Expression whose basis is sought + + Examples + ======== + + >>> from sympy.physics.quantum.represent import get_basis + >>> from sympy.physics.quantum.cartesian import XOp, XKet, PxOp, PxKet + >>> x = XKet() + >>> X = XOp() + >>> get_basis(x) + |x> + >>> get_basis(X) + |x> + >>> get_basis(x, basis=PxOp()) + |px> + >>> get_basis(x, basis=PxKet) + |px> + + """ + + if basis is None and not replace_none: + return None + + if basis is None: + if isinstance(expr, KetBase): + return _make_default(expr.__class__) + elif isinstance(expr, BraBase): + return _make_default(expr.dual_class()) + elif isinstance(expr, Operator): + state_inst = operators_to_state(expr) + return (state_inst if state_inst is not None else None) + else: + return None + elif (isinstance(basis, Operator) or + (not isinstance(basis, StateBase) and issubclass(basis, Operator))): + state = operators_to_state(basis) + if state is None: + return None + elif isinstance(state, StateBase): + return state + else: + return _make_default(state) + elif isinstance(basis, StateBase): + return basis + elif issubclass(basis, StateBase): + return _make_default(basis) + else: + return None + + +def _make_default(expr): + # XXX: Catching TypeError like this is a bad way of distinguishing + # instances from classes. The logic using this function should be + # rewritten somehow. + try: + expr = expr() + except TypeError: + return expr + + return expr + + +def enumerate_states(*args, **options): + """ + Returns instances of the given state with dummy indices appended + + Operates in two different modes: + + 1. Two arguments are passed to it. The first is the base state which is to + be indexed, and the second argument is a list of indices to append. + + 2. Three arguments are passed. The first is again the base state to be + indexed. The second is the start index for counting. The final argument + is the number of kets you wish to receive. + + Tries to call state._enumerate_state. If this fails, returns an empty list + + Parameters + ========== + + args : list + See list of operation modes above for explanation + + Examples + ======== + + >>> from sympy.physics.quantum.cartesian import XBra, XKet + >>> from sympy.physics.quantum.represent import enumerate_states + >>> test = XKet('foo') + >>> enumerate_states(test, 1, 3) + [|foo_1>, |foo_2>, |foo_3>] + >>> test2 = XBra('bar') + >>> enumerate_states(test2, [4, 5, 10]) + [>> from sympy.physics.quantum.sho1d import RaisingOp + >>> from sympy.physics.quantum import Dagger + + >>> ad = RaisingOp('a') + >>> ad.rewrite('xp').doit() + sqrt(2)*(m*omega*X - I*Px)/(2*sqrt(hbar)*sqrt(m*omega)) + + >>> Dagger(ad) + a + + Taking the commutator of a^dagger with other Operators: + + >>> from sympy.physics.quantum import Commutator + >>> from sympy.physics.quantum.sho1d import RaisingOp, LoweringOp + >>> from sympy.physics.quantum.sho1d import NumberOp + + >>> ad = RaisingOp('a') + >>> a = LoweringOp('a') + >>> N = NumberOp('N') + >>> Commutator(ad, a).doit() + -1 + >>> Commutator(ad, N).doit() + -RaisingOp(a) + + Apply a^dagger to a state: + + >>> from sympy.physics.quantum import qapply + >>> from sympy.physics.quantum.sho1d import RaisingOp, SHOKet + + >>> ad = RaisingOp('a') + >>> k = SHOKet('k') + >>> qapply(ad*k) + sqrt(k + 1)*|k + 1> + + Matrix Representation + + >>> from sympy.physics.quantum.sho1d import RaisingOp + >>> from sympy.physics.quantum.represent import represent + >>> ad = RaisingOp('a') + >>> represent(ad, basis=N, ndim=4, format='sympy') + Matrix([ + [0, 0, 0, 0], + [1, 0, 0, 0], + [0, sqrt(2), 0, 0], + [0, 0, sqrt(3), 0]]) + + """ + + def _eval_rewrite_as_xp(self, *args, **kwargs): + return (S.One/sqrt(Integer(2)*hbar*m*omega))*( + S.NegativeOne*I*Px + m*omega*X) + + def _eval_adjoint(self): + return LoweringOp(*self.args) + + def _eval_commutator_LoweringOp(self, other): + return S.NegativeOne + + def _eval_commutator_NumberOp(self, other): + return S.NegativeOne*self + + def _apply_operator_SHOKet(self, ket, **options): + temp = ket.n + S.One + return sqrt(temp)*SHOKet(temp) + + def _represent_default_basis(self, **options): + return self._represent_NumberOp(None, **options) + + def _represent_XOp(self, basis, **options): + # This logic is good but the underlying position + # representation logic is broken. + # temp = self.rewrite('xp').doit() + # result = represent(temp, basis=X) + # return result + raise NotImplementedError('Position representation is not implemented') + + def _represent_NumberOp(self, basis, **options): + ndim_info = options.get('ndim', 4) + format = options.get('format','sympy') + matrix = matrix_zeros(ndim_info, ndim_info, **options) + for i in range(ndim_info - 1): + value = sqrt(i + 1) + if format == 'scipy.sparse': + value = float(value) + matrix[i + 1, i] = value + if format == 'scipy.sparse': + matrix = matrix.tocsr() + return matrix + + #-------------------------------------------------------------------------- + # Printing Methods + #-------------------------------------------------------------------------- + + def _print_contents(self, printer, *args): + arg0 = printer._print(self.args[0], *args) + return '%s(%s)' % (self.__class__.__name__, arg0) + + def _print_contents_pretty(self, printer, *args): + from sympy.printing.pretty.stringpict import prettyForm + pform = printer._print(self.args[0], *args) + pform = pform**prettyForm('\N{DAGGER}') + return pform + + def _print_contents_latex(self, printer, *args): + arg = printer._print(self.args[0]) + return '%s^{\\dagger}' % arg + +class LoweringOp(SHOOp): + """The Lowering Operator or 'a'. + + When 'a' acts on a state it lowers the state up by one. Taking + the adjoint of 'a' returns a^dagger, the Raising Operator. 'a' + can be rewritten in terms of position and momentum. We can + represent 'a' as a matrix, which will be its default basis. + + Parameters + ========== + + args : tuple + The list of numbers or parameters that uniquely specify the + operator. + + Examples + ======== + + Create a Lowering Operator and rewrite it in terms of position and + momentum, and show that taking its adjoint returns a^dagger: + + >>> from sympy.physics.quantum.sho1d import LoweringOp + >>> from sympy.physics.quantum import Dagger + + >>> a = LoweringOp('a') + >>> a.rewrite('xp').doit() + sqrt(2)*(m*omega*X + I*Px)/(2*sqrt(hbar)*sqrt(m*omega)) + + >>> Dagger(a) + RaisingOp(a) + + Taking the commutator of 'a' with other Operators: + + >>> from sympy.physics.quantum import Commutator + >>> from sympy.physics.quantum.sho1d import LoweringOp, RaisingOp + >>> from sympy.physics.quantum.sho1d import NumberOp + + >>> a = LoweringOp('a') + >>> ad = RaisingOp('a') + >>> N = NumberOp('N') + >>> Commutator(a, ad).doit() + 1 + >>> Commutator(a, N).doit() + a + + Apply 'a' to a state: + + >>> from sympy.physics.quantum import qapply + >>> from sympy.physics.quantum.sho1d import LoweringOp, SHOKet + + >>> a = LoweringOp('a') + >>> k = SHOKet('k') + >>> qapply(a*k) + sqrt(k)*|k - 1> + + Taking 'a' of the lowest state will return 0: + + >>> from sympy.physics.quantum import qapply + >>> from sympy.physics.quantum.sho1d import LoweringOp, SHOKet + + >>> a = LoweringOp('a') + >>> k = SHOKet(0) + >>> qapply(a*k) + 0 + + Matrix Representation + + >>> from sympy.physics.quantum.sho1d import LoweringOp + >>> from sympy.physics.quantum.represent import represent + >>> a = LoweringOp('a') + >>> represent(a, basis=N, ndim=4, format='sympy') + Matrix([ + [0, 1, 0, 0], + [0, 0, sqrt(2), 0], + [0, 0, 0, sqrt(3)], + [0, 0, 0, 0]]) + + """ + + def _eval_rewrite_as_xp(self, *args, **kwargs): + return (S.One/sqrt(Integer(2)*hbar*m*omega))*( + I*Px + m*omega*X) + + def _eval_adjoint(self): + return RaisingOp(*self.args) + + def _eval_commutator_RaisingOp(self, other): + return S.One + + def _eval_commutator_NumberOp(self, other): + return self + + def _apply_operator_SHOKet(self, ket, **options): + temp = ket.n - Integer(1) + if ket.n is S.Zero: + return S.Zero + else: + return sqrt(ket.n)*SHOKet(temp) + + def _represent_default_basis(self, **options): + return self._represent_NumberOp(None, **options) + + def _represent_XOp(self, basis, **options): + # This logic is good but the underlying position + # representation logic is broken. + # temp = self.rewrite('xp').doit() + # result = represent(temp, basis=X) + # return result + raise NotImplementedError('Position representation is not implemented') + + def _represent_NumberOp(self, basis, **options): + ndim_info = options.get('ndim', 4) + format = options.get('format', 'sympy') + matrix = matrix_zeros(ndim_info, ndim_info, **options) + for i in range(ndim_info - 1): + value = sqrt(i + 1) + if format == 'scipy.sparse': + value = float(value) + matrix[i,i + 1] = value + if format == 'scipy.sparse': + matrix = matrix.tocsr() + return matrix + + +class NumberOp(SHOOp): + """The Number Operator is simply a^dagger*a + + It is often useful to write a^dagger*a as simply the Number Operator + because the Number Operator commutes with the Hamiltonian. And can be + expressed using the Number Operator. Also the Number Operator can be + applied to states. We can represent the Number Operator as a matrix, + which will be its default basis. + + Parameters + ========== + + args : tuple + The list of numbers or parameters that uniquely specify the + operator. + + Examples + ======== + + Create a Number Operator and rewrite it in terms of the ladder + operators, position and momentum operators, and Hamiltonian: + + >>> from sympy.physics.quantum.sho1d import NumberOp + + >>> N = NumberOp('N') + >>> N.rewrite('a').doit() + RaisingOp(a)*a + >>> N.rewrite('xp').doit() + -1/2 + (m**2*omega**2*X**2 + Px**2)/(2*hbar*m*omega) + >>> N.rewrite('H').doit() + -1/2 + H/(hbar*omega) + + Take the Commutator of the Number Operator with other Operators: + + >>> from sympy.physics.quantum import Commutator + >>> from sympy.physics.quantum.sho1d import NumberOp, Hamiltonian + >>> from sympy.physics.quantum.sho1d import RaisingOp, LoweringOp + + >>> N = NumberOp('N') + >>> H = Hamiltonian('H') + >>> ad = RaisingOp('a') + >>> a = LoweringOp('a') + >>> Commutator(N,H).doit() + 0 + >>> Commutator(N,ad).doit() + RaisingOp(a) + >>> Commutator(N,a).doit() + -a + + Apply the Number Operator to a state: + + >>> from sympy.physics.quantum import qapply + >>> from sympy.physics.quantum.sho1d import NumberOp, SHOKet + + >>> N = NumberOp('N') + >>> k = SHOKet('k') + >>> qapply(N*k) + k*|k> + + Matrix Representation + + >>> from sympy.physics.quantum.sho1d import NumberOp + >>> from sympy.physics.quantum.represent import represent + >>> N = NumberOp('N') + >>> represent(N, basis=N, ndim=4, format='sympy') + Matrix([ + [0, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 2, 0], + [0, 0, 0, 3]]) + + """ + + def _eval_rewrite_as_a(self, *args, **kwargs): + return ad*a + + def _eval_rewrite_as_xp(self, *args, **kwargs): + return (S.One/(Integer(2)*m*hbar*omega))*(Px**2 + ( + m*omega*X)**2) - S.Half + + def _eval_rewrite_as_H(self, *args, **kwargs): + return H/(hbar*omega) - S.Half + + def _apply_operator_SHOKet(self, ket, **options): + return ket.n*ket + + def _eval_commutator_Hamiltonian(self, other): + return S.Zero + + def _eval_commutator_RaisingOp(self, other): + return other + + def _eval_commutator_LoweringOp(self, other): + return S.NegativeOne*other + + def _represent_default_basis(self, **options): + return self._represent_NumberOp(None, **options) + + def _represent_XOp(self, basis, **options): + # This logic is good but the underlying position + # representation logic is broken. + # temp = self.rewrite('xp').doit() + # result = represent(temp, basis=X) + # return result + raise NotImplementedError('Position representation is not implemented') + + def _represent_NumberOp(self, basis, **options): + ndim_info = options.get('ndim', 4) + format = options.get('format', 'sympy') + matrix = matrix_zeros(ndim_info, ndim_info, **options) + for i in range(ndim_info): + value = i + if format == 'scipy.sparse': + value = float(value) + matrix[i,i] = value + if format == 'scipy.sparse': + matrix = matrix.tocsr() + return matrix + + +class Hamiltonian(SHOOp): + """The Hamiltonian Operator. + + The Hamiltonian is used to solve the time-independent Schrodinger + equation. The Hamiltonian can be expressed using the ladder operators, + as well as by position and momentum. We can represent the Hamiltonian + Operator as a matrix, which will be its default basis. + + Parameters + ========== + + args : tuple + The list of numbers or parameters that uniquely specify the + operator. + + Examples + ======== + + Create a Hamiltonian Operator and rewrite it in terms of the ladder + operators, position and momentum, and the Number Operator: + + >>> from sympy.physics.quantum.sho1d import Hamiltonian + + >>> H = Hamiltonian('H') + >>> H.rewrite('a').doit() + hbar*omega*(1/2 + RaisingOp(a)*a) + >>> H.rewrite('xp').doit() + (m**2*omega**2*X**2 + Px**2)/(2*m) + >>> H.rewrite('N').doit() + hbar*omega*(1/2 + N) + + Take the Commutator of the Hamiltonian and the Number Operator: + + >>> from sympy.physics.quantum import Commutator + >>> from sympy.physics.quantum.sho1d import Hamiltonian, NumberOp + + >>> H = Hamiltonian('H') + >>> N = NumberOp('N') + >>> Commutator(H,N).doit() + 0 + + Apply the Hamiltonian Operator to a state: + + >>> from sympy.physics.quantum import qapply + >>> from sympy.physics.quantum.sho1d import Hamiltonian, SHOKet + + >>> H = Hamiltonian('H') + >>> k = SHOKet('k') + >>> qapply(H*k) + hbar*k*omega*|k> + hbar*omega*|k>/2 + + Matrix Representation + + >>> from sympy.physics.quantum.sho1d import Hamiltonian + >>> from sympy.physics.quantum.represent import represent + + >>> H = Hamiltonian('H') + >>> represent(H, basis=N, ndim=4, format='sympy') + Matrix([ + [hbar*omega/2, 0, 0, 0], + [ 0, 3*hbar*omega/2, 0, 0], + [ 0, 0, 5*hbar*omega/2, 0], + [ 0, 0, 0, 7*hbar*omega/2]]) + + """ + + def _eval_rewrite_as_a(self, *args, **kwargs): + return hbar*omega*(ad*a + S.Half) + + def _eval_rewrite_as_xp(self, *args, **kwargs): + return (S.One/(Integer(2)*m))*(Px**2 + (m*omega*X)**2) + + def _eval_rewrite_as_N(self, *args, **kwargs): + return hbar*omega*(N + S.Half) + + def _apply_operator_SHOKet(self, ket, **options): + return (hbar*omega*(ket.n + S.Half))*ket + + def _eval_commutator_NumberOp(self, other): + return S.Zero + + def _represent_default_basis(self, **options): + return self._represent_NumberOp(None, **options) + + def _represent_XOp(self, basis, **options): + # This logic is good but the underlying position + # representation logic is broken. + # temp = self.rewrite('xp').doit() + # result = represent(temp, basis=X) + # return result + raise NotImplementedError('Position representation is not implemented') + + def _represent_NumberOp(self, basis, **options): + ndim_info = options.get('ndim', 4) + format = options.get('format', 'sympy') + matrix = matrix_zeros(ndim_info, ndim_info, **options) + for i in range(ndim_info): + value = i + S.Half + if format == 'scipy.sparse': + value = float(value) + matrix[i,i] = value + if format == 'scipy.sparse': + matrix = matrix.tocsr() + return hbar*omega*matrix + +#------------------------------------------------------------------------------ + +class SHOState(State): + """State class for SHO states""" + + @classmethod + def _eval_hilbert_space(cls, label): + return ComplexSpace(S.Infinity) + + @property + def n(self): + return self.args[0] + + +class SHOKet(SHOState, Ket): + """1D eigenket. + + Inherits from SHOState and Ket. + + Parameters + ========== + + args : tuple + The list of numbers or parameters that uniquely specify the ket + This is usually its quantum numbers or its symbol. + + Examples + ======== + + Ket's know about their associated bra: + + >>> from sympy.physics.quantum.sho1d import SHOKet + + >>> k = SHOKet('k') + >>> k.dual + >> k.dual_class() + + + Take the Inner Product with a bra: + + >>> from sympy.physics.quantum import InnerProduct + >>> from sympy.physics.quantum.sho1d import SHOKet, SHOBra + + >>> k = SHOKet('k') + >>> b = SHOBra('b') + >>> InnerProduct(b,k).doit() + KroneckerDelta(b, k) + + Vector representation of a numerical state ket: + + >>> from sympy.physics.quantum.sho1d import SHOKet, NumberOp + >>> from sympy.physics.quantum.represent import represent + + >>> k = SHOKet(3) + >>> N = NumberOp('N') + >>> represent(k, basis=N, ndim=4) + Matrix([ + [0], + [0], + [0], + [1]]) + + """ + + @classmethod + def dual_class(self): + return SHOBra + + def _eval_innerproduct_SHOBra(self, bra, **hints): + result = KroneckerDelta(self.n, bra.n) + return result + + def _represent_default_basis(self, **options): + return self._represent_NumberOp(None, **options) + + def _represent_NumberOp(self, basis, **options): + ndim_info = options.get('ndim', 4) + format = options.get('format', 'sympy') + options['spmatrix'] = 'lil' + vector = matrix_zeros(ndim_info, 1, **options) + if isinstance(self.n, Integer): + if self.n >= ndim_info: + return ValueError("N-Dimension too small") + if format == 'scipy.sparse': + vector[int(self.n), 0] = 1.0 + vector = vector.tocsr() + elif format == 'numpy': + vector[int(self.n), 0] = 1.0 + else: + vector[self.n, 0] = S.One + return vector + else: + return ValueError("Not Numerical State") + + +class SHOBra(SHOState, Bra): + """A time-independent Bra in SHO. + + Inherits from SHOState and Bra. + + Parameters + ========== + + args : tuple + The list of numbers or parameters that uniquely specify the ket + This is usually its quantum numbers or its symbol. + + Examples + ======== + + Bra's know about their associated ket: + + >>> from sympy.physics.quantum.sho1d import SHOBra + + >>> b = SHOBra('b') + >>> b.dual + |b> + >>> b.dual_class() + + + Vector representation of a numerical state bra: + + >>> from sympy.physics.quantum.sho1d import SHOBra, NumberOp + >>> from sympy.physics.quantum.represent import represent + + >>> b = SHOBra(3) + >>> N = NumberOp('N') + >>> represent(b, basis=N, ndim=4) + Matrix([[0, 0, 0, 1]]) + + """ + + @classmethod + def dual_class(self): + return SHOKet + + def _represent_default_basis(self, **options): + return self._represent_NumberOp(None, **options) + + def _represent_NumberOp(self, basis, **options): + ndim_info = options.get('ndim', 4) + format = options.get('format', 'sympy') + options['spmatrix'] = 'lil' + vector = matrix_zeros(1, ndim_info, **options) + if isinstance(self.n, Integer): + if self.n >= ndim_info: + return ValueError("N-Dimension too small") + if format == 'scipy.sparse': + vector[0, int(self.n)] = 1.0 + vector = vector.tocsr() + elif format == 'numpy': + vector[0, int(self.n)] = 1.0 + else: + vector[0, self.n] = S.One + return vector + else: + return ValueError("Not Numerical State") + + +ad = RaisingOp('a') +a = LoweringOp('a') +H = Hamiltonian('H') +N = NumberOp('N') +omega = Symbol('omega') +m = Symbol('m') diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/shor.py b/phivenv/Lib/site-packages/sympy/physics/quantum/shor.py new file mode 100644 index 0000000000000000000000000000000000000000..fc9e55229d74634bdb82efc03c2d1649e088efb3 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/shor.py @@ -0,0 +1,173 @@ +"""Shor's algorithm and helper functions. + +Todo: + +* Get the CMod gate working again using the new Gate API. +* Fix everything. +* Update docstrings and reformat. +""" + +import math +import random + +from sympy.core.mul import Mul +from sympy.core.singleton import S +from sympy.functions.elementary.exponential import log +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.core.intfunc import igcd +from sympy.ntheory import continued_fraction_periodic as continued_fraction +from sympy.utilities.iterables import variations + +from sympy.physics.quantum.gate import Gate +from sympy.physics.quantum.qubit import Qubit, measure_partial_oneshot +from sympy.physics.quantum.qapply import qapply +from sympy.physics.quantum.qft import QFT +from sympy.physics.quantum.qexpr import QuantumError + + +class OrderFindingException(QuantumError): + pass + + +class CMod(Gate): + """A controlled mod gate. + + This is black box controlled Mod function for use by shor's algorithm. + TODO: implement a decompose property that returns how to do this in terms + of elementary gates + """ + + @classmethod + def _eval_args(cls, args): + # t = args[0] + # a = args[1] + # N = args[2] + raise NotImplementedError('The CMod gate has not been completed.') + + @property + def t(self): + """Size of 1/2 input register. First 1/2 holds output.""" + return self.label[0] + + @property + def a(self): + """Base of the controlled mod function.""" + return self.label[1] + + @property + def N(self): + """N is the type of modular arithmetic we are doing.""" + return self.label[2] + + def _apply_operator_Qubit(self, qubits, **options): + """ + This directly calculates the controlled mod of the second half of + the register and puts it in the second + This will look pretty when we get Tensor Symbolically working + """ + n = 1 + k = 0 + # Determine the value stored in high memory. + for i in range(self.t): + k += n*qubits[self.t + i] + n *= 2 + + # The value to go in low memory will be out. + out = int(self.a**k % self.N) + + # Create array for new qbit-ket which will have high memory unaffected + outarray = list(qubits.args[0][:self.t]) + + # Place out in low memory + for i in reversed(range(self.t)): + outarray.append((out >> i) & 1) + + return Qubit(*outarray) + + +def shor(N): + """This function implements Shor's factoring algorithm on the Integer N + + The algorithm starts by picking a random number (a) and seeing if it is + coprime with N. If it is not, then the gcd of the two numbers is a factor + and we are done. Otherwise, it begins the period_finding subroutine which + finds the period of a in modulo N arithmetic. This period, if even, can + be used to calculate factors by taking a**(r/2)-1 and a**(r/2)+1. + These values are returned. + """ + a = random.randrange(N - 2) + 2 + if igcd(N, a) != 1: + return igcd(N, a) + r = period_find(a, N) + if r % 2 == 1: + shor(N) + answer = (igcd(a**(r/2) - 1, N), igcd(a**(r/2) + 1, N)) + return answer + + +def getr(x, y, N): + fraction = continued_fraction(x, y) + # Now convert into r + total = ratioize(fraction, N) + return total + + +def ratioize(list, N): + if list[0] > N: + return S.Zero + if len(list) == 1: + return list[0] + return list[0] + ratioize(list[1:], N) + + +def period_find(a, N): + """Finds the period of a in modulo N arithmetic + + This is quantum part of Shor's algorithm. It takes two registers, + puts first in superposition of states with Hadamards so: ``|k>|0>`` + with k being all possible choices. It then does a controlled mod and + a QFT to determine the order of a. + """ + epsilon = .5 + # picks out t's such that maintains accuracy within epsilon + t = int(2*math.ceil(log(N, 2))) + # make the first half of register be 0's |000...000> + start = [0 for x in range(t)] + # Put second half into superposition of states so we have |1>x|0> + |2>x|0> + ... |k>x>|0> + ... + |2**n-1>x|0> + factor = 1/sqrt(2**t) + qubits = 0 + for arr in variations(range(2), t, repetition=True): + qbitArray = list(arr) + start + qubits = qubits + Qubit(*qbitArray) + circuit = (factor*qubits).expand() + # Controlled second half of register so that we have: + # |1>x|a**1 %N> + |2>x|a**2 %N> + ... + |k>x|a**k %N >+ ... + |2**n-1=k>x|a**k % n> + circuit = CMod(t, a, N)*circuit + # will measure first half of register giving one of the a**k%N's + + circuit = qapply(circuit) + for i in range(t): + circuit = measure_partial_oneshot(circuit, i) + # Now apply Inverse Quantum Fourier Transform on the second half of the register + + circuit = qapply(QFT(t, t*2).decompose()*circuit, floatingPoint=True) + for i in range(t): + circuit = measure_partial_oneshot(circuit, i + t) + if isinstance(circuit, Qubit): + register = circuit + elif isinstance(circuit, Mul): + register = circuit.args[-1] + else: + register = circuit.args[-1].args[-1] + + n = 1 + answer = 0 + for i in range(len(register)/2): + answer += n*register[i + t] + n = n << 1 + if answer == 0: + raise OrderFindingException( + "Order finder returned 0. Happens with chance %f" % epsilon) + #turn answer into r using continued fractions + g = getr(answer, 2**t, N) + return g diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/spin.py b/phivenv/Lib/site-packages/sympy/physics/quantum/spin.py new file mode 100644 index 0000000000000000000000000000000000000000..6be53d01711adbed8c078fffca1d618c1aa3c6e6 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/spin.py @@ -0,0 +1,2150 @@ +"""Quantum mechanical angular momentum.""" + +from sympy.concrete.summations import Sum +from sympy.core.add import Add +from sympy.core.containers import Tuple +from sympy.core.expr import Expr +from sympy.core.numbers import int_valued +from sympy.core.mul import Mul +from sympy.core.numbers import (I, Integer, Rational, pi) +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, symbols) +from sympy.core.sympify import sympify +from sympy.functions.combinatorial.factorials import (binomial, factorial) +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.simplify.simplify import simplify +from sympy.matrices import zeros +from sympy.printing.pretty.stringpict import prettyForm, stringPict +from sympy.printing.pretty.pretty_symbology import pretty_symbol + +from sympy.physics.quantum.qexpr import QExpr +from sympy.physics.quantum.operator import (HermitianOperator, Operator, + UnitaryOperator) +from sympy.physics.quantum.state import Bra, Ket, State +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.physics.quantum.constants import hbar +from sympy.physics.quantum.hilbert import ComplexSpace, DirectSumHilbertSpace +from sympy.physics.quantum.tensorproduct import TensorProduct +from sympy.physics.quantum.cg import CG +from sympy.physics.quantum.qapply import qapply + + +__all__ = [ + 'm_values', + 'Jplus', + 'Jminus', + 'Jx', + 'Jy', + 'Jz', + 'J2', + 'Rotation', + 'WignerD', + 'JxKet', + 'JxBra', + 'JyKet', + 'JyBra', + 'JzKet', + 'JzBra', + 'JzOp', + 'J2Op', + 'JxKetCoupled', + 'JxBraCoupled', + 'JyKetCoupled', + 'JyBraCoupled', + 'JzKetCoupled', + 'JzBraCoupled', + 'couple', + 'uncouple' +] + + +def m_values(j): + j = sympify(j) + size = 2*j + 1 + if not size.is_Integer or not size > 0: + raise ValueError( + 'Only integer or half-integer values allowed for j, got: : %r' % j + ) + return size, [j - i for i in range(int(2*j + 1))] + + +#----------------------------------------------------------------------------- +# Spin Operators +#----------------------------------------------------------------------------- + + +class SpinOpBase: + """Base class for spin operators.""" + + @classmethod + def _eval_hilbert_space(cls, label): + # We consider all j values so our space is infinite. + return ComplexSpace(S.Infinity) + + @property + def name(self): + return self.args[0] + + def _print_contents(self, printer, *args): + return '%s%s' % (self.name, self._coord) + + def _print_contents_pretty(self, printer, *args): + a = stringPict(str(self.name)) + b = stringPict(self._coord) + return self._print_subscript_pretty(a, b) + + def _print_contents_latex(self, printer, *args): + return r'%s_%s' % ((self.name, self._coord)) + + def _represent_base(self, basis, **options): + j = options.get('j', S.Half) + size, mvals = m_values(j) + result = zeros(size, size) + for p in range(size): + for q in range(size): + me = self.matrix_element(j, mvals[p], j, mvals[q]) + result[p, q] = me + return result + + def _apply_op(self, ket, orig_basis, **options): + state = ket.rewrite(self.basis) + # If the state has only one term + if isinstance(state, State): + ret = (hbar*state.m)*state + # state is a linear combination of states + elif isinstance(state, Sum): + ret = self._apply_operator_Sum(state, **options) + else: + ret = qapply(self*state) + if ret == self*state: + raise NotImplementedError + return ret.rewrite(orig_basis) + + def _apply_operator_JxKet(self, ket, **options): + return self._apply_op(ket, 'Jx', **options) + + def _apply_operator_JxKetCoupled(self, ket, **options): + return self._apply_op(ket, 'Jx', **options) + + def _apply_operator_JyKet(self, ket, **options): + return self._apply_op(ket, 'Jy', **options) + + def _apply_operator_JyKetCoupled(self, ket, **options): + return self._apply_op(ket, 'Jy', **options) + + def _apply_operator_JzKet(self, ket, **options): + return self._apply_op(ket, 'Jz', **options) + + def _apply_operator_JzKetCoupled(self, ket, **options): + return self._apply_op(ket, 'Jz', **options) + + def _apply_operator_TensorProduct(self, tp, **options): + # Uncoupling operator is only easily found for coordinate basis spin operators + # TODO: add methods for uncoupling operators + if not isinstance(self, (JxOp, JyOp, JzOp)): + raise NotImplementedError + result = [] + for n in range(len(tp.args)): + arg = [] + arg.extend(tp.args[:n]) + arg.append(self._apply_operator(tp.args[n])) + arg.extend(tp.args[n + 1:]) + result.append(tp.__class__(*arg)) + return Add(*result).expand() + + # TODO: move this to qapply_Mul + def _apply_operator_Sum(self, s, **options): + new_func = qapply(self*s.function) + if new_func == self*s.function: + raise NotImplementedError + return Sum(new_func, *s.limits) + + def _eval_trace(self, **options): + #TODO: use options to use different j values + #For now eval at default basis + + # is it efficient to represent each time + # to do a trace? + return self._represent_default_basis().trace() + + +class JplusOp(SpinOpBase, Operator): + """The J+ operator.""" + + _coord = '+' + + basis = 'Jz' + + def _eval_commutator_JminusOp(self, other): + return 2*hbar*JzOp(self.name) + + def _apply_operator_JzKet(self, ket, **options): + j = ket.j + m = ket.m + if m.is_Number and j.is_Number: + if m >= j: + return S.Zero + return hbar*sqrt(j*(j + S.One) - m*(m + S.One))*JzKet(j, m + S.One) + + def _apply_operator_JzKetCoupled(self, ket, **options): + j = ket.j + m = ket.m + jn = ket.jn + coupling = ket.coupling + if m.is_Number and j.is_Number: + if m >= j: + return S.Zero + return hbar*sqrt(j*(j + S.One) - m*(m + S.One))*JzKetCoupled(j, m + S.One, jn, coupling) + + def matrix_element(self, j, m, jp, mp): + result = hbar*sqrt(j*(j + S.One) - mp*(mp + S.One)) + result *= KroneckerDelta(m, mp + 1) + result *= KroneckerDelta(j, jp) + return result + + def _represent_default_basis(self, **options): + return self._represent_JzOp(None, **options) + + def _represent_JzOp(self, basis, **options): + return self._represent_base(basis, **options) + + def _eval_rewrite_as_xyz(self, *args, **kwargs): + return JxOp(args[0]) + I*JyOp(args[0]) + + +class JminusOp(SpinOpBase, Operator): + """The J- operator.""" + + _coord = '-' + + basis = 'Jz' + + def _apply_operator_JzKet(self, ket, **options): + j = ket.j + m = ket.m + if m.is_Number and j.is_Number: + if m <= -j: + return S.Zero + return hbar*sqrt(j*(j + S.One) - m*(m - S.One))*JzKet(j, m - S.One) + + def _apply_operator_JzKetCoupled(self, ket, **options): + j = ket.j + m = ket.m + jn = ket.jn + coupling = ket.coupling + if m.is_Number and j.is_Number: + if m <= -j: + return S.Zero + return hbar*sqrt(j*(j + S.One) - m*(m - S.One))*JzKetCoupled(j, m - S.One, jn, coupling) + + def matrix_element(self, j, m, jp, mp): + result = hbar*sqrt(j*(j + S.One) - mp*(mp - S.One)) + result *= KroneckerDelta(m, mp - 1) + result *= KroneckerDelta(j, jp) + return result + + def _represent_default_basis(self, **options): + return self._represent_JzOp(None, **options) + + def _represent_JzOp(self, basis, **options): + return self._represent_base(basis, **options) + + def _eval_rewrite_as_xyz(self, *args, **kwargs): + return JxOp(args[0]) - I*JyOp(args[0]) + + +class JxOp(SpinOpBase, HermitianOperator): + """The Jx operator.""" + + _coord = 'x' + + basis = 'Jx' + + def _eval_commutator_JyOp(self, other): + return I*hbar*JzOp(self.name) + + def _eval_commutator_JzOp(self, other): + return -I*hbar*JyOp(self.name) + + def _apply_operator_JzKet(self, ket, **options): + jp = JplusOp(self.name)._apply_operator_JzKet(ket, **options) + jm = JminusOp(self.name)._apply_operator_JzKet(ket, **options) + return (jp + jm)/Integer(2) + + def _apply_operator_JzKetCoupled(self, ket, **options): + jp = JplusOp(self.name)._apply_operator_JzKetCoupled(ket, **options) + jm = JminusOp(self.name)._apply_operator_JzKetCoupled(ket, **options) + return (jp + jm)/Integer(2) + + def _represent_default_basis(self, **options): + return self._represent_JzOp(None, **options) + + def _represent_JzOp(self, basis, **options): + jp = JplusOp(self.name)._represent_JzOp(basis, **options) + jm = JminusOp(self.name)._represent_JzOp(basis, **options) + return (jp + jm)/Integer(2) + + def _eval_rewrite_as_plusminus(self, *args, **kwargs): + return (JplusOp(args[0]) + JminusOp(args[0]))/2 + + +class JyOp(SpinOpBase, HermitianOperator): + """The Jy operator.""" + + _coord = 'y' + + basis = 'Jy' + + def _eval_commutator_JzOp(self, other): + return I*hbar*JxOp(self.name) + + def _eval_commutator_JxOp(self, other): + return -I*hbar*J2Op(self.name) + + def _apply_operator_JzKet(self, ket, **options): + jp = JplusOp(self.name)._apply_operator_JzKet(ket, **options) + jm = JminusOp(self.name)._apply_operator_JzKet(ket, **options) + return (jp - jm)/(Integer(2)*I) + + def _apply_operator_JzKetCoupled(self, ket, **options): + jp = JplusOp(self.name)._apply_operator_JzKetCoupled(ket, **options) + jm = JminusOp(self.name)._apply_operator_JzKetCoupled(ket, **options) + return (jp - jm)/(Integer(2)*I) + + def _represent_default_basis(self, **options): + return self._represent_JzOp(None, **options) + + def _represent_JzOp(self, basis, **options): + jp = JplusOp(self.name)._represent_JzOp(basis, **options) + jm = JminusOp(self.name)._represent_JzOp(basis, **options) + return (jp - jm)/(Integer(2)*I) + + def _eval_rewrite_as_plusminus(self, *args, **kwargs): + return (JplusOp(args[0]) - JminusOp(args[0]))/(2*I) + + +class JzOp(SpinOpBase, HermitianOperator): + """The Jz operator.""" + + _coord = 'z' + + basis = 'Jz' + + def _eval_commutator_JxOp(self, other): + return I*hbar*JyOp(self.name) + + def _eval_commutator_JyOp(self, other): + return -I*hbar*JxOp(self.name) + + def _eval_commutator_JplusOp(self, other): + return hbar*JplusOp(self.name) + + def _eval_commutator_JminusOp(self, other): + return -hbar*JminusOp(self.name) + + def matrix_element(self, j, m, jp, mp): + result = hbar*mp + result *= KroneckerDelta(m, mp) + result *= KroneckerDelta(j, jp) + return result + + def _represent_default_basis(self, **options): + return self._represent_JzOp(None, **options) + + def _represent_JzOp(self, basis, **options): + return self._represent_base(basis, **options) + + +class J2Op(SpinOpBase, HermitianOperator): + """The J^2 operator.""" + + _coord = '2' + + def _eval_commutator_JxOp(self, other): + return S.Zero + + def _eval_commutator_JyOp(self, other): + return S.Zero + + def _eval_commutator_JzOp(self, other): + return S.Zero + + def _eval_commutator_JplusOp(self, other): + return S.Zero + + def _eval_commutator_JminusOp(self, other): + return S.Zero + + def _apply_operator_JxKet(self, ket, **options): + j = ket.j + return hbar**2*j*(j + 1)*ket + + def _apply_operator_JxKetCoupled(self, ket, **options): + j = ket.j + return hbar**2*j*(j + 1)*ket + + def _apply_operator_JyKet(self, ket, **options): + j = ket.j + return hbar**2*j*(j + 1)*ket + + def _apply_operator_JyKetCoupled(self, ket, **options): + j = ket.j + return hbar**2*j*(j + 1)*ket + + def _apply_operator_JzKet(self, ket, **options): + j = ket.j + return hbar**2*j*(j + 1)*ket + + def _apply_operator_JzKetCoupled(self, ket, **options): + j = ket.j + return hbar**2*j*(j + 1)*ket + + def matrix_element(self, j, m, jp, mp): + result = (hbar**2)*j*(j + 1) + result *= KroneckerDelta(m, mp) + result *= KroneckerDelta(j, jp) + return result + + def _represent_default_basis(self, **options): + return self._represent_JzOp(None, **options) + + def _represent_JzOp(self, basis, **options): + return self._represent_base(basis, **options) + + def _print_contents_pretty(self, printer, *args): + a = prettyForm(str(self.name)) + b = prettyForm('2') + return a**b + + def _print_contents_latex(self, printer, *args): + return r'%s^2' % str(self.name) + + def _eval_rewrite_as_xyz(self, *args, **kwargs): + return JxOp(args[0])**2 + JyOp(args[0])**2 + JzOp(args[0])**2 + + def _eval_rewrite_as_plusminus(self, *args, **kwargs): + a = args[0] + return JzOp(a)**2 + \ + S.Half*(JplusOp(a)*JminusOp(a) + JminusOp(a)*JplusOp(a)) + + +class Rotation(UnitaryOperator): + """Wigner D operator in terms of Euler angles. + + Defines the rotation operator in terms of the Euler angles defined by + the z-y-z convention for a passive transformation. That is the coordinate + axes are rotated first about the z-axis, giving the new x'-y'-z' axes. Then + this new coordinate system is rotated about the new y'-axis, giving new + x''-y''-z'' axes. Then this new coordinate system is rotated about the + z''-axis. Conventions follow those laid out in [1]_. + + Parameters + ========== + + alpha : Number, Symbol + First Euler Angle + beta : Number, Symbol + Second Euler angle + gamma : Number, Symbol + Third Euler angle + + Examples + ======== + + A simple example rotation operator: + + >>> from sympy import pi + >>> from sympy.physics.quantum.spin import Rotation + >>> Rotation(pi, 0, pi/2) + R(pi,0,pi/2) + + With symbolic Euler angles and calculating the inverse rotation operator: + + >>> from sympy import symbols + >>> a, b, c = symbols('a b c') + >>> Rotation(a, b, c) + R(a,b,c) + >>> Rotation(a, b, c).inverse() + R(-c,-b,-a) + + See Also + ======== + + WignerD: Symbolic Wigner-D function + D: Wigner-D function + d: Wigner small-d function + + References + ========== + + .. [1] Varshalovich, D A, Quantum Theory of Angular Momentum. 1988. + """ + + @classmethod + def _eval_args(cls, args): + args = QExpr._eval_args(args) + if len(args) != 3: + raise ValueError('3 Euler angles required, got: %r' % args) + return args + + @classmethod + def _eval_hilbert_space(cls, label): + # We consider all j values so our space is infinite. + return ComplexSpace(S.Infinity) + + @property + def alpha(self): + return self.label[0] + + @property + def beta(self): + return self.label[1] + + @property + def gamma(self): + return self.label[2] + + def _print_operator_name(self, printer, *args): + return 'R' + + def _print_operator_name_pretty(self, printer, *args): + if printer._use_unicode: + return prettyForm('\N{SCRIPT CAPITAL R}' + ' ') + else: + return prettyForm("R ") + + def _print_operator_name_latex(self, printer, *args): + return r'\mathcal{R}' + + def _eval_inverse(self): + return Rotation(-self.gamma, -self.beta, -self.alpha) + + @classmethod + def D(cls, j, m, mp, alpha, beta, gamma): + """Wigner D-function. + + Returns an instance of the WignerD class corresponding to the Wigner-D + function specified by the parameters. + + Parameters + =========== + + j : Number + Total angular momentum + m : Number + Eigenvalue of angular momentum along axis after rotation + mp : Number + Eigenvalue of angular momentum along rotated axis + alpha : Number, Symbol + First Euler angle of rotation + beta : Number, Symbol + Second Euler angle of rotation + gamma : Number, Symbol + Third Euler angle of rotation + + Examples + ======== + + Return the Wigner-D matrix element for a defined rotation, both + numerical and symbolic: + + >>> from sympy.physics.quantum.spin import Rotation + >>> from sympy import pi, symbols + >>> alpha, beta, gamma = symbols('alpha beta gamma') + >>> Rotation.D(1, 1, 0,pi, pi/2,-pi) + WignerD(1, 1, 0, pi, pi/2, -pi) + + See Also + ======== + + WignerD: Symbolic Wigner-D function + + """ + return WignerD(j, m, mp, alpha, beta, gamma) + + @classmethod + def d(cls, j, m, mp, beta): + """Wigner small-d function. + + Returns an instance of the WignerD class corresponding to the Wigner-D + function specified by the parameters with the alpha and gamma angles + given as 0. + + Parameters + =========== + + j : Number + Total angular momentum + m : Number + Eigenvalue of angular momentum along axis after rotation + mp : Number + Eigenvalue of angular momentum along rotated axis + beta : Number, Symbol + Second Euler angle of rotation + + Examples + ======== + + Return the Wigner-D matrix element for a defined rotation, both + numerical and symbolic: + + >>> from sympy.physics.quantum.spin import Rotation + >>> from sympy import pi, symbols + >>> beta = symbols('beta') + >>> Rotation.d(1, 1, 0, pi/2) + WignerD(1, 1, 0, 0, pi/2, 0) + + See Also + ======== + + WignerD: Symbolic Wigner-D function + + """ + return WignerD(j, m, mp, 0, beta, 0) + + def matrix_element(self, j, m, jp, mp): + result = self.__class__.D( + jp, m, mp, self.alpha, self.beta, self.gamma + ) + result *= KroneckerDelta(j, jp) + return result + + def _represent_base(self, basis, **options): + j = sympify(options.get('j', S.Half)) + # TODO: move evaluation up to represent function/implement elsewhere + evaluate = sympify(options.get('doit')) + size, mvals = m_values(j) + result = zeros(size, size) + for p in range(size): + for q in range(size): + me = self.matrix_element(j, mvals[p], j, mvals[q]) + if evaluate: + result[p, q] = me.doit() + else: + result[p, q] = me + return result + + def _represent_default_basis(self, **options): + return self._represent_JzOp(None, **options) + + def _represent_JzOp(self, basis, **options): + return self._represent_base(basis, **options) + + def _apply_operator_uncoupled(self, state, ket, *, dummy=True, **options): + a = self.alpha + b = self.beta + g = self.gamma + j = ket.j + m = ket.m + if j.is_number: + s = [] + size = m_values(j) + sz = size[1] + for mp in sz: + r = Rotation.D(j, m, mp, a, b, g) + z = r.doit() + s.append(z*state(j, mp)) + return Add(*s) + else: + if dummy: + mp = Dummy('mp') + else: + mp = symbols('mp') + return Sum(Rotation.D(j, m, mp, a, b, g)*state(j, mp), (mp, -j, j)) + + def _apply_operator_JxKet(self, ket, **options): + return self._apply_operator_uncoupled(JxKet, ket, **options) + + def _apply_operator_JyKet(self, ket, **options): + return self._apply_operator_uncoupled(JyKet, ket, **options) + + def _apply_operator_JzKet(self, ket, **options): + return self._apply_operator_uncoupled(JzKet, ket, **options) + + def _apply_operator_coupled(self, state, ket, *, dummy=True, **options): + a = self.alpha + b = self.beta + g = self.gamma + j = ket.j + m = ket.m + jn = ket.jn + coupling = ket.coupling + if j.is_number: + s = [] + size = m_values(j) + sz = size[1] + for mp in sz: + r = Rotation.D(j, m, mp, a, b, g) + z = r.doit() + s.append(z*state(j, mp, jn, coupling)) + return Add(*s) + else: + if dummy: + mp = Dummy('mp') + else: + mp = symbols('mp') + return Sum(Rotation.D(j, m, mp, a, b, g)*state( + j, mp, jn, coupling), (mp, -j, j)) + + def _apply_operator_JxKetCoupled(self, ket, **options): + return self._apply_operator_coupled(JxKetCoupled, ket, **options) + + def _apply_operator_JyKetCoupled(self, ket, **options): + return self._apply_operator_coupled(JyKetCoupled, ket, **options) + + def _apply_operator_JzKetCoupled(self, ket, **options): + return self._apply_operator_coupled(JzKetCoupled, ket, **options) + +class WignerD(Expr): + r"""Wigner-D function + + The Wigner D-function gives the matrix elements of the rotation + operator in the jm-representation. For the Euler angles `\alpha`, + `\beta`, `\gamma`, the D-function is defined such that: + + .. math :: + = \delta_{jj'} D(j, m, m', \alpha, \beta, \gamma) + + Where the rotation operator is as defined by the Rotation class [1]_. + + The Wigner D-function defined in this way gives: + + .. math :: + D(j, m, m', \alpha, \beta, \gamma) = e^{-i m \alpha} d(j, m, m', \beta) e^{-i m' \gamma} + + Where d is the Wigner small-d function, which is given by Rotation.d. + + The Wigner small-d function gives the component of the Wigner + D-function that is determined by the second Euler angle. That is the + Wigner D-function is: + + .. math :: + D(j, m, m', \alpha, \beta, \gamma) = e^{-i m \alpha} d(j, m, m', \beta) e^{-i m' \gamma} + + Where d is the small-d function. The Wigner D-function is given by + Rotation.D. + + Note that to evaluate the D-function, the j, m and mp parameters must + be integer or half integer numbers. + + Parameters + ========== + + j : Number + Total angular momentum + m : Number + Eigenvalue of angular momentum along axis after rotation + mp : Number + Eigenvalue of angular momentum along rotated axis + alpha : Number, Symbol + First Euler angle of rotation + beta : Number, Symbol + Second Euler angle of rotation + gamma : Number, Symbol + Third Euler angle of rotation + + Examples + ======== + + Evaluate the Wigner-D matrix elements of a simple rotation: + + >>> from sympy.physics.quantum.spin import Rotation + >>> from sympy import pi + >>> rot = Rotation.D(1, 1, 0, pi, pi/2, 0) + >>> rot + WignerD(1, 1, 0, pi, pi/2, 0) + >>> rot.doit() + sqrt(2)/2 + + Evaluate the Wigner-d matrix elements of a simple rotation + + >>> rot = Rotation.d(1, 1, 0, pi/2) + >>> rot + WignerD(1, 1, 0, 0, pi/2, 0) + >>> rot.doit() + -sqrt(2)/2 + + See Also + ======== + + Rotation: Rotation operator + + References + ========== + + .. [1] Varshalovich, D A, Quantum Theory of Angular Momentum. 1988. + """ + + is_commutative = True + + def __new__(cls, *args, **hints): + if not len(args) == 6: + raise ValueError('6 parameters expected, got %s' % args) + args = sympify(args) + evaluate = hints.get('evaluate', False) + if evaluate: + return Expr.__new__(cls, *args)._eval_wignerd() + return Expr.__new__(cls, *args) + + @property + def j(self): + return self.args[0] + + @property + def m(self): + return self.args[1] + + @property + def mp(self): + return self.args[2] + + @property + def alpha(self): + return self.args[3] + + @property + def beta(self): + return self.args[4] + + @property + def gamma(self): + return self.args[5] + + def _latex(self, printer, *args): + if self.alpha == 0 and self.gamma == 0: + return r'd^{%s}_{%s,%s}\left(%s\right)' % \ + ( + printer._print(self.j), printer._print( + self.m), printer._print(self.mp), + printer._print(self.beta) ) + return r'D^{%s}_{%s,%s}\left(%s,%s,%s\right)' % \ + ( + printer._print( + self.j), printer._print(self.m), printer._print(self.mp), + printer._print(self.alpha), printer._print(self.beta), printer._print(self.gamma) ) + + def _pretty(self, printer, *args): + top = printer._print(self.j) + + bot = printer._print(self.m) + bot = prettyForm(*bot.right(',')) + bot = prettyForm(*bot.right(printer._print(self.mp))) + + pad = max(top.width(), bot.width()) + top = prettyForm(*top.left(' ')) + bot = prettyForm(*bot.left(' ')) + if pad > top.width(): + top = prettyForm(*top.right(' '*(pad - top.width()))) + if pad > bot.width(): + bot = prettyForm(*bot.right(' '*(pad - bot.width()))) + if self.alpha == 0 and self.gamma == 0: + args = printer._print(self.beta) + s = stringPict('d' + ' '*pad) + else: + args = printer._print(self.alpha) + args = prettyForm(*args.right(',')) + args = prettyForm(*args.right(printer._print(self.beta))) + args = prettyForm(*args.right(',')) + args = prettyForm(*args.right(printer._print(self.gamma))) + + s = stringPict('D' + ' '*pad) + + args = prettyForm(*args.parens()) + s = prettyForm(*s.above(top)) + s = prettyForm(*s.below(bot)) + s = prettyForm(*s.right(args)) + return s + + def doit(self, **hints): + hints['evaluate'] = True + return WignerD(*self.args, **hints) + + def _eval_wignerd(self): + j = self.j + m = self.m + mp = self.mp + alpha = self.alpha + beta = self.beta + gamma = self.gamma + if alpha == 0 and beta == 0 and gamma == 0: + return KroneckerDelta(m, mp) + if not j.is_number: + raise ValueError( + 'j parameter must be numerical to evaluate, got %s' % j) + r = 0 + if beta == pi/2: + # Varshalovich Equation (5), Section 4.16, page 113, setting + # alpha=gamma=0. + for k in range(2*j + 1): + if k > j + mp or k > j - m or k < mp - m: + continue + r += (S.NegativeOne)**k*binomial(j + mp, k)*binomial(j - mp, k + m - mp) + r *= (S.NegativeOne)**(m - mp) / 2**j*sqrt(factorial(j + m) * + factorial(j - m) / (factorial(j + mp)*factorial(j - mp))) + else: + # Varshalovich Equation(5), Section 4.7.2, page 87, where we set + # beta1=beta2=pi/2, and we get alpha=gamma=pi/2 and beta=phi+pi, + # then we use the Eq. (1), Section 4.4. page 79, to simplify: + # d(j, m, mp, beta+pi) = (-1)**(j-mp)*d(j, m, -mp, beta) + # This happens to be almost the same as in Eq.(10), Section 4.16, + # except that we need to substitute -mp for mp. + size, mvals = m_values(j) + for mpp in mvals: + r += Rotation.d(j, m, mpp, pi/2).doit()*(cos(-mpp*beta) + I*sin(-mpp*beta))*\ + Rotation.d(j, mpp, -mp, pi/2).doit() + # Empirical normalization factor so results match Varshalovich + # Tables 4.3-4.12 + # Note that this exact normalization does not follow from the + # above equations + r = r*I**(2*j - m - mp)*(-1)**(2*m) + # Finally, simplify the whole expression + r = simplify(r) + r *= exp(-I*m*alpha)*exp(-I*mp*gamma) + return r + + +Jx = JxOp('J') +Jy = JyOp('J') +Jz = JzOp('J') +J2 = J2Op('J') +Jplus = JplusOp('J') +Jminus = JminusOp('J') + + +#----------------------------------------------------------------------------- +# Spin States +#----------------------------------------------------------------------------- + + +class SpinState(State): + """Base class for angular momentum states.""" + + _label_separator = ',' + + def __new__(cls, j, m): + j = sympify(j) + m = sympify(m) + if j.is_number: + if 2*j != int(2*j): + raise ValueError( + 'j must be integer or half-integer, got: %s' % j) + if j < 0: + raise ValueError('j must be >= 0, got: %s' % j) + if m.is_number: + if 2*m != int(2*m): + raise ValueError( + 'm must be integer or half-integer, got: %s' % m) + if j.is_number and m.is_number: + if abs(m) > j: + raise ValueError('Allowed values for m are -j <= m <= j, got j, m: %s, %s' % (j, m)) + if int(j - m) != j - m: + raise ValueError('Both j and m must be integer or half-integer, got j, m: %s, %s' % (j, m)) + return State.__new__(cls, j, m) + + @property + def j(self): + return self.label[0] + + @property + def m(self): + return self.label[1] + + @classmethod + def _eval_hilbert_space(cls, label): + return ComplexSpace(2*label[0] + 1) + + def _represent_base(self, **options): + j = self.j + m = self.m + alpha = sympify(options.get('alpha', 0)) + beta = sympify(options.get('beta', 0)) + gamma = sympify(options.get('gamma', 0)) + size, mvals = m_values(j) + result = zeros(size, 1) + # breaks finding angles on L930 + for p, mval in enumerate(mvals): + if m.is_number: + result[p, 0] = Rotation.D( + self.j, mval, self.m, alpha, beta, gamma).doit() + else: + result[p, 0] = Rotation.D(self.j, mval, + self.m, alpha, beta, gamma) + return result + + def _eval_rewrite_as_Jx(self, *args, **options): + if isinstance(self, Bra): + return self._rewrite_basis(Jx, JxBra, **options) + return self._rewrite_basis(Jx, JxKet, **options) + + def _eval_rewrite_as_Jy(self, *args, **options): + if isinstance(self, Bra): + return self._rewrite_basis(Jy, JyBra, **options) + return self._rewrite_basis(Jy, JyKet, **options) + + def _eval_rewrite_as_Jz(self, *args, **options): + if isinstance(self, Bra): + return self._rewrite_basis(Jz, JzBra, **options) + return self._rewrite_basis(Jz, JzKet, **options) + + def _rewrite_basis(self, basis, evect, **options): + from sympy.physics.quantum.represent import represent + j = self.j + args = self.args[2:] + if j.is_number: + if isinstance(self, CoupledSpinState): + if j == int(j): + start = j**2 + else: + start = (2*j - 1)*(2*j + 1)/4 + else: + start = 0 + vect = represent(self, basis=basis, **options) + result = Add( + *[vect[start + i]*evect(j, j - i, *args) for i in range(2*j + 1)]) + if isinstance(self, CoupledSpinState) and options.get('coupled') is False: + return uncouple(result) + return result + else: + i = 0 + mi = symbols('mi') + # make sure not to introduce a symbol already in the state + while self.subs(mi, 0) != self: + i += 1 + mi = symbols('mi%d' % i) + break + # TODO: better way to get angles of rotation + if isinstance(self, CoupledSpinState): + test_args = (0, mi, (0, 0)) + else: + test_args = (0, mi) + if isinstance(self, Ket): + angles = represent( + self.__class__(*test_args), basis=basis)[0].args[3:6] + else: + angles = represent(self.__class__( + *test_args), basis=basis)[0].args[0].args[3:6] + if angles == (0, 0, 0): + return self + else: + state = evect(j, mi, *args) + lt = Rotation.D(j, mi, self.m, *angles) + return Sum(lt*state, (mi, -j, j)) + + def _eval_innerproduct_JxBra(self, bra, **hints): + result = KroneckerDelta(self.j, bra.j) + if bra.dual_class() is not self.__class__: + result *= self._represent_JxOp(None)[bra.j - bra.m] + else: + result *= KroneckerDelta( + self.j, bra.j)*KroneckerDelta(self.m, bra.m) + return result + + def _eval_innerproduct_JyBra(self, bra, **hints): + result = KroneckerDelta(self.j, bra.j) + if bra.dual_class() is not self.__class__: + result *= self._represent_JyOp(None)[bra.j - bra.m] + else: + result *= KroneckerDelta( + self.j, bra.j)*KroneckerDelta(self.m, bra.m) + return result + + def _eval_innerproduct_JzBra(self, bra, **hints): + result = KroneckerDelta(self.j, bra.j) + if bra.dual_class() is not self.__class__: + result *= self._represent_JzOp(None)[bra.j - bra.m] + else: + result *= KroneckerDelta( + self.j, bra.j)*KroneckerDelta(self.m, bra.m) + return result + + def _eval_trace(self, bra, **hints): + + # One way to implement this method is to assume the basis set k is + # passed. + # Then we can apply the discrete form of Trace formula here + # Tr(|i> + #then we do qapply() on each each inner product and sum over them. + + # OR + + # Inner product of |i>>> from sympy.physics.quantum.spin import JzKet, JxKet + >>> from sympy import symbols + >>> JzKet(1, 0) + |1,0> + >>> j, m = symbols('j m') + >>> JzKet(j, m) + |j,m> + + Rewriting the JzKet in terms of eigenkets of the Jx operator: + Note: that the resulting eigenstates are JxKet's + + >>> JzKet(1,1).rewrite("Jx") + |1,-1>/2 - sqrt(2)*|1,0>/2 + |1,1>/2 + + Get the vector representation of a state in terms of the basis elements + of the Jx operator: + + >>> from sympy.physics.quantum.represent import represent + >>> from sympy.physics.quantum.spin import Jx, Jz + >>> represent(JzKet(1,-1), basis=Jx) + Matrix([ + [ 1/2], + [sqrt(2)/2], + [ 1/2]]) + + Apply innerproducts between states: + + >>> from sympy.physics.quantum.innerproduct import InnerProduct + >>> from sympy.physics.quantum.spin import JxBra + >>> i = InnerProduct(JxBra(1,1), JzKet(1,1)) + >>> i + <1,1|1,1> + >>> i.doit() + 1/2 + + *Uncoupled States:* + + Define an uncoupled state as a TensorProduct between two Jz eigenkets: + + >>> from sympy.physics.quantum.tensorproduct import TensorProduct + >>> j1,m1,j2,m2 = symbols('j1 m1 j2 m2') + >>> TensorProduct(JzKet(1,0), JzKet(1,1)) + |1,0>x|1,1> + >>> TensorProduct(JzKet(j1,m1), JzKet(j2,m2)) + |j1,m1>x|j2,m2> + + A TensorProduct can be rewritten, in which case the eigenstates that make + up the tensor product is rewritten to the new basis: + + >>> TensorProduct(JzKet(1,1),JxKet(1,1)).rewrite('Jz') + |1,1>x|1,-1>/2 + sqrt(2)*|1,1>x|1,0>/2 + |1,1>x|1,1>/2 + + The represent method for TensorProduct's gives the vector representation of + the state. Note that the state in the product basis is the equivalent of the + tensor product of the vector representation of the component eigenstates: + + >>> represent(TensorProduct(JzKet(1,0),JzKet(1,1))) + Matrix([ + [0], + [0], + [0], + [1], + [0], + [0], + [0], + [0], + [0]]) + >>> represent(TensorProduct(JzKet(1,1),JxKet(1,1)), basis=Jz) + Matrix([ + [ 1/2], + [sqrt(2)/2], + [ 1/2], + [ 0], + [ 0], + [ 0], + [ 0], + [ 0], + [ 0]]) + + See Also + ======== + + JzKetCoupled: Coupled eigenstates + sympy.physics.quantum.tensorproduct.TensorProduct: Used to specify uncoupled states + uncouple: Uncouples states given coupling parameters + couple: Couples uncoupled states + + """ + + @classmethod + def dual_class(self): + return JzBra + + @classmethod + def coupled_class(self): + return JzKetCoupled + + def _represent_default_basis(self, **options): + return self._represent_JzOp(None, **options) + + def _represent_JxOp(self, basis, **options): + return self._represent_base(beta=pi*Rational(3, 2), **options) + + def _represent_JyOp(self, basis, **options): + return self._represent_base(alpha=pi*Rational(3, 2), beta=pi/2, gamma=pi/2, **options) + + def _represent_JzOp(self, basis, **options): + return self._represent_base(**options) + + +class JzBra(SpinState, Bra): + """Eigenbra of Jz. + + See the JzKet for the usage of spin eigenstates. + + See Also + ======== + + JzKet: Usage of spin states + + """ + + @classmethod + def dual_class(self): + return JzKet + + @classmethod + def coupled_class(self): + return JzBraCoupled + + +# Method used primarily to create coupled_n and coupled_jn by __new__ in +# CoupledSpinState +# This same method is also used by the uncouple method, and is separated from +# the CoupledSpinState class to maintain consistency in defining coupling +def _build_coupled(jcoupling, length): + n_list = [ [n + 1] for n in range(length) ] + coupled_jn = [] + coupled_n = [] + for n1, n2, j_new in jcoupling: + coupled_jn.append(j_new) + coupled_n.append( (n_list[n1 - 1], n_list[n2 - 1]) ) + n_sort = sorted(n_list[n1 - 1] + n_list[n2 - 1]) + n_list[n_sort[0] - 1] = n_sort + return coupled_n, coupled_jn + + +class CoupledSpinState(SpinState): + """Base class for coupled angular momentum states.""" + + def __new__(cls, j, m, jn, *jcoupling): + # Check j and m values using SpinState + SpinState(j, m) + # Build and check coupling scheme from arguments + if len(jcoupling) == 0: + # Use default coupling scheme + jcoupling = [] + for n in range(2, len(jn)): + jcoupling.append( (1, n, Add(*[jn[i] for i in range(n)])) ) + jcoupling.append( (1, len(jn), j) ) + elif len(jcoupling) == 1: + # Use specified coupling scheme + jcoupling = jcoupling[0] + else: + raise TypeError("CoupledSpinState only takes 3 or 4 arguments, got: %s" % (len(jcoupling) + 3) ) + # Check arguments have correct form + if not isinstance(jn, (list, tuple, Tuple)): + raise TypeError('jn must be Tuple, list or tuple, got %s' % + jn.__class__.__name__) + if not isinstance(jcoupling, (list, tuple, Tuple)): + raise TypeError('jcoupling must be Tuple, list or tuple, got %s' % + jcoupling.__class__.__name__) + if not all(isinstance(term, (list, tuple, Tuple)) for term in jcoupling): + raise TypeError( + 'All elements of jcoupling must be list, tuple or Tuple') + if not len(jn) - 1 == len(jcoupling): + raise ValueError('jcoupling must have length of %d, got %d' % + (len(jn) - 1, len(jcoupling))) + if not all(len(x) == 3 for x in jcoupling): + raise ValueError('All elements of jcoupling must have length 3') + # Build sympified args + j = sympify(j) + m = sympify(m) + jn = Tuple( *[sympify(ji) for ji in jn] ) + jcoupling = Tuple( *[Tuple(sympify( + n1), sympify(n2), sympify(ji)) for (n1, n2, ji) in jcoupling] ) + # Check values in coupling scheme give physical state + if any(2*ji != int(2*ji) for ji in jn if ji.is_number): + raise ValueError('All elements of jn must be integer or half-integer, got: %s' % jn) + if any(n1 != int(n1) or n2 != int(n2) for (n1, n2, _) in jcoupling): + raise ValueError('Indices in jcoupling must be integers') + if any(n1 < 1 or n2 < 1 or n1 > len(jn) or n2 > len(jn) for (n1, n2, _) in jcoupling): + raise ValueError('Indices must be between 1 and the number of coupled spin spaces') + if any(2*ji != int(2*ji) for (_, _, ji) in jcoupling if ji.is_number): + raise ValueError('All coupled j values in coupling scheme must be integer or half-integer') + coupled_n, coupled_jn = _build_coupled(jcoupling, len(jn)) + jvals = list(jn) + for n, (n1, n2) in enumerate(coupled_n): + j1 = jvals[min(n1) - 1] + j2 = jvals[min(n2) - 1] + j3 = coupled_jn[n] + if sympify(j1).is_number and sympify(j2).is_number and sympify(j3).is_number: + if j1 + j2 < j3: + raise ValueError('All couplings must have j1+j2 >= j3, ' + 'in coupling number %d got j1,j2,j3: %d,%d,%d' % (n + 1, j1, j2, j3)) + if abs(j1 - j2) > j3: + raise ValueError("All couplings must have |j1+j2| <= j3, " + "in coupling number %d got j1,j2,j3: %d,%d,%d" % (n + 1, j1, j2, j3)) + if int_valued(j1 + j2): + pass + jvals[min(n1 + n2) - 1] = j3 + if len(jcoupling) > 0 and jcoupling[-1][2] != j: + raise ValueError('Last j value coupled together must be the final j of the state') + # Return state + return State.__new__(cls, j, m, jn, jcoupling) + + def _print_label(self, printer, *args): + label = [printer._print(self.j), printer._print(self.m)] + for i, ji in enumerate(self.jn, start=1): + label.append('j%d=%s' % ( + i, printer._print(ji) + )) + for jn, (n1, n2) in zip(self.coupled_jn[:-1], self.coupled_n[:-1]): + label.append('j(%s)=%s' % ( + ','.join(str(i) for i in sorted(n1 + n2)), printer._print(jn) + )) + return ','.join(label) + + def _print_label_pretty(self, printer, *args): + label = [self.j, self.m] + for i, ji in enumerate(self.jn, start=1): + symb = 'j%d' % i + symb = pretty_symbol(symb) + symb = prettyForm(symb + '=') + item = prettyForm(*symb.right(printer._print(ji))) + label.append(item) + for jn, (n1, n2) in zip(self.coupled_jn[:-1], self.coupled_n[:-1]): + n = ','.join(pretty_symbol("j%d" % i)[-1] for i in sorted(n1 + n2)) + symb = prettyForm('j' + n + '=') + item = prettyForm(*symb.right(printer._print(jn))) + label.append(item) + return self._print_sequence_pretty( + label, self._label_separator, printer, *args + ) + + def _print_label_latex(self, printer, *args): + label = [ + printer._print(self.j, *args), + printer._print(self.m, *args) + ] + for i, ji in enumerate(self.jn, start=1): + label.append('j_{%d}=%s' % (i, printer._print(ji, *args)) ) + for jn, (n1, n2) in zip(self.coupled_jn[:-1], self.coupled_n[:-1]): + n = ','.join(str(i) for i in sorted(n1 + n2)) + label.append('j_{%s}=%s' % (n, printer._print(jn, *args)) ) + return self._label_separator.join(label) + + @property + def jn(self): + return self.label[2] + + @property + def coupling(self): + return self.label[3] + + @property + def coupled_jn(self): + return _build_coupled(self.label[3], len(self.label[2]))[1] + + @property + def coupled_n(self): + return _build_coupled(self.label[3], len(self.label[2]))[0] + + @classmethod + def _eval_hilbert_space(cls, label): + j = Add(*label[2]) + if j.is_number: + return DirectSumHilbertSpace(*[ ComplexSpace(x) for x in range(int(2*j + 1), 0, -2) ]) + else: + # TODO: Need hilbert space fix, see issue 5732 + # Desired behavior: + #ji = symbols('ji') + #ret = Sum(ComplexSpace(2*ji + 1), (ji, 0, j)) + # Temporary fix: + return ComplexSpace(2*j + 1) + + def _represent_coupled_base(self, **options): + evect = self.uncoupled_class() + if not self.j.is_number: + raise ValueError( + 'State must not have symbolic j value to represent') + if not self.hilbert_space.dimension.is_number: + raise ValueError( + 'State must not have symbolic j values to represent') + result = zeros(self.hilbert_space.dimension, 1) + if self.j == int(self.j): + start = self.j**2 + else: + start = (2*self.j - 1)*(1 + 2*self.j)/4 + result[start:start + 2*self.j + 1, 0] = evect( + self.j, self.m)._represent_base(**options) + return result + + def _eval_rewrite_as_Jx(self, *args, **options): + if isinstance(self, Bra): + return self._rewrite_basis(Jx, JxBraCoupled, **options) + return self._rewrite_basis(Jx, JxKetCoupled, **options) + + def _eval_rewrite_as_Jy(self, *args, **options): + if isinstance(self, Bra): + return self._rewrite_basis(Jy, JyBraCoupled, **options) + return self._rewrite_basis(Jy, JyKetCoupled, **options) + + def _eval_rewrite_as_Jz(self, *args, **options): + if isinstance(self, Bra): + return self._rewrite_basis(Jz, JzBraCoupled, **options) + return self._rewrite_basis(Jz, JzKetCoupled, **options) + + +class JxKetCoupled(CoupledSpinState, Ket): + """Coupled eigenket of Jx. + + See JzKetCoupled for the usage of coupled spin eigenstates. + + See Also + ======== + + JzKetCoupled: Usage of coupled spin states + + """ + + @classmethod + def dual_class(self): + return JxBraCoupled + + @classmethod + def uncoupled_class(self): + return JxKet + + def _represent_default_basis(self, **options): + return self._represent_JzOp(None, **options) + + def _represent_JxOp(self, basis, **options): + return self._represent_coupled_base(**options) + + def _represent_JyOp(self, basis, **options): + return self._represent_coupled_base(alpha=pi*Rational(3, 2), **options) + + def _represent_JzOp(self, basis, **options): + return self._represent_coupled_base(beta=pi/2, **options) + + +class JxBraCoupled(CoupledSpinState, Bra): + """Coupled eigenbra of Jx. + + See JzKetCoupled for the usage of coupled spin eigenstates. + + See Also + ======== + + JzKetCoupled: Usage of coupled spin states + + """ + + @classmethod + def dual_class(self): + return JxKetCoupled + + @classmethod + def uncoupled_class(self): + return JxBra + + +class JyKetCoupled(CoupledSpinState, Ket): + """Coupled eigenket of Jy. + + See JzKetCoupled for the usage of coupled spin eigenstates. + + See Also + ======== + + JzKetCoupled: Usage of coupled spin states + + """ + + @classmethod + def dual_class(self): + return JyBraCoupled + + @classmethod + def uncoupled_class(self): + return JyKet + + def _represent_default_basis(self, **options): + return self._represent_JzOp(None, **options) + + def _represent_JxOp(self, basis, **options): + return self._represent_coupled_base(gamma=pi/2, **options) + + def _represent_JyOp(self, basis, **options): + return self._represent_coupled_base(**options) + + def _represent_JzOp(self, basis, **options): + return self._represent_coupled_base(alpha=pi*Rational(3, 2), beta=-pi/2, gamma=pi/2, **options) + + +class JyBraCoupled(CoupledSpinState, Bra): + """Coupled eigenbra of Jy. + + See JzKetCoupled for the usage of coupled spin eigenstates. + + See Also + ======== + + JzKetCoupled: Usage of coupled spin states + + """ + + @classmethod + def dual_class(self): + return JyKetCoupled + + @classmethod + def uncoupled_class(self): + return JyBra + + +class JzKetCoupled(CoupledSpinState, Ket): + r"""Coupled eigenket of Jz + + Spin state that is an eigenket of Jz which represents the coupling of + separate spin spaces. + + The arguments for creating instances of JzKetCoupled are ``j``, ``m``, + ``jn`` and an optional ``jcoupling`` argument. The ``j`` and ``m`` options + are the total angular momentum quantum numbers, as used for normal states + (e.g. JzKet). + + The other required parameter in ``jn``, which is a tuple defining the `j_n` + angular momentum quantum numbers of the product spaces. So for example, if + a state represented the coupling of the product basis state + `\left|j_1,m_1\right\rangle\times\left|j_2,m_2\right\rangle`, the ``jn`` + for this state would be ``(j1,j2)``. + + The final option is ``jcoupling``, which is used to define how the spaces + specified by ``jn`` are coupled, which includes both the order these spaces + are coupled together and the quantum numbers that arise from these + couplings. The ``jcoupling`` parameter itself is a list of lists, such that + each of the sublists defines a single coupling between the spin spaces. If + there are N coupled angular momentum spaces, that is ``jn`` has N elements, + then there must be N-1 sublists. Each of these sublists making up the + ``jcoupling`` parameter have length 3. The first two elements are the + indices of the product spaces that are considered to be coupled together. + For example, if we want to couple `j_1` and `j_4`, the indices would be 1 + and 4. If a state has already been coupled, it is referenced by the + smallest index that is coupled, so if `j_2` and `j_4` has already been + coupled to some `j_{24}`, then this value can be coupled by referencing it + with index 2. The final element of the sublist is the quantum number of the + coupled state. So putting everything together, into a valid sublist for + ``jcoupling``, if `j_1` and `j_2` are coupled to an angular momentum space + with quantum number `j_{12}` with the value ``j12``, the sublist would be + ``(1,2,j12)``, N-1 of these sublists are used in the list for + ``jcoupling``. + + Note the ``jcoupling`` parameter is optional, if it is not specified, the + default coupling is taken. This default value is to coupled the spaces in + order and take the quantum number of the coupling to be the maximum value. + For example, if the spin spaces are `j_1`, `j_2`, `j_3`, `j_4`, then the + default coupling couples `j_1` and `j_2` to `j_{12}=j_1+j_2`, then, + `j_{12}` and `j_3` are coupled to `j_{123}=j_{12}+j_3`, and finally + `j_{123}` and `j_4` to `j=j_{123}+j_4`. The jcoupling value that would + correspond to this is: + + ``((1,2,j1+j2),(1,3,j1+j2+j3))`` + + Parameters + ========== + + args : tuple + The arguments that must be passed are ``j``, ``m``, ``jn``, and + ``jcoupling``. The ``j`` value is the total angular momentum. The ``m`` + value is the eigenvalue of the Jz spin operator. The ``jn`` list are + the j values of argular momentum spaces coupled together. The + ``jcoupling`` parameter is an optional parameter defining how the spaces + are coupled together. See the above description for how these coupling + parameters are defined. + + Examples + ======== + + Defining simple spin states, both numerical and symbolic: + + >>> from sympy.physics.quantum.spin import JzKetCoupled + >>> from sympy import symbols + >>> JzKetCoupled(1, 0, (1, 1)) + |1,0,j1=1,j2=1> + >>> j, m, j1, j2 = symbols('j m j1 j2') + >>> JzKetCoupled(j, m, (j1, j2)) + |j,m,j1=j1,j2=j2> + + Defining coupled spin states for more than 2 coupled spaces with various + coupling parameters: + + >>> JzKetCoupled(2, 1, (1, 1, 1)) + |2,1,j1=1,j2=1,j3=1,j(1,2)=2> + >>> JzKetCoupled(2, 1, (1, 1, 1), ((1,2,2),(1,3,2)) ) + |2,1,j1=1,j2=1,j3=1,j(1,2)=2> + >>> JzKetCoupled(2, 1, (1, 1, 1), ((2,3,1),(1,2,2)) ) + |2,1,j1=1,j2=1,j3=1,j(2,3)=1> + + Rewriting the JzKetCoupled in terms of eigenkets of the Jx operator: + Note: that the resulting eigenstates are JxKetCoupled + + >>> JzKetCoupled(1,1,(1,1)).rewrite("Jx") + |1,-1,j1=1,j2=1>/2 - sqrt(2)*|1,0,j1=1,j2=1>/2 + |1,1,j1=1,j2=1>/2 + + The rewrite method can be used to convert a coupled state to an uncoupled + state. This is done by passing coupled=False to the rewrite function: + + >>> JzKetCoupled(1, 0, (1, 1)).rewrite('Jz', coupled=False) + -sqrt(2)*|1,-1>x|1,1>/2 + sqrt(2)*|1,1>x|1,-1>/2 + + Get the vector representation of a state in terms of the basis elements + of the Jx operator: + + >>> from sympy.physics.quantum.represent import represent + >>> from sympy.physics.quantum.spin import Jx + >>> from sympy import S + >>> represent(JzKetCoupled(1,-1,(S(1)/2,S(1)/2)), basis=Jx) + Matrix([ + [ 0], + [ 1/2], + [sqrt(2)/2], + [ 1/2]]) + + See Also + ======== + + JzKet: Normal spin eigenstates + uncouple: Uncoupling of coupling spin states + couple: Coupling of uncoupled spin states + + """ + + @classmethod + def dual_class(self): + return JzBraCoupled + + @classmethod + def uncoupled_class(self): + return JzKet + + def _represent_default_basis(self, **options): + return self._represent_JzOp(None, **options) + + def _represent_JxOp(self, basis, **options): + return self._represent_coupled_base(beta=pi*Rational(3, 2), **options) + + def _represent_JyOp(self, basis, **options): + return self._represent_coupled_base(alpha=pi*Rational(3, 2), beta=pi/2, gamma=pi/2, **options) + + def _represent_JzOp(self, basis, **options): + return self._represent_coupled_base(**options) + + +class JzBraCoupled(CoupledSpinState, Bra): + """Coupled eigenbra of Jz. + + See the JzKetCoupled for the usage of coupled spin eigenstates. + + See Also + ======== + + JzKetCoupled: Usage of coupled spin states + + """ + + @classmethod + def dual_class(self): + return JzKetCoupled + + @classmethod + def uncoupled_class(self): + return JzBra + +#----------------------------------------------------------------------------- +# Coupling/uncoupling +#----------------------------------------------------------------------------- + + +def couple(expr, jcoupling_list=None): + """ Couple a tensor product of spin states + + This function can be used to couple an uncoupled tensor product of spin + states. All of the eigenstates to be coupled must be of the same class. It + will return a linear combination of eigenstates that are subclasses of + CoupledSpinState determined by Clebsch-Gordan angular momentum coupling + coefficients. + + Parameters + ========== + + expr : Expr + An expression involving TensorProducts of spin states to be coupled. + Each state must be a subclass of SpinState and they all must be the + same class. + + jcoupling_list : list or tuple + Elements of this list are sub-lists of length 2 specifying the order of + the coupling of the spin spaces. The length of this must be N-1, where N + is the number of states in the tensor product to be coupled. The + elements of this sublist are the same as the first two elements of each + sublist in the ``jcoupling`` parameter defined for JzKetCoupled. If this + parameter is not specified, the default value is taken, which couples + the first and second product basis spaces, then couples this new coupled + space to the third product space, etc + + Examples + ======== + + Couple a tensor product of numerical states for two spaces: + + >>> from sympy.physics.quantum.spin import JzKet, couple + >>> from sympy.physics.quantum.tensorproduct import TensorProduct + >>> couple(TensorProduct(JzKet(1,0), JzKet(1,1))) + -sqrt(2)*|1,1,j1=1,j2=1>/2 + sqrt(2)*|2,1,j1=1,j2=1>/2 + + + Numerical coupling of three spaces using the default coupling method, i.e. + first and second spaces couple, then this couples to the third space: + + >>> couple(TensorProduct(JzKet(1,1), JzKet(1,1), JzKet(1,0))) + sqrt(6)*|2,2,j1=1,j2=1,j3=1,j(1,2)=2>/3 + sqrt(3)*|3,2,j1=1,j2=1,j3=1,j(1,2)=2>/3 + + Perform this same coupling, but we define the coupling to first couple + the first and third spaces: + + >>> couple(TensorProduct(JzKet(1,1), JzKet(1,1), JzKet(1,0)), ((1,3),(1,2)) ) + sqrt(2)*|2,2,j1=1,j2=1,j3=1,j(1,3)=1>/2 - sqrt(6)*|2,2,j1=1,j2=1,j3=1,j(1,3)=2>/6 + sqrt(3)*|3,2,j1=1,j2=1,j3=1,j(1,3)=2>/3 + + Couple a tensor product of symbolic states: + + >>> from sympy import symbols + >>> j1,m1,j2,m2 = symbols('j1 m1 j2 m2') + >>> couple(TensorProduct(JzKet(j1,m1), JzKet(j2,m2))) + Sum(CG(j1, m1, j2, m2, j, m1 + m2)*|j,m1 + m2,j1=j1,j2=j2>, (j, m1 + m2, j1 + j2)) + + """ + a = expr.atoms(TensorProduct) + for tp in a: + # Allow other tensor products to be in expression + if not all(isinstance(state, SpinState) for state in tp.args): + continue + # If tensor product has all spin states, raise error for invalid tensor product state + if not all(state.__class__ is tp.args[0].__class__ for state in tp.args): + raise TypeError('All states must be the same basis') + expr = expr.subs(tp, _couple(tp, jcoupling_list)) + return expr + + +def _couple(tp, jcoupling_list): + states = tp.args + coupled_evect = states[0].coupled_class() + + # Define default coupling if none is specified + if jcoupling_list is None: + jcoupling_list = [] + for n in range(1, len(states)): + jcoupling_list.append( (1, n + 1) ) + + # Check jcoupling_list valid + if not len(jcoupling_list) == len(states) - 1: + raise TypeError('jcoupling_list must be length %d, got %d' % + (len(states) - 1, len(jcoupling_list))) + if not all( len(coupling) == 2 for coupling in jcoupling_list): + raise ValueError('Each coupling must define 2 spaces') + if any(n1 == n2 for n1, n2 in jcoupling_list): + raise ValueError('Spin spaces cannot couple to themselves') + if all(sympify(n1).is_number and sympify(n2).is_number for n1, n2 in jcoupling_list): + j_test = [0]*len(states) + for n1, n2 in jcoupling_list: + if j_test[n1 - 1] == -1 or j_test[n2 - 1] == -1: + raise ValueError('Spaces coupling j_n\'s are referenced by smallest n value') + j_test[max(n1, n2) - 1] = -1 + + # j values of states to be coupled together + jn = [state.j for state in states] + mn = [state.m for state in states] + + # Create coupling_list, which defines all the couplings between all + # the spaces from jcoupling_list + coupling_list = [] + n_list = [ [i + 1] for i in range(len(states)) ] + for j_coupling in jcoupling_list: + # Least n for all j_n which is coupled as first and second spaces + n1, n2 = j_coupling + # List of all n's coupled in first and second spaces + j1_n = list(n_list[n1 - 1]) + j2_n = list(n_list[n2 - 1]) + coupling_list.append( (j1_n, j2_n) ) + # Set new j_n to be coupling of all j_n in both first and second spaces + n_list[ min(n1, n2) - 1 ] = sorted(j1_n + j2_n) + + if all(state.j.is_number and state.m.is_number for state in states): + # Numerical coupling + # Iterate over difference between maximum possible j value of each coupling and the actual value + diff_max = [ Add( *[ jn[n - 1] - mn[n - 1] for n in coupling[0] + + coupling[1] ] ) for coupling in coupling_list ] + result = [] + for diff in range(diff_max[-1] + 1): + # Determine available configurations + n = len(coupling_list) + tot = binomial(diff + n - 1, diff) + + for config_num in range(tot): + diff_list = _confignum_to_difflist(config_num, diff, n) + + # Skip the configuration if non-physical + # This is a lazy check for physical states given the loose restrictions of diff_max + if any(d > m for d, m in zip(diff_list, diff_max)): + continue + + # Determine term + cg_terms = [] + coupled_j = list(jn) + jcoupling = [] + for (j1_n, j2_n), coupling_diff in zip(coupling_list, diff_list): + j1 = coupled_j[ min(j1_n) - 1 ] + j2 = coupled_j[ min(j2_n) - 1 ] + j3 = j1 + j2 - coupling_diff + coupled_j[ min(j1_n + j2_n) - 1 ] = j3 + m1 = Add( *[ mn[x - 1] for x in j1_n] ) + m2 = Add( *[ mn[x - 1] for x in j2_n] ) + m3 = m1 + m2 + cg_terms.append( (j1, m1, j2, m2, j3, m3) ) + jcoupling.append( (min(j1_n), min(j2_n), j3) ) + # Better checks that state is physical + if any(abs(term[5]) > term[4] for term in cg_terms): + continue + if any(term[0] + term[2] < term[4] for term in cg_terms): + continue + if any(abs(term[0] - term[2]) > term[4] for term in cg_terms): + continue + coeff = Mul( *[ CG(*term).doit() for term in cg_terms] ) + state = coupled_evect(j3, m3, jn, jcoupling) + result.append(coeff*state) + return Add(*result) + else: + # Symbolic coupling + cg_terms = [] + jcoupling = [] + sum_terms = [] + coupled_j = list(jn) + for j1_n, j2_n in coupling_list: + j1 = coupled_j[ min(j1_n) - 1 ] + j2 = coupled_j[ min(j2_n) - 1 ] + if len(j1_n + j2_n) == len(states): + j3 = symbols('j') + else: + j3_name = 'j' + ''.join(["%s" % n for n in j1_n + j2_n]) + j3 = symbols(j3_name) + coupled_j[ min(j1_n + j2_n) - 1 ] = j3 + m1 = Add( *[ mn[x - 1] for x in j1_n] ) + m2 = Add( *[ mn[x - 1] for x in j2_n] ) + m3 = m1 + m2 + cg_terms.append( (j1, m1, j2, m2, j3, m3) ) + jcoupling.append( (min(j1_n), min(j2_n), j3) ) + sum_terms.append((j3, m3, j1 + j2)) + coeff = Mul( *[ CG(*term) for term in cg_terms] ) + state = coupled_evect(j3, m3, jn, jcoupling) + return Sum(coeff*state, *sum_terms) + + +def uncouple(expr, jn=None, jcoupling_list=None): + """ Uncouple a coupled spin state + + Gives the uncoupled representation of a coupled spin state. Arguments must + be either a spin state that is a subclass of CoupledSpinState or a spin + state that is a subclass of SpinState and an array giving the j values + of the spaces that are to be coupled + + Parameters + ========== + + expr : Expr + The expression containing states that are to be coupled. If the states + are a subclass of SpinState, the ``jn`` and ``jcoupling`` parameters + must be defined. If the states are a subclass of CoupledSpinState, + ``jn`` and ``jcoupling`` will be taken from the state. + + jn : list or tuple + The list of the j-values that are coupled. If state is a + CoupledSpinState, this parameter is ignored. This must be defined if + state is not a subclass of CoupledSpinState. The syntax of this + parameter is the same as the ``jn`` parameter of JzKetCoupled. + + jcoupling_list : list or tuple + The list defining how the j-values are coupled together. If state is a + CoupledSpinState, this parameter is ignored. This must be defined if + state is not a subclass of CoupledSpinState. The syntax of this + parameter is the same as the ``jcoupling`` parameter of JzKetCoupled. + + Examples + ======== + + Uncouple a numerical state using a CoupledSpinState state: + + >>> from sympy.physics.quantum.spin import JzKetCoupled, uncouple + >>> from sympy import S + >>> uncouple(JzKetCoupled(1, 0, (S(1)/2, S(1)/2))) + sqrt(2)*|1/2,-1/2>x|1/2,1/2>/2 + sqrt(2)*|1/2,1/2>x|1/2,-1/2>/2 + + Perform the same calculation using a SpinState state: + + >>> from sympy.physics.quantum.spin import JzKet + >>> uncouple(JzKet(1, 0), (S(1)/2, S(1)/2)) + sqrt(2)*|1/2,-1/2>x|1/2,1/2>/2 + sqrt(2)*|1/2,1/2>x|1/2,-1/2>/2 + + Uncouple a numerical state of three coupled spaces using a CoupledSpinState state: + + >>> uncouple(JzKetCoupled(1, 1, (1, 1, 1), ((1,3,1),(1,2,1)) )) + |1,-1>x|1,1>x|1,1>/2 - |1,0>x|1,0>x|1,1>/2 + |1,1>x|1,0>x|1,0>/2 - |1,1>x|1,1>x|1,-1>/2 + + Perform the same calculation using a SpinState state: + + >>> uncouple(JzKet(1, 1), (1, 1, 1), ((1,3,1),(1,2,1)) ) + |1,-1>x|1,1>x|1,1>/2 - |1,0>x|1,0>x|1,1>/2 + |1,1>x|1,0>x|1,0>/2 - |1,1>x|1,1>x|1,-1>/2 + + Uncouple a symbolic state using a CoupledSpinState state: + + >>> from sympy import symbols + >>> j,m,j1,j2 = symbols('j m j1 j2') + >>> uncouple(JzKetCoupled(j, m, (j1, j2))) + Sum(CG(j1, m1, j2, m2, j, m)*|j1,m1>x|j2,m2>, (m1, -j1, j1), (m2, -j2, j2)) + + Perform the same calculation using a SpinState state + + >>> uncouple(JzKet(j, m), (j1, j2)) + Sum(CG(j1, m1, j2, m2, j, m)*|j1,m1>x|j2,m2>, (m1, -j1, j1), (m2, -j2, j2)) + + """ + a = expr.atoms(SpinState) + for state in a: + expr = expr.subs(state, _uncouple(state, jn, jcoupling_list)) + return expr + + +def _uncouple(state, jn, jcoupling_list): + if isinstance(state, CoupledSpinState): + jn = state.jn + coupled_n = state.coupled_n + coupled_jn = state.coupled_jn + evect = state.uncoupled_class() + elif isinstance(state, SpinState): + if jn is None: + raise ValueError("Must specify j-values for coupled state") + if not isinstance(jn, (list, tuple)): + raise TypeError("jn must be list or tuple") + if jcoupling_list is None: + # Use default + jcoupling_list = [] + for i in range(1, len(jn)): + jcoupling_list.append( + (1, 1 + i, Add(*[jn[j] for j in range(i + 1)])) ) + if not isinstance(jcoupling_list, (list, tuple)): + raise TypeError("jcoupling must be a list or tuple") + if not len(jcoupling_list) == len(jn) - 1: + raise ValueError("Must specify 2 fewer coupling terms than the number of j values") + coupled_n, coupled_jn = _build_coupled(jcoupling_list, len(jn)) + evect = state.__class__ + else: + raise TypeError("state must be a spin state") + j = state.j + m = state.m + coupling_list = [] + j_list = list(jn) + + # Create coupling, which defines all the couplings between all the spaces + for j3, (n1, n2) in zip(coupled_jn, coupled_n): + # j's which are coupled as first and second spaces + j1 = j_list[n1[0] - 1] + j2 = j_list[n2[0] - 1] + # Build coupling list + coupling_list.append( (n1, n2, j1, j2, j3) ) + # Set new value in j_list + j_list[min(n1 + n2) - 1] = j3 + + if j.is_number and m.is_number: + diff_max = [ 2*x for x in jn ] + diff = Add(*jn) - m + + n = len(jn) + tot = binomial(diff + n - 1, diff) + + result = [] + for config_num in range(tot): + diff_list = _confignum_to_difflist(config_num, diff, n) + if any(d > p for d, p in zip(diff_list, diff_max)): + continue + + cg_terms = [] + for coupling in coupling_list: + j1_n, j2_n, j1, j2, j3 = coupling + m1 = Add( *[ jn[x - 1] - diff_list[x - 1] for x in j1_n ] ) + m2 = Add( *[ jn[x - 1] - diff_list[x - 1] for x in j2_n ] ) + m3 = m1 + m2 + cg_terms.append( (j1, m1, j2, m2, j3, m3) ) + coeff = Mul( *[ CG(*term).doit() for term in cg_terms ] ) + state = TensorProduct( + *[ evect(j, j - d) for j, d in zip(jn, diff_list) ] ) + result.append(coeff*state) + return Add(*result) + else: + # Symbolic coupling + m_str = "m1:%d" % (len(jn) + 1) + mvals = symbols(m_str) + cg_terms = [(j1, Add(*[mvals[n - 1] for n in j1_n]), + j2, Add(*[mvals[n - 1] for n in j2_n]), + j3, Add(*[mvals[n - 1] for n in j1_n + j2_n])) for j1_n, j2_n, j1, j2, j3 in coupling_list[:-1] ] + cg_terms.append(*[(j1, Add(*[mvals[n - 1] for n in j1_n]), + j2, Add(*[mvals[n - 1] for n in j2_n]), + j, m) for j1_n, j2_n, j1, j2, j3 in [coupling_list[-1]] ]) + cg_coeff = Mul(*[CG(*cg_term) for cg_term in cg_terms]) + sum_terms = [ (m, -j, j) for j, m in zip(jn, mvals) ] + state = TensorProduct( *[ evect(j, m) for j, m in zip(jn, mvals) ] ) + return Sum(cg_coeff*state, *sum_terms) + + +def _confignum_to_difflist(config_num, diff, list_len): + # Determines configuration of diffs into list_len number of slots + diff_list = [] + for n in range(list_len): + prev_diff = diff + # Number of spots after current one + rem_spots = list_len - n - 1 + # Number of configurations of distributing diff among the remaining spots + rem_configs = binomial(diff + rem_spots - 1, diff) + while config_num >= rem_configs: + config_num -= rem_configs + diff -= 1 + rem_configs = binomial(diff + rem_spots - 1, diff) + diff_list.append(prev_diff - diff) + return diff_list diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/state.py b/phivenv/Lib/site-packages/sympy/physics/quantum/state.py new file mode 100644 index 0000000000000000000000000000000000000000..4ccd1ce9b9875b59a5d1293ab3026808bdc85b27 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/state.py @@ -0,0 +1,987 @@ +"""Dirac notation for states.""" + +from sympy.core.cache import cacheit +from sympy.core.containers import Tuple +from sympy.core.expr import Expr +from sympy.core.function import Function +from sympy.core.numbers import oo, equal_valued +from sympy.core.singleton import S +from sympy.functions.elementary.complexes import conjugate +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.integrals.integrals import integrate +from sympy.printing.pretty.stringpict import stringPict +from sympy.physics.quantum.qexpr import QExpr, dispatch_method +from sympy.physics.quantum.kind import KetKind, BraKind + + +__all__ = [ + 'KetBase', + 'BraBase', + 'StateBase', + 'State', + 'Ket', + 'Bra', + 'TimeDepState', + 'TimeDepBra', + 'TimeDepKet', + 'OrthogonalKet', + 'OrthogonalBra', + 'OrthogonalState', + 'Wavefunction' +] + + +#----------------------------------------------------------------------------- +# States, bras and kets. +#----------------------------------------------------------------------------- + +# ASCII brackets +_lbracket = "<" +_rbracket = ">" +_straight_bracket = "|" + + +# Unicode brackets +# MATHEMATICAL ANGLE BRACKETS +_lbracket_ucode = "\N{MATHEMATICAL LEFT ANGLE BRACKET}" +_rbracket_ucode = "\N{MATHEMATICAL RIGHT ANGLE BRACKET}" +# LIGHT VERTICAL BAR +_straight_bracket_ucode = "\N{LIGHT VERTICAL BAR}" + +# Other options for unicode printing of <, > and | for Dirac notation. + +# LEFT-POINTING ANGLE BRACKET +# _lbracket = "\u2329" +# _rbracket = "\u232A" + +# LEFT ANGLE BRACKET +# _lbracket = "\u3008" +# _rbracket = "\u3009" + +# VERTICAL LINE +# _straight_bracket = "\u007C" + + +class StateBase(QExpr): + """Abstract base class for general abstract states in quantum mechanics. + + All other state classes defined will need to inherit from this class. It + carries the basic structure for all other states such as dual, _eval_adjoint + and label. + + This is an abstract base class and you should not instantiate it directly, + instead use State. + """ + + @classmethod + def _operators_to_state(self, ops, **options): + """ Returns the eigenstate instance for the passed operators. + + This method should be overridden in subclasses. It will handle being + passed either an Operator instance or set of Operator instances. It + should return the corresponding state INSTANCE or simply raise a + NotImplementedError. See cartesian.py for an example. + """ + + raise NotImplementedError("Cannot map operators to states in this class. Method not implemented!") + + def _state_to_operators(self, op_classes, **options): + """ Returns the operators which this state instance is an eigenstate + of. + + This method should be overridden in subclasses. It will be called on + state instances and be passed the operator classes that we wish to make + into instances. The state instance will then transform the classes + appropriately, or raise a NotImplementedError if it cannot return + operator instances. See cartesian.py for examples, + """ + + raise NotImplementedError( + "Cannot map this state to operators. Method not implemented!") + + @property + def operators(self): + """Return the operator(s) that this state is an eigenstate of""" + from .operatorset import state_to_operators # import internally to avoid circular import errors + return state_to_operators(self) + + def _enumerate_state(self, num_states, **options): + raise NotImplementedError("Cannot enumerate this state!") + + def _represent_default_basis(self, **options): + return self._represent(basis=self.operators) + + def _apply_operator(self, op, **options): + return None + + #------------------------------------------------------------------------- + # Dagger/dual + #------------------------------------------------------------------------- + + @property + def dual(self): + """Return the dual state of this one.""" + return self.dual_class()._new_rawargs(self.hilbert_space, *self.args) + + @classmethod + def dual_class(self): + """Return the class used to construct the dual.""" + raise NotImplementedError( + 'dual_class must be implemented in a subclass' + ) + + def _eval_adjoint(self): + """Compute the dagger of this state using the dual.""" + return self.dual + + #------------------------------------------------------------------------- + # Printing + #------------------------------------------------------------------------- + + def _pretty_brackets(self, height, use_unicode=True): + # Return pretty printed brackets for the state + # Ideally, this could be done by pform.parens but it does not support the angled < and > + + # Setup for unicode vs ascii + if use_unicode: + lbracket, rbracket = getattr(self, 'lbracket_ucode', ""), getattr(self, 'rbracket_ucode', "") + slash, bslash, vert = '\N{BOX DRAWINGS LIGHT DIAGONAL UPPER RIGHT TO LOWER LEFT}', \ + '\N{BOX DRAWINGS LIGHT DIAGONAL UPPER LEFT TO LOWER RIGHT}', \ + '\N{BOX DRAWINGS LIGHT VERTICAL}' + else: + lbracket, rbracket = getattr(self, 'lbracket', ""), getattr(self, 'rbracket', "") + slash, bslash, vert = '/', '\\', '|' + + # If height is 1, just return brackets + if height == 1: + return stringPict(lbracket), stringPict(rbracket) + # Make height even + height += (height % 2) + + brackets = [] + for bracket in lbracket, rbracket: + # Create left bracket + if bracket in {_lbracket, _lbracket_ucode}: + bracket_args = [ ' ' * (height//2 - i - 1) + + slash for i in range(height // 2)] + bracket_args.extend( + [' ' * i + bslash for i in range(height // 2)]) + # Create right bracket + elif bracket in {_rbracket, _rbracket_ucode}: + bracket_args = [ ' ' * i + bslash for i in range(height // 2)] + bracket_args.extend([ ' ' * ( + height//2 - i - 1) + slash for i in range(height // 2)]) + # Create straight bracket + elif bracket in {_straight_bracket, _straight_bracket_ucode}: + bracket_args = [vert] * height + else: + raise ValueError(bracket) + brackets.append( + stringPict('\n'.join(bracket_args), baseline=height//2)) + return brackets + + def _sympystr(self, printer, *args): + contents = self._print_contents(printer, *args) + return '%s%s%s' % (getattr(self, 'lbracket', ""), contents, getattr(self, 'rbracket', "")) + + def _pretty(self, printer, *args): + from sympy.printing.pretty.stringpict import prettyForm + # Get brackets + pform = self._print_contents_pretty(printer, *args) + lbracket, rbracket = self._pretty_brackets( + pform.height(), printer._use_unicode) + # Put together state + pform = prettyForm(*pform.left(lbracket)) + pform = prettyForm(*pform.right(rbracket)) + return pform + + def _latex(self, printer, *args): + contents = self._print_contents_latex(printer, *args) + # The extra {} brackets are needed to get matplotlib's latex + # rendered to render this properly. + return '{%s%s%s}' % (getattr(self, 'lbracket_latex', ""), contents, getattr(self, 'rbracket_latex', "")) + + +class KetBase(StateBase): + """Base class for Kets. + + This class defines the dual property and the brackets for printing. This is + an abstract base class and you should not instantiate it directly, instead + use Ket. + """ + + kind = KetKind + + lbracket = _straight_bracket + rbracket = _rbracket + lbracket_ucode = _straight_bracket_ucode + rbracket_ucode = _rbracket_ucode + lbracket_latex = r'\left|' + rbracket_latex = r'\right\rangle ' + + @classmethod + def default_args(self): + return ("psi",) + + @classmethod + def dual_class(self): + return BraBase + + #------------------------------------------------------------------------- + # _eval_* methods + #------------------------------------------------------------------------- + + def _eval_innerproduct(self, bra, **hints): + """Evaluate the inner product between this ket and a bra. + + This is called to compute , where the ket is ``self``. + + This method will dispatch to sub-methods having the format:: + + ``def _eval_innerproduct_BraClass(self, **hints):`` + + Subclasses should define these methods (one for each BraClass) to + teach the ket how to take inner products with bras. + """ + return dispatch_method(self, '_eval_innerproduct', bra, **hints) + + def _apply_from_right_to(self, op, **options): + """Apply an Operator to this Ket as Operator*Ket + + This method will dispatch to methods having the format:: + + ``def _apply_from_right_to_OperatorName(op, **options):`` + + Subclasses should define these methods (one for each OperatorName) to + teach the Ket how to implement OperatorName*Ket + + Parameters + ========== + + op : Operator + The Operator that is acting on the Ket as op*Ket + options : dict + A dict of key/value pairs that control how the operator is applied + to the Ket. + """ + return dispatch_method(self, '_apply_from_right_to', op, **options) + + +class BraBase(StateBase): + """Base class for Bras. + + This class defines the dual property and the brackets for printing. This + is an abstract base class and you should not instantiate it directly, + instead use Bra. + """ + + kind = BraKind + + lbracket = _lbracket + rbracket = _straight_bracket + lbracket_ucode = _lbracket_ucode + rbracket_ucode = _straight_bracket_ucode + lbracket_latex = r'\left\langle ' + rbracket_latex = r'\right|' + + @classmethod + def _operators_to_state(self, ops, **options): + state = self.dual_class()._operators_to_state(ops, **options) + return state.dual + + def _state_to_operators(self, op_classes, **options): + return self.dual._state_to_operators(op_classes, **options) + + def _enumerate_state(self, num_states, **options): + dual_states = self.dual._enumerate_state(num_states, **options) + return [x.dual for x in dual_states] + + @classmethod + def default_args(self): + return self.dual_class().default_args() + + @classmethod + def dual_class(self): + return KetBase + + def _represent(self, **options): + """A default represent that uses the Ket's version.""" + from sympy.physics.quantum.dagger import Dagger + return Dagger(self.dual._represent(**options)) + + +class State(StateBase): + """General abstract quantum state used as a base class for Ket and Bra.""" + pass + + +class Ket(State, KetBase): + """A general time-independent Ket in quantum mechanics. + + Inherits from State and KetBase. This class should be used as the base + class for all physical, time-independent Kets in a system. This class + and its subclasses will be the main classes that users will use for + expressing Kets in Dirac notation [1]_. + + Parameters + ========== + + args : tuple + The list of numbers or parameters that uniquely specify the + ket. This will usually be its symbol or its quantum numbers. For + time-dependent state, this will include the time. + + Examples + ======== + + Create a simple Ket and looking at its properties:: + + >>> from sympy.physics.quantum import Ket + >>> from sympy import symbols, I + >>> k = Ket('psi') + >>> k + |psi> + >>> k.hilbert_space + H + >>> k.is_commutative + False + >>> k.label + (psi,) + + Ket's know about their associated bra:: + + >>> k.dual + >> k.dual_class() + + + Take a linear combination of two kets:: + + >>> k0 = Ket(0) + >>> k1 = Ket(1) + >>> 2*I*k0 - 4*k1 + 2*I*|0> - 4*|1> + + Compound labels are passed as tuples:: + + >>> n, m = symbols('n,m') + >>> k = Ket(n,m) + >>> k + |nm> + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Bra-ket_notation + """ + + @classmethod + def dual_class(self): + return Bra + + +class Bra(State, BraBase): + """A general time-independent Bra in quantum mechanics. + + Inherits from State and BraBase. A Bra is the dual of a Ket [1]_. This + class and its subclasses will be the main classes that users will use for + expressing Bras in Dirac notation. + + Parameters + ========== + + args : tuple + The list of numbers or parameters that uniquely specify the + ket. This will usually be its symbol or its quantum numbers. For + time-dependent state, this will include the time. + + Examples + ======== + + Create a simple Bra and look at its properties:: + + >>> from sympy.physics.quantum import Bra + >>> from sympy import symbols, I + >>> b = Bra('psi') + >>> b + >> b.hilbert_space + H + >>> b.is_commutative + False + + Bra's know about their dual Ket's:: + + >>> b.dual + |psi> + >>> b.dual_class() + + + Like Kets, Bras can have compound labels and be manipulated in a similar + manner:: + + >>> n, m = symbols('n,m') + >>> b = Bra(n,m) - I*Bra(m,n) + >>> b + -I*>> b.subs(n,m) + >> from sympy.physics.quantum import TimeDepKet + >>> k = TimeDepKet('psi', 't') + >>> k + |psi;t> + >>> k.time + t + >>> k.label + (psi,) + >>> k.hilbert_space + H + + TimeDepKets know about their dual bra:: + + >>> k.dual + >> k.dual_class() + + """ + + @classmethod + def dual_class(self): + return TimeDepBra + + +class TimeDepBra(TimeDepState, BraBase): + """General time-dependent Bra in quantum mechanics. + + This inherits from TimeDepState and BraBase and is the main class that + should be used for Bras that vary with time. Its dual is a TimeDepBra. + + Parameters + ========== + + args : tuple + The list of numbers or parameters that uniquely specify the ket. This + will usually be its symbol or its quantum numbers. For time-dependent + state, this will include the time as the final argument. + + Examples + ======== + + >>> from sympy.physics.quantum import TimeDepBra + >>> b = TimeDepBra('psi', 't') + >>> b + >> b.time + t + >>> b.label + (psi,) + >>> b.hilbert_space + H + >>> b.dual + |psi;t> + """ + + @classmethod + def dual_class(self): + return TimeDepKet + + +class OrthogonalState(State): + """General abstract quantum state used as a base class for Ket and Bra.""" + pass + +class OrthogonalKet(OrthogonalState, KetBase): + """Orthogonal Ket in quantum mechanics. + + The inner product of two states with different labels will give zero, + states with the same label will give one. + + >>> from sympy.physics.quantum import OrthogonalBra, OrthogonalKet + >>> from sympy.abc import m, n + >>> (OrthogonalBra(n)*OrthogonalKet(n)).doit() + 1 + >>> (OrthogonalBra(n)*OrthogonalKet(n+1)).doit() + 0 + >>> (OrthogonalBra(n)*OrthogonalKet(m)).doit() + + """ + + @classmethod + def dual_class(self): + return OrthogonalBra + + def _eval_innerproduct(self, bra, **hints): + + if len(self.args) != len(bra.args): + raise ValueError('Cannot multiply a ket that has a different number of labels.') + + for arg, bra_arg in zip(self.args, bra.args): + diff = arg - bra_arg + diff = diff.expand() + + is_zero = diff.is_zero + + if is_zero is False: + return S.Zero # i.e. Integer(0) + + if is_zero is None: + return None + + return S.One # i.e. Integer(1) + + +class OrthogonalBra(OrthogonalState, BraBase): + """Orthogonal Bra in quantum mechanics. + """ + + @classmethod + def dual_class(self): + return OrthogonalKet + + +class Wavefunction(Function): + """Class for representations in continuous bases + + This class takes an expression and coordinates in its constructor. It can + be used to easily calculate normalizations and probabilities. + + Parameters + ========== + + expr : Expr + The expression representing the functional form of the w.f. + + coords : Symbol or tuple + The coordinates to be integrated over, and their bounds + + Examples + ======== + + Particle in a box, specifying bounds in the more primitive way of using + Piecewise: + + >>> from sympy import Symbol, Piecewise, pi, N + >>> from sympy.functions import sqrt, sin + >>> from sympy.physics.quantum.state import Wavefunction + >>> x = Symbol('x', real=True) + >>> n = 1 + >>> L = 1 + >>> g = Piecewise((0, x < 0), (0, x > L), (sqrt(2//L)*sin(n*pi*x/L), True)) + >>> f = Wavefunction(g, x) + >>> f.norm + 1 + >>> f.is_normalized + True + >>> p = f.prob() + >>> p(0) + 0 + >>> p(L) + 0 + >>> p(0.5) + 2 + >>> p(0.85*L) + 2*sin(0.85*pi)**2 + >>> N(p(0.85*L)) + 0.412214747707527 + + Additionally, you can specify the bounds of the function and the indices in + a more compact way: + + >>> from sympy import symbols, pi, diff + >>> from sympy.functions import sqrt, sin + >>> from sympy.physics.quantum.state import Wavefunction + >>> x, L = symbols('x,L', positive=True) + >>> n = symbols('n', integer=True, positive=True) + >>> g = sqrt(2/L)*sin(n*pi*x/L) + >>> f = Wavefunction(g, (x, 0, L)) + >>> f.norm + 1 + >>> f(L+1) + 0 + >>> f(L-1) + sqrt(2)*sin(pi*n*(L - 1)/L)/sqrt(L) + >>> f(-1) + 0 + >>> f(0.85) + sqrt(2)*sin(0.85*pi*n/L)/sqrt(L) + >>> f(0.85, n=1, L=1) + sqrt(2)*sin(0.85*pi) + >>> f.is_commutative + False + + All arguments are automatically sympified, so you can define the variables + as strings rather than symbols: + + >>> expr = x**2 + >>> f = Wavefunction(expr, 'x') + >>> type(f.variables[0]) + + + Derivatives of Wavefunctions will return Wavefunctions: + + >>> diff(f, x) + Wavefunction(2*x, x) + + """ + + #Any passed tuples for coordinates and their bounds need to be + #converted to Tuples before Function's constructor is called, to + #avoid errors from calling is_Float in the constructor + def __new__(cls, *args, **options): + new_args = [None for i in args] + ct = 0 + for arg in args: + if isinstance(arg, tuple): + new_args[ct] = Tuple(*arg) + else: + new_args[ct] = arg + ct += 1 + + return super().__new__(cls, *new_args, **options) + + def __call__(self, *args, **options): + var = self.variables + + if len(args) != len(var): + raise NotImplementedError( + "Incorrect number of arguments to function!") + + ct = 0 + #If the passed value is outside the specified bounds, return 0 + for v in var: + lower, upper = self.limits[v] + + #Do the comparison to limits only if the passed symbol is actually + #a symbol present in the limits; + #Had problems with a comparison of x > L + if isinstance(args[ct], Expr) and \ + not (lower in args[ct].free_symbols + or upper in args[ct].free_symbols): + continue + + if (args[ct] < lower) == True or (args[ct] > upper) == True: + return S.Zero + + ct += 1 + + expr = self.expr + + #Allows user to make a call like f(2, 4, m=1, n=1) + for symbol in list(expr.free_symbols): + if str(symbol) in options.keys(): + val = options[str(symbol)] + expr = expr.subs(symbol, val) + + return expr.subs(zip(var, args)) + + def _eval_derivative(self, symbol): + expr = self.expr + deriv = expr._eval_derivative(symbol) + + return Wavefunction(deriv, *self.args[1:]) + + def _eval_conjugate(self): + return Wavefunction(conjugate(self.expr), *self.args[1:]) + + def _eval_transpose(self): + return self + + @property + def is_commutative(self): + """ + Override Function's is_commutative so that order is preserved in + represented expressions + """ + return False + + @classmethod + def eval(self, *args): + return None + + @property + def variables(self): + """ + Return the coordinates which the wavefunction depends on + + Examples + ======== + + >>> from sympy.physics.quantum.state import Wavefunction + >>> from sympy import symbols + >>> x,y = symbols('x,y') + >>> f = Wavefunction(x*y, x, y) + >>> f.variables + (x, y) + >>> g = Wavefunction(x*y, x) + >>> g.variables + (x,) + + """ + var = [g[0] if isinstance(g, Tuple) else g for g in self._args[1:]] + return tuple(var) + + @property + def limits(self): + """ + Return the limits of the coordinates which the w.f. depends on If no + limits are specified, defaults to ``(-oo, oo)``. + + Examples + ======== + + >>> from sympy.physics.quantum.state import Wavefunction + >>> from sympy import symbols + >>> x, y = symbols('x, y') + >>> f = Wavefunction(x**2, (x, 0, 1)) + >>> f.limits + {x: (0, 1)} + >>> f = Wavefunction(x**2, x) + >>> f.limits + {x: (-oo, oo)} + >>> f = Wavefunction(x**2 + y**2, x, (y, -1, 2)) + >>> f.limits + {x: (-oo, oo), y: (-1, 2)} + + """ + limits = [(g[1], g[2]) if isinstance(g, Tuple) else (-oo, oo) + for g in self._args[1:]] + return dict(zip(self.variables, tuple(limits))) + + @property + def expr(self): + """ + Return the expression which is the functional form of the Wavefunction + + Examples + ======== + + >>> from sympy.physics.quantum.state import Wavefunction + >>> from sympy import symbols + >>> x, y = symbols('x, y') + >>> f = Wavefunction(x**2, x) + >>> f.expr + x**2 + + """ + return self._args[0] + + @property + def is_normalized(self): + """ + Returns true if the Wavefunction is properly normalized + + Examples + ======== + + >>> from sympy import symbols, pi + >>> from sympy.functions import sqrt, sin + >>> from sympy.physics.quantum.state import Wavefunction + >>> x, L = symbols('x,L', positive=True) + >>> n = symbols('n', integer=True, positive=True) + >>> g = sqrt(2/L)*sin(n*pi*x/L) + >>> f = Wavefunction(g, (x, 0, L)) + >>> f.is_normalized + True + + """ + + return equal_valued(self.norm, 1) + + @property # type: ignore + @cacheit + def norm(self): + """ + Return the normalization of the specified functional form. + + This function integrates over the coordinates of the Wavefunction, with + the bounds specified. + + Examples + ======== + + >>> from sympy import symbols, pi + >>> from sympy.functions import sqrt, sin + >>> from sympy.physics.quantum.state import Wavefunction + >>> x, L = symbols('x,L', positive=True) + >>> n = symbols('n', integer=True, positive=True) + >>> g = sqrt(2/L)*sin(n*pi*x/L) + >>> f = Wavefunction(g, (x, 0, L)) + >>> f.norm + 1 + >>> g = sin(n*pi*x/L) + >>> f = Wavefunction(g, (x, 0, L)) + >>> f.norm + sqrt(2)*sqrt(L)/2 + + """ + + exp = self.expr*conjugate(self.expr) + var = self.variables + limits = self.limits + + for v in var: + curr_limits = limits[v] + exp = integrate(exp, (v, curr_limits[0], curr_limits[1])) + + return sqrt(exp) + + def normalize(self): + """ + Return a normalized version of the Wavefunction + + Examples + ======== + + >>> from sympy import symbols, pi + >>> from sympy.functions import sin + >>> from sympy.physics.quantum.state import Wavefunction + >>> x = symbols('x', real=True) + >>> L = symbols('L', positive=True) + >>> n = symbols('n', integer=True, positive=True) + >>> g = sin(n*pi*x/L) + >>> f = Wavefunction(g, (x, 0, L)) + >>> f.normalize() + Wavefunction(sqrt(2)*sin(pi*n*x/L)/sqrt(L), (x, 0, L)) + + """ + const = self.norm + + if const is oo: + raise NotImplementedError("The function is not normalizable!") + else: + return Wavefunction((const)**(-1)*self.expr, *self.args[1:]) + + def prob(self): + r""" + Return the absolute magnitude of the w.f., `|\psi(x)|^2` + + Examples + ======== + + >>> from sympy import symbols, pi + >>> from sympy.functions import sin + >>> from sympy.physics.quantum.state import Wavefunction + >>> x, L = symbols('x,L', real=True) + >>> n = symbols('n', integer=True) + >>> g = sin(n*pi*x/L) + >>> f = Wavefunction(g, (x, 0, L)) + >>> f.prob() + Wavefunction(sin(pi*n*x/L)**2, x) + + """ + + return Wavefunction(self.expr*conjugate(self.expr), *self.variables) diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tensorproduct.py b/phivenv/Lib/site-packages/sympy/physics/quantum/tensorproduct.py new file mode 100644 index 0000000000000000000000000000000000000000..058b3459227e5a020e2d0397fc66f56a2f917293 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/tensorproduct.py @@ -0,0 +1,363 @@ +"""Abstract tensor product.""" + +from sympy.core.add import Add +from sympy.core.expr import Expr +from sympy.core.kind import KindDispatcher +from sympy.core.mul import Mul +from sympy.core.power import Pow +from sympy.core.sympify import sympify +from sympy.matrices.dense import DenseMatrix as Matrix +from sympy.matrices.immutable import ImmutableDenseMatrix as ImmutableMatrix +from sympy.printing.pretty.stringpict import prettyForm +from sympy.utilities.exceptions import sympy_deprecation_warning + +from sympy.physics.quantum.dagger import Dagger +from sympy.physics.quantum.kind import ( + KetKind, _KetKind, + BraKind, _BraKind, + OperatorKind, _OperatorKind +) +from sympy.physics.quantum.matrixutils import ( + numpy_ndarray, + scipy_sparse_matrix, + matrix_tensor_product +) +from sympy.physics.quantum.state import Ket, Bra +from sympy.physics.quantum.trace import Tr + + +__all__ = [ + 'TensorProduct', + 'tensor_product_simp' +] + +#----------------------------------------------------------------------------- +# Tensor product +#----------------------------------------------------------------------------- + +_combined_printing = False + + +def combined_tensor_printing(combined): + """Set flag controlling whether tensor products of states should be + printed as a combined bra/ket or as an explicit tensor product of different + bra/kets. This is a global setting for all TensorProduct class instances. + + Parameters + ---------- + combine : bool + When true, tensor product states are combined into one ket/bra, and + when false explicit tensor product notation is used between each + ket/bra. + """ + global _combined_printing + _combined_printing = combined + + +class TensorProduct(Expr): + """The tensor product of two or more arguments. + + For matrices, this uses ``matrix_tensor_product`` to compute the Kronecker + or tensor product matrix. For other objects a symbolic ``TensorProduct`` + instance is returned. The tensor product is a non-commutative + multiplication that is used primarily with operators and states in quantum + mechanics. + + Currently, the tensor product distinguishes between commutative and + non-commutative arguments. Commutative arguments are assumed to be scalars + and are pulled out in front of the ``TensorProduct``. Non-commutative + arguments remain in the resulting ``TensorProduct``. + + Parameters + ========== + + args : tuple + A sequence of the objects to take the tensor product of. + + Examples + ======== + + Start with a simple tensor product of SymPy matrices:: + + >>> from sympy import Matrix + >>> from sympy.physics.quantum import TensorProduct + + >>> m1 = Matrix([[1,2],[3,4]]) + >>> m2 = Matrix([[1,0],[0,1]]) + >>> TensorProduct(m1, m2) + Matrix([ + [1, 0, 2, 0], + [0, 1, 0, 2], + [3, 0, 4, 0], + [0, 3, 0, 4]]) + >>> TensorProduct(m2, m1) + Matrix([ + [1, 2, 0, 0], + [3, 4, 0, 0], + [0, 0, 1, 2], + [0, 0, 3, 4]]) + + We can also construct tensor products of non-commutative symbols: + + >>> from sympy import Symbol + >>> A = Symbol('A',commutative=False) + >>> B = Symbol('B',commutative=False) + >>> tp = TensorProduct(A, B) + >>> tp + AxB + + We can take the dagger of a tensor product (note the order does NOT reverse + like the dagger of a normal product): + + >>> from sympy.physics.quantum import Dagger + >>> Dagger(tp) + Dagger(A)xDagger(B) + + Expand can be used to distribute a tensor product across addition: + + >>> C = Symbol('C',commutative=False) + >>> tp = TensorProduct(A+B,C) + >>> tp + (A + B)xC + >>> tp.expand(tensorproduct=True) + AxC + BxC + """ + is_commutative = False + + _kind_dispatcher = KindDispatcher("TensorProduct_kind_dispatcher", commutative=True) + + @property + def kind(self): + """Calculate the kind of a tensor product by looking at its children.""" + arg_kinds = (a.kind for a in self.args) + return self._kind_dispatcher(*arg_kinds) + + def __new__(cls, *args): + if isinstance(args[0], (Matrix, ImmutableMatrix, numpy_ndarray, + scipy_sparse_matrix)): + return matrix_tensor_product(*args) + c_part, new_args = cls.flatten(sympify(args)) + c_part = Mul(*c_part) + if len(new_args) == 0: + return c_part + elif len(new_args) == 1: + return c_part * new_args[0] + else: + tp = Expr.__new__(cls, *new_args) + return c_part * tp + + @classmethod + def flatten(cls, args): + # TODO: disallow nested TensorProducts. + c_part = [] + nc_parts = [] + for arg in args: + cp, ncp = arg.args_cnc() + c_part.extend(list(cp)) + nc_parts.append(Mul._from_args(ncp)) + return c_part, nc_parts + + def _eval_adjoint(self): + return TensorProduct(*[Dagger(i) for i in self.args]) + + def _eval_rewrite(self, rule, args, **hints): + return TensorProduct(*args).expand(tensorproduct=True) + + def _sympystr(self, printer, *args): + length = len(self.args) + s = '' + for i in range(length): + if isinstance(self.args[i], (Add, Pow, Mul)): + s = s + '(' + s = s + printer._print(self.args[i]) + if isinstance(self.args[i], (Add, Pow, Mul)): + s = s + ')' + if i != length - 1: + s = s + 'x' + return s + + def _pretty(self, printer, *args): + + if (_combined_printing and + (all(isinstance(arg, Ket) for arg in self.args) or + all(isinstance(arg, Bra) for arg in self.args))): + + length = len(self.args) + pform = printer._print('', *args) + for i in range(length): + next_pform = printer._print('', *args) + length_i = len(self.args[i].args) + for j in range(length_i): + part_pform = printer._print(self.args[i].args[j], *args) + next_pform = prettyForm(*next_pform.right(part_pform)) + if j != length_i - 1: + next_pform = prettyForm(*next_pform.right(', ')) + + if len(self.args[i].args) > 1: + next_pform = prettyForm( + *next_pform.parens(left='{', right='}')) + pform = prettyForm(*pform.right(next_pform)) + if i != length - 1: + pform = prettyForm(*pform.right(',' + ' ')) + + pform = prettyForm(*pform.left(self.args[0].lbracket)) + pform = prettyForm(*pform.right(self.args[0].rbracket)) + return pform + + length = len(self.args) + pform = printer._print('', *args) + for i in range(length): + next_pform = printer._print(self.args[i], *args) + if isinstance(self.args[i], (Add, Mul)): + next_pform = prettyForm( + *next_pform.parens(left='(', right=')') + ) + pform = prettyForm(*pform.right(next_pform)) + if i != length - 1: + if printer._use_unicode: + pform = prettyForm(*pform.right('\N{N-ARY CIRCLED TIMES OPERATOR}' + ' ')) + else: + pform = prettyForm(*pform.right('x' + ' ')) + return pform + + def _latex(self, printer, *args): + + if (_combined_printing and + (all(isinstance(arg, Ket) for arg in self.args) or + all(isinstance(arg, Bra) for arg in self.args))): + + def _label_wrap(label, nlabels): + return label if nlabels == 1 else r"\left\{%s\right\}" % label + + s = r", ".join([_label_wrap(arg._print_label_latex(printer, *args), + len(arg.args)) for arg in self.args]) + + return r"{%s%s%s}" % (self.args[0].lbracket_latex, s, + self.args[0].rbracket_latex) + + length = len(self.args) + s = '' + for i in range(length): + if isinstance(self.args[i], (Add, Mul)): + s = s + '\\left(' + # The extra {} brackets are needed to get matplotlib's latex + # rendered to render this properly. + s = s + '{' + printer._print(self.args[i], *args) + '}' + if isinstance(self.args[i], (Add, Mul)): + s = s + '\\right)' + if i != length - 1: + s = s + '\\otimes ' + return s + + def doit(self, **hints): + return TensorProduct(*[item.doit(**hints) for item in self.args]) + + def _eval_expand_tensorproduct(self, **hints): + """Distribute TensorProducts across addition.""" + args = self.args + add_args = [] + for i in range(len(args)): + if isinstance(args[i], Add): + for aa in args[i].args: + tp = TensorProduct(*args[:i] + (aa,) + args[i + 1:]) + c_part, nc_part = tp.args_cnc() + # Check for TensorProduct object: is the one object in nc_part, if any: + # (Note: any other object type to be expanded must be added here) + if len(nc_part) == 1 and isinstance(nc_part[0], TensorProduct): + nc_part = (nc_part[0]._eval_expand_tensorproduct(), ) + add_args.append(Mul(*c_part)*Mul(*nc_part)) + break + + if add_args: + return Add(*add_args) + else: + return self + + def _eval_trace(self, **kwargs): + indices = kwargs.get('indices', None) + exp = self + + if indices is None or len(indices) == 0: + return Mul(*[Tr(arg).doit() for arg in exp.args]) + else: + return Mul(*[Tr(value).doit() if idx in indices else value + for idx, value in enumerate(exp.args)]) + + +def tensor_product_simp_Mul(e): + """Simplify a Mul with tensor products. + + .. deprecated:: 1.14. + The transformations applied by this function are not done automatically + when tensor products are combined. + + Originally, the main use of this function is to simplify a ``Mul`` of + ``TensorProduct``s to a ``TensorProduct`` of ``Muls``. + """ + sympy_deprecation_warning( + """ + tensor_product_simp_Mul has been deprecated. The transformations + performed by this function are now done automatically when + tensor products are multiplied. + """, + deprecated_since_version="1.14", + active_deprecations_target='deprecated-tensorproduct-simp' + ) + return e + +def tensor_product_simp_Pow(e): + """Evaluates ``Pow`` expressions whose base is ``TensorProduct`` + + .. deprecated:: 1.14. + The transformations applied by this function are not done automatically + when tensor products are combined. + """ + sympy_deprecation_warning( + """ + tensor_product_simp_Pow has been deprecated. The transformations + performed by this function are now done automatically when + tensor products are exponentiated. + """, + deprecated_since_version="1.14", + active_deprecations_target='deprecated-tensorproduct-simp' + ) + return e + + +def tensor_product_simp(e, **hints): + """Try to simplify and combine tensor products. + + .. deprecated:: 1.14. + The transformations applied by this function are not done automatically + when tensor products are combined. + + Originally, this function tried to pull expressions inside of ``TensorProducts``. + It only worked for relatively simple cases where the products have + only scalars, raw ``TensorProducts``, not ``Add``, ``Pow``, ``Commutators`` + of ``TensorProducts``. + """ + sympy_deprecation_warning( + """ + tensor_product_simp has been deprecated. The transformations + performed by this function are now done automatically when + tensor products are combined. + """, + deprecated_since_version="1.14", + active_deprecations_target='deprecated-tensorproduct-simp' + ) + return e + + +@TensorProduct._kind_dispatcher.register(_OperatorKind, _OperatorKind) +def find_op_kind(e1, e2): + return OperatorKind + + +@TensorProduct._kind_dispatcher.register(_KetKind, _KetKind) +def find_ket_kind(e1, e2): + return KetKind + + +@TensorProduct._kind_dispatcher.register(_BraKind, _BraKind) +def find_bra_kind(e1, e2): + return BraKind diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__init__.py b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ebc000b5d42e26ad1b76667213ff7b829c466aef Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_anticommutator.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_anticommutator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d470ff0601a55e14c054702492560cdc514519b1 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_anticommutator.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_boson.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_boson.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..357affdae2378d534f9d65bff3f56c82fc996d8e Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_boson.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_cartesian.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_cartesian.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25a841d9d1a03b5597e0c00dbc07652ae80535a2 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_cartesian.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_cg.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_cg.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f1138d7bf5ee6f79ae96f61fcf1bf52bb494356 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_cg.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_circuitplot.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_circuitplot.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4a0a77fd1cbbdad40d9928471cc77ba741b4ea4 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_circuitplot.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_circuitutils.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_circuitutils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30a2b898c209497bd63326d844a50d9ef71b13f6 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_circuitutils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_commutator.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_commutator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a3848522217f2043db2e4fc69e3d1ab3e7aefb67 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_commutator.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_constants.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_constants.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1803215131254e4686ce21ddec9641fecc79283a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_constants.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_dagger.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_dagger.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af3642ca954ed650771e65ddcf4be72eb88a36f9 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_dagger.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_density.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_density.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f3ae8bb5eb3ff7b6454adfd592bddc1c7944c02 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_density.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_fermion.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_fermion.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca8cb1317a5f300f3239828fc14155214469d40c Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_fermion.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_gate.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_gate.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1539a351299776ef384bebf488051097674a2998 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_gate.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_grover.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_grover.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3c834e5f486382cae3e40364c1c7e7d2f0096c6 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_grover.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_hilbert.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_hilbert.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45f398d6460e80e4f2830908e5036ca3a8f5dfc3 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_hilbert.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_identitysearch.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_identitysearch.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94c964273cbd7b76a4bd73a5ad806d1cda113f79 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_identitysearch.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_innerproduct.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_innerproduct.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ea0314922d8b03741dd366d2a04f8dac80ad06f Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_innerproduct.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_kind.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_kind.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d97520f67c72472b3c2a2e841d8d40e88efe0976 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_kind.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_matrixutils.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_matrixutils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ebb5e8107cd4c72174e75c41a6e55fb80ce76469 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_matrixutils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_operator.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_operator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b689b94a56d73922bee2b13269066fa07a7d7a05 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_operator.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_operatorordering.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_operatorordering.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..652fa75c62342774957815271ea078615a64b243 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_operatorordering.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_operatorset.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_operatorset.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94ac6fbd6574e5217b8f4a748b53b1aa4128af9a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_operatorset.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_pauli.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_pauli.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f406c847627898ce641b22ed144bfa99ded4837 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_pauli.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_piab.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_piab.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..264994858628cb8e472dad8bd3bff4fc75290f93 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_piab.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_printing.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_printing.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ec8efea9985ad2c04b48467da26050f2179cc4d Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_printing.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_qapply.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_qapply.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7db0323d1555470b2a01c0ef63bc77e0877cabd8 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_qapply.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_qasm.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_qasm.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67f1e746dd51825c172d6717d8b23d372e98d380 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_qasm.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_qexpr.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_qexpr.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98babe25eec5e3abdb6ea4bc64102aade99a01fe Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_qexpr.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_qft.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_qft.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46b8c70d476874865f1b9d3d7dd1097c12cdb87e Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_qft.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_qubit.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_qubit.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ac47cbda15ee83c74f7bfdd400ff13a047d04f4 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_qubit.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_represent.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_represent.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab2ed18f85d88a3214984a0d125627082389ea8d Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_represent.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_sho1d.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_sho1d.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee64047eb57d9bf60552f63eb8af0284441c867e Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_sho1d.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_shor.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_shor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..508ae9adf682739eec67993af993b0e61dcf68a5 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_shor.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_state.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_state.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97889c22f2fa80b068dddbae93788f36e73ea3b8 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_state.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_tensorproduct.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_tensorproduct.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..246e99c3267a3d3c4b969a8ee7a03d6f6f7c07fc Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_tensorproduct.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_trace.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_trace.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ee0a0c248f4818f1df6dbed07485ac10f70ed1e Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_trace.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_transforms.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_transforms.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..caf51c5468cec1d64db7525f0005bb9cc7a6bed9 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/__pycache__/test_transforms.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_anticommutator.py b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_anticommutator.py new file mode 100644 index 0000000000000000000000000000000000000000..0e6b6cbc50651742fcbbbe6adce3f20dfadc2ec5 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_anticommutator.py @@ -0,0 +1,56 @@ +from sympy.core.numbers import Integer +from sympy.core.symbol import symbols + +from sympy.physics.quantum.dagger import Dagger +from sympy.physics.quantum.anticommutator import AntiCommutator as AComm +from sympy.physics.quantum.operator import Operator + + +a, b, c = symbols('a,b,c') +A, B, C, D = symbols('A,B,C,D', commutative=False) + + +def test_anticommutator(): + ac = AComm(A, B) + assert isinstance(ac, AComm) + assert ac.is_commutative is False + assert ac.subs(A, C) == AComm(C, B) + + +def test_commutator_identities(): + assert AComm(a*A, b*B) == a*b*AComm(A, B) + assert AComm(A, A) == 2*A**2 + assert AComm(A, B) == AComm(B, A) + assert AComm(a, b) == 2*a*b + assert AComm(A, B).doit() == A*B + B*A + + +def test_anticommutator_dagger(): + assert Dagger(AComm(A, B)) == AComm(Dagger(A), Dagger(B)) + + +class Foo(Operator): + + def _eval_anticommutator_Bar(self, bar): + return Integer(0) + + +class Bar(Operator): + pass + + +class Tam(Operator): + + def _eval_anticommutator_Foo(self, foo): + return Integer(1) + + +def test_eval_commutator(): + F = Foo('F') + B = Bar('B') + T = Tam('T') + assert AComm(F, B).doit() == 0 + assert AComm(B, F).doit() == 0 + assert AComm(F, T).doit() == 1 + assert AComm(T, F).doit() == 1 + assert AComm(B, T).doit() == B*T + T*B diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_boson.py b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_boson.py new file mode 100644 index 0000000000000000000000000000000000000000..cd8dab745bede8b1c70303917dae81146fc03395 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_boson.py @@ -0,0 +1,50 @@ +from math import prod + +from sympy.core.numbers import Rational +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.physics.quantum import Dagger, Commutator, qapply +from sympy.physics.quantum.boson import BosonOp +from sympy.physics.quantum.boson import ( + BosonFockKet, BosonFockBra, BosonCoherentKet, BosonCoherentBra) + + +def test_bosonoperator(): + a = BosonOp('a') + b = BosonOp('b') + + assert isinstance(a, BosonOp) + assert isinstance(Dagger(a), BosonOp) + + assert a.is_annihilation + assert not Dagger(a).is_annihilation + + assert BosonOp("a") == BosonOp("a", True) + assert BosonOp("a") != BosonOp("c") + assert BosonOp("a", True) != BosonOp("a", False) + + assert Commutator(a, Dagger(a)).doit() == 1 + + assert Commutator(a, Dagger(b)).doit() == a * Dagger(b) - Dagger(b) * a + + assert Dagger(exp(a)) == exp(Dagger(a)) + + +def test_boson_states(): + a = BosonOp("a") + + # Fock states + n = 3 + assert (BosonFockBra(0) * BosonFockKet(1)).doit() == 0 + assert (BosonFockBra(1) * BosonFockKet(1)).doit() == 1 + assert qapply(BosonFockBra(n) * Dagger(a)**n * BosonFockKet(0)) \ + == sqrt(prod(range(1, n+1))) + + # Coherent states + alpha1, alpha2 = 1.2, 4.3 + assert (BosonCoherentBra(alpha1) * BosonCoherentKet(alpha1)).doit() == 1 + assert (BosonCoherentBra(alpha2) * BosonCoherentKet(alpha2)).doit() == 1 + assert abs((BosonCoherentBra(alpha1) * BosonCoherentKet(alpha2)).doit() - + exp((alpha1 - alpha2) ** 2 * Rational(-1, 2))) < 1e-12 + assert qapply(a * BosonCoherentKet(alpha1)) == \ + alpha1 * BosonCoherentKet(alpha1) diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_cartesian.py b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_cartesian.py new file mode 100644 index 0000000000000000000000000000000000000000..f1dd435fab68c9c71ac3602bc4c53847cbe39d57 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_cartesian.py @@ -0,0 +1,113 @@ +"""Tests for cartesian.py""" + +from sympy.core.numbers import (I, pi) +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.special.delta_functions import DiracDelta +from sympy.sets.sets import Interval +from sympy.testing.pytest import XFAIL + +from sympy.physics.quantum import qapply, represent, L2, Dagger +from sympy.physics.quantum import Commutator, hbar +from sympy.physics.quantum.cartesian import ( + XOp, YOp, ZOp, PxOp, X, Y, Z, Px, XKet, XBra, PxKet, PxBra, + PositionKet3D, PositionBra3D +) +from sympy.physics.quantum.operator import DifferentialOperator + +x, y, z, x_1, x_2, x_3, y_1, z_1 = symbols('x,y,z,x_1,x_2,x_3,y_1,z_1') +px, py, px_1, px_2 = symbols('px py px_1 px_2') + + +def test_x(): + assert X.hilbert_space == L2(Interval(S.NegativeInfinity, S.Infinity)) + assert Commutator(X, Px).doit() == I*hbar + assert qapply(X*XKet(x)) == x*XKet(x) + assert XKet(x).dual_class() == XBra + assert XBra(x).dual_class() == XKet + assert (Dagger(XKet(y))*XKet(x)).doit() == DiracDelta(x - y) + assert (PxBra(px)*XKet(x)).doit() == \ + exp(-I*x*px/hbar)/sqrt(2*pi*hbar) + assert represent(XKet(x)) == DiracDelta(x - x_1) + assert represent(XBra(x)) == DiracDelta(-x + x_1) + assert XBra(x).position == x + assert represent(XOp()*XKet()) == x*DiracDelta(x - x_2) + assert represent(XBra("y")*XKet()) == DiracDelta(x - y) + assert represent( + XKet()*XBra()) == DiracDelta(x - x_2) * DiracDelta(x_1 - x) + + rep_p = represent(XOp(), basis=PxOp) + assert rep_p == hbar*I*DiracDelta(px_1 - px_2)*DifferentialOperator(px_1) + assert rep_p == represent(XOp(), basis=PxOp()) + assert rep_p == represent(XOp(), basis=PxKet) + assert rep_p == represent(XOp(), basis=PxKet()) + + assert represent(XOp()*PxKet(), basis=PxKet) == \ + hbar*I*DiracDelta(px - px_2)*DifferentialOperator(px) + + +@XFAIL +def _text_x_broken(): + # represent has some broken logic that is relying in particular + # forms of input, rather than a full and proper handling of + # all valid quantum expressions. Marking this test as XFAIL until + # we can refactor represent. + assert represent(XOp()*XKet()*XBra('y')) == \ + x*DiracDelta(x - x_3)*DiracDelta(x_1 - y) + + +def test_p(): + assert Px.hilbert_space == L2(Interval(S.NegativeInfinity, S.Infinity)) + assert qapply(Px*PxKet(px)) == px*PxKet(px) + assert PxKet(px).dual_class() == PxBra + assert PxBra(x).dual_class() == PxKet + assert (Dagger(PxKet(py))*PxKet(px)).doit() == DiracDelta(px - py) + assert (XBra(x)*PxKet(px)).doit() == \ + exp(I*x*px/hbar)/sqrt(2*pi*hbar) + assert represent(PxKet(px)) == DiracDelta(px - px_1) + + rep_x = represent(PxOp(), basis=XOp) + assert rep_x == -hbar*I*DiracDelta(x_1 - x_2)*DifferentialOperator(x_1) + assert rep_x == represent(PxOp(), basis=XOp()) + assert rep_x == represent(PxOp(), basis=XKet) + assert rep_x == represent(PxOp(), basis=XKet()) + + assert represent(PxOp()*XKet(), basis=XKet) == \ + -hbar*I*DiracDelta(x - x_2)*DifferentialOperator(x) + assert represent(XBra("y")*PxOp()*XKet(), basis=XKet) == \ + -hbar*I*DiracDelta(x - y)*DifferentialOperator(x) + + +def test_3dpos(): + assert Y.hilbert_space == L2(Interval(S.NegativeInfinity, S.Infinity)) + assert Z.hilbert_space == L2(Interval(S.NegativeInfinity, S.Infinity)) + + test_ket = PositionKet3D(x, y, z) + assert qapply(X*test_ket) == x*test_ket + assert qapply(Y*test_ket) == y*test_ket + assert qapply(Z*test_ket) == z*test_ket + assert qapply(X*Y*test_ket) == x*y*test_ket + assert qapply(X*Y*Z*test_ket) == x*y*z*test_ket + assert qapply(Y*Z*test_ket) == y*z*test_ket + + assert PositionKet3D() == test_ket + assert YOp() == Y + assert ZOp() == Z + + assert PositionKet3D.dual_class() == PositionBra3D + assert PositionBra3D.dual_class() == PositionKet3D + + other_ket = PositionKet3D(x_1, y_1, z_1) + assert (Dagger(other_ket)*test_ket).doit() == \ + DiracDelta(x - x_1)*DiracDelta(y - y_1)*DiracDelta(z - z_1) + + assert test_ket.position_x == x + assert test_ket.position_y == y + assert test_ket.position_z == z + assert other_ket.position_x == x_1 + assert other_ket.position_y == y_1 + assert other_ket.position_z == z_1 + + # TODO: Add tests for representations diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_cg.py b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_cg.py new file mode 100644 index 0000000000000000000000000000000000000000..384512aaac7a8d984ff2a733e6349161dc9414a0 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_cg.py @@ -0,0 +1,183 @@ +from sympy.concrete.summations import Sum +from sympy.core.numbers import Rational +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.physics.quantum.cg import Wigner3j, Wigner6j, Wigner9j, CG, cg_simp +from sympy.functions.special.tensor_functions import KroneckerDelta + + +def test_cg_simp_add(): + j, m1, m1p, m2, m2p = symbols('j m1 m1p m2 m2p') + # Test Varshalovich 8.7.1 Eq 1 + a = CG(S.Half, S.Half, 0, 0, S.Half, S.Half) + b = CG(S.Half, Rational(-1, 2), 0, 0, S.Half, Rational(-1, 2)) + c = CG(1, 1, 0, 0, 1, 1) + d = CG(1, 0, 0, 0, 1, 0) + e = CG(1, -1, 0, 0, 1, -1) + assert cg_simp(a + b) == 2 + assert cg_simp(c + d + e) == 3 + assert cg_simp(a + b + c + d + e) == 5 + assert cg_simp(a + b + c) == 2 + c + assert cg_simp(2*a + b) == 2 + a + assert cg_simp(2*c + d + e) == 3 + c + assert cg_simp(5*a + 5*b) == 10 + assert cg_simp(5*c + 5*d + 5*e) == 15 + assert cg_simp(-a - b) == -2 + assert cg_simp(-c - d - e) == -3 + assert cg_simp(-6*a - 6*b) == -12 + assert cg_simp(-4*c - 4*d - 4*e) == -12 + a = CG(S.Half, S.Half, j, 0, S.Half, S.Half) + b = CG(S.Half, Rational(-1, 2), j, 0, S.Half, Rational(-1, 2)) + c = CG(1, 1, j, 0, 1, 1) + d = CG(1, 0, j, 0, 1, 0) + e = CG(1, -1, j, 0, 1, -1) + assert cg_simp(a + b) == 2*KroneckerDelta(j, 0) + assert cg_simp(c + d + e) == 3*KroneckerDelta(j, 0) + assert cg_simp(a + b + c + d + e) == 5*KroneckerDelta(j, 0) + assert cg_simp(a + b + c) == 2*KroneckerDelta(j, 0) + c + assert cg_simp(2*a + b) == 2*KroneckerDelta(j, 0) + a + assert cg_simp(2*c + d + e) == 3*KroneckerDelta(j, 0) + c + assert cg_simp(5*a + 5*b) == 10*KroneckerDelta(j, 0) + assert cg_simp(5*c + 5*d + 5*e) == 15*KroneckerDelta(j, 0) + assert cg_simp(-a - b) == -2*KroneckerDelta(j, 0) + assert cg_simp(-c - d - e) == -3*KroneckerDelta(j, 0) + assert cg_simp(-6*a - 6*b) == -12*KroneckerDelta(j, 0) + assert cg_simp(-4*c - 4*d - 4*e) == -12*KroneckerDelta(j, 0) + # Test Varshalovich 8.7.1 Eq 2 + a = CG(S.Half, S.Half, S.Half, Rational(-1, 2), 0, 0) + b = CG(S.Half, Rational(-1, 2), S.Half, S.Half, 0, 0) + c = CG(1, 1, 1, -1, 0, 0) + d = CG(1, 0, 1, 0, 0, 0) + e = CG(1, -1, 1, 1, 0, 0) + assert cg_simp(a - b) == sqrt(2) + assert cg_simp(c - d + e) == sqrt(3) + assert cg_simp(a - b + c - d + e) == sqrt(2) + sqrt(3) + assert cg_simp(a - b + c) == sqrt(2) + c + assert cg_simp(2*a - b) == sqrt(2) + a + assert cg_simp(2*c - d + e) == sqrt(3) + c + assert cg_simp(5*a - 5*b) == 5*sqrt(2) + assert cg_simp(5*c - 5*d + 5*e) == 5*sqrt(3) + assert cg_simp(-a + b) == -sqrt(2) + assert cg_simp(-c + d - e) == -sqrt(3) + assert cg_simp(-6*a + 6*b) == -6*sqrt(2) + assert cg_simp(-4*c + 4*d - 4*e) == -4*sqrt(3) + a = CG(S.Half, S.Half, S.Half, Rational(-1, 2), j, 0) + b = CG(S.Half, Rational(-1, 2), S.Half, S.Half, j, 0) + c = CG(1, 1, 1, -1, j, 0) + d = CG(1, 0, 1, 0, j, 0) + e = CG(1, -1, 1, 1, j, 0) + assert cg_simp(a - b) == sqrt(2)*KroneckerDelta(j, 0) + assert cg_simp(c - d + e) == sqrt(3)*KroneckerDelta(j, 0) + assert cg_simp(a - b + c - d + e) == sqrt( + 2)*KroneckerDelta(j, 0) + sqrt(3)*KroneckerDelta(j, 0) + assert cg_simp(a - b + c) == sqrt(2)*KroneckerDelta(j, 0) + c + assert cg_simp(2*a - b) == sqrt(2)*KroneckerDelta(j, 0) + a + assert cg_simp(2*c - d + e) == sqrt(3)*KroneckerDelta(j, 0) + c + assert cg_simp(5*a - 5*b) == 5*sqrt(2)*KroneckerDelta(j, 0) + assert cg_simp(5*c - 5*d + 5*e) == 5*sqrt(3)*KroneckerDelta(j, 0) + assert cg_simp(-a + b) == -sqrt(2)*KroneckerDelta(j, 0) + assert cg_simp(-c + d - e) == -sqrt(3)*KroneckerDelta(j, 0) + assert cg_simp(-6*a + 6*b) == -6*sqrt(2)*KroneckerDelta(j, 0) + assert cg_simp(-4*c + 4*d - 4*e) == -4*sqrt(3)*KroneckerDelta(j, 0) + # Test Varshalovich 8.7.2 Eq 9 + # alpha=alphap,beta=betap case + # numerical + a = CG(S.Half, S.Half, S.Half, Rational(-1, 2), 1, 0)**2 + b = CG(S.Half, S.Half, S.Half, Rational(-1, 2), 0, 0)**2 + c = CG(1, 0, 1, 1, 1, 1)**2 + d = CG(1, 0, 1, 1, 2, 1)**2 + assert cg_simp(a + b) == 1 + assert cg_simp(c + d) == 1 + assert cg_simp(a + b + c + d) == 2 + assert cg_simp(4*a + 4*b) == 4 + assert cg_simp(4*c + 4*d) == 4 + assert cg_simp(5*a + 3*b) == 3 + 2*a + assert cg_simp(5*c + 3*d) == 3 + 2*c + assert cg_simp(-a - b) == -1 + assert cg_simp(-c - d) == -1 + # symbolic + a = CG(S.Half, m1, S.Half, m2, 1, 1)**2 + b = CG(S.Half, m1, S.Half, m2, 1, 0)**2 + c = CG(S.Half, m1, S.Half, m2, 1, -1)**2 + d = CG(S.Half, m1, S.Half, m2, 0, 0)**2 + assert cg_simp(a + b + c + d) == 1 + assert cg_simp(4*a + 4*b + 4*c + 4*d) == 4 + assert cg_simp(3*a + 5*b + 3*c + 4*d) == 3 + 2*b + d + assert cg_simp(-a - b - c - d) == -1 + a = CG(1, m1, 1, m2, 2, 2)**2 + b = CG(1, m1, 1, m2, 2, 1)**2 + c = CG(1, m1, 1, m2, 2, 0)**2 + d = CG(1, m1, 1, m2, 2, -1)**2 + e = CG(1, m1, 1, m2, 2, -2)**2 + f = CG(1, m1, 1, m2, 1, 1)**2 + g = CG(1, m1, 1, m2, 1, 0)**2 + h = CG(1, m1, 1, m2, 1, -1)**2 + i = CG(1, m1, 1, m2, 0, 0)**2 + assert cg_simp(a + b + c + d + e + f + g + h + i) == 1 + assert cg_simp(4*(a + b + c + d + e + f + g + h + i)) == 4 + assert cg_simp(a + b + 2*c + d + 4*e + f + g + h + i) == 1 + c + 3*e + assert cg_simp(-a - b - c - d - e - f - g - h - i) == -1 + # alpha!=alphap or beta!=betap case + # numerical + a = CG(S.Half, S( + 1)/2, S.Half, Rational(-1, 2), 1, 0)*CG(S.Half, Rational(-1, 2), S.Half, S.Half, 1, 0) + b = CG(S.Half, S( + 1)/2, S.Half, Rational(-1, 2), 0, 0)*CG(S.Half, Rational(-1, 2), S.Half, S.Half, 0, 0) + c = CG(1, 1, 1, 0, 2, 1)*CG(1, 0, 1, 1, 2, 1) + d = CG(1, 1, 1, 0, 1, 1)*CG(1, 0, 1, 1, 1, 1) + assert cg_simp(a + b) == 0 + assert cg_simp(c + d) == 0 + # symbolic + a = CG(S.Half, m1, S.Half, m2, 1, 1)*CG(S.Half, m1p, S.Half, m2p, 1, 1) + b = CG(S.Half, m1, S.Half, m2, 1, 0)*CG(S.Half, m1p, S.Half, m2p, 1, 0) + c = CG(S.Half, m1, S.Half, m2, 1, -1)*CG(S.Half, m1p, S.Half, m2p, 1, -1) + d = CG(S.Half, m1, S.Half, m2, 0, 0)*CG(S.Half, m1p, S.Half, m2p, 0, 0) + assert cg_simp(a + b + c + d) == KroneckerDelta(m1, m1p)*KroneckerDelta(m2, m2p) + a = CG(1, m1, 1, m2, 2, 2)*CG(1, m1p, 1, m2p, 2, 2) + b = CG(1, m1, 1, m2, 2, 1)*CG(1, m1p, 1, m2p, 2, 1) + c = CG(1, m1, 1, m2, 2, 0)*CG(1, m1p, 1, m2p, 2, 0) + d = CG(1, m1, 1, m2, 2, -1)*CG(1, m1p, 1, m2p, 2, -1) + e = CG(1, m1, 1, m2, 2, -2)*CG(1, m1p, 1, m2p, 2, -2) + f = CG(1, m1, 1, m2, 1, 1)*CG(1, m1p, 1, m2p, 1, 1) + g = CG(1, m1, 1, m2, 1, 0)*CG(1, m1p, 1, m2p, 1, 0) + h = CG(1, m1, 1, m2, 1, -1)*CG(1, m1p, 1, m2p, 1, -1) + i = CG(1, m1, 1, m2, 0, 0)*CG(1, m1p, 1, m2p, 0, 0) + assert cg_simp( + a + b + c + d + e + f + g + h + i) == KroneckerDelta(m1, m1p)*KroneckerDelta(m2, m2p) + + +def test_cg_simp_sum(): + x, a, b, c, cp, alpha, beta, gamma, gammap = symbols( + 'x a b c cp alpha beta gamma gammap') + # Varshalovich 8.7.1 Eq 1 + assert cg_simp(x * Sum(CG(a, alpha, b, 0, a, alpha), (alpha, -a, a) + )) == x*(2*a + 1)*KroneckerDelta(b, 0) + assert cg_simp(x * Sum(CG(a, alpha, b, 0, a, alpha), (alpha, -a, a)) + CG(1, 0, 1, 0, 1, 0)) == x*(2*a + 1)*KroneckerDelta(b, 0) + CG(1, 0, 1, 0, 1, 0) + assert cg_simp(2 * Sum(CG(1, alpha, 0, 0, 1, alpha), (alpha, -1, 1))) == 6 + # Varshalovich 8.7.1 Eq 2 + assert cg_simp(x*Sum((-1)**(a - alpha) * CG(a, alpha, a, -alpha, c, + 0), (alpha, -a, a))) == x*sqrt(2*a + 1)*KroneckerDelta(c, 0) + assert cg_simp(3*Sum((-1)**(2 - alpha) * CG( + 2, alpha, 2, -alpha, 0, 0), (alpha, -2, 2))) == 3*sqrt(5) + # Varshalovich 8.7.2 Eq 4 + assert cg_simp(Sum(CG(a, alpha, b, beta, c, gamma)*CG(a, alpha, b, beta, cp, gammap), (alpha, -a, a), (beta, -b, b))) == KroneckerDelta(c, cp)*KroneckerDelta(gamma, gammap) + assert cg_simp(Sum(CG(a, alpha, b, beta, c, gamma)*CG(a, alpha, b, beta, c, gammap), (alpha, -a, a), (beta, -b, b))) == KroneckerDelta(gamma, gammap) + assert cg_simp(Sum(CG(a, alpha, b, beta, c, gamma)*CG(a, alpha, b, beta, cp, gamma), (alpha, -a, a), (beta, -b, b))) == KroneckerDelta(c, cp) + assert cg_simp(Sum(CG( + a, alpha, b, beta, c, gamma)**2, (alpha, -a, a), (beta, -b, b))) == 1 + assert cg_simp(Sum(CG(2, alpha, 1, beta, 2, gamma)*CG(2, alpha, 1, beta, 2, gammap), (alpha, -2, 2), (beta, -1, 1))) == KroneckerDelta(gamma, gammap) + + +def test_doit(): + assert Wigner3j(S.Half, Rational(-1, 2), S.Half, S.Half, 0, 0).doit() == -sqrt(2)/2 + assert Wigner3j(1/2,1/2,1/2,1/2,1/2,1/2).doit() == 0 + assert Wigner3j(9/2,9/2,9/2,9/2,9/2,9/2).doit() == 0 + assert Wigner6j(1, 2, 3, 2, 1, 2).doit() == sqrt(21)/105 + assert Wigner6j(3, 1, 2, 2, 2, 1).doit() == sqrt(21) / 105 + assert Wigner9j( + 2, 1, 1, Rational(3, 2), S.Half, 1, S.Half, S.Half, 0).doit() == sqrt(2)/12 + assert CG(S.Half, S.Half, S.Half, Rational(-1, 2), 1, 0).doit() == sqrt(2)/2 + # J minus M is not integer + assert Wigner3j(1, -1, S.Half, S.Half, 1, S.Half).doit() == 0 + assert CG(4, -1, S.Half, S.Half, 4, Rational(-1, 2)).doit() == 0 diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_circuitplot.py b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_circuitplot.py new file mode 100644 index 0000000000000000000000000000000000000000..fcc89f77047450ad3f8663f371f483654dc70ea9 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_circuitplot.py @@ -0,0 +1,69 @@ +from sympy.physics.quantum.circuitplot import labeller, render_label, Mz, CreateOneQubitGate,\ + CreateCGate +from sympy.physics.quantum.gate import CNOT, H, SWAP, CGate, S, T +from sympy.external import import_module +from sympy.testing.pytest import skip + +mpl = import_module('matplotlib') + +def test_render_label(): + assert render_label('q0') == r'$\left|q0\right\rangle$' + assert render_label('q0', {'q0': '0'}) == r'$\left|q0\right\rangle=\left|0\right\rangle$' + +def test_Mz(): + assert str(Mz(0)) == 'Mz(0)' + +def test_create1(): + Qgate = CreateOneQubitGate('Q') + assert str(Qgate(0)) == 'Q(0)' + +def test_createc(): + Qgate = CreateCGate('Q') + assert str(Qgate([1],0)) == 'C((1),Q(0))' + +def test_labeller(): + """Test the labeller utility""" + assert labeller(2) == ['q_1', 'q_0'] + assert labeller(3,'j') == ['j_2', 'j_1', 'j_0'] + +def test_cnot(): + """Test a simple cnot circuit. Right now this only makes sure the code doesn't + raise an exception, and some simple properties + """ + if not mpl: + skip("matplotlib not installed") + else: + from sympy.physics.quantum.circuitplot import CircuitPlot + + c = CircuitPlot(CNOT(1,0),2,labels=labeller(2)) + assert c.ngates == 2 + assert c.nqubits == 2 + assert c.labels == ['q_1', 'q_0'] + + c = CircuitPlot(CNOT(1,0),2) + assert c.ngates == 2 + assert c.nqubits == 2 + assert c.labels == [] + +def test_ex1(): + if not mpl: + skip("matplotlib not installed") + else: + from sympy.physics.quantum.circuitplot import CircuitPlot + + c = CircuitPlot(CNOT(1,0)*H(1),2,labels=labeller(2)) + assert c.ngates == 2 + assert c.nqubits == 2 + assert c.labels == ['q_1', 'q_0'] + +def test_ex4(): + if not mpl: + skip("matplotlib not installed") + else: + from sympy.physics.quantum.circuitplot import CircuitPlot + + c = CircuitPlot(SWAP(0,2)*H(0)* CGate((0,),S(1)) *H(1)*CGate((0,),T(2))\ + *CGate((1,),S(2))*H(2),3,labels=labeller(3,'j')) + assert c.ngates == 7 + assert c.nqubits == 3 + assert c.labels == ['j_2', 'j_1', 'j_0'] diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_circuitutils.py b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_circuitutils.py new file mode 100644 index 0000000000000000000000000000000000000000..8ea7232320417db8bf745871cff0e77aaf1901e7 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_circuitutils.py @@ -0,0 +1,402 @@ +from sympy.core.mul import Mul +from sympy.core.numbers import Integer +from sympy.core.symbol import Symbol +from sympy.utilities import numbered_symbols +from sympy.physics.quantum.gate import X, Y, Z, H, CNOT, CGate +from sympy.physics.quantum.identitysearch import bfs_identity_search +from sympy.physics.quantum.circuitutils import (kmp_table, find_subcircuit, + replace_subcircuit, convert_to_symbolic_indices, + convert_to_real_indices, random_reduce, random_insert, + flatten_ids) +from sympy.testing.pytest import slow + + +def create_gate_sequence(qubit=0): + gates = (X(qubit), Y(qubit), Z(qubit), H(qubit)) + return gates + + +def test_kmp_table(): + word = ('a', 'b', 'c', 'd', 'a', 'b', 'd') + expected_table = [-1, 0, 0, 0, 0, 1, 2] + assert expected_table == kmp_table(word) + + word = ('P', 'A', 'R', 'T', 'I', 'C', 'I', 'P', 'A', 'T', 'E', ' ', + 'I', 'N', ' ', 'P', 'A', 'R', 'A', 'C', 'H', 'U', 'T', 'E') + expected_table = [-1, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, + 0, 0, 0, 0, 1, 2, 3, 0, 0, 0, 0, 0] + assert expected_table == kmp_table(word) + + x = X(0) + y = Y(0) + z = Z(0) + h = H(0) + word = (x, y, y, x, z) + expected_table = [-1, 0, 0, 0, 1] + assert expected_table == kmp_table(word) + + word = (x, x, y, h, z) + expected_table = [-1, 0, 1, 0, 0] + assert expected_table == kmp_table(word) + + +def test_find_subcircuit(): + x = X(0) + y = Y(0) + z = Z(0) + h = H(0) + x1 = X(1) + y1 = Y(1) + + i0 = Symbol('i0') + x_i0 = X(i0) + y_i0 = Y(i0) + z_i0 = Z(i0) + h_i0 = H(i0) + + circuit = (x, y, z) + + assert find_subcircuit(circuit, (x,)) == 0 + assert find_subcircuit(circuit, (x1,)) == -1 + assert find_subcircuit(circuit, (y,)) == 1 + assert find_subcircuit(circuit, (h,)) == -1 + assert find_subcircuit(circuit, Mul(x, h)) == -1 + assert find_subcircuit(circuit, Mul(x, y, z)) == 0 + assert find_subcircuit(circuit, Mul(y, z)) == 1 + assert find_subcircuit(Mul(*circuit), (x, y, z, h)) == -1 + assert find_subcircuit(Mul(*circuit), (z, y, x)) == -1 + assert find_subcircuit(circuit, (x,), start=2, end=1) == -1 + + circuit = (x, y, x, y, z) + assert find_subcircuit(Mul(*circuit), Mul(x, y, z)) == 2 + assert find_subcircuit(circuit, (x,), start=1) == 2 + assert find_subcircuit(circuit, (x, y), start=1, end=2) == -1 + assert find_subcircuit(Mul(*circuit), (x, y), start=1, end=3) == -1 + assert find_subcircuit(circuit, (x, y), start=1, end=4) == 2 + assert find_subcircuit(circuit, (x, y), start=2, end=4) == 2 + + circuit = (x, y, z, x1, x, y, z, h, x, y, x1, + x, y, z, h, y1, h) + assert find_subcircuit(circuit, (x, y, z, h, y1)) == 11 + + circuit = (x, y, x_i0, y_i0, z_i0, z) + assert find_subcircuit(circuit, (x_i0, y_i0, z_i0)) == 2 + + circuit = (x_i0, y_i0, z_i0, x_i0, y_i0, h_i0) + subcircuit = (x_i0, y_i0, z_i0) + result = find_subcircuit(circuit, subcircuit) + assert result == 0 + + +def test_replace_subcircuit(): + x = X(0) + y = Y(0) + z = Z(0) + h = H(0) + cnot = CNOT(1, 0) + cgate_z = CGate((0,), Z(1)) + + # Standard cases + circuit = (z, y, x, x) + remove = (z, y, x) + assert replace_subcircuit(circuit, Mul(*remove)) == (x,) + assert replace_subcircuit(circuit, remove + (x,)) == () + assert replace_subcircuit(circuit, remove, pos=1) == circuit + assert replace_subcircuit(circuit, remove, pos=0) == (x,) + assert replace_subcircuit(circuit, (x, x), pos=2) == (z, y) + assert replace_subcircuit(circuit, (h,)) == circuit + + circuit = (x, y, x, y, z) + remove = (x, y, z) + assert replace_subcircuit(Mul(*circuit), Mul(*remove)) == (x, y) + remove = (x, y, x, y) + assert replace_subcircuit(circuit, remove) == (z,) + + circuit = (x, h, cgate_z, h, cnot) + remove = (x, h, cgate_z) + assert replace_subcircuit(circuit, Mul(*remove), pos=-1) == (h, cnot) + assert replace_subcircuit(circuit, remove, pos=1) == circuit + remove = (h, h) + assert replace_subcircuit(circuit, remove) == circuit + remove = (h, cgate_z, h, cnot) + assert replace_subcircuit(circuit, remove) == (x,) + + replace = (h, x) + actual = replace_subcircuit(circuit, remove, + replace=replace) + assert actual == (x, h, x) + + circuit = (x, y, h, x, y, z) + remove = (x, y) + replace = (cnot, cgate_z) + actual = replace_subcircuit(circuit, remove, + replace=Mul(*replace)) + assert actual == (cnot, cgate_z, h, x, y, z) + + actual = replace_subcircuit(circuit, remove, + replace=replace, pos=1) + assert actual == (x, y, h, cnot, cgate_z, z) + + +def test_convert_to_symbolic_indices(): + (x, y, z, h) = create_gate_sequence() + + i0 = Symbol('i0') + exp_map = {i0: Integer(0)} + actual, act_map, sndx, gen = convert_to_symbolic_indices((x,)) + assert actual == (X(i0),) + assert act_map == exp_map + + expected = (X(i0), Y(i0), Z(i0), H(i0)) + exp_map = {i0: Integer(0)} + actual, act_map, sndx, gen = convert_to_symbolic_indices((x, y, z, h)) + assert actual == expected + assert exp_map == act_map + + (x1, y1, z1, h1) = create_gate_sequence(1) + i1 = Symbol('i1') + + expected = (X(i0), Y(i0), Z(i0), H(i0)) + exp_map = {i0: Integer(1)} + actual, act_map, sndx, gen = convert_to_symbolic_indices((x1, y1, z1, h1)) + assert actual == expected + assert act_map == exp_map + + expected = (X(i0), Y(i0), Z(i0), H(i0), X(i1), Y(i1), Z(i1), H(i1)) + exp_map = {i0: Integer(0), i1: Integer(1)} + actual, act_map, sndx, gen = convert_to_symbolic_indices((x, y, z, h, + x1, y1, z1, h1)) + assert actual == expected + assert act_map == exp_map + + exp_map = {i0: Integer(1), i1: Integer(0)} + actual, act_map, sndx, gen = convert_to_symbolic_indices(Mul(x1, y1, + z1, h1, x, y, z, h)) + assert actual == expected + assert act_map == exp_map + + expected = (X(i0), X(i1), Y(i0), Y(i1), Z(i0), Z(i1), H(i0), H(i1)) + exp_map = {i0: Integer(0), i1: Integer(1)} + actual, act_map, sndx, gen = convert_to_symbolic_indices(Mul(x, x1, + y, y1, z, z1, h, h1)) + assert actual == expected + assert act_map == exp_map + + exp_map = {i0: Integer(1), i1: Integer(0)} + actual, act_map, sndx, gen = convert_to_symbolic_indices((x1, x, y1, y, + z1, z, h1, h)) + assert actual == expected + assert act_map == exp_map + + cnot_10 = CNOT(1, 0) + cnot_01 = CNOT(0, 1) + cgate_z_10 = CGate(1, Z(0)) + cgate_z_01 = CGate(0, Z(1)) + + expected = (X(i0), X(i1), Y(i0), Y(i1), Z(i0), Z(i1), + H(i0), H(i1), CNOT(i1, i0), CNOT(i0, i1), + CGate(i1, Z(i0)), CGate(i0, Z(i1))) + exp_map = {i0: Integer(0), i1: Integer(1)} + args = (x, x1, y, y1, z, z1, h, h1, cnot_10, cnot_01, + cgate_z_10, cgate_z_01) + actual, act_map, sndx, gen = convert_to_symbolic_indices(args) + assert actual == expected + assert act_map == exp_map + + args = (x1, x, y1, y, z1, z, h1, h, cnot_10, cnot_01, + cgate_z_10, cgate_z_01) + expected = (X(i0), X(i1), Y(i0), Y(i1), Z(i0), Z(i1), + H(i0), H(i1), CNOT(i0, i1), CNOT(i1, i0), + CGate(i0, Z(i1)), CGate(i1, Z(i0))) + exp_map = {i0: Integer(1), i1: Integer(0)} + actual, act_map, sndx, gen = convert_to_symbolic_indices(args) + assert actual == expected + assert act_map == exp_map + + args = (cnot_10, h, cgate_z_01, h) + expected = (CNOT(i0, i1), H(i1), CGate(i1, Z(i0)), H(i1)) + exp_map = {i0: Integer(1), i1: Integer(0)} + actual, act_map, sndx, gen = convert_to_symbolic_indices(args) + assert actual == expected + assert act_map == exp_map + + args = (cnot_01, h1, cgate_z_10, h1) + exp_map = {i0: Integer(0), i1: Integer(1)} + actual, act_map, sndx, gen = convert_to_symbolic_indices(args) + assert actual == expected + assert act_map == exp_map + + args = (cnot_10, h1, cgate_z_01, h1) + expected = (CNOT(i0, i1), H(i0), CGate(i1, Z(i0)), H(i0)) + exp_map = {i0: Integer(1), i1: Integer(0)} + actual, act_map, sndx, gen = convert_to_symbolic_indices(args) + assert actual == expected + assert act_map == exp_map + + i2 = Symbol('i2') + ccgate_z = CGate(0, CGate(1, Z(2))) + ccgate_x = CGate(1, CGate(2, X(0))) + args = (ccgate_z, ccgate_x) + + expected = (CGate(i0, CGate(i1, Z(i2))), CGate(i1, CGate(i2, X(i0)))) + exp_map = {i0: Integer(0), i1: Integer(1), i2: Integer(2)} + actual, act_map, sndx, gen = convert_to_symbolic_indices(args) + assert actual == expected + assert act_map == exp_map + + ndx_map = {i0: Integer(0)} + index_gen = numbered_symbols(prefix='i', start=1) + actual, act_map, sndx, gen = convert_to_symbolic_indices(args, + qubit_map=ndx_map, + start=i0, + gen=index_gen) + assert actual == expected + assert act_map == exp_map + + i3 = Symbol('i3') + cgate_x0_c321 = CGate((3, 2, 1), X(0)) + exp_map = {i0: Integer(3), i1: Integer(2), + i2: Integer(1), i3: Integer(0)} + expected = (CGate((i0, i1, i2), X(i3)),) + args = (cgate_x0_c321,) + actual, act_map, sndx, gen = convert_to_symbolic_indices(args) + assert actual == expected + assert act_map == exp_map + + +def test_convert_to_real_indices(): + i0 = Symbol('i0') + i1 = Symbol('i1') + + (x, y, z, h) = create_gate_sequence() + + x_i0 = X(i0) + y_i0 = Y(i0) + z_i0 = Z(i0) + + qubit_map = {i0: 0} + args = (z_i0, y_i0, x_i0) + expected = (z, y, x) + actual = convert_to_real_indices(args, qubit_map) + assert actual == expected + + cnot_10 = CNOT(1, 0) + cnot_01 = CNOT(0, 1) + cgate_z_10 = CGate(1, Z(0)) + cgate_z_01 = CGate(0, Z(1)) + + cnot_i1_i0 = CNOT(i1, i0) + cnot_i0_i1 = CNOT(i0, i1) + cgate_z_i1_i0 = CGate(i1, Z(i0)) + + qubit_map = {i0: 0, i1: 1} + args = (cnot_i1_i0,) + expected = (cnot_10,) + actual = convert_to_real_indices(args, qubit_map) + assert actual == expected + + args = (cgate_z_i1_i0,) + expected = (cgate_z_10,) + actual = convert_to_real_indices(args, qubit_map) + assert actual == expected + + args = (cnot_i0_i1,) + expected = (cnot_01,) + actual = convert_to_real_indices(args, qubit_map) + assert actual == expected + + qubit_map = {i0: 1, i1: 0} + args = (cgate_z_i1_i0,) + expected = (cgate_z_01,) + actual = convert_to_real_indices(args, qubit_map) + assert actual == expected + + i2 = Symbol('i2') + ccgate_z = CGate(i0, CGate(i1, Z(i2))) + ccgate_x = CGate(i1, CGate(i2, X(i0))) + + qubit_map = {i0: 0, i1: 1, i2: 2} + args = (ccgate_z, ccgate_x) + expected = (CGate(0, CGate(1, Z(2))), CGate(1, CGate(2, X(0)))) + actual = convert_to_real_indices(Mul(*args), qubit_map) + assert actual == expected + + qubit_map = {i0: 1, i2: 0, i1: 2} + args = (ccgate_x, ccgate_z) + expected = (CGate(2, CGate(0, X(1))), CGate(1, CGate(2, Z(0)))) + actual = convert_to_real_indices(args, qubit_map) + assert actual == expected + + +@slow +def test_random_reduce(): + x = X(0) + y = Y(0) + z = Z(0) + h = H(0) + cnot = CNOT(1, 0) + cgate_z = CGate((0,), Z(1)) + + gate_list = [x, y, z] + ids = list(bfs_identity_search(gate_list, 1, max_depth=4)) + + circuit = (x, y, h, z, cnot) + assert random_reduce(circuit, []) == circuit + assert random_reduce(circuit, ids) == circuit + + seq = [2, 11, 9, 3, 5] + circuit = (x, y, z, x, y, h) + assert random_reduce(circuit, ids, seed=seq) == (x, y, h) + + circuit = (x, x, y, y, z, z) + assert random_reduce(circuit, ids, seed=seq) == (x, x, y, y) + + seq = [14, 13, 0] + assert random_reduce(circuit, ids, seed=seq) == (y, y, z, z) + + gate_list = [x, y, z, h, cnot, cgate_z] + ids = list(bfs_identity_search(gate_list, 2, max_depth=4)) + + seq = [25] + circuit = (x, y, z, y, h, y, h, cgate_z, h, cnot) + expected = (x, y, z, cgate_z, h, cnot) + assert random_reduce(circuit, ids, seed=seq) == expected + circuit = Mul(*circuit) + assert random_reduce(circuit, ids, seed=seq) == expected + + +@slow +def test_random_insert(): + x = X(0) + y = Y(0) + z = Z(0) + h = H(0) + cnot = CNOT(1, 0) + cgate_z = CGate((0,), Z(1)) + + choices = [(x, x)] + circuit = (y, y) + loc, choice = 0, 0 + actual = random_insert(circuit, choices, seed=[loc, choice]) + assert actual == (x, x, y, y) + + circuit = (x, y, z, h) + choices = [(h, h), (x, y, z)] + expected = (x, x, y, z, y, z, h) + loc, choice = 1, 1 + actual = random_insert(circuit, choices, seed=[loc, choice]) + assert actual == expected + + gate_list = [x, y, z, h, cnot, cgate_z] + ids = list(bfs_identity_search(gate_list, 2, max_depth=4)) + + eq_ids = flatten_ids(ids) + + circuit = (x, y, h, cnot, cgate_z) + expected = (x, z, x, z, x, y, h, cnot, cgate_z) + loc, choice = 1, 30 + actual = random_insert(circuit, eq_ids, seed=[loc, choice]) + assert actual == expected + circuit = Mul(*circuit) + actual = random_insert(circuit, eq_ids, seed=[loc, choice]) + assert actual == expected diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_commutator.py b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_commutator.py new file mode 100644 index 0000000000000000000000000000000000000000..04f45feddaca63d7306363a9235c63f534d11430 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_commutator.py @@ -0,0 +1,81 @@ +from sympy.core.numbers import Integer +from sympy.core.symbol import symbols + +from sympy.physics.quantum.dagger import Dagger +from sympy.physics.quantum.commutator import Commutator as Comm +from sympy.physics.quantum.operator import Operator + + +a, b, c = symbols('a,b,c') +n = symbols('n', integer=True) +A, B, C, D = symbols('A,B,C,D', commutative=False) + + +def test_commutator(): + c = Comm(A, B) + assert c.is_commutative is False + assert isinstance(c, Comm) + assert c.subs(A, C) == Comm(C, B) + + +def test_commutator_identities(): + assert Comm(a*A, b*B) == a*b*Comm(A, B) + assert Comm(A, A) == 0 + assert Comm(a, b) == 0 + assert Comm(A, B) == -Comm(B, A) + assert Comm(A, B).doit() == A*B - B*A + assert Comm(A, B*C).expand(commutator=True) == Comm(A, B)*C + B*Comm(A, C) + assert Comm(A*B, C*D).expand(commutator=True) == \ + A*C*Comm(B, D) + A*Comm(B, C)*D + C*Comm(A, D)*B + Comm(A, C)*D*B + assert Comm(A, B**2).expand(commutator=True) == Comm(A, B)*B + B*Comm(A, B) + assert Comm(A**2, C**2).expand(commutator=True) == \ + Comm(A*B, C*D).expand(commutator=True).replace(B, A).replace(D, C) == \ + A*C*Comm(A, C) + A*Comm(A, C)*C + C*Comm(A, C)*A + Comm(A, C)*C*A + assert Comm(A, C**-2).expand(commutator=True) == \ + Comm(A, (1/C)*(1/D)).expand(commutator=True).replace(D, C) + assert Comm(A + B, C + D).expand(commutator=True) == \ + Comm(A, C) + Comm(A, D) + Comm(B, C) + Comm(B, D) + assert Comm(A, B + C).expand(commutator=True) == Comm(A, B) + Comm(A, C) + assert Comm(A**n, B).expand(commutator=True) == Comm(A**n, B) + + e = Comm(A, Comm(B, C)) + Comm(B, Comm(C, A)) + Comm(C, Comm(A, B)) + assert e.doit().expand() == 0 + + +def test_commutator_dagger(): + comm = Comm(A*B, C) + assert Dagger(comm).expand(commutator=True) == \ + - Comm(Dagger(B), Dagger(C))*Dagger(A) - \ + Dagger(B)*Comm(Dagger(A), Dagger(C)) + + +class Foo(Operator): + + def _eval_commutator_Bar(self, bar): + return Integer(0) + + +class Bar(Operator): + pass + + +class Tam(Operator): + + def _eval_commutator_Foo(self, foo): + return Integer(1) + + +def test_eval_commutator(): + F = Foo('F') + B = Bar('B') + T = Tam('T') + assert Comm(F, B).doit() == 0 + assert Comm(B, F).doit() == 0 + assert Comm(F, T).doit() == -1 + assert Comm(T, F).doit() == 1 + assert Comm(B, T).doit() == B*T - T*B + assert Comm(F**2, B).expand(commutator=True).doit() == 0 + assert Comm(F**2, T).expand(commutator=True).doit() == -2*F + assert Comm(F, T**2).expand(commutator=True).doit() == -2*T + assert Comm(T**2, F).expand(commutator=True).doit() == 2*T + assert Comm(T**2, F**3).expand(commutator=True).doit() == 2*F*T*F + 2*F**2*T + 2*T*F**2 diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_constants.py b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_constants.py new file mode 100644 index 0000000000000000000000000000000000000000..48a773ea6b5afbaf956143b50b16b3b18aaf5beb --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_constants.py @@ -0,0 +1,13 @@ +from sympy.core.numbers import Float + +from sympy.physics.quantum.constants import hbar + + +def test_hbar(): + assert hbar.is_commutative is True + assert hbar.is_real is True + assert hbar.is_positive is True + assert hbar.is_negative is False + assert hbar.is_irrational is True + + assert hbar.evalf() == Float(1.05457162e-34) diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_dagger.py b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_dagger.py new file mode 100644 index 0000000000000000000000000000000000000000..1357c9320a20afa2ba905a117d90ed1ac2e9642c --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_dagger.py @@ -0,0 +1,103 @@ +from sympy.core.expr import Expr +from sympy.core.mul import Mul +from sympy.core.numbers import (I, Integer) +from sympy.core.symbol import symbols +from sympy.functions.elementary.complexes import conjugate +from sympy.matrices.dense import Matrix + +from sympy.physics.quantum.dagger import adjoint, Dagger +from sympy.external import import_module +from sympy.testing.pytest import skip, warns_deprecated_sympy +from sympy.physics.quantum.operator import Operator, IdentityOperator + + +def test_scalars(): + x = symbols('x', complex=True) + assert Dagger(x) == conjugate(x) + assert Dagger(I*x) == -I*conjugate(x) + + i = symbols('i', real=True) + assert Dagger(i) == i + + p = symbols('p') + assert isinstance(Dagger(p), conjugate) + + i = Integer(3) + assert Dagger(i) == i + + A = symbols('A', commutative=False) + assert Dagger(A).is_commutative is False + + +def test_matrix(): + x = symbols('x') + m = Matrix([[I, x*I], [2, 4]]) + assert Dagger(m) == m.H + + +def test_dagger_mul(): + O = Operator('O') + assert Dagger(O)*O == Dagger(O)*O + with warns_deprecated_sympy(): + I = IdentityOperator() + assert Dagger(O)*O*I == Mul(Dagger(O), O)*I + assert Dagger(O)*Dagger(O) == Dagger(O)**2 + assert Dagger(O)*Dagger(I) == Dagger(O) + + +class Foo(Expr): + + def _eval_adjoint(self): + return I + + +def test_eval_adjoint(): + f = Foo() + d = Dagger(f) + assert d == I + +np = import_module('numpy') + + +def test_numpy_dagger(): + if not np: + skip("numpy not installed.") + + a = np.array([[1.0, 2.0j], [-1.0j, 2.0]]) + adag = a.copy().transpose().conjugate() + assert (Dagger(a) == adag).all() + + +scipy = import_module('scipy', import_kwargs={'fromlist': ['sparse']}) + + +def test_scipy_sparse_dagger(): + if not np: + skip("numpy not installed.") + if not scipy: + skip("scipy not installed.") + else: + sparse = scipy.sparse + + a = sparse.csr_matrix([[1.0 + 0.0j, 2.0j], [-1.0j, 2.0 + 0.0j]]) + adag = a.copy().transpose().conjugate() + assert np.linalg.norm((Dagger(a) - adag).todense()) == 0.0 + + +def test_unknown(): + """Check treatment of unknown objects. + Objects without adjoint or conjugate/transpose methods + are sympified and wrapped in dagger. + """ + x = symbols("x", commutative=False) + result = Dagger(x) + assert result.args == (x,) and isinstance(result, adjoint) + + +def test_unevaluated(): + """Check that evaluate=False returns unevaluated Dagger. + """ + x = symbols("x", real=True) + assert Dagger(x) == x + result = Dagger(x, evaluate=False) + assert result.args == (x,) and isinstance(result, adjoint) diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_density.py b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_density.py new file mode 100644 index 0000000000000000000000000000000000000000..399acce6e201b39f65ea674048198fd2f087b4d0 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_density.py @@ -0,0 +1,289 @@ +from sympy.core.numbers import Rational +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.exponential import log +from sympy.external import import_module +from sympy.physics.quantum.density import Density, entropy, fidelity +from sympy.physics.quantum.state import Ket, TimeDepKet +from sympy.physics.quantum.qubit import Qubit +from sympy.physics.quantum.represent import represent +from sympy.physics.quantum.dagger import Dagger +from sympy.physics.quantum.cartesian import XKet, PxKet, PxOp, XOp +from sympy.physics.quantum.spin import JzKet +from sympy.physics.quantum.operator import OuterProduct +from sympy.physics.quantum.trace import Tr +from sympy.functions import sqrt +from sympy.testing.pytest import raises +from sympy.physics.quantum.matrixutils import scipy_sparse_matrix +from sympy.physics.quantum.tensorproduct import TensorProduct + + +def test_eval_args(): + # check instance created + assert isinstance(Density([Ket(0), 0.5], [Ket(1), 0.5]), Density) + assert isinstance(Density([Qubit('00'), 1/sqrt(2)], + [Qubit('11'), 1/sqrt(2)]), Density) + + #test if Qubit object type preserved + d = Density([Qubit('00'), 1/sqrt(2)], [Qubit('11'), 1/sqrt(2)]) + for (state, prob) in d.args: + assert isinstance(state, Qubit) + + # check for value error, when prob is not provided + raises(ValueError, lambda: Density([Ket(0)], [Ket(1)])) + + +def test_doit(): + + x, y = symbols('x y') + A, B, C, D, E, F = symbols('A B C D E F', commutative=False) + d = Density([XKet(), 0.5], [PxKet(), 0.5]) + assert (0.5*(PxKet()*Dagger(PxKet())) + + 0.5*(XKet()*Dagger(XKet()))) == d.doit() + + # check for kets with expr in them + d_with_sym = Density([XKet(x*y), 0.5], [PxKet(x*y), 0.5]) + assert (0.5*(PxKet(x*y)*Dagger(PxKet(x*y))) + + 0.5*(XKet(x*y)*Dagger(XKet(x*y)))) == d_with_sym.doit() + + d = Density([(A + B)*C, 1.0]) + assert d.doit() == (1.0*A*C*Dagger(C)*Dagger(A) + + 1.0*A*C*Dagger(C)*Dagger(B) + + 1.0*B*C*Dagger(C)*Dagger(A) + + 1.0*B*C*Dagger(C)*Dagger(B)) + + # With TensorProducts as args + # Density with simple tensor products as args + t = TensorProduct(A, B, C) + d = Density([t, 1.0]) + assert d.doit() == \ + 1.0 * TensorProduct(A*Dagger(A), B*Dagger(B), C*Dagger(C)) + + # Density with multiple Tensorproducts as states + t2 = TensorProduct(A, B) + t3 = TensorProduct(C, D) + + d = Density([t2, 0.5], [t3, 0.5]) + assert d.doit() == (0.5 * TensorProduct(A*Dagger(A), B*Dagger(B)) + + 0.5 * TensorProduct(C*Dagger(C), D*Dagger(D))) + + #Density with mixed states + d = Density([t2 + t3, 1.0]) + assert d.doit() == (1.0 * TensorProduct(A*Dagger(A), B*Dagger(B)) + + 1.0 * TensorProduct(A*Dagger(C), B*Dagger(D)) + + 1.0 * TensorProduct(C*Dagger(A), D*Dagger(B)) + + 1.0 * TensorProduct(C*Dagger(C), D*Dagger(D))) + + #Density operators with spin states + tp1 = TensorProduct(JzKet(1, 1), JzKet(1, -1)) + d = Density([tp1, 1]) + + # full trace + t = Tr(d) + assert t.doit() == 1 + + #Partial trace on density operators with spin states + t = Tr(d, [0]) + assert t.doit() == JzKet(1, -1) * Dagger(JzKet(1, -1)) + t = Tr(d, [1]) + assert t.doit() == JzKet(1, 1) * Dagger(JzKet(1, 1)) + + # with another spin state + tp2 = TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))) + d = Density([tp2, 1]) + + #full trace + t = Tr(d) + assert t.doit() == 1 + + #Partial trace on density operators with spin states + t = Tr(d, [0]) + assert t.doit() == JzKet(S.Half, Rational(-1, 2)) * Dagger(JzKet(S.Half, Rational(-1, 2))) + t = Tr(d, [1]) + assert t.doit() == JzKet(S.Half, S.Half) * Dagger(JzKet(S.Half, S.Half)) + + +def test_apply_op(): + d = Density([Ket(0), 0.5], [Ket(1), 0.5]) + assert d.apply_op(XOp()) == Density([XOp()*Ket(0), 0.5], + [XOp()*Ket(1), 0.5]) + + +def test_represent(): + x, y = symbols('x y') + d = Density([XKet(), 0.5], [PxKet(), 0.5]) + assert (represent(0.5*(PxKet()*Dagger(PxKet()))) + + represent(0.5*(XKet()*Dagger(XKet())))) == represent(d) + + # check for kets with expr in them + d_with_sym = Density([XKet(x*y), 0.5], [PxKet(x*y), 0.5]) + assert (represent(0.5*(PxKet(x*y)*Dagger(PxKet(x*y)))) + + represent(0.5*(XKet(x*y)*Dagger(XKet(x*y))))) == \ + represent(d_with_sym) + + # check when given explicit basis + assert (represent(0.5*(XKet()*Dagger(XKet())), basis=PxOp()) + + represent(0.5*(PxKet()*Dagger(PxKet())), basis=PxOp())) == \ + represent(d, basis=PxOp()) + + +def test_states(): + d = Density([Ket(0), 0.5], [Ket(1), 0.5]) + states = d.states() + assert states[0] == Ket(0) and states[1] == Ket(1) + + +def test_probs(): + d = Density([Ket(0), .75], [Ket(1), 0.25]) + probs = d.probs() + assert probs[0] == 0.75 and probs[1] == 0.25 + + #probs can be symbols + x, y = symbols('x y') + d = Density([Ket(0), x], [Ket(1), y]) + probs = d.probs() + assert probs[0] == x and probs[1] == y + + +def test_get_state(): + x, y = symbols('x y') + d = Density([Ket(0), x], [Ket(1), y]) + states = (d.get_state(0), d.get_state(1)) + assert states[0] == Ket(0) and states[1] == Ket(1) + + +def test_get_prob(): + x, y = symbols('x y') + d = Density([Ket(0), x], [Ket(1), y]) + probs = (d.get_prob(0), d.get_prob(1)) + assert probs[0] == x and probs[1] == y + + +def test_entropy(): + up = JzKet(S.Half, S.Half) + down = JzKet(S.Half, Rational(-1, 2)) + d = Density((up, S.Half), (down, S.Half)) + + # test for density object + ent = entropy(d) + assert entropy(d) == log(2)/2 + assert d.entropy() == log(2)/2 + + np = import_module('numpy', min_module_version='1.4.0') + if np: + #do this test only if 'numpy' is available on test machine + np_mat = represent(d, format='numpy') + ent = entropy(np_mat) + assert isinstance(np_mat, np.ndarray) + assert ent.real == 0.69314718055994529 + assert ent.imag == 0 + + scipy = import_module('scipy', import_kwargs={'fromlist': ['sparse']}) + if scipy and np: + #do this test only if numpy and scipy are available + mat = represent(d, format="scipy.sparse") + assert isinstance(mat, scipy_sparse_matrix) + assert ent.real == 0.69314718055994529 + assert ent.imag == 0 + + +def test_eval_trace(): + up = JzKet(S.Half, S.Half) + down = JzKet(S.Half, Rational(-1, 2)) + d = Density((up, 0.5), (down, 0.5)) + + t = Tr(d) + assert t.doit() == 1.0 + + #test dummy time dependent states + class TestTimeDepKet(TimeDepKet): + def _eval_trace(self, bra, **options): + return 1 + + x, t = symbols('x t') + k1 = TestTimeDepKet(0, 0.5) + k2 = TestTimeDepKet(0, 1) + d = Density([k1, 0.5], [k2, 0.5]) + assert d.doit() == (0.5 * OuterProduct(k1, k1.dual) + + 0.5 * OuterProduct(k2, k2.dual)) + + t = Tr(d) + assert t.doit() == 1.0 + + +def test_fidelity(): + #test with kets + up = JzKet(S.Half, S.Half) + down = JzKet(S.Half, Rational(-1, 2)) + updown = (S.One/sqrt(2))*up + (S.One/sqrt(2))*down + + #check with matrices + up_dm = represent(up * Dagger(up)) + down_dm = represent(down * Dagger(down)) + updown_dm = represent(updown * Dagger(updown)) + + assert abs(fidelity(up_dm, up_dm) - 1) < 1e-3 + assert fidelity(up_dm, down_dm) < 1e-3 + assert abs(fidelity(up_dm, updown_dm) - (S.One/sqrt(2))) < 1e-3 + assert abs(fidelity(updown_dm, down_dm) - (S.One/sqrt(2))) < 1e-3 + + #check with density + up_dm = Density([up, 1.0]) + down_dm = Density([down, 1.0]) + updown_dm = Density([updown, 1.0]) + + assert abs(fidelity(up_dm, up_dm) - 1) < 1e-3 + assert abs(fidelity(up_dm, down_dm)) < 1e-3 + assert abs(fidelity(up_dm, updown_dm) - (S.One/sqrt(2))) < 1e-3 + assert abs(fidelity(updown_dm, down_dm) - (S.One/sqrt(2))) < 1e-3 + + #check mixed states with density + updown2 = sqrt(3)/2*up + S.Half*down + d1 = Density([updown, 0.25], [updown2, 0.75]) + d2 = Density([updown, 0.75], [updown2, 0.25]) + assert abs(fidelity(d1, d2) - 0.991) < 1e-3 + assert abs(fidelity(d2, d1) - fidelity(d1, d2)) < 1e-3 + + #using qubits/density(pure states) + state1 = Qubit('0') + state2 = Qubit('1') + state3 = S.One/sqrt(2)*state1 + S.One/sqrt(2)*state2 + state4 = sqrt(Rational(2, 3))*state1 + S.One/sqrt(3)*state2 + + state1_dm = Density([state1, 1]) + state2_dm = Density([state2, 1]) + state3_dm = Density([state3, 1]) + + assert fidelity(state1_dm, state1_dm) == 1 + assert fidelity(state1_dm, state2_dm) == 0 + assert abs(fidelity(state1_dm, state3_dm) - 1/sqrt(2)) < 1e-3 + assert abs(fidelity(state3_dm, state2_dm) - 1/sqrt(2)) < 1e-3 + + #using qubits/density(mixed states) + d1 = Density([state3, 0.70], [state4, 0.30]) + d2 = Density([state3, 0.20], [state4, 0.80]) + assert abs(fidelity(d1, d1) - 1) < 1e-3 + assert abs(fidelity(d1, d2) - 0.996) < 1e-3 + assert abs(fidelity(d1, d2) - fidelity(d2, d1)) < 1e-3 + + #TODO: test for invalid arguments + # non-square matrix + mat1 = [[0, 0], + [0, 0], + [0, 0]] + + mat2 = [[0, 0], + [0, 0]] + raises(ValueError, lambda: fidelity(mat1, mat2)) + + # unequal dimensions + mat1 = [[0, 0], + [0, 0]] + mat2 = [[0, 0, 0], + [0, 0, 0], + [0, 0, 0]] + raises(ValueError, lambda: fidelity(mat1, mat2)) + + # unsupported data-type + x, y = 1, 2 # random values that is not a matrix + raises(ValueError, lambda: fidelity(x, y)) diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_fermion.py b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_fermion.py new file mode 100644 index 0000000000000000000000000000000000000000..061648c2d5578481196949c38e90ff169fcea972 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_fermion.py @@ -0,0 +1,62 @@ +from pytest import raises + +import sympy +from sympy.physics.quantum import Dagger, AntiCommutator, qapply +from sympy.physics.quantum.fermion import FermionOp +from sympy.physics.quantum.fermion import FermionFockKet, FermionFockBra +from sympy import Symbol + + +def test_fermionoperator(): + c = FermionOp('c') + d = FermionOp('d') + + assert isinstance(c, FermionOp) + assert isinstance(Dagger(c), FermionOp) + + assert c.is_annihilation + assert not Dagger(c).is_annihilation + + assert FermionOp("c") == FermionOp("c", True) + assert FermionOp("c") != FermionOp("d") + assert FermionOp("c", True) != FermionOp("c", False) + + assert AntiCommutator(c, Dagger(c)).doit() == 1 + + assert AntiCommutator(c, Dagger(d)).doit() == c * Dagger(d) + Dagger(d) * c + + +def test_fermion_states(): + c = FermionOp("c") + + # Fock states + assert (FermionFockBra(0) * FermionFockKet(1)).doit() == 0 + assert (FermionFockBra(1) * FermionFockKet(1)).doit() == 1 + + assert qapply(c * FermionFockKet(1)) == FermionFockKet(0) + assert qapply(c * FermionFockKet(0)) == 0 + + assert qapply(Dagger(c) * FermionFockKet(0)) == FermionFockKet(1) + assert qapply(Dagger(c) * FermionFockKet(1)) == 0 + + +def test_power(): + c = FermionOp("c") + assert c**0 == 1 + assert c**1 == c + assert c**2 == 0 + assert c**3 == 0 + assert Dagger(c)**1 == Dagger(c) + assert Dagger(c)**2 == 0 + + assert (c**Symbol('a')).func == sympy.core.power.Pow + assert (c**Symbol('a')).args == (c, Symbol('a')) + + with raises(ValueError): + c**-1 + + with raises(ValueError): + c**3.2 + + with raises(TypeError): + c**1j diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_gate.py b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_gate.py new file mode 100644 index 0000000000000000000000000000000000000000..2d7bf1d624faca8afe4b10699d23acc161ca0cdd --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_gate.py @@ -0,0 +1,360 @@ +from sympy.core.mul import Mul +from sympy.core.numbers import (I, Integer, Rational, pi) +from sympy.core.symbol import (Wild, symbols) +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.matrices import Matrix, ImmutableMatrix + +from sympy.physics.quantum.gate import (XGate, YGate, ZGate, random_circuit, + CNOT, IdentityGate, H, X, Y, S, T, Z, SwapGate, gate_simp, gate_sort, + CNotGate, TGate, HadamardGate, PhaseGate, UGate, CGate) +from sympy.physics.quantum.commutator import Commutator +from sympy.physics.quantum.anticommutator import AntiCommutator +from sympy.physics.quantum.represent import represent +from sympy.physics.quantum.qapply import qapply +from sympy.physics.quantum.qubit import Qubit, IntQubit, qubit_to_matrix, \ + matrix_to_qubit +from sympy.physics.quantum.matrixutils import matrix_to_zero +from sympy.physics.quantum.matrixcache import sqrt2_inv +from sympy.physics.quantum import Dagger + + +def test_gate(): + """Test a basic gate.""" + h = HadamardGate(1) + assert h.min_qubits == 2 + assert h.nqubits == 1 + + i0 = Wild('i0') + i1 = Wild('i1') + h0_w1 = HadamardGate(i0) + h0_w2 = HadamardGate(i0) + h1_w1 = HadamardGate(i1) + + assert h0_w1 == h0_w2 + assert h0_w1 != h1_w1 + assert h1_w1 != h0_w2 + + cnot_10_w1 = CNOT(i1, i0) + cnot_10_w2 = CNOT(i1, i0) + cnot_01_w1 = CNOT(i0, i1) + + assert cnot_10_w1 == cnot_10_w2 + assert cnot_10_w1 != cnot_01_w1 + assert cnot_10_w2 != cnot_01_w1 + + +def test_UGate(): + a, b, c, d = symbols('a,b,c,d') + uMat = Matrix([[a, b], [c, d]]) + + # Test basic case where gate exists in 1-qubit space + u1 = UGate((0,), uMat) + assert represent(u1, nqubits=1) == uMat + assert qapply(u1*Qubit('0')) == a*Qubit('0') + c*Qubit('1') + assert qapply(u1*Qubit('1')) == b*Qubit('0') + d*Qubit('1') + + # Test case where gate exists in a larger space + u2 = UGate((1,), uMat) + u2Rep = represent(u2, nqubits=2) + for i in range(4): + assert u2Rep*qubit_to_matrix(IntQubit(i, 2)) == \ + qubit_to_matrix(qapply(u2*IntQubit(i, 2))) + + +def test_cgate(): + """Test the general CGate.""" + # Test single control functionality + CNOTMatrix = Matrix( + [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0]]) + assert represent(CGate(1, XGate(0)), nqubits=2) == CNOTMatrix + + # Test multiple control bit functionality + ToffoliGate = CGate((1, 2), XGate(0)) + assert represent(ToffoliGate, nqubits=3) == \ + Matrix( + [[1, 0, 0, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 0, 1, 0, 0, 0], [0, 0, 0, 0, 0, + 1, 0, 0], [0, 0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 0, 1, 0]]) + + ToffoliGate = CGate((3, 0), XGate(1)) + assert qapply(ToffoliGate*Qubit('1001')) == \ + matrix_to_qubit(represent(ToffoliGate*Qubit('1001'), nqubits=4)) + assert qapply(ToffoliGate*Qubit('0000')) == \ + matrix_to_qubit(represent(ToffoliGate*Qubit('0000'), nqubits=4)) + + CYGate = CGate(1, YGate(0)) + CYGate_matrix = Matrix( + ((1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 0, -I), (0, 0, I, 0))) + # Test 2 qubit controlled-Y gate decompose method. + assert represent(CYGate.decompose(), nqubits=2) == CYGate_matrix + + CZGate = CGate(0, ZGate(1)) + CZGate_matrix = Matrix( + ((1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, -1))) + assert qapply(CZGate*Qubit('11')) == -Qubit('11') + assert matrix_to_qubit(represent(CZGate*Qubit('11'), nqubits=2)) == \ + -Qubit('11') + # Test 2 qubit controlled-Z gate decompose method. + assert represent(CZGate.decompose(), nqubits=2) == CZGate_matrix + + CPhaseGate = CGate(0, PhaseGate(1)) + assert qapply(CPhaseGate*Qubit('11')) == \ + I*Qubit('11') + assert matrix_to_qubit(represent(CPhaseGate*Qubit('11'), nqubits=2)) == \ + I*Qubit('11') + + # Test that the dagger, inverse, and power of CGate is evaluated properly + assert Dagger(CZGate) == CZGate + assert pow(CZGate, 1) == Dagger(CZGate) + assert Dagger(CZGate) == CZGate.inverse() + assert Dagger(CPhaseGate) != CPhaseGate + assert Dagger(CPhaseGate) == CPhaseGate.inverse() + assert Dagger(CPhaseGate) == pow(CPhaseGate, -1) + assert pow(CPhaseGate, -1) == CPhaseGate.inverse() + + +def test_UGate_CGate_combo(): + a, b, c, d = symbols('a,b,c,d') + uMat = Matrix([[a, b], [c, d]]) + cMat = Matrix([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, a, b], [0, 0, c, d]]) + + # Test basic case where gate exists in 1-qubit space. + u1 = UGate((0,), uMat) + cu1 = CGate(1, u1) + assert represent(cu1, nqubits=2) == cMat + assert qapply(cu1*Qubit('10')) == a*Qubit('10') + c*Qubit('11') + assert qapply(cu1*Qubit('11')) == b*Qubit('10') + d*Qubit('11') + assert qapply(cu1*Qubit('01')) == Qubit('01') + assert qapply(cu1*Qubit('00')) == Qubit('00') + + # Test case where gate exists in a larger space. + u2 = UGate((1,), uMat) + u2Rep = represent(u2, nqubits=2) + for i in range(4): + assert u2Rep*qubit_to_matrix(IntQubit(i, 2)) == \ + qubit_to_matrix(qapply(u2*IntQubit(i, 2))) + +def test_UGate_OneQubitGate_combo(): + v, w, f, g = symbols('v w f g') + uMat1 = ImmutableMatrix([[v, w], [f, g]]) + cMat1 = Matrix([[v, w + 1, 0, 0], [f + 1, g, 0, 0], [0, 0, v, w + 1], [0, 0, f + 1, g]]) + u1 = X(0) + UGate(0, uMat1) + assert represent(u1, nqubits=2) == cMat1 + + uMat2 = ImmutableMatrix([[1/sqrt(2), 1/sqrt(2)], [I/sqrt(2), -I/sqrt(2)]]) + cMat2_1 = Matrix([[Rational(1, 2) + I/2, Rational(1, 2) - I/2], + [Rational(1, 2) - I/2, Rational(1, 2) + I/2]]) + cMat2_2 = Matrix([[1, 0], [0, I]]) + u2 = UGate(0, uMat2) + assert represent(H(0)*u2, nqubits=1) == cMat2_1 + assert represent(u2*H(0), nqubits=1) == cMat2_2 + +def test_represent_hadamard(): + """Test the representation of the hadamard gate.""" + circuit = HadamardGate(0)*Qubit('00') + answer = represent(circuit, nqubits=2) + # Check that the answers are same to within an epsilon. + assert answer == Matrix([sqrt2_inv, sqrt2_inv, 0, 0]) + + +def test_represent_xgate(): + """Test the representation of the X gate.""" + circuit = XGate(0)*Qubit('00') + answer = represent(circuit, nqubits=2) + assert Matrix([0, 1, 0, 0]) == answer + + +def test_represent_ygate(): + """Test the representation of the Y gate.""" + circuit = YGate(0)*Qubit('00') + answer = represent(circuit, nqubits=2) + assert answer[0] == 0 and answer[1] == I and \ + answer[2] == 0 and answer[3] == 0 + + +def test_represent_zgate(): + """Test the representation of the Z gate.""" + circuit = ZGate(0)*Qubit('00') + answer = represent(circuit, nqubits=2) + assert Matrix([1, 0, 0, 0]) == answer + + +def test_represent_phasegate(): + """Test the representation of the S gate.""" + circuit = PhaseGate(0)*Qubit('01') + answer = represent(circuit, nqubits=2) + assert Matrix([0, I, 0, 0]) == answer + + +def test_represent_tgate(): + """Test the representation of the T gate.""" + circuit = TGate(0)*Qubit('01') + assert Matrix([0, exp(I*pi/4), 0, 0]) == represent(circuit, nqubits=2) + + +def test_compound_gates(): + """Test a compound gate representation.""" + circuit = YGate(0)*ZGate(0)*XGate(0)*HadamardGate(0)*Qubit('00') + answer = represent(circuit, nqubits=2) + assert Matrix([I/sqrt(2), I/sqrt(2), 0, 0]) == answer + + +def test_cnot_gate(): + """Test the CNOT gate.""" + circuit = CNotGate(1, 0) + assert represent(circuit, nqubits=2) == \ + Matrix([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0]]) + circuit = circuit*Qubit('111') + assert matrix_to_qubit(represent(circuit, nqubits=3)) == \ + qapply(circuit) + + circuit = CNotGate(1, 0) + assert Dagger(circuit) == circuit + assert Dagger(Dagger(circuit)) == circuit + assert circuit*circuit == 1 + + +def test_gate_sort(): + """Test gate_sort.""" + for g in (X, Y, Z, H, S, T): + assert gate_sort(g(2)*g(1)*g(0)) == g(0)*g(1)*g(2) + e = gate_sort(X(1)*H(0)**2*CNOT(0, 1)*X(1)*X(0)) + assert e == H(0)**2*CNOT(0, 1)*X(0)*X(1)**2 + assert gate_sort(Z(0)*X(0)) == -X(0)*Z(0) + assert gate_sort(Z(0)*X(0)**2) == X(0)**2*Z(0) + assert gate_sort(Y(0)*H(0)) == -H(0)*Y(0) + assert gate_sort(Y(0)*X(0)) == -X(0)*Y(0) + assert gate_sort(Z(0)*Y(0)) == -Y(0)*Z(0) + assert gate_sort(T(0)*S(0)) == S(0)*T(0) + assert gate_sort(Z(0)*S(0)) == S(0)*Z(0) + assert gate_sort(Z(0)*T(0)) == T(0)*Z(0) + assert gate_sort(Z(0)*CNOT(0, 1)) == CNOT(0, 1)*Z(0) + assert gate_sort(S(0)*CNOT(0, 1)) == CNOT(0, 1)*S(0) + assert gate_sort(T(0)*CNOT(0, 1)) == CNOT(0, 1)*T(0) + assert gate_sort(X(1)*CNOT(0, 1)) == CNOT(0, 1)*X(1) + # This takes a long time and should only be uncommented once in a while. + # nqubits = 5 + # ngates = 10 + # trials = 10 + # for i in range(trials): + # c = random_circuit(ngates, nqubits) + # assert represent(c, nqubits=nqubits) == \ + # represent(gate_sort(c), nqubits=nqubits) + + +def test_gate_simp(): + """Test gate_simp.""" + e = H(0)*X(1)*H(0)**2*CNOT(0, 1)*X(1)**3*X(0)*Z(3)**2*S(4)**3 + assert gate_simp(e) == H(0)*CNOT(0, 1)*S(4)*X(0)*Z(4) + assert gate_simp(X(0)*X(0)) == 1 + assert gate_simp(Y(0)*Y(0)) == 1 + assert gate_simp(Z(0)*Z(0)) == 1 + assert gate_simp(H(0)*H(0)) == 1 + assert gate_simp(T(0)*T(0)) == S(0) + assert gate_simp(S(0)*S(0)) == Z(0) + assert gate_simp(Integer(1)) == Integer(1) + assert gate_simp(X(0)**2 + Y(0)**2) == Integer(2) + + +def test_swap_gate(): + """Test the SWAP gate.""" + swap_gate_matrix = Matrix( + ((1, 0, 0, 0), (0, 0, 1, 0), (0, 1, 0, 0), (0, 0, 0, 1))) + assert represent(SwapGate(1, 0).decompose(), nqubits=2) == swap_gate_matrix + assert qapply(SwapGate(1, 3)*Qubit('0010')) == Qubit('1000') + nqubits = 4 + for i in range(nqubits): + for j in range(i): + assert represent(SwapGate(i, j), nqubits=nqubits) == \ + represent(SwapGate(i, j).decompose(), nqubits=nqubits) + + +def test_one_qubit_commutators(): + """Test single qubit gate commutation relations.""" + for g1 in (IdentityGate, X, Y, Z, H, T, S): + for g2 in (IdentityGate, X, Y, Z, H, T, S): + e = Commutator(g1(0), g2(0)) + a = matrix_to_zero(represent(e, nqubits=1, format='sympy')) + b = matrix_to_zero(represent(e.doit(), nqubits=1, format='sympy')) + assert a == b + + e = Commutator(g1(0), g2(1)) + assert e.doit() == 0 + + +def test_one_qubit_anticommutators(): + """Test single qubit gate anticommutation relations.""" + for g1 in (IdentityGate, X, Y, Z, H): + for g2 in (IdentityGate, X, Y, Z, H): + e = AntiCommutator(g1(0), g2(0)) + a = matrix_to_zero(represent(e, nqubits=1, format='sympy')) + b = matrix_to_zero(represent(e.doit(), nqubits=1, format='sympy')) + assert a == b + e = AntiCommutator(g1(0), g2(1)) + a = matrix_to_zero(represent(e, nqubits=2, format='sympy')) + b = matrix_to_zero(represent(e.doit(), nqubits=2, format='sympy')) + assert a == b + + +def test_cnot_commutators(): + """Test commutators of involving CNOT gates.""" + assert Commutator(CNOT(0, 1), Z(0)).doit() == 0 + assert Commutator(CNOT(0, 1), T(0)).doit() == 0 + assert Commutator(CNOT(0, 1), S(0)).doit() == 0 + assert Commutator(CNOT(0, 1), X(1)).doit() == 0 + assert Commutator(CNOT(0, 1), CNOT(0, 1)).doit() == 0 + assert Commutator(CNOT(0, 1), CNOT(0, 2)).doit() == 0 + assert Commutator(CNOT(0, 2), CNOT(0, 1)).doit() == 0 + assert Commutator(CNOT(1, 2), CNOT(1, 0)).doit() == 0 + + +def test_random_circuit(): + c = random_circuit(10, 3) + assert isinstance(c, Mul) + m = represent(c, nqubits=3) + assert m.shape == (8, 8) + assert isinstance(m, Matrix) + + +def test_hermitian_XGate(): + x = XGate(1, 2) + x_dagger = Dagger(x) + + assert (x == x_dagger) + + +def test_hermitian_YGate(): + y = YGate(1, 2) + y_dagger = Dagger(y) + + assert (y == y_dagger) + + +def test_hermitian_ZGate(): + z = ZGate(1, 2) + z_dagger = Dagger(z) + + assert (z == z_dagger) + + +def test_unitary_XGate(): + x = XGate(1, 2) + x_dagger = Dagger(x) + + assert (x*x_dagger == 1) + + +def test_unitary_YGate(): + y = YGate(1, 2) + y_dagger = Dagger(y) + + assert (y*y_dagger == 1) + + +def test_unitary_ZGate(): + z = ZGate(1, 2) + z_dagger = Dagger(z) + + assert (z*z_dagger == 1) diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_grover.py b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_grover.py new file mode 100644 index 0000000000000000000000000000000000000000..b93a5bc5e59380a993dc34e4a160e75f799b3493 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_grover.py @@ -0,0 +1,92 @@ +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.matrices.dense import Matrix +from sympy.physics.quantum.represent import represent +from sympy.physics.quantum.qapply import qapply +from sympy.physics.quantum.qubit import IntQubit +from sympy.physics.quantum.grover import (apply_grover, superposition_basis, + OracleGate, grover_iteration, WGate) + + +def return_one_on_two(qubits): + return qubits == IntQubit(2, qubits.nqubits) + + +def return_one_on_one(qubits): + return qubits == IntQubit(1, nqubits=qubits.nqubits) + + +def test_superposition_basis(): + nbits = 2 + first_half_state = IntQubit(0, nqubits=nbits)/2 + IntQubit(1, nqubits=nbits)/2 + second_half_state = IntQubit(2, nbits)/2 + IntQubit(3, nbits)/2 + assert first_half_state + second_half_state == superposition_basis(nbits) + + nbits = 3 + firstq = (1/sqrt(8))*IntQubit(0, nqubits=nbits) + (1/sqrt(8))*IntQubit(1, nqubits=nbits) + secondq = (1/sqrt(8))*IntQubit(2, nbits) + (1/sqrt(8))*IntQubit(3, nbits) + thirdq = (1/sqrt(8))*IntQubit(4, nbits) + (1/sqrt(8))*IntQubit(5, nbits) + fourthq = (1/sqrt(8))*IntQubit(6, nbits) + (1/sqrt(8))*IntQubit(7, nbits) + assert firstq + secondq + thirdq + fourthq == superposition_basis(nbits) + + +def test_OracleGate(): + v = OracleGate(1, lambda qubits: qubits == IntQubit(0)) + assert qapply(v*IntQubit(0)) == -IntQubit(0) + assert qapply(v*IntQubit(1)) == IntQubit(1) + + nbits = 2 + v = OracleGate(2, return_one_on_two) + assert qapply(v*IntQubit(0, nbits)) == IntQubit(0, nqubits=nbits) + assert qapply(v*IntQubit(1, nbits)) == IntQubit(1, nqubits=nbits) + assert qapply(v*IntQubit(2, nbits)) == -IntQubit(2, nbits) + assert qapply(v*IntQubit(3, nbits)) == IntQubit(3, nbits) + + assert represent(OracleGate(1, lambda qubits: qubits == IntQubit(0)), nqubits=1) == \ + Matrix([[-1, 0], [0, 1]]) + assert represent(v, nqubits=2) == Matrix([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) + + +def test_WGate(): + nqubits = 2 + basis_states = superposition_basis(nqubits) + assert qapply(WGate(nqubits)*basis_states) == basis_states + + expected = ((2/sqrt(pow(2, nqubits)))*basis_states) - IntQubit(1, nqubits=nqubits) + assert qapply(WGate(nqubits)*IntQubit(1, nqubits=nqubits)) == expected + + +def test_grover_iteration_1(): + numqubits = 2 + basis_states = superposition_basis(numqubits) + v = OracleGate(numqubits, return_one_on_one) + expected = IntQubit(1, nqubits=numqubits) + assert qapply(grover_iteration(basis_states, v)) == expected + + +def test_grover_iteration_2(): + numqubits = 4 + basis_states = superposition_basis(numqubits) + v = OracleGate(numqubits, return_one_on_two) + # After (pi/4)sqrt(pow(2, n)), IntQubit(2) should have highest prob + # In this case, after around pi times (3 or 4) + iterated = grover_iteration(basis_states, v) + iterated = qapply(iterated) + iterated = grover_iteration(iterated, v) + iterated = qapply(iterated) + iterated = grover_iteration(iterated, v) + iterated = qapply(iterated) + # In this case, probability was highest after 3 iterations + # Probability of Qubit('0010') was 251/256 (3) vs 781/1024 (4) + # Ask about measurement + expected = (-13*basis_states)/64 + 264*IntQubit(2, numqubits)/256 + assert qapply(expected) == iterated + + +def test_grover(): + nqubits = 2 + assert apply_grover(return_one_on_one, nqubits) == IntQubit(1, nqubits=nqubits) + + nqubits = 4 + basis_states = superposition_basis(nqubits) + expected = (-13*basis_states)/64 + 264*IntQubit(2, nqubits)/256 + assert apply_grover(return_one_on_two, 4) == qapply(expected) diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_hilbert.py b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_hilbert.py new file mode 100644 index 0000000000000000000000000000000000000000..9a0e5c4187c6c62e14505efb1597a5cd63c23fea --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_hilbert.py @@ -0,0 +1,110 @@ +from sympy.physics.quantum.hilbert import ( + HilbertSpace, ComplexSpace, L2, FockSpace, TensorProductHilbertSpace, + DirectSumHilbertSpace, TensorPowerHilbertSpace +) + +from sympy.core.numbers import oo +from sympy.core.symbol import Symbol +from sympy.printing.repr import srepr +from sympy.printing.str import sstr +from sympy.sets.sets import Interval + + +def test_hilbert_space(): + hs = HilbertSpace() + assert isinstance(hs, HilbertSpace) + assert sstr(hs) == 'H' + assert srepr(hs) == 'HilbertSpace()' + + +def test_complex_space(): + c1 = ComplexSpace(2) + assert isinstance(c1, ComplexSpace) + assert c1.dimension == 2 + assert sstr(c1) == 'C(2)' + assert srepr(c1) == 'ComplexSpace(Integer(2))' + + n = Symbol('n') + c2 = ComplexSpace(n) + assert isinstance(c2, ComplexSpace) + assert c2.dimension == n + assert sstr(c2) == 'C(n)' + assert srepr(c2) == "ComplexSpace(Symbol('n'))" + assert c2.subs(n, 2) == ComplexSpace(2) + + +def test_L2(): + b1 = L2(Interval(-oo, 1)) + assert isinstance(b1, L2) + assert b1.dimension is oo + assert b1.interval == Interval(-oo, 1) + + x = Symbol('x', real=True) + y = Symbol('y', real=True) + b2 = L2(Interval(x, y)) + assert b2.dimension is oo + assert b2.interval == Interval(x, y) + assert b2.subs(x, -1) == L2(Interval(-1, y)) + + +def test_fock_space(): + f1 = FockSpace() + f2 = FockSpace() + assert isinstance(f1, FockSpace) + assert f1.dimension is oo + assert f1 == f2 + + +def test_tensor_product(): + n = Symbol('n') + hs1 = ComplexSpace(2) + hs2 = ComplexSpace(n) + + h = hs1*hs2 + assert isinstance(h, TensorProductHilbertSpace) + assert h.dimension == 2*n + assert h.spaces == (hs1, hs2) + + h = hs2*hs2 + assert isinstance(h, TensorPowerHilbertSpace) + assert h.base == hs2 + assert h.exp == 2 + assert h.dimension == n**2 + + f = FockSpace() + h = hs1*hs2*f + assert h.dimension is oo + + +def test_tensor_power(): + n = Symbol('n') + hs1 = ComplexSpace(2) + hs2 = ComplexSpace(n) + + h = hs1**2 + assert isinstance(h, TensorPowerHilbertSpace) + assert h.base == hs1 + assert h.exp == 2 + assert h.dimension == 4 + + h = hs2**3 + assert isinstance(h, TensorPowerHilbertSpace) + assert h.base == hs2 + assert h.exp == 3 + assert h.dimension == n**3 + + +def test_direct_sum(): + n = Symbol('n') + hs1 = ComplexSpace(2) + hs2 = ComplexSpace(n) + + h = hs1 + hs2 + assert isinstance(h, DirectSumHilbertSpace) + assert h.dimension == 2 + n + assert h.spaces == (hs1, hs2) + + f = FockSpace() + h = hs1 + f + hs2 + assert h.dimension is oo + assert h.spaces == (hs1, f, hs2) diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_identitysearch.py b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_identitysearch.py new file mode 100644 index 0000000000000000000000000000000000000000..8747b1f9d9630e699695f67734333f9d61581fb8 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_identitysearch.py @@ -0,0 +1,492 @@ +from sympy.external import import_module +from sympy.core.mul import Mul +from sympy.core.numbers import Integer +from sympy.physics.quantum.dagger import Dagger +from sympy.physics.quantum.gate import (X, Y, Z, H, CNOT, + IdentityGate, CGate, PhaseGate, TGate) +from sympy.physics.quantum.identitysearch import (generate_gate_rules, + generate_equivalent_ids, GateIdentity, bfs_identity_search, + is_scalar_sparse_matrix, + is_scalar_nonsparse_matrix, is_degenerate, is_reducible) +from sympy.testing.pytest import skip + + +def create_gate_sequence(qubit=0): + gates = (X(qubit), Y(qubit), Z(qubit), H(qubit)) + return gates + + +def test_generate_gate_rules_1(): + # Test with tuples + (x, y, z, h) = create_gate_sequence() + ph = PhaseGate(0) + cgate_t = CGate(0, TGate(1)) + + assert generate_gate_rules((x,)) == {((x,), ())} + + gate_rules = {((x, x), ()), + ((x,), (x,))} + assert generate_gate_rules((x, x)) == gate_rules + + gate_rules = {((x, y, x), ()), + ((y, x, x), ()), + ((x, x, y), ()), + ((y, x), (x,)), + ((x, y), (x,)), + ((y,), (x, x))} + assert generate_gate_rules((x, y, x)) == gate_rules + + gate_rules = {((x, y, z), ()), ((y, z, x), ()), ((z, x, y), ()), + ((), (x, z, y)), ((), (y, x, z)), ((), (z, y, x)), + ((x,), (z, y)), ((y, z), (x,)), ((y,), (x, z)), + ((z, x), (y,)), ((z,), (y, x)), ((x, y), (z,))} + actual = generate_gate_rules((x, y, z)) + assert actual == gate_rules + + gate_rules = { + ((), (h, z, y, x)), ((), (x, h, z, y)), ((), (y, x, h, z)), + ((), (z, y, x, h)), ((h,), (z, y, x)), ((x,), (h, z, y)), + ((y,), (x, h, z)), ((z,), (y, x, h)), ((h, x), (z, y)), + ((x, y), (h, z)), ((y, z), (x, h)), ((z, h), (y, x)), + ((h, x, y), (z,)), ((x, y, z), (h,)), ((y, z, h), (x,)), + ((z, h, x), (y,)), ((h, x, y, z), ()), ((x, y, z, h), ()), + ((y, z, h, x), ()), ((z, h, x, y), ())} + actual = generate_gate_rules((x, y, z, h)) + assert actual == gate_rules + + gate_rules = {((), (cgate_t**(-1), ph**(-1), x)), + ((), (ph**(-1), x, cgate_t**(-1))), + ((), (x, cgate_t**(-1), ph**(-1))), + ((cgate_t,), (ph**(-1), x)), + ((ph,), (x, cgate_t**(-1))), + ((x,), (cgate_t**(-1), ph**(-1))), + ((cgate_t, x), (ph**(-1),)), + ((ph, cgate_t), (x,)), + ((x, ph), (cgate_t**(-1),)), + ((cgate_t, x, ph), ()), + ((ph, cgate_t, x), ()), + ((x, ph, cgate_t), ())} + actual = generate_gate_rules((x, ph, cgate_t)) + assert actual == gate_rules + + gate_rules = {(Integer(1), cgate_t**(-1)*ph**(-1)*x), + (Integer(1), ph**(-1)*x*cgate_t**(-1)), + (Integer(1), x*cgate_t**(-1)*ph**(-1)), + (cgate_t, ph**(-1)*x), + (ph, x*cgate_t**(-1)), + (x, cgate_t**(-1)*ph**(-1)), + (cgate_t*x, ph**(-1)), + (ph*cgate_t, x), + (x*ph, cgate_t**(-1)), + (cgate_t*x*ph, Integer(1)), + (ph*cgate_t*x, Integer(1)), + (x*ph*cgate_t, Integer(1))} + actual = generate_gate_rules((x, ph, cgate_t), return_as_muls=True) + assert actual == gate_rules + + +def test_generate_gate_rules_2(): + # Test with Muls + (x, y, z, h) = create_gate_sequence() + ph = PhaseGate(0) + cgate_t = CGate(0, TGate(1)) + + # Note: 1 (type int) is not the same as 1 (type One) + expected = {(x, Integer(1))} + assert generate_gate_rules((x,), return_as_muls=True) == expected + + expected = {(Integer(1), Integer(1))} + assert generate_gate_rules(x*x, return_as_muls=True) == expected + + expected = {((), ())} + assert generate_gate_rules(x*x, return_as_muls=False) == expected + + gate_rules = {(x*y*x, Integer(1)), + (y, Integer(1)), + (y*x, x), + (x*y, x)} + assert generate_gate_rules(x*y*x, return_as_muls=True) == gate_rules + + gate_rules = {(x*y*z, Integer(1)), + (y*z*x, Integer(1)), + (z*x*y, Integer(1)), + (Integer(1), x*z*y), + (Integer(1), y*x*z), + (Integer(1), z*y*x), + (x, z*y), + (y*z, x), + (y, x*z), + (z*x, y), + (z, y*x), + (x*y, z)} + actual = generate_gate_rules(x*y*z, return_as_muls=True) + assert actual == gate_rules + + gate_rules = {(Integer(1), h*z*y*x), + (Integer(1), x*h*z*y), + (Integer(1), y*x*h*z), + (Integer(1), z*y*x*h), + (h, z*y*x), (x, h*z*y), + (y, x*h*z), (z, y*x*h), + (h*x, z*y), (z*h, y*x), + (x*y, h*z), (y*z, x*h), + (h*x*y, z), (x*y*z, h), + (y*z*h, x), (z*h*x, y), + (h*x*y*z, Integer(1)), + (x*y*z*h, Integer(1)), + (y*z*h*x, Integer(1)), + (z*h*x*y, Integer(1))} + actual = generate_gate_rules(x*y*z*h, return_as_muls=True) + assert actual == gate_rules + + gate_rules = {(Integer(1), cgate_t**(-1)*ph**(-1)*x), + (Integer(1), ph**(-1)*x*cgate_t**(-1)), + (Integer(1), x*cgate_t**(-1)*ph**(-1)), + (cgate_t, ph**(-1)*x), + (ph, x*cgate_t**(-1)), + (x, cgate_t**(-1)*ph**(-1)), + (cgate_t*x, ph**(-1)), + (ph*cgate_t, x), + (x*ph, cgate_t**(-1)), + (cgate_t*x*ph, Integer(1)), + (ph*cgate_t*x, Integer(1)), + (x*ph*cgate_t, Integer(1))} + actual = generate_gate_rules(x*ph*cgate_t, return_as_muls=True) + assert actual == gate_rules + + gate_rules = {((), (cgate_t**(-1), ph**(-1), x)), + ((), (ph**(-1), x, cgate_t**(-1))), + ((), (x, cgate_t**(-1), ph**(-1))), + ((cgate_t,), (ph**(-1), x)), + ((ph,), (x, cgate_t**(-1))), + ((x,), (cgate_t**(-1), ph**(-1))), + ((cgate_t, x), (ph**(-1),)), + ((ph, cgate_t), (x,)), + ((x, ph), (cgate_t**(-1),)), + ((cgate_t, x, ph), ()), + ((ph, cgate_t, x), ()), + ((x, ph, cgate_t), ())} + actual = generate_gate_rules(x*ph*cgate_t) + assert actual == gate_rules + + +def test_generate_equivalent_ids_1(): + # Test with tuples + (x, y, z, h) = create_gate_sequence() + + assert generate_equivalent_ids((x,)) == {(x,)} + assert generate_equivalent_ids((x, x)) == {(x, x)} + assert generate_equivalent_ids((x, y)) == {(x, y), (y, x)} + + gate_seq = (x, y, z) + gate_ids = {(x, y, z), (y, z, x), (z, x, y), (z, y, x), + (y, x, z), (x, z, y)} + assert generate_equivalent_ids(gate_seq) == gate_ids + + gate_ids = {Mul(x, y, z), Mul(y, z, x), Mul(z, x, y), + Mul(z, y, x), Mul(y, x, z), Mul(x, z, y)} + assert generate_equivalent_ids(gate_seq, return_as_muls=True) == gate_ids + + gate_seq = (x, y, z, h) + gate_ids = {(x, y, z, h), (y, z, h, x), + (h, x, y, z), (h, z, y, x), + (z, y, x, h), (y, x, h, z), + (z, h, x, y), (x, h, z, y)} + assert generate_equivalent_ids(gate_seq) == gate_ids + + gate_seq = (x, y, x, y) + gate_ids = {(x, y, x, y), (y, x, y, x)} + assert generate_equivalent_ids(gate_seq) == gate_ids + + cgate_y = CGate((1,), y) + gate_seq = (y, cgate_y, y, cgate_y) + gate_ids = {(y, cgate_y, y, cgate_y), (cgate_y, y, cgate_y, y)} + assert generate_equivalent_ids(gate_seq) == gate_ids + + cnot = CNOT(1, 0) + cgate_z = CGate((0,), Z(1)) + gate_seq = (cnot, h, cgate_z, h) + gate_ids = {(cnot, h, cgate_z, h), (h, cgate_z, h, cnot), + (h, cnot, h, cgate_z), (cgate_z, h, cnot, h)} + assert generate_equivalent_ids(gate_seq) == gate_ids + + +def test_generate_equivalent_ids_2(): + # Test with Muls + (x, y, z, h) = create_gate_sequence() + + assert generate_equivalent_ids((x,), return_as_muls=True) == {x} + + gate_ids = {Integer(1)} + assert generate_equivalent_ids(x*x, return_as_muls=True) == gate_ids + + gate_ids = {x*y, y*x} + assert generate_equivalent_ids(x*y, return_as_muls=True) == gate_ids + + gate_ids = {(x, y), (y, x)} + assert generate_equivalent_ids(x*y) == gate_ids + + circuit = Mul(*(x, y, z)) + gate_ids = {x*y*z, y*z*x, z*x*y, z*y*x, + y*x*z, x*z*y} + assert generate_equivalent_ids(circuit, return_as_muls=True) == gate_ids + + circuit = Mul(*(x, y, z, h)) + gate_ids = {x*y*z*h, y*z*h*x, + h*x*y*z, h*z*y*x, + z*y*x*h, y*x*h*z, + z*h*x*y, x*h*z*y} + assert generate_equivalent_ids(circuit, return_as_muls=True) == gate_ids + + circuit = Mul(*(x, y, x, y)) + gate_ids = {x*y*x*y, y*x*y*x} + assert generate_equivalent_ids(circuit, return_as_muls=True) == gate_ids + + cgate_y = CGate((1,), y) + circuit = Mul(*(y, cgate_y, y, cgate_y)) + gate_ids = {y*cgate_y*y*cgate_y, cgate_y*y*cgate_y*y} + assert generate_equivalent_ids(circuit, return_as_muls=True) == gate_ids + + cnot = CNOT(1, 0) + cgate_z = CGate((0,), Z(1)) + circuit = Mul(*(cnot, h, cgate_z, h)) + gate_ids = {cnot*h*cgate_z*h, h*cgate_z*h*cnot, + h*cnot*h*cgate_z, cgate_z*h*cnot*h} + assert generate_equivalent_ids(circuit, return_as_muls=True) == gate_ids + + +def test_is_scalar_nonsparse_matrix(): + numqubits = 2 + id_only = False + + id_gate = (IdentityGate(1),) + actual = is_scalar_nonsparse_matrix(id_gate, numqubits, id_only) + assert actual is True + + x0 = X(0) + xx_circuit = (x0, x0) + actual = is_scalar_nonsparse_matrix(xx_circuit, numqubits, id_only) + assert actual is True + + x1 = X(1) + y1 = Y(1) + xy_circuit = (x1, y1) + actual = is_scalar_nonsparse_matrix(xy_circuit, numqubits, id_only) + assert actual is False + + z1 = Z(1) + xyz_circuit = (x1, y1, z1) + actual = is_scalar_nonsparse_matrix(xyz_circuit, numqubits, id_only) + assert actual is True + + cnot = CNOT(1, 0) + cnot_circuit = (cnot, cnot) + actual = is_scalar_nonsparse_matrix(cnot_circuit, numqubits, id_only) + assert actual is True + + h = H(0) + hh_circuit = (h, h) + actual = is_scalar_nonsparse_matrix(hh_circuit, numqubits, id_only) + assert actual is True + + h1 = H(1) + xhzh_circuit = (x1, h1, z1, h1) + actual = is_scalar_nonsparse_matrix(xhzh_circuit, numqubits, id_only) + assert actual is True + + id_only = True + actual = is_scalar_nonsparse_matrix(xhzh_circuit, numqubits, id_only) + assert actual is True + actual = is_scalar_nonsparse_matrix(xyz_circuit, numqubits, id_only) + assert actual is False + actual = is_scalar_nonsparse_matrix(cnot_circuit, numqubits, id_only) + assert actual is True + actual = is_scalar_nonsparse_matrix(hh_circuit, numqubits, id_only) + assert actual is True + + +def test_is_scalar_sparse_matrix(): + np = import_module('numpy') + if not np: + skip("numpy not installed.") + + scipy = import_module('scipy', import_kwargs={'fromlist': ['sparse']}) + if not scipy: + skip("scipy not installed.") + + numqubits = 2 + id_only = False + + id_gate = (IdentityGate(1),) + assert is_scalar_sparse_matrix(id_gate, numqubits, id_only) is True + + x0 = X(0) + xx_circuit = (x0, x0) + assert is_scalar_sparse_matrix(xx_circuit, numqubits, id_only) is True + + x1 = X(1) + y1 = Y(1) + xy_circuit = (x1, y1) + assert is_scalar_sparse_matrix(xy_circuit, numqubits, id_only) is False + + z1 = Z(1) + xyz_circuit = (x1, y1, z1) + assert is_scalar_sparse_matrix(xyz_circuit, numqubits, id_only) is True + + cnot = CNOT(1, 0) + cnot_circuit = (cnot, cnot) + assert is_scalar_sparse_matrix(cnot_circuit, numqubits, id_only) is True + + h = H(0) + hh_circuit = (h, h) + assert is_scalar_sparse_matrix(hh_circuit, numqubits, id_only) is True + + # NOTE: + # The elements of the sparse matrix for the following circuit + # is actually 1.0000000000000002+0.0j. + h1 = H(1) + xhzh_circuit = (x1, h1, z1, h1) + assert is_scalar_sparse_matrix(xhzh_circuit, numqubits, id_only) is True + + id_only = True + assert is_scalar_sparse_matrix(xhzh_circuit, numqubits, id_only) is True + assert is_scalar_sparse_matrix(xyz_circuit, numqubits, id_only) is False + assert is_scalar_sparse_matrix(cnot_circuit, numqubits, id_only) is True + assert is_scalar_sparse_matrix(hh_circuit, numqubits, id_only) is True + + +def test_is_degenerate(): + (x, y, z, h) = create_gate_sequence() + + gate_id = GateIdentity(x, y, z) + ids = {gate_id} + + another_id = (z, y, x) + assert is_degenerate(ids, another_id) is True + + +def test_is_reducible(): + nqubits = 2 + (x, y, z, h) = create_gate_sequence() + + circuit = (x, y, y) + assert is_reducible(circuit, nqubits, 1, 3) is True + + circuit = (x, y, x) + assert is_reducible(circuit, nqubits, 1, 3) is False + + circuit = (x, y, y, x) + assert is_reducible(circuit, nqubits, 0, 4) is True + + circuit = (x, y, y, x) + assert is_reducible(circuit, nqubits, 1, 3) is True + + circuit = (x, y, z, y, y) + assert is_reducible(circuit, nqubits, 1, 5) is True + + +def test_bfs_identity_search(): + assert bfs_identity_search([], 1) == set() + + (x, y, z, h) = create_gate_sequence() + + gate_list = [x] + id_set = {GateIdentity(x, x)} + assert bfs_identity_search(gate_list, 1, max_depth=2) == id_set + + # Set should not contain degenerate quantum circuits + gate_list = [x, y, z] + id_set = {GateIdentity(x, x), + GateIdentity(y, y), + GateIdentity(z, z), + GateIdentity(x, y, z)} + assert bfs_identity_search(gate_list, 1) == id_set + + id_set = {GateIdentity(x, x), + GateIdentity(y, y), + GateIdentity(z, z), + GateIdentity(x, y, z), + GateIdentity(x, y, x, y), + GateIdentity(x, z, x, z), + GateIdentity(y, z, y, z)} + assert bfs_identity_search(gate_list, 1, max_depth=4) == id_set + assert bfs_identity_search(gate_list, 1, max_depth=5) == id_set + + gate_list = [x, y, z, h] + id_set = {GateIdentity(x, x), + GateIdentity(y, y), + GateIdentity(z, z), + GateIdentity(h, h), + GateIdentity(x, y, z), + GateIdentity(x, y, x, y), + GateIdentity(x, z, x, z), + GateIdentity(x, h, z, h), + GateIdentity(y, z, y, z), + GateIdentity(y, h, y, h)} + assert bfs_identity_search(gate_list, 1) == id_set + + id_set = {GateIdentity(x, x), + GateIdentity(y, y), + GateIdentity(z, z), + GateIdentity(h, h)} + assert id_set == bfs_identity_search(gate_list, 1, max_depth=3, + identity_only=True) + + id_set = {GateIdentity(x, x), + GateIdentity(y, y), + GateIdentity(z, z), + GateIdentity(h, h), + GateIdentity(x, y, z), + GateIdentity(x, y, x, y), + GateIdentity(x, z, x, z), + GateIdentity(x, h, z, h), + GateIdentity(y, z, y, z), + GateIdentity(y, h, y, h), + GateIdentity(x, y, h, x, h), + GateIdentity(x, z, h, y, h), + GateIdentity(y, z, h, z, h)} + assert bfs_identity_search(gate_list, 1, max_depth=5) == id_set + + id_set = {GateIdentity(x, x), + GateIdentity(y, y), + GateIdentity(z, z), + GateIdentity(h, h), + GateIdentity(x, h, z, h)} + assert id_set == bfs_identity_search(gate_list, 1, max_depth=4, + identity_only=True) + + cnot = CNOT(1, 0) + gate_list = [x, cnot] + id_set = {GateIdentity(x, x), + GateIdentity(cnot, cnot), + GateIdentity(x, cnot, x, cnot)} + assert bfs_identity_search(gate_list, 2, max_depth=4) == id_set + + cgate_x = CGate((1,), x) + gate_list = [x, cgate_x] + id_set = {GateIdentity(x, x), + GateIdentity(cgate_x, cgate_x), + GateIdentity(x, cgate_x, x, cgate_x)} + assert bfs_identity_search(gate_list, 2, max_depth=4) == id_set + + cgate_z = CGate((0,), Z(1)) + gate_list = [cnot, cgate_z, h] + id_set = {GateIdentity(h, h), + GateIdentity(cgate_z, cgate_z), + GateIdentity(cnot, cnot), + GateIdentity(cnot, h, cgate_z, h)} + assert bfs_identity_search(gate_list, 2, max_depth=4) == id_set + + s = PhaseGate(0) + t = TGate(0) + gate_list = [s, t] + id_set = {GateIdentity(s, s, s, s)} + assert bfs_identity_search(gate_list, 1, max_depth=4) == id_set + + +def test_bfs_identity_search_xfail(): + s = PhaseGate(0) + t = TGate(0) + gate_list = [Dagger(s), t] + id_set = {GateIdentity(Dagger(s), t, t)} + assert bfs_identity_search(gate_list, 1, max_depth=3) == id_set diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_innerproduct.py b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_innerproduct.py new file mode 100644 index 0000000000000000000000000000000000000000..2632031f8a9a9ec65dfab6d834eb704a00b621d3 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_innerproduct.py @@ -0,0 +1,71 @@ +from sympy.core.numbers import (I, Integer) + +from sympy.physics.quantum.innerproduct import InnerProduct +from sympy.physics.quantum.dagger import Dagger +from sympy.physics.quantum.state import Bra, Ket, StateBase + + +def test_innerproduct(): + k = Ket('k') + b = Bra('b') + ip = InnerProduct(b, k) + assert isinstance(ip, InnerProduct) + assert ip.bra == b + assert ip.ket == k + assert b*k == InnerProduct(b, k) + assert k*(b*k)*b == k*InnerProduct(b, k)*b + assert InnerProduct(b, k).subs(b, Dagger(k)) == Dagger(k)*k + + +def test_innerproduct_dagger(): + k = Ket('k') + b = Bra('b') + ip = b*k + assert Dagger(ip) == Dagger(k)*Dagger(b) + + +class FooState(StateBase): + pass + + +class FooKet(Ket, FooState): + + @classmethod + def dual_class(self): + return FooBra + + def _eval_innerproduct_FooBra(self, bra): + return Integer(1) + + def _eval_innerproduct_BarBra(self, bra): + return I + + +class FooBra(Bra, FooState): + @classmethod + def dual_class(self): + return FooKet + + +class BarState(StateBase): + pass + + +class BarKet(Ket, BarState): + @classmethod + def dual_class(self): + return BarBra + + +class BarBra(Bra, BarState): + @classmethod + def dual_class(self): + return BarKet + + +def test_doit(): + f = FooKet('foo') + b = BarBra('bar') + assert InnerProduct(b, f).doit() == I + assert InnerProduct(Dagger(f), Dagger(b)).doit() == -I + assert InnerProduct(Dagger(f), f).doit() == Integer(1) diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_kind.py b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_kind.py new file mode 100644 index 0000000000000000000000000000000000000000..e50467db4c2d9bd8e19f4ea883c26bd5ac5bc8d8 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_kind.py @@ -0,0 +1,75 @@ +"""Tests for sympy.physics.quantum.kind.""" + +from sympy.core.kind import NumberKind, UndefinedKind +from sympy.core.symbol import symbols + +from sympy.physics.quantum.kind import ( + OperatorKind, KetKind, BraKind +) +from sympy.physics.quantum.anticommutator import AntiCommutator +from sympy.physics.quantum.commutator import Commutator +from sympy.physics.quantum.dagger import Dagger +from sympy.physics.quantum.operator import Operator +from sympy.physics.quantum.state import Ket, Bra +from sympy.physics.quantum.tensorproduct import TensorProduct + +k = Ket('k') +b = Bra('k') +A = Operator('A') +B = Operator('B') +x, y, z = symbols('x y z', integer=True) + +def test_bra_ket(): + assert k.kind == KetKind + assert b.kind == BraKind + assert (b*k).kind == NumberKind # inner product + assert (x*k).kind == KetKind + assert (x*b).kind == BraKind + + +def test_operator_kind(): + assert A.kind == OperatorKind + assert (A*B).kind == OperatorKind + assert (x*A).kind == OperatorKind + assert (x*A*B).kind == OperatorKind + assert (x*k*b).kind == OperatorKind # outer product + + +def test_undefind_kind(): + # Because of limitations in the kind dispatcher API, we are currently + # unable to have OperatorKind*KetKind -> KetKind (and similar for bras). + assert (A*k).kind == UndefinedKind + assert (b*A).kind == UndefinedKind + assert (x*b*A*k).kind == UndefinedKind + + +def test_dagger_kind(): + assert Dagger(k).kind == BraKind + assert Dagger(b).kind == KetKind + assert Dagger(A).kind == OperatorKind + + +def test_commutator_kind(): + assert Commutator(A, B).kind == OperatorKind + assert Commutator(A, x*B).kind == OperatorKind + assert Commutator(x*A, B).kind == OperatorKind + assert Commutator(x*A, x*B).kind == OperatorKind + + +def test_anticommutator_kind(): + assert AntiCommutator(A, B).kind == OperatorKind + assert AntiCommutator(A, x*B).kind == OperatorKind + assert AntiCommutator(x*A, B).kind == OperatorKind + assert AntiCommutator(x*A, x*B).kind == OperatorKind + + +def test_tensorproduct_kind(): + assert TensorProduct(k,k).kind == KetKind + assert TensorProduct(b,b).kind == BraKind + assert TensorProduct(x*k,y*k).kind == KetKind + assert TensorProduct(x*b,y*b).kind == BraKind + assert TensorProduct(x*b*k, y*b*k).kind == NumberKind + assert TensorProduct(x*k*b, y*k*b).kind == OperatorKind + assert TensorProduct(A, B).kind == OperatorKind + assert TensorProduct(A, x*B).kind == OperatorKind + assert TensorProduct(x*A, B).kind == OperatorKind diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_matrixutils.py b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_matrixutils.py new file mode 100644 index 0000000000000000000000000000000000000000..4d4fa8a0a2a4374d200473fa03c68fc453262a4c --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_matrixutils.py @@ -0,0 +1,136 @@ +from sympy.core.random import randint + +from sympy.core.numbers import Integer +from sympy.matrices.dense import (Matrix, ones, zeros) + +from sympy.physics.quantum.matrixutils import ( + to_sympy, to_numpy, to_scipy_sparse, matrix_tensor_product, + matrix_to_zero, matrix_zeros, numpy_ndarray, scipy_sparse_matrix +) + +from sympy.external import import_module +from sympy.testing.pytest import skip + +m = Matrix([[1, 2], [3, 4]]) + + +def test_sympy_to_sympy(): + assert to_sympy(m) == m + + +def test_matrix_to_zero(): + assert matrix_to_zero(m) == m + assert matrix_to_zero(Matrix([[0, 0], [0, 0]])) == Integer(0) + +np = import_module('numpy') + + +def test_to_numpy(): + if not np: + skip("numpy not installed.") + + result = np.array([[1, 2], [3, 4]], dtype='complex') + assert (to_numpy(m) == result).all() + + +def test_matrix_tensor_product(): + if not np: + skip("numpy not installed.") + + l1 = zeros(4) + for i in range(16): + l1[i] = 2**i + l2 = zeros(4) + for i in range(16): + l2[i] = i + l3 = zeros(2) + for i in range(4): + l3[i] = i + vec = Matrix([1, 2, 3]) + + #test for Matrix known 4x4 matrices + numpyl1 = np.array(l1.tolist()) + numpyl2 = np.array(l2.tolist()) + numpy_product = np.kron(numpyl1, numpyl2) + args = [l1, l2] + sympy_product = matrix_tensor_product(*args) + assert numpy_product.tolist() == sympy_product.tolist() + numpy_product = np.kron(numpyl2, numpyl1) + args = [l2, l1] + sympy_product = matrix_tensor_product(*args) + assert numpy_product.tolist() == sympy_product.tolist() + + #test for other known matrix of different dimensions + numpyl2 = np.array(l3.tolist()) + numpy_product = np.kron(numpyl1, numpyl2) + args = [l1, l3] + sympy_product = matrix_tensor_product(*args) + assert numpy_product.tolist() == sympy_product.tolist() + numpy_product = np.kron(numpyl2, numpyl1) + args = [l3, l1] + sympy_product = matrix_tensor_product(*args) + assert numpy_product.tolist() == sympy_product.tolist() + + #test for non square matrix + numpyl2 = np.array(vec.tolist()) + numpy_product = np.kron(numpyl1, numpyl2) + args = [l1, vec] + sympy_product = matrix_tensor_product(*args) + assert numpy_product.tolist() == sympy_product.tolist() + numpy_product = np.kron(numpyl2, numpyl1) + args = [vec, l1] + sympy_product = matrix_tensor_product(*args) + assert numpy_product.tolist() == sympy_product.tolist() + + #test for random matrix with random values that are floats + random_matrix1 = np.random.rand(randint(1, 5), randint(1, 5)) + random_matrix2 = np.random.rand(randint(1, 5), randint(1, 5)) + numpy_product = np.kron(random_matrix1, random_matrix2) + args = [Matrix(random_matrix1.tolist()), Matrix(random_matrix2.tolist())] + sympy_product = matrix_tensor_product(*args) + assert not (sympy_product - Matrix(numpy_product.tolist())).tolist() > \ + (ones(sympy_product.rows, sympy_product.cols)*epsilon).tolist() + + #test for three matrix kronecker + sympy_product = matrix_tensor_product(l1, vec, l2) + + numpy_product = np.kron(l1, np.kron(vec, l2)) + assert numpy_product.tolist() == sympy_product.tolist() + + +scipy = import_module('scipy', import_kwargs={'fromlist': ['sparse']}) + + +def test_to_scipy_sparse(): + if not np: + skip("numpy not installed.") + if not scipy: + skip("scipy not installed.") + else: + sparse = scipy.sparse + + result = sparse.csr_matrix([[1, 2], [3, 4]], dtype='complex') + assert np.linalg.norm((to_scipy_sparse(m) - result).todense()) == 0.0 + +epsilon = .000001 + + +def test_matrix_zeros_sympy(): + sym = matrix_zeros(4, 4, format='sympy') + assert isinstance(sym, Matrix) + +def test_matrix_zeros_numpy(): + if not np: + skip("numpy not installed.") + + num = matrix_zeros(4, 4, format='numpy') + assert isinstance(num, numpy_ndarray) + +def test_matrix_zeros_scipy(): + if not np: + skip("numpy not installed.") + if not scipy: + skip("scipy not installed.") + + sci = matrix_zeros(4, 4, format='scipy.sparse') + assert isinstance(sci, scipy_sparse_matrix) diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_operator.py b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_operator.py new file mode 100644 index 0000000000000000000000000000000000000000..100cacd9a800f7c4435b93672ef77877a3a99e5e --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_operator.py @@ -0,0 +1,269 @@ +from sympy.core.function import (Derivative, Function, diff) +from sympy.core.mul import Mul +from sympy.core.numbers import (Integer, pi) +from sympy.core.symbol import (Symbol, symbols) +from sympy.core.sympify import sympify +from sympy.functions.elementary.trigonometric import sin +from sympy.physics.quantum.qexpr import QExpr +from sympy.physics.quantum.dagger import Dagger +from sympy.physics.quantum.hilbert import HilbertSpace +from sympy.physics.quantum.operator import (Operator, UnitaryOperator, + HermitianOperator, OuterProduct, + DifferentialOperator, + IdentityOperator) +from sympy.physics.quantum.state import Ket, Bra, Wavefunction +from sympy.physics.quantum.qapply import qapply +from sympy.physics.quantum.represent import represent +from sympy.physics.quantum.spin import JzKet, JzBra +from sympy.physics.quantum.trace import Tr +from sympy.matrices import eye + +from sympy.testing.pytest import warns_deprecated_sympy + + +class CustomKet(Ket): + @classmethod + def default_args(self): + return ("t",) + + +class CustomOp(HermitianOperator): + @classmethod + def default_args(self): + return ("T",) + +t_ket = CustomKet() +t_op = CustomOp() + + +def test_operator(): + A = Operator('A') + B = Operator('B') + C = Operator('C') + + assert isinstance(A, Operator) + assert isinstance(A, QExpr) + + assert A.label == (Symbol('A'),) + assert A.is_commutative is False + assert A.hilbert_space == HilbertSpace() + + assert A*B != B*A + + assert (A*(B + C)).expand() == A*B + A*C + assert ((A + B)**2).expand() == A**2 + A*B + B*A + B**2 + + assert t_op.label[0] == Symbol(t_op.default_args()[0]) + + assert Operator() == Operator("O") + with warns_deprecated_sympy(): + assert A*IdentityOperator() == A + + +def test_operator_inv(): + A = Operator('A') + assert A*A.inv() == 1 + assert A.inv()*A == 1 + + +def test_hermitian(): + H = HermitianOperator('H') + + assert isinstance(H, HermitianOperator) + assert isinstance(H, Operator) + + assert Dagger(H) == H + assert H.inv() != H + assert H.is_commutative is False + assert Dagger(H).is_commutative is False + + +def test_unitary(): + U = UnitaryOperator('U') + + assert isinstance(U, UnitaryOperator) + assert isinstance(U, Operator) + + assert U.inv() == Dagger(U) + assert U*Dagger(U) == 1 + assert Dagger(U)*U == 1 + assert U.is_commutative is False + assert Dagger(U).is_commutative is False + + +def test_identity(): + with warns_deprecated_sympy(): + I = IdentityOperator() + O = Operator('O') + x = Symbol("x") + three = sympify(3) + + assert isinstance(I, IdentityOperator) + assert isinstance(I, Operator) + + assert I * O == O + assert O * I == O + assert I * Dagger(O) == Dagger(O) + assert Dagger(O) * I == Dagger(O) + assert isinstance(I * I, IdentityOperator) + assert three * I == three + assert I * x == x + assert I.inv() == I + assert Dagger(I) == I + assert qapply(I * O) == O + assert qapply(O * I) == O + + for n in [2, 3, 5]: + assert represent(IdentityOperator(n)) == eye(n) + + +def test_outer_product(): + k = Ket('k') + b = Bra('b') + op = OuterProduct(k, b) + + assert isinstance(op, OuterProduct) + assert isinstance(op, Operator) + + assert op.ket == k + assert op.bra == b + assert op.label == (k, b) + assert op.is_commutative is False + + op = k*b + + assert isinstance(op, OuterProduct) + assert isinstance(op, Operator) + + assert op.ket == k + assert op.bra == b + assert op.label == (k, b) + assert op.is_commutative is False + + op = 2*k*b + + assert op == Mul(Integer(2), k, b) + + op = 2*(k*b) + + assert op == Mul(Integer(2), OuterProduct(k, b)) + + assert Dagger(k*b) == OuterProduct(Dagger(b), Dagger(k)) + assert Dagger(k*b).is_commutative is False + + #test the _eval_trace + assert Tr(OuterProduct(JzKet(1, 1), JzBra(1, 1))).doit() == 1 + + # test scaled kets and bras + assert OuterProduct(2 * k, b) == 2 * OuterProduct(k, b) + assert OuterProduct(k, 2 * b) == 2 * OuterProduct(k, b) + + # test sums of kets and bras + k1, k2 = Ket('k1'), Ket('k2') + b1, b2 = Bra('b1'), Bra('b2') + assert (OuterProduct(k1 + k2, b1) == + OuterProduct(k1, b1) + OuterProduct(k2, b1)) + assert (OuterProduct(k1, b1 + b2) == + OuterProduct(k1, b1) + OuterProduct(k1, b2)) + assert (OuterProduct(1 * k1 + 2 * k2, 3 * b1 + 4 * b2) == + 3 * OuterProduct(k1, b1) + + 4 * OuterProduct(k1, b2) + + 6 * OuterProduct(k2, b1) + + 8 * OuterProduct(k2, b2)) + + +def test_operator_dagger(): + A = Operator('A') + B = Operator('B') + assert Dagger(A*B) == Dagger(B)*Dagger(A) + assert Dagger(A + B) == Dagger(A) + Dagger(B) + assert Dagger(A**2) == Dagger(A)**2 + + +def test_differential_operator(): + x = Symbol('x') + f = Function('f') + d = DifferentialOperator(Derivative(f(x), x), f(x)) + g = Wavefunction(x**2, x) + assert qapply(d*g) == Wavefunction(2*x, x) + assert d.expr == Derivative(f(x), x) + assert d.function == f(x) + assert d.variables == (x,) + assert diff(d, x) == DifferentialOperator(Derivative(f(x), x, 2), f(x)) + + d = DifferentialOperator(Derivative(f(x), x, 2), f(x)) + g = Wavefunction(x**3, x) + assert qapply(d*g) == Wavefunction(6*x, x) + assert d.expr == Derivative(f(x), x, 2) + assert d.function == f(x) + assert d.variables == (x,) + assert diff(d, x) == DifferentialOperator(Derivative(f(x), x, 3), f(x)) + + d = DifferentialOperator(1/x*Derivative(f(x), x), f(x)) + assert d.expr == 1/x*Derivative(f(x), x) + assert d.function == f(x) + assert d.variables == (x,) + assert diff(d, x) == \ + DifferentialOperator(Derivative(1/x*Derivative(f(x), x), x), f(x)) + assert qapply(d*g) == Wavefunction(3*x, x) + + # 2D cartesian Laplacian + y = Symbol('y') + d = DifferentialOperator(Derivative(f(x, y), x, 2) + + Derivative(f(x, y), y, 2), f(x, y)) + w = Wavefunction(x**3*y**2 + y**3*x**2, x, y) + assert d.expr == Derivative(f(x, y), x, 2) + Derivative(f(x, y), y, 2) + assert d.function == f(x, y) + assert d.variables == (x, y) + assert diff(d, x) == \ + DifferentialOperator(Derivative(d.expr, x), f(x, y)) + assert diff(d, y) == \ + DifferentialOperator(Derivative(d.expr, y), f(x, y)) + assert qapply(d*w) == Wavefunction(2*x**3 + 6*x*y**2 + 6*x**2*y + 2*y**3, + x, y) + + # 2D polar Laplacian (th = theta) + r, th = symbols('r th') + d = DifferentialOperator(1/r*Derivative(r*Derivative(f(r, th), r), r) + + 1/(r**2)*Derivative(f(r, th), th, 2), f(r, th)) + w = Wavefunction(r**2*sin(th), r, (th, 0, pi)) + assert d.expr == \ + 1/r*Derivative(r*Derivative(f(r, th), r), r) + \ + 1/(r**2)*Derivative(f(r, th), th, 2) + assert d.function == f(r, th) + assert d.variables == (r, th) + assert diff(d, r) == \ + DifferentialOperator(Derivative(d.expr, r), f(r, th)) + assert diff(d, th) == \ + DifferentialOperator(Derivative(d.expr, th), f(r, th)) + assert qapply(d*w) == Wavefunction(3*sin(th), r, (th, 0, pi)) + + +def test_eval_power(): + from sympy.core import Pow + from sympy.core.expr import unchanged + O = Operator('O') + U = UnitaryOperator('U') + H = HermitianOperator('H') + assert O**-1 == O.inv() # same as doc test + assert U**-1 == U.inv() + assert H**-1 == H.inv() + x = symbols("x", commutative = True) + assert unchanged(Pow, H, x) # verify Pow(H,x)=="X^n" + assert H**x == Pow(H, x) + assert Pow(H,x) == Pow(H, x, evaluate=False) # Just check + from sympy.physics.quantum.gate import XGate + X = XGate(0) # is hermitian and unitary + assert unchanged(Pow, X, x) # verify Pow(X,x)=="X^x" + assert X**x == Pow(X, x) + assert Pow(X, x, evaluate=False) == Pow(X, x) # Just check + n = symbols("n", integer=True, even=True) + assert X**n == 1 + n = symbols("n", integer=True, odd=True) + assert X**n == X + n = symbols("n", integer=True) + assert unchanged(Pow, X, n) # verify Pow(X,n)=="X^n" + assert X**n == Pow(X, n) + assert Pow(X, n, evaluate=False)==Pow(X, n) # Just check + assert X**4 == 1 + assert X**7 == X diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_operatorordering.py b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_operatorordering.py new file mode 100644 index 0000000000000000000000000000000000000000..f5255d555d1582b694dfe4ed681d894136ea0b70 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_operatorordering.py @@ -0,0 +1,50 @@ +from sympy.physics.quantum import Dagger +from sympy.physics.quantum.boson import BosonOp +from sympy.physics.quantum.fermion import FermionOp +from sympy.physics.quantum.operatorordering import (normal_order, + normal_ordered_form) + + +def test_normal_order(): + a = BosonOp('a') + + c = FermionOp('c') + + assert normal_order(a * Dagger(a)) == Dagger(a) * a + assert normal_order(Dagger(a) * a) == Dagger(a) * a + assert normal_order(a * Dagger(a) ** 2) == Dagger(a) ** 2 * a + + assert normal_order(c * Dagger(c)) == - Dagger(c) * c + assert normal_order(Dagger(c) * c) == Dagger(c) * c + assert normal_order(c * Dagger(c) ** 2) == Dagger(c) ** 2 * c + + +def test_normal_ordered_form(): + a = BosonOp('a') + b = BosonOp('b') + + c = FermionOp('c') + d = FermionOp('d') + + assert normal_ordered_form(Dagger(a) * a) == Dagger(a) * a + assert normal_ordered_form(a * Dagger(a)) == 1 + Dagger(a) * a + assert normal_ordered_form(a ** 2 * Dagger(a)) == \ + 2 * a + Dagger(a) * a ** 2 + assert normal_ordered_form(a ** 3 * Dagger(a)) == \ + 3 * a ** 2 + Dagger(a) * a ** 3 + + assert normal_ordered_form(Dagger(c) * c) == Dagger(c) * c + assert normal_ordered_form(c * Dagger(c)) == 1 - Dagger(c) * c + assert normal_ordered_form(c ** 2 * Dagger(c)) == Dagger(c) * c ** 2 + assert normal_ordered_form(c ** 3 * Dagger(c)) == \ + c ** 2 - Dagger(c) * c ** 3 + + assert normal_ordered_form(a * Dagger(b), True) == Dagger(b) * a + assert normal_ordered_form(Dagger(a) * b, True) == Dagger(a) * b + assert normal_ordered_form(b * a, True) == a * b + assert normal_ordered_form(Dagger(b) * Dagger(a), True) == Dagger(a) * Dagger(b) + + assert normal_ordered_form(c * Dagger(d), True) == -Dagger(d) * c + assert normal_ordered_form(Dagger(c) * d, True) == Dagger(c) * d + assert normal_ordered_form(d * c, True) == -c * d + assert normal_ordered_form(Dagger(d) * Dagger(c), True) == -Dagger(c) * Dagger(d) diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_operatorset.py b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_operatorset.py new file mode 100644 index 0000000000000000000000000000000000000000..fff038bb12a7e6aa100ac00b0e145dc323a77e4d --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_operatorset.py @@ -0,0 +1,68 @@ +from sympy.core.singleton import S + +from sympy.physics.quantum.operatorset import ( + operators_to_state, state_to_operators +) + +from sympy.physics.quantum.cartesian import ( + XOp, XKet, PxOp, PxKet, XBra, PxBra +) + +from sympy.physics.quantum.state import Ket, Bra +from sympy.physics.quantum.operator import Operator +from sympy.physics.quantum.spin import ( + JxKet, JyKet, JzKet, JxBra, JyBra, JzBra, + JxOp, JyOp, JzOp, J2Op +) + +from sympy.testing.pytest import raises + + +def test_spin(): + assert operators_to_state({J2Op, JxOp}) == JxKet + assert operators_to_state({J2Op, JyOp}) == JyKet + assert operators_to_state({J2Op, JzOp}) == JzKet + assert operators_to_state({J2Op(), JxOp()}) == JxKet + assert operators_to_state({J2Op(), JyOp()}) == JyKet + assert operators_to_state({J2Op(), JzOp()}) == JzKet + + assert state_to_operators(JxKet) == {J2Op, JxOp} + assert state_to_operators(JyKet) == {J2Op, JyOp} + assert state_to_operators(JzKet) == {J2Op, JzOp} + assert state_to_operators(JxBra) == {J2Op, JxOp} + assert state_to_operators(JyBra) == {J2Op, JyOp} + assert state_to_operators(JzBra) == {J2Op, JzOp} + + assert state_to_operators(JxKet(S.Half, S.Half)) == {J2Op(), JxOp()} + assert state_to_operators(JyKet(S.Half, S.Half)) == {J2Op(), JyOp()} + assert state_to_operators(JzKet(S.Half, S.Half)) == {J2Op(), JzOp()} + assert state_to_operators(JxBra(S.Half, S.Half)) == {J2Op(), JxOp()} + assert state_to_operators(JyBra(S.Half, S.Half)) == {J2Op(), JyOp()} + assert state_to_operators(JzBra(S.Half, S.Half)) == {J2Op(), JzOp()} + + +def test_op_to_state(): + assert operators_to_state(XOp) == XKet() + assert operators_to_state(PxOp) == PxKet() + assert operators_to_state(Operator) == Ket() + + assert state_to_operators(operators_to_state(XOp("Q"))) == XOp("Q") + assert state_to_operators(operators_to_state(XOp())) == XOp() + + raises(NotImplementedError, lambda: operators_to_state(XKet)) + + +def test_state_to_op(): + assert state_to_operators(XKet) == XOp() + assert state_to_operators(PxKet) == PxOp() + assert state_to_operators(XBra) == XOp() + assert state_to_operators(PxBra) == PxOp() + assert state_to_operators(Ket) == Operator() + assert state_to_operators(Bra) == Operator() + + assert operators_to_state(state_to_operators(XKet("test"))) == XKet("test") + assert operators_to_state(state_to_operators(XBra("test"))) == XKet("test") + assert operators_to_state(state_to_operators(XKet())) == XKet() + assert operators_to_state(state_to_operators(XBra())) == XKet() + + raises(NotImplementedError, lambda: state_to_operators(XOp)) diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_pauli.py b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_pauli.py new file mode 100644 index 0000000000000000000000000000000000000000..77bbed93ac5b4b49680be01aefa2f779b62fc7ee --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_pauli.py @@ -0,0 +1,159 @@ +from sympy.core.mul import Mul +from sympy.core.numbers import I +from sympy.matrices.dense import Matrix +from sympy.printing.latex import latex +from sympy.physics.quantum import (Dagger, Commutator, AntiCommutator, qapply, + Operator, represent) +from sympy.physics.quantum.pauli import (SigmaOpBase, SigmaX, SigmaY, SigmaZ, + SigmaMinus, SigmaPlus, + qsimplify_pauli) +from sympy.physics.quantum.pauli import SigmaZKet, SigmaZBra +from sympy.testing.pytest import raises + + +sx, sy, sz = SigmaX(), SigmaY(), SigmaZ() +sx1, sy1, sz1 = SigmaX(1), SigmaY(1), SigmaZ(1) +sx2, sy2, sz2 = SigmaX(2), SigmaY(2), SigmaZ(2) + +sm, sp = SigmaMinus(), SigmaPlus() +sm1, sp1 = SigmaMinus(1), SigmaPlus(1) +A, B = Operator("A"), Operator("B") + + +def test_pauli_operators_types(): + + assert isinstance(sx, SigmaOpBase) and isinstance(sx, SigmaX) + assert isinstance(sy, SigmaOpBase) and isinstance(sy, SigmaY) + assert isinstance(sz, SigmaOpBase) and isinstance(sz, SigmaZ) + assert isinstance(sm, SigmaOpBase) and isinstance(sm, SigmaMinus) + assert isinstance(sp, SigmaOpBase) and isinstance(sp, SigmaPlus) + + +def test_pauli_operators_commutator(): + + assert Commutator(sx, sy).doit() == 2 * I * sz + assert Commutator(sy, sz).doit() == 2 * I * sx + assert Commutator(sz, sx).doit() == 2 * I * sy + + +def test_pauli_operators_commutator_with_labels(): + + assert Commutator(sx1, sy1).doit() == 2 * I * sz1 + assert Commutator(sy1, sz1).doit() == 2 * I * sx1 + assert Commutator(sz1, sx1).doit() == 2 * I * sy1 + + assert Commutator(sx2, sy2).doit() == 2 * I * sz2 + assert Commutator(sy2, sz2).doit() == 2 * I * sx2 + assert Commutator(sz2, sx2).doit() == 2 * I * sy2 + + assert Commutator(sx1, sy2).doit() == 0 + assert Commutator(sy1, sz2).doit() == 0 + assert Commutator(sz1, sx2).doit() == 0 + + +def test_pauli_operators_anticommutator(): + + assert AntiCommutator(sy, sz).doit() == 0 + assert AntiCommutator(sz, sx).doit() == 0 + assert AntiCommutator(sx, sm).doit() == 1 + assert AntiCommutator(sx, sp).doit() == 1 + + +def test_pauli_operators_adjoint(): + + assert Dagger(sx) == sx + assert Dagger(sy) == sy + assert Dagger(sz) == sz + + +def test_pauli_operators_adjoint_with_labels(): + + assert Dagger(sx1) == sx1 + assert Dagger(sy1) == sy1 + assert Dagger(sz1) == sz1 + + assert Dagger(sx1) != sx2 + assert Dagger(sy1) != sy2 + assert Dagger(sz1) != sz2 + + +def test_pauli_operators_multiplication(): + + assert qsimplify_pauli(sx * sx) == 1 + assert qsimplify_pauli(sy * sy) == 1 + assert qsimplify_pauli(sz * sz) == 1 + + assert qsimplify_pauli(sx * sy) == I * sz + assert qsimplify_pauli(sy * sz) == I * sx + assert qsimplify_pauli(sz * sx) == I * sy + + assert qsimplify_pauli(sy * sx) == - I * sz + assert qsimplify_pauli(sz * sy) == - I * sx + assert qsimplify_pauli(sx * sz) == - I * sy + + +def test_pauli_operators_multiplication_with_labels(): + + assert qsimplify_pauli(sx1 * sx1) == 1 + assert qsimplify_pauli(sy1 * sy1) == 1 + assert qsimplify_pauli(sz1 * sz1) == 1 + + assert isinstance(sx1 * sx2, Mul) + assert isinstance(sy1 * sy2, Mul) + assert isinstance(sz1 * sz2, Mul) + + assert qsimplify_pauli(sx1 * sy1 * sx2 * sy2) == - sz1 * sz2 + assert qsimplify_pauli(sy1 * sz1 * sz2 * sx2) == - sx1 * sy2 + + +def test_pauli_states(): + sx, sz = SigmaX(), SigmaZ() + + up = SigmaZKet(0) + down = SigmaZKet(1) + + assert qapply(sx * up) == down + assert qapply(sx * down) == up + assert qapply(sz * up) == up + assert qapply(sz * down) == - down + + up = SigmaZBra(0) + down = SigmaZBra(1) + + assert qapply(up * sx, dagger=True) == down + assert qapply(down * sx, dagger=True) == up + assert qapply(up * sz, dagger=True) == up + assert qapply(down * sz, dagger=True) == - down + + assert Dagger(SigmaZKet(0)) == SigmaZBra(0) + assert Dagger(SigmaZBra(1)) == SigmaZKet(1) + raises(ValueError, lambda: SigmaZBra(2)) + raises(ValueError, lambda: SigmaZKet(2)) + + +def test_use_name(): + assert sm.use_name is False + assert sm1.use_name is True + assert sx.use_name is False + assert sx1.use_name is True + + +def test_printing(): + assert latex(sx) == r'{\sigma_x}' + assert latex(sx1) == r'{\sigma_x^{(1)}}' + assert latex(sy) == r'{\sigma_y}' + assert latex(sy1) == r'{\sigma_y^{(1)}}' + assert latex(sz) == r'{\sigma_z}' + assert latex(sz1) == r'{\sigma_z^{(1)}}' + assert latex(sm) == r'{\sigma_-}' + assert latex(sm1) == r'{\sigma_-^{(1)}}' + assert latex(sp) == r'{\sigma_+}' + assert latex(sp1) == r'{\sigma_+^{(1)}}' + + +def test_represent(): + assert represent(sx) == Matrix([[0, 1], [1, 0]]) + assert represent(sy) == Matrix([[0, -I], [I, 0]]) + assert represent(sz) == Matrix([[1, 0], [0, -1]]) + assert represent(sm) == Matrix([[0, 0], [1, 0]]) + assert represent(sp) == Matrix([[0, 1], [0, 0]]) diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_piab.py b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_piab.py new file mode 100644 index 0000000000000000000000000000000000000000..3a4c2540b3269593c74bdbae93bf72d131a94ed9 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_piab.py @@ -0,0 +1,29 @@ +"""Tests for piab.py""" + +from sympy.core.numbers import pi +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import sin +from sympy.sets.sets import Interval +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.physics.quantum import L2, qapply, hbar, represent +from sympy.physics.quantum.piab import PIABHamiltonian, PIABKet, PIABBra, m, L + +i, j, n, x = symbols('i j n x') + + +def test_H(): + assert PIABHamiltonian('H').hilbert_space == \ + L2(Interval(S.NegativeInfinity, S.Infinity)) + assert qapply(PIABHamiltonian('H')*PIABKet(n)) == \ + (n**2*pi**2*hbar**2)/(2*m*L**2)*PIABKet(n) + + +def test_states(): + assert PIABKet(n).dual_class() == PIABBra + assert PIABKet(n).hilbert_space == \ + L2(Interval(S.NegativeInfinity, S.Infinity)) + assert represent(PIABKet(n)) == sqrt(2/L)*sin(n*pi*x/L) + assert (PIABBra(i)*PIABKet(j)).doit() == KroneckerDelta(i, j) + assert PIABBra(n).dual_class() == PIABKet diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_printing.py b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_printing.py new file mode 100644 index 0000000000000000000000000000000000000000..ce4004cee2f9e57b1c9e435f13a6850b92d929b3 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_printing.py @@ -0,0 +1,900 @@ +# -*- encoding: utf-8 -*- +""" +TODO: +* Address Issue 2251, printing of spin states +""" +from __future__ import annotations +from typing import Any + +from sympy.physics.quantum.anticommutator import AntiCommutator +from sympy.physics.quantum.cg import CG, Wigner3j, Wigner6j, Wigner9j +from sympy.physics.quantum.commutator import Commutator +from sympy.physics.quantum.constants import hbar +from sympy.physics.quantum.dagger import Dagger +from sympy.physics.quantum.gate import CGate, CNotGate, IdentityGate, UGate, XGate +from sympy.physics.quantum.hilbert import ComplexSpace, FockSpace, HilbertSpace, L2 +from sympy.physics.quantum.innerproduct import InnerProduct +from sympy.physics.quantum.operator import Operator, OuterProduct, DifferentialOperator +from sympy.physics.quantum.qexpr import QExpr +from sympy.physics.quantum.qubit import Qubit, IntQubit +from sympy.physics.quantum.spin import Jz, J2, JzBra, JzBraCoupled, JzKet, JzKetCoupled, Rotation, WignerD +from sympy.physics.quantum.state import Bra, Ket, TimeDepBra, TimeDepKet +from sympy.physics.quantum.tensorproduct import TensorProduct +from sympy.physics.quantum.sho1d import RaisingOp + +from sympy.core.function import (Derivative, Function) +from sympy.core.numbers import oo +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.matrices.dense import Matrix +from sympy.sets.sets import Interval +from sympy.testing.pytest import XFAIL + +# Imports used in srepr strings +from sympy.physics.quantum.spin import JzOp + +from sympy.printing import srepr +from sympy.printing.pretty import pretty as xpretty +from sympy.printing.latex import latex + +MutableDenseMatrix = Matrix + + +ENV: dict[str, Any] = {} +exec('from sympy import *', ENV) +exec('from sympy.physics.quantum import *', ENV) +exec('from sympy.physics.quantum.cg import *', ENV) +exec('from sympy.physics.quantum.spin import *', ENV) +exec('from sympy.physics.quantum.hilbert import *', ENV) +exec('from sympy.physics.quantum.qubit import *', ENV) +exec('from sympy.physics.quantum.qexpr import *', ENV) +exec('from sympy.physics.quantum.gate import *', ENV) +exec('from sympy.physics.quantum.constants import *', ENV) + + +def sT(expr, string): + """ + sT := sreprTest + from sympy/printing/tests/test_repr.py + """ + assert srepr(expr) == string + assert eval(string, ENV) == expr + + +def pretty(expr): + """ASCII pretty-printing""" + return xpretty(expr, use_unicode=False, wrap_line=False) + + +def upretty(expr): + """Unicode pretty-printing""" + return xpretty(expr, use_unicode=True, wrap_line=False) + + +def test_anticommutator(): + A = Operator('A') + B = Operator('B') + ac = AntiCommutator(A, B) + ac_tall = AntiCommutator(A**2, B) + assert str(ac) == '{A,B}' + assert pretty(ac) == '{A,B}' + assert upretty(ac) == '{A,B}' + assert latex(ac) == r'\left\{A,B\right\}' + sT(ac, "AntiCommutator(Operator(Symbol('A')),Operator(Symbol('B')))") + assert str(ac_tall) == '{A**2,B}' + ascii_str = \ +"""\ +/ 2 \\\n\ +\n\ +\\ /\ +""" + ucode_str = \ +"""\ +⎧ 2 ⎫\n\ +⎨A ,B⎬\n\ +⎩ ⎭\ +""" + assert pretty(ac_tall) == ascii_str + assert upretty(ac_tall) == ucode_str + assert latex(ac_tall) == r'\left\{A^{2},B\right\}' + sT(ac_tall, "AntiCommutator(Pow(Operator(Symbol('A')), Integer(2)),Operator(Symbol('B')))") + + +def test_cg(): + cg = CG(1, 2, 3, 4, 5, 6) + wigner3j = Wigner3j(1, 2, 3, 4, 5, 6) + wigner6j = Wigner6j(1, 2, 3, 4, 5, 6) + wigner9j = Wigner9j(1, 2, 3, 4, 5, 6, 7, 8, 9) + assert str(cg) == 'CG(1, 2, 3, 4, 5, 6)' + ascii_str = \ +"""\ + 5,6 \n\ +C \n\ + 1,2,3,4\ +""" + ucode_str = \ +"""\ + 5,6 \n\ +C \n\ + 1,2,3,4\ +""" + assert pretty(cg) == ascii_str + assert upretty(cg) == ucode_str + assert latex(cg) == 'C^{5,6}_{1,2,3,4}' + assert latex(cg ** 2) == R'\left(C^{5,6}_{1,2,3,4}\right)^{2}' + sT(cg, "CG(Integer(1), Integer(2), Integer(3), Integer(4), Integer(5), Integer(6))") + assert str(wigner3j) == 'Wigner3j(1, 2, 3, 4, 5, 6)' + ascii_str = \ +"""\ +/1 3 5\\\n\ +| |\n\ +\\2 4 6/\ +""" + ucode_str = \ +"""\ +⎛1 3 5⎞\n\ +⎜ ⎟\n\ +⎝2 4 6⎠\ +""" + assert pretty(wigner3j) == ascii_str + assert upretty(wigner3j) == ucode_str + assert latex(wigner3j) == \ + r'\left(\begin{array}{ccc} 1 & 3 & 5 \\ 2 & 4 & 6 \end{array}\right)' + sT(wigner3j, "Wigner3j(Integer(1), Integer(2), Integer(3), Integer(4), Integer(5), Integer(6))") + assert str(wigner6j) == 'Wigner6j(1, 2, 3, 4, 5, 6)' + ascii_str = \ +"""\ +/1 2 3\\\n\ +< >\n\ +\\4 5 6/\ +""" + ucode_str = \ +"""\ +⎧1 2 3⎫\n\ +⎨ ⎬\n\ +⎩4 5 6⎭\ +""" + assert pretty(wigner6j) == ascii_str + assert upretty(wigner6j) == ucode_str + assert latex(wigner6j) == \ + r'\left\{\begin{array}{ccc} 1 & 2 & 3 \\ 4 & 5 & 6 \end{array}\right\}' + sT(wigner6j, "Wigner6j(Integer(1), Integer(2), Integer(3), Integer(4), Integer(5), Integer(6))") + assert str(wigner9j) == 'Wigner9j(1, 2, 3, 4, 5, 6, 7, 8, 9)' + ascii_str = \ +"""\ +/1 2 3\\\n\ +| |\n\ +<4 5 6>\n\ +| |\n\ +\\7 8 9/\ +""" + ucode_str = \ +"""\ +⎧1 2 3⎫\n\ +⎪ ⎪\n\ +⎨4 5 6⎬\n\ +⎪ ⎪\n\ +⎩7 8 9⎭\ +""" + assert pretty(wigner9j) == ascii_str + assert upretty(wigner9j) == ucode_str + assert latex(wigner9j) == \ + r'\left\{\begin{array}{ccc} 1 & 2 & 3 \\ 4 & 5 & 6 \\ 7 & 8 & 9 \end{array}\right\}' + sT(wigner9j, "Wigner9j(Integer(1), Integer(2), Integer(3), Integer(4), Integer(5), Integer(6), Integer(7), Integer(8), Integer(9))") + + +def test_commutator(): + A = Operator('A') + B = Operator('B') + c = Commutator(A, B) + c_tall = Commutator(A**2, B) + assert str(c) == '[A,B]' + assert pretty(c) == '[A,B]' + assert upretty(c) == '[A,B]' + assert latex(c) == r'\left[A,B\right]' + sT(c, "Commutator(Operator(Symbol('A')),Operator(Symbol('B')))") + assert str(c_tall) == '[A**2,B]' + ascii_str = \ +"""\ +[ 2 ]\n\ +[A ,B]\ +""" + ucode_str = \ +"""\ +⎡ 2 ⎤\n\ +⎣A ,B⎦\ +""" + assert pretty(c_tall) == ascii_str + assert upretty(c_tall) == ucode_str + assert latex(c_tall) == r'\left[A^{2},B\right]' + sT(c_tall, "Commutator(Pow(Operator(Symbol('A')), Integer(2)),Operator(Symbol('B')))") + + +def test_constants(): + assert str(hbar) == 'hbar' + assert pretty(hbar) == 'hbar' + assert upretty(hbar) == 'ℏ' + assert latex(hbar) == r'\hbar' + sT(hbar, "HBar()") + + +def test_dagger(): + x = symbols('x', commutative=False) + expr = Dagger(x) + assert str(expr) == 'Dagger(x)' + ascii_str = \ +"""\ + +\n\ +x \ +""" + ucode_str = \ +"""\ + †\n\ +x \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + assert latex(expr) == r'x^{\dagger}' + sT(expr, "Dagger(Symbol('x', commutative=False))") + + +@XFAIL +def test_gate_failing(): + a, b, c, d = symbols('a,b,c,d') + uMat = Matrix([[a, b], [c, d]]) + g = UGate((0,), uMat) + assert str(g) == 'U(0)' + + +def test_gate(): + a, b, c, d = symbols('a,b,c,d') + uMat = Matrix([[a, b], [c, d]]) + q = Qubit(1, 0, 1, 0, 1) + g1 = IdentityGate(2) + g2 = CGate((3, 0), XGate(1)) + g3 = CNotGate(1, 0) + g4 = UGate((0,), uMat) + assert str(g1) == '1(2)' + assert pretty(g1) == '1 \n 2' + assert upretty(g1) == '1 \n 2' + assert latex(g1) == r'1_{2}' + sT(g1, "IdentityGate(Integer(2))") + assert str(g1*q) == '1(2)*|10101>' + ascii_str = \ +"""\ +1 *|10101>\n\ + 2 \ +""" + ucode_str = \ +"""\ +1 ⋅❘10101⟩\n\ + 2 \ +""" + assert pretty(g1*q) == ascii_str + assert upretty(g1*q) == ucode_str + assert latex(g1*q) == r'1_{2} {\left|10101\right\rangle }' + sT(g1*q, "Mul(IdentityGate(Integer(2)), Qubit(Integer(1),Integer(0),Integer(1),Integer(0),Integer(1)))") + assert str(g2) == 'C((3,0),X(1))' + ascii_str = \ +"""\ +C /X \\\n\ + 3,0\\ 1/\ +""" + ucode_str = \ +"""\ +C ⎛X ⎞\n\ + 3,0⎝ 1⎠\ +""" + assert pretty(g2) == ascii_str + assert upretty(g2) == ucode_str + assert latex(g2) == r'C_{3,0}{\left(X_{1}\right)}' + sT(g2, "CGate(Tuple(Integer(3), Integer(0)),XGate(Integer(1)))") + assert str(g3) == 'CNOT(1,0)' + ascii_str = \ +"""\ +CNOT \n\ + 1,0\ +""" + ucode_str = \ +"""\ +CNOT \n\ + 1,0\ +""" + assert pretty(g3) == ascii_str + assert upretty(g3) == ucode_str + assert latex(g3) == r'\text{CNOT}_{1,0}' + sT(g3, "CNotGate(Integer(1),Integer(0))") + ascii_str = \ +"""\ +U \n\ + 0\ +""" + ucode_str = \ +"""\ +U \n\ + 0\ +""" + assert str(g4) == \ +"""\ +U((0,),Matrix([\n\ +[a, b],\n\ +[c, d]]))\ +""" + assert pretty(g4) == ascii_str + assert upretty(g4) == ucode_str + assert latex(g4) == r'U_{0}' + sT(g4, "UGate(Tuple(Integer(0)),ImmutableDenseMatrix([[Symbol('a'), Symbol('b')], [Symbol('c'), Symbol('d')]]))") + + +def test_hilbert(): + h1 = HilbertSpace() + h2 = ComplexSpace(2) + h3 = FockSpace() + h4 = L2(Interval(0, oo)) + assert str(h1) == 'H' + assert pretty(h1) == 'H' + assert upretty(h1) == 'H' + assert latex(h1) == r'\mathcal{H}' + sT(h1, "HilbertSpace()") + assert str(h2) == 'C(2)' + ascii_str = \ +"""\ + 2\n\ +C \ +""" + ucode_str = \ +"""\ + 2\n\ +C \ +""" + assert pretty(h2) == ascii_str + assert upretty(h2) == ucode_str + assert latex(h2) == r'\mathcal{C}^{2}' + sT(h2, "ComplexSpace(Integer(2))") + assert str(h3) == 'F' + assert pretty(h3) == 'F' + assert upretty(h3) == 'F' + assert latex(h3) == r'\mathcal{F}' + sT(h3, "FockSpace()") + assert str(h4) == 'L2(Interval(0, oo))' + ascii_str = \ +"""\ + 2\n\ +L \ +""" + ucode_str = \ +"""\ + 2\n\ +L \ +""" + assert pretty(h4) == ascii_str + assert upretty(h4) == ucode_str + assert latex(h4) == r'{\mathcal{L}^2}\left( \left[0, \infty\right) \right)' + sT(h4, "L2(Interval(Integer(0), oo, false, true))") + assert str(h1 + h2) == 'H+C(2)' + ascii_str = \ +"""\ + 2\n\ +H + C \ +""" + ucode_str = \ +"""\ + 2\n\ +H ⊕ C \ +""" + assert pretty(h1 + h2) == ascii_str + assert upretty(h1 + h2) == ucode_str + assert latex(h1 + h2) + sT(h1 + h2, "DirectSumHilbertSpace(HilbertSpace(),ComplexSpace(Integer(2)))") + assert str(h1*h2) == "H*C(2)" + ascii_str = \ +"""\ + 2\n\ +H x C \ +""" + ucode_str = \ +"""\ + 2\n\ +H ⨂ C \ +""" + assert pretty(h1*h2) == ascii_str + assert upretty(h1*h2) == ucode_str + assert latex(h1*h2) + sT(h1*h2, + "TensorProductHilbertSpace(HilbertSpace(),ComplexSpace(Integer(2)))") + assert str(h1**2) == 'H**2' + ascii_str = \ +"""\ + x2\n\ +H \ +""" + ucode_str = \ +"""\ + ⨂2\n\ +H \ +""" + assert pretty(h1**2) == ascii_str + assert upretty(h1**2) == ucode_str + assert latex(h1**2) == r'{\mathcal{H}}^{\otimes 2}' + sT(h1**2, "TensorPowerHilbertSpace(HilbertSpace(),Integer(2))") + + +def test_innerproduct(): + x = symbols('x') + ip1 = InnerProduct(Bra(), Ket()) + ip2 = InnerProduct(TimeDepBra(), TimeDepKet()) + ip3 = InnerProduct(JzBra(1, 1), JzKet(1, 1)) + ip4 = InnerProduct(JzBraCoupled(1, 1, (1, 1)), JzKetCoupled(1, 1, (1, 1))) + ip_tall1 = InnerProduct(Bra(x/2), Ket(x/2)) + ip_tall2 = InnerProduct(Bra(x), Ket(x/2)) + ip_tall3 = InnerProduct(Bra(x/2), Ket(x)) + assert str(ip1) == '' + assert pretty(ip1) == '' + assert upretty(ip1) == '⟨ψ❘ψ⟩' + assert latex( + ip1) == r'\left\langle \psi \right. {\left|\psi\right\rangle }' + sT(ip1, "InnerProduct(Bra(Symbol('psi')),Ket(Symbol('psi')))") + assert str(ip2) == '' + assert pretty(ip2) == '' + assert upretty(ip2) == '⟨ψ;t❘ψ;t⟩' + assert latex(ip2) == \ + r'\left\langle \psi;t \right. {\left|\psi;t\right\rangle }' + sT(ip2, "InnerProduct(TimeDepBra(Symbol('psi'),Symbol('t')),TimeDepKet(Symbol('psi'),Symbol('t')))") + assert str(ip3) == "<1,1|1,1>" + assert pretty(ip3) == '<1,1|1,1>' + assert upretty(ip3) == '⟨1,1❘1,1⟩' + assert latex(ip3) == r'\left\langle 1,1 \right. {\left|1,1\right\rangle }' + sT(ip3, "InnerProduct(JzBra(Integer(1),Integer(1)),JzKet(Integer(1),Integer(1)))") + assert str(ip4) == "<1,1,j1=1,j2=1|1,1,j1=1,j2=1>" + assert pretty(ip4) == '<1,1,j1=1,j2=1|1,1,j1=1,j2=1>' + assert upretty(ip4) == '⟨1,1,j₁=1,j₂=1❘1,1,j₁=1,j₂=1⟩' + assert latex(ip4) == \ + r'\left\langle 1,1,j_{1}=1,j_{2}=1 \right. {\left|1,1,j_{1}=1,j_{2}=1\right\rangle }' + sT(ip4, "InnerProduct(JzBraCoupled(Integer(1),Integer(1),Tuple(Integer(1), Integer(1)),Tuple(Tuple(Integer(1), Integer(2), Integer(1)))),JzKetCoupled(Integer(1),Integer(1),Tuple(Integer(1), Integer(1)),Tuple(Tuple(Integer(1), Integer(2), Integer(1)))))") + assert str(ip_tall1) == '' + ascii_str = \ +"""\ + / | \\ \n\ +/ x|x \\\n\ +\\ -|- /\n\ + \\2|2/ \ +""" + ucode_str = \ +"""\ + ╱ │ ╲ \n\ +╱ x│x ╲\n\ +╲ ─│─ ╱\n\ + ╲2│2╱ \ +""" + assert pretty(ip_tall1) == ascii_str + assert upretty(ip_tall1) == ucode_str + assert latex(ip_tall1) == \ + r'\left\langle \frac{x}{2} \right. {\left|\frac{x}{2}\right\rangle }' + sT(ip_tall1, "InnerProduct(Bra(Mul(Rational(1, 2), Symbol('x'))),Ket(Mul(Rational(1, 2), Symbol('x'))))") + assert str(ip_tall2) == '' + ascii_str = \ +"""\ + / | \\ \n\ +/ |x \\\n\ +\\ x|- /\n\ + \\ |2/ \ +""" + ucode_str = \ +"""\ + ╱ │ ╲ \n\ +╱ │x ╲\n\ +╲ x│─ ╱\n\ + ╲ │2╱ \ +""" + assert pretty(ip_tall2) == ascii_str + assert upretty(ip_tall2) == ucode_str + assert latex(ip_tall2) == \ + r'\left\langle x \right. {\left|\frac{x}{2}\right\rangle }' + sT(ip_tall2, + "InnerProduct(Bra(Symbol('x')),Ket(Mul(Rational(1, 2), Symbol('x'))))") + assert str(ip_tall3) == '' + ascii_str = \ +"""\ + / | \\ \n\ +/ x| \\\n\ +\\ -|x /\n\ + \\2| / \ +""" + ucode_str = \ +"""\ + ╱ │ ╲ \n\ +╱ x│ ╲\n\ +╲ ─│x ╱\n\ + ╲2│ ╱ \ +""" + assert pretty(ip_tall3) == ascii_str + assert upretty(ip_tall3) == ucode_str + assert latex(ip_tall3) == \ + r'\left\langle \frac{x}{2} \right. {\left|x\right\rangle }' + sT(ip_tall3, + "InnerProduct(Bra(Mul(Rational(1, 2), Symbol('x'))),Ket(Symbol('x')))") + + +def test_operator(): + a = Operator('A') + b = Operator('B', Symbol('t'), S.Half) + inv = a.inv() + f = Function('f') + x = symbols('x') + d = DifferentialOperator(Derivative(f(x), x), f(x)) + op = OuterProduct(Ket(), Bra()) + assert str(a) == 'A' + assert pretty(a) == 'A' + assert upretty(a) == 'A' + assert latex(a) == 'A' + sT(a, "Operator(Symbol('A'))") + assert str(inv) == 'A**(-1)' + ascii_str = \ +"""\ + -1\n\ +A \ +""" + ucode_str = \ +"""\ + -1\n\ +A \ +""" + assert pretty(inv) == ascii_str + assert upretty(inv) == ucode_str + assert latex(inv) == r'A^{-1}' + sT(inv, "Pow(Operator(Symbol('A')), Integer(-1))") + assert str(d) == 'DifferentialOperator(Derivative(f(x), x),f(x))' + ascii_str = \ +"""\ + /d \\\n\ +DifferentialOperator|--(f(x)),f(x)|\n\ + \\dx /\ +""" + ucode_str = \ +"""\ + ⎛d ⎞\n\ +DifferentialOperator⎜──(f(x)),f(x)⎟\n\ + ⎝dx ⎠\ +""" + assert pretty(d) == ascii_str + assert upretty(d) == ucode_str + assert latex(d) == \ + r'DifferentialOperator\left(\frac{d}{d x} f{\left(x \right)},f{\left(x \right)}\right)' + sT(d, "DifferentialOperator(Derivative(Function('f')(Symbol('x')), Tuple(Symbol('x'), Integer(1))),Function('f')(Symbol('x')))") + assert str(b) == 'Operator(B,t,1/2)' + assert pretty(b) == 'Operator(B,t,1/2)' + assert upretty(b) == 'Operator(B,t,1/2)' + assert latex(b) == r'Operator\left(B,t,\frac{1}{2}\right)' + sT(b, "Operator(Symbol('B'),Symbol('t'),Rational(1, 2))") + assert str(op) == '|psi>' + assert pretty(q1) == '|0101>' + assert upretty(q1) == '❘0101⟩' + assert latex(q1) == r'{\left|0101\right\rangle }' + sT(q1, "Qubit(Integer(0),Integer(1),Integer(0),Integer(1))") + assert str(q2) == '|8>' + assert pretty(q2) == '|8>' + assert upretty(q2) == '❘8⟩' + assert latex(q2) == r'{\left|8\right\rangle }' + sT(q2, "IntQubit(8)") + + +def test_spin(): + lz = JzOp('L') + ket = JzKet(1, 0) + bra = JzBra(1, 0) + cket = JzKetCoupled(1, 0, (1, 2)) + cbra = JzBraCoupled(1, 0, (1, 2)) + cket_big = JzKetCoupled(1, 0, (1, 2, 3)) + cbra_big = JzBraCoupled(1, 0, (1, 2, 3)) + rot = Rotation(1, 2, 3) + bigd = WignerD(1, 2, 3, 4, 5, 6) + smalld = WignerD(1, 2, 3, 0, 4, 0) + assert str(lz) == 'Lz' + ascii_str = \ +"""\ +L \n\ + z\ +""" + ucode_str = \ +"""\ +L \n\ + z\ +""" + assert pretty(lz) == ascii_str + assert upretty(lz) == ucode_str + assert latex(lz) == 'L_z' + sT(lz, "JzOp(Symbol('L'))") + assert str(J2) == 'J2' + ascii_str = \ +"""\ + 2\n\ +J \ +""" + ucode_str = \ +"""\ + 2\n\ +J \ +""" + assert pretty(J2) == ascii_str + assert upretty(J2) == ucode_str + assert latex(J2) == r'J^2' + sT(J2, "J2Op(Symbol('J'))") + assert str(Jz) == 'Jz' + ascii_str = \ +"""\ +J \n\ + z\ +""" + ucode_str = \ +"""\ +J \n\ + z\ +""" + assert pretty(Jz) == ascii_str + assert upretty(Jz) == ucode_str + assert latex(Jz) == 'J_z' + sT(Jz, "JzOp(Symbol('J'))") + assert str(ket) == '|1,0>' + assert pretty(ket) == '|1,0>' + assert upretty(ket) == '❘1,0⟩' + assert latex(ket) == r'{\left|1,0\right\rangle }' + sT(ket, "JzKet(Integer(1),Integer(0))") + assert str(bra) == '<1,0|' + assert pretty(bra) == '<1,0|' + assert upretty(bra) == '⟨1,0❘' + assert latex(bra) == r'{\left\langle 1,0\right|}' + sT(bra, "JzBra(Integer(1),Integer(0))") + assert str(cket) == '|1,0,j1=1,j2=2>' + assert pretty(cket) == '|1,0,j1=1,j2=2>' + assert upretty(cket) == '❘1,0,j₁=1,j₂=2⟩' + assert latex(cket) == r'{\left|1,0,j_{1}=1,j_{2}=2\right\rangle }' + sT(cket, "JzKetCoupled(Integer(1),Integer(0),Tuple(Integer(1), Integer(2)),Tuple(Tuple(Integer(1), Integer(2), Integer(1))))") + assert str(cbra) == '<1,0,j1=1,j2=2|' + assert pretty(cbra) == '<1,0,j1=1,j2=2|' + assert upretty(cbra) == '⟨1,0,j₁=1,j₂=2❘' + assert latex(cbra) == r'{\left\langle 1,0,j_{1}=1,j_{2}=2\right|}' + sT(cbra, "JzBraCoupled(Integer(1),Integer(0),Tuple(Integer(1), Integer(2)),Tuple(Tuple(Integer(1), Integer(2), Integer(1))))") + assert str(cket_big) == '|1,0,j1=1,j2=2,j3=3,j(1,2)=3>' + # TODO: Fix non-unicode pretty printing + # i.e. j1,2 -> j(1,2) + assert pretty(cket_big) == '|1,0,j1=1,j2=2,j3=3,j1,2=3>' + assert upretty(cket_big) == '❘1,0,j₁=1,j₂=2,j₃=3,j₁,₂=3⟩' + assert latex(cket_big) == \ + r'{\left|1,0,j_{1}=1,j_{2}=2,j_{3}=3,j_{1,2}=3\right\rangle }' + sT(cket_big, "JzKetCoupled(Integer(1),Integer(0),Tuple(Integer(1), Integer(2), Integer(3)),Tuple(Tuple(Integer(1), Integer(2), Integer(3)), Tuple(Integer(1), Integer(3), Integer(1))))") + assert str(cbra_big) == '<1,0,j1=1,j2=2,j3=3,j(1,2)=3|' + assert pretty(cbra_big) == '<1,0,j1=1,j2=2,j3=3,j1,2=3|' + assert upretty(cbra_big) == '⟨1,0,j₁=1,j₂=2,j₃=3,j₁,₂=3❘' + assert latex(cbra_big) == \ + r'{\left\langle 1,0,j_{1}=1,j_{2}=2,j_{3}=3,j_{1,2}=3\right|}' + sT(cbra_big, "JzBraCoupled(Integer(1),Integer(0),Tuple(Integer(1), Integer(2), Integer(3)),Tuple(Tuple(Integer(1), Integer(2), Integer(3)), Tuple(Integer(1), Integer(3), Integer(1))))") + assert str(rot) == 'R(1,2,3)' + assert pretty(rot) == 'R (1,2,3)' + assert upretty(rot) == 'ℛ (1,2,3)' + assert latex(rot) == r'\mathcal{R}\left(1,2,3\right)' + sT(rot, "Rotation(Integer(1),Integer(2),Integer(3))") + assert str(bigd) == 'WignerD(1, 2, 3, 4, 5, 6)' + ascii_str = \ +"""\ + 1 \n\ +D (4,5,6)\n\ + 2,3 \ +""" + ucode_str = \ +"""\ + 1 \n\ +D (4,5,6)\n\ + 2,3 \ +""" + assert pretty(bigd) == ascii_str + assert upretty(bigd) == ucode_str + assert latex(bigd) == r'D^{1}_{2,3}\left(4,5,6\right)' + sT(bigd, "WignerD(Integer(1), Integer(2), Integer(3), Integer(4), Integer(5), Integer(6))") + assert str(smalld) == 'WignerD(1, 2, 3, 0, 4, 0)' + ascii_str = \ +"""\ + 1 \n\ +d (4)\n\ + 2,3 \ +""" + ucode_str = \ +"""\ + 1 \n\ +d (4)\n\ + 2,3 \ +""" + assert pretty(smalld) == ascii_str + assert upretty(smalld) == ucode_str + assert latex(smalld) == r'd^{1}_{2,3}\left(4\right)' + sT(smalld, "WignerD(Integer(1), Integer(2), Integer(3), Integer(0), Integer(4), Integer(0))") + + +def test_state(): + x = symbols('x') + bra = Bra() + ket = Ket() + bra_tall = Bra(x/2) + ket_tall = Ket(x/2) + tbra = TimeDepBra() + tket = TimeDepKet() + assert str(bra) == '' + assert pretty(ket) == '|psi>' + assert upretty(ket) == '❘ψ⟩' + assert latex(ket) == r'{\left|\psi\right\rangle }' + sT(ket, "Ket(Symbol('psi'))") + assert str(bra_tall) == '' + ascii_str = \ +"""\ +| \\ \n\ +|x \\\n\ +|- /\n\ +|2/ \ +""" + ucode_str = \ +"""\ +│ ╲ \n\ +│x ╲\n\ +│─ ╱\n\ +│2╱ \ +""" + assert pretty(ket_tall) == ascii_str + assert upretty(ket_tall) == ucode_str + assert latex(ket_tall) == r'{\left|\frac{x}{2}\right\rangle }' + sT(ket_tall, "Ket(Mul(Rational(1, 2), Symbol('x')))") + assert str(tbra) == '' + assert pretty(tket) == '|psi;t>' + assert upretty(tket) == '❘ψ;t⟩' + assert latex(tket) == r'{\left|\psi;t\right\rangle }' + sT(tket, "TimeDepKet(Symbol('psi'),Symbol('t'))") + + +def test_tensorproduct(): + tp = TensorProduct(JzKet(1, 1), JzKet(1, 0)) + assert str(tp) == '|1,1>x|1,0>' + assert pretty(tp) == '|1,1>x |1,0>' + assert upretty(tp) == '❘1,1⟩⨂ ❘1,0⟩' + assert latex(tp) == \ + r'{{\left|1,1\right\rangle }}\otimes {{\left|1,0\right\rangle }}' + sT(tp, "TensorProduct(JzKet(Integer(1),Integer(1)), JzKet(Integer(1),Integer(0)))") + + +def test_big_expr(): + f = Function('f') + x = symbols('x') + e1 = Dagger(AntiCommutator(Operator('A') + Operator('B'), Pow(DifferentialOperator(Derivative(f(x), x), f(x)), 3))*TensorProduct(Jz**2, Operator('A') + Operator('B')))*(JzBra(1, 0) + JzBra(1, 1))*(JzKet(0, 0) + JzKet(1, -1)) + e2 = Commutator(Jz**2, Operator('A') + Operator('B'))*AntiCommutator(Dagger(Operator('C')*Operator('D')), Operator('E').inv()**2)*Dagger(Commutator(Jz, J2)) + e3 = Wigner3j(1, 2, 3, 4, 5, 6)*TensorProduct(Commutator(Operator('A') + Dagger(Operator('B')), Operator('C') + Operator('D')), Jz - J2)*Dagger(OuterProduct(Dagger(JzBra(1, 1)), JzBra(1, 0)))*TensorProduct(JzKetCoupled(1, 1, (1, 1)) + JzKetCoupled(1, 0, (1, 1)), JzKetCoupled(1, -1, (1, 1))) + e4 = (ComplexSpace(1)*ComplexSpace(2) + FockSpace()**2)*(L2(Interval( + 0, oo)) + HilbertSpace()) + assert str(e1) == '(Jz**2)x(Dagger(A) + Dagger(B))*{Dagger(DifferentialOperator(Derivative(f(x), x),f(x)))**3,Dagger(A) + Dagger(B)}*(<1,0| + <1,1|)*(|0,0> + |1,-1>)' + ascii_str = \ +"""\ + / 3 \\ \n\ + |/ +\\ | \n\ + 2 / + +\\ <| /d \\ | + +> \n\ +/J \\ x \\A + B /*||DifferentialOperator|--(f(x)),f(x)| | ,A + B |*(<1,0| + <1,1|)*(|0,0> + |1,-1>)\n\ +\\ z/ \\\\ \\dx / / / \ +""" + ucode_str = \ +"""\ + ⎧ 3 ⎫ \n\ + ⎪⎛ †⎞ ⎪ \n\ + 2 ⎛ † †⎞ ⎨⎜ ⎛d ⎞ ⎟ † †⎬ \n\ +⎛J ⎞ ⨂ ⎝A + B ⎠⋅⎪⎜DifferentialOperator⎜──(f(x)),f(x)⎟ ⎟ ,A + B ⎪⋅(⟨1,0❘ + ⟨1,1❘)⋅(❘0,0⟩ + ❘1,-1⟩)\n\ +⎝ z⎠ ⎩⎝ ⎝dx ⎠ ⎠ ⎭ \ +""" + assert pretty(e1) == ascii_str + assert upretty(e1) == ucode_str + assert latex(e1) == \ + r'{J_z^{2}}\otimes \left({A^{\dagger} + B^{\dagger}}\right) \left\{\left(DifferentialOperator\left(\frac{d}{d x} f{\left(x \right)},f{\left(x \right)}\right)^{\dagger}\right)^{3},A^{\dagger} + B^{\dagger}\right\} \left({\left\langle 1,0\right|} + {\left\langle 1,1\right|}\right) \left({\left|0,0\right\rangle } + {\left|1,-1\right\rangle }\right)' + sT(e1, "Mul(TensorProduct(Pow(JzOp(Symbol('J')), Integer(2)), Add(Dagger(Operator(Symbol('A'))), Dagger(Operator(Symbol('B'))))), AntiCommutator(Pow(Dagger(DifferentialOperator(Derivative(Function('f')(Symbol('x')), Tuple(Symbol('x'), Integer(1))),Function('f')(Symbol('x')))), Integer(3)),Add(Dagger(Operator(Symbol('A'))), Dagger(Operator(Symbol('B'))))), Add(JzBra(Integer(1),Integer(0)), JzBra(Integer(1),Integer(1))), Add(JzKet(Integer(0),Integer(0)), JzKet(Integer(1),Integer(-1))))") + assert str(e2) == '[Jz**2,A + B]*{E**(-2),Dagger(D)*Dagger(C)}*[J2,Jz]' + ascii_str = \ +"""\ +[ 2 ] / -2 + +\\ [ 2 ]\n\ +[/J \\ ,A + B]**[J ,J ]\n\ +[\\ z/ ] \\ / [ z]\ +""" + ucode_str = \ +"""\ +⎡ 2 ⎤ ⎧ -2 † †⎫ ⎡ 2 ⎤\n\ +⎢⎛J ⎞ ,A + B⎥⋅⎨E ,D ⋅C ⎬⋅⎢J ,J ⎥\n\ +⎣⎝ z⎠ ⎦ ⎩ ⎭ ⎣ z⎦\ +""" + assert pretty(e2) == ascii_str + assert upretty(e2) == ucode_str + assert latex(e2) == \ + r'\left[J_z^{2},A + B\right] \left\{E^{-2},D^{\dagger} C^{\dagger}\right\} \left[J^2,J_z\right]' + sT(e2, "Mul(Commutator(Pow(JzOp(Symbol('J')), Integer(2)),Add(Operator(Symbol('A')), Operator(Symbol('B')))), AntiCommutator(Pow(Operator(Symbol('E')), Integer(-2)),Mul(Dagger(Operator(Symbol('D'))), Dagger(Operator(Symbol('C'))))), Commutator(J2Op(Symbol('J')),JzOp(Symbol('J'))))") + assert str(e3) == \ + "Wigner3j(1, 2, 3, 4, 5, 6)*[Dagger(B) + A,C + D]x(-J2 + Jz)*|1,0><1,1|*(|1,0,j1=1,j2=1> + |1,1,j1=1,j2=1>)x|1,-1,j1=1,j2=1>" + ascii_str = \ +"""\ + [ + ] / 2 \\ \n\ +/1 3 5\\*[B + A,C + D]x |- J + J |*|1,0><1,1|*(|1,0,j1=1,j2=1> + |1,1,j1=1,j2=1>)x |1,-1,j1=1,j2=1>\n\ +| | \\ z/ \n\ +\\2 4 6/ \ +""" + ucode_str = \ +"""\ + ⎡ † ⎤ ⎛ 2 ⎞ \n\ +⎛1 3 5⎞⋅⎣B + A,C + D⎦⨂ ⎜- J + J ⎟⋅❘1,0⟩⟨1,1❘⋅(❘1,0,j₁=1,j₂=1⟩ + ❘1,1,j₁=1,j₂=1⟩)⨂ ❘1,-1,j₁=1,j₂=1⟩\n\ +⎜ ⎟ ⎝ z⎠ \n\ +⎝2 4 6⎠ \ +""" + assert pretty(e3) == ascii_str + assert upretty(e3) == ucode_str + assert latex(e3) == \ + r'\left(\begin{array}{ccc} 1 & 3 & 5 \\ 2 & 4 & 6 \end{array}\right) {\left[B^{\dagger} + A,C + D\right]}\otimes \left({- J^2 + J_z}\right) {\left|1,0\right\rangle }{\left\langle 1,1\right|} \left({{\left|1,0,j_{1}=1,j_{2}=1\right\rangle } + {\left|1,1,j_{1}=1,j_{2}=1\right\rangle }}\right)\otimes {{\left|1,-1,j_{1}=1,j_{2}=1\right\rangle }}' + sT(e3, "Mul(Wigner3j(Integer(1), Integer(2), Integer(3), Integer(4), Integer(5), Integer(6)), TensorProduct(Commutator(Add(Dagger(Operator(Symbol('B'))), Operator(Symbol('A'))),Add(Operator(Symbol('C')), Operator(Symbol('D')))), Add(Mul(Integer(-1), J2Op(Symbol('J'))), JzOp(Symbol('J')))), OuterProduct(JzKet(Integer(1),Integer(0)),JzBra(Integer(1),Integer(1))), TensorProduct(Add(JzKetCoupled(Integer(1),Integer(0),Tuple(Integer(1), Integer(1)),Tuple(Tuple(Integer(1), Integer(2), Integer(1)))), JzKetCoupled(Integer(1),Integer(1),Tuple(Integer(1), Integer(1)),Tuple(Tuple(Integer(1), Integer(2), Integer(1))))), JzKetCoupled(Integer(1),Integer(-1),Tuple(Integer(1), Integer(1)),Tuple(Tuple(Integer(1), Integer(2), Integer(1))))))") + assert str(e4) == '(C(1)*C(2)+F**2)*(L2(Interval(0, oo))+H)' + ascii_str = \ +"""\ +// 1 2\\ x2\\ / 2 \\\n\ +\\\\C x C / + F / x \\L + H/\ +""" + ucode_str = \ +"""\ +⎛⎛ 1 2⎞ ⨂2⎞ ⎛ 2 ⎞\n\ +⎝⎝C ⨂ C ⎠ ⊕ F ⎠ ⨂ ⎝L ⊕ H⎠\ +""" + assert pretty(e4) == ascii_str + assert upretty(e4) == ucode_str + assert latex(e4) == \ + r'\left(\left(\mathcal{C}^{1}\otimes \mathcal{C}^{2}\right)\oplus {\mathcal{F}}^{\otimes 2}\right)\otimes \left({\mathcal{L}^2}\left( \left[0, \infty\right) \right)\oplus \mathcal{H}\right)' + sT(e4, "TensorProductHilbertSpace((DirectSumHilbertSpace(TensorProductHilbertSpace(ComplexSpace(Integer(1)),ComplexSpace(Integer(2))),TensorPowerHilbertSpace(FockSpace(),Integer(2)))),(DirectSumHilbertSpace(L2(Interval(Integer(0), oo, false, true)),HilbertSpace())))") + + +def _test_sho1d(): + ad = RaisingOp('a') + assert pretty(ad) == ' \N{DAGGER}\na ' + assert latex(ad) == 'a^{\\dagger}' diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_qapply.py b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_qapply.py new file mode 100644 index 0000000000000000000000000000000000000000..be6f68d9869df84bc25bd0ebdfcde9ff49adc508 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_qapply.py @@ -0,0 +1,152 @@ +from sympy.core.mul import Mul +from sympy.core.numbers import (I, Integer, Rational) +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.miscellaneous import sqrt + +from sympy.physics.quantum.anticommutator import AntiCommutator +from sympy.physics.quantum.commutator import Commutator +from sympy.physics.quantum.constants import hbar +from sympy.physics.quantum.dagger import Dagger +from sympy.physics.quantum.gate import H, XGate, IdentityGate +from sympy.physics.quantum.operator import Operator, IdentityOperator +from sympy.physics.quantum.qapply import qapply +from sympy.physics.quantum.spin import Jx, Jy, Jz, Jplus, Jminus, J2, JzKet +from sympy.physics.quantum.tensorproduct import TensorProduct +from sympy.physics.quantum.state import Ket +from sympy.physics.quantum.density import Density +from sympy.physics.quantum.qubit import Qubit, QubitBra +from sympy.physics.quantum.boson import BosonOp, BosonFockKet, BosonFockBra +from sympy.testing.pytest import warns_deprecated_sympy + + +j, jp, m, mp = symbols("j j' m m'") + +z = JzKet(1, 0) +po = JzKet(1, 1) +mo = JzKet(1, -1) + +A = Operator('A') + + +class Foo(Operator): + def _apply_operator_JzKet(self, ket, **options): + return ket + + +def test_basic(): + assert qapply(Jz*po) == hbar*po + assert qapply(Jx*z) == hbar*po/sqrt(2) + hbar*mo/sqrt(2) + assert qapply((Jplus + Jminus)*z/sqrt(2)) == hbar*po + hbar*mo + assert qapply(Jz*(po + mo)) == hbar*po - hbar*mo + assert qapply(Jz*po + Jz*mo) == hbar*po - hbar*mo + assert qapply(Jminus*Jminus*po) == 2*hbar**2*mo + assert qapply(Jplus**2*mo) == 2*hbar**2*po + assert qapply(Jplus**2*Jminus**2*po) == 4*hbar**4*po + + +def test_extra(): + extra = z.dual*A*z + assert qapply(Jz*po*extra) == hbar*po*extra + assert qapply(Jx*z*extra) == (hbar*po/sqrt(2) + hbar*mo/sqrt(2))*extra + assert qapply( + (Jplus + Jminus)*z/sqrt(2)*extra) == hbar*po*extra + hbar*mo*extra + assert qapply(Jz*(po + mo)*extra) == hbar*po*extra - hbar*mo*extra + assert qapply(Jz*po*extra + Jz*mo*extra) == hbar*po*extra - hbar*mo*extra + assert qapply(Jminus*Jminus*po*extra) == 2*hbar**2*mo*extra + assert qapply(Jplus**2*mo*extra) == 2*hbar**2*po*extra + assert qapply(Jplus**2*Jminus**2*po*extra) == 4*hbar**4*po*extra + + +def test_innerproduct(): + assert qapply(po.dual*Jz*po, ip_doit=False) == hbar*(po.dual*po) + assert qapply(po.dual*Jz*po) == hbar + + +def test_zero(): + assert qapply(0) == 0 + assert qapply(Integer(0)) == 0 + + +def test_commutator(): + assert qapply(Commutator(Jx, Jy)*Jz*po) == I*hbar**3*po + assert qapply(Commutator(J2, Jz)*Jz*po) == 0 + assert qapply(Commutator(Jz, Foo('F'))*po) == 0 + assert qapply(Commutator(Foo('F'), Jz)*po) == 0 + + +def test_anticommutator(): + assert qapply(AntiCommutator(Jz, Foo('F'))*po) == 2*hbar*po + assert qapply(AntiCommutator(Foo('F'), Jz)*po) == 2*hbar*po + + +def test_outerproduct(): + e = Jz*(mo*po.dual)*Jz*po + assert qapply(e) == -hbar**2*mo + assert qapply(e, ip_doit=False) == -hbar**2*(po.dual*po)*mo + assert qapply(e).doit() == -hbar**2*mo + + +def test_tensorproduct(): + a = BosonOp("a") + b = BosonOp("b") + ket1 = TensorProduct(BosonFockKet(1), BosonFockKet(2)) + ket2 = TensorProduct(BosonFockKet(0), BosonFockKet(0)) + ket3 = TensorProduct(BosonFockKet(0), BosonFockKet(2)) + bra1 = TensorProduct(BosonFockBra(0), BosonFockBra(0)) + bra2 = TensorProduct(BosonFockBra(1), BosonFockBra(2)) + assert qapply(TensorProduct(a, b ** 2) * ket1) == sqrt(2) * ket2 + assert qapply(TensorProduct(a, Dagger(b) * b) * ket1) == 2 * ket3 + assert qapply(bra1 * TensorProduct(a, b * b), + dagger=True) == sqrt(2) * bra2 + assert qapply(bra2 * ket1).doit() == S.One + assert qapply(TensorProduct(a, b * b) * ket1) == sqrt(2) * ket2 + assert qapply(Dagger(TensorProduct(a, b * b) * ket1), + dagger=True) == sqrt(2) * Dagger(ket2) + + +def test_dagger(): + lhs = Dagger(Qubit(0))*Dagger(H(0)) + rhs = Dagger(Qubit(1))/sqrt(2) + Dagger(Qubit(0))/sqrt(2) + assert qapply(lhs, dagger=True) == rhs + + +def test_issue_6073(): + x, y = symbols('x y', commutative=False) + A = Ket(x, y) + B = Operator('B') + assert qapply(A) == A + assert qapply(A.dual*B) == A.dual*B + + +def test_density(): + d = Density([Jz*mo, 0.5], [Jz*po, 0.5]) + assert qapply(d) == Density([-hbar*mo, 0.5], [hbar*po, 0.5]) + + +def test_issue3044(): + expr1 = TensorProduct(Jz*JzKet(S(2),S.NegativeOne)/sqrt(2), Jz*JzKet(S.Half,S.Half)) + result = Mul(S.NegativeOne, Rational(1, 4), 2**S.Half, hbar**2) + result *= TensorProduct(JzKet(2,-1), JzKet(S.Half,S.Half)) + assert qapply(expr1) == result + + +# Issue 24158: Tests whether qapply incorrectly evaluates some ket*op as op*ket +def test_issue24158_ket_times_op(): + P = BosonFockKet(0) * BosonOp("a") # undefined term + # Does lhs._apply_operator_BosonOp(rhs) still evaluate ket*op as op*ket? + assert qapply(P) == P # qapply(P) -> BosonOp("a")*BosonFockKet(0) = 0 before fix + P = Qubit(1) * XGate(0) # undefined term + # Does rhs._apply_operator_Qubit(lhs) still evaluate ket*op as op*ket? + assert qapply(P) == P # qapply(P) -> Qubit(0) before fix + P1 = Mul(QubitBra(0), Mul(QubitBra(0), Qubit(0)), XGate(0)) # legal expr <0| * (<1|*|1>) * X + assert qapply(P1) == QubitBra(0) * XGate(0) # qapply(P1) -> 0 before fix + P1 = qapply(P1, dagger = True) # unsatisfactorily -> <0|*X(0), expect <1| since dagger=True + assert qapply(P1, dagger = True) == QubitBra(1) # qapply(P1, dagger=True) -> 0 before fix + P2 = QubitBra(0) * (QubitBra(0) * Qubit(0)) * XGate(0) # 'forgot' to set brackets + P2 = qapply(P2, dagger = True) # unsatisfactorily -> <0|*X(0), expect <1| since dagger=True + assert P2 == QubitBra(1) # qapply(P1) -> 0 before fix + # Pull Request 24237: IdentityOperator from the right without dagger=True option + with warns_deprecated_sympy(): + assert qapply(QubitBra(1)*IdentityOperator()) == QubitBra(1) + assert qapply(IdentityGate(0)*(Qubit(0) + Qubit(1))) == Qubit(0) + Qubit(1) diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_qasm.py b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_qasm.py new file mode 100644 index 0000000000000000000000000000000000000000..81c7ee8523e732d336211f7739a6e8f7fbab5220 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_qasm.py @@ -0,0 +1,89 @@ +from sympy.physics.quantum.qasm import Qasm, flip_index, trim,\ + get_index, nonblank, fullsplit, fixcommand, stripquotes, read_qasm +from sympy.physics.quantum.gate import X, Z, H, S, T +from sympy.physics.quantum.gate import CNOT, SWAP, CPHASE, CGate, CGateS +from sympy.physics.quantum.circuitplot import Mz + +def test_qasm_readqasm(): + qasm_lines = """\ + qubit q_0 + qubit q_1 + h q_0 + cnot q_0,q_1 + """ + q = read_qasm(qasm_lines) + assert q.get_circuit() == CNOT(1,0)*H(1) + +def test_qasm_ex1(): + q = Qasm('qubit q0', 'qubit q1', 'h q0', 'cnot q0,q1') + assert q.get_circuit() == CNOT(1,0)*H(1) + +def test_qasm_ex1_methodcalls(): + q = Qasm() + q.qubit('q_0') + q.qubit('q_1') + q.h('q_0') + q.cnot('q_0', 'q_1') + assert q.get_circuit() == CNOT(1,0)*H(1) + +def test_qasm_swap(): + q = Qasm('qubit q0', 'qubit q1', 'cnot q0,q1', 'cnot q1,q0', 'cnot q0,q1') + assert q.get_circuit() == CNOT(1,0)*CNOT(0,1)*CNOT(1,0) + + +def test_qasm_ex2(): + q = Qasm('qubit q_0', 'qubit q_1', 'qubit q_2', 'h q_1', + 'cnot q_1,q_2', 'cnot q_0,q_1', 'h q_0', + 'measure q_1', 'measure q_0', + 'c-x q_1,q_2', 'c-z q_0,q_2') + assert q.get_circuit() == CGate(2,Z(0))*CGate(1,X(0))*Mz(2)*Mz(1)*H(2)*CNOT(2,1)*CNOT(1,0)*H(1) + +def test_qasm_1q(): + for symbol, gate in [('x', X), ('z', Z), ('h', H), ('s', S), ('t', T), ('measure', Mz)]: + q = Qasm('qubit q_0', '%s q_0' % symbol) + assert q.get_circuit() == gate(0) + +def test_qasm_2q(): + for symbol, gate in [('cnot', CNOT), ('swap', SWAP), ('cphase', CPHASE)]: + q = Qasm('qubit q_0', 'qubit q_1', '%s q_0,q_1' % symbol) + assert q.get_circuit() == gate(1,0) + +def test_qasm_3q(): + q = Qasm('qubit q0', 'qubit q1', 'qubit q2', 'toffoli q2,q1,q0') + assert q.get_circuit() == CGateS((0,1),X(2)) + +def test_qasm_flip_index(): + assert flip_index(0, 2) == 1 + assert flip_index(1, 2) == 0 + +def test_qasm_trim(): + assert trim('nothing happens here') == 'nothing happens here' + assert trim("Something #happens here") == "Something " + +def test_qasm_get_index(): + assert get_index('q0', ['q0', 'q1']) == 1 + assert get_index('q1', ['q0', 'q1']) == 0 + +def test_qasm_nonblank(): + assert list(nonblank('abcd')) == list('abcd') + assert list(nonblank('abc ')) == list('abc') + +def test_qasm_fullsplit(): + assert fullsplit('g q0,q1,q2, q3') == ('g', ['q0', 'q1', 'q2', 'q3']) + +def test_qasm_fixcommand(): + assert fixcommand('foo') == 'foo' + assert fixcommand('def') == 'qdef' + +def test_qasm_stripquotes(): + assert stripquotes("'S'") == 'S' + assert stripquotes('"S"') == 'S' + assert stripquotes('S') == 'S' + +def test_qasm_qdef(): + # weaker test condition (str) since we don't have access to the actual class + q = Qasm("def Q,0,Q",'qubit q0','Q q0') + assert str(q.get_circuit()) == 'Q(0)' + + q = Qasm("def CQ,1,Q", 'qubit q0', 'qubit q1', 'CQ q0,q1') + assert str(q.get_circuit()) == 'C((1),Q(0))' diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_qexpr.py b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_qexpr.py new file mode 100644 index 0000000000000000000000000000000000000000..c01817935a0f977e44c8e0dc29746e070b2cb693 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_qexpr.py @@ -0,0 +1,64 @@ +from sympy.core.numbers import Integer +from sympy.core.symbol import Symbol +from sympy.concrete import Sum +from sympy.physics.quantum.qexpr import QExpr, _qsympify_sequence +from sympy.physics.quantum.hilbert import HilbertSpace +from sympy.core.containers import Tuple + +x = Symbol('x') +y = Symbol('y') +n = Symbol('n', integer=True) +m = Symbol('m', integer=True) + + +def test_qexpr_new(): + q = QExpr(0) + assert q.label == (0,) + assert q.hilbert_space == HilbertSpace() + assert q.is_commutative is False + + q = QExpr(0, 1) + assert q.label == (Integer(0), Integer(1)) + + q = QExpr._new_rawargs(HilbertSpace(), Integer(0), Integer(1)) + assert q.label == (Integer(0), Integer(1)) + assert q.hilbert_space == HilbertSpace() + + +def test_qexpr_commutative(): + q1 = QExpr(x) + q2 = QExpr(y) + assert q1.is_commutative is False + assert q2.is_commutative is False + assert q1*q2 != q2*q1 + + q = QExpr._new_rawargs(Integer(0), Integer(1), HilbertSpace()) + assert q.is_commutative is False + + +def test_qexpr_free_symbols(): + q1 = QExpr(x, y) + assert q1.free_symbols == {x, y} + + +def test_qexpr_sum(): + q1 = Sum(QExpr(n), (n,0,2)) + assert q1.doit() == QExpr(0) + QExpr(1) + QExpr(2) + + q2 = Sum(QExpr(n, m), (n, 0, 2), (m, 0, 2)) + assert q2.doit() == QExpr(0, 0) + QExpr(0, 1) + QExpr(0, 2) + \ + QExpr(1, 0) + QExpr(1, 1) + QExpr(1, 2) + \ + QExpr(2, 0) + QExpr(2, 1) + QExpr(2, 2) + + +def test_qexpr_subs(): + q1 = QExpr(x, y) + assert q1.subs(x, y) == QExpr(y, y) + assert q1.subs({x: 1, y: 2}) == QExpr(1, 2) + + +def test_qsympify(): + assert _qsympify_sequence([[1, 2], [1, 3]]) == (Tuple(1, 2), Tuple(1, 3)) + assert _qsympify_sequence(([1, 2, [3, 4, [2, ]], 1], 3)) == \ + (Tuple(1, 2, Tuple(3, 4, Tuple(2,)), 1), 3) + assert _qsympify_sequence((1,)) == (1,) diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_qft.py b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_qft.py new file mode 100644 index 0000000000000000000000000000000000000000..832f0194702b2031cfdff9d061a259e85476a88d --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_qft.py @@ -0,0 +1,52 @@ +from sympy.core.numbers import (I, pi) +from sympy.core.symbol import Symbol +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.matrices.dense import Matrix + +from sympy.physics.quantum.qft import QFT, IQFT, RkGate +from sympy.physics.quantum.gate import (ZGate, SwapGate, HadamardGate, CGate, + PhaseGate, TGate) +from sympy.physics.quantum.qubit import Qubit +from sympy.physics.quantum.qapply import qapply +from sympy.physics.quantum.represent import represent + +from sympy.functions.elementary.complexes import sign + + +def test_RkGate(): + x = Symbol('x') + assert RkGate(1, x).k == x + assert RkGate(1, x).targets == (1,) + assert RkGate(1, 1) == ZGate(1) + assert RkGate(2, 2) == PhaseGate(2) + assert RkGate(3, 3) == TGate(3) + + assert represent( + RkGate(0, x), nqubits=1) == Matrix([[1, 0], [0, exp(sign(x)*2*pi*I/(2**abs(x)))]]) + + +def test_quantum_fourier(): + assert QFT(0, 3).decompose() == \ + SwapGate(0, 2)*HadamardGate(0)*CGate((0,), PhaseGate(1)) * \ + HadamardGate(1)*CGate((0,), TGate(2))*CGate((1,), PhaseGate(2)) * \ + HadamardGate(2) + + assert IQFT(0, 3).decompose() == \ + HadamardGate(2)*CGate((1,), RkGate(2, -2))*CGate((0,), RkGate(2, -3)) * \ + HadamardGate(1)*CGate((0,), RkGate(1, -2))*HadamardGate(0)*SwapGate(0, 2) + + assert represent(QFT(0, 3), nqubits=3) == \ + Matrix([[exp(2*pi*I/8)**(i*j % 8)/sqrt(8) for i in range(8)] for j in range(8)]) + + assert QFT(0, 4).decompose() # non-trivial decomposition + assert qapply(QFT(0, 3).decompose()*Qubit(0, 0, 0)).expand() == qapply( + HadamardGate(0)*HadamardGate(1)*HadamardGate(2)*Qubit(0, 0, 0) + ).expand() + + +def test_qft_represent(): + c = QFT(0, 3) + a = represent(c, nqubits=3) + b = represent(c.decompose(), nqubits=3) + assert a.evalf(n=10) == b.evalf(n=10) diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_qubit.py b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_qubit.py new file mode 100644 index 0000000000000000000000000000000000000000..b4c236008a6b8cf85b5a45c5167b9dc36fb21019 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_qubit.py @@ -0,0 +1,264 @@ +import random + +from sympy.core.numbers import (Integer, Rational) +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.matrices.dense import Matrix +from sympy.physics.quantum.qubit import (measure_all, measure_all_oneshot, measure_partial, + matrix_to_qubit, matrix_to_density, + qubit_to_matrix, IntQubit, + IntQubitBra, QubitBra) +from sympy.physics.quantum.gate import (HadamardGate, CNOT, XGate, YGate, + ZGate, PhaseGate) +from sympy.physics.quantum.qapply import qapply +from sympy.physics.quantum.represent import represent +from sympy.physics.quantum.shor import Qubit +from sympy.testing.pytest import raises +from sympy.physics.quantum.density import Density +from sympy.physics.quantum.trace import Tr + +x, y = symbols('x,y') + +epsilon = .000001 + + +def test_Qubit(): + array = [0, 0, 1, 1, 0] + qb = Qubit('00110') + assert qb.flip(0) == Qubit('00111') + assert qb.flip(1) == Qubit('00100') + assert qb.flip(4) == Qubit('10110') + assert qb.qubit_values == (0, 0, 1, 1, 0) + assert qb.dimension == 5 + for i in range(5): + assert qb[i] == array[4 - i] + assert len(qb) == 5 + qb = Qubit('110') + + +def test_QubitBra(): + qb = Qubit(0) + qb_bra = QubitBra(0) + assert qb.dual_class() == QubitBra + assert qb_bra.dual_class() == Qubit + + qb = Qubit(1, 1, 0) + qb_bra = QubitBra(1, 1, 0) + assert represent(qb, nqubits=3).H == represent(qb_bra, nqubits=3) + + qb = Qubit(0, 1) + qb_bra = QubitBra(1,0) + assert qb._eval_innerproduct_QubitBra(qb_bra) == Integer(0) + + qb_bra = QubitBra(0, 1) + assert qb._eval_innerproduct_QubitBra(qb_bra) == Integer(1) + + +def test_IntQubit(): + # issue 9136 + iqb = IntQubit(0, nqubits=1) + assert qubit_to_matrix(Qubit('0')) == qubit_to_matrix(iqb) + + qb = Qubit('1010') + assert qubit_to_matrix(IntQubit(qb)) == qubit_to_matrix(qb) + + iqb = IntQubit(1, nqubits=1) + assert qubit_to_matrix(Qubit('1')) == qubit_to_matrix(iqb) + assert qubit_to_matrix(IntQubit(1)) == qubit_to_matrix(iqb) + + iqb = IntQubit(7, nqubits=4) + assert qubit_to_matrix(Qubit('0111')) == qubit_to_matrix(iqb) + assert qubit_to_matrix(IntQubit(7, 4)) == qubit_to_matrix(iqb) + + iqb = IntQubit(8) + assert iqb.as_int() == 8 + assert iqb.qubit_values == (1, 0, 0, 0) + + iqb = IntQubit(7, 4) + assert iqb.qubit_values == (0, 1, 1, 1) + assert IntQubit(3) == IntQubit(3, 2) + + #test Dual Classes + iqb = IntQubit(3) + iqb_bra = IntQubitBra(3) + assert iqb.dual_class() == IntQubitBra + assert iqb_bra.dual_class() == IntQubit + + iqb = IntQubit(5) + iqb_bra = IntQubitBra(5) + assert iqb._eval_innerproduct_IntQubitBra(iqb_bra) == Integer(1) + + iqb = IntQubit(4) + iqb_bra = IntQubitBra(5) + assert iqb._eval_innerproduct_IntQubitBra(iqb_bra) == Integer(0) + raises(ValueError, lambda: IntQubit(4, 1)) + + raises(ValueError, lambda: IntQubit('5')) + raises(ValueError, lambda: IntQubit(5, '5')) + raises(ValueError, lambda: IntQubit(5, nqubits='5')) + raises(TypeError, lambda: IntQubit(5, bad_arg=True)) + +def test_superposition_of_states(): + state = 1/sqrt(2)*Qubit('01') + 1/sqrt(2)*Qubit('10') + state_gate = CNOT(0, 1)*HadamardGate(0)*state + state_expanded = Qubit('01')/2 + Qubit('00')/2 - Qubit('11')/2 + Qubit('10')/2 + assert qapply(state_gate).expand() == state_expanded + assert matrix_to_qubit(represent(state_gate, nqubits=2)) == state_expanded + + +#test apply methods +def test_apply_represent_equality(): + gates = [HadamardGate(int(3*random.random())), + XGate(int(3*random.random())), ZGate(int(3*random.random())), + YGate(int(3*random.random())), ZGate(int(3*random.random())), + PhaseGate(int(3*random.random()))] + + circuit = Qubit(int(random.random()*2), int(random.random()*2), + int(random.random()*2), int(random.random()*2), int(random.random()*2), + int(random.random()*2)) + for i in range(int(random.random()*6)): + circuit = gates[int(random.random()*6)]*circuit + + mat = represent(circuit, nqubits=6) + states = qapply(circuit) + state_rep = matrix_to_qubit(mat) + states = states.expand() + state_rep = state_rep.expand() + assert state_rep == states + + +def test_matrix_to_qubits(): + qb = Qubit(0, 0, 0, 0) + mat = Matrix([1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) + assert matrix_to_qubit(mat) == qb + assert qubit_to_matrix(qb) == mat + + state = 2*sqrt(2)*(Qubit(0, 0, 0) + Qubit(0, 0, 1) + Qubit(0, 1, 0) + + Qubit(0, 1, 1) + Qubit(1, 0, 0) + Qubit(1, 0, 1) + + Qubit(1, 1, 0) + Qubit(1, 1, 1)) + ones = sqrt(2)*2*Matrix([1, 1, 1, 1, 1, 1, 1, 1]) + assert matrix_to_qubit(ones) == state.expand() + assert qubit_to_matrix(state) == ones + + +def test_measure_normalize(): + a, b = symbols('a b') + state = a*Qubit('110') + b*Qubit('111') + assert measure_partial(state, (0,), normalize=False) == \ + [(a*Qubit('110'), a*a.conjugate()), (b*Qubit('111'), b*b.conjugate())] + assert measure_all(state, normalize=False) == \ + [(Qubit('110'), a*a.conjugate()), (Qubit('111'), b*b.conjugate())] + + +def test_measure_partial(): + #Basic test of collapse of entangled two qubits (Bell States) + state = Qubit('01') + Qubit('10') + assert measure_partial(state, (0,)) == \ + [(Qubit('10'), S.Half), (Qubit('01'), S.Half)] + assert measure_partial(state, int(0)) == \ + [(Qubit('10'), S.Half), (Qubit('01'), S.Half)] + assert measure_partial(state, (0,)) == \ + measure_partial(state, (1,))[::-1] + + #Test of more complex collapse and probability calculation + state1 = sqrt(2)/sqrt(3)*Qubit('00001') + 1/sqrt(3)*Qubit('11111') + assert measure_partial(state1, (0,)) == \ + [(sqrt(2)/sqrt(3)*Qubit('00001') + 1/sqrt(3)*Qubit('11111'), 1)] + assert measure_partial(state1, (1, 2)) == measure_partial(state1, (3, 4)) + assert measure_partial(state1, (1, 2, 3)) == \ + [(Qubit('00001'), Rational(2, 3)), (Qubit('11111'), Rational(1, 3))] + + #test of measuring multiple bits at once + state2 = Qubit('1111') + Qubit('1101') + Qubit('1011') + Qubit('1000') + assert measure_partial(state2, (0, 1, 3)) == \ + [(Qubit('1000'), Rational(1, 4)), (Qubit('1101'), Rational(1, 4)), + (Qubit('1011')/sqrt(2) + Qubit('1111')/sqrt(2), S.Half)] + assert measure_partial(state2, (0,)) == \ + [(Qubit('1000'), Rational(1, 4)), + (Qubit('1111')/sqrt(3) + Qubit('1101')/sqrt(3) + + Qubit('1011')/sqrt(3), Rational(3, 4))] + + +def test_measure_all(): + assert measure_all(Qubit('11')) == [(Qubit('11'), 1)] + state = Qubit('11') + Qubit('10') + assert measure_all(state) == [(Qubit('10'), S.Half), + (Qubit('11'), S.Half)] + state2 = Qubit('11')/sqrt(5) + 2*Qubit('00')/sqrt(5) + assert measure_all(state2) == \ + [(Qubit('00'), Rational(4, 5)), (Qubit('11'), Rational(1, 5))] + + # from issue #12585 + assert measure_all(qapply(Qubit('0'))) == [(Qubit('0'), 1)] + + +def test_measure_all_oneshot(): + random.seed(42) + # for issue #27092 + assert measure_all_oneshot(Qubit('11')) == Qubit('11') + assert measure_all_oneshot(Qubit('1')) == Qubit('1') + assert measure_all_oneshot(Qubit('0')/sqrt(2) + Qubit('1')/sqrt(2)) == \ + Qubit('0') + + +def test_eval_trace(): + q1 = Qubit('10110') + q2 = Qubit('01010') + d = Density([q1, 0.6], [q2, 0.4]) + + t = Tr(d) + assert t.doit() == 1.0 + + # extreme bits + t = Tr(d, 0) + assert t.doit() == (0.4*Density([Qubit('0101'), 1]) + + 0.6*Density([Qubit('1011'), 1])) + t = Tr(d, 4) + assert t.doit() == (0.4*Density([Qubit('1010'), 1]) + + 0.6*Density([Qubit('0110'), 1])) + # index somewhere in between + t = Tr(d, 2) + assert t.doit() == (0.4*Density([Qubit('0110'), 1]) + + 0.6*Density([Qubit('1010'), 1])) + #trace all indices + t = Tr(d, [0, 1, 2, 3, 4]) + assert t.doit() == 1.0 + + # trace some indices, initialized in + # non-canonical order + t = Tr(d, [2, 1, 3]) + assert t.doit() == (0.4*Density([Qubit('00'), 1]) + + 0.6*Density([Qubit('10'), 1])) + + # mixed states + q = (1/sqrt(2)) * (Qubit('00') + Qubit('11')) + d = Density( [q, 1.0] ) + t = Tr(d, 0) + assert t.doit() == (0.5*Density([Qubit('0'), 1]) + + 0.5*Density([Qubit('1'), 1])) + + +def test_matrix_to_density(): + mat = Matrix([[0, 0], [0, 1]]) + assert matrix_to_density(mat) == Density([Qubit('1'), 1]) + + mat = Matrix([[1, 0], [0, 0]]) + assert matrix_to_density(mat) == Density([Qubit('0'), 1]) + + mat = Matrix([[0, 0], [0, 0]]) + assert matrix_to_density(mat) == 0 + + mat = Matrix([[0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 0]]) + + assert matrix_to_density(mat) == Density([Qubit('10'), 1]) + + mat = Matrix([[1, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0]]) + + assert matrix_to_density(mat) == Density([Qubit('00'), 1]) diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_represent.py b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_represent.py new file mode 100644 index 0000000000000000000000000000000000000000..c49dcbd7e7876f30cbe8e5426c91419903add5ff --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_represent.py @@ -0,0 +1,186 @@ +from sympy.core.numbers import (Float, I, Integer) +from sympy.matrices.dense import Matrix +from sympy.external import import_module +from sympy.testing.pytest import skip + +from sympy.physics.quantum.dagger import Dagger +from sympy.physics.quantum.represent import (represent, rep_innerproduct, + rep_expectation, enumerate_states) +from sympy.physics.quantum.state import Bra, Ket +from sympy.physics.quantum.operator import Operator, OuterProduct +from sympy.physics.quantum.tensorproduct import TensorProduct +from sympy.physics.quantum.tensorproduct import matrix_tensor_product +from sympy.physics.quantum.commutator import Commutator +from sympy.physics.quantum.anticommutator import AntiCommutator +from sympy.physics.quantum.innerproduct import InnerProduct +from sympy.physics.quantum.matrixutils import (numpy_ndarray, + scipy_sparse_matrix, to_numpy, + to_scipy_sparse, to_sympy) +from sympy.physics.quantum.cartesian import XKet, XOp, XBra +from sympy.physics.quantum.qapply import qapply +from sympy.physics.quantum.operatorset import operators_to_state +from sympy.testing.pytest import raises + +Amat = Matrix([[1, I], [-I, 1]]) +Bmat = Matrix([[1, 2], [3, 4]]) +Avec = Matrix([[1], [I]]) + + +class AKet(Ket): + + @classmethod + def dual_class(self): + return ABra + + def _represent_default_basis(self, **options): + return self._represent_AOp(None, **options) + + def _represent_AOp(self, basis, **options): + return Avec + + +class ABra(Bra): + + @classmethod + def dual_class(self): + return AKet + + +class AOp(Operator): + + def _represent_default_basis(self, **options): + return self._represent_AOp(None, **options) + + def _represent_AOp(self, basis, **options): + return Amat + + +class BOp(Operator): + + def _represent_default_basis(self, **options): + return self._represent_AOp(None, **options) + + def _represent_AOp(self, basis, **options): + return Bmat + + +k = AKet('a') +b = ABra('a') +A = AOp('A') +B = BOp('B') + +_tests = [ + # Bra + (b, Dagger(Avec)), + (Dagger(b), Avec), + # Ket + (k, Avec), + (Dagger(k), Dagger(Avec)), + # Operator + (A, Amat), + (Dagger(A), Dagger(Amat)), + # OuterProduct + (OuterProduct(k, b), Avec*Avec.H), + # TensorProduct + (TensorProduct(A, B), matrix_tensor_product(Amat, Bmat)), + # Pow + (A**2, Amat**2), + # Add/Mul + (A*B + 2*A, Amat*Bmat + 2*Amat), + # Commutator + (Commutator(A, B), Amat*Bmat - Bmat*Amat), + # AntiCommutator + (AntiCommutator(A, B), Amat*Bmat + Bmat*Amat), + # InnerProduct + (InnerProduct(b, k), (Avec.H*Avec)[0]) +] + + +def test_format_sympy(): + for test in _tests: + lhs = represent(test[0], basis=A, format='sympy') + rhs = to_sympy(test[1]) + assert lhs == rhs + + +def test_scalar_sympy(): + assert represent(Integer(1)) == Integer(1) + assert represent(Float(1.0)) == Float(1.0) + assert represent(1.0 + I) == 1.0 + I + + +np = import_module('numpy') + + +def test_format_numpy(): + if not np: + skip("numpy not installed.") + + for test in _tests: + lhs = represent(test[0], basis=A, format='numpy') + rhs = to_numpy(test[1]) + if isinstance(lhs, numpy_ndarray): + assert (lhs == rhs).all() + else: + assert lhs == rhs + + +def test_scalar_numpy(): + if not np: + skip("numpy not installed.") + + assert represent(Integer(1), format='numpy') == 1 + assert represent(Float(1.0), format='numpy') == 1.0 + assert represent(1.0 + I, format='numpy') == 1.0 + 1.0j + + +scipy = import_module('scipy', import_kwargs={'fromlist': ['sparse']}) + + +def test_format_scipy_sparse(): + if not np: + skip("numpy not installed.") + if not scipy: + skip("scipy not installed.") + + for test in _tests: + lhs = represent(test[0], basis=A, format='scipy.sparse') + rhs = to_scipy_sparse(test[1]) + if isinstance(lhs, scipy_sparse_matrix): + assert np.linalg.norm((lhs - rhs).todense()) == 0.0 + else: + assert lhs == rhs + + +def test_scalar_scipy_sparse(): + if not np: + skip("numpy not installed.") + if not scipy: + skip("scipy not installed.") + + assert represent(Integer(1), format='scipy.sparse') == 1 + assert represent(Float(1.0), format='scipy.sparse') == 1.0 + assert represent(1.0 + I, format='scipy.sparse') == 1.0 + 1.0j + +x_ket = XKet('x') +x_bra = XBra('x') +x_op = XOp('X') + + +def test_innerprod_represent(): + assert rep_innerproduct(x_ket) == InnerProduct(XBra("x_1"), x_ket).doit() + assert rep_innerproduct(x_bra) == InnerProduct(x_bra, XKet("x_1")).doit() + raises(TypeError, lambda: rep_innerproduct(x_op)) + + +def test_operator_represent(): + basis_kets = enumerate_states(operators_to_state(x_op), 1, 2) + assert rep_expectation( + x_op) == qapply(basis_kets[1].dual*x_op*basis_kets[0]) + + +def test_enumerate_states(): + test = XKet("foo") + assert enumerate_states(test, 1, 1) == [XKet("foo_1")] + assert enumerate_states( + test, [1, 2, 4]) == [XKet("foo_1"), XKet("foo_2"), XKet("foo_4")] diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_sho1d.py b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_sho1d.py new file mode 100644 index 0000000000000000000000000000000000000000..6acb1f1e7044ac278061cf3b4f04c3c8c09d1848 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_sho1d.py @@ -0,0 +1,176 @@ +"""Tests for sho1d.py""" + +from sympy.concrete import Sum +from sympy.core import oo +from sympy.core.numbers import (I, Integer) +from sympy.core.singleton import S +from sympy.core.symbol import Symbol, symbols +from sympy.functions.combinatorial.factorials import factorial +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.complexes import Abs +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.physics.quantum import Dagger +from sympy.physics.quantum.constants import hbar +from sympy.physics.quantum import Commutator +from sympy.physics.quantum.qapply import qapply +from sympy.physics.quantum.innerproduct import InnerProduct +from sympy.physics.quantum.cartesian import X, Px +from sympy.physics.quantum.hilbert import ComplexSpace +from sympy.physics.quantum.represent import represent +from sympy.simplify import simplify +from sympy.external import import_module +from sympy.tensor import IndexedBase, Idx +from sympy.testing.pytest import skip, raises + +from sympy.physics.quantum.sho1d import (RaisingOp, LoweringOp, + SHOKet, SHOBra, + Hamiltonian, NumberOp) + +ad = RaisingOp('a') +a = LoweringOp('a') +k = SHOKet('k') +kz = SHOKet(0) +kf = SHOKet(1) +k3 = SHOKet(3) +b = SHOBra('b') +b3 = SHOBra(3) +H = Hamiltonian('H') +N = NumberOp('N') +omega = Symbol('omega') +m = Symbol('m') +ndim = Integer(4) +p = Symbol('p', integer=True) +q = Symbol('q', nonnegative=True, integer=True) + + +np = import_module('numpy') +scipy = import_module('scipy', import_kwargs={'fromlist': ['sparse']}) + +ad_rep_sympy = represent(ad, basis=N, ndim=4, format='sympy') +a_rep = represent(a, basis=N, ndim=4, format='sympy') +N_rep = represent(N, basis=N, ndim=4, format='sympy') +H_rep = represent(H, basis=N, ndim=4, format='sympy') +k3_rep = represent(k3, basis=N, ndim=4, format='sympy') +b3_rep = represent(b3, basis=N, ndim=4, format='sympy') + +def test_RaisingOp(): + assert Dagger(ad) == a + assert Commutator(ad, a).doit() == Integer(-1) + assert Commutator(ad, N).doit() == Integer(-1)*ad + assert qapply(ad*k) == (sqrt(k.n + 1)*SHOKet(k.n + 1)).expand() + assert qapply(ad*kz) == (sqrt(kz.n + 1)*SHOKet(kz.n + 1)).expand() + assert qapply(ad*kf) == (sqrt(kf.n + 1)*SHOKet(kf.n + 1)).expand() + assert ad.rewrite('xp').doit() == \ + (Integer(1)/sqrt(Integer(2)*hbar*m*omega))*(Integer(-1)*I*Px + m*omega*X) + assert ad.hilbert_space == ComplexSpace(S.Infinity) + for i in range(ndim - 1): + assert ad_rep_sympy[i + 1,i] == sqrt(i + 1) + + if not np: + skip("numpy not installed.") + + ad_rep_numpy = represent(ad, basis=N, ndim=4, format='numpy') + for i in range(ndim - 1): + assert ad_rep_numpy[i + 1,i] == float(sqrt(i + 1)) + + if not np: + skip("numpy not installed.") + if not scipy: + skip("scipy not installed.") + + ad_rep_scipy = represent(ad, basis=N, ndim=4, format='scipy.sparse', spmatrix='lil') + for i in range(ndim - 1): + assert ad_rep_scipy[i + 1,i] == float(sqrt(i + 1)) + + assert ad_rep_numpy.dtype == 'float64' + assert ad_rep_scipy.dtype == 'float64' + +def test_LoweringOp(): + assert Dagger(a) == ad + assert Commutator(a, ad).doit() == Integer(1) + assert Commutator(a, N).doit() == a + assert qapply(a*k) == (sqrt(k.n)*SHOKet(k.n-Integer(1))).expand() + assert qapply(a*kz) == Integer(0) + assert qapply(a*kf) == (sqrt(kf.n)*SHOKet(kf.n-Integer(1))).expand() + assert a.rewrite('xp').doit() == \ + (Integer(1)/sqrt(Integer(2)*hbar*m*omega))*(I*Px + m*omega*X) + for i in range(ndim - 1): + assert a_rep[i,i + 1] == sqrt(i + 1) + +def test_NumberOp(): + assert Commutator(N, ad).doit() == ad + assert Commutator(N, a).doit() == Integer(-1)*a + assert Commutator(N, H).doit() == Integer(0) + assert qapply(N*k) == (k.n*k).expand() + assert N.rewrite('a').doit() == ad*a + assert N.rewrite('xp').doit() == (Integer(1)/(Integer(2)*m*hbar*omega))*( + Px**2 + (m*omega*X)**2) - Integer(1)/Integer(2) + assert N.rewrite('H').doit() == H/(hbar*omega) - Integer(1)/Integer(2) + for i in range(ndim): + assert N_rep[i,i] == i + assert N_rep == ad_rep_sympy*a_rep + +def test_Hamiltonian(): + assert Commutator(H, N).doit() == Integer(0) + assert qapply(H*k) == ((hbar*omega*(k.n + Integer(1)/Integer(2)))*k).expand() + assert H.rewrite('a').doit() == hbar*omega*(ad*a + Integer(1)/Integer(2)) + assert H.rewrite('xp').doit() == \ + (Integer(1)/(Integer(2)*m))*(Px**2 + (m*omega*X)**2) + assert H.rewrite('N').doit() == hbar*omega*(N + Integer(1)/Integer(2)) + for i in range(ndim): + assert H_rep[i,i] == hbar*omega*(i + Integer(1)/Integer(2)) + +def test_SHOKet(): + assert SHOKet('k').dual_class() == SHOBra + assert SHOBra('b').dual_class() == SHOKet + assert InnerProduct(b,k).doit() == KroneckerDelta(k.n, b.n) + assert k.hilbert_space == ComplexSpace(S.Infinity) + assert k3_rep[k3.n, 0] == Integer(1) + assert b3_rep[0, b3.n] == Integer(1) + +def test_sho_sums(): + e1 = Sum(SHOKet(p)*SHOBra(p), (p, 0, 1)) + assert e1.doit() == SHOKet(0)*SHOBra(0) + SHOKet(1)*SHOBra(1) + + # Test qapply with Sum on the left + assert qapply( + Sum(SHOKet(p)*SHOBra(p), (p, 0, oo))*SHOKet(q), + sum_doit=True + ) == SHOKet(q) + + # Test qapply with Sum on the right + a = IndexedBase('a') + n = symbols('n', cls=Idx) + result = qapply(SHOBra(q)*Sum(a[n]*SHOKet(n), (n,0,oo)), sum_doit=True) + assert result == a[q] + + # Test qapply with a product of Sums + result = qapply( + SHOBra(q)*Sum(SHOKet(p)*SHOBra(p), (p, 0, oo))*Sum(a[n]*SHOKet(n), (n,0,oo)), + sum_doit=True + ) + assert result == a[q] + + with raises(ValueError): + result = qapply( + SHOBra(q)*Sum(SHOKet(p)*SHOBra(p), (p, 0, oo))*Sum(a[p]*SHOKet(p), (p,0,oo)), + sum_doit=True + ) + +def test_sho_coherant_state(): + alpha = Symbol('alpha', is_complex=True) + cstate = exp(-Abs(alpha)**2/S(2))*Sum(((alpha**p)/sqrt(factorial(p)))*SHOKet(p), (p,0,oo)) + # Projection onto the number eigenstate + assert qapply(SHOBra(q)*cstate, sum_doit=True) == exp(-Abs(alpha)**2/S(2))*alpha**q/sqrt(factorial(q)) + # Ensure that the coherent state is an eigenstate of annihilation operator + assert simplify(qapply(SHOBra(q)*a*cstate, sum_doit=True)) == simplify(qapply(SHOBra(q)*alpha*cstate, sum_doit=True)) + +def test_issue_26495(): + nbar = Symbol('nbar', real=True, nonnegative=True) + n = Symbol('n', integer=True) + i = Symbol('i', integer=True, nonnegative=True) + j = Symbol('j', integer=True, nonnegative=True) + rho = Sum((nbar/(1+nbar))**n*SHOKet(n)*SHOBra(n), (n,0,oo)) + result = qapply(SHOBra(i)*rho*SHOKet(j), sum_doit=True) + assert simplify(result) == (nbar/(nbar+1))**i*KroneckerDelta(i,j) diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_shor.py b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_shor.py new file mode 100644 index 0000000000000000000000000000000000000000..0ebccbc199be8640f2021933abbe58716c68f788 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_shor.py @@ -0,0 +1,21 @@ +from sympy.testing.pytest import XFAIL + +from sympy.physics.quantum.qapply import qapply +from sympy.physics.quantum.qubit import Qubit +from sympy.physics.quantum.shor import CMod, getr + + +@XFAIL +def test_CMod(): + assert qapply(CMod(4, 2, 2)*Qubit(0, 0, 1, 0, 0, 0, 0, 0)) == \ + Qubit(0, 0, 1, 0, 0, 0, 0, 0) + assert qapply(CMod(5, 5, 7)*Qubit(0, 0, 1, 0, 0, 0, 0, 0, 0, 0)) == \ + Qubit(0, 0, 1, 0, 0, 0, 0, 0, 1, 0) + assert qapply(CMod(3, 2, 3)*Qubit(0, 1, 0, 0, 0, 0)) == \ + Qubit(0, 1, 0, 0, 0, 1) + + +def test_continued_frac(): + assert getr(513, 1024, 10) == 2 + assert getr(169, 1024, 11) == 6 + assert getr(314, 4096, 16) == 13 diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_spin.py b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_spin.py new file mode 100644 index 0000000000000000000000000000000000000000..f905a7de5aed31e24a6d7c882b6a768a787c61cb --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_spin.py @@ -0,0 +1,4333 @@ +from sympy.concrete.summations import Sum +from sympy.core.function import expand +from sympy.core.numbers import (I, Rational, pi) +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.matrices.dense import Matrix +from sympy.abc import alpha, beta, gamma, j, m +from sympy.simplify import simplify + +from sympy.physics.quantum import hbar, represent, Commutator, InnerProduct +from sympy.physics.quantum.qapply import qapply +from sympy.physics.quantum.tensorproduct import TensorProduct +from sympy.physics.quantum.cg import CG +from sympy.physics.quantum.spin import ( + Jx, Jy, Jz, Jplus, Jminus, J2, + JxBra, JyBra, JzBra, + JxKet, JyKet, JzKet, + JxKetCoupled, JyKetCoupled, JzKetCoupled, + couple, uncouple, + Rotation, WignerD +) + +from sympy.testing.pytest import raises, slow + +j1, j2, j3, j4, m1, m2, m3, m4 = symbols('j1:5 m1:5') +j12, j13, j24, j34, j123, j134, mi, mi1, mp = symbols( + 'j12 j13 j24 j34 j123 j134 mi mi1 mp') + + +def assert_simplify_expand(e1, e2): + """Helper for simplifying and expanding results. + + This is needed to help us test complex expressions whose form + might change in subtle ways as the rest of sympy evolves. + """ + assert simplify(e1.expand(tensorproduct=True)) == \ + simplify(e2.expand(tensorproduct=True)) + + +def test_represent_spin_operators(): + assert represent(Jx) == hbar*Matrix([[0, 1], [1, 0]])/2 + assert represent( + Jx, j=1) == hbar*sqrt(2)*Matrix([[0, 1, 0], [1, 0, 1], [0, 1, 0]])/2 + assert represent(Jy) == hbar*I*Matrix([[0, -1], [1, 0]])/2 + assert represent(Jy, j=1) == hbar*I*sqrt(2)*Matrix([[0, -1, 0], [1, + 0, -1], [0, 1, 0]])/2 + assert represent(Jz) == hbar*Matrix([[1, 0], [0, -1]])/2 + assert represent( + Jz, j=1) == hbar*Matrix([[1, 0, 0], [0, 0, 0], [0, 0, -1]]) + + +def test_represent_spin_states(): + # Jx basis + assert represent(JxKet(S.Half, S.Half), basis=Jx) == Matrix([1, 0]) + assert represent(JxKet(S.Half, Rational(-1, 2)), basis=Jx) == Matrix([0, 1]) + assert represent(JxKet(1, 1), basis=Jx) == Matrix([1, 0, 0]) + assert represent(JxKet(1, 0), basis=Jx) == Matrix([0, 1, 0]) + assert represent(JxKet(1, -1), basis=Jx) == Matrix([0, 0, 1]) + assert represent( + JyKet(S.Half, S.Half), basis=Jx) == Matrix([exp(-I*pi/4), 0]) + assert represent( + JyKet(S.Half, Rational(-1, 2)), basis=Jx) == Matrix([0, exp(I*pi/4)]) + assert represent(JyKet(1, 1), basis=Jx) == Matrix([-I, 0, 0]) + assert represent(JyKet(1, 0), basis=Jx) == Matrix([0, 1, 0]) + assert represent(JyKet(1, -1), basis=Jx) == Matrix([0, 0, I]) + assert represent( + JzKet(S.Half, S.Half), basis=Jx) == sqrt(2)*Matrix([-1, 1])/2 + assert represent( + JzKet(S.Half, Rational(-1, 2)), basis=Jx) == sqrt(2)*Matrix([-1, -1])/2 + assert represent(JzKet(1, 1), basis=Jx) == Matrix([1, -sqrt(2), 1])/2 + assert represent(JzKet(1, 0), basis=Jx) == sqrt(2)*Matrix([1, 0, -1])/2 + assert represent(JzKet(1, -1), basis=Jx) == Matrix([1, sqrt(2), 1])/2 + # Jy basis + assert represent( + JxKet(S.Half, S.Half), basis=Jy) == Matrix([exp(I*pi*Rational(-3, 4)), 0]) + assert represent( + JxKet(S.Half, Rational(-1, 2)), basis=Jy) == Matrix([0, exp(I*pi*Rational(3, 4))]) + assert represent(JxKet(1, 1), basis=Jy) == Matrix([I, 0, 0]) + assert represent(JxKet(1, 0), basis=Jy) == Matrix([0, 1, 0]) + assert represent(JxKet(1, -1), basis=Jy) == Matrix([0, 0, -I]) + assert represent(JyKet(S.Half, S.Half), basis=Jy) == Matrix([1, 0]) + assert represent(JyKet(S.Half, Rational(-1, 2)), basis=Jy) == Matrix([0, 1]) + assert represent(JyKet(1, 1), basis=Jy) == Matrix([1, 0, 0]) + assert represent(JyKet(1, 0), basis=Jy) == Matrix([0, 1, 0]) + assert represent(JyKet(1, -1), basis=Jy) == Matrix([0, 0, 1]) + assert represent( + JzKet(S.Half, S.Half), basis=Jy) == sqrt(2)*Matrix([-1, I])/2 + assert represent( + JzKet(S.Half, Rational(-1, 2)), basis=Jy) == sqrt(2)*Matrix([I, -1])/2 + assert represent(JzKet(1, 1), basis=Jy) == Matrix([1, -I*sqrt(2), -1])/2 + assert represent( + JzKet(1, 0), basis=Jy) == Matrix([-sqrt(2)*I, 0, -sqrt(2)*I])/2 + assert represent(JzKet(1, -1), basis=Jy) == Matrix([-1, -sqrt(2)*I, 1])/2 + # Jz basis + assert represent( + JxKet(S.Half, S.Half), basis=Jz) == sqrt(2)*Matrix([1, 1])/2 + assert represent( + JxKet(S.Half, Rational(-1, 2)), basis=Jz) == sqrt(2)*Matrix([-1, 1])/2 + assert represent(JxKet(1, 1), basis=Jz) == Matrix([1, sqrt(2), 1])/2 + assert represent(JxKet(1, 0), basis=Jz) == sqrt(2)*Matrix([-1, 0, 1])/2 + assert represent(JxKet(1, -1), basis=Jz) == Matrix([1, -sqrt(2), 1])/2 + assert represent( + JyKet(S.Half, S.Half), basis=Jz) == sqrt(2)*Matrix([-1, -I])/2 + assert represent( + JyKet(S.Half, Rational(-1, 2)), basis=Jz) == sqrt(2)*Matrix([-I, -1])/2 + assert represent(JyKet(1, 1), basis=Jz) == Matrix([1, sqrt(2)*I, -1])/2 + assert represent(JyKet(1, 0), basis=Jz) == sqrt(2)*Matrix([I, 0, I])/2 + assert represent(JyKet(1, -1), basis=Jz) == Matrix([-1, sqrt(2)*I, 1])/2 + assert represent(JzKet(S.Half, S.Half), basis=Jz) == Matrix([1, 0]) + assert represent(JzKet(S.Half, Rational(-1, 2)), basis=Jz) == Matrix([0, 1]) + assert represent(JzKet(1, 1), basis=Jz) == Matrix([1, 0, 0]) + assert represent(JzKet(1, 0), basis=Jz) == Matrix([0, 1, 0]) + assert represent(JzKet(1, -1), basis=Jz) == Matrix([0, 0, 1]) + + +def test_represent_uncoupled_states(): + # Jx basis + assert represent(TensorProduct(JxKet(S.Half, S.Half), JxKet(S.Half, S.Half)), basis=Jx) == \ + Matrix([1, 0, 0, 0]) + assert represent(TensorProduct(JxKet(S.Half, S.Half), JxKet(S.Half, Rational(-1, 2))), basis=Jx) == \ + Matrix([0, 1, 0, 0]) + assert represent(TensorProduct(JxKet(S.Half, Rational(-1, 2)), JxKet(S.Half, S.Half)), basis=Jx) == \ + Matrix([0, 0, 1, 0]) + assert represent(TensorProduct(JxKet(S.Half, Rational(-1, 2)), JxKet(S.Half, Rational(-1, 2))), basis=Jx) == \ + Matrix([0, 0, 0, 1]) + assert represent(TensorProduct(JyKet(S.Half, S.Half), JyKet(S.Half, S.Half)), basis=Jx) == \ + Matrix([-I, 0, 0, 0]) + assert represent(TensorProduct(JyKet(S.Half, S.Half), JyKet(S.Half, Rational(-1, 2))), basis=Jx) == \ + Matrix([0, 1, 0, 0]) + assert represent(TensorProduct(JyKet(S.Half, Rational(-1, 2)), JyKet(S.Half, S.Half)), basis=Jx) == \ + Matrix([0, 0, 1, 0]) + assert represent(TensorProduct(JyKet(S.Half, Rational(-1, 2)), JyKet(S.Half, Rational(-1, 2))), basis=Jx) == \ + Matrix([0, 0, 0, I]) + assert represent(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)), basis=Jx) == \ + Matrix([S.Half, Rational(-1, 2), Rational(-1, 2), S.Half]) + assert represent(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))), basis=Jx) == \ + Matrix([S.Half, S.Half, Rational(-1, 2), Rational(-1, 2)]) + assert represent(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)), basis=Jx) == \ + Matrix([S.Half, Rational(-1, 2), S.Half, Rational(-1, 2)]) + assert represent(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))), basis=Jx) == \ + Matrix([S.Half, S.Half, S.Half, S.Half]) + # Jy basis + assert represent(TensorProduct(JxKet(S.Half, S.Half), JxKet(S.Half, S.Half)), basis=Jy) == \ + Matrix([I, 0, 0, 0]) + assert represent(TensorProduct(JxKet(S.Half, S.Half), JxKet(S.Half, Rational(-1, 2))), basis=Jy) == \ + Matrix([0, 1, 0, 0]) + assert represent(TensorProduct(JxKet(S.Half, Rational(-1, 2)), JxKet(S.Half, S.Half)), basis=Jy) == \ + Matrix([0, 0, 1, 0]) + assert represent(TensorProduct(JxKet(S.Half, Rational(-1, 2)), JxKet(S.Half, Rational(-1, 2))), basis=Jy) == \ + Matrix([0, 0, 0, -I]) + assert represent(TensorProduct(JyKet(S.Half, S.Half), JyKet(S.Half, S.Half)), basis=Jy) == \ + Matrix([1, 0, 0, 0]) + assert represent(TensorProduct(JyKet(S.Half, S.Half), JyKet(S.Half, Rational(-1, 2))), basis=Jy) == \ + Matrix([0, 1, 0, 0]) + assert represent(TensorProduct(JyKet(S.Half, Rational(-1, 2)), JyKet(S.Half, S.Half)), basis=Jy) == \ + Matrix([0, 0, 1, 0]) + assert represent(TensorProduct(JyKet(S.Half, Rational(-1, 2)), JyKet(S.Half, Rational(-1, 2))), basis=Jy) == \ + Matrix([0, 0, 0, 1]) + assert represent(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)), basis=Jy) == \ + Matrix([S.Half, -I/2, -I/2, Rational(-1, 2)]) + assert represent(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))), basis=Jy) == \ + Matrix([-I/2, S.Half, Rational(-1, 2), -I/2]) + assert represent(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)), basis=Jy) == \ + Matrix([-I/2, Rational(-1, 2), S.Half, -I/2]) + assert represent(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))), basis=Jy) == \ + Matrix([Rational(-1, 2), -I/2, -I/2, S.Half]) + # Jz basis + assert represent(TensorProduct(JxKet(S.Half, S.Half), JxKet(S.Half, S.Half)), basis=Jz) == \ + Matrix([S.Half, S.Half, S.Half, S.Half]) + assert represent(TensorProduct(JxKet(S.Half, S.Half), JxKet(S.Half, Rational(-1, 2))), basis=Jz) == \ + Matrix([Rational(-1, 2), S.Half, Rational(-1, 2), S.Half]) + assert represent(TensorProduct(JxKet(S.Half, Rational(-1, 2)), JxKet(S.Half, S.Half)), basis=Jz) == \ + Matrix([Rational(-1, 2), Rational(-1, 2), S.Half, S.Half]) + assert represent(TensorProduct(JxKet(S.Half, Rational(-1, 2)), JxKet(S.Half, Rational(-1, 2))), basis=Jz) == \ + Matrix([S.Half, Rational(-1, 2), Rational(-1, 2), S.Half]) + assert represent(TensorProduct(JyKet(S.Half, S.Half), JyKet(S.Half, S.Half)), basis=Jz) == \ + Matrix([S.Half, I/2, I/2, Rational(-1, 2)]) + assert represent(TensorProduct(JyKet(S.Half, S.Half), JyKet(S.Half, Rational(-1, 2))), basis=Jz) == \ + Matrix([I/2, S.Half, Rational(-1, 2), I/2]) + assert represent(TensorProduct(JyKet(S.Half, Rational(-1, 2)), JyKet(S.Half, S.Half)), basis=Jz) == \ + Matrix([I/2, Rational(-1, 2), S.Half, I/2]) + assert represent(TensorProduct(JyKet(S.Half, Rational(-1, 2)), JyKet(S.Half, Rational(-1, 2))), basis=Jz) == \ + Matrix([Rational(-1, 2), I/2, I/2, S.Half]) + assert represent(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)), basis=Jz) == \ + Matrix([1, 0, 0, 0]) + assert represent(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))), basis=Jz) == \ + Matrix([0, 1, 0, 0]) + assert represent(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)), basis=Jz) == \ + Matrix([0, 0, 1, 0]) + assert represent(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))), basis=Jz) == \ + Matrix([0, 0, 0, 1]) + + +def test_represent_coupled_states(): + # Jx basis + assert represent(JxKetCoupled(0, 0, (S.Half, S.Half)), basis=Jx) == \ + Matrix([1, 0, 0, 0]) + assert represent(JxKetCoupled(1, 1, (S.Half, S.Half)), basis=Jx) == \ + Matrix([0, 1, 0, 0]) + assert represent(JxKetCoupled(1, 0, (S.Half, S.Half)), basis=Jx) == \ + Matrix([0, 0, 1, 0]) + assert represent(JxKetCoupled(1, -1, (S.Half, S.Half)), basis=Jx) == \ + Matrix([0, 0, 0, 1]) + assert represent(JyKetCoupled(0, 0, (S.Half, S.Half)), basis=Jx) == \ + Matrix([1, 0, 0, 0]) + assert represent(JyKetCoupled(1, 1, (S.Half, S.Half)), basis=Jx) == \ + Matrix([0, -I, 0, 0]) + assert represent(JyKetCoupled(1, 0, (S.Half, S.Half)), basis=Jx) == \ + Matrix([0, 0, 1, 0]) + assert represent(JyKetCoupled(1, -1, (S.Half, S.Half)), basis=Jx) == \ + Matrix([0, 0, 0, I]) + assert represent(JzKetCoupled(0, 0, (S.Half, S.Half)), basis=Jx) == \ + Matrix([1, 0, 0, 0]) + assert represent(JzKetCoupled(1, 1, (S.Half, S.Half)), basis=Jx) == \ + Matrix([0, S.Half, -sqrt(2)/2, S.Half]) + assert represent(JzKetCoupled(1, 0, (S.Half, S.Half)), basis=Jx) == \ + Matrix([0, sqrt(2)/2, 0, -sqrt(2)/2]) + assert represent(JzKetCoupled(1, -1, (S.Half, S.Half)), basis=Jx) == \ + Matrix([0, S.Half, sqrt(2)/2, S.Half]) + # Jy basis + assert represent(JxKetCoupled(0, 0, (S.Half, S.Half)), basis=Jy) == \ + Matrix([1, 0, 0, 0]) + assert represent(JxKetCoupled(1, 1, (S.Half, S.Half)), basis=Jy) == \ + Matrix([0, I, 0, 0]) + assert represent(JxKetCoupled(1, 0, (S.Half, S.Half)), basis=Jy) == \ + Matrix([0, 0, 1, 0]) + assert represent(JxKetCoupled(1, -1, (S.Half, S.Half)), basis=Jy) == \ + Matrix([0, 0, 0, -I]) + assert represent(JyKetCoupled(0, 0, (S.Half, S.Half)), basis=Jy) == \ + Matrix([1, 0, 0, 0]) + assert represent(JyKetCoupled(1, 1, (S.Half, S.Half)), basis=Jy) == \ + Matrix([0, 1, 0, 0]) + assert represent(JyKetCoupled(1, 0, (S.Half, S.Half)), basis=Jy) == \ + Matrix([0, 0, 1, 0]) + assert represent(JyKetCoupled(1, -1, (S.Half, S.Half)), basis=Jy) == \ + Matrix([0, 0, 0, 1]) + assert represent(JzKetCoupled(0, 0, (S.Half, S.Half)), basis=Jy) == \ + Matrix([1, 0, 0, 0]) + assert represent(JzKetCoupled(1, 1, (S.Half, S.Half)), basis=Jy) == \ + Matrix([0, S.Half, -I*sqrt(2)/2, Rational(-1, 2)]) + assert represent(JzKetCoupled(1, 0, (S.Half, S.Half)), basis=Jy) == \ + Matrix([0, -I*sqrt(2)/2, 0, -I*sqrt(2)/2]) + assert represent(JzKetCoupled(1, -1, (S.Half, S.Half)), basis=Jy) == \ + Matrix([0, Rational(-1, 2), -I*sqrt(2)/2, S.Half]) + # Jz basis + assert represent(JxKetCoupled(0, 0, (S.Half, S.Half)), basis=Jz) == \ + Matrix([1, 0, 0, 0]) + assert represent(JxKetCoupled(1, 1, (S.Half, S.Half)), basis=Jz) == \ + Matrix([0, S.Half, sqrt(2)/2, S.Half]) + assert represent(JxKetCoupled(1, 0, (S.Half, S.Half)), basis=Jz) == \ + Matrix([0, -sqrt(2)/2, 0, sqrt(2)/2]) + assert represent(JxKetCoupled(1, -1, (S.Half, S.Half)), basis=Jz) == \ + Matrix([0, S.Half, -sqrt(2)/2, S.Half]) + assert represent(JyKetCoupled(0, 0, (S.Half, S.Half)), basis=Jz) == \ + Matrix([1, 0, 0, 0]) + assert represent(JyKetCoupled(1, 1, (S.Half, S.Half)), basis=Jz) == \ + Matrix([0, S.Half, I*sqrt(2)/2, Rational(-1, 2)]) + assert represent(JyKetCoupled(1, 0, (S.Half, S.Half)), basis=Jz) == \ + Matrix([0, I*sqrt(2)/2, 0, I*sqrt(2)/2]) + assert represent(JyKetCoupled(1, -1, (S.Half, S.Half)), basis=Jz) == \ + Matrix([0, Rational(-1, 2), I*sqrt(2)/2, S.Half]) + assert represent(JzKetCoupled(0, 0, (S.Half, S.Half)), basis=Jz) == \ + Matrix([1, 0, 0, 0]) + assert represent(JzKetCoupled(1, 1, (S.Half, S.Half)), basis=Jz) == \ + Matrix([0, 1, 0, 0]) + assert represent(JzKetCoupled(1, 0, (S.Half, S.Half)), basis=Jz) == \ + Matrix([0, 0, 1, 0]) + assert represent(JzKetCoupled(1, -1, (S.Half, S.Half)), basis=Jz) == \ + Matrix([0, 0, 0, 1]) + + +def test_represent_rotation(): + assert represent(Rotation(0, pi/2, 0)) == \ + Matrix( + [[WignerD( + S( + 1)/2, S( + 1)/2, S( + 1)/2, 0, pi/2, 0), WignerD( + S.Half, S.Half, Rational(-1, 2), 0, pi/2, 0)], + [WignerD(S.Half, Rational(-1, 2), S.Half, 0, pi/2, 0), WignerD(S.Half, Rational(-1, 2), Rational(-1, 2), 0, pi/2, 0)]]) + assert represent(Rotation(0, pi/2, 0), doit=True) == \ + Matrix([[sqrt(2)/2, -sqrt(2)/2], + [sqrt(2)/2, sqrt(2)/2]]) + + +def test_rewrite_same(): + # Rewrite to same basis + assert JxBra(1, 1).rewrite('Jx') == JxBra(1, 1) + assert JxBra(j, m).rewrite('Jx') == JxBra(j, m) + assert JxKet(1, 1).rewrite('Jx') == JxKet(1, 1) + assert JxKet(j, m).rewrite('Jx') == JxKet(j, m) + + +def test_rewrite_Bra(): + # Numerical + assert JxBra(1, 1).rewrite('Jy') == -I*JyBra(1, 1) + assert JxBra(1, 0).rewrite('Jy') == JyBra(1, 0) + assert JxBra(1, -1).rewrite('Jy') == I*JyBra(1, -1) + assert JxBra(1, 1).rewrite( + 'Jz') == JzBra(1, 1)/2 + JzBra(1, 0)/sqrt(2) + JzBra(1, -1)/2 + assert JxBra( + 1, 0).rewrite('Jz') == -sqrt(2)*JzBra(1, 1)/2 + sqrt(2)*JzBra(1, -1)/2 + assert JxBra(1, -1).rewrite( + 'Jz') == JzBra(1, 1)/2 - JzBra(1, 0)/sqrt(2) + JzBra(1, -1)/2 + assert JyBra(1, 1).rewrite('Jx') == I*JxBra(1, 1) + assert JyBra(1, 0).rewrite('Jx') == JxBra(1, 0) + assert JyBra(1, -1).rewrite('Jx') == -I*JxBra(1, -1) + assert JyBra(1, 1).rewrite( + 'Jz') == JzBra(1, 1)/2 - sqrt(2)*I*JzBra(1, 0)/2 - JzBra(1, -1)/2 + assert JyBra(1, 0).rewrite( + 'Jz') == -sqrt(2)*I*JzBra(1, 1)/2 - sqrt(2)*I*JzBra(1, -1)/2 + assert JyBra(1, -1).rewrite( + 'Jz') == -JzBra(1, 1)/2 - sqrt(2)*I*JzBra(1, 0)/2 + JzBra(1, -1)/2 + assert JzBra(1, 1).rewrite( + 'Jx') == JxBra(1, 1)/2 - sqrt(2)*JxBra(1, 0)/2 + JxBra(1, -1)/2 + assert JzBra( + 1, 0).rewrite('Jx') == sqrt(2)*JxBra(1, 1)/2 - sqrt(2)*JxBra(1, -1)/2 + assert JzBra(1, -1).rewrite( + 'Jx') == JxBra(1, 1)/2 + sqrt(2)*JxBra(1, 0)/2 + JxBra(1, -1)/2 + assert JzBra(1, 1).rewrite( + 'Jy') == JyBra(1, 1)/2 + sqrt(2)*I*JyBra(1, 0)/2 - JyBra(1, -1)/2 + assert JzBra(1, 0).rewrite( + 'Jy') == sqrt(2)*I*JyBra(1, 1)/2 + sqrt(2)*I*JyBra(1, -1)/2 + assert JzBra(1, -1).rewrite( + 'Jy') == -JyBra(1, 1)/2 + sqrt(2)*I*JyBra(1, 0)/2 + JyBra(1, -1)/2 + # Symbolic + assert JxBra(j, m).rewrite('Jy') == Sum( + WignerD(j, mi, m, pi*Rational(3, 2), 0, 0) * JyBra(j, mi), (mi, -j, j)) + assert JxBra(j, m).rewrite('Jz') == Sum( + WignerD(j, mi, m, 0, pi/2, 0) * JzBra(j, mi), (mi, -j, j)) + assert JyBra(j, m).rewrite('Jx') == Sum( + WignerD(j, mi, m, 0, 0, pi/2) * JxBra(j, mi), (mi, -j, j)) + assert JyBra(j, m).rewrite('Jz') == Sum( + WignerD(j, mi, m, pi*Rational(3, 2), -pi/2, pi/2) * JzBra(j, mi), (mi, -j, j)) + assert JzBra(j, m).rewrite('Jx') == Sum( + WignerD(j, mi, m, 0, pi*Rational(3, 2), 0) * JxBra(j, mi), (mi, -j, j)) + assert JzBra(j, m).rewrite('Jy') == Sum( + WignerD(j, mi, m, pi*Rational(3, 2), pi/2, pi/2) * JyBra(j, mi), (mi, -j, j)) + + +def test_rewrite_Ket(): + # Numerical + assert JxKet(1, 1).rewrite('Jy') == I*JyKet(1, 1) + assert JxKet(1, 0).rewrite('Jy') == JyKet(1, 0) + assert JxKet(1, -1).rewrite('Jy') == -I*JyKet(1, -1) + assert JxKet(1, 1).rewrite( + 'Jz') == JzKet(1, 1)/2 + JzKet(1, 0)/sqrt(2) + JzKet(1, -1)/2 + assert JxKet( + 1, 0).rewrite('Jz') == -sqrt(2)*JzKet(1, 1)/2 + sqrt(2)*JzKet(1, -1)/2 + assert JxKet(1, -1).rewrite( + 'Jz') == JzKet(1, 1)/2 - JzKet(1, 0)/sqrt(2) + JzKet(1, -1)/2 + assert JyKet(1, 1).rewrite('Jx') == -I*JxKet(1, 1) + assert JyKet(1, 0).rewrite('Jx') == JxKet(1, 0) + assert JyKet(1, -1).rewrite('Jx') == I*JxKet(1, -1) + assert JyKet(1, 1).rewrite( + 'Jz') == JzKet(1, 1)/2 + sqrt(2)*I*JzKet(1, 0)/2 - JzKet(1, -1)/2 + assert JyKet(1, 0).rewrite( + 'Jz') == sqrt(2)*I*JzKet(1, 1)/2 + sqrt(2)*I*JzKet(1, -1)/2 + assert JyKet(1, -1).rewrite( + 'Jz') == -JzKet(1, 1)/2 + sqrt(2)*I*JzKet(1, 0)/2 + JzKet(1, -1)/2 + assert JzKet(1, 1).rewrite( + 'Jx') == JxKet(1, 1)/2 - sqrt(2)*JxKet(1, 0)/2 + JxKet(1, -1)/2 + assert JzKet( + 1, 0).rewrite('Jx') == sqrt(2)*JxKet(1, 1)/2 - sqrt(2)*JxKet(1, -1)/2 + assert JzKet(1, -1).rewrite( + 'Jx') == JxKet(1, 1)/2 + sqrt(2)*JxKet(1, 0)/2 + JxKet(1, -1)/2 + assert JzKet(1, 1).rewrite( + 'Jy') == JyKet(1, 1)/2 - sqrt(2)*I*JyKet(1, 0)/2 - JyKet(1, -1)/2 + assert JzKet(1, 0).rewrite( + 'Jy') == -sqrt(2)*I*JyKet(1, 1)/2 - sqrt(2)*I*JyKet(1, -1)/2 + assert JzKet(1, -1).rewrite( + 'Jy') == -JyKet(1, 1)/2 - sqrt(2)*I*JyKet(1, 0)/2 + JyKet(1, -1)/2 + # Symbolic + assert JxKet(j, m).rewrite('Jy') == Sum( + WignerD(j, mi, m, pi*Rational(3, 2), 0, 0) * JyKet(j, mi), (mi, -j, j)) + assert JxKet(j, m).rewrite('Jz') == Sum( + WignerD(j, mi, m, 0, pi/2, 0) * JzKet(j, mi), (mi, -j, j)) + assert JyKet(j, m).rewrite('Jx') == Sum( + WignerD(j, mi, m, 0, 0, pi/2) * JxKet(j, mi), (mi, -j, j)) + assert JyKet(j, m).rewrite('Jz') == Sum( + WignerD(j, mi, m, pi*Rational(3, 2), -pi/2, pi/2) * JzKet(j, mi), (mi, -j, j)) + assert JzKet(j, m).rewrite('Jx') == Sum( + WignerD(j, mi, m, 0, pi*Rational(3, 2), 0) * JxKet(j, mi), (mi, -j, j)) + assert JzKet(j, m).rewrite('Jy') == Sum( + WignerD(j, mi, m, pi*Rational(3, 2), pi/2, pi/2) * JyKet(j, mi), (mi, -j, j)) + + +def test_rewrite_uncoupled_state(): + # Numerical + assert TensorProduct(JyKet(1, 1), JxKet( + 1, 1)).rewrite('Jx') == -I*TensorProduct(JxKet(1, 1), JxKet(1, 1)) + assert TensorProduct(JyKet(1, 0), JxKet( + 1, 1)).rewrite('Jx') == TensorProduct(JxKet(1, 0), JxKet(1, 1)) + assert TensorProduct(JyKet(1, -1), JxKet( + 1, 1)).rewrite('Jx') == I*TensorProduct(JxKet(1, -1), JxKet(1, 1)) + assert TensorProduct(JzKet(1, 1), JxKet(1, 1)).rewrite('Jx') == \ + TensorProduct(JxKet(1, -1), JxKet(1, 1))/2 - sqrt(2)*TensorProduct(JxKet( + 1, 0), JxKet(1, 1))/2 + TensorProduct(JxKet(1, 1), JxKet(1, 1))/2 + assert TensorProduct(JzKet(1, 0), JxKet(1, 1)).rewrite('Jx') == \ + -sqrt(2)*TensorProduct(JxKet(1, -1), JxKet(1, 1))/2 + sqrt( + 2)*TensorProduct(JxKet(1, 1), JxKet(1, 1))/2 + assert TensorProduct(JzKet(1, -1), JxKet(1, 1)).rewrite('Jx') == \ + TensorProduct(JxKet(1, -1), JxKet(1, 1))/2 + sqrt(2)*TensorProduct(JxKet(1, 0), JxKet(1, 1))/2 + TensorProduct(JxKet(1, 1), JxKet(1, 1))/2 + assert TensorProduct(JxKet(1, 1), JyKet( + 1, 1)).rewrite('Jy') == I*TensorProduct(JyKet(1, 1), JyKet(1, 1)) + assert TensorProduct(JxKet(1, 0), JyKet( + 1, 1)).rewrite('Jy') == TensorProduct(JyKet(1, 0), JyKet(1, 1)) + assert TensorProduct(JxKet(1, -1), JyKet( + 1, 1)).rewrite('Jy') == -I*TensorProduct(JyKet(1, -1), JyKet(1, 1)) + assert TensorProduct(JzKet(1, 1), JyKet(1, 1)).rewrite('Jy') == \ + -TensorProduct(JyKet(1, -1), JyKet(1, 1))/2 - sqrt(2)*I*TensorProduct(JyKet(1, 0), JyKet(1, 1))/2 + TensorProduct(JyKet(1, 1), JyKet(1, 1))/2 + assert TensorProduct(JzKet(1, 0), JyKet(1, 1)).rewrite('Jy') == \ + -sqrt(2)*I*TensorProduct(JyKet(1, -1), JyKet( + 1, 1))/2 - sqrt(2)*I*TensorProduct(JyKet(1, 1), JyKet(1, 1))/2 + assert TensorProduct(JzKet(1, -1), JyKet(1, 1)).rewrite('Jy') == \ + TensorProduct(JyKet(1, -1), JyKet(1, 1))/2 - sqrt(2)*I*TensorProduct(JyKet(1, 0), JyKet(1, 1))/2 - TensorProduct(JyKet(1, 1), JyKet(1, 1))/2 + assert TensorProduct(JxKet(1, 1), JzKet(1, 1)).rewrite('Jz') == \ + TensorProduct(JzKet(1, -1), JzKet(1, 1))/2 + sqrt(2)*TensorProduct(JzKet(1, 0), JzKet(1, 1))/2 + TensorProduct(JzKet(1, 1), JzKet(1, 1))/2 + assert TensorProduct(JxKet(1, 0), JzKet(1, 1)).rewrite('Jz') == \ + sqrt(2)*TensorProduct(JzKet(1, -1), JzKet( + 1, 1))/2 - sqrt(2)*TensorProduct(JzKet(1, 1), JzKet(1, 1))/2 + assert TensorProduct(JxKet(1, -1), JzKet(1, 1)).rewrite('Jz') == \ + TensorProduct(JzKet(1, -1), JzKet(1, 1))/2 - sqrt(2)*TensorProduct(JzKet(1, 0), JzKet(1, 1))/2 + TensorProduct(JzKet(1, 1), JzKet(1, 1))/2 + assert TensorProduct(JyKet(1, 1), JzKet(1, 1)).rewrite('Jz') == \ + -TensorProduct(JzKet(1, -1), JzKet(1, 1))/2 + sqrt(2)*I*TensorProduct(JzKet(1, 0), JzKet(1, 1))/2 + TensorProduct(JzKet(1, 1), JzKet(1, 1))/2 + assert TensorProduct(JyKet(1, 0), JzKet(1, 1)).rewrite('Jz') == \ + sqrt(2)*I*TensorProduct(JzKet(1, -1), JzKet( + 1, 1))/2 + sqrt(2)*I*TensorProduct(JzKet(1, 1), JzKet(1, 1))/2 + assert TensorProduct(JyKet(1, -1), JzKet(1, 1)).rewrite('Jz') == \ + TensorProduct(JzKet(1, -1), JzKet(1, 1))/2 + sqrt(2)*I*TensorProduct(JzKet(1, 0), JzKet(1, 1))/2 - TensorProduct(JzKet(1, 1), JzKet(1, 1))/2 + # Symbolic + assert TensorProduct(JyKet(j1, m1), JxKet(j2, m2)).rewrite('Jy') == \ + TensorProduct(JyKet(j1, m1), Sum( + WignerD(j2, mi, m2, pi*Rational(3, 2), 0, 0) * JyKet(j2, mi), (mi, -j2, j2))) + assert TensorProduct(JzKet(j1, m1), JxKet(j2, m2)).rewrite('Jz') == \ + TensorProduct(JzKet(j1, m1), Sum( + WignerD(j2, mi, m2, 0, pi/2, 0) * JzKet(j2, mi), (mi, -j2, j2))) + assert TensorProduct(JxKet(j1, m1), JyKet(j2, m2)).rewrite('Jx') == \ + TensorProduct(JxKet(j1, m1), Sum( + WignerD(j2, mi, m2, 0, 0, pi/2) * JxKet(j2, mi), (mi, -j2, j2))) + assert TensorProduct(JzKet(j1, m1), JyKet(j2, m2)).rewrite('Jz') == \ + TensorProduct(JzKet(j1, m1), Sum(WignerD( + j2, mi, m2, pi*Rational(3, 2), -pi/2, pi/2) * JzKet(j2, mi), (mi, -j2, j2))) + assert TensorProduct(JxKet(j1, m1), JzKet(j2, m2)).rewrite('Jx') == \ + TensorProduct(JxKet(j1, m1), Sum( + WignerD(j2, mi, m2, 0, pi*Rational(3, 2), 0) * JxKet(j2, mi), (mi, -j2, j2))) + assert TensorProduct(JyKet(j1, m1), JzKet(j2, m2)).rewrite('Jy') == \ + TensorProduct(JyKet(j1, m1), Sum(WignerD( + j2, mi, m2, pi*Rational(3, 2), pi/2, pi/2) * JyKet(j2, mi), (mi, -j2, j2))) + + +def test_rewrite_coupled_state(): + # Numerical + assert JyKetCoupled(0, 0, (S.Half, S.Half)).rewrite('Jx') == \ + JxKetCoupled(0, 0, (S.Half, S.Half)) + assert JyKetCoupled(1, 1, (S.Half, S.Half)).rewrite('Jx') == \ + -I*JxKetCoupled(1, 1, (S.Half, S.Half)) + assert JyKetCoupled(1, 0, (S.Half, S.Half)).rewrite('Jx') == \ + JxKetCoupled(1, 0, (S.Half, S.Half)) + assert JyKetCoupled(1, -1, (S.Half, S.Half)).rewrite('Jx') == \ + I*JxKetCoupled(1, -1, (S.Half, S.Half)) + assert JzKetCoupled(0, 0, (S.Half, S.Half)).rewrite('Jx') == \ + JxKetCoupled(0, 0, (S.Half, S.Half)) + assert JzKetCoupled(1, 1, (S.Half, S.Half)).rewrite('Jx') == \ + JxKetCoupled(1, 1, (S.Half, S.Half))/2 - sqrt(2)*JxKetCoupled(1, 0, ( + S.Half, S.Half))/2 + JxKetCoupled(1, -1, (S.Half, S.Half))/2 + assert JzKetCoupled(1, 0, (S.Half, S.Half)).rewrite('Jx') == \ + sqrt(2)*JxKetCoupled(1, 1, (S( + 1)/2, S.Half))/2 - sqrt(2)*JxKetCoupled(1, -1, (S.Half, S.Half))/2 + assert JzKetCoupled(1, -1, (S.Half, S.Half)).rewrite('Jx') == \ + JxKetCoupled(1, 1, (S.Half, S.Half))/2 + sqrt(2)*JxKetCoupled(1, 0, ( + S.Half, S.Half))/2 + JxKetCoupled(1, -1, (S.Half, S.Half))/2 + assert JxKetCoupled(0, 0, (S.Half, S.Half)).rewrite('Jy') == \ + JyKetCoupled(0, 0, (S.Half, S.Half)) + assert JxKetCoupled(1, 1, (S.Half, S.Half)).rewrite('Jy') == \ + I*JyKetCoupled(1, 1, (S.Half, S.Half)) + assert JxKetCoupled(1, 0, (S.Half, S.Half)).rewrite('Jy') == \ + JyKetCoupled(1, 0, (S.Half, S.Half)) + assert JxKetCoupled(1, -1, (S.Half, S.Half)).rewrite('Jy') == \ + -I*JyKetCoupled(1, -1, (S.Half, S.Half)) + assert JzKetCoupled(0, 0, (S.Half, S.Half)).rewrite('Jy') == \ + JyKetCoupled(0, 0, (S.Half, S.Half)) + assert JzKetCoupled(1, 1, (S.Half, S.Half)).rewrite('Jy') == \ + JyKetCoupled(1, 1, (S.Half, S.Half))/2 - I*sqrt(2)*JyKetCoupled(1, 0, ( + S.Half, S.Half))/2 - JyKetCoupled(1, -1, (S.Half, S.Half))/2 + assert JzKetCoupled(1, 0, (S.Half, S.Half)).rewrite('Jy') == \ + -I*sqrt(2)*JyKetCoupled(1, 1, (S.Half, S.Half))/2 - I*sqrt( + 2)*JyKetCoupled(1, -1, (S.Half, S.Half))/2 + assert JzKetCoupled(1, -1, (S.Half, S.Half)).rewrite('Jy') == \ + -JyKetCoupled(1, 1, (S.Half, S.Half))/2 - I*sqrt(2)*JyKetCoupled(1, 0, (S.Half, S.Half))/2 + JyKetCoupled(1, -1, (S.Half, S.Half))/2 + assert JxKetCoupled(0, 0, (S.Half, S.Half)).rewrite('Jz') == \ + JzKetCoupled(0, 0, (S.Half, S.Half)) + assert JxKetCoupled(1, 1, (S.Half, S.Half)).rewrite('Jz') == \ + JzKetCoupled(1, 1, (S.Half, S.Half))/2 + sqrt(2)*JzKetCoupled(1, 0, ( + S.Half, S.Half))/2 + JzKetCoupled(1, -1, (S.Half, S.Half))/2 + assert JxKetCoupled(1, 0, (S.Half, S.Half)).rewrite('Jz') == \ + -sqrt(2)*JzKetCoupled(1, 1, (S( + 1)/2, S.Half))/2 + sqrt(2)*JzKetCoupled(1, -1, (S.Half, S.Half))/2 + assert JxKetCoupled(1, -1, (S.Half, S.Half)).rewrite('Jz') == \ + JzKetCoupled(1, 1, (S.Half, S.Half))/2 - sqrt(2)*JzKetCoupled(1, 0, ( + S.Half, S.Half))/2 + JzKetCoupled(1, -1, (S.Half, S.Half))/2 + assert JyKetCoupled(0, 0, (S.Half, S.Half)).rewrite('Jz') == \ + JzKetCoupled(0, 0, (S.Half, S.Half)) + assert JyKetCoupled(1, 1, (S.Half, S.Half)).rewrite('Jz') == \ + JzKetCoupled(1, 1, (S.Half, S.Half))/2 + I*sqrt(2)*JzKetCoupled(1, 0, ( + S.Half, S.Half))/2 - JzKetCoupled(1, -1, (S.Half, S.Half))/2 + assert JyKetCoupled(1, 0, (S.Half, S.Half)).rewrite('Jz') == \ + I*sqrt(2)*JzKetCoupled(1, 1, (S.Half, S.Half))/2 + I*sqrt( + 2)*JzKetCoupled(1, -1, (S.Half, S.Half))/2 + assert JyKetCoupled(1, -1, (S.Half, S.Half)).rewrite('Jz') == \ + -JzKetCoupled(1, 1, (S.Half, S.Half))/2 + I*sqrt(2)*JzKetCoupled(1, 0, (S.Half, S.Half))/2 + JzKetCoupled(1, -1, (S.Half, S.Half))/2 + # Symbolic + assert JyKetCoupled(j, m, (j1, j2)).rewrite('Jx') == \ + Sum(WignerD(j, mi, m, 0, 0, pi/2) * JxKetCoupled(j, mi, ( + j1, j2)), (mi, -j, j)) + assert JzKetCoupled(j, m, (j1, j2)).rewrite('Jx') == \ + Sum(WignerD(j, mi, m, 0, pi*Rational(3, 2), 0) * JxKetCoupled(j, mi, ( + j1, j2)), (mi, -j, j)) + assert JxKetCoupled(j, m, (j1, j2)).rewrite('Jy') == \ + Sum(WignerD(j, mi, m, pi*Rational(3, 2), 0, 0) * JyKetCoupled(j, mi, ( + j1, j2)), (mi, -j, j)) + assert JzKetCoupled(j, m, (j1, j2)).rewrite('Jy') == \ + Sum(WignerD(j, mi, m, pi*Rational(3, 2), pi/2, pi/2) * JyKetCoupled(j, + mi, (j1, j2)), (mi, -j, j)) + assert JxKetCoupled(j, m, (j1, j2)).rewrite('Jz') == \ + Sum(WignerD(j, mi, m, 0, pi/2, 0) * JzKetCoupled(j, mi, ( + j1, j2)), (mi, -j, j)) + assert JyKetCoupled(j, m, (j1, j2)).rewrite('Jz') == \ + Sum(WignerD(j, mi, m, pi*Rational(3, 2), -pi/2, pi/2) * JzKetCoupled( + j, mi, (j1, j2)), (mi, -j, j)) + + +def test_innerproducts_of_rewritten_states(): + # Numerical + assert qapply(JxBra(1, 1)*JxKet(1, 1).rewrite('Jy')).doit() == 1 + assert qapply(JxBra(1, 0)*JxKet(1, 0).rewrite('Jy')).doit() == 1 + assert qapply(JxBra(1, -1)*JxKet(1, -1).rewrite('Jy')).doit() == 1 + assert qapply(JxBra(1, 1)*JxKet(1, 1).rewrite('Jz')).doit() == 1 + assert qapply(JxBra(1, 0)*JxKet(1, 0).rewrite('Jz')).doit() == 1 + assert qapply(JxBra(1, -1)*JxKet(1, -1).rewrite('Jz')).doit() == 1 + assert qapply(JyBra(1, 1)*JyKet(1, 1).rewrite('Jx')).doit() == 1 + assert qapply(JyBra(1, 0)*JyKet(1, 0).rewrite('Jx')).doit() == 1 + assert qapply(JyBra(1, -1)*JyKet(1, -1).rewrite('Jx')).doit() == 1 + assert qapply(JyBra(1, 1)*JyKet(1, 1).rewrite('Jz')).doit() == 1 + assert qapply(JyBra(1, 0)*JyKet(1, 0).rewrite('Jz')).doit() == 1 + assert qapply(JyBra(1, -1)*JyKet(1, -1).rewrite('Jz')).doit() == 1 + assert qapply(JyBra(1, 1)*JyKet(1, 1).rewrite('Jz')).doit() == 1 + assert qapply(JyBra(1, 0)*JyKet(1, 0).rewrite('Jz')).doit() == 1 + assert qapply(JyBra(1, -1)*JyKet(1, -1).rewrite('Jz')).doit() == 1 + assert qapply(JzBra(1, 1)*JzKet(1, 1).rewrite('Jy')).doit() == 1 + assert qapply(JzBra(1, 0)*JzKet(1, 0).rewrite('Jy')).doit() == 1 + assert qapply(JzBra(1, -1)*JzKet(1, -1).rewrite('Jy')).doit() == 1 + assert qapply(JxBra(1, 1)*JxKet(1, 0).rewrite('Jy')).doit() == 0 + assert qapply(JxBra(1, 1)*JxKet(1, -1).rewrite('Jy')) == 0 + assert qapply(JxBra(1, 1)*JxKet(1, 0).rewrite('Jz')).doit() == 0 + assert qapply(JxBra(1, 1)*JxKet(1, -1).rewrite('Jz')) == 0 + assert qapply(JyBra(1, 1)*JyKet(1, 0).rewrite('Jx')).doit() == 0 + assert qapply(JyBra(1, 1)*JyKet(1, -1).rewrite('Jx')) == 0 + assert qapply(JyBra(1, 1)*JyKet(1, 0).rewrite('Jz')).doit() == 0 + assert qapply(JyBra(1, 1)*JyKet(1, -1).rewrite('Jz')) == 0 + assert qapply(JzBra(1, 1)*JzKet(1, 0).rewrite('Jx')).doit() == 0 + assert qapply(JzBra(1, 1)*JzKet(1, -1).rewrite('Jx')) == 0 + assert qapply(JzBra(1, 1)*JzKet(1, 0).rewrite('Jy')).doit() == 0 + assert qapply(JzBra(1, 1)*JzKet(1, -1).rewrite('Jy')) == 0 + assert qapply(JxBra(1, 0)*JxKet(1, 1).rewrite('Jy')) == 0 + assert qapply(JxBra(1, 0)*JxKet(1, -1).rewrite('Jy')) == 0 + assert qapply(JxBra(1, 0)*JxKet(1, 1).rewrite('Jz')) == 0 + assert qapply(JxBra(1, 0)*JxKet(1, -1).rewrite('Jz')) == 0 + assert qapply(JyBra(1, 0)*JyKet(1, 1).rewrite('Jx')) == 0 + assert qapply(JyBra(1, 0)*JyKet(1, -1).rewrite('Jx')) == 0 + assert qapply(JyBra(1, 0)*JyKet(1, 1).rewrite('Jz')) == 0 + assert qapply(JyBra(1, 0)*JyKet(1, -1).rewrite('Jz')) == 0 + assert qapply(JzBra(1, 0)*JzKet(1, 1).rewrite('Jx')) == 0 + assert qapply(JzBra(1, 0)*JzKet(1, -1).rewrite('Jx')) == 0 + assert qapply(JzBra(1, 0)*JzKet(1, 1).rewrite('Jy')) == 0 + assert qapply(JzBra(1, 0)*JzKet(1, -1).rewrite('Jy')) == 0 + assert qapply(JxBra(1, -1)*JxKet(1, 1).rewrite('Jy')) == 0 + assert qapply(JxBra(1, -1)*JxKet(1, 0).rewrite('Jy')).doit() == 0 + assert qapply(JxBra(1, -1)*JxKet(1, 1).rewrite('Jz')) == 0 + assert qapply(JxBra(1, -1)*JxKet(1, 0).rewrite('Jz')).doit() == 0 + assert qapply(JyBra(1, -1)*JyKet(1, 1).rewrite('Jx')) == 0 + assert qapply(JyBra(1, -1)*JyKet(1, 0).rewrite('Jx')).doit() == 0 + assert qapply(JyBra(1, -1)*JyKet(1, 1).rewrite('Jz')) == 0 + assert qapply(JyBra(1, -1)*JyKet(1, 0).rewrite('Jz')).doit() == 0 + assert qapply(JzBra(1, -1)*JzKet(1, 1).rewrite('Jx')) == 0 + assert qapply(JzBra(1, -1)*JzKet(1, 0).rewrite('Jx')).doit() == 0 + assert qapply(JzBra(1, -1)*JzKet(1, 1).rewrite('Jy')) == 0 + assert qapply(JzBra(1, -1)*JzKet(1, 0).rewrite('Jy')).doit() == 0 + + +def test_uncouple_2_coupled_states(): + # j1=1/2, j2=1/2 + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( + TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( + TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( + TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( + TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))) ))) + # j1=1/2, j2=1 + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 1)) == \ + expand(uncouple( + couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 1)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 0)) == \ + expand(uncouple( + couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 0)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(1, -1)) == \ + expand(uncouple( + couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(1, -1)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1)) == \ + expand(uncouple( + couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0)) == \ + expand(uncouple( + couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1)) == \ + expand(uncouple( + couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1)) ))) + # j1=1, j2=1 + assert TensorProduct(JzKet(1, 1), JzKet(1, 1)) == \ + expand(uncouple(couple( TensorProduct(JzKet(1, 1), JzKet(1, 1)) ))) + assert TensorProduct(JzKet(1, 1), JzKet(1, 0)) == \ + expand(uncouple(couple( TensorProduct(JzKet(1, 1), JzKet(1, 0)) ))) + assert TensorProduct(JzKet(1, 1), JzKet(1, -1)) == \ + expand(uncouple(couple( TensorProduct(JzKet(1, 1), JzKet(1, -1)) ))) + assert TensorProduct(JzKet(1, 0), JzKet(1, 1)) == \ + expand(uncouple(couple( TensorProduct(JzKet(1, 0), JzKet(1, 1)) ))) + assert TensorProduct(JzKet(1, 0), JzKet(1, 0)) == \ + expand(uncouple(couple( TensorProduct(JzKet(1, 0), JzKet(1, 0)) ))) + assert TensorProduct(JzKet(1, 0), JzKet(1, -1)) == \ + expand(uncouple(couple( TensorProduct(JzKet(1, 0), JzKet(1, -1)) ))) + assert TensorProduct(JzKet(1, -1), JzKet(1, 1)) == \ + expand(uncouple(couple( TensorProduct(JzKet(1, -1), JzKet(1, 1)) ))) + assert TensorProduct(JzKet(1, -1), JzKet(1, 0)) == \ + expand(uncouple(couple( TensorProduct(JzKet(1, -1), JzKet(1, 0)) ))) + assert TensorProduct(JzKet(1, -1), JzKet(1, -1)) == \ + expand(uncouple(couple( TensorProduct(JzKet(1, -1), JzKet(1, -1)) ))) + + +def test_uncouple_3_coupled_states(): + # Default coupling + # j1=1/2, j2=1/2, j3=1/2 + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet( + S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S( + 1)/2, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S( + 1)/2, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S( + 1)/2, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S( + 1)/2, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S( + 1)/2, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S( + 1)/2, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.NegativeOne/ + 2), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))) ))) + # j1=1/2, j2=1, j3=1/2 + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct( + JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct( + JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct( + JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct( + JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct( + JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct( + JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct( + JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct( + JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct( + JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct( + JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct( + JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct( + JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(S.Half, Rational(-1, 2))) ))) + # Coupling j1+j3=j13, j13+j2=j + # j1=1/2, j2=1/2, j3=1/2 + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet( + S.Half, S.Half), JzKet(S.Half, S.Half)), ((1, 3), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet( + S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet( + S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)), ((1, 3), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet( + S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet( + S.Half, S.Half), JzKet(S.Half, S.Half)), ((1, 3), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet( + S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet( + S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)), ((1, 3), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet( + S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (1, 2)) ))) + # j1=1/2, j2=1, j3=1/2 + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S( + 1)/2), JzKet(1, 1), JzKet(S.Half, S.Half)), ((1, 3), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S( + 1)/2), JzKet(1, 1), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S( + 1)/2), JzKet(1, 0), JzKet(S.Half, S.Half)), ((1, 3), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S( + 1)/2), JzKet(1, 0), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S( + 1)/2), JzKet(1, -1), JzKet(S.Half, S.Half)), ((1, 3), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S( + 1)/2), JzKet(1, -1), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S( + -1)/2), JzKet(1, 1), JzKet(S.Half, S.Half)), ((1, 3), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S( + -1)/2), JzKet(1, 1), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S( + -1)/2), JzKet(1, 0), JzKet(S.Half, S.Half)), ((1, 3), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S( + -1)/2), JzKet(1, 0), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S( + -1)/2), JzKet(1, -1), JzKet(S.Half, S.Half)), ((1, 3), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.NegativeOne/ + 2), JzKet(1, -1), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (1, 2)) ))) + + +@slow +def test_uncouple_4_coupled_states(): + # j1=1/2, j2=1/2, j3=1/2, j4=1/2 + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet( + S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S( + 1)/2, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S( + 1)/2, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S( + 1)/2, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S( + 1)/2, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S( + 1)/2, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S( + 1)/2, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet( + S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S( + 1)/2, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S( + 1)/2, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S( + 1)/2, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S( + 1)/2, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S( + 1)/2, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S( + 1)/2, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))) ))) + # j1=1/2, j2=1/2, j3=1, j4=1/2 + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), + JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), + JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), + JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), + JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), + JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet( + S.Half, S.Half), JzKet(1, -1), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), + JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet( + S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), + JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet( + S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet( + S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet( + S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), + JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), + JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), + JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), + JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), + JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet( + S.Half, S.Half), JzKet(1, -1), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), + JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet( + S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), + JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet( + S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(S.Half, Rational(-1, 2))) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet( + S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(S.Half, S.Half)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet( + S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(S.Half, Rational(-1, 2))) ))) + # Couple j1+j3=j13, j2+j4=j24, j13+j24=j + # j1=1/2, j2=1/2, j3=1/2, j4=1/2 + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (2, 4), (1, 2)) ))) + # j1=1/2, j2=1/2, j3=1, j4=1/2 + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(S.Half, S.Half)), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(S.Half, S.Half)), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(S.Half, S.Half)), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(S.Half, S.Half)), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(S.Half, S.Half)), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(S.Half, S.Half)), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(S.Half, S.Half)), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(S.Half, S.Half)), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(S.Half, S.Half)), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(S.Half, S.Half)), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(S.Half, S.Half)), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(S.Half, S.Half)) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(S.Half, S.Half)), ((1, 3), (2, 4), (1, 2)) ))) + assert TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(S.Half, Rational(-1, 2))) == \ + expand(uncouple(couple( TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (2, 4), (1, 2)) ))) + + +def test_uncouple_2_coupled_states_numerical(): + # j1=1/2, j2=1/2 + assert uncouple(JzKetCoupled(0, 0, (S.Half, S.Half))) == \ + sqrt(2)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)))/2 - \ + sqrt(2)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half))/2 + assert uncouple(JzKetCoupled(1, 1, (S.Half, S.Half))) == \ + TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)) + assert uncouple(JzKetCoupled(1, 0, (S.Half, S.Half))) == \ + sqrt(2)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)))/2 + \ + sqrt(2)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half))/2 + assert uncouple(JzKetCoupled(1, -1, (S.Half, S.Half))) == \ + TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))) + # j1=1, j2=1/2 + assert uncouple(JzKetCoupled(S.Half, S.Half, (1, S.Half))) == \ + -sqrt(3)*TensorProduct(JzKet(1, 0), JzKet(S.Half, S.Half))/3 + \ + sqrt(6)*TensorProduct(JzKet(1, 1), JzKet(S.Half, Rational(-1, 2)))/3 + assert uncouple(JzKetCoupled(S.Half, Rational(-1, 2), (1, S.Half))) == \ + sqrt(3)*TensorProduct(JzKet(1, 0), JzKet(S.Half, Rational(-1, 2)))/3 - \ + sqrt(6)*TensorProduct(JzKet(1, -1), JzKet(S.Half, S.Half))/3 + assert uncouple(JzKetCoupled(Rational(3, 2), Rational(3, 2), (1, S.Half))) == \ + TensorProduct(JzKet(1, 1), JzKet(S.Half, S.Half)) + assert uncouple(JzKetCoupled(Rational(3, 2), S.Half, (1, S.Half))) == \ + sqrt(3)*TensorProduct(JzKet(1, 1), JzKet(S.Half, Rational(-1, 2)))/3 + \ + sqrt(6)*TensorProduct(JzKet(1, 0), JzKet(S.Half, S.Half))/3 + assert uncouple(JzKetCoupled(Rational(3, 2), Rational(-1, 2), (1, S.Half))) == \ + sqrt(6)*TensorProduct(JzKet(1, 0), JzKet(S.Half, Rational(-1, 2)))/3 + \ + sqrt(3)*TensorProduct(JzKet(1, -1), JzKet(S.Half, S.Half))/3 + assert uncouple(JzKetCoupled(Rational(3, 2), Rational(-3, 2), (1, S.Half))) == \ + TensorProduct(JzKet(1, -1), JzKet(S.Half, Rational(-1, 2))) + # j1=1, j2=1 + assert uncouple(JzKetCoupled(0, 0, (1, 1))) == \ + sqrt(3)*TensorProduct(JzKet(1, 1), JzKet(1, -1))/3 - \ + sqrt(3)*TensorProduct(JzKet(1, 0), JzKet(1, 0))/3 + \ + sqrt(3)*TensorProduct(JzKet(1, -1), JzKet(1, 1))/3 + assert uncouple(JzKetCoupled(1, 1, (1, 1))) == \ + sqrt(2)*TensorProduct(JzKet(1, 1), JzKet(1, 0))/2 - \ + sqrt(2)*TensorProduct(JzKet(1, 0), JzKet(1, 1))/2 + assert uncouple(JzKetCoupled(1, 0, (1, 1))) == \ + sqrt(2)*TensorProduct(JzKet(1, 1), JzKet(1, -1))/2 - \ + sqrt(2)*TensorProduct(JzKet(1, -1), JzKet(1, 1))/2 + assert uncouple(JzKetCoupled(1, -1, (1, 1))) == \ + sqrt(2)*TensorProduct(JzKet(1, 0), JzKet(1, -1))/2 - \ + sqrt(2)*TensorProduct(JzKet(1, -1), JzKet(1, 0))/2 + assert uncouple(JzKetCoupled(2, 2, (1, 1))) == \ + TensorProduct(JzKet(1, 1), JzKet(1, 1)) + assert uncouple(JzKetCoupled(2, 1, (1, 1))) == \ + sqrt(2)*TensorProduct(JzKet(1, 1), JzKet(1, 0))/2 + \ + sqrt(2)*TensorProduct(JzKet(1, 0), JzKet(1, 1))/2 + assert uncouple(JzKetCoupled(2, 0, (1, 1))) == \ + sqrt(6)*TensorProduct(JzKet(1, 1), JzKet(1, -1))/6 + \ + sqrt(6)*TensorProduct(JzKet(1, 0), JzKet(1, 0))/3 + \ + sqrt(6)*TensorProduct(JzKet(1, -1), JzKet(1, 1))/6 + assert uncouple(JzKetCoupled(2, -1, (1, 1))) == \ + sqrt(2)*TensorProduct(JzKet(1, 0), JzKet(1, -1))/2 + \ + sqrt(2)*TensorProduct(JzKet(1, -1), JzKet(1, 0))/2 + assert uncouple(JzKetCoupled(2, -2, (1, 1))) == \ + TensorProduct(JzKet(1, -1), JzKet(1, -1)) + + +def test_uncouple_3_coupled_states_numerical(): + # Default coupling + # j1=1/2, j2=1/2, j3=1/2 + assert uncouple(JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, S.Half))) == \ + TensorProduct(JzKet( + S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)) + assert uncouple(JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half))) == \ + sqrt(3)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half))/3 + \ + sqrt(3)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half))/3 + \ + sqrt(3)*TensorProduct(JzKet( + S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)))/3 + assert uncouple(JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half))) == \ + sqrt(3)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half))/3 + \ + sqrt(3)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)))/3 + \ + sqrt(3)*TensorProduct(JzKet( + S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)))/3 + assert uncouple(JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, S.Half))) == \ + TensorProduct(JzKet( + S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))) + # j1=1/2, j2=1/2, j3=1 + assert uncouple(JzKetCoupled(2, 2, (S.Half, S.Half, 1))) == \ + TensorProduct( + JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 1)) + assert uncouple(JzKetCoupled(2, 1, (S.Half, S.Half, 1))) == \ + TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1))/2 + \ + TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1))/2 + \ + sqrt(2)*TensorProduct( + JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0))/2 + assert uncouple(JzKetCoupled(2, 0, (S.Half, S.Half, 1))) == \ + sqrt(6)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1))/6 + \ + sqrt(3)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0))/3 + \ + sqrt(3)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0))/3 + \ + sqrt(6)*TensorProduct( + JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1))/6 + assert uncouple(JzKetCoupled(2, -1, (S.Half, S.Half, 1))) == \ + sqrt(2)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0))/2 + \ + TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1))/2 + \ + TensorProduct( + JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1))/2 + assert uncouple(JzKetCoupled(2, -2, (S.Half, S.Half, 1))) == \ + TensorProduct( + JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1)) + assert uncouple(JzKetCoupled(1, 1, (S.Half, S.Half, 1))) == \ + -TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1))/2 - \ + TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1))/2 + \ + sqrt(2)*TensorProduct( + JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0))/2 + assert uncouple(JzKetCoupled(1, 0, (S.Half, S.Half, 1))) == \ + -sqrt(2)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1))/2 + \ + sqrt(2)*TensorProduct( + JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1))/2 + assert uncouple(JzKetCoupled(1, -1, (S.Half, S.Half, 1))) == \ + -sqrt(2)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0))/2 + \ + TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1))/2 + \ + TensorProduct( + JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1))/2 + # j1=1/2, j2=1, j3=1 + assert uncouple(JzKetCoupled(Rational(5, 2), Rational(5, 2), (S.Half, 1, 1))) == \ + TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, 1)) + assert uncouple(JzKetCoupled(Rational(5, 2), Rational(3, 2), (S.Half, 1, 1))) == \ + sqrt(5)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 1))/5 + \ + sqrt(10)*TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 1))/5 + \ + sqrt(10)*TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 1), + JzKet(1, 0))/5 + assert uncouple(JzKetCoupled(Rational(5, 2), S.Half, (S.Half, 1, 1))) == \ + sqrt(5)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 1))/5 + \ + sqrt(5)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 0))/5 + \ + sqrt(10)*TensorProduct(JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 1))/10 + \ + sqrt(10)*TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 0))/5 + \ + sqrt(10)*TensorProduct( + JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, -1))/10 + assert uncouple(JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, 1, 1))) == \ + sqrt(10)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 1))/10 + \ + sqrt(10)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 0))/5 + \ + sqrt(10)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, -1))/10 + \ + sqrt(5)*TensorProduct(JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 0))/5 + \ + sqrt(5)*TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 0), + JzKet(1, -1))/5 + assert uncouple(JzKetCoupled(Rational(5, 2), Rational(-3, 2), (S.Half, 1, 1))) == \ + sqrt(10)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 0))/5 + \ + sqrt(10)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, -1))/5 + \ + sqrt(5)*TensorProduct( + JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, -1))/5 + assert uncouple(JzKetCoupled(Rational(5, 2), Rational(-5, 2), (S.Half, 1, 1))) == \ + TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, -1)) + assert uncouple(JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, 1, 1))) == \ + -sqrt(30)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 1))/15 - \ + 2*sqrt(15)*TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 1))/15 + \ + sqrt(15)*TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 1), + JzKet(1, 0))/5 + assert uncouple(JzKetCoupled(Rational(3, 2), S.Half, (S.Half, 1, 1))) == \ + -4*sqrt(5)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 1))/15 + \ + sqrt(5)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 0))/15 - \ + 2*sqrt(10)*TensorProduct(JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 1))/15 + \ + sqrt(10)*TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 0))/15 + \ + sqrt(10)*TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 1), + JzKet(1, -1))/5 + assert uncouple(JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, 1, 1))) == \ + -sqrt(10)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 1))/5 - \ + sqrt(10)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 0))/15 + \ + 2*sqrt(10)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, -1))/15 - \ + sqrt(5)*TensorProduct(JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 0))/15 + \ + 4*sqrt(5)*TensorProduct( + JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, -1))/15 + assert uncouple(JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, 1, 1))) == \ + -sqrt(15)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 0))/5 + \ + 2*sqrt(15)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, -1))/15 + \ + sqrt(30)*TensorProduct( + JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, -1))/15 + assert uncouple(JzKetCoupled(S.Half, S.Half, (S.Half, 1, 1))) == \ + TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 1))/3 - \ + TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 0))/3 + \ + sqrt(2)*TensorProduct(JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 1))/6 - \ + sqrt(2)*TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 0))/3 + \ + sqrt(2)*TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 1), + JzKet(1, -1))/2 + assert uncouple(JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, 1, 1))) == \ + sqrt(2)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 1))/2 - \ + sqrt(2)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 0))/3 + \ + sqrt(2)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, -1))/6 - \ + TensorProduct(JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 0))/3 + \ + TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, -1))/3 + # j1=1, j2=1, j3=1 + assert uncouple(JzKetCoupled(3, 3, (1, 1, 1))) == \ + TensorProduct(JzKet(1, 1), JzKet(1, 1), JzKet(1, 1)) + assert uncouple(JzKetCoupled(3, 2, (1, 1, 1))) == \ + sqrt(3)*TensorProduct(JzKet(1, 0), JzKet(1, 1), JzKet(1, 1))/3 + \ + sqrt(3)*TensorProduct(JzKet(1, 1), JzKet(1, 0), JzKet(1, 1))/3 + \ + sqrt(3)*TensorProduct(JzKet(1, 1), JzKet(1, 1), JzKet(1, 0))/3 + assert uncouple(JzKetCoupled(3, 1, (1, 1, 1))) == \ + sqrt(15)*TensorProduct(JzKet(1, -1), JzKet(1, 1), JzKet(1, 1))/15 + \ + 2*sqrt(15)*TensorProduct(JzKet(1, 0), JzKet(1, 0), JzKet(1, 1))/15 + \ + 2*sqrt(15)*TensorProduct(JzKet(1, 0), JzKet(1, 1), JzKet(1, 0))/15 + \ + sqrt(15)*TensorProduct(JzKet(1, 1), JzKet(1, -1), JzKet(1, 1))/15 + \ + 2*sqrt(15)*TensorProduct(JzKet(1, 1), JzKet(1, 0), JzKet(1, 0))/15 + \ + sqrt(15)*TensorProduct(JzKet(1, 1), JzKet(1, 1), JzKet(1, -1))/15 + assert uncouple(JzKetCoupled(3, 0, (1, 1, 1))) == \ + sqrt(10)*TensorProduct(JzKet(1, -1), JzKet(1, 0), JzKet(1, 1))/10 + \ + sqrt(10)*TensorProduct(JzKet(1, -1), JzKet(1, 1), JzKet(1, 0))/10 + \ + sqrt(10)*TensorProduct(JzKet(1, 0), JzKet(1, -1), JzKet(1, 1))/10 + \ + sqrt(10)*TensorProduct(JzKet(1, 0), JzKet(1, 0), JzKet(1, 0))/5 + \ + sqrt(10)*TensorProduct(JzKet(1, 0), JzKet(1, 1), JzKet(1, -1))/10 + \ + sqrt(10)*TensorProduct(JzKet(1, 1), JzKet(1, -1), JzKet(1, 0))/10 + \ + sqrt(10)*TensorProduct(JzKet(1, 1), JzKet(1, 0), JzKet(1, -1))/10 + assert uncouple(JzKetCoupled(3, -1, (1, 1, 1))) == \ + sqrt(15)*TensorProduct(JzKet(1, -1), JzKet(1, -1), JzKet(1, 1))/15 + \ + 2*sqrt(15)*TensorProduct(JzKet(1, -1), JzKet(1, 0), JzKet(1, 0))/15 + \ + sqrt(15)*TensorProduct(JzKet(1, -1), JzKet(1, 1), JzKet(1, -1))/15 + \ + 2*sqrt(15)*TensorProduct(JzKet(1, 0), JzKet(1, -1), JzKet(1, 0))/15 + \ + 2*sqrt(15)*TensorProduct(JzKet(1, 0), JzKet(1, 0), JzKet(1, -1))/15 + \ + sqrt(15)*TensorProduct(JzKet(1, 1), JzKet(1, -1), JzKet(1, -1))/15 + assert uncouple(JzKetCoupled(3, -2, (1, 1, 1))) == \ + sqrt(3)*TensorProduct(JzKet(1, -1), JzKet(1, -1), JzKet(1, 0))/3 + \ + sqrt(3)*TensorProduct(JzKet(1, -1), JzKet(1, 0), JzKet(1, -1))/3 + \ + sqrt(3)*TensorProduct(JzKet(1, 0), JzKet(1, -1), JzKet(1, -1))/3 + assert uncouple(JzKetCoupled(3, -3, (1, 1, 1))) == \ + TensorProduct(JzKet(1, -1), JzKet(1, -1), JzKet(1, -1)) + assert uncouple(JzKetCoupled(2, 2, (1, 1, 1))) == \ + -sqrt(6)*TensorProduct(JzKet(1, 0), JzKet(1, 1), JzKet(1, 1))/6 - \ + sqrt(6)*TensorProduct(JzKet(1, 1), JzKet(1, 0), JzKet(1, 1))/6 + \ + sqrt(6)*TensorProduct(JzKet(1, 1), JzKet(1, 1), JzKet(1, 0))/3 + assert uncouple(JzKetCoupled(2, 1, (1, 1, 1))) == \ + -sqrt(3)*TensorProduct(JzKet(1, -1), JzKet(1, 1), JzKet(1, 1))/6 - \ + sqrt(3)*TensorProduct(JzKet(1, 0), JzKet(1, 0), JzKet(1, 1))/3 + \ + sqrt(3)*TensorProduct(JzKet(1, 0), JzKet(1, 1), JzKet(1, 0))/6 - \ + sqrt(3)*TensorProduct(JzKet(1, 1), JzKet(1, -1), JzKet(1, 1))/6 + \ + sqrt(3)*TensorProduct(JzKet(1, 1), JzKet(1, 0), JzKet(1, 0))/6 + \ + sqrt(3)*TensorProduct(JzKet(1, 1), JzKet(1, 1), JzKet(1, -1))/3 + assert uncouple(JzKetCoupled(2, 0, (1, 1, 1))) == \ + -TensorProduct(JzKet(1, -1), JzKet(1, 0), JzKet(1, 1))/2 - \ + TensorProduct(JzKet(1, 0), JzKet(1, -1), JzKet(1, 1))/2 + \ + TensorProduct(JzKet(1, 0), JzKet(1, 1), JzKet(1, -1))/2 + \ + TensorProduct(JzKet(1, 1), JzKet(1, 0), JzKet(1, -1))/2 + assert uncouple(JzKetCoupled(2, -1, (1, 1, 1))) == \ + -sqrt(3)*TensorProduct(JzKet(1, -1), JzKet(1, -1), JzKet(1, 1))/3 - \ + sqrt(3)*TensorProduct(JzKet(1, -1), JzKet(1, 0), JzKet(1, 0))/6 + \ + sqrt(3)*TensorProduct(JzKet(1, -1), JzKet(1, 1), JzKet(1, -1))/6 - \ + sqrt(3)*TensorProduct(JzKet(1, 0), JzKet(1, -1), JzKet(1, 0))/6 + \ + sqrt(3)*TensorProduct(JzKet(1, 0), JzKet(1, 0), JzKet(1, -1))/3 + \ + sqrt(3)*TensorProduct(JzKet(1, 1), JzKet(1, -1), JzKet(1, -1))/6 + assert uncouple(JzKetCoupled(2, -2, (1, 1, 1))) == \ + -sqrt(6)*TensorProduct(JzKet(1, -1), JzKet(1, -1), JzKet(1, 0))/3 + \ + sqrt(6)*TensorProduct(JzKet(1, -1), JzKet(1, 0), JzKet(1, -1))/6 + \ + sqrt(6)*TensorProduct(JzKet(1, 0), JzKet(1, -1), JzKet(1, -1))/6 + assert uncouple(JzKetCoupled(1, 1, (1, 1, 1))) == \ + sqrt(15)*TensorProduct(JzKet(1, -1), JzKet(1, 1), JzKet(1, 1))/30 + \ + sqrt(15)*TensorProduct(JzKet(1, 0), JzKet(1, 0), JzKet(1, 1))/15 - \ + sqrt(15)*TensorProduct(JzKet(1, 0), JzKet(1, 1), JzKet(1, 0))/10 + \ + sqrt(15)*TensorProduct(JzKet(1, 1), JzKet(1, -1), JzKet(1, 1))/30 - \ + sqrt(15)*TensorProduct(JzKet(1, 1), JzKet(1, 0), JzKet(1, 0))/10 + \ + sqrt(15)*TensorProduct(JzKet(1, 1), JzKet(1, 1), JzKet(1, -1))/5 + assert uncouple(JzKetCoupled(1, 0, (1, 1, 1))) == \ + sqrt(15)*TensorProduct(JzKet(1, -1), JzKet(1, 0), JzKet(1, 1))/10 - \ + sqrt(15)*TensorProduct(JzKet(1, -1), JzKet(1, 1), JzKet(1, 0))/15 + \ + sqrt(15)*TensorProduct(JzKet(1, 0), JzKet(1, -1), JzKet(1, 1))/10 - \ + 2*sqrt(15)*TensorProduct(JzKet(1, 0), JzKet(1, 0), JzKet(1, 0))/15 + \ + sqrt(15)*TensorProduct(JzKet(1, 0), JzKet(1, 1), JzKet(1, -1))/10 - \ + sqrt(15)*TensorProduct(JzKet(1, 1), JzKet(1, -1), JzKet(1, 0))/15 + \ + sqrt(15)*TensorProduct(JzKet(1, 1), JzKet(1, 0), JzKet(1, -1))/10 + assert uncouple(JzKetCoupled(1, -1, (1, 1, 1))) == \ + sqrt(15)*TensorProduct(JzKet(1, -1), JzKet(1, -1), JzKet(1, 1))/5 - \ + sqrt(15)*TensorProduct(JzKet(1, -1), JzKet(1, 0), JzKet(1, 0))/10 + \ + sqrt(15)*TensorProduct(JzKet(1, -1), JzKet(1, 1), JzKet(1, -1))/30 - \ + sqrt(15)*TensorProduct(JzKet(1, 0), JzKet(1, -1), JzKet(1, 0))/10 + \ + sqrt(15)*TensorProduct(JzKet(1, 0), JzKet(1, 0), JzKet(1, -1))/15 + \ + sqrt(15)*TensorProduct(JzKet(1, 1), JzKet(1, -1), JzKet(1, -1))/30 + # Defined j13 + # j1=1/2, j2=1/2, j3=1, j13=1/2 + assert uncouple(JzKetCoupled(1, 1, (S.Half, S.Half, 1), ((1, 3, S.Half), (1, 2, 1)) )) == \ + -sqrt(6)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1))/3 + \ + sqrt(3)*TensorProduct( + JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0))/3 + assert uncouple(JzKetCoupled(1, 0, (S.Half, S.Half, 1), ((1, 3, S.Half), (1, 2, 1)) )) == \ + -sqrt(3)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1))/3 - \ + sqrt(6)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0))/6 + \ + sqrt(6)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0))/6 + \ + sqrt(3)*TensorProduct( + JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1))/3 + assert uncouple(JzKetCoupled(1, -1, (S.Half, S.Half, 1), ((1, 3, S.Half), (1, 2, 1)) )) == \ + -sqrt(3)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0))/3 + \ + sqrt(6)*TensorProduct( + JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1))/3 + # j1=1/2, j2=1, j3=1, j13=1/2 + assert uncouple(JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, 1, 1), ((1, 3, S.Half), (1, 2, Rational(3, 2))))) == \ + -sqrt(6)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 1))/3 + \ + sqrt(3)*TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 1), + JzKet(1, 0))/3 + assert uncouple(JzKetCoupled(Rational(3, 2), S.Half, (S.Half, 1, 1), ((1, 3, S.Half), (1, 2, Rational(3, 2))))) == \ + -2*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 1))/3 - \ + TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 0))/3 + \ + sqrt(2)*TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 0))/3 + \ + sqrt(2)*TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 1), + JzKet(1, -1))/3 + assert uncouple(JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, 3, S.Half), (1, 2, Rational(3, 2))))) == \ + -sqrt(2)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 1))/3 - \ + sqrt(2)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 0))/3 + \ + TensorProduct(JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 0))/3 + \ + 2*TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, -1))/3 + assert uncouple(JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, 1, 1), ((1, 3, S.Half), (1, 2, Rational(3, 2))))) == \ + -sqrt(3)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 0))/3 + \ + sqrt(6)*TensorProduct( + JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, -1))/3 + # j1=1, j2=1, j3=1, j13=1 + assert uncouple(JzKetCoupled(2, 2, (1, 1, 1), ((1, 3, 1), (1, 2, 2)))) == \ + -sqrt(2)*TensorProduct(JzKet(1, 0), JzKet(1, 1), JzKet(1, 1))/2 + \ + sqrt(2)*TensorProduct(JzKet(1, 1), JzKet(1, 1), JzKet(1, 0))/2 + assert uncouple(JzKetCoupled(2, 1, (1, 1, 1), ((1, 3, 1), (1, 2, 2)))) == \ + -TensorProduct(JzKet(1, -1), JzKet(1, 1), JzKet(1, 1))/2 - \ + TensorProduct(JzKet(1, 0), JzKet(1, 0), JzKet(1, 1))/2 + \ + TensorProduct(JzKet(1, 1), JzKet(1, 0), JzKet(1, 0))/2 + \ + TensorProduct(JzKet(1, 1), JzKet(1, 1), JzKet(1, -1))/2 + assert uncouple(JzKetCoupled(2, 0, (1, 1, 1), ((1, 3, 1), (1, 2, 2)))) == \ + -sqrt(3)*TensorProduct(JzKet(1, -1), JzKet(1, 0), JzKet(1, 1))/3 - \ + sqrt(3)*TensorProduct(JzKet(1, -1), JzKet(1, 1), JzKet(1, 0))/6 - \ + sqrt(3)*TensorProduct(JzKet(1, 0), JzKet(1, -1), JzKet(1, 1))/6 + \ + sqrt(3)*TensorProduct(JzKet(1, 0), JzKet(1, 1), JzKet(1, -1))/6 + \ + sqrt(3)*TensorProduct(JzKet(1, 1), JzKet(1, -1), JzKet(1, 0))/6 + \ + sqrt(3)*TensorProduct(JzKet(1, 1), JzKet(1, 0), JzKet(1, -1))/3 + assert uncouple(JzKetCoupled(2, -1, (1, 1, 1), ((1, 3, 1), (1, 2, 2)))) == \ + -TensorProduct(JzKet(1, -1), JzKet(1, -1), JzKet(1, 1))/2 - \ + TensorProduct(JzKet(1, -1), JzKet(1, 0), JzKet(1, 0))/2 + \ + TensorProduct(JzKet(1, 0), JzKet(1, 0), JzKet(1, -1))/2 + \ + TensorProduct(JzKet(1, 1), JzKet(1, -1), JzKet(1, -1))/2 + assert uncouple(JzKetCoupled(2, -2, (1, 1, 1), ((1, 3, 1), (1, 2, 2)))) == \ + -sqrt(2)*TensorProduct(JzKet(1, -1), JzKet(1, -1), JzKet(1, 0))/2 + \ + sqrt(2)*TensorProduct(JzKet(1, 0), JzKet(1, -1), JzKet(1, -1))/2 + assert uncouple(JzKetCoupled(1, 1, (1, 1, 1), ((1, 3, 1), (1, 2, 1)))) == \ + TensorProduct(JzKet(1, -1), JzKet(1, 1), JzKet(1, 1))/2 - \ + TensorProduct(JzKet(1, 0), JzKet(1, 0), JzKet(1, 1))/2 + \ + TensorProduct(JzKet(1, 1), JzKet(1, 0), JzKet(1, 0))/2 - \ + TensorProduct(JzKet(1, 1), JzKet(1, 1), JzKet(1, -1))/2 + assert uncouple(JzKetCoupled(1, 0, (1, 1, 1), ((1, 3, 1), (1, 2, 1)))) == \ + TensorProduct(JzKet(1, -1), JzKet(1, 1), JzKet(1, 0))/2 - \ + TensorProduct(JzKet(1, 0), JzKet(1, -1), JzKet(1, 1))/2 - \ + TensorProduct(JzKet(1, 0), JzKet(1, 1), JzKet(1, -1))/2 + \ + TensorProduct(JzKet(1, 1), JzKet(1, -1), JzKet(1, 0))/2 + assert uncouple(JzKetCoupled(1, -1, (1, 1, 1), ((1, 3, 1), (1, 2, 1)))) == \ + -TensorProduct(JzKet(1, -1), JzKet(1, -1), JzKet(1, 1))/2 + \ + TensorProduct(JzKet(1, -1), JzKet(1, 0), JzKet(1, 0))/2 - \ + TensorProduct(JzKet(1, 0), JzKet(1, 0), JzKet(1, -1))/2 + \ + TensorProduct(JzKet(1, 1), JzKet(1, -1), JzKet(1, -1))/2 + + +def test_uncouple_4_coupled_states_numerical(): + # j1=1/2, j2=1/2, j3=1, j4=1, default coupling + assert uncouple(JzKetCoupled(3, 3, (S.Half, S.Half, 1, 1))) == \ + TensorProduct(JzKet( + S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, 1)) + assert uncouple(JzKetCoupled(3, 2, (S.Half, S.Half, 1, 1))) == \ + sqrt(6)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, 1))/6 + \ + sqrt(6)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 1))/6 + \ + sqrt(3)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 1))/3 + \ + sqrt(3)*TensorProduct(JzKet(S( + 1)/2, S.Half), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, 0))/3 + assert uncouple(JzKetCoupled(3, 1, (S.Half, S.Half, 1, 1))) == \ + sqrt(15)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 1))/15 + \ + sqrt(30)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 1))/15 + \ + sqrt(30)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, 0))/15 + \ + sqrt(30)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 1))/15 + \ + sqrt(30)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 0))/15 + \ + sqrt(15)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 1))/15 + \ + 2*sqrt(15)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 0))/15 + \ + sqrt(15)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, + S.Half), JzKet(1, 1), JzKet(1, -1))/15 + assert uncouple(JzKetCoupled(3, 0, (S.Half, S.Half, 1, 1))) == \ + sqrt(10)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 1))/10 + \ + sqrt(10)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 0))/10 + \ + sqrt(5)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 1))/10 + \ + sqrt(5)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 0))/5 + \ + sqrt(5)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, -1))/10 + \ + sqrt(5)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 1))/10 + \ + sqrt(5)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 0))/5 + \ + sqrt(5)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, -1))/10 + \ + sqrt(10)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 0))/10 + \ + sqrt(10)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, + S.Half), JzKet(1, 0), JzKet(1, -1))/10 + assert uncouple(JzKetCoupled(3, -1, (S.Half, S.Half, 1, 1))) == \ + sqrt(15)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 1))/15 + \ + 2*sqrt(15)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 0))/15 + \ + sqrt(15)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, -1))/15 + \ + sqrt(30)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 0))/15 + \ + sqrt(30)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, -1))/15 + \ + sqrt(30)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 0))/15 + \ + sqrt(30)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, -1))/15 + \ + sqrt(15)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, + S.Half), JzKet(1, -1), JzKet(1, -1))/15 + assert uncouple(JzKetCoupled(3, -2, (S.Half, S.Half, 1, 1))) == \ + sqrt(3)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 0))/3 + \ + sqrt(3)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, -1))/3 + \ + sqrt(6)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, -1))/6 + \ + sqrt(6)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, + Rational(-1, 2)), JzKet(1, -1), JzKet(1, -1))/6 + assert uncouple(JzKetCoupled(3, -3, (S.Half, S.Half, 1, 1))) == \ + TensorProduct(JzKet(S.Half, -S( + 1)/2), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, -1)) + assert uncouple(JzKetCoupled(2, 2, (S.Half, S.Half, 1, 1))) == \ + -sqrt(3)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, 1))/6 - \ + sqrt(3)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 1))/6 - \ + sqrt(6)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 1))/6 + \ + sqrt(6)*TensorProduct(JzKet(S( + 1)/2, S.Half), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, 0))/3 + assert uncouple(JzKetCoupled(2, 1, (S.Half, S.Half, 1, 1))) == \ + -sqrt(3)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 1))/6 - \ + sqrt(6)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 1))/6 + \ + sqrt(6)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, 0))/12 - \ + sqrt(6)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 1))/6 + \ + sqrt(6)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 0))/12 - \ + sqrt(3)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 1))/6 + \ + sqrt(3)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 0))/6 + \ + sqrt(3)*TensorProduct(JzKet(S( + 1)/2, S.Half), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, -1))/3 + assert uncouple(JzKetCoupled(2, 0, (S.Half, S.Half, 1, 1))) == \ + -TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 1))/2 - \ + sqrt(2)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 1))/4 + \ + sqrt(2)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, -1))/4 - \ + sqrt(2)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 1))/4 + \ + sqrt(2)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, -1))/4 + \ + TensorProduct(JzKet(S( + 1)/2, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, -1))/2 + assert uncouple(JzKetCoupled(2, -1, (S.Half, S.Half, 1, 1))) == \ + -sqrt(3)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 1))/3 - \ + sqrt(3)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 0))/6 + \ + sqrt(3)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, -1))/6 - \ + sqrt(6)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 0))/12 + \ + sqrt(6)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, -1))/6 - \ + sqrt(6)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 0))/12 + \ + sqrt(6)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, -1))/6 + \ + sqrt(3)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, + S.Half), JzKet(1, -1), JzKet(1, -1))/6 + assert uncouple(JzKetCoupled(2, -2, (S.Half, S.Half, 1, 1))) == \ + -sqrt(6)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 0))/3 + \ + sqrt(6)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, -1))/6 + \ + sqrt(3)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, -1))/6 + \ + sqrt(3)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, + Rational(-1, 2)), JzKet(1, -1), JzKet(1, -1))/6 + assert uncouple(JzKetCoupled(1, 1, (S.Half, S.Half, 1, 1))) == \ + sqrt(15)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 1))/30 + \ + sqrt(30)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 1))/30 - \ + sqrt(30)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, 0))/20 + \ + sqrt(30)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 1))/30 - \ + sqrt(30)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 0))/20 + \ + sqrt(15)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 1))/30 - \ + sqrt(15)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 0))/10 + \ + sqrt(15)*TensorProduct(JzKet(S( + 1)/2, S.Half), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, -1))/5 + assert uncouple(JzKetCoupled(1, 0, (S.Half, S.Half, 1, 1))) == \ + sqrt(15)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 1))/10 - \ + sqrt(15)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 0))/15 + \ + sqrt(30)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 1))/20 - \ + sqrt(30)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 0))/15 + \ + sqrt(30)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, -1))/20 + \ + sqrt(30)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 1))/20 - \ + sqrt(30)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 0))/15 + \ + sqrt(30)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, -1))/20 - \ + sqrt(15)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 0))/15 + \ + sqrt(15)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, + S.Half), JzKet(1, 0), JzKet(1, -1))/10 + assert uncouple(JzKetCoupled(1, -1, (S.Half, S.Half, 1, 1))) == \ + sqrt(15)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 1))/5 - \ + sqrt(15)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 0))/10 + \ + sqrt(15)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, -1))/30 - \ + sqrt(30)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 0))/20 + \ + sqrt(30)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, -1))/30 - \ + sqrt(30)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 0))/20 + \ + sqrt(30)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, -1))/30 + \ + sqrt(15)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, + S.Half), JzKet(1, -1), JzKet(1, -1))/30 + # j1=1/2, j2=1/2, j3=1, j4=1, j12=1, j34=1 + assert uncouple(JzKetCoupled(2, 2, (S.Half, S.Half, 1, 1), ((1, 2, 1), (3, 4, 1), (1, 3, 2)))) == \ + -sqrt(2)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 1))/2 + \ + sqrt(2)*TensorProduct(JzKet(S( + 1)/2, S.Half), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, 0))/2 + assert uncouple(JzKetCoupled(2, 1, (S.Half, S.Half, 1, 1), ((1, 2, 1), (3, 4, 1), (1, 3, 2)))) == \ + -sqrt(2)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 1))/4 + \ + sqrt(2)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, 0))/4 - \ + sqrt(2)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 1))/4 + \ + sqrt(2)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 0))/4 - \ + TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 1))/2 + \ + TensorProduct(JzKet(S( + 1)/2, S.Half), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, -1))/2 + assert uncouple(JzKetCoupled(2, 0, (S.Half, S.Half, 1, 1), ((1, 2, 1), (3, 4, 1), (1, 3, 2)))) == \ + -sqrt(3)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 1))/6 + \ + sqrt(3)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 0))/6 - \ + sqrt(6)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 1))/6 + \ + sqrt(6)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, -1))/6 - \ + sqrt(6)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 1))/6 + \ + sqrt(6)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, -1))/6 - \ + sqrt(3)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 0))/6 + \ + sqrt(3)*TensorProduct(JzKet(S( + 1)/2, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, -1))/6 + assert uncouple(JzKetCoupled(2, -1, (S.Half, S.Half, 1, 1), ((1, 2, 1), (3, 4, 1), (1, 3, 2)))) == \ + -TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 1))/2 + \ + TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, -1))/2 - \ + sqrt(2)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 0))/4 + \ + sqrt(2)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, -1))/4 - \ + sqrt(2)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 0))/4 + \ + sqrt(2)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, + Rational(-1, 2)), JzKet(1, 0), JzKet(1, -1))/4 + assert uncouple(JzKetCoupled(2, -2, (S.Half, S.Half, 1, 1), ((1, 2, 1), (3, 4, 1), (1, 3, 2)))) == \ + -sqrt(2)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 0))/2 + \ + sqrt(2)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, + Rational(-1, 2)), JzKet(1, 0), JzKet(1, -1))/2 + assert uncouple(JzKetCoupled(1, 1, (S.Half, S.Half, 1, 1), ((1, 2, 1), (3, 4, 1), (1, 3, 1)))) == \ + sqrt(2)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 1))/4 - \ + sqrt(2)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, 0))/4 + \ + sqrt(2)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 1))/4 - \ + sqrt(2)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 0))/4 - \ + TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 1))/2 + \ + TensorProduct(JzKet(S( + 1)/2, S.Half), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, -1))/2 + assert uncouple(JzKetCoupled(1, 0, (S.Half, S.Half, 1, 1), ((1, 2, 1), (3, 4, 1), (1, 3, 1)))) == \ + TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 1))/2 - \ + TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 0))/2 - \ + TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 0))/2 + \ + TensorProduct(JzKet(S( + 1)/2, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, -1))/2 + assert uncouple(JzKetCoupled(1, -1, (S.Half, S.Half, 1, 1), ((1, 2, 1), (3, 4, 1), (1, 3, 1)))) == \ + TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 1))/2 - \ + TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, -1))/2 - \ + sqrt(2)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 0))/4 + \ + sqrt(2)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, -1))/4 - \ + sqrt(2)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 0))/4 + \ + sqrt(2)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, + Rational(-1, 2)), JzKet(1, 0), JzKet(1, -1))/4 + # j1=1/2, j2=1/2, j3=1, j4=1, j12=1, j34=2 + assert uncouple(JzKetCoupled(3, 3, (S.Half, S.Half, 1, 1), ((1, 2, 1), (3, 4, 2), (1, 3, 3)))) == \ + TensorProduct(JzKet( + S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, 1)) + assert uncouple(JzKetCoupled(3, 2, (S.Half, S.Half, 1, 1), ((1, 2, 1), (3, 4, 2), (1, 3, 3)))) == \ + sqrt(6)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, 1))/6 + \ + sqrt(6)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 1))/6 + \ + sqrt(3)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 1))/3 + \ + sqrt(3)*TensorProduct(JzKet(S( + 1)/2, S.Half), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, 0))/3 + assert uncouple(JzKetCoupled(3, 1, (S.Half, S.Half, 1, 1), ((1, 2, 1), (3, 4, 2), (1, 3, 3)))) == \ + sqrt(15)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 1))/15 + \ + sqrt(30)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 1))/15 + \ + sqrt(30)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, 0))/15 + \ + sqrt(30)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 1))/15 + \ + sqrt(30)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 0))/15 + \ + sqrt(15)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 1))/15 + \ + 2*sqrt(15)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 0))/15 + \ + sqrt(15)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, + S.Half), JzKet(1, 1), JzKet(1, -1))/15 + assert uncouple(JzKetCoupled(3, 0, (S.Half, S.Half, 1, 1), ((1, 2, 1), (3, 4, 2), (1, 3, 3)))) == \ + sqrt(10)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 1))/10 + \ + sqrt(10)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 0))/10 + \ + sqrt(5)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 1))/10 + \ + sqrt(5)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 0))/5 + \ + sqrt(5)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, -1))/10 + \ + sqrt(5)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 1))/10 + \ + sqrt(5)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 0))/5 + \ + sqrt(5)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, -1))/10 + \ + sqrt(10)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 0))/10 + \ + sqrt(10)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, + S.Half), JzKet(1, 0), JzKet(1, -1))/10 + assert uncouple(JzKetCoupled(3, -1, (S.Half, S.Half, 1, 1), ((1, 2, 1), (3, 4, 2), (1, 3, 3)))) == \ + sqrt(15)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 1))/15 + \ + 2*sqrt(15)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 0))/15 + \ + sqrt(15)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, -1))/15 + \ + sqrt(30)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 0))/15 + \ + sqrt(30)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, -1))/15 + \ + sqrt(30)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 0))/15 + \ + sqrt(30)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, -1))/15 + \ + sqrt(15)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, + S.Half), JzKet(1, -1), JzKet(1, -1))/15 + assert uncouple(JzKetCoupled(3, -2, (S.Half, S.Half, 1, 1), ((1, 2, 1), (3, 4, 2), (1, 3, 3)))) == \ + sqrt(3)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 0))/3 + \ + sqrt(3)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, -1))/3 + \ + sqrt(6)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, -1))/6 + \ + sqrt(6)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, + Rational(-1, 2)), JzKet(1, -1), JzKet(1, -1))/6 + assert uncouple(JzKetCoupled(3, -3, (S.Half, S.Half, 1, 1), ((1, 2, 1), (3, 4, 2), (1, 3, 3)))) == \ + TensorProduct(JzKet(S.Half, -S( + 1)/2), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, -1)) + assert uncouple(JzKetCoupled(2, 2, (S.Half, S.Half, 1, 1), ((1, 2, 1), (3, 4, 2), (1, 3, 2)))) == \ + -sqrt(3)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, 1))/3 - \ + sqrt(3)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 1))/3 + \ + sqrt(6)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 1))/6 + \ + sqrt(6)*TensorProduct(JzKet(S( + 1)/2, S.Half), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, 0))/6 + assert uncouple(JzKetCoupled(2, 1, (S.Half, S.Half, 1, 1), ((1, 2, 1), (3, 4, 2), (1, 3, 2)))) == \ + -sqrt(3)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 1))/3 - \ + sqrt(6)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 1))/12 - \ + sqrt(6)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, 0))/12 - \ + sqrt(6)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 1))/12 - \ + sqrt(6)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 0))/12 + \ + sqrt(3)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 1))/6 + \ + sqrt(3)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 0))/3 + \ + sqrt(3)*TensorProduct(JzKet(S( + 1)/2, S.Half), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, -1))/6 + assert uncouple(JzKetCoupled(2, 0, (S.Half, S.Half, 1, 1), ((1, 2, 1), (3, 4, 2), (1, 3, 2)))) == \ + -TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 1))/2 - \ + TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 0))/2 + \ + TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 0))/2 + \ + TensorProduct(JzKet(S( + 1)/2, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, -1))/2 + assert uncouple(JzKetCoupled(2, -1, (S.Half, S.Half, 1, 1), ((1, 2, 1), (3, 4, 2), (1, 3, 2)))) == \ + -sqrt(3)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 1))/6 - \ + sqrt(3)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 0))/3 - \ + sqrt(3)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, -1))/6 + \ + sqrt(6)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 0))/12 + \ + sqrt(6)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, -1))/12 + \ + sqrt(6)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 0))/12 + \ + sqrt(6)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, -1))/12 + \ + sqrt(3)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, + S.Half), JzKet(1, -1), JzKet(1, -1))/3 + assert uncouple(JzKetCoupled(2, -2, (S.Half, S.Half, 1, 1), ((1, 2, 1), (3, 4, 2), (1, 3, 2)))) == \ + -sqrt(6)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 0))/6 - \ + sqrt(6)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, -1))/6 + \ + sqrt(3)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, -1))/3 + \ + sqrt(3)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, + Rational(-1, 2)), JzKet(1, -1), JzKet(1, -1))/3 + assert uncouple(JzKetCoupled(1, 1, (S.Half, S.Half, 1, 1), ((1, 2, 1), (3, 4, 2), (1, 3, 1)))) == \ + sqrt(15)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 1))/5 - \ + sqrt(30)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 1))/20 - \ + sqrt(30)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, 0))/20 - \ + sqrt(30)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 1))/20 - \ + sqrt(30)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 0))/20 + \ + sqrt(15)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 1))/30 + \ + sqrt(15)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 0))/15 + \ + sqrt(15)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, + S.Half), JzKet(1, 1), JzKet(1, -1))/30 + assert uncouple(JzKetCoupled(1, 0, (S.Half, S.Half, 1, 1), ((1, 2, 1), (3, 4, 2), (1, 3, 1)))) == \ + sqrt(15)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 1))/10 + \ + sqrt(15)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 0))/10 - \ + sqrt(30)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 1))/30 - \ + sqrt(30)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 0))/15 - \ + sqrt(30)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, -1))/30 - \ + sqrt(30)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 1))/30 - \ + sqrt(30)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 0))/15 - \ + sqrt(30)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, -1))/30 + \ + sqrt(15)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 0))/10 + \ + sqrt(15)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, + S.Half), JzKet(1, 0), JzKet(1, -1))/10 + assert uncouple(JzKetCoupled(1, -1, (S.Half, S.Half, 1, 1), ((1, 2, 1), (3, 4, 2), (1, 3, 1)))) == \ + sqrt(15)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 1))/30 + \ + sqrt(15)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 0))/15 + \ + sqrt(15)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, -1))/30 - \ + sqrt(30)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 0))/20 - \ + sqrt(30)*TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, -1))/20 - \ + sqrt(30)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 0))/20 - \ + sqrt(30)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, -1))/20 + \ + sqrt(15)*TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, + S.Half), JzKet(1, -1), JzKet(1, -1))/5 + + +def test_uncouple_symbolic(): + assert uncouple(JzKetCoupled(j, m, (j1, j2) )) == \ + Sum(CG(j1, m1, j2, m2, j, m) * + TensorProduct(JzKet(j1, m1), JzKet(j2, m2)), + (m1, -j1, j1), (m2, -j2, j2)) + assert uncouple(JzKetCoupled(j, m, (j1, j2, j3) )) == \ + Sum(CG(j1, m1, j2, m2, j1 + j2, m1 + m2) * CG(j1 + j2, m1 + m2, j3, m3, j, m) * + TensorProduct(JzKet(j1, m1), JzKet(j2, m2), JzKet(j3, m3)), + (m1, -j1, j1), (m2, -j2, j2), (m3, -j3, j3)) + assert uncouple(JzKetCoupled(j, m, (j1, j2, j3), ((1, 3, j13), (1, 2, j)) )) == \ + Sum(CG(j1, m1, j3, m3, j13, m1 + m3) * CG(j13, m1 + m3, j2, m2, j, m) * + TensorProduct(JzKet(j1, m1), JzKet(j2, m2), JzKet(j3, m3)), + (m1, -j1, j1), (m2, -j2, j2), (m3, -j3, j3)) + assert uncouple(JzKetCoupled(j, m, (j1, j2, j3, j4) )) == \ + Sum(CG(j1, m1, j2, m2, j1 + j2, m1 + m2) * CG(j1 + j2, m1 + m2, j3, m3, j1 + j2 + j3, m1 + m2 + m3) * CG(j1 + j2 + j3, m1 + m2 + m3, j4, m4, j, m) * + TensorProduct( + JzKet(j1, m1), JzKet(j2, m2), JzKet(j3, m3), JzKet(j4, m4)), + (m1, -j1, j1), (m2, -j2, j2), (m3, -j3, j3), (m4, -j4, j4)) + assert uncouple(JzKetCoupled(j, m, (j1, j2, j3, j4), ((1, 3, j13), (2, 4, j24), (1, 2, j)) )) == \ + Sum(CG(j1, m1, j3, m3, j13, m1 + m3) * CG(j2, m2, j4, m4, j24, m2 + m4) * CG(j13, m1 + m3, j24, m2 + m4, j, m) * + TensorProduct( + JzKet(j1, m1), JzKet(j2, m2), JzKet(j3, m3), JzKet(j4, m4)), + (m1, -j1, j1), (m2, -j2, j2), (m3, -j3, j3), (m4, -j4, j4)) + + +def test_couple_2_states(): + # j1=1/2, j2=1/2 + assert JzKetCoupled(0, 0, (S.Half, S.Half)) == \ + expand(couple(uncouple( JzKetCoupled(0, 0, (S.Half, S.Half)) ))) + assert JzKetCoupled(1, 1, (S.Half, S.Half)) == \ + expand(couple(uncouple( JzKetCoupled(1, 1, (S.Half, S.Half)) ))) + assert JzKetCoupled(1, 0, (S.Half, S.Half)) == \ + expand(couple(uncouple( JzKetCoupled(1, 0, (S.Half, S.Half)) ))) + assert JzKetCoupled(1, -1, (S.Half, S.Half)) == \ + expand(couple(uncouple( JzKetCoupled(1, -1, (S.Half, S.Half)) ))) + # j1=1, j2=1/2 + assert JzKetCoupled(S.Half, S.Half, (1, S.Half)) == \ + expand(couple(uncouple( JzKetCoupled(S.Half, S.Half, (1, S.Half)) ))) + assert JzKetCoupled(S.Half, Rational(-1, 2), (1, S.Half)) == \ + expand(couple(uncouple( JzKetCoupled(S.Half, Rational(-1, 2), (1, S.Half)) ))) + assert JzKetCoupled(Rational(3, 2), Rational(3, 2), (1, S.Half)) == \ + expand(couple(uncouple( JzKetCoupled(Rational(3, 2), Rational(3, 2), (1, S.Half)) ))) + assert JzKetCoupled(Rational(3, 2), S.Half, (1, S.Half)) == \ + expand(couple(uncouple( JzKetCoupled(Rational(3, 2), S.Half, (1, S.Half)) ))) + assert JzKetCoupled(Rational(3, 2), Rational(-1, 2), (1, S.Half)) == \ + expand(couple(uncouple( JzKetCoupled(Rational(3, 2), Rational(-1, 2), (1, S.Half)) ))) + assert JzKetCoupled(Rational(3, 2), Rational(-3, 2), (1, S.Half)) == \ + expand(couple(uncouple( JzKetCoupled(Rational(3, 2), Rational(-3, 2), (1, S.Half)) ))) + # j1=1, j2=1 + assert JzKetCoupled(0, 0, (1, 1)) == \ + expand(couple(uncouple( JzKetCoupled(0, 0, (1, 1)) ))) + assert JzKetCoupled(1, 1, (1, 1)) == \ + expand(couple(uncouple( JzKetCoupled(1, 1, (1, 1)) ))) + assert JzKetCoupled(1, 0, (1, 1)) == \ + expand(couple(uncouple( JzKetCoupled(1, 0, (1, 1)) ))) + assert JzKetCoupled(1, -1, (1, 1)) == \ + expand(couple(uncouple( JzKetCoupled(1, -1, (1, 1)) ))) + assert JzKetCoupled(2, 2, (1, 1)) == \ + expand(couple(uncouple( JzKetCoupled(2, 2, (1, 1)) ))) + assert JzKetCoupled(2, 1, (1, 1)) == \ + expand(couple(uncouple( JzKetCoupled(2, 1, (1, 1)) ))) + assert JzKetCoupled(2, 0, (1, 1)) == \ + expand(couple(uncouple( JzKetCoupled(2, 0, (1, 1)) ))) + assert JzKetCoupled(2, -1, (1, 1)) == \ + expand(couple(uncouple( JzKetCoupled(2, -1, (1, 1)) ))) + assert JzKetCoupled(2, -2, (1, 1)) == \ + expand(couple(uncouple( JzKetCoupled(2, -2, (1, 1)) ))) + # j1=1/2, j2=3/2 + assert JzKetCoupled(1, 1, (S.Half, Rational(3, 2))) == \ + expand(couple(uncouple( JzKetCoupled(1, 1, (S.Half, Rational(3, 2))) ))) + assert JzKetCoupled(1, 0, (S.Half, Rational(3, 2))) == \ + expand(couple(uncouple( JzKetCoupled(1, 0, (S.Half, Rational(3, 2))) ))) + assert JzKetCoupled(1, -1, (S.Half, Rational(3, 2))) == \ + expand(couple(uncouple( JzKetCoupled(1, -1, (S.Half, Rational(3, 2))) ))) + assert JzKetCoupled(2, 2, (S.Half, Rational(3, 2))) == \ + expand(couple(uncouple( JzKetCoupled(2, 2, (S.Half, Rational(3, 2))) ))) + assert JzKetCoupled(2, 1, (S.Half, Rational(3, 2))) == \ + expand(couple(uncouple( JzKetCoupled(2, 1, (S.Half, Rational(3, 2))) ))) + assert JzKetCoupled(2, 0, (S.Half, Rational(3, 2))) == \ + expand(couple(uncouple( JzKetCoupled(2, 0, (S.Half, Rational(3, 2))) ))) + assert JzKetCoupled(2, -1, (S.Half, Rational(3, 2))) == \ + expand(couple(uncouple( JzKetCoupled(2, -1, (S.Half, Rational(3, 2))) ))) + assert JzKetCoupled(2, -2, (S.Half, Rational(3, 2))) == \ + expand(couple(uncouple( JzKetCoupled(2, -2, (S.Half, Rational(3, 2))) ))) + + +def test_couple_3_states(): + # Default coupling + # j1=1/2, j2=1/2, j3=1/2 + assert JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half)) == \ + expand(couple(uncouple( + JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half)) ))) + assert JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half)) == \ + expand(couple(uncouple( + JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half)) ))) + assert JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, S.Half)) == \ + expand(couple(uncouple( + JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, S.Half)) ))) + assert JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half)) == \ + expand(couple(uncouple( + JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half)) ))) + assert JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half)) == \ + expand(couple(uncouple( + JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half)) ))) + assert JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, S.Half)) == \ + expand(couple(uncouple( + JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, S.Half)) ))) + # j1=1/2, j2=1/2, j3=1 + assert JzKetCoupled(0, 0, (S.Half, S.Half, 1)) == \ + expand(couple(uncouple( JzKetCoupled(0, 0, (S.Half, S.Half, 1)) ))) + assert JzKetCoupled(1, 1, (S.Half, S.Half, 1)) == \ + expand(couple(uncouple( JzKetCoupled(1, 1, (S.Half, S.Half, 1)) ))) + assert JzKetCoupled(1, 0, (S.Half, S.Half, 1)) == \ + expand(couple(uncouple( JzKetCoupled(1, 0, (S.Half, S.Half, 1)) ))) + assert JzKetCoupled(1, -1, (S.Half, S.Half, 1)) == \ + expand(couple(uncouple( JzKetCoupled(1, -1, (S.Half, S.Half, 1)) ))) + assert JzKetCoupled(2, 2, (S.Half, S.Half, 1)) == \ + expand(couple(uncouple( JzKetCoupled(2, 2, (S.Half, S.Half, 1)) ))) + assert JzKetCoupled(2, 1, (S.Half, S.Half, 1)) == \ + expand(couple(uncouple( JzKetCoupled(2, 1, (S.Half, S.Half, 1)) ))) + assert JzKetCoupled(2, 0, (S.Half, S.Half, 1)) == \ + expand(couple(uncouple( JzKetCoupled(2, 0, (S.Half, S.Half, 1)) ))) + assert JzKetCoupled(2, -1, (S.Half, S.Half, 1)) == \ + expand(couple(uncouple( JzKetCoupled(2, -1, (S.Half, S.Half, 1)) ))) + assert JzKetCoupled(2, -2, (S.Half, S.Half, 1)) == \ + expand(couple(uncouple( JzKetCoupled(2, -2, (S.Half, S.Half, 1)) ))) + # Couple j1+j3=j13, j13+j2=j + # j1=1/2, j2=1/2, j3=1/2, j13=0 + assert JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half), ((1, 3, 0), (1, 2, S.Half))) == \ + expand(couple(uncouple( JzKetCoupled(S.Half, S.Half, (S.Half, S( + 1)/2, S.Half), ((1, 3, 0), (1, 2, S.Half))) ), ((1, 3), (1, 2)) )) + assert JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half), ((1, 3, 0), (1, 2, S.Half))) == \ + expand(couple(uncouple( JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S( + 1)/2, S.Half), ((1, 3, 0), (1, 2, S.Half))) ), ((1, 3), (1, 2)) )) + # j1=1, j2=1/2, j3=1, j13=1 + assert JzKetCoupled(S.Half, S.Half, (1, S.Half, 1), ((1, 3, 1), (1, 2, S.Half))) == \ + expand(couple(uncouple( JzKetCoupled(S.Half, S.Half, ( + 1, S.Half, 1), ((1, 3, 1), (1, 2, S.Half))) ), ((1, 3), (1, 2)) )) + assert JzKetCoupled(S.Half, Rational(-1, 2), (1, S.Half, 1), ((1, 3, 1), (1, 2, S.Half))) == \ + expand(couple(uncouple( JzKetCoupled(S.Half, Rational(-1, 2), ( + 1, S.Half, 1), ((1, 3, 1), (1, 2, S.Half))) ), ((1, 3), (1, 2)) )) + assert JzKetCoupled(Rational(3, 2), Rational(3, 2), (1, S.Half, 1), ((1, 3, 1), (1, 2, Rational(3, 2)))) == \ + expand(couple(uncouple( JzKetCoupled(Rational(3, 2), Rational(3, 2), ( + 1, S.Half, 1), ((1, 3, 1), (1, 2, Rational(3, 2)))) ), ((1, 3), (1, 2)) )) + assert JzKetCoupled(Rational(3, 2), S.Half, (1, S.Half, 1), ((1, 3, 1), (1, 2, Rational(3, 2)))) == \ + expand(couple(uncouple( JzKetCoupled(Rational(3, 2), S.Half, ( + 1, S.Half, 1), ((1, 3, 1), (1, 2, Rational(3, 2)))) ), ((1, 3), (1, 2)) )) + assert JzKetCoupled(Rational(3, 2), Rational(-1, 2), (1, S.Half, 1), ((1, 3, 1), (1, 2, Rational(3, 2)))) == \ + expand(couple(uncouple( JzKetCoupled(Rational(3, 2), Rational(-1, 2), ( + 1, S.Half, 1), ((1, 3, 1), (1, 2, Rational(3, 2)))) ), ((1, 3), (1, 2)) )) + assert JzKetCoupled(Rational(3, 2), Rational(-3, 2), (1, S.Half, 1), ((1, 3, 1), (1, 2, Rational(3, 2)))) == \ + expand(couple(uncouple( JzKetCoupled(Rational(3, 2), Rational(-3, 2), ( + 1, S.Half, 1), ((1, 3, 1), (1, 2, Rational(3, 2)))) ), ((1, 3), (1, 2)) )) + + +def test_couple_4_states(): + # Default coupling + # j1=1/2, j2=1/2, j3=1/2, j4=1/2 + assert JzKetCoupled(1, 1, (S.Half, S.Half, S.Half, S.Half)) == \ + expand(couple( + uncouple( JzKetCoupled(1, 1, (S.Half, S.Half, S.Half, S.Half)) ))) + assert JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half)) == \ + expand(couple( + uncouple( JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half)) ))) + assert JzKetCoupled(1, -1, (S.Half, S.Half, S.Half, S.Half)) == \ + expand(couple(uncouple( + JzKetCoupled(1, -1, (S.Half, S.Half, S.Half, S.Half)) ))) + assert JzKetCoupled(2, 2, (S.Half, S.Half, S.Half, S.Half)) == \ + expand(couple( + uncouple( JzKetCoupled(2, 2, (S.Half, S.Half, S.Half, S.Half)) ))) + assert JzKetCoupled(2, 1, (S.Half, S.Half, S.Half, S.Half)) == \ + expand(couple( + uncouple( JzKetCoupled(2, 1, (S.Half, S.Half, S.Half, S.Half)) ))) + assert JzKetCoupled(2, 0, (S.Half, S.Half, S.Half, S.Half)) == \ + expand(couple( + uncouple( JzKetCoupled(2, 0, (S.Half, S.Half, S.Half, S.Half)) ))) + assert JzKetCoupled(2, -1, (S.Half, S.Half, S.Half, S.Half)) == \ + expand(couple(uncouple( + JzKetCoupled(2, -1, (S.Half, S.Half, S.Half, S.Half)) ))) + assert JzKetCoupled(2, -2, (S.Half, S.Half, S.Half, S.Half)) == \ + expand(couple(uncouple( + JzKetCoupled(2, -2, (S.Half, S.Half, S.Half, S.Half)) ))) + # j1=1/2, j2=1/2, j3=1/2, j4=1 + assert JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1)) == \ + expand(couple(uncouple( + JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1)) ))) + assert JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1)) == \ + expand(couple(uncouple( + JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1)) ))) + assert JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, S.Half, 1)) == \ + expand(couple(uncouple( + JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, S.Half, 1)) ))) + assert JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1)) == \ + expand(couple(uncouple( + JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1)) ))) + assert JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1)) == \ + expand(couple(uncouple( + JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1)) ))) + assert JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, S.Half, 1)) == \ + expand(couple(uncouple( + JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, S.Half, 1)) ))) + assert JzKetCoupled(Rational(5, 2), Rational(5, 2), (S.Half, S.Half, S.Half, 1)) == \ + expand(couple(uncouple( + JzKetCoupled(Rational(5, 2), Rational(5, 2), (S.Half, S.Half, S.Half, 1)) ))) + assert JzKetCoupled(Rational(5, 2), Rational(3, 2), (S.Half, S.Half, S.Half, 1)) == \ + expand(couple(uncouple( + JzKetCoupled(Rational(5, 2), Rational(3, 2), (S.Half, S.Half, S.Half, 1)) ))) + assert JzKetCoupled(Rational(5, 2), S.Half, (S.Half, S.Half, S.Half, 1)) == \ + expand(couple(uncouple( + JzKetCoupled(Rational(5, 2), S.Half, (S.Half, S.Half, S.Half, 1)) ))) + assert JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1)) == \ + expand(couple(uncouple( + JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1)) ))) + assert JzKetCoupled(Rational(5, 2), Rational(-3, 2), (S.Half, S.Half, S.Half, 1)) == \ + expand(couple(uncouple( + JzKetCoupled(Rational(5, 2), Rational(-3, 2), (S.Half, S.Half, S.Half, 1)) ))) + assert JzKetCoupled(Rational(5, 2), Rational(-5, 2), (S.Half, S.Half, S.Half, 1)) == \ + expand(couple(uncouple( + JzKetCoupled(Rational(5, 2), Rational(-5, 2), (S.Half, S.Half, S.Half, 1)) ))) + # Coupling j1+j3=j13, j2+j4=j24, j13+j24=j + # j1=1/2, j2=1/2, j3=1/2, j4=1/2, j13=1, j24=0 + assert JzKetCoupled(1, 1, (S.Half, S.Half, S.Half, S.Half), ((1, 3, 1), (2, 4, 0), (1, 2, 1)) ) == \ + expand(couple(uncouple( JzKetCoupled(1, 1, (S.Half, S.Half, S.Half, S.Half), ((1, 3, 1), (2, 4, 0), (1, 2, 1)) ) ), ((1, 3), (2, 4), (1, 2)) )) + assert JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 3, 1), (2, 4, 0), (1, 2, 1)) ) == \ + expand(couple(uncouple( JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 3, 1), (2, 4, 0), (1, 2, 1)) ) ), ((1, 3), (2, 4), (1, 2)) )) + assert JzKetCoupled(1, -1, (S.Half, S.Half, S.Half, S.Half), ((1, 3, 1), (2, 4, 0), (1, 2, 1)) ) == \ + expand(couple(uncouple( JzKetCoupled(1, -1, (S.Half, S.Half, S.Half, S.Half), ((1, 3, 1), (2, 4, 0), (1, 2, 1)) ) ), ((1, 3), (2, 4), (1, 2)) )) + # j1=1/2, j2=1/2, j3=1/2, j4=1, j13=1, j24=1/2 + assert JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 3, 1), (2, 4, S.Half), (1, 2, S.Half)) ) == \ + expand(couple(uncouple( JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 3, 1), (2, 4, S.Half), (1, 2, S.Half)) )), ((1, 3), (2, 4), (1, 2)) )) + assert JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 3, 1), (2, 4, S.Half), (1, 2, S.Half)) ) == \ + expand(couple(uncouple( JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 3, 1), (2, 4, S.Half), (1, 2, S.Half)) ) ), ((1, 3), (2, 4), (1, 2)) )) + assert JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, S.Half, 1), ((1, 3, 1), (2, 4, S.Half), (1, 2, Rational(3, 2))) ) == \ + expand(couple(uncouple( JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, S.Half, 1), ((1, 3, 1), (2, 4, S.Half), (1, 2, Rational(3, 2))) ) ), ((1, 3), (2, 4), (1, 2)) )) + assert JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 3, 1), (2, 4, S.Half), (1, 2, Rational(3, 2))) ) == \ + expand(couple(uncouple( JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 3, 1), (2, 4, S.Half), (1, 2, Rational(3, 2))) ) ), ((1, 3), (2, 4), (1, 2)) )) + assert JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 3, 1), (2, 4, S.Half), (1, 2, Rational(3, 2))) ) == \ + expand(couple(uncouple( JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 3, 1), (2, 4, S.Half), (1, 2, Rational(3, 2))) ) ), ((1, 3), (2, 4), (1, 2)) )) + assert JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, S.Half, 1), ((1, 3, 1), (2, 4, S.Half), (1, 2, Rational(3, 2))) ) == \ + expand(couple(uncouple( JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, S.Half, 1), ((1, 3, 1), (2, 4, S.Half), (1, 2, Rational(3, 2))) ) ), ((1, 3), (2, 4), (1, 2)) )) + # j1=1/2, j2=1, j3=1/2, j4=1, j13=0, j24=1 + assert JzKetCoupled(1, 1, (S.Half, 1, S.Half, 1), ((1, 3, 0), (2, 4, 1), (1, 2, 1)) ) == \ + expand(couple(uncouple( JzKetCoupled(1, 1, (S.Half, 1, S.Half, 1), ( + (1, 3, 0), (2, 4, 1), (1, 2, 1))) ), ((1, 3), (2, 4), (1, 2)) )) + assert JzKetCoupled(1, 0, (S.Half, 1, S.Half, 1), ((1, 3, 0), (2, 4, 1), (1, 2, 1)) ) == \ + expand(couple(uncouple( JzKetCoupled(1, 0, (S.Half, 1, S.Half, 1), ( + (1, 3, 0), (2, 4, 1), (1, 2, 1))) ), ((1, 3), (2, 4), (1, 2)) )) + assert JzKetCoupled(1, -1, (S.Half, 1, S.Half, 1), ((1, 3, 0), (2, 4, 1), (1, 2, 1)) ) == \ + expand(couple(uncouple( JzKetCoupled(1, -1, (S.Half, 1, S.Half, 1), ( + (1, 3, 0), (2, 4, 1), (1, 2, 1))) ), ((1, 3), (2, 4), (1, 2)) )) + # j1=1/2, j2=1, j3=1/2, j4=1, j13=1, j24=1 + assert JzKetCoupled(0, 0, (S.Half, 1, S.Half, 1), ((1, 3, 1), (2, 4, 1), (1, 2, 0)) ) == \ + expand(couple(uncouple( JzKetCoupled(0, 0, (S.Half, 1, S.Half, 1), ( + (1, 3, 1), (2, 4, 1), (1, 2, 0))) ), ((1, 3), (2, 4), (1, 2)) )) + assert JzKetCoupled(1, 1, (S.Half, 1, S.Half, 1), ((1, 3, 1), (2, 4, 1), (1, 2, 1)) ) == \ + expand(couple(uncouple( JzKetCoupled(1, 1, (S.Half, 1, S.Half, 1), ( + (1, 3, 1), (2, 4, 1), (1, 2, 1))) ), ((1, 3), (2, 4), (1, 2)) )) + assert JzKetCoupled(1, 0, (S.Half, 1, S.Half, 1), ((1, 3, 1), (2, 4, 1), (1, 2, 1)) ) == \ + expand(couple(uncouple( JzKetCoupled(1, 0, (S.Half, 1, S.Half, 1), ( + (1, 3, 1), (2, 4, 1), (1, 2, 1))) ), ((1, 3), (2, 4), (1, 2)) )) + assert JzKetCoupled(1, -1, (S.Half, 1, S.Half, 1), ((1, 3, 1), (2, 4, 1), (1, 2, 1)) ) == \ + expand(couple(uncouple( JzKetCoupled(1, -1, (S.Half, 1, S.Half, 1), ( + (1, 3, 1), (2, 4, 1), (1, 2, 1))) ), ((1, 3), (2, 4), (1, 2)) )) + assert JzKetCoupled(2, 2, (S.Half, 1, S.Half, 1), ((1, 3, 1), (2, 4, 1), (1, 2, 2)) ) == \ + expand(couple(uncouple( JzKetCoupled(2, 2, (S.Half, 1, S.Half, 1), ( + (1, 3, 1), (2, 4, 1), (1, 2, 2))) ), ((1, 3), (2, 4), (1, 2)) )) + assert JzKetCoupled(2, 1, (S.Half, 1, S.Half, 1), ((1, 3, 1), (2, 4, 1), (1, 2, 2)) ) == \ + expand(couple(uncouple( JzKetCoupled(2, 1, (S.Half, 1, S.Half, 1), ( + (1, 3, 1), (2, 4, 1), (1, 2, 2))) ), ((1, 3), (2, 4), (1, 2)) )) + assert JzKetCoupled(2, 0, (S.Half, 1, S.Half, 1), ((1, 3, 1), (2, 4, 1), (1, 2, 2)) ) == \ + expand(couple(uncouple( JzKetCoupled(2, 0, (S.Half, 1, S.Half, 1), ( + (1, 3, 1), (2, 4, 1), (1, 2, 2))) ), ((1, 3), (2, 4), (1, 2)) )) + assert JzKetCoupled(2, -1, (S.Half, 1, S.Half, 1), ((1, 3, 1), (2, 4, 1), (1, 2, 2)) ) == \ + expand(couple(uncouple( JzKetCoupled(2, -1, (S.Half, 1, S.Half, 1), ( + (1, 3, 1), (2, 4, 1), (1, 2, 2))) ), ((1, 3), (2, 4), (1, 2)) )) + assert JzKetCoupled(2, -2, (S.Half, 1, S.Half, 1), ((1, 3, 1), (2, 4, 1), (1, 2, 2)) ) == \ + expand(couple(uncouple( JzKetCoupled(2, -2, (S.Half, 1, S.Half, 1), ( + (1, 3, 1), (2, 4, 1), (1, 2, 2))) ), ((1, 3), (2, 4), (1, 2)) )) + + +def test_couple_2_states_numerical(): + # j1=1/2, j2=1/2 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half))) == \ + JzKetCoupled(1, 1, (S.Half, S.Half)) + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)))) == \ + sqrt(2)*JzKetCoupled(0, 0, (S( + 1)/2, S.Half))/2 + sqrt(2)*JzKetCoupled(1, 0, (S.Half, S.Half))/2 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half))) == \ + -sqrt(2)*JzKetCoupled(0, 0, (S( + 1)/2, S.Half))/2 + sqrt(2)*JzKetCoupled(1, 0, (S.Half, S.Half))/2 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)))) == \ + JzKetCoupled(1, -1, (S.Half, S.Half)) + # j1=1, j2=1/2 + assert couple(TensorProduct(JzKet(1, 1), JzKet(S.Half, S.Half))) == \ + JzKetCoupled(Rational(3, 2), Rational(3, 2), (1, S.Half)) + assert couple(TensorProduct(JzKet(1, 1), JzKet(S.Half, Rational(-1, 2)))) == \ + sqrt(6)*JzKetCoupled(S.Half, S.Half, (1, S.Half))/3 + sqrt( + 3)*JzKetCoupled(Rational(3, 2), S.Half, (1, S.Half))/3 + assert couple(TensorProduct(JzKet(1, 0), JzKet(S.Half, S.Half))) == \ + -sqrt(3)*JzKetCoupled(S.Half, S.Half, (1, S.Half))/3 + \ + sqrt(6)*JzKetCoupled(Rational(3, 2), S.Half, (1, S.Half))/3 + assert couple(TensorProduct(JzKet(1, 0), JzKet(S.Half, Rational(-1, 2)))) == \ + sqrt(3)*JzKetCoupled(S.Half, Rational(-1, 2), (1, S.Half))/3 + \ + sqrt(6)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (1, S.Half))/3 + assert couple(TensorProduct(JzKet(1, -1), JzKet(S.Half, S.Half))) == \ + -sqrt(6)*JzKetCoupled(S.Half, Rational(-1, 2), (1, S( + 1)/2))/3 + sqrt(3)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (1, S.Half))/3 + assert couple(TensorProduct(JzKet(1, -1), JzKet(S.Half, Rational(-1, 2)))) == \ + JzKetCoupled(Rational(3, 2), Rational(-3, 2), (1, S.Half)) + # j1=1, j2=1 + assert couple(TensorProduct(JzKet(1, 1), JzKet(1, 1))) == \ + JzKetCoupled(2, 2, (1, 1)) + assert couple(TensorProduct(JzKet(1, 1), JzKet(1, 0))) == \ + sqrt(2)*JzKetCoupled( + 1, 1, (1, 1))/2 + sqrt(2)*JzKetCoupled(2, 1, (1, 1))/2 + assert couple(TensorProduct(JzKet(1, 1), JzKet(1, -1))) == \ + sqrt(3)*JzKetCoupled(0, 0, (1, 1))/3 + sqrt(2)*JzKetCoupled( + 1, 0, (1, 1))/2 + sqrt(6)*JzKetCoupled(2, 0, (1, 1))/6 + assert couple(TensorProduct(JzKet(1, 0), JzKet(1, 1))) == \ + -sqrt(2)*JzKetCoupled( + 1, 1, (1, 1))/2 + sqrt(2)*JzKetCoupled(2, 1, (1, 1))/2 + assert couple(TensorProduct(JzKet(1, 0), JzKet(1, 0))) == \ + -sqrt(3)*JzKetCoupled( + 0, 0, (1, 1))/3 + sqrt(6)*JzKetCoupled(2, 0, (1, 1))/3 + assert couple(TensorProduct(JzKet(1, 0), JzKet(1, -1))) == \ + sqrt(2)*JzKetCoupled( + 1, -1, (1, 1))/2 + sqrt(2)*JzKetCoupled(2, -1, (1, 1))/2 + assert couple(TensorProduct(JzKet(1, -1), JzKet(1, 1))) == \ + sqrt(3)*JzKetCoupled(0, 0, (1, 1))/3 - sqrt(2)*JzKetCoupled( + 1, 0, (1, 1))/2 + sqrt(6)*JzKetCoupled(2, 0, (1, 1))/6 + assert couple(TensorProduct(JzKet(1, -1), JzKet(1, 0))) == \ + -sqrt(2)*JzKetCoupled( + 1, -1, (1, 1))/2 + sqrt(2)*JzKetCoupled(2, -1, (1, 1))/2 + assert couple(TensorProduct(JzKet(1, -1), JzKet(1, -1))) == \ + JzKetCoupled(2, -2, (1, 1)) + # j1=3/2, j2=1/2 + assert couple(TensorProduct(JzKet(Rational(3, 2), Rational(3, 2)), JzKet(S.Half, S.Half))) == \ + JzKetCoupled(2, 2, (Rational(3, 2), S.Half)) + assert couple(TensorProduct(JzKet(Rational(3, 2), Rational(3, 2)), JzKet(S.Half, Rational(-1, 2)))) == \ + sqrt(3)*JzKetCoupled( + 1, 1, (Rational(3, 2), S.Half))/2 + JzKetCoupled(2, 1, (Rational(3, 2), S.Half))/2 + assert couple(TensorProduct(JzKet(Rational(3, 2), S.Half), JzKet(S.Half, S.Half))) == \ + -JzKetCoupled(1, 1, (S( + 3)/2, S.Half))/2 + sqrt(3)*JzKetCoupled(2, 1, (Rational(3, 2), S.Half))/2 + assert couple(TensorProduct(JzKet(Rational(3, 2), S.Half), JzKet(S.Half, Rational(-1, 2)))) == \ + sqrt(2)*JzKetCoupled(1, 0, (S( + 3)/2, S.Half))/2 + sqrt(2)*JzKetCoupled(2, 0, (Rational(3, 2), S.Half))/2 + assert couple(TensorProduct(JzKet(Rational(3, 2), Rational(-1, 2)), JzKet(S.Half, S.Half))) == \ + -sqrt(2)*JzKetCoupled(1, 0, (S( + 3)/2, S.Half))/2 + sqrt(2)*JzKetCoupled(2, 0, (Rational(3, 2), S.Half))/2 + assert couple(TensorProduct(JzKet(Rational(3, 2), Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)))) == \ + JzKetCoupled(1, -1, (S( + 3)/2, S.Half))/2 + sqrt(3)*JzKetCoupled(2, -1, (Rational(3, 2), S.Half))/2 + assert couple(TensorProduct(JzKet(Rational(3, 2), Rational(-3, 2)), JzKet(S.Half, S.Half))) == \ + -sqrt(3)*JzKetCoupled(1, -1, (Rational(3, 2), S.Half))/2 + \ + JzKetCoupled(2, -1, (Rational(3, 2), S.Half))/2 + assert couple(TensorProduct(JzKet(Rational(3, 2), Rational(-3, 2)), JzKet(S.Half, Rational(-1, 2)))) == \ + JzKetCoupled(2, -2, (Rational(3, 2), S.Half)) + + +def test_couple_3_states_numerical(): + # Default coupling + # j1=1/2,j2=1/2,j3=1/2 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half))) == \ + JzKetCoupled(Rational(3, 2), S( + 3)/2, (S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2))) ) + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)))) == \ + sqrt(6)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, S.Half)) )/3 + \ + sqrt(3)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.One/ + 2), ((1, 2, 1), (1, 3, Rational(3, 2))) )/3 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half))) == \ + sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half), ((1, 2, 0), (1, 3, S.Half)) )/2 - \ + sqrt(6)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, S.Half)) )/6 + \ + sqrt(3)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.One/ + 2), ((1, 2, 1), (1, 3, Rational(3, 2))) )/3 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)))) == \ + sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half), ((1, 2, 0), (1, 3, S.Half)) )/2 + \ + sqrt(6)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, S.Half)) )/6 + \ + sqrt(3)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.One + /2), ((1, 2, 1), (1, 3, Rational(3, 2))) )/3 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half))) == \ + -sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half), ((1, 2, 0), (1, 3, S.Half)) )/2 - \ + sqrt(6)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, S.Half)) )/6 + \ + sqrt(3)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.One/ + 2), ((1, 2, 1), (1, 3, Rational(3, 2))) )/3 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)))) == \ + -sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half), ((1, 2, 0), (1, 3, S.Half)) )/2 + \ + sqrt(6)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, S.Half)) )/6 + \ + sqrt(3)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.One + /2), ((1, 2, 1), (1, 3, Rational(3, 2))) )/3 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half))) == \ + -sqrt(6)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, S.Half)) )/3 + \ + sqrt(3)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.One + /2), ((1, 2, 1), (1, 3, Rational(3, 2))) )/3 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)))) == \ + JzKetCoupled(Rational(3, 2), -S( + 3)/2, (S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2))) ) + # j1=S.Half, j2=S.Half, j3=1 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 1))) == \ + JzKetCoupled(2, 2, (S.Half, S.Half, 1), ((1, 2, 1), (1, 3, 2)) ) + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0))) == \ + sqrt(2)*JzKetCoupled(1, 1, (S.Half, S.Half, 1), ((1, 2, 1), (1, 3, 1)) )/2 + \ + sqrt(2)*JzKetCoupled( + 2, 1, (S.Half, S.Half, 1), ((1, 2, 1), (1, 3, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1))) == \ + sqrt(3)*JzKetCoupled(0, 0, (S.Half, S.Half, 1), ((1, 2, 1), (1, 3, 0)) )/3 + \ + sqrt(2)*JzKetCoupled(1, 0, (S.Half, S.Half, 1), ((1, 2, 1), (1, 3, 1)) )/2 + \ + sqrt(6)*JzKetCoupled( + 2, 0, (S.Half, S.Half, 1), ((1, 2, 1), (1, 3, 2)) )/6 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1))) == \ + sqrt(2)*JzKetCoupled(1, 1, (S.Half, S.Half, 1), ((1, 2, 0), (1, 3, 1)) )/2 - \ + JzKetCoupled(1, 1, (S.Half, S.Half, 1), ((1, 2, 1), (1, 3, 1)) )/2 + \ + JzKetCoupled(2, 1, (S.Half, S.Half, 1), ((1, 2, 1), (1, 3, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0))) == \ + -sqrt(6)*JzKetCoupled(0, 0, (S.Half, S.Half, 1), ((1, 2, 1), (1, 3, 0)) )/6 + \ + sqrt(2)*JzKetCoupled(1, 0, (S.Half, S.Half, 1), ((1, 2, 0), (1, 3, 1)) )/2 + \ + sqrt(3)*JzKetCoupled( + 2, 0, (S.Half, S.Half, 1), ((1, 2, 1), (1, 3, 2)) )/3 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1))) == \ + sqrt(2)*JzKetCoupled(1, -1, (S.Half, S.Half, 1), ((1, 2, 0), (1, 3, 1)) )/2 + \ + JzKetCoupled(1, -1, (S.Half, S.Half, 1), ((1, 2, 1), (1, 3, 1)) )/2 + \ + JzKetCoupled(2, -1, (S.Half, S.Half, 1), ((1, 2, 1), (1, 3, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1))) == \ + -sqrt(2)*JzKetCoupled(1, 1, (S.Half, S.Half, 1), ((1, 2, 0), (1, 3, 1)) )/2 - \ + JzKetCoupled(1, 1, (S.Half, S.Half, 1), ((1, 2, 1), (1, 3, 1)) )/2 + \ + JzKetCoupled(2, 1, (S.Half, S.Half, 1), ((1, 2, 1), (1, 3, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0))) == \ + -sqrt(6)*JzKetCoupled(0, 0, (S.Half, S.Half, 1), ((1, 2, 1), (1, 3, 0)) )/6 - \ + sqrt(2)*JzKetCoupled(1, 0, (S.Half, S.Half, 1), ((1, 2, 0), (1, 3, 1)) )/2 + \ + sqrt(3)*JzKetCoupled( + 2, 0, (S.Half, S.Half, 1), ((1, 2, 1), (1, 3, 2)) )/3 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1))) == \ + -sqrt(2)*JzKetCoupled(1, -1, (S.Half, S.Half, 1), ((1, 2, 0), (1, 3, 1)) )/2 + \ + JzKetCoupled(1, -1, (S.Half, S.Half, 1), ((1, 2, 1), (1, 3, 1)) )/2 + \ + JzKetCoupled(2, -1, (S.Half, S.Half, 1), ((1, 2, 1), (1, 3, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1))) == \ + sqrt(3)*JzKetCoupled(0, 0, (S.Half, S.Half, 1), ((1, 2, 1), (1, 3, 0)) )/3 - \ + sqrt(2)*JzKetCoupled(1, 0, (S.Half, S.Half, 1), ((1, 2, 1), (1, 3, 1)) )/2 + \ + sqrt(6)*JzKetCoupled( + 2, 0, (S.Half, S.Half, 1), ((1, 2, 1), (1, 3, 2)) )/6 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0))) == \ + -sqrt(2)*JzKetCoupled(1, -1, (S.Half, S.Half, 1), ((1, 2, 1), (1, 3, 1)) )/2 + \ + sqrt(2)*JzKetCoupled( + 2, -1, (S.Half, S.Half, 1), ((1, 2, 1), (1, 3, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1))) == \ + JzKetCoupled(2, -2, (S.Half, S.Half, 1), ((1, 2, 1), (1, 3, 2)) ) + # j1=S.Half, j2=1, j3=1 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, 1))) == \ + JzKetCoupled( + Rational(5, 2), Rational(5, 2), (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, Rational(5, 2))) ) + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, 0))) == \ + sqrt(15)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, Rational(3, 2))) )/5 + \ + sqrt(10)*JzKetCoupled(S( + 5)/2, Rational(3, 2), (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, -1))) == \ + sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, S.Half)) )/2 + \ + sqrt(10)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, Rational(3, 2))) )/5 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), S.Half, (S.Half, 1, 1), ((1, + 2, Rational(3, 2)), (1, 3, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 1))) == \ + sqrt(3)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, 1, 1), ((1, 2, S.Half), (1, 3, Rational(3, 2))) )/3 - \ + 2*sqrt(15)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, Rational(3, 2))) )/15 + \ + sqrt(10)*JzKetCoupled(S( + 5)/2, Rational(3, 2), (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 0))) == \ + JzKetCoupled(S.Half, S.Half, (S.Half, 1, 1), ((1, 2, S.Half), (1, 3, S.Half)) )/3 - \ + sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, S.Half)) )/3 + \ + sqrt(2)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, 1, 1), ((1, 2, S.Half), (1, 3, Rational(3, 2))) )/3 + \ + sqrt(10)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, Rational(3, 2))) )/15 + \ + sqrt(10)*JzKetCoupled(S( + 5)/2, S.Half, (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, -1))) == \ + sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, 1, 1), ((1, 2, S.Half), (1, 3, S.Half)) )/3 + \ + JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, S.Half)) )/3 + \ + JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, 2, S.Half), (1, 3, Rational(3, 2))) )/3 + \ + 4*sqrt(5)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, + 2, Rational(3, 2)), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 1))) == \ + -2*JzKetCoupled(S.Half, S.Half, (S.Half, 1, 1), ((1, 2, S.Half), (1, 3, S.Half)) )/3 + \ + sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, S.Half)) )/6 + \ + sqrt(2)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, 1, 1), ((1, 2, S.Half), (1, 3, Rational(3, 2))) )/3 - \ + 2*sqrt(10)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, Rational(3, 2))) )/15 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), S.Half, (S.Half, 1, 1), ((1, + 2, Rational(3, 2)), (1, 3, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 0))) == \ + -sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, 1, 1), ((1, 2, S.Half), (1, 3, S.Half)) )/3 - \ + JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, S.Half)) )/3 + \ + 2*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, 2, S.Half), (1, 3, Rational(3, 2))) )/3 - \ + sqrt(5)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, + 2, Rational(3, 2)), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, -1))) == \ + sqrt(6)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, 1, 1), ((1, 2, S.Half), (1, 3, Rational(3, 2))) )/3 + \ + sqrt(30)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(-3, 2), (S.Half, 1, 1), ((1, + 2, Rational(3, 2)), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 1))) == \ + -sqrt(6)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, 1, 1), ((1, 2, S.Half), (1, 3, Rational(3, 2))) )/3 - \ + sqrt(30)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(S( + 5)/2, Rational(3, 2), (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 0))) == \ + -sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, 1, 1), ((1, 2, S.Half), (1, 3, S.Half)) )/3 - \ + JzKetCoupled(S.Half, S.Half, (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, S.Half)) )/3 - \ + 2*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, 1, 1), ((1, 2, S.Half), (1, 3, Rational(3, 2))) )/3 + \ + sqrt(5)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(S( + 5)/2, S.Half, (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, -1))) == \ + -2*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, 1, 1), ((1, 2, S.Half), (1, 3, S.Half)) )/3 + \ + sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, S.Half)) )/6 - \ + sqrt(2)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, 2, S.Half), (1, 3, Rational(3, 2))) )/3 + \ + 2*sqrt(10)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, Rational(3, 2))) )/15 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, + 2, Rational(3, 2)), (1, 3, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 1))) == \ + sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, 1, 1), ((1, 2, S.Half), (1, 3, S.Half)) )/3 + \ + JzKetCoupled(S.Half, S.Half, (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, S.Half)) )/3 - \ + JzKetCoupled(Rational(3, 2), S.Half, (S.Half, 1, 1), ((1, 2, S.Half), (1, 3, Rational(3, 2))) )/3 - \ + 4*sqrt(5)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(S( + 5)/2, S.Half, (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 0))) == \ + JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, 1, 1), ((1, 2, S.Half), (1, 3, S.Half)) )/3 - \ + sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, S.Half)) )/3 - \ + sqrt(2)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, 2, S.Half), (1, 3, Rational(3, 2))) )/3 - \ + sqrt(10)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, Rational(3, 2))) )/15 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, + 2, Rational(3, 2)), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, -1))) == \ + -sqrt(3)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, 1, 1), ((1, 2, S.Half), (1, 3, Rational(3, 2))) )/3 + \ + 2*sqrt(15)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, Rational(3, 2))) )/15 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), Rational(-3, 2), (S.Half, 1, 1), ((1, + 2, Rational(3, 2)), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 1))) == \ + sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, S.Half)) )/2 - \ + sqrt(10)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, Rational(3, 2))) )/5 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, + 2, Rational(3, 2)), (1, 3, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 0))) == \ + -sqrt(15)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, Rational(3, 2))) )/5 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), Rational(-3, 2), (S.Half, 1, 1), ((1, + 2, Rational(3, 2)), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, -1))) == \ + JzKetCoupled(S( + 5)/2, Rational(-5, 2), (S.Half, 1, 1), ((1, 2, Rational(3, 2)), (1, 3, Rational(5, 2))) ) + # j1=1, j2=1, j3=1 + assert couple(TensorProduct(JzKet(1, 1), JzKet(1, 1), JzKet(1, 1))) == \ + JzKetCoupled(3, 3, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) ) + assert couple(TensorProduct(JzKet(1, 1), JzKet(1, 1), JzKet(1, 0))) == \ + sqrt(6)*JzKetCoupled(2, 2, (1, 1, 1), ((1, 2, 2), (1, 3, 2)) )/3 + \ + sqrt(3)*JzKetCoupled(3, 2, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) )/3 + assert couple(TensorProduct(JzKet(1, 1), JzKet(1, 1), JzKet(1, -1))) == \ + sqrt(15)*JzKetCoupled(1, 1, (1, 1, 1), ((1, 2, 2), (1, 3, 1)) )/5 + \ + sqrt(3)*JzKetCoupled(2, 1, (1, 1, 1), ((1, 2, 2), (1, 3, 2)) )/3 + \ + sqrt(15)*JzKetCoupled(3, 1, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) )/15 + assert couple(TensorProduct(JzKet(1, 1), JzKet(1, 0), JzKet(1, 1))) == \ + sqrt(2)*JzKetCoupled(2, 2, (1, 1, 1), ((1, 2, 1), (1, 3, 2)) )/2 - \ + sqrt(6)*JzKetCoupled(2, 2, (1, 1, 1), ((1, 2, 2), (1, 3, 2)) )/6 + \ + sqrt(3)*JzKetCoupled(3, 2, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) )/3 + assert couple(TensorProduct(JzKet(1, 1), JzKet(1, 0), JzKet(1, 0))) == \ + JzKetCoupled(1, 1, (1, 1, 1), ((1, 2, 1), (1, 3, 1)) )/2 - \ + sqrt(15)*JzKetCoupled(1, 1, (1, 1, 1), ((1, 2, 2), (1, 3, 1)) )/10 + \ + JzKetCoupled(2, 1, (1, 1, 1), ((1, 2, 1), (1, 3, 2)) )/2 + \ + sqrt(3)*JzKetCoupled(2, 1, (1, 1, 1), ((1, 2, 2), (1, 3, 2)) )/6 + \ + 2*sqrt(15)*JzKetCoupled(3, 1, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) )/15 + assert couple(TensorProduct(JzKet(1, 1), JzKet(1, 0), JzKet(1, -1))) == \ + sqrt(6)*JzKetCoupled(0, 0, (1, 1, 1), ((1, 2, 1), (1, 3, 0)) )/6 + \ + JzKetCoupled(1, 0, (1, 1, 1), ((1, 2, 1), (1, 3, 1)) )/2 + \ + sqrt(15)*JzKetCoupled(1, 0, (1, 1, 1), ((1, 2, 2), (1, 3, 1)) )/10 + \ + sqrt(3)*JzKetCoupled(2, 0, (1, 1, 1), ((1, 2, 1), (1, 3, 2)) )/6 + \ + JzKetCoupled(2, 0, (1, 1, 1), ((1, 2, 2), (1, 3, 2)) )/2 + \ + sqrt(10)*JzKetCoupled(3, 0, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) )/10 + assert couple(TensorProduct(JzKet(1, 1), JzKet(1, -1), JzKet(1, 1))) == \ + sqrt(3)*JzKetCoupled(1, 1, (1, 1, 1), ((1, 2, 0), (1, 3, 1)) )/3 - \ + JzKetCoupled(1, 1, (1, 1, 1), ((1, 2, 1), (1, 3, 1)) )/2 + \ + sqrt(15)*JzKetCoupled(1, 1, (1, 1, 1), ((1, 2, 2), (1, 3, 1)) )/30 + \ + JzKetCoupled(2, 1, (1, 1, 1), ((1, 2, 1), (1, 3, 2)) )/2 - \ + sqrt(3)*JzKetCoupled(2, 1, (1, 1, 1), ((1, 2, 2), (1, 3, 2)) )/6 + \ + sqrt(15)*JzKetCoupled(3, 1, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) )/15 + assert couple(TensorProduct(JzKet(1, 1), JzKet(1, -1), JzKet(1, 0))) == \ + -sqrt(6)*JzKetCoupled(0, 0, (1, 1, 1), ((1, 2, 1), (1, 3, 0)) )/6 + \ + sqrt(3)*JzKetCoupled(1, 0, (1, 1, 1), ((1, 2, 0), (1, 3, 1)) )/3 - \ + sqrt(15)*JzKetCoupled(1, 0, (1, 1, 1), ((1, 2, 2), (1, 3, 1)) )/15 + \ + sqrt(3)*JzKetCoupled(2, 0, (1, 1, 1), ((1, 2, 1), (1, 3, 2)) )/3 + \ + sqrt(10)*JzKetCoupled(3, 0, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) )/10 + assert couple(TensorProduct(JzKet(1, 1), JzKet(1, -1), JzKet(1, -1))) == \ + sqrt(3)*JzKetCoupled(1, -1, (1, 1, 1), ((1, 2, 0), (1, 3, 1)) )/3 + \ + JzKetCoupled(1, -1, (1, 1, 1), ((1, 2, 1), (1, 3, 1)) )/2 + \ + sqrt(15)*JzKetCoupled(1, -1, (1, 1, 1), ((1, 2, 2), (1, 3, 1)) )/30 + \ + JzKetCoupled(2, -1, (1, 1, 1), ((1, 2, 1), (1, 3, 2)) )/2 + \ + sqrt(3)*JzKetCoupled(2, -1, (1, 1, 1), ((1, 2, 2), (1, 3, 2)) )/6 + \ + sqrt(15)*JzKetCoupled(3, -1, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) )/15 + assert couple(TensorProduct(JzKet(1, 0), JzKet(1, 1), JzKet(1, 1))) == \ + -sqrt(2)*JzKetCoupled(2, 2, (1, 1, 1), ((1, 2, 1), (1, 3, 2)) )/2 - \ + sqrt(6)*JzKetCoupled(2, 2, (1, 1, 1), ((1, 2, 2), (1, 3, 2)) )/6 + \ + sqrt(3)*JzKetCoupled(3, 2, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) )/3 + assert couple(TensorProduct(JzKet(1, 0), JzKet(1, 1), JzKet(1, 0))) == \ + -JzKetCoupled(1, 1, (1, 1, 1), ((1, 2, 1), (1, 3, 1)) )/2 - \ + sqrt(15)*JzKetCoupled(1, 1, (1, 1, 1), ((1, 2, 2), (1, 3, 1)) )/10 - \ + JzKetCoupled(2, 1, (1, 1, 1), ((1, 2, 1), (1, 3, 2)) )/2 + \ + sqrt(3)*JzKetCoupled(2, 1, (1, 1, 1), ((1, 2, 2), (1, 3, 2)) )/6 + \ + 2*sqrt(15)*JzKetCoupled(3, 1, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) )/15 + assert couple(TensorProduct(JzKet(1, 0), JzKet(1, 1), JzKet(1, -1))) == \ + -sqrt(6)*JzKetCoupled(0, 0, (1, 1, 1), ((1, 2, 1), (1, 3, 0)) )/6 - \ + JzKetCoupled(1, 0, (1, 1, 1), ((1, 2, 1), (1, 3, 1)) )/2 + \ + sqrt(15)*JzKetCoupled(1, 0, (1, 1, 1), ((1, 2, 2), (1, 3, 1)) )/10 - \ + sqrt(3)*JzKetCoupled(2, 0, (1, 1, 1), ((1, 2, 1), (1, 3, 2)) )/6 + \ + JzKetCoupled(2, 0, (1, 1, 1), ((1, 2, 2), (1, 3, 2)) )/2 + \ + sqrt(10)*JzKetCoupled(3, 0, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) )/10 + assert couple(TensorProduct(JzKet(1, 0), JzKet(1, 0), JzKet(1, 1))) == \ + -sqrt(3)*JzKetCoupled(1, 1, (1, 1, 1), ((1, 2, 0), (1, 3, 1)) )/3 + \ + sqrt(15)*JzKetCoupled(1, 1, (1, 1, 1), ((1, 2, 2), (1, 3, 1)) )/15 - \ + sqrt(3)*JzKetCoupled(2, 1, (1, 1, 1), ((1, 2, 2), (1, 3, 2)) )/3 + \ + 2*sqrt(15)*JzKetCoupled(3, 1, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) )/15 + assert couple(TensorProduct(JzKet(1, 0), JzKet(1, 0), JzKet(1, 0))) == \ + -sqrt(3)*JzKetCoupled(1, 0, (1, 1, 1), ((1, 2, 0), (1, 3, 1)) )/3 - \ + 2*sqrt(15)*JzKetCoupled(1, 0, (1, 1, 1), ((1, 2, 2), (1, 3, 1)) )/15 + \ + sqrt(10)*JzKetCoupled(3, 0, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) )/5 + assert couple(TensorProduct(JzKet(1, 0), JzKet(1, 0), JzKet(1, -1))) == \ + -sqrt(3)*JzKetCoupled(1, -1, (1, 1, 1), ((1, 2, 0), (1, 3, 1)) )/3 + \ + sqrt(15)*JzKetCoupled(1, -1, (1, 1, 1), ((1, 2, 2), (1, 3, 1)) )/15 + \ + sqrt(3)*JzKetCoupled(2, -1, (1, 1, 1), ((1, 2, 2), (1, 3, 2)) )/3 + \ + 2*sqrt(15)*JzKetCoupled(3, -1, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) )/15 + assert couple(TensorProduct(JzKet(1, 0), JzKet(1, -1), JzKet(1, 1))) == \ + sqrt(6)*JzKetCoupled(0, 0, (1, 1, 1), ((1, 2, 1), (1, 3, 0)) )/6 - \ + JzKetCoupled(1, 0, (1, 1, 1), ((1, 2, 1), (1, 3, 1)) )/2 + \ + sqrt(15)*JzKetCoupled(1, 0, (1, 1, 1), ((1, 2, 2), (1, 3, 1)) )/10 + \ + sqrt(3)*JzKetCoupled(2, 0, (1, 1, 1), ((1, 2, 1), (1, 3, 2)) )/6 - \ + JzKetCoupled(2, 0, (1, 1, 1), ((1, 2, 2), (1, 3, 2)) )/2 + \ + sqrt(10)*JzKetCoupled(3, 0, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) )/10 + assert couple(TensorProduct(JzKet(1, 0), JzKet(1, -1), JzKet(1, 0))) == \ + -JzKetCoupled(1, -1, (1, 1, 1), ((1, 2, 1), (1, 3, 1)) )/2 - \ + sqrt(15)*JzKetCoupled(1, -1, (1, 1, 1), ((1, 2, 2), (1, 3, 1)) )/10 + \ + JzKetCoupled(2, -1, (1, 1, 1), ((1, 2, 1), (1, 3, 2)) )/2 - \ + sqrt(3)*JzKetCoupled(2, -1, (1, 1, 1), ((1, 2, 2), (1, 3, 2)) )/6 + \ + 2*sqrt(15)*JzKetCoupled(3, -1, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) )/15 + assert couple(TensorProduct(JzKet(1, 0), JzKet(1, -1), JzKet(1, -1))) == \ + sqrt(2)*JzKetCoupled(2, -2, (1, 1, 1), ((1, 2, 1), (1, 3, 2)) )/2 + \ + sqrt(6)*JzKetCoupled(2, -2, (1, 1, 1), ((1, 2, 2), (1, 3, 2)) )/6 + \ + sqrt(3)*JzKetCoupled(3, -2, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) )/3 + assert couple(TensorProduct(JzKet(1, -1), JzKet(1, 1), JzKet(1, 1))) == \ + sqrt(3)*JzKetCoupled(1, 1, (1, 1, 1), ((1, 2, 0), (1, 3, 1)) )/3 + \ + JzKetCoupled(1, 1, (1, 1, 1), ((1, 2, 1), (1, 3, 1)) )/2 + \ + sqrt(15)*JzKetCoupled(1, 1, (1, 1, 1), ((1, 2, 2), (1, 3, 1)) )/30 - \ + JzKetCoupled(2, 1, (1, 1, 1), ((1, 2, 1), (1, 3, 2)) )/2 - \ + sqrt(3)*JzKetCoupled(2, 1, (1, 1, 1), ((1, 2, 2), (1, 3, 2)) )/6 + \ + sqrt(15)*JzKetCoupled(3, 1, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) )/15 + assert couple(TensorProduct(JzKet(1, -1), JzKet(1, 1), JzKet(1, 0))) == \ + sqrt(6)*JzKetCoupled(0, 0, (1, 1, 1), ((1, 2, 1), (1, 3, 0)) )/6 + \ + sqrt(3)*JzKetCoupled(1, 0, (1, 1, 1), ((1, 2, 0), (1, 3, 1)) )/3 - \ + sqrt(15)*JzKetCoupled(1, 0, (1, 1, 1), ((1, 2, 2), (1, 3, 1)) )/15 - \ + sqrt(3)*JzKetCoupled(2, 0, (1, 1, 1), ((1, 2, 1), (1, 3, 2)) )/3 + \ + sqrt(10)*JzKetCoupled(3, 0, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) )/10 + assert couple(TensorProduct(JzKet(1, -1), JzKet(1, 1), JzKet(1, -1))) == \ + sqrt(3)*JzKetCoupled(1, -1, (1, 1, 1), ((1, 2, 0), (1, 3, 1)) )/3 - \ + JzKetCoupled(1, -1, (1, 1, 1), ((1, 2, 1), (1, 3, 1)) )/2 + \ + sqrt(15)*JzKetCoupled(1, -1, (1, 1, 1), ((1, 2, 2), (1, 3, 1)) )/30 - \ + JzKetCoupled(2, -1, (1, 1, 1), ((1, 2, 1), (1, 3, 2)) )/2 + \ + sqrt(3)*JzKetCoupled(2, -1, (1, 1, 1), ((1, 2, 2), (1, 3, 2)) )/6 + \ + sqrt(15)*JzKetCoupled(3, -1, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) )/15 + assert couple(TensorProduct(JzKet(1, -1), JzKet(1, 0), JzKet(1, 1))) == \ + -sqrt(6)*JzKetCoupled(0, 0, (1, 1, 1), ((1, 2, 1), (1, 3, 0)) )/6 + \ + JzKetCoupled(1, 0, (1, 1, 1), ((1, 2, 1), (1, 3, 1)) )/2 + \ + sqrt(15)*JzKetCoupled(1, 0, (1, 1, 1), ((1, 2, 2), (1, 3, 1)) )/10 - \ + sqrt(3)*JzKetCoupled(2, 0, (1, 1, 1), ((1, 2, 1), (1, 3, 2)) )/6 - \ + JzKetCoupled(2, 0, (1, 1, 1), ((1, 2, 2), (1, 3, 2)) )/2 + \ + sqrt(10)*JzKetCoupled(3, 0, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) )/10 + assert couple(TensorProduct(JzKet(1, -1), JzKet(1, 0), JzKet(1, 0))) == \ + JzKetCoupled(1, -1, (1, 1, 1), ((1, 2, 1), (1, 3, 1)) )/2 - \ + sqrt(15)*JzKetCoupled(1, -1, (1, 1, 1), ((1, 2, 2), (1, 3, 1)) )/10 - \ + JzKetCoupled(2, -1, (1, 1, 1), ((1, 2, 1), (1, 3, 2)) )/2 - \ + sqrt(3)*JzKetCoupled(2, -1, (1, 1, 1), ((1, 2, 2), (1, 3, 2)) )/6 + \ + 2*sqrt(15)*JzKetCoupled(3, -1, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) )/15 + assert couple(TensorProduct(JzKet(1, -1), JzKet(1, 0), JzKet(1, -1))) == \ + -sqrt(2)*JzKetCoupled(2, -2, (1, 1, 1), ((1, 2, 1), (1, 3, 2)) )/2 + \ + sqrt(6)*JzKetCoupled(2, -2, (1, 1, 1), ((1, 2, 2), (1, 3, 2)) )/6 + \ + sqrt(3)*JzKetCoupled(3, -2, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) )/3 + assert couple(TensorProduct(JzKet(1, -1), JzKet(1, -1), JzKet(1, 1))) == \ + sqrt(15)*JzKetCoupled(1, -1, (1, 1, 1), ((1, 2, 2), (1, 3, 1)) )/5 - \ + sqrt(3)*JzKetCoupled(2, -1, (1, 1, 1), ((1, 2, 2), (1, 3, 2)) )/3 + \ + sqrt(15)*JzKetCoupled(3, -1, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) )/15 + assert couple(TensorProduct(JzKet(1, -1), JzKet(1, -1), JzKet(1, 0))) == \ + -sqrt(6)*JzKetCoupled(2, -2, (1, 1, 1), ((1, 2, 2), (1, 3, 2)) )/3 + \ + sqrt(3)*JzKetCoupled(3, -2, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) )/3 + assert couple(TensorProduct(JzKet(1, -1), JzKet(1, -1), JzKet(1, -1))) == \ + JzKetCoupled(3, -3, (1, 1, 1), ((1, 2, 2), (1, 3, 3)) ) + # j1=S.Half, j2=S.Half, j3=Rational(3, 2) + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(Rational(3, 2), Rational(3, 2)))) == \ + JzKetCoupled(Rational(5, 2), S( + 5)/2, (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, Rational(5, 2))) ) + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(Rational(3, 2), S.Half))) == \ + sqrt(10)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, Rational(3, 2))) )/5 + \ + sqrt(15)*JzKetCoupled(Rational(5, 2), Rational(3, 2), (S.Half, S.Half, S(3) + /2), ((1, 2, 1), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(Rational(3, 2), Rational(-1, 2)))) == \ + sqrt(6)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, S.Half)) )/6 + \ + 2*sqrt(30)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, Rational(3, 2))) )/15 + \ + sqrt(30)*JzKetCoupled(Rational(5, 2), S( + 1)/2, (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(Rational(3, 2), Rational(-3, 2)))) == \ + sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, S.Half)) )/2 + \ + sqrt(10)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, Rational(3, 2))) )/5 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), -S( + 1)/2, (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(Rational(3, 2), Rational(3, 2)))) == \ + sqrt(2)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 2, 0), (1, 3, Rational(3, 2))) )/2 - \ + sqrt(30)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, Rational(3, 2))) )/10 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(3, 2), (S.Half, S.Half, S(3)/ + 2), ((1, 2, 1), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(Rational(3, 2), S.Half))) == \ + -sqrt(6)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, S.Half)) )/6 + \ + sqrt(2)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, Rational(3, 2)), ((1, 2, 0), (1, 3, Rational(3, 2))) )/2 - \ + sqrt(30)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, Rational(3, 2))) )/30 + \ + sqrt(30)*JzKetCoupled(Rational(5, 2), S( + 1)/2, (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(Rational(3, 2), Rational(-1, 2)))) == \ + -sqrt(6)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, S.Half)) )/6 + \ + sqrt(2)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 2, 0), (1, 3, Rational(3, 2))) )/2 + \ + sqrt(30)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, Rational(3, 2))) )/30 + \ + sqrt(30)*JzKetCoupled(Rational(5, 2), -S( + 1)/2, (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(Rational(3, 2), Rational(-3, 2)))) == \ + sqrt(2)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 2, 0), (1, 3, Rational(3, 2))) )/2 + \ + sqrt(30)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, Rational(3, 2))) )/10 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(-3, 2), (S.Half, S.Half, S(3) + /2), ((1, 2, 1), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(Rational(3, 2), Rational(3, 2)))) == \ + -sqrt(2)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 2, 0), (1, 3, Rational(3, 2))) )/2 - \ + sqrt(30)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, Rational(3, 2))) )/10 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(3, 2), (S.Half, S.Half, S(3)/ + 2), ((1, 2, 1), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(Rational(3, 2), S.Half))) == \ + -sqrt(6)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, S.Half)) )/6 - \ + sqrt(2)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, Rational(3, 2)), ((1, 2, 0), (1, 3, Rational(3, 2))) )/2 - \ + sqrt(30)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, Rational(3, 2))) )/30 + \ + sqrt(30)*JzKetCoupled(Rational(5, 2), S( + 1)/2, (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(Rational(3, 2), Rational(-1, 2)))) == \ + -sqrt(6)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, S.Half)) )/6 - \ + sqrt(2)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 2, 0), (1, 3, Rational(3, 2))) )/2 + \ + sqrt(30)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, Rational(3, 2))) )/30 + \ + sqrt(30)*JzKetCoupled(Rational(5, 2), -S( + 1)/2, (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(Rational(3, 2), Rational(-3, 2)))) == \ + -sqrt(2)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 2, 0), (1, 3, Rational(3, 2))) )/2 + \ + sqrt(30)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, Rational(3, 2))) )/10 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(-3, 2), (S.Half, S.Half, S(3) + /2), ((1, 2, 1), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(Rational(3, 2), Rational(3, 2)))) == \ + sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, S.Half)) )/2 - \ + sqrt(10)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, Rational(3, 2))) )/5 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), S( + 1)/2, (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(Rational(3, 2), S.Half))) == \ + sqrt(6)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, S.Half)) )/6 - \ + 2*sqrt(30)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, Rational(3, 2))) )/15 + \ + sqrt(30)*JzKetCoupled(Rational(5, 2), -S( + 1)/2, (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(Rational(3, 2), Rational(-1, 2)))) == \ + -sqrt(10)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, Rational(3, 2))) )/5 + \ + sqrt(15)*JzKetCoupled(Rational(5, 2), Rational(-3, 2), (S.Half, S.Half, S( + 3)/2), ((1, 2, 1), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(Rational(3, 2), Rational(-3, 2)))) == \ + JzKetCoupled(Rational(5, 2), -S( + 5)/2, (S.Half, S.Half, Rational(3, 2)), ((1, 2, 1), (1, 3, Rational(5, 2))) ) + # Couple j1 to j3 + # j1=1/2, j2=1/2, j3=1/2 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)), ((1, 3), (1, 2)) ) == \ + JzKetCoupled(Rational(3, 2), S( + 3)/2, (S.Half, S.Half, S.Half), ((1, 3, 1), (1, 2, Rational(3, 2))) ) + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (1, 2)) ) == \ + sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half), ((1, 3, 0), (1, 2, S.Half)) )/2 - \ + sqrt(6)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half), ((1, 3, 1), (1, 2, S.Half)) )/6 + \ + sqrt(3)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.One/ + 2), ((1, 3, 1), (1, 2, Rational(3, 2))) )/3 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)), ((1, 3), (1, 2)) ) == \ + sqrt(6)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half), ((1, 3, 1), (1, 2, S.Half)) )/3 + \ + sqrt(3)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.One/ + 2), ((1, 3, 1), (1, 2, Rational(3, 2))) )/3 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (1, 2)) ) == \ + sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half), ((1, 3, 0), (1, 2, S.Half)) )/2 + \ + sqrt(6)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half), ((1, 3, 1), (1, 2, S.Half)) )/6 + \ + sqrt(3)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.One + /2), ((1, 3, 1), (1, 2, Rational(3, 2))) )/3 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)), ((1, 3), (1, 2)) ) == \ + -sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half), ((1, 3, 0), (1, 2, S.Half)) )/2 - \ + sqrt(6)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half), ((1, 3, 1), (1, 2, S.Half)) )/6 + \ + sqrt(3)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.One/ + 2), ((1, 3, 1), (1, 2, Rational(3, 2))) )/3 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (1, 2)) ) == \ + -sqrt(6)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half), ((1, 3, 1), (1, 2, S.Half)) )/3 + \ + sqrt(3)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.One + /2), ((1, 3, 1), (1, 2, Rational(3, 2))) )/3 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)), ((1, 3), (1, 2)) ) == \ + -sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half), ((1, 3, 0), (1, 2, S.Half)) )/2 + \ + sqrt(6)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half), ((1, 3, 1), (1, 2, S.Half)) )/6 + \ + sqrt(3)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.One + /2), ((1, 3, 1), (1, 2, Rational(3, 2))) )/3 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))), ((1, 3), (1, 2)) ) == \ + JzKetCoupled(Rational(3, 2), -S( + 3)/2, (S.Half, S.Half, S.Half), ((1, 3, 1), (1, 2, Rational(3, 2))) ) + # j1=1/2, j2=1/2, j3=1 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 1)), ((1, 3), (1, 2)) ) == \ + JzKetCoupled(2, 2, (S.Half, S.Half, 1), ((1, 3, Rational(3, 2)), (1, 2, 2)) ) + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0)), ((1, 3), (1, 2)) ) == \ + sqrt(3)*JzKetCoupled(1, 1, (S.Half, S.Half, 1), ((1, 3, S.Half), (1, 2, 1)) )/3 - \ + sqrt(6)*JzKetCoupled(1, 1, (S.Half, S.Half, 1), ((1, 3, Rational(3, 2)), (1, 2, 1)) )/6 + \ + sqrt(2)*JzKetCoupled( + 2, 1, (S.Half, S.Half, 1), ((1, 3, Rational(3, 2)), (1, 2, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1)), ((1, 3), (1, 2)) ) == \ + -sqrt(3)*JzKetCoupled(0, 0, (S.Half, S.Half, 1), ((1, 3, S.Half), (1, 2, 0)) )/3 + \ + sqrt(3)*JzKetCoupled(1, 0, (S.Half, S.Half, 1), ((1, 3, S.Half), (1, 2, 1)) )/3 - \ + sqrt(6)*JzKetCoupled(1, 0, (S.Half, S.Half, 1), ((1, 3, Rational(3, 2)), (1, 2, 1)) )/6 + \ + sqrt(6)*JzKetCoupled( + 2, 0, (S.Half, S.Half, 1), ((1, 3, Rational(3, 2)), (1, 2, 2)) )/6 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1)), ((1, 3), (1, 2)) ) == \ + sqrt(3)*JzKetCoupled(1, 1, (S.Half, S.Half, 1), ((1, 3, Rational(3, 2)), (1, 2, 1)) )/2 + \ + JzKetCoupled(2, 1, (S.Half, S.Half, 1), ((1, 3, Rational(3, 2)), (1, 2, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0)), ((1, 3), (1, 2)) ) == \ + sqrt(6)*JzKetCoupled(0, 0, (S.Half, S.Half, 1), ((1, 3, S.Half), (1, 2, 0)) )/6 + \ + sqrt(6)*JzKetCoupled(1, 0, (S.Half, S.Half, 1), ((1, 3, S.Half), (1, 2, 1)) )/6 + \ + sqrt(3)*JzKetCoupled(1, 0, (S.Half, S.Half, 1), ((1, 3, Rational(3, 2)), (1, 2, 1)) )/3 + \ + sqrt(3)*JzKetCoupled( + 2, 0, (S.Half, S.Half, 1), ((1, 3, Rational(3, 2)), (1, 2, 2)) )/3 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1)), ((1, 3), (1, 2)) ) == \ + sqrt(6)*JzKetCoupled(1, -1, (S.Half, S.Half, 1), ((1, 3, S.Half), (1, 2, 1)) )/3 + \ + sqrt(3)*JzKetCoupled(1, -1, (S.Half, S.Half, 1), ((1, 3, Rational(3, 2)), (1, 2, 1)) )/6 + \ + JzKetCoupled( + 2, -1, (S.Half, S.Half, 1), ((1, 3, Rational(3, 2)), (1, 2, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1)), ((1, 3), (1, 2)) ) == \ + -sqrt(6)*JzKetCoupled(1, 1, (S.Half, S.Half, 1), ((1, 3, S.Half), (1, 2, 1)) )/3 - \ + sqrt(3)*JzKetCoupled(1, 1, (S.Half, S.Half, 1), ((1, 3, Rational(3, 2)), (1, 2, 1)) )/6 + \ + JzKetCoupled(2, 1, (S.Half, S.Half, 1), ((1, 3, Rational(3, 2)), (1, 2, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0)), ((1, 3), (1, 2)) ) == \ + sqrt(6)*JzKetCoupled(0, 0, (S.Half, S.Half, 1), ((1, 3, S.Half), (1, 2, 0)) )/6 - \ + sqrt(6)*JzKetCoupled(1, 0, (S.Half, S.Half, 1), ((1, 3, S.Half), (1, 2, 1)) )/6 - \ + sqrt(3)*JzKetCoupled(1, 0, (S.Half, S.Half, 1), ((1, 3, Rational(3, 2)), (1, 2, 1)) )/3 + \ + sqrt(3)*JzKetCoupled( + 2, 0, (S.Half, S.Half, 1), ((1, 3, Rational(3, 2)), (1, 2, 2)) )/3 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1)), ((1, 3), (1, 2)) ) == \ + -sqrt(3)*JzKetCoupled(1, -1, (S.Half, S.Half, 1), ((1, 3, Rational(3, 2)), (1, 2, 1)) )/2 + \ + JzKetCoupled( + 2, -1, (S.Half, S.Half, 1), ((1, 3, Rational(3, 2)), (1, 2, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1)), ((1, 3), (1, 2)) ) == \ + -sqrt(3)*JzKetCoupled(0, 0, (S.Half, S.Half, 1), ((1, 3, S.Half), (1, 2, 0)) )/3 - \ + sqrt(3)*JzKetCoupled(1, 0, (S.Half, S.Half, 1), ((1, 3, S.Half), (1, 2, 1)) )/3 + \ + sqrt(6)*JzKetCoupled(1, 0, (S.Half, S.Half, 1), ((1, 3, Rational(3, 2)), (1, 2, 1)) )/6 + \ + sqrt(6)*JzKetCoupled( + 2, 0, (S.Half, S.Half, 1), ((1, 3, Rational(3, 2)), (1, 2, 2)) )/6 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0)), ((1, 3), (1, 2)) ) == \ + -sqrt(3)*JzKetCoupled(1, -1, (S.Half, S.Half, 1), ((1, 3, S.Half), (1, 2, 1)) )/3 + \ + sqrt(6)*JzKetCoupled(1, -1, (S.Half, S.Half, 1), ((1, 3, Rational(3, 2)), (1, 2, 1)) )/6 + \ + sqrt(2)*JzKetCoupled( + 2, -1, (S.Half, S.Half, 1), ((1, 3, Rational(3, 2)), (1, 2, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1)), ((1, 3), (1, 2)) ) == \ + JzKetCoupled(2, -2, (S.Half, S.Half, 1), ((1, 3, Rational(3, 2)), (1, 2, 2)) ) + # j 1=1/2, j 2=1, j 3=1 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, 1)), ((1, 3), (1, 2)) ) == \ + JzKetCoupled( + Rational(5, 2), Rational(5, 2), (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, Rational(5, 2))) ) + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, 0)), ((1, 3), (1, 2)) ) == \ + sqrt(3)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, 1, 1), ((1, 3, S.Half), (1, 2, Rational(3, 2))) )/3 - \ + 2*sqrt(15)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, Rational(3, 2))) )/15 + \ + sqrt(10)*JzKetCoupled(S( + 5)/2, Rational(3, 2), (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 1), JzKet(1, -1)), ((1, 3), (1, 2)) ) == \ + -2*JzKetCoupled(S.Half, S.Half, (S.Half, 1, 1), ((1, 3, S.Half), (1, 2, S.Half)) )/3 + \ + sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, S.Half)) )/6 + \ + sqrt(2)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, 1, 1), ((1, 3, S.Half), (1, 2, Rational(3, 2))) )/3 - \ + 2*sqrt(10)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, Rational(3, 2))) )/15 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), S.Half, (S.Half, 1, 1), ((1, + 3, Rational(3, 2)), (1, 2, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 1)), ((1, 3), (1, 2)) ) == \ + sqrt(15)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, Rational(3, 2))) )/5 + \ + sqrt(10)*JzKetCoupled(S( + 5)/2, Rational(3, 2), (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, 0)), ((1, 3), (1, 2)) ) == \ + JzKetCoupled(S.Half, S.Half, (S.Half, 1, 1), ((1, 3, S.Half), (1, 2, S.Half)) )/3 - \ + sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, S.Half)) )/3 + \ + sqrt(2)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, 1, 1), ((1, 3, S.Half), (1, 2, Rational(3, 2))) )/3 + \ + sqrt(10)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, Rational(3, 2))) )/15 + \ + sqrt(10)*JzKetCoupled(S( + 5)/2, S.Half, (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(1, 0), JzKet(1, -1)), ((1, 3), (1, 2)) ) == \ + -sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, 1, 1), ((1, 3, S.Half), (1, 2, S.Half)) )/3 - \ + JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, S.Half)) )/3 + \ + 2*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, 3, S.Half), (1, 2, Rational(3, 2))) )/3 - \ + sqrt(5)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, + 3, Rational(3, 2)), (1, 2, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 1)), ((1, 3), (1, 2)) ) == \ + sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, S.Half)) )/2 + \ + sqrt(10)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, Rational(3, 2))) )/5 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), S.Half, (S.Half, 1, 1), ((1, + 3, Rational(3, 2)), (1, 2, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, 0)), ((1, 3), (1, 2)) ) == \ + sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, 1, 1), ((1, 3, S.Half), (1, 2, S.Half)) )/3 + \ + JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, S.Half)) )/3 + \ + JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, 3, S.Half), (1, 2, Rational(3, 2))) )/3 + \ + 4*sqrt(5)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, + 3, Rational(3, 2)), (1, 2, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(1, -1), JzKet(1, -1)), ((1, 3), (1, 2)) ) == \ + sqrt(6)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, 1, 1), ((1, 3, S.Half), (1, 2, Rational(3, 2))) )/3 + \ + sqrt(30)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(-3, 2), (S.Half, 1, 1), ((1, + 3, Rational(3, 2)), (1, 2, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 1)), ((1, 3), (1, 2)) ) == \ + -sqrt(6)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, 1, 1), ((1, 3, S.Half), (1, 2, Rational(3, 2))) )/3 - \ + sqrt(30)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(S( + 5)/2, Rational(3, 2), (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, 0)), ((1, 3), (1, 2)) ) == \ + sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, 1, 1), ((1, 3, S.Half), (1, 2, S.Half)) )/3 + \ + JzKetCoupled(S.Half, S.Half, (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, S.Half)) )/3 - \ + JzKetCoupled(Rational(3, 2), S.Half, (S.Half, 1, 1), ((1, 3, S.Half), (1, 2, Rational(3, 2))) )/3 - \ + 4*sqrt(5)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(S( + 5)/2, S.Half, (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1), JzKet(1, -1)), ((1, 3), (1, 2)) ) == \ + sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, S.Half)) )/2 - \ + sqrt(10)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, Rational(3, 2))) )/5 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, + 3, Rational(3, 2)), (1, 2, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 1)), ((1, 3), (1, 2)) ) == \ + -sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, 1, 1), ((1, 3, S.Half), (1, 2, S.Half)) )/3 - \ + JzKetCoupled(S.Half, S.Half, (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, S.Half)) )/3 - \ + 2*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, 1, 1), ((1, 3, S.Half), (1, 2, Rational(3, 2))) )/3 + \ + sqrt(5)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(S( + 5)/2, S.Half, (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, 0)), ((1, 3), (1, 2)) ) == \ + JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, 1, 1), ((1, 3, S.Half), (1, 2, S.Half)) )/3 - \ + sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, S.Half)) )/3 - \ + sqrt(2)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, 3, S.Half), (1, 2, Rational(3, 2))) )/3 - \ + sqrt(10)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, Rational(3, 2))) )/15 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, + 3, Rational(3, 2)), (1, 2, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0), JzKet(1, -1)), ((1, 3), (1, 2)) ) == \ + -sqrt(15)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, Rational(3, 2))) )/5 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), Rational(-3, 2), (S.Half, 1, 1), ((1, + 3, Rational(3, 2)), (1, 2, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 1)), ((1, 3), (1, 2)) ) == \ + -2*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, 1, 1), ((1, 3, S.Half), (1, 2, S.Half)) )/3 + \ + sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, S.Half)) )/6 - \ + sqrt(2)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, 3, S.Half), (1, 2, Rational(3, 2))) )/3 + \ + 2*sqrt(10)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, Rational(3, 2))) )/15 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, 1, 1), ((1, + 3, Rational(3, 2)), (1, 2, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, 0)), ((1, 3), (1, 2)) ) == \ + -sqrt(3)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, 1, 1), ((1, 3, S.Half), (1, 2, Rational(3, 2))) )/3 + \ + 2*sqrt(15)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, Rational(3, 2))) )/15 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), Rational(-3, 2), (S.Half, 1, 1), ((1, + 3, Rational(3, 2)), (1, 2, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1), JzKet(1, -1)), ((1, 3), (1, 2)) ) == \ + JzKetCoupled(S( + 5)/2, Rational(-5, 2), (S.Half, 1, 1), ((1, 3, Rational(3, 2)), (1, 2, Rational(5, 2))) ) + # j1=1, 1, 1 + assert couple(TensorProduct(JzKet(1, 1), JzKet(1, 1), JzKet(1, 1)), ((1, 3), (1, 2)) ) == \ + JzKetCoupled(3, 3, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) ) + assert couple(TensorProduct(JzKet(1, 1), JzKet(1, 1), JzKet(1, 0)), ((1, 3), (1, 2)) ) == \ + sqrt(2)*JzKetCoupled(2, 2, (1, 1, 1), ((1, 3, 1), (1, 2, 2)) )/2 - \ + sqrt(6)*JzKetCoupled(2, 2, (1, 1, 1), ((1, 3, 2), (1, 2, 2)) )/6 + \ + sqrt(3)*JzKetCoupled(3, 2, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) )/3 + assert couple(TensorProduct(JzKet(1, 1), JzKet(1, 1), JzKet(1, -1)), ((1, 3), (1, 2)) ) == \ + sqrt(3)*JzKetCoupled(1, 1, (1, 1, 1), ((1, 3, 0), (1, 2, 1)) )/3 - \ + JzKetCoupled(1, 1, (1, 1, 1), ((1, 3, 1), (1, 2, 1)) )/2 + \ + sqrt(15)*JzKetCoupled(1, 1, (1, 1, 1), ((1, 3, 2), (1, 2, 1)) )/30 + \ + JzKetCoupled(2, 1, (1, 1, 1), ((1, 3, 1), (1, 2, 2)) )/2 - \ + sqrt(3)*JzKetCoupled(2, 1, (1, 1, 1), ((1, 3, 2), (1, 2, 2)) )/6 + \ + sqrt(15)*JzKetCoupled(3, 1, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) )/15 + assert couple(TensorProduct(JzKet(1, 1), JzKet(1, 0), JzKet(1, 1)), ((1, 3), (1, 2)) ) == \ + sqrt(6)*JzKetCoupled(2, 2, (1, 1, 1), ((1, 3, 2), (1, 2, 2)) )/3 + \ + sqrt(3)*JzKetCoupled(3, 2, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) )/3 + assert couple(TensorProduct(JzKet(1, 1), JzKet(1, 0), JzKet(1, 0)), ((1, 3), (1, 2)) ) == \ + JzKetCoupled(1, 1, (1, 1, 1), ((1, 3, 1), (1, 2, 1)) )/2 - \ + sqrt(15)*JzKetCoupled(1, 1, (1, 1, 1), ((1, 3, 2), (1, 2, 1)) )/10 + \ + JzKetCoupled(2, 1, (1, 1, 1), ((1, 3, 1), (1, 2, 2)) )/2 + \ + sqrt(3)*JzKetCoupled(2, 1, (1, 1, 1), ((1, 3, 2), (1, 2, 2)) )/6 + \ + 2*sqrt(15)*JzKetCoupled(3, 1, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) )/15 + assert couple(TensorProduct(JzKet(1, 1), JzKet(1, 0), JzKet(1, -1)), ((1, 3), (1, 2)) ) == \ + -sqrt(6)*JzKetCoupled(0, 0, (1, 1, 1), ((1, 3, 1), (1, 2, 0)) )/6 + \ + sqrt(3)*JzKetCoupled(1, 0, (1, 1, 1), ((1, 3, 0), (1, 2, 1)) )/3 - \ + sqrt(15)*JzKetCoupled(1, 0, (1, 1, 1), ((1, 3, 2), (1, 2, 1)) )/15 + \ + sqrt(3)*JzKetCoupled(2, 0, (1, 1, 1), ((1, 3, 1), (1, 2, 2)) )/3 + \ + sqrt(10)*JzKetCoupled(3, 0, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) )/10 + assert couple(TensorProduct(JzKet(1, 1), JzKet(1, -1), JzKet(1, 1)), ((1, 3), (1, 2)) ) == \ + sqrt(15)*JzKetCoupled(1, 1, (1, 1, 1), ((1, 3, 2), (1, 2, 1)) )/5 + \ + sqrt(3)*JzKetCoupled(2, 1, (1, 1, 1), ((1, 3, 2), (1, 2, 2)) )/3 + \ + sqrt(15)*JzKetCoupled(3, 1, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) )/15 + assert couple(TensorProduct(JzKet(1, 1), JzKet(1, -1), JzKet(1, 0)), ((1, 3), (1, 2)) ) == \ + sqrt(6)*JzKetCoupled(0, 0, (1, 1, 1), ((1, 3, 1), (1, 2, 0)) )/6 + \ + JzKetCoupled(1, 0, (1, 1, 1), ((1, 3, 1), (1, 2, 1)) )/2 + \ + sqrt(15)*JzKetCoupled(1, 0, (1, 1, 1), ((1, 3, 2), (1, 2, 1)) )/10 + \ + sqrt(3)*JzKetCoupled(2, 0, (1, 1, 1), ((1, 3, 1), (1, 2, 2)) )/6 + \ + JzKetCoupled(2, 0, (1, 1, 1), ((1, 3, 2), (1, 2, 2)) )/2 + \ + sqrt(10)*JzKetCoupled(3, 0, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) )/10 + assert couple(TensorProduct(JzKet(1, 1), JzKet(1, -1), JzKet(1, -1)), ((1, 3), (1, 2)) ) == \ + sqrt(3)*JzKetCoupled(1, -1, (1, 1, 1), ((1, 3, 0), (1, 2, 1)) )/3 + \ + JzKetCoupled(1, -1, (1, 1, 1), ((1, 3, 1), (1, 2, 1)) )/2 + \ + sqrt(15)*JzKetCoupled(1, -1, (1, 1, 1), ((1, 3, 2), (1, 2, 1)) )/30 + \ + JzKetCoupled(2, -1, (1, 1, 1), ((1, 3, 1), (1, 2, 2)) )/2 + \ + sqrt(3)*JzKetCoupled(2, -1, (1, 1, 1), ((1, 3, 2), (1, 2, 2)) )/6 + \ + sqrt(15)*JzKetCoupled(3, -1, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) )/15 + assert couple(TensorProduct(JzKet(1, 0), JzKet(1, 1), JzKet(1, 1)), ((1, 3), (1, 2)) ) == \ + -sqrt(2)*JzKetCoupled(2, 2, (1, 1, 1), ((1, 3, 1), (1, 2, 2)) )/2 - \ + sqrt(6)*JzKetCoupled(2, 2, (1, 1, 1), ((1, 3, 2), (1, 2, 2)) )/6 + \ + sqrt(3)*JzKetCoupled(3, 2, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) )/3 + assert couple(TensorProduct(JzKet(1, 0), JzKet(1, 1), JzKet(1, 0)), ((1, 3), (1, 2)) ) == \ + -sqrt(3)*JzKetCoupled(1, 1, (1, 1, 1), ((1, 3, 0), (1, 2, 1)) )/3 + \ + sqrt(15)*JzKetCoupled(1, 1, (1, 1, 1), ((1, 3, 2), (1, 2, 1)) )/15 - \ + sqrt(3)*JzKetCoupled(2, 1, (1, 1, 1), ((1, 3, 2), (1, 2, 2)) )/3 + \ + 2*sqrt(15)*JzKetCoupled(3, 1, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) )/15 + assert couple(TensorProduct(JzKet(1, 0), JzKet(1, 1), JzKet(1, -1)), ((1, 3), (1, 2)) ) == \ + sqrt(6)*JzKetCoupled(0, 0, (1, 1, 1), ((1, 3, 1), (1, 2, 0)) )/6 - \ + JzKetCoupled(1, 0, (1, 1, 1), ((1, 3, 1), (1, 2, 1)) )/2 + \ + sqrt(15)*JzKetCoupled(1, 0, (1, 1, 1), ((1, 3, 2), (1, 2, 1)) )/10 + \ + sqrt(3)*JzKetCoupled(2, 0, (1, 1, 1), ((1, 3, 1), (1, 2, 2)) )/6 - \ + JzKetCoupled(2, 0, (1, 1, 1), ((1, 3, 2), (1, 2, 2)) )/2 + \ + sqrt(10)*JzKetCoupled(3, 0, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) )/10 + assert couple(TensorProduct(JzKet(1, 0), JzKet(1, 0), JzKet(1, 1)), ((1, 3), (1, 2)) ) == \ + -JzKetCoupled(1, 1, (1, 1, 1), ((1, 3, 1), (1, 2, 1)) )/2 - \ + sqrt(15)*JzKetCoupled(1, 1, (1, 1, 1), ((1, 3, 2), (1, 2, 1)) )/10 - \ + JzKetCoupled(2, 1, (1, 1, 1), ((1, 3, 1), (1, 2, 2)) )/2 + \ + sqrt(3)*JzKetCoupled(2, 1, (1, 1, 1), ((1, 3, 2), (1, 2, 2)) )/6 + \ + 2*sqrt(15)*JzKetCoupled(3, 1, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) )/15 + assert couple(TensorProduct(JzKet(1, 0), JzKet(1, 0), JzKet(1, 0)), ((1, 3), (1, 2)) ) == \ + -sqrt(3)*JzKetCoupled(1, 0, (1, 1, 1), ((1, 3, 0), (1, 2, 1)) )/3 - \ + 2*sqrt(15)*JzKetCoupled(1, 0, (1, 1, 1), ((1, 3, 2), (1, 2, 1)) )/15 + \ + sqrt(10)*JzKetCoupled(3, 0, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) )/5 + assert couple(TensorProduct(JzKet(1, 0), JzKet(1, 0), JzKet(1, -1)), ((1, 3), (1, 2)) ) == \ + -JzKetCoupled(1, -1, (1, 1, 1), ((1, 3, 1), (1, 2, 1)) )/2 - \ + sqrt(15)*JzKetCoupled(1, -1, (1, 1, 1), ((1, 3, 2), (1, 2, 1)) )/10 + \ + JzKetCoupled(2, -1, (1, 1, 1), ((1, 3, 1), (1, 2, 2)) )/2 - \ + sqrt(3)*JzKetCoupled(2, -1, (1, 1, 1), ((1, 3, 2), (1, 2, 2)) )/6 + \ + 2*sqrt(15)*JzKetCoupled(3, -1, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) )/15 + assert couple(TensorProduct(JzKet(1, 0), JzKet(1, -1), JzKet(1, 1)), ((1, 3), (1, 2)) ) == \ + -sqrt(6)*JzKetCoupled(0, 0, (1, 1, 1), ((1, 3, 1), (1, 2, 0)) )/6 - \ + JzKetCoupled(1, 0, (1, 1, 1), ((1, 3, 1), (1, 2, 1)) )/2 + \ + sqrt(15)*JzKetCoupled(1, 0, (1, 1, 1), ((1, 3, 2), (1, 2, 1)) )/10 - \ + sqrt(3)*JzKetCoupled(2, 0, (1, 1, 1), ((1, 3, 1), (1, 2, 2)) )/6 + \ + JzKetCoupled(2, 0, (1, 1, 1), ((1, 3, 2), (1, 2, 2)) )/2 + \ + sqrt(10)*JzKetCoupled(3, 0, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) )/10 + assert couple(TensorProduct(JzKet(1, 0), JzKet(1, -1), JzKet(1, 0)), ((1, 3), (1, 2)) ) == \ + -sqrt(3)*JzKetCoupled(1, -1, (1, 1, 1), ((1, 3, 0), (1, 2, 1)) )/3 + \ + sqrt(15)*JzKetCoupled(1, -1, (1, 1, 1), ((1, 3, 2), (1, 2, 1)) )/15 + \ + sqrt(3)*JzKetCoupled(2, -1, (1, 1, 1), ((1, 3, 2), (1, 2, 2)) )/3 + \ + 2*sqrt(15)*JzKetCoupled(3, -1, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) )/15 + assert couple(TensorProduct(JzKet(1, 0), JzKet(1, -1), JzKet(1, -1)), ((1, 3), (1, 2)) ) == \ + sqrt(2)*JzKetCoupled(2, -2, (1, 1, 1), ((1, 3, 1), (1, 2, 2)) )/2 + \ + sqrt(6)*JzKetCoupled(2, -2, (1, 1, 1), ((1, 3, 2), (1, 2, 2)) )/6 + \ + sqrt(3)*JzKetCoupled(3, -2, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) )/3 + assert couple(TensorProduct(JzKet(1, -1), JzKet(1, 1), JzKet(1, 1)), ((1, 3), (1, 2)) ) == \ + sqrt(3)*JzKetCoupled(1, 1, (1, 1, 1), ((1, 3, 0), (1, 2, 1)) )/3 + \ + JzKetCoupled(1, 1, (1, 1, 1), ((1, 3, 1), (1, 2, 1)) )/2 + \ + sqrt(15)*JzKetCoupled(1, 1, (1, 1, 1), ((1, 3, 2), (1, 2, 1)) )/30 - \ + JzKetCoupled(2, 1, (1, 1, 1), ((1, 3, 1), (1, 2, 2)) )/2 - \ + sqrt(3)*JzKetCoupled(2, 1, (1, 1, 1), ((1, 3, 2), (1, 2, 2)) )/6 + \ + sqrt(15)*JzKetCoupled(3, 1, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) )/15 + assert couple(TensorProduct(JzKet(1, -1), JzKet(1, 1), JzKet(1, 0)), ((1, 3), (1, 2)) ) == \ + -sqrt(6)*JzKetCoupled(0, 0, (1, 1, 1), ((1, 3, 1), (1, 2, 0)) )/6 + \ + JzKetCoupled(1, 0, (1, 1, 1), ((1, 3, 1), (1, 2, 1)) )/2 + \ + sqrt(15)*JzKetCoupled(1, 0, (1, 1, 1), ((1, 3, 2), (1, 2, 1)) )/10 - \ + sqrt(3)*JzKetCoupled(2, 0, (1, 1, 1), ((1, 3, 1), (1, 2, 2)) )/6 - \ + JzKetCoupled(2, 0, (1, 1, 1), ((1, 3, 2), (1, 2, 2)) )/2 + \ + sqrt(10)*JzKetCoupled(3, 0, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) )/10 + assert couple(TensorProduct(JzKet(1, -1), JzKet(1, 1), JzKet(1, -1)), ((1, 3), (1, 2)) ) == \ + sqrt(15)*JzKetCoupled(1, -1, (1, 1, 1), ((1, 3, 2), (1, 2, 1)) )/5 - \ + sqrt(3)*JzKetCoupled(2, -1, (1, 1, 1), ((1, 3, 2), (1, 2, 2)) )/3 + \ + sqrt(15)*JzKetCoupled(3, -1, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) )/15 + assert couple(TensorProduct(JzKet(1, -1), JzKet(1, 0), JzKet(1, 1)), ((1, 3), (1, 2)) ) == \ + sqrt(6)*JzKetCoupled(0, 0, (1, 1, 1), ((1, 3, 1), (1, 2, 0)) )/6 + \ + sqrt(3)*JzKetCoupled(1, 0, (1, 1, 1), ((1, 3, 0), (1, 2, 1)) )/3 - \ + sqrt(15)*JzKetCoupled(1, 0, (1, 1, 1), ((1, 3, 2), (1, 2, 1)) )/15 - \ + sqrt(3)*JzKetCoupled(2, 0, (1, 1, 1), ((1, 3, 1), (1, 2, 2)) )/3 + \ + sqrt(10)*JzKetCoupled(3, 0, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) )/10 + assert couple(TensorProduct(JzKet(1, -1), JzKet(1, 0), JzKet(1, 0)), ((1, 3), (1, 2)) ) == \ + JzKetCoupled(1, -1, (1, 1, 1), ((1, 3, 1), (1, 2, 1)) )/2 - \ + sqrt(15)*JzKetCoupled(1, -1, (1, 1, 1), ((1, 3, 2), (1, 2, 1)) )/10 - \ + JzKetCoupled(2, -1, (1, 1, 1), ((1, 3, 1), (1, 2, 2)) )/2 - \ + sqrt(3)*JzKetCoupled(2, -1, (1, 1, 1), ((1, 3, 2), (1, 2, 2)) )/6 + \ + 2*sqrt(15)*JzKetCoupled(3, -1, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) )/15 + assert couple(TensorProduct(JzKet(1, -1), JzKet(1, 0), JzKet(1, -1)), ((1, 3), (1, 2)) ) == \ + -sqrt(6)*JzKetCoupled(2, -2, (1, 1, 1), ((1, 3, 2), (1, 2, 2)) )/3 + \ + sqrt(3)*JzKetCoupled(3, -2, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) )/3 + assert couple(TensorProduct(JzKet(1, -1), JzKet(1, -1), JzKet(1, 1)), ((1, 3), (1, 2)) ) == \ + sqrt(3)*JzKetCoupled(1, -1, (1, 1, 1), ((1, 3, 0), (1, 2, 1)) )/3 - \ + JzKetCoupled(1, -1, (1, 1, 1), ((1, 3, 1), (1, 2, 1)) )/2 + \ + sqrt(15)*JzKetCoupled(1, -1, (1, 1, 1), ((1, 3, 2), (1, 2, 1)) )/30 - \ + JzKetCoupled(2, -1, (1, 1, 1), ((1, 3, 1), (1, 2, 2)) )/2 + \ + sqrt(3)*JzKetCoupled(2, -1, (1, 1, 1), ((1, 3, 2), (1, 2, 2)) )/6 + \ + sqrt(15)*JzKetCoupled(3, -1, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) )/15 + assert couple(TensorProduct(JzKet(1, -1), JzKet(1, -1), JzKet(1, 0)), ((1, 3), (1, 2)) ) == \ + -sqrt(2)*JzKetCoupled(2, -2, (1, 1, 1), ((1, 3, 1), (1, 2, 2)) )/2 + \ + sqrt(6)*JzKetCoupled(2, -2, (1, 1, 1), ((1, 3, 2), (1, 2, 2)) )/6 + \ + sqrt(3)*JzKetCoupled(3, -2, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) )/3 + assert couple(TensorProduct(JzKet(1, -1), JzKet(1, -1), JzKet(1, -1)), ((1, 3), (1, 2)) ) == \ + JzKetCoupled(3, -3, (1, 1, 1), ((1, 3, 2), (1, 2, 3)) ) + # j1=1/2, j2=1/2, j3=3/2 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(Rational(3, 2), Rational(3, 2))), ((1, 3), (1, 2)) ) == \ + JzKetCoupled(Rational(5, 2), S( + 5)/2, (S.Half, S.Half, Rational(3, 2)), ((1, 3, 2), (1, 2, Rational(5, 2))) ) + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(Rational(3, 2), S.Half)), ((1, 3), (1, 2)) ) == \ + JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 3, 1), (1, 2, Rational(3, 2))) )/2 - \ + sqrt(15)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 3, 2), (1, 2, Rational(3, 2))) )/10 + \ + sqrt(15)*JzKetCoupled(Rational(5, 2), Rational(3, 2), (S.Half, S.Half, S(3) + /2), ((1, 3, 2), (1, 2, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(Rational(3, 2), Rational(-1, 2))), ((1, 3), (1, 2)) ) == \ + -sqrt(6)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, Rational(3, 2)), ((1, 3, 1), (1, 2, S.Half)) )/6 + \ + sqrt(3)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, Rational(3, 2)), ((1, 3, 1), (1, 2, Rational(3, 2))) )/3 - \ + sqrt(5)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, Rational(3, 2)), ((1, 3, 2), (1, 2, Rational(3, 2))) )/5 + \ + sqrt(30)*JzKetCoupled(Rational(5, 2), S( + 1)/2, (S.Half, S.Half, Rational(3, 2)), ((1, 3, 2), (1, 2, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(Rational(3, 2), Rational(-3, 2))), ((1, 3), (1, 2)) ) == \ + -sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 3, 1), (1, 2, S.Half)) )/2 + \ + JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 3, 1), (1, 2, Rational(3, 2))) )/2 - \ + sqrt(15)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 3, 2), (1, 2, Rational(3, 2))) )/10 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), -S( + 1)/2, (S.Half, S.Half, Rational(3, 2)), ((1, 3, 2), (1, 2, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(Rational(3, 2), Rational(3, 2))), ((1, 3), (1, 2)) ) == \ + 2*sqrt(5)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 3, 2), (1, 2, Rational(3, 2))) )/5 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(3, 2), (S.Half, S.Half, S(3)/ + 2), ((1, 3, 2), (1, 2, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(Rational(3, 2), S.Half)), ((1, 3), (1, 2)) ) == \ + sqrt(6)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, Rational(3, 2)), ((1, 3, 1), (1, 2, S.Half)) )/6 + \ + sqrt(3)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, Rational(3, 2)), ((1, 3, 1), (1, 2, Rational(3, 2))) )/6 + \ + 3*sqrt(5)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, Rational(3, 2)), ((1, 3, 2), (1, 2, Rational(3, 2))) )/10 + \ + sqrt(30)*JzKetCoupled(Rational(5, 2), S( + 1)/2, (S.Half, S.Half, Rational(3, 2)), ((1, 3, 2), (1, 2, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(Rational(3, 2), Rational(-1, 2))), ((1, 3), (1, 2)) ) == \ + sqrt(6)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 3, 1), (1, 2, S.Half)) )/6 + \ + sqrt(3)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 3, 1), (1, 2, Rational(3, 2))) )/3 + \ + sqrt(5)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 3, 2), (1, 2, Rational(3, 2))) )/5 + \ + sqrt(30)*JzKetCoupled(Rational(5, 2), -S( + 1)/2, (S.Half, S.Half, Rational(3, 2)), ((1, 3, 2), (1, 2, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(Rational(3, 2), Rational(-3, 2))), ((1, 3), (1, 2)) ) == \ + sqrt(3)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 3, 1), (1, 2, Rational(3, 2))) )/2 + \ + sqrt(5)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 3, 2), (1, 2, Rational(3, 2))) )/10 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(-3, 2), (S.Half, S.Half, S(3) + /2), ((1, 3, 2), (1, 2, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(Rational(3, 2), Rational(3, 2))), ((1, 3), (1, 2)) ) == \ + -sqrt(3)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 3, 1), (1, 2, Rational(3, 2))) )/2 - \ + sqrt(5)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 3, 2), (1, 2, Rational(3, 2))) )/10 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(3, 2), (S.Half, S.Half, S(3)/ + 2), ((1, 3, 2), (1, 2, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(Rational(3, 2), S.Half)), ((1, 3), (1, 2)) ) == \ + sqrt(6)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, Rational(3, 2)), ((1, 3, 1), (1, 2, S.Half)) )/6 - \ + sqrt(3)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, Rational(3, 2)), ((1, 3, 1), (1, 2, Rational(3, 2))) )/3 - \ + sqrt(5)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, Rational(3, 2)), ((1, 3, 2), (1, 2, Rational(3, 2))) )/5 + \ + sqrt(30)*JzKetCoupled(Rational(5, 2), S( + 1)/2, (S.Half, S.Half, Rational(3, 2)), ((1, 3, 2), (1, 2, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(Rational(3, 2), Rational(-1, 2))), ((1, 3), (1, 2)) ) == \ + sqrt(6)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 3, 1), (1, 2, S.Half)) )/6 - \ + sqrt(3)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 3, 1), (1, 2, Rational(3, 2))) )/6 - \ + 3*sqrt(5)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 3, 2), (1, 2, Rational(3, 2))) )/10 + \ + sqrt(30)*JzKetCoupled(Rational(5, 2), -S( + 1)/2, (S.Half, S.Half, Rational(3, 2)), ((1, 3, 2), (1, 2, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(Rational(3, 2), Rational(-3, 2))), ((1, 3), (1, 2)) ) == \ + -2*sqrt(5)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 3, 2), (1, 2, Rational(3, 2))) )/5 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(-3, 2), (S.Half, S.Half, S(3) + /2), ((1, 3, 2), (1, 2, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(Rational(3, 2), Rational(3, 2))), ((1, 3), (1, 2)) ) == \ + -sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, Rational(3, 2)), ((1, 3, 1), (1, 2, S.Half)) )/2 - \ + JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, Rational(3, 2)), ((1, 3, 1), (1, 2, Rational(3, 2))) )/2 + \ + sqrt(15)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, Rational(3, 2)), ((1, 3, 2), (1, 2, Rational(3, 2))) )/10 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), S( + 1)/2, (S.Half, S.Half, Rational(3, 2)), ((1, 3, 2), (1, 2, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(Rational(3, 2), S.Half)), ((1, 3), (1, 2)) ) == \ + -sqrt(6)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 3, 1), (1, 2, S.Half)) )/6 - \ + sqrt(3)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 3, 1), (1, 2, Rational(3, 2))) )/3 + \ + sqrt(5)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 3, 2), (1, 2, Rational(3, 2))) )/5 + \ + sqrt(30)*JzKetCoupled(Rational(5, 2), -S( + 1)/2, (S.Half, S.Half, Rational(3, 2)), ((1, 3, 2), (1, 2, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(Rational(3, 2), Rational(-1, 2))), ((1, 3), (1, 2)) ) == \ + -JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 3, 1), (1, 2, Rational(3, 2))) )/2 + \ + sqrt(15)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, Rational(3, 2)), ((1, 3, 2), (1, 2, Rational(3, 2))) )/10 + \ + sqrt(15)*JzKetCoupled(Rational(5, 2), Rational(-3, 2), (S.Half, S.Half, S( + 3)/2), ((1, 3, 2), (1, 2, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(Rational(3, 2), Rational(-3, 2))), ((1, 3), (1, 2)) ) == \ + JzKetCoupled(Rational(5, 2), -S( + 5)/2, (S.Half, S.Half, Rational(3, 2)), ((1, 3, 2), (1, 2, Rational(5, 2))) ) + + +def test_couple_4_states_numerical(): + # Default coupling + # j1=1/2, j2=1/2, j3=1/2, j4=1/2 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half))) == \ + JzKetCoupled(2, 2, (S.Half, S( + 1)/2, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 2)) ) + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)))) == \ + sqrt(3)*JzKetCoupled(1, 1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 1)) )/2 + \ + JzKetCoupled(2, 1, (S.Half, S( + 1)/2, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half))) == \ + sqrt(6)*JzKetCoupled(1, 1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, S.Half), (1, 4, 1)) )/3 - \ + sqrt(3)*JzKetCoupled(1, 1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 1)) )/6 + \ + JzKetCoupled(2, 1, (S.Half, S( + 1)/2, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)))) == \ + sqrt(3)*JzKetCoupled(0, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, S.Half), (1, 4, 0)) )/3 + \ + sqrt(3)*JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, S.Half), (1, 4, 1)) )/3 + \ + sqrt(6)*JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 1)) )/6 + \ + sqrt(6)*JzKetCoupled(2, 0, (S.Half, S( + 1)/2, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 2)) )/6 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half))) == \ + sqrt(2)*JzKetCoupled(1, 1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 0), (1, 3, S.Half), (1, 4, 1)) )/2 - \ + sqrt(6)*JzKetCoupled(1, 1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, S.Half), (1, 4, 1)) )/6 - \ + sqrt(3)*JzKetCoupled(1, 1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 1)) )/6 + \ + JzKetCoupled(2, 1, (S.Half, S( + 1)/2, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), + JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)))) == \ + JzKetCoupled(0, 0, (S.Half, S.Half, S.Half, S.Half), + ((1, 2, 0), (1, 3, S.Half), (1, 4, 0)))/2 - \ + sqrt(3)*JzKetCoupled(0, 0, (S.Half, S.Half, S.Half, S.Half), + ((1, 2, 1), (1, 3, S.Half), (1, 4, 0)))/6 + \ + JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), + ((1, 2, 0), (1, 3, S.Half), (1, 4, 1)))/2 - \ + sqrt(3)*JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), + ((1, 2, 1), (1, 3, S.Half), (1, 4, 1)))/6 + \ + sqrt(6)*JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), + ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 1)))/6 + \ + sqrt(6)*JzKetCoupled(2, 0, (S.Half, S.Half, S.Half, S.Half), + ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 2)))/6 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half))) == \ + -JzKetCoupled(0, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 0), (1, 3, S.Half), (1, 4, 0)) )/2 - \ + sqrt(3)*JzKetCoupled(0, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, S.Half), (1, 4, 0)) )/6 + \ + JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 0), (1, 3, S.Half), (1, 4, 1)) )/2 + \ + sqrt(3)*JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, S.Half), (1, 4, 1)) )/6 - \ + sqrt(6)*JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 1)) )/6 + \ + sqrt(6)*JzKetCoupled(2, 0, (S.Half, S( + 1)/2, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 2)) )/6 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)))) == \ + sqrt(2)*JzKetCoupled(1, -1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 0), (1, 3, S.Half), (1, 4, 1)) )/2 + \ + sqrt(6)*JzKetCoupled(1, -1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, S.Half), (1, 4, 1)) )/6 + \ + sqrt(3)*JzKetCoupled(1, -1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 1)) )/6 + \ + JzKetCoupled(2, -1, (S.Half, S( + 1)/2, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half))) == \ + -sqrt(2)*JzKetCoupled(1, 1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 0), (1, 3, S.Half), (1, 4, 1)) )/2 - \ + sqrt(6)*JzKetCoupled(1, 1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, S.Half), (1, 4, 1)) )/6 - \ + sqrt(3)*JzKetCoupled(1, 1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 1)) )/6 + \ + JzKetCoupled(2, 1, (S.Half, S( + 1)/2, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)))) == \ + -JzKetCoupled(0, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 0), (1, 3, S.Half), (1, 4, 0)) )/2 - \ + sqrt(3)*JzKetCoupled(0, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, S.Half), (1, 4, 0)) )/6 - \ + JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 0), (1, 3, S.Half), (1, 4, 1)) )/2 - \ + sqrt(3)*JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, S.Half), (1, 4, 1)) )/6 + \ + sqrt(6)*JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 1)) )/6 + \ + sqrt(6)*JzKetCoupled(2, 0, (S.Half, S( + 1)/2, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 2)) )/6 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half))) == \ + JzKetCoupled(0, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 0), (1, 3, S.Half), (1, 4, 0)) )/2 - \ + sqrt(3)*JzKetCoupled(0, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, S.Half), (1, 4, 0)) )/6 - \ + JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 0), (1, 3, S.Half), (1, 4, 1)) )/2 + \ + sqrt(3)*JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, S.Half), (1, 4, 1)) )/6 - \ + sqrt(6)*JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 1)) )/6 + \ + sqrt(6)*JzKetCoupled(2, 0, (S.Half, S( + 1)/2, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 2)) )/6 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)))) == \ + -sqrt(2)*JzKetCoupled(1, -1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 0), (1, 3, S.Half), (1, 4, 1)) )/2 + \ + sqrt(6)*JzKetCoupled(1, -1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, S.Half), (1, 4, 1)) )/6 + \ + sqrt(3)*JzKetCoupled(1, -1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 1)) )/6 + \ + JzKetCoupled(2, -1, (S.Half, S( + 1)/2, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half))) == \ + sqrt(3)*JzKetCoupled(0, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, S.Half), (1, 4, 0)) )/3 - \ + sqrt(3)*JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, S.Half), (1, 4, 1)) )/3 - \ + sqrt(6)*JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 1)) )/6 + \ + sqrt(6)*JzKetCoupled(2, 0, (S.Half, S( + 1)/2, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 2)) )/6 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)))) == \ + -sqrt(6)*JzKetCoupled(1, -1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, S.Half), (1, 4, 1)) )/3 + \ + sqrt(3)*JzKetCoupled(1, -1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 1)) )/6 + \ + JzKetCoupled(2, -1, (S.Half, S( + 1)/2, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half))) == \ + -sqrt(3)*JzKetCoupled(1, -1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 1)) )/2 + \ + JzKetCoupled(2, -1, (S.Half, S( + 1)/2, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)))) == \ + JzKetCoupled(2, -2, (S.Half, S( + 1)/2, S.Half, S.Half), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, 2)) ) + # j1=S.Half, S.Half, S.Half, 1 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 1))) == \ + JzKetCoupled(Rational(5, 2), Rational(5, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(5, 2))) ) + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0))) == \ + sqrt(15)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(3, 2))) )/5 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), Rational(3, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1))) == \ + sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, S.Half)) )/2 + \ + sqrt(10)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(3, 2))) )/5 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), S.Half, (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1))) == \ + sqrt(6)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/3 - \ + sqrt(30)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(3, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0))) == \ + sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, S.Half)) )/3 - \ + JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, S.Half)) )/3 + \ + 2*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/3 + \ + sqrt(5)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), S.Half, (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1))) == \ + 2*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, S.Half)) )/3 + \ + sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, S.Half)) )/6 + \ + sqrt(2)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/3 + \ + 2*sqrt(10)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(3, 2))) )/15 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1))) == \ + sqrt(2)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/2 - \ + sqrt(6)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/6 - \ + sqrt(30)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(3, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0))) == \ + sqrt(6)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (1, 3, S.Half), (1, 4, S.Half)) )/6 - \ + sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, S.Half)) )/6 - \ + JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, S.Half)) )/3 + \ + sqrt(3)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/3 - \ + JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/3 + \ + sqrt(5)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), S.Half, (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1))) == \ + sqrt(3)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (1, 3, S.Half), (1, 4, S.Half)) )/3 - \ + JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, S.Half)) )/3 + \ + sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, S.Half)) )/6 + \ + sqrt(6)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/6 - \ + sqrt(2)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/6 + \ + 2*sqrt(10)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(3, 2))) )/15 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1))) == \ + -sqrt(3)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (1, 3, S.Half), (1, 4, S.Half)) )/3 - \ + JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, S.Half)) )/3 + \ + sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, S.Half)) )/6 + \ + sqrt(6)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/6 + \ + sqrt(2)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/6 - \ + 2*sqrt(10)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(3, 2))) )/15 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), S.Half, (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0))) == \ + -sqrt(6)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (1, 3, S.Half), (1, 4, S.Half)) )/6 - \ + sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, S.Half)) )/6 - \ + JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, S.Half)) )/3 + \ + sqrt(3)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/3 + \ + JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/3 - \ + sqrt(5)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1))) == \ + sqrt(2)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/2 + \ + sqrt(6)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/6 + \ + sqrt(30)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(-3, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 1))) == \ + -sqrt(2)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/2 - \ + sqrt(6)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/6 - \ + sqrt(30)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(3, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0))) == \ + -sqrt(6)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (1, 3, S.Half), (1, 4, S.Half)) )/6 - \ + sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, S.Half)) )/6 - \ + JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, S.Half)) )/3 - \ + sqrt(3)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/3 - \ + JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/3 + \ + sqrt(5)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), S.Half, (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1))) == \ + -sqrt(3)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (1, 3, S.Half), (1, 4, S.Half)) )/3 - \ + JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, S.Half)) )/3 + \ + sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, S.Half)) )/6 - \ + sqrt(6)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/6 - \ + sqrt(2)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/6 + \ + 2*sqrt(10)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(3, 2))) )/15 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1))) == \ + sqrt(3)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (1, 3, S.Half), (1, 4, S.Half)) )/3 - \ + JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, S.Half)) )/3 + \ + sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, S.Half)) )/6 - \ + sqrt(6)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/6 + \ + sqrt(2)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/6 - \ + 2*sqrt(10)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(3, 2))) )/15 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), S.Half, (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0))) == \ + sqrt(6)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (1, 3, S.Half), (1, 4, S.Half)) )/6 - \ + sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, S.Half)) )/6 - \ + JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, S.Half)) )/3 - \ + sqrt(3)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/3 + \ + JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/3 - \ + sqrt(5)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1))) == \ + -sqrt(2)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/2 + \ + sqrt(6)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/6 + \ + sqrt(30)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(-3, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1))) == \ + 2*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, S.Half)) )/3 + \ + sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, S.Half)) )/6 - \ + sqrt(2)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/3 - \ + 2*sqrt(10)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(3, 2))) )/15 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), S.Half, (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0))) == \ + sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, S.Half)) )/3 - \ + JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, S.Half)) )/3 - \ + 2*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/3 - \ + sqrt(5)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1))) == \ + -sqrt(6)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, S.Half), (1, 4, Rational(3, 2))) )/3 + \ + sqrt(30)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(-3, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1))) == \ + sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, S.Half)) )/2 - \ + sqrt(10)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(3, 2))) )/5 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0))) == \ + -sqrt(15)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(3, 2))) )/5 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), Rational(-3, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1))) == \ + JzKetCoupled(Rational(5, 2), Rational(-5, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (1, 3, Rational(3, 2)), (1, 4, Rational(5, 2))) ) + # Couple j1 to j2, j3 to j4 + # j1=1/2, j2=1/2, j3=1/2, j4=1/2 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)), ((1, 2), (3, 4), (1, 3)) ) == \ + JzKetCoupled(2, 2, (S( + 1)/2, S.Half, S.Half, S.Half), ((1, 2, 1), (3, 4, 1), (1, 3, 2)) ) + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))), ((1, 2), (3, 4), (1, 3)) ) == \ + sqrt(2)*JzKetCoupled(1, 1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (3, 4, 0), (1, 3, 1)) )/2 + \ + JzKetCoupled(1, 1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (3, 4, 1), (1, 3, 1)) )/2 + \ + JzKetCoupled(2, 1, (S.Half, S( + 1)/2, S.Half, S.Half), ((1, 2, 1), (3, 4, 1), (1, 3, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)), ((1, 2), (3, 4), (1, 3)) ) == \ + -sqrt(2)*JzKetCoupled(1, 1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (3, 4, 0), (1, 3, 1)) )/2 + \ + JzKetCoupled(1, 1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (3, 4, 1), (1, 3, 1)) )/2 + \ + JzKetCoupled(2, 1, (S.Half, S( + 1)/2, S.Half, S.Half), ((1, 2, 1), (3, 4, 1), (1, 3, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))), ((1, 2), (3, 4), (1, 3)) ) == \ + sqrt(3)*JzKetCoupled(0, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (3, 4, 1), (1, 3, 0)) )/3 + \ + sqrt(2)*JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (3, 4, 1), (1, 3, 1)) )/2 + \ + sqrt(6)*JzKetCoupled(2, 0, (S.Half, S.Half, S.Half, S.One/ + 2), ((1, 2, 1), (3, 4, 1), (1, 3, 2)) )/6 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)), ((1, 2), (3, 4), (1, 3)) ) == \ + sqrt(2)*JzKetCoupled(1, 1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 0), (3, 4, 1), (1, 3, 1)) )/2 - \ + JzKetCoupled(1, 1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (3, 4, 1), (1, 3, 1)) )/2 + \ + JzKetCoupled(2, 1, (S.Half, S( + 1)/2, S.Half, S.Half), ((1, 2, 1), (3, 4, 1), (1, 3, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))), ((1, 2), (3, 4), (1, 3)) ) == \ + JzKetCoupled(0, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 0), (3, 4, 0), (1, 3, 0)) )/2 - \ + sqrt(3)*JzKetCoupled(0, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (3, 4, 1), (1, 3, 0)) )/6 + \ + JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 0), (3, 4, 1), (1, 3, 1)) )/2 + \ + JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (3, 4, 0), (1, 3, 1)) )/2 + \ + sqrt(6)*JzKetCoupled(2, 0, (S.Half, S.Half, S.Half, S.One/ + 2), ((1, 2, 1), (3, 4, 1), (1, 3, 2)) )/6 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)), ((1, 2), (3, 4), (1, 3)) ) == \ + -JzKetCoupled(0, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 0), (3, 4, 0), (1, 3, 0)) )/2 - \ + sqrt(3)*JzKetCoupled(0, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (3, 4, 1), (1, 3, 0)) )/6 + \ + JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 0), (3, 4, 1), (1, 3, 1)) )/2 - \ + JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (3, 4, 0), (1, 3, 1)) )/2 + \ + sqrt(6)*JzKetCoupled(2, 0, (S.Half, S.Half, S.Half, S.One/ + 2), ((1, 2, 1), (3, 4, 1), (1, 3, 2)) )/6 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))), ((1, 2), (3, 4), (1, 3)) ) == \ + sqrt(2)*JzKetCoupled(1, -1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 0), (3, 4, 1), (1, 3, 1)) )/2 + \ + JzKetCoupled(1, -1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (3, 4, 1), (1, 3, 1)) )/2 + \ + JzKetCoupled(2, -1, (S.Half, S( + 1)/2, S.Half, S.Half), ((1, 2, 1), (3, 4, 1), (1, 3, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)), ((1, 2), (3, 4), (1, 3)) ) == \ + -sqrt(2)*JzKetCoupled(1, 1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 0), (3, 4, 1), (1, 3, 1)) )/2 - \ + JzKetCoupled(1, 1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (3, 4, 1), (1, 3, 1)) )/2 + \ + JzKetCoupled(2, 1, (S.Half, S( + 1)/2, S.Half, S.Half), ((1, 2, 1), (3, 4, 1), (1, 3, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))), ((1, 2), (3, 4), (1, 3)) ) == \ + -JzKetCoupled(0, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 0), (3, 4, 0), (1, 3, 0)) )/2 - \ + sqrt(3)*JzKetCoupled(0, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (3, 4, 1), (1, 3, 0)) )/6 - \ + JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 0), (3, 4, 1), (1, 3, 1)) )/2 + \ + JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (3, 4, 0), (1, 3, 1)) )/2 + \ + sqrt(6)*JzKetCoupled(2, 0, (S.Half, S.Half, S.Half, S.One/ + 2), ((1, 2, 1), (3, 4, 1), (1, 3, 2)) )/6 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)), ((1, 2), (3, 4), (1, 3)) ) == \ + JzKetCoupled(0, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 0), (3, 4, 0), (1, 3, 0)) )/2 - \ + sqrt(3)*JzKetCoupled(0, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (3, 4, 1), (1, 3, 0)) )/6 - \ + JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 0), (3, 4, 1), (1, 3, 1)) )/2 - \ + JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (3, 4, 0), (1, 3, 1)) )/2 + \ + sqrt(6)*JzKetCoupled(2, 0, (S.Half, S.Half, S.Half, S.One/ + 2), ((1, 2, 1), (3, 4, 1), (1, 3, 2)) )/6 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))), ((1, 2), (3, 4), (1, 3)) ) == \ + -sqrt(2)*JzKetCoupled(1, -1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 0), (3, 4, 1), (1, 3, 1)) )/2 + \ + JzKetCoupled(1, -1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (3, 4, 1), (1, 3, 1)) )/2 + \ + JzKetCoupled(2, -1, (S.Half, S( + 1)/2, S.Half, S.Half), ((1, 2, 1), (3, 4, 1), (1, 3, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half)), ((1, 2), (3, 4), (1, 3)) ) == \ + sqrt(3)*JzKetCoupled(0, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (3, 4, 1), (1, 3, 0)) )/3 - \ + sqrt(2)*JzKetCoupled(1, 0, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (3, 4, 1), (1, 3, 1)) )/2 + \ + sqrt(6)*JzKetCoupled(2, 0, (S.Half, S.Half, S.Half, S.One/ + 2), ((1, 2, 1), (3, 4, 1), (1, 3, 2)) )/6 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))), ((1, 2), (3, 4), (1, 3)) ) == \ + sqrt(2)*JzKetCoupled(1, -1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (3, 4, 0), (1, 3, 1)) )/2 - \ + JzKetCoupled(1, -1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (3, 4, 1), (1, 3, 1)) )/2 + \ + JzKetCoupled(2, -1, (S.Half, S( + 1)/2, S.Half, S.Half), ((1, 2, 1), (3, 4, 1), (1, 3, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half)), ((1, 2), (3, 4), (1, 3)) ) == \ + -sqrt(2)*JzKetCoupled(1, -1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (3, 4, 0), (1, 3, 1)) )/2 - \ + JzKetCoupled(1, -1, (S.Half, S.Half, S.Half, S.Half), ((1, 2, 1), (3, 4, 1), (1, 3, 1)) )/2 + \ + JzKetCoupled(2, -1, (S.Half, S( + 1)/2, S.Half, S.Half), ((1, 2, 1), (3, 4, 1), (1, 3, 2)) )/2 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2))), ((1, 2), (3, 4), (1, 3)) ) == \ + JzKetCoupled(2, -2, (S( + 1)/2, S.Half, S.Half, S.Half), ((1, 2, 1), (3, 4, 1), (1, 3, 2)) ) + # j1=S.Half, S.Half, S.Half, 1 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 1)), ((1, 2), (3, 4), (1, 3)) ) == \ + JzKetCoupled(Rational(5, 2), Rational(5, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(5, 2))) ) + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0)), ((1, 2), (3, 4), (1, 3)) ) == \ + sqrt(3)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, Rational(3, 2))) )/3 + \ + 2*sqrt(15)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/15 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), Rational(3, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1)), ((1, 2), (3, 4), (1, 3)) ) == \ + 2*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, S.Half)) )/3 + \ + sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, S.Half)) )/6 + \ + sqrt(2)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, Rational(3, 2))) )/3 + \ + 2*sqrt(10)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/15 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), S.Half, (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1)), ((1, 2), (3, 4), (1, 3)) ) == \ + -sqrt(6)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, Rational(3, 2))) )/3 + \ + sqrt(30)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(3, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0)), ((1, 2), (3, 4), (1, 3)) ) == \ + -sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, S.Half)) )/3 + \ + JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, S.Half)) )/3 - \ + JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, Rational(3, 2))) )/3 + \ + 4*sqrt(5)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), S.Half, (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1)), ((1, 2), (3, 4), (1, 3)) ) == \ + sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, S.Half)) )/2 + \ + sqrt(10)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/5 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1)), ((1, 2), (3, 4), (1, 3)) ) == \ + sqrt(2)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/2 - \ + sqrt(30)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/10 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(3, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0)), ((1, 2), (3, 4), (1, 3)) ) == \ + sqrt(6)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (3, 4, S.Half), (1, 3, S.Half)) )/6 - \ + sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, S.Half)) )/6 - \ + JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, S.Half)) )/3 + \ + sqrt(3)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/3 + \ + JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, Rational(3, 2))) )/3 - \ + sqrt(5)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), S.Half, (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1)), ((1, 2), (3, 4), (1, 3)) ) == \ + sqrt(3)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (3, 4, S.Half), (1, 3, S.Half)) )/3 + \ + JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, S.Half)) )/3 - \ + sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, S.Half)) )/6 + \ + sqrt(6)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/6 + \ + sqrt(2)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, Rational(3, 2))) )/3 + \ + sqrt(10)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/30 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1)), ((1, 2), (3, 4), (1, 3)) ) == \ + -sqrt(3)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (3, 4, S.Half), (1, 3, S.Half)) )/3 + \ + JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, S.Half)) )/3 - \ + sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, S.Half)) )/6 + \ + sqrt(6)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/6 - \ + sqrt(2)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, Rational(3, 2))) )/3 - \ + sqrt(10)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/30 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), S.Half, (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0)), ((1, 2), (3, 4), (1, 3)) ) == \ + -sqrt(6)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (3, 4, S.Half), (1, 3, S.Half)) )/6 - \ + sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, S.Half)) )/6 - \ + JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, S.Half)) )/3 + \ + sqrt(3)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/3 - \ + JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, Rational(3, 2))) )/3 + \ + sqrt(5)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1)), ((1, 2), (3, 4), (1, 3)) ) == \ + sqrt(2)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/2 + \ + sqrt(30)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/10 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(-3, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 1)), ((1, 2), (3, 4), (1, 3)) ) == \ + -sqrt(2)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/2 - \ + sqrt(30)*JzKetCoupled(Rational(3, 2), Rational(3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/10 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(3, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, 0)), ((1, 2), (3, 4), (1, 3)) ) == \ + -sqrt(6)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (3, 4, S.Half), (1, 3, S.Half)) )/6 - \ + sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, S.Half)) )/6 - \ + JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, S.Half)) )/3 - \ + sqrt(3)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/3 + \ + JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, Rational(3, 2))) )/3 - \ + sqrt(5)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), S.Half, (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, S.Half), JzKet(1, -1)), ((1, 2), (3, 4), (1, 3)) ) == \ + -sqrt(3)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (3, 4, S.Half), (1, 3, S.Half)) )/3 + \ + JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, S.Half)) )/3 - \ + sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, S.Half)) )/6 - \ + sqrt(6)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/6 + \ + sqrt(2)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, Rational(3, 2))) )/3 + \ + sqrt(10)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/30 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1)), ((1, 2), (3, 4), (1, 3)) ) == \ + sqrt(3)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (3, 4, S.Half), (1, 3, S.Half)) )/3 + \ + JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, S.Half)) )/3 - \ + sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, S.Half)) )/6 - \ + sqrt(6)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/6 - \ + sqrt(2)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, Rational(3, 2))) )/3 - \ + sqrt(10)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/30 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), S.Half, (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0)), ((1, 2), (3, 4), (1, 3)) ) == \ + sqrt(6)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (3, 4, S.Half), (1, 3, S.Half)) )/6 - \ + sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, S.Half)) )/6 - \ + JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, S.Half)) )/3 - \ + sqrt(3)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/3 - \ + JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, Rational(3, 2))) )/3 + \ + sqrt(5)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1)), ((1, 2), (3, 4), (1, 3)) ) == \ + -sqrt(2)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 0), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/2 + \ + sqrt(30)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/10 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(-3, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 1)), ((1, 2), (3, 4), (1, 3)) ) == \ + sqrt(2)*JzKetCoupled(S.Half, S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, S.Half)) )/2 - \ + sqrt(10)*JzKetCoupled(Rational(3, 2), S.Half, (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/5 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), S.Half, (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, 0)), ((1, 2), (3, 4), (1, 3)) ) == \ + -sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, S.Half)) )/3 + \ + JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, S.Half)) )/3 + \ + JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, Rational(3, 2))) )/3 - \ + 4*sqrt(5)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, S.Half), JzKet(1, -1)), ((1, 2), (3, 4), (1, 3)) ) == \ + sqrt(6)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, Rational(3, 2))) )/3 - \ + sqrt(30)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/15 + \ + sqrt(5)*JzKetCoupled(Rational(5, 2), Rational(-3, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 1)), ((1, 2), (3, 4), (1, 3)) ) == \ + 2*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, S.Half)) )/3 + \ + sqrt(2)*JzKetCoupled(S.Half, Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, S.Half)) )/6 - \ + sqrt(2)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, Rational(3, 2))) )/3 - \ + 2*sqrt(10)*JzKetCoupled(Rational(3, 2), Rational(-1, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/15 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), Rational(-1, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(5, 2))) )/10 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, 0)), ((1, 2), (3, 4), (1, 3)) ) == \ + -sqrt(3)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, S.Half), (1, 3, Rational(3, 2))) )/3 - \ + 2*sqrt(15)*JzKetCoupled(Rational(3, 2), Rational(-3, 2), (S.Half, S.Half, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(3, 2))) )/15 + \ + sqrt(10)*JzKetCoupled(Rational(5, 2), Rational(-3, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(5, 2))) )/5 + assert couple(TensorProduct(JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(S.Half, Rational(-1, 2)), JzKet(1, -1)), ((1, 2), (3, 4), (1, 3)) ) == \ + JzKetCoupled(Rational(5, 2), Rational(-5, 2), (S.Half, S( + 1)/2, S.Half, 1), ((1, 2, 1), (3, 4, Rational(3, 2)), (1, 3, Rational(5, 2))) ) + + +def test_couple_symbolic(): + assert couple(TensorProduct(JzKet(j1, m1), JzKet(j2, m2))) == \ + Sum(CG(j1, m1, j2, m2, j, m1 + m2) * JzKetCoupled(j, m1 + m2, ( + j1, j2)), (j, m1 + m2, j1 + j2)) + assert couple(TensorProduct(JzKet(j1, m1), JzKet(j2, m2), JzKet(j3, m3))) == \ + Sum(CG(j1, m1, j2, m2, j12, m1 + m2) * CG(j12, m1 + m2, j3, m3, j, m1 + m2 + m3) * + JzKetCoupled(j, m1 + m2 + m3, (j1, j2, j3), ((1, 2, j12), (1, 3, j)) ), + (j12, m1 + m2, j1 + j2), (j, m1 + m2 + m3, j12 + j3)) + assert couple(TensorProduct(JzKet(j1, m1), JzKet(j2, m2), JzKet(j3, m3)), ((1, 3), (1, 2)) ) == \ + Sum(CG(j1, m1, j3, m3, j13, m1 + m3) * CG(j13, m1 + m3, j2, m2, j, m1 + m2 + m3) * + JzKetCoupled(j, m1 + m2 + m3, (j1, j2, j3), ((1, 3, j13), (1, 2, j)) ), + (j13, m1 + m3, j1 + j3), (j, m1 + m2 + m3, j13 + j2)) + assert couple(TensorProduct(JzKet(j1, m1), JzKet(j2, m2), JzKet(j3, m3), JzKet(j4, m4))) == \ + Sum(CG(j1, m1, j2, m2, j12, m1 + m2) * CG(j12, m1 + m2, j3, m3, j123, m1 + m2 + m3) * CG(j123, m1 + m2 + m3, j4, m4, j, m1 + m2 + m3 + m4) * + JzKetCoupled(j, m1 + m2 + m3 + m4, ( + j1, j2, j3, j4), ((1, 2, j12), (1, 3, j123), (1, 4, j)) ), + (j12, m1 + m2, j1 + j2), (j123, m1 + m2 + m3, j12 + j3), (j, m1 + m2 + m3 + m4, j123 + j4)) + assert couple(TensorProduct(JzKet(j1, m1), JzKet(j2, m2), JzKet(j3, m3), JzKet(j4, m4)), ((1, 2), (3, 4), (1, 3)) ) == \ + Sum(CG(j1, m1, j2, m2, j12, m1 + m2) * CG(j3, m3, j4, m4, j34, m3 + m4) * CG(j12, m1 + m2, j34, m3 + m4, j, m1 + m2 + m3 + m4) * + JzKetCoupled(j, m1 + m2 + m3 + m4, ( + j1, j2, j3, j4), ((1, 2, j12), (3, 4, j34), (1, 3, j)) ), + (j12, m1 + m2, j1 + j2), (j34, m3 + m4, j3 + j4), (j, m1 + m2 + m3 + m4, j12 + j34)) + assert couple(TensorProduct(JzKet(j1, m1), JzKet(j2, m2), JzKet(j3, m3), JzKet(j4, m4)), ((1, 3), (1, 4), (1, 2)) ) == \ + Sum(CG(j1, m1, j3, m3, j13, m1 + m3) * CG(j13, m1 + m3, j4, m4, j134, m1 + m3 + m4) * CG(j134, m1 + m3 + m4, j2, m2, j, m1 + m2 + m3 + m4) * + JzKetCoupled(j, m1 + m2 + m3 + m4, ( + j1, j2, j3, j4), ((1, 3, j13), (1, 4, j134), (1, 2, j)) ), + (j13, m1 + m3, j1 + j3), (j134, m1 + m3 + m4, j13 + j4), (j, m1 + m2 + m3 + m4, j134 + j2)) + + +def test_innerproduct(): + assert InnerProduct(JzBra(1, 1), JzKet(1, 1)).doit() == 1 + assert InnerProduct( + JzBra(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))).doit() == 0 + assert InnerProduct(JzBra(j, m), JzKet(j, m)).doit() == 1 + assert InnerProduct(JzBra(1, 0), JyKet(1, 1)).doit() == I/sqrt(2) + assert InnerProduct( + JxBra(S.Half, S.Half), JzKet(S.Half, S.Half)).doit() == -sqrt(2)/2 + assert InnerProduct(JyBra(1, 1), JzKet(1, 1)).doit() == S.Half + assert InnerProduct(JxBra(1, -1), JyKet(1, 1)).doit() == 0 + + +def test_rotation_small_d(): + # Symbolic tests + # j = 1/2 + assert Rotation.d(S.Half, S.Half, S.Half, beta).doit() == cos(beta/2) + assert Rotation.d(S.Half, S.Half, Rational(-1, 2), beta).doit() == -sin(beta/2) + assert Rotation.d(S.Half, Rational(-1, 2), S.Half, beta).doit() == sin(beta/2) + assert Rotation.d(S.Half, Rational(-1, 2), Rational(-1, 2), beta).doit() == cos(beta/2) + # j = 1 + assert Rotation.d(1, 1, 1, beta).doit() == (1 + cos(beta))/2 + assert Rotation.d(1, 1, 0, beta).doit() == -sin(beta)/sqrt(2) + assert Rotation.d(1, 1, -1, beta).doit() == (1 - cos(beta))/2 + assert Rotation.d(1, 0, 1, beta).doit() == sin(beta)/sqrt(2) + assert Rotation.d(1, 0, 0, beta).doit() == cos(beta) + assert Rotation.d(1, 0, -1, beta).doit() == -sin(beta)/sqrt(2) + assert Rotation.d(1, -1, 1, beta).doit() == (1 - cos(beta))/2 + assert Rotation.d(1, -1, 0, beta).doit() == sin(beta)/sqrt(2) + assert Rotation.d(1, -1, -1, beta).doit() == (1 + cos(beta))/2 + # j = 3/2 + assert Rotation.d(S( + 3)/2, Rational(3, 2), Rational(3, 2), beta).doit() == (3*cos(beta/2) + cos(beta*Rational(3, 2)))/4 + assert Rotation.d(Rational(3, 2), S( + 3)/2, S.Half, beta).doit() == -sqrt(3)*(sin(beta/2) + sin(beta*Rational(3, 2)))/4 + assert Rotation.d(Rational(3, 2), S( + 3)/2, Rational(-1, 2), beta).doit() == sqrt(3)*(cos(beta/2) - cos(beta*Rational(3, 2)))/4 + assert Rotation.d(Rational(3, 2), S( + 3)/2, Rational(-3, 2), beta).doit() == (-3*sin(beta/2) + sin(beta*Rational(3, 2)))/4 + assert Rotation.d(Rational(3, 2), S( + 1)/2, Rational(3, 2), beta).doit() == sqrt(3)*(sin(beta/2) + sin(beta*Rational(3, 2)))/4 + assert Rotation.d(S( + 3)/2, S.Half, S.Half, beta).doit() == (cos(beta/2) + 3*cos(beta*Rational(3, 2)))/4 + assert Rotation.d(S( + 3)/2, S.Half, Rational(-1, 2), beta).doit() == (sin(beta/2) - 3*sin(beta*Rational(3, 2)))/4 + assert Rotation.d(Rational(3, 2), S( + 1)/2, Rational(-3, 2), beta).doit() == sqrt(3)*(cos(beta/2) - cos(beta*Rational(3, 2)))/4 + assert Rotation.d(Rational(3, 2), -S( + 1)/2, Rational(3, 2), beta).doit() == sqrt(3)*(cos(beta/2) - cos(beta*Rational(3, 2)))/4 + assert Rotation.d(Rational(3, 2), -S( + 1)/2, S.Half, beta).doit() == (-sin(beta/2) + 3*sin(beta*Rational(3, 2)))/4 + assert Rotation.d(Rational(3, 2), -S( + 1)/2, Rational(-1, 2), beta).doit() == (cos(beta/2) + 3*cos(beta*Rational(3, 2)))/4 + assert Rotation.d(Rational(3, 2), -S( + 1)/2, Rational(-3, 2), beta).doit() == -sqrt(3)*(sin(beta/2) + sin(beta*Rational(3, 2)))/4 + assert Rotation.d(S( + 3)/2, Rational(-3, 2), Rational(3, 2), beta).doit() == (3*sin(beta/2) - sin(beta*Rational(3, 2)))/4 + assert Rotation.d(Rational(3, 2), -S( + 3)/2, S.Half, beta).doit() == sqrt(3)*(cos(beta/2) - cos(beta*Rational(3, 2)))/4 + assert Rotation.d(Rational(3, 2), -S( + 3)/2, Rational(-1, 2), beta).doit() == sqrt(3)*(sin(beta/2) + sin(beta*Rational(3, 2)))/4 + assert Rotation.d(Rational(3, 2), -S( + 3)/2, Rational(-3, 2), beta).doit() == (3*cos(beta/2) + cos(beta*Rational(3, 2)))/4 + # j = 2 + assert Rotation.d(2, 2, 2, beta).doit() == (3 + 4*cos(beta) + cos(2*beta))/8 + assert Rotation.d(2, 2, 1, beta).doit() == -((cos(beta) + 1)*sin(beta))/2 + assert Rotation.d(2, 2, 0, beta).doit() == sqrt(6)*sin(beta)**2/4 + assert Rotation.d(2, 2, -1, beta).doit() == (cos(beta) - 1)*sin(beta)/2 + assert Rotation.d(2, 2, -2, beta).doit() == (3 - 4*cos(beta) + cos(2*beta))/8 + assert Rotation.d(2, 1, 2, beta).doit() == (cos(beta) + 1)*sin(beta)/2 + assert Rotation.d(2, 1, 1, beta).doit() == (cos(beta) + cos(2*beta))/2 + assert Rotation.d(2, 1, 0, beta).doit() == -sqrt(6)*sin(2*beta)/4 + assert Rotation.d(2, 1, -1, beta).doit() == (cos(beta) - cos(2*beta))/2 + assert Rotation.d(2, 1, -2, beta).doit() == (cos(beta) - 1)*sin(beta)/2 + assert Rotation.d(2, 0, 2, beta).doit() == sqrt(6)*sin(beta)**2/4 + assert Rotation.d(2, 0, 1, beta).doit() == sqrt(6)*sin(2*beta)/4 + assert Rotation.d(2, 0, 0, beta).doit() == (1 + 3*cos(2*beta))/4 + assert Rotation.d(2, 0, -1, beta).doit() == -sqrt(6)*sin(2*beta)/4 + assert Rotation.d(2, 0, -2, beta).doit() == sqrt(6)*sin(beta)**2/4 + assert Rotation.d(2, -1, 2, beta).doit() == (2*sin(beta) - sin(2*beta))/4 + assert Rotation.d(2, -1, 1, beta).doit() == (cos(beta) - cos(2*beta))/2 + assert Rotation.d(2, -1, 0, beta).doit() == sqrt(6)*sin(2*beta)/4 + assert Rotation.d(2, -1, -1, beta).doit() == (cos(beta) + cos(2*beta))/2 + assert Rotation.d(2, -1, -2, beta).doit() == -((cos(beta) + 1)*sin(beta))/2 + assert Rotation.d(2, -2, 2, beta).doit() == (3 - 4*cos(beta) + cos(2*beta))/8 + assert Rotation.d(2, -2, 1, beta).doit() == (2*sin(beta) - sin(2*beta))/4 + assert Rotation.d(2, -2, 0, beta).doit() == sqrt(6)*sin(beta)**2/4 + assert Rotation.d(2, -2, -1, beta).doit() == (cos(beta) + 1)*sin(beta)/2 + assert Rotation.d(2, -2, -2, beta).doit() == (3 + 4*cos(beta) + cos(2*beta))/8 + # Numerical tests + # j = 1/2 + assert Rotation.d(S.Half, S.Half, S.Half, pi/2).doit() == sqrt(2)/2 + assert Rotation.d(S.Half, S.Half, Rational(-1, 2), pi/2).doit() == -sqrt(2)/2 + assert Rotation.d(S.Half, Rational(-1, 2), S.Half, pi/2).doit() == sqrt(2)/2 + assert Rotation.d(S.Half, Rational(-1, 2), Rational(-1, 2), pi/2).doit() == sqrt(2)/2 + # j = 1 + assert Rotation.d(1, 1, 1, pi/2).doit() == S.Half + assert Rotation.d(1, 1, 0, pi/2).doit() == -sqrt(2)/2 + assert Rotation.d(1, 1, -1, pi/2).doit() == S.Half + assert Rotation.d(1, 0, 1, pi/2).doit() == sqrt(2)/2 + assert Rotation.d(1, 0, 0, pi/2).doit() == 0 + assert Rotation.d(1, 0, -1, pi/2).doit() == -sqrt(2)/2 + assert Rotation.d(1, -1, 1, pi/2).doit() == S.Half + assert Rotation.d(1, -1, 0, pi/2).doit() == sqrt(2)/2 + assert Rotation.d(1, -1, -1, pi/2).doit() == S.Half + # j = 3/2 + assert Rotation.d(Rational(3, 2), Rational(3, 2), Rational(3, 2), pi/2).doit() == sqrt(2)/4 + assert Rotation.d(Rational(3, 2), Rational(3, 2), S.Half, pi/2).doit() == -sqrt(6)/4 + assert Rotation.d(Rational(3, 2), Rational(3, 2), Rational(-1, 2), pi/2).doit() == sqrt(6)/4 + assert Rotation.d(Rational(3, 2), Rational(3, 2), Rational(-3, 2), pi/2).doit() == -sqrt(2)/4 + assert Rotation.d(Rational(3, 2), S.Half, Rational(3, 2), pi/2).doit() == sqrt(6)/4 + assert Rotation.d(Rational(3, 2), S.Half, S.Half, pi/2).doit() == -sqrt(2)/4 + assert Rotation.d(Rational(3, 2), S.Half, Rational(-1, 2), pi/2).doit() == -sqrt(2)/4 + assert Rotation.d(Rational(3, 2), S.Half, Rational(-3, 2), pi/2).doit() == sqrt(6)/4 + assert Rotation.d(Rational(3, 2), Rational(-1, 2), Rational(3, 2), pi/2).doit() == sqrt(6)/4 + assert Rotation.d(Rational(3, 2), Rational(-1, 2), S.Half, pi/2).doit() == sqrt(2)/4 + assert Rotation.d(Rational(3, 2), Rational(-1, 2), Rational(-1, 2), pi/2).doit() == -sqrt(2)/4 + assert Rotation.d(Rational(3, 2), Rational(-1, 2), Rational(-3, 2), pi/2).doit() == -sqrt(6)/4 + assert Rotation.d(Rational(3, 2), Rational(-3, 2), Rational(3, 2), pi/2).doit() == sqrt(2)/4 + assert Rotation.d(Rational(3, 2), Rational(-3, 2), S.Half, pi/2).doit() == sqrt(6)/4 + assert Rotation.d(Rational(3, 2), Rational(-3, 2), Rational(-1, 2), pi/2).doit() == sqrt(6)/4 + assert Rotation.d(Rational(3, 2), Rational(-3, 2), Rational(-3, 2), pi/2).doit() == sqrt(2)/4 + # j = 2 + assert Rotation.d(2, 2, 2, pi/2).doit() == Rational(1, 4) + assert Rotation.d(2, 2, 1, pi/2).doit() == Rational(-1, 2) + assert Rotation.d(2, 2, 0, pi/2).doit() == sqrt(6)/4 + assert Rotation.d(2, 2, -1, pi/2).doit() == Rational(-1, 2) + assert Rotation.d(2, 2, -2, pi/2).doit() == Rational(1, 4) + assert Rotation.d(2, 1, 2, pi/2).doit() == S.Half + assert Rotation.d(2, 1, 1, pi/2).doit() == Rational(-1, 2) + assert Rotation.d(2, 1, 0, pi/2).doit() == 0 + assert Rotation.d(2, 1, -1, pi/2).doit() == S.Half + assert Rotation.d(2, 1, -2, pi/2).doit() == Rational(-1, 2) + assert Rotation.d(2, 0, 2, pi/2).doit() == sqrt(6)/4 + assert Rotation.d(2, 0, 1, pi/2).doit() == 0 + assert Rotation.d(2, 0, 0, pi/2).doit() == Rational(-1, 2) + assert Rotation.d(2, 0, -1, pi/2).doit() == 0 + assert Rotation.d(2, 0, -2, pi/2).doit() == sqrt(6)/4 + assert Rotation.d(2, -1, 2, pi/2).doit() == S.Half + assert Rotation.d(2, -1, 1, pi/2).doit() == S.Half + assert Rotation.d(2, -1, 0, pi/2).doit() == 0 + assert Rotation.d(2, -1, -1, pi/2).doit() == Rational(-1, 2) + assert Rotation.d(2, -1, -2, pi/2).doit() == Rational(-1, 2) + assert Rotation.d(2, -2, 2, pi/2).doit() == Rational(1, 4) + assert Rotation.d(2, -2, 1, pi/2).doit() == S.Half + assert Rotation.d(2, -2, 0, pi/2).doit() == sqrt(6)/4 + assert Rotation.d(2, -2, -1, pi/2).doit() == S.Half + assert Rotation.d(2, -2, -2, pi/2).doit() == Rational(1, 4) + + +def test_rotation_d(): + # Symbolic tests + # j = 1/2 + assert Rotation.D(S.Half, S.Half, S.Half, alpha, beta, gamma).doit() == \ + cos(beta/2)*exp(-I*alpha/2)*exp(-I*gamma/2) + assert Rotation.D(S.Half, S.Half, Rational(-1, 2), alpha, beta, gamma).doit() == \ + -sin(beta/2)*exp(-I*alpha/2)*exp(I*gamma/2) + assert Rotation.D(S.Half, Rational(-1, 2), S.Half, alpha, beta, gamma).doit() == \ + sin(beta/2)*exp(I*alpha/2)*exp(-I*gamma/2) + assert Rotation.D(S.Half, Rational(-1, 2), Rational(-1, 2), alpha, beta, gamma).doit() == \ + cos(beta/2)*exp(I*alpha/2)*exp(I*gamma/2) + # j = 1 + assert Rotation.D(1, 1, 1, alpha, beta, gamma).doit() == \ + (1 + cos(beta))/2*exp(-I*alpha)*exp(-I*gamma) + assert Rotation.D(1, 1, 0, alpha, beta, gamma).doit() == -sin( + beta)/sqrt(2)*exp(-I*alpha) + assert Rotation.D(1, 1, -1, alpha, beta, gamma).doit() == \ + (1 - cos(beta))/2*exp(-I*alpha)*exp(I*gamma) + assert Rotation.D(1, 0, 1, alpha, beta, gamma).doit() == \ + sin(beta)/sqrt(2)*exp(-I*gamma) + assert Rotation.D(1, 0, 0, alpha, beta, gamma).doit() == cos(beta) + assert Rotation.D(1, 0, -1, alpha, beta, gamma).doit() == \ + -sin(beta)/sqrt(2)*exp(I*gamma) + assert Rotation.D(1, -1, 1, alpha, beta, gamma).doit() == \ + (1 - cos(beta))/2*exp(I*alpha)*exp(-I*gamma) + assert Rotation.D(1, -1, 0, alpha, beta, gamma).doit() == \ + sin(beta)/sqrt(2)*exp(I*alpha) + assert Rotation.D(1, -1, -1, alpha, beta, gamma).doit() == \ + (1 + cos(beta))/2*exp(I*alpha)*exp(I*gamma) + # j = 3/2 + assert Rotation.D(Rational(3, 2), Rational(3, 2), Rational(3, 2), alpha, beta, gamma).doit() == \ + (3*cos(beta/2) + cos(beta*Rational(3, 2)))/4*exp(I*alpha*Rational(-3, 2))*exp(I*gamma*Rational(-3, 2)) + assert Rotation.D(Rational(3, 2), Rational(3, 2), S.Half, alpha, beta, gamma).doit() == \ + -sqrt(3)*(sin(beta/2) + sin(beta*Rational(3, 2)))/4*exp(I*alpha*Rational(-3, 2))*exp(-I*gamma/2) + assert Rotation.D(Rational(3, 2), Rational(3, 2), Rational(-1, 2), alpha, beta, gamma).doit() == \ + sqrt(3)*(cos(beta/2) - cos(beta*Rational(3, 2)))/4*exp(I*alpha*Rational(-3, 2))*exp(I*gamma/2) + assert Rotation.D(Rational(3, 2), Rational(3, 2), Rational(-3, 2), alpha, beta, gamma).doit() == \ + (-3*sin(beta/2) + sin(beta*Rational(3, 2)))/4*exp(I*alpha*Rational(-3, 2))*exp(I*gamma*Rational(3, 2)) + assert Rotation.D(Rational(3, 2), S.Half, Rational(3, 2), alpha, beta, gamma).doit() == \ + sqrt(3)*(sin(beta/2) + sin(beta*Rational(3, 2)))/4*exp(-I*alpha/2)*exp(I*gamma*Rational(-3, 2)) + assert Rotation.D(Rational(3, 2), S.Half, S.Half, alpha, beta, gamma).doit() == \ + (cos(beta/2) + 3*cos(beta*Rational(3, 2)))/4*exp(-I*alpha/2)*exp(-I*gamma/2) + assert Rotation.D(Rational(3, 2), S.Half, Rational(-1, 2), alpha, beta, gamma).doit() == \ + (sin(beta/2) - 3*sin(beta*Rational(3, 2)))/4*exp(-I*alpha/2)*exp(I*gamma/2) + assert Rotation.D(Rational(3, 2), S.Half, Rational(-3, 2), alpha, beta, gamma).doit() == \ + sqrt(3)*(cos(beta/2) - cos(beta*Rational(3, 2)))/4*exp(-I*alpha/2)*exp(I*gamma*Rational(3, 2)) + assert Rotation.D(Rational(3, 2), Rational(-1, 2), Rational(3, 2), alpha, beta, gamma).doit() == \ + sqrt(3)*(cos(beta/2) - cos(beta*Rational(3, 2)))/4*exp(I*alpha/2)*exp(I*gamma*Rational(-3, 2)) + assert Rotation.D(Rational(3, 2), Rational(-1, 2), S.Half, alpha, beta, gamma).doit() == \ + (-sin(beta/2) + 3*sin(beta*Rational(3, 2)))/4*exp(I*alpha/2)*exp(-I*gamma/2) + assert Rotation.D(Rational(3, 2), Rational(-1, 2), Rational(-1, 2), alpha, beta, gamma).doit() == \ + (cos(beta/2) + 3*cos(beta*Rational(3, 2)))/4*exp(I*alpha/2)*exp(I*gamma/2) + assert Rotation.D(Rational(3, 2), Rational(-1, 2), Rational(-3, 2), alpha, beta, gamma).doit() == \ + -sqrt(3)*(sin(beta/2) + sin(beta*Rational(3, 2)))/4*exp(I*alpha/2)*exp(I*gamma*Rational(3, 2)) + assert Rotation.D(Rational(3, 2), Rational(-3, 2), Rational(3, 2), alpha, beta, gamma).doit() == \ + (3*sin(beta/2) - sin(beta*Rational(3, 2)))/4*exp(I*alpha*Rational(3, 2))*exp(I*gamma*Rational(-3, 2)) + assert Rotation.D(Rational(3, 2), Rational(-3, 2), S.Half, alpha, beta, gamma).doit() == \ + sqrt(3)*(cos(beta/2) - cos(beta*Rational(3, 2)))/4*exp(I*alpha*Rational(3, 2))*exp(-I*gamma/2) + assert Rotation.D(Rational(3, 2), Rational(-3, 2), Rational(-1, 2), alpha, beta, gamma).doit() == \ + sqrt(3)*(sin(beta/2) + sin(beta*Rational(3, 2)))/4*exp(I*alpha*Rational(3, 2))*exp(I*gamma/2) + assert Rotation.D(Rational(3, 2), Rational(-3, 2), Rational(-3, 2), alpha, beta, gamma).doit() == \ + (3*cos(beta/2) + cos(beta*Rational(3, 2)))/4*exp(I*alpha*Rational(3, 2))*exp(I*gamma*Rational(3, 2)) + # j = 2 + assert Rotation.D(2, 2, 2, alpha, beta, gamma).doit() == \ + (3 + 4*cos(beta) + cos(2*beta))/8*exp(-2*I*alpha)*exp(-2*I*gamma) + assert Rotation.D(2, 2, 1, alpha, beta, gamma).doit() == \ + -((cos(beta) + 1)*exp(-2*I*alpha)*exp(-I*gamma)*sin(beta))/2 + assert Rotation.D(2, 2, 0, alpha, beta, gamma).doit() == \ + sqrt(6)*sin(beta)**2/4*exp(-2*I*alpha) + assert Rotation.D(2, 2, -1, alpha, beta, gamma).doit() == \ + (cos(beta) - 1)*sin(beta)/2*exp(-2*I*alpha)*exp(I*gamma) + assert Rotation.D(2, 2, -2, alpha, beta, gamma).doit() == \ + (3 - 4*cos(beta) + cos(2*beta))/8*exp(-2*I*alpha)*exp(2*I*gamma) + assert Rotation.D(2, 1, 2, alpha, beta, gamma).doit() == \ + (cos(beta) + 1)*sin(beta)/2*exp(-I*alpha)*exp(-2*I*gamma) + assert Rotation.D(2, 1, 1, alpha, beta, gamma).doit() == \ + (cos(beta) + cos(2*beta))/2*exp(-I*alpha)*exp(-I*gamma) + assert Rotation.D(2, 1, 0, alpha, beta, gamma).doit() == -sqrt(6)* \ + sin(2*beta)/4*exp(-I*alpha) + assert Rotation.D(2, 1, -1, alpha, beta, gamma).doit() == \ + (cos(beta) - cos(2*beta))/2*exp(-I*alpha)*exp(I*gamma) + assert Rotation.D(2, 1, -2, alpha, beta, gamma).doit() == \ + (cos(beta) - 1)*sin(beta)/2*exp(-I*alpha)*exp(2*I*gamma) + assert Rotation.D(2, 0, 2, alpha, beta, gamma).doit() == \ + sqrt(6)*sin(beta)**2/4*exp(-2*I*gamma) + assert Rotation.D(2, 0, 1, alpha, beta, gamma).doit() == sqrt(6)* \ + sin(2*beta)/4*exp(-I*gamma) + assert Rotation.D( + 2, 0, 0, alpha, beta, gamma).doit() == (1 + 3*cos(2*beta))/4 + assert Rotation.D(2, 0, -1, alpha, beta, gamma).doit() == -sqrt(6)* \ + sin(2*beta)/4*exp(I*gamma) + assert Rotation.D(2, 0, -2, alpha, beta, gamma).doit() == \ + sqrt(6)*sin(beta)**2/4*exp(2*I*gamma) + assert Rotation.D(2, -1, 2, alpha, beta, gamma).doit() == \ + (2*sin(beta) - sin(2*beta))/4*exp(I*alpha)*exp(-2*I*gamma) + assert Rotation.D(2, -1, 1, alpha, beta, gamma).doit() == \ + (cos(beta) - cos(2*beta))/2*exp(I*alpha)*exp(-I*gamma) + assert Rotation.D(2, -1, 0, alpha, beta, gamma).doit() == sqrt(6)* \ + sin(2*beta)/4*exp(I*alpha) + assert Rotation.D(2, -1, -1, alpha, beta, gamma).doit() == \ + (cos(beta) + cos(2*beta))/2*exp(I*alpha)*exp(I*gamma) + assert Rotation.D(2, -1, -2, alpha, beta, gamma).doit() == \ + -((cos(beta) + 1)*sin(beta))/2*exp(I*alpha)*exp(2*I*gamma) + assert Rotation.D(2, -2, 2, alpha, beta, gamma).doit() == \ + (3 - 4*cos(beta) + cos(2*beta))/8*exp(2*I*alpha)*exp(-2*I*gamma) + assert Rotation.D(2, -2, 1, alpha, beta, gamma).doit() == \ + (2*sin(beta) - sin(2*beta))/4*exp(2*I*alpha)*exp(-I*gamma) + assert Rotation.D(2, -2, 0, alpha, beta, gamma).doit() == \ + sqrt(6)*sin(beta)**2/4*exp(2*I*alpha) + assert Rotation.D(2, -2, -1, alpha, beta, gamma).doit() == \ + (cos(beta) + 1)*sin(beta)/2*exp(2*I*alpha)*exp(I*gamma) + assert Rotation.D(2, -2, -2, alpha, beta, gamma).doit() == \ + (3 + 4*cos(beta) + cos(2*beta))/8*exp(2*I*alpha)*exp(2*I*gamma) + # Numerical tests + # j = 1/2 + assert Rotation.D( + S.Half, S.Half, S.Half, pi/2, pi/2, pi/2).doit() == -I*sqrt(2)/2 + assert Rotation.D( + S.Half, S.Half, Rational(-1, 2), pi/2, pi/2, pi/2).doit() == -sqrt(2)/2 + assert Rotation.D( + S.Half, Rational(-1, 2), S.Half, pi/2, pi/2, pi/2).doit() == sqrt(2)/2 + assert Rotation.D( + S.Half, Rational(-1, 2), Rational(-1, 2), pi/2, pi/2, pi/2).doit() == I*sqrt(2)/2 + # j = 1 + assert Rotation.D(1, 1, 1, pi/2, pi/2, pi/2).doit() == Rational(-1, 2) + assert Rotation.D(1, 1, 0, pi/2, pi/2, pi/2).doit() == I*sqrt(2)/2 + assert Rotation.D(1, 1, -1, pi/2, pi/2, pi/2).doit() == S.Half + assert Rotation.D(1, 0, 1, pi/2, pi/2, pi/2).doit() == -I*sqrt(2)/2 + assert Rotation.D(1, 0, 0, pi/2, pi/2, pi/2).doit() == 0 + assert Rotation.D(1, 0, -1, pi/2, pi/2, pi/2).doit() == -I*sqrt(2)/2 + assert Rotation.D(1, -1, 1, pi/2, pi/2, pi/2).doit() == S.Half + assert Rotation.D(1, -1, 0, pi/2, pi/2, pi/2).doit() == I*sqrt(2)/2 + assert Rotation.D(1, -1, -1, pi/2, pi/2, pi/2).doit() == Rational(-1, 2) + # j = 3/2 + assert Rotation.D( + Rational(3, 2), Rational(3, 2), Rational(3, 2), pi/2, pi/2, pi/2).doit() == I*sqrt(2)/4 + assert Rotation.D( + Rational(3, 2), Rational(3, 2), S.Half, pi/2, pi/2, pi/2).doit() == sqrt(6)/4 + assert Rotation.D( + Rational(3, 2), Rational(3, 2), Rational(-1, 2), pi/2, pi/2, pi/2).doit() == -I*sqrt(6)/4 + assert Rotation.D( + Rational(3, 2), Rational(3, 2), Rational(-3, 2), pi/2, pi/2, pi/2).doit() == -sqrt(2)/4 + assert Rotation.D( + Rational(3, 2), S.Half, Rational(3, 2), pi/2, pi/2, pi/2).doit() == -sqrt(6)/4 + assert Rotation.D( + Rational(3, 2), S.Half, S.Half, pi/2, pi/2, pi/2).doit() == I*sqrt(2)/4 + assert Rotation.D( + Rational(3, 2), S.Half, Rational(-1, 2), pi/2, pi/2, pi/2).doit() == -sqrt(2)/4 + assert Rotation.D( + Rational(3, 2), S.Half, Rational(-3, 2), pi/2, pi/2, pi/2).doit() == I*sqrt(6)/4 + assert Rotation.D( + Rational(3, 2), Rational(-1, 2), Rational(3, 2), pi/2, pi/2, pi/2).doit() == -I*sqrt(6)/4 + assert Rotation.D( + Rational(3, 2), Rational(-1, 2), S.Half, pi/2, pi/2, pi/2).doit() == sqrt(2)/4 + assert Rotation.D( + Rational(3, 2), Rational(-1, 2), Rational(-1, 2), pi/2, pi/2, pi/2).doit() == -I*sqrt(2)/4 + assert Rotation.D( + Rational(3, 2), Rational(-1, 2), Rational(-3, 2), pi/2, pi/2, pi/2).doit() == sqrt(6)/4 + assert Rotation.D( + Rational(3, 2), Rational(-3, 2), Rational(3, 2), pi/2, pi/2, pi/2).doit() == sqrt(2)/4 + assert Rotation.D( + Rational(3, 2), Rational(-3, 2), S.Half, pi/2, pi/2, pi/2).doit() == I*sqrt(6)/4 + assert Rotation.D( + Rational(3, 2), Rational(-3, 2), Rational(-1, 2), pi/2, pi/2, pi/2).doit() == -sqrt(6)/4 + assert Rotation.D( + Rational(3, 2), Rational(-3, 2), Rational(-3, 2), pi/2, pi/2, pi/2).doit() == -I*sqrt(2)/4 + # j = 2 + assert Rotation.D(2, 2, 2, pi/2, pi/2, pi/2).doit() == Rational(1, 4) + assert Rotation.D(2, 2, 1, pi/2, pi/2, pi/2).doit() == -I/2 + assert Rotation.D(2, 2, 0, pi/2, pi/2, pi/2).doit() == -sqrt(6)/4 + assert Rotation.D(2, 2, -1, pi/2, pi/2, pi/2).doit() == I/2 + assert Rotation.D(2, 2, -2, pi/2, pi/2, pi/2).doit() == Rational(1, 4) + assert Rotation.D(2, 1, 2, pi/2, pi/2, pi/2).doit() == I/2 + assert Rotation.D(2, 1, 1, pi/2, pi/2, pi/2).doit() == S.Half + assert Rotation.D(2, 1, 0, pi/2, pi/2, pi/2).doit() == 0 + assert Rotation.D(2, 1, -1, pi/2, pi/2, pi/2).doit() == S.Half + assert Rotation.D(2, 1, -2, pi/2, pi/2, pi/2).doit() == -I/2 + assert Rotation.D(2, 0, 2, pi/2, pi/2, pi/2).doit() == -sqrt(6)/4 + assert Rotation.D(2, 0, 1, pi/2, pi/2, pi/2).doit() == 0 + assert Rotation.D(2, 0, 0, pi/2, pi/2, pi/2).doit() == Rational(-1, 2) + assert Rotation.D(2, 0, -1, pi/2, pi/2, pi/2).doit() == 0 + assert Rotation.D(2, 0, -2, pi/2, pi/2, pi/2).doit() == -sqrt(6)/4 + assert Rotation.D(2, -1, 2, pi/2, pi/2, pi/2).doit() == -I/2 + assert Rotation.D(2, -1, 1, pi/2, pi/2, pi/2).doit() == S.Half + assert Rotation.D(2, -1, 0, pi/2, pi/2, pi/2).doit() == 0 + assert Rotation.D(2, -1, -1, pi/2, pi/2, pi/2).doit() == S.Half + assert Rotation.D(2, -1, -2, pi/2, pi/2, pi/2).doit() == I/2 + assert Rotation.D(2, -2, 2, pi/2, pi/2, pi/2).doit() == Rational(1, 4) + assert Rotation.D(2, -2, 1, pi/2, pi/2, pi/2).doit() == I/2 + assert Rotation.D(2, -2, 0, pi/2, pi/2, pi/2).doit() == -sqrt(6)/4 + assert Rotation.D(2, -2, -1, pi/2, pi/2, pi/2).doit() == -I/2 + assert Rotation.D(2, -2, -2, pi/2, pi/2, pi/2).doit() == Rational(1, 4) + + +def test_wignerd(): + assert Rotation.D( + j, m, mp, alpha, beta, gamma) == WignerD(j, m, mp, alpha, beta, gamma) + assert Rotation.d(j, m, mp, beta) == WignerD(j, m, mp, 0, beta, 0) + +def test_wignerD(): + i,j=symbols('i j') + assert Rotation.D(1, 1, 1, 0, 0, 0) == WignerD(1, 1, 1, 0, 0, 0) + assert Rotation.D(1, 1, 2, 0, 0, 0) == WignerD(1, 1, 2, 0, 0, 0) + assert Rotation.D(1, i**2 - j**2, i**2 - j**2, 0, 0, 0) == WignerD(1, i**2 - j**2, i**2 - j**2, 0, 0, 0) + assert Rotation.D(1, i, i, 0, 0, 0) == WignerD(1, i, i, 0, 0, 0) + assert Rotation.D(1, i, i+1, 0, 0, 0) == WignerD(1, i, i+1, 0, 0, 0) + assert Rotation.D(1, 0, 0, 0, 0, 0) == WignerD(1, 0, 0, 0, 0, 0) + +def test_jplus(): + assert Commutator(Jplus, Jminus).doit() == 2*hbar*Jz + assert Jplus.matrix_element(1, 1, 1, 1) == 0 + assert Jplus.rewrite('xyz') == Jx + I*Jy + # Normal operators, normal states + # Numerical + assert qapply(Jplus*JxKet(1, 1)) == \ + -hbar*sqrt(2)*JxKet(1, 0)/2 + hbar*JxKet(1, 1) + assert qapply(Jplus*JyKet(1, 1)) == \ + hbar*sqrt(2)*JyKet(1, 0)/2 + I*hbar*JyKet(1, 1) + assert qapply(Jplus*JzKet(1, 1)) == 0 + # Symbolic + assert qapply(Jplus*JxKet(j, m)) == \ + Sum(hbar * sqrt(-mi**2 - mi + j**2 + j) * WignerD(j, mi, m, 0, pi/2, 0) * + Sum(WignerD(j, mi1, mi + 1, 0, pi*Rational(3, 2), 0) * JxKet(j, mi1), + (mi1, -j, j)), (mi, -j, j)) + assert qapply(Jplus*JyKet(j, m)) == \ + Sum(hbar * sqrt(j**2 + j - mi**2 - mi) * WignerD(j, mi, m, pi*Rational(3, 2), -pi/2, pi/2) * + Sum(WignerD(j, mi1, mi + 1, pi*Rational(3, 2), pi/2, pi/2) * JyKet(j, mi1), + (mi1, -j, j)), (mi, -j, j)) + assert qapply(Jplus*JzKet(j, m)) == \ + hbar*sqrt(j**2 + j - m**2 - m)*JzKet(j, m + 1) + # Normal operators, coupled states + # Numerical + assert qapply(Jplus*JxKetCoupled(1, 1, (1, 1))) == -hbar*sqrt(2) * \ + JxKetCoupled(1, 0, (1, 1))/2 + hbar*JxKetCoupled(1, 1, (1, 1)) + assert qapply(Jplus*JyKetCoupled(1, 1, (1, 1))) == hbar*sqrt(2) * \ + JyKetCoupled(1, 0, (1, 1))/2 + I*hbar*JyKetCoupled(1, 1, (1, 1)) + assert qapply(Jplus*JzKet(1, 1)) == 0 + # Symbolic + assert qapply(Jplus*JxKetCoupled(j, m, (j1, j2))) == \ + Sum(hbar * sqrt(-mi**2 - mi + j**2 + j) * WignerD(j, mi, m, 0, pi/2, 0) * + Sum( + WignerD( + j, mi1, mi + 1, 0, pi*Rational(3, 2), 0) * JxKetCoupled(j, mi1, (j1, j2)), + (mi1, -j, j)), (mi, -j, j)) + assert qapply(Jplus*JyKetCoupled(j, m, (j1, j2))) == \ + Sum(hbar * sqrt(j**2 + j - mi**2 - mi) * WignerD(j, mi, m, pi*Rational(3, 2), -pi/2, pi/2) * + Sum( + WignerD(j, mi1, mi + 1, pi*Rational(3, 2), pi/2, pi/2) * + JyKetCoupled(j, mi1, (j1, j2)), + (mi1, -j, j)), (mi, -j, j)) + assert qapply(Jplus*JzKetCoupled(j, m, (j1, j2))) == \ + hbar*sqrt(j**2 + j - m**2 - m)*JzKetCoupled(j, m + 1, (j1, j2)) + # Uncoupled operators, uncoupled states + # Numerical + e1 = qapply(TensorProduct(Jplus, 1)*TensorProduct(JxKet(1, 1), JxKet(1, -1))) + e2 = -hbar*sqrt(2)*TensorProduct(JxKet(1, 0), JxKet(1, -1))/2 + \ + hbar*TensorProduct(JxKet(1, 1), JxKet(1, -1)) + assert_simplify_expand(e1, e2) + e1 = qapply(TensorProduct(1, Jplus)*TensorProduct(JxKet(1, 1), JxKet(1, -1))) + e2 = -hbar*TensorProduct(JxKet(1, 1), JxKet(1, -1)) + \ + hbar*sqrt(2)*TensorProduct(JxKet(1, 1), JxKet(1, 0))/2 + assert_simplify_expand(e1, e2) + e1 = qapply(TensorProduct(Jplus, 1)*TensorProduct(JyKet(1, 1), JyKet(1, -1))) + e2 = hbar*sqrt(2)*TensorProduct(JyKet(1, 0), JyKet(1, -1))/2 + \ + hbar*I*TensorProduct(JyKet(1, 1), JyKet(1, -1)) + assert_simplify_expand(e1, e2) + e1 = qapply(TensorProduct(1, Jplus)*TensorProduct(JyKet(1, 1), JyKet(1, -1))) + e2 = -hbar*I*TensorProduct(JyKet(1, 1), JyKet(1, -1)) + \ + hbar*sqrt(2)*TensorProduct(JyKet(1, 1), JyKet(1, 0))/2 + assert_simplify_expand(e1, e2) + assert qapply( + TensorProduct(Jplus, 1)*TensorProduct(JzKet(1, 1), JzKet(1, -1))) == 0 + assert qapply(TensorProduct(1, Jplus)*TensorProduct(JzKet(1, 1), JzKet(1, -1))) == \ + hbar*sqrt(2)*TensorProduct(JzKet(1, 1), JzKet(1, 0)) + # Symbolic + assert qapply(TensorProduct(Jplus, 1)*TensorProduct(JxKet(j1, m1), JxKet(j2, m2))) == \ + TensorProduct(Sum(hbar * sqrt(-mi**2 - mi + j1**2 + j1) * WignerD(j1, mi, m1, 0, pi/2, 0) * + Sum(WignerD(j1, mi1, mi + 1, 0, pi*Rational(3, 2), 0) * JxKet(j1, mi1), + (mi1, -j1, j1)), (mi, -j1, j1)), JxKet(j2, m2)) + assert qapply(TensorProduct(1, Jplus)*TensorProduct(JxKet(j1, m1), JxKet(j2, m2))) == \ + TensorProduct(JxKet(j1, m1), Sum(hbar * sqrt(-mi**2 - mi + j2**2 + j2) * WignerD(j2, mi, m2, 0, pi/2, 0) * + Sum(WignerD(j2, mi1, mi + 1, 0, pi*Rational(3, 2), 0) * JxKet(j2, mi1), + (mi1, -j2, j2)), (mi, -j2, j2))) + assert qapply(TensorProduct(Jplus, 1)*TensorProduct(JyKet(j1, m1), JyKet(j2, m2))) == \ + TensorProduct(Sum(hbar * sqrt(j1**2 + j1 - mi**2 - mi) * WignerD(j1, mi, m1, pi*Rational(3, 2), -pi/2, pi/2) * + Sum(WignerD(j1, mi1, mi + 1, pi*Rational(3, 2), pi/2, pi/2) * JyKet(j1, mi1), + (mi1, -j1, j1)), (mi, -j1, j1)), JyKet(j2, m2)) + assert qapply(TensorProduct(1, Jplus)*TensorProduct(JyKet(j1, m1), JyKet(j2, m2))) == \ + TensorProduct(JyKet(j1, m1), Sum(hbar * sqrt(j2**2 + j2 - mi**2 - mi) * WignerD(j2, mi, m2, pi*Rational(3, 2), -pi/2, pi/2) * + Sum(WignerD(j2, mi1, mi + 1, pi*Rational(3, 2), pi/2, pi/2) * JyKet(j2, mi1), + (mi1, -j2, j2)), (mi, -j2, j2))) + assert qapply(TensorProduct(Jplus, 1)*TensorProduct(JzKet(j1, m1), JzKet(j2, m2))) == \ + hbar*sqrt( + j1**2 + j1 - m1**2 - m1)*TensorProduct(JzKet(j1, m1 + 1), JzKet(j2, m2)) + assert qapply(TensorProduct(1, Jplus)*TensorProduct(JzKet(j1, m1), JzKet(j2, m2))) == \ + hbar*sqrt( + j2**2 + j2 - m2**2 - m2)*TensorProduct(JzKet(j1, m1), JzKet(j2, m2 + 1)) + + +def test_jminus(): + assert qapply(Jminus*JzKet(1, -1)) == 0 + assert Jminus.matrix_element(1, 0, 1, 1) == sqrt(2)*hbar + assert Jminus.rewrite('xyz') == Jx - I*Jy + # Normal operators, normal states + # Numerical + assert qapply(Jminus*JxKet(1, 1)) == \ + hbar*sqrt(2)*JxKet(1, 0)/2 + hbar*JxKet(1, 1) + assert qapply(Jminus*JyKet(1, 1)) == \ + hbar*sqrt(2)*JyKet(1, 0)/2 - hbar*I*JyKet(1, 1) + assert qapply(Jminus*JzKet(1, 1)) == sqrt(2)*hbar*JzKet(1, 0) + # Symbolic + assert qapply(Jminus*JxKet(j, m)) == \ + Sum(hbar*sqrt(j**2 + j - mi**2 + mi)*WignerD(j, mi, m, 0, pi/2, 0) * + Sum(WignerD(j, mi1, mi - 1, 0, pi*Rational(3, 2), 0)*JxKet(j, mi1), + (mi1, -j, j)), (mi, -j, j)) + assert qapply(Jminus*JyKet(j, m)) == \ + Sum(hbar*sqrt(j**2 + j - mi**2 + mi)*WignerD(j, mi, m, pi*Rational(3, 2), -pi/2, pi/2) * + Sum(WignerD(j, mi1, mi - 1, pi*Rational(3, 2), pi/2, pi/2)*JyKet(j, mi1), + (mi1, -j, j)), (mi, -j, j)) + assert qapply(Jminus*JzKet(j, m)) == \ + hbar*sqrt(j**2 + j - m**2 + m)*JzKet(j, m - 1) + # Normal operators, coupled states + # Numerical + assert qapply(Jminus*JxKetCoupled(1, 1, (1, 1))) == \ + hbar*sqrt(2)*JxKetCoupled(1, 0, (1, 1))/2 + \ + hbar*JxKetCoupled(1, 1, (1, 1)) + assert qapply(Jminus*JyKetCoupled(1, 1, (1, 1))) == \ + hbar*sqrt(2)*JyKetCoupled(1, 0, (1, 1))/2 - \ + hbar*I*JyKetCoupled(1, 1, (1, 1)) + assert qapply(Jminus*JzKetCoupled(1, 1, (1, 1))) == \ + sqrt(2)*hbar*JzKetCoupled(1, 0, (1, 1)) + # Symbolic + assert qapply(Jminus*JxKetCoupled(j, m, (j1, j2))) == \ + Sum(hbar*sqrt(j**2 + j - mi**2 + mi)*WignerD(j, mi, m, 0, pi/2, 0) * + Sum(WignerD(j, mi1, mi - 1, 0, pi*Rational(3, 2), 0)*JxKetCoupled(j, mi1, (j1, j2)), + (mi1, -j, j)), (mi, -j, j)) + assert qapply(Jminus*JyKetCoupled(j, m, (j1, j2))) == \ + Sum(hbar*sqrt(j**2 + j - mi**2 + mi)*WignerD(j, mi, m, pi*Rational(3, 2), -pi/2, pi/2) * + Sum( + WignerD(j, mi1, mi - 1, pi*Rational(3, 2), pi/2, pi/2)* + JyKetCoupled(j, mi1, (j1, j2)), + (mi1, -j, j)), (mi, -j, j)) + assert qapply(Jminus*JzKetCoupled(j, m, (j1, j2))) == \ + hbar*sqrt(j**2 + j - m**2 + m)*JzKetCoupled(j, m - 1, (j1, j2)) + # Uncoupled operators, uncoupled states + # Numerical + e1 = qapply(TensorProduct(Jminus, 1)*TensorProduct(JxKet(1, 1), JxKet(1, -1))) + e2 = hbar*sqrt(2)*TensorProduct(JxKet(1, 0), JxKet(1, -1))/2 + \ + hbar*TensorProduct(JxKet(1, 1), JxKet(1, -1)) + assert_simplify_expand(e1, e2) + e1 = qapply(TensorProduct(1, Jminus)*TensorProduct(JxKet(1, 1), JxKet(1, -1))) + e2 = -hbar*TensorProduct(JxKet(1, 1), JxKet(1, -1)) - \ + hbar*sqrt(2)*TensorProduct(JxKet(1, 1), JxKet(1, 0))/2 + assert_simplify_expand(e1, e2) + e1 = qapply(TensorProduct(Jminus, 1)*TensorProduct(JyKet(1, 1), JyKet(1, -1))) + e2 = hbar*sqrt(2)*TensorProduct(JyKet(1, 0), JyKet(1, -1))/2 - \ + hbar*I*TensorProduct(JyKet(1, 1), JyKet(1, -1)) + assert_simplify_expand(e1, e2) + e1 = qapply(TensorProduct(1, Jminus)*TensorProduct(JyKet(1, 1), JyKet(1, -1))) + e2 = hbar*I*TensorProduct(JyKet(1, 1), JyKet(1, -1)) + \ + hbar*sqrt(2)*TensorProduct(JyKet(1, 1), JyKet(1, 0))/2 + assert_simplify_expand(e1, e2) + assert qapply(TensorProduct(Jminus, 1)*TensorProduct(JzKet(1, 1), JzKet(1, -1))) == \ + sqrt(2)*hbar*TensorProduct(JzKet(1, 0), JzKet(1, -1)) + assert qapply(TensorProduct( + 1, Jminus)*TensorProduct(JzKet(1, 1), JzKet(1, -1))) == 0 + # Symbolic + assert qapply(TensorProduct(Jminus, 1)*TensorProduct(JxKet(j1, m1), JxKet(j2, m2))) == \ + TensorProduct(Sum(hbar*sqrt(j1**2 + j1 - mi**2 + mi)*WignerD(j1, mi, m1, 0, pi/2, 0) * + Sum(WignerD(j1, mi1, mi - 1, 0, pi*Rational(3, 2), 0)*JxKet(j1, mi1), + (mi1, -j1, j1)), (mi, -j1, j1)), JxKet(j2, m2)) + assert qapply(TensorProduct(1, Jminus)*TensorProduct(JxKet(j1, m1), JxKet(j2, m2))) == \ + TensorProduct(JxKet(j1, m1), Sum(hbar*sqrt(j2**2 + j2 - mi**2 + mi)*WignerD(j2, mi, m2, 0, pi/2, 0) * + Sum(WignerD(j2, mi1, mi - 1, 0, pi*Rational(3, 2), 0)*JxKet(j2, mi1), + (mi1, -j2, j2)), (mi, -j2, j2))) + assert qapply(TensorProduct(Jminus, 1)*TensorProduct(JyKet(j1, m1), JyKet(j2, m2))) == \ + TensorProduct(Sum(hbar*sqrt(j1**2 + j1 - mi**2 + mi)*WignerD(j1, mi, m1, pi*Rational(3, 2), -pi/2, pi/2) * + Sum(WignerD(j1, mi1, mi - 1, pi*Rational(3, 2), pi/2, pi/2)*JyKet(j1, mi1), + (mi1, -j1, j1)), (mi, -j1, j1)), JyKet(j2, m2)) + assert qapply(TensorProduct(1, Jminus)*TensorProduct(JyKet(j1, m1), JyKet(j2, m2))) == \ + TensorProduct(JyKet(j1, m1), Sum(hbar*sqrt(j2**2 + j2 - mi**2 + mi)*WignerD(j2, mi, m2, pi*Rational(3, 2), -pi/2, pi/2) * + Sum(WignerD(j2, mi1, mi - 1, pi*Rational(3, 2), pi/2, pi/2)*JyKet(j2, mi1), + (mi1, -j2, j2)), (mi, -j2, j2))) + assert qapply(TensorProduct(Jminus, 1)*TensorProduct(JzKet(j1, m1), JzKet(j2, m2))) == \ + hbar*sqrt( + j1**2 + j1 - m1**2 + m1)*TensorProduct(JzKet(j1, m1 - 1), JzKet(j2, m2)) + assert qapply(TensorProduct(1, Jminus)*TensorProduct(JzKet(j1, m1), JzKet(j2, m2))) == \ + hbar*sqrt( + j2**2 + j2 - m2**2 + m2)*TensorProduct(JzKet(j1, m1), JzKet(j2, m2 - 1)) + + +def test_j2(): + assert Commutator(J2, Jz).doit() == 0 + assert J2.matrix_element(1, 1, 1, 1) == 2*hbar**2 + # Normal operators, normal states + # Numerical + assert qapply(J2*JxKet(1, 1)) == 2*hbar**2*JxKet(1, 1) + assert qapply(J2*JyKet(1, 1)) == 2*hbar**2*JyKet(1, 1) + assert qapply(J2*JzKet(1, 1)) == 2*hbar**2*JzKet(1, 1) + # Symbolic + assert qapply(J2*JxKet(j, m)) == \ + hbar**2*j**2*JxKet(j, m) + hbar**2*j*JxKet(j, m) + assert qapply(J2*JyKet(j, m)) == \ + hbar**2*j**2*JyKet(j, m) + hbar**2*j*JyKet(j, m) + assert qapply(J2*JzKet(j, m)) == \ + hbar**2*j**2*JzKet(j, m) + hbar**2*j*JzKet(j, m) + # Normal operators, coupled states + # Numerical + assert qapply(J2*JxKetCoupled(1, 1, (1, 1))) == \ + 2*hbar**2*JxKetCoupled(1, 1, (1, 1)) + assert qapply(J2*JyKetCoupled(1, 1, (1, 1))) == \ + 2*hbar**2*JyKetCoupled(1, 1, (1, 1)) + assert qapply(J2*JzKetCoupled(1, 1, (1, 1))) == \ + 2*hbar**2*JzKetCoupled(1, 1, (1, 1)) + # Symbolic + assert qapply(J2*JxKetCoupled(j, m, (j1, j2))) == \ + hbar**2*j**2*JxKetCoupled(j, m, (j1, j2)) + \ + hbar**2*j*JxKetCoupled(j, m, (j1, j2)) + assert qapply(J2*JyKetCoupled(j, m, (j1, j2))) == \ + hbar**2*j**2*JyKetCoupled(j, m, (j1, j2)) + \ + hbar**2*j*JyKetCoupled(j, m, (j1, j2)) + assert qapply(J2*JzKetCoupled(j, m, (j1, j2))) == \ + hbar**2*j**2*JzKetCoupled(j, m, (j1, j2)) + \ + hbar**2*j*JzKetCoupled(j, m, (j1, j2)) + # Uncoupled operators, uncoupled states + # Numerical + assert qapply(TensorProduct(J2, 1)*TensorProduct(JxKet(1, 1), JxKet(1, -1))) == \ + 2*hbar**2*TensorProduct(JxKet(1, 1), JxKet(1, -1)) + assert qapply(TensorProduct(1, J2)*TensorProduct(JxKet(1, 1), JxKet(1, -1))) == \ + 2*hbar**2*TensorProduct(JxKet(1, 1), JxKet(1, -1)) + assert qapply(TensorProduct(J2, 1)*TensorProduct(JyKet(1, 1), JyKet(1, -1))) == \ + 2*hbar**2*TensorProduct(JyKet(1, 1), JyKet(1, -1)) + assert qapply(TensorProduct(1, J2)*TensorProduct(JyKet(1, 1), JyKet(1, -1))) == \ + 2*hbar**2*TensorProduct(JyKet(1, 1), JyKet(1, -1)) + assert qapply(TensorProduct(J2, 1)*TensorProduct(JzKet(1, 1), JzKet(1, -1))) == \ + 2*hbar**2*TensorProduct(JzKet(1, 1), JzKet(1, -1)) + assert qapply(TensorProduct(1, J2)*TensorProduct(JzKet(1, 1), JzKet(1, -1))) == \ + 2*hbar**2*TensorProduct(JzKet(1, 1), JzKet(1, -1)) + # Symbolic + e1 = qapply(TensorProduct(J2, 1)*TensorProduct(JxKet(j1, m1), JxKet(j2, m2))) + e2 = hbar**2*j1**2*TensorProduct(JxKet(j1, m1), JxKet(j2, m2)) + \ + hbar**2*j1*TensorProduct(JxKet(j1, m1), JxKet(j2, m2)) + assert_simplify_expand(e1, e2) + e1 = qapply(TensorProduct(1, J2)*TensorProduct(JxKet(j1, m1), JxKet(j2, m2))) + e2 = hbar**2*j2**2*TensorProduct(JxKet(j1, m1), JxKet(j2, m2)) + \ + hbar**2*j2*TensorProduct(JxKet(j1, m1), JxKet(j2, m2)) + assert_simplify_expand(e1, e2) + e1 = qapply(TensorProduct(J2, 1)*TensorProduct(JyKet(j1, m1), JyKet(j2, m2))) + e2 = hbar**2*j1**2*TensorProduct(JyKet(j1, m1), JyKet(j2, m2)) + \ + hbar**2*j1*TensorProduct(JyKet(j1, m1), JyKet(j2, m2)) + assert_simplify_expand(e1, e2) + e1 = qapply(TensorProduct(1, J2)*TensorProduct(JyKet(j1, m1), JyKet(j2, m2))) + e2 = hbar**2*j2**2*TensorProduct(JyKet(j1, m1), JyKet(j2, m2)) + \ + hbar**2*j2*TensorProduct(JyKet(j1, m1), JyKet(j2, m2)) + assert_simplify_expand(e1, e2) + e1 = qapply(TensorProduct(J2, 1)*TensorProduct(JzKet(j1, m1), JzKet(j2, m2))) + e2 = hbar**2*j1**2*TensorProduct(JzKet(j1, m1), JzKet(j2, m2)) + \ + hbar**2*j1*TensorProduct(JzKet(j1, m1), JzKet(j2, m2)) + assert_simplify_expand(e1, e2) + e1 = qapply(TensorProduct(1, J2)*TensorProduct(JzKet(j1, m1), JzKet(j2, m2))) + e2 = hbar**2*j2**2*TensorProduct(JzKet(j1, m1), JzKet(j2, m2)) + \ + hbar**2*j2*TensorProduct(JzKet(j1, m1), JzKet(j2, m2)) + assert_simplify_expand(e1, e2) + + +def test_jx(): + assert Commutator(Jx, Jz).doit() == -I*hbar*Jy + assert Jx.rewrite('plusminus') == (Jminus + Jplus)/2 + assert represent(Jx, basis=Jz, j=1) == ( + represent(Jplus, basis=Jz, j=1) + represent(Jminus, basis=Jz, j=1))/2 + # Normal operators, normal states + # Numerical + assert qapply(Jx*JxKet(1, 1)) == hbar*JxKet(1, 1) + assert qapply(Jx*JyKet(1, 1)) == hbar*JyKet(1, 1) + assert qapply(Jx*JzKet(1, 1)) == sqrt(2)*hbar*JzKet(1, 0)/2 + # Symbolic + assert qapply(Jx*JxKet(j, m)) == hbar*m*JxKet(j, m) + assert qapply(Jx*JyKet(j, m)) == \ + Sum(hbar*mi*WignerD(j, mi, m, 0, 0, pi/2)*Sum(WignerD(j, + mi1, mi, pi*Rational(3, 2), 0, 0)*JyKet(j, mi1), (mi1, -j, j)), (mi, -j, j)) + assert qapply(Jx*JzKet(j, m)) == \ + hbar*sqrt(j**2 + j - m**2 - m)*JzKet(j, m + 1)/2 + hbar*sqrt(j**2 + + j - m**2 + m)*JzKet(j, m - 1)/2 + # Normal operators, coupled states + # Numerical + assert qapply(Jx*JxKetCoupled(1, 1, (1, 1))) == \ + hbar*JxKetCoupled(1, 1, (1, 1)) + assert qapply(Jx*JyKetCoupled(1, 1, (1, 1))) == \ + hbar*JyKetCoupled(1, 1, (1, 1)) + assert qapply(Jx*JzKetCoupled(1, 1, (1, 1))) == \ + sqrt(2)*hbar*JzKetCoupled(1, 0, (1, 1))/2 + # Symbolic + assert qapply(Jx*JxKetCoupled(j, m, (j1, j2))) == \ + hbar*m*JxKetCoupled(j, m, (j1, j2)) + assert qapply(Jx*JyKetCoupled(j, m, (j1, j2))) == \ + Sum(hbar*mi*WignerD(j, mi, m, 0, 0, pi/2)*Sum(WignerD(j, mi1, mi, pi*Rational(3, 2), 0, 0)*JyKetCoupled(j, mi1, (j1, j2)), (mi1, -j, j)), (mi, -j, j)) + assert qapply(Jx*JzKetCoupled(j, m, (j1, j2))) == \ + hbar*sqrt(j**2 + j - m**2 - m)*JzKetCoupled(j, m + 1, (j1, j2))/2 + \ + hbar*sqrt(j**2 + j - m**2 + m)*JzKetCoupled(j, m - 1, (j1, j2))/2 + # Normal operators, uncoupled states + # Numerical + assert qapply(Jx*TensorProduct(JxKet(1, 1), JxKet(1, 1))) == \ + 2*hbar*TensorProduct(JxKet(1, 1), JxKet(1, 1)) + assert qapply(Jx*TensorProduct(JyKet(1, 1), JyKet(1, 1))) == \ + hbar*TensorProduct(JyKet(1, 1), JyKet(1, 1)) + \ + hbar*TensorProduct(JyKet(1, 1), JyKet(1, 1)) + assert qapply(Jx*TensorProduct(JzKet(1, 1), JzKet(1, 1))) == \ + sqrt(2)*hbar*TensorProduct(JzKet(1, 1), JzKet(1, 0))/2 + \ + sqrt(2)*hbar*TensorProduct(JzKet(1, 0), JzKet(1, 1))/2 + assert qapply(Jx*TensorProduct(JxKet(1, 1), JxKet(1, -1))) == 0 + # Symbolic + assert qapply(Jx*TensorProduct(JxKet(j1, m1), JxKet(j2, m2))) == \ + hbar*m1*TensorProduct(JxKet(j1, m1), JxKet(j2, m2)) + \ + hbar*m2*TensorProduct(JxKet(j1, m1), JxKet(j2, m2)) + assert qapply(Jx*TensorProduct(JyKet(j1, m1), JyKet(j2, m2))) == \ + TensorProduct(Sum(hbar*mi*WignerD(j1, mi, m1, 0, 0, pi/2)*Sum(WignerD(j1, mi1, mi, pi*Rational(3, 2), 0, 0)*JyKet(j1, mi1), (mi1, -j1, j1)), (mi, -j1, j1)), JyKet(j2, m2)) + \ + TensorProduct(JyKet(j1, m1), Sum(hbar*mi*WignerD(j2, mi, m2, 0, 0, pi/2)*Sum(WignerD(j2, mi1, mi, pi*Rational(3, 2), 0, 0)*JyKet(j2, mi1), (mi1, -j2, j2)), (mi, -j2, j2))) + assert qapply(Jx*TensorProduct(JzKet(j1, m1), JzKet(j2, m2))) == \ + hbar*sqrt(j1**2 + j1 - m1**2 - m1)*TensorProduct(JzKet(j1, m1 + 1), JzKet(j2, m2))/2 + \ + hbar*sqrt(j1**2 + j1 - m1**2 + m1)*TensorProduct(JzKet(j1, m1 - 1), JzKet(j2, m2))/2 + \ + hbar*sqrt(j2**2 + j2 - m2**2 - m2)*TensorProduct(JzKet(j1, m1), JzKet(j2, m2 + 1))/2 + \ + hbar*sqrt( + j2**2 + j2 - m2**2 + m2)*TensorProduct(JzKet(j1, m1), JzKet(j2, m2 - 1))/2 + # Uncoupled operators, uncoupled states + # Numerical + assert qapply(TensorProduct(Jx, 1)*TensorProduct(JxKet(1, 1), JxKet(1, -1))) == \ + hbar*TensorProduct(JxKet(1, 1), JxKet(1, -1)) + assert qapply(TensorProduct(1, Jx)*TensorProduct(JxKet(1, 1), JxKet(1, -1))) == \ + -hbar*TensorProduct(JxKet(1, 1), JxKet(1, -1)) + assert qapply(TensorProduct(Jx, 1)*TensorProduct(JyKet(1, 1), JyKet(1, -1))) == \ + hbar*TensorProduct(JyKet(1, 1), JyKet(1, -1)) + assert qapply(TensorProduct(1, Jx)*TensorProduct(JyKet(1, 1), JyKet(1, -1))) == \ + -hbar*TensorProduct(JyKet(1, 1), JyKet(1, -1)) + assert qapply(TensorProduct(Jx, 1)*TensorProduct(JzKet(1, 1), JzKet(1, -1))) == \ + hbar*sqrt(2)*TensorProduct(JzKet(1, 0), JzKet(1, -1))/2 + assert qapply(TensorProduct(1, Jx)*TensorProduct(JzKet(1, 1), JzKet(1, -1))) == \ + hbar*sqrt(2)*TensorProduct(JzKet(1, 1), JzKet(1, 0))/2 + # Symbolic + assert qapply(TensorProduct(Jx, 1)*TensorProduct(JxKet(j1, m1), JxKet(j2, m2))) == \ + hbar*m1*TensorProduct(JxKet(j1, m1), JxKet(j2, m2)) + assert qapply(TensorProduct(1, Jx)*TensorProduct(JxKet(j1, m1), JxKet(j2, m2))) == \ + hbar*m2*TensorProduct(JxKet(j1, m1), JxKet(j2, m2)) + assert qapply(TensorProduct(Jx, 1)*TensorProduct(JyKet(j1, m1), JyKet(j2, m2))) == \ + TensorProduct(Sum(hbar*mi*WignerD(j1, mi, m1, 0, 0, pi/2) * Sum(WignerD(j1, mi1, mi, pi*Rational(3, 2), 0, 0)*JyKet(j1, mi1), (mi1, -j1, j1)), (mi, -j1, j1)), JyKet(j2, m2)) + assert qapply(TensorProduct(1, Jx)*TensorProduct(JyKet(j1, m1), JyKet(j2, m2))) == \ + TensorProduct(JyKet(j1, m1), Sum(hbar*mi*WignerD(j2, mi, m2, 0, 0, pi/2) * Sum(WignerD(j2, mi1, mi, pi*Rational(3, 2), 0, 0)*JyKet(j2, mi1), (mi1, -j2, j2)), (mi, -j2, j2))) + e1 = qapply(TensorProduct(Jx, 1)*TensorProduct(JzKet(j1, m1), JzKet(j2, m2))) + e2 = hbar*sqrt(j1**2 + j1 - m1**2 - m1)*TensorProduct(JzKet(j1, m1 + 1), JzKet(j2, m2))/2 + \ + hbar*sqrt( + j1**2 + j1 - m1**2 + m1)*TensorProduct(JzKet(j1, m1 - 1), JzKet(j2, m2))/2 + assert_simplify_expand(e1, e2) + e1 = qapply(TensorProduct(1, Jx)*TensorProduct(JzKet(j1, m1), JzKet(j2, m2))) + e2 = hbar*sqrt(j2**2 + j2 - m2**2 - m2)*TensorProduct(JzKet(j1, m1), JzKet(j2, m2 + 1))/2 + \ + hbar*sqrt( + j2**2 + j2 - m2**2 + m2)*TensorProduct(JzKet(j1, m1), JzKet(j2, m2 - 1))/2 + assert_simplify_expand(e1, e2) + + +def test_jy(): + assert Commutator(Jy, Jz).doit() == I*hbar*Jx + assert Jy.rewrite('plusminus') == (Jplus - Jminus)/(2*I) + assert represent(Jy, basis=Jz) == ( + represent(Jplus, basis=Jz) - represent(Jminus, basis=Jz))/(2*I) + # Normal operators, normal states + # Numerical + assert qapply(Jy*JxKet(1, 1)) == hbar*JxKet(1, 1) + assert qapply(Jy*JyKet(1, 1)) == hbar*JyKet(1, 1) + assert qapply(Jy*JzKet(1, 1)) == sqrt(2)*hbar*I*JzKet(1, 0)/2 + # Symbolic + assert qapply(Jy*JxKet(j, m)) == \ + Sum(hbar*mi*WignerD(j, mi, m, pi*Rational(3, 2), 0, 0)*Sum(WignerD( + j, mi1, mi, 0, 0, pi/2)*JxKet(j, mi1), (mi1, -j, j)), (mi, -j, j)) + assert qapply(Jy*JyKet(j, m)) == hbar*m*JyKet(j, m) + assert qapply(Jy*JzKet(j, m)) == \ + -hbar*I*sqrt(j**2 + j - m**2 - m)*JzKet( + j, m + 1)/2 + hbar*I*sqrt(j**2 + j - m**2 + m)*JzKet(j, m - 1)/2 + # Normal operators, coupled states + # Numerical + assert qapply(Jy*JxKetCoupled(1, 1, (1, 1))) == \ + hbar*JxKetCoupled(1, 1, (1, 1)) + assert qapply(Jy*JyKetCoupled(1, 1, (1, 1))) == \ + hbar*JyKetCoupled(1, 1, (1, 1)) + assert qapply(Jy*JzKetCoupled(1, 1, (1, 1))) == \ + sqrt(2)*hbar*I*JzKetCoupled(1, 0, (1, 1))/2 + # Symbolic + assert qapply(Jy*JxKetCoupled(j, m, (j1, j2))) == \ + Sum(hbar*mi*WignerD(j, mi, m, pi*Rational(3, 2), 0, 0)*Sum(WignerD(j, mi1, mi, 0, 0, pi/2)*JxKetCoupled(j, mi1, (j1, j2)), (mi1, -j, j)), (mi, -j, j)) + assert qapply(Jy*JyKetCoupled(j, m, (j1, j2))) == \ + hbar*m*JyKetCoupled(j, m, (j1, j2)) + assert qapply(Jy*JzKetCoupled(j, m, (j1, j2))) == \ + -hbar*I*sqrt(j**2 + j - m**2 - m)*JzKetCoupled(j, m + 1, (j1, j2))/2 + \ + hbar*I*sqrt(j**2 + j - m**2 + m)*JzKetCoupled(j, m - 1, (j1, j2))/2 + # Normal operators, uncoupled states + # Numerical + assert qapply(Jy*TensorProduct(JxKet(1, 1), JxKet(1, 1))) == \ + hbar*TensorProduct(JxKet(1, 1), JxKet(1, 1)) + \ + hbar*TensorProduct(JxKet(1, 1), JxKet(1, 1)) + assert qapply(Jy*TensorProduct(JyKet(1, 1), JyKet(1, 1))) == \ + 2*hbar*TensorProduct(JyKet(1, 1), JyKet(1, 1)) + assert qapply(Jy*TensorProduct(JzKet(1, 1), JzKet(1, 1))) == \ + sqrt(2)*hbar*I*TensorProduct(JzKet(1, 1), JzKet(1, 0))/2 + \ + sqrt(2)*hbar*I*TensorProduct(JzKet(1, 0), JzKet(1, 1))/2 + assert qapply(Jy*TensorProduct(JyKet(1, 1), JyKet(1, -1))) == 0 + # Symbolic + assert qapply(Jy*TensorProduct(JxKet(j1, m1), JxKet(j2, m2))) == \ + TensorProduct(JxKet(j1, m1), Sum(hbar*mi*WignerD(j2, mi, m2, pi*Rational(3, 2), 0, 0)*Sum(WignerD(j2, mi1, mi, 0, 0, pi/2)*JxKet(j2, mi1), (mi1, -j2, j2)), (mi, -j2, j2))) + \ + TensorProduct(Sum(hbar*mi*WignerD(j1, mi, m1, pi*Rational(3, 2), 0, 0)*Sum(WignerD(j1, mi1, mi, 0, 0, pi/2)*JxKet(j1, mi1), (mi1, -j1, j1)), (mi, -j1, j1)), JxKet(j2, m2)) + assert qapply(Jy*TensorProduct(JyKet(j1, m1), JyKet(j2, m2))) == \ + hbar*m1*TensorProduct(JyKet(j1, m1), JyKet( + j2, m2)) + hbar*m2*TensorProduct(JyKet(j1, m1), JyKet(j2, m2)) + assert qapply(Jy*TensorProduct(JzKet(j1, m1), JzKet(j2, m2))) == \ + -hbar*I*sqrt(j1**2 + j1 - m1**2 - m1)*TensorProduct(JzKet(j1, m1 + 1), JzKet(j2, m2))/2 + \ + hbar*I*sqrt(j1**2 + j1 - m1**2 + m1)*TensorProduct(JzKet(j1, m1 - 1), JzKet(j2, m2))/2 + \ + -hbar*I*sqrt(j2**2 + j2 - m2**2 - m2)*TensorProduct(JzKet(j1, m1), JzKet(j2, m2 + 1))/2 + \ + hbar*I*sqrt( + j2**2 + j2 - m2**2 + m2)*TensorProduct(JzKet(j1, m1), JzKet(j2, m2 - 1))/2 + # Uncoupled operators, uncoupled states + # Numerical + assert qapply(TensorProduct(Jy, 1)*TensorProduct(JxKet(1, 1), JxKet(1, -1))) == \ + hbar*TensorProduct(JxKet(1, 1), JxKet(1, -1)) + assert qapply(TensorProduct(1, Jy)*TensorProduct(JxKet(1, 1), JxKet(1, -1))) == \ + -hbar*TensorProduct(JxKet(1, 1), JxKet(1, -1)) + assert qapply(TensorProduct(Jy, 1)*TensorProduct(JyKet(1, 1), JyKet(1, -1))) == \ + hbar*TensorProduct(JyKet(1, 1), JyKet(1, -1)) + assert qapply(TensorProduct(1, Jy)*TensorProduct(JyKet(1, 1), JyKet(1, -1))) == \ + -hbar*TensorProduct(JyKet(1, 1), JyKet(1, -1)) + assert qapply(TensorProduct(Jy, 1)*TensorProduct(JzKet(1, 1), JzKet(1, -1))) == \ + hbar*sqrt(2)*I*TensorProduct(JzKet(1, 0), JzKet(1, -1))/2 + assert qapply(TensorProduct(1, Jy)*TensorProduct(JzKet(1, 1), JzKet(1, -1))) == \ + -hbar*sqrt(2)*I*TensorProduct(JzKet(1, 1), JzKet(1, 0))/2 + # Symbolic + assert qapply(TensorProduct(Jy, 1)*TensorProduct(JxKet(j1, m1), JxKet(j2, m2))) == \ + TensorProduct(Sum(hbar*mi*WignerD(j1, mi, m1, pi*Rational(3, 2), 0, 0) * Sum(WignerD(j1, mi1, mi, 0, 0, pi/2)*JxKet(j1, mi1), (mi1, -j1, j1)), (mi, -j1, j1)), JxKet(j2, m2)) + assert qapply(TensorProduct(1, Jy)*TensorProduct(JxKet(j1, m1), JxKet(j2, m2))) == \ + TensorProduct(JxKet(j1, m1), Sum(hbar*mi*WignerD(j2, mi, m2, pi*Rational(3, 2), 0, 0) * Sum(WignerD(j2, mi1, mi, 0, 0, pi/2)*JxKet(j2, mi1), (mi1, -j2, j2)), (mi, -j2, j2))) + assert qapply(TensorProduct(Jy, 1)*TensorProduct(JyKet(j1, m1), JyKet(j2, m2))) == \ + hbar*m1*TensorProduct(JyKet(j1, m1), JyKet(j2, m2)) + assert qapply(TensorProduct(1, Jy)*TensorProduct(JyKet(j1, m1), JyKet(j2, m2))) == \ + hbar*m2*TensorProduct(JyKet(j1, m1), JyKet(j2, m2)) + e1 = qapply(TensorProduct(Jy, 1)*TensorProduct(JzKet(j1, m1), JzKet(j2, m2))) + e2 = -hbar*I*sqrt(j1**2 + j1 - m1**2 - m1)*TensorProduct(JzKet(j1, m1 + 1), JzKet(j2, m2))/2 + \ + hbar*I*sqrt( + j1**2 + j1 - m1**2 + m1)*TensorProduct(JzKet(j1, m1 - 1), JzKet(j2, m2))/2 + assert_simplify_expand(e1, e2) + e1 = qapply(TensorProduct(1, Jy)*TensorProduct(JzKet(j1, m1), JzKet(j2, m2))) + e2 = -hbar*I*sqrt(j2**2 + j2 - m2**2 - m2)*TensorProduct(JzKet(j1, m1), JzKet(j2, m2 + 1))/2 + \ + hbar*I*sqrt( + j2**2 + j2 - m2**2 + m2)*TensorProduct(JzKet(j1, m1), JzKet(j2, m2 - 1))/2 + assert_simplify_expand(e1, e2) + + +def test_jz(): + assert Commutator(Jz, Jminus).doit() == -hbar*Jminus + # Normal operators, normal states + # Numerical + assert qapply(Jz*JxKet(1, 1)) == -sqrt(2)*hbar*JxKet(1, 0)/2 + assert qapply(Jz*JyKet(1, 1)) == -sqrt(2)*hbar*I*JyKet(1, 0)/2 + assert qapply(Jz*JzKet(2, 1)) == hbar*JzKet(2, 1) + # Symbolic + assert qapply(Jz*JxKet(j, m)) == \ + Sum(hbar*mi*WignerD(j, mi, m, 0, pi/2, 0)*Sum(WignerD(j, + mi1, mi, 0, pi*Rational(3, 2), 0)*JxKet(j, mi1), (mi1, -j, j)), (mi, -j, j)) + assert qapply(Jz*JyKet(j, m)) == \ + Sum(hbar*mi*WignerD(j, mi, m, pi*Rational(3, 2), -pi/2, pi/2)*Sum(WignerD(j, mi1, + mi, pi*Rational(3, 2), pi/2, pi/2)*JyKet(j, mi1), (mi1, -j, j)), (mi, -j, j)) + assert qapply(Jz*JzKet(j, m)) == hbar*m*JzKet(j, m) + # Normal operators, coupled states + # Numerical + assert qapply(Jz*JxKetCoupled(1, 1, (1, 1))) == \ + -sqrt(2)*hbar*JxKetCoupled(1, 0, (1, 1))/2 + assert qapply(Jz*JyKetCoupled(1, 1, (1, 1))) == \ + -sqrt(2)*hbar*I*JyKetCoupled(1, 0, (1, 1))/2 + assert qapply(Jz*JzKetCoupled(1, 1, (1, 1))) == \ + hbar*JzKetCoupled(1, 1, (1, 1)) + # Symbolic + assert qapply(Jz*JxKetCoupled(j, m, (j1, j2))) == \ + Sum(hbar*mi*WignerD(j, mi, m, 0, pi/2, 0)*Sum(WignerD(j, mi1, mi, 0, pi*Rational(3, 2), 0)*JxKetCoupled(j, mi1, (j1, j2)), (mi1, -j, j)), (mi, -j, j)) + assert qapply(Jz*JyKetCoupled(j, m, (j1, j2))) == \ + Sum(hbar*mi*WignerD(j, mi, m, pi*Rational(3, 2), -pi/2, pi/2)*Sum(WignerD(j, mi1, mi, pi*Rational(3, 2), pi/2, pi/2)*JyKetCoupled(j, mi1, (j1, j2)), (mi1, -j, j)), (mi, -j, j)) + assert qapply(Jz*JzKetCoupled(j, m, (j1, j2))) == \ + hbar*m*JzKetCoupled(j, m, (j1, j2)) + # Normal operators, uncoupled states + # Numerical + assert qapply(Jz*TensorProduct(JxKet(1, 1), JxKet(1, 1))) == \ + -sqrt(2)*hbar*TensorProduct(JxKet(1, 1), JxKet(1, 0))/2 - \ + sqrt(2)*hbar*TensorProduct(JxKet(1, 0), JxKet(1, 1))/2 + assert qapply(Jz*TensorProduct(JyKet(1, 1), JyKet(1, 1))) == \ + -sqrt(2)*hbar*I*TensorProduct(JyKet(1, 1), JyKet(1, 0))/2 - \ + sqrt(2)*hbar*I*TensorProduct(JyKet(1, 0), JyKet(1, 1))/2 + assert qapply(Jz*TensorProduct(JzKet(1, 1), JzKet(1, 1))) == \ + 2*hbar*TensorProduct(JzKet(1, 1), JzKet(1, 1)) + assert qapply(Jz*TensorProduct(JzKet(1, 1), JzKet(1, -1))) == 0 + # Symbolic + assert qapply(Jz*TensorProduct(JxKet(j1, m1), JxKet(j2, m2))) == \ + TensorProduct(JxKet(j1, m1), Sum(hbar*mi*WignerD(j2, mi, m2, 0, pi/2, 0)*Sum(WignerD(j2, mi1, mi, 0, pi*Rational(3, 2), 0)*JxKet(j2, mi1), (mi1, -j2, j2)), (mi, -j2, j2))) + \ + TensorProduct(Sum(hbar*mi*WignerD(j1, mi, m1, 0, pi/2, 0)*Sum(WignerD(j1, mi1, mi, 0, pi*Rational(3, 2), 0)*JxKet(j1, mi1), (mi1, -j1, j1)), (mi, -j1, j1)), JxKet(j2, m2)) + assert qapply(Jz*TensorProduct(JyKet(j1, m1), JyKet(j2, m2))) == \ + TensorProduct(JyKet(j1, m1), Sum(hbar*mi*WignerD(j2, mi, m2, pi*Rational(3, 2), -pi/2, pi/2)*Sum(WignerD(j2, mi1, mi, pi*Rational(3, 2), pi/2, pi/2)*JyKet(j2, mi1), (mi1, -j2, j2)), (mi, -j2, j2))) + \ + TensorProduct(Sum(hbar*mi*WignerD(j1, mi, m1, pi*Rational(3, 2), -pi/2, pi/2)*Sum(WignerD(j1, mi1, mi, pi*Rational(3, 2), pi/2, pi/2)*JyKet(j1, mi1), (mi1, -j1, j1)), (mi, -j1, j1)), JyKet(j2, m2)) + assert qapply(Jz*TensorProduct(JzKet(j1, m1), JzKet(j2, m2))) == \ + hbar*m1*TensorProduct(JzKet(j1, m1), JzKet( + j2, m2)) + hbar*m2*TensorProduct(JzKet(j1, m1), JzKet(j2, m2)) + # Uncoupled Operators + # Numerical + assert qapply(TensorProduct(Jz, 1)*TensorProduct(JxKet(1, 1), JxKet(1, -1))) == \ + -sqrt(2)*hbar*TensorProduct(JxKet(1, 0), JxKet(1, -1))/2 + assert qapply(TensorProduct(1, Jz)*TensorProduct(JxKet(1, 1), JxKet(1, -1))) == \ + -sqrt(2)*hbar*TensorProduct(JxKet(1, 1), JxKet(1, 0))/2 + assert qapply(TensorProduct(Jz, 1)*TensorProduct(JyKet(1, 1), JyKet(1, -1))) == \ + -sqrt(2)*I*hbar*TensorProduct(JyKet(1, 0), JyKet(1, -1))/2 + assert qapply(TensorProduct(1, Jz)*TensorProduct(JyKet(1, 1), JyKet(1, -1))) == \ + sqrt(2)*I*hbar*TensorProduct(JyKet(1, 1), JyKet(1, 0))/2 + assert qapply(TensorProduct(Jz, 1)*TensorProduct(JzKet(1, 1), JzKet(1, -1))) == \ + hbar*TensorProduct(JzKet(1, 1), JzKet(1, -1)) + assert qapply(TensorProduct(1, Jz)*TensorProduct(JzKet(1, 1), JzKet(1, -1))) == \ + -hbar*TensorProduct(JzKet(1, 1), JzKet(1, -1)) + # Symbolic + assert qapply(TensorProduct(Jz, 1)*TensorProduct(JxKet(j1, m1), JxKet(j2, m2))) == \ + TensorProduct(Sum(hbar*mi*WignerD(j1, mi, m1, 0, pi/2, 0)*Sum(WignerD(j1, mi1, mi, 0, pi*Rational(3, 2), 0)*JxKet(j1, mi1), (mi1, -j1, j1)), (mi, -j1, j1)), JxKet(j2, m2)) + assert qapply(TensorProduct(1, Jz)*TensorProduct(JxKet(j1, m1), JxKet(j2, m2))) == \ + TensorProduct(JxKet(j1, m1), Sum(hbar*mi*WignerD(j2, mi, m2, 0, pi/2, 0)*Sum(WignerD(j2, mi1, mi, 0, pi*Rational(3, 2), 0)*JxKet(j2, mi1), (mi1, -j2, j2)), (mi, -j2, j2))) + assert qapply(TensorProduct(Jz, 1)*TensorProduct(JyKet(j1, m1), JyKet(j2, m2))) == \ + TensorProduct(Sum(hbar*mi*WignerD(j1, mi, m1, pi*Rational(3, 2), -pi/2, pi/2)*Sum(WignerD(j1, mi1, mi, pi*Rational(3, 2), pi/2, pi/2)*JyKet(j1, mi1), (mi1, -j1, j1)), (mi, -j1, j1)), JyKet(j2, m2)) + assert qapply(TensorProduct(1, Jz)*TensorProduct(JyKet(j1, m1), JyKet(j2, m2))) == \ + TensorProduct(JyKet(j1, m1), Sum(hbar*mi*WignerD(j2, mi, m2, pi*Rational(3, 2), -pi/2, pi/2)*Sum(WignerD(j2, mi1, mi, pi*Rational(3, 2), pi/2, pi/2)*JyKet(j2, mi1), (mi1, -j2, j2)), (mi, -j2, j2))) + assert qapply(TensorProduct(Jz, 1)*TensorProduct(JzKet(j1, m1), JzKet(j2, m2))) == \ + hbar*m1*TensorProduct(JzKet(j1, m1), JzKet(j2, m2)) + assert qapply(TensorProduct(1, Jz)*TensorProduct(JzKet(j1, m1), JzKet(j2, m2))) == \ + hbar*m2*TensorProduct(JzKet(j1, m1), JzKet(j2, m2)) + + +def test_rotation(): + a, b, g = symbols('a b g') + j, m = symbols('j m') + #Uncoupled + answ = [JxKet(1,-1)/2 - sqrt(2)*JxKet(1,0)/2 + JxKet(1,1)/2 , + JyKet(1,-1)/2 - sqrt(2)*JyKet(1,0)/2 + JyKet(1,1)/2 , + JzKet(1,-1)/2 - sqrt(2)*JzKet(1,0)/2 + JzKet(1,1)/2] + fun = [state(1, 1) for state in (JxKet, JyKet, JzKet)] + for state in fun: + got = qapply(Rotation(0, pi/2, 0)*state) + assert got in answ + answ.remove(got) + assert not answ + arg = Rotation(a, b, g)*fun[0] + assert qapply(arg) == (-exp(-I*a)*exp(I*g)*cos(b)*JxKet(1,-1)/2 + + exp(-I*a)*exp(I*g)*JxKet(1,-1)/2 - sqrt(2)*exp(-I*a)*sin(b)*JxKet(1,0)/2 + + exp(-I*a)*exp(-I*g)*cos(b)*JxKet(1,1)/2 + exp(-I*a)*exp(-I*g)*JxKet(1,1)/2) + #dummy effective + assert str(qapply(Rotation(a, b, g)*JzKet(j, m), dummy=False)) == str( + qapply(Rotation(a, b, g)*JzKet(j, m), dummy=True)).replace('_','') + #Coupled + ans = [JxKetCoupled(1,-1,(1,1))/2 - sqrt(2)*JxKetCoupled(1,0,(1,1))/2 + + JxKetCoupled(1,1,(1,1))/2 , + JyKetCoupled(1,-1,(1,1))/2 - sqrt(2)*JyKetCoupled(1,0,(1,1))/2 + + JyKetCoupled(1,1,(1,1))/2 , + JzKetCoupled(1,-1,(1,1))/2 - sqrt(2)*JzKetCoupled(1,0,(1,1))/2 + + JzKetCoupled(1,1,(1,1))/2] + fun = [state(1, 1, (1,1)) for state in (JxKetCoupled, JyKetCoupled, JzKetCoupled)] + for state in fun: + got = qapply(Rotation(0, pi/2, 0)*state) + assert got in ans + ans.remove(got) + assert not ans + arg = Rotation(a, b, g)*fun[0] + assert qapply(arg) == ( + -exp(-I*a)*exp(I*g)*cos(b)*JxKetCoupled(1,-1,(1,1))/2 + + exp(-I*a)*exp(I*g)*JxKetCoupled(1,-1,(1,1))/2 - + sqrt(2)*exp(-I*a)*sin(b)*JxKetCoupled(1,0,(1,1))/2 + + exp(-I*a)*exp(-I*g)*cos(b)*JxKetCoupled(1,1,(1,1))/2 + + exp(-I*a)*exp(-I*g)*JxKetCoupled(1,1,(1,1))/2) + #dummy effective + assert str(qapply(Rotation(a,b,g)*JzKetCoupled(j,m,(j1,j2)), dummy=False)) == str( + qapply(Rotation(a,b,g)*JzKetCoupled(j,m,(j1,j2)), dummy=True)).replace('_','') + + +def test_jzket(): + j, m = symbols('j m') + # j not integer or half integer + raises(ValueError, lambda: JzKet(Rational(2, 3), Rational(-1, 3))) + raises(ValueError, lambda: JzKet(Rational(2, 3), m)) + # j < 0 + raises(ValueError, lambda: JzKet(-1, 1)) + raises(ValueError, lambda: JzKet(-1, m)) + # m not integer or half integer + raises(ValueError, lambda: JzKet(j, Rational(-1, 3))) + # abs(m) > j + raises(ValueError, lambda: JzKet(1, 2)) + raises(ValueError, lambda: JzKet(1, -2)) + # j-m not integer + raises(ValueError, lambda: JzKet(1, S.Half)) + + +def test_jzketcoupled(): + j, m = symbols('j m') + # j not integer or half integer + raises(ValueError, lambda: JzKetCoupled(Rational(2, 3), Rational(-1, 3), (1,))) + raises(ValueError, lambda: JzKetCoupled(Rational(2, 3), m, (1,))) + # j < 0 + raises(ValueError, lambda: JzKetCoupled(-1, 1, (1,))) + raises(ValueError, lambda: JzKetCoupled(-1, m, (1,))) + # m not integer or half integer + raises(ValueError, lambda: JzKetCoupled(j, Rational(-1, 3), (1,))) + # abs(m) > j + raises(ValueError, lambda: JzKetCoupled(1, 2, (1,))) + raises(ValueError, lambda: JzKetCoupled(1, -2, (1,))) + # j-m not integer + raises(ValueError, lambda: JzKetCoupled(1, S.Half, (1,))) + # checks types on coupling scheme + raises(TypeError, lambda: JzKetCoupled(1, 1, 1)) + raises(TypeError, lambda: JzKetCoupled(1, 1, (1,), 1)) + raises(TypeError, lambda: JzKetCoupled(1, 1, (1, 1), (1,))) + raises(TypeError, lambda: JzKetCoupled(1, 1, (1, 1, 1), (1, 2, 1), + (1, 3, 1))) + # checks length of coupling terms + raises(ValueError, lambda: JzKetCoupled(1, 1, (1,), ((1, 2, 1),))) + raises(ValueError, lambda: JzKetCoupled(1, 1, (1, 1), ((1, 2),))) + # all jn are integer or half-integer + raises(ValueError, lambda: JzKetCoupled(1, 1, (Rational(1, 3), Rational(2, 3)))) + # indices in coupling scheme must be integers + raises(ValueError, lambda: JzKetCoupled(1, 1, (1, 1), ((S.Half, 1, 2),) )) + raises(ValueError, lambda: JzKetCoupled(1, 1, (1, 1), ((1, S.Half, 2),) )) + # indices out of range + raises(ValueError, lambda: JzKetCoupled(1, 1, (1, 1), ((0, 2, 1),) )) + raises(ValueError, lambda: JzKetCoupled(1, 1, (1, 1), ((3, 2, 1),) )) + raises(ValueError, lambda: JzKetCoupled(1, 1, (1, 1), ((1, 0, 1),) )) + raises(ValueError, lambda: JzKetCoupled(1, 1, (1, 1), ((1, 3, 1),) )) + # all j values in coupling scheme must by integer or half-integer + raises(ValueError, lambda: JzKetCoupled(1, 1, (1, 1, 1), ((1, 2, S( + 4)/3), (1, 3, 1)) )) + # each coupling must satisfy |j1-j2| <= j3 <= j1+j2 + raises(ValueError, lambda: JzKetCoupled(1, 1, (1, 5))) + raises(ValueError, lambda: JzKetCoupled(5, 1, (1, 1))) + # final j of coupling must be j of the state + raises(ValueError, lambda: JzKetCoupled(1, 1, (1, 1), ((1, 2, 2),) )) diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_state.py b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_state.py new file mode 100644 index 0000000000000000000000000000000000000000..c9fd5029fa3d77c2ddfc6899187624da02796ffa --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_state.py @@ -0,0 +1,248 @@ +from sympy.core.add import Add +from sympy.core.function import diff +from sympy.core.mul import Mul +from sympy.core.numbers import (I, Integer, Rational, oo, pi) +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.core.sympify import sympify +from sympy.functions.elementary.complexes import conjugate +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import sin +from sympy.testing.pytest import raises + +from sympy.physics.quantum.dagger import Dagger +from sympy.physics.quantum.qexpr import QExpr +from sympy.physics.quantum.state import ( + Ket, Bra, TimeDepKet, TimeDepBra, + KetBase, BraBase, StateBase, Wavefunction, + OrthogonalKet, OrthogonalBra +) +from sympy.physics.quantum.hilbert import HilbertSpace + +x, y, t = symbols('x,y,t') + + +class CustomKet(Ket): + @classmethod + def default_args(self): + return ("test",) + + +class CustomKetMultipleLabels(Ket): + @classmethod + def default_args(self): + return ("r", "theta", "phi") + + +class CustomTimeDepKet(TimeDepKet): + @classmethod + def default_args(self): + return ("test", "t") + + +class CustomTimeDepKetMultipleLabels(TimeDepKet): + @classmethod + def default_args(self): + return ("r", "theta", "phi", "t") + + +def test_ket(): + k = Ket('0') + + assert isinstance(k, Ket) + assert isinstance(k, KetBase) + assert isinstance(k, StateBase) + assert isinstance(k, QExpr) + + assert k.label == (Symbol('0'),) + assert k.hilbert_space == HilbertSpace() + assert k.is_commutative is False + + # Make sure this doesn't get converted to the number pi. + k = Ket('pi') + assert k.label == (Symbol('pi'),) + + k = Ket(x, y) + assert k.label == (x, y) + assert k.hilbert_space == HilbertSpace() + assert k.is_commutative is False + + assert k.dual_class() == Bra + assert k.dual == Bra(x, y) + assert k.subs(x, y) == Ket(y, y) + + k = CustomKet() + assert k == CustomKet("test") + + k = CustomKetMultipleLabels() + assert k == CustomKetMultipleLabels("r", "theta", "phi") + + assert Ket() == Ket('psi') + + +def test_bra(): + b = Bra('0') + + assert isinstance(b, Bra) + assert isinstance(b, BraBase) + assert isinstance(b, StateBase) + assert isinstance(b, QExpr) + + assert b.label == (Symbol('0'),) + assert b.hilbert_space == HilbertSpace() + assert b.is_commutative is False + + # Make sure this doesn't get converted to the number pi. + b = Bra('pi') + assert b.label == (Symbol('pi'),) + + b = Bra(x, y) + assert b.label == (x, y) + assert b.hilbert_space == HilbertSpace() + assert b.is_commutative is False + + assert b.dual_class() == Ket + assert b.dual == Ket(x, y) + assert b.subs(x, y) == Bra(y, y) + + assert Bra() == Bra('psi') + + +def test_ops(): + k0 = Ket(0) + k1 = Ket(1) + k = 2*I*k0 - (x/sqrt(2))*k1 + assert k == Add(Mul(2, I, k0), + Mul(Rational(-1, 2), x, Pow(2, S.Half), k1)) + + +def test_time_dep_ket(): + k = TimeDepKet(0, t) + + assert isinstance(k, TimeDepKet) + assert isinstance(k, KetBase) + assert isinstance(k, StateBase) + assert isinstance(k, QExpr) + + assert k.label == (Integer(0),) + assert k.args == (Integer(0), t) + assert k.time == t + + assert k.dual_class() == TimeDepBra + assert k.dual == TimeDepBra(0, t) + + assert k.subs(t, 2) == TimeDepKet(0, 2) + + k = TimeDepKet(x, 0.5) + assert k.label == (x,) + assert k.args == (x, sympify(0.5)) + + k = CustomTimeDepKet() + assert k.label == (Symbol("test"),) + assert k.time == Symbol("t") + assert k == CustomTimeDepKet("test", "t") + + k = CustomTimeDepKetMultipleLabels() + assert k.label == (Symbol("r"), Symbol("theta"), Symbol("phi")) + assert k.time == Symbol("t") + assert k == CustomTimeDepKetMultipleLabels("r", "theta", "phi", "t") + + assert TimeDepKet() == TimeDepKet("psi", "t") + + +def test_time_dep_bra(): + b = TimeDepBra(0, t) + + assert isinstance(b, TimeDepBra) + assert isinstance(b, BraBase) + assert isinstance(b, StateBase) + assert isinstance(b, QExpr) + + assert b.label == (Integer(0),) + assert b.args == (Integer(0), t) + assert b.time == t + + assert b.dual_class() == TimeDepKet + assert b.dual == TimeDepKet(0, t) + + k = TimeDepBra(x, 0.5) + assert k.label == (x,) + assert k.args == (x, sympify(0.5)) + + assert TimeDepBra() == TimeDepBra("psi", "t") + + +def test_bra_ket_dagger(): + x = symbols('x', complex=True) + k = Ket('k') + b = Bra('b') + assert Dagger(k) == Bra('k') + assert Dagger(b) == Ket('b') + assert Dagger(k).is_commutative is False + + k2 = Ket('k2') + e = 2*I*k + x*k2 + assert Dagger(e) == conjugate(x)*Dagger(k2) - 2*I*Dagger(k) + + +def test_wavefunction(): + x, y = symbols('x y', real=True) + L = symbols('L', positive=True) + n = symbols('n', integer=True, positive=True) + + f = Wavefunction(x**2, x) + p = f.prob() + lims = f.limits + + assert f.is_normalized is False + assert f.norm is oo + assert f(10) == 100 + assert p(10) == 10000 + assert lims[x] == (-oo, oo) + assert diff(f, x) == Wavefunction(2*x, x) + raises(NotImplementedError, lambda: f.normalize()) + assert conjugate(f) == Wavefunction(conjugate(f.expr), x) + assert conjugate(f) == Dagger(f) + + g = Wavefunction(x**2*y + y**2*x, (x, 0, 1), (y, 0, 2)) + lims_g = g.limits + + assert lims_g[x] == (0, 1) + assert lims_g[y] == (0, 2) + assert g.is_normalized is False + assert g.norm == sqrt(42)/3 + assert g(2, 4) == 0 + assert g(1, 1) == 2 + assert diff(diff(g, x), y) == Wavefunction(2*x + 2*y, (x, 0, 1), (y, 0, 2)) + assert conjugate(g) == Wavefunction(conjugate(g.expr), *g.args[1:]) + assert conjugate(g) == Dagger(g) + + h = Wavefunction(sqrt(5)*x**2, (x, 0, 1)) + assert h.is_normalized is True + assert h.normalize() == h + assert conjugate(h) == Wavefunction(conjugate(h.expr), (x, 0, 1)) + assert conjugate(h) == Dagger(h) + + piab = Wavefunction(sin(n*pi*x/L), (x, 0, L)) + assert piab.norm == sqrt(L/2) + assert piab(L + 1) == 0 + assert piab(0.5) == sin(0.5*n*pi/L) + assert piab(0.5, n=1, L=1) == sin(0.5*pi) + assert piab.normalize() == \ + Wavefunction(sqrt(2)/sqrt(L)*sin(n*pi*x/L), (x, 0, L)) + assert conjugate(piab) == Wavefunction(conjugate(piab.expr), (x, 0, L)) + assert conjugate(piab) == Dagger(piab) + + k = Wavefunction(x**2, 'x') + assert type(k.variables[0]) == Symbol + +def test_orthogonal_states(): + bracket = OrthogonalBra(x) * OrthogonalKet(x) + assert bracket.doit() == 1 + + bracket = OrthogonalBra(x) * OrthogonalKet(x+1) + assert bracket.doit() == 0 + + bracket = OrthogonalBra(x) * OrthogonalKet(y) + assert bracket.doit() == bracket diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_tensorproduct.py b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_tensorproduct.py new file mode 100644 index 0000000000000000000000000000000000000000..c17d533ae6d4ae97cb313eb345219fd82c6e483c --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_tensorproduct.py @@ -0,0 +1,142 @@ +from sympy.core.numbers import I +from sympy.core.symbol import symbols +from sympy.core.expr import unchanged +from sympy.matrices import Matrix, SparseMatrix, ImmutableMatrix +from sympy.testing.pytest import warns_deprecated_sympy + +from sympy.physics.quantum.commutator import Commutator as Comm +from sympy.physics.quantum.tensorproduct import TensorProduct +from sympy.physics.quantum.tensorproduct import TensorProduct as TP +from sympy.physics.quantum.tensorproduct import tensor_product_simp +from sympy.physics.quantum.dagger import Dagger +from sympy.physics.quantum.qubit import Qubit, QubitBra +from sympy.physics.quantum.operator import OuterProduct, Operator +from sympy.physics.quantum.density import Density +from sympy.physics.quantum.trace import Tr + +A = Operator('A') +B = Operator('B') +C = Operator('C') +D = Operator('D') +x = symbols('x') +y = symbols('y', integer=True, positive=True) + +mat1 = Matrix([[1, 2*I], [1 + I, 3]]) +mat2 = Matrix([[2*I, 3], [4*I, 2]]) + + +def test_sparse_matrices(): + spm = SparseMatrix.diag(1, 0) + assert unchanged(TensorProduct, spm, spm) + + +def test_tensor_product_dagger(): + assert Dagger(TensorProduct(I*A, B)) == \ + -I*TensorProduct(Dagger(A), Dagger(B)) + assert Dagger(TensorProduct(mat1, mat2)) == \ + TensorProduct(Dagger(mat1), Dagger(mat2)) + + +def test_tensor_product_abstract(): + + assert TP(x*A, 2*B) == x*2*TP(A, B) + assert TP(A, B) != TP(B, A) + assert TP(A, B).is_commutative is False + assert isinstance(TP(A, B), TP) + assert TP(A, B).subs(A, C) == TP(C, B) + + +def test_tensor_product_expand(): + assert TP(A + B, B + C).expand(tensorproduct=True) == \ + TP(A, B) + TP(A, C) + TP(B, B) + TP(B, C) + #Tests for fix of issue #24142 + assert TP(A-B, B-A).expand(tensorproduct=True) == \ + TP(A, B) - TP(A, A) - TP(B, B) + TP(B, A) + assert TP(2*A + B, A + B).expand(tensorproduct=True) == \ + 2 * TP(A, A) + 2 * TP(A, B) + TP(B, A) + TP(B, B) + assert TP(2 * A * B + A, A + B).expand(tensorproduct=True) == \ + 2 * TP(A*B, A) + 2 * TP(A*B, B) + TP(A, A) + TP(A, B) + + +def test_tensor_product_commutator(): + assert TP(Comm(A, B), C).doit().expand(tensorproduct=True) == \ + TP(A*B, C) - TP(B*A, C) + assert Comm(TP(A, B), TP(B, C)).doit() == \ + TP(A, B)*TP(B, C) - TP(B, C)*TP(A, B) + + +def test_tensor_product_simp(): + with warns_deprecated_sympy(): + assert tensor_product_simp(TP(A, B)*TP(B, C)) == TP(A*B, B*C) + # tests for Pow-expressions + assert TP(A, B)**y == TP(A**y, B**y) + assert tensor_product_simp(TP(A, B)**y) == TP(A**y, B**y) + assert tensor_product_simp(x*TP(A, B)**2) == x*TP(A**2,B**2) + assert tensor_product_simp(x*(TP(A, B)**2)*TP(C,D)) == x*TP(A**2*C,B**2*D) + assert tensor_product_simp(TP(A,B)-TP(C,D)**y) == TP(A,B)-TP(C**y,D**y) + + +def test_issue_5923(): + # most of the issue regarding sympification of args has been handled + # and is tested internally by the use of args_cnc through the quantum + # module, but the following is a test from the issue that used to raise. + assert TensorProduct(1, Qubit('1')*Qubit('1').dual) == \ + TensorProduct(1, OuterProduct(Qubit(1), QubitBra(1))) + + +def test_eval_trace(): + # This test includes tests with dependencies between TensorProducts + #and density operators. Since, the test is more to test the behavior of + #TensorProducts it remains here + + # Density with simple tensor products as args + t = TensorProduct(A, B) + d = Density([t, 1.0]) + tr = Tr(d) + assert tr.doit() == 1.0*Tr(A*Dagger(A))*Tr(B*Dagger(B)) + + ## partial trace with simple tensor products as args + t = TensorProduct(A, B, C) + d = Density([t, 1.0]) + tr = Tr(d, [1]) + assert tr.doit() == 1.0*A*Dagger(A)*Tr(B*Dagger(B))*C*Dagger(C) + + tr = Tr(d, [0, 2]) + assert tr.doit() == 1.0*Tr(A*Dagger(A))*B*Dagger(B)*Tr(C*Dagger(C)) + + # Density with multiple Tensorproducts as states + t2 = TensorProduct(A, B) + t3 = TensorProduct(C, D) + + d = Density([t2, 0.5], [t3, 0.5]) + t = Tr(d) + assert t.doit() == (0.5*Tr(A*Dagger(A))*Tr(B*Dagger(B)) + + 0.5*Tr(C*Dagger(C))*Tr(D*Dagger(D))) + + t = Tr(d, [0]) + assert t.doit() == (0.5*Tr(A*Dagger(A))*B*Dagger(B) + + 0.5*Tr(C*Dagger(C))*D*Dagger(D)) + + #Density with mixed states + d = Density([t2 + t3, 1.0]) + t = Tr(d) + assert t.doit() == ( 1.0*Tr(A*Dagger(A))*Tr(B*Dagger(B)) + + 1.0*Tr(A*Dagger(C))*Tr(B*Dagger(D)) + + 1.0*Tr(C*Dagger(A))*Tr(D*Dagger(B)) + + 1.0*Tr(C*Dagger(C))*Tr(D*Dagger(D))) + + t = Tr(d, [1] ) + assert t.doit() == ( 1.0*A*Dagger(A)*Tr(B*Dagger(B)) + + 1.0*A*Dagger(C)*Tr(B*Dagger(D)) + + 1.0*C*Dagger(A)*Tr(D*Dagger(B)) + + 1.0*C*Dagger(C)*Tr(D*Dagger(D))) + + +def test_pr24993(): + from sympy.matrices.expressions.kronecker import matrix_kronecker_product + from sympy.physics.quantum.matrixutils import matrix_tensor_product + X = Matrix([[0, 1], [1, 0]]) + Xi = ImmutableMatrix(X) + assert TensorProduct(Xi, Xi) == TensorProduct(X, X) + assert TensorProduct(Xi, Xi) == matrix_tensor_product(X, X) + assert TensorProduct(Xi, Xi) == matrix_kronecker_product(X, X) diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_trace.py b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_trace.py new file mode 100644 index 0000000000000000000000000000000000000000..85db6c60ad9d2bd1fbfafcf5d84b97d2fe304250 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_trace.py @@ -0,0 +1,109 @@ +from sympy.core.containers import Tuple +from sympy.core.symbol import symbols +from sympy.matrices.dense import Matrix +from sympy.physics.quantum.trace import Tr +from sympy.testing.pytest import raises, warns_deprecated_sympy + + +def test_trace_new(): + a, b, c, d, Y = symbols('a b c d Y') + A, B, C, D = symbols('A B C D', commutative=False) + + assert Tr(a + b) == a + b + assert Tr(A + B) == Tr(A) + Tr(B) + + #check trace args not implicitly permuted + assert Tr(C*D*A*B).args[0].args == (C, D, A, B) + + # check for mul and adds + assert Tr((a*b) + ( c*d)) == (a*b) + (c*d) + # Tr(scalar*A) = scalar*Tr(A) + assert Tr(a*A) == a*Tr(A) + assert Tr(a*A*B*b) == a*b*Tr(A*B) + + # since A is symbol and not commutative + assert isinstance(Tr(A), Tr) + + #POW + assert Tr(pow(a, b)) == a**b + assert isinstance(Tr(pow(A, a)), Tr) + + #Matrix + M = Matrix([[1, 1], [2, 2]]) + assert Tr(M) == 3 + + ##test indices in different forms + #no index + t = Tr(A) + assert t.args[1] == Tuple() + + #single index + t = Tr(A, 0) + assert t.args[1] == Tuple(0) + + #index in a list + t = Tr(A, [0]) + assert t.args[1] == Tuple(0) + + t = Tr(A, [0, 1, 2]) + assert t.args[1] == Tuple(0, 1, 2) + + #index is tuple + t = Tr(A, (0)) + assert t.args[1] == Tuple(0) + + t = Tr(A, (1, 2)) + assert t.args[1] == Tuple(1, 2) + + #trace indices test + t = Tr((A + B), [2]) + assert t.args[0].args[1] == Tuple(2) and t.args[1].args[1] == Tuple(2) + + t = Tr(a*A, [2, 3]) + assert t.args[1].args[1] == Tuple(2, 3) + + #class with trace method defined + #to simulate numpy objects + class Foo: + def trace(self): + return 1 + assert Tr(Foo()) == 1 + + #argument test + # check for value error, when either/both arguments are not provided + raises(ValueError, lambda: Tr()) + raises(ValueError, lambda: Tr(A, 1, 2)) + + +def test_trace_doit(): + a, b, c, d = symbols('a b c d') + A, B, C, D = symbols('A B C D', commutative=False) + + #TODO: needed while testing reduced density operations, etc. + + +def test_permute(): + A, B, C, D, E, F, G = symbols('A B C D E F G', commutative=False) + t = Tr(A*B*C*D*E*F*G) + + assert t.permute(0).args[0].args == (A, B, C, D, E, F, G) + assert t.permute(2).args[0].args == (F, G, A, B, C, D, E) + assert t.permute(4).args[0].args == (D, E, F, G, A, B, C) + assert t.permute(6).args[0].args == (B, C, D, E, F, G, A) + assert t.permute(8).args[0].args == t.permute(1).args[0].args + + assert t.permute(-1).args[0].args == (B, C, D, E, F, G, A) + assert t.permute(-3).args[0].args == (D, E, F, G, A, B, C) + assert t.permute(-5).args[0].args == (F, G, A, B, C, D, E) + assert t.permute(-8).args[0].args == t.permute(-1).args[0].args + + t = Tr((A + B)*(B*B)*C*D) + assert t.permute(2).args[0].args == (C, D, (A + B), (B**2)) + + t1 = Tr(A*B) + t2 = t1.permute(1) + assert id(t1) != id(t2) and t1 == t2 + +def test_deprecated_core_trace(): + with warns_deprecated_sympy(): + from sympy.core.trace import Tr # noqa:F401 diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_transforms.py b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..55349ebe3b8003b5a107648516706034beaf22af --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/tests/test_transforms.py @@ -0,0 +1,75 @@ +"""Tests of transforms of quantum expressions for Mul and Pow.""" + +from sympy.core.symbol import symbols +from sympy.testing.pytest import raises + +from sympy.physics.quantum.operator import ( + Operator, OuterProduct +) +from sympy.physics.quantum.state import Ket, Bra +from sympy.physics.quantum.innerproduct import InnerProduct +from sympy.physics.quantum.tensorproduct import TensorProduct + + +k1 = Ket('k1') +k2 = Ket('k2') +k3 = Ket('k3') +b1 = Bra('b1') +b2 = Bra('b2') +b3 = Bra('b3') +A = Operator('A') +B = Operator('B') +C = Operator('C') +x, y, z = symbols('x y z') + + +def test_bra_ket(): + assert b1*k1 == InnerProduct(b1, k1) + assert k1*b1 == OuterProduct(k1, b1) + # Test priority of inner product + assert OuterProduct(k1, b1)*k2 == InnerProduct(b1, k2)*k1 + assert b1*OuterProduct(k1, b2) == InnerProduct(b1, k1)*b2 + + +def test_tensor_product(): + # We are attempting to be rigourous and raise TypeError when a user tries + # to combine bras, kets, and operators in a manner that doesn't make sense. + # In particular, we are not trying to interpret regular ``*`` multiplication + # as a tensor product. + with raises(TypeError): + k1*k1 + with raises(TypeError): + b1*b1 + with raises(TypeError): + k1*TensorProduct(k2, k3) + with raises(TypeError): + b1*TensorProduct(b2, b3) + with raises(TypeError): + TensorProduct(k2, k3)*k1 + with raises(TypeError): + TensorProduct(b2, b3)*b1 + + assert TensorProduct(A, B, C)*TensorProduct(k1, k2, k3) == \ + TensorProduct(A*k1, B*k2, C*k3) + assert TensorProduct(b1, b2, b3)*TensorProduct(A, B, C) == \ + TensorProduct(b1*A, b2*B, b3*C) + assert TensorProduct(b1, b2, b3)*TensorProduct(k1, k2, k3) == \ + InnerProduct(b1, k1)*InnerProduct(b2, k2)*InnerProduct(b3, k3) + assert TensorProduct(b1, b2, b3)*TensorProduct(A, B, C)*TensorProduct(k1, k2, k3) == \ + TensorProduct(b1*A*k1, b2*B*k2, b3*C*k3) + + +def test_outer_product(): + assert OuterProduct(k1, b1)*OuterProduct(k2, b2) == \ + InnerProduct(b1, k2)*OuterProduct(k1, b2) + + +def test_compound(): + e1 = b1*A*B*k1*b2*k2*b3 + assert e1 == InnerProduct(b2, k2)*b1*A*B*OuterProduct(k1, b3) + + e2 = TensorProduct(k1, k2)*TensorProduct(b1, b2) + assert e2 == TensorProduct( + OuterProduct(k1, b1), + OuterProduct(k2, b2) + ) diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/trace.py b/phivenv/Lib/site-packages/sympy/physics/quantum/trace.py new file mode 100644 index 0000000000000000000000000000000000000000..03ab18f78a1bfcf5bfcd679f00eac8685144fd8c --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/trace.py @@ -0,0 +1,230 @@ +from sympy.core.add import Add +from sympy.core.containers import Tuple +from sympy.core.expr import Expr +from sympy.core.mul import Mul +from sympy.core.power import Pow +from sympy.core.sorting import default_sort_key +from sympy.core.sympify import sympify +from sympy.matrices import Matrix + + +def _is_scalar(e): + """ Helper method used in Tr""" + + # sympify to set proper attributes + e = sympify(e) + if isinstance(e, Expr): + if (e.is_Integer or e.is_Float or + e.is_Rational or e.is_Number or + (e.is_Symbol and e.is_commutative) + ): + return True + + return False + + +def _cycle_permute(l): + """ Cyclic permutations based on canonical ordering + + Explanation + =========== + + This method does the sort based ascii values while + a better approach would be to used lexicographic sort. + + TODO: Handle condition such as symbols have subscripts/superscripts + in case of lexicographic sort + + """ + + if len(l) == 1: + return l + + min_item = min(l, key=default_sort_key) + indices = [i for i, x in enumerate(l) if x == min_item] + + le = list(l) + le.extend(l) # duplicate and extend string for easy processing + + # adding the first min_item index back for easier looping + indices.append(len(l) + indices[0]) + + # create sublist of items with first item as min_item and last_item + # in each of the sublist is item just before the next occurrence of + # minitem in the cycle formed. + sublist = [[le[indices[i]:indices[i + 1]]] for i in + range(len(indices) - 1)] + + # we do comparison of strings by comparing elements + # in each sublist + idx = sublist.index(min(sublist)) + ordered_l = le[indices[idx]:indices[idx] + len(l)] + + return ordered_l + + +def _rearrange_args(l): + """ this just moves the last arg to first position + to enable expansion of args + A,B,A ==> A**2,B + """ + if len(l) == 1: + return l + + x = list(l[-1:]) + x.extend(l[0:-1]) + return Mul(*x).args + + +class Tr(Expr): + """ Generic Trace operation than can trace over: + + a) SymPy matrix + b) operators + c) outer products + + Parameters + ========== + o : operator, matrix, expr + i : tuple/list indices (optional) + + Examples + ======== + + # TODO: Need to handle printing + + a) Trace(A+B) = Tr(A) + Tr(B) + b) Trace(scalar*Operator) = scalar*Trace(Operator) + + >>> from sympy.physics.quantum.trace import Tr + >>> from sympy import symbols, Matrix + >>> a, b = symbols('a b', commutative=True) + >>> A, B = symbols('A B', commutative=False) + >>> Tr(a*A,[2]) + a*Tr(A) + >>> m = Matrix([[1,2],[1,1]]) + >>> Tr(m) + 2 + + """ + def __new__(cls, *args): + """ Construct a Trace object. + + Parameters + ========== + args = SymPy expression + indices = tuple/list if indices, optional + + """ + + # expect no indices,int or a tuple/list/Tuple + if (len(args) == 2): + if not isinstance(args[1], (list, Tuple, tuple)): + indices = Tuple(args[1]) + else: + indices = Tuple(*args[1]) + + expr = args[0] + elif (len(args) == 1): + indices = Tuple() + expr = args[0] + else: + raise ValueError("Arguments to Tr should be of form " + "(expr[, [indices]])") + + if isinstance(expr, Matrix): + return expr.trace() + elif hasattr(expr, 'trace') and callable(expr.trace): + #for any objects that have trace() defined e.g numpy + return expr.trace() + elif isinstance(expr, Add): + return Add(*[Tr(arg, indices) for arg in expr.args]) + elif isinstance(expr, Mul): + c_part, nc_part = expr.args_cnc() + if len(nc_part) == 0: + return Mul(*c_part) + else: + obj = Expr.__new__(cls, Mul(*nc_part), indices ) + #this check is needed to prevent cached instances + #being returned even if len(c_part)==0 + return Mul(*c_part)*obj if len(c_part) > 0 else obj + elif isinstance(expr, Pow): + if (_is_scalar(expr.args[0]) and + _is_scalar(expr.args[1])): + return expr + else: + return Expr.__new__(cls, expr, indices) + else: + if (_is_scalar(expr)): + return expr + + return Expr.__new__(cls, expr, indices) + + @property + def kind(self): + expr = self.args[0] + expr_kind = expr.kind + return expr_kind.element_kind + + def doit(self, **hints): + """ Perform the trace operation. + + #TODO: Current version ignores the indices set for partial trace. + + >>> from sympy.physics.quantum.trace import Tr + >>> from sympy.physics.quantum.operator import OuterProduct + >>> from sympy.physics.quantum.spin import JzKet, JzBra + >>> t = Tr(OuterProduct(JzKet(1,1), JzBra(1,1))) + >>> t.doit() + 1 + + """ + if hasattr(self.args[0], '_eval_trace'): + return self.args[0]._eval_trace(indices=self.args[1]) + + return self + + @property + def is_number(self): + # TODO : improve this implementation + return True + + #TODO: Review if the permute method is needed + # and if it needs to return a new instance + def permute(self, pos): + """ Permute the arguments cyclically. + + Parameters + ========== + + pos : integer, if positive, shift-right, else shift-left + + Examples + ======== + + >>> from sympy.physics.quantum.trace import Tr + >>> from sympy import symbols + >>> A, B, C, D = symbols('A B C D', commutative=False) + >>> t = Tr(A*B*C*D) + >>> t.permute(2) + Tr(C*D*A*B) + >>> t.permute(-2) + Tr(C*D*A*B) + + """ + if pos > 0: + pos = pos % len(self.args[0].args) + else: + pos = -(abs(pos) % len(self.args[0].args)) + + args = list(self.args[0].args[-pos:] + self.args[0].args[0:-pos]) + + return Tr(Mul(*(args))) + + def _hashable_content(self): + if isinstance(self.args[0], Mul): + args = _cycle_permute(_rearrange_args(self.args[0].args)) + else: + args = [self.args[0]] + + return tuple(args) + (self.args[1], ) diff --git a/phivenv/Lib/site-packages/sympy/physics/quantum/transforms.py b/phivenv/Lib/site-packages/sympy/physics/quantum/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..dcbbcd9040b4f8f987375c2f903031610d6f9061 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/quantum/transforms.py @@ -0,0 +1,291 @@ +"""Transforms that are always applied to quantum expressions. + +This module uses the kind and _constructor_postprocessor_mapping APIs +to transform different combinations of Operators, Bras, and Kets into +Inner/Outer/TensorProducts. These transformations are registered +with the postprocessing API of core classes like `Mul` and `Pow` and +are always applied to any expression involving Bras, Kets, and +Operators. This API replaces the custom `__mul__` and `__pow__` +methods of the quantum classes, which were found to be inconsistent. + +THIS IS EXPERIMENTAL. +""" +from sympy.core.basic import Basic +from sympy.core.expr import Expr +from sympy.core.mul import Mul +from sympy.core.singleton import S +from sympy.multipledispatch.dispatcher import ( + Dispatcher, ambiguity_register_error_ignore_dup +) +from sympy.utilities.misc import debug + +from sympy.physics.quantum.innerproduct import InnerProduct +from sympy.physics.quantum.kind import KetKind, BraKind, OperatorKind +from sympy.physics.quantum.operator import ( + OuterProduct, IdentityOperator, Operator +) +from sympy.physics.quantum.state import BraBase, KetBase, StateBase +from sympy.physics.quantum.tensorproduct import TensorProduct + + +#----------------------------------------------------------------------------- +# Multipledispatch based transformed for Mul and Pow +#----------------------------------------------------------------------------- + +_transform_state_pair = Dispatcher('_transform_state_pair') +"""Transform a pair of expression in a Mul to their canonical form. + +All functions that are registered with this dispatcher need to take +two inputs and return either tuple of transformed outputs, or None if no +transform is applied. The output tuple is inserted into the right place +of the ``Mul`` that is being put into canonical form. It works something like +the following: + +``Mul(a, b, c, d, e, f) -> Mul(*(_transform_state_pair(a, b) + (c, d, e, f))))`` + +The transforms here are always applied when quantum objects are multiplied. + +THIS IS EXPERIMENTAL. + +However, users of ``sympy.physics.quantum`` can import this dispatcher and +register their own transforms to control the canonical form of products +of quantum expressions. +""" + +@_transform_state_pair.register(Expr, Expr) +def _transform_expr(a, b): + """Default transformer that does nothing for base types.""" + return None + + +# The identity times anything is the anything. +_transform_state_pair.add( + (IdentityOperator, Expr), + lambda x, y: (y,), + on_ambiguity=ambiguity_register_error_ignore_dup +) +_transform_state_pair.add( + (Expr, IdentityOperator), + lambda x, y: (x,), + on_ambiguity=ambiguity_register_error_ignore_dup +) +_transform_state_pair.add( + (IdentityOperator, IdentityOperator), + lambda x, y: S.One, + on_ambiguity=ambiguity_register_error_ignore_dup +) + +@_transform_state_pair.register(BraBase, KetBase) +def _transform_bra_ket(a, b): + """Transform a bra*ket -> InnerProduct(bra, ket).""" + return (InnerProduct(a, b),) + +@_transform_state_pair.register(KetBase, BraBase) +def _transform_ket_bra(a, b): + """Transform a keT*bra -> OuterProduct(ket, bra).""" + return (OuterProduct(a, b),) + +@_transform_state_pair.register(KetBase, KetBase) +def _transform_ket_ket(a, b): + """Raise a TypeError if a user tries to multiply two kets. + + Multiplication based on `*` is not a shorthand for tensor products. + """ + raise TypeError( + 'Multiplication of two kets is not allowed. Use TensorProduct instead.' + ) + +@_transform_state_pair.register(BraBase, BraBase) +def _transform_bra_bra(a, b): + """Raise a TypeError if a user tries to multiply two bras. + + Multiplication based on `*` is not a shorthand for tensor products. + """ + raise TypeError( + 'Multiplication of two bras is not allowed. Use TensorProduct instead.' + ) + +@_transform_state_pair.register(OuterProduct, KetBase) +def _transform_op_ket(a, b): + return (InnerProduct(a.bra, b), a.ket) + +@_transform_state_pair.register(BraBase, OuterProduct) +def _transform_bra_op(a, b): + return (InnerProduct(a, b.ket), b.bra) + +@_transform_state_pair.register(TensorProduct, KetBase) +def _transform_tp_ket(a, b): + """Raise a TypeError if a user tries to multiply TensorProduct(*kets)*ket. + + Multiplication based on `*` is not a shorthand for tensor products. + """ + if a.kind == KetKind: + raise TypeError( + 'Multiplication of TensorProduct(*kets)*ket is invalid.' + ) + +@_transform_state_pair.register(KetBase, TensorProduct) +def _transform_ket_tp(a, b): + """Raise a TypeError if a user tries to multiply ket*TensorProduct(*kets). + + Multiplication based on `*` is not a shorthand for tensor products. + """ + if b.kind == KetKind: + raise TypeError( + 'Multiplication of ket*TensorProduct(*kets) is invalid.' + ) + +@_transform_state_pair.register(TensorProduct, BraBase) +def _transform_tp_bra(a, b): + """Raise a TypeError if a user tries to multiply TensorProduct(*bras)*bra. + + Multiplication based on `*` is not a shorthand for tensor products. + """ + if a.kind == BraKind: + raise TypeError( + 'Multiplication of TensorProduct(*bras)*bra is invalid.' + ) + +@_transform_state_pair.register(BraBase, TensorProduct) +def _transform_bra_tp(a, b): + """Raise a TypeError if a user tries to multiply bra*TensorProduct(*bras). + + Multiplication based on `*` is not a shorthand for tensor products. + """ + if b.kind == BraKind: + raise TypeError( + 'Multiplication of bra*TensorProduct(*bras) is invalid.' + ) + +@_transform_state_pair.register(TensorProduct, TensorProduct) +def _transform_tp_tp(a, b): + """Combine a product of tensor products if their number of args matches.""" + debug('_transform_tp_tp', a, b) + if len(a.args) == len(b.args): + if a.kind == BraKind and b.kind == KetKind: + return tuple([InnerProduct(i, j) for (i, j) in zip(a.args, b.args)]) + else: + return (TensorProduct(*(i*j for (i, j) in zip(a.args, b.args))), ) + +@_transform_state_pair.register(OuterProduct, OuterProduct) +def _transform_op_op(a, b): + """Extract an inner produt from a product of outer products.""" + return (InnerProduct(a.bra, b.ket), OuterProduct(a.ket, b.bra)) + + +#----------------------------------------------------------------------------- +# Postprocessing transforms for Mul and Pow +#----------------------------------------------------------------------------- + + +def _postprocess_state_mul(expr): + """Transform a ``Mul`` of quantum expressions into canonical form. + + This function is registered ``_constructor_postprocessor_mapping`` as a + transformer for ``Mul``. This means that every time a quantum expression + is multiplied, this function will be called to transform it into canonical + form as defined by the binary functions registered with + ``_transform_state_pair``. + + The algorithm of this function is as follows. It walks the args + of the input ``Mul`` from left to right and calls ``_transform_state_pair`` + on every overlapping pair of args. Each time ``_transform_state_pair`` + is called it can return a tuple of items or None. If None, the pair isn't + transformed. If a tuple, then the last element of the tuple goes back into + the args to be transformed again and the others are extended onto the result + args list. + + The algorithm can be visualized in the following table: + + step result args + ============================================================================ + #0 [] [a, b, c, d, e, f] + #1 [] [T(a,b), c, d, e, f] + #2 [T(a,b)[:-1]] [T(a,b)[-1], c, d, e, f] + #3 [T(a,b)[:-1]] [T(T(a,b)[-1], c), d, e, f] + #4 [T(a,b)[:-1], T(T(a,b)[-1], c)[:-1]] [T(T(T(a,b)[-1], c)[-1], d), e, f] + #5 ... + + One limitation of the current implementation is that we assume that only the + last item of the transformed tuple goes back into the args to be transformed + again. These seems to handle the cases needed for Mul. However, we may need + to extend the algorithm to have the entire tuple go back into the args for + further transformation. + """ + args = list(expr.args) + result = [] + + # Continue as long as we have at least 2 elements + while len(args) > 1: + # Get first two elements + first = args.pop(0) + second = args[0] # Look at second element without popping yet + + transformed = _transform_state_pair(first, second) + + if transformed is None: + # If transform returns None, append first element + result.append(first) + else: + # This item was transformed, pop and discard + args.pop(0) + # The last item goes back to be transformed again + args.insert(0, transformed[-1]) + # All other items go directly into the result + result.extend(transformed[:-1]) + + # Append any remaining element + if args: + result.append(args[0]) + + return Mul._from_args(result, is_commutative=False) + + +def _postprocess_state_pow(expr): + """Handle bras and kets raised to powers. + + Under ``*`` multiplication this is invalid. Users should use a + TensorProduct instead. + """ + base, exp = expr.as_base_exp() + if base.kind == KetKind or base.kind == BraKind: + raise TypeError( + 'A bra or ket to a power is invalid, use TensorProduct instead.' + ) + + +def _postprocess_tp_pow(expr): + """Handle TensorProduct(*operators)**(positive integer). + + This handles a tensor product of operators, to an integer power. + The power here is interpreted as regular multiplication, not + tensor product exponentiation. The form of exponentiation performed + here leaves the space and dimension of the object the same. + + This operation does not make sense for tensor product's of states. + """ + base, exp = expr.as_base_exp() + debug('_postprocess_tp_pow: ', base, exp, expr.args) + if isinstance(base, TensorProduct) and exp.is_integer and exp.is_positive and base.kind == OperatorKind: + new_args = [a**exp for a in base.args] + return TensorProduct(*new_args) + + +#----------------------------------------------------------------------------- +# Register the transformers with Basic._constructor_postprocessor_mapping +#----------------------------------------------------------------------------- + + +Basic._constructor_postprocessor_mapping[StateBase] = { + "Mul": [_postprocess_state_mul], + "Pow": [_postprocess_state_pow] +} + +Basic._constructor_postprocessor_mapping[TensorProduct] = { + "Mul": [_postprocess_state_mul], + "Pow": [_postprocess_tp_pow] +} + +Basic._constructor_postprocessor_mapping[Operator] = { + "Mul": [_postprocess_state_mul] +} diff --git a/phivenv/Lib/site-packages/sympy/physics/secondquant.py b/phivenv/Lib/site-packages/sympy/physics/secondquant.py new file mode 100644 index 0000000000000000000000000000000000000000..189e8e8b50c785759b03f19f28285f7988cfca75 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/secondquant.py @@ -0,0 +1,3125 @@ +""" +Second quantization operators and states for bosons. + +This follow the formulation of Fetter and Welecka, "Quantum Theory +of Many-Particle Systems." +""" +from collections import defaultdict + +from sympy.core.add import Add +from sympy.core.basic import Basic +from sympy.core.cache import cacheit +from sympy.core.containers import Tuple +from sympy.core.expr import Expr +from sympy.core.function import Function +from sympy.core.mul import Mul +from sympy.core.numbers import I +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.core.sorting import default_sort_key +from sympy.core.symbol import Dummy, Symbol +from sympy.core.sympify import sympify +from sympy.functions.elementary.complexes import conjugate +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.matrices.dense import zeros +from sympy.printing.str import StrPrinter +from sympy.utilities.iterables import has_dups + +__all__ = [ + 'Dagger', + 'KroneckerDelta', + 'BosonicOperator', + 'AnnihilateBoson', + 'CreateBoson', + 'AnnihilateFermion', + 'CreateFermion', + 'FockState', + 'FockStateBra', + 'FockStateKet', + 'FockStateBosonKet', + 'FockStateBosonBra', + 'FockStateFermionKet', + 'FockStateFermionBra', + 'BBra', + 'BKet', + 'FBra', + 'FKet', + 'F', + 'Fd', + 'B', + 'Bd', + 'apply_operators', + 'InnerProduct', + 'BosonicBasis', + 'VarBosonicBasis', + 'FixedBosonicBasis', + 'Commutator', + 'matrix_rep', + 'contraction', + 'wicks', + 'NO', + 'evaluate_deltas', + 'AntiSymmetricTensor', + 'substitute_dummies', + 'PermutationOperator', + 'simplify_index_permutations', +] + + +class SecondQuantizationError(Exception): + pass + + +class AppliesOnlyToSymbolicIndex(SecondQuantizationError): + pass + + +class ContractionAppliesOnlyToFermions(SecondQuantizationError): + pass + + +class ViolationOfPauliPrinciple(SecondQuantizationError): + pass + + +class SubstitutionOfAmbigousOperatorFailed(SecondQuantizationError): + pass + + +class WicksTheoremDoesNotApply(SecondQuantizationError): + pass + + +class Dagger(Expr): + """ + Hermitian conjugate of creation/annihilation operators. + + Examples + ======== + + >>> from sympy import I + >>> from sympy.physics.secondquant import Dagger, B, Bd + >>> Dagger(2*I) + -2*I + >>> Dagger(B(0)) + CreateBoson(0) + >>> Dagger(Bd(0)) + AnnihilateBoson(0) + + """ + + def __new__(cls, arg): + arg = sympify(arg) + r = cls.eval(arg) + if isinstance(r, Basic): + return r + obj = Basic.__new__(cls, arg) + return obj + + @classmethod + def eval(cls, arg): + """ + Evaluates the Dagger instance. + + Examples + ======== + + >>> from sympy import I + >>> from sympy.physics.secondquant import Dagger, B, Bd + >>> Dagger(2*I) + -2*I + >>> Dagger(B(0)) + CreateBoson(0) + >>> Dagger(Bd(0)) + AnnihilateBoson(0) + + The eval() method is called automatically. + + """ + dagger = getattr(arg, '_dagger_', None) + if dagger is not None: + return dagger() + if isinstance(arg, Symbol) and arg.is_commutative: + return conjugate(arg) + if isinstance(arg, Basic): + if arg.is_Add: + return Add(*tuple(map(Dagger, arg.args))) + if arg.is_Mul: + return Mul(*tuple(map(Dagger, reversed(arg.args)))) + if arg.is_Number: + return arg + if arg.is_Pow: + return Pow(Dagger(arg.args[0]), arg.args[1]) + if arg == I: + return -arg + if isinstance(arg, Function): + if all(a.is_commutative for a in arg.args): + return arg.func(*[Dagger(a) for a in arg.args]) + else: + return None + + def _dagger_(self): + return self.args[0] + + +class TensorSymbol(Expr): + + is_commutative = True + + +class AntiSymmetricTensor(TensorSymbol): + """Stores upper and lower indices in separate Tuple's. + + Each group of indices is assumed to be antisymmetric. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.secondquant import AntiSymmetricTensor + >>> i, j = symbols('i j', below_fermi=True) + >>> a, b = symbols('a b', above_fermi=True) + >>> AntiSymmetricTensor('v', (a, i), (b, j)) + AntiSymmetricTensor(v, (a, i), (b, j)) + >>> AntiSymmetricTensor('v', (i, a), (b, j)) + -AntiSymmetricTensor(v, (a, i), (b, j)) + + As you can see, the indices are automatically sorted to a canonical form. + + """ + + def __new__(cls, symbol, upper, lower): + + try: + upper, signu = _sort_anticommuting_fermions( + upper, key=_sqkey_index) + lower, signl = _sort_anticommuting_fermions( + lower, key=_sqkey_index) + + except ViolationOfPauliPrinciple: + return S.Zero + + symbol = sympify(symbol) + upper = Tuple(*upper) + lower = Tuple(*lower) + + if (signu + signl) % 2: + return -TensorSymbol.__new__(cls, symbol, upper, lower) + else: + + return TensorSymbol.__new__(cls, symbol, upper, lower) + + def _latex(self, printer): + return "{%s^{%s}_{%s}}" % ( + self.symbol, + "".join([ printer._print(i) for i in self.args[1]]), + "".join([ printer._print(i) for i in self.args[2]]) + ) + + @property + def symbol(self): + """ + Returns the symbol of the tensor. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.secondquant import AntiSymmetricTensor + >>> i, j = symbols('i,j', below_fermi=True) + >>> a, b = symbols('a,b', above_fermi=True) + >>> AntiSymmetricTensor('v', (a, i), (b, j)) + AntiSymmetricTensor(v, (a, i), (b, j)) + >>> AntiSymmetricTensor('v', (a, i), (b, j)).symbol + v + + """ + return self.args[0] + + @property + def upper(self): + """ + Returns the upper indices. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.secondquant import AntiSymmetricTensor + >>> i, j = symbols('i,j', below_fermi=True) + >>> a, b = symbols('a,b', above_fermi=True) + >>> AntiSymmetricTensor('v', (a, i), (b, j)) + AntiSymmetricTensor(v, (a, i), (b, j)) + >>> AntiSymmetricTensor('v', (a, i), (b, j)).upper + (a, i) + + + """ + return self.args[1] + + @property + def lower(self): + """ + Returns the lower indices. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.secondquant import AntiSymmetricTensor + >>> i, j = symbols('i,j', below_fermi=True) + >>> a, b = symbols('a,b', above_fermi=True) + >>> AntiSymmetricTensor('v', (a, i), (b, j)) + AntiSymmetricTensor(v, (a, i), (b, j)) + >>> AntiSymmetricTensor('v', (a, i), (b, j)).lower + (b, j) + + """ + return self.args[2] + + def __str__(self): + return "%s(%s,%s)" % self.args + + +class SqOperator(Expr): + """ + Base class for Second Quantization operators. + """ + + op_symbol = 'sq' + + is_commutative = False + + def __new__(cls, k): + obj = Basic.__new__(cls, sympify(k)) + return obj + + @property + def state(self): + """ + Returns the state index related to this operator. + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy.physics.secondquant import F, Fd, B, Bd + >>> p = Symbol('p') + >>> F(p).state + p + >>> Fd(p).state + p + >>> B(p).state + p + >>> Bd(p).state + p + + """ + return self.args[0] + + @property + def is_symbolic(self): + """ + Returns True if the state is a symbol (as opposed to a number). + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy.physics.secondquant import F + >>> p = Symbol('p') + >>> F(p).is_symbolic + True + >>> F(1).is_symbolic + False + + """ + if self.state.is_Integer: + return False + else: + return True + + def __repr__(self): + return NotImplemented + + def __str__(self): + return "%s(%r)" % (self.op_symbol, self.state) + + def apply_operator(self, state): + """ + Applies an operator to itself. + """ + raise NotImplementedError('implement apply_operator in a subclass') + + +class BosonicOperator(SqOperator): + pass + + +class Annihilator(SqOperator): + pass + + +class Creator(SqOperator): + pass + + +class AnnihilateBoson(BosonicOperator, Annihilator): + """ + Bosonic annihilation operator. + + Examples + ======== + + >>> from sympy.physics.secondquant import B + >>> from sympy.abc import x + >>> B(x) + AnnihilateBoson(x) + """ + + op_symbol = 'b' + + def _dagger_(self): + return CreateBoson(self.state) + + def apply_operator(self, state): + """ + Apply state to self if self is not symbolic and state is a FockStateKet, else + multiply self by state. + + Examples + ======== + + >>> from sympy.physics.secondquant import B, BKet + >>> from sympy.abc import x, y, n + >>> B(x).apply_operator(y) + y*AnnihilateBoson(x) + >>> B(0).apply_operator(BKet((n,))) + sqrt(n)*FockStateBosonKet((n - 1,)) + + """ + if not self.is_symbolic and isinstance(state, FockStateKet): + element = self.state + amp = sqrt(state[element]) + return amp*state.down(element) + else: + return Mul(self, state) + + def __repr__(self): + return "AnnihilateBoson(%s)" % self.state + + def _latex(self, printer): + if self.state is S.Zero: + return "b_{0}" + else: + return "b_{%s}" % printer._print(self.state) + +class CreateBoson(BosonicOperator, Creator): + """ + Bosonic creation operator. + """ + + op_symbol = 'b+' + + def _dagger_(self): + return AnnihilateBoson(self.state) + + def apply_operator(self, state): + """ + Apply state to self if self is not symbolic and state is a FockStateKet, else + multiply self by state. + + Examples + ======== + + >>> from sympy.physics.secondquant import B, Dagger, BKet + >>> from sympy.abc import x, y, n + >>> Dagger(B(x)).apply_operator(y) + y*CreateBoson(x) + >>> B(0).apply_operator(BKet((n,))) + sqrt(n)*FockStateBosonKet((n - 1,)) + """ + if not self.is_symbolic and isinstance(state, FockStateKet): + element = self.state + amp = sqrt(state[element] + 1) + return amp*state.up(element) + else: + return Mul(self, state) + + def __repr__(self): + return "CreateBoson(%s)" % self.state + + def _latex(self, printer): + if self.state is S.Zero: + return "{b^\\dagger_{0}}" + else: + return "{b^\\dagger_{%s}}" % printer._print(self.state) + +B = AnnihilateBoson +Bd = CreateBoson + + +class FermionicOperator(SqOperator): + + @property + def is_restricted(self): + """ + Is this FermionicOperator restricted with respect to fermi level? + + Returns + ======= + + 1 : restricted to orbits above fermi + 0 : no restriction + -1 : restricted to orbits below fermi + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy.physics.secondquant import F, Fd + >>> a = Symbol('a', above_fermi=True) + >>> i = Symbol('i', below_fermi=True) + >>> p = Symbol('p') + + >>> F(a).is_restricted + 1 + >>> Fd(a).is_restricted + 1 + >>> F(i).is_restricted + -1 + >>> Fd(i).is_restricted + -1 + >>> F(p).is_restricted + 0 + >>> Fd(p).is_restricted + 0 + + """ + ass = self.args[0].assumptions0 + if ass.get("below_fermi"): + return -1 + if ass.get("above_fermi"): + return 1 + return 0 + + @property + def is_above_fermi(self): + """ + Does the index of this FermionicOperator allow values above fermi? + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy.physics.secondquant import F + >>> a = Symbol('a', above_fermi=True) + >>> i = Symbol('i', below_fermi=True) + >>> p = Symbol('p') + + >>> F(a).is_above_fermi + True + >>> F(i).is_above_fermi + False + >>> F(p).is_above_fermi + True + + Note + ==== + + The same applies to creation operators Fd + + """ + return not self.args[0].assumptions0.get("below_fermi") + + @property + def is_below_fermi(self): + """ + Does the index of this FermionicOperator allow values below fermi? + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy.physics.secondquant import F + >>> a = Symbol('a', above_fermi=True) + >>> i = Symbol('i', below_fermi=True) + >>> p = Symbol('p') + + >>> F(a).is_below_fermi + False + >>> F(i).is_below_fermi + True + >>> F(p).is_below_fermi + True + + The same applies to creation operators Fd + + """ + return not self.args[0].assumptions0.get("above_fermi") + + @property + def is_only_below_fermi(self): + """ + Is the index of this FermionicOperator restricted to values below fermi? + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy.physics.secondquant import F + >>> a = Symbol('a', above_fermi=True) + >>> i = Symbol('i', below_fermi=True) + >>> p = Symbol('p') + + >>> F(a).is_only_below_fermi + False + >>> F(i).is_only_below_fermi + True + >>> F(p).is_only_below_fermi + False + + The same applies to creation operators Fd + """ + return self.is_below_fermi and not self.is_above_fermi + + @property + def is_only_above_fermi(self): + """ + Is the index of this FermionicOperator restricted to values above fermi? + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy.physics.secondquant import F + >>> a = Symbol('a', above_fermi=True) + >>> i = Symbol('i', below_fermi=True) + >>> p = Symbol('p') + + >>> F(a).is_only_above_fermi + True + >>> F(i).is_only_above_fermi + False + >>> F(p).is_only_above_fermi + False + + The same applies to creation operators Fd + """ + return self.is_above_fermi and not self.is_below_fermi + + def _sortkey(self): + h = hash(self) + label = str(self.args[0]) + + if self.is_only_q_creator: + return 1, label, h + if self.is_only_q_annihilator: + return 4, label, h + if isinstance(self, Annihilator): + return 3, label, h + if isinstance(self, Creator): + return 2, label, h + + +class AnnihilateFermion(FermionicOperator, Annihilator): + """ + Fermionic annihilation operator. + """ + + op_symbol = 'f' + + def _dagger_(self): + return CreateFermion(self.state) + + def apply_operator(self, state): + """ + Apply state to self if self is not symbolic and state is a FockStateKet, else + multiply self by state. + + Examples + ======== + + >>> from sympy.physics.secondquant import B, Dagger, BKet + >>> from sympy.abc import x, y, n + >>> Dagger(B(x)).apply_operator(y) + y*CreateBoson(x) + >>> B(0).apply_operator(BKet((n,))) + sqrt(n)*FockStateBosonKet((n - 1,)) + """ + if isinstance(state, FockStateFermionKet): + element = self.state + return state.down(element) + + elif isinstance(state, Mul): + c_part, nc_part = state.args_cnc() + if isinstance(nc_part[0], FockStateFermionKet): + element = self.state + return Mul(*(c_part + [nc_part[0].down(element)] + nc_part[1:])) + else: + return Mul(self, state) + + else: + return Mul(self, state) + + @property + def is_q_creator(self): + """ + Can we create a quasi-particle? (create hole or create particle) + If so, would that be above or below the fermi surface? + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy.physics.secondquant import F + >>> a = Symbol('a', above_fermi=True) + >>> i = Symbol('i', below_fermi=True) + >>> p = Symbol('p') + + >>> F(a).is_q_creator + 0 + >>> F(i).is_q_creator + -1 + >>> F(p).is_q_creator + -1 + + """ + if self.is_below_fermi: + return -1 + return 0 + + @property + def is_q_annihilator(self): + """ + Can we destroy a quasi-particle? (annihilate hole or annihilate particle) + If so, would that be above or below the fermi surface? + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy.physics.secondquant import F + >>> a = Symbol('a', above_fermi=1) + >>> i = Symbol('i', below_fermi=1) + >>> p = Symbol('p') + + >>> F(a).is_q_annihilator + 1 + >>> F(i).is_q_annihilator + 0 + >>> F(p).is_q_annihilator + 1 + + """ + if self.is_above_fermi: + return 1 + return 0 + + @property + def is_only_q_creator(self): + """ + Always create a quasi-particle? (create hole or create particle) + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy.physics.secondquant import F + >>> a = Symbol('a', above_fermi=True) + >>> i = Symbol('i', below_fermi=True) + >>> p = Symbol('p') + + >>> F(a).is_only_q_creator + False + >>> F(i).is_only_q_creator + True + >>> F(p).is_only_q_creator + False + + """ + return self.is_only_below_fermi + + @property + def is_only_q_annihilator(self): + """ + Always destroy a quasi-particle? (annihilate hole or annihilate particle) + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy.physics.secondquant import F + >>> a = Symbol('a', above_fermi=True) + >>> i = Symbol('i', below_fermi=True) + >>> p = Symbol('p') + + >>> F(a).is_only_q_annihilator + True + >>> F(i).is_only_q_annihilator + False + >>> F(p).is_only_q_annihilator + False + + """ + return self.is_only_above_fermi + + def __repr__(self): + return "AnnihilateFermion(%s)" % self.state + + def _latex(self, printer): + if self.state is S.Zero: + return "a_{0}" + else: + return "a_{%s}" % printer._print(self.state) + + +class CreateFermion(FermionicOperator, Creator): + """ + Fermionic creation operator. + """ + + op_symbol = 'f+' + + def _dagger_(self): + return AnnihilateFermion(self.state) + + def apply_operator(self, state): + """ + Apply state to self if self is not symbolic and state is a FockStateKet, else + multiply self by state. + + Examples + ======== + + >>> from sympy.physics.secondquant import B, Dagger, BKet + >>> from sympy.abc import x, y, n + >>> Dagger(B(x)).apply_operator(y) + y*CreateBoson(x) + >>> B(0).apply_operator(BKet((n,))) + sqrt(n)*FockStateBosonKet((n - 1,)) + """ + if isinstance(state, FockStateFermionKet): + element = self.state + return state.up(element) + + elif isinstance(state, Mul): + c_part, nc_part = state.args_cnc() + if isinstance(nc_part[0], FockStateFermionKet): + element = self.state + return Mul(*(c_part + [nc_part[0].up(element)] + nc_part[1:])) + + return Mul(self, state) + + @property + def is_q_creator(self): + """ + Can we create a quasi-particle? (create hole or create particle) + If so, would that be above or below the fermi surface? + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy.physics.secondquant import Fd + >>> a = Symbol('a', above_fermi=True) + >>> i = Symbol('i', below_fermi=True) + >>> p = Symbol('p') + + >>> Fd(a).is_q_creator + 1 + >>> Fd(i).is_q_creator + 0 + >>> Fd(p).is_q_creator + 1 + + """ + if self.is_above_fermi: + return 1 + return 0 + + @property + def is_q_annihilator(self): + """ + Can we destroy a quasi-particle? (annihilate hole or annihilate particle) + If so, would that be above or below the fermi surface? + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy.physics.secondquant import Fd + >>> a = Symbol('a', above_fermi=1) + >>> i = Symbol('i', below_fermi=1) + >>> p = Symbol('p') + + >>> Fd(a).is_q_annihilator + 0 + >>> Fd(i).is_q_annihilator + -1 + >>> Fd(p).is_q_annihilator + -1 + + """ + if self.is_below_fermi: + return -1 + return 0 + + @property + def is_only_q_creator(self): + """ + Always create a quasi-particle? (create hole or create particle) + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy.physics.secondquant import Fd + >>> a = Symbol('a', above_fermi=True) + >>> i = Symbol('i', below_fermi=True) + >>> p = Symbol('p') + + >>> Fd(a).is_only_q_creator + True + >>> Fd(i).is_only_q_creator + False + >>> Fd(p).is_only_q_creator + False + + """ + return self.is_only_above_fermi + + @property + def is_only_q_annihilator(self): + """ + Always destroy a quasi-particle? (annihilate hole or annihilate particle) + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy.physics.secondquant import Fd + >>> a = Symbol('a', above_fermi=True) + >>> i = Symbol('i', below_fermi=True) + >>> p = Symbol('p') + + >>> Fd(a).is_only_q_annihilator + False + >>> Fd(i).is_only_q_annihilator + True + >>> Fd(p).is_only_q_annihilator + False + + """ + return self.is_only_below_fermi + + def __repr__(self): + return "CreateFermion(%s)" % self.state + + def _latex(self, printer): + if self.state is S.Zero: + return "{a^\\dagger_{0}}" + else: + return "{a^\\dagger_{%s}}" % printer._print(self.state) + +Fd = CreateFermion +F = AnnihilateFermion + + +class FockState(Expr): + """ + Many particle Fock state with a sequence of occupation numbers. + + Anywhere you can have a FockState, you can also have S.Zero. + All code must check for this! + + Base class to represent FockStates. + """ + is_commutative = False + + def __new__(cls, occupations): + """ + occupations is a list with two possible meanings: + + - For bosons it is a list of occupation numbers. + Element i is the number of particles in state i. + + - For fermions it is a list of occupied orbits. + Element 0 is the state that was occupied first, element i + is the i'th occupied state. + """ + occupations = list(map(sympify, occupations)) + obj = Basic.__new__(cls, Tuple(*occupations)) + return obj + + def __getitem__(self, i): + i = int(i) + return self.args[0][i] + + def __repr__(self): + return ("FockState(%r)") % (self.args) + + def __str__(self): + return "%s%r%s" % (getattr(self, 'lbracket', ""), self._labels(), getattr(self, 'rbracket', "")) + + def _labels(self): + return self.args[0] + + def __len__(self): + return len(self.args[0]) + + def _latex(self, printer): + return "%s%s%s" % (getattr(self, 'lbracket_latex', ""), printer._print(self._labels()), getattr(self, 'rbracket_latex', "")) + + +class BosonState(FockState): + """ + Base class for FockStateBoson(Ket/Bra). + """ + + def up(self, i): + """ + Performs the action of a creation operator. + + Examples + ======== + + >>> from sympy.physics.secondquant import BBra + >>> b = BBra([1, 2]) + >>> b + FockStateBosonBra((1, 2)) + >>> b.up(1) + FockStateBosonBra((1, 3)) + """ + i = int(i) + new_occs = list(self.args[0]) + new_occs[i] = new_occs[i] + S.One + return self.__class__(new_occs) + + def down(self, i): + """ + Performs the action of an annihilation operator. + + Examples + ======== + + >>> from sympy.physics.secondquant import BBra + >>> b = BBra([1, 2]) + >>> b + FockStateBosonBra((1, 2)) + >>> b.down(1) + FockStateBosonBra((1, 1)) + """ + i = int(i) + new_occs = list(self.args[0]) + if new_occs[i] == S.Zero: + return S.Zero + else: + new_occs[i] = new_occs[i] - S.One + return self.__class__(new_occs) + + +class FermionState(FockState): + """ + Base class for FockStateFermion(Ket/Bra). + """ + + fermi_level = 0 + + def __new__(cls, occupations, fermi_level=0): + occupations = list(map(sympify, occupations)) + if len(occupations) > 1: + try: + (occupations, sign) = _sort_anticommuting_fermions( + occupations, key=_sqkey_index) + except ViolationOfPauliPrinciple: + return S.Zero + else: + sign = 0 + + cls.fermi_level = fermi_level + + if cls._count_holes(occupations) > fermi_level: + return S.Zero + + if sign % 2: + return S.NegativeOne*FockState.__new__(cls, occupations) + else: + return FockState.__new__(cls, occupations) + + def up(self, i): + """ + Performs the action of a creation operator. + + Explanation + =========== + + If below fermi we try to remove a hole, + if above fermi we try to create a particle. + + If general index p we return ``Kronecker(p,i)*self`` + where ``i`` is a new symbol with restriction above or below. + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy.physics.secondquant import FKet + >>> a = Symbol('a', above_fermi=True) + >>> i = Symbol('i', below_fermi=True) + >>> p = Symbol('p') + + >>> FKet([]).up(a) + FockStateFermionKet((a,)) + + A creator acting on vacuum below fermi vanishes + + >>> FKet([]).up(i) + 0 + + + """ + present = i in self.args[0] + + if self._only_above_fermi(i): + if present: + return S.Zero + else: + return self._add_orbit(i) + elif self._only_below_fermi(i): + if present: + return self._remove_orbit(i) + else: + return S.Zero + else: + if present: + hole = Dummy("i", below_fermi=True) + return KroneckerDelta(i, hole)*self._remove_orbit(i) + else: + particle = Dummy("a", above_fermi=True) + return KroneckerDelta(i, particle)*self._add_orbit(i) + + def down(self, i): + """ + Performs the action of an annihilation operator. + + Explanation + =========== + + If below fermi we try to create a hole, + If above fermi we try to remove a particle. + + If general index p we return ``Kronecker(p,i)*self`` + where ``i`` is a new symbol with restriction above or below. + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy.physics.secondquant import FKet + >>> a = Symbol('a', above_fermi=True) + >>> i = Symbol('i', below_fermi=True) + >>> p = Symbol('p') + + An annihilator acting on vacuum above fermi vanishes + + >>> FKet([]).down(a) + 0 + + Also below fermi, it vanishes, unless we specify a fermi level > 0 + + >>> FKet([]).down(i) + 0 + >>> FKet([],4).down(i) + FockStateFermionKet((i,)) + + """ + present = i in self.args[0] + + if self._only_above_fermi(i): + if present: + return self._remove_orbit(i) + else: + return S.Zero + + elif self._only_below_fermi(i): + if present: + return S.Zero + else: + return self._add_orbit(i) + else: + if present: + hole = Dummy("i", below_fermi=True) + return KroneckerDelta(i, hole)*self._add_orbit(i) + else: + particle = Dummy("a", above_fermi=True) + return KroneckerDelta(i, particle)*self._remove_orbit(i) + + @classmethod + def _only_below_fermi(cls, i): + """ + Tests if given orbit is only below fermi surface. + + If nothing can be concluded we return a conservative False. + """ + if i.is_number: + return i <= cls.fermi_level + if i.assumptions0.get('below_fermi'): + return True + return False + + @classmethod + def _only_above_fermi(cls, i): + """ + Tests if given orbit is only above fermi surface. + + If fermi level has not been set we return True. + If nothing can be concluded we return a conservative False. + """ + if i.is_number: + return i > cls.fermi_level + if i.assumptions0.get('above_fermi'): + return True + return not cls.fermi_level + + def _remove_orbit(self, i): + """ + Removes particle/fills hole in orbit i. No input tests performed here. + """ + new_occs = list(self.args[0]) + pos = new_occs.index(i) + del new_occs[pos] + if (pos) % 2: + return S.NegativeOne*self.__class__(new_occs, self.fermi_level) + else: + return self.__class__(new_occs, self.fermi_level) + + def _add_orbit(self, i): + """ + Adds particle/creates hole in orbit i. No input tests performed here. + """ + return self.__class__((i,) + self.args[0], self.fermi_level) + + @classmethod + def _count_holes(cls, occupations): + """ + Returns the number of identified hole states in occupations list. + """ + return len([i for i in occupations if cls._only_below_fermi(i)]) + + def _negate_holes(self, occupations): + """ + Returns the occupations list where states below the fermi level have negative labels. + + For symbolic state labels, no sign is included. + """ + return tuple([-i if self._only_below_fermi(i) and i.is_number else i for i in occupations]) + + def __repr__(self): + if self.fermi_level: + return "FockStateKet(%r, fermi_level=%s)" % (self.args[0], self.fermi_level) + else: + return "FockStateKet(%r)" % (self.args[0],) + + def _labels(self): + return self._negate_holes(self.args[0]) + + +class FockStateKet(FockState): + """ + Representation of a ket. + """ + lbracket = '|' + rbracket = '>' + lbracket_latex = r'\left|' + rbracket_latex = r'\right\rangle' + + +class FockStateBra(FockState): + """ + Representation of a bra. + """ + lbracket = '<' + rbracket = '|' + lbracket_latex = r'\left\langle' + rbracket_latex = r'\right|' + + def __mul__(self, other): + if isinstance(other, FockStateKet): + return InnerProduct(self, other) + else: + return Expr.__mul__(self, other) + + +class FockStateBosonKet(BosonState, FockStateKet): + """ + Many particle Fock state with a sequence of occupation numbers. + + Occupation numbers can be any integer >= 0. + + Examples + ======== + + >>> from sympy.physics.secondquant import BKet + >>> BKet([1, 2]) + FockStateBosonKet((1, 2)) + """ + def _dagger_(self): + return FockStateBosonBra(*self.args) + + +class FockStateBosonBra(BosonState, FockStateBra): + """ + Describes a collection of BosonBra particles. + + Examples + ======== + + >>> from sympy.physics.secondquant import BBra + >>> BBra([1, 2]) + FockStateBosonBra((1, 2)) + """ + def _dagger_(self): + return FockStateBosonKet(*self.args) + + +class FockStateFermionKet(FermionState, FockStateKet): + """ + Many-particle Fock state with a sequence of occupied orbits. + + Explanation + =========== + + Each state can only have one particle, so we choose to store a list of + occupied orbits rather than a tuple with occupation numbers (zeros and ones). + + states below fermi level are holes, and are represented by negative labels + in the occupation list. + + For symbolic state labels, the fermi_level caps the number of allowed hole- + states. + + Examples + ======== + + >>> from sympy.physics.secondquant import FKet + >>> FKet([1, 2]) + FockStateFermionKet((1, 2)) + """ + def _dagger_(self): + return FockStateFermionBra(*self.args) + + +class FockStateFermionBra(FermionState, FockStateBra): + """ + See Also + ======== + + FockStateFermionKet + + Examples + ======== + + >>> from sympy.physics.secondquant import FBra + >>> FBra([1, 2]) + FockStateFermionBra((1, 2)) + """ + def _dagger_(self): + return FockStateFermionKet(*self.args) + +BBra = FockStateBosonBra +BKet = FockStateBosonKet +FBra = FockStateFermionBra +FKet = FockStateFermionKet + + +def _apply_Mul(m): + """ + Take a Mul instance with operators and apply them to states. + + Explanation + =========== + + This method applies all operators with integer state labels + to the actual states. For symbolic state labels, nothing is done. + When inner products of FockStates are encountered (like ), + they are converted to instances of InnerProduct. + + This does not currently work on double inner products like, + . + + If the argument is not a Mul, it is simply returned as is. + """ + if not isinstance(m, Mul): + return m + c_part, nc_part = m.args_cnc() + n_nc = len(nc_part) + if n_nc in (0, 1): + return m + else: + last = nc_part[-1] + next_to_last = nc_part[-2] + if isinstance(last, FockStateKet): + if isinstance(next_to_last, SqOperator): + if next_to_last.is_symbolic: + return m + else: + result = next_to_last.apply_operator(last) + if result == 0: + return S.Zero + else: + return _apply_Mul(Mul(*(c_part + nc_part[:-2] + [result]))) + elif isinstance(next_to_last, Pow): + if isinstance(next_to_last.base, SqOperator) and \ + next_to_last.exp.is_Integer: + if next_to_last.base.is_symbolic: + return m + else: + result = last + for i in range(next_to_last.exp): + result = next_to_last.base.apply_operator(result) + if result == 0: + break + if result == 0: + return S.Zero + else: + return _apply_Mul(Mul(*(c_part + nc_part[:-2] + [result]))) + else: + return m + elif isinstance(next_to_last, FockStateBra): + result = InnerProduct(next_to_last, last) + if result == 0: + return S.Zero + else: + return _apply_Mul(Mul(*(c_part + nc_part[:-2] + [result]))) + else: + return m + else: + return m + + +def apply_operators(e): + """ + Take a SymPy expression with operators and states and apply the operators. + + Examples + ======== + + >>> from sympy.physics.secondquant import apply_operators + >>> from sympy import sympify + >>> apply_operators(sympify(3)+4) + 7 + """ + e = e.expand() + muls = e.atoms(Mul) + subs_list = [(m, _apply_Mul(m)) for m in iter(muls)] + return e.subs(subs_list) + + +class InnerProduct(Basic): + """ + An unevaluated inner product between a bra and ket. + + Explanation + =========== + + Currently this class just reduces things to a product of + Kronecker Deltas. In the future, we could introduce abstract + states like ``|a>`` and ``|b>``, and leave the inner product unevaluated as + ````. + + """ + is_commutative = True + + def __new__(cls, bra, ket): + if not isinstance(bra, FockStateBra): + raise TypeError("must be a bra") + if not isinstance(ket, FockStateKet): + raise TypeError("must be a ket") + return cls.eval(bra, ket) + + @classmethod + def eval(cls, bra, ket): + result = S.One + for i, j in zip(bra.args[0], ket.args[0]): + result *= KroneckerDelta(i, j) + if result == 0: + break + return result + + @property + def bra(self): + """Returns the bra part of the state""" + return self.args[0] + + @property + def ket(self): + """Returns the ket part of the state""" + return self.args[1] + + def __repr__(self): + sbra = repr(self.bra) + sket = repr(self.ket) + return "%s|%s" % (sbra[:-1], sket[1:]) + + def __str__(self): + return self.__repr__() + + +def matrix_rep(op, basis): + """ + Find the representation of an operator in a basis. + + Examples + ======== + + >>> from sympy.physics.secondquant import VarBosonicBasis, B, matrix_rep + >>> b = VarBosonicBasis(5) + >>> o = B(0) + >>> matrix_rep(o, b) + Matrix([ + [0, 1, 0, 0, 0], + [0, 0, sqrt(2), 0, 0], + [0, 0, 0, sqrt(3), 0], + [0, 0, 0, 0, 2], + [0, 0, 0, 0, 0]]) + """ + a = zeros(len(basis)) + for i in range(len(basis)): + for j in range(len(basis)): + a[i, j] = apply_operators(Dagger(basis[i])*op*basis[j]) + return a + + +class BosonicBasis: + """ + Base class for a basis set of bosonic Fock states. + """ + pass + + +class VarBosonicBasis: + """ + A single state, variable particle number basis set. + + Examples + ======== + + >>> from sympy.physics.secondquant import VarBosonicBasis + >>> b = VarBosonicBasis(5) + >>> b + [FockState((0,)), FockState((1,)), FockState((2,)), + FockState((3,)), FockState((4,))] + """ + + def __init__(self, n_max): + self.n_max = n_max + self._build_states() + + def _build_states(self): + self.basis = [] + for i in range(self.n_max): + self.basis.append(FockStateBosonKet([i])) + self.n_basis = len(self.basis) + + def index(self, state): + """ + Returns the index of state in basis. + + Examples + ======== + + >>> from sympy.physics.secondquant import VarBosonicBasis + >>> b = VarBosonicBasis(3) + >>> state = b.state(1) + >>> b + [FockState((0,)), FockState((1,)), FockState((2,))] + >>> state + FockStateBosonKet((1,)) + >>> b.index(state) + 1 + """ + return self.basis.index(state) + + def state(self, i): + """ + The state of a single basis. + + Examples + ======== + + >>> from sympy.physics.secondquant import VarBosonicBasis + >>> b = VarBosonicBasis(5) + >>> b.state(3) + FockStateBosonKet((3,)) + """ + return self.basis[i] + + def __getitem__(self, i): + return self.state(i) + + def __len__(self): + return len(self.basis) + + def __repr__(self): + return repr(self.basis) + + +class FixedBosonicBasis(BosonicBasis): + """ + Fixed particle number basis set. + + Examples + ======== + + >>> from sympy.physics.secondquant import FixedBosonicBasis + >>> b = FixedBosonicBasis(2, 2) + >>> state = b.state(1) + >>> b + [FockState((2, 0)), FockState((1, 1)), FockState((0, 2))] + >>> state + FockStateBosonKet((1, 1)) + >>> b.index(state) + 1 + """ + def __init__(self, n_particles, n_levels): + self.n_particles = n_particles + self.n_levels = n_levels + self._build_particle_locations() + self._build_states() + + def _build_particle_locations(self): + tup = ["i%i" % i for i in range(self.n_particles)] + first_loop = "for i0 in range(%i)" % self.n_levels + other_loops = '' + for cur, prev in zip(tup[1:], tup): + temp = "for %s in range(%s + 1) " % (cur, prev) + other_loops = other_loops + temp + tup_string = "(%s)" % ", ".join(tup) + list_comp = "[%s %s %s]" % (tup_string, first_loop, other_loops) + result = eval(list_comp) + if self.n_particles == 1: + result = [(item,) for item in result] + self.particle_locations = result + + def _build_states(self): + self.basis = [] + for tuple_of_indices in self.particle_locations: + occ_numbers = self.n_levels*[0] + for level in tuple_of_indices: + occ_numbers[level] += 1 + self.basis.append(FockStateBosonKet(occ_numbers)) + self.n_basis = len(self.basis) + + def index(self, state): + """Returns the index of state in basis. + + Examples + ======== + + >>> from sympy.physics.secondquant import FixedBosonicBasis + >>> b = FixedBosonicBasis(2, 3) + >>> b.index(b.state(3)) + 3 + """ + return self.basis.index(state) + + def state(self, i): + """Returns the state that lies at index i of the basis + + Examples + ======== + + >>> from sympy.physics.secondquant import FixedBosonicBasis + >>> b = FixedBosonicBasis(2, 3) + >>> b.state(3) + FockStateBosonKet((1, 0, 1)) + """ + return self.basis[i] + + def __getitem__(self, i): + return self.state(i) + + def __len__(self): + return len(self.basis) + + def __repr__(self): + return repr(self.basis) + + +class Commutator(Function): + """ + The Commutator: [A, B] = A*B - B*A + + The arguments are ordered according to .__cmp__() + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.secondquant import Commutator + >>> A, B = symbols('A,B', commutative=False) + >>> Commutator(B, A) + -Commutator(A, B) + + Evaluate the commutator with .doit() + + >>> comm = Commutator(A,B); comm + Commutator(A, B) + >>> comm.doit() + A*B - B*A + + + For two second quantization operators the commutator is evaluated + immediately: + + >>> from sympy.physics.secondquant import Fd, F + >>> a = symbols('a', above_fermi=True) + >>> i = symbols('i', below_fermi=True) + >>> p,q = symbols('p,q') + + >>> Commutator(Fd(a),Fd(i)) + 2*NO(CreateFermion(a)*CreateFermion(i)) + + But for more complicated expressions, the evaluation is triggered by + a call to .doit() + + >>> comm = Commutator(Fd(p)*Fd(q),F(i)); comm + Commutator(CreateFermion(p)*CreateFermion(q), AnnihilateFermion(i)) + >>> comm.doit(wicks=True) + -KroneckerDelta(i, p)*CreateFermion(q) + + KroneckerDelta(i, q)*CreateFermion(p) + + """ + + is_commutative = False + + @classmethod + def eval(cls, a, b): + """ + The Commutator [A,B] is on canonical form if A < B. + + Examples + ======== + + >>> from sympy.physics.secondquant import Commutator, F, Fd + >>> from sympy.abc import x + >>> c1 = Commutator(F(x), Fd(x)) + >>> c2 = Commutator(Fd(x), F(x)) + >>> Commutator.eval(c1, c2) + 0 + """ + if not (a and b): + return S.Zero + if a == b: + return S.Zero + if a.is_commutative or b.is_commutative: + return S.Zero + + # + # [A+B,C] -> [A,C] + [B,C] + # + a = a.expand() + if isinstance(a, Add): + return Add(*[cls(term, b) for term in a.args]) + b = b.expand() + if isinstance(b, Add): + return Add(*[cls(a, term) for term in b.args]) + + # + # [xA,yB] -> xy*[A,B] + # + ca, nca = a.args_cnc() + cb, ncb = b.args_cnc() + c_part = list(ca) + list(cb) + if c_part: + return Mul(Mul(*c_part), cls(Mul._from_args(nca), Mul._from_args(ncb))) + + # + # single second quantization operators + # + if isinstance(a, BosonicOperator) and isinstance(b, BosonicOperator): + if isinstance(b, CreateBoson) and isinstance(a, AnnihilateBoson): + return KroneckerDelta(a.state, b.state) + if isinstance(a, CreateBoson) and isinstance(b, AnnihilateBoson): + return S.NegativeOne*KroneckerDelta(a.state, b.state) + else: + return S.Zero + if isinstance(a, FermionicOperator) and isinstance(b, FermionicOperator): + return wicks(a*b) - wicks(b*a) + + # + # Canonical ordering of arguments + # + if a.sort_key() > b.sort_key(): + return S.NegativeOne*cls(b, a) + + def doit(self, **hints): + """ + Enables the computation of complex expressions. + + Examples + ======== + + >>> from sympy.physics.secondquant import Commutator, F, Fd + >>> from sympy import symbols + >>> i, j = symbols('i,j', below_fermi=True) + >>> a, b = symbols('a,b', above_fermi=True) + >>> c = Commutator(Fd(a)*F(i),Fd(b)*F(j)) + >>> c.doit(wicks=True) + 0 + """ + a = self.args[0] + b = self.args[1] + + if hints.get("wicks"): + a = a.doit(**hints) + b = b.doit(**hints) + try: + return wicks(a*b) - wicks(b*a) + except ContractionAppliesOnlyToFermions: + pass + except WicksTheoremDoesNotApply: + pass + + return (a*b - b*a).doit(**hints) + + def __repr__(self): + return "Commutator(%s,%s)" % (self.args[0], self.args[1]) + + def __str__(self): + return "[%s,%s]" % (self.args[0], self.args[1]) + + def _latex(self, printer): + return "\\left[%s,%s\\right]" % tuple([ + printer._print(arg) for arg in self.args]) + + +class NO(Expr): + """ + This Object is used to represent normal ordering brackets. + + i.e. {abcd} sometimes written :abcd: + + Explanation + =========== + + Applying the function NO(arg) to an argument means that all operators in + the argument will be assumed to anticommute, and have vanishing + contractions. This allows an immediate reordering to canonical form + upon object creation. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.secondquant import NO, F, Fd + >>> p,q = symbols('p,q') + >>> NO(Fd(p)*F(q)) + NO(CreateFermion(p)*AnnihilateFermion(q)) + >>> NO(F(q)*Fd(p)) + -NO(CreateFermion(p)*AnnihilateFermion(q)) + + + Note + ==== + + If you want to generate a normal ordered equivalent of an expression, you + should use the function wicks(). This class only indicates that all + operators inside the brackets anticommute, and have vanishing contractions. + Nothing more, nothing less. + + """ + is_commutative = False + + def __new__(cls, arg): + """ + Use anticommutation to get canonical form of operators. + + Explanation + =========== + + Employ associativity of normal ordered product: {ab{cd}} = {abcd} + but note that {ab}{cd} /= {abcd}. + + We also employ distributivity: {ab + cd} = {ab} + {cd}. + + Canonical form also implies expand() {ab(c+d)} = {abc} + {abd}. + + """ + + # {ab + cd} = {ab} + {cd} + arg = sympify(arg) + arg = arg.expand() + if arg.is_Add: + return Add(*[ cls(term) for term in arg.args]) + + if arg.is_Mul: + + # take coefficient outside of normal ordering brackets + c_part, seq = arg.args_cnc() + if c_part: + coeff = Mul(*c_part) + if not seq: + return coeff + else: + coeff = S.One + + # {ab{cd}} = {abcd} + newseq = [] + foundit = False + for fac in seq: + if isinstance(fac, NO): + newseq.extend(fac.args) + foundit = True + else: + newseq.append(fac) + if foundit: + return coeff*cls(Mul(*newseq)) + + # We assume that the user don't mix B and F operators + if isinstance(seq[0], BosonicOperator): + raise NotImplementedError + + try: + newseq, sign = _sort_anticommuting_fermions(seq) + except ViolationOfPauliPrinciple: + return S.Zero + + if sign % 2: + return (S.NegativeOne*coeff)*cls(Mul(*newseq)) + elif sign: + return coeff*cls(Mul(*newseq)) + else: + pass # since sign==0, no permutations was necessary + + # if we couldn't do anything with Mul object, we just + # mark it as normal ordered + if coeff != S.One: + return coeff*cls(Mul(*newseq)) + return Expr.__new__(cls, Mul(*newseq)) + + if isinstance(arg, NO): + return arg + + # if object was not Mul or Add, normal ordering does not apply + return arg + + @property + def has_q_creators(self): + """ + Return 0 if the leftmost argument of the first argument is a not a + q_creator, else 1 if it is above fermi or -1 if it is below fermi. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.secondquant import NO, F, Fd + + >>> a = symbols('a', above_fermi=True) + >>> i = symbols('i', below_fermi=True) + >>> NO(Fd(a)*Fd(i)).has_q_creators + 1 + >>> NO(F(i)*F(a)).has_q_creators + -1 + >>> NO(Fd(i)*F(a)).has_q_creators #doctest: +SKIP + 0 + + """ + return self.args[0].args[0].is_q_creator + + @property + def has_q_annihilators(self): + """ + Return 0 if the rightmost argument of the first argument is a not a + q_annihilator, else 1 if it is above fermi or -1 if it is below fermi. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.secondquant import NO, F, Fd + + >>> a = symbols('a', above_fermi=True) + >>> i = symbols('i', below_fermi=True) + >>> NO(Fd(a)*Fd(i)).has_q_annihilators + -1 + >>> NO(F(i)*F(a)).has_q_annihilators + 1 + >>> NO(Fd(a)*F(i)).has_q_annihilators + 0 + + """ + return self.args[0].args[-1].is_q_annihilator + + def doit(self, **hints): + """ + Either removes the brackets or enables complex computations + in its arguments. + + Examples + ======== + + >>> from sympy.physics.secondquant import NO, Fd, F + >>> from textwrap import fill + >>> from sympy import symbols, Dummy + >>> p,q = symbols('p,q', cls=Dummy) + >>> print(fill(str(NO(Fd(p)*F(q)).doit()))) + KroneckerDelta(_a, _p)*KroneckerDelta(_a, + _q)*CreateFermion(_a)*AnnihilateFermion(_a) + KroneckerDelta(_a, + _p)*KroneckerDelta(_i, _q)*CreateFermion(_a)*AnnihilateFermion(_i) - + KroneckerDelta(_a, _q)*KroneckerDelta(_i, + _p)*AnnihilateFermion(_a)*CreateFermion(_i) - KroneckerDelta(_i, + _p)*KroneckerDelta(_i, _q)*AnnihilateFermion(_i)*CreateFermion(_i) + """ + if hints.get("remove_brackets", True): + return self._remove_brackets() + else: + return self.__new__(type(self), self.args[0].doit(**hints)) + + def _remove_brackets(self): + """ + Returns the sorted string without normal order brackets. + + The returned string have the property that no nonzero + contractions exist. + """ + + # check if any creator is also an annihilator + subslist = [] + for i in self.iter_q_creators(): + if self[i].is_q_annihilator: + assume = self[i].state.assumptions0 + + # only operators with a dummy index can be split in two terms + if isinstance(self[i].state, Dummy): + + # create indices with fermi restriction + assume.pop("above_fermi", None) + assume["below_fermi"] = True + below = Dummy('i', **assume) + assume.pop("below_fermi", None) + assume["above_fermi"] = True + above = Dummy('a', **assume) + + cls = type(self[i]) + split = ( + self[i].__new__(cls, below) + * KroneckerDelta(below, self[i].state) + + self[i].__new__(cls, above) + * KroneckerDelta(above, self[i].state) + ) + subslist.append((self[i], split)) + else: + raise SubstitutionOfAmbigousOperatorFailed(self[i]) + if subslist: + result = NO(self.subs(subslist)) + if isinstance(result, Add): + return Add(*[term.doit() for term in result.args]) + else: + return self.args[0] + + def _expand_operators(self): + """ + Returns a sum of NO objects that contain no ambiguous q-operators. + + Explanation + =========== + + If an index q has range both above and below fermi, the operator F(q) + is ambiguous in the sense that it can be both a q-creator and a q-annihilator. + If q is dummy, it is assumed to be a summation variable and this method + rewrites it into a sum of NO terms with unambiguous operators: + + {Fd(p)*F(q)} = {Fd(a)*F(b)} + {Fd(a)*F(i)} + {Fd(j)*F(b)} -{F(i)*Fd(j)} + + where a,b are above and i,j are below fermi level. + """ + return NO(self._remove_brackets) + + def __getitem__(self, i): + if isinstance(i, slice): + indices = i.indices(len(self)) + return [self.args[0].args[i] for i in range(*indices)] + else: + return self.args[0].args[i] + + def __len__(self): + return len(self.args[0].args) + + def iter_q_annihilators(self): + """ + Iterates over the annihilation operators. + + Examples + ======== + + >>> from sympy import symbols + >>> i, j = symbols('i j', below_fermi=True) + >>> a, b = symbols('a b', above_fermi=True) + >>> from sympy.physics.secondquant import NO, F, Fd + >>> no = NO(Fd(a)*F(i)*F(b)*Fd(j)) + + >>> no.iter_q_creators() + + >>> list(no.iter_q_creators()) + [0, 1] + >>> list(no.iter_q_annihilators()) + [3, 2] + + """ + ops = self.args[0].args + iter = range(len(ops) - 1, -1, -1) + for i in iter: + if ops[i].is_q_annihilator: + yield i + else: + break + + def iter_q_creators(self): + """ + Iterates over the creation operators. + + Examples + ======== + + >>> from sympy import symbols + >>> i, j = symbols('i j', below_fermi=True) + >>> a, b = symbols('a b', above_fermi=True) + >>> from sympy.physics.secondquant import NO, F, Fd + >>> no = NO(Fd(a)*F(i)*F(b)*Fd(j)) + + >>> no.iter_q_creators() + + >>> list(no.iter_q_creators()) + [0, 1] + >>> list(no.iter_q_annihilators()) + [3, 2] + + """ + + ops = self.args[0].args + iter = range(0, len(ops)) + for i in iter: + if ops[i].is_q_creator: + yield i + else: + break + + def get_subNO(self, i): + """ + Returns a NO() without FermionicOperator at index i. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.secondquant import F, NO + >>> p, q, r = symbols('p,q,r') + + >>> NO(F(p)*F(q)*F(r)).get_subNO(1) + NO(AnnihilateFermion(p)*AnnihilateFermion(r)) + + """ + arg0 = self.args[0] # it's a Mul by definition of how it's created + mul = arg0._new_rawargs(*(arg0.args[:i] + arg0.args[i + 1:])) + return NO(mul) + + def _latex(self, printer): + return "\\left\\{%s\\right\\}" % printer._print(self.args[0]) + + def __repr__(self): + return "NO(%s)" % self.args[0] + + def __str__(self): + return ":%s:" % self.args[0] + + +def contraction(a, b): + """ + Calculates contraction of Fermionic operators a and b. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.secondquant import F, Fd, contraction + >>> p, q = symbols('p,q') + >>> a, b = symbols('a,b', above_fermi=True) + >>> i, j = symbols('i,j', below_fermi=True) + + A contraction is non-zero only if a quasi-creator is to the right of a + quasi-annihilator: + + >>> contraction(F(a),Fd(b)) + KroneckerDelta(a, b) + >>> contraction(Fd(i),F(j)) + KroneckerDelta(i, j) + + For general indices a non-zero result restricts the indices to below/above + the fermi surface: + + >>> contraction(Fd(p),F(q)) + KroneckerDelta(_i, q)*KroneckerDelta(p, q) + >>> contraction(F(p),Fd(q)) + KroneckerDelta(_a, q)*KroneckerDelta(p, q) + + Two creators or two annihilators always vanishes: + + >>> contraction(F(p),F(q)) + 0 + >>> contraction(Fd(p),Fd(q)) + 0 + + """ + if isinstance(b, FermionicOperator) and isinstance(a, FermionicOperator): + if isinstance(a, AnnihilateFermion) and isinstance(b, CreateFermion): + if b.state.assumptions0.get("below_fermi"): + return S.Zero + if a.state.assumptions0.get("below_fermi"): + return S.Zero + if b.state.assumptions0.get("above_fermi"): + return KroneckerDelta(a.state, b.state) + if a.state.assumptions0.get("above_fermi"): + return KroneckerDelta(a.state, b.state) + + return (KroneckerDelta(a.state, b.state)* + KroneckerDelta(b.state, Dummy('a', above_fermi=True))) + if isinstance(b, AnnihilateFermion) and isinstance(a, CreateFermion): + if b.state.assumptions0.get("above_fermi"): + return S.Zero + if a.state.assumptions0.get("above_fermi"): + return S.Zero + if b.state.assumptions0.get("below_fermi"): + return KroneckerDelta(a.state, b.state) + if a.state.assumptions0.get("below_fermi"): + return KroneckerDelta(a.state, b.state) + + return (KroneckerDelta(a.state, b.state)* + KroneckerDelta(b.state, Dummy('i', below_fermi=True))) + + # vanish if 2xAnnihilator or 2xCreator + return S.Zero + + else: + #not fermion operators + t = ( isinstance(i, FermionicOperator) for i in (a, b) ) + raise ContractionAppliesOnlyToFermions(*t) + + +def _sqkey_operator(sq_operator): + """Generates key for canonical sorting of SQ operators.""" + return sq_operator._sortkey() + +def _sqkey_index(index): + """Key for sorting of indices. + + particle < hole < general + + FIXME: This is a bottle-neck, can we do it faster? + """ + h = hash(index) + label = str(index) + if isinstance(index, Dummy): + if index.assumptions0.get('above_fermi'): + return (20, label, h) + elif index.assumptions0.get('below_fermi'): + return (21, label, h) + else: + return (22, label, h) + + if index.assumptions0.get('above_fermi'): + return (10, label, h) + elif index.assumptions0.get('below_fermi'): + return (11, label, h) + else: + return (12, label, h) + + + +def _sort_anticommuting_fermions(string1, key=_sqkey_operator): + """Sort fermionic operators to canonical order, assuming all pairs anticommute. + + Explanation + =========== + + Uses a bidirectional bubble sort. Items in string1 are not referenced + so in principle they may be any comparable objects. The sorting depends on the + operators '>' and '=='. + + If the Pauli principle is violated, an exception is raised. + + Returns + ======= + + tuple (sorted_str, sign) + + sorted_str: list containing the sorted operators + sign: int telling how many times the sign should be changed + (if sign==0 the string was already sorted) + """ + + verified = False + sign = 0 + rng = list(range(len(string1) - 1)) + rev = list(range(len(string1) - 3, -1, -1)) + + keys = list(map(key, string1)) + key_val = dict(list(zip(keys, string1))) + + while not verified: + verified = True + for i in rng: + left = keys[i] + right = keys[i + 1] + if left == right: + raise ViolationOfPauliPrinciple([left, right]) + if left > right: + verified = False + keys[i:i + 2] = [right, left] + sign = sign + 1 + if verified: + break + for i in rev: + left = keys[i] + right = keys[i + 1] + if left == right: + raise ViolationOfPauliPrinciple([left, right]) + if left > right: + verified = False + keys[i:i + 2] = [right, left] + sign = sign + 1 + string1 = [ key_val[k] for k in keys ] + return (string1, sign) + + +def evaluate_deltas(e): + """ + We evaluate KroneckerDelta symbols in the expression assuming Einstein summation. + + Explanation + =========== + + If one index is repeated it is summed over and in effect substituted with + the other one. If both indices are repeated we substitute according to what + is the preferred index. this is determined by + KroneckerDelta.preferred_index and KroneckerDelta.killable_index. + + In case there are no possible substitutions or if a substitution would + imply a loss of information, nothing is done. + + In case an index appears in more than one KroneckerDelta, the resulting + substitution depends on the order of the factors. Since the ordering is platform + dependent, the literal expression resulting from this function may be hard to + predict. + + Examples + ======== + + We assume the following: + + >>> from sympy import symbols, Function, Dummy, KroneckerDelta + >>> from sympy.physics.secondquant import evaluate_deltas + >>> i,j = symbols('i j', below_fermi=True, cls=Dummy) + >>> a,b = symbols('a b', above_fermi=True, cls=Dummy) + >>> p,q = symbols('p q', cls=Dummy) + >>> f = Function('f') + >>> t = Function('t') + + The order of preference for these indices according to KroneckerDelta is + (a, b, i, j, p, q). + + Trivial cases: + + >>> evaluate_deltas(KroneckerDelta(i,j)*f(i)) # d_ij f(i) -> f(j) + f(_j) + >>> evaluate_deltas(KroneckerDelta(i,j)*f(j)) # d_ij f(j) -> f(i) + f(_i) + >>> evaluate_deltas(KroneckerDelta(i,p)*f(p)) # d_ip f(p) -> f(i) + f(_i) + >>> evaluate_deltas(KroneckerDelta(q,p)*f(p)) # d_qp f(p) -> f(q) + f(_q) + >>> evaluate_deltas(KroneckerDelta(q,p)*f(q)) # d_qp f(q) -> f(p) + f(_p) + + More interesting cases: + + >>> evaluate_deltas(KroneckerDelta(i,p)*t(a,i)*f(p,q)) + f(_i, _q)*t(_a, _i) + >>> evaluate_deltas(KroneckerDelta(a,p)*t(a,i)*f(p,q)) + f(_a, _q)*t(_a, _i) + >>> evaluate_deltas(KroneckerDelta(p,q)*f(p,q)) + f(_p, _p) + + Finally, here are some cases where nothing is done, because that would + imply a loss of information: + + >>> evaluate_deltas(KroneckerDelta(i,p)*f(q)) + f(_q)*KroneckerDelta(_i, _p) + >>> evaluate_deltas(KroneckerDelta(i,p)*f(i)) + f(_i)*KroneckerDelta(_i, _p) + """ + + # We treat Deltas only in mul objects + # for general function objects we don't evaluate KroneckerDeltas in arguments, + # but here we hard code exceptions to this rule + accepted_functions = ( + Add, + ) + if isinstance(e, accepted_functions): + return e.func(*[evaluate_deltas(arg) for arg in e.args]) + + elif isinstance(e, Mul): + # find all occurrences of delta function and count each index present in + # expression. + deltas = [] + indices = {} + for i in e.args: + for s in i.free_symbols: + if s in indices: + indices[s] += 1 + else: + indices[s] = 0 # geek counting simplifies logic below + if isinstance(i, KroneckerDelta): + deltas.append(i) + + for d in deltas: + # If we do something, and there are more deltas, we should recurse + # to treat the resulting expression properly + if d.killable_index.is_Symbol and indices[d.killable_index]: + e = e.subs(d.killable_index, d.preferred_index) + if len(deltas) > 1: + return evaluate_deltas(e) + elif (d.preferred_index.is_Symbol and indices[d.preferred_index] + and d.indices_contain_equal_information): + e = e.subs(d.preferred_index, d.killable_index) + if len(deltas) > 1: + return evaluate_deltas(e) + else: + pass + + return e + # nothing to do, maybe we hit a Symbol or a number + else: + return e + + +def substitute_dummies(expr, new_indices=False, pretty_indices={}): + """ + Collect terms by substitution of dummy variables. + + Explanation + =========== + + This routine allows simplification of Add expressions containing terms + which differ only due to dummy variables. + + The idea is to substitute all dummy variables consistently depending on + the structure of the term. For each term, we obtain a sequence of all + dummy variables, where the order is determined by the index range, what + factors the index belongs to and its position in each factor. See + _get_ordered_dummies() for more information about the sorting of dummies. + The index sequence is then substituted consistently in each term. + + Examples + ======== + + >>> from sympy import symbols, Function, Dummy + >>> from sympy.physics.secondquant import substitute_dummies + >>> a,b,c,d = symbols('a b c d', above_fermi=True, cls=Dummy) + >>> i,j = symbols('i j', below_fermi=True, cls=Dummy) + >>> f = Function('f') + + >>> expr = f(a,b) + f(c,d); expr + f(_a, _b) + f(_c, _d) + + Since a, b, c and d are equivalent summation indices, the expression can be + simplified to a single term (for which the dummy indices are still summed over) + + >>> substitute_dummies(expr) + 2*f(_a, _b) + + + Controlling output: + + By default the dummy symbols that are already present in the expression + will be reused in a different permutation. However, if new_indices=True, + new dummies will be generated and inserted. The keyword 'pretty_indices' + can be used to control this generation of new symbols. + + By default the new dummies will be generated on the form i_1, i_2, a_1, + etc. If you supply a dictionary with key:value pairs in the form: + + { index_group: string_of_letters } + + The letters will be used as labels for the new dummy symbols. The + index_groups must be one of 'above', 'below' or 'general'. + + >>> expr = f(a,b,i,j) + >>> my_dummies = { 'above':'st', 'below':'uv' } + >>> substitute_dummies(expr, new_indices=True, pretty_indices=my_dummies) + f(_s, _t, _u, _v) + + If we run out of letters, or if there is no keyword for some index_group + the default dummy generator will be used as a fallback: + + >>> p,q = symbols('p q', cls=Dummy) # general indices + >>> expr = f(p,q) + >>> substitute_dummies(expr, new_indices=True, pretty_indices=my_dummies) + f(_p_0, _p_1) + + """ + + # setup the replacing dummies + if new_indices: + letters_above = pretty_indices.get('above', "") + letters_below = pretty_indices.get('below', "") + letters_general = pretty_indices.get('general', "") + len_above = len(letters_above) + len_below = len(letters_below) + len_general = len(letters_general) + + def _i(number): + try: + return letters_below[number] + except IndexError: + return 'i_' + str(number - len_below) + + def _a(number): + try: + return letters_above[number] + except IndexError: + return 'a_' + str(number - len_above) + + def _p(number): + try: + return letters_general[number] + except IndexError: + return 'p_' + str(number - len_general) + + aboves = [] + belows = [] + generals = [] + + dummies = expr.atoms(Dummy) + if not new_indices: + dummies = sorted(dummies, key=default_sort_key) + + # generate lists with the dummies we will insert + a = i = p = 0 + for d in dummies: + assum = d.assumptions0 + + if assum.get("above_fermi"): + if new_indices: + sym = _a(a) + a += 1 + l1 = aboves + elif assum.get("below_fermi"): + if new_indices: + sym = _i(i) + i += 1 + l1 = belows + else: + if new_indices: + sym = _p(p) + p += 1 + l1 = generals + + if new_indices: + l1.append(Dummy(sym, **assum)) + else: + l1.append(d) + + expr = expr.expand() + terms = Add.make_args(expr) + new_terms = [] + for term in terms: + i = iter(belows) + a = iter(aboves) + p = iter(generals) + ordered = _get_ordered_dummies(term) + subsdict = {} + for d in ordered: + if d.assumptions0.get('below_fermi'): + subsdict[d] = next(i) + elif d.assumptions0.get('above_fermi'): + subsdict[d] = next(a) + else: + subsdict[d] = next(p) + subslist = [] + final_subs = [] + for k, v in subsdict.items(): + if k == v: + continue + if v in subsdict: + # We check if the sequence of substitutions end quickly. In + # that case, we can avoid temporary symbols if we ensure the + # correct substitution order. + if subsdict[v] in subsdict: + # (x, y) -> (y, x), we need a temporary variable + x = Dummy('x') + subslist.append((k, x)) + final_subs.append((x, v)) + else: + # (x, y) -> (y, a), x->y must be done last + # but before temporary variables are resolved + final_subs.insert(0, (k, v)) + else: + subslist.append((k, v)) + subslist.extend(final_subs) + new_terms.append(term.subs(subslist)) + return Add(*new_terms) + + +class KeyPrinter(StrPrinter): + """Printer for which only equal objects are equal in print""" + def _print_Dummy(self, expr): + return "(%s_%i)" % (expr.name, expr.dummy_index) + + +def __kprint(expr): + p = KeyPrinter() + return p.doprint(expr) + + +def _get_ordered_dummies(mul, verbose=False): + """Returns all dummies in the mul sorted in canonical order. + + Explanation + =========== + + The purpose of the canonical ordering is that dummies can be substituted + consistently across terms with the result that equivalent terms can be + simplified. + + It is not possible to determine if two terms are equivalent based solely on + the dummy order. However, a consistent substitution guided by the ordered + dummies should lead to trivially (non-)equivalent terms, thereby revealing + the equivalence. This also means that if two terms have identical sequences of + dummies, the (non-)equivalence should already be apparent. + + Strategy + -------- + + The canonical order is given by an arbitrary sorting rule. A sort key + is determined for each dummy as a tuple that depends on all factors where + the index is present. The dummies are thereby sorted according to the + contraction structure of the term, instead of sorting based solely on the + dummy symbol itself. + + After all dummies in the term has been assigned a key, we check for identical + keys, i.e. unorderable dummies. If any are found, we call a specialized + method, _determine_ambiguous(), that will determine a unique order based + on recursive calls to _get_ordered_dummies(). + + Key description + --------------- + + A high level description of the sort key: + + 1. Range of the dummy index + 2. Relation to external (non-dummy) indices + 3. Position of the index in the first factor + 4. Position of the index in the second factor + + The sort key is a tuple with the following components: + + 1. A single character indicating the range of the dummy (above, below + or general.) + 2. A list of strings with fully masked string representations of all + factors where the dummy is present. By masked, we mean that dummies + are represented by a symbol to indicate either below fermi, above or + general. No other information is displayed about the dummies at + this point. The list is sorted stringwise. + 3. An integer number indicating the position of the index, in the first + factor as sorted in 2. + 4. An integer number indicating the position of the index, in the second + factor as sorted in 2. + + If a factor is either of type AntiSymmetricTensor or SqOperator, the index + position in items 3 and 4 is indicated as 'upper' or 'lower' only. + (Creation operators are considered upper and annihilation operators lower.) + + If the masked factors are identical, the two factors cannot be ordered + unambiguously in item 2. In this case, items 3, 4 are left out. If several + indices are contracted between the unorderable factors, it will be handled by + _determine_ambiguous() + + + """ + # setup dicts to avoid repeated calculations in key() + args = Mul.make_args(mul) + fac_dum = { fac: fac.atoms(Dummy) for fac in args } + fac_repr = { fac: __kprint(fac) for fac in args } + all_dums = set().union(*fac_dum.values()) + mask = {} + for d in all_dums: + if d.assumptions0.get('below_fermi'): + mask[d] = '0' + elif d.assumptions0.get('above_fermi'): + mask[d] = '1' + else: + mask[d] = '2' + dum_repr = {d: __kprint(d) for d in all_dums} + + def _key(d): + dumstruct = [ fac for fac in fac_dum if d in fac_dum[fac] ] + other_dums = set().union(*[fac_dum[fac] for fac in dumstruct]) + fac = dumstruct[-1] + if other_dums is fac_dum[fac]: + other_dums = fac_dum[fac].copy() + other_dums.remove(d) + masked_facs = [ fac_repr[fac] for fac in dumstruct ] + for d2 in other_dums: + masked_facs = [ fac.replace(dum_repr[d2], mask[d2]) + for fac in masked_facs ] + all_masked = [ fac.replace(dum_repr[d], mask[d]) + for fac in masked_facs ] + masked_facs = dict(list(zip(dumstruct, masked_facs))) + + # dummies for which the ordering cannot be determined + if has_dups(all_masked): + all_masked.sort() + return mask[d], tuple(all_masked) # positions are ambiguous + + # sort factors according to fully masked strings + keydict = dict(list(zip(dumstruct, all_masked))) + dumstruct.sort(key=lambda x: keydict[x]) + all_masked.sort() + + pos_val = [] + for fac in dumstruct: + if isinstance(fac, AntiSymmetricTensor): + if d in fac.upper: + pos_val.append('u') + if d in fac.lower: + pos_val.append('l') + elif isinstance(fac, Creator): + pos_val.append('u') + elif isinstance(fac, Annihilator): + pos_val.append('l') + elif isinstance(fac, NO): + ops = [ op for op in fac if op.has(d) ] + for op in ops: + if isinstance(op, Creator): + pos_val.append('u') + else: + pos_val.append('l') + else: + # fallback to position in string representation + facpos = -1 + while 1: + facpos = masked_facs[fac].find(dum_repr[d], facpos + 1) + if facpos == -1: + break + pos_val.append(facpos) + return (mask[d], tuple(all_masked), pos_val[0], pos_val[-1]) + dumkey = dict(list(zip(all_dums, list(map(_key, all_dums))))) + result = sorted(all_dums, key=lambda x: dumkey[x]) + if has_dups(iter(dumkey.values())): + # We have ambiguities + unordered = defaultdict(set) + for d, k in dumkey.items(): + unordered[k].add(d) + for k in [ k for k in unordered if len(unordered[k]) < 2 ]: + del unordered[k] + + unordered = [ unordered[k] for k in sorted(unordered) ] + result = _determine_ambiguous(mul, result, unordered) + return result + + +def _determine_ambiguous(term, ordered, ambiguous_groups): + # We encountered a term for which the dummy substitution is ambiguous. + # This happens for terms with 2 or more contractions between factors that + # cannot be uniquely ordered independent of summation indices. For + # example: + # + # Sum(p, q) v^{p, .}_{q, .}v^{q, .}_{p, .} + # + # Assuming that the indices represented by . are dummies with the + # same range, the factors cannot be ordered, and there is no + # way to determine a consistent ordering of p and q. + # + # The strategy employed here, is to relabel all unambiguous dummies with + # non-dummy symbols and call _get_ordered_dummies again. This procedure is + # applied to the entire term so there is a possibility that + # _determine_ambiguous() is called again from a deeper recursion level. + + # break recursion if there are no ordered dummies + all_ambiguous = set() + for dummies in ambiguous_groups: + all_ambiguous |= dummies + all_ordered = set(ordered) - all_ambiguous + if not all_ordered: + # FIXME: If we arrive here, there are no ordered dummies. A method to + # handle this needs to be implemented. In order to return something + # useful nevertheless, we choose arbitrarily the first dummy and + # determine the rest from this one. This method is dependent on the + # actual dummy labels which violates an assumption for the + # canonicalization procedure. A better implementation is needed. + group = [ d for d in ordered if d in ambiguous_groups[0] ] + d = group[0] + all_ordered.add(d) + ambiguous_groups[0].remove(d) + + stored_counter = _symbol_factory._counter + subslist = [] + for d in [ d for d in ordered if d in all_ordered ]: + nondum = _symbol_factory._next() + subslist.append((d, nondum)) + newterm = term.subs(subslist) + neworder = _get_ordered_dummies(newterm) + _symbol_factory._set_counter(stored_counter) + + # update ordered list with new information + for group in ambiguous_groups: + ordered_group = [ d for d in neworder if d in group ] + ordered_group.reverse() + result = [] + for d in ordered: + if d in group: + result.append(ordered_group.pop()) + else: + result.append(d) + ordered = result + return ordered + + +class _SymbolFactory: + def __init__(self, label): + self._counterVar = 0 + self._label = label + + def _set_counter(self, value): + """ + Sets counter to value. + """ + self._counterVar = value + + @property + def _counter(self): + """ + What counter is currently at. + """ + return self._counterVar + + def _next(self): + """ + Generates the next symbols and increments counter by 1. + """ + s = Symbol("%s%i" % (self._label, self._counterVar)) + self._counterVar += 1 + return s +_symbol_factory = _SymbolFactory('_]"]_') # most certainly a unique label + + +@cacheit +def _get_contractions(string1, keep_only_fully_contracted=False): + """ + Returns Add-object with contracted terms. + + Uses recursion to find all contractions. -- Internal helper function -- + + Will find nonzero contractions in string1 between indices given in + leftrange and rightrange. + + """ + + # Should we store current level of contraction? + if keep_only_fully_contracted and string1: + result = [] + else: + result = [NO(Mul(*string1))] + + for i in range(len(string1) - 1): + for j in range(i + 1, len(string1)): + + c = contraction(string1[i], string1[j]) + + if c: + sign = (j - i + 1) % 2 + if sign: + coeff = S.NegativeOne*c + else: + coeff = c + + # + # Call next level of recursion + # ============================ + # + # We now need to find more contractions among operators + # + # oplist = string1[:i]+ string1[i+1:j] + string1[j+1:] + # + # To prevent overcounting, we don't allow contractions + # we have already encountered. i.e. contractions between + # string1[:i] <---> string1[i+1:j] + # and string1[:i] <---> string1[j+1:]. + # + # This leaves the case: + oplist = string1[i + 1:j] + string1[j + 1:] + + if oplist: + + result.append(coeff*NO( + Mul(*string1[:i])*_get_contractions( oplist, + keep_only_fully_contracted=keep_only_fully_contracted))) + + else: + result.append(coeff*NO( Mul(*string1[:i]))) + + if keep_only_fully_contracted: + break # next iteration over i leaves leftmost operator string1[0] uncontracted + + return Add(*result) + + +def wicks(e, **kw_args): + """ + Returns the normal ordered equivalent of an expression using Wicks Theorem. + + Examples + ======== + + >>> from sympy import symbols, Dummy + >>> from sympy.physics.secondquant import wicks, F, Fd + >>> p, q, r = symbols('p,q,r') + >>> wicks(Fd(p)*F(q)) + KroneckerDelta(_i, q)*KroneckerDelta(p, q) + NO(CreateFermion(p)*AnnihilateFermion(q)) + + By default, the expression is expanded: + + >>> wicks(F(p)*(F(q)+F(r))) + NO(AnnihilateFermion(p)*AnnihilateFermion(q)) + NO(AnnihilateFermion(p)*AnnihilateFermion(r)) + + With the keyword 'keep_only_fully_contracted=True', only fully contracted + terms are returned. + + By request, the result can be simplified in the following order: + -- KroneckerDelta functions are evaluated + -- Dummy variables are substituted consistently across terms + + >>> p, q, r = symbols('p q r', cls=Dummy) + >>> wicks(Fd(p)*(F(q)+F(r)), keep_only_fully_contracted=True) + KroneckerDelta(_i, _q)*KroneckerDelta(_p, _q) + KroneckerDelta(_i, _r)*KroneckerDelta(_p, _r) + + """ + + if not e: + return S.Zero + + opts = { + 'simplify_kronecker_deltas': False, + 'expand': True, + 'simplify_dummies': False, + 'keep_only_fully_contracted': False + } + opts.update(kw_args) + + # check if we are already normally ordered + if isinstance(e, NO): + if opts['keep_only_fully_contracted']: + return S.Zero + else: + return e + elif isinstance(e, FermionicOperator): + if opts['keep_only_fully_contracted']: + return S.Zero + else: + return e + + # break up any NO-objects, and evaluate commutators + e = e.doit(wicks=True) + + # make sure we have only one term to consider + e = e.expand() + if isinstance(e, Add): + if opts['simplify_dummies']: + return substitute_dummies(Add(*[ wicks(term, **kw_args) for term in e.args])) + else: + return Add(*[ wicks(term, **kw_args) for term in e.args]) + + # For Mul-objects we can actually do something + if isinstance(e, Mul): + + # we don't want to mess around with commuting part of Mul + # so we factorize it out before starting recursion + c_part = [] + string1 = [] + for factor in e.args: + if factor.is_commutative: + c_part.append(factor) + else: + string1.append(factor) + n = len(string1) + + # catch trivial cases + if n == 0: + result = e + elif n == 1: + if opts['keep_only_fully_contracted']: + return S.Zero + else: + result = e + + else: # non-trivial + + if isinstance(string1[0], BosonicOperator): + raise NotImplementedError + + string1 = tuple(string1) + + # recursion over higher order contractions + result = _get_contractions(string1, + keep_only_fully_contracted=opts['keep_only_fully_contracted'] ) + result = Mul(*c_part)*result + + if opts['expand']: + result = result.expand() + if opts['simplify_kronecker_deltas']: + result = evaluate_deltas(result) + + return result + + # there was nothing to do + return e + + +class PermutationOperator(Expr): + """ + Represents the index permutation operator P(ij). + + P(ij)*f(i)*g(j) = f(i)*g(j) - f(j)*g(i) + """ + is_commutative = True + + def __new__(cls, i, j): + i, j = sorted(map(sympify, (i, j)), key=default_sort_key) + obj = Basic.__new__(cls, i, j) + return obj + + def get_permuted(self, expr): + """ + Returns -expr with permuted indices. + + Explanation + =========== + + >>> from sympy import symbols, Function + >>> from sympy.physics.secondquant import PermutationOperator + >>> p,q = symbols('p,q') + >>> f = Function('f') + >>> PermutationOperator(p,q).get_permuted(f(p,q)) + -f(q, p) + + """ + i = self.args[0] + j = self.args[1] + if expr.has(i) and expr.has(j): + tmp = Dummy() + expr = expr.subs(i, tmp) + expr = expr.subs(j, i) + expr = expr.subs(tmp, j) + return S.NegativeOne*expr + else: + return expr + + def _latex(self, printer): + return "P(%s%s)" % tuple(printer._print(i) for i in self.args) + + +def simplify_index_permutations(expr, permutation_operators): + """ + Performs simplification by introducing PermutationOperators where appropriate. + + Explanation + =========== + + Schematically: + [abij] - [abji] - [baij] + [baji] -> P(ab)*P(ij)*[abij] + + permutation_operators is a list of PermutationOperators to consider. + + If permutation_operators=[P(ab),P(ij)] we will try to introduce the + permutation operators P(ij) and P(ab) in the expression. If there are other + possible simplifications, we ignore them. + + >>> from sympy import symbols, Function + >>> from sympy.physics.secondquant import simplify_index_permutations + >>> from sympy.physics.secondquant import PermutationOperator + >>> p,q,r,s = symbols('p,q,r,s') + >>> f = Function('f') + >>> g = Function('g') + + >>> expr = f(p)*g(q) - f(q)*g(p); expr + f(p)*g(q) - f(q)*g(p) + >>> simplify_index_permutations(expr,[PermutationOperator(p,q)]) + f(p)*g(q)*PermutationOperator(p, q) + + >>> PermutList = [PermutationOperator(p,q),PermutationOperator(r,s)] + >>> expr = f(p,r)*g(q,s) - f(q,r)*g(p,s) + f(q,s)*g(p,r) - f(p,s)*g(q,r) + >>> simplify_index_permutations(expr,PermutList) + f(p, r)*g(q, s)*PermutationOperator(p, q)*PermutationOperator(r, s) + + """ + + def _get_indices(expr, ind): + """ + Collects indices recursively in predictable order. + """ + result = [] + for arg in expr.args: + if arg in ind: + result.append(arg) + else: + if arg.args: + result.extend(_get_indices(arg, ind)) + return result + + def _choose_one_to_keep(a, b, ind): + # we keep the one where indices in ind are in order ind[0] < ind[1] + return min(a, b, key=lambda x: default_sort_key(_get_indices(x, ind))) + + expr = expr.expand() + if isinstance(expr, Add): + terms = set(expr.args) + + for P in permutation_operators: + new_terms = set() + on_hold = set() + while terms: + term = terms.pop() + permuted = P.get_permuted(term) + if permuted in terms | on_hold: + try: + terms.remove(permuted) + except KeyError: + on_hold.remove(permuted) + keep = _choose_one_to_keep(term, permuted, P.args) + new_terms.add(P*keep) + else: + + # Some terms must get a second chance because the permuted + # term may already have canonical dummy ordering. Then + # substitute_dummies() does nothing. However, the other + # term, if it exists, will be able to match with us. + permuted1 = permuted + permuted = substitute_dummies(permuted) + if permuted1 == permuted: + on_hold.add(term) + elif permuted in terms | on_hold: + try: + terms.remove(permuted) + except KeyError: + on_hold.remove(permuted) + keep = _choose_one_to_keep(term, permuted, P.args) + new_terms.add(P*keep) + else: + new_terms.add(term) + terms = new_terms | on_hold + return Add(*terms) + return expr diff --git a/phivenv/Lib/site-packages/sympy/physics/sho.py b/phivenv/Lib/site-packages/sympy/physics/sho.py new file mode 100644 index 0000000000000000000000000000000000000000..c55b31b3fa9fca4fa33a9f8e91c90c2174fe81a5 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/sho.py @@ -0,0 +1,95 @@ +from sympy.core import S, pi, Rational +from sympy.functions import assoc_laguerre, sqrt, exp, factorial, factorial2 + + +def R_nl(n, l, nu, r): + """ + Returns the radial wavefunction R_{nl} for a 3d isotropic harmonic + oscillator. + + Parameters + ========== + + n : + The "nodal" quantum number. Corresponds to the number of nodes in + the wavefunction. ``n >= 0`` + l : + The quantum number for orbital angular momentum. + nu : + mass-scaled frequency: nu = m*omega/(2*hbar) where `m` is the mass + and `omega` the frequency of the oscillator. + (in atomic units ``nu == omega/2``) + r : + Radial coordinate. + + Examples + ======== + + >>> from sympy.physics.sho import R_nl + >>> from sympy.abc import r, nu, l + >>> R_nl(0, 0, 1, r) + 2*2**(3/4)*exp(-r**2)/pi**(1/4) + >>> R_nl(1, 0, 1, r) + 4*2**(1/4)*sqrt(3)*(3/2 - 2*r**2)*exp(-r**2)/(3*pi**(1/4)) + + l, nu and r may be symbolic: + + >>> R_nl(0, 0, nu, r) + 2*2**(3/4)*sqrt(nu**(3/2))*exp(-nu*r**2)/pi**(1/4) + >>> R_nl(0, l, 1, r) + r**l*sqrt(2**(l + 3/2)*2**(l + 2)/factorial2(2*l + 1))*exp(-r**2)/pi**(1/4) + + The normalization of the radial wavefunction is: + + >>> from sympy import Integral, oo + >>> Integral(R_nl(0, 0, 1, r)**2*r**2, (r, 0, oo)).n() + 1.00000000000000 + >>> Integral(R_nl(1, 0, 1, r)**2*r**2, (r, 0, oo)).n() + 1.00000000000000 + >>> Integral(R_nl(1, 1, 1, r)**2*r**2, (r, 0, oo)).n() + 1.00000000000000 + + """ + n, l, nu, r = map(S, [n, l, nu, r]) + + # formula uses n >= 1 (instead of nodal n >= 0) + n = n + 1 + C = sqrt( + ((2*nu)**(l + Rational(3, 2))*2**(n + l + 1)*factorial(n - 1))/ + (sqrt(pi)*(factorial2(2*n + 2*l - 1))) + ) + return C*r**(l)*exp(-nu*r**2)*assoc_laguerre(n - 1, l + S.Half, 2*nu*r**2) + + +def E_nl(n, l, hw): + """ + Returns the Energy of an isotropic harmonic oscillator. + + Parameters + ========== + + n : + The "nodal" quantum number. + l : + The orbital angular momentum. + hw : + The harmonic oscillator parameter. + + Notes + ===== + + The unit of the returned value matches the unit of hw, since the energy is + calculated as: + + E_nl = (2*n + l + 3/2)*hw + + Examples + ======== + + >>> from sympy.physics.sho import E_nl + >>> from sympy import symbols + >>> x, y, z = symbols('x, y, z') + >>> E_nl(x, y, z) + z*(2*x + y + 3/2) + """ + return (2*n + l + Rational(3, 2))*hw diff --git a/phivenv/Lib/site-packages/sympy/physics/tests/__init__.py b/phivenv/Lib/site-packages/sympy/physics/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/physics/tests/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/tests/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..813d9e7a5b848e6de3666a5ed991cda61083894b Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/tests/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/tests/__pycache__/test_clebsch_gordan.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/tests/__pycache__/test_clebsch_gordan.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5894eaae7ede302a2d193d54394bb90f2641ca30 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/tests/__pycache__/test_clebsch_gordan.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/tests/__pycache__/test_hydrogen.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/tests/__pycache__/test_hydrogen.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5a6922a7d545c8f349e9344bb38999f75692b13 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/tests/__pycache__/test_hydrogen.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/tests/__pycache__/test_paulialgebra.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/tests/__pycache__/test_paulialgebra.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2846439637a37f0ee32d3300dba76a66e4d9014b Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/tests/__pycache__/test_paulialgebra.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/tests/__pycache__/test_physics_matrices.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/tests/__pycache__/test_physics_matrices.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..72ec5ee55fc2d9eebd8a444c21e8472796399262 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/tests/__pycache__/test_physics_matrices.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/tests/__pycache__/test_pring.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/tests/__pycache__/test_pring.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0a4615d9e178075c7f1ab79274c65c60f17fe82 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/tests/__pycache__/test_pring.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/tests/__pycache__/test_qho_1d.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/tests/__pycache__/test_qho_1d.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cbbae24dd786d3147593a3623fa14d9ba316427c Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/tests/__pycache__/test_qho_1d.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/tests/__pycache__/test_secondquant.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/tests/__pycache__/test_secondquant.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20caec17a1e0ca2219b11918f8fc07ee4328cbc8 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/tests/__pycache__/test_secondquant.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/tests/__pycache__/test_sho.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/tests/__pycache__/test_sho.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6fa053bf14a24e2ddb4033ca49b8492a95171940 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/tests/__pycache__/test_sho.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/tests/test_clebsch_gordan.py b/phivenv/Lib/site-packages/sympy/physics/tests/test_clebsch_gordan.py new file mode 100644 index 0000000000000000000000000000000000000000..e4313e3e412d6d1883efaf693c13e0f967daf9da --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/tests/test_clebsch_gordan.py @@ -0,0 +1,223 @@ +from sympy.core.numbers import (I, pi, Rational) +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.functions.special.spherical_harmonics import Ynm +from sympy.matrices.dense import Matrix +from sympy.physics.wigner import (clebsch_gordan, wigner_9j, wigner_6j, gaunt, + real_gaunt, racah, dot_rot_grad_Ynm, wigner_3j, wigner_d_small, wigner_d) +from sympy.testing.pytest import raises, skip + +# for test cases, refer : https://en.wikipedia.org/wiki/Table_of_Clebsch%E2%80%93Gordan_coefficients + +def test_clebsch_gordan_docs(): + assert clebsch_gordan(Rational(3, 2), S.Half, 2, Rational(3, 2), S.Half, 2) == 1 + assert clebsch_gordan(Rational(3, 2), S.Half, 1, Rational(3, 2), Rational(-1, 2), 1) == sqrt(3)/2 + assert clebsch_gordan(Rational(3, 2), S.Half, 1, Rational(-1, 2), S.Half, 0) == -sqrt(2)/2 + + +def test_clebsch_gordan(): + # Argument order: (j_1, j_2, j, m_1, m_2, m) + + h = S.One + k = S.Half + l = Rational(3, 2) + i = Rational(-1, 2) + n = Rational(7, 2) + p = Rational(5, 2) + assert clebsch_gordan(k, k, 1, k, k, 1) == 1 + assert clebsch_gordan(k, k, 1, k, k, 0) == 0 + assert clebsch_gordan(k, k, 1, i, i, -1) == 1 + assert clebsch_gordan(k, k, 1, k, i, 0) == sqrt(2)/2 + assert clebsch_gordan(k, k, 0, k, i, 0) == sqrt(2)/2 + assert clebsch_gordan(k, k, 1, i, k, 0) == sqrt(2)/2 + assert clebsch_gordan(k, k, 0, i, k, 0) == -sqrt(2)/2 + assert clebsch_gordan(h, k, l, 1, k, l) == 1 + assert clebsch_gordan(h, k, l, 1, i, k) == 1/sqrt(3) + assert clebsch_gordan(h, k, k, 1, i, k) == sqrt(2)/sqrt(3) + assert clebsch_gordan(h, k, k, 0, k, k) == -1/sqrt(3) + assert clebsch_gordan(h, k, l, 0, k, k) == sqrt(2)/sqrt(3) + assert clebsch_gordan(h, h, S(2), 1, 1, S(2)) == 1 + assert clebsch_gordan(h, h, S(2), 1, 0, 1) == 1/sqrt(2) + assert clebsch_gordan(h, h, S(2), 0, 1, 1) == 1/sqrt(2) + assert clebsch_gordan(h, h, 1, 1, 0, 1) == 1/sqrt(2) + assert clebsch_gordan(h, h, 1, 0, 1, 1) == -1/sqrt(2) + assert clebsch_gordan(l, l, S(3), l, l, S(3)) == 1 + assert clebsch_gordan(l, l, S(2), l, k, S(2)) == 1/sqrt(2) + assert clebsch_gordan(l, l, S(3), l, k, S(2)) == 1/sqrt(2) + assert clebsch_gordan(S(2), S(2), S(4), S(2), S(2), S(4)) == 1 + assert clebsch_gordan(S(2), S(2), S(3), S(2), 1, S(3)) == 1/sqrt(2) + assert clebsch_gordan(S(2), S(2), S(3), 1, 1, S(2)) == 0 + assert clebsch_gordan(p, h, n, p, 1, n) == 1 + assert clebsch_gordan(p, h, p, p, 0, p) == sqrt(5)/sqrt(7) + assert clebsch_gordan(p, h, l, k, 1, l) == 1/sqrt(15) + + +def test_clebsch_gordan_numpy(): + try: + import numpy as np + except ImportError: + skip("numpy not installed") + assert clebsch_gordan(*np.zeros(6).astype(np.int64)) == 1 + assert wigner_3j(2, np.float64(6.0), 4.0, 0, 0, 0) == sqrt(715)/143 + assert wigner_3j(0, 0.5, 0.5, 0, 0.5, -0.5) == sqrt(2)/2 + raises(ValueError, lambda: wigner_3j(2.1, 6, 4, 0, 0, 0)) + + +def test_wigner(): + try: + import numpy as np + except ImportError: + skip("numpy not installed") + def tn(a, b): + return (a - b).n(64) < S('1e-64') + assert tn(wigner_9j(1, 1, 1, 1, 1, 1, 1, 1, 0, prec=64), Rational(1, 18)) + assert wigner_9j(3, 3, 2, 3, 3, 2, 3, 3, 2) == 3221*sqrt( + 70)/(246960*sqrt(105)) - 365/(3528*sqrt(70)*sqrt(105)) + assert wigner_6j(5, 5, 5, 5, 5, 5) == Rational(1, 52) + assert tn(wigner_6j(8, 8, 8, 8, 8, 8, prec=64), Rational(-12219, 965770)) + assert wigner_6j(1, 1, 1, 1.0, np.float64(1.0), 1) == Rational(1, 6) + assert wigner_6j(3.0, np.float32(3), 3.0, 3, 3, 3) == Rational(-1, 14) + # regression test for #8747 + half = S.Half + assert wigner_9j(0, 0, 0, 0, half, half, 0, half, half) == half + assert (wigner_9j(3, 5, 4, + 7 * half, 5 * half, 4, + 9 * half, 9 * half, 0) + == -sqrt(Rational(361, 205821000))) + assert (wigner_9j(1, 4, 3, + 5 * half, 4, 5 * half, + 5 * half, 2, 7 * half) + == -sqrt(Rational(3971, 373403520))) + assert (wigner_9j(4, 9 * half, 5 * half, + 2, 4, 4, + 5, 7 * half, 7 * half) + == -sqrt(Rational(3481, 5042614500))) + assert (wigner_9j(5, 5, 5.0, + np.float64(5.0), 5, 5, + 5, 5, 5) + == 0) + assert (wigner_9j(1.0, 2.0, 3.0, + 3, 2, 1, + 2, 1, 3) + == -4*sqrt(70)/11025) + + +def test_gaunt(): + def tn(a, b): + return (a - b).n(64) < S('1e-64') + assert gaunt(1, 0, 1, 1, 0, -1) == -1/(2*sqrt(pi)) + assert isinstance(gaunt(1, 1, 0, -1, 1, 0).args[0], Rational) + assert isinstance(gaunt(0, 1, 1, 0, -1, 1).args[0], Rational) + + assert tn(gaunt( + 10, 10, 12, 9, 3, -12, prec=64), (Rational(-98, 62031)) * sqrt(6279)/sqrt(pi)) + def gaunt_ref(l1, l2, l3, m1, m2, m3): + return ( + sqrt((2 * l1 + 1) * (2 * l2 + 1) * (2 * l3 + 1) / (4 * pi)) * + wigner_3j(l1, l2, l3, 0, 0, 0) * + wigner_3j(l1, l2, l3, m1, m2, m3) + ) + threshold = 1e-10 + l_max = 3 + l3_max = 24 + for l1 in range(l_max + 1): + for l2 in range(l_max + 1): + for l3 in range(l3_max + 1): + for m1 in range(-l1, l1 + 1): + for m2 in range(-l2, l2 + 1): + for m3 in range(-l3, l3 + 1): + args = l1, l2, l3, m1, m2, m3 + g = gaunt(*args) + g0 = gaunt_ref(*args) + assert abs(g - g0) < threshold + if m1 + m2 + m3 != 0: + assert abs(g) < threshold + if (l1 + l2 + l3) % 2: + assert abs(g) < threshold + assert gaunt(1, 1, 0, 0, 2, -2) is S.Zero + + +def test_realgaunt(): + # All non-zero values corresponding to l values from 0 to 2 + for l in range(3): + for m in range(-l, l+1): + assert real_gaunt(0, l, l, 0, m, m) == 1/(2*sqrt(pi)) + assert real_gaunt(1, 1, 2, 0, 0, 0) == sqrt(5)/(5*sqrt(pi)) + assert real_gaunt(1, 1, 2, 1, 1, 0) == -sqrt(5)/(10*sqrt(pi)) + assert real_gaunt(2, 2, 2, 0, 0, 0) == sqrt(5)/(7*sqrt(pi)) + assert real_gaunt(2, 2, 2, 0, 2, 2) == -sqrt(5)/(7*sqrt(pi)) + assert real_gaunt(2, 2, 2, -2, -2, 0) == -sqrt(5)/(7*sqrt(pi)) + assert real_gaunt(1, 1, 2, -1, 0, -1) == sqrt(15)/(10*sqrt(pi)) + assert real_gaunt(1, 1, 2, 0, 1, 1) == sqrt(15)/(10*sqrt(pi)) + assert real_gaunt(1, 1, 2, 1, 1, 2) == sqrt(15)/(10*sqrt(pi)) + assert real_gaunt(1, 1, 2, -1, 1, -2) == sqrt(15)/(10*sqrt(pi)) + assert real_gaunt(1, 1, 2, -1, -1, 2) == -sqrt(15)/(10*sqrt(pi)) + assert real_gaunt(2, 2, 2, 0, 1, 1) == sqrt(5)/(14*sqrt(pi)) + assert real_gaunt(2, 2, 2, 1, 1, 2) == sqrt(15)/(14*sqrt(pi)) + assert real_gaunt(2, 2, 2, -1, -1, 2) == -sqrt(15)/(14*sqrt(pi)) + + assert real_gaunt(-2, -2, -2, -2, -2, 0) is S.Zero # m test + assert real_gaunt(-2, 1, 0, 1, 1, 1) is S.Zero # l test + assert real_gaunt(-2, -1, -2, -1, -1, 0) is S.Zero # m and l test + assert real_gaunt(-2, -2, -2, -2, -2, -2) is S.Zero # m and k test + assert real_gaunt(-2, -1, -2, -1, -1, -1) is S.Zero # m, l and k test + + x = symbols('x', integer=True) + v = [0]*6 + for i in range(len(v)): + v[i] = x # non literal ints fail + raises(ValueError, lambda: real_gaunt(*v)) + v[i] = 0 + + +def test_racah(): + assert racah(3,3,3,3,3,3) == Rational(-1,14) + assert racah(2,2,2,2,2,2) == Rational(-3,70) + assert racah(7,8,7,1,7,7, prec=4).is_Float + assert racah(5.5,7.5,9.5,6.5,8,9) == -719*sqrt(598)/1158924 + assert abs(racah(5.5,7.5,9.5,6.5,8,9, prec=4) - (-0.01517)) < S('1e-4') + + +def test_dot_rota_grad_SH(): + theta, phi = symbols("theta phi") + assert dot_rot_grad_Ynm(1, 1, 1, 1, 1, 0) != \ + sqrt(30)*Ynm(2, 2, 1, 0)/(10*sqrt(pi)) + assert dot_rot_grad_Ynm(1, 1, 1, 1, 1, 0).doit() == \ + sqrt(30)*Ynm(2, 2, 1, 0)/(10*sqrt(pi)) + assert dot_rot_grad_Ynm(1, 5, 1, 1, 1, 2) != \ + 0 + assert dot_rot_grad_Ynm(1, 5, 1, 1, 1, 2).doit() == \ + 0 + assert dot_rot_grad_Ynm(3, 3, 3, 3, theta, phi).doit() == \ + 15*sqrt(3003)*Ynm(6, 6, theta, phi)/(143*sqrt(pi)) + assert dot_rot_grad_Ynm(3, 3, 1, 1, theta, phi).doit() == \ + sqrt(3)*Ynm(4, 4, theta, phi)/sqrt(pi) + assert dot_rot_grad_Ynm(3, 2, 2, 0, theta, phi).doit() == \ + 3*sqrt(55)*Ynm(5, 2, theta, phi)/(11*sqrt(pi)) + assert dot_rot_grad_Ynm(3, 2, 3, 2, theta, phi).doit().expand() == \ + -sqrt(70)*Ynm(4, 4, theta, phi)/(11*sqrt(pi)) + \ + 45*sqrt(182)*Ynm(6, 4, theta, phi)/(143*sqrt(pi)) + + +def test_wigner_d(): + half = S(1)/2 + assert wigner_d_small(half, 0) == Matrix([[1, 0], [0, 1]]) + assert wigner_d_small(half, pi/2) == Matrix([[1, 1], [-1, 1]])/sqrt(2) + assert wigner_d_small(half, pi) == Matrix([[0, 1], [-1, 0]]) + + alpha, beta, gamma = symbols("alpha, beta, gamma", real=True) + D = wigner_d(half, alpha, beta, gamma) + assert D[0, 0] == exp(I*alpha/2)*exp(I*gamma/2)*cos(beta/2) + assert D[0, 1] == exp(I*alpha/2)*exp(-I*gamma/2)*sin(beta/2) + assert D[1, 0] == -exp(-I*alpha/2)*exp(I*gamma/2)*sin(beta/2) + assert D[1, 1] == exp(-I*alpha/2)*exp(-I*gamma/2)*cos(beta/2) + + # Test Y_{n mi}(g*x)=\sum_{mj}D^n_{mi mj}*Y_{n mj}(x) + theta, phi = symbols("theta phi", real=True) + v = Matrix([Ynm(1, mj, theta, phi) for mj in range(1, -2, -1)]) + w = wigner_d(1, -pi/2, pi/2, -pi/2)@v.subs({theta: pi/4, phi: pi}) + w_ = v.subs({theta: pi/2, phi: pi/4}) + assert w.expand(func=True).as_real_imag() == w_.expand(func=True).as_real_imag() diff --git a/phivenv/Lib/site-packages/sympy/physics/tests/test_hydrogen.py b/phivenv/Lib/site-packages/sympy/physics/tests/test_hydrogen.py new file mode 100644 index 0000000000000000000000000000000000000000..eb11744dd8e731f24fcd6f6be2a92ada4fffc554 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/tests/test_hydrogen.py @@ -0,0 +1,126 @@ +from sympy.core.numbers import (I, Rational, oo, pi) +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.integrals.integrals import integrate +from sympy.simplify.simplify import simplify +from sympy.physics.hydrogen import R_nl, E_nl, E_nl_dirac, Psi_nlm +from sympy.testing.pytest import raises + +n, r, Z = symbols('n r Z') + + +def feq(a, b, max_relative_error=1e-12, max_absolute_error=1e-12): + a = float(a) + b = float(b) + # if the numbers are close enough (absolutely), then they are equal + if abs(a - b) < max_absolute_error: + return True + # if not, they can still be equal if their relative error is small + if abs(b) > abs(a): + relative_error = abs((a - b)/b) + else: + relative_error = abs((a - b)/a) + return relative_error <= max_relative_error + + +def test_wavefunction(): + a = 1/Z + R = { + (1, 0): 2*sqrt(1/a**3) * exp(-r/a), + (2, 0): sqrt(1/(2*a**3)) * exp(-r/(2*a)) * (1 - r/(2*a)), + (2, 1): S.Half * sqrt(1/(6*a**3)) * exp(-r/(2*a)) * r/a, + (3, 0): Rational(2, 3) * sqrt(1/(3*a**3)) * exp(-r/(3*a)) * + (1 - 2*r/(3*a) + Rational(2, 27) * (r/a)**2), + (3, 1): Rational(4, 27) * sqrt(2/(3*a**3)) * exp(-r/(3*a)) * + (1 - r/(6*a)) * r/a, + (3, 2): Rational(2, 81) * sqrt(2/(15*a**3)) * exp(-r/(3*a)) * (r/a)**2, + (4, 0): Rational(1, 4) * sqrt(1/a**3) * exp(-r/(4*a)) * + (1 - 3*r/(4*a) + Rational(1, 8) * (r/a)**2 - Rational(1, 192) * (r/a)**3), + (4, 1): Rational(1, 16) * sqrt(5/(3*a**3)) * exp(-r/(4*a)) * + (1 - r/(4*a) + Rational(1, 80) * (r/a)**2) * (r/a), + (4, 2): Rational(1, 64) * sqrt(1/(5*a**3)) * exp(-r/(4*a)) * + (1 - r/(12*a)) * (r/a)**2, + (4, 3): Rational(1, 768) * sqrt(1/(35*a**3)) * exp(-r/(4*a)) * (r/a)**3, + } + for n, l in R: + assert simplify(R_nl(n, l, r, Z) - R[(n, l)]) == 0 + + +def test_norm(): + # Maximum "n" which is tested: + n_max = 2 # it works, but is slow, for n_max > 2 + for n in range(n_max + 1): + for l in range(n): + assert integrate(R_nl(n, l, r)**2 * r**2, (r, 0, oo)) == 1 + +def test_psi_nlm(): + r=S('r') + phi=S('phi') + theta=S('theta') + assert (Psi_nlm(1, 0, 0, r, phi, theta) == exp(-r) / sqrt(pi)) + assert (Psi_nlm(2, 1, -1, r, phi, theta)) == S.Half * exp(-r / (2)) * r \ + * (sin(theta) * exp(-I * phi) / (4 * sqrt(pi))) + assert (Psi_nlm(3, 2, 1, r, phi, theta, 2) == -sqrt(2) * sin(theta) \ + * exp(I * phi) * cos(theta) / (4 * sqrt(pi)) * S(2) / 81 \ + * sqrt(2 * 2 ** 3) * exp(-2 * r / (3)) * (r * 2) ** 2) + +def test_hydrogen_energies(): + assert E_nl(n, Z) == -Z**2/(2*n**2) + assert E_nl(n) == -1/(2*n**2) + + assert E_nl(1, 47) == -S(47)**2/(2*1**2) + assert E_nl(2, 47) == -S(47)**2/(2*2**2) + + assert E_nl(1) == -S.One/(2*1**2) + assert E_nl(2) == -S.One/(2*2**2) + assert E_nl(3) == -S.One/(2*3**2) + assert E_nl(4) == -S.One/(2*4**2) + assert E_nl(100) == -S.One/(2*100**2) + + raises(ValueError, lambda: E_nl(0)) + + +def test_hydrogen_energies_relat(): + # First test exact formulas for small "c" so that we get nice expressions: + assert E_nl_dirac(2, 0, Z=1, c=1) == 1/sqrt(2) - 1 + assert simplify(E_nl_dirac(2, 0, Z=1, c=2) - ( (8*sqrt(3) + 16) + / sqrt(16*sqrt(3) + 32) - 4)) == 0 + assert simplify(E_nl_dirac(2, 0, Z=1, c=3) - ( (54*sqrt(2) + 81) + / sqrt(108*sqrt(2) + 162) - 9)) == 0 + + # Now test for almost the correct speed of light, without floating point + # numbers: + assert simplify(E_nl_dirac(2, 0, Z=1, c=137) - ( (352275361 + 10285412 * + sqrt(1173)) / sqrt(704550722 + 20570824 * sqrt(1173)) - 18769)) == 0 + assert simplify(E_nl_dirac(2, 0, Z=82, c=137) - ( (352275361 + 2571353 * + sqrt(12045)) / sqrt(704550722 + 5142706*sqrt(12045)) - 18769)) == 0 + + # Test using exact speed of light, and compare against the nonrelativistic + # energies: + for n in range(1, 5): + for l in range(n): + assert feq(E_nl_dirac(n, l), E_nl(n), 1e-5, 1e-5) + if l > 0: + assert feq(E_nl_dirac(n, l, False), E_nl(n), 1e-5, 1e-5) + + Z = 2 + for n in range(1, 5): + for l in range(n): + assert feq(E_nl_dirac(n, l, Z=Z), E_nl(n, Z), 1e-4, 1e-4) + if l > 0: + assert feq(E_nl_dirac(n, l, False, Z), E_nl(n, Z), 1e-4, 1e-4) + + Z = 3 + for n in range(1, 5): + for l in range(n): + assert feq(E_nl_dirac(n, l, Z=Z), E_nl(n, Z), 1e-3, 1e-3) + if l > 0: + assert feq(E_nl_dirac(n, l, False, Z), E_nl(n, Z), 1e-3, 1e-3) + + # Test the exceptions: + raises(ValueError, lambda: E_nl_dirac(0, 0)) + raises(ValueError, lambda: E_nl_dirac(1, -1)) + raises(ValueError, lambda: E_nl_dirac(1, 0, False)) diff --git a/phivenv/Lib/site-packages/sympy/physics/tests/test_paulialgebra.py b/phivenv/Lib/site-packages/sympy/physics/tests/test_paulialgebra.py new file mode 100644 index 0000000000000000000000000000000000000000..f773470a1802f2864b79f56d38be1de030ff86dc --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/tests/test_paulialgebra.py @@ -0,0 +1,57 @@ +from sympy.core.numbers import I +from sympy.core.symbol import symbols +from sympy.physics.paulialgebra import Pauli +from sympy.testing.pytest import XFAIL +from sympy.physics.quantum import TensorProduct + +sigma1 = Pauli(1) +sigma2 = Pauli(2) +sigma3 = Pauli(3) + +tau1 = symbols("tau1", commutative = False) + + +def test_Pauli(): + + assert sigma1 == sigma1 + assert sigma1 != sigma2 + + assert sigma1*sigma2 == I*sigma3 + assert sigma3*sigma1 == I*sigma2 + assert sigma2*sigma3 == I*sigma1 + + assert sigma1*sigma1 == 1 + assert sigma2*sigma2 == 1 + assert sigma3*sigma3 == 1 + + assert sigma1**0 == 1 + assert sigma1**1 == sigma1 + assert sigma1**2 == 1 + assert sigma1**3 == sigma1 + assert sigma1**4 == 1 + + assert sigma3**2 == 1 + + assert sigma1*2*sigma1 == 2 + + +def test_evaluate_pauli_product(): + from sympy.physics.paulialgebra import evaluate_pauli_product + + assert evaluate_pauli_product(I*sigma2*sigma3) == -sigma1 + + # Check issue 6471 + assert evaluate_pauli_product(-I*4*sigma1*sigma2) == 4*sigma3 + + assert evaluate_pauli_product( + 1 + I*sigma1*sigma2*sigma1*sigma2 + \ + I*sigma1*sigma2*tau1*sigma1*sigma3 + \ + ((tau1**2).subs(tau1, I*sigma1)) + \ + sigma3*((tau1**2).subs(tau1, I*sigma1)) + \ + TensorProduct(I*sigma1*sigma2*sigma1*sigma2, 1) + ) == 1 -I + I*sigma3*tau1*sigma2 - 1 - sigma3 - I*TensorProduct(1,1) + + +@XFAIL +def test_Pauli_should_work(): + assert sigma1*sigma3*sigma1 == -sigma3 diff --git a/phivenv/Lib/site-packages/sympy/physics/tests/test_physics_matrices.py b/phivenv/Lib/site-packages/sympy/physics/tests/test_physics_matrices.py new file mode 100644 index 0000000000000000000000000000000000000000..14fa47668d0760826e0354c8cafae787a24256eb --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/tests/test_physics_matrices.py @@ -0,0 +1,84 @@ +from sympy.physics.matrices import msigma, mgamma, minkowski_tensor, pat_matrix, mdft +from sympy.core.numbers import (I, Rational) +from sympy.core.singleton import S +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.matrices.dense import (Matrix, eye, zeros) +from sympy.testing.pytest import warns_deprecated_sympy + + +def test_parallel_axis_theorem(): + # This tests the parallel axis theorem matrix by comparing to test + # matrices. + + # First case, 1 in all directions. + mat1 = Matrix(((2, -1, -1), (-1, 2, -1), (-1, -1, 2))) + assert pat_matrix(1, 1, 1, 1) == mat1 + assert pat_matrix(2, 1, 1, 1) == 2*mat1 + + # Second case, 1 in x, 0 in all others + mat2 = Matrix(((0, 0, 0), (0, 1, 0), (0, 0, 1))) + assert pat_matrix(1, 1, 0, 0) == mat2 + assert pat_matrix(2, 1, 0, 0) == 2*mat2 + + # Third case, 1 in y, 0 in all others + mat3 = Matrix(((1, 0, 0), (0, 0, 0), (0, 0, 1))) + assert pat_matrix(1, 0, 1, 0) == mat3 + assert pat_matrix(2, 0, 1, 0) == 2*mat3 + + # Fourth case, 1 in z, 0 in all others + mat4 = Matrix(((1, 0, 0), (0, 1, 0), (0, 0, 0))) + assert pat_matrix(1, 0, 0, 1) == mat4 + assert pat_matrix(2, 0, 0, 1) == 2*mat4 + + +def test_Pauli(): + #this and the following test are testing both Pauli and Dirac matrices + #and also that the general Matrix class works correctly in a real world + #situation + sigma1 = msigma(1) + sigma2 = msigma(2) + sigma3 = msigma(3) + + assert sigma1 == sigma1 + assert sigma1 != sigma2 + + # sigma*I -> I*sigma (see #354) + assert sigma1*sigma2 == sigma3*I + assert sigma3*sigma1 == sigma2*I + assert sigma2*sigma3 == sigma1*I + + assert sigma1*sigma1 == eye(2) + assert sigma2*sigma2 == eye(2) + assert sigma3*sigma3 == eye(2) + + assert sigma1*2*sigma1 == 2*eye(2) + assert sigma1*sigma3*sigma1 == -sigma3 + + +def test_Dirac(): + gamma0 = mgamma(0) + gamma1 = mgamma(1) + gamma2 = mgamma(2) + gamma3 = mgamma(3) + gamma5 = mgamma(5) + + # gamma*I -> I*gamma (see #354) + assert gamma5 == gamma0 * gamma1 * gamma2 * gamma3 * I + assert gamma1 * gamma2 + gamma2 * gamma1 == zeros(4) + assert gamma0 * gamma0 == eye(4) * minkowski_tensor[0, 0] + assert gamma2 * gamma2 != eye(4) * minkowski_tensor[0, 0] + assert gamma2 * gamma2 == eye(4) * minkowski_tensor[2, 2] + + assert mgamma(5, True) == \ + mgamma(0, True)*mgamma(1, True)*mgamma(2, True)*mgamma(3, True)*I + +def test_mdft(): + with warns_deprecated_sympy(): + assert mdft(1) == Matrix([[1]]) + with warns_deprecated_sympy(): + assert mdft(2) == 1/sqrt(2)*Matrix([[1,1],[1,-1]]) + with warns_deprecated_sympy(): + assert mdft(4) == Matrix([[S.Half, S.Half, S.Half, S.Half], + [S.Half, -I/2, Rational(-1,2), I/2], + [S.Half, Rational(-1,2), S.Half, Rational(-1,2)], + [S.Half, I/2, Rational(-1,2), -I/2]]) diff --git a/phivenv/Lib/site-packages/sympy/physics/tests/test_pring.py b/phivenv/Lib/site-packages/sympy/physics/tests/test_pring.py new file mode 100644 index 0000000000000000000000000000000000000000..ed7398eac4a8bb1cd4af810825caf3fcefb5f18f --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/tests/test_pring.py @@ -0,0 +1,41 @@ +from sympy.physics.pring import wavefunction, energy +from sympy.core.numbers import (I, pi) +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.integrals.integrals import integrate +from sympy.simplify.simplify import simplify +from sympy.abc import m, x, r +from sympy.physics.quantum.constants import hbar + + +def test_wavefunction(): + Psi = { + 0: (1/sqrt(2 * pi)), + 1: (1/sqrt(2 * pi)) * exp(I * x), + 2: (1/sqrt(2 * pi)) * exp(2 * I * x), + 3: (1/sqrt(2 * pi)) * exp(3 * I * x) + } + for n in Psi: + assert simplify(wavefunction(n, x) - Psi[n]) == 0 + + +def test_norm(n=1): + # Maximum "n" which is tested: + for i in range(n + 1): + assert integrate( + wavefunction(i, x) * wavefunction(-i, x), (x, 0, 2 * pi)) == 1 + + +def test_orthogonality(n=1): + # Maximum "n" which is tested: + for i in range(n + 1): + for j in range(i+1, n+1): + assert integrate( + wavefunction(i, x) * wavefunction(j, x), (x, 0, 2 * pi)) == 0 + + +def test_energy(n=1): + # Maximum "n" which is tested: + for i in range(n+1): + assert simplify( + energy(i, m, r) - ((i**2 * hbar**2) / (2 * m * r**2))) == 0 diff --git a/phivenv/Lib/site-packages/sympy/physics/tests/test_qho_1d.py b/phivenv/Lib/site-packages/sympy/physics/tests/test_qho_1d.py new file mode 100644 index 0000000000000000000000000000000000000000..34e52c9e3a721496fc61f7d2b31414db15caa7a8 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/tests/test_qho_1d.py @@ -0,0 +1,50 @@ +from sympy.core.numbers import (Rational, oo, pi) +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.integrals.integrals import integrate +from sympy.simplify.simplify import simplify +from sympy.abc import omega, m, x +from sympy.physics.qho_1d import psi_n, E_n, coherent_state +from sympy.physics.quantum.constants import hbar + +nu = m * omega / hbar + + +def test_wavefunction(): + Psi = { + 0: (nu/pi)**Rational(1, 4) * exp(-nu * x**2 /2), + 1: (nu/pi)**Rational(1, 4) * sqrt(2*nu) * x * exp(-nu * x**2 /2), + 2: (nu/pi)**Rational(1, 4) * (2 * nu * x**2 - 1)/sqrt(2) * exp(-nu * x**2 /2), + 3: (nu/pi)**Rational(1, 4) * sqrt(nu/3) * (2 * nu * x**3 - 3 * x) * exp(-nu * x**2 /2) + } + for n in Psi: + assert simplify(psi_n(n, x, m, omega) - Psi[n]) == 0 + + +def test_norm(n=1): + # Maximum "n" which is tested: + for i in range(n + 1): + assert integrate(psi_n(i, x, 1, 1)**2, (x, -oo, oo)) == 1 + + +def test_orthogonality(n=1): + # Maximum "n" which is tested: + for i in range(n + 1): + for j in range(i + 1, n + 1): + assert integrate( + psi_n(i, x, 1, 1)*psi_n(j, x, 1, 1), (x, -oo, oo)) == 0 + + +def test_energies(n=1): + # Maximum "n" which is tested: + for i in range(n + 1): + assert E_n(i, omega) == hbar * omega * (i + S.Half) + +def test_coherent_state(n=10): + # Maximum "n" which is tested: + # test whether coherent state is the eigenstate of annihilation operator + alpha = Symbol("alpha") + for i in range(n + 1): + assert simplify(sqrt(n + 1) * coherent_state(n + 1, alpha)) == simplify(alpha * coherent_state(n, alpha)) diff --git a/phivenv/Lib/site-packages/sympy/physics/tests/test_secondquant.py b/phivenv/Lib/site-packages/sympy/physics/tests/test_secondquant.py new file mode 100644 index 0000000000000000000000000000000000000000..e7f60fab05497aead65ad748460802c9c29740ce --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/tests/test_secondquant.py @@ -0,0 +1,1301 @@ +from sympy.functions.elementary.complexes import conjugate +from sympy.functions.elementary.exponential import exp +from sympy.physics.secondquant import ( + Dagger, Bd, VarBosonicBasis, BBra, B, BKet, FixedBosonicBasis, + matrix_rep, apply_operators, InnerProduct, Commutator, KroneckerDelta, + AnnihilateBoson, CreateBoson, BosonicOperator, + F, Fd, FKet, BosonState, CreateFermion, AnnihilateFermion, + evaluate_deltas, AntiSymmetricTensor, contraction, NO, wicks, + PermutationOperator, simplify_index_permutations, + _sort_anticommuting_fermions, _get_ordered_dummies, + substitute_dummies, FockStateBosonKet, + ContractionAppliesOnlyToFermions +) + +from sympy.concrete.summations import Sum +from sympy.core.function import (Function, expand) +from sympy.core.numbers import (I, Rational) +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol, symbols) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.printing.repr import srepr +from sympy.simplify.simplify import simplify + +from sympy.testing.pytest import slow, raises +from sympy.printing.latex import latex + + +def test_PermutationOperator(): + p, q, r, s = symbols('p,q,r,s') + f, g, h, i = map(Function, 'fghi') + P = PermutationOperator + assert P(p, q).get_permuted(f(p)*g(q)) == -f(q)*g(p) + assert P(p, q).get_permuted(f(p, q)) == -f(q, p) + assert P(p, q).get_permuted(f(p)) == f(p) + expr = (f(p)*g(q)*h(r)*i(s) + - f(q)*g(p)*h(r)*i(s) + - f(p)*g(q)*h(s)*i(r) + + f(q)*g(p)*h(s)*i(r)) + perms = [P(p, q), P(r, s)] + assert (simplify_index_permutations(expr, perms) == + P(p, q)*P(r, s)*f(p)*g(q)*h(r)*i(s)) + assert latex(P(p, q)) == 'P(pq)' + + p1, p2 = symbols('p1,p2') + assert latex(P(p1,p2) == 'P(p_{1}p_{2})') + +def test_index_permutations_with_dummies(): + a, b, c, d = symbols('a b c d') + p, q, r, s = symbols('p q r s', cls=Dummy) + f, g = map(Function, 'fg') + P = PermutationOperator + + # No dummy substitution necessary + expr = f(a, b, p, q) - f(b, a, p, q) + assert simplify_index_permutations( + expr, [P(a, b)]) == P(a, b)*f(a, b, p, q) + + # Cases where dummy substitution is needed + expected = P(a, b)*substitute_dummies(f(a, b, p, q)) + + expr = f(a, b, p, q) - f(b, a, q, p) + result = simplify_index_permutations(expr, [P(a, b)]) + assert expected == substitute_dummies(result) + + expr = f(a, b, q, p) - f(b, a, p, q) + result = simplify_index_permutations(expr, [P(a, b)]) + assert expected == substitute_dummies(result) + + # A case where nothing can be done + expr = f(a, b, q, p) - g(b, a, p, q) + result = simplify_index_permutations(expr, [P(a, b)]) + assert expr == result + + +def test_dagger(): + i, j, n, m = symbols('i,j,n,m') + assert Dagger(1) == 1 + assert Dagger(1.0) == 1.0 + assert Dagger(2*I) == -2*I + assert Dagger(S.Half*I/3.0) == I*Rational(-1, 2)/3.0 + assert Dagger(BKet([n])) == BBra([n]) + assert Dagger(B(0)) == Bd(0) + assert Dagger(Bd(0)) == B(0) + assert Dagger(B(n)) == Bd(n) + assert Dagger(Bd(n)) == B(n) + assert Dagger(B(0) + B(1)) == Bd(0) + Bd(1) + assert Dagger(n*m) == Dagger(n)*Dagger(m) # n, m commute + assert Dagger(B(n)*B(m)) == Bd(m)*Bd(n) + assert Dagger(B(n)**10) == Dagger(B(n))**10 + assert Dagger('a') == Dagger(Symbol('a')) + assert Dagger(Dagger('a')) == Symbol('a') + assert Dagger(exp(2 * I)) == exp(-2 * I) + assert Dagger(i) == conjugate(i) + + +def test_operator(): + i, j = symbols('i,j') + o = BosonicOperator(i) + assert o.state == i + assert o.is_symbolic + o = BosonicOperator(1) + assert o.state == 1 + assert not o.is_symbolic + + +def test_create(): + i, j, n, m, p1 = symbols('i,j,n,m,p1') + o = Bd(i) + assert latex(o) == "{b^\\dagger_{i}}" + assert latex(Bd(p1)) == "{b^\\dagger_{p_{1}}}" + assert isinstance(o, CreateBoson) + o = o.subs(i, j) + assert o.atoms(Symbol) == {j} + o = Bd(0) + assert o.apply_operator(BKet([n])) == sqrt(n + 1)*BKet([n + 1]) + o = Bd(n) + assert o.apply_operator(BKet([n])) == o*BKet([n]) + + +def test_annihilate(): + i, j, n, m, p1 = symbols('i,j,n,m,p1') + o = B(i) + assert latex(o) == "b_{i}" + assert latex(B(p1)) == "b_{p_{1}}" + assert isinstance(o, AnnihilateBoson) + o = o.subs(i, j) + assert o.atoms(Symbol) == {j} + o = B(0) + assert o.apply_operator(BKet([n])) == sqrt(n)*BKet([n - 1]) + o = B(n) + assert o.apply_operator(BKet([n])) == o*BKet([n]) + + +def test_basic_state(): + i, j, n, m = symbols('i,j,n,m') + s = BosonState([0, 1, 2, 3, 4]) + assert len(s) == 5 + assert s.args[0] == tuple(range(5)) + assert s.up(0) == BosonState([1, 1, 2, 3, 4]) + assert s.down(4) == BosonState([0, 1, 2, 3, 3]) + for i in range(5): + assert s.up(i).down(i) == s + assert s.down(0) == 0 + for i in range(5): + assert s[i] == i + s = BosonState([n, m]) + assert s.down(0) == BosonState([n - 1, m]) + assert s.up(0) == BosonState([n + 1, m]) + + +def test_basic_apply(): + n = symbols("n") + e = B(0)*BKet([n]) + assert apply_operators(e) == sqrt(n)*BKet([n - 1]) + e = Bd(0)*BKet([n]) + assert apply_operators(e) == sqrt(n + 1)*BKet([n + 1]) + + +def test_complex_apply(): + n, m = symbols("n,m") + o = Bd(0)*B(0)*Bd(1)*B(0) + e = apply_operators(o*BKet([n, m])) + answer = sqrt(n)*sqrt(m + 1)*(-1 + n)*BKet([-1 + n, 1 + m]) + assert expand(e) == expand(answer) + + +def test_number_operator(): + n = symbols("n") + o = Bd(0)*B(0) + e = apply_operators(o*BKet([n])) + assert e == n*BKet([n]) + + +def test_inner_product(): + i, j, k, l = symbols('i,j,k,l') + s1 = BBra([0]) + s2 = BKet([1]) + assert InnerProduct(s1, Dagger(s1)) == 1 + assert InnerProduct(s1, s2) == 0 + s1 = BBra([i, j]) + s2 = BKet([k, l]) + r = InnerProduct(s1, s2) + assert r == KroneckerDelta(i, k)*KroneckerDelta(j, l) + + +def test_symbolic_matrix_elements(): + n, m = symbols('n,m') + s1 = BBra([n]) + s2 = BKet([m]) + o = B(0) + e = apply_operators(s1*o*s2) + assert e == sqrt(m)*KroneckerDelta(n, m - 1) + + +def test_matrix_elements(): + b = VarBosonicBasis(5) + o = B(0) + m = matrix_rep(o, b) + for i in range(4): + assert m[i, i + 1] == sqrt(i + 1) + o = Bd(0) + m = matrix_rep(o, b) + for i in range(4): + assert m[i + 1, i] == sqrt(i + 1) + + +def test_fixed_bosonic_basis(): + b = FixedBosonicBasis(2, 2) + # assert b == [FockState((2, 0)), FockState((1, 1)), FockState((0, 2))] + state = b.state(1) + assert state == FockStateBosonKet((1, 1)) + assert b.index(state) == 1 + assert b.state(1) == b[1] + assert len(b) == 3 + assert str(b) == '[FockState((2, 0)), FockState((1, 1)), FockState((0, 2))]' + assert repr(b) == '[FockState((2, 0)), FockState((1, 1)), FockState((0, 2))]' + assert srepr(b) == '[FockState((2, 0)), FockState((1, 1)), FockState((0, 2))]' + + +@slow +def test_sho(): + n, m = symbols('n,m') + h_n = Bd(n)*B(n)*(n + S.Half) + H = Sum(h_n, (n, 0, 5)) + o = H.doit(deep=False) + b = FixedBosonicBasis(2, 6) + m = matrix_rep(o, b) + # We need to double check these energy values to make sure that they + # are correct and have the proper degeneracies! + diag = [1, 2, 3, 3, 4, 5, 4, 5, 6, 7, 5, 6, 7, 8, 9, 6, 7, 8, 9, 10, 11] + for i in range(len(diag)): + assert diag[i] == m[i, i] + + +def test_commutation(): + n, m = symbols("n,m", above_fermi=True) + c = Commutator(B(0), Bd(0)) + assert c == 1 + c = Commutator(Bd(0), B(0)) + assert c == -1 + c = Commutator(B(n), Bd(0)) + assert c == KroneckerDelta(n, 0) + c = Commutator(B(0), B(0)) + assert c == 0 + c = Commutator(B(0), Bd(0)) + e = simplify(apply_operators(c*BKet([n]))) + assert e == BKet([n]) + c = Commutator(B(0), B(1)) + e = simplify(apply_operators(c*BKet([n, m]))) + assert e == 0 + + c = Commutator(F(m), Fd(m)) + assert c == +1 - 2*NO(Fd(m)*F(m)) + c = Commutator(Fd(m), F(m)) + assert c.expand() == -1 + 2*NO(Fd(m)*F(m)) + + C = Commutator + X, Y, Z = symbols('X,Y,Z', commutative=False) + assert C(C(X, Y), Z) != 0 + assert C(C(X, Z), Y) != 0 + assert C(Y, C(X, Z)) != 0 + + i, j, k, l = symbols('i,j,k,l', below_fermi=True) + a, b, c, d = symbols('a,b,c,d', above_fermi=True) + p, q, r, s = symbols('p,q,r,s') + D = KroneckerDelta + + assert C(Fd(a), F(i)) == -2*NO(F(i)*Fd(a)) + assert C(Fd(j), NO(Fd(a)*F(i))).doit(wicks=True) == -D(j, i)*Fd(a) + assert C(Fd(a)*F(i), Fd(b)*F(j)).doit(wicks=True) == 0 + + c1 = Commutator(F(a), Fd(a)) + assert Commutator.eval(c1, c1) == 0 + c = Commutator(Fd(a)*F(i),Fd(b)*F(j)) + assert latex(c) == r'\left[{a^\dagger_{a}} a_{i},{a^\dagger_{b}} a_{j}\right]' + assert repr(c) == 'Commutator(CreateFermion(a)*AnnihilateFermion(i),CreateFermion(b)*AnnihilateFermion(j))' + assert str(c) == '[CreateFermion(a)*AnnihilateFermion(i),CreateFermion(b)*AnnihilateFermion(j)]' + + +def test_create_f(): + i, j, n, m = symbols('i,j,n,m') + o = Fd(i) + assert isinstance(o, CreateFermion) + o = o.subs(i, j) + assert o.atoms(Symbol) == {j} + o = Fd(1) + assert o.apply_operator(FKet([n])) == FKet([1, n]) + assert o.apply_operator(FKet([n])) == -FKet([n, 1]) + o = Fd(n) + assert o.apply_operator(FKet([])) == FKet([n]) + + vacuum = FKet([], fermi_level=4) + assert vacuum == FKet([], fermi_level=4) + + i, j, k, l = symbols('i,j,k,l', below_fermi=True) + a, b, c, d = symbols('a,b,c,d', above_fermi=True) + p, q, r, s = symbols('p,q,r,s') + p1 = symbols("p1") + + assert Fd(i).apply_operator(FKet([i, j, k], 4)) == FKet([j, k], 4) + assert Fd(a).apply_operator(FKet([i, b, k], 4)) == FKet([a, i, b, k], 4) + + assert Dagger(B(p)).apply_operator(q) == q*CreateBoson(p) + assert repr(Fd(p)) == 'CreateFermion(p)' + assert srepr(Fd(p)) == "CreateFermion(Symbol('p'))" + assert latex(Fd(p)) == r'{a^\dagger_{p}}' + assert latex(Fd(p1)) == r'{a^\dagger_{p_{1}}}' + assert latex(FKet([a,i], 1)) == r"\left|\left( a, \ i\right)\right\rangle" + assert latex(FKet([j,i,b,a], 2)) == r"\left|\left( a, \ b, \ i, \ j\right)\right\rangle" + + +def test_annihilate_f(): + i, j, n, m = symbols('i,j,n,m') + o = F(i) + assert isinstance(o, AnnihilateFermion) + o = o.subs(i, j) + assert o.atoms(Symbol) == {j} + o = F(1) + assert o.apply_operator(FKet([1, n])) == FKet([n]) + assert o.apply_operator(FKet([n, 1])) == -FKet([n]) + o = F(n) + assert o.apply_operator(FKet([n])) == FKet([]) + + i, j, k, l = symbols('i,j,k,l', below_fermi=True) + a, b, c, d = symbols('a,b,c,d', above_fermi=True) + p, q, r, s = symbols('p,q,r,s') + p1 = symbols('p1') + + assert F(i).apply_operator(FKet([i, j, k], 4)) == 0 + assert F(a).apply_operator(FKet([i, b, k], 4)) == 0 + assert F(l).apply_operator(FKet([i, j, k], 3)) == 0 + assert F(l).apply_operator(FKet([i, j, k], 4)) == FKet([l, i, j, k], 4) + assert str(F(p)) == 'f(p)' + assert repr(F(p)) == 'AnnihilateFermion(p)' + assert srepr(F(p)) == "AnnihilateFermion(Symbol('p'))" + assert latex(F(p)) == 'a_{p}' + assert latex(F(p1)) == 'a_{p_{1}}' + + +def test_create_b(): + i, j, n, m = symbols('i,j,n,m') + o = Bd(i) + assert isinstance(o, CreateBoson) + o = o.subs(i, j) + assert o.atoms(Symbol) == {j} + o = Bd(0) + assert o.apply_operator(BKet([n])) == sqrt(n + 1)*BKet([n + 1]) + o = Bd(n) + assert o.apply_operator(BKet([n])) == o*BKet([n]) + + +def test_annihilate_b(): + i, j, n, m = symbols('i,j,n,m') + o = B(i) + assert isinstance(o, AnnihilateBoson) + o = o.subs(i, j) + assert o.atoms(Symbol) == {j} + o = B(0) + + +def test_wicks(): + p, q, r, s = symbols('p,q,r,s', above_fermi=True) + + # Testing for particles only + + str = F(p)*Fd(q) + assert wicks(str) == NO(F(p)*Fd(q)) + KroneckerDelta(p, q) + str = Fd(p)*F(q) + assert wicks(str) == NO(Fd(p)*F(q)) + + str = F(p)*Fd(q)*F(r)*Fd(s) + nstr = wicks(str) + fasit = NO( + KroneckerDelta(p, q)*KroneckerDelta(r, s) + + KroneckerDelta(p, q)*AnnihilateFermion(r)*CreateFermion(s) + + KroneckerDelta(r, s)*AnnihilateFermion(p)*CreateFermion(q) + - KroneckerDelta(p, s)*AnnihilateFermion(r)*CreateFermion(q) + - AnnihilateFermion(p)*AnnihilateFermion(r)*CreateFermion(q)*CreateFermion(s)) + assert nstr == fasit + + assert (p*q*nstr).expand() == wicks(p*q*str) + assert (nstr*p*q*2).expand() == wicks(str*p*q*2) + + # Testing CC equations particles and holes + i, j, k, l = symbols('i j k l', below_fermi=True, cls=Dummy) + a, b, c, d = symbols('a b c d', above_fermi=True, cls=Dummy) + p, q, r, s = symbols('p q r s', cls=Dummy) + + assert (wicks(F(a)*NO(F(i)*F(j))*Fd(b)) == + NO(F(a)*F(i)*F(j)*Fd(b)) + + KroneckerDelta(a, b)*NO(F(i)*F(j))) + assert (wicks(F(a)*NO(F(i)*F(j)*F(k))*Fd(b)) == + NO(F(a)*F(i)*F(j)*F(k)*Fd(b)) - + KroneckerDelta(a, b)*NO(F(i)*F(j)*F(k))) + + expr = wicks(Fd(i)*NO(Fd(j)*F(k))*F(l)) + assert (expr == + -KroneckerDelta(i, k)*NO(Fd(j)*F(l)) - + KroneckerDelta(j, l)*NO(Fd(i)*F(k)) - + KroneckerDelta(i, k)*KroneckerDelta(j, l) + + KroneckerDelta(i, l)*NO(Fd(j)*F(k)) + + NO(Fd(i)*Fd(j)*F(k)*F(l))) + expr = wicks(F(a)*NO(F(b)*Fd(c))*Fd(d)) + assert (expr == + -KroneckerDelta(a, c)*NO(F(b)*Fd(d)) - + KroneckerDelta(b, d)*NO(F(a)*Fd(c)) - + KroneckerDelta(a, c)*KroneckerDelta(b, d) + + KroneckerDelta(a, d)*NO(F(b)*Fd(c)) + + NO(F(a)*F(b)*Fd(c)*Fd(d))) + + +def test_NO(): + i, j, k, l = symbols('i j k l', below_fermi=True) + a, b, c, d = symbols('a b c d', above_fermi=True) + p, q, r, s = symbols('p q r s', cls=Dummy) + + assert (NO(Fd(p)*F(q) + Fd(a)*F(b)) == + NO(Fd(p)*F(q)) + NO(Fd(a)*F(b))) + assert (NO(Fd(i)*NO(F(j)*Fd(a))) == + NO(Fd(i)*F(j)*Fd(a))) + assert NO(1) == 1 + assert NO(i) == i + assert (NO(Fd(a)*Fd(b)*(F(c) + F(d))) == + NO(Fd(a)*Fd(b)*F(c)) + + NO(Fd(a)*Fd(b)*F(d))) + + assert NO(Fd(a)*F(b))._remove_brackets() == Fd(a)*F(b) + assert NO(F(j)*Fd(i))._remove_brackets() == F(j)*Fd(i) + + assert (NO(Fd(p)*F(q)).subs(Fd(p), Fd(a) + Fd(i)) == + NO(Fd(a)*F(q)) + NO(Fd(i)*F(q))) + assert (NO(Fd(p)*F(q)).subs(F(q), F(a) + F(i)) == + NO(Fd(p)*F(a)) + NO(Fd(p)*F(i))) + + expr = NO(Fd(p)*F(q))._remove_brackets() + assert wicks(expr) == NO(expr) + + assert NO(Fd(a)*F(b)) == - NO(F(b)*Fd(a)) + + no = NO(Fd(a)*F(i)*F(b)*Fd(j)) + l1 = list(no.iter_q_creators()) + assert l1 == [0, 1] + l2 = list(no.iter_q_annihilators()) + assert l2 == [3, 2] + no = NO(Fd(a)*Fd(i)) + assert no.has_q_creators == 1 + assert no.has_q_annihilators == -1 + assert str(no) == ':CreateFermion(a)*CreateFermion(i):' + assert repr(no) == 'NO(CreateFermion(a)*CreateFermion(i))' + assert latex(no) == r'\left\{{a^\dagger_{a}} {a^\dagger_{i}}\right\}' + raises(NotImplementedError, lambda: NO(Bd(p)*F(q))) + + +def test_sorting(): + i, j = symbols('i,j', below_fermi=True) + a, b = symbols('a,b', above_fermi=True) + p, q = symbols('p,q') + + # p, q + assert _sort_anticommuting_fermions([Fd(p), F(q)]) == ([Fd(p), F(q)], 0) + assert _sort_anticommuting_fermions([F(p), Fd(q)]) == ([Fd(q), F(p)], 1) + + # i, p + assert _sort_anticommuting_fermions([F(p), Fd(i)]) == ([F(p), Fd(i)], 0) + assert _sort_anticommuting_fermions([Fd(i), F(p)]) == ([F(p), Fd(i)], 1) + assert _sort_anticommuting_fermions([Fd(p), Fd(i)]) == ([Fd(p), Fd(i)], 0) + assert _sort_anticommuting_fermions([Fd(i), Fd(p)]) == ([Fd(p), Fd(i)], 1) + assert _sort_anticommuting_fermions([F(p), F(i)]) == ([F(i), F(p)], 1) + assert _sort_anticommuting_fermions([F(i), F(p)]) == ([F(i), F(p)], 0) + assert _sort_anticommuting_fermions([Fd(p), F(i)]) == ([F(i), Fd(p)], 1) + assert _sort_anticommuting_fermions([F(i), Fd(p)]) == ([F(i), Fd(p)], 0) + + # a, p + assert _sort_anticommuting_fermions([F(p), Fd(a)]) == ([Fd(a), F(p)], 1) + assert _sort_anticommuting_fermions([Fd(a), F(p)]) == ([Fd(a), F(p)], 0) + assert _sort_anticommuting_fermions([Fd(p), Fd(a)]) == ([Fd(a), Fd(p)], 1) + assert _sort_anticommuting_fermions([Fd(a), Fd(p)]) == ([Fd(a), Fd(p)], 0) + assert _sort_anticommuting_fermions([F(p), F(a)]) == ([F(p), F(a)], 0) + assert _sort_anticommuting_fermions([F(a), F(p)]) == ([F(p), F(a)], 1) + assert _sort_anticommuting_fermions([Fd(p), F(a)]) == ([Fd(p), F(a)], 0) + assert _sort_anticommuting_fermions([F(a), Fd(p)]) == ([Fd(p), F(a)], 1) + + # i, a + assert _sort_anticommuting_fermions([F(i), Fd(j)]) == ([F(i), Fd(j)], 0) + assert _sort_anticommuting_fermions([Fd(j), F(i)]) == ([F(i), Fd(j)], 1) + assert _sort_anticommuting_fermions([Fd(a), Fd(i)]) == ([Fd(a), Fd(i)], 0) + assert _sort_anticommuting_fermions([Fd(i), Fd(a)]) == ([Fd(a), Fd(i)], 1) + assert _sort_anticommuting_fermions([F(a), F(i)]) == ([F(i), F(a)], 1) + assert _sort_anticommuting_fermions([F(i), F(a)]) == ([F(i), F(a)], 0) + + +def test_contraction(): + i, j, k, l = symbols('i,j,k,l', below_fermi=True) + a, b, c, d = symbols('a,b,c,d', above_fermi=True) + p, q, r, s = symbols('p,q,r,s') + assert contraction(Fd(i), F(j)) == KroneckerDelta(i, j) + assert contraction(F(a), Fd(b)) == KroneckerDelta(a, b) + assert contraction(F(a), Fd(i)) == 0 + assert contraction(Fd(a), F(i)) == 0 + assert contraction(F(i), Fd(a)) == 0 + assert contraction(Fd(i), F(a)) == 0 + assert contraction(Fd(i), F(p)) == KroneckerDelta(i, p) + restr = evaluate_deltas(contraction(Fd(p), F(q))) + assert restr.is_only_below_fermi + restr = evaluate_deltas(contraction(F(p), Fd(q))) + assert restr.is_only_above_fermi + raises(ContractionAppliesOnlyToFermions, lambda: contraction(B(a), Fd(b))) + + +def test_evaluate_deltas(): + i, j, k = symbols('i,j,k') + + r = KroneckerDelta(i, j) * KroneckerDelta(j, k) + assert evaluate_deltas(r) == KroneckerDelta(i, k) + + r = KroneckerDelta(i, 0) * KroneckerDelta(j, k) + assert evaluate_deltas(r) == KroneckerDelta(i, 0) * KroneckerDelta(j, k) + + r = KroneckerDelta(1, j) * KroneckerDelta(j, k) + assert evaluate_deltas(r) == KroneckerDelta(1, k) + + r = KroneckerDelta(j, 2) * KroneckerDelta(k, j) + assert evaluate_deltas(r) == KroneckerDelta(2, k) + + r = KroneckerDelta(i, 0) * KroneckerDelta(i, j) * KroneckerDelta(j, 1) + assert evaluate_deltas(r) == 0 + + r = (KroneckerDelta(0, i) * KroneckerDelta(0, j) + * KroneckerDelta(1, j) * KroneckerDelta(1, j)) + assert evaluate_deltas(r) == 0 + + +def test_Tensors(): + i, j, k, l = symbols('i j k l', below_fermi=True, cls=Dummy) + a, b, c, d = symbols('a b c d', above_fermi=True, cls=Dummy) + p, q, r, s = symbols('p q r s') + + AT = AntiSymmetricTensor + assert AT('t', (a, b), (i, j)) == -AT('t', (b, a), (i, j)) + assert AT('t', (a, b), (i, j)) == AT('t', (b, a), (j, i)) + assert AT('t', (a, b), (i, j)) == -AT('t', (a, b), (j, i)) + assert AT('t', (a, a), (i, j)) == 0 + assert AT('t', (a, b), (i, i)) == 0 + assert AT('t', (a, b, c), (i, j)) == -AT('t', (b, a, c), (i, j)) + assert AT('t', (a, b, c), (i, j, k)) == AT('t', (b, a, c), (i, k, j)) + + tabij = AT('t', (a, b), (i, j)) + assert tabij.has(a) + assert tabij.has(b) + assert tabij.has(i) + assert tabij.has(j) + assert tabij.subs(b, c) == AT('t', (a, c), (i, j)) + assert (2*tabij).subs(i, c) == 2*AT('t', (a, b), (c, j)) + assert tabij.symbol == Symbol('t') + assert latex(tabij) == '{t^{ab}_{ij}}' + assert str(tabij) == 't((_a, _b),(_i, _j))' + + assert AT('t', (a, a), (i, j)).subs(a, b) == AT('t', (b, b), (i, j)) + assert AT('t', (a, i), (a, j)).subs(a, b) == AT('t', (b, i), (b, j)) + + a1, a2, a3, a4 = symbols('alpha1:5') + u_alpha1234 = AntiSymmetricTensor("u", (a1, a2), (a3, a4)) + + assert latex(u_alpha1234) == r'{u^{\alpha_{1}\alpha_{2}}_{\alpha_{3}\alpha_{4}}}' + assert str(u_alpha1234) == 'u((alpha1, alpha2),(alpha3, alpha4))' + + +def test_fully_contracted(): + i, j, k, l = symbols('i j k l', below_fermi=True) + a, b, c, d = symbols('a b c d', above_fermi=True) + p, q, r, s = symbols('p q r s', cls=Dummy) + + Fock = (AntiSymmetricTensor('f', (p,), (q,))* + NO(Fd(p)*F(q))) + V = (AntiSymmetricTensor('v', (p, q), (r, s))* + NO(Fd(p)*Fd(q)*F(s)*F(r)))/4 + + Fai = wicks(NO(Fd(i)*F(a))*Fock, + keep_only_fully_contracted=True, + simplify_kronecker_deltas=True) + assert Fai == AntiSymmetricTensor('f', (a,), (i,)) + Vabij = wicks(NO(Fd(i)*Fd(j)*F(b)*F(a))*V, + keep_only_fully_contracted=True, + simplify_kronecker_deltas=True) + assert Vabij == AntiSymmetricTensor('v', (a, b), (i, j)) + + +def test_substitute_dummies_without_dummies(): + i, j = symbols('i,j') + assert substitute_dummies(att(i, j) + 2) == att(i, j) + 2 + assert substitute_dummies(att(i, j) + 1) == att(i, j) + 1 + + +def test_substitute_dummies_NO_operator(): + i, j = symbols('i j', cls=Dummy) + assert substitute_dummies(att(i, j)*NO(Fd(i)*F(j)) + - att(j, i)*NO(Fd(j)*F(i))) == 0 + + +def test_substitute_dummies_SQ_operator(): + i, j = symbols('i j', cls=Dummy) + assert substitute_dummies(att(i, j)*Fd(i)*F(j) + - att(j, i)*Fd(j)*F(i)) == 0 + + +def test_substitute_dummies_new_indices(): + i, j = symbols('i j', below_fermi=True, cls=Dummy) + a, b = symbols('a b', above_fermi=True, cls=Dummy) + p, q = symbols('p q', cls=Dummy) + f = Function('f') + assert substitute_dummies(f(i, a, p) - f(j, b, q), new_indices=True) == 0 + + +def test_substitute_dummies_substitution_order(): + i, j, k, l = symbols('i j k l', below_fermi=True, cls=Dummy) + f = Function('f') + from sympy.utilities.iterables import variations + for permut in variations([i, j, k, l], 4): + assert substitute_dummies(f(*permut) - f(i, j, k, l)) == 0 + + +def test_dummy_order_inner_outer_lines_VT1T1T1(): + ii = symbols('i', below_fermi=True) + aa = symbols('a', above_fermi=True) + k, l = symbols('k l', below_fermi=True, cls=Dummy) + c, d = symbols('c d', above_fermi=True, cls=Dummy) + + v = Function('v') + t = Function('t') + dums = _get_ordered_dummies + + # Coupled-Cluster T1 terms with V*T1*T1*T1 + # t^{a}_{k} t^{c}_{i} t^{d}_{l} v^{lk}_{dc} + exprs = [ + # permut v and t <=> swapping internal lines, equivalent + # irrespective of symmetries in v + v(k, l, c, d)*t(c, ii)*t(d, l)*t(aa, k), + v(l, k, c, d)*t(c, ii)*t(d, k)*t(aa, l), + v(k, l, d, c)*t(d, ii)*t(c, l)*t(aa, k), + v(l, k, d, c)*t(d, ii)*t(c, k)*t(aa, l), + ] + for permut in exprs[1:]: + assert dums(exprs[0]) != dums(permut) + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + + +def test_dummy_order_inner_outer_lines_VT1T1T1T1(): + ii, jj = symbols('i j', below_fermi=True) + aa, bb = symbols('a b', above_fermi=True) + k, l = symbols('k l', below_fermi=True, cls=Dummy) + c, d = symbols('c d', above_fermi=True, cls=Dummy) + + v = Function('v') + t = Function('t') + dums = _get_ordered_dummies + + # Coupled-Cluster T2 terms with V*T1*T1*T1*T1 + exprs = [ + # permut t <=> swapping external lines, not equivalent + # except if v has certain symmetries. + v(k, l, c, d)*t(c, ii)*t(d, jj)*t(aa, k)*t(bb, l), + v(k, l, c, d)*t(c, jj)*t(d, ii)*t(aa, k)*t(bb, l), + v(k, l, c, d)*t(c, ii)*t(d, jj)*t(bb, k)*t(aa, l), + v(k, l, c, d)*t(c, jj)*t(d, ii)*t(bb, k)*t(aa, l), + ] + for permut in exprs[1:]: + assert dums(exprs[0]) != dums(permut) + assert substitute_dummies(exprs[0]) != substitute_dummies(permut) + exprs = [ + # permut v <=> swapping external lines, not equivalent + # except if v has certain symmetries. + # + # Note that in contrast to above, these permutations have identical + # dummy order. That is because the proximity to external indices + # has higher influence on the canonical dummy ordering than the + # position of a dummy on the factors. In fact, the terms here are + # similar in structure as the result of the dummy substitutions above. + v(k, l, c, d)*t(c, ii)*t(d, jj)*t(aa, k)*t(bb, l), + v(l, k, c, d)*t(c, ii)*t(d, jj)*t(aa, k)*t(bb, l), + v(k, l, d, c)*t(c, ii)*t(d, jj)*t(aa, k)*t(bb, l), + v(l, k, d, c)*t(c, ii)*t(d, jj)*t(aa, k)*t(bb, l), + ] + for permut in exprs[1:]: + assert dums(exprs[0]) == dums(permut) + assert substitute_dummies(exprs[0]) != substitute_dummies(permut) + exprs = [ + # permut t and v <=> swapping internal lines, equivalent. + # Canonical dummy order is different, and a consistent + # substitution reveals the equivalence. + v(k, l, c, d)*t(c, ii)*t(d, jj)*t(aa, k)*t(bb, l), + v(k, l, d, c)*t(c, jj)*t(d, ii)*t(aa, k)*t(bb, l), + v(l, k, c, d)*t(c, ii)*t(d, jj)*t(bb, k)*t(aa, l), + v(l, k, d, c)*t(c, jj)*t(d, ii)*t(bb, k)*t(aa, l), + ] + for permut in exprs[1:]: + assert dums(exprs[0]) != dums(permut) + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + + +def test_get_subNO(): + p, q, r = symbols('p,q,r') + assert NO(F(p)*F(q)*F(r)).get_subNO(1) == NO(F(p)*F(r)) + assert NO(F(p)*F(q)*F(r)).get_subNO(0) == NO(F(q)*F(r)) + assert NO(F(p)*F(q)*F(r)).get_subNO(2) == NO(F(p)*F(q)) + + +def test_equivalent_internal_lines_VT1T1(): + i, j, k, l = symbols('i j k l', below_fermi=True, cls=Dummy) + a, b, c, d = symbols('a b c d', above_fermi=True, cls=Dummy) + + v = Function('v') + t = Function('t') + dums = _get_ordered_dummies + + exprs = [ # permute v. Different dummy order. Not equivalent. + v(i, j, a, b)*t(a, i)*t(b, j), + v(j, i, a, b)*t(a, i)*t(b, j), + v(i, j, b, a)*t(a, i)*t(b, j), + ] + for permut in exprs[1:]: + assert dums(exprs[0]) != dums(permut) + assert substitute_dummies(exprs[0]) != substitute_dummies(permut) + + exprs = [ # permute v. Different dummy order. Equivalent + v(i, j, a, b)*t(a, i)*t(b, j), + v(j, i, b, a)*t(a, i)*t(b, j), + ] + for permut in exprs[1:]: + assert dums(exprs[0]) != dums(permut) + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + + exprs = [ # permute t. Same dummy order, not equivalent. + v(i, j, a, b)*t(a, i)*t(b, j), + v(i, j, a, b)*t(b, i)*t(a, j), + ] + for permut in exprs[1:]: + assert dums(exprs[0]) == dums(permut) + assert substitute_dummies(exprs[0]) != substitute_dummies(permut) + + exprs = [ # permute v and t. Different dummy order, equivalent + v(i, j, a, b)*t(a, i)*t(b, j), + v(j, i, a, b)*t(a, j)*t(b, i), + v(i, j, b, a)*t(b, i)*t(a, j), + v(j, i, b, a)*t(b, j)*t(a, i), + ] + for permut in exprs[1:]: + assert dums(exprs[0]) != dums(permut) + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + + +def test_equivalent_internal_lines_VT2conjT2(): + # this diagram requires special handling in TCE + i, j, k, l, m, n = symbols('i j k l m n', below_fermi=True, cls=Dummy) + a, b, c, d, e, f = symbols('a b c d e f', above_fermi=True, cls=Dummy) + p1, p2, p3, p4 = symbols('p1 p2 p3 p4', above_fermi=True, cls=Dummy) + h1, h2, h3, h4 = symbols('h1 h2 h3 h4', below_fermi=True, cls=Dummy) + + from sympy.utilities.iterables import variations + + v = Function('v') + t = Function('t') + dums = _get_ordered_dummies + + # v(abcd)t(abij)t(ijcd) + template = v(p1, p2, p3, p4)*t(p1, p2, i, j)*t(i, j, p3, p4) + permutator = variations([a, b, c, d], 4) + base = template.subs(zip([p1, p2, p3, p4], next(permutator))) + for permut in permutator: + subslist = zip([p1, p2, p3, p4], permut) + expr = template.subs(subslist) + assert dums(base) != dums(expr) + assert substitute_dummies(expr) == substitute_dummies(base) + template = v(p1, p2, p3, p4)*t(p1, p2, j, i)*t(j, i, p3, p4) + permutator = variations([a, b, c, d], 4) + base = template.subs(zip([p1, p2, p3, p4], next(permutator))) + for permut in permutator: + subslist = zip([p1, p2, p3, p4], permut) + expr = template.subs(subslist) + assert dums(base) != dums(expr) + assert substitute_dummies(expr) == substitute_dummies(base) + + # v(abcd)t(abij)t(jicd) + template = v(p1, p2, p3, p4)*t(p1, p2, i, j)*t(j, i, p3, p4) + permutator = variations([a, b, c, d], 4) + base = template.subs(zip([p1, p2, p3, p4], next(permutator))) + for permut in permutator: + subslist = zip([p1, p2, p3, p4], permut) + expr = template.subs(subslist) + assert dums(base) != dums(expr) + assert substitute_dummies(expr) == substitute_dummies(base) + template = v(p1, p2, p3, p4)*t(p1, p2, j, i)*t(i, j, p3, p4) + permutator = variations([a, b, c, d], 4) + base = template.subs(zip([p1, p2, p3, p4], next(permutator))) + for permut in permutator: + subslist = zip([p1, p2, p3, p4], permut) + expr = template.subs(subslist) + assert dums(base) != dums(expr) + assert substitute_dummies(expr) == substitute_dummies(base) + + +def test_equivalent_internal_lines_VT2conjT2_ambiguous_order(): + # These diagrams invokes _determine_ambiguous() because the + # dummies can not be ordered unambiguously by the key alone + i, j, k, l, m, n = symbols('i j k l m n', below_fermi=True, cls=Dummy) + a, b, c, d, e, f = symbols('a b c d e f', above_fermi=True, cls=Dummy) + p1, p2, p3, p4 = symbols('p1 p2 p3 p4', above_fermi=True, cls=Dummy) + h1, h2, h3, h4 = symbols('h1 h2 h3 h4', below_fermi=True, cls=Dummy) + + from sympy.utilities.iterables import variations + + v = Function('v') + t = Function('t') + dums = _get_ordered_dummies + + # v(abcd)t(abij)t(cdij) + template = v(p1, p2, p3, p4)*t(p1, p2, i, j)*t(p3, p4, i, j) + permutator = variations([a, b, c, d], 4) + base = template.subs(zip([p1, p2, p3, p4], next(permutator))) + for permut in permutator: + subslist = zip([p1, p2, p3, p4], permut) + expr = template.subs(subslist) + assert dums(base) != dums(expr) + assert substitute_dummies(expr) == substitute_dummies(base) + template = v(p1, p2, p3, p4)*t(p1, p2, j, i)*t(p3, p4, i, j) + permutator = variations([a, b, c, d], 4) + base = template.subs(zip([p1, p2, p3, p4], next(permutator))) + for permut in permutator: + subslist = zip([p1, p2, p3, p4], permut) + expr = template.subs(subslist) + assert dums(base) != dums(expr) + assert substitute_dummies(expr) == substitute_dummies(base) + + +def test_equivalent_internal_lines_VT2(): + i, j, k, l = symbols('i j k l', below_fermi=True, cls=Dummy) + a, b, c, d = symbols('a b c d', above_fermi=True, cls=Dummy) + + v = Function('v') + t = Function('t') + dums = _get_ordered_dummies + exprs = [ + # permute v. Same dummy order, not equivalent. + # + # This test show that the dummy order may not be sensitive to all + # index permutations. The following expressions have identical + # structure as the resulting terms from of the dummy substitutions + # in the test above. Here, all expressions have the same dummy + # order, so they cannot be simplified by means of dummy + # substitution. In order to simplify further, it is necessary to + # exploit symmetries in the objects, for instance if t or v is + # antisymmetric. + v(i, j, a, b)*t(a, b, i, j), + v(j, i, a, b)*t(a, b, i, j), + v(i, j, b, a)*t(a, b, i, j), + v(j, i, b, a)*t(a, b, i, j), + ] + for permut in exprs[1:]: + assert dums(exprs[0]) == dums(permut) + assert substitute_dummies(exprs[0]) != substitute_dummies(permut) + + exprs = [ + # permute t. + v(i, j, a, b)*t(a, b, i, j), + v(i, j, a, b)*t(b, a, i, j), + v(i, j, a, b)*t(a, b, j, i), + v(i, j, a, b)*t(b, a, j, i), + ] + for permut in exprs[1:]: + assert dums(exprs[0]) != dums(permut) + assert substitute_dummies(exprs[0]) != substitute_dummies(permut) + + exprs = [ # permute v and t. Relabelling of dummies should be equivalent. + v(i, j, a, b)*t(a, b, i, j), + v(j, i, a, b)*t(a, b, j, i), + v(i, j, b, a)*t(b, a, i, j), + v(j, i, b, a)*t(b, a, j, i), + ] + for permut in exprs[1:]: + assert dums(exprs[0]) != dums(permut) + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + + +def test_internal_external_VT2T2(): + ii, jj = symbols('i j', below_fermi=True) + aa, bb = symbols('a b', above_fermi=True) + k, l = symbols('k l', below_fermi=True, cls=Dummy) + c, d = symbols('c d', above_fermi=True, cls=Dummy) + + v = Function('v') + t = Function('t') + dums = _get_ordered_dummies + + exprs = [ + v(k, l, c, d)*t(aa, c, ii, k)*t(bb, d, jj, l), + v(l, k, c, d)*t(aa, c, ii, l)*t(bb, d, jj, k), + v(k, l, d, c)*t(aa, d, ii, k)*t(bb, c, jj, l), + v(l, k, d, c)*t(aa, d, ii, l)*t(bb, c, jj, k), + ] + for permut in exprs[1:]: + assert dums(exprs[0]) != dums(permut) + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + exprs = [ + v(k, l, c, d)*t(aa, c, ii, k)*t(d, bb, jj, l), + v(l, k, c, d)*t(aa, c, ii, l)*t(d, bb, jj, k), + v(k, l, d, c)*t(aa, d, ii, k)*t(c, bb, jj, l), + v(l, k, d, c)*t(aa, d, ii, l)*t(c, bb, jj, k), + ] + for permut in exprs[1:]: + assert dums(exprs[0]) != dums(permut) + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + exprs = [ + v(k, l, c, d)*t(c, aa, ii, k)*t(bb, d, jj, l), + v(l, k, c, d)*t(c, aa, ii, l)*t(bb, d, jj, k), + v(k, l, d, c)*t(d, aa, ii, k)*t(bb, c, jj, l), + v(l, k, d, c)*t(d, aa, ii, l)*t(bb, c, jj, k), + ] + for permut in exprs[1:]: + assert dums(exprs[0]) != dums(permut) + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + + +def test_internal_external_pqrs(): + ii, jj = symbols('i j') + aa, bb = symbols('a b') + k, l = symbols('k l', cls=Dummy) + c, d = symbols('c d', cls=Dummy) + + v = Function('v') + t = Function('t') + dums = _get_ordered_dummies + + exprs = [ + v(k, l, c, d)*t(aa, c, ii, k)*t(bb, d, jj, l), + v(l, k, c, d)*t(aa, c, ii, l)*t(bb, d, jj, k), + v(k, l, d, c)*t(aa, d, ii, k)*t(bb, c, jj, l), + v(l, k, d, c)*t(aa, d, ii, l)*t(bb, c, jj, k), + ] + for permut in exprs[1:]: + assert dums(exprs[0]) != dums(permut) + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + + +def test_dummy_order_well_defined(): + aa, bb = symbols('a b', above_fermi=True) + k, l, m = symbols('k l m', below_fermi=True, cls=Dummy) + c, d = symbols('c d', above_fermi=True, cls=Dummy) + p, q = symbols('p q', cls=Dummy) + + A = Function('A') + B = Function('B') + C = Function('C') + dums = _get_ordered_dummies + + # We go through all key components in the order of increasing priority, + # and consider only fully orderable expressions. Non-orderable expressions + # are tested elsewhere. + + # pos in first factor determines sort order + assert dums(A(k, l)*B(l, k)) == [k, l] + assert dums(A(l, k)*B(l, k)) == [l, k] + assert dums(A(k, l)*B(k, l)) == [k, l] + assert dums(A(l, k)*B(k, l)) == [l, k] + + # factors involving the index + assert dums(A(k, l)*B(l, m)*C(k, m)) == [l, k, m] + assert dums(A(k, l)*B(l, m)*C(m, k)) == [l, k, m] + assert dums(A(l, k)*B(l, m)*C(k, m)) == [l, k, m] + assert dums(A(l, k)*B(l, m)*C(m, k)) == [l, k, m] + assert dums(A(k, l)*B(m, l)*C(k, m)) == [l, k, m] + assert dums(A(k, l)*B(m, l)*C(m, k)) == [l, k, m] + assert dums(A(l, k)*B(m, l)*C(k, m)) == [l, k, m] + assert dums(A(l, k)*B(m, l)*C(m, k)) == [l, k, m] + + # same, but with factor order determined by non-dummies + assert dums(A(k, aa, l)*A(l, bb, m)*A(bb, k, m)) == [l, k, m] + assert dums(A(k, aa, l)*A(l, bb, m)*A(bb, m, k)) == [l, k, m] + assert dums(A(k, aa, l)*A(m, bb, l)*A(bb, k, m)) == [l, k, m] + assert dums(A(k, aa, l)*A(m, bb, l)*A(bb, m, k)) == [l, k, m] + assert dums(A(l, aa, k)*A(l, bb, m)*A(bb, k, m)) == [l, k, m] + assert dums(A(l, aa, k)*A(l, bb, m)*A(bb, m, k)) == [l, k, m] + assert dums(A(l, aa, k)*A(m, bb, l)*A(bb, k, m)) == [l, k, m] + assert dums(A(l, aa, k)*A(m, bb, l)*A(bb, m, k)) == [l, k, m] + + # index range + assert dums(A(p, c, k)*B(p, c, k)) == [k, c, p] + assert dums(A(p, k, c)*B(p, c, k)) == [k, c, p] + assert dums(A(c, k, p)*B(p, c, k)) == [k, c, p] + assert dums(A(c, p, k)*B(p, c, k)) == [k, c, p] + assert dums(A(k, c, p)*B(p, c, k)) == [k, c, p] + assert dums(A(k, p, c)*B(p, c, k)) == [k, c, p] + assert dums(B(p, c, k)*A(p, c, k)) == [k, c, p] + assert dums(B(p, k, c)*A(p, c, k)) == [k, c, p] + assert dums(B(c, k, p)*A(p, c, k)) == [k, c, p] + assert dums(B(c, p, k)*A(p, c, k)) == [k, c, p] + assert dums(B(k, c, p)*A(p, c, k)) == [k, c, p] + assert dums(B(k, p, c)*A(p, c, k)) == [k, c, p] + + +def test_dummy_order_ambiguous(): + aa, bb = symbols('a b', above_fermi=True) + i, j, k, l, m = symbols('i j k l m', below_fermi=True, cls=Dummy) + a, b, c, d, e = symbols('a b c d e', above_fermi=True, cls=Dummy) + p, q = symbols('p q', cls=Dummy) + p1, p2, p3, p4 = symbols('p1 p2 p3 p4', above_fermi=True, cls=Dummy) + p5, p6, p7, p8 = symbols('p5 p6 p7 p8', above_fermi=True, cls=Dummy) + h1, h2, h3, h4 = symbols('h1 h2 h3 h4', below_fermi=True, cls=Dummy) + h5, h6, h7, h8 = symbols('h5 h6 h7 h8', below_fermi=True, cls=Dummy) + + A = Function('A') + B = Function('B') + + from sympy.utilities.iterables import variations + + # A*A*A*A*B -- ordering of p5 and p4 is used to figure out the rest + template = A(p1, p2)*A(p4, p1)*A(p2, p3)*A(p3, p5)*B(p5, p4) + permutator = variations([a, b, c, d, e], 5) + base = template.subs(zip([p1, p2, p3, p4, p5], next(permutator))) + for permut in permutator: + subslist = zip([p1, p2, p3, p4, p5], permut) + expr = template.subs(subslist) + assert substitute_dummies(expr) == substitute_dummies(base) + + # A*A*A*A*A -- an arbitrary index is assigned and the rest are figured out + template = A(p1, p2)*A(p4, p1)*A(p2, p3)*A(p3, p5)*A(p5, p4) + permutator = variations([a, b, c, d, e], 5) + base = template.subs(zip([p1, p2, p3, p4, p5], next(permutator))) + for permut in permutator: + subslist = zip([p1, p2, p3, p4, p5], permut) + expr = template.subs(subslist) + assert substitute_dummies(expr) == substitute_dummies(base) + + # A*A*A -- ordering of p5 and p4 is used to figure out the rest + template = A(p1, p2, p4, p1)*A(p2, p3, p3, p5)*A(p5, p4) + permutator = variations([a, b, c, d, e], 5) + base = template.subs(zip([p1, p2, p3, p4, p5], next(permutator))) + for permut in permutator: + subslist = zip([p1, p2, p3, p4, p5], permut) + expr = template.subs(subslist) + assert substitute_dummies(expr) == substitute_dummies(base) + + +def atv(*args): + return AntiSymmetricTensor('v', args[:2], args[2:] ) + + +def att(*args): + if len(args) == 4: + return AntiSymmetricTensor('t', args[:2], args[2:] ) + elif len(args) == 2: + return AntiSymmetricTensor('t', (args[0],), (args[1],)) + + +def test_dummy_order_inner_outer_lines_VT1T1T1_AT(): + ii = symbols('i', below_fermi=True) + aa = symbols('a', above_fermi=True) + k, l = symbols('k l', below_fermi=True, cls=Dummy) + c, d = symbols('c d', above_fermi=True, cls=Dummy) + + # Coupled-Cluster T1 terms with V*T1*T1*T1 + # t^{a}_{k} t^{c}_{i} t^{d}_{l} v^{lk}_{dc} + exprs = [ + # permut v and t <=> swapping internal lines, equivalent + # irrespective of symmetries in v + atv(k, l, c, d)*att(c, ii)*att(d, l)*att(aa, k), + atv(l, k, c, d)*att(c, ii)*att(d, k)*att(aa, l), + atv(k, l, d, c)*att(d, ii)*att(c, l)*att(aa, k), + atv(l, k, d, c)*att(d, ii)*att(c, k)*att(aa, l), + ] + for permut in exprs[1:]: + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + + +def test_dummy_order_inner_outer_lines_VT1T1T1T1_AT(): + ii, jj = symbols('i j', below_fermi=True) + aa, bb = symbols('a b', above_fermi=True) + k, l = symbols('k l', below_fermi=True, cls=Dummy) + c, d = symbols('c d', above_fermi=True, cls=Dummy) + + # Coupled-Cluster T2 terms with V*T1*T1*T1*T1 + # non-equivalent substitutions (change of sign) + exprs = [ + # permut t <=> swapping external lines + atv(k, l, c, d)*att(c, ii)*att(d, jj)*att(aa, k)*att(bb, l), + atv(k, l, c, d)*att(c, jj)*att(d, ii)*att(aa, k)*att(bb, l), + atv(k, l, c, d)*att(c, ii)*att(d, jj)*att(bb, k)*att(aa, l), + ] + for permut in exprs[1:]: + assert substitute_dummies(exprs[0]) == -substitute_dummies(permut) + + # equivalent substitutions + exprs = [ + atv(k, l, c, d)*att(c, ii)*att(d, jj)*att(aa, k)*att(bb, l), + # permut t <=> swapping external lines + atv(k, l, c, d)*att(c, jj)*att(d, ii)*att(bb, k)*att(aa, l), + ] + for permut in exprs[1:]: + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + + +def test_equivalent_internal_lines_VT1T1_AT(): + i, j, k, l = symbols('i j k l', below_fermi=True, cls=Dummy) + a, b, c, d = symbols('a b c d', above_fermi=True, cls=Dummy) + + exprs = [ # permute v. Different dummy order. Not equivalent. + atv(i, j, a, b)*att(a, i)*att(b, j), + atv(j, i, a, b)*att(a, i)*att(b, j), + atv(i, j, b, a)*att(a, i)*att(b, j), + ] + for permut in exprs[1:]: + assert substitute_dummies(exprs[0]) != substitute_dummies(permut) + + exprs = [ # permute v. Different dummy order. Equivalent + atv(i, j, a, b)*att(a, i)*att(b, j), + atv(j, i, b, a)*att(a, i)*att(b, j), + ] + for permut in exprs[1:]: + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + + exprs = [ # permute t. Same dummy order, not equivalent. + atv(i, j, a, b)*att(a, i)*att(b, j), + atv(i, j, a, b)*att(b, i)*att(a, j), + ] + for permut in exprs[1:]: + assert substitute_dummies(exprs[0]) != substitute_dummies(permut) + + exprs = [ # permute v and t. Different dummy order, equivalent + atv(i, j, a, b)*att(a, i)*att(b, j), + atv(j, i, a, b)*att(a, j)*att(b, i), + atv(i, j, b, a)*att(b, i)*att(a, j), + atv(j, i, b, a)*att(b, j)*att(a, i), + ] + for permut in exprs[1:]: + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + + +def test_equivalent_internal_lines_VT2conjT2_AT(): + # this diagram requires special handling in TCE + i, j, k, l, m, n = symbols('i j k l m n', below_fermi=True, cls=Dummy) + a, b, c, d, e, f = symbols('a b c d e f', above_fermi=True, cls=Dummy) + p1, p2, p3, p4 = symbols('p1 p2 p3 p4', above_fermi=True, cls=Dummy) + h1, h2, h3, h4 = symbols('h1 h2 h3 h4', below_fermi=True, cls=Dummy) + + from sympy.utilities.iterables import variations + + # atv(abcd)att(abij)att(ijcd) + template = atv(p1, p2, p3, p4)*att(p1, p2, i, j)*att(i, j, p3, p4) + permutator = variations([a, b, c, d], 4) + base = template.subs(zip([p1, p2, p3, p4], next(permutator))) + for permut in permutator: + subslist = zip([p1, p2, p3, p4], permut) + expr = template.subs(subslist) + assert substitute_dummies(expr) == substitute_dummies(base) + template = atv(p1, p2, p3, p4)*att(p1, p2, j, i)*att(j, i, p3, p4) + permutator = variations([a, b, c, d], 4) + base = template.subs(zip([p1, p2, p3, p4], next(permutator))) + for permut in permutator: + subslist = zip([p1, p2, p3, p4], permut) + expr = template.subs(subslist) + assert substitute_dummies(expr) == substitute_dummies(base) + + # atv(abcd)att(abij)att(jicd) + template = atv(p1, p2, p3, p4)*att(p1, p2, i, j)*att(j, i, p3, p4) + permutator = variations([a, b, c, d], 4) + base = template.subs(zip([p1, p2, p3, p4], next(permutator))) + for permut in permutator: + subslist = zip([p1, p2, p3, p4], permut) + expr = template.subs(subslist) + assert substitute_dummies(expr) == substitute_dummies(base) + template = atv(p1, p2, p3, p4)*att(p1, p2, j, i)*att(i, j, p3, p4) + permutator = variations([a, b, c, d], 4) + base = template.subs(zip([p1, p2, p3, p4], next(permutator))) + for permut in permutator: + subslist = zip([p1, p2, p3, p4], permut) + expr = template.subs(subslist) + assert substitute_dummies(expr) == substitute_dummies(base) + + +def test_equivalent_internal_lines_VT2conjT2_ambiguous_order_AT(): + # These diagrams invokes _determine_ambiguous() because the + # dummies can not be ordered unambiguously by the key alone + i, j, k, l, m, n = symbols('i j k l m n', below_fermi=True, cls=Dummy) + a, b, c, d, e, f = symbols('a b c d e f', above_fermi=True, cls=Dummy) + p1, p2, p3, p4 = symbols('p1 p2 p3 p4', above_fermi=True, cls=Dummy) + h1, h2, h3, h4 = symbols('h1 h2 h3 h4', below_fermi=True, cls=Dummy) + + from sympy.utilities.iterables import variations + + # atv(abcd)att(abij)att(cdij) + template = atv(p1, p2, p3, p4)*att(p1, p2, i, j)*att(p3, p4, i, j) + permutator = variations([a, b, c, d], 4) + base = template.subs(zip([p1, p2, p3, p4], next(permutator))) + for permut in permutator: + subslist = zip([p1, p2, p3, p4], permut) + expr = template.subs(subslist) + assert substitute_dummies(expr) == substitute_dummies(base) + template = atv(p1, p2, p3, p4)*att(p1, p2, j, i)*att(p3, p4, i, j) + permutator = variations([a, b, c, d], 4) + base = template.subs(zip([p1, p2, p3, p4], next(permutator))) + for permut in permutator: + subslist = zip([p1, p2, p3, p4], permut) + expr = template.subs(subslist) + assert substitute_dummies(expr) == substitute_dummies(base) + + +def test_equivalent_internal_lines_VT2_AT(): + i, j, k, l = symbols('i j k l', below_fermi=True, cls=Dummy) + a, b, c, d = symbols('a b c d', above_fermi=True, cls=Dummy) + + exprs = [ + # permute v. Same dummy order, not equivalent. + atv(i, j, a, b)*att(a, b, i, j), + atv(j, i, a, b)*att(a, b, i, j), + atv(i, j, b, a)*att(a, b, i, j), + ] + for permut in exprs[1:]: + assert substitute_dummies(exprs[0]) != substitute_dummies(permut) + + exprs = [ + # permute t. + atv(i, j, a, b)*att(a, b, i, j), + atv(i, j, a, b)*att(b, a, i, j), + atv(i, j, a, b)*att(a, b, j, i), + ] + for permut in exprs[1:]: + assert substitute_dummies(exprs[0]) != substitute_dummies(permut) + + exprs = [ # permute v and t. Relabelling of dummies should be equivalent. + atv(i, j, a, b)*att(a, b, i, j), + atv(j, i, a, b)*att(a, b, j, i), + atv(i, j, b, a)*att(b, a, i, j), + atv(j, i, b, a)*att(b, a, j, i), + ] + for permut in exprs[1:]: + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + + +def test_internal_external_VT2T2_AT(): + ii, jj = symbols('i j', below_fermi=True) + aa, bb = symbols('a b', above_fermi=True) + k, l = symbols('k l', below_fermi=True, cls=Dummy) + c, d = symbols('c d', above_fermi=True, cls=Dummy) + + exprs = [ + atv(k, l, c, d)*att(aa, c, ii, k)*att(bb, d, jj, l), + atv(l, k, c, d)*att(aa, c, ii, l)*att(bb, d, jj, k), + atv(k, l, d, c)*att(aa, d, ii, k)*att(bb, c, jj, l), + atv(l, k, d, c)*att(aa, d, ii, l)*att(bb, c, jj, k), + ] + for permut in exprs[1:]: + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + exprs = [ + atv(k, l, c, d)*att(aa, c, ii, k)*att(d, bb, jj, l), + atv(l, k, c, d)*att(aa, c, ii, l)*att(d, bb, jj, k), + atv(k, l, d, c)*att(aa, d, ii, k)*att(c, bb, jj, l), + atv(l, k, d, c)*att(aa, d, ii, l)*att(c, bb, jj, k), + ] + for permut in exprs[1:]: + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + exprs = [ + atv(k, l, c, d)*att(c, aa, ii, k)*att(bb, d, jj, l), + atv(l, k, c, d)*att(c, aa, ii, l)*att(bb, d, jj, k), + atv(k, l, d, c)*att(d, aa, ii, k)*att(bb, c, jj, l), + atv(l, k, d, c)*att(d, aa, ii, l)*att(bb, c, jj, k), + ] + for permut in exprs[1:]: + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + + +def test_internal_external_pqrs_AT(): + ii, jj = symbols('i j') + aa, bb = symbols('a b') + k, l = symbols('k l', cls=Dummy) + c, d = symbols('c d', cls=Dummy) + + exprs = [ + atv(k, l, c, d)*att(aa, c, ii, k)*att(bb, d, jj, l), + atv(l, k, c, d)*att(aa, c, ii, l)*att(bb, d, jj, k), + atv(k, l, d, c)*att(aa, d, ii, k)*att(bb, c, jj, l), + atv(l, k, d, c)*att(aa, d, ii, l)*att(bb, c, jj, k), + ] + for permut in exprs[1:]: + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + + +def test_issue_19661(): + a = Symbol('0') + assert latex(Commutator(Bd(a)**2, B(a)) + ) == '- \\left[b_{0},{b^\\dagger_{0}}^{2}\\right]' + + +def test_canonical_ordering_AntiSymmetricTensor(): + v = symbols("v") + + c, d = symbols(('c','d'), above_fermi=True, + cls=Dummy) + k, l = symbols(('k','l'), below_fermi=True, + cls=Dummy) + + # formerly, the left gave either the left or the right + assert AntiSymmetricTensor(v, (k, l), (d, c) + ) == -AntiSymmetricTensor(v, (l, k), (d, c)) diff --git a/phivenv/Lib/site-packages/sympy/physics/tests/test_sho.py b/phivenv/Lib/site-packages/sympy/physics/tests/test_sho.py new file mode 100644 index 0000000000000000000000000000000000000000..7248838b4bb9ad280fd4211bbe208063b65adcf5 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/tests/test_sho.py @@ -0,0 +1,21 @@ +from sympy.core import symbols, Rational, Function, diff +from sympy.physics.sho import R_nl, E_nl +from sympy.simplify.simplify import simplify + + +def test_sho_R_nl(): + omega, r = symbols('omega r') + l = symbols('l', integer=True) + u = Function('u') + + # check that it obeys the Schrodinger equation + for n in range(5): + schreq = ( -diff(u(r), r, 2)/2 + ((l*(l + 1))/(2*r**2) + + omega**2*r**2/2 - E_nl(n, l, omega))*u(r) ) + result = schreq.subs(u(r), r*R_nl(n, l, omega/2, r)) + assert simplify(result.doit()) == 0 + + +def test_energy(): + n, l, hw = symbols('n l hw') + assert simplify(E_nl(n, l, hw) - (2*n + l + Rational(3, 2))*hw) == 0 diff --git a/phivenv/Lib/site-packages/sympy/physics/units/__init__.py b/phivenv/Lib/site-packages/sympy/physics/units/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bf17c7f3051b03d9c0fc794d9d79885c94cc878e --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/units/__init__.py @@ -0,0 +1,453 @@ +# isort:skip_file +""" +Dimensional analysis and unit systems. + +This module defines dimension/unit systems and physical quantities. It is +based on a group-theoretical construction where dimensions are represented as +vectors (coefficients being the exponents), and units are defined as a dimension +to which we added a scale. + +Quantities are built from a factor and a unit, and are the basic objects that +one will use when doing computations. + +All objects except systems and prefixes can be used in SymPy expressions. +Note that as part of a CAS, various objects do not combine automatically +under operations. + +Details about the implementation can be found in the documentation, and we +will not repeat all the explanations we gave there concerning our approach. +Ideas about future developments can be found on the `Github wiki +`_, and you should consult +this page if you are willing to help. + +Useful functions: + +- ``find_unit``: easily lookup pre-defined units. +- ``convert_to(expr, newunit)``: converts an expression into the same + expression expressed in another unit. + +""" + +from .dimensions import Dimension, DimensionSystem +from .unitsystem import UnitSystem +from .util import convert_to +from .quantities import Quantity + +from .definitions.dimension_definitions import ( + amount_of_substance, acceleration, action, area, + capacitance, charge, conductance, current, energy, + force, frequency, impedance, inductance, length, + luminous_intensity, magnetic_density, + magnetic_flux, mass, momentum, power, pressure, temperature, time, + velocity, voltage, volume +) + +Unit = Quantity + +speed = velocity +luminosity = luminous_intensity +magnetic_flux_density = magnetic_density +amount = amount_of_substance + +from .prefixes import ( + # 10-power based: + yotta, + zetta, + exa, + peta, + tera, + giga, + mega, + kilo, + hecto, + deca, + deci, + centi, + milli, + micro, + nano, + pico, + femto, + atto, + zepto, + yocto, + # 2-power based: + kibi, + mebi, + gibi, + tebi, + pebi, + exbi, +) + +from .definitions import ( + percent, percents, + permille, + rad, radian, radians, + deg, degree, degrees, + sr, steradian, steradians, + mil, angular_mil, angular_mils, + m, meter, meters, + kg, kilogram, kilograms, + s, second, seconds, + A, ampere, amperes, + K, kelvin, kelvins, + mol, mole, moles, + cd, candela, candelas, + g, gram, grams, + mg, milligram, milligrams, + ug, microgram, micrograms, + t, tonne, metric_ton, + newton, newtons, N, + joule, joules, J, + watt, watts, W, + pascal, pascals, Pa, pa, + hertz, hz, Hz, + coulomb, coulombs, C, + volt, volts, v, V, + ohm, ohms, + siemens, S, mho, mhos, + farad, farads, F, + henry, henrys, H, + tesla, teslas, T, + weber, webers, Wb, wb, + optical_power, dioptre, D, + lux, lx, + katal, kat, + gray, Gy, + becquerel, Bq, + km, kilometer, kilometers, + dm, decimeter, decimeters, + cm, centimeter, centimeters, + mm, millimeter, millimeters, + um, micrometer, micrometers, micron, microns, + nm, nanometer, nanometers, + pm, picometer, picometers, + ft, foot, feet, + inch, inches, + yd, yard, yards, + mi, mile, miles, + nmi, nautical_mile, nautical_miles, + angstrom, angstroms, + ha, hectare, + l, L, liter, liters, + dl, dL, deciliter, deciliters, + cl, cL, centiliter, centiliters, + ml, mL, milliliter, milliliters, + ms, millisecond, milliseconds, + us, microsecond, microseconds, + ns, nanosecond, nanoseconds, + ps, picosecond, picoseconds, + minute, minutes, + h, hour, hours, + day, days, + anomalistic_year, anomalistic_years, + sidereal_year, sidereal_years, + tropical_year, tropical_years, + common_year, common_years, + julian_year, julian_years, + draconic_year, draconic_years, + gaussian_year, gaussian_years, + full_moon_cycle, full_moon_cycles, + year, years, + G, gravitational_constant, + c, speed_of_light, + elementary_charge, + hbar, + planck, + eV, electronvolt, electronvolts, + avogadro_number, + avogadro, avogadro_constant, + boltzmann, boltzmann_constant, + stefan, stefan_boltzmann_constant, + R, molar_gas_constant, + faraday_constant, + josephson_constant, + von_klitzing_constant, + Da, dalton, amu, amus, atomic_mass_unit, atomic_mass_constant, + me, electron_rest_mass, + gee, gees, acceleration_due_to_gravity, + u0, magnetic_constant, vacuum_permeability, + e0, electric_constant, vacuum_permittivity, + Z0, vacuum_impedance, + coulomb_constant, electric_force_constant, + atmosphere, atmospheres, atm, + kPa, + bar, bars, + pound, pounds, + psi, + dHg0, + mmHg, torr, + mmu, mmus, milli_mass_unit, + quart, quarts, + ly, lightyear, lightyears, + au, astronomical_unit, astronomical_units, + planck_mass, + planck_time, + planck_temperature, + planck_length, + planck_charge, + planck_area, + planck_volume, + planck_momentum, + planck_energy, + planck_force, + planck_power, + planck_density, + planck_energy_density, + planck_intensity, + planck_angular_frequency, + planck_pressure, + planck_current, + planck_voltage, + planck_impedance, + planck_acceleration, + bit, bits, + byte, + kibibyte, kibibytes, + mebibyte, mebibytes, + gibibyte, gibibytes, + tebibyte, tebibytes, + pebibyte, pebibytes, + exbibyte, exbibytes, +) + +from .systems import ( + mks, mksa, si +) + + +def find_unit(quantity, unit_system="SI"): + """ + Return a list of matching units or dimension names. + + - If ``quantity`` is a string -- units/dimensions containing the string + `quantity`. + - If ``quantity`` is a unit or dimension -- units having matching base + units or dimensions. + + Examples + ======== + + >>> from sympy.physics import units as u + >>> u.find_unit('charge') + ['C', 'coulomb', 'coulombs', 'planck_charge', 'elementary_charge'] + >>> u.find_unit(u.charge) + ['C', 'coulomb', 'coulombs', 'planck_charge', 'elementary_charge'] + >>> u.find_unit("ampere") + ['ampere', 'amperes'] + >>> u.find_unit('angstrom') + ['angstrom', 'angstroms'] + >>> u.find_unit('volt') + ['volt', 'volts', 'electronvolt', 'electronvolts', 'planck_voltage'] + >>> u.find_unit(u.inch**3)[:9] + ['L', 'l', 'cL', 'cl', 'dL', 'dl', 'mL', 'ml', 'liter'] + """ + unit_system = UnitSystem.get_unit_system(unit_system) + + import sympy.physics.units as u + rv = [] + if isinstance(quantity, str): + rv = [i for i in dir(u) if quantity in i and isinstance(getattr(u, i), Quantity)] + dim = getattr(u, quantity) + if isinstance(dim, Dimension): + rv.extend(find_unit(dim)) + else: + for i in sorted(dir(u)): + other = getattr(u, i) + if not isinstance(other, Quantity): + continue + if isinstance(quantity, Quantity): + if quantity.dimension == other.dimension: + rv.append(str(i)) + elif isinstance(quantity, Dimension): + if other.dimension == quantity: + rv.append(str(i)) + elif other.dimension == Dimension(unit_system.get_dimensional_expr(quantity)): + rv.append(str(i)) + return sorted(set(rv), key=lambda x: (len(x), x)) + +# NOTE: the old units module had additional variables: +# 'density', 'illuminance', 'resistance'. +# They were not dimensions, but units (old Unit class). + +__all__ = [ + 'Dimension', 'DimensionSystem', + 'UnitSystem', + 'convert_to', + 'Quantity', + + 'amount_of_substance', 'acceleration', 'action', 'area', + 'capacitance', 'charge', 'conductance', 'current', 'energy', + 'force', 'frequency', 'impedance', 'inductance', 'length', + 'luminous_intensity', 'magnetic_density', + 'magnetic_flux', 'mass', 'momentum', 'power', 'pressure', 'temperature', 'time', + 'velocity', 'voltage', 'volume', + + 'Unit', + + 'speed', + 'luminosity', + 'magnetic_flux_density', + 'amount', + + 'yotta', + 'zetta', + 'exa', + 'peta', + 'tera', + 'giga', + 'mega', + 'kilo', + 'hecto', + 'deca', + 'deci', + 'centi', + 'milli', + 'micro', + 'nano', + 'pico', + 'femto', + 'atto', + 'zepto', + 'yocto', + + 'kibi', + 'mebi', + 'gibi', + 'tebi', + 'pebi', + 'exbi', + + 'percent', 'percents', + 'permille', + 'rad', 'radian', 'radians', + 'deg', 'degree', 'degrees', + 'sr', 'steradian', 'steradians', + 'mil', 'angular_mil', 'angular_mils', + 'm', 'meter', 'meters', + 'kg', 'kilogram', 'kilograms', + 's', 'second', 'seconds', + 'A', 'ampere', 'amperes', + 'K', 'kelvin', 'kelvins', + 'mol', 'mole', 'moles', + 'cd', 'candela', 'candelas', + 'g', 'gram', 'grams', + 'mg', 'milligram', 'milligrams', + 'ug', 'microgram', 'micrograms', + 't', 'tonne', 'metric_ton', + 'newton', 'newtons', 'N', + 'joule', 'joules', 'J', + 'watt', 'watts', 'W', + 'pascal', 'pascals', 'Pa', 'pa', + 'hertz', 'hz', 'Hz', + 'coulomb', 'coulombs', 'C', + 'volt', 'volts', 'v', 'V', + 'ohm', 'ohms', + 'siemens', 'S', 'mho', 'mhos', + 'farad', 'farads', 'F', + 'henry', 'henrys', 'H', + 'tesla', 'teslas', 'T', + 'weber', 'webers', 'Wb', 'wb', + 'optical_power', 'dioptre', 'D', + 'lux', 'lx', + 'katal', 'kat', + 'gray', 'Gy', + 'becquerel', 'Bq', + 'km', 'kilometer', 'kilometers', + 'dm', 'decimeter', 'decimeters', + 'cm', 'centimeter', 'centimeters', + 'mm', 'millimeter', 'millimeters', + 'um', 'micrometer', 'micrometers', 'micron', 'microns', + 'nm', 'nanometer', 'nanometers', + 'pm', 'picometer', 'picometers', + 'ft', 'foot', 'feet', + 'inch', 'inches', + 'yd', 'yard', 'yards', + 'mi', 'mile', 'miles', + 'nmi', 'nautical_mile', 'nautical_miles', + 'angstrom', 'angstroms', + 'ha', 'hectare', + 'l', 'L', 'liter', 'liters', + 'dl', 'dL', 'deciliter', 'deciliters', + 'cl', 'cL', 'centiliter', 'centiliters', + 'ml', 'mL', 'milliliter', 'milliliters', + 'ms', 'millisecond', 'milliseconds', + 'us', 'microsecond', 'microseconds', + 'ns', 'nanosecond', 'nanoseconds', + 'ps', 'picosecond', 'picoseconds', + 'minute', 'minutes', + 'h', 'hour', 'hours', + 'day', 'days', + 'anomalistic_year', 'anomalistic_years', + 'sidereal_year', 'sidereal_years', + 'tropical_year', 'tropical_years', + 'common_year', 'common_years', + 'julian_year', 'julian_years', + 'draconic_year', 'draconic_years', + 'gaussian_year', 'gaussian_years', + 'full_moon_cycle', 'full_moon_cycles', + 'year', 'years', + 'G', 'gravitational_constant', + 'c', 'speed_of_light', + 'elementary_charge', + 'hbar', + 'planck', + 'eV', 'electronvolt', 'electronvolts', + 'avogadro_number', + 'avogadro', 'avogadro_constant', + 'boltzmann', 'boltzmann_constant', + 'stefan', 'stefan_boltzmann_constant', + 'R', 'molar_gas_constant', + 'faraday_constant', + 'josephson_constant', + 'von_klitzing_constant', + 'Da', 'dalton', 'amu', 'amus', 'atomic_mass_unit', 'atomic_mass_constant', + 'me', 'electron_rest_mass', + 'gee', 'gees', 'acceleration_due_to_gravity', + 'u0', 'magnetic_constant', 'vacuum_permeability', + 'e0', 'electric_constant', 'vacuum_permittivity', + 'Z0', 'vacuum_impedance', + 'coulomb_constant', 'electric_force_constant', + 'atmosphere', 'atmospheres', 'atm', + 'kPa', + 'bar', 'bars', + 'pound', 'pounds', + 'psi', + 'dHg0', + 'mmHg', 'torr', + 'mmu', 'mmus', 'milli_mass_unit', + 'quart', 'quarts', + 'ly', 'lightyear', 'lightyears', + 'au', 'astronomical_unit', 'astronomical_units', + 'planck_mass', + 'planck_time', + 'planck_temperature', + 'planck_length', + 'planck_charge', + 'planck_area', + 'planck_volume', + 'planck_momentum', + 'planck_energy', + 'planck_force', + 'planck_power', + 'planck_density', + 'planck_energy_density', + 'planck_intensity', + 'planck_angular_frequency', + 'planck_pressure', + 'planck_current', + 'planck_voltage', + 'planck_impedance', + 'planck_acceleration', + 'bit', 'bits', + 'byte', + 'kibibyte', 'kibibytes', + 'mebibyte', 'mebibytes', + 'gibibyte', 'gibibytes', + 'tebibyte', 'tebibytes', + 'pebibyte', 'pebibytes', + 'exbibyte', 'exbibytes', + + 'mks', 'mksa', 'si', +] diff --git a/phivenv/Lib/site-packages/sympy/physics/units/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/units/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b06acd555c1adbb02fd75871d2d08d913020037 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/units/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/units/__pycache__/dimensions.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/units/__pycache__/dimensions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96453f0c2184b6babb6988b44ed9cd9a216d35ba Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/units/__pycache__/dimensions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/units/__pycache__/prefixes.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/units/__pycache__/prefixes.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9d3e62f015c961b33b733b3cf8ae308c11df20a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/units/__pycache__/prefixes.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/units/__pycache__/quantities.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/units/__pycache__/quantities.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c6295813dc90a73716a6ed7de57c0af93ad275f Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/units/__pycache__/quantities.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/units/__pycache__/unitsystem.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/units/__pycache__/unitsystem.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81d33b41429abd258bb06eb80d75b3e8908ecb8a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/units/__pycache__/unitsystem.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/units/__pycache__/util.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/units/__pycache__/util.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e0d9ca563f4f5bb8dd3665be9557c7a38e4b3c9 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/units/__pycache__/util.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/units/definitions/__init__.py b/phivenv/Lib/site-packages/sympy/physics/units/definitions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..759889695d38c6e78237cc64974da3ecca6425cd --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/units/definitions/__init__.py @@ -0,0 +1,265 @@ +from .unit_definitions import ( + percent, percents, + permille, + rad, radian, radians, + deg, degree, degrees, + sr, steradian, steradians, + mil, angular_mil, angular_mils, + m, meter, meters, + kg, kilogram, kilograms, + s, second, seconds, + A, ampere, amperes, + K, kelvin, kelvins, + mol, mole, moles, + cd, candela, candelas, + g, gram, grams, + mg, milligram, milligrams, + ug, microgram, micrograms, + t, tonne, metric_ton, + newton, newtons, N, + joule, joules, J, + watt, watts, W, + pascal, pascals, Pa, pa, + hertz, hz, Hz, + coulomb, coulombs, C, + volt, volts, v, V, + ohm, ohms, + siemens, S, mho, mhos, + farad, farads, F, + henry, henrys, H, + tesla, teslas, T, + weber, webers, Wb, wb, + optical_power, dioptre, D, + lux, lx, + katal, kat, + gray, Gy, + becquerel, Bq, + km, kilometer, kilometers, + dm, decimeter, decimeters, + cm, centimeter, centimeters, + mm, millimeter, millimeters, + um, micrometer, micrometers, micron, microns, + nm, nanometer, nanometers, + pm, picometer, picometers, + ft, foot, feet, + inch, inches, + yd, yard, yards, + mi, mile, miles, + nmi, nautical_mile, nautical_miles, + ha, hectare, + l, L, liter, liters, + dl, dL, deciliter, deciliters, + cl, cL, centiliter, centiliters, + ml, mL, milliliter, milliliters, + ms, millisecond, milliseconds, + us, microsecond, microseconds, + ns, nanosecond, nanoseconds, + ps, picosecond, picoseconds, + minute, minutes, + h, hour, hours, + day, days, + anomalistic_year, anomalistic_years, + sidereal_year, sidereal_years, + tropical_year, tropical_years, + common_year, common_years, + julian_year, julian_years, + draconic_year, draconic_years, + gaussian_year, gaussian_years, + full_moon_cycle, full_moon_cycles, + year, years, + G, gravitational_constant, + c, speed_of_light, + elementary_charge, + hbar, + planck, + eV, electronvolt, electronvolts, + avogadro_number, + avogadro, avogadro_constant, + boltzmann, boltzmann_constant, + stefan, stefan_boltzmann_constant, + R, molar_gas_constant, + faraday_constant, + josephson_constant, + von_klitzing_constant, + Da, dalton, amu, amus, atomic_mass_unit, atomic_mass_constant, + me, electron_rest_mass, + gee, gees, acceleration_due_to_gravity, + u0, magnetic_constant, vacuum_permeability, + e0, electric_constant, vacuum_permittivity, + Z0, vacuum_impedance, + coulomb_constant, coulombs_constant, electric_force_constant, + atmosphere, atmospheres, atm, + kPa, kilopascal, + bar, bars, + pound, pounds, + psi, + dHg0, + mmHg, torr, + mmu, mmus, milli_mass_unit, + quart, quarts, + angstrom, angstroms, + ly, lightyear, lightyears, + au, astronomical_unit, astronomical_units, + planck_mass, + planck_time, + planck_temperature, + planck_length, + planck_charge, + planck_area, + planck_volume, + planck_momentum, + planck_energy, + planck_force, + planck_power, + planck_density, + planck_energy_density, + planck_intensity, + planck_angular_frequency, + planck_pressure, + planck_current, + planck_voltage, + planck_impedance, + planck_acceleration, + bit, bits, + byte, + kibibyte, kibibytes, + mebibyte, mebibytes, + gibibyte, gibibytes, + tebibyte, tebibytes, + pebibyte, pebibytes, + exbibyte, exbibytes, + curie, rutherford +) + +__all__ = [ + 'percent', 'percents', + 'permille', + 'rad', 'radian', 'radians', + 'deg', 'degree', 'degrees', + 'sr', 'steradian', 'steradians', + 'mil', 'angular_mil', 'angular_mils', + 'm', 'meter', 'meters', + 'kg', 'kilogram', 'kilograms', + 's', 'second', 'seconds', + 'A', 'ampere', 'amperes', + 'K', 'kelvin', 'kelvins', + 'mol', 'mole', 'moles', + 'cd', 'candela', 'candelas', + 'g', 'gram', 'grams', + 'mg', 'milligram', 'milligrams', + 'ug', 'microgram', 'micrograms', + 't', 'tonne', 'metric_ton', + 'newton', 'newtons', 'N', + 'joule', 'joules', 'J', + 'watt', 'watts', 'W', + 'pascal', 'pascals', 'Pa', 'pa', + 'hertz', 'hz', 'Hz', + 'coulomb', 'coulombs', 'C', + 'volt', 'volts', 'v', 'V', + 'ohm', 'ohms', + 'siemens', 'S', 'mho', 'mhos', + 'farad', 'farads', 'F', + 'henry', 'henrys', 'H', + 'tesla', 'teslas', 'T', + 'weber', 'webers', 'Wb', 'wb', + 'optical_power', 'dioptre', 'D', + 'lux', 'lx', + 'katal', 'kat', + 'gray', 'Gy', + 'becquerel', 'Bq', + 'km', 'kilometer', 'kilometers', + 'dm', 'decimeter', 'decimeters', + 'cm', 'centimeter', 'centimeters', + 'mm', 'millimeter', 'millimeters', + 'um', 'micrometer', 'micrometers', 'micron', 'microns', + 'nm', 'nanometer', 'nanometers', + 'pm', 'picometer', 'picometers', + 'ft', 'foot', 'feet', + 'inch', 'inches', + 'yd', 'yard', 'yards', + 'mi', 'mile', 'miles', + 'nmi', 'nautical_mile', 'nautical_miles', + 'ha', 'hectare', + 'l', 'L', 'liter', 'liters', + 'dl', 'dL', 'deciliter', 'deciliters', + 'cl', 'cL', 'centiliter', 'centiliters', + 'ml', 'mL', 'milliliter', 'milliliters', + 'ms', 'millisecond', 'milliseconds', + 'us', 'microsecond', 'microseconds', + 'ns', 'nanosecond', 'nanoseconds', + 'ps', 'picosecond', 'picoseconds', + 'minute', 'minutes', + 'h', 'hour', 'hours', + 'day', 'days', + 'anomalistic_year', 'anomalistic_years', + 'sidereal_year', 'sidereal_years', + 'tropical_year', 'tropical_years', + 'common_year', 'common_years', + 'julian_year', 'julian_years', + 'draconic_year', 'draconic_years', + 'gaussian_year', 'gaussian_years', + 'full_moon_cycle', 'full_moon_cycles', + 'year', 'years', + 'G', 'gravitational_constant', + 'c', 'speed_of_light', + 'elementary_charge', + 'hbar', + 'planck', + 'eV', 'electronvolt', 'electronvolts', + 'avogadro_number', + 'avogadro', 'avogadro_constant', + 'boltzmann', 'boltzmann_constant', + 'stefan', 'stefan_boltzmann_constant', + 'R', 'molar_gas_constant', + 'faraday_constant', + 'josephson_constant', + 'von_klitzing_constant', + 'Da', 'dalton', 'amu', 'amus', 'atomic_mass_unit', 'atomic_mass_constant', + 'me', 'electron_rest_mass', + 'gee', 'gees', 'acceleration_due_to_gravity', + 'u0', 'magnetic_constant', 'vacuum_permeability', + 'e0', 'electric_constant', 'vacuum_permittivity', + 'Z0', 'vacuum_impedance', + 'coulomb_constant', 'coulombs_constant', 'electric_force_constant', + 'atmosphere', 'atmospheres', 'atm', + 'kPa', 'kilopascal', + 'bar', 'bars', + 'pound', 'pounds', + 'psi', + 'dHg0', + 'mmHg', 'torr', + 'mmu', 'mmus', 'milli_mass_unit', + 'quart', 'quarts', + 'angstrom', 'angstroms', + 'ly', 'lightyear', 'lightyears', + 'au', 'astronomical_unit', 'astronomical_units', + 'planck_mass', + 'planck_time', + 'planck_temperature', + 'planck_length', + 'planck_charge', + 'planck_area', + 'planck_volume', + 'planck_momentum', + 'planck_energy', + 'planck_force', + 'planck_power', + 'planck_density', + 'planck_energy_density', + 'planck_intensity', + 'planck_angular_frequency', + 'planck_pressure', + 'planck_current', + 'planck_voltage', + 'planck_impedance', + 'planck_acceleration', + 'bit', 'bits', + 'byte', + 'kibibyte', 'kibibytes', + 'mebibyte', 'mebibytes', + 'gibibyte', 'gibibytes', + 'tebibyte', 'tebibytes', + 'pebibyte', 'pebibytes', + 'exbibyte', 'exbibytes', + 'curie', 'rutherford', +] diff --git a/phivenv/Lib/site-packages/sympy/physics/units/definitions/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/units/definitions/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ba3df2bcb53026517efd238f01e2aef2387200c Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/units/definitions/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/units/definitions/__pycache__/dimension_definitions.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/units/definitions/__pycache__/dimension_definitions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d4276df97ca5b90c8d7054cd2686b62f3522bfe Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/units/definitions/__pycache__/dimension_definitions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/units/definitions/__pycache__/unit_definitions.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/units/definitions/__pycache__/unit_definitions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a5c976ea0ea731a6cd32c7f815bb629900d8c61 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/units/definitions/__pycache__/unit_definitions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/units/definitions/dimension_definitions.py b/phivenv/Lib/site-packages/sympy/physics/units/definitions/dimension_definitions.py new file mode 100644 index 0000000000000000000000000000000000000000..b2b5f1dee01f9107d966a99d1e616c79a070a5b8 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/units/definitions/dimension_definitions.py @@ -0,0 +1,43 @@ +from sympy.physics.units import Dimension + + +angle: Dimension = Dimension(name="angle") + +# base dimensions (MKS) +length = Dimension(name="length", symbol="L") +mass = Dimension(name="mass", symbol="M") +time = Dimension(name="time", symbol="T") + +# base dimensions (MKSA not in MKS) +current: Dimension = Dimension(name='current', symbol='I') + +# other base dimensions: +temperature: Dimension = Dimension("temperature", "T") +amount_of_substance: Dimension = Dimension("amount_of_substance") +luminous_intensity: Dimension = Dimension("luminous_intensity") + +# derived dimensions (MKS) +velocity = Dimension(name="velocity") +acceleration = Dimension(name="acceleration") +momentum = Dimension(name="momentum") +force = Dimension(name="force", symbol="F") +energy = Dimension(name="energy", symbol="E") +power = Dimension(name="power") +pressure = Dimension(name="pressure") +frequency = Dimension(name="frequency", symbol="f") +action = Dimension(name="action", symbol="A") +area = Dimension("area") +volume = Dimension("volume") + +# derived dimensions (MKSA not in MKS) +voltage: Dimension = Dimension(name='voltage', symbol='U') +impedance: Dimension = Dimension(name='impedance', symbol='Z') +conductance: Dimension = Dimension(name='conductance', symbol='G') +capacitance: Dimension = Dimension(name='capacitance') +inductance: Dimension = Dimension(name='inductance') +charge: Dimension = Dimension(name='charge', symbol='Q') +magnetic_density: Dimension = Dimension(name='magnetic_density', symbol='B') +magnetic_flux: Dimension = Dimension(name='magnetic_flux') + +# Dimensions in information theory: +information: Dimension = Dimension(name='information') diff --git a/phivenv/Lib/site-packages/sympy/physics/units/definitions/unit_definitions.py b/phivenv/Lib/site-packages/sympy/physics/units/definitions/unit_definitions.py new file mode 100644 index 0000000000000000000000000000000000000000..c0a89802a444a40172a0dc70094321f07a7e396b --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/units/definitions/unit_definitions.py @@ -0,0 +1,407 @@ +from sympy.physics.units.definitions.dimension_definitions import current, temperature, amount_of_substance, \ + luminous_intensity, angle, charge, voltage, impedance, conductance, capacitance, inductance, magnetic_density, \ + magnetic_flux, information + +from sympy.core.numbers import (Rational, pi) +from sympy.core.singleton import S as S_singleton +from sympy.physics.units.prefixes import kilo, mega, milli, micro, deci, centi, nano, pico, kibi, mebi, gibi, tebi, pebi, exbi +from sympy.physics.units.quantities import PhysicalConstant, Quantity + +One = S_singleton.One + +#### UNITS #### + +# Dimensionless: +percent = percents = Quantity("percent", latex_repr=r"\%") +percent.set_global_relative_scale_factor(Rational(1, 100), One) + +permille = Quantity("permille") +permille.set_global_relative_scale_factor(Rational(1, 1000), One) + + +# Angular units (dimensionless) +rad = radian = radians = Quantity("radian", abbrev="rad") +radian.set_global_dimension(angle) +deg = degree = degrees = Quantity("degree", abbrev="deg", latex_repr=r"^\circ") +degree.set_global_relative_scale_factor(pi/180, radian) +sr = steradian = steradians = Quantity("steradian", abbrev="sr") +mil = angular_mil = angular_mils = Quantity("angular_mil", abbrev="mil") + +# Base units: +m = meter = meters = Quantity("meter", abbrev="m") + +# gram; used to define its prefixed units +g = gram = grams = Quantity("gram", abbrev="g") + +# NOTE: the `kilogram` has scale factor 1000. In SI, kg is a base unit, but +# nonetheless we are trying to be compatible with the `kilo` prefix. In a +# similar manner, people using CGS or gaussian units could argue that the +# `centimeter` rather than `meter` is the fundamental unit for length, but the +# scale factor of `centimeter` will be kept as 1/100 to be compatible with the +# `centi` prefix. The current state of the code assumes SI unit dimensions, in +# the future this module will be modified in order to be unit system-neutral +# (that is, support all kinds of unit systems). +kg = kilogram = kilograms = Quantity("kilogram", abbrev="kg") +kg.set_global_relative_scale_factor(kilo, gram) + +s = second = seconds = Quantity("second", abbrev="s") +A = ampere = amperes = Quantity("ampere", abbrev='A') +ampere.set_global_dimension(current) +K = kelvin = kelvins = Quantity("kelvin", abbrev='K') +kelvin.set_global_dimension(temperature) +mol = mole = moles = Quantity("mole", abbrev="mol") +mole.set_global_dimension(amount_of_substance) +cd = candela = candelas = Quantity("candela", abbrev="cd") +candela.set_global_dimension(luminous_intensity) + +# derived units +newton = newtons = N = Quantity("newton", abbrev="N") + +kilonewton = kilonewtons = kN = Quantity("kilonewton", abbrev="kN") +kilonewton.set_global_relative_scale_factor(kilo, newton) + +meganewton = meganewtons = MN = Quantity("meganewton", abbrev="MN") +meganewton.set_global_relative_scale_factor(mega, newton) + +joule = joules = J = Quantity("joule", abbrev="J") +watt = watts = W = Quantity("watt", abbrev="W") +pascal = pascals = Pa = pa = Quantity("pascal", abbrev="Pa") +hertz = hz = Hz = Quantity("hertz", abbrev="Hz") + +# CGS derived units: +dyne = Quantity("dyne") +dyne.set_global_relative_scale_factor(One/10**5, newton) +erg = Quantity("erg") +erg.set_global_relative_scale_factor(One/10**7, joule) + +# MKSA extension to MKS: derived units +coulomb = coulombs = C = Quantity("coulomb", abbrev='C') +coulomb.set_global_dimension(charge) +volt = volts = v = V = Quantity("volt", abbrev='V') +volt.set_global_dimension(voltage) +ohm = ohms = Quantity("ohm", abbrev='ohm', latex_repr=r"\Omega") +ohm.set_global_dimension(impedance) +siemens = S = mho = mhos = Quantity("siemens", abbrev='S') +siemens.set_global_dimension(conductance) +farad = farads = F = Quantity("farad", abbrev='F') +farad.set_global_dimension(capacitance) +henry = henrys = H = Quantity("henry", abbrev='H') +henry.set_global_dimension(inductance) +tesla = teslas = T = Quantity("tesla", abbrev='T') +tesla.set_global_dimension(magnetic_density) +weber = webers = Wb = wb = Quantity("weber", abbrev='Wb') +weber.set_global_dimension(magnetic_flux) + +# CGS units for electromagnetic quantities: +statampere = Quantity("statampere") +statcoulomb = statC = franklin = Quantity("statcoulomb", abbrev="statC") +statvolt = Quantity("statvolt") +gauss = Quantity("gauss") +maxwell = Quantity("maxwell") +debye = Quantity("debye") +oersted = Quantity("oersted") + +# Other derived units: +optical_power = dioptre = diopter = D = Quantity("dioptre") +lux = lx = Quantity("lux", abbrev="lx") + +# katal is the SI unit of catalytic activity +katal = kat = Quantity("katal", abbrev="kat") + +# gray is the SI unit of absorbed dose +gray = Gy = Quantity("gray") + +# becquerel is the SI unit of radioactivity +becquerel = Bq = Quantity("becquerel", abbrev="Bq") + + +# Common mass units + +mg = milligram = milligrams = Quantity("milligram", abbrev="mg") +mg.set_global_relative_scale_factor(milli, gram) + +ug = microgram = micrograms = Quantity("microgram", abbrev="ug", latex_repr=r"\mu\text{g}") +ug.set_global_relative_scale_factor(micro, gram) + +# Atomic mass constant +Da = dalton = amu = amus = atomic_mass_unit = atomic_mass_constant = PhysicalConstant("atomic_mass_constant") + +t = metric_ton = tonne = Quantity("tonne", abbrev="t") +tonne.set_global_relative_scale_factor(mega, gram) + +# Electron rest mass +me = electron_rest_mass = Quantity("electron_rest_mass", abbrev="me") + + +# Common length units + +km = kilometer = kilometers = Quantity("kilometer", abbrev="km") +km.set_global_relative_scale_factor(kilo, meter) + +dm = decimeter = decimeters = Quantity("decimeter", abbrev="dm") +dm.set_global_relative_scale_factor(deci, meter) + +cm = centimeter = centimeters = Quantity("centimeter", abbrev="cm") +cm.set_global_relative_scale_factor(centi, meter) + +mm = millimeter = millimeters = Quantity("millimeter", abbrev="mm") +mm.set_global_relative_scale_factor(milli, meter) + +um = micrometer = micrometers = micron = microns = \ + Quantity("micrometer", abbrev="um", latex_repr=r'\mu\text{m}') +um.set_global_relative_scale_factor(micro, meter) + +nm = nanometer = nanometers = Quantity("nanometer", abbrev="nm") +nm.set_global_relative_scale_factor(nano, meter) + +pm = picometer = picometers = Quantity("picometer", abbrev="pm") +pm.set_global_relative_scale_factor(pico, meter) + +ft = foot = feet = Quantity("foot", abbrev="ft") +ft.set_global_relative_scale_factor(Rational(3048, 10000), meter) + +inch = inches = Quantity("inch") +inch.set_global_relative_scale_factor(Rational(1, 12), foot) + +yd = yard = yards = Quantity("yard", abbrev="yd") +yd.set_global_relative_scale_factor(3, feet) + +mi = mile = miles = Quantity("mile") +mi.set_global_relative_scale_factor(5280, feet) + +nmi = nautical_mile = nautical_miles = Quantity("nautical_mile") +nmi.set_global_relative_scale_factor(6076, feet) + +angstrom = angstroms = Quantity("angstrom", latex_repr=r'\r{A}') +angstrom.set_global_relative_scale_factor(Rational(1, 10**10), meter) + + +# Common volume and area units + +ha = hectare = Quantity("hectare", abbrev="ha") + +l = L = liter = liters = Quantity("liter", abbrev="l") + +dl = dL = deciliter = deciliters = Quantity("deciliter", abbrev="dl") +dl.set_global_relative_scale_factor(Rational(1, 10), liter) + +cl = cL = centiliter = centiliters = Quantity("centiliter", abbrev="cl") +cl.set_global_relative_scale_factor(Rational(1, 100), liter) + +ml = mL = milliliter = milliliters = Quantity("milliliter", abbrev="ml") +ml.set_global_relative_scale_factor(Rational(1, 1000), liter) + + +# Common time units + +ms = millisecond = milliseconds = Quantity("millisecond", abbrev="ms") +millisecond.set_global_relative_scale_factor(milli, second) + +us = microsecond = microseconds = Quantity("microsecond", abbrev="us", latex_repr=r'\mu\text{s}') +microsecond.set_global_relative_scale_factor(micro, second) + +ns = nanosecond = nanoseconds = Quantity("nanosecond", abbrev="ns") +nanosecond.set_global_relative_scale_factor(nano, second) + +ps = picosecond = picoseconds = Quantity("picosecond", abbrev="ps") +picosecond.set_global_relative_scale_factor(pico, second) + +minute = minutes = Quantity("minute") +minute.set_global_relative_scale_factor(60, second) + +h = hour = hours = Quantity("hour") +hour.set_global_relative_scale_factor(60, minute) + +day = days = Quantity("day") +day.set_global_relative_scale_factor(24, hour) + +anomalistic_year = anomalistic_years = Quantity("anomalistic_year") +anomalistic_year.set_global_relative_scale_factor(365.259636, day) + +sidereal_year = sidereal_years = Quantity("sidereal_year") +sidereal_year.set_global_relative_scale_factor(31558149.540, seconds) + +tropical_year = tropical_years = Quantity("tropical_year") +tropical_year.set_global_relative_scale_factor(365.24219, day) + +common_year = common_years = Quantity("common_year") +common_year.set_global_relative_scale_factor(365, day) + +julian_year = julian_years = Quantity("julian_year") +julian_year.set_global_relative_scale_factor((365 + One/4), day) + +draconic_year = draconic_years = Quantity("draconic_year") +draconic_year.set_global_relative_scale_factor(346.62, day) + +gaussian_year = gaussian_years = Quantity("gaussian_year") +gaussian_year.set_global_relative_scale_factor(365.2568983, day) + +full_moon_cycle = full_moon_cycles = Quantity("full_moon_cycle") +full_moon_cycle.set_global_relative_scale_factor(411.78443029, day) + +year = years = tropical_year + + +#### CONSTANTS #### + +# Newton constant +G = gravitational_constant = PhysicalConstant("gravitational_constant", abbrev="G") + +# speed of light +c = speed_of_light = PhysicalConstant("speed_of_light", abbrev="c") + +# elementary charge +elementary_charge = PhysicalConstant("elementary_charge", abbrev="e") + +# Planck constant +planck = PhysicalConstant("planck", abbrev="h") + +# Reduced Planck constant +hbar = PhysicalConstant("hbar", abbrev="hbar") + +# Electronvolt +eV = electronvolt = electronvolts = PhysicalConstant("electronvolt", abbrev="eV") + +# Avogadro number +avogadro_number = PhysicalConstant("avogadro_number") + +# Avogadro constant +avogadro = avogadro_constant = PhysicalConstant("avogadro_constant") + +# Boltzmann constant +boltzmann = boltzmann_constant = PhysicalConstant("boltzmann_constant") + +# Stefan-Boltzmann constant +stefan = stefan_boltzmann_constant = PhysicalConstant("stefan_boltzmann_constant") + +# Molar gas constant +R = molar_gas_constant = PhysicalConstant("molar_gas_constant", abbrev="R") + +# Faraday constant +faraday_constant = PhysicalConstant("faraday_constant") + +# Josephson constant +josephson_constant = PhysicalConstant("josephson_constant", abbrev="K_j") + +# Von Klitzing constant +von_klitzing_constant = PhysicalConstant("von_klitzing_constant", abbrev="R_k") + +# Acceleration due to gravity (on the Earth surface) +gee = gees = acceleration_due_to_gravity = PhysicalConstant("acceleration_due_to_gravity", abbrev="g") + +# magnetic constant: +u0 = magnetic_constant = vacuum_permeability = PhysicalConstant("magnetic_constant") + +# electric constat: +e0 = electric_constant = vacuum_permittivity = PhysicalConstant("vacuum_permittivity") + +# vacuum impedance: +Z0 = vacuum_impedance = PhysicalConstant("vacuum_impedance", abbrev='Z_0', latex_repr=r'Z_{0}') + +# Coulomb's constant: +coulomb_constant = coulombs_constant = electric_force_constant = \ + PhysicalConstant("coulomb_constant", abbrev="k_e") + + +atmosphere = atmospheres = atm = Quantity("atmosphere", abbrev="atm") + +kPa = kilopascal = Quantity("kilopascal", abbrev="kPa") +kilopascal.set_global_relative_scale_factor(kilo, Pa) + +bar = bars = Quantity("bar", abbrev="bar") + +pound = pounds = Quantity("pound") # exact + +psi = Quantity("psi") + +dHg0 = 13.5951 # approx value at 0 C +mmHg = torr = Quantity("mmHg") + +atmosphere.set_global_relative_scale_factor(101325, pascal) +bar.set_global_relative_scale_factor(100, kPa) +pound.set_global_relative_scale_factor(Rational(45359237, 100000000), kg) + +mmu = mmus = milli_mass_unit = Quantity("milli_mass_unit") + +quart = quarts = Quantity("quart") + + +# Other convenient units and magnitudes + +ly = lightyear = lightyears = Quantity("lightyear", abbrev="ly") + +au = astronomical_unit = astronomical_units = Quantity("astronomical_unit", abbrev="AU") + + +# Fundamental Planck units: +planck_mass = Quantity("planck_mass", abbrev="m_P", latex_repr=r'm_\text{P}') + +planck_time = Quantity("planck_time", abbrev="t_P", latex_repr=r't_\text{P}') + +planck_temperature = Quantity("planck_temperature", abbrev="T_P", + latex_repr=r'T_\text{P}') + +planck_length = Quantity("planck_length", abbrev="l_P", latex_repr=r'l_\text{P}') + +planck_charge = Quantity("planck_charge", abbrev="q_P", latex_repr=r'q_\text{P}') + + +# Derived Planck units: +planck_area = Quantity("planck_area") + +planck_volume = Quantity("planck_volume") + +planck_momentum = Quantity("planck_momentum") + +planck_energy = Quantity("planck_energy", abbrev="E_P", latex_repr=r'E_\text{P}') + +planck_force = Quantity("planck_force", abbrev="F_P", latex_repr=r'F_\text{P}') + +planck_power = Quantity("planck_power", abbrev="P_P", latex_repr=r'P_\text{P}') + +planck_density = Quantity("planck_density", abbrev="rho_P", latex_repr=r'\rho_\text{P}') + +planck_energy_density = Quantity("planck_energy_density", abbrev="rho^E_P") + +planck_intensity = Quantity("planck_intensity", abbrev="I_P", latex_repr=r'I_\text{P}') + +planck_angular_frequency = Quantity("planck_angular_frequency", abbrev="omega_P", + latex_repr=r'\omega_\text{P}') + +planck_pressure = Quantity("planck_pressure", abbrev="p_P", latex_repr=r'p_\text{P}') + +planck_current = Quantity("planck_current", abbrev="I_P", latex_repr=r'I_\text{P}') + +planck_voltage = Quantity("planck_voltage", abbrev="V_P", latex_repr=r'V_\text{P}') + +planck_impedance = Quantity("planck_impedance", abbrev="Z_P", latex_repr=r'Z_\text{P}') + +planck_acceleration = Quantity("planck_acceleration", abbrev="a_P", + latex_repr=r'a_\text{P}') + + +# Information theory units: +bit = bits = Quantity("bit") +bit.set_global_dimension(information) + +byte = bytes = Quantity("byte") + +kibibyte = kibibytes = Quantity("kibibyte") +mebibyte = mebibytes = Quantity("mebibyte") +gibibyte = gibibytes = Quantity("gibibyte") +tebibyte = tebibytes = Quantity("tebibyte") +pebibyte = pebibytes = Quantity("pebibyte") +exbibyte = exbibytes = Quantity("exbibyte") + +byte.set_global_relative_scale_factor(8, bit) +kibibyte.set_global_relative_scale_factor(kibi, byte) +mebibyte.set_global_relative_scale_factor(mebi, byte) +gibibyte.set_global_relative_scale_factor(gibi, byte) +tebibyte.set_global_relative_scale_factor(tebi, byte) +pebibyte.set_global_relative_scale_factor(pebi, byte) +exbibyte.set_global_relative_scale_factor(exbi, byte) + +# Older units for radioactivity +curie = Ci = Quantity("curie", abbrev="Ci") + +rutherford = Rd = Quantity("rutherford", abbrev="Rd") diff --git a/phivenv/Lib/site-packages/sympy/physics/units/dimensions.py b/phivenv/Lib/site-packages/sympy/physics/units/dimensions.py new file mode 100644 index 0000000000000000000000000000000000000000..de42912edca025a6cb53d457fd3e03d8fa30931e --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/units/dimensions.py @@ -0,0 +1,590 @@ +""" +Definition of physical dimensions. + +Unit systems will be constructed on top of these dimensions. + +Most of the examples in the doc use MKS system and are presented from the +computer point of view: from a human point, adding length to time is not legal +in MKS but it is in natural system; for a computer in natural system there is +no time dimension (but a velocity dimension instead) - in the basis - so the +question of adding time to length has no meaning. +""" + +from __future__ import annotations + +import collections +from functools import reduce + +from sympy.core.basic import Basic +from sympy.core.containers import (Dict, Tuple) +from sympy.core.singleton import S +from sympy.core.sorting import default_sort_key +from sympy.core.symbol import Symbol +from sympy.core.sympify import sympify +from sympy.matrices.dense import Matrix +from sympy.functions.elementary.trigonometric import TrigonometricFunction +from sympy.core.expr import Expr +from sympy.core.power import Pow + + +class _QuantityMapper: + + _quantity_scale_factors_global: dict[Expr, Expr] = {} + _quantity_dimensional_equivalence_map_global: dict[Expr, Expr] = {} + _quantity_dimension_global: dict[Expr, Expr] = {} + + def __init__(self, *args, **kwargs): + self._quantity_dimension_map = {} + self._quantity_scale_factors = {} + + def set_quantity_dimension(self, quantity, dimension): + """ + Set the dimension for the quantity in a unit system. + + If this relation is valid in every unit system, use + ``quantity.set_global_dimension(dimension)`` instead. + """ + from sympy.physics.units import Quantity + dimension = sympify(dimension) + if not isinstance(dimension, Dimension): + if dimension == 1: + dimension = Dimension(1) + else: + raise ValueError("expected dimension or 1") + elif isinstance(dimension, Quantity): + dimension = self.get_quantity_dimension(dimension) + self._quantity_dimension_map[quantity] = dimension + + def set_quantity_scale_factor(self, quantity, scale_factor): + """ + Set the scale factor of a quantity relative to another quantity. + + It should be used only once per quantity to just one other quantity, + the algorithm will then be able to compute the scale factors to all + other quantities. + + In case the scale factor is valid in every unit system, please use + ``quantity.set_global_relative_scale_factor(scale_factor)`` instead. + """ + from sympy.physics.units import Quantity + from sympy.physics.units.prefixes import Prefix + scale_factor = sympify(scale_factor) + # replace all prefixes by their ratio to canonical units: + scale_factor = scale_factor.replace( + lambda x: isinstance(x, Prefix), + lambda x: x.scale_factor + ) + # replace all quantities by their ratio to canonical units: + scale_factor = scale_factor.replace( + lambda x: isinstance(x, Quantity), + lambda x: self.get_quantity_scale_factor(x) + ) + self._quantity_scale_factors[quantity] = scale_factor + + def get_quantity_dimension(self, unit): + from sympy.physics.units import Quantity + # First look-up the local dimension map, then the global one: + if unit in self._quantity_dimension_map: + return self._quantity_dimension_map[unit] + if unit in self._quantity_dimension_global: + return self._quantity_dimension_global[unit] + if unit in self._quantity_dimensional_equivalence_map_global: + dep_unit = self._quantity_dimensional_equivalence_map_global[unit] + if isinstance(dep_unit, Quantity): + return self.get_quantity_dimension(dep_unit) + else: + return Dimension(self.get_dimensional_expr(dep_unit)) + if isinstance(unit, Quantity): + return Dimension(unit.name) + else: + return Dimension(1) + + def get_quantity_scale_factor(self, unit): + if unit in self._quantity_scale_factors: + return self._quantity_scale_factors[unit] + if unit in self._quantity_scale_factors_global: + mul_factor, other_unit = self._quantity_scale_factors_global[unit] + return mul_factor*self.get_quantity_scale_factor(other_unit) + return S.One + + +class Dimension(Expr): + """ + This class represent the dimension of a physical quantities. + + The ``Dimension`` constructor takes as parameters a name and an optional + symbol. + + For example, in classical mechanics we know that time is different from + temperature and dimensions make this difference (but they do not provide + any measure of these quantities. + + >>> from sympy.physics.units import Dimension + >>> length = Dimension('length') + >>> length + Dimension(length) + >>> time = Dimension('time') + >>> time + Dimension(time) + + Dimensions can be composed using multiplication, division and + exponentiation (by a number) to give new dimensions. Addition and + subtraction is defined only when the two objects are the same dimension. + + >>> velocity = length / time + >>> velocity + Dimension(length/time) + + It is possible to use a dimension system object to get the dimensionsal + dependencies of a dimension, for example the dimension system used by the + SI units convention can be used: + + >>> from sympy.physics.units.systems.si import dimsys_SI + >>> dimsys_SI.get_dimensional_dependencies(velocity) + {Dimension(length, L): 1, Dimension(time, T): -1} + >>> length + length + Dimension(length) + >>> l2 = length**2 + >>> l2 + Dimension(length**2) + >>> dimsys_SI.get_dimensional_dependencies(l2) + {Dimension(length, L): 2} + + """ + + _op_priority = 13.0 + + # XXX: This doesn't seem to be used anywhere... + _dimensional_dependencies = {} # type: ignore + + is_commutative = True + is_number = False + # make sqrt(M**2) --> M + is_positive = True + is_real = True + + def __new__(cls, name, symbol=None): + + if isinstance(name, str): + name = Symbol(name) + else: + name = sympify(name) + + if not isinstance(name, Expr): + raise TypeError("Dimension name needs to be a valid math expression") + + if isinstance(symbol, str): + symbol = Symbol(symbol) + elif symbol is not None: + assert isinstance(symbol, Symbol) + + obj = Expr.__new__(cls, name) + + obj._name = name + obj._symbol = symbol + return obj + + @property + def name(self): + return self._name + + @property + def symbol(self): + return self._symbol + + def __str__(self): + """ + Display the string representation of the dimension. + """ + if self.symbol is None: + return "Dimension(%s)" % (self.name) + else: + return "Dimension(%s, %s)" % (self.name, self.symbol) + + def __repr__(self): + return self.__str__() + + def __neg__(self): + return self + + def __add__(self, other): + from sympy.physics.units.quantities import Quantity + other = sympify(other) + if isinstance(other, Basic): + if other.has(Quantity): + raise TypeError("cannot sum dimension and quantity") + if isinstance(other, Dimension) and self == other: + return self + return super().__add__(other) + return self + + def __radd__(self, other): + return self.__add__(other) + + def __sub__(self, other): + # there is no notion of ordering (or magnitude) among dimension, + # subtraction is equivalent to addition when the operation is legal + return self + other + + def __rsub__(self, other): + # there is no notion of ordering (or magnitude) among dimension, + # subtraction is equivalent to addition when the operation is legal + return self + other + + def __pow__(self, other): + return self._eval_power(other) + + def _eval_power(self, other): + other = sympify(other) + return Dimension(self.name**other) + + def __mul__(self, other): + from sympy.physics.units.quantities import Quantity + if isinstance(other, Basic): + if other.has(Quantity): + raise TypeError("cannot sum dimension and quantity") + if isinstance(other, Dimension): + return Dimension(self.name*other.name) + if not other.free_symbols: # other.is_number cannot be used + return self + return super().__mul__(other) + return self + + def __rmul__(self, other): + return self.__mul__(other) + + def __truediv__(self, other): + return self*Pow(other, -1) + + def __rtruediv__(self, other): + return other * pow(self, -1) + + @classmethod + def _from_dimensional_dependencies(cls, dependencies): + return reduce(lambda x, y: x * y, ( + d**e for d, e in dependencies.items() + ), 1) + + def has_integer_powers(self, dim_sys): + """ + Check if the dimension object has only integer powers. + + All the dimension powers should be integers, but rational powers may + appear in intermediate steps. This method may be used to check that the + final result is well-defined. + """ + + return all(dpow.is_Integer for dpow in dim_sys.get_dimensional_dependencies(self).values()) + + +# Create dimensions according to the base units in MKSA. +# For other unit systems, they can be derived by transforming the base +# dimensional dependency dictionary. + + +class DimensionSystem(Basic, _QuantityMapper): + r""" + DimensionSystem represents a coherent set of dimensions. + + The constructor takes three parameters: + + - base dimensions; + - derived dimensions: these are defined in terms of the base dimensions + (for example velocity is defined from the division of length by time); + - dependency of dimensions: how the derived dimensions depend + on the base dimensions. + + Optionally either the ``derived_dims`` or the ``dimensional_dependencies`` + may be omitted. + """ + + def __new__(cls, base_dims, derived_dims=(), dimensional_dependencies={}): + dimensional_dependencies = dict(dimensional_dependencies) + + def parse_dim(dim): + if isinstance(dim, str): + dim = Dimension(Symbol(dim)) + elif isinstance(dim, Dimension): + pass + elif isinstance(dim, Symbol): + dim = Dimension(dim) + else: + raise TypeError("%s wrong type" % dim) + return dim + + base_dims = [parse_dim(i) for i in base_dims] + derived_dims = [parse_dim(i) for i in derived_dims] + + for dim in base_dims: + if (dim in dimensional_dependencies + and (len(dimensional_dependencies[dim]) != 1 or + dimensional_dependencies[dim].get(dim, None) != 1)): + raise IndexError("Repeated value in base dimensions") + dimensional_dependencies[dim] = Dict({dim: 1}) + + def parse_dim_name(dim): + if isinstance(dim, Dimension): + return dim + elif isinstance(dim, str): + return Dimension(Symbol(dim)) + elif isinstance(dim, Symbol): + return Dimension(dim) + else: + raise TypeError("unrecognized type %s for %s" % (type(dim), dim)) + + for dim in dimensional_dependencies.keys(): + dim = parse_dim(dim) + if (dim not in derived_dims) and (dim not in base_dims): + derived_dims.append(dim) + + def parse_dict(d): + return Dict({parse_dim_name(i): j for i, j in d.items()}) + + # Make sure everything is a SymPy type: + dimensional_dependencies = {parse_dim_name(i): parse_dict(j) for i, j in + dimensional_dependencies.items()} + + for dim in derived_dims: + if dim in base_dims: + raise ValueError("Dimension %s both in base and derived" % dim) + if dim not in dimensional_dependencies: + # TODO: should this raise a warning? + dimensional_dependencies[dim] = Dict({dim: 1}) + + base_dims.sort(key=default_sort_key) + derived_dims.sort(key=default_sort_key) + + base_dims = Tuple(*base_dims) + derived_dims = Tuple(*derived_dims) + dimensional_dependencies = Dict({i: Dict(j) for i, j in dimensional_dependencies.items()}) + obj = Basic.__new__(cls, base_dims, derived_dims, dimensional_dependencies) + return obj + + @property + def base_dims(self): + return self.args[0] + + @property + def derived_dims(self): + return self.args[1] + + @property + def dimensional_dependencies(self): + return self.args[2] + + def _get_dimensional_dependencies_for_name(self, dimension): + if isinstance(dimension, str): + dimension = Dimension(Symbol(dimension)) + elif not isinstance(dimension, Dimension): + dimension = Dimension(dimension) + + if dimension.name.is_Symbol: + # Dimensions not included in the dependencies are considered + # as base dimensions: + return dict(self.dimensional_dependencies.get(dimension, {dimension: 1})) + + if dimension.name.is_number or dimension.name.is_NumberSymbol: + return {} + + get_for_name = self._get_dimensional_dependencies_for_name + + if dimension.name.is_Mul: + ret = collections.defaultdict(int) + dicts = [get_for_name(i) for i in dimension.name.args] + for d in dicts: + for k, v in d.items(): + ret[k] += v + return {k: v for (k, v) in ret.items() if v != 0} + + if dimension.name.is_Add: + dicts = [get_for_name(i) for i in dimension.name.args] + if all(d == dicts[0] for d in dicts[1:]): + return dicts[0] + raise TypeError("Only equivalent dimensions can be added or subtracted.") + + if dimension.name.is_Pow: + dim_base = get_for_name(dimension.name.base) + dim_exp = get_for_name(dimension.name.exp) + if dim_exp == {} or dimension.name.exp.is_Symbol: + return {k: v * dimension.name.exp for (k, v) in dim_base.items()} + else: + raise TypeError("The exponent for the power operator must be a Symbol or dimensionless.") + + if dimension.name.is_Function: + args = (Dimension._from_dimensional_dependencies( + get_for_name(arg)) for arg in dimension.name.args) + result = dimension.name.func(*args) + + dicts = [get_for_name(i) for i in dimension.name.args] + + if isinstance(result, Dimension): + return self.get_dimensional_dependencies(result) + elif result.func == dimension.name.func: + if isinstance(dimension.name, TrigonometricFunction): + if dicts[0] in ({}, {Dimension('angle'): 1}): + return {} + else: + raise TypeError("The input argument for the function {} must be dimensionless or have dimensions of angle.".format(dimension.func)) + else: + if all(item == {} for item in dicts): + return {} + else: + raise TypeError("The input arguments for the function {} must be dimensionless.".format(dimension.func)) + else: + return get_for_name(result) + + raise TypeError("Type {} not implemented for get_dimensional_dependencies".format(type(dimension.name))) + + def get_dimensional_dependencies(self, name, mark_dimensionless=False): + dimdep = self._get_dimensional_dependencies_for_name(name) + if mark_dimensionless and dimdep == {}: + return {Dimension(1): 1} + return dict(dimdep.items()) + + def equivalent_dims(self, dim1, dim2): + deps1 = self.get_dimensional_dependencies(dim1) + deps2 = self.get_dimensional_dependencies(dim2) + return deps1 == deps2 + + def extend(self, new_base_dims, new_derived_dims=(), new_dim_deps=None): + deps = dict(self.dimensional_dependencies) + if new_dim_deps: + deps.update(new_dim_deps) + + new_dim_sys = DimensionSystem( + tuple(self.base_dims) + tuple(new_base_dims), + tuple(self.derived_dims) + tuple(new_derived_dims), + deps + ) + new_dim_sys._quantity_dimension_map.update(self._quantity_dimension_map) + new_dim_sys._quantity_scale_factors.update(self._quantity_scale_factors) + return new_dim_sys + + def is_dimensionless(self, dimension): + """ + Check if the dimension object really has a dimension. + + A dimension should have at least one component with non-zero power. + """ + if dimension.name == 1: + return True + return self.get_dimensional_dependencies(dimension) == {} + + @property + def list_can_dims(self): + """ + Useless method, kept for compatibility with previous versions. + + DO NOT USE. + + List all canonical dimension names. + """ + dimset = set() + for i in self.base_dims: + dimset.update(set(self.get_dimensional_dependencies(i).keys())) + return tuple(sorted(dimset, key=str)) + + @property + def inv_can_transf_matrix(self): + """ + Useless method, kept for compatibility with previous versions. + + DO NOT USE. + + Compute the inverse transformation matrix from the base to the + canonical dimension basis. + + It corresponds to the matrix where columns are the vector of base + dimensions in canonical basis. + + This matrix will almost never be used because dimensions are always + defined with respect to the canonical basis, so no work has to be done + to get them in this basis. Nonetheless if this matrix is not square + (or not invertible) it means that we have chosen a bad basis. + """ + matrix = reduce(lambda x, y: x.row_join(y), + [self.dim_can_vector(d) for d in self.base_dims]) + return matrix + + @property + def can_transf_matrix(self): + """ + Useless method, kept for compatibility with previous versions. + + DO NOT USE. + + Return the canonical transformation matrix from the canonical to the + base dimension basis. + + It is the inverse of the matrix computed with inv_can_transf_matrix(). + """ + + #TODO: the inversion will fail if the system is inconsistent, for + # example if the matrix is not a square + return reduce(lambda x, y: x.row_join(y), + [self.dim_can_vector(d) for d in sorted(self.base_dims, key=str)] + ).inv() + + def dim_can_vector(self, dim): + """ + Useless method, kept for compatibility with previous versions. + + DO NOT USE. + + Dimensional representation in terms of the canonical base dimensions. + """ + + vec = [] + for d in self.list_can_dims: + vec.append(self.get_dimensional_dependencies(dim).get(d, 0)) + return Matrix(vec) + + def dim_vector(self, dim): + """ + Useless method, kept for compatibility with previous versions. + + DO NOT USE. + + + Vector representation in terms of the base dimensions. + """ + return self.can_transf_matrix * Matrix(self.dim_can_vector(dim)) + + def print_dim_base(self, dim): + """ + Give the string expression of a dimension in term of the basis symbols. + """ + dims = self.dim_vector(dim) + symbols = [i.symbol if i.symbol is not None else i.name for i in self.base_dims] + res = S.One + for (s, p) in zip(symbols, dims): + res *= s**p + return res + + @property + def dim(self): + """ + Useless method, kept for compatibility with previous versions. + + DO NOT USE. + + Give the dimension of the system. + + That is return the number of dimensions forming the basis. + """ + return len(self.base_dims) + + @property + def is_consistent(self): + """ + Useless method, kept for compatibility with previous versions. + + DO NOT USE. + + Check if the system is well defined. + """ + + # not enough or too many base dimensions compared to independent + # dimensions + # in vector language: the set of vectors do not form a basis + return self.inv_can_transf_matrix.is_square diff --git a/phivenv/Lib/site-packages/sympy/physics/units/prefixes.py b/phivenv/Lib/site-packages/sympy/physics/units/prefixes.py new file mode 100644 index 0000000000000000000000000000000000000000..44fd7cb9efe4b1d6307810af6b9cd140817126f9 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/units/prefixes.py @@ -0,0 +1,219 @@ +""" +Module defining unit prefixe class and some constants. + +Constant dict for SI and binary prefixes are defined as PREFIXES and +BIN_PREFIXES. +""" +from sympy.core.expr import Expr +from sympy.core.sympify import sympify +from sympy.core.singleton import S + +class Prefix(Expr): + """ + This class represent prefixes, with their name, symbol and factor. + + Prefixes are used to create derived units from a given unit. They should + always be encapsulated into units. + + The factor is constructed from a base (default is 10) to some power, and + it gives the total multiple or fraction. For example the kilometer km + is constructed from the meter (factor 1) and the kilo (10 to the power 3, + i.e. 1000). The base can be changed to allow e.g. binary prefixes. + + A prefix multiplied by something will always return the product of this + other object times the factor, except if the other object: + + - is a prefix and they can be combined into a new prefix; + - defines multiplication with prefixes (which is the case for the Unit + class). + """ + _op_priority = 13.0 + is_commutative = True + + def __new__(cls, name, abbrev, exponent, base=sympify(10), latex_repr=None): + + name = sympify(name) + abbrev = sympify(abbrev) + exponent = sympify(exponent) + base = sympify(base) + + obj = Expr.__new__(cls, name, abbrev, exponent, base) + obj._name = name + obj._abbrev = abbrev + obj._scale_factor = base**exponent + obj._exponent = exponent + obj._base = base + obj._latex_repr = latex_repr + return obj + + @property + def name(self): + return self._name + + @property + def abbrev(self): + return self._abbrev + + @property + def scale_factor(self): + return self._scale_factor + + def _latex(self, printer): + if self._latex_repr is None: + return r'\text{%s}' % self._abbrev + return self._latex_repr + + @property + def base(self): + return self._base + + def __str__(self): + return str(self._abbrev) + + def __repr__(self): + if self.base == 10: + return "Prefix(%r, %r, %r)" % ( + str(self.name), str(self.abbrev), self._exponent) + else: + return "Prefix(%r, %r, %r, %r)" % ( + str(self.name), str(self.abbrev), self._exponent, self.base) + + def __mul__(self, other): + from sympy.physics.units import Quantity + if not isinstance(other, (Quantity, Prefix)): + return super().__mul__(other) + + fact = self.scale_factor * other.scale_factor + + if isinstance(other, Prefix): + if fact == 1: + return S.One + # simplify prefix + for p in PREFIXES: + if PREFIXES[p].scale_factor == fact: + return PREFIXES[p] + return fact + + return self.scale_factor * other + + def __truediv__(self, other): + if not hasattr(other, "scale_factor"): + return super().__truediv__(other) + + fact = self.scale_factor / other.scale_factor + + if fact == 1: + return S.One + elif isinstance(other, Prefix): + for p in PREFIXES: + if PREFIXES[p].scale_factor == fact: + return PREFIXES[p] + return fact + + return self.scale_factor / other + + def __rtruediv__(self, other): + if other == 1: + for p in PREFIXES: + if PREFIXES[p].scale_factor == 1 / self.scale_factor: + return PREFIXES[p] + return other / self.scale_factor + + +def prefix_unit(unit, prefixes): + """ + Return a list of all units formed by unit and the given prefixes. + + You can use the predefined PREFIXES or BIN_PREFIXES, but you can also + pass as argument a subdict of them if you do not want all prefixed units. + + >>> from sympy.physics.units.prefixes import (PREFIXES, + ... prefix_unit) + >>> from sympy.physics.units import m + >>> pref = {"m": PREFIXES["m"], "c": PREFIXES["c"], "d": PREFIXES["d"]} + >>> prefix_unit(m, pref) # doctest: +SKIP + [millimeter, centimeter, decimeter] + """ + + from sympy.physics.units.quantities import Quantity + from sympy.physics.units import UnitSystem + + prefixed_units = [] + + for prefix in prefixes.values(): + quantity = Quantity( + "%s%s" % (prefix.name, unit.name), + abbrev=("%s%s" % (prefix.abbrev, unit.abbrev)), + is_prefixed=True, + ) + UnitSystem._quantity_dimensional_equivalence_map_global[quantity] = unit + UnitSystem._quantity_scale_factors_global[quantity] = (prefix.scale_factor, unit) + prefixed_units.append(quantity) + + return prefixed_units + + +yotta = Prefix('yotta', 'Y', 24) +zetta = Prefix('zetta', 'Z', 21) +exa = Prefix('exa', 'E', 18) +peta = Prefix('peta', 'P', 15) +tera = Prefix('tera', 'T', 12) +giga = Prefix('giga', 'G', 9) +mega = Prefix('mega', 'M', 6) +kilo = Prefix('kilo', 'k', 3) +hecto = Prefix('hecto', 'h', 2) +deca = Prefix('deca', 'da', 1) +deci = Prefix('deci', 'd', -1) +centi = Prefix('centi', 'c', -2) +milli = Prefix('milli', 'm', -3) +micro = Prefix('micro', 'mu', -6, latex_repr=r"\mu") +nano = Prefix('nano', 'n', -9) +pico = Prefix('pico', 'p', -12) +femto = Prefix('femto', 'f', -15) +atto = Prefix('atto', 'a', -18) +zepto = Prefix('zepto', 'z', -21) +yocto = Prefix('yocto', 'y', -24) + + +# https://physics.nist.gov/cuu/Units/prefixes.html +PREFIXES = { + 'Y': yotta, + 'Z': zetta, + 'E': exa, + 'P': peta, + 'T': tera, + 'G': giga, + 'M': mega, + 'k': kilo, + 'h': hecto, + 'da': deca, + 'd': deci, + 'c': centi, + 'm': milli, + 'mu': micro, + 'n': nano, + 'p': pico, + 'f': femto, + 'a': atto, + 'z': zepto, + 'y': yocto, +} + + +kibi = Prefix('kibi', 'Y', 10, 2) +mebi = Prefix('mebi', 'Y', 20, 2) +gibi = Prefix('gibi', 'Y', 30, 2) +tebi = Prefix('tebi', 'Y', 40, 2) +pebi = Prefix('pebi', 'Y', 50, 2) +exbi = Prefix('exbi', 'Y', 60, 2) + + +# https://physics.nist.gov/cuu/Units/binary.html +BIN_PREFIXES = { + 'Ki': kibi, + 'Mi': mebi, + 'Gi': gibi, + 'Ti': tebi, + 'Pi': pebi, + 'Ei': exbi, +} diff --git a/phivenv/Lib/site-packages/sympy/physics/units/quantities.py b/phivenv/Lib/site-packages/sympy/physics/units/quantities.py new file mode 100644 index 0000000000000000000000000000000000000000..cc19e72aea83b5bd8ae7cf2f63dd49388a3815ee --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/units/quantities.py @@ -0,0 +1,152 @@ +""" +Physical quantities. +""" + +from sympy.core.expr import AtomicExpr +from sympy.core.symbol import Symbol +from sympy.core.sympify import sympify +from sympy.physics.units.dimensions import _QuantityMapper +from sympy.physics.units.prefixes import Prefix + + +class Quantity(AtomicExpr): + """ + Physical quantity: can be a unit of measure, a constant or a generic quantity. + """ + + is_commutative = True + is_real = True + is_number = False + is_nonzero = True + is_physical_constant = False + _diff_wrt = True + + def __new__(cls, name, abbrev=None, + latex_repr=None, pretty_unicode_repr=None, + pretty_ascii_repr=None, mathml_presentation_repr=None, + is_prefixed=False, + **assumptions): + + if not isinstance(name, Symbol): + name = Symbol(name) + + if abbrev is None: + abbrev = name + elif isinstance(abbrev, str): + abbrev = Symbol(abbrev) + + # HACK: These are here purely for type checking. They actually get assigned below. + cls._is_prefixed = is_prefixed + + obj = AtomicExpr.__new__(cls, name, abbrev) + obj._name = name + obj._abbrev = abbrev + obj._latex_repr = latex_repr + obj._unicode_repr = pretty_unicode_repr + obj._ascii_repr = pretty_ascii_repr + obj._mathml_repr = mathml_presentation_repr + obj._is_prefixed = is_prefixed + return obj + + def set_global_dimension(self, dimension): + _QuantityMapper._quantity_dimension_global[self] = dimension + + def set_global_relative_scale_factor(self, scale_factor, reference_quantity): + """ + Setting a scale factor that is valid across all unit system. + """ + from sympy.physics.units import UnitSystem + scale_factor = sympify(scale_factor) + if isinstance(scale_factor, Prefix): + self._is_prefixed = True + # replace all prefixes by their ratio to canonical units: + scale_factor = scale_factor.replace( + lambda x: isinstance(x, Prefix), + lambda x: x.scale_factor + ) + scale_factor = sympify(scale_factor) + UnitSystem._quantity_scale_factors_global[self] = (scale_factor, reference_quantity) + UnitSystem._quantity_dimensional_equivalence_map_global[self] = reference_quantity + + @property + def name(self): + return self._name + + @property + def dimension(self): + from sympy.physics.units import UnitSystem + unit_system = UnitSystem.get_default_unit_system() + return unit_system.get_quantity_dimension(self) + + @property + def abbrev(self): + """ + Symbol representing the unit name. + + Prepend the abbreviation with the prefix symbol if it is defines. + """ + return self._abbrev + + @property + def scale_factor(self): + """ + Overall magnitude of the quantity as compared to the canonical units. + """ + from sympy.physics.units import UnitSystem + unit_system = UnitSystem.get_default_unit_system() + return unit_system.get_quantity_scale_factor(self) + + def _eval_is_positive(self): + return True + + def _eval_is_constant(self): + return True + + def _eval_Abs(self): + return self + + def _eval_subs(self, old, new): + if isinstance(new, Quantity) and self != old: + return self + + def _latex(self, printer): + if self._latex_repr: + return self._latex_repr + else: + return r'\text{{{}}}'.format(self.args[1] \ + if len(self.args) >= 2 else self.args[0]) + + def convert_to(self, other, unit_system="SI"): + """ + Convert the quantity to another quantity of same dimensions. + + Examples + ======== + + >>> from sympy.physics.units import speed_of_light, meter, second + >>> speed_of_light + speed_of_light + >>> speed_of_light.convert_to(meter/second) + 299792458*meter/second + + >>> from sympy.physics.units import liter + >>> liter.convert_to(meter**3) + meter**3/1000 + """ + from .util import convert_to + return convert_to(self, other, unit_system) + + @property + def free_symbols(self): + """Return free symbols from quantity.""" + return set() + + @property + def is_prefixed(self): + """Whether or not the quantity is prefixed. Eg. `kilogram` is prefixed, but `gram` is not.""" + return self._is_prefixed + +class PhysicalConstant(Quantity): + """Represents a physical constant, eg. `speed_of_light` or `avogadro_constant`.""" + + is_physical_constant = True diff --git a/phivenv/Lib/site-packages/sympy/physics/units/systems/__init__.py b/phivenv/Lib/site-packages/sympy/physics/units/systems/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7c4f28d42eec86be8d679227f7b11ed7d48e61f1 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/units/systems/__init__.py @@ -0,0 +1,6 @@ +from sympy.physics.units.systems.mks import MKS +from sympy.physics.units.systems.mksa import MKSA +from sympy.physics.units.systems.natural import natural +from sympy.physics.units.systems.si import SI + +__all__ = ['MKS', 'MKSA', 'natural', 'SI'] diff --git a/phivenv/Lib/site-packages/sympy/physics/units/systems/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/units/systems/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd739daf8b3b6a665a74788a00b206f5219e41c0 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/units/systems/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/units/systems/__pycache__/cgs.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/units/systems/__pycache__/cgs.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4307b27798de41e83c80a250f4d145d12b4ab1df Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/units/systems/__pycache__/cgs.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/units/systems/__pycache__/length_weight_time.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/units/systems/__pycache__/length_weight_time.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15c1655dc7ecaffea0fdbe46a0f38f823107ae5e Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/units/systems/__pycache__/length_weight_time.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/units/systems/__pycache__/mks.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/units/systems/__pycache__/mks.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a7691f3c6a4556ba648b6b493290984901bfde7 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/units/systems/__pycache__/mks.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/units/systems/__pycache__/mksa.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/units/systems/__pycache__/mksa.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88ddd15629c0a3f88739d3df9157162c2c0ec54f Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/units/systems/__pycache__/mksa.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/units/systems/__pycache__/natural.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/units/systems/__pycache__/natural.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a8ed7749e4754e94bd4e667caa0b5cf65eab7a4 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/units/systems/__pycache__/natural.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/units/systems/__pycache__/si.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/units/systems/__pycache__/si.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b16d607cb3264969a1b559484eed267c5b713b9 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/units/systems/__pycache__/si.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/units/systems/cgs.py b/phivenv/Lib/site-packages/sympy/physics/units/systems/cgs.py new file mode 100644 index 0000000000000000000000000000000000000000..1f5ee0b5454f1998672e1979ae4eaabe57a8edb4 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/units/systems/cgs.py @@ -0,0 +1,82 @@ +from sympy.core.singleton import S +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.physics.units import UnitSystem, centimeter, gram, second, coulomb, charge, speed_of_light, current, mass, \ + length, voltage, magnetic_density, magnetic_flux +from sympy.physics.units.definitions import coulombs_constant +from sympy.physics.units.definitions.unit_definitions import statcoulomb, statampere, statvolt, volt, tesla, gauss, \ + weber, maxwell, debye, oersted, ohm, farad, henry, erg, ampere, coulomb_constant +from sympy.physics.units.systems.mks import dimsys_length_weight_time + +One = S.One + +dimsys_cgs = dimsys_length_weight_time.extend( + [], + new_dim_deps={ + # Dimensional dependencies for derived dimensions + "impedance": {"time": 1, "length": -1}, + "conductance": {"time": -1, "length": 1}, + "capacitance": {"length": 1}, + "inductance": {"time": 2, "length": -1}, + "charge": {"mass": S.Half, "length": S(3)/2, "time": -1}, + "current": {"mass": One/2, "length": 3*One/2, "time": -2}, + "voltage": {"length": -One/2, "mass": One/2, "time": -1}, + "magnetic_density": {"length": -One/2, "mass": One/2, "time": -1}, + "magnetic_flux": {"length": 3*One/2, "mass": One/2, "time": -1}, + } +) + +cgs_gauss = UnitSystem( + base_units=[centimeter, gram, second], + units=[], + name="cgs_gauss", + dimension_system=dimsys_cgs) + + +cgs_gauss.set_quantity_scale_factor(coulombs_constant, 1) + +cgs_gauss.set_quantity_dimension(statcoulomb, charge) +cgs_gauss.set_quantity_scale_factor(statcoulomb, centimeter**(S(3)/2)*gram**(S.Half)/second) + +cgs_gauss.set_quantity_dimension(coulomb, charge) + +cgs_gauss.set_quantity_dimension(statampere, current) +cgs_gauss.set_quantity_scale_factor(statampere, statcoulomb/second) + +cgs_gauss.set_quantity_dimension(statvolt, voltage) +cgs_gauss.set_quantity_scale_factor(statvolt, erg/statcoulomb) + +cgs_gauss.set_quantity_dimension(volt, voltage) + +cgs_gauss.set_quantity_dimension(gauss, magnetic_density) +cgs_gauss.set_quantity_scale_factor(gauss, sqrt(gram/centimeter)/second) + +cgs_gauss.set_quantity_dimension(tesla, magnetic_density) + +cgs_gauss.set_quantity_dimension(maxwell, magnetic_flux) +cgs_gauss.set_quantity_scale_factor(maxwell, sqrt(centimeter**3*gram)/second) + +# SI units expressed in CGS-gaussian units: +cgs_gauss.set_quantity_scale_factor(coulomb, 10*speed_of_light*statcoulomb) +cgs_gauss.set_quantity_scale_factor(ampere, 10*speed_of_light*statcoulomb/second) +cgs_gauss.set_quantity_scale_factor(volt, 10**6/speed_of_light*statvolt) +cgs_gauss.set_quantity_scale_factor(weber, 10**8*maxwell) +cgs_gauss.set_quantity_scale_factor(tesla, 10**4*gauss) +cgs_gauss.set_quantity_scale_factor(debye, One/10**18*statcoulomb*centimeter) +cgs_gauss.set_quantity_scale_factor(oersted, sqrt(gram/centimeter)/second) +cgs_gauss.set_quantity_scale_factor(ohm, 10**5/speed_of_light**2*second/centimeter) +cgs_gauss.set_quantity_scale_factor(farad, One/10**5*speed_of_light**2*centimeter) +cgs_gauss.set_quantity_scale_factor(henry, 10**5/speed_of_light**2/centimeter*second**2) + +# Coulomb's constant: +cgs_gauss.set_quantity_dimension(coulomb_constant, 1) +cgs_gauss.set_quantity_scale_factor(coulomb_constant, 1) + +__all__ = [ + 'ohm', 'tesla', 'maxwell', 'speed_of_light', 'volt', 'second', 'voltage', + 'debye', 'dimsys_length_weight_time', 'centimeter', 'coulomb_constant', + 'farad', 'sqrt', 'UnitSystem', 'current', 'charge', 'weber', 'gram', + 'statcoulomb', 'gauss', 'S', 'statvolt', 'oersted', 'statampere', + 'dimsys_cgs', 'coulomb', 'magnetic_density', 'magnetic_flux', 'One', + 'length', 'erg', 'mass', 'coulombs_constant', 'henry', 'ampere', + 'cgs_gauss', +] diff --git a/phivenv/Lib/site-packages/sympy/physics/units/systems/length_weight_time.py b/phivenv/Lib/site-packages/sympy/physics/units/systems/length_weight_time.py new file mode 100644 index 0000000000000000000000000000000000000000..dca4ded82afb8ff0e45f197e51c23850ca824737 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/units/systems/length_weight_time.py @@ -0,0 +1,156 @@ +from sympy.core.singleton import S + +from sympy.core.numbers import pi + +from sympy.physics.units import DimensionSystem, hertz, kilogram +from sympy.physics.units.definitions import ( + G, Hz, J, N, Pa, W, c, g, kg, m, s, meter, gram, second, newton, + joule, watt, pascal) +from sympy.physics.units.definitions.dimension_definitions import ( + acceleration, action, energy, force, frequency, momentum, + power, pressure, velocity, length, mass, time) +from sympy.physics.units.prefixes import PREFIXES, prefix_unit +from sympy.physics.units.prefixes import ( + kibi, mebi, gibi, tebi, pebi, exbi +) +from sympy.physics.units.definitions import ( + cd, K, coulomb, volt, ohm, siemens, farad, henry, tesla, weber, dioptre, + lux, katal, gray, becquerel, inch, hectare, liter, julian_year, + gravitational_constant, speed_of_light, elementary_charge, planck, hbar, + electronvolt, avogadro_number, avogadro_constant, boltzmann_constant, + stefan_boltzmann_constant, atomic_mass_constant, molar_gas_constant, + faraday_constant, josephson_constant, von_klitzing_constant, + acceleration_due_to_gravity, magnetic_constant, vacuum_permittivity, + vacuum_impedance, coulomb_constant, atmosphere, bar, pound, psi, mmHg, + milli_mass_unit, quart, lightyear, astronomical_unit, planck_mass, + planck_time, planck_temperature, planck_length, planck_charge, + planck_area, planck_volume, planck_momentum, planck_energy, planck_force, + planck_power, planck_density, planck_energy_density, planck_intensity, + planck_angular_frequency, planck_pressure, planck_current, planck_voltage, + planck_impedance, planck_acceleration, bit, byte, kibibyte, mebibyte, + gibibyte, tebibyte, pebibyte, exbibyte, curie, rutherford, radian, degree, + steradian, angular_mil, atomic_mass_unit, gee, kPa, ampere, u0, kelvin, + mol, mole, candela, electric_constant, boltzmann, angstrom +) + + +dimsys_length_weight_time = DimensionSystem([ + # Dimensional dependencies for MKS base dimensions + length, + mass, + time, +], dimensional_dependencies={ + # Dimensional dependencies for derived dimensions + "velocity": {"length": 1, "time": -1}, + "acceleration": {"length": 1, "time": -2}, + "momentum": {"mass": 1, "length": 1, "time": -1}, + "force": {"mass": 1, "length": 1, "time": -2}, + "energy": {"mass": 1, "length": 2, "time": -2}, + "power": {"length": 2, "mass": 1, "time": -3}, + "pressure": {"mass": 1, "length": -1, "time": -2}, + "frequency": {"time": -1}, + "action": {"length": 2, "mass": 1, "time": -1}, + "area": {"length": 2}, + "volume": {"length": 3}, +}) + + +One = S.One + + +# Base units: +dimsys_length_weight_time.set_quantity_dimension(meter, length) +dimsys_length_weight_time.set_quantity_scale_factor(meter, One) + +# gram; used to define its prefixed units +dimsys_length_weight_time.set_quantity_dimension(gram, mass) +dimsys_length_weight_time.set_quantity_scale_factor(gram, One) + +dimsys_length_weight_time.set_quantity_dimension(second, time) +dimsys_length_weight_time.set_quantity_scale_factor(second, One) + +# derived units + +dimsys_length_weight_time.set_quantity_dimension(newton, force) +dimsys_length_weight_time.set_quantity_scale_factor(newton, kilogram*meter/second**2) + +dimsys_length_weight_time.set_quantity_dimension(joule, energy) +dimsys_length_weight_time.set_quantity_scale_factor(joule, newton*meter) + +dimsys_length_weight_time.set_quantity_dimension(watt, power) +dimsys_length_weight_time.set_quantity_scale_factor(watt, joule/second) + +dimsys_length_weight_time.set_quantity_dimension(pascal, pressure) +dimsys_length_weight_time.set_quantity_scale_factor(pascal, newton/meter**2) + +dimsys_length_weight_time.set_quantity_dimension(hertz, frequency) +dimsys_length_weight_time.set_quantity_scale_factor(hertz, One) + +# Other derived units: + +dimsys_length_weight_time.set_quantity_dimension(dioptre, 1 / length) +dimsys_length_weight_time.set_quantity_scale_factor(dioptre, 1/meter) + +# Common volume and area units + +dimsys_length_weight_time.set_quantity_dimension(hectare, length**2) +dimsys_length_weight_time.set_quantity_scale_factor(hectare, (meter**2)*(10000)) + +dimsys_length_weight_time.set_quantity_dimension(liter, length**3) +dimsys_length_weight_time.set_quantity_scale_factor(liter, meter**3/1000) + + +# Newton constant +# REF: NIST SP 959 (June 2019) + +dimsys_length_weight_time.set_quantity_dimension(gravitational_constant, length ** 3 * mass ** -1 * time ** -2) +dimsys_length_weight_time.set_quantity_scale_factor(gravitational_constant, 6.67430e-11*m**3/(kg*s**2)) + +# speed of light + +dimsys_length_weight_time.set_quantity_dimension(speed_of_light, velocity) +dimsys_length_weight_time.set_quantity_scale_factor(speed_of_light, 299792458*meter/second) + + +# Planck constant +# REF: NIST SP 959 (June 2019) + +dimsys_length_weight_time.set_quantity_dimension(planck, action) +dimsys_length_weight_time.set_quantity_scale_factor(planck, 6.62607015e-34*joule*second) + +# Reduced Planck constant +# REF: NIST SP 959 (June 2019) + +dimsys_length_weight_time.set_quantity_dimension(hbar, action) +dimsys_length_weight_time.set_quantity_scale_factor(hbar, planck / (2 * pi)) + + +__all__ = [ + 'mmHg', 'atmosphere', 'newton', 'meter', 'vacuum_permittivity', 'pascal', + 'magnetic_constant', 'angular_mil', 'julian_year', 'weber', 'exbibyte', + 'liter', 'molar_gas_constant', 'faraday_constant', 'avogadro_constant', + 'planck_momentum', 'planck_density', 'gee', 'mol', 'bit', 'gray', 'kibi', + 'bar', 'curie', 'prefix_unit', 'PREFIXES', 'planck_time', 'gram', + 'candela', 'force', 'planck_intensity', 'energy', 'becquerel', + 'planck_acceleration', 'speed_of_light', 'dioptre', 'second', 'frequency', + 'Hz', 'power', 'lux', 'planck_current', 'momentum', 'tebibyte', + 'planck_power', 'degree', 'mebi', 'K', 'planck_volume', + 'quart', 'pressure', 'W', 'joule', 'boltzmann_constant', 'c', 'g', + 'planck_force', 'exbi', 's', 'watt', 'action', 'hbar', 'gibibyte', + 'DimensionSystem', 'cd', 'volt', 'planck_charge', 'angstrom', + 'dimsys_length_weight_time', 'pebi', 'vacuum_impedance', 'planck', + 'farad', 'gravitational_constant', 'u0', 'hertz', 'tesla', 'steradian', + 'josephson_constant', 'planck_area', 'stefan_boltzmann_constant', + 'astronomical_unit', 'J', 'N', 'planck_voltage', 'planck_energy', + 'atomic_mass_constant', 'rutherford', 'elementary_charge', 'Pa', + 'planck_mass', 'henry', 'planck_angular_frequency', 'ohm', 'pound', + 'planck_pressure', 'G', 'avogadro_number', 'psi', 'von_klitzing_constant', + 'planck_length', 'radian', 'mole', 'acceleration', + 'planck_energy_density', 'mebibyte', 'length', + 'acceleration_due_to_gravity', 'planck_temperature', 'tebi', 'inch', + 'electronvolt', 'coulomb_constant', 'kelvin', 'kPa', 'boltzmann', + 'milli_mass_unit', 'gibi', 'planck_impedance', 'electric_constant', 'kg', + 'coulomb', 'siemens', 'byte', 'atomic_mass_unit', 'm', 'kibibyte', + 'kilogram', 'lightyear', 'mass', 'time', 'pebibyte', 'velocity', + 'ampere', 'katal', +] diff --git a/phivenv/Lib/site-packages/sympy/physics/units/systems/mks.py b/phivenv/Lib/site-packages/sympy/physics/units/systems/mks.py new file mode 100644 index 0000000000000000000000000000000000000000..18cc4b1be5e2cbf5773845e48a0cb552fb750fae --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/units/systems/mks.py @@ -0,0 +1,46 @@ +""" +MKS unit system. + +MKS stands for "meter, kilogram, second". +""" + +from sympy.physics.units import UnitSystem +from sympy.physics.units.definitions import gravitational_constant, hertz, joule, newton, pascal, watt, speed_of_light, gram, kilogram, meter, second +from sympy.physics.units.definitions.dimension_definitions import ( + acceleration, action, energy, force, frequency, momentum, + power, pressure, velocity, length, mass, time) +from sympy.physics.units.prefixes import PREFIXES, prefix_unit +from sympy.physics.units.systems.length_weight_time import dimsys_length_weight_time + +dims = (velocity, acceleration, momentum, force, energy, power, pressure, + frequency, action) + +units = [meter, gram, second, joule, newton, watt, pascal, hertz] +all_units = [] + +# Prefixes of units like gram, joule, newton etc get added using `prefix_unit` +# in the for loop, but the actual units have to be added manually. +all_units.extend([gram, joule, newton, watt, pascal, hertz]) + +for u in units: + all_units.extend(prefix_unit(u, PREFIXES)) +all_units.extend([gravitational_constant, speed_of_light]) + +# unit system +MKS = UnitSystem(base_units=(meter, kilogram, second), units=all_units, name="MKS", dimension_system=dimsys_length_weight_time, derived_units={ + power: watt, + time: second, + pressure: pascal, + length: meter, + frequency: hertz, + mass: kilogram, + force: newton, + energy: joule, + velocity: meter/second, + acceleration: meter/(second**2), +}) + + +__all__ = [ + 'MKS', 'units', 'all_units', 'dims', +] diff --git a/phivenv/Lib/site-packages/sympy/physics/units/systems/mksa.py b/phivenv/Lib/site-packages/sympy/physics/units/systems/mksa.py new file mode 100644 index 0000000000000000000000000000000000000000..c18c0d6ae3801358d8828e2309d091cb9cb987d8 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/units/systems/mksa.py @@ -0,0 +1,54 @@ +""" +MKS unit system. + +MKS stands for "meter, kilogram, second, ampere". +""" + +from __future__ import annotations + +from sympy.physics.units.definitions import Z0, ampere, coulomb, farad, henry, siemens, tesla, volt, weber, ohm +from sympy.physics.units.definitions.dimension_definitions import ( + capacitance, charge, conductance, current, impedance, inductance, + magnetic_density, magnetic_flux, voltage) +from sympy.physics.units.prefixes import PREFIXES, prefix_unit +from sympy.physics.units.systems.mks import MKS, dimsys_length_weight_time +from sympy.physics.units.quantities import Quantity + +dims = (voltage, impedance, conductance, current, capacitance, inductance, charge, + magnetic_density, magnetic_flux) + +units = [ampere, volt, ohm, siemens, farad, henry, coulomb, tesla, weber] + +all_units: list[Quantity] = [] +for u in units: + all_units.extend(prefix_unit(u, PREFIXES)) +all_units.extend(units) + +all_units.append(Z0) + +dimsys_MKSA = dimsys_length_weight_time.extend([ + # Dimensional dependencies for base dimensions (MKSA not in MKS) + current, +], new_dim_deps={ + # Dimensional dependencies for derived dimensions + "voltage": {"mass": 1, "length": 2, "current": -1, "time": -3}, + "impedance": {"mass": 1, "length": 2, "current": -2, "time": -3}, + "conductance": {"mass": -1, "length": -2, "current": 2, "time": 3}, + "capacitance": {"mass": -1, "length": -2, "current": 2, "time": 4}, + "inductance": {"mass": 1, "length": 2, "current": -2, "time": -2}, + "charge": {"current": 1, "time": 1}, + "magnetic_density": {"mass": 1, "current": -1, "time": -2}, + "magnetic_flux": {"length": 2, "mass": 1, "current": -1, "time": -2}, +}) + +MKSA = MKS.extend(base=(ampere,), units=all_units, name='MKSA', dimension_system=dimsys_MKSA, derived_units={ + magnetic_flux: weber, + impedance: ohm, + current: ampere, + voltage: volt, + inductance: henry, + conductance: siemens, + magnetic_density: tesla, + charge: coulomb, + capacitance: farad, +}) diff --git a/phivenv/Lib/site-packages/sympy/physics/units/systems/natural.py b/phivenv/Lib/site-packages/sympy/physics/units/systems/natural.py new file mode 100644 index 0000000000000000000000000000000000000000..13eb2c19e982438fab4b1422ddc5a25b16204be8 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/units/systems/natural.py @@ -0,0 +1,27 @@ +""" +Naturalunit system. + +The natural system comes from "setting c = 1, hbar = 1". From the computer +point of view it means that we use velocity and action instead of length and +time. Moreover instead of mass we use energy. +""" + +from sympy.physics.units import DimensionSystem +from sympy.physics.units.definitions import c, eV, hbar +from sympy.physics.units.definitions.dimension_definitions import ( + action, energy, force, frequency, length, mass, momentum, + power, time, velocity) +from sympy.physics.units.prefixes import PREFIXES, prefix_unit +from sympy.physics.units.unitsystem import UnitSystem + + +# dimension system +_natural_dim = DimensionSystem( + base_dims=(action, energy, velocity), + derived_dims=(length, mass, time, momentum, force, power, frequency) +) + +units = prefix_unit(eV, PREFIXES) + +# unit system +natural = UnitSystem(base_units=(hbar, eV, c), units=units, name="Natural system") diff --git a/phivenv/Lib/site-packages/sympy/physics/units/systems/si.py b/phivenv/Lib/site-packages/sympy/physics/units/systems/si.py new file mode 100644 index 0000000000000000000000000000000000000000..2bfa7805871b8663c70b8af7da9ca1dc9b4afab3 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/units/systems/si.py @@ -0,0 +1,377 @@ +""" +SI unit system. +Based on MKSA, which stands for "meter, kilogram, second, ampere". +Added kelvin, candela and mole. + +""" + +from __future__ import annotations + +from sympy.physics.units import DimensionSystem, Dimension, dHg0 + +from sympy.physics.units.quantities import Quantity + +from sympy.core.numbers import (Rational, pi) +from sympy.core.singleton import S +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.physics.units.definitions.dimension_definitions import ( + acceleration, action, current, impedance, length, mass, time, velocity, + amount_of_substance, temperature, information, frequency, force, pressure, + energy, power, charge, voltage, capacitance, conductance, magnetic_flux, + magnetic_density, inductance, luminous_intensity +) +from sympy.physics.units.definitions import ( + kilogram, newton, second, meter, gram, cd, K, joule, watt, pascal, hertz, + coulomb, volt, ohm, siemens, farad, henry, tesla, weber, dioptre, lux, + katal, gray, becquerel, inch, liter, julian_year, gravitational_constant, + speed_of_light, elementary_charge, planck, hbar, electronvolt, + avogadro_number, avogadro_constant, boltzmann_constant, electron_rest_mass, + stefan_boltzmann_constant, Da, atomic_mass_constant, molar_gas_constant, + faraday_constant, josephson_constant, von_klitzing_constant, + acceleration_due_to_gravity, magnetic_constant, vacuum_permittivity, + vacuum_impedance, coulomb_constant, atmosphere, bar, pound, psi, mmHg, + milli_mass_unit, quart, lightyear, astronomical_unit, planck_mass, + planck_time, planck_temperature, planck_length, planck_charge, planck_area, + planck_volume, planck_momentum, planck_energy, planck_force, planck_power, + planck_density, planck_energy_density, planck_intensity, + planck_angular_frequency, planck_pressure, planck_current, planck_voltage, + planck_impedance, planck_acceleration, bit, byte, kibibyte, mebibyte, + gibibyte, tebibyte, pebibyte, exbibyte, curie, rutherford, radian, degree, + steradian, angular_mil, atomic_mass_unit, gee, kPa, ampere, u0, c, kelvin, + mol, mole, candela, m, kg, s, electric_constant, G, boltzmann +) +from sympy.physics.units.prefixes import PREFIXES, prefix_unit +from sympy.physics.units.systems.mksa import MKSA, dimsys_MKSA + +derived_dims = (frequency, force, pressure, energy, power, charge, voltage, + capacitance, conductance, magnetic_flux, + magnetic_density, inductance, luminous_intensity) +base_dims = (amount_of_substance, luminous_intensity, temperature) + +units = [mol, cd, K, lux, hertz, newton, pascal, joule, watt, coulomb, volt, + farad, ohm, siemens, weber, tesla, henry, candela, lux, becquerel, + gray, katal] + +all_units: list[Quantity] = [] +for u in units: + all_units.extend(prefix_unit(u, PREFIXES)) + +all_units.extend(units) +all_units.extend([mol, cd, K, lux]) + + +dimsys_SI = dimsys_MKSA.extend( + [ + # Dimensional dependencies for other base dimensions: + temperature, + amount_of_substance, + luminous_intensity, + ]) + +dimsys_default = dimsys_SI.extend( + [information], +) + +SI = MKSA.extend(base=(mol, cd, K), units=all_units, name='SI', dimension_system=dimsys_SI, derived_units={ + power: watt, + magnetic_flux: weber, + time: second, + impedance: ohm, + pressure: pascal, + current: ampere, + voltage: volt, + length: meter, + frequency: hertz, + inductance: henry, + temperature: kelvin, + amount_of_substance: mole, + luminous_intensity: candela, + conductance: siemens, + mass: kilogram, + magnetic_density: tesla, + charge: coulomb, + force: newton, + capacitance: farad, + energy: joule, + velocity: meter/second, +}) + +One = S.One + +SI.set_quantity_dimension(radian, One) + +SI.set_quantity_scale_factor(ampere, One) + +SI.set_quantity_scale_factor(kelvin, One) + +SI.set_quantity_scale_factor(mole, One) + +SI.set_quantity_scale_factor(candela, One) + +# MKSA extension to MKS: derived units + +SI.set_quantity_scale_factor(coulomb, One) + +SI.set_quantity_scale_factor(volt, joule/coulomb) + +SI.set_quantity_scale_factor(ohm, volt/ampere) + +SI.set_quantity_scale_factor(siemens, ampere/volt) + +SI.set_quantity_scale_factor(farad, coulomb/volt) + +SI.set_quantity_scale_factor(henry, volt*second/ampere) + +SI.set_quantity_scale_factor(tesla, volt*second/meter**2) + +SI.set_quantity_scale_factor(weber, joule/ampere) + + +SI.set_quantity_dimension(lux, luminous_intensity / length ** 2) +SI.set_quantity_scale_factor(lux, steradian*candela/meter**2) + +# katal is the SI unit of catalytic activity + +SI.set_quantity_dimension(katal, amount_of_substance / time) +SI.set_quantity_scale_factor(katal, mol/second) + +# gray is the SI unit of absorbed dose + +SI.set_quantity_dimension(gray, energy / mass) +SI.set_quantity_scale_factor(gray, meter**2/second**2) + +# becquerel is the SI unit of radioactivity + +SI.set_quantity_dimension(becquerel, 1 / time) +SI.set_quantity_scale_factor(becquerel, 1/second) + +#### CONSTANTS #### + +# elementary charge +# REF: NIST SP 959 (June 2019) + +SI.set_quantity_dimension(elementary_charge, charge) +SI.set_quantity_scale_factor(elementary_charge, 1.602176634e-19*coulomb) + +# Electronvolt +# REF: NIST SP 959 (June 2019) + +SI.set_quantity_dimension(electronvolt, energy) +SI.set_quantity_scale_factor(electronvolt, 1.602176634e-19*joule) + +# Avogadro number +# REF: NIST SP 959 (June 2019) + +SI.set_quantity_dimension(avogadro_number, One) +SI.set_quantity_scale_factor(avogadro_number, 6.02214076e23) + +# Avogadro constant + +SI.set_quantity_dimension(avogadro_constant, amount_of_substance ** -1) +SI.set_quantity_scale_factor(avogadro_constant, avogadro_number / mol) + +# Boltzmann constant +# REF: NIST SP 959 (June 2019) + +SI.set_quantity_dimension(boltzmann_constant, energy / temperature) +SI.set_quantity_scale_factor(boltzmann_constant, 1.380649e-23*joule/kelvin) + +# Stefan-Boltzmann constant +# REF: NIST SP 959 (June 2019) + +SI.set_quantity_dimension(stefan_boltzmann_constant, energy * time ** -1 * length ** -2 * temperature ** -4) +SI.set_quantity_scale_factor(stefan_boltzmann_constant, pi**2 * boltzmann_constant**4 / (60 * hbar**3 * speed_of_light ** 2)) + +# Atomic mass +# REF: NIST SP 959 (June 2019) + +SI.set_quantity_dimension(atomic_mass_constant, mass) +SI.set_quantity_scale_factor(atomic_mass_constant, 1.66053906660e-24*gram) + +# Molar gas constant +# REF: NIST SP 959 (June 2019) + +SI.set_quantity_dimension(molar_gas_constant, energy / (temperature * amount_of_substance)) +SI.set_quantity_scale_factor(molar_gas_constant, boltzmann_constant * avogadro_constant) + +# Faraday constant + +SI.set_quantity_dimension(faraday_constant, charge / amount_of_substance) +SI.set_quantity_scale_factor(faraday_constant, elementary_charge * avogadro_constant) + +# Josephson constant + +SI.set_quantity_dimension(josephson_constant, frequency / voltage) +SI.set_quantity_scale_factor(josephson_constant, 0.5 * planck / elementary_charge) + +# Von Klitzing constant + +SI.set_quantity_dimension(von_klitzing_constant, voltage / current) +SI.set_quantity_scale_factor(von_klitzing_constant, hbar / elementary_charge ** 2) + +# Acceleration due to gravity (on the Earth surface) + +SI.set_quantity_dimension(acceleration_due_to_gravity, acceleration) +SI.set_quantity_scale_factor(acceleration_due_to_gravity, 9.80665*meter/second**2) + +# magnetic constant: + +SI.set_quantity_dimension(magnetic_constant, force / current ** 2) +SI.set_quantity_scale_factor(magnetic_constant, 4*pi/10**7 * newton/ampere**2) + +# electric constant: + +SI.set_quantity_dimension(vacuum_permittivity, capacitance / length) +SI.set_quantity_scale_factor(vacuum_permittivity, 1/(u0 * c**2)) + +# vacuum impedance: + +SI.set_quantity_dimension(vacuum_impedance, impedance) +SI.set_quantity_scale_factor(vacuum_impedance, u0 * c) + +# Electron rest mass +SI.set_quantity_dimension(electron_rest_mass, mass) +SI.set_quantity_scale_factor(electron_rest_mass, 9.1093837015e-31*kilogram) + +# Coulomb's constant: +SI.set_quantity_dimension(coulomb_constant, force * length ** 2 / charge ** 2) +SI.set_quantity_scale_factor(coulomb_constant, 1/(4*pi*vacuum_permittivity)) + +SI.set_quantity_dimension(psi, pressure) +SI.set_quantity_scale_factor(psi, pound * gee / inch ** 2) + +SI.set_quantity_dimension(mmHg, pressure) +SI.set_quantity_scale_factor(mmHg, dHg0 * acceleration_due_to_gravity * kilogram / meter**2) + +SI.set_quantity_dimension(milli_mass_unit, mass) +SI.set_quantity_scale_factor(milli_mass_unit, atomic_mass_unit/1000) + +SI.set_quantity_dimension(quart, length ** 3) +SI.set_quantity_scale_factor(quart, Rational(231, 4) * inch**3) + +# Other convenient units and magnitudes + +SI.set_quantity_dimension(lightyear, length) +SI.set_quantity_scale_factor(lightyear, speed_of_light*julian_year) + +SI.set_quantity_dimension(astronomical_unit, length) +SI.set_quantity_scale_factor(astronomical_unit, 149597870691*meter) + +# Fundamental Planck units: + +SI.set_quantity_dimension(planck_mass, mass) +SI.set_quantity_scale_factor(planck_mass, sqrt(hbar*speed_of_light/G)) + +SI.set_quantity_dimension(planck_time, time) +SI.set_quantity_scale_factor(planck_time, sqrt(hbar*G/speed_of_light**5)) + +SI.set_quantity_dimension(planck_temperature, temperature) +SI.set_quantity_scale_factor(planck_temperature, sqrt(hbar*speed_of_light**5/G/boltzmann**2)) + +SI.set_quantity_dimension(planck_length, length) +SI.set_quantity_scale_factor(planck_length, sqrt(hbar*G/speed_of_light**3)) + +SI.set_quantity_dimension(planck_charge, charge) +SI.set_quantity_scale_factor(planck_charge, sqrt(4*pi*electric_constant*hbar*speed_of_light)) + +# Derived Planck units: + +SI.set_quantity_dimension(planck_area, length ** 2) +SI.set_quantity_scale_factor(planck_area, planck_length**2) + +SI.set_quantity_dimension(planck_volume, length ** 3) +SI.set_quantity_scale_factor(planck_volume, planck_length**3) + +SI.set_quantity_dimension(planck_momentum, mass * velocity) +SI.set_quantity_scale_factor(planck_momentum, planck_mass * speed_of_light) + +SI.set_quantity_dimension(planck_energy, energy) +SI.set_quantity_scale_factor(planck_energy, planck_mass * speed_of_light**2) + +SI.set_quantity_dimension(planck_force, force) +SI.set_quantity_scale_factor(planck_force, planck_energy / planck_length) + +SI.set_quantity_dimension(planck_power, power) +SI.set_quantity_scale_factor(planck_power, planck_energy / planck_time) + +SI.set_quantity_dimension(planck_density, mass / length ** 3) +SI.set_quantity_scale_factor(planck_density, planck_mass / planck_length**3) + +SI.set_quantity_dimension(planck_energy_density, energy / length ** 3) +SI.set_quantity_scale_factor(planck_energy_density, planck_energy / planck_length**3) + +SI.set_quantity_dimension(planck_intensity, mass * time ** (-3)) +SI.set_quantity_scale_factor(planck_intensity, planck_energy_density * speed_of_light) + +SI.set_quantity_dimension(planck_angular_frequency, 1 / time) +SI.set_quantity_scale_factor(planck_angular_frequency, 1 / planck_time) + +SI.set_quantity_dimension(planck_pressure, pressure) +SI.set_quantity_scale_factor(planck_pressure, planck_force / planck_length**2) + +SI.set_quantity_dimension(planck_current, current) +SI.set_quantity_scale_factor(planck_current, planck_charge / planck_time) + +SI.set_quantity_dimension(planck_voltage, voltage) +SI.set_quantity_scale_factor(planck_voltage, planck_energy / planck_charge) + +SI.set_quantity_dimension(planck_impedance, impedance) +SI.set_quantity_scale_factor(planck_impedance, planck_voltage / planck_current) + +SI.set_quantity_dimension(planck_acceleration, acceleration) +SI.set_quantity_scale_factor(planck_acceleration, speed_of_light / planck_time) + +# Older units for radioactivity + +SI.set_quantity_dimension(curie, 1 / time) +SI.set_quantity_scale_factor(curie, 37000000000*becquerel) + +SI.set_quantity_dimension(rutherford, 1 / time) +SI.set_quantity_scale_factor(rutherford, 1000000*becquerel) + + +# check that scale factors are the right SI dimensions: +for _scale_factor, _dimension in zip( + SI._quantity_scale_factors.values(), + SI._quantity_dimension_map.values() +): + dimex = SI.get_dimensional_expr(_scale_factor) + if dimex != 1: + # XXX: equivalent_dims is an instance method taking two arguments in + # addition to self so this can not work: + if not DimensionSystem.equivalent_dims(_dimension, Dimension(dimex)): # type: ignore + raise ValueError("quantity value and dimension mismatch") +del _scale_factor, _dimension + +__all__ = [ + 'mmHg', 'atmosphere', 'inductance', 'newton', 'meter', + 'vacuum_permittivity', 'pascal', 'magnetic_constant', 'voltage', + 'angular_mil', 'luminous_intensity', 'all_units', + 'julian_year', 'weber', 'exbibyte', 'liter', + 'molar_gas_constant', 'faraday_constant', 'avogadro_constant', + 'lightyear', 'planck_density', 'gee', 'mol', 'bit', 'gray', + 'planck_momentum', 'bar', 'magnetic_density', 'prefix_unit', 'PREFIXES', + 'planck_time', 'dimex', 'gram', 'candela', 'force', 'planck_intensity', + 'energy', 'becquerel', 'planck_acceleration', 'speed_of_light', + 'conductance', 'frequency', 'coulomb_constant', 'degree', 'lux', 'planck', + 'current', 'planck_current', 'tebibyte', 'planck_power', 'MKSA', 'power', + 'K', 'planck_volume', 'quart', 'pressure', 'amount_of_substance', + 'joule', 'boltzmann_constant', 'Dimension', 'c', 'planck_force', 'length', + 'watt', 'action', 'hbar', 'gibibyte', 'DimensionSystem', 'cd', 'volt', + 'planck_charge', 'dioptre', 'vacuum_impedance', 'dimsys_default', 'farad', + 'charge', 'gravitational_constant', 'temperature', 'u0', 'hertz', + 'capacitance', 'tesla', 'steradian', 'planck_mass', 'josephson_constant', + 'planck_area', 'stefan_boltzmann_constant', 'base_dims', + 'astronomical_unit', 'radian', 'planck_voltage', 'impedance', + 'planck_energy', 'Da', 'atomic_mass_constant', 'rutherford', 'second', 'inch', + 'elementary_charge', 'SI', 'electronvolt', 'dimsys_SI', 'henry', + 'planck_angular_frequency', 'ohm', 'pound', 'planck_pressure', 'G', 'psi', + 'dHg0', 'von_klitzing_constant', 'planck_length', 'avogadro_number', + 'mole', 'acceleration', 'information', 'planck_energy_density', + 'mebibyte', 's', 'acceleration_due_to_gravity', 'electron_rest_mass', + 'planck_temperature', 'units', 'mass', 'dimsys_MKSA', 'kelvin', 'kPa', + 'boltzmann', 'milli_mass_unit', 'planck_impedance', 'electric_constant', + 'derived_dims', 'kg', 'coulomb', 'siemens', 'byte', 'magnetic_flux', + 'atomic_mass_unit', 'm', 'kibibyte', 'kilogram', 'One', 'curie', 'u', + 'time', 'pebibyte', 'velocity', 'ampere', 'katal', +] diff --git a/phivenv/Lib/site-packages/sympy/physics/units/tests/__init__.py b/phivenv/Lib/site-packages/sympy/physics/units/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/physics/units/tests/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/units/tests/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d7308c99c3fa3f1b28acf31006e0f52247832b4 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/units/tests/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/units/tests/__pycache__/test_dimensions.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/units/tests/__pycache__/test_dimensions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce737539314ff37cd7d1a46ea7a30eadae550569 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/units/tests/__pycache__/test_dimensions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/units/tests/__pycache__/test_dimensionsystem.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/units/tests/__pycache__/test_dimensionsystem.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b5935d0815796885f16c95470ca259e052b4151 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/units/tests/__pycache__/test_dimensionsystem.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/units/tests/__pycache__/test_prefixes.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/units/tests/__pycache__/test_prefixes.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20082821f425444108bc5d1990f297b804862d16 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/units/tests/__pycache__/test_prefixes.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/units/tests/__pycache__/test_quantities.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/units/tests/__pycache__/test_quantities.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82c9662ec4d29d0cc96c3bc507d0a61f7e7710c6 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/units/tests/__pycache__/test_quantities.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/units/tests/__pycache__/test_unit_system_cgs_gauss.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/units/tests/__pycache__/test_unit_system_cgs_gauss.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40585de9f333f27a1dc6fcdeb15323b9ff0322ee Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/units/tests/__pycache__/test_unit_system_cgs_gauss.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/units/tests/__pycache__/test_unitsystem.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/units/tests/__pycache__/test_unitsystem.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3893e108a7f1080c171c6ece2fc25738a211c28 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/units/tests/__pycache__/test_unitsystem.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/units/tests/__pycache__/test_util.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/units/tests/__pycache__/test_util.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a74de4ee6f5e01407fc29007c44920aa99d36b2a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/units/tests/__pycache__/test_util.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/units/tests/test_dimensions.py b/phivenv/Lib/site-packages/sympy/physics/units/tests/test_dimensions.py new file mode 100644 index 0000000000000000000000000000000000000000..6455df41068a07c966c5f3e782e561fec4d16a97 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/units/tests/test_dimensions.py @@ -0,0 +1,150 @@ +from sympy.physics.units.systems.si import dimsys_SI + +from sympy.core.numbers import pi +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.exponential import log +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (acos, atan2, cos) +from sympy.physics.units.dimensions import Dimension +from sympy.physics.units.definitions.dimension_definitions import ( + length, time, mass, force, pressure, angle +) +from sympy.physics.units import foot +from sympy.testing.pytest import raises + + +def test_Dimension_definition(): + assert dimsys_SI.get_dimensional_dependencies(length) == {length: 1} + assert length.name == Symbol("length") + assert length.symbol == Symbol("L") + + halflength = sqrt(length) + assert dimsys_SI.get_dimensional_dependencies(halflength) == {length: S.Half} + + +def test_Dimension_error_definition(): + # tuple with more or less than two entries + raises(TypeError, lambda: Dimension(("length", 1, 2))) + raises(TypeError, lambda: Dimension(["length"])) + + # non-number power + raises(TypeError, lambda: Dimension({"length": "a"})) + + # non-number with named argument + raises(TypeError, lambda: Dimension({"length": (1, 2)})) + + # symbol should by Symbol or str + raises(AssertionError, lambda: Dimension("length", symbol=1)) + + +def test_str(): + assert str(Dimension("length")) == "Dimension(length)" + assert str(Dimension("length", "L")) == "Dimension(length, L)" + + +def test_Dimension_properties(): + assert dimsys_SI.is_dimensionless(length) is False + assert dimsys_SI.is_dimensionless(length/length) is True + assert dimsys_SI.is_dimensionless(Dimension("undefined")) is False + + assert length.has_integer_powers(dimsys_SI) is True + assert (length**(-1)).has_integer_powers(dimsys_SI) is True + assert (length**1.5).has_integer_powers(dimsys_SI) is False + + +def test_Dimension_add_sub(): + assert length + length == length + assert length - length == length + assert -length == length + + raises(TypeError, lambda: length + foot) + raises(TypeError, lambda: foot + length) + raises(TypeError, lambda: length - foot) + raises(TypeError, lambda: foot - length) + + # issue 14547 - only raise error for dimensional args; allow + # others to pass + x = Symbol('x') + e = length + x + assert e == x + length and e.is_Add and set(e.args) == {length, x} + e = length + 1 + assert e == 1 + length == 1 - length and e.is_Add and set(e.args) == {length, 1} + + assert dimsys_SI.get_dimensional_dependencies(mass * length / time**2 + force) == \ + {length: 1, mass: 1, time: -2} + assert dimsys_SI.get_dimensional_dependencies(mass * length / time**2 + force - + pressure * length**2) == \ + {length: 1, mass: 1, time: -2} + + raises(TypeError, lambda: dimsys_SI.get_dimensional_dependencies(mass * length / time**2 + pressure)) + +def test_Dimension_mul_div_exp(): + assert 2*length == length*2 == length/2 == length + assert 2/length == 1/length + x = Symbol('x') + m = x*length + assert m == length*x and m.is_Mul and set(m.args) == {x, length} + d = x/length + assert d == x*length**-1 and d.is_Mul and set(d.args) == {x, 1/length} + d = length/x + assert d == length*x**-1 and d.is_Mul and set(d.args) == {1/x, length} + + velo = length / time + + assert (length * length) == length ** 2 + + assert dimsys_SI.get_dimensional_dependencies(length * length) == {length: 2} + assert dimsys_SI.get_dimensional_dependencies(length ** 2) == {length: 2} + assert dimsys_SI.get_dimensional_dependencies(length * time) == {length: 1, time: 1} + assert dimsys_SI.get_dimensional_dependencies(velo) == {length: 1, time: -1} + assert dimsys_SI.get_dimensional_dependencies(velo ** 2) == {length: 2, time: -2} + + assert dimsys_SI.get_dimensional_dependencies(length / length) == {} + assert dimsys_SI.get_dimensional_dependencies(velo / length * time) == {} + assert dimsys_SI.get_dimensional_dependencies(length ** -1) == {length: -1} + assert dimsys_SI.get_dimensional_dependencies(velo ** -1.5) == {length: -1.5, time: 1.5} + + length_a = length**"a" + assert dimsys_SI.get_dimensional_dependencies(length_a) == {length: Symbol("a")} + + assert dimsys_SI.get_dimensional_dependencies(length**pi) == {length: pi} + assert dimsys_SI.get_dimensional_dependencies(length**(length/length)) == {length: Dimension(1)} + + raises(TypeError, lambda: dimsys_SI.get_dimensional_dependencies(length**length)) + + assert length != 1 + assert length / length != 1 + + length_0 = length ** 0 + assert dimsys_SI.get_dimensional_dependencies(length_0) == {} + + # issue 18738 + a = Symbol('a') + b = Symbol('b') + c = sqrt(a**2 + b**2) + c_dim = c.subs({a: length, b: length}) + assert dimsys_SI.equivalent_dims(c_dim, length) + +def test_Dimension_functions(): + raises(TypeError, lambda: dimsys_SI.get_dimensional_dependencies(cos(length))) + raises(TypeError, lambda: dimsys_SI.get_dimensional_dependencies(acos(angle))) + raises(TypeError, lambda: dimsys_SI.get_dimensional_dependencies(atan2(length, time))) + raises(TypeError, lambda: dimsys_SI.get_dimensional_dependencies(log(length))) + raises(TypeError, lambda: dimsys_SI.get_dimensional_dependencies(log(100, length))) + raises(TypeError, lambda: dimsys_SI.get_dimensional_dependencies(log(length, 10))) + + assert dimsys_SI.get_dimensional_dependencies(pi) == {} + + assert dimsys_SI.get_dimensional_dependencies(cos(1)) == {} + assert dimsys_SI.get_dimensional_dependencies(cos(angle)) == {} + + assert dimsys_SI.get_dimensional_dependencies(atan2(length, length)) == {} + + assert dimsys_SI.get_dimensional_dependencies(log(length / length, length / length)) == {} + + assert dimsys_SI.get_dimensional_dependencies(Abs(length)) == {length: 1} + assert dimsys_SI.get_dimensional_dependencies(Abs(length / length)) == {} + + assert dimsys_SI.get_dimensional_dependencies(sqrt(-1)) == {} diff --git a/phivenv/Lib/site-packages/sympy/physics/units/tests/test_dimensionsystem.py b/phivenv/Lib/site-packages/sympy/physics/units/tests/test_dimensionsystem.py new file mode 100644 index 0000000000000000000000000000000000000000..8a55ac398c38adf24d93bfa376c9cc51c1ec40fe --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/units/tests/test_dimensionsystem.py @@ -0,0 +1,95 @@ +from sympy.core.symbol import symbols +from sympy.matrices.dense import (Matrix, eye) +from sympy.physics.units.definitions.dimension_definitions import ( + action, current, length, mass, time, + velocity) +from sympy.physics.units.dimensions import DimensionSystem + + +def test_extend(): + ms = DimensionSystem((length, time), (velocity,)) + + mks = ms.extend((mass,), (action,)) + + res = DimensionSystem((length, time, mass), (velocity, action)) + assert mks.base_dims == res.base_dims + assert mks.derived_dims == res.derived_dims + + +def test_list_dims(): + dimsys = DimensionSystem((length, time, mass)) + + assert dimsys.list_can_dims == (length, mass, time) + + +def test_dim_can_vector(): + dimsys = DimensionSystem( + [length, mass, time], + [velocity, action], + { + velocity: {length: 1, time: -1} + } + ) + + assert dimsys.dim_can_vector(length) == Matrix([1, 0, 0]) + assert dimsys.dim_can_vector(velocity) == Matrix([1, 0, -1]) + + dimsys = DimensionSystem( + (length, velocity, action), + (mass, time), + { + time: {length: 1, velocity: -1} + } + ) + + assert dimsys.dim_can_vector(length) == Matrix([0, 1, 0]) + assert dimsys.dim_can_vector(velocity) == Matrix([0, 0, 1]) + assert dimsys.dim_can_vector(time) == Matrix([0, 1, -1]) + + dimsys = DimensionSystem( + (length, mass, time), + (velocity, action), + {velocity: {length: 1, time: -1}, + action: {mass: 1, length: 2, time: -1}}) + + assert dimsys.dim_vector(length) == Matrix([1, 0, 0]) + assert dimsys.dim_vector(velocity) == Matrix([1, 0, -1]) + + +def test_inv_can_transf_matrix(): + dimsys = DimensionSystem((length, mass, time)) + assert dimsys.inv_can_transf_matrix == eye(3) + + +def test_can_transf_matrix(): + dimsys = DimensionSystem((length, mass, time)) + assert dimsys.can_transf_matrix == eye(3) + + dimsys = DimensionSystem((length, velocity, action)) + assert dimsys.can_transf_matrix == eye(3) + + dimsys = DimensionSystem((length, time), (velocity,), {velocity: {length: 1, time: -1}}) + assert dimsys.can_transf_matrix == eye(2) + + +def test_is_consistent(): + assert DimensionSystem((length, time)).is_consistent is True + + +def test_print_dim_base(): + mksa = DimensionSystem( + (length, time, mass, current), + (action,), + {action: {mass: 1, length: 2, time: -1}}) + L, M, T = symbols("L M T") + assert mksa.print_dim_base(action) == L**2*M/T + + +def test_dim(): + dimsys = DimensionSystem( + (length, mass, time), + (velocity, action), + {velocity: {length: 1, time: -1}, + action: {mass: 1, length: 2, time: -1}} + ) + assert dimsys.dim == 3 diff --git a/phivenv/Lib/site-packages/sympy/physics/units/tests/test_prefixes.py b/phivenv/Lib/site-packages/sympy/physics/units/tests/test_prefixes.py new file mode 100644 index 0000000000000000000000000000000000000000..7b180102ecd00abf3ff5f8cb4c24aa82ae76ef77 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/units/tests/test_prefixes.py @@ -0,0 +1,86 @@ +from sympy.core.mul import Mul +from sympy.core.numbers import Rational +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.physics.units import Quantity, length, meter, W +from sympy.physics.units.prefixes import PREFIXES, Prefix, prefix_unit, kilo, \ + kibi +from sympy.physics.units.systems import SI + +x = Symbol('x') + + +def test_prefix_operations(): + m = PREFIXES['m'] + k = PREFIXES['k'] + M = PREFIXES['M'] + + dodeca = Prefix('dodeca', 'dd', 1, base=12) + + assert m * k is S.One + assert m * W == W / 1000 + assert k * k == M + assert 1 / m == k + assert k / m == M + + assert dodeca * dodeca == 144 + assert 1 / dodeca == S.One / 12 + assert k / dodeca == S(1000) / 12 + assert dodeca / dodeca is S.One + + m = Quantity("fake_meter") + SI.set_quantity_dimension(m, S.One) + SI.set_quantity_scale_factor(m, S.One) + + assert dodeca * m == 12 * m + assert dodeca / m == 12 / m + + expr1 = kilo * 3 + assert isinstance(expr1, Mul) + assert expr1.args == (3, kilo) + + expr2 = kilo * x + assert isinstance(expr2, Mul) + assert expr2.args == (x, kilo) + + expr3 = kilo / 3 + assert isinstance(expr3, Mul) + assert expr3.args == (Rational(1, 3), kilo) + assert expr3.args == (S.One/3, kilo) + + expr4 = kilo / x + assert isinstance(expr4, Mul) + assert expr4.args == (1/x, kilo) + + +def test_prefix_unit(): + m = Quantity("fake_meter", abbrev="m") + m.set_global_relative_scale_factor(1, meter) + + pref = {"m": PREFIXES["m"], "c": PREFIXES["c"], "d": PREFIXES["d"]} + + q1 = Quantity("millifake_meter", abbrev="mm") + q2 = Quantity("centifake_meter", abbrev="cm") + q3 = Quantity("decifake_meter", abbrev="dm") + + SI.set_quantity_dimension(q1, length) + + SI.set_quantity_scale_factor(q1, PREFIXES["m"]) + SI.set_quantity_scale_factor(q1, PREFIXES["c"]) + SI.set_quantity_scale_factor(q1, PREFIXES["d"]) + + res = [q1, q2, q3] + + prefs = prefix_unit(m, pref) + assert set(prefs) == set(res) + assert {v.abbrev for v in prefs} == set(symbols("mm,cm,dm")) + + +def test_bases(): + assert kilo.base == 10 + assert kibi.base == 2 + + +def test_repr(): + assert eval(repr(kilo)) == kilo + assert eval(repr(kibi)) == kibi diff --git a/phivenv/Lib/site-packages/sympy/physics/units/tests/test_quantities.py b/phivenv/Lib/site-packages/sympy/physics/units/tests/test_quantities.py new file mode 100644 index 0000000000000000000000000000000000000000..4e24ca48cc858bd8afd0b3c9762c4f8b6d0c5194 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/units/tests/test_quantities.py @@ -0,0 +1,575 @@ +import warnings + +from sympy.core.add import Add +from sympy.core.function import (Function, diff) +from sympy.core.numbers import (Number, Rational) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import sin +from sympy.integrals.integrals import integrate +from sympy.physics.units import (amount_of_substance, area, convert_to, find_unit, + volume, kilometer, joule, molar_gas_constant, + vacuum_permittivity, elementary_charge, volt, + ohm) +from sympy.physics.units.definitions import (amu, au, centimeter, coulomb, + day, foot, grams, hour, inch, kg, km, m, meter, millimeter, + minute, quart, s, second, speed_of_light, bit, + byte, kibibyte, mebibyte, gibibyte, tebibyte, pebibyte, exbibyte, + kilogram, gravitational_constant, electron_rest_mass) + +from sympy.physics.units.definitions.dimension_definitions import ( + Dimension, charge, length, time, temperature, pressure, + energy, mass +) +from sympy.physics.units.prefixes import PREFIXES, kilo +from sympy.physics.units.quantities import PhysicalConstant, Quantity +from sympy.physics.units.systems import SI +from sympy.testing.pytest import raises + +k = PREFIXES["k"] + + +def test_str_repr(): + assert str(kg) == "kilogram" + + +def test_eq(): + # simple test + assert 10*m == 10*m + assert 10*m != 10*s + + +def test_convert_to(): + q = Quantity("q1") + q.set_global_relative_scale_factor(S(5000), meter) + + assert q.convert_to(m) == 5000*m + + assert speed_of_light.convert_to(m / s) == 299792458 * m / s + assert day.convert_to(s) == 86400*s + + # Wrong dimension to convert: + assert q.convert_to(s) == q + assert speed_of_light.convert_to(m) == speed_of_light + + expr = joule*second + conv = convert_to(expr, joule) + assert conv == joule*second + + +def test_Quantity_definition(): + q = Quantity("s10", abbrev="sabbr") + q.set_global_relative_scale_factor(10, second) + u = Quantity("u", abbrev="dam") + u.set_global_relative_scale_factor(10, meter) + km = Quantity("km") + km.set_global_relative_scale_factor(kilo, meter) + v = Quantity("u") + v.set_global_relative_scale_factor(5*kilo, meter) + + assert q.scale_factor == 10 + assert q.dimension == time + assert q.abbrev == Symbol("sabbr") + + assert u.dimension == length + assert u.scale_factor == 10 + assert u.abbrev == Symbol("dam") + + assert km.scale_factor == 1000 + assert km.func(*km.args) == km + assert km.func(*km.args).args == km.args + + assert v.dimension == length + assert v.scale_factor == 5000 + + +def test_abbrev(): + u = Quantity("u") + u.set_global_relative_scale_factor(S.One, meter) + + assert u.name == Symbol("u") + assert u.abbrev == Symbol("u") + + u = Quantity("u", abbrev="om") + u.set_global_relative_scale_factor(S(2), meter) + + assert u.name == Symbol("u") + assert u.abbrev == Symbol("om") + assert u.scale_factor == 2 + assert isinstance(u.scale_factor, Number) + + u = Quantity("u", abbrev="ikm") + u.set_global_relative_scale_factor(3*kilo, meter) + + assert u.abbrev == Symbol("ikm") + assert u.scale_factor == 3000 + + +def test_print(): + u = Quantity("unitname", abbrev="dam") + assert repr(u) == "unitname" + assert str(u) == "unitname" + + +def test_Quantity_eq(): + u = Quantity("u", abbrev="dam") + v = Quantity("v1") + assert u != v + v = Quantity("v2", abbrev="ds") + assert u != v + v = Quantity("v3", abbrev="dm") + assert u != v + + +def test_add_sub(): + u = Quantity("u") + v = Quantity("v") + w = Quantity("w") + + u.set_global_relative_scale_factor(S(10), meter) + v.set_global_relative_scale_factor(S(5), meter) + w.set_global_relative_scale_factor(S(2), second) + + assert isinstance(u + v, Add) + assert (u + v.convert_to(u)) == (1 + S.Half)*u + assert isinstance(u - v, Add) + assert (u - v.convert_to(u)) == S.Half*u + + +def test_quantity_abs(): + v_w1 = Quantity('v_w1') + v_w2 = Quantity('v_w2') + v_w3 = Quantity('v_w3') + + v_w1.set_global_relative_scale_factor(1, meter/second) + v_w2.set_global_relative_scale_factor(1, meter/second) + v_w3.set_global_relative_scale_factor(1, meter/second) + + expr = v_w3 - Abs(v_w1 - v_w2) + + assert SI.get_dimensional_expr(v_w1) == (length/time).name + + Dq = Dimension(SI.get_dimensional_expr(expr)) + + assert SI.get_dimension_system().get_dimensional_dependencies(Dq) == { + length: 1, + time: -1, + } + assert meter == sqrt(meter**2) + + +def test_check_unit_consistency(): + u = Quantity("u") + v = Quantity("v") + w = Quantity("w") + + u.set_global_relative_scale_factor(S(10), meter) + v.set_global_relative_scale_factor(S(5), meter) + w.set_global_relative_scale_factor(S(2), second) + + def check_unit_consistency(expr): + SI._collect_factor_and_dimension(expr) + + raises(ValueError, lambda: check_unit_consistency(u + w)) + raises(ValueError, lambda: check_unit_consistency(u - w)) + raises(ValueError, lambda: check_unit_consistency(u + 1)) + raises(ValueError, lambda: check_unit_consistency(u - 1)) + raises(ValueError, lambda: check_unit_consistency(1 - exp(u / w))) + + +def test_mul_div(): + u = Quantity("u") + v = Quantity("v") + t = Quantity("t") + ut = Quantity("ut") + v2 = Quantity("v") + + u.set_global_relative_scale_factor(S(10), meter) + v.set_global_relative_scale_factor(S(5), meter) + t.set_global_relative_scale_factor(S(2), second) + ut.set_global_relative_scale_factor(S(20), meter*second) + v2.set_global_relative_scale_factor(S(5), meter/second) + + assert 1 / u == u**(-1) + assert u / 1 == u + + v1 = u / t + v2 = v + + # Pow only supports structural equality: + assert v1 != v2 + assert v1 == v2.convert_to(v1) + + # TODO: decide whether to allow such expression in the future + # (requires somehow manipulating the core). + # assert u / Quantity('l2', dimension=length, scale_factor=2) == 5 + + assert u * 1 == u + + ut1 = u * t + ut2 = ut + + # Mul only supports structural equality: + assert ut1 != ut2 + assert ut1 == ut2.convert_to(ut1) + + # Mul only supports structural equality: + lp1 = Quantity("lp1") + lp1.set_global_relative_scale_factor(S(2), 1/meter) + assert u * lp1 != 20 + + assert u**0 == 1 + assert u**1 == u + + # TODO: Pow only support structural equality: + u2 = Quantity("u2") + u3 = Quantity("u3") + u2.set_global_relative_scale_factor(S(100), meter**2) + u3.set_global_relative_scale_factor(Rational(1, 10), 1/meter) + + assert u ** 2 != u2 + assert u ** -1 != u3 + + assert u ** 2 == u2.convert_to(u) + assert u ** -1 == u3.convert_to(u) + + +def test_units(): + assert convert_to((5*m/s * day) / km, 1) == 432 + assert convert_to(foot / meter, meter) == Rational(3048, 10000) + # amu is a pure mass so mass/mass gives a number, not an amount (mol) + # TODO: need better simplification routine: + assert str(convert_to(grams/amu, grams).n(2)) == '6.0e+23' + + # Light from the sun needs about 8.3 minutes to reach earth + t = (1*au / speed_of_light) / minute + # TODO: need a better way to simplify expressions containing units: + t = convert_to(convert_to(t, meter / minute), meter) + assert t.simplify() == Rational(49865956897, 5995849160) + + # TODO: fix this, it should give `m` without `Abs` + assert sqrt(m**2) == m + assert (sqrt(m))**2 == m + + t = Symbol('t') + assert integrate(t*m/s, (t, 1*s, 5*s)) == 12*m*s + assert (t * m/s).integrate((t, 1*s, 5*s)) == 12*m*s + + +def test_issue_quart(): + assert convert_to(4 * quart / inch ** 3, meter) == 231 + assert convert_to(4 * quart / inch ** 3, millimeter) == 231 + +def test_electron_rest_mass(): + assert convert_to(electron_rest_mass, kilogram) == 9.1093837015e-31*kilogram + assert convert_to(electron_rest_mass, grams) == 9.1093837015e-28*grams + +def test_issue_5565(): + assert (m < s).is_Relational + + +def test_find_unit(): + assert find_unit('coulomb') == ['coulomb', 'coulombs', 'coulomb_constant'] + assert find_unit(coulomb) == ['C', 'coulomb', 'coulombs', 'planck_charge', 'elementary_charge'] + assert find_unit(charge) == ['C', 'coulomb', 'coulombs', 'planck_charge', 'elementary_charge'] + assert find_unit(inch) == [ + 'm', 'au', 'cm', 'dm', 'ft', 'km', 'ly', 'mi', 'mm', 'nm', 'pm', 'um', 'yd', + 'nmi', 'feet', 'foot', 'inch', 'mile', 'yard', 'meter', 'miles', 'yards', + 'inches', 'meters', 'micron', 'microns', 'angstrom', 'angstroms', 'decimeter', + 'kilometer', 'lightyear', 'nanometer', 'picometer', 'centimeter', 'decimeters', + 'kilometers', 'lightyears', 'micrometer', 'millimeter', 'nanometers', 'picometers', + 'centimeters', 'micrometers', 'millimeters', 'nautical_mile', 'planck_length', + 'nautical_miles', 'astronomical_unit', 'astronomical_units'] + assert find_unit(inch**-1) == ['D', 'dioptre', 'optical_power'] + assert find_unit(length**-1) == ['D', 'dioptre', 'optical_power'] + assert find_unit(inch ** 2) == ['ha', 'hectare', 'planck_area'] + assert find_unit(inch ** 3) == [ + 'L', 'l', 'cL', 'cl', 'dL', 'dl', 'mL', 'ml', 'liter', 'quart', 'liters', 'quarts', + 'deciliter', 'centiliter', 'deciliters', 'milliliter', + 'centiliters', 'milliliters', 'planck_volume'] + assert find_unit('voltage') == ['V', 'v', 'volt', 'volts', 'planck_voltage'] + assert find_unit(grams) == ['g', 't', 'Da', 'kg', 'me', 'mg', 'ug', 'amu', 'mmu', 'amus', + 'gram', 'mmus', 'grams', 'pound', 'tonne', 'dalton', 'pounds', + 'kilogram', 'kilograms', 'microgram', 'milligram', 'metric_ton', + 'micrograms', 'milligrams', 'planck_mass', 'milli_mass_unit', 'atomic_mass_unit', + 'electron_rest_mass', 'atomic_mass_constant'] + + +def test_Quantity_derivative(): + x = symbols("x") + assert diff(x*meter, x) == meter + assert diff(x**3*meter**2, x) == 3*x**2*meter**2 + assert diff(meter, meter) == 1 + assert diff(meter**2, meter) == 2*meter + + +def test_quantity_postprocessing(): + q1 = Quantity('q1') + q2 = Quantity('q2') + + SI.set_quantity_dimension(q1, length*pressure**2*temperature/time) + SI.set_quantity_dimension(q2, energy*pressure*temperature/(length**2*time)) + + assert q1 + q2 + q = q1 + q2 + Dq = Dimension(SI.get_dimensional_expr(q)) + assert SI.get_dimension_system().get_dimensional_dependencies(Dq) == { + length: -1, + mass: 2, + temperature: 1, + time: -5, + } + + +def test_factor_and_dimension(): + assert (3000, Dimension(1)) == SI._collect_factor_and_dimension(3000) + assert (1001, length) == SI._collect_factor_and_dimension(meter + km) + assert (2, length/time) == SI._collect_factor_and_dimension( + meter/second + 36*km/(10*hour)) + + x, y = symbols('x y') + assert (x + y/100, length) == SI._collect_factor_and_dimension( + x*m + y*centimeter) + + cH = Quantity('cH') + SI.set_quantity_dimension(cH, amount_of_substance/volume) + + pH = -log(cH) + + assert (1, volume/amount_of_substance) == SI._collect_factor_and_dimension( + exp(pH)) + + v_w1 = Quantity('v_w1') + v_w2 = Quantity('v_w2') + + v_w1.set_global_relative_scale_factor(Rational(3, 2), meter/second) + v_w2.set_global_relative_scale_factor(2, meter/second) + + expr = Abs(v_w1/2 - v_w2) + assert (Rational(5, 4), length/time) == \ + SI._collect_factor_and_dimension(expr) + + expr = Rational(5, 2)*second/meter*v_w1 - 3000 + assert (-(2996 + Rational(1, 4)), Dimension(1)) == \ + SI._collect_factor_and_dimension(expr) + + expr = v_w1**(v_w2/v_w1) + assert ((Rational(3, 2))**Rational(4, 3), (length/time)**Rational(4, 3)) == \ + SI._collect_factor_and_dimension(expr) + + +def test_dimensional_expr_of_derivative(): + l = Quantity('l') + t = Quantity('t') + t1 = Quantity('t1') + l.set_global_relative_scale_factor(36, km) + t.set_global_relative_scale_factor(1, hour) + t1.set_global_relative_scale_factor(1, second) + x = Symbol('x') + y = Symbol('y') + f = Function('f') + dfdx = f(x, y).diff(x, y) + dl_dt = dfdx.subs({f(x, y): l, x: t, y: t1}) + assert SI.get_dimensional_expr(dl_dt) ==\ + SI.get_dimensional_expr(l / t / t1) ==\ + Symbol("length")/Symbol("time")**2 + assert SI._collect_factor_and_dimension(dl_dt) ==\ + SI._collect_factor_and_dimension(l / t / t1) ==\ + (10, length/time**2) + + +def test_get_dimensional_expr_with_function(): + v_w1 = Quantity('v_w1') + v_w2 = Quantity('v_w2') + v_w1.set_global_relative_scale_factor(1, meter/second) + v_w2.set_global_relative_scale_factor(1, meter/second) + + assert SI.get_dimensional_expr(sin(v_w1)) == \ + sin(SI.get_dimensional_expr(v_w1)) + assert SI.get_dimensional_expr(sin(v_w1/v_w2)) == 1 + + +def test_binary_information(): + assert convert_to(kibibyte, byte) == 1024*byte + assert convert_to(mebibyte, byte) == 1024**2*byte + assert convert_to(gibibyte, byte) == 1024**3*byte + assert convert_to(tebibyte, byte) == 1024**4*byte + assert convert_to(pebibyte, byte) == 1024**5*byte + assert convert_to(exbibyte, byte) == 1024**6*byte + + assert kibibyte.convert_to(bit) == 8*1024*bit + assert byte.convert_to(bit) == 8*bit + + a = 10*kibibyte*hour + + assert convert_to(a, byte) == 10240*byte*hour + assert convert_to(a, minute) == 600*kibibyte*minute + assert convert_to(a, [byte, minute]) == 614400*byte*minute + + +def test_conversion_with_2_nonstandard_dimensions(): + good_grade = Quantity("good_grade") + kilo_good_grade = Quantity("kilo_good_grade") + centi_good_grade = Quantity("centi_good_grade") + + kilo_good_grade.set_global_relative_scale_factor(1000, good_grade) + centi_good_grade.set_global_relative_scale_factor(S.One/10**5, kilo_good_grade) + + charity_points = Quantity("charity_points") + milli_charity_points = Quantity("milli_charity_points") + missions = Quantity("missions") + + milli_charity_points.set_global_relative_scale_factor(S.One/1000, charity_points) + missions.set_global_relative_scale_factor(251, charity_points) + + assert convert_to( + kilo_good_grade*milli_charity_points*millimeter, + [centi_good_grade, missions, centimeter] + ) == S.One * 10**5 / (251*1000) / 10 * centi_good_grade*missions*centimeter + + +def test_eval_subs(): + energy, mass, force = symbols('energy mass force') + expr1 = energy/mass + units = {energy: kilogram*meter**2/second**2, mass: kilogram} + assert expr1.subs(units) == meter**2/second**2 + expr2 = force/mass + units = {force:gravitational_constant*kilogram**2/meter**2, mass:kilogram} + assert expr2.subs(units) == gravitational_constant*kilogram/meter**2 + + +def test_issue_14932(): + assert (log(inch) - log(2)).simplify() == log(inch/2) + assert (log(inch) - log(foot)).simplify() == -log(12) + p = symbols('p', positive=True) + assert (log(inch) - log(p)).simplify() == log(inch/p) + + +def test_issue_14547(): + # the root issue is that an argument with dimensions should + # not raise an error when the `arg - 1` calculation is + # performed in the assumptions system + from sympy.physics.units import foot, inch + from sympy.core.relational import Eq + assert log(foot).is_zero is None + assert log(foot).is_positive is None + assert log(foot).is_nonnegative is None + assert log(foot).is_negative is None + assert log(foot).is_algebraic is None + assert log(foot).is_rational is None + # doesn't raise error + assert Eq(log(foot), log(inch)) is not None # might be False or unevaluated + + x = Symbol('x') + e = foot + x + assert e.is_Add and set(e.args) == {foot, x} + e = foot + 1 + assert e.is_Add and set(e.args) == {foot, 1} + + +def test_issue_22164(): + warnings.simplefilter("error") + dm = Quantity("dm") + SI.set_quantity_dimension(dm, length) + SI.set_quantity_scale_factor(dm, 1) + + bad_exp = Quantity("bad_exp") + SI.set_quantity_dimension(bad_exp, length) + SI.set_quantity_scale_factor(bad_exp, 1) + + expr = dm ** bad_exp + + # deprecation warning is not expected here + SI._collect_factor_and_dimension(expr) + + +def test_issue_22819(): + from sympy.physics.units import tonne, gram, Da + from sympy.physics.units.systems.si import dimsys_SI + assert tonne.convert_to(gram) == 1000000*gram + assert dimsys_SI.get_dimensional_dependencies(area) == {length: 2} + assert Da.scale_factor == 1.66053906660000e-24 + + +def test_issue_20288(): + from sympy.core.numbers import E + from sympy.physics.units import energy + u = Quantity('u') + v = Quantity('v') + SI.set_quantity_dimension(u, energy) + SI.set_quantity_dimension(v, energy) + u.set_global_relative_scale_factor(1, joule) + v.set_global_relative_scale_factor(1, joule) + expr = 1 + exp(u**2/v**2) + assert SI._collect_factor_and_dimension(expr) == (1 + E, Dimension(1)) + + +def test_issue_24062(): + from sympy.core.numbers import E + from sympy.physics.units import impedance, capacitance, time, ohm, farad, second + + R = Quantity('R') + C = Quantity('C') + T = Quantity('T') + SI.set_quantity_dimension(R, impedance) + SI.set_quantity_dimension(C, capacitance) + SI.set_quantity_dimension(T, time) + R.set_global_relative_scale_factor(1, ohm) + C.set_global_relative_scale_factor(1, farad) + T.set_global_relative_scale_factor(1, second) + expr = T / (R * C) + dim = SI._collect_factor_and_dimension(expr)[1] + assert SI.get_dimension_system().is_dimensionless(dim) + + exp_expr = 1 + exp(expr) + assert SI._collect_factor_and_dimension(exp_expr) == (1 + E, Dimension(1)) + +def test_issue_24211(): + from sympy.physics.units import time, velocity, acceleration, second, meter + V1 = Quantity('V1') + SI.set_quantity_dimension(V1, velocity) + SI.set_quantity_scale_factor(V1, 1 * meter / second) + A1 = Quantity('A1') + SI.set_quantity_dimension(A1, acceleration) + SI.set_quantity_scale_factor(A1, 1 * meter / second**2) + T1 = Quantity('T1') + SI.set_quantity_dimension(T1, time) + SI.set_quantity_scale_factor(T1, 1 * second) + + expr = A1*T1 + V1 + # should not throw ValueError here + SI._collect_factor_and_dimension(expr) + + +def test_prefixed_property(): + assert not meter.is_prefixed + assert not joule.is_prefixed + assert not day.is_prefixed + assert not second.is_prefixed + assert not volt.is_prefixed + assert not ohm.is_prefixed + assert centimeter.is_prefixed + assert kilometer.is_prefixed + assert kilogram.is_prefixed + assert pebibyte.is_prefixed + +def test_physics_constant(): + from sympy.physics.units import definitions + + for name in dir(definitions): + quantity = getattr(definitions, name) + if not isinstance(quantity, Quantity): + continue + if name.endswith('_constant'): + assert isinstance(quantity, PhysicalConstant), f"{quantity} must be PhysicalConstant, but is {type(quantity)}" + assert quantity.is_physical_constant, f"{name} is not marked as physics constant when it should be" + + for const in [gravitational_constant, molar_gas_constant, vacuum_permittivity, speed_of_light, elementary_charge]: + assert isinstance(const, PhysicalConstant), f"{const} must be PhysicalConstant, but is {type(const)}" + assert const.is_physical_constant, f"{const} is not marked as physics constant when it should be" + + assert not meter.is_physical_constant + assert not joule.is_physical_constant diff --git a/phivenv/Lib/site-packages/sympy/physics/units/tests/test_unit_system_cgs_gauss.py b/phivenv/Lib/site-packages/sympy/physics/units/tests/test_unit_system_cgs_gauss.py new file mode 100644 index 0000000000000000000000000000000000000000..12629280785c94fa8be33bc97bdd714140a3e346 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/units/tests/test_unit_system_cgs_gauss.py @@ -0,0 +1,55 @@ +from sympy.concrete.tests.test_sums_products import NS + +from sympy.core.singleton import S +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.physics.units import convert_to, coulomb_constant, elementary_charge, gravitational_constant, planck +from sympy.physics.units.definitions.unit_definitions import angstrom, statcoulomb, coulomb, second, gram, centimeter, erg, \ + newton, joule, dyne, speed_of_light, meter, farad, henry, statvolt, volt, ohm +from sympy.physics.units.systems import SI +from sympy.physics.units.systems.cgs import cgs_gauss + + +def test_conversion_to_from_si(): + assert convert_to(statcoulomb, coulomb, cgs_gauss) == coulomb/2997924580 + assert convert_to(coulomb, statcoulomb, cgs_gauss) == 2997924580*statcoulomb + assert convert_to(statcoulomb, sqrt(gram*centimeter**3)/second, cgs_gauss) == centimeter**(S(3)/2)*sqrt(gram)/second + assert convert_to(coulomb, sqrt(gram*centimeter**3)/second, cgs_gauss) == 2997924580*centimeter**(S(3)/2)*sqrt(gram)/second + + # SI units have an additional base unit, no conversion in case of electromagnetism: + assert convert_to(coulomb, statcoulomb, SI) == coulomb + assert convert_to(statcoulomb, coulomb, SI) == statcoulomb + + # SI without electromagnetism: + assert convert_to(erg, joule, SI) == joule/10**7 + assert convert_to(erg, joule, cgs_gauss) == joule/10**7 + assert convert_to(joule, erg, SI) == 10**7*erg + assert convert_to(joule, erg, cgs_gauss) == 10**7*erg + + + assert convert_to(dyne, newton, SI) == newton/10**5 + assert convert_to(dyne, newton, cgs_gauss) == newton/10**5 + assert convert_to(newton, dyne, SI) == 10**5*dyne + assert convert_to(newton, dyne, cgs_gauss) == 10**5*dyne + + +def test_cgs_gauss_convert_constants(): + + assert convert_to(speed_of_light, centimeter/second, cgs_gauss) == 29979245800*centimeter/second + + assert convert_to(coulomb_constant, 1, cgs_gauss) == 1 + assert convert_to(coulomb_constant, newton*meter**2/coulomb**2, cgs_gauss) == 22468879468420441*meter**2*newton/(2500000*coulomb**2) + assert convert_to(coulomb_constant, newton*meter**2/coulomb**2, SI) == 22468879468420441*meter**2*newton/(2500000*coulomb**2) + assert convert_to(coulomb_constant, dyne*centimeter**2/statcoulomb**2, cgs_gauss) == centimeter**2*dyne/statcoulomb**2 + assert convert_to(coulomb_constant, 1, SI) == coulomb_constant + assert NS(convert_to(coulomb_constant, newton*meter**2/coulomb**2, SI)) == '8987551787.36818*meter**2*newton/coulomb**2' + + assert convert_to(elementary_charge, statcoulomb, cgs_gauss) + assert convert_to(angstrom, centimeter, cgs_gauss) == 1*centimeter/10**8 + assert convert_to(gravitational_constant, dyne*centimeter**2/gram**2, cgs_gauss) + assert NS(convert_to(planck, erg*second, cgs_gauss)) == '6.62607015e-27*erg*second' + + spc = 25000*second/(22468879468420441*centimeter) + assert convert_to(ohm, second/centimeter, cgs_gauss) == spc + assert convert_to(henry, second**2/centimeter, cgs_gauss) == spc*second + assert convert_to(volt, statvolt, cgs_gauss) == 10**6*statvolt/299792458 + assert convert_to(farad, centimeter, cgs_gauss) == 299792458**2*centimeter/10**5 diff --git a/phivenv/Lib/site-packages/sympy/physics/units/tests/test_unitsystem.py b/phivenv/Lib/site-packages/sympy/physics/units/tests/test_unitsystem.py new file mode 100644 index 0000000000000000000000000000000000000000..a04f3aabb6274bed4f1b82ac0719fa618b55eed7 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/units/tests/test_unitsystem.py @@ -0,0 +1,86 @@ +from sympy.physics.units import DimensionSystem, joule, second, ampere + +from sympy.core.numbers import Rational +from sympy.core.singleton import S +from sympy.physics.units.definitions import c, kg, m, s +from sympy.physics.units.definitions.dimension_definitions import length, time +from sympy.physics.units.quantities import Quantity +from sympy.physics.units.unitsystem import UnitSystem +from sympy.physics.units.util import convert_to + + +def test_definition(): + # want to test if the system can have several units of the same dimension + dm = Quantity("dm") + base = (m, s) + # base_dim = (m.dimension, s.dimension) + ms = UnitSystem(base, (c, dm), "MS", "MS system") + ms.set_quantity_dimension(dm, length) + ms.set_quantity_scale_factor(dm, Rational(1, 10)) + + assert set(ms._base_units) == set(base) + assert set(ms._units) == {m, s, c, dm} + # assert ms._units == DimensionSystem._sort_dims(base + (velocity,)) + assert ms.name == "MS" + assert ms.descr == "MS system" + + +def test_str_repr(): + assert str(UnitSystem((m, s), name="MS")) == "MS" + assert str(UnitSystem((m, s))) == "UnitSystem((meter, second))" + + assert repr(UnitSystem((m, s))) == "" % (m, s) + + +def test_convert_to(): + A = Quantity("A") + A.set_global_relative_scale_factor(S.One, ampere) + + Js = Quantity("Js") + Js.set_global_relative_scale_factor(S.One, joule*second) + + mksa = UnitSystem((m, kg, s, A), (Js,)) + assert convert_to(Js, mksa._base_units) == m**2*kg*s**-1/1000 + + +def test_extend(): + ms = UnitSystem((m, s), (c,)) + Js = Quantity("Js") + Js.set_global_relative_scale_factor(1, joule*second) + mks = ms.extend((kg,), (Js,)) + + res = UnitSystem((m, s, kg), (c, Js)) + assert set(mks._base_units) == set(res._base_units) + assert set(mks._units) == set(res._units) + + +def test_dim(): + dimsys = UnitSystem((m, kg, s), (c,)) + assert dimsys.dim == 3 + + +def test_is_consistent(): + dimension_system = DimensionSystem([length, time]) + us = UnitSystem([m, s], dimension_system=dimension_system) + assert us.is_consistent == True + + +def test_get_units_non_prefixed(): + from sympy.physics.units import volt, ohm + unit_system = UnitSystem.get_unit_system("SI") + units = unit_system.get_units_non_prefixed() + for prefix in ["giga", "tera", "peta", "exa", "zetta", "yotta", "kilo", "hecto", "deca", "deci", "centi", "milli", "micro", "nano", "pico", "femto", "atto", "zepto", "yocto"]: + for unit in units: + assert isinstance(unit, Quantity), f"{unit} must be a Quantity, not {type(unit)}" + assert not unit.is_prefixed, f"{unit} is marked as prefixed" + assert not unit.is_physical_constant, f"{unit} is marked as physics constant" + assert not unit.name.name.startswith(prefix), f"Unit {unit.name} has prefix {prefix}" + assert volt in units + assert ohm in units + +def test_derived_units_must_exist_in_unit_system(): + for unit_system in UnitSystem._unit_systems.values(): + for preferred_unit in unit_system.derived_units.values(): + units = preferred_unit.atoms(Quantity) + for unit in units: + assert unit in unit_system._units, f"Unit {unit} is not in unit system {unit_system}" diff --git a/phivenv/Lib/site-packages/sympy/physics/units/tests/test_util.py b/phivenv/Lib/site-packages/sympy/physics/units/tests/test_util.py new file mode 100644 index 0000000000000000000000000000000000000000..3522af675d33275f322e2b731309e19bffde1e1d --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/units/tests/test_util.py @@ -0,0 +1,178 @@ +from sympy.core.containers import Tuple +from sympy.core.numbers import pi +from sympy.core.power import Pow +from sympy.core.symbol import symbols +from sympy.core.sympify import sympify +from sympy.printing.str import sstr +from sympy.physics.units import ( + G, centimeter, coulomb, day, degree, gram, hbar, hour, inch, joule, kelvin, + kilogram, kilometer, length, meter, mile, minute, newton, planck, + planck_length, planck_mass, planck_temperature, planck_time, radians, + second, speed_of_light, steradian, time, km) +from sympy.physics.units.util import convert_to, check_dimensions +from sympy.testing.pytest import raises +from sympy.functions.elementary.miscellaneous import sqrt + + +def NS(e, n=15, **options): + return sstr(sympify(e).evalf(n, **options), full_prec=True) + + +L = length +T = time + + +def test_dim_simplify_add(): + # assert Add(L, L) == L + assert L + L == L + + +def test_dim_simplify_mul(): + # assert Mul(L, T) == L*T + assert L*T == L*T + + +def test_dim_simplify_pow(): + assert Pow(L, 2) == L**2 + + +def test_dim_simplify_rec(): + # assert Mul(Add(L, L), T) == L*T + assert (L + L) * T == L*T + + +def test_convert_to_quantities(): + assert convert_to(3, meter) == 3 + + assert convert_to(mile, kilometer) == 25146*kilometer/15625 + assert convert_to(meter/second, speed_of_light) == speed_of_light/299792458 + assert convert_to(299792458*meter/second, speed_of_light) == speed_of_light + assert convert_to(2*299792458*meter/second, speed_of_light) == 2*speed_of_light + assert convert_to(speed_of_light, meter/second) == 299792458*meter/second + assert convert_to(2*speed_of_light, meter/second) == 599584916*meter/second + assert convert_to(day, second) == 86400*second + assert convert_to(2*hour, minute) == 120*minute + assert convert_to(mile, meter) == 201168*meter/125 + assert convert_to(mile/hour, kilometer/hour) == 25146*kilometer/(15625*hour) + assert convert_to(3*newton, meter/second) == 3*newton + assert convert_to(3*newton, kilogram*meter/second**2) == 3*meter*kilogram/second**2 + assert convert_to(kilometer + mile, meter) == 326168*meter/125 + assert convert_to(2*kilometer + 3*mile, meter) == 853504*meter/125 + assert convert_to(inch**2, meter**2) == 16129*meter**2/25000000 + assert convert_to(3*inch**2, meter) == 48387*meter**2/25000000 + assert convert_to(2*kilometer/hour + 3*mile/hour, meter/second) == 53344*meter/(28125*second) + assert convert_to(2*kilometer/hour + 3*mile/hour, centimeter/second) == 213376*centimeter/(1125*second) + assert convert_to(kilometer * (mile + kilometer), meter) == 2609344 * meter ** 2 + + assert convert_to(steradian, coulomb) == steradian + assert convert_to(radians, degree) == 180*degree/pi + assert convert_to(radians, [meter, degree]) == 180*degree/pi + assert convert_to(pi*radians, degree) == 180*degree + assert convert_to(pi, degree) == 180*degree + + # https://github.com/sympy/sympy/issues/26263 + assert convert_to(sqrt(meter**2 + meter**2.0), meter) == sqrt(meter**2 + meter**2.0) + assert convert_to((meter**2 + meter**2.0)**2, meter) == (meter**2 + meter**2.0)**2 + + +def test_convert_to_tuples_of_quantities(): + from sympy.core.symbol import symbols + + alpha, beta = symbols('alpha beta') + + assert convert_to(speed_of_light, [meter, second]) == 299792458 * meter / second + assert convert_to(speed_of_light, (meter, second)) == 299792458 * meter / second + assert convert_to(speed_of_light, Tuple(meter, second)) == 299792458 * meter / second + assert convert_to(joule, [meter, kilogram, second]) == kilogram*meter**2/second**2 + assert convert_to(joule, [centimeter, gram, second]) == 10000000*centimeter**2*gram/second**2 + assert convert_to(299792458*meter/second, [speed_of_light]) == speed_of_light + assert convert_to(speed_of_light / 2, [meter, second, kilogram]) == meter/second*299792458 / 2 + # This doesn't make physically sense, but let's keep it as a conversion test: + assert convert_to(2 * speed_of_light, [meter, second, kilogram]) == 2 * 299792458 * meter / second + assert convert_to(G, [G, speed_of_light, planck]) == 1.0*G + + assert NS(convert_to(meter, [G, speed_of_light, hbar]), n=7) == '6.187142e+34*gravitational_constant**0.5000000*hbar**0.5000000/speed_of_light**1.500000' + assert NS(convert_to(planck_mass, kilogram), n=7) == '2.176434e-8*kilogram' + assert NS(convert_to(planck_length, meter), n=7) == '1.616255e-35*meter' + assert NS(convert_to(planck_time, second), n=6) == '5.39125e-44*second' + assert NS(convert_to(planck_temperature, kelvin), n=7) == '1.416784e+32*kelvin' + assert NS(convert_to(convert_to(meter, [G, speed_of_light, planck]), meter), n=10) == '1.000000000*meter' + + # similar to https://github.com/sympy/sympy/issues/26263 + assert convert_to(sqrt(meter**2 + second**2.0), [meter, second]) == sqrt(meter**2 + second**2.0) + assert convert_to((meter**2 + second**2.0)**2, [meter, second]) == (meter**2 + second**2.0)**2 + + # similar to https://github.com/sympy/sympy/issues/21463 + assert convert_to(1/(beta*meter + meter), 1/meter) == 1/(beta*meter + meter) + assert convert_to(1/(beta*meter + alpha*meter), 1/kilometer) == (1/(kilometer*beta/1000 + alpha*kilometer/1000)) + +def test_eval_simplify(): + from sympy.physics.units import cm, mm, km, m, K, kilo + from sympy.core.symbol import symbols + + x, y = symbols('x y') + + assert (cm/mm).simplify() == 10 + assert (km/m).simplify() == 1000 + assert (km/cm).simplify() == 100000 + assert (10*x*K*km**2/m/cm).simplify() == 1000000000*x*kelvin + assert (cm/km/m).simplify() == 1/(10000000*centimeter) + + assert (3*kilo*meter).simplify() == 3000*meter + assert (4*kilo*meter/(2*kilometer)).simplify() == 2 + assert (4*kilometer**2/(kilo*meter)**2).simplify() == 4 + + +def test_quantity_simplify(): + from sympy.physics.units.util import quantity_simplify + from sympy.physics.units import kilo, foot + from sympy.core.symbol import symbols + + x, y = symbols('x y') + + assert quantity_simplify(x*(8*kilo*newton*meter + y)) == x*(8000*meter*newton + y) + assert quantity_simplify(foot*inch*(foot + inch)) == foot**2*(foot + foot/12)/12 + assert quantity_simplify(foot*inch*(foot*foot + inch*(foot + inch))) == foot**2*(foot**2 + foot/12*(foot + foot/12))/12 + assert quantity_simplify(2**(foot/inch*kilo/1000)*inch) == 4096*foot/12 + assert quantity_simplify(foot**2*inch + inch**2*foot) == 13*foot**3/144 + +def test_quantity_simplify_across_dimensions(): + from sympy.physics.units.util import quantity_simplify + from sympy.physics.units import ampere, ohm, volt, joule, pascal, farad, second, watt, siemens, henry, tesla, weber, hour, newton + + assert quantity_simplify(ampere*ohm, across_dimensions=True, unit_system="SI") == volt + assert quantity_simplify(6*ampere*ohm, across_dimensions=True, unit_system="SI") == 6*volt + assert quantity_simplify(volt/ampere, across_dimensions=True, unit_system="SI") == ohm + assert quantity_simplify(volt/ohm, across_dimensions=True, unit_system="SI") == ampere + assert quantity_simplify(joule/meter**3, across_dimensions=True, unit_system="SI") == pascal + assert quantity_simplify(farad*ohm, across_dimensions=True, unit_system="SI") == second + assert quantity_simplify(joule/second, across_dimensions=True, unit_system="SI") == watt + assert quantity_simplify(meter**3/second, across_dimensions=True, unit_system="SI") == meter**3/second + assert quantity_simplify(joule/second, across_dimensions=True, unit_system="SI") == watt + + assert quantity_simplify(joule/coulomb, across_dimensions=True, unit_system="SI") == volt + assert quantity_simplify(volt/ampere, across_dimensions=True, unit_system="SI") == ohm + assert quantity_simplify(ampere/volt, across_dimensions=True, unit_system="SI") == siemens + assert quantity_simplify(coulomb/volt, across_dimensions=True, unit_system="SI") == farad + assert quantity_simplify(volt*second/ampere, across_dimensions=True, unit_system="SI") == henry + assert quantity_simplify(volt*second/meter**2, across_dimensions=True, unit_system="SI") == tesla + assert quantity_simplify(joule/ampere, across_dimensions=True, unit_system="SI") == weber + + assert quantity_simplify(5*kilometer/hour, across_dimensions=True, unit_system="SI") == 25*meter/(18*second) + assert quantity_simplify(5*kilogram*meter/second**2, across_dimensions=True, unit_system="SI") == 5*newton + +def test_check_dimensions(): + x = symbols('x') + assert check_dimensions(inch + x) == inch + x + assert check_dimensions(length + x) == length + x + # after subs we get 2*length; check will clear the constant + assert check_dimensions((length + x).subs(x, length)) == length + assert check_dimensions(newton*meter + joule) == joule + meter*newton + raises(ValueError, lambda: check_dimensions(inch + 1)) + raises(ValueError, lambda: check_dimensions(length + 1)) + raises(ValueError, lambda: check_dimensions(length + time)) + raises(ValueError, lambda: check_dimensions(meter + second)) + raises(ValueError, lambda: check_dimensions(2 * meter + second)) + raises(ValueError, lambda: check_dimensions(2 * meter + 3 * second)) + raises(ValueError, lambda: check_dimensions(1 / second + 1 / meter)) + raises(ValueError, lambda: check_dimensions(2 * meter*(mile + centimeter) + km)) diff --git a/phivenv/Lib/site-packages/sympy/physics/units/unitsystem.py b/phivenv/Lib/site-packages/sympy/physics/units/unitsystem.py new file mode 100644 index 0000000000000000000000000000000000000000..795f8026e9df7236fdb2abf882043a843797219d --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/units/unitsystem.py @@ -0,0 +1,204 @@ +""" +Unit system for physical quantities; include definition of constants. +""" +from __future__ import annotations + +from sympy.core.add import Add +from sympy.core.function import (Derivative, Function) +from sympy.core.mul import Mul +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.physics.units.dimensions import _QuantityMapper +from sympy.physics.units.quantities import Quantity + +from .dimensions import Dimension + + +class UnitSystem(_QuantityMapper): + """ + UnitSystem represents a coherent set of units. + + A unit system is basically a dimension system with notions of scales. Many + of the methods are defined in the same way. + + It is much better if all base units have a symbol. + """ + + _unit_systems: dict[str, UnitSystem] = {} + + def __init__(self, base_units, units=(), name="", descr="", dimension_system=None, derived_units: dict[Dimension, Quantity]={}): + + UnitSystem._unit_systems[name] = self + + self.name = name + self.descr = descr + + self._base_units = base_units + self._dimension_system = dimension_system + self._units = tuple(set(base_units) | set(units)) + self._base_units = tuple(base_units) + self._derived_units = derived_units + + super().__init__() + + def __str__(self): + """ + Return the name of the system. + + If it does not exist, then it makes a list of symbols (or names) of + the base dimensions. + """ + + if self.name != "": + return self.name + else: + return "UnitSystem((%s))" % ", ".join( + str(d) for d in self._base_units) + + def __repr__(self): + return '' % repr(self._base_units) + + def extend(self, base, units=(), name="", description="", dimension_system=None, derived_units: dict[Dimension, Quantity]={}): + """Extend the current system into a new one. + + Take the base and normal units of the current system to merge + them to the base and normal units given in argument. + If not provided, name and description are overridden by empty strings. + """ + + base = self._base_units + tuple(base) + units = self._units + tuple(units) + + return UnitSystem(base, units, name, description, dimension_system, {**self._derived_units, **derived_units}) + + def get_dimension_system(self): + return self._dimension_system + + def get_quantity_dimension(self, unit): + qdm = self.get_dimension_system()._quantity_dimension_map + if unit in qdm: + return qdm[unit] + return super().get_quantity_dimension(unit) + + def get_quantity_scale_factor(self, unit): + qsfm = self.get_dimension_system()._quantity_scale_factors + if unit in qsfm: + return qsfm[unit] + return super().get_quantity_scale_factor(unit) + + @staticmethod + def get_unit_system(unit_system): + if isinstance(unit_system, UnitSystem): + return unit_system + + if unit_system not in UnitSystem._unit_systems: + raise ValueError( + "Unit system is not supported. Currently" + "supported unit systems are {}".format( + ", ".join(sorted(UnitSystem._unit_systems)) + ) + ) + + return UnitSystem._unit_systems[unit_system] + + @staticmethod + def get_default_unit_system(): + return UnitSystem._unit_systems["SI"] + + @property + def dim(self): + """ + Give the dimension of the system. + + That is return the number of units forming the basis. + """ + return len(self._base_units) + + @property + def is_consistent(self): + """ + Check if the underlying dimension system is consistent. + """ + # test is performed in DimensionSystem + return self.get_dimension_system().is_consistent + + @property + def derived_units(self) -> dict[Dimension, Quantity]: + return self._derived_units + + def get_dimensional_expr(self, expr): + from sympy.physics.units import Quantity + if isinstance(expr, Mul): + return Mul(*[self.get_dimensional_expr(i) for i in expr.args]) + elif isinstance(expr, Pow): + return self.get_dimensional_expr(expr.base) ** expr.exp + elif isinstance(expr, Add): + return self.get_dimensional_expr(expr.args[0]) + elif isinstance(expr, Derivative): + dim = self.get_dimensional_expr(expr.expr) + for independent, count in expr.variable_count: + dim /= self.get_dimensional_expr(independent)**count + return dim + elif isinstance(expr, Function): + args = [self.get_dimensional_expr(arg) for arg in expr.args] + if all(i == 1 for i in args): + return S.One + return expr.func(*args) + elif isinstance(expr, Quantity): + return self.get_quantity_dimension(expr).name + return S.One + + def _collect_factor_and_dimension(self, expr): + """ + Return tuple with scale factor expression and dimension expression. + """ + from sympy.physics.units import Quantity + if isinstance(expr, Quantity): + return expr.scale_factor, expr.dimension + elif isinstance(expr, Mul): + factor = 1 + dimension = Dimension(1) + for arg in expr.args: + arg_factor, arg_dim = self._collect_factor_and_dimension(arg) + factor *= arg_factor + dimension *= arg_dim + return factor, dimension + elif isinstance(expr, Pow): + factor, dim = self._collect_factor_and_dimension(expr.base) + exp_factor, exp_dim = self._collect_factor_and_dimension(expr.exp) + if self.get_dimension_system().is_dimensionless(exp_dim): + exp_dim = 1 + return factor ** exp_factor, dim ** (exp_factor * exp_dim) + elif isinstance(expr, Add): + factor, dim = self._collect_factor_and_dimension(expr.args[0]) + for addend in expr.args[1:]: + addend_factor, addend_dim = \ + self._collect_factor_and_dimension(addend) + if not self.get_dimension_system().equivalent_dims(dim, addend_dim): + raise ValueError( + 'Dimension of "{}" is {}, ' + 'but it should be {}'.format( + addend, addend_dim, dim)) + factor += addend_factor + return factor, dim + elif isinstance(expr, Derivative): + factor, dim = self._collect_factor_and_dimension(expr.args[0]) + for independent, count in expr.variable_count: + ifactor, idim = self._collect_factor_and_dimension(independent) + factor /= ifactor**count + dim /= idim**count + return factor, dim + elif isinstance(expr, Function): + fds = [self._collect_factor_and_dimension(arg) for arg in expr.args] + dims = [Dimension(1) if self.get_dimension_system().is_dimensionless(d[1]) else d[1] for d in fds] + return (expr.func(*(f[0] for f in fds)), *dims) + elif isinstance(expr, Dimension): + return S.One, expr + else: + return expr, Dimension(1) + + def get_units_non_prefixed(self) -> set[Quantity]: + """ + Return the units of the system that do not have a prefix. + """ + return set(filter(lambda u: not u.is_prefixed and not u.is_physical_constant, self._units)) diff --git a/phivenv/Lib/site-packages/sympy/physics/units/util.py b/phivenv/Lib/site-packages/sympy/physics/units/util.py new file mode 100644 index 0000000000000000000000000000000000000000..ccd6300acdb1a3c60b74076d4700e7f699ca46f5 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/units/util.py @@ -0,0 +1,265 @@ +""" +Several methods to simplify expressions involving unit objects. +""" +from functools import reduce +from collections.abc import Iterable +from typing import Optional + +from sympy import default_sort_key +from sympy.core.add import Add +from sympy.core.containers import Tuple +from sympy.core.mul import Mul +from sympy.core.power import Pow +from sympy.core.sorting import ordered +from sympy.core.sympify import sympify +from sympy.core.function import Function +from sympy.matrices.exceptions import NonInvertibleMatrixError +from sympy.physics.units.dimensions import Dimension, DimensionSystem +from sympy.physics.units.prefixes import Prefix +from sympy.physics.units.quantities import Quantity +from sympy.physics.units.unitsystem import UnitSystem +from sympy.utilities.iterables import sift + + +def _get_conversion_matrix_for_expr(expr, target_units, unit_system): + from sympy.matrices.dense import Matrix + + dimension_system = unit_system.get_dimension_system() + + expr_dim = Dimension(unit_system.get_dimensional_expr(expr)) + dim_dependencies = dimension_system.get_dimensional_dependencies(expr_dim, mark_dimensionless=True) + target_dims = [Dimension(unit_system.get_dimensional_expr(x)) for x in target_units] + canon_dim_units = [i for x in target_dims for i in dimension_system.get_dimensional_dependencies(x, mark_dimensionless=True)] + canon_expr_units = set(dim_dependencies) + + if not canon_expr_units.issubset(set(canon_dim_units)): + return None + + seen = set() + canon_dim_units = [i for i in canon_dim_units if not (i in seen or seen.add(i))] + + camat = Matrix([[dimension_system.get_dimensional_dependencies(i, mark_dimensionless=True).get(j, 0) for i in target_dims] for j in canon_dim_units]) + exprmat = Matrix([dim_dependencies.get(k, 0) for k in canon_dim_units]) + + try: + res_exponents = camat.solve(exprmat) + except NonInvertibleMatrixError: + return None + + return res_exponents + + +def convert_to(expr, target_units, unit_system="SI"): + """ + Convert ``expr`` to the same expression with all of its units and quantities + represented as factors of ``target_units``, whenever the dimension is compatible. + + ``target_units`` may be a single unit/quantity, or a collection of + units/quantities. + + Examples + ======== + + >>> from sympy.physics.units import speed_of_light, meter, gram, second, day + >>> from sympy.physics.units import mile, newton, kilogram, atomic_mass_constant + >>> from sympy.physics.units import kilometer, centimeter + >>> from sympy.physics.units import gravitational_constant, hbar + >>> from sympy.physics.units import convert_to + >>> convert_to(mile, kilometer) + 25146*kilometer/15625 + >>> convert_to(mile, kilometer).n() + 1.609344*kilometer + >>> convert_to(speed_of_light, meter/second) + 299792458*meter/second + >>> convert_to(day, second) + 86400*second + >>> 3*newton + 3*newton + >>> convert_to(3*newton, kilogram*meter/second**2) + 3*kilogram*meter/second**2 + >>> convert_to(atomic_mass_constant, gram) + 1.660539060e-24*gram + + Conversion to multiple units: + + >>> convert_to(speed_of_light, [meter, second]) + 299792458*meter/second + >>> convert_to(3*newton, [centimeter, gram, second]) + 300000*centimeter*gram/second**2 + + Conversion to Planck units: + + >>> convert_to(atomic_mass_constant, [gravitational_constant, speed_of_light, hbar]).n() + 7.62963087839509e-20*hbar**0.5*speed_of_light**0.5/gravitational_constant**0.5 + + """ + from sympy.physics.units import UnitSystem + unit_system = UnitSystem.get_unit_system(unit_system) + + if not isinstance(target_units, (Iterable, Tuple)): + target_units = [target_units] + + def handle_Adds(expr): + return Add.fromiter(convert_to(i, target_units, unit_system) + for i in expr.args) + + if isinstance(expr, Add): + return handle_Adds(expr) + elif isinstance(expr, Pow) and isinstance(expr.base, Add): + return handle_Adds(expr.base) ** expr.exp + + expr = sympify(expr) + target_units = sympify(target_units) + + if isinstance(expr, Function): + expr = expr.together() + + if not isinstance(expr, Quantity) and expr.has(Quantity): + expr = expr.replace(lambda x: isinstance(x, Quantity), + lambda x: x.convert_to(target_units, unit_system)) + + def get_total_scale_factor(expr): + if isinstance(expr, Mul): + return reduce(lambda x, y: x * y, + [get_total_scale_factor(i) for i in expr.args]) + elif isinstance(expr, Pow): + return get_total_scale_factor(expr.base) ** expr.exp + elif isinstance(expr, Quantity): + return unit_system.get_quantity_scale_factor(expr) + return expr + + depmat = _get_conversion_matrix_for_expr(expr, target_units, unit_system) + if depmat is None: + return expr + + expr_scale_factor = get_total_scale_factor(expr) + return expr_scale_factor * Mul.fromiter( + (1/get_total_scale_factor(u)*u)**p for u, p in + zip(target_units, depmat)) + + +def quantity_simplify(expr, across_dimensions: bool=False, unit_system=None): + """Return an equivalent expression in which prefixes are replaced + with numerical values and all units of a given dimension are the + unified in a canonical manner by default. `across_dimensions` allows + for units of different dimensions to be simplified together. + + `unit_system` must be specified if `across_dimensions` is True. + + Examples + ======== + + >>> from sympy.physics.units.util import quantity_simplify + >>> from sympy.physics.units.prefixes import kilo + >>> from sympy.physics.units import foot, inch, joule, coulomb + >>> quantity_simplify(kilo*foot*inch) + 250*foot**2/3 + >>> quantity_simplify(foot - 6*inch) + foot/2 + >>> quantity_simplify(5*joule/coulomb, across_dimensions=True, unit_system="SI") + 5*volt + """ + + if expr.is_Atom or not expr.has(Prefix, Quantity): + return expr + + # replace all prefixes with numerical values + p = expr.atoms(Prefix) + expr = expr.xreplace({p: p.scale_factor for p in p}) + + # replace all quantities of given dimension with a canonical + # quantity, chosen from those in the expression + d = sift(expr.atoms(Quantity), lambda i: i.dimension) + for k in d: + if len(d[k]) == 1: + continue + v = list(ordered(d[k])) + ref = v[0]/v[0].scale_factor + expr = expr.xreplace({vi: ref*vi.scale_factor for vi in v[1:]}) + + if across_dimensions: + # combine quantities of different dimensions into a single + # quantity that is equivalent to the original expression + + if unit_system is None: + raise ValueError("unit_system must be specified if across_dimensions is True") + + unit_system = UnitSystem.get_unit_system(unit_system) + dimension_system: DimensionSystem = unit_system.get_dimension_system() + dim_expr = unit_system.get_dimensional_expr(expr) + dim_deps = dimension_system.get_dimensional_dependencies(dim_expr, mark_dimensionless=True) + + target_dimension: Optional[Dimension] = None + for ds_dim, ds_dim_deps in dimension_system.dimensional_dependencies.items(): + if ds_dim_deps == dim_deps: + target_dimension = ds_dim + break + + if target_dimension is None: + # if we can't find a target dimension, we can't do anything. unsure how to handle this case. + return expr + + target_unit = unit_system.derived_units.get(target_dimension) + if target_unit: + expr = convert_to(expr, target_unit, unit_system) + + return expr + + +def check_dimensions(expr, unit_system="SI"): + """Return expr if units in addends have the same + base dimensions, else raise a ValueError.""" + # the case of adding a number to a dimensional quantity + # is ignored for the sake of SymPy core routines, so this + # function will raise an error now if such an addend is + # found. + # Also, when doing substitutions, multiplicative constants + # might be introduced, so remove those now + + from sympy.physics.units import UnitSystem + unit_system = UnitSystem.get_unit_system(unit_system) + + def addDict(dict1, dict2): + """Merge dictionaries by adding values of common keys and + removing keys with value of 0.""" + dict3 = {**dict1, **dict2} + for key, value in dict3.items(): + if key in dict1 and key in dict2: + dict3[key] = value + dict1[key] + return {key:val for key, val in dict3.items() if val != 0} + + adds = expr.atoms(Add) + DIM_OF = unit_system.get_dimension_system().get_dimensional_dependencies + for a in adds: + deset = set() + for ai in a.args: + if ai.is_number: + deset.add(()) + continue + dims = [] + skip = False + dimdict = {} + for i in Mul.make_args(ai): + if i.has(Quantity): + i = Dimension(unit_system.get_dimensional_expr(i)) + if i.has(Dimension): + dimdict = addDict(dimdict, DIM_OF(i)) + elif i.free_symbols: + skip = True + break + dims.extend(dimdict.items()) + if not skip: + deset.add(tuple(sorted(dims, key=default_sort_key))) + if len(deset) > 1: + raise ValueError( + "addends have incompatible dimensions: {}".format(deset)) + + # clear multiplicative constants on Dimensions which may be + # left after substitution + reps = {} + for m in expr.atoms(Mul): + if any(isinstance(i, Dimension) for i in m.args): + reps[m] = m.func(*[ + i for i in m.args if not i.is_number]) + + return expr.xreplace(reps) diff --git a/phivenv/Lib/site-packages/sympy/physics/vector/__init__.py b/phivenv/Lib/site-packages/sympy/physics/vector/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e714852064c0b940ebda2e5fe7a08faf13f07ed0 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/vector/__init__.py @@ -0,0 +1,36 @@ +__all__ = [ + 'CoordinateSym', 'ReferenceFrame', + + 'Dyadic', + + 'Vector', + + 'Point', + + 'cross', 'dot', 'express', 'time_derivative', 'outer', + 'kinematic_equations', 'get_motion_params', 'partial_velocity', + 'dynamicsymbols', + + 'vprint', 'vsstrrepr', 'vsprint', 'vpprint', 'vlatex', 'init_vprinting', + + 'curl', 'divergence', 'gradient', 'is_conservative', 'is_solenoidal', + 'scalar_potential', 'scalar_potential_difference', + +] +from .frame import CoordinateSym, ReferenceFrame + +from .dyadic import Dyadic + +from .vector import Vector + +from .point import Point + +from .functions import (cross, dot, express, time_derivative, outer, + kinematic_equations, get_motion_params, partial_velocity, + dynamicsymbols) + +from .printing import (vprint, vsstrrepr, vsprint, vpprint, vlatex, + init_vprinting) + +from .fieldfunctions import (curl, divergence, gradient, is_conservative, + is_solenoidal, scalar_potential, scalar_potential_difference) diff --git a/phivenv/Lib/site-packages/sympy/physics/vector/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/vector/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..01489cd8c1c1cfcf658bc9220cbca4f3a39494b7 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/vector/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/vector/__pycache__/dyadic.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/vector/__pycache__/dyadic.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..728d6e76747929e7ba90d1339dd64e7e5a1a60fd Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/vector/__pycache__/dyadic.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/vector/__pycache__/fieldfunctions.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/vector/__pycache__/fieldfunctions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e815e34ac8d917511e01c44cf9eb639adb889af Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/vector/__pycache__/fieldfunctions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/vector/__pycache__/frame.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/vector/__pycache__/frame.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed71aea5a60fffb7a92d7ccdb603ff1f29300a9b Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/vector/__pycache__/frame.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/vector/__pycache__/functions.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/vector/__pycache__/functions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f48806f1b34742b7940c3af2b10b35b1dd69678 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/vector/__pycache__/functions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/vector/__pycache__/point.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/vector/__pycache__/point.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..35c97448affcb51f33cb3c62d3b150be9376217f Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/vector/__pycache__/point.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/vector/__pycache__/printing.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/vector/__pycache__/printing.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b1f088074b30ccc19ed194d17b241c1144d7690 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/vector/__pycache__/printing.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/vector/__pycache__/vector.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/vector/__pycache__/vector.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ca81e2c29c39d6352741c060200837cd24e23f0 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/vector/__pycache__/vector.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/vector/dyadic.py b/phivenv/Lib/site-packages/sympy/physics/vector/dyadic.py new file mode 100644 index 0000000000000000000000000000000000000000..0adacab2c2be5a287f59b6944206a07398a5fb9d --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/vector/dyadic.py @@ -0,0 +1,545 @@ +from sympy import sympify, Add, ImmutableMatrix as Matrix +from sympy.core.evalf import EvalfMixin +from sympy.printing.defaults import Printable + +from mpmath.libmp.libmpf import prec_to_dps + + +__all__ = ['Dyadic'] + + +class Dyadic(Printable, EvalfMixin): + """A Dyadic object. + + See: + https://en.wikipedia.org/wiki/Dyadic_tensor + Kane, T., Levinson, D. Dynamics Theory and Applications. 1985 McGraw-Hill + + A more powerful way to represent a rigid body's inertia. While it is more + complex, by choosing Dyadic components to be in body fixed basis vectors, + the resulting matrix is equivalent to the inertia tensor. + + """ + + is_number = False + + def __init__(self, inlist): + """ + Just like Vector's init, you should not call this unless creating a + zero dyadic. + + zd = Dyadic(0) + + Stores a Dyadic as a list of lists; the inner list has the measure + number and the two unit vectors; the outerlist holds each unique + unit vector pair. + + """ + + self.args = [] + if inlist == 0: + inlist = [] + while len(inlist) != 0: + added = 0 + for i, v in enumerate(self.args): + if ((str(inlist[0][1]) == str(self.args[i][1])) and + (str(inlist[0][2]) == str(self.args[i][2]))): + self.args[i] = (self.args[i][0] + inlist[0][0], + inlist[0][1], inlist[0][2]) + inlist.remove(inlist[0]) + added = 1 + break + if added != 1: + self.args.append(inlist[0]) + inlist.remove(inlist[0]) + i = 0 + # This code is to remove empty parts from the list + while i < len(self.args): + if ((self.args[i][0] == 0) | (self.args[i][1] == 0) | + (self.args[i][2] == 0)): + self.args.remove(self.args[i]) + i -= 1 + i += 1 + + @property + def func(self): + """Returns the class Dyadic. """ + return Dyadic + + def __add__(self, other): + """The add operator for Dyadic. """ + other = _check_dyadic(other) + return Dyadic(self.args + other.args) + + __radd__ = __add__ + + def __mul__(self, other): + """Multiplies the Dyadic by a sympifyable expression. + + Parameters + ========== + + other : Sympafiable + The scalar to multiply this Dyadic with + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame, outer + >>> N = ReferenceFrame('N') + >>> d = outer(N.x, N.x) + >>> 5 * d + 5*(N.x|N.x) + + """ + newlist = list(self.args) + other = sympify(other) + for i in range(len(newlist)): + newlist[i] = (other * newlist[i][0], newlist[i][1], + newlist[i][2]) + return Dyadic(newlist) + + __rmul__ = __mul__ + + def dot(self, other): + """The inner product operator for a Dyadic and a Dyadic or Vector. + + Parameters + ========== + + other : Dyadic or Vector + The other Dyadic or Vector to take the inner product with + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame, outer + >>> N = ReferenceFrame('N') + >>> D1 = outer(N.x, N.y) + >>> D2 = outer(N.y, N.y) + >>> D1.dot(D2) + (N.x|N.y) + >>> D1.dot(N.y) + N.x + + """ + from sympy.physics.vector.vector import Vector, _check_vector + if isinstance(other, Dyadic): + other = _check_dyadic(other) + ol = Dyadic(0) + for v in self.args: + for v2 in other.args: + ol += v[0] * v2[0] * (v[2].dot(v2[1])) * (v[1].outer(v2[2])) + else: + other = _check_vector(other) + ol = Vector(0) + for v in self.args: + ol += v[0] * v[1] * (v[2].dot(other)) + return ol + + # NOTE : supports non-advertised Dyadic & Dyadic, Dyadic & Vector notation + __and__ = dot + + def __truediv__(self, other): + """Divides the Dyadic by a sympifyable expression. """ + return self.__mul__(1 / other) + + def __eq__(self, other): + """Tests for equality. + + Is currently weak; needs stronger comparison testing + + """ + + if other == 0: + other = Dyadic(0) + other = _check_dyadic(other) + if (self.args == []) and (other.args == []): + return True + elif (self.args == []) or (other.args == []): + return False + return set(self.args) == set(other.args) + + def __ne__(self, other): + return not self == other + + def __neg__(self): + return self * -1 + + def _latex(self, printer): + ar = self.args # just to shorten things + if len(ar) == 0: + return str(0) + ol = [] # output list, to be concatenated to a string + for v in ar: + # if the coef of the dyadic is 1, we skip the 1 + if v[0] == 1: + ol.append(' + ' + printer._print(v[1]) + r"\otimes " + + printer._print(v[2])) + # if the coef of the dyadic is -1, we skip the 1 + elif v[0] == -1: + ol.append(' - ' + + printer._print(v[1]) + + r"\otimes " + + printer._print(v[2])) + # If the coefficient of the dyadic is not 1 or -1, + # we might wrap it in parentheses, for readability. + elif v[0] != 0: + arg_str = printer._print(v[0]) + if isinstance(v[0], Add): + arg_str = '(%s)' % arg_str + if arg_str.startswith('-'): + arg_str = arg_str[1:] + str_start = ' - ' + else: + str_start = ' + ' + ol.append(str_start + arg_str + printer._print(v[1]) + + r"\otimes " + printer._print(v[2])) + outstr = ''.join(ol) + if outstr.startswith(' + '): + outstr = outstr[3:] + elif outstr.startswith(' '): + outstr = outstr[1:] + return outstr + + def _pretty(self, printer): + e = self + + class Fake: + baseline = 0 + + def render(self, *args, **kwargs): + ar = e.args # just to shorten things + mpp = printer + if len(ar) == 0: + return str(0) + bar = "\N{CIRCLED TIMES}" if printer._use_unicode else "|" + ol = [] # output list, to be concatenated to a string + for v in ar: + # if the coef of the dyadic is 1, we skip the 1 + if v[0] == 1: + ol.extend([" + ", + mpp.doprint(v[1]), + bar, + mpp.doprint(v[2])]) + + # if the coef of the dyadic is -1, we skip the 1 + elif v[0] == -1: + ol.extend([" - ", + mpp.doprint(v[1]), + bar, + mpp.doprint(v[2])]) + + # If the coefficient of the dyadic is not 1 or -1, + # we might wrap it in parentheses, for readability. + elif v[0] != 0: + if isinstance(v[0], Add): + arg_str = mpp._print( + v[0]).parens()[0] + else: + arg_str = mpp.doprint(v[0]) + if arg_str.startswith("-"): + arg_str = arg_str[1:] + str_start = " - " + else: + str_start = " + " + ol.extend([str_start, arg_str, " ", + mpp.doprint(v[1]), + bar, + mpp.doprint(v[2])]) + + outstr = "".join(ol) + if outstr.startswith(" + "): + outstr = outstr[3:] + elif outstr.startswith(" "): + outstr = outstr[1:] + return outstr + return Fake() + + def __rsub__(self, other): + return (-1 * self) + other + + def _sympystr(self, printer): + """Printing method. """ + ar = self.args # just to shorten things + if len(ar) == 0: + return printer._print(0) + ol = [] # output list, to be concatenated to a string + for v in ar: + # if the coef of the dyadic is 1, we skip the 1 + if v[0] == 1: + ol.append(' + (' + printer._print(v[1]) + '|' + + printer._print(v[2]) + ')') + # if the coef of the dyadic is -1, we skip the 1 + elif v[0] == -1: + ol.append(' - (' + printer._print(v[1]) + '|' + + printer._print(v[2]) + ')') + # If the coefficient of the dyadic is not 1 or -1, + # we might wrap it in parentheses, for readability. + elif v[0] != 0: + arg_str = printer._print(v[0]) + if isinstance(v[0], Add): + arg_str = "(%s)" % arg_str + if arg_str[0] == '-': + arg_str = arg_str[1:] + str_start = ' - ' + else: + str_start = ' + ' + ol.append(str_start + arg_str + '*(' + + printer._print(v[1]) + + '|' + printer._print(v[2]) + ')') + outstr = ''.join(ol) + if outstr.startswith(' + '): + outstr = outstr[3:] + elif outstr.startswith(' '): + outstr = outstr[1:] + return outstr + + def __sub__(self, other): + """The subtraction operator. """ + return self.__add__(other * -1) + + def cross(self, other): + """Returns the dyadic resulting from the dyadic vector cross product: + Dyadic x Vector. + + Parameters + ========== + other : Vector + Vector to cross with. + + Examples + ======== + >>> from sympy.physics.vector import ReferenceFrame, outer, cross + >>> N = ReferenceFrame('N') + >>> d = outer(N.x, N.x) + >>> cross(d, N.y) + (N.x|N.z) + + """ + from sympy.physics.vector.vector import _check_vector + other = _check_vector(other) + ol = Dyadic(0) + for v in self.args: + ol += v[0] * (v[1].outer((v[2].cross(other)))) + return ol + + # NOTE : supports non-advertised Dyadic ^ Vector notation + __xor__ = cross + + def express(self, frame1, frame2=None): + """Expresses this Dyadic in alternate frame(s) + + The first frame is the list side expression, the second frame is the + right side; if Dyadic is in form A.x|B.y, you can express it in two + different frames. If no second frame is given, the Dyadic is + expressed in only one frame. + + Calls the global express function + + Parameters + ========== + + frame1 : ReferenceFrame + The frame to express the left side of the Dyadic in + frame2 : ReferenceFrame + If provided, the frame to express the right side of the Dyadic in + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame, outer, dynamicsymbols + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> N = ReferenceFrame('N') + >>> q = dynamicsymbols('q') + >>> B = N.orientnew('B', 'Axis', [q, N.z]) + >>> d = outer(N.x, N.x) + >>> d.express(B, N) + cos(q)*(B.x|N.x) - sin(q)*(B.y|N.x) + + """ + from sympy.physics.vector.functions import express + return express(self, frame1, frame2) + + def to_matrix(self, reference_frame, second_reference_frame=None): + """Returns the matrix form of the dyadic with respect to one or two + reference frames. + + Parameters + ---------- + reference_frame : ReferenceFrame + The reference frame that the rows and columns of the matrix + correspond to. If a second reference frame is provided, this + only corresponds to the rows of the matrix. + second_reference_frame : ReferenceFrame, optional, default=None + The reference frame that the columns of the matrix correspond + to. + + Returns + ------- + matrix : ImmutableMatrix, shape(3,3) + The matrix that gives the 2D tensor form. + + Examples + ======== + + >>> from sympy import symbols, trigsimp + >>> from sympy.physics.vector import ReferenceFrame + >>> from sympy.physics.mechanics import inertia + >>> Ixx, Iyy, Izz, Ixy, Iyz, Ixz = symbols('Ixx, Iyy, Izz, Ixy, Iyz, Ixz') + >>> N = ReferenceFrame('N') + >>> inertia_dyadic = inertia(N, Ixx, Iyy, Izz, Ixy, Iyz, Ixz) + >>> inertia_dyadic.to_matrix(N) + Matrix([ + [Ixx, Ixy, Ixz], + [Ixy, Iyy, Iyz], + [Ixz, Iyz, Izz]]) + >>> beta = symbols('beta') + >>> A = N.orientnew('A', 'Axis', (beta, N.x)) + >>> trigsimp(inertia_dyadic.to_matrix(A)) + Matrix([ + [ Ixx, Ixy*cos(beta) + Ixz*sin(beta), -Ixy*sin(beta) + Ixz*cos(beta)], + [ Ixy*cos(beta) + Ixz*sin(beta), Iyy*cos(2*beta)/2 + Iyy/2 + Iyz*sin(2*beta) - Izz*cos(2*beta)/2 + Izz/2, -Iyy*sin(2*beta)/2 + Iyz*cos(2*beta) + Izz*sin(2*beta)/2], + [-Ixy*sin(beta) + Ixz*cos(beta), -Iyy*sin(2*beta)/2 + Iyz*cos(2*beta) + Izz*sin(2*beta)/2, -Iyy*cos(2*beta)/2 + Iyy/2 - Iyz*sin(2*beta) + Izz*cos(2*beta)/2 + Izz/2]]) + + """ + + if second_reference_frame is None: + second_reference_frame = reference_frame + + return Matrix([i.dot(self).dot(j) for i in reference_frame for j in + second_reference_frame]).reshape(3, 3) + + def doit(self, **hints): + """Calls .doit() on each term in the Dyadic""" + return sum([Dyadic([(v[0].doit(**hints), v[1], v[2])]) + for v in self.args], Dyadic(0)) + + def dt(self, frame): + """Take the time derivative of this Dyadic in a frame. + + This function calls the global time_derivative method + + Parameters + ========== + + frame : ReferenceFrame + The frame to take the time derivative in + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame, outer, dynamicsymbols + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> N = ReferenceFrame('N') + >>> q = dynamicsymbols('q') + >>> B = N.orientnew('B', 'Axis', [q, N.z]) + >>> d = outer(N.x, N.x) + >>> d.dt(B) + - q'*(N.y|N.x) - q'*(N.x|N.y) + + """ + from sympy.physics.vector.functions import time_derivative + return time_derivative(self, frame) + + def simplify(self): + """Returns a simplified Dyadic.""" + out = Dyadic(0) + for v in self.args: + out += Dyadic([(v[0].simplify(), v[1], v[2])]) + return out + + def subs(self, *args, **kwargs): + """Substitution on the Dyadic. + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame + >>> from sympy import Symbol + >>> N = ReferenceFrame('N') + >>> s = Symbol('s') + >>> a = s*(N.x|N.x) + >>> a.subs({s: 2}) + 2*(N.x|N.x) + + """ + + return sum([Dyadic([(v[0].subs(*args, **kwargs), v[1], v[2])]) + for v in self.args], Dyadic(0)) + + def applyfunc(self, f): + """Apply a function to each component of a Dyadic.""" + if not callable(f): + raise TypeError("`f` must be callable.") + + out = Dyadic(0) + for a, b, c in self.args: + out += f(a) * (b.outer(c)) + return out + + def _eval_evalf(self, prec): + if not self.args: + return self + new_args = [] + dps = prec_to_dps(prec) + for inlist in self.args: + new_inlist = list(inlist) + new_inlist[0] = inlist[0].evalf(n=dps) + new_args.append(tuple(new_inlist)) + return Dyadic(new_args) + + def xreplace(self, rule): + """ + Replace occurrences of objects within the measure numbers of the + Dyadic. + + Parameters + ========== + + rule : dict-like + Expresses a replacement rule. + + Returns + ======= + + Dyadic + Result of the replacement. + + Examples + ======== + + >>> from sympy import symbols, pi + >>> from sympy.physics.vector import ReferenceFrame, outer + >>> N = ReferenceFrame('N') + >>> D = outer(N.x, N.x) + >>> x, y, z = symbols('x y z') + >>> ((1 + x*y) * D).xreplace({x: pi}) + (pi*y + 1)*(N.x|N.x) + >>> ((1 + x*y) * D).xreplace({x: pi, y: 2}) + (1 + 2*pi)*(N.x|N.x) + + Replacements occur only if an entire node in the expression tree is + matched: + + >>> ((x*y + z) * D).xreplace({x*y: pi}) + (z + pi)*(N.x|N.x) + >>> ((x*y*z) * D).xreplace({x*y: pi}) + x*y*z*(N.x|N.x) + + """ + + new_args = [] + for inlist in self.args: + new_inlist = list(inlist) + new_inlist[0] = new_inlist[0].xreplace(rule) + new_args.append(tuple(new_inlist)) + return Dyadic(new_args) + + +def _check_dyadic(other): + if not isinstance(other, Dyadic): + raise TypeError('A Dyadic must be supplied') + return other diff --git a/phivenv/Lib/site-packages/sympy/physics/vector/fieldfunctions.py b/phivenv/Lib/site-packages/sympy/physics/vector/fieldfunctions.py new file mode 100644 index 0000000000000000000000000000000000000000..50dd74ff9e5cb4fdf469a0ea5d72d812c8f03f15 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/vector/fieldfunctions.py @@ -0,0 +1,313 @@ +from sympy.core.function import diff +from sympy.core.singleton import S +from sympy.integrals.integrals import integrate +from sympy.physics.vector import Vector, express +from sympy.physics.vector.frame import _check_frame +from sympy.physics.vector.vector import _check_vector + + +__all__ = ['curl', 'divergence', 'gradient', 'is_conservative', + 'is_solenoidal', 'scalar_potential', + 'scalar_potential_difference'] + + +def curl(vect, frame): + """ + Returns the curl of a vector field computed wrt the coordinate + symbols of the given frame. + + Parameters + ========== + + vect : Vector + The vector operand + + frame : ReferenceFrame + The reference frame to calculate the curl in + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame + >>> from sympy.physics.vector import curl + >>> R = ReferenceFrame('R') + >>> v1 = R[1]*R[2]*R.x + R[0]*R[2]*R.y + R[0]*R[1]*R.z + >>> curl(v1, R) + 0 + >>> v2 = R[0]*R[1]*R[2]*R.x + >>> curl(v2, R) + R_x*R_y*R.y - R_x*R_z*R.z + + """ + + _check_vector(vect) + if vect == 0: + return Vector(0) + vect = express(vect, frame, variables=True) + # A mechanical approach to avoid looping overheads + vectx = vect.dot(frame.x) + vecty = vect.dot(frame.y) + vectz = vect.dot(frame.z) + outvec = Vector(0) + outvec += (diff(vectz, frame[1]) - diff(vecty, frame[2])) * frame.x + outvec += (diff(vectx, frame[2]) - diff(vectz, frame[0])) * frame.y + outvec += (diff(vecty, frame[0]) - diff(vectx, frame[1])) * frame.z + return outvec + + +def divergence(vect, frame): + """ + Returns the divergence of a vector field computed wrt the coordinate + symbols of the given frame. + + Parameters + ========== + + vect : Vector + The vector operand + + frame : ReferenceFrame + The reference frame to calculate the divergence in + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame + >>> from sympy.physics.vector import divergence + >>> R = ReferenceFrame('R') + >>> v1 = R[0]*R[1]*R[2] * (R.x+R.y+R.z) + >>> divergence(v1, R) + R_x*R_y + R_x*R_z + R_y*R_z + >>> v2 = 2*R[1]*R[2]*R.y + >>> divergence(v2, R) + 2*R_z + + """ + + _check_vector(vect) + if vect == 0: + return S.Zero + vect = express(vect, frame, variables=True) + vectx = vect.dot(frame.x) + vecty = vect.dot(frame.y) + vectz = vect.dot(frame.z) + out = S.Zero + out += diff(vectx, frame[0]) + out += diff(vecty, frame[1]) + out += diff(vectz, frame[2]) + return out + + +def gradient(scalar, frame): + """ + Returns the vector gradient of a scalar field computed wrt the + coordinate symbols of the given frame. + + Parameters + ========== + + scalar : sympifiable + The scalar field to take the gradient of + + frame : ReferenceFrame + The frame to calculate the gradient in + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame + >>> from sympy.physics.vector import gradient + >>> R = ReferenceFrame('R') + >>> s1 = R[0]*R[1]*R[2] + >>> gradient(s1, R) + R_y*R_z*R.x + R_x*R_z*R.y + R_x*R_y*R.z + >>> s2 = 5*R[0]**2*R[2] + >>> gradient(s2, R) + 10*R_x*R_z*R.x + 5*R_x**2*R.z + + """ + + _check_frame(frame) + outvec = Vector(0) + scalar = express(scalar, frame, variables=True) + for i, x in enumerate(frame): + outvec += diff(scalar, frame[i]) * x # noqa: PLR1736 + return outvec + + +def is_conservative(field): + """ + Checks if a field is conservative. + + Parameters + ========== + + field : Vector + The field to check for conservative property + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame + >>> from sympy.physics.vector import is_conservative + >>> R = ReferenceFrame('R') + >>> is_conservative(R[1]*R[2]*R.x + R[0]*R[2]*R.y + R[0]*R[1]*R.z) + True + >>> is_conservative(R[2] * R.y) + False + + """ + + # Field is conservative irrespective of frame + # Take the first frame in the result of the separate method of Vector + if field == Vector(0): + return True + frame = list(field.separate())[0] + return curl(field, frame).simplify() == Vector(0) + + +def is_solenoidal(field): + """ + Checks if a field is solenoidal. + + Parameters + ========== + + field : Vector + The field to check for solenoidal property + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame + >>> from sympy.physics.vector import is_solenoidal + >>> R = ReferenceFrame('R') + >>> is_solenoidal(R[1]*R[2]*R.x + R[0]*R[2]*R.y + R[0]*R[1]*R.z) + True + >>> is_solenoidal(R[1] * R.y) + False + + """ + + # Field is solenoidal irrespective of frame + # Take the first frame in the result of the separate method in Vector + if field == Vector(0): + return True + frame = list(field.separate())[0] + return divergence(field, frame).simplify() is S.Zero + + +def scalar_potential(field, frame): + """ + Returns the scalar potential function of a field in a given frame + (without the added integration constant). + + Parameters + ========== + + field : Vector + The vector field whose scalar potential function is to be + calculated + + frame : ReferenceFrame + The frame to do the calculation in + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame + >>> from sympy.physics.vector import scalar_potential, gradient + >>> R = ReferenceFrame('R') + >>> scalar_potential(R.z, R) == R[2] + True + >>> scalar_field = 2*R[0]**2*R[1]*R[2] + >>> grad_field = gradient(scalar_field, R) + >>> scalar_potential(grad_field, R) + 2*R_x**2*R_y*R_z + + """ + + # Check whether field is conservative + if not is_conservative(field): + raise ValueError("Field is not conservative") + if field == Vector(0): + return S.Zero + # Express the field exntirely in frame + # Substitute coordinate variables also + _check_frame(frame) + field = express(field, frame, variables=True) + # Make a list of dimensions of the frame + dimensions = list(frame) + # Calculate scalar potential function + temp_function = integrate(field.dot(dimensions[0]), frame[0]) + for i, dim in enumerate(dimensions[1:]): + partial_diff = diff(temp_function, frame[i + 1]) + partial_diff = field.dot(dim) - partial_diff + temp_function += integrate(partial_diff, frame[i + 1]) + return temp_function + + +def scalar_potential_difference(field, frame, point1, point2, origin): + """ + Returns the scalar potential difference between two points in a + certain frame, wrt a given field. + + If a scalar field is provided, its values at the two points are + considered. If a conservative vector field is provided, the values + of its scalar potential function at the two points are used. + + Returns (potential at position 2) - (potential at position 1) + + Parameters + ========== + + field : Vector/sympyfiable + The field to calculate wrt + + frame : ReferenceFrame + The frame to do the calculations in + + point1 : Point + The initial Point in given frame + + position2 : Point + The second Point in the given frame + + origin : Point + The Point to use as reference point for position vector + calculation + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame, Point + >>> from sympy.physics.vector import scalar_potential_difference + >>> R = ReferenceFrame('R') + >>> O = Point('O') + >>> P = O.locatenew('P', R[0]*R.x + R[1]*R.y + R[2]*R.z) + >>> vectfield = 4*R[0]*R[1]*R.x + 2*R[0]**2*R.y + >>> scalar_potential_difference(vectfield, R, O, P, O) + 2*R_x**2*R_y + >>> Q = O.locatenew('O', 3*R.x + R.y + 2*R.z) + >>> scalar_potential_difference(vectfield, R, P, Q, O) + -2*R_x**2*R_y + 18 + + """ + + _check_frame(frame) + if isinstance(field, Vector): + # Get the scalar potential function + scalar_fn = scalar_potential(field, frame) + else: + # Field is a scalar + scalar_fn = field + # Express positions in required frame + position1 = express(point1.pos_from(origin), frame, variables=True) + position2 = express(point2.pos_from(origin), frame, variables=True) + # Get the two positions as substitution dicts for coordinate variables + subs_dict1 = {} + subs_dict2 = {} + for i, x in enumerate(frame): + subs_dict1[frame[i]] = x.dot(position1) + subs_dict2[frame[i]] = x.dot(position2) + return scalar_fn.subs(subs_dict2) - scalar_fn.subs(subs_dict1) diff --git a/phivenv/Lib/site-packages/sympy/physics/vector/frame.py b/phivenv/Lib/site-packages/sympy/physics/vector/frame.py new file mode 100644 index 0000000000000000000000000000000000000000..4aa28fe3717696b6fd8196e652b6b1aa0daf5609 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/vector/frame.py @@ -0,0 +1,1575 @@ +from sympy import (diff, expand, sin, cos, sympify, eye, zeros, + ImmutableMatrix as Matrix, MatrixBase) +from sympy.core.symbol import Symbol +from sympy.simplify.trigsimp import trigsimp +from sympy.physics.vector.vector import Vector, _check_vector +from sympy.utilities.misc import translate + +from warnings import warn + +__all__ = ['CoordinateSym', 'ReferenceFrame'] + + +class CoordinateSym(Symbol): + """ + A coordinate symbol/base scalar associated wrt a Reference Frame. + + Ideally, users should not instantiate this class. Instances of + this class must only be accessed through the corresponding frame + as 'frame[index]'. + + CoordinateSyms having the same frame and index parameters are equal + (even though they may be instantiated separately). + + Parameters + ========== + + name : string + The display name of the CoordinateSym + + frame : ReferenceFrame + The reference frame this base scalar belongs to + + index : 0, 1 or 2 + The index of the dimension denoted by this coordinate variable + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame, CoordinateSym + >>> A = ReferenceFrame('A') + >>> A[1] + A_y + >>> type(A[0]) + + >>> a_y = CoordinateSym('a_y', A, 1) + >>> a_y == A[1] + True + + """ + + def __new__(cls, name, frame, index): + # We can't use the cached Symbol.__new__ because this class depends on + # frame and index, which are not passed to Symbol.__xnew__. + assumptions = {} + super()._sanitize(assumptions, cls) + obj = super().__xnew__(cls, name, **assumptions) + _check_frame(frame) + if index not in range(0, 3): + raise ValueError("Invalid index specified") + obj._id = (frame, index) + return obj + + def __getnewargs_ex__(self): + return (self.name, *self._id), {} + + @property + def frame(self): + return self._id[0] + + def __eq__(self, other): + # Check if the other object is a CoordinateSym of the same frame and + # same index + if isinstance(other, CoordinateSym): + if other._id == self._id: + return True + return False + + def __ne__(self, other): + return not self == other + + def __hash__(self): + return (self._id[0].__hash__(), self._id[1]).__hash__() + + +class ReferenceFrame: + """A reference frame in classical mechanics. + + ReferenceFrame is a class used to represent a reference frame in classical + mechanics. It has a standard basis of three unit vectors in the frame's + x, y, and z directions. + + It also can have a rotation relative to a parent frame; this rotation is + defined by a direction cosine matrix relating this frame's basis vectors to + the parent frame's basis vectors. It can also have an angular velocity + vector, defined in another frame. + + """ + _count = 0 + + def __init__(self, name, indices=None, latexs=None, variables=None): + """ReferenceFrame initialization method. + + A ReferenceFrame has a set of orthonormal basis vectors, along with + orientations relative to other ReferenceFrames and angular velocities + relative to other ReferenceFrames. + + Parameters + ========== + + indices : tuple of str + Enables the reference frame's basis unit vectors to be accessed by + Python's square bracket indexing notation using the provided three + indice strings and alters the printing of the unit vectors to + reflect this choice. + latexs : tuple of str + Alters the LaTeX printing of the reference frame's basis unit + vectors to the provided three valid LaTeX strings. + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame, vlatex + >>> N = ReferenceFrame('N') + >>> N.x + N.x + >>> O = ReferenceFrame('O', indices=('1', '2', '3')) + >>> O.x + O['1'] + >>> O['1'] + O['1'] + >>> P = ReferenceFrame('P', latexs=('A1', 'A2', 'A3')) + >>> vlatex(P.x) + 'A1' + + ``symbols()`` can be used to create multiple Reference Frames in one + step, for example: + + >>> from sympy.physics.vector import ReferenceFrame + >>> from sympy import symbols + >>> A, B, C = symbols('A B C', cls=ReferenceFrame) + >>> D, E = symbols('D E', cls=ReferenceFrame, indices=('1', '2', '3')) + >>> A[0] + A_x + >>> D.x + D['1'] + >>> E.y + E['2'] + >>> type(A) == type(D) + True + + Unit dyads for the ReferenceFrame can be accessed through the attributes ``xx``, ``xy``, etc. For example: + + >>> from sympy.physics.vector import ReferenceFrame + >>> N = ReferenceFrame('N') + >>> N.yz + (N.y|N.z) + >>> N.zx + (N.z|N.x) + >>> P = ReferenceFrame('P', indices=['1', '2', '3']) + >>> P.xx + (P['1']|P['1']) + >>> P.zy + (P['3']|P['2']) + + Unit dyadic is also accessible via the ``u`` attribute: + + >>> from sympy.physics.vector import ReferenceFrame + >>> N = ReferenceFrame('N') + >>> N.u + (N.x|N.x) + (N.y|N.y) + (N.z|N.z) + >>> P = ReferenceFrame('P', indices=['1', '2', '3']) + >>> P.u + (P['1']|P['1']) + (P['2']|P['2']) + (P['3']|P['3']) + + """ + + if not isinstance(name, str): + raise TypeError('Need to supply a valid name') + # The if statements below are for custom printing of basis-vectors for + # each frame. + # First case, when custom indices are supplied + if indices is not None: + if not isinstance(indices, (tuple, list)): + raise TypeError('Supply the indices as a list') + if len(indices) != 3: + raise ValueError('Supply 3 indices') + for i in indices: + if not isinstance(i, str): + raise TypeError('Indices must be strings') + self.str_vecs = [(name + '[\'' + indices[0] + '\']'), + (name + '[\'' + indices[1] + '\']'), + (name + '[\'' + indices[2] + '\']')] + self.pretty_vecs = [(name.lower() + "_" + indices[0]), + (name.lower() + "_" + indices[1]), + (name.lower() + "_" + indices[2])] + self.latex_vecs = [(r"\mathbf{\hat{%s}_{%s}}" % (name.lower(), + indices[0])), + (r"\mathbf{\hat{%s}_{%s}}" % (name.lower(), + indices[1])), + (r"\mathbf{\hat{%s}_{%s}}" % (name.lower(), + indices[2]))] + self.indices = indices + # Second case, when no custom indices are supplied + else: + self.str_vecs = [(name + '.x'), (name + '.y'), (name + '.z')] + self.pretty_vecs = [name.lower() + "_x", + name.lower() + "_y", + name.lower() + "_z"] + self.latex_vecs = [(r"\mathbf{\hat{%s}_x}" % name.lower()), + (r"\mathbf{\hat{%s}_y}" % name.lower()), + (r"\mathbf{\hat{%s}_z}" % name.lower())] + self.indices = ['x', 'y', 'z'] + # Different step, for custom latex basis vectors + if latexs is not None: + if not isinstance(latexs, (tuple, list)): + raise TypeError('Supply the indices as a list') + if len(latexs) != 3: + raise ValueError('Supply 3 indices') + for i in latexs: + if not isinstance(i, str): + raise TypeError('Latex entries must be strings') + self.latex_vecs = latexs + self.name = name + self._var_dict = {} + # The _dcm_dict dictionary will only store the dcms of adjacent + # parent-child relationships. The _dcm_cache dictionary will store + # calculated dcm along with all content of _dcm_dict for faster + # retrieval of dcms. + self._dcm_dict = {} + self._dcm_cache = {} + self._ang_vel_dict = {} + self._ang_acc_dict = {} + self._dlist = [self._dcm_dict, self._ang_vel_dict, self._ang_acc_dict] + self._cur = 0 + self._x = Vector([(Matrix([1, 0, 0]), self)]) + self._y = Vector([(Matrix([0, 1, 0]), self)]) + self._z = Vector([(Matrix([0, 0, 1]), self)]) + # Associate coordinate symbols wrt this frame + if variables is not None: + if not isinstance(variables, (tuple, list)): + raise TypeError('Supply the variable names as a list/tuple') + if len(variables) != 3: + raise ValueError('Supply 3 variable names') + for i in variables: + if not isinstance(i, str): + raise TypeError('Variable names must be strings') + else: + variables = [name + '_x', name + '_y', name + '_z'] + self.varlist = (CoordinateSym(variables[0], self, 0), + CoordinateSym(variables[1], self, 1), + CoordinateSym(variables[2], self, 2)) + ReferenceFrame._count += 1 + self.index = ReferenceFrame._count + + def __getitem__(self, ind): + """ + Returns basis vector for the provided index, if the index is a string. + + If the index is a number, returns the coordinate variable correspon- + -ding to that index. + """ + if not isinstance(ind, str): + if ind < 3: + return self.varlist[ind] + else: + raise ValueError("Invalid index provided") + if self.indices[0] == ind: + return self.x + if self.indices[1] == ind: + return self.y + if self.indices[2] == ind: + return self.z + else: + raise ValueError('Not a defined index') + + def __iter__(self): + return iter([self.x, self.y, self.z]) + + def __str__(self): + """Returns the name of the frame. """ + return self.name + + __repr__ = __str__ + + def _dict_list(self, other, num): + """Returns an inclusive list of reference frames that connect this + reference frame to the provided reference frame. + + Parameters + ========== + other : ReferenceFrame + The other reference frame to look for a connecting relationship to. + num : integer + ``0``, ``1``, and ``2`` will look for orientation, angular + velocity, and angular acceleration relationships between the two + frames, respectively. + + Returns + ======= + list + Inclusive list of reference frames that connect this reference + frame to the other reference frame. + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame + >>> A = ReferenceFrame('A') + >>> B = ReferenceFrame('B') + >>> C = ReferenceFrame('C') + >>> D = ReferenceFrame('D') + >>> B.orient_axis(A, A.x, 1.0) + >>> C.orient_axis(B, B.x, 1.0) + >>> D.orient_axis(C, C.x, 1.0) + >>> D._dict_list(A, 0) + [D, C, B, A] + + Raises + ====== + + ValueError + When no path is found between the two reference frames or ``num`` + is an incorrect value. + + """ + + connect_type = {0: 'orientation', + 1: 'angular velocity', + 2: 'angular acceleration'} + + if num not in connect_type.keys(): + raise ValueError('Valid values for num are 0, 1, or 2.') + + possible_connecting_paths = [[self]] + oldlist = [[]] + while possible_connecting_paths != oldlist: + oldlist = possible_connecting_paths.copy() + for frame_list in possible_connecting_paths: + frames_adjacent_to_last = frame_list[-1]._dlist[num].keys() + for adjacent_frame in frames_adjacent_to_last: + if adjacent_frame not in frame_list: + connecting_path = frame_list + [adjacent_frame] + if connecting_path not in possible_connecting_paths: + possible_connecting_paths.append(connecting_path) + + for connecting_path in oldlist: + if connecting_path[-1] != other: + possible_connecting_paths.remove(connecting_path) + possible_connecting_paths.sort(key=len) + + if len(possible_connecting_paths) != 0: + return possible_connecting_paths[0] # selects the shortest path + + msg = 'No connecting {} path found between {} and {}.' + raise ValueError(msg.format(connect_type[num], self.name, other.name)) + + def _w_diff_dcm(self, otherframe): + """Angular velocity from time differentiating the DCM. """ + from sympy.physics.vector.functions import dynamicsymbols + dcm2diff = otherframe.dcm(self) + diffed = dcm2diff.diff(dynamicsymbols._t) + angvelmat = diffed * dcm2diff.T + w1 = trigsimp(expand(angvelmat[7]), recursive=True) + w2 = trigsimp(expand(angvelmat[2]), recursive=True) + w3 = trigsimp(expand(angvelmat[3]), recursive=True) + return Vector([(Matrix([w1, w2, w3]), otherframe)]) + + def variable_map(self, otherframe): + """ + Returns a dictionary which expresses the coordinate variables + of this frame in terms of the variables of otherframe. + + If Vector.simp is True, returns a simplified version of the mapped + values. Else, returns them without simplification. + + Simplification of the expressions may take time. + + Parameters + ========== + + otherframe : ReferenceFrame + The other frame to map the variables to + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame, dynamicsymbols + >>> A = ReferenceFrame('A') + >>> q = dynamicsymbols('q') + >>> B = A.orientnew('B', 'Axis', [q, A.z]) + >>> A.variable_map(B) + {A_x: B_x*cos(q(t)) - B_y*sin(q(t)), A_y: B_x*sin(q(t)) + B_y*cos(q(t)), A_z: B_z} + + """ + + _check_frame(otherframe) + if (otherframe, Vector.simp) in self._var_dict: + return self._var_dict[(otherframe, Vector.simp)] + else: + vars_matrix = self.dcm(otherframe) * Matrix(otherframe.varlist) + mapping = {} + for i, x in enumerate(self): + if Vector.simp: + mapping[self.varlist[i]] = trigsimp(vars_matrix[i], + method='fu') + else: + mapping[self.varlist[i]] = vars_matrix[i] + self._var_dict[(otherframe, Vector.simp)] = mapping + return mapping + + def ang_acc_in(self, otherframe): + """Returns the angular acceleration Vector of the ReferenceFrame. + + Effectively returns the Vector: + + ``N_alpha_B`` + + which represent the angular acceleration of B in N, where B is self, + and N is otherframe. + + Parameters + ========== + + otherframe : ReferenceFrame + The ReferenceFrame which the angular acceleration is returned in. + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame + >>> N = ReferenceFrame('N') + >>> A = ReferenceFrame('A') + >>> V = 10 * N.x + >>> A.set_ang_acc(N, V) + >>> A.ang_acc_in(N) + 10*N.x + + """ + + _check_frame(otherframe) + if otherframe in self._ang_acc_dict: + return self._ang_acc_dict[otherframe] + else: + return self.ang_vel_in(otherframe).dt(otherframe) + + def ang_vel_in(self, otherframe): + """Returns the angular velocity Vector of the ReferenceFrame. + + Effectively returns the Vector: + + ^N omega ^B + + which represent the angular velocity of B in N, where B is self, and + N is otherframe. + + Parameters + ========== + + otherframe : ReferenceFrame + The ReferenceFrame which the angular velocity is returned in. + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame + >>> N = ReferenceFrame('N') + >>> A = ReferenceFrame('A') + >>> V = 10 * N.x + >>> A.set_ang_vel(N, V) + >>> A.ang_vel_in(N) + 10*N.x + + """ + + _check_frame(otherframe) + flist = self._dict_list(otherframe, 1) + outvec = Vector(0) + for i in range(len(flist) - 1): + outvec += flist[i]._ang_vel_dict[flist[i + 1]] + return outvec + + def dcm(self, otherframe): + r"""Returns the direction cosine matrix of this reference frame + relative to the provided reference frame. + + The returned matrix can be used to express the orthogonal unit vectors + of this frame in terms of the orthogonal unit vectors of + ``otherframe``. + + Parameters + ========== + + otherframe : ReferenceFrame + The reference frame which the direction cosine matrix of this frame + is formed relative to. + + Examples + ======== + + The following example rotates the reference frame A relative to N by a + simple rotation and then calculates the direction cosine matrix of N + relative to A. + + >>> from sympy import symbols, sin, cos + >>> from sympy.physics.vector import ReferenceFrame + >>> q1 = symbols('q1') + >>> N = ReferenceFrame('N') + >>> A = ReferenceFrame('A') + >>> A.orient_axis(N, q1, N.x) + >>> N.dcm(A) + Matrix([ + [1, 0, 0], + [0, cos(q1), -sin(q1)], + [0, sin(q1), cos(q1)]]) + + The second row of the above direction cosine matrix represents the + ``N.y`` unit vector in N expressed in A. Like so: + + >>> Ny = 0*A.x + cos(q1)*A.y - sin(q1)*A.z + + Thus, expressing ``N.y`` in A should return the same result: + + >>> N.y.express(A) + cos(q1)*A.y - sin(q1)*A.z + + Notes + ===== + + It is important to know what form of the direction cosine matrix is + returned. If ``B.dcm(A)`` is called, it means the "direction cosine + matrix of B rotated relative to A". This is the matrix + :math:`{}^B\mathbf{C}^A` shown in the following relationship: + + .. math:: + + \begin{bmatrix} + \hat{\mathbf{b}}_1 \\ + \hat{\mathbf{b}}_2 \\ + \hat{\mathbf{b}}_3 + \end{bmatrix} + = + {}^B\mathbf{C}^A + \begin{bmatrix} + \hat{\mathbf{a}}_1 \\ + \hat{\mathbf{a}}_2 \\ + \hat{\mathbf{a}}_3 + \end{bmatrix}. + + :math:`{}^B\mathbf{C}^A` is the matrix that expresses the B unit + vectors in terms of the A unit vectors. + + """ + + _check_frame(otherframe) + # Check if the dcm wrt that frame has already been calculated + if otherframe in self._dcm_cache: + return self._dcm_cache[otherframe] + flist = self._dict_list(otherframe, 0) + outdcm = eye(3) + for i in range(len(flist) - 1): + outdcm = outdcm * flist[i]._dcm_dict[flist[i + 1]] + # After calculation, store the dcm in dcm cache for faster future + # retrieval + self._dcm_cache[otherframe] = outdcm + otherframe._dcm_cache[self] = outdcm.T + return outdcm + + def _dcm(self, parent, parent_orient): + # If parent.oreint(self) is already defined,then + # update the _dcm_dict of parent while over write + # all content of self._dcm_dict and self._dcm_cache + # with new dcm relation. + # Else update _dcm_cache and _dcm_dict of both + # self and parent. + frames = self._dcm_cache.keys() + dcm_dict_del = [] + dcm_cache_del = [] + if parent in frames: + for frame in frames: + if frame in self._dcm_dict: + dcm_dict_del += [frame] + dcm_cache_del += [frame] + # Reset the _dcm_cache of this frame, and remove it from the + # _dcm_caches of the frames it is linked to. Also remove it from + # the _dcm_dict of its parent + for frame in dcm_dict_del: + del frame._dcm_dict[self] + for frame in dcm_cache_del: + del frame._dcm_cache[self] + # Reset the _dcm_dict + self._dcm_dict = self._dlist[0] = {} + # Reset the _dcm_cache + self._dcm_cache = {} + + else: + # Check for loops and raise warning accordingly. + visited = [] + queue = list(frames) + cont = True # Flag to control queue loop. + while queue and cont: + node = queue.pop(0) + if node not in visited: + visited.append(node) + neighbors = node._dcm_dict.keys() + for neighbor in neighbors: + if neighbor == parent: + warn('Loops are defined among the orientation of ' + 'frames. This is likely not desired and may ' + 'cause errors in your calculations.') + cont = False + break + queue.append(neighbor) + + # Add the dcm relationship to _dcm_dict + self._dcm_dict.update({parent: parent_orient.T}) + parent._dcm_dict.update({self: parent_orient}) + # Update the dcm cache + self._dcm_cache.update({parent: parent_orient.T}) + parent._dcm_cache.update({self: parent_orient}) + + def orient_axis(self, parent, axis, angle): + """Sets the orientation of this reference frame with respect to a + parent reference frame by rotating through an angle about an axis fixed + in the parent reference frame. + + Parameters + ========== + + parent : ReferenceFrame + Reference frame that this reference frame will be rotated relative + to. + axis : Vector + Vector fixed in the parent frame about about which this frame is + rotated. It need not be a unit vector and the rotation follows the + right hand rule. + angle : sympifiable + Angle in radians by which it the frame is to be rotated. + + Warns + ====== + + UserWarning + If the orientation creates a kinematic loop. + + Examples + ======== + + Setup variables for the examples: + + >>> from sympy import symbols + >>> from sympy.physics.vector import ReferenceFrame + >>> q1 = symbols('q1') + >>> N = ReferenceFrame('N') + >>> B = ReferenceFrame('B') + >>> B.orient_axis(N, N.x, q1) + + The ``orient_axis()`` method generates a direction cosine matrix and + its transpose which defines the orientation of B relative to N and vice + versa. Once orient is called, ``dcm()`` outputs the appropriate + direction cosine matrix: + + >>> B.dcm(N) + Matrix([ + [1, 0, 0], + [0, cos(q1), sin(q1)], + [0, -sin(q1), cos(q1)]]) + >>> N.dcm(B) + Matrix([ + [1, 0, 0], + [0, cos(q1), -sin(q1)], + [0, sin(q1), cos(q1)]]) + + The following two lines show that the sense of the rotation can be + defined by negating the vector direction or the angle. Both lines + produce the same result. + + >>> B.orient_axis(N, -N.x, q1) + >>> B.orient_axis(N, N.x, -q1) + + """ + + from sympy.physics.vector.functions import dynamicsymbols + _check_frame(parent) + + if not isinstance(axis, Vector) and isinstance(angle, Vector): + axis, angle = angle, axis + + axis = _check_vector(axis) + theta = sympify(angle) + + if not axis.dt(parent) == 0: + raise ValueError('Axis cannot be time-varying.') + unit_axis = axis.express(parent).normalize() + unit_col = unit_axis.args[0][0] + parent_orient_axis = ( + (eye(3) - unit_col * unit_col.T) * cos(theta) + + Matrix([[0, -unit_col[2], unit_col[1]], + [unit_col[2], 0, -unit_col[0]], + [-unit_col[1], unit_col[0], 0]]) * + sin(theta) + unit_col * unit_col.T) + + self._dcm(parent, parent_orient_axis) + + thetad = (theta).diff(dynamicsymbols._t) + wvec = thetad*axis.express(parent).normalize() + self._ang_vel_dict.update({parent: wvec}) + parent._ang_vel_dict.update({self: -wvec}) + self._var_dict = {} + + def orient_explicit(self, parent, dcm): + """Sets the orientation of this reference frame relative to another (parent) reference frame + using a direction cosine matrix that describes the rotation from the parent to the child. + + Parameters + ========== + + parent : ReferenceFrame + Reference frame that this reference frame will be rotated relative + to. + dcm : Matrix, shape(3, 3) + Direction cosine matrix that specifies the relative rotation + between the two reference frames. + + Warns + ====== + + UserWarning + If the orientation creates a kinematic loop. + + Examples + ======== + + Setup variables for the examples: + + >>> from sympy import symbols, Matrix, sin, cos + >>> from sympy.physics.vector import ReferenceFrame + >>> q1 = symbols('q1') + >>> A = ReferenceFrame('A') + >>> B = ReferenceFrame('B') + >>> N = ReferenceFrame('N') + + A simple rotation of ``A`` relative to ``N`` about ``N.x`` is defined + by the following direction cosine matrix: + + >>> dcm = Matrix([[1, 0, 0], + ... [0, cos(q1), -sin(q1)], + ... [0, sin(q1), cos(q1)]]) + >>> A.orient_explicit(N, dcm) + >>> A.dcm(N) + Matrix([ + [1, 0, 0], + [0, cos(q1), sin(q1)], + [0, -sin(q1), cos(q1)]]) + + This is equivalent to using ``orient_axis()``: + + >>> B.orient_axis(N, N.x, q1) + >>> B.dcm(N) + Matrix([ + [1, 0, 0], + [0, cos(q1), sin(q1)], + [0, -sin(q1), cos(q1)]]) + + **Note carefully that** ``N.dcm(B)`` **(the transpose) would be passed + into** ``orient_explicit()`` **for** ``A.dcm(N)`` **to match** + ``B.dcm(N)``: + + >>> A.orient_explicit(N, N.dcm(B)) + >>> A.dcm(N) + Matrix([ + [1, 0, 0], + [0, cos(q1), sin(q1)], + [0, -sin(q1), cos(q1)]]) + + """ + _check_frame(parent) + # amounts must be a Matrix type object + # (e.g. sympy.matrices.dense.MutableDenseMatrix). + if not isinstance(dcm, MatrixBase): + raise TypeError("Amounts must be a SymPy Matrix type object.") + + self.orient_dcm(parent, dcm.T) + + def orient_dcm(self, parent, dcm): + """Sets the orientation of this reference frame relative to another (parent) reference frame + using a direction cosine matrix that describes the rotation from the child to the parent. + + Parameters + ========== + + parent : ReferenceFrame + Reference frame that this reference frame will be rotated relative + to. + dcm : Matrix, shape(3, 3) + Direction cosine matrix that specifies the relative rotation + between the two reference frames. + + Warns + ====== + + UserWarning + If the orientation creates a kinematic loop. + + Examples + ======== + + Setup variables for the examples: + + >>> from sympy import symbols, Matrix, sin, cos + >>> from sympy.physics.vector import ReferenceFrame + >>> q1 = symbols('q1') + >>> A = ReferenceFrame('A') + >>> B = ReferenceFrame('B') + >>> N = ReferenceFrame('N') + + A simple rotation of ``A`` relative to ``N`` about ``N.x`` is defined + by the following direction cosine matrix: + + >>> dcm = Matrix([[1, 0, 0], + ... [0, cos(q1), sin(q1)], + ... [0, -sin(q1), cos(q1)]]) + >>> A.orient_dcm(N, dcm) + >>> A.dcm(N) + Matrix([ + [1, 0, 0], + [0, cos(q1), sin(q1)], + [0, -sin(q1), cos(q1)]]) + + This is equivalent to using ``orient_axis()``: + + >>> B.orient_axis(N, N.x, q1) + >>> B.dcm(N) + Matrix([ + [1, 0, 0], + [0, cos(q1), sin(q1)], + [0, -sin(q1), cos(q1)]]) + + """ + + _check_frame(parent) + # amounts must be a Matrix type object + # (e.g. sympy.matrices.dense.MutableDenseMatrix). + if not isinstance(dcm, MatrixBase): + raise TypeError("Amounts must be a SymPy Matrix type object.") + + self._dcm(parent, dcm.T) + + wvec = self._w_diff_dcm(parent) + self._ang_vel_dict.update({parent: wvec}) + parent._ang_vel_dict.update({self: -wvec}) + self._var_dict = {} + + def _rot(self, axis, angle): + """DCM for simple axis 1,2,or 3 rotations.""" + if axis == 1: + return Matrix([[1, 0, 0], + [0, cos(angle), -sin(angle)], + [0, sin(angle), cos(angle)]]) + elif axis == 2: + return Matrix([[cos(angle), 0, sin(angle)], + [0, 1, 0], + [-sin(angle), 0, cos(angle)]]) + elif axis == 3: + return Matrix([[cos(angle), -sin(angle), 0], + [sin(angle), cos(angle), 0], + [0, 0, 1]]) + + def _parse_consecutive_rotations(self, angles, rotation_order): + """Helper for orient_body_fixed and orient_space_fixed. + + Parameters + ========== + angles : 3-tuple of sympifiable + Three angles in radians used for the successive rotations. + rotation_order : 3 character string or 3 digit integer + Order of the rotations. The order can be specified by the strings + ``'XZX'``, ``'131'``, or the integer ``131``. There are 12 unique + valid rotation orders. + + Returns + ======= + + amounts : list + List of sympifiables corresponding to the rotation angles. + rot_order : list + List of integers corresponding to the axis of rotation. + rot_matrices : list + List of DCM around the given axis with corresponding magnitude. + + """ + amounts = list(angles) + for i, v in enumerate(amounts): + if not isinstance(v, Vector): + amounts[i] = sympify(v) + + approved_orders = ('123', '231', '312', '132', '213', '321', '121', + '131', '212', '232', '313', '323', '') + # make sure XYZ => 123 + rot_order = translate(str(rotation_order), 'XYZxyz', '123123') + if rot_order not in approved_orders: + raise TypeError('The rotation order is not a valid order.') + + rot_order = [int(r) for r in rot_order] + if not (len(amounts) == 3 & len(rot_order) == 3): + raise TypeError('Body orientation takes 3 values & 3 orders') + rot_matrices = [self._rot(order, amount) + for (order, amount) in zip(rot_order, amounts)] + return amounts, rot_order, rot_matrices + + def orient_body_fixed(self, parent, angles, rotation_order): + """Rotates this reference frame relative to the parent reference frame + by right hand rotating through three successive body fixed simple axis + rotations. Each subsequent axis of rotation is about the "body fixed" + unit vectors of a new intermediate reference frame. This type of + rotation is also referred to rotating through the `Euler and Tait-Bryan + Angles`_. + + .. _Euler and Tait-Bryan Angles: https://en.wikipedia.org/wiki/Euler_angles + + The computed angular velocity in this method is by default expressed in + the child's frame, so it is most preferable to use ``u1 * child.x + u2 * + child.y + u3 * child.z`` as generalized speeds. + + Parameters + ========== + + parent : ReferenceFrame + Reference frame that this reference frame will be rotated relative + to. + angles : 3-tuple of sympifiable + Three angles in radians used for the successive rotations. + rotation_order : 3 character string or 3 digit integer + Order of the rotations about each intermediate reference frames' + unit vectors. The Euler rotation about the X, Z', X'' axes can be + specified by the strings ``'XZX'``, ``'131'``, or the integer + ``131``. There are 12 unique valid rotation orders (6 Euler and 6 + Tait-Bryan): zxz, xyx, yzy, zyz, xzx, yxy, xyz, yzx, zxy, xzy, zyx, + and yxz. + + Warns + ====== + + UserWarning + If the orientation creates a kinematic loop. + + Examples + ======== + + Setup variables for the examples: + + >>> from sympy import symbols + >>> from sympy.physics.vector import ReferenceFrame + >>> q1, q2, q3 = symbols('q1, q2, q3') + >>> N = ReferenceFrame('N') + >>> B = ReferenceFrame('B') + >>> B1 = ReferenceFrame('B1') + >>> B2 = ReferenceFrame('B2') + >>> B3 = ReferenceFrame('B3') + + For example, a classic Euler Angle rotation can be done by: + + >>> B.orient_body_fixed(N, (q1, q2, q3), 'XYX') + >>> B.dcm(N) + Matrix([ + [ cos(q2), sin(q1)*sin(q2), -sin(q2)*cos(q1)], + [sin(q2)*sin(q3), -sin(q1)*sin(q3)*cos(q2) + cos(q1)*cos(q3), sin(q1)*cos(q3) + sin(q3)*cos(q1)*cos(q2)], + [sin(q2)*cos(q3), -sin(q1)*cos(q2)*cos(q3) - sin(q3)*cos(q1), -sin(q1)*sin(q3) + cos(q1)*cos(q2)*cos(q3)]]) + + This rotates reference frame B relative to reference frame N through + ``q1`` about ``N.x``, then rotates B again through ``q2`` about + ``B.y``, and finally through ``q3`` about ``B.x``. It is equivalent to + three successive ``orient_axis()`` calls: + + >>> B1.orient_axis(N, N.x, q1) + >>> B2.orient_axis(B1, B1.y, q2) + >>> B3.orient_axis(B2, B2.x, q3) + >>> B3.dcm(N) + Matrix([ + [ cos(q2), sin(q1)*sin(q2), -sin(q2)*cos(q1)], + [sin(q2)*sin(q3), -sin(q1)*sin(q3)*cos(q2) + cos(q1)*cos(q3), sin(q1)*cos(q3) + sin(q3)*cos(q1)*cos(q2)], + [sin(q2)*cos(q3), -sin(q1)*cos(q2)*cos(q3) - sin(q3)*cos(q1), -sin(q1)*sin(q3) + cos(q1)*cos(q2)*cos(q3)]]) + + Acceptable rotation orders are of length 3, expressed in as a string + ``'XYZ'`` or ``'123'`` or integer ``123``. Rotations about an axis + twice in a row are prohibited. + + >>> B.orient_body_fixed(N, (q1, q2, 0), 'ZXZ') + >>> B.orient_body_fixed(N, (q1, q2, 0), '121') + >>> B.orient_body_fixed(N, (q1, q2, q3), 123) + + """ + from sympy.physics.vector.functions import dynamicsymbols + + _check_frame(parent) + + amounts, rot_order, rot_matrices = self._parse_consecutive_rotations( + angles, rotation_order) + self._dcm(parent, rot_matrices[0] * rot_matrices[1] * rot_matrices[2]) + + rot_vecs = [zeros(3, 1) for _ in range(3)] + for i, order in enumerate(rot_order): + rot_vecs[i][order - 1] = amounts[i].diff(dynamicsymbols._t) + u1, u2, u3 = rot_vecs[2] + rot_matrices[2].T * ( + rot_vecs[1] + rot_matrices[1].T * rot_vecs[0]) + wvec = u1 * self.x + u2 * self.y + u3 * self.z # There is a double - + self._ang_vel_dict.update({parent: wvec}) + parent._ang_vel_dict.update({self: -wvec}) + self._var_dict = {} + + def orient_space_fixed(self, parent, angles, rotation_order): + """Rotates this reference frame relative to the parent reference frame + by right hand rotating through three successive space fixed simple axis + rotations. Each subsequent axis of rotation is about the "space fixed" + unit vectors of the parent reference frame. + + The computed angular velocity in this method is by default expressed in + the child's frame, so it is most preferable to use ``u1 * child.x + u2 * + child.y + u3 * child.z`` as generalized speeds. + + Parameters + ========== + parent : ReferenceFrame + Reference frame that this reference frame will be rotated relative + to. + angles : 3-tuple of sympifiable + Three angles in radians used for the successive rotations. + rotation_order : 3 character string or 3 digit integer + Order of the rotations about the parent reference frame's unit + vectors. The order can be specified by the strings ``'XZX'``, + ``'131'``, or the integer ``131``. There are 12 unique valid + rotation orders. + + Warns + ====== + + UserWarning + If the orientation creates a kinematic loop. + + Examples + ======== + + Setup variables for the examples: + + >>> from sympy import symbols + >>> from sympy.physics.vector import ReferenceFrame + >>> q1, q2, q3 = symbols('q1, q2, q3') + >>> N = ReferenceFrame('N') + >>> B = ReferenceFrame('B') + >>> B1 = ReferenceFrame('B1') + >>> B2 = ReferenceFrame('B2') + >>> B3 = ReferenceFrame('B3') + + >>> B.orient_space_fixed(N, (q1, q2, q3), '312') + >>> B.dcm(N) + Matrix([ + [ sin(q1)*sin(q2)*sin(q3) + cos(q1)*cos(q3), sin(q1)*cos(q2), sin(q1)*sin(q2)*cos(q3) - sin(q3)*cos(q1)], + [-sin(q1)*cos(q3) + sin(q2)*sin(q3)*cos(q1), cos(q1)*cos(q2), sin(q1)*sin(q3) + sin(q2)*cos(q1)*cos(q3)], + [ sin(q3)*cos(q2), -sin(q2), cos(q2)*cos(q3)]]) + + is equivalent to: + + >>> B1.orient_axis(N, N.z, q1) + >>> B2.orient_axis(B1, N.x, q2) + >>> B3.orient_axis(B2, N.y, q3) + >>> B3.dcm(N).simplify() + Matrix([ + [ sin(q1)*sin(q2)*sin(q3) + cos(q1)*cos(q3), sin(q1)*cos(q2), sin(q1)*sin(q2)*cos(q3) - sin(q3)*cos(q1)], + [-sin(q1)*cos(q3) + sin(q2)*sin(q3)*cos(q1), cos(q1)*cos(q2), sin(q1)*sin(q3) + sin(q2)*cos(q1)*cos(q3)], + [ sin(q3)*cos(q2), -sin(q2), cos(q2)*cos(q3)]]) + + It is worth noting that space-fixed and body-fixed rotations are + related by the order of the rotations, i.e. the reverse order of body + fixed will give space fixed and vice versa. + + >>> B.orient_space_fixed(N, (q1, q2, q3), '231') + >>> B.dcm(N) + Matrix([ + [cos(q1)*cos(q2), sin(q1)*sin(q3) + sin(q2)*cos(q1)*cos(q3), -sin(q1)*cos(q3) + sin(q2)*sin(q3)*cos(q1)], + [ -sin(q2), cos(q2)*cos(q3), sin(q3)*cos(q2)], + [sin(q1)*cos(q2), sin(q1)*sin(q2)*cos(q3) - sin(q3)*cos(q1), sin(q1)*sin(q2)*sin(q3) + cos(q1)*cos(q3)]]) + + >>> B.orient_body_fixed(N, (q3, q2, q1), '132') + >>> B.dcm(N) + Matrix([ + [cos(q1)*cos(q2), sin(q1)*sin(q3) + sin(q2)*cos(q1)*cos(q3), -sin(q1)*cos(q3) + sin(q2)*sin(q3)*cos(q1)], + [ -sin(q2), cos(q2)*cos(q3), sin(q3)*cos(q2)], + [sin(q1)*cos(q2), sin(q1)*sin(q2)*cos(q3) - sin(q3)*cos(q1), sin(q1)*sin(q2)*sin(q3) + cos(q1)*cos(q3)]]) + + """ + from sympy.physics.vector.functions import dynamicsymbols + + _check_frame(parent) + + amounts, rot_order, rot_matrices = self._parse_consecutive_rotations( + angles, rotation_order) + self._dcm(parent, rot_matrices[2] * rot_matrices[1] * rot_matrices[0]) + + rot_vecs = [zeros(3, 1) for _ in range(3)] + for i, order in enumerate(rot_order): + rot_vecs[i][order - 1] = amounts[i].diff(dynamicsymbols._t) + u1, u2, u3 = rot_vecs[0] + rot_matrices[0].T * ( + rot_vecs[1] + rot_matrices[1].T * rot_vecs[2]) + wvec = u1 * self.x + u2 * self.y + u3 * self.z # There is a double - + self._ang_vel_dict.update({parent: wvec}) + parent._ang_vel_dict.update({self: -wvec}) + self._var_dict = {} + + def orient_quaternion(self, parent, numbers): + """Sets the orientation of this reference frame relative to a parent + reference frame via an orientation quaternion. An orientation + quaternion is defined as a finite rotation a unit vector, ``(lambda_x, + lambda_y, lambda_z)``, by an angle ``theta``. The orientation + quaternion is described by four parameters: + + - ``q0 = cos(theta/2)`` + - ``q1 = lambda_x*sin(theta/2)`` + - ``q2 = lambda_y*sin(theta/2)`` + - ``q3 = lambda_z*sin(theta/2)`` + + See `Quaternions and Spatial Rotation + `_ on + Wikipedia for more information. + + Parameters + ========== + parent : ReferenceFrame + Reference frame that this reference frame will be rotated relative + to. + numbers : 4-tuple of sympifiable + The four quaternion scalar numbers as defined above: ``q0``, + ``q1``, ``q2``, ``q3``. + + Warns + ====== + + UserWarning + If the orientation creates a kinematic loop. + + Examples + ======== + + Setup variables for the examples: + + >>> from sympy import symbols + >>> from sympy.physics.vector import ReferenceFrame + >>> q0, q1, q2, q3 = symbols('q0 q1 q2 q3') + >>> N = ReferenceFrame('N') + >>> B = ReferenceFrame('B') + + Set the orientation: + + >>> B.orient_quaternion(N, (q0, q1, q2, q3)) + >>> B.dcm(N) + Matrix([ + [q0**2 + q1**2 - q2**2 - q3**2, 2*q0*q3 + 2*q1*q2, -2*q0*q2 + 2*q1*q3], + [ -2*q0*q3 + 2*q1*q2, q0**2 - q1**2 + q2**2 - q3**2, 2*q0*q1 + 2*q2*q3], + [ 2*q0*q2 + 2*q1*q3, -2*q0*q1 + 2*q2*q3, q0**2 - q1**2 - q2**2 + q3**2]]) + + """ + + from sympy.physics.vector.functions import dynamicsymbols + _check_frame(parent) + + numbers = list(numbers) + for i, v in enumerate(numbers): + if not isinstance(v, Vector): + numbers[i] = sympify(v) + + if not (isinstance(numbers, (list, tuple)) & (len(numbers) == 4)): + raise TypeError('Amounts are a list or tuple of length 4') + q0, q1, q2, q3 = numbers + parent_orient_quaternion = ( + Matrix([[q0**2 + q1**2 - q2**2 - q3**2, + 2 * (q1 * q2 - q0 * q3), + 2 * (q0 * q2 + q1 * q3)], + [2 * (q1 * q2 + q0 * q3), + q0**2 - q1**2 + q2**2 - q3**2, + 2 * (q2 * q3 - q0 * q1)], + [2 * (q1 * q3 - q0 * q2), + 2 * (q0 * q1 + q2 * q3), + q0**2 - q1**2 - q2**2 + q3**2]])) + + self._dcm(parent, parent_orient_quaternion) + + t = dynamicsymbols._t + q0, q1, q2, q3 = numbers + q0d = diff(q0, t) + q1d = diff(q1, t) + q2d = diff(q2, t) + q3d = diff(q3, t) + w1 = 2 * (q1d * q0 + q2d * q3 - q3d * q2 - q0d * q1) + w2 = 2 * (q2d * q0 + q3d * q1 - q1d * q3 - q0d * q2) + w3 = 2 * (q3d * q0 + q1d * q2 - q2d * q1 - q0d * q3) + wvec = Vector([(Matrix([w1, w2, w3]), self)]) + + self._ang_vel_dict.update({parent: wvec}) + parent._ang_vel_dict.update({self: -wvec}) + self._var_dict = {} + + def orient(self, parent, rot_type, amounts, rot_order=''): + """Sets the orientation of this reference frame relative to another + (parent) reference frame. + + .. note:: It is now recommended to use the ``.orient_axis, + .orient_body_fixed, .orient_space_fixed, .orient_quaternion`` + methods for the different rotation types. + + Parameters + ========== + + parent : ReferenceFrame + Reference frame that this reference frame will be rotated relative + to. + rot_type : str + The method used to generate the direction cosine matrix. Supported + methods are: + + - ``'Axis'``: simple rotations about a single common axis + - ``'DCM'``: for setting the direction cosine matrix directly + - ``'Body'``: three successive rotations about new intermediate + axes, also called "Euler and Tait-Bryan angles" + - ``'Space'``: three successive rotations about the parent + frames' unit vectors + - ``'Quaternion'``: rotations defined by four parameters which + result in a singularity free direction cosine matrix + + amounts : + Expressions defining the rotation angles or direction cosine + matrix. These must match the ``rot_type``. See examples below for + details. The input types are: + + - ``'Axis'``: 2-tuple (expr/sym/func, Vector) + - ``'DCM'``: Matrix, shape(3,3) + - ``'Body'``: 3-tuple of expressions, symbols, or functions + - ``'Space'``: 3-tuple of expressions, symbols, or functions + - ``'Quaternion'``: 4-tuple of expressions, symbols, or + functions + + rot_order : str or int, optional + If applicable, the order of the successive of rotations. The string + ``'123'`` and integer ``123`` are equivalent, for example. Required + for ``'Body'`` and ``'Space'``. + + Warns + ====== + + UserWarning + If the orientation creates a kinematic loop. + + """ + + _check_frame(parent) + + approved_orders = ('123', '231', '312', '132', '213', '321', '121', + '131', '212', '232', '313', '323', '') + rot_order = translate(str(rot_order), 'XYZxyz', '123123') + rot_type = rot_type.upper() + + if rot_order not in approved_orders: + raise TypeError('The supplied order is not an approved type') + + if rot_type == 'AXIS': + self.orient_axis(parent, amounts[1], amounts[0]) + + elif rot_type == 'DCM': + self.orient_explicit(parent, amounts) + + elif rot_type == 'BODY': + self.orient_body_fixed(parent, amounts, rot_order) + + elif rot_type == 'SPACE': + self.orient_space_fixed(parent, amounts, rot_order) + + elif rot_type == 'QUATERNION': + self.orient_quaternion(parent, amounts) + + else: + raise NotImplementedError('That is not an implemented rotation') + + def orientnew(self, newname, rot_type, amounts, rot_order='', + variables=None, indices=None, latexs=None): + r"""Returns a new reference frame oriented with respect to this + reference frame. + + See ``ReferenceFrame.orient()`` for detailed examples of how to orient + reference frames. + + Parameters + ========== + + newname : str + Name for the new reference frame. + rot_type : str + The method used to generate the direction cosine matrix. Supported + methods are: + + - ``'Axis'``: simple rotations about a single common axis + - ``'DCM'``: for setting the direction cosine matrix directly + - ``'Body'``: three successive rotations about new intermediate + axes, also called "Euler and Tait-Bryan angles" + - ``'Space'``: three successive rotations about the parent + frames' unit vectors + - ``'Quaternion'``: rotations defined by four parameters which + result in a singularity free direction cosine matrix + + amounts : + Expressions defining the rotation angles or direction cosine + matrix. These must match the ``rot_type``. See examples below for + details. The input types are: + + - ``'Axis'``: 2-tuple (expr/sym/func, Vector) + - ``'DCM'``: Matrix, shape(3,3) + - ``'Body'``: 3-tuple of expressions, symbols, or functions + - ``'Space'``: 3-tuple of expressions, symbols, or functions + - ``'Quaternion'``: 4-tuple of expressions, symbols, or + functions + + rot_order : str or int, optional + If applicable, the order of the successive of rotations. The string + ``'123'`` and integer ``123`` are equivalent, for example. Required + for ``'Body'`` and ``'Space'``. + indices : tuple of str + Enables the reference frame's basis unit vectors to be accessed by + Python's square bracket indexing notation using the provided three + indice strings and alters the printing of the unit vectors to + reflect this choice. + latexs : tuple of str + Alters the LaTeX printing of the reference frame's basis unit + vectors to the provided three valid LaTeX strings. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.vector import ReferenceFrame, vlatex + >>> q0, q1, q2, q3 = symbols('q0 q1 q2 q3') + >>> N = ReferenceFrame('N') + + Create a new reference frame A rotated relative to N through a simple + rotation. + + >>> A = N.orientnew('A', 'Axis', (q0, N.x)) + + Create a new reference frame B rotated relative to N through body-fixed + rotations. + + >>> B = N.orientnew('B', 'Body', (q1, q2, q3), '123') + + Create a new reference frame C rotated relative to N through a simple + rotation with unique indices and LaTeX printing. + + >>> C = N.orientnew('C', 'Axis', (q0, N.x), indices=('1', '2', '3'), + ... latexs=(r'\hat{\mathbf{c}}_1',r'\hat{\mathbf{c}}_2', + ... r'\hat{\mathbf{c}}_3')) + >>> C['1'] + C['1'] + >>> print(vlatex(C['1'])) + \hat{\mathbf{c}}_1 + + """ + + newframe = self.__class__(newname, variables=variables, + indices=indices, latexs=latexs) + + approved_orders = ('123', '231', '312', '132', '213', '321', '121', + '131', '212', '232', '313', '323', '') + rot_order = translate(str(rot_order), 'XYZxyz', '123123') + rot_type = rot_type.upper() + + if rot_order not in approved_orders: + raise TypeError('The supplied order is not an approved type') + + if rot_type == 'AXIS': + newframe.orient_axis(self, amounts[1], amounts[0]) + + elif rot_type == 'DCM': + newframe.orient_explicit(self, amounts) + + elif rot_type == 'BODY': + newframe.orient_body_fixed(self, amounts, rot_order) + + elif rot_type == 'SPACE': + newframe.orient_space_fixed(self, amounts, rot_order) + + elif rot_type == 'QUATERNION': + newframe.orient_quaternion(self, amounts) + + else: + raise NotImplementedError('That is not an implemented rotation') + return newframe + + def set_ang_acc(self, otherframe, value): + """Define the angular acceleration Vector in a ReferenceFrame. + + Defines the angular acceleration of this ReferenceFrame, in another. + Angular acceleration can be defined with respect to multiple different + ReferenceFrames. Care must be taken to not create loops which are + inconsistent. + + Parameters + ========== + + otherframe : ReferenceFrame + A ReferenceFrame to define the angular acceleration in + value : Vector + The Vector representing angular acceleration + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame + >>> N = ReferenceFrame('N') + >>> A = ReferenceFrame('A') + >>> V = 10 * N.x + >>> A.set_ang_acc(N, V) + >>> A.ang_acc_in(N) + 10*N.x + + """ + + if value == 0: + value = Vector(0) + value = _check_vector(value) + _check_frame(otherframe) + self._ang_acc_dict.update({otherframe: value}) + otherframe._ang_acc_dict.update({self: -value}) + + def set_ang_vel(self, otherframe, value): + """Define the angular velocity vector in a ReferenceFrame. + + Defines the angular velocity of this ReferenceFrame, in another. + Angular velocity can be defined with respect to multiple different + ReferenceFrames. Care must be taken to not create loops which are + inconsistent. + + Parameters + ========== + + otherframe : ReferenceFrame + A ReferenceFrame to define the angular velocity in + value : Vector + The Vector representing angular velocity + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame + >>> N = ReferenceFrame('N') + >>> A = ReferenceFrame('A') + >>> V = 10 * N.x + >>> A.set_ang_vel(N, V) + >>> A.ang_vel_in(N) + 10*N.x + + """ + + if value == 0: + value = Vector(0) + value = _check_vector(value) + _check_frame(otherframe) + self._ang_vel_dict.update({otherframe: value}) + otherframe._ang_vel_dict.update({self: -value}) + + @property + def x(self): + """The basis Vector for the ReferenceFrame, in the x direction. """ + return self._x + + @property + def y(self): + """The basis Vector for the ReferenceFrame, in the y direction. """ + return self._y + + @property + def z(self): + """The basis Vector for the ReferenceFrame, in the z direction. """ + return self._z + + @property + def xx(self): + """Unit dyad of basis Vectors x and x for the ReferenceFrame.""" + return Vector.outer(self.x, self.x) + + @property + def xy(self): + """Unit dyad of basis Vectors x and y for the ReferenceFrame.""" + return Vector.outer(self.x, self.y) + + @property + def xz(self): + """Unit dyad of basis Vectors x and z for the ReferenceFrame.""" + return Vector.outer(self.x, self.z) + + @property + def yx(self): + """Unit dyad of basis Vectors y and x for the ReferenceFrame.""" + return Vector.outer(self.y, self.x) + + @property + def yy(self): + """Unit dyad of basis Vectors y and y for the ReferenceFrame.""" + return Vector.outer(self.y, self.y) + + @property + def yz(self): + """Unit dyad of basis Vectors y and z for the ReferenceFrame.""" + return Vector.outer(self.y, self.z) + + @property + def zx(self): + """Unit dyad of basis Vectors z and x for the ReferenceFrame.""" + return Vector.outer(self.z, self.x) + + @property + def zy(self): + """Unit dyad of basis Vectors z and y for the ReferenceFrame.""" + return Vector.outer(self.z, self.y) + + @property + def zz(self): + """Unit dyad of basis Vectors z and z for the ReferenceFrame.""" + return Vector.outer(self.z, self.z) + + @property + def u(self): + """Unit dyadic for the ReferenceFrame.""" + return self.xx + self.yy + self.zz + + def partial_velocity(self, frame, *gen_speeds): + """Returns the partial angular velocities of this frame in the given + frame with respect to one or more provided generalized speeds. + + Parameters + ========== + frame : ReferenceFrame + The frame with which the angular velocity is defined in. + gen_speeds : functions of time + The generalized speeds. + + Returns + ======= + partial_velocities : tuple of Vector + The partial angular velocity vectors corresponding to the provided + generalized speeds. + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame, dynamicsymbols + >>> N = ReferenceFrame('N') + >>> A = ReferenceFrame('A') + >>> u1, u2 = dynamicsymbols('u1, u2') + >>> A.set_ang_vel(N, u1 * A.x + u2 * N.y) + >>> A.partial_velocity(N, u1) + A.x + >>> A.partial_velocity(N, u1, u2) + (A.x, N.y) + + """ + + from sympy.physics.vector.functions import partial_velocity + + vel = self.ang_vel_in(frame) + partials = partial_velocity([vel], gen_speeds, frame)[0] + + if len(partials) == 1: + return partials[0] + else: + return tuple(partials) + + +def _check_frame(other): + from .vector import VectorTypeError + if not isinstance(other, ReferenceFrame): + raise VectorTypeError(other, ReferenceFrame('A')) diff --git a/phivenv/Lib/site-packages/sympy/physics/vector/functions.py b/phivenv/Lib/site-packages/sympy/physics/vector/functions.py new file mode 100644 index 0000000000000000000000000000000000000000..6775b4b23bb376992d6a9e7651ba73a951c84287 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/vector/functions.py @@ -0,0 +1,650 @@ +from functools import reduce + +from sympy import (sympify, diff, sin, cos, Matrix, symbols, + Function, S, Symbol, linear_eq_to_matrix) +from sympy.integrals.integrals import integrate +from sympy.simplify.trigsimp import trigsimp +from .vector import Vector, _check_vector +from .frame import CoordinateSym, _check_frame +from .dyadic import Dyadic +from .printing import vprint, vsprint, vpprint, vlatex, init_vprinting +from sympy.utilities.iterables import iterable +from sympy.utilities.misc import translate + +__all__ = ['cross', 'dot', 'express', 'time_derivative', 'outer', + 'kinematic_equations', 'get_motion_params', 'partial_velocity', + 'dynamicsymbols', 'vprint', 'vsprint', 'vpprint', 'vlatex', + 'init_vprinting'] + + +def cross(vec1, vec2): + """Cross product convenience wrapper for Vector.cross(): \n""" + if not isinstance(vec1, (Vector, Dyadic)): + raise TypeError('Cross product is between two vectors') + return vec1 ^ vec2 + + +cross.__doc__ += Vector.cross.__doc__ # type: ignore + + +def dot(vec1, vec2): + """Dot product convenience wrapper for Vector.dot(): \n""" + if not isinstance(vec1, (Vector, Dyadic)): + raise TypeError('Dot product is between two vectors') + return vec1 & vec2 + + +dot.__doc__ += Vector.dot.__doc__ # type: ignore + + +def express(expr, frame, frame2=None, variables=False): + """ + Global function for 'express' functionality. + + Re-expresses a Vector, scalar(sympyfiable) or Dyadic in given frame. + + Refer to the local methods of Vector and Dyadic for details. + If 'variables' is True, then the coordinate variables (CoordinateSym + instances) of other frames present in the vector/scalar field or + dyadic expression are also substituted in terms of the base scalars of + this frame. + + Parameters + ========== + + expr : Vector/Dyadic/scalar(sympyfiable) + The expression to re-express in ReferenceFrame 'frame' + + frame: ReferenceFrame + The reference frame to express expr in + + frame2 : ReferenceFrame + The other frame required for re-expression(only for Dyadic expr) + + variables : boolean + Specifies whether to substitute the coordinate variables present + in expr, in terms of those of frame + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame, outer, dynamicsymbols + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> N = ReferenceFrame('N') + >>> q = dynamicsymbols('q') + >>> B = N.orientnew('B', 'Axis', [q, N.z]) + >>> d = outer(N.x, N.x) + >>> from sympy.physics.vector import express + >>> express(d, B, N) + cos(q)*(B.x|N.x) - sin(q)*(B.y|N.x) + >>> express(B.x, N) + cos(q)*N.x + sin(q)*N.y + >>> express(N[0], B, variables=True) + B_x*cos(q) - B_y*sin(q) + + """ + + _check_frame(frame) + + if expr == 0: + return expr + + if isinstance(expr, Vector): + # Given expr is a Vector + if variables: + # If variables attribute is True, substitute the coordinate + # variables in the Vector + frame_list = [x[-1] for x in expr.args] + subs_dict = {} + for f in frame_list: + subs_dict.update(f.variable_map(frame)) + expr = expr.subs(subs_dict) + # Re-express in this frame + outvec = Vector([]) + for v in expr.args: + if v[1] != frame: + temp = frame.dcm(v[1]) * v[0] + if Vector.simp: + temp = temp.applyfunc(lambda x: + trigsimp(x, method='fu')) + outvec += Vector([(temp, frame)]) + else: + outvec += Vector([v]) + return outvec + + if isinstance(expr, Dyadic): + if frame2 is None: + frame2 = frame + _check_frame(frame2) + ol = Dyadic(0) + for v in expr.args: + ol += express(v[0], frame, variables=variables) * \ + (express(v[1], frame, variables=variables) | + express(v[2], frame2, variables=variables)) + return ol + + else: + if variables: + # Given expr is a scalar field + frame_set = set() + expr = sympify(expr) + # Substitute all the coordinate variables + for x in expr.free_symbols: + if isinstance(x, CoordinateSym) and x.frame != frame: + frame_set.add(x.frame) + subs_dict = {} + for f in frame_set: + subs_dict.update(f.variable_map(frame)) + return expr.subs(subs_dict) + return expr + + +def time_derivative(expr, frame, order=1): + """ + Calculate the time derivative of a vector/scalar field function + or dyadic expression in given frame. + + References + ========== + + https://en.wikipedia.org/wiki/Rotating_reference_frame#Time_derivatives_in_the_two_frames + + Parameters + ========== + + expr : Vector/Dyadic/sympifyable + The expression whose time derivative is to be calculated + + frame : ReferenceFrame + The reference frame to calculate the time derivative in + + order : integer + The order of the derivative to be calculated + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame, dynamicsymbols + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> from sympy import Symbol + >>> q1 = Symbol('q1') + >>> u1 = dynamicsymbols('u1') + >>> N = ReferenceFrame('N') + >>> A = N.orientnew('A', 'Axis', [q1, N.x]) + >>> v = u1 * N.x + >>> A.set_ang_vel(N, 10*A.x) + >>> from sympy.physics.vector import time_derivative + >>> time_derivative(v, N) + u1'*N.x + >>> time_derivative(u1*A[0], N) + N_x*u1' + >>> B = N.orientnew('B', 'Axis', [u1, N.z]) + >>> from sympy.physics.vector import outer + >>> d = outer(N.x, N.x) + >>> time_derivative(d, B) + - u1'*(N.y|N.x) - u1'*(N.x|N.y) + + """ + + t = dynamicsymbols._t + _check_frame(frame) + + if order == 0: + return expr + if order % 1 != 0 or order < 0: + raise ValueError("Unsupported value of order entered") + + if isinstance(expr, Vector): + outlist = [] + for v in expr.args: + if v[1] == frame: + outlist += [(express(v[0], frame, variables=True).diff(t), + frame)] + else: + outlist += (time_derivative(Vector([v]), v[1]) + + (v[1].ang_vel_in(frame) ^ Vector([v]))).args + outvec = Vector(outlist) + return time_derivative(outvec, frame, order - 1) + + if isinstance(expr, Dyadic): + ol = Dyadic(0) + for v in expr.args: + ol += (v[0].diff(t) * (v[1] | v[2])) + ol += (v[0] * (time_derivative(v[1], frame) | v[2])) + ol += (v[0] * (v[1] | time_derivative(v[2], frame))) + return time_derivative(ol, frame, order - 1) + + else: + return diff(express(expr, frame, variables=True), t, order) + + +def outer(vec1, vec2): + """Outer product convenience wrapper for Vector.outer():\n""" + if not isinstance(vec1, Vector): + raise TypeError('Outer product is between two Vectors') + return vec1.outer(vec2) + + +outer.__doc__ += Vector.outer.__doc__ # type: ignore + + +def kinematic_equations(speeds, coords, rot_type, rot_order=''): + """Gives equations relating the qdot's to u's for a rotation type. + + Supply rotation type and order as in orient. Speeds are assumed to be + body-fixed; if we are defining the orientation of B in A using by rot_type, + the angular velocity of B in A is assumed to be in the form: speed[0]*B.x + + speed[1]*B.y + speed[2]*B.z + + Parameters + ========== + + speeds : list of length 3 + The body fixed angular velocity measure numbers. + coords : list of length 3 or 4 + The coordinates used to define the orientation of the two frames. + rot_type : str + The type of rotation used to create the equations. Body, Space, or + Quaternion only + rot_order : str or int + If applicable, the order of a series of rotations. + + Examples + ======== + + >>> from sympy.physics.vector import dynamicsymbols + >>> from sympy.physics.vector import kinematic_equations, vprint + >>> u1, u2, u3 = dynamicsymbols('u1 u2 u3') + >>> q1, q2, q3 = dynamicsymbols('q1 q2 q3') + >>> vprint(kinematic_equations([u1,u2,u3], [q1,q2,q3], 'body', '313'), + ... order=None) + [-(u1*sin(q3) + u2*cos(q3))/sin(q2) + q1', -u1*cos(q3) + u2*sin(q3) + q2', (u1*sin(q3) + u2*cos(q3))*cos(q2)/sin(q2) - u3 + q3'] + + """ + + # Code below is checking and sanitizing input + approved_orders = ('123', '231', '312', '132', '213', '321', '121', '131', + '212', '232', '313', '323', '1', '2', '3', '') + # make sure XYZ => 123 and rot_type is in lower case + rot_order = translate(str(rot_order), 'XYZxyz', '123123') + rot_type = rot_type.lower() + + if not isinstance(speeds, (list, tuple)): + raise TypeError('Need to supply speeds in a list') + if len(speeds) != 3: + raise TypeError('Need to supply 3 body-fixed speeds') + if not isinstance(coords, (list, tuple)): + raise TypeError('Need to supply coordinates in a list') + if rot_type in ['body', 'space']: + if rot_order not in approved_orders: + raise ValueError('Not an acceptable rotation order') + if len(coords) != 3: + raise ValueError('Need 3 coordinates for body or space') + # Actual hard-coded kinematic differential equations + w1, w2, w3 = speeds + if w1 == w2 == w3 == 0: + return [S.Zero]*3 + q1, q2, q3 = coords + q1d, q2d, q3d = [diff(i, dynamicsymbols._t) for i in coords] + s1, s2, s3 = [sin(q1), sin(q2), sin(q3)] + c1, c2, c3 = [cos(q1), cos(q2), cos(q3)] + if rot_type == 'body': + if rot_order == '123': + return [q1d - (w1 * c3 - w2 * s3) / c2, q2d - w1 * s3 - w2 * + c3, q3d - (-w1 * c3 + w2 * s3) * s2 / c2 - w3] + if rot_order == '231': + return [q1d - (w2 * c3 - w3 * s3) / c2, q2d - w2 * s3 - w3 * + c3, q3d - w1 - (- w2 * c3 + w3 * s3) * s2 / c2] + if rot_order == '312': + return [q1d - (-w1 * s3 + w3 * c3) / c2, q2d - w1 * c3 - w3 * + s3, q3d - (w1 * s3 - w3 * c3) * s2 / c2 - w2] + if rot_order == '132': + return [q1d - (w1 * c3 + w3 * s3) / c2, q2d + w1 * s3 - w3 * + c3, q3d - (w1 * c3 + w3 * s3) * s2 / c2 - w2] + if rot_order == '213': + return [q1d - (w1 * s3 + w2 * c3) / c2, q2d - w1 * c3 + w2 * + s3, q3d - (w1 * s3 + w2 * c3) * s2 / c2 - w3] + if rot_order == '321': + return [q1d - (w2 * s3 + w3 * c3) / c2, q2d - w2 * c3 + w3 * + s3, q3d - w1 - (w2 * s3 + w3 * c3) * s2 / c2] + if rot_order == '121': + return [q1d - (w2 * s3 + w3 * c3) / s2, q2d - w2 * c3 + w3 * + s3, q3d - w1 + (w2 * s3 + w3 * c3) * c2 / s2] + if rot_order == '131': + return [q1d - (-w2 * c3 + w3 * s3) / s2, q2d - w2 * s3 - w3 * + c3, q3d - w1 - (w2 * c3 - w3 * s3) * c2 / s2] + if rot_order == '212': + return [q1d - (w1 * s3 - w3 * c3) / s2, q2d - w1 * c3 - w3 * + s3, q3d - (-w1 * s3 + w3 * c3) * c2 / s2 - w2] + if rot_order == '232': + return [q1d - (w1 * c3 + w3 * s3) / s2, q2d + w1 * s3 - w3 * + c3, q3d + (w1 * c3 + w3 * s3) * c2 / s2 - w2] + if rot_order == '313': + return [q1d - (w1 * s3 + w2 * c3) / s2, q2d - w1 * c3 + w2 * + s3, q3d + (w1 * s3 + w2 * c3) * c2 / s2 - w3] + if rot_order == '323': + return [q1d - (-w1 * c3 + w2 * s3) / s2, q2d - w1 * s3 - w2 * + c3, q3d - (w1 * c3 - w2 * s3) * c2 / s2 - w3] + if rot_type == 'space': + if rot_order == '123': + return [q1d - w1 - (w2 * s1 + w3 * c1) * s2 / c2, q2d - w2 * + c1 + w3 * s1, q3d - (w2 * s1 + w3 * c1) / c2] + if rot_order == '231': + return [q1d - (w1 * c1 + w3 * s1) * s2 / c2 - w2, q2d + w1 * + s1 - w3 * c1, q3d - (w1 * c1 + w3 * s1) / c2] + if rot_order == '312': + return [q1d - (w1 * s1 + w2 * c1) * s2 / c2 - w3, q2d - w1 * + c1 + w2 * s1, q3d - (w1 * s1 + w2 * c1) / c2] + if rot_order == '132': + return [q1d - w1 - (-w2 * c1 + w3 * s1) * s2 / c2, q2d - w2 * + s1 - w3 * c1, q3d - (w2 * c1 - w3 * s1) / c2] + if rot_order == '213': + return [q1d - (w1 * s1 - w3 * c1) * s2 / c2 - w2, q2d - w1 * + c1 - w3 * s1, q3d - (-w1 * s1 + w3 * c1) / c2] + if rot_order == '321': + return [q1d - (-w1 * c1 + w2 * s1) * s2 / c2 - w3, q2d - w1 * + s1 - w2 * c1, q3d - (w1 * c1 - w2 * s1) / c2] + if rot_order == '121': + return [q1d - w1 + (w2 * s1 + w3 * c1) * c2 / s2, q2d - w2 * + c1 + w3 * s1, q3d - (w2 * s1 + w3 * c1) / s2] + if rot_order == '131': + return [q1d - w1 - (w2 * c1 - w3 * s1) * c2 / s2, q2d - w2 * + s1 - w3 * c1, q3d - (-w2 * c1 + w3 * s1) / s2] + if rot_order == '212': + return [q1d - (-w1 * s1 + w3 * c1) * c2 / s2 - w2, q2d - w1 * + c1 - w3 * s1, q3d - (w1 * s1 - w3 * c1) / s2] + if rot_order == '232': + return [q1d + (w1 * c1 + w3 * s1) * c2 / s2 - w2, q2d + w1 * + s1 - w3 * c1, q3d - (w1 * c1 + w3 * s1) / s2] + if rot_order == '313': + return [q1d + (w1 * s1 + w2 * c1) * c2 / s2 - w3, q2d - w1 * + c1 + w2 * s1, q3d - (w1 * s1 + w2 * c1) / s2] + if rot_order == '323': + return [q1d - (w1 * c1 - w2 * s1) * c2 / s2 - w3, q2d - w1 * + s1 - w2 * c1, q3d - (-w1 * c1 + w2 * s1) / s2] + elif rot_type == 'quaternion': + if rot_order != '': + raise ValueError('Cannot have rotation order for quaternion') + if len(coords) != 4: + raise ValueError('Need 4 coordinates for quaternion') + # Actual hard-coded kinematic differential equations + e0, e1, e2, e3 = coords + w = Matrix(speeds + [0]) + E = Matrix([[e0, -e3, e2, e1], + [e3, e0, -e1, e2], + [-e2, e1, e0, e3], + [-e1, -e2, -e3, e0]]) + edots = Matrix([diff(i, dynamicsymbols._t) for i in [e1, e2, e3, e0]]) + return list(edots.T - 0.5 * w.T * E.T) + else: + raise ValueError('Not an approved rotation type for this function') + + +def get_motion_params(frame, **kwargs): + """ + Returns the three motion parameters - (acceleration, velocity, and + position) as vectorial functions of time in the given frame. + + If a higher order differential function is provided, the lower order + functions are used as boundary conditions. For example, given the + acceleration, the velocity and position parameters are taken as + boundary conditions. + + The values of time at which the boundary conditions are specified + are taken from timevalue1(for position boundary condition) and + timevalue2(for velocity boundary condition). + + If any of the boundary conditions are not provided, they are taken + to be zero by default (zero vectors, in case of vectorial inputs). If + the boundary conditions are also functions of time, they are converted + to constants by substituting the time values in the dynamicsymbols._t + time Symbol. + + This function can also be used for calculating rotational motion + parameters. Have a look at the Parameters and Examples for more clarity. + + Parameters + ========== + + frame : ReferenceFrame + The frame to express the motion parameters in + + acceleration : Vector + Acceleration of the object/frame as a function of time + + velocity : Vector + Velocity as function of time or as boundary condition + of velocity at time = timevalue1 + + position : Vector + Velocity as function of time or as boundary condition + of velocity at time = timevalue1 + + timevalue1 : sympyfiable + Value of time for position boundary condition + + timevalue2 : sympyfiable + Value of time for velocity boundary condition + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame, get_motion_params, dynamicsymbols + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> from sympy import symbols + >>> R = ReferenceFrame('R') + >>> v1, v2, v3 = dynamicsymbols('v1 v2 v3') + >>> v = v1*R.x + v2*R.y + v3*R.z + >>> get_motion_params(R, position = v) + (v1''*R.x + v2''*R.y + v3''*R.z, v1'*R.x + v2'*R.y + v3'*R.z, v1*R.x + v2*R.y + v3*R.z) + >>> a, b, c = symbols('a b c') + >>> v = a*R.x + b*R.y + c*R.z + >>> get_motion_params(R, velocity = v) + (0, a*R.x + b*R.y + c*R.z, a*t*R.x + b*t*R.y + c*t*R.z) + >>> parameters = get_motion_params(R, acceleration = v) + >>> parameters[1] + a*t*R.x + b*t*R.y + c*t*R.z + >>> parameters[2] + a*t**2/2*R.x + b*t**2/2*R.y + c*t**2/2*R.z + + """ + + def _process_vector_differential(vectdiff, condition, variable, ordinate, + frame): + """ + Helper function for get_motion methods. Finds derivative of vectdiff + wrt variable, and its integral using the specified boundary condition + at value of variable = ordinate. + Returns a tuple of - (derivative, function and integral) wrt vectdiff + + """ + + # Make sure boundary condition is independent of 'variable' + if condition != 0: + condition = express(condition, frame, variables=True) + # Special case of vectdiff == 0 + if vectdiff == Vector(0): + return (0, 0, condition) + # Express vectdiff completely in condition's frame to give vectdiff1 + vectdiff1 = express(vectdiff, frame) + # Find derivative of vectdiff + vectdiff2 = time_derivative(vectdiff, frame) + # Integrate and use boundary condition + vectdiff0 = Vector(0) + lims = (variable, ordinate, variable) + for dim in frame: + function1 = vectdiff1.dot(dim) + abscissa = dim.dot(condition).subs({variable: ordinate}) + # Indefinite integral of 'function1' wrt 'variable', using + # the given initial condition (ordinate, abscissa). + vectdiff0 += (integrate(function1, lims) + abscissa) * dim + # Return tuple + return (vectdiff2, vectdiff, vectdiff0) + + _check_frame(frame) + # Decide mode of operation based on user's input + if 'acceleration' in kwargs: + mode = 2 + elif 'velocity' in kwargs: + mode = 1 + else: + mode = 0 + # All the possible parameters in kwargs + # Not all are required for every case + # If not specified, set to default values(may or may not be used in + # calculations) + conditions = ['acceleration', 'velocity', 'position', + 'timevalue', 'timevalue1', 'timevalue2'] + for i, x in enumerate(conditions): + if x not in kwargs: + if i < 3: + kwargs[x] = Vector(0) + else: + kwargs[x] = S.Zero + elif i < 3: + _check_vector(kwargs[x]) + else: + kwargs[x] = sympify(kwargs[x]) + if mode == 2: + vel = _process_vector_differential(kwargs['acceleration'], + kwargs['velocity'], + dynamicsymbols._t, + kwargs['timevalue2'], frame)[2] + pos = _process_vector_differential(vel, kwargs['position'], + dynamicsymbols._t, + kwargs['timevalue1'], frame)[2] + return (kwargs['acceleration'], vel, pos) + elif mode == 1: + return _process_vector_differential(kwargs['velocity'], + kwargs['position'], + dynamicsymbols._t, + kwargs['timevalue1'], frame) + else: + vel = time_derivative(kwargs['position'], frame) + acc = time_derivative(vel, frame) + return (acc, vel, kwargs['position']) + + +def partial_velocity(vel_vecs, gen_speeds, frame): + """Returns a list of partial velocities with respect to the provided + generalized speeds in the given reference frame for each of the supplied + velocity vectors. + + The output is a list of lists. The outer list has a number of elements + equal to the number of supplied velocity vectors. The inner lists are, for + each velocity vector, the partial derivatives of that velocity vector with + respect to the generalized speeds supplied. + + Parameters + ========== + + vel_vecs : iterable + An iterable of velocity vectors (angular or linear). + gen_speeds : iterable + An iterable of generalized speeds. + frame : ReferenceFrame + The reference frame that the partial derivatives are going to be taken + in. + + Examples + ======== + + >>> from sympy.physics.vector import Point, ReferenceFrame + >>> from sympy.physics.vector import dynamicsymbols + >>> from sympy.physics.vector import partial_velocity + >>> u = dynamicsymbols('u') + >>> N = ReferenceFrame('N') + >>> P = Point('P') + >>> P.set_vel(N, u * N.x) + >>> vel_vecs = [P.vel(N)] + >>> gen_speeds = [u] + >>> partial_velocity(vel_vecs, gen_speeds, N) + [[N.x]] + + """ + + if not iterable(vel_vecs): + raise TypeError('Velocity vectors must be contained in an iterable.') + + if not iterable(gen_speeds): + raise TypeError('Generalized speeds must be contained in an iterable') + + vec_partials = [] + gen_speeds = list(gen_speeds) + for vel in vel_vecs: + partials = [Vector(0) for _ in gen_speeds] + for components, ref in vel.args: + mat, _ = linear_eq_to_matrix(components, gen_speeds) + for i in range(len(gen_speeds)): + for dim, direction in enumerate(ref): + if mat[dim, i] != 0: + partials[i] += direction * mat[dim, i] + + vec_partials.append(partials) + + return vec_partials + + +def dynamicsymbols(names, level=0, **assumptions): + """Uses symbols and Function for functions of time. + + Creates a SymPy UndefinedFunction, which is then initialized as a function + of a variable, the default being Symbol('t'). + + Parameters + ========== + + names : str + Names of the dynamic symbols you want to create; works the same way as + inputs to symbols + level : int + Level of differentiation of the returned function; d/dt once of t, + twice of t, etc. + assumptions : + - real(bool) : This is used to set the dynamicsymbol as real, + by default is False. + - positive(bool) : This is used to set the dynamicsymbol as positive, + by default is False. + - commutative(bool) : This is used to set the commutative property of + a dynamicsymbol, by default is True. + - integer(bool) : This is used to set the dynamicsymbol as integer, + by default is False. + + Examples + ======== + + >>> from sympy.physics.vector import dynamicsymbols + >>> from sympy import diff, Symbol + >>> q1 = dynamicsymbols('q1') + >>> q1 + q1(t) + >>> q2 = dynamicsymbols('q2', real=True) + >>> q2.is_real + True + >>> q3 = dynamicsymbols('q3', positive=True) + >>> q3.is_positive + True + >>> q4, q5 = dynamicsymbols('q4,q5', commutative=False) + >>> bool(q4*q5 != q5*q4) + True + >>> q6 = dynamicsymbols('q6', integer=True) + >>> q6.is_integer + True + >>> diff(q1, Symbol('t')) + Derivative(q1(t), t) + + """ + esses = symbols(names, cls=Function, **assumptions) + t = dynamicsymbols._t + if iterable(esses): + esses = [reduce(diff, [t] * level, e(t)) for e in esses] + return esses + else: + return reduce(diff, [t] * level, esses(t)) + + +dynamicsymbols._t = Symbol('t') # type: ignore +dynamicsymbols._str = '\'' # type: ignore diff --git a/phivenv/Lib/site-packages/sympy/physics/vector/point.py b/phivenv/Lib/site-packages/sympy/physics/vector/point.py new file mode 100644 index 0000000000000000000000000000000000000000..2841f9d465883b6fa6e1b5dc8bc0c107f18b65f7 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/vector/point.py @@ -0,0 +1,635 @@ +from .vector import Vector, _check_vector +from .frame import _check_frame +from warnings import warn +from sympy.utilities.misc import filldedent + +__all__ = ['Point'] + + +class Point: + """This object represents a point in a dynamic system. + + It stores the: position, velocity, and acceleration of a point. + The position is a vector defined as the vector distance from a parent + point to this point. + + Parameters + ========== + + name : string + The display name of the Point + + Examples + ======== + + >>> from sympy.physics.vector import Point, ReferenceFrame, dynamicsymbols + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> N = ReferenceFrame('N') + >>> O = Point('O') + >>> P = Point('P') + >>> u1, u2, u3 = dynamicsymbols('u1 u2 u3') + >>> O.set_vel(N, u1 * N.x + u2 * N.y + u3 * N.z) + >>> O.acc(N) + u1'*N.x + u2'*N.y + u3'*N.z + + ``symbols()`` can be used to create multiple Points in a single step, for + example: + + >>> from sympy.physics.vector import Point, ReferenceFrame, dynamicsymbols + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> from sympy import symbols + >>> N = ReferenceFrame('N') + >>> u1, u2 = dynamicsymbols('u1 u2') + >>> A, B = symbols('A B', cls=Point) + >>> type(A) + + >>> A.set_vel(N, u1 * N.x + u2 * N.y) + >>> B.set_vel(N, u2 * N.x + u1 * N.y) + >>> A.acc(N) - B.acc(N) + (u1' - u2')*N.x + (-u1' + u2')*N.y + + """ + + def __init__(self, name): + """Initialization of a Point object. """ + self.name = name + self._pos_dict = {} + self._vel_dict = {} + self._acc_dict = {} + self._pdlist = [self._pos_dict, self._vel_dict, self._acc_dict] + + def __str__(self): + return self.name + + __repr__ = __str__ + + def _check_point(self, other): + if not isinstance(other, Point): + raise TypeError('A Point must be supplied') + + def _pdict_list(self, other, num): + """Returns a list of points that gives the shortest path with respect + to position, velocity, or acceleration from this point to the provided + point. + + Parameters + ========== + other : Point + A point that may be related to this point by position, velocity, or + acceleration. + num : integer + 0 for searching the position tree, 1 for searching the velocity + tree, and 2 for searching the acceleration tree. + + Returns + ======= + list of Points + A sequence of points from self to other. + + Notes + ===== + + It is not clear if num = 1 or num = 2 actually works because the keys + to ``_vel_dict`` and ``_acc_dict`` are :class:`ReferenceFrame` objects + which do not have the ``_pdlist`` attribute. + + """ + outlist = [[self]] + oldlist = [[]] + while outlist != oldlist: + oldlist = outlist.copy() + for v in outlist: + templist = v[-1]._pdlist[num].keys() + for v2 in templist: + if not v.__contains__(v2): + littletemplist = v + [v2] + if not outlist.__contains__(littletemplist): + outlist.append(littletemplist) + for v in oldlist: + if v[-1] != other: + outlist.remove(v) + outlist.sort(key=len) + if len(outlist) != 0: + return outlist[0] + raise ValueError('No Connecting Path found between ' + other.name + + ' and ' + self.name) + + def a1pt_theory(self, otherpoint, outframe, interframe): + """Sets the acceleration of this point with the 1-point theory. + + The 1-point theory for point acceleration looks like this: + + ^N a^P = ^B a^P + ^N a^O + ^N alpha^B x r^OP + ^N omega^B x (^N omega^B + x r^OP) + 2 ^N omega^B x ^B v^P + + where O is a point fixed in B, P is a point moving in B, and B is + rotating in frame N. + + Parameters + ========== + + otherpoint : Point + The first point of the 1-point theory (O) + outframe : ReferenceFrame + The frame we want this point's acceleration defined in (N) + fixedframe : ReferenceFrame + The intermediate frame in this calculation (B) + + Examples + ======== + + >>> from sympy.physics.vector import Point, ReferenceFrame + >>> from sympy.physics.vector import dynamicsymbols + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> q = dynamicsymbols('q') + >>> q2 = dynamicsymbols('q2') + >>> qd = dynamicsymbols('q', 1) + >>> q2d = dynamicsymbols('q2', 1) + >>> N = ReferenceFrame('N') + >>> B = ReferenceFrame('B') + >>> B.set_ang_vel(N, 5 * B.y) + >>> O = Point('O') + >>> P = O.locatenew('P', q * B.x + q2 * B.y) + >>> P.set_vel(B, qd * B.x + q2d * B.y) + >>> O.set_vel(N, 0) + >>> P.a1pt_theory(O, N, B) + (-25*q + q'')*B.x + q2''*B.y - 10*q'*B.z + + """ + + _check_frame(outframe) + _check_frame(interframe) + self._check_point(otherpoint) + dist = self.pos_from(otherpoint) + v = self.vel(interframe) + a1 = otherpoint.acc(outframe) + a2 = self.acc(interframe) + omega = interframe.ang_vel_in(outframe) + alpha = interframe.ang_acc_in(outframe) + self.set_acc(outframe, a2 + 2 * (omega.cross(v)) + a1 + + (alpha.cross(dist)) + (omega.cross(omega.cross(dist)))) + return self.acc(outframe) + + def a2pt_theory(self, otherpoint, outframe, fixedframe): + """Sets the acceleration of this point with the 2-point theory. + + The 2-point theory for point acceleration looks like this: + + ^N a^P = ^N a^O + ^N alpha^B x r^OP + ^N omega^B x (^N omega^B x r^OP) + + where O and P are both points fixed in frame B, which is rotating in + frame N. + + Parameters + ========== + + otherpoint : Point + The first point of the 2-point theory (O) + outframe : ReferenceFrame + The frame we want this point's acceleration defined in (N) + fixedframe : ReferenceFrame + The frame in which both points are fixed (B) + + Examples + ======== + + >>> from sympy.physics.vector import Point, ReferenceFrame, dynamicsymbols + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> q = dynamicsymbols('q') + >>> qd = dynamicsymbols('q', 1) + >>> N = ReferenceFrame('N') + >>> B = N.orientnew('B', 'Axis', [q, N.z]) + >>> O = Point('O') + >>> P = O.locatenew('P', 10 * B.x) + >>> O.set_vel(N, 5 * N.x) + >>> P.a2pt_theory(O, N, B) + - 10*q'**2*B.x + 10*q''*B.y + + """ + + _check_frame(outframe) + _check_frame(fixedframe) + self._check_point(otherpoint) + dist = self.pos_from(otherpoint) + a = otherpoint.acc(outframe) + omega = fixedframe.ang_vel_in(outframe) + alpha = fixedframe.ang_acc_in(outframe) + self.set_acc(outframe, a + (alpha.cross(dist)) + + (omega.cross(omega.cross(dist)))) + return self.acc(outframe) + + def acc(self, frame): + """The acceleration Vector of this Point in a ReferenceFrame. + + Parameters + ========== + + frame : ReferenceFrame + The frame in which the returned acceleration vector will be defined + in. + + Examples + ======== + + >>> from sympy.physics.vector import Point, ReferenceFrame + >>> N = ReferenceFrame('N') + >>> p1 = Point('p1') + >>> p1.set_acc(N, 10 * N.x) + >>> p1.acc(N) + 10*N.x + + """ + + _check_frame(frame) + if not (frame in self._acc_dict): + if self.vel(frame) != 0: + return (self._vel_dict[frame]).dt(frame) + else: + return Vector(0) + return self._acc_dict[frame] + + def locatenew(self, name, value): + """Creates a new point with a position defined from this point. + + Parameters + ========== + + name : str + The name for the new point + value : Vector + The position of the new point relative to this point + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame, Point + >>> N = ReferenceFrame('N') + >>> P1 = Point('P1') + >>> P2 = P1.locatenew('P2', 10 * N.x) + + """ + + if not isinstance(name, str): + raise TypeError('Must supply a valid name') + if value == 0: + value = Vector(0) + value = _check_vector(value) + p = Point(name) + p.set_pos(self, value) + self.set_pos(p, -value) + return p + + def pos_from(self, otherpoint): + """Returns a Vector distance between this Point and the other Point. + + Parameters + ========== + + otherpoint : Point + The otherpoint we are locating this one relative to + + Examples + ======== + + >>> from sympy.physics.vector import Point, ReferenceFrame + >>> N = ReferenceFrame('N') + >>> p1 = Point('p1') + >>> p2 = Point('p2') + >>> p1.set_pos(p2, 10 * N.x) + >>> p1.pos_from(p2) + 10*N.x + + """ + + outvec = Vector(0) + plist = self._pdict_list(otherpoint, 0) + for i in range(len(plist) - 1): + outvec += plist[i]._pos_dict[plist[i + 1]] + return outvec + + def set_acc(self, frame, value): + """Used to set the acceleration of this Point in a ReferenceFrame. + + Parameters + ========== + + frame : ReferenceFrame + The frame in which this point's acceleration is defined + value : Vector + The vector value of this point's acceleration in the frame + + Examples + ======== + + >>> from sympy.physics.vector import Point, ReferenceFrame + >>> N = ReferenceFrame('N') + >>> p1 = Point('p1') + >>> p1.set_acc(N, 10 * N.x) + >>> p1.acc(N) + 10*N.x + + """ + + if value == 0: + value = Vector(0) + value = _check_vector(value) + _check_frame(frame) + self._acc_dict.update({frame: value}) + + def set_pos(self, otherpoint, value): + """Used to set the position of this point w.r.t. another point. + + Parameters + ========== + + otherpoint : Point + The other point which this point's location is defined relative to + value : Vector + The vector which defines the location of this point + + Examples + ======== + + >>> from sympy.physics.vector import Point, ReferenceFrame + >>> N = ReferenceFrame('N') + >>> p1 = Point('p1') + >>> p2 = Point('p2') + >>> p1.set_pos(p2, 10 * N.x) + >>> p1.pos_from(p2) + 10*N.x + + """ + + if value == 0: + value = Vector(0) + value = _check_vector(value) + self._check_point(otherpoint) + self._pos_dict.update({otherpoint: value}) + otherpoint._pos_dict.update({self: -value}) + + def set_vel(self, frame, value): + """Sets the velocity Vector of this Point in a ReferenceFrame. + + Parameters + ========== + + frame : ReferenceFrame + The frame in which this point's velocity is defined + value : Vector + The vector value of this point's velocity in the frame + + Examples + ======== + + >>> from sympy.physics.vector import Point, ReferenceFrame + >>> N = ReferenceFrame('N') + >>> p1 = Point('p1') + >>> p1.set_vel(N, 10 * N.x) + >>> p1.vel(N) + 10*N.x + + """ + + if value == 0: + value = Vector(0) + value = _check_vector(value) + _check_frame(frame) + self._vel_dict.update({frame: value}) + + def v1pt_theory(self, otherpoint, outframe, interframe): + """Sets the velocity of this point with the 1-point theory. + + The 1-point theory for point velocity looks like this: + + ^N v^P = ^B v^P + ^N v^O + ^N omega^B x r^OP + + where O is a point fixed in B, P is a point moving in B, and B is + rotating in frame N. + + Parameters + ========== + + otherpoint : Point + The first point of the 1-point theory (O) + outframe : ReferenceFrame + The frame we want this point's velocity defined in (N) + interframe : ReferenceFrame + The intermediate frame in this calculation (B) + + Examples + ======== + + >>> from sympy.physics.vector import Point, ReferenceFrame + >>> from sympy.physics.vector import dynamicsymbols + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> q = dynamicsymbols('q') + >>> q2 = dynamicsymbols('q2') + >>> qd = dynamicsymbols('q', 1) + >>> q2d = dynamicsymbols('q2', 1) + >>> N = ReferenceFrame('N') + >>> B = ReferenceFrame('B') + >>> B.set_ang_vel(N, 5 * B.y) + >>> O = Point('O') + >>> P = O.locatenew('P', q * B.x + q2 * B.y) + >>> P.set_vel(B, qd * B.x + q2d * B.y) + >>> O.set_vel(N, 0) + >>> P.v1pt_theory(O, N, B) + q'*B.x + q2'*B.y - 5*q*B.z + + """ + + _check_frame(outframe) + _check_frame(interframe) + self._check_point(otherpoint) + dist = self.pos_from(otherpoint) + v1 = self.vel(interframe) + v2 = otherpoint.vel(outframe) + omega = interframe.ang_vel_in(outframe) + self.set_vel(outframe, v1 + v2 + (omega.cross(dist))) + return self.vel(outframe) + + def v2pt_theory(self, otherpoint, outframe, fixedframe): + """Sets the velocity of this point with the 2-point theory. + + The 2-point theory for point velocity looks like this: + + ^N v^P = ^N v^O + ^N omega^B x r^OP + + where O and P are both points fixed in frame B, which is rotating in + frame N. + + Parameters + ========== + + otherpoint : Point + The first point of the 2-point theory (O) + outframe : ReferenceFrame + The frame we want this point's velocity defined in (N) + fixedframe : ReferenceFrame + The frame in which both points are fixed (B) + + Examples + ======== + + >>> from sympy.physics.vector import Point, ReferenceFrame, dynamicsymbols + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> q = dynamicsymbols('q') + >>> qd = dynamicsymbols('q', 1) + >>> N = ReferenceFrame('N') + >>> B = N.orientnew('B', 'Axis', [q, N.z]) + >>> O = Point('O') + >>> P = O.locatenew('P', 10 * B.x) + >>> O.set_vel(N, 5 * N.x) + >>> P.v2pt_theory(O, N, B) + 5*N.x + 10*q'*B.y + + """ + + _check_frame(outframe) + _check_frame(fixedframe) + self._check_point(otherpoint) + dist = self.pos_from(otherpoint) + v = otherpoint.vel(outframe) + omega = fixedframe.ang_vel_in(outframe) + self.set_vel(outframe, v + (omega.cross(dist))) + return self.vel(outframe) + + def vel(self, frame): + """The velocity Vector of this Point in the ReferenceFrame. + + Parameters + ========== + + frame : ReferenceFrame + The frame in which the returned velocity vector will be defined in + + Examples + ======== + + >>> from sympy.physics.vector import Point, ReferenceFrame, dynamicsymbols + >>> N = ReferenceFrame('N') + >>> p1 = Point('p1') + >>> p1.set_vel(N, 10 * N.x) + >>> p1.vel(N) + 10*N.x + + Velocities will be automatically calculated if possible, otherwise a + ``ValueError`` will be returned. If it is possible to calculate + multiple different velocities from the relative points, the points + defined most directly relative to this point will be used. In the case + of inconsistent relative positions of points, incorrect velocities may + be returned. It is up to the user to define prior relative positions + and velocities of points in a self-consistent way. + + >>> p = Point('p') + >>> q = dynamicsymbols('q') + >>> p.set_vel(N, 10 * N.x) + >>> p2 = Point('p2') + >>> p2.set_pos(p, q*N.x) + >>> p2.vel(N) + (Derivative(q(t), t) + 10)*N.x + + """ + + _check_frame(frame) + if not (frame in self._vel_dict): + valid_neighbor_found = False + is_cyclic = False + visited = [] + queue = [self] + candidate_neighbor = [] + while queue: # BFS to find nearest point + node = queue.pop(0) + if node not in visited: + visited.append(node) + for neighbor, neighbor_pos in node._pos_dict.items(): + if neighbor in visited: + continue + try: + # Checks if pos vector is valid + neighbor_pos.express(frame) + except ValueError: + continue + if neighbor in queue: + is_cyclic = True + try: + # Checks if point has its vel defined in req frame + neighbor_velocity = neighbor._vel_dict[frame] + except KeyError: + queue.append(neighbor) + continue + candidate_neighbor.append(neighbor) + if not valid_neighbor_found: + self.set_vel(frame, self.pos_from(neighbor).dt(frame) + neighbor_velocity) + valid_neighbor_found = True + if is_cyclic: + warn(filldedent(""" + Kinematic loops are defined among the positions of points. This + is likely not desired and may cause errors in your calculations. + """)) + if len(candidate_neighbor) > 1: + warn(filldedent(f""" + Velocity of {self.name} automatically calculated based on point + {candidate_neighbor[0].name} but it is also possible from + points(s): {str(candidate_neighbor[1:])}. Velocities from these + points are not necessarily the same. This may cause errors in + your calculations.""")) + if valid_neighbor_found: + return self._vel_dict[frame] + else: + raise ValueError(filldedent(f""" + Velocity of point {self.name} has not been defined in + ReferenceFrame {frame.name}.""")) + + return self._vel_dict[frame] + + def partial_velocity(self, frame, *gen_speeds): + """Returns the partial velocities of the linear velocity vector of this + point in the given frame with respect to one or more provided + generalized speeds. + + Parameters + ========== + frame : ReferenceFrame + The frame with which the velocity is defined in. + gen_speeds : functions of time + The generalized speeds. + + Returns + ======= + partial_velocities : tuple of Vector + The partial velocity vectors corresponding to the provided + generalized speeds. + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame, Point + >>> from sympy.physics.vector import dynamicsymbols + >>> N = ReferenceFrame('N') + >>> A = ReferenceFrame('A') + >>> p = Point('p') + >>> u1, u2 = dynamicsymbols('u1, u2') + >>> p.set_vel(N, u1 * N.x + u2 * A.y) + >>> p.partial_velocity(N, u1) + N.x + >>> p.partial_velocity(N, u1, u2) + (N.x, A.y) + + """ + + from sympy.physics.vector.functions import partial_velocity + + vel = self.vel(frame) + partials = partial_velocity([vel], gen_speeds, frame)[0] + + if len(partials) == 1: + return partials[0] + else: + return tuple(partials) diff --git a/phivenv/Lib/site-packages/sympy/physics/vector/printing.py b/phivenv/Lib/site-packages/sympy/physics/vector/printing.py new file mode 100644 index 0000000000000000000000000000000000000000..2b589f673329e1e598b9b568fba6c07b8abe67bc --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/vector/printing.py @@ -0,0 +1,371 @@ +from sympy.core.function import Derivative +from sympy.core.function import UndefinedFunction, AppliedUndef +from sympy.core.symbol import Symbol +from sympy.interactive.printing import init_printing +from sympy.printing.latex import LatexPrinter +from sympy.printing.pretty.pretty import PrettyPrinter +from sympy.printing.pretty.pretty_symbology import center_accent +from sympy.printing.str import StrPrinter +from sympy.printing.precedence import PRECEDENCE + +__all__ = ['vprint', 'vsstrrepr', 'vsprint', 'vpprint', 'vlatex', + 'init_vprinting'] + + +class VectorStrPrinter(StrPrinter): + """String Printer for vector expressions. """ + + def _print_Derivative(self, e): + from sympy.physics.vector.functions import dynamicsymbols + t = dynamicsymbols._t + if (bool(sum(i == t for i in e.variables)) & + isinstance(type(e.args[0]), UndefinedFunction)): + ol = str(e.args[0].func) + for i, v in enumerate(e.variables): + ol += dynamicsymbols._str + return ol + else: + return StrPrinter().doprint(e) + + def _print_Function(self, e): + from sympy.physics.vector.functions import dynamicsymbols + t = dynamicsymbols._t + if isinstance(type(e), UndefinedFunction): + return StrPrinter().doprint(e).replace("(%s)" % t, '') + return e.func.__name__ + "(%s)" % self.stringify(e.args, ", ") + + +class VectorStrReprPrinter(VectorStrPrinter): + """String repr printer for vector expressions.""" + def _print_str(self, s): + return repr(s) + + +class VectorLatexPrinter(LatexPrinter): + """Latex Printer for vector expressions. """ + + def _print_Function(self, expr, exp=None): + from sympy.physics.vector.functions import dynamicsymbols + func = expr.func.__name__ + t = dynamicsymbols._t + + if (hasattr(self, '_print_' + func) and not + isinstance(type(expr), UndefinedFunction)): + return getattr(self, '_print_' + func)(expr, exp) + elif isinstance(type(expr), UndefinedFunction) and (expr.args == (t,)): + # treat this function like a symbol + expr = Symbol(func) + if exp is not None: + # copied from LatexPrinter._helper_print_standard_power, which + # we can't call because we only have exp as a string. + base = self.parenthesize(expr, PRECEDENCE['Pow']) + base = self.parenthesize_super(base) + return r"%s^{%s}" % (base, exp) + else: + return super()._print(expr) + else: + return super()._print_Function(expr, exp) + + def _print_Derivative(self, der_expr): + from sympy.physics.vector.functions import dynamicsymbols + # make sure it is in the right form + der_expr = der_expr.doit() + if not isinstance(der_expr, Derivative): + return r"\left(%s\right)" % self.doprint(der_expr) + + # check if expr is a dynamicsymbol + t = dynamicsymbols._t + expr = der_expr.expr + red = expr.atoms(AppliedUndef) + syms = der_expr.variables + test1 = not all(True for i in red if i.free_symbols == {t}) + test2 = not all(t == i for i in syms) + if test1 or test2: + return super()._print_Derivative(der_expr) + + # done checking + dots = len(syms) + base = self._print_Function(expr) + base_split = base.split('_', 1) + base = base_split[0] + if dots == 1: + base = r"\dot{%s}" % base + elif dots == 2: + base = r"\ddot{%s}" % base + elif dots == 3: + base = r"\dddot{%s}" % base + elif dots == 4: + base = r"\ddddot{%s}" % base + else: # Fallback to standard printing + return super()._print_Derivative(der_expr) + if len(base_split) != 1: + base += '_' + base_split[1] + return base + + +class VectorPrettyPrinter(PrettyPrinter): + """Pretty Printer for vectorialexpressions. """ + + def _print_Derivative(self, deriv): + from sympy.physics.vector.functions import dynamicsymbols + # XXX use U('PARTIAL DIFFERENTIAL') here ? + t = dynamicsymbols._t + dot_i = 0 + syms = list(reversed(deriv.variables)) + + while len(syms) > 0: + if syms[-1] == t: + syms.pop() + dot_i += 1 + else: + return super()._print_Derivative(deriv) + + if not (isinstance(type(deriv.expr), UndefinedFunction) and + (deriv.expr.args == (t,))): + return super()._print_Derivative(deriv) + else: + pform = self._print_Function(deriv.expr) + + # the following condition would happen with some sort of non-standard + # dynamic symbol I guess, so we'll just print the SymPy way + if len(pform.picture) > 1: + return super()._print_Derivative(deriv) + + # There are only special symbols up to fourth-order derivatives + if dot_i >= 5: + return super()._print_Derivative(deriv) + + # Deal with special symbols + dots = {0: "", + 1: "\N{COMBINING DOT ABOVE}", + 2: "\N{COMBINING DIAERESIS}", + 3: "\N{COMBINING THREE DOTS ABOVE}", + 4: "\N{COMBINING FOUR DOTS ABOVE}"} + + d = pform.__dict__ + # if unicode is false then calculate number of apostrophes needed and + # add to output + if not self._use_unicode: + apostrophes = "" + for i in range(0, dot_i): + apostrophes += "'" + d['picture'][0] += apostrophes + "(t)" + else: + d['picture'] = [center_accent(d['picture'][0], dots[dot_i])] + return pform + + def _print_Function(self, e): + from sympy.physics.vector.functions import dynamicsymbols + t = dynamicsymbols._t + # XXX works only for applied functions + func = e.func + args = e.args + func_name = func.__name__ + pform = self._print_Symbol(Symbol(func_name)) + # If this function is an Undefined function of t, it is probably a + # dynamic symbol, so we'll skip the (t). The rest of the code is + # identical to the normal PrettyPrinter code + if not (isinstance(func, UndefinedFunction) and (args == (t,))): + return super()._print_Function(e) + return pform + + +def vprint(expr, **settings): + r"""Function for printing of expressions generated in the + sympy.physics vector package. + + Extends SymPy's StrPrinter, takes the same setting accepted by SymPy's + :func:`~.sstr`, and is equivalent to ``print(sstr(foo))``. + + Parameters + ========== + + expr : valid SymPy object + SymPy expression to print. + settings : args + Same as the settings accepted by SymPy's sstr(). + + Examples + ======== + + >>> from sympy.physics.vector import vprint, dynamicsymbols + >>> u1 = dynamicsymbols('u1') + >>> print(u1) + u1(t) + >>> vprint(u1) + u1 + + """ + + outstr = vsprint(expr, **settings) + + import builtins + if (outstr != 'None'): + builtins._ = outstr + print(outstr) + + +def vsstrrepr(expr, **settings): + """Function for displaying expression representation's with vector + printing enabled. + + Parameters + ========== + + expr : valid SymPy object + SymPy expression to print. + settings : args + Same as the settings accepted by SymPy's sstrrepr(). + + """ + p = VectorStrReprPrinter(settings) + return p.doprint(expr) + + +def vsprint(expr, **settings): + r"""Function for displaying expressions generated in the + sympy.physics vector package. + + Returns the output of vprint() as a string. + + Parameters + ========== + + expr : valid SymPy object + SymPy expression to print + settings : args + Same as the settings accepted by SymPy's sstr(). + + Examples + ======== + + >>> from sympy.physics.vector import vsprint, dynamicsymbols + >>> u1, u2 = dynamicsymbols('u1 u2') + >>> u2d = dynamicsymbols('u2', level=1) + >>> print("%s = %s" % (u1, u2 + u2d)) + u1(t) = u2(t) + Derivative(u2(t), t) + >>> print("%s = %s" % (vsprint(u1), vsprint(u2 + u2d))) + u1 = u2 + u2' + + """ + + string_printer = VectorStrPrinter(settings) + return string_printer.doprint(expr) + + +def vpprint(expr, **settings): + r"""Function for pretty printing of expressions generated in the + sympy.physics vector package. + + Mainly used for expressions not inside a vector; the output of running + scripts and generating equations of motion. Takes the same options as + SymPy's :func:`~.pretty_print`; see that function for more information. + + Parameters + ========== + + expr : valid SymPy object + SymPy expression to pretty print + settings : args + Same as those accepted by SymPy's pretty_print. + + + """ + + pp = VectorPrettyPrinter(settings) + + # Note that this is copied from sympy.printing.pretty.pretty_print: + + # XXX: this is an ugly hack, but at least it works + use_unicode = pp._settings['use_unicode'] + from sympy.printing.pretty.pretty_symbology import pretty_use_unicode + uflag = pretty_use_unicode(use_unicode) + + try: + return pp.doprint(expr) + finally: + pretty_use_unicode(uflag) + + +def vlatex(expr, **settings): + r"""Function for printing latex representation of sympy.physics.vector + objects. + + For latex representation of Vectors, Dyadics, and dynamicsymbols. Takes the + same options as SymPy's :func:`~.latex`; see that function for more + information; + + Parameters + ========== + + expr : valid SymPy object + SymPy expression to represent in LaTeX form + settings : args + Same as latex() + + Examples + ======== + + >>> from sympy.physics.vector import vlatex, ReferenceFrame, dynamicsymbols + >>> N = ReferenceFrame('N') + >>> q1, q2 = dynamicsymbols('q1 q2') + >>> q1d, q2d = dynamicsymbols('q1 q2', 1) + >>> q1dd, q2dd = dynamicsymbols('q1 q2', 2) + >>> vlatex(N.x + N.y) + '\\mathbf{\\hat{n}_x} + \\mathbf{\\hat{n}_y}' + >>> vlatex(q1 + q2) + 'q_{1} + q_{2}' + >>> vlatex(q1d) + '\\dot{q}_{1}' + >>> vlatex(q1 * q2d) + 'q_{1} \\dot{q}_{2}' + >>> vlatex(q1dd * q1 / q1d) + '\\frac{q_{1} \\ddot{q}_{1}}{\\dot{q}_{1}}' + + """ + latex_printer = VectorLatexPrinter(settings) + + return latex_printer.doprint(expr) + + +def init_vprinting(**kwargs): + """Initializes time derivative printing for all SymPy objects, i.e. any + functions of time will be displayed in a more compact notation. The main + benefit of this is for printing of time derivatives; instead of + displaying as ``Derivative(f(t),t)``, it will display ``f'``. This is + only actually needed for when derivatives are present and are not in a + physics.vector.Vector or physics.vector.Dyadic object. This function is a + light wrapper to :func:`~.init_printing`. Any keyword + arguments for it are valid here. + + {0} + + Examples + ======== + + >>> from sympy import Function, symbols + >>> t, x = symbols('t, x') + >>> omega = Function('omega') + >>> omega(x).diff() + Derivative(omega(x), x) + >>> omega(t).diff() + Derivative(omega(t), t) + + Now use the string printer: + + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> omega(x).diff() + Derivative(omega(x), x) + >>> omega(t).diff() + omega' + + """ + kwargs['str_printer'] = vsstrrepr + kwargs['pretty_printer'] = vpprint + kwargs['latex_printer'] = vlatex + init_printing(**kwargs) + + +params = init_printing.__doc__.split('Examples\n ========')[0] # type: ignore +init_vprinting.__doc__ = init_vprinting.__doc__.format(params) # type: ignore diff --git a/phivenv/Lib/site-packages/sympy/physics/vector/tests/__init__.py b/phivenv/Lib/site-packages/sympy/physics/vector/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/physics/vector/tests/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/vector/tests/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c71f4634e26c9b03a32f364ee02067ab39612c3 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/vector/tests/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/vector/tests/__pycache__/test_dyadic.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/vector/tests/__pycache__/test_dyadic.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c87f1b6d66a9935ff28872d026c381a95863619c Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/vector/tests/__pycache__/test_dyadic.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/vector/tests/__pycache__/test_fieldfunctions.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/vector/tests/__pycache__/test_fieldfunctions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14fd29d49cde21bf9fb51eed8bb4e12770cb3163 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/vector/tests/__pycache__/test_fieldfunctions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/vector/tests/__pycache__/test_frame.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/vector/tests/__pycache__/test_frame.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f81c80af66883bfb72f26dd682deb494a7e6d133 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/vector/tests/__pycache__/test_frame.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/vector/tests/__pycache__/test_functions.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/vector/tests/__pycache__/test_functions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..efd87cd4f16a6e4c52c2ea8640cea47af79f4e26 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/vector/tests/__pycache__/test_functions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/vector/tests/__pycache__/test_output.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/vector/tests/__pycache__/test_output.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62b8379cf190af403db00d049b658d2bb96160dc Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/vector/tests/__pycache__/test_output.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/vector/tests/__pycache__/test_point.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/vector/tests/__pycache__/test_point.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67ce301fc68e3d845f721c0f54af755ebe55080a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/vector/tests/__pycache__/test_point.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/vector/tests/__pycache__/test_printing.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/vector/tests/__pycache__/test_printing.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..334a4f2d65b5c6618aec3d72c5c67664d42a4125 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/vector/tests/__pycache__/test_printing.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/vector/tests/__pycache__/test_vector.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/physics/vector/tests/__pycache__/test_vector.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5fd5a480da8a205d975b0c4defd9ba279dd474d2 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/physics/vector/tests/__pycache__/test_vector.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/physics/vector/tests/test_dyadic.py b/phivenv/Lib/site-packages/sympy/physics/vector/tests/test_dyadic.py new file mode 100644 index 0000000000000000000000000000000000000000..ab365b4687162ccbd3b21dd9709b84dbcdec8aa0 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/vector/tests/test_dyadic.py @@ -0,0 +1,123 @@ +from sympy.core.numbers import (Float, pi) +from sympy.core.symbol import symbols +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.matrices.immutable import ImmutableDenseMatrix as Matrix +from sympy.physics.vector import ReferenceFrame, dynamicsymbols, outer +from sympy.physics.vector.dyadic import _check_dyadic +from sympy.testing.pytest import raises + +A = ReferenceFrame('A') + + +def test_dyadic(): + d1 = A.x | A.x + d2 = A.y | A.y + d3 = A.x | A.y + assert d1 * 0 == 0 + assert d1 != 0 + assert d1 * 2 == 2 * A.x | A.x + assert d1 / 2. == 0.5 * d1 + assert d1 & (0 * d1) == 0 + assert d1 & d2 == 0 + assert d1 & A.x == A.x + assert d1 ^ A.x == 0 + assert d1 ^ A.y == A.x | A.z + assert d1 ^ A.z == - A.x | A.y + assert d2 ^ A.x == - A.y | A.z + assert A.x ^ d1 == 0 + assert A.y ^ d1 == - A.z | A.x + assert A.z ^ d1 == A.y | A.x + assert A.x & d1 == A.x + assert A.y & d1 == 0 + assert A.y & d2 == A.y + assert d1 & d3 == A.x | A.y + assert d3 & d1 == 0 + assert d1.dt(A) == 0 + q = dynamicsymbols('q') + qd = dynamicsymbols('q', 1) + B = A.orientnew('B', 'Axis', [q, A.z]) + assert d1.express(B) == d1.express(B, B) + assert d1.express(B) == ((cos(q)**2) * (B.x | B.x) + (-sin(q) * cos(q)) * + (B.x | B.y) + (-sin(q) * cos(q)) * (B.y | B.x) + (sin(q)**2) * + (B.y | B.y)) + assert d1.express(B, A) == (cos(q)) * (B.x | A.x) + (-sin(q)) * (B.y | A.x) + assert d1.express(A, B) == (cos(q)) * (A.x | B.x) + (-sin(q)) * (A.x | B.y) + assert d1.dt(B) == (-qd) * (A.y | A.x) + (-qd) * (A.x | A.y) + + assert d1.to_matrix(A) == Matrix([[1, 0, 0], [0, 0, 0], [0, 0, 0]]) + assert d1.to_matrix(A, B) == Matrix([[cos(q), -sin(q), 0], + [0, 0, 0], + [0, 0, 0]]) + assert d3.to_matrix(A) == Matrix([[0, 1, 0], [0, 0, 0], [0, 0, 0]]) + a, b, c, d, e, f = symbols('a, b, c, d, e, f') + v1 = a * A.x + b * A.y + c * A.z + v2 = d * A.x + e * A.y + f * A.z + d4 = v1.outer(v2) + assert d4.to_matrix(A) == Matrix([[a * d, a * e, a * f], + [b * d, b * e, b * f], + [c * d, c * e, c * f]]) + d5 = v1.outer(v1) + C = A.orientnew('C', 'Axis', [q, A.x]) + for expected, actual in zip(C.dcm(A) * d5.to_matrix(A) * C.dcm(A).T, + d5.to_matrix(C)): + assert (expected - actual).simplify() == 0 + + raises(TypeError, lambda: d1.applyfunc(0)) + + +def test_dyadic_simplify(): + x, y, z, k, n, m, w, f, s, A = symbols('x, y, z, k, n, m, w, f, s, A') + N = ReferenceFrame('N') + + dy = N.x | N.x + test1 = (1 / x + 1 / y) * dy + assert (N.x & test1 & N.x) != (x + y) / (x * y) + test1 = test1.simplify() + assert (N.x & test1 & N.x) == (x + y) / (x * y) + + test2 = (A**2 * s**4 / (4 * pi * k * m**3)) * dy + test2 = test2.simplify() + assert (N.x & test2 & N.x) == (A**2 * s**4 / (4 * pi * k * m**3)) + + test3 = ((4 + 4 * x - 2 * (2 + 2 * x)) / (2 + 2 * x)) * dy + test3 = test3.simplify() + assert (N.x & test3 & N.x) == 0 + + test4 = ((-4 * x * y**2 - 2 * y**3 - 2 * x**2 * y) / (x + y)**2) * dy + test4 = test4.simplify() + assert (N.x & test4 & N.x) == -2 * y + + +def test_dyadic_subs(): + N = ReferenceFrame('N') + s = symbols('s') + a = s*(N.x | N.x) + assert a.subs({s: 2}) == 2*(N.x | N.x) + + +def test_check_dyadic(): + raises(TypeError, lambda: _check_dyadic(0)) + + +def test_dyadic_evalf(): + N = ReferenceFrame('N') + a = pi * (N.x | N.x) + assert a.evalf(3) == Float('3.1416', 3) * (N.x | N.x) + s = symbols('s') + a = 5 * s * pi* (N.x | N.x) + assert a.evalf(2) == Float('5', 2) * Float('3.1416', 2) * s * (N.x | N.x) + assert a.evalf(9, subs={s: 5.124}) == Float('80.48760378', 9) * (N.x | N.x) + + +def test_dyadic_xreplace(): + x, y, z = symbols('x y z') + N = ReferenceFrame('N') + D = outer(N.x, N.x) + v = x*y * D + assert v.xreplace({x : cos(x)}) == cos(x)*y * D + assert v.xreplace({x*y : pi}) == pi * D + v = (x*y)**z * D + assert v.xreplace({(x*y)**z : 1}) == D + assert v.xreplace({x:1, z:0}) == D + raises(TypeError, lambda: v.xreplace()) + raises(TypeError, lambda: v.xreplace([x, y])) diff --git a/phivenv/Lib/site-packages/sympy/physics/vector/tests/test_fieldfunctions.py b/phivenv/Lib/site-packages/sympy/physics/vector/tests/test_fieldfunctions.py new file mode 100644 index 0000000000000000000000000000000000000000..4e5c67aad44ca972dac6e455c57b60a74bae207a --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/vector/tests/test_fieldfunctions.py @@ -0,0 +1,133 @@ +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.physics.vector import ReferenceFrame, Vector, Point, \ + dynamicsymbols +from sympy.physics.vector.fieldfunctions import divergence, \ + gradient, curl, is_conservative, is_solenoidal, \ + scalar_potential, scalar_potential_difference +from sympy.testing.pytest import raises + +R = ReferenceFrame('R') +q = dynamicsymbols('q') +P = R.orientnew('P', 'Axis', [q, R.z]) + + +def test_curl(): + assert curl(Vector(0), R) == Vector(0) + assert curl(R.x, R) == Vector(0) + assert curl(2*R[1]**2*R.y, R) == Vector(0) + assert curl(R[0]*R[1]*R.z, R) == R[0]*R.x - R[1]*R.y + assert curl(R[0]*R[1]*R[2] * (R.x+R.y+R.z), R) == \ + (-R[0]*R[1] + R[0]*R[2])*R.x + (R[0]*R[1] - R[1]*R[2])*R.y + \ + (-R[0]*R[2] + R[1]*R[2])*R.z + assert curl(2*R[0]**2*R.y, R) == 4*R[0]*R.z + assert curl(P[0]**2*R.x + P.y, R) == \ + - 2*(R[0]*cos(q) + R[1]*sin(q))*sin(q)*R.z + assert curl(P[0]*R.y, P) == cos(q)*P.z + + +def test_divergence(): + assert divergence(Vector(0), R) is S.Zero + assert divergence(R.x, R) is S.Zero + assert divergence(R[0]**2*R.x, R) == 2*R[0] + assert divergence(R[0]*R[1]*R[2] * (R.x+R.y+R.z), R) == \ + R[0]*R[1] + R[0]*R[2] + R[1]*R[2] + assert divergence((1/(R[0]*R[1]*R[2])) * (R.x+R.y+R.z), R) == \ + -1/(R[0]*R[1]*R[2]**2) - 1/(R[0]*R[1]**2*R[2]) - \ + 1/(R[0]**2*R[1]*R[2]) + v = P[0]*P.x + P[1]*P.y + P[2]*P.z + assert divergence(v, P) == 3 + assert divergence(v, R).simplify() == 3 + assert divergence(P[0]*R.x + R[0]*P.x, R) == 2*cos(q) + + +def test_gradient(): + a = Symbol('a') + assert gradient(0, R) == Vector(0) + assert gradient(R[0], R) == R.x + assert gradient(R[0]*R[1]*R[2], R) == \ + R[1]*R[2]*R.x + R[0]*R[2]*R.y + R[0]*R[1]*R.z + assert gradient(2*R[0]**2, R) == 4*R[0]*R.x + assert gradient(a*sin(R[1])/R[0], R) == \ + - a*sin(R[1])/R[0]**2*R.x + a*cos(R[1])/R[0]*R.y + assert gradient(P[0]*P[1], R) == \ + ((-R[0]*sin(q) + R[1]*cos(q))*cos(q) - (R[0]*cos(q) + R[1]*sin(q))*sin(q))*R.x + \ + ((-R[0]*sin(q) + R[1]*cos(q))*sin(q) + (R[0]*cos(q) + R[1]*sin(q))*cos(q))*R.y + assert gradient(P[0]*R[2], P) == P[2]*P.x + P[0]*P.z + + +scalar_field = 2*R[0]**2*R[1]*R[2] +grad_field = gradient(scalar_field, R) +vector_field = R[1]**2*R.x + 3*R[0]*R.y + 5*R[1]*R[2]*R.z +curl_field = curl(vector_field, R) + + +def test_conservative(): + assert is_conservative(0) is True + assert is_conservative(R.x) is True + assert is_conservative(2 * R.x + 3 * R.y + 4 * R.z) is True + assert is_conservative(R[1]*R[2]*R.x + R[0]*R[2]*R.y + R[0]*R[1]*R.z) is \ + True + assert is_conservative(R[0] * R.y) is False + assert is_conservative(grad_field) is True + assert is_conservative(curl_field) is False + assert is_conservative(4*R[0]*R[1]*R[2]*R.x + 2*R[0]**2*R[2]*R.y) is \ + False + assert is_conservative(R[2]*P.x + P[0]*R.z) is True + + +def test_solenoidal(): + assert is_solenoidal(0) is True + assert is_solenoidal(R.x) is True + assert is_solenoidal(2 * R.x + 3 * R.y + 4 * R.z) is True + assert is_solenoidal(R[1]*R[2]*R.x + R[0]*R[2]*R.y + R[0]*R[1]*R.z) is \ + True + assert is_solenoidal(R[1] * R.y) is False + assert is_solenoidal(grad_field) is False + assert is_solenoidal(curl_field) is True + assert is_solenoidal((-2*R[1] + 3)*R.z) is True + assert is_solenoidal(cos(q)*R.x + sin(q)*R.y + cos(q)*P.z) is True + assert is_solenoidal(R[2]*P.x + P[0]*R.z) is True + + +def test_scalar_potential(): + assert scalar_potential(0, R) == 0 + assert scalar_potential(R.x, R) == R[0] + assert scalar_potential(R.y, R) == R[1] + assert scalar_potential(R.z, R) == R[2] + assert scalar_potential(R[1]*R[2]*R.x + R[0]*R[2]*R.y + \ + R[0]*R[1]*R.z, R) == R[0]*R[1]*R[2] + assert scalar_potential(grad_field, R) == scalar_field + assert scalar_potential(R[2]*P.x + P[0]*R.z, R) == \ + R[0]*R[2]*cos(q) + R[1]*R[2]*sin(q) + assert scalar_potential(R[2]*P.x + P[0]*R.z, P) == P[0]*P[2] + raises(ValueError, lambda: scalar_potential(R[0] * R.y, R)) + + +def test_scalar_potential_difference(): + origin = Point('O') + point1 = origin.locatenew('P1', 1*R.x + 2*R.y + 3*R.z) + point2 = origin.locatenew('P2', 4*R.x + 5*R.y + 6*R.z) + genericpointR = origin.locatenew('RP', R[0]*R.x + R[1]*R.y + R[2]*R.z) + genericpointP = origin.locatenew('PP', P[0]*P.x + P[1]*P.y + P[2]*P.z) + assert scalar_potential_difference(S.Zero, R, point1, point2, \ + origin) == 0 + assert scalar_potential_difference(scalar_field, R, origin, \ + genericpointR, origin) == \ + scalar_field + assert scalar_potential_difference(grad_field, R, origin, \ + genericpointR, origin) == \ + scalar_field + assert scalar_potential_difference(grad_field, R, point1, point2, + origin) == 948 + assert scalar_potential_difference(R[1]*R[2]*R.x + R[0]*R[2]*R.y + \ + R[0]*R[1]*R.z, R, point1, + genericpointR, origin) == \ + R[0]*R[1]*R[2] - 6 + potential_diff_P = 2*P[2]*(P[0]*sin(q) + P[1]*cos(q))*\ + (P[0]*cos(q) - P[1]*sin(q))**2 + assert scalar_potential_difference(grad_field, P, origin, \ + genericpointP, \ + origin).simplify() == \ + potential_diff_P diff --git a/phivenv/Lib/site-packages/sympy/physics/vector/tests/test_frame.py b/phivenv/Lib/site-packages/sympy/physics/vector/tests/test_frame.py new file mode 100644 index 0000000000000000000000000000000000000000..8e2d0234c7d2d9f91fdb5421c5a92f05495006c6 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/vector/tests/test_frame.py @@ -0,0 +1,761 @@ +from sympy.core.numbers import pi +from sympy.core.symbol import symbols +from sympy.simplify import trigsimp +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.matrices.dense import (eye, zeros) +from sympy.matrices.immutable import ImmutableDenseMatrix as Matrix +from sympy.simplify.simplify import simplify +from sympy.physics.vector import (ReferenceFrame, Vector, CoordinateSym, + dynamicsymbols, time_derivative, express, + dot) +from sympy.physics.vector.frame import _check_frame +from sympy.physics.vector.vector import VectorTypeError +from sympy.testing.pytest import raises +import warnings +import pickle + + +def test_dict_list(): + + A = ReferenceFrame('A') + B = ReferenceFrame('B') + C = ReferenceFrame('C') + D = ReferenceFrame('D') + E = ReferenceFrame('E') + F = ReferenceFrame('F') + + B.orient_axis(A, A.x, 1.0) + C.orient_axis(B, B.x, 1.0) + D.orient_axis(C, C.x, 1.0) + + assert D._dict_list(A, 0) == [D, C, B, A] + + E.orient_axis(D, D.x, 1.0) + + assert C._dict_list(A, 0) == [C, B, A] + assert C._dict_list(E, 0) == [C, D, E] + + # only 0, 1, 2 permitted for second argument + raises(ValueError, lambda: C._dict_list(E, 5)) + # no connecting path + raises(ValueError, lambda: F._dict_list(A, 0)) + + +def test_coordinate_vars(): + """Tests the coordinate variables functionality""" + A = ReferenceFrame('A') + assert CoordinateSym('Ax', A, 0) == A[0] + assert CoordinateSym('Ax', A, 1) == A[1] + assert CoordinateSym('Ax', A, 2) == A[2] + raises(ValueError, lambda: CoordinateSym('Ax', A, 3)) + q = dynamicsymbols('q') + qd = dynamicsymbols('q', 1) + assert isinstance(A[0], CoordinateSym) and \ + isinstance(A[0], CoordinateSym) and \ + isinstance(A[0], CoordinateSym) + assert A.variable_map(A) == {A[0]:A[0], A[1]:A[1], A[2]:A[2]} + assert A[0].frame == A + B = A.orientnew('B', 'Axis', [q, A.z]) + assert B.variable_map(A) == {B[2]: A[2], B[1]: -A[0]*sin(q) + A[1]*cos(q), + B[0]: A[0]*cos(q) + A[1]*sin(q)} + assert A.variable_map(B) == {A[0]: B[0]*cos(q) - B[1]*sin(q), + A[1]: B[0]*sin(q) + B[1]*cos(q), A[2]: B[2]} + assert time_derivative(B[0], A) == -A[0]*sin(q)*qd + A[1]*cos(q)*qd + assert time_derivative(B[1], A) == -A[0]*cos(q)*qd - A[1]*sin(q)*qd + assert time_derivative(B[2], A) == 0 + assert express(B[0], A, variables=True) == A[0]*cos(q) + A[1]*sin(q) + assert express(B[1], A, variables=True) == -A[0]*sin(q) + A[1]*cos(q) + assert express(B[2], A, variables=True) == A[2] + assert time_derivative(A[0]*A.x + A[1]*A.y + A[2]*A.z, B) == A[1]*qd*A.x - A[0]*qd*A.y + assert time_derivative(B[0]*B.x + B[1]*B.y + B[2]*B.z, A) == - B[1]*qd*B.x + B[0]*qd*B.y + assert express(B[0]*B[1]*B[2], A, variables=True) == \ + A[2]*(-A[0]*sin(q) + A[1]*cos(q))*(A[0]*cos(q) + A[1]*sin(q)) + assert (time_derivative(B[0]*B[1]*B[2], A) - + (A[2]*(-A[0]**2*cos(2*q) - + 2*A[0]*A[1]*sin(2*q) + + A[1]**2*cos(2*q))*qd)).trigsimp() == 0 + assert express(B[0]*B.x + B[1]*B.y + B[2]*B.z, A) == \ + (B[0]*cos(q) - B[1]*sin(q))*A.x + (B[0]*sin(q) + \ + B[1]*cos(q))*A.y + B[2]*A.z + assert express(B[0]*B.x + B[1]*B.y + B[2]*B.z, A, + variables=True).simplify() == A[0]*A.x + A[1]*A.y + A[2]*A.z + assert express(A[0]*A.x + A[1]*A.y + A[2]*A.z, B) == \ + (A[0]*cos(q) + A[1]*sin(q))*B.x + \ + (-A[0]*sin(q) + A[1]*cos(q))*B.y + A[2]*B.z + assert express(A[0]*A.x + A[1]*A.y + A[2]*A.z, B, + variables=True).simplify() == B[0]*B.x + B[1]*B.y + B[2]*B.z + N = B.orientnew('N', 'Axis', [-q, B.z]) + assert ({k: v.simplify() for k, v in N.variable_map(A).items()} == + {N[0]: A[0], N[2]: A[2], N[1]: A[1]}) + C = A.orientnew('C', 'Axis', [q, A.x + A.y + A.z]) + mapping = A.variable_map(C) + assert trigsimp(mapping[A[0]]) == (2*C[0]*cos(q)/3 + C[0]/3 - + 2*C[1]*sin(q + pi/6)/3 + + C[1]/3 - 2*C[2]*cos(q + pi/3)/3 + + C[2]/3) + assert trigsimp(mapping[A[1]]) == -2*C[0]*cos(q + pi/3)/3 + \ + C[0]/3 + 2*C[1]*cos(q)/3 + C[1]/3 - 2*C[2]*sin(q + pi/6)/3 + C[2]/3 + assert trigsimp(mapping[A[2]]) == -2*C[0]*sin(q + pi/6)/3 + C[0]/3 - \ + 2*C[1]*cos(q + pi/3)/3 + C[1]/3 + 2*C[2]*cos(q)/3 + C[2]/3 + + +def test_ang_vel(): + q1, q2, q3, q4 = dynamicsymbols('q1 q2 q3 q4') + q1d, q2d, q3d, q4d = dynamicsymbols('q1 q2 q3 q4', 1) + N = ReferenceFrame('N') + A = N.orientnew('A', 'Axis', [q1, N.z]) + B = A.orientnew('B', 'Axis', [q2, A.x]) + C = B.orientnew('C', 'Axis', [q3, B.y]) + D = N.orientnew('D', 'Axis', [q4, N.y]) + u1, u2, u3 = dynamicsymbols('u1 u2 u3') + assert A.ang_vel_in(N) == (q1d)*A.z + assert B.ang_vel_in(N) == (q2d)*B.x + (q1d)*A.z + assert C.ang_vel_in(N) == (q3d)*C.y + (q2d)*B.x + (q1d)*A.z + + A2 = N.orientnew('A2', 'Axis', [q4, N.y]) + assert N.ang_vel_in(N) == 0 + assert N.ang_vel_in(A) == -q1d*N.z + assert N.ang_vel_in(B) == -q1d*A.z - q2d*B.x + assert N.ang_vel_in(C) == -q1d*A.z - q2d*B.x - q3d*B.y + assert N.ang_vel_in(A2) == -q4d*N.y + + assert A.ang_vel_in(N) == q1d*N.z + assert A.ang_vel_in(A) == 0 + assert A.ang_vel_in(B) == - q2d*B.x + assert A.ang_vel_in(C) == - q2d*B.x - q3d*B.y + assert A.ang_vel_in(A2) == q1d*N.z - q4d*N.y + + assert B.ang_vel_in(N) == q1d*A.z + q2d*A.x + assert B.ang_vel_in(A) == q2d*A.x + assert B.ang_vel_in(B) == 0 + assert B.ang_vel_in(C) == -q3d*B.y + assert B.ang_vel_in(A2) == q1d*A.z + q2d*A.x - q4d*N.y + + assert C.ang_vel_in(N) == q1d*A.z + q2d*A.x + q3d*B.y + assert C.ang_vel_in(A) == q2d*A.x + q3d*C.y + assert C.ang_vel_in(B) == q3d*B.y + assert C.ang_vel_in(C) == 0 + assert C.ang_vel_in(A2) == q1d*A.z + q2d*A.x + q3d*B.y - q4d*N.y + + assert A2.ang_vel_in(N) == q4d*A2.y + assert A2.ang_vel_in(A) == q4d*A2.y - q1d*N.z + assert A2.ang_vel_in(B) == q4d*N.y - q1d*A.z - q2d*A.x + assert A2.ang_vel_in(C) == q4d*N.y - q1d*A.z - q2d*A.x - q3d*B.y + assert A2.ang_vel_in(A2) == 0 + + C.set_ang_vel(N, u1*C.x + u2*C.y + u3*C.z) + assert C.ang_vel_in(N) == (u1)*C.x + (u2)*C.y + (u3)*C.z + assert N.ang_vel_in(C) == (-u1)*C.x + (-u2)*C.y + (-u3)*C.z + assert C.ang_vel_in(D) == (u1)*C.x + (u2)*C.y + (u3)*C.z + (-q4d)*D.y + assert D.ang_vel_in(C) == (-u1)*C.x + (-u2)*C.y + (-u3)*C.z + (q4d)*D.y + + q0 = dynamicsymbols('q0') + q0d = dynamicsymbols('q0', 1) + E = N.orientnew('E', 'Quaternion', (q0, q1, q2, q3)) + assert E.ang_vel_in(N) == ( + 2 * (q1d * q0 + q2d * q3 - q3d * q2 - q0d * q1) * E.x + + 2 * (q2d * q0 + q3d * q1 - q1d * q3 - q0d * q2) * E.y + + 2 * (q3d * q0 + q1d * q2 - q2d * q1 - q0d * q3) * E.z) + + F = N.orientnew('F', 'Body', (q1, q2, q3), 313) + assert F.ang_vel_in(N) == ((sin(q2)*sin(q3)*q1d + cos(q3)*q2d)*F.x + + (sin(q2)*cos(q3)*q1d - sin(q3)*q2d)*F.y + (cos(q2)*q1d + q3d)*F.z) + G = N.orientnew('G', 'Axis', (q1, N.x + N.y)) + assert G.ang_vel_in(N) == q1d * (N.x + N.y).normalize() + assert N.ang_vel_in(G) == -q1d * (N.x + N.y).normalize() + + +def test_dcm(): + q1, q2, q3, q4 = dynamicsymbols('q1 q2 q3 q4') + N = ReferenceFrame('N') + A = N.orientnew('A', 'Axis', [q1, N.z]) + B = A.orientnew('B', 'Axis', [q2, A.x]) + C = B.orientnew('C', 'Axis', [q3, B.y]) + D = N.orientnew('D', 'Axis', [q4, N.y]) + E = N.orientnew('E', 'Space', [q1, q2, q3], '123') + assert N.dcm(C) == Matrix([ + [- sin(q1) * sin(q2) * sin(q3) + cos(q1) * cos(q3), - sin(q1) * + cos(q2), sin(q1) * sin(q2) * cos(q3) + sin(q3) * cos(q1)], [sin(q1) * + cos(q3) + sin(q2) * sin(q3) * cos(q1), cos(q1) * cos(q2), sin(q1) * + sin(q3) - sin(q2) * cos(q1) * cos(q3)], [- sin(q3) * cos(q2), sin(q2), + cos(q2) * cos(q3)]]) + # This is a little touchy. Is it ok to use simplify in assert? + test_mat = D.dcm(C) - Matrix( + [[cos(q1) * cos(q3) * cos(q4) - sin(q3) * (- sin(q4) * cos(q2) + + sin(q1) * sin(q2) * cos(q4)), - sin(q2) * sin(q4) - sin(q1) * + cos(q2) * cos(q4), sin(q3) * cos(q1) * cos(q4) + cos(q3) * (- sin(q4) * + cos(q2) + sin(q1) * sin(q2) * cos(q4))], [sin(q1) * cos(q3) + + sin(q2) * sin(q3) * cos(q1), cos(q1) * cos(q2), sin(q1) * sin(q3) - + sin(q2) * cos(q1) * cos(q3)], [sin(q4) * cos(q1) * cos(q3) - + sin(q3) * (cos(q2) * cos(q4) + sin(q1) * sin(q2) * sin(q4)), sin(q2) * + cos(q4) - sin(q1) * sin(q4) * cos(q2), sin(q3) * sin(q4) * cos(q1) + + cos(q3) * (cos(q2) * cos(q4) + sin(q1) * sin(q2) * sin(q4))]]) + assert test_mat.expand() == zeros(3, 3) + assert E.dcm(N) == Matrix( + [[cos(q2)*cos(q3), sin(q3)*cos(q2), -sin(q2)], + [sin(q1)*sin(q2)*cos(q3) - sin(q3)*cos(q1), sin(q1)*sin(q2)*sin(q3) + + cos(q1)*cos(q3), sin(q1)*cos(q2)], [sin(q1)*sin(q3) + + sin(q2)*cos(q1)*cos(q3), - sin(q1)*cos(q3) + sin(q2)*sin(q3)*cos(q1), + cos(q1)*cos(q2)]]) + +def test_w_diff_dcm1(): + # Ref: + # Dynamics Theory and Applications, Kane 1985 + # Sec. 2.1 ANGULAR VELOCITY + A = ReferenceFrame('A') + B = ReferenceFrame('B') + + c11, c12, c13 = dynamicsymbols('C11 C12 C13') + c21, c22, c23 = dynamicsymbols('C21 C22 C23') + c31, c32, c33 = dynamicsymbols('C31 C32 C33') + + c11d, c12d, c13d = dynamicsymbols('C11 C12 C13', level=1) + c21d, c22d, c23d = dynamicsymbols('C21 C22 C23', level=1) + c31d, c32d, c33d = dynamicsymbols('C31 C32 C33', level=1) + + DCM = Matrix([ + [c11, c12, c13], + [c21, c22, c23], + [c31, c32, c33] + ]) + + B.orient(A, 'DCM', DCM) + b1a = (B.x).express(A) + b2a = (B.y).express(A) + b3a = (B.z).express(A) + + # Equation (2.1.1) + B.set_ang_vel(A, B.x*(dot((b3a).dt(A), B.y)) + + B.y*(dot((b1a).dt(A), B.z)) + + B.z*(dot((b2a).dt(A), B.x))) + + # Equation (2.1.21) + expr = ( (c12*c13d + c22*c23d + c32*c33d)*B.x + + (c13*c11d + c23*c21d + c33*c31d)*B.y + + (c11*c12d + c21*c22d + c31*c32d)*B.z) + assert B.ang_vel_in(A) - expr == 0 + +def test_w_diff_dcm2(): + q1, q2, q3 = dynamicsymbols('q1:4') + N = ReferenceFrame('N') + A = N.orientnew('A', 'axis', [q1, N.x]) + B = A.orientnew('B', 'axis', [q2, A.y]) + C = B.orientnew('C', 'axis', [q3, B.z]) + + DCM = C.dcm(N).T + D = N.orientnew('D', 'DCM', DCM) + + # Frames D and C are the same ReferenceFrame, + # since they have equal DCM respect to frame N. + # Therefore, D and C should have same angle velocity in N. + assert D.dcm(N) == C.dcm(N) == Matrix([ + [cos(q2)*cos(q3), sin(q1)*sin(q2)*cos(q3) + + sin(q3)*cos(q1), sin(q1)*sin(q3) - + sin(q2)*cos(q1)*cos(q3)], [-sin(q3)*cos(q2), + -sin(q1)*sin(q2)*sin(q3) + cos(q1)*cos(q3), + sin(q1)*cos(q3) + sin(q2)*sin(q3)*cos(q1)], + [sin(q2), -sin(q1)*cos(q2), cos(q1)*cos(q2)]]) + assert (D.ang_vel_in(N) - C.ang_vel_in(N)).express(N).simplify() == 0 + +def test_orientnew_respects_parent_class(): + class MyReferenceFrame(ReferenceFrame): + pass + B = MyReferenceFrame('B') + C = B.orientnew('C', 'Axis', [0, B.x]) + assert isinstance(C, MyReferenceFrame) + + +def test_orientnew_respects_input_indices(): + N = ReferenceFrame('N') + q1 = dynamicsymbols('q1') + A = N.orientnew('a', 'Axis', [q1, N.z]) + #modify default indices: + minds = [x+'1' for x in N.indices] + B = N.orientnew('b', 'Axis', [q1, N.z], indices=minds) + + assert N.indices == A.indices + assert B.indices == minds + +def test_orientnew_respects_input_latexs(): + N = ReferenceFrame('N') + q1 = dynamicsymbols('q1') + A = N.orientnew('a', 'Axis', [q1, N.z]) + + #build default and alternate latex_vecs: + def_latex_vecs = [(r"\mathbf{\hat{%s}_%s}" % (A.name.lower(), + A.indices[0])), (r"\mathbf{\hat{%s}_%s}" % + (A.name.lower(), A.indices[1])), + (r"\mathbf{\hat{%s}_%s}" % (A.name.lower(), + A.indices[2]))] + + name = 'b' + indices = [x+'1' for x in N.indices] + new_latex_vecs = [(r"\mathbf{\hat{%s}_{%s}}" % (name.lower(), + indices[0])), (r"\mathbf{\hat{%s}_{%s}}" % + (name.lower(), indices[1])), + (r"\mathbf{\hat{%s}_{%s}}" % (name.lower(), + indices[2]))] + + B = N.orientnew(name, 'Axis', [q1, N.z], latexs=new_latex_vecs) + + assert A.latex_vecs == def_latex_vecs + assert B.latex_vecs == new_latex_vecs + assert B.indices != indices + +def test_orientnew_respects_input_variables(): + N = ReferenceFrame('N') + q1 = dynamicsymbols('q1') + A = N.orientnew('a', 'Axis', [q1, N.z]) + + #build non-standard variable names + name = 'b' + new_variables = ['notb_'+x+'1' for x in N.indices] + B = N.orientnew(name, 'Axis', [q1, N.z], variables=new_variables) + + for j,var in enumerate(A.varlist): + assert var.name == A.name + '_' + A.indices[j] + + for j,var in enumerate(B.varlist): + assert var.name == new_variables[j] + +def test_issue_10348(): + u = dynamicsymbols('u:3') + I = ReferenceFrame('I') + I.orientnew('A', 'space', u, 'XYZ') + + +def test_issue_11503(): + A = ReferenceFrame("A") + A.orientnew("B", "Axis", [35, A.y]) + C = ReferenceFrame("C") + A.orient(C, "Axis", [70, C.z]) + + +def test_partial_velocity(): + + N = ReferenceFrame('N') + A = ReferenceFrame('A') + + u1, u2 = dynamicsymbols('u1, u2') + + A.set_ang_vel(N, u1 * A.x + u2 * N.y) + + assert N.partial_velocity(A, u1) == -A.x + assert N.partial_velocity(A, u1, u2) == (-A.x, -N.y) + + assert A.partial_velocity(N, u1) == A.x + assert A.partial_velocity(N, u1, u2) == (A.x, N.y) + + assert N.partial_velocity(N, u1) == 0 + assert A.partial_velocity(A, u1) == 0 + + +def test_issue_11498(): + A = ReferenceFrame('A') + B = ReferenceFrame('B') + + # Identity transformation + A.orient(B, 'DCM', eye(3)) + assert A.dcm(B) == Matrix([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + assert B.dcm(A) == Matrix([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + + # x -> y + # y -> -z + # z -> -x + A.orient(B, 'DCM', Matrix([[0, 1, 0], [0, 0, -1], [-1, 0, 0]])) + assert B.dcm(A) == Matrix([[0, 1, 0], [0, 0, -1], [-1, 0, 0]]) + assert A.dcm(B) == Matrix([[0, 0, -1], [1, 0, 0], [0, -1, 0]]) + assert B.dcm(A).T == A.dcm(B) + + +def test_reference_frame(): + raises(TypeError, lambda: ReferenceFrame(0)) + raises(TypeError, lambda: ReferenceFrame('N', 0)) + raises(ValueError, lambda: ReferenceFrame('N', [0, 1])) + raises(TypeError, lambda: ReferenceFrame('N', [0, 1, 2])) + raises(TypeError, lambda: ReferenceFrame('N', ['a', 'b', 'c'], 0)) + raises(ValueError, lambda: ReferenceFrame('N', ['a', 'b', 'c'], [0, 1])) + raises(TypeError, lambda: ReferenceFrame('N', ['a', 'b', 'c'], [0, 1, 2])) + raises(TypeError, lambda: ReferenceFrame('N', ['a', 'b', 'c'], + ['a', 'b', 'c'], 0)) + raises(ValueError, lambda: ReferenceFrame('N', ['a', 'b', 'c'], + ['a', 'b', 'c'], [0, 1])) + raises(TypeError, lambda: ReferenceFrame('N', ['a', 'b', 'c'], + ['a', 'b', 'c'], [0, 1, 2])) + N = ReferenceFrame('N') + assert N[0] == CoordinateSym('N_x', N, 0) + assert N[1] == CoordinateSym('N_y', N, 1) + assert N[2] == CoordinateSym('N_z', N, 2) + raises(ValueError, lambda: N[3]) + N = ReferenceFrame('N', ['a', 'b', 'c']) + assert N['a'] == N.x + assert N['b'] == N.y + assert N['c'] == N.z + raises(ValueError, lambda: N['d']) + assert str(N) == 'N' + + A = ReferenceFrame('A') + B = ReferenceFrame('B') + q0, q1, q2, q3 = symbols('q0 q1 q2 q3') + raises(TypeError, lambda: A.orient(B, 'DCM', 0)) + raises(TypeError, lambda: B.orient(N, 'Space', [q1, q2, q3], '222')) + raises(TypeError, lambda: B.orient(N, 'Axis', [q1, N.x + 2 * N.y], '222')) + raises(TypeError, lambda: B.orient(N, 'Axis', q1)) + raises(IndexError, lambda: B.orient(N, 'Axis', [q1])) + raises(TypeError, lambda: B.orient(N, 'Quaternion', [q0, q1, q2, q3], '222')) + raises(TypeError, lambda: B.orient(N, 'Quaternion', q0)) + raises(TypeError, lambda: B.orient(N, 'Quaternion', [q0, q1, q2])) + raises(NotImplementedError, lambda: B.orient(N, 'Foo', [q0, q1, q2])) + raises(TypeError, lambda: B.orient(N, 'Body', [q1, q2], '232')) + raises(TypeError, lambda: B.orient(N, 'Space', [q1, q2], '232')) + + N.set_ang_acc(B, 0) + assert N.ang_acc_in(B) == Vector(0) + N.set_ang_vel(B, 0) + assert N.ang_vel_in(B) == Vector(0) + + +def test_check_frame(): + raises(VectorTypeError, lambda: _check_frame(0)) + + +def test_dcm_diff_16824(): + # NOTE : This is a regression test for the bug introduced in PR 14758, + # identified in 16824, and solved by PR 16828. + + # This is the solution to Problem 2.2 on page 264 in Kane & Lenvinson's + # 1985 book. + + q1, q2, q3 = dynamicsymbols('q1:4') + + s1 = sin(q1) + c1 = cos(q1) + s2 = sin(q2) + c2 = cos(q2) + s3 = sin(q3) + c3 = cos(q3) + + dcm = Matrix([[c2*c3, s1*s2*c3 - s3*c1, c1*s2*c3 + s3*s1], + [c2*s3, s1*s2*s3 + c3*c1, c1*s2*s3 - c3*s1], + [-s2, s1*c2, c1*c2]]) + + A = ReferenceFrame('A') + B = ReferenceFrame('B') + B.orient(A, 'DCM', dcm) + + AwB = B.ang_vel_in(A) + + alpha2 = s3*c2*q1.diff() + c3*q2.diff() + beta2 = s1*c2*q3.diff() + c1*q2.diff() + + assert simplify(AwB.dot(A.y) - alpha2) == 0 + assert simplify(AwB.dot(B.y) - beta2) == 0 + +def test_orient_explicit(): + cxx, cyy, czz = dynamicsymbols('c_{xx}, c_{yy}, c_{zz}') + cxy, cxz, cyx = dynamicsymbols('c_{xy}, c_{xz}, c_{yx}') + cyz, czx, czy = dynamicsymbols('c_{yz}, c_{zx}, c_{zy}') + dcxx, dcyy, dczz = dynamicsymbols('c_{xx}, c_{yy}, c_{zz}', 1) + dcxy, dcxz, dcyx = dynamicsymbols('c_{xy}, c_{xz}, c_{yx}', 1) + dcyz, dczx, dczy = dynamicsymbols('c_{yz}, c_{zx}, c_{zy}', 1) + A = ReferenceFrame('A') + B = ReferenceFrame('B') + B_C_A = Matrix([[cxx, cxy, cxz], + [cyx, cyy, cyz], + [czx, czy, czz]]) + B_w_A = ((cyx*dczx + cyy*dczy + cyz*dczz)*B.x + + (czx*dcxx + czy*dcxy + czz*dcxz)*B.y + + (cxx*dcyx + cxy*dcyy + cxz*dcyz)*B.z) + A.orient_explicit(B, B_C_A) + assert B.dcm(A) == B_C_A + assert A.ang_vel_in(B) == B_w_A + assert B.ang_vel_in(A) == -B_w_A + +def test_orient_dcm(): + cxx, cyy, czz = dynamicsymbols('c_{xx}, c_{yy}, c_{zz}') + cxy, cxz, cyx = dynamicsymbols('c_{xy}, c_{xz}, c_{yx}') + cyz, czx, czy = dynamicsymbols('c_{yz}, c_{zx}, c_{zy}') + B_C_A = Matrix([[cxx, cxy, cxz], + [cyx, cyy, cyz], + [czx, czy, czz]]) + A = ReferenceFrame('A') + B = ReferenceFrame('B') + B.orient_dcm(A, B_C_A) + assert B.dcm(A) == Matrix([[cxx, cxy, cxz], + [cyx, cyy, cyz], + [czx, czy, czz]]) + +def test_orient_axis(): + A = ReferenceFrame('A') + B = ReferenceFrame('B') + A.orient_axis(B,-B.x, 1) + A1 = A.dcm(B) + A.orient_axis(B, B.x, -1) + A2 = A.dcm(B) + A.orient_axis(B, 1, -B.x) + A3 = A.dcm(B) + assert A1 == A2 + assert A2 == A3 + raises(TypeError, lambda: A.orient_axis(B, 1, 1)) + +def test_orient_body(): + A = ReferenceFrame('A') + B = ReferenceFrame('B') + B.orient_body_fixed(A, (1,1,0), 'XYX') + assert B.dcm(A) == Matrix([[cos(1), sin(1)**2, -sin(1)*cos(1)], [0, cos(1), sin(1)], [sin(1), -sin(1)*cos(1), cos(1)**2]]) + + +def test_orient_body_advanced(): + q1, q2, q3 = dynamicsymbols('q1:4') + c1, c2, c3 = symbols('c1:4') + u1, u2, u3 = dynamicsymbols('q1:4', 1) + + # Test with everything as dynamicsymbols + A, B = ReferenceFrame('A'), ReferenceFrame('B') + B.orient_body_fixed(A, (q1, q2, q3), 'zxy') + assert A.dcm(B) == Matrix([ + [-sin(q1) * sin(q2) * sin(q3) + cos(q1) * cos(q3), -sin(q1) * cos(q2), + sin(q1) * sin(q2) * cos(q3) + sin(q3) * cos(q1)], + [sin(q1) * cos(q3) + sin(q2) * sin(q3) * cos(q1), cos(q1) * cos(q2), + sin(q1) * sin(q3) - sin(q2) * cos(q1) * cos(q3)], + [-sin(q3) * cos(q2), sin(q2), cos(q2) * cos(q3)]]) + assert B.ang_vel_in(A).to_matrix(B) == Matrix([ + [-sin(q3) * cos(q2) * u1 + cos(q3) * u2], + [sin(q2) * u1 + u3], + [sin(q3) * u2 + cos(q2) * cos(q3) * u1]]) + + # Test with constant symbol + A, B = ReferenceFrame('A'), ReferenceFrame('B') + B.orient_body_fixed(A, (q1, c2, q3), 131) + assert A.dcm(B) == Matrix([ + [cos(c2), -sin(c2) * cos(q3), sin(c2) * sin(q3)], + [sin(c2) * cos(q1), -sin(q1) * sin(q3) + cos(c2) * cos(q1) * cos(q3), + -sin(q1) * cos(q3) - sin(q3) * cos(c2) * cos(q1)], + [sin(c2) * sin(q1), sin(q1) * cos(c2) * cos(q3) + sin(q3) * cos(q1), + -sin(q1) * sin(q3) * cos(c2) + cos(q1) * cos(q3)]]) + assert B.ang_vel_in(A).to_matrix(B) == Matrix([ + [cos(c2) * u1 + u3], + [-sin(c2) * cos(q3) * u1], + [sin(c2) * sin(q3) * u1]]) + + # Test all symbols not time dependent + A, B = ReferenceFrame('A'), ReferenceFrame('B') + B.orient_body_fixed(A, (c1, c2, c3), 123) + assert B.ang_vel_in(A) == Vector(0) + + +def test_orient_space_advanced(): + # space fixed is in the end like body fixed only in opposite order + q1, q2, q3 = dynamicsymbols('q1:4') + c1, c2, c3 = symbols('c1:4') + u1, u2, u3 = dynamicsymbols('q1:4', 1) + + # Test with everything as dynamicsymbols + A, B = ReferenceFrame('A'), ReferenceFrame('B') + B.orient_space_fixed(A, (q3, q2, q1), 'yxz') + assert A.dcm(B) == Matrix([ + [-sin(q1) * sin(q2) * sin(q3) + cos(q1) * cos(q3), -sin(q1) * cos(q2), + sin(q1) * sin(q2) * cos(q3) + sin(q3) * cos(q1)], + [sin(q1) * cos(q3) + sin(q2) * sin(q3) * cos(q1), cos(q1) * cos(q2), + sin(q1) * sin(q3) - sin(q2) * cos(q1) * cos(q3)], + [-sin(q3) * cos(q2), sin(q2), cos(q2) * cos(q3)]]) + assert B.ang_vel_in(A).to_matrix(B) == Matrix([ + [-sin(q3) * cos(q2) * u1 + cos(q3) * u2], + [sin(q2) * u1 + u3], + [sin(q3) * u2 + cos(q2) * cos(q3) * u1]]) + + # Test with constant symbol + A, B = ReferenceFrame('A'), ReferenceFrame('B') + B.orient_space_fixed(A, (q3, c2, q1), 131) + assert A.dcm(B) == Matrix([ + [cos(c2), -sin(c2) * cos(q3), sin(c2) * sin(q3)], + [sin(c2) * cos(q1), -sin(q1) * sin(q3) + cos(c2) * cos(q1) * cos(q3), + -sin(q1) * cos(q3) - sin(q3) * cos(c2) * cos(q1)], + [sin(c2) * sin(q1), sin(q1) * cos(c2) * cos(q3) + sin(q3) * cos(q1), + -sin(q1) * sin(q3) * cos(c2) + cos(q1) * cos(q3)]]) + assert B.ang_vel_in(A).to_matrix(B) == Matrix([ + [cos(c2) * u1 + u3], + [-sin(c2) * cos(q3) * u1], + [sin(c2) * sin(q3) * u1]]) + + # Test all symbols not time dependent + A, B = ReferenceFrame('A'), ReferenceFrame('B') + B.orient_space_fixed(A, (c1, c2, c3), 123) + assert B.ang_vel_in(A) == Vector(0) + + +def test_orient_body_simple_ang_vel(): + """This test ensures that the simplest form of that linear system solution + is returned, thus the == for the expression comparison.""" + + psi, theta, phi = dynamicsymbols('psi, theta, varphi') + t = dynamicsymbols._t + A = ReferenceFrame('A') + B = ReferenceFrame('B') + B.orient_body_fixed(A, (psi, theta, phi), 'ZXZ') + A_w_B = B.ang_vel_in(A) + assert A_w_B.args[0][1] == B + assert A_w_B.args[0][0][0] == (sin(theta)*sin(phi)*psi.diff(t) + + cos(phi)*theta.diff(t)) + assert A_w_B.args[0][0][1] == (sin(theta)*cos(phi)*psi.diff(t) - + sin(phi)*theta.diff(t)) + assert A_w_B.args[0][0][2] == cos(theta)*psi.diff(t) + phi.diff(t) + + +def test_orient_space(): + A = ReferenceFrame('A') + B = ReferenceFrame('B') + B.orient_space_fixed(A, (0,0,0), '123') + assert B.dcm(A) == Matrix([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + +def test_orient_quaternion(): + A = ReferenceFrame('A') + B = ReferenceFrame('B') + B.orient_quaternion(A, (0,0,0,0)) + assert B.dcm(A) == Matrix([[0, 0, 0], [0, 0, 0], [0, 0, 0]]) + +def test_looped_frame_warning(): + A = ReferenceFrame('A') + B = ReferenceFrame('B') + C = ReferenceFrame('C') + + a, b, c = symbols('a b c') + B.orient_axis(A, A.x, a) + C.orient_axis(B, B.x, b) + + with warnings.catch_warnings(record = True) as w: + warnings.simplefilter("always") + A.orient_axis(C, C.x, c) + assert issubclass(w[-1].category, UserWarning) + assert 'Loops are defined among the orientation of frames. ' + \ + 'This is likely not desired and may cause errors in your calculations.' in str(w[-1].message) + +def test_frame_dict(): + A = ReferenceFrame('A') + B = ReferenceFrame('B') + C = ReferenceFrame('C') + + a, b, c = symbols('a b c') + + B.orient_axis(A, A.x, a) + assert A._dcm_dict == {B: Matrix([[1, 0, 0],[0, cos(a), -sin(a)],[0, sin(a), cos(a)]])} + assert B._dcm_dict == {A: Matrix([[1, 0, 0],[0, cos(a), sin(a)],[0, -sin(a), cos(a)]])} + assert C._dcm_dict == {} + + B.orient_axis(C, C.x, b) + # Previous relation is not wiped + assert A._dcm_dict == {B: Matrix([[1, 0, 0],[0, cos(a), -sin(a)],[0, sin(a), cos(a)]])} + assert B._dcm_dict == {A: Matrix([[1, 0, 0],[0, cos(a), sin(a)],[0, -sin(a), cos(a)]]), \ + C: Matrix([[1, 0, 0],[0, cos(b), sin(b)],[0, -sin(b), cos(b)]])} + assert C._dcm_dict == {B: Matrix([[1, 0, 0],[0, cos(b), -sin(b)],[0, sin(b), cos(b)]])} + + A.orient_axis(B, B.x, c) + # Previous relation is updated + assert B._dcm_dict == {C: Matrix([[1, 0, 0],[0, cos(b), sin(b)],[0, -sin(b), cos(b)]]),\ + A: Matrix([[1, 0, 0],[0, cos(c), -sin(c)],[0, sin(c), cos(c)]])} + assert A._dcm_dict == {B: Matrix([[1, 0, 0],[0, cos(c), sin(c)],[0, -sin(c), cos(c)]])} + assert C._dcm_dict == {B: Matrix([[1, 0, 0],[0, cos(b), -sin(b)],[0, sin(b), cos(b)]])} + +def test_dcm_cache_dict(): + A = ReferenceFrame('A') + B = ReferenceFrame('B') + C = ReferenceFrame('C') + D = ReferenceFrame('D') + + a, b, c = symbols('a b c') + + B.orient_axis(A, A.x, a) + C.orient_axis(B, B.x, b) + D.orient_axis(C, C.x, c) + + assert D._dcm_dict == {C: Matrix([[1, 0, 0],[0, cos(c), sin(c)],[0, -sin(c), cos(c)]])} + assert C._dcm_dict == {B: Matrix([[1, 0, 0],[0, cos(b), sin(b)],[0, -sin(b), cos(b)]]), \ + D: Matrix([[1, 0, 0],[0, cos(c), -sin(c)],[0, sin(c), cos(c)]])} + assert B._dcm_dict == {A: Matrix([[1, 0, 0],[0, cos(a), sin(a)],[0, -sin(a), cos(a)]]), \ + C: Matrix([[1, 0, 0],[0, cos(b), -sin(b)],[0, sin(b), cos(b)]])} + assert A._dcm_dict == {B: Matrix([[1, 0, 0],[0, cos(a), -sin(a)],[0, sin(a), cos(a)]])} + + assert D._dcm_dict == D._dcm_cache + + D.dcm(A) # Check calculated dcm relation is stored in _dcm_cache and not in _dcm_dict + assert list(A._dcm_cache.keys()) == [A, B, D] + assert list(D._dcm_cache.keys()) == [C, A] + assert list(A._dcm_dict.keys()) == [B] + assert list(D._dcm_dict.keys()) == [C] + assert A._dcm_dict != A._dcm_cache + + A.orient_axis(B, B.x, b) # _dcm_cache of A is wiped out and new relation is stored. + assert A._dcm_dict == {B: Matrix([[1, 0, 0],[0, cos(b), sin(b)],[0, -sin(b), cos(b)]])} + assert A._dcm_dict == A._dcm_cache + assert B._dcm_dict == {C: Matrix([[1, 0, 0],[0, cos(b), -sin(b)],[0, sin(b), cos(b)]]), \ + A: Matrix([[1, 0, 0],[0, cos(b), -sin(b)],[0, sin(b), cos(b)]])} + +def test_xx_dyad(): + N = ReferenceFrame('N') + F = ReferenceFrame('F', indices=['1', '2', '3']) + assert N.xx == Vector.outer(N.x, N.x) + assert F.xx == Vector.outer(F.x, F.x) + +def test_xy_dyad(): + N = ReferenceFrame('N') + F = ReferenceFrame('F', indices=['1', '2', '3']) + assert N.xy == Vector.outer(N.x, N.y) + assert F.xy == Vector.outer(F.x, F.y) + +def test_xz_dyad(): + N = ReferenceFrame('N') + F = ReferenceFrame('F', indices=['1', '2', '3']) + assert N.xz == Vector.outer(N.x, N.z) + assert F.xz == Vector.outer(F.x, F.z) + +def test_yx_dyad(): + N = ReferenceFrame('N') + F = ReferenceFrame('F', indices=['1', '2', '3']) + assert N.yx == Vector.outer(N.y, N.x) + assert F.yx == Vector.outer(F.y, F.x) + +def test_yy_dyad(): + N = ReferenceFrame('N') + F = ReferenceFrame('F', indices=['1', '2', '3']) + assert N.yy == Vector.outer(N.y, N.y) + assert F.yy == Vector.outer(F.y, F.y) + +def test_yz_dyad(): + N = ReferenceFrame('N') + F = ReferenceFrame('F', indices=['1', '2', '3']) + assert N.yz == Vector.outer(N.y, N.z) + assert F.yz == Vector.outer(F.y, F.z) + +def test_zx_dyad(): + N = ReferenceFrame('N') + F = ReferenceFrame('F', indices=['1', '2', '3']) + assert N.zx == Vector.outer(N.z, N.x) + assert F.zx == Vector.outer(F.z, F.x) + +def test_zy_dyad(): + N = ReferenceFrame('N') + F = ReferenceFrame('F', indices=['1', '2', '3']) + assert N.zy == Vector.outer(N.z, N.y) + assert F.zy == Vector.outer(F.z, F.y) + +def test_zz_dyad(): + N = ReferenceFrame('N') + F = ReferenceFrame('F', indices=['1', '2', '3']) + assert N.zz == Vector.outer(N.z, N.z) + assert F.zz == Vector.outer(F.z, F.z) + +def test_unit_dyadic(): + N = ReferenceFrame('N') + F = ReferenceFrame('F', indices=['1', '2', '3']) + assert N.u == N.xx + N.yy + N.zz + assert F.u == F.xx + F.yy + F.zz + + +def test_pickle_frame(): + N = ReferenceFrame('N') + A = ReferenceFrame('A') + A.orient_axis(N, N.x, 1) + A_C_N = A.dcm(N) + N1 = pickle.loads(pickle.dumps(N)) + A1 = tuple(N1._dcm_dict.keys())[0] + assert A1.dcm(N1) == A_C_N diff --git a/phivenv/Lib/site-packages/sympy/physics/vector/tests/test_functions.py b/phivenv/Lib/site-packages/sympy/physics/vector/tests/test_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..ff938da980c4bbd51d378b30fd5310a88e528e97 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/vector/tests/test_functions.py @@ -0,0 +1,509 @@ +from sympy.core.numbers import pi +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.integrals.integrals import Integral +from sympy.physics.vector import Dyadic, Point, ReferenceFrame, Vector +from sympy.physics.vector.functions import (cross, dot, express, + time_derivative, + kinematic_equations, outer, + partial_velocity, + get_motion_params, dynamicsymbols) +from sympy.simplify import trigsimp +from sympy.testing.pytest import raises + +q1, q2, q3, q4, q5 = symbols('q1 q2 q3 q4 q5') +N = ReferenceFrame('N') +A = N.orientnew('A', 'Axis', [q1, N.z]) +B = A.orientnew('B', 'Axis', [q2, A.x]) +C = B.orientnew('C', 'Axis', [q3, B.y]) + + +def test_dot(): + assert dot(A.x, A.x) == 1 + assert dot(A.x, A.y) == 0 + assert dot(A.x, A.z) == 0 + + assert dot(A.y, A.x) == 0 + assert dot(A.y, A.y) == 1 + assert dot(A.y, A.z) == 0 + + assert dot(A.z, A.x) == 0 + assert dot(A.z, A.y) == 0 + assert dot(A.z, A.z) == 1 + + +def test_dot_different_frames(): + assert dot(N.x, A.x) == cos(q1) + assert dot(N.x, A.y) == -sin(q1) + assert dot(N.x, A.z) == 0 + assert dot(N.y, A.x) == sin(q1) + assert dot(N.y, A.y) == cos(q1) + assert dot(N.y, A.z) == 0 + assert dot(N.z, A.x) == 0 + assert dot(N.z, A.y) == 0 + assert dot(N.z, A.z) == 1 + + assert trigsimp(dot(N.x, A.x + A.y)) == sqrt(2)*cos(q1 + pi/4) + assert trigsimp(dot(N.x, A.x + A.y)) == trigsimp(dot(A.x + A.y, N.x)) + + assert dot(A.x, C.x) == cos(q3) + assert dot(A.x, C.y) == 0 + assert dot(A.x, C.z) == sin(q3) + assert dot(A.y, C.x) == sin(q2)*sin(q3) + assert dot(A.y, C.y) == cos(q2) + assert dot(A.y, C.z) == -sin(q2)*cos(q3) + assert dot(A.z, C.x) == -cos(q2)*sin(q3) + assert dot(A.z, C.y) == sin(q2) + assert dot(A.z, C.z) == cos(q2)*cos(q3) + + +def test_cross(): + assert cross(A.x, A.x) == 0 + assert cross(A.x, A.y) == A.z + assert cross(A.x, A.z) == -A.y + + assert cross(A.y, A.x) == -A.z + assert cross(A.y, A.y) == 0 + assert cross(A.y, A.z) == A.x + + assert cross(A.z, A.x) == A.y + assert cross(A.z, A.y) == -A.x + assert cross(A.z, A.z) == 0 + + +def test_cross_different_frames(): + assert cross(N.x, A.x) == sin(q1)*A.z + assert cross(N.x, A.y) == cos(q1)*A.z + assert cross(N.x, A.z) == -sin(q1)*A.x - cos(q1)*A.y + assert cross(N.y, A.x) == -cos(q1)*A.z + assert cross(N.y, A.y) == sin(q1)*A.z + assert cross(N.y, A.z) == cos(q1)*A.x - sin(q1)*A.y + assert cross(N.z, A.x) == A.y + assert cross(N.z, A.y) == -A.x + assert cross(N.z, A.z) == 0 + + assert cross(N.x, A.x) == sin(q1)*A.z + assert cross(N.x, A.y) == cos(q1)*A.z + assert cross(N.x, A.x + A.y) == sin(q1)*A.z + cos(q1)*A.z + assert cross(A.x + A.y, N.x) == -sin(q1)*A.z - cos(q1)*A.z + + assert cross(A.x, C.x) == sin(q3)*C.y + assert cross(A.x, C.y) == -sin(q3)*C.x + cos(q3)*C.z + assert cross(A.x, C.z) == -cos(q3)*C.y + assert cross(C.x, A.x) == -sin(q3)*C.y + assert cross(C.y, A.x).express(C).simplify() == sin(q3)*C.x - cos(q3)*C.z + assert cross(C.z, A.x) == cos(q3)*C.y + +def test_operator_match(): + """Test that the output of dot, cross, outer functions match + operator behavior. + """ + A = ReferenceFrame('A') + v = A.x + A.y + d = v | v + zerov = Vector(0) + zerod = Dyadic(0) + + # dot products + assert d & d == dot(d, d) + assert d & zerod == dot(d, zerod) + assert zerod & d == dot(zerod, d) + assert d & v == dot(d, v) + assert v & d == dot(v, d) + assert d & zerov == dot(d, zerov) + assert zerov & d == dot(zerov, d) + raises(TypeError, lambda: dot(d, S.Zero)) + raises(TypeError, lambda: dot(S.Zero, d)) + raises(TypeError, lambda: dot(d, 0)) + raises(TypeError, lambda: dot(0, d)) + assert v & v == dot(v, v) + assert v & zerov == dot(v, zerov) + assert zerov & v == dot(zerov, v) + raises(TypeError, lambda: dot(v, S.Zero)) + raises(TypeError, lambda: dot(S.Zero, v)) + raises(TypeError, lambda: dot(v, 0)) + raises(TypeError, lambda: dot(0, v)) + + # cross products + raises(TypeError, lambda: cross(d, d)) + raises(TypeError, lambda: cross(d, zerod)) + raises(TypeError, lambda: cross(zerod, d)) + assert d ^ v == cross(d, v) + assert v ^ d == cross(v, d) + assert d ^ zerov == cross(d, zerov) + assert zerov ^ d == cross(zerov, d) + assert zerov ^ d == cross(zerov, d) + raises(TypeError, lambda: cross(d, S.Zero)) + raises(TypeError, lambda: cross(S.Zero, d)) + raises(TypeError, lambda: cross(d, 0)) + raises(TypeError, lambda: cross(0, d)) + assert v ^ v == cross(v, v) + assert v ^ zerov == cross(v, zerov) + assert zerov ^ v == cross(zerov, v) + raises(TypeError, lambda: cross(v, S.Zero)) + raises(TypeError, lambda: cross(S.Zero, v)) + raises(TypeError, lambda: cross(v, 0)) + raises(TypeError, lambda: cross(0, v)) + + # outer products + raises(TypeError, lambda: outer(d, d)) + raises(TypeError, lambda: outer(d, zerod)) + raises(TypeError, lambda: outer(zerod, d)) + raises(TypeError, lambda: outer(d, v)) + raises(TypeError, lambda: outer(v, d)) + raises(TypeError, lambda: outer(d, zerov)) + raises(TypeError, lambda: outer(zerov, d)) + raises(TypeError, lambda: outer(zerov, d)) + raises(TypeError, lambda: outer(d, S.Zero)) + raises(TypeError, lambda: outer(S.Zero, d)) + raises(TypeError, lambda: outer(d, 0)) + raises(TypeError, lambda: outer(0, d)) + assert v | v == outer(v, v) + assert v | zerov == outer(v, zerov) + assert zerov | v == outer(zerov, v) + raises(TypeError, lambda: outer(v, S.Zero)) + raises(TypeError, lambda: outer(S.Zero, v)) + raises(TypeError, lambda: outer(v, 0)) + raises(TypeError, lambda: outer(0, v)) + + +def test_express(): + assert express(Vector(0), N) == Vector(0) + assert express(S.Zero, N) is S.Zero + assert express(A.x, C) == cos(q3)*C.x + sin(q3)*C.z + assert express(A.y, C) == sin(q2)*sin(q3)*C.x + cos(q2)*C.y - \ + sin(q2)*cos(q3)*C.z + assert express(A.z, C) == -sin(q3)*cos(q2)*C.x + sin(q2)*C.y + \ + cos(q2)*cos(q3)*C.z + assert express(A.x, N) == cos(q1)*N.x + sin(q1)*N.y + assert express(A.y, N) == -sin(q1)*N.x + cos(q1)*N.y + assert express(A.z, N) == N.z + assert express(A.x, A) == A.x + assert express(A.y, A) == A.y + assert express(A.z, A) == A.z + assert express(A.x, B) == B.x + assert express(A.y, B) == cos(q2)*B.y - sin(q2)*B.z + assert express(A.z, B) == sin(q2)*B.y + cos(q2)*B.z + assert express(A.x, C) == cos(q3)*C.x + sin(q3)*C.z + assert express(A.y, C) == sin(q2)*sin(q3)*C.x + cos(q2)*C.y - \ + sin(q2)*cos(q3)*C.z + assert express(A.z, C) == -sin(q3)*cos(q2)*C.x + sin(q2)*C.y + \ + cos(q2)*cos(q3)*C.z + # Check to make sure UnitVectors get converted properly + assert express(N.x, N) == N.x + assert express(N.y, N) == N.y + assert express(N.z, N) == N.z + assert express(N.x, A) == (cos(q1)*A.x - sin(q1)*A.y) + assert express(N.y, A) == (sin(q1)*A.x + cos(q1)*A.y) + assert express(N.z, A) == A.z + assert express(N.x, B) == (cos(q1)*B.x - sin(q1)*cos(q2)*B.y + + sin(q1)*sin(q2)*B.z) + assert express(N.y, B) == (sin(q1)*B.x + cos(q1)*cos(q2)*B.y - + sin(q2)*cos(q1)*B.z) + assert express(N.z, B) == (sin(q2)*B.y + cos(q2)*B.z) + assert express(N.x, C) == ( + (cos(q1)*cos(q3) - sin(q1)*sin(q2)*sin(q3))*C.x - + sin(q1)*cos(q2)*C.y + + (sin(q3)*cos(q1) + sin(q1)*sin(q2)*cos(q3))*C.z) + assert express(N.y, C) == ( + (sin(q1)*cos(q3) + sin(q2)*sin(q3)*cos(q1))*C.x + + cos(q1)*cos(q2)*C.y + + (sin(q1)*sin(q3) - sin(q2)*cos(q1)*cos(q3))*C.z) + assert express(N.z, C) == (-sin(q3)*cos(q2)*C.x + sin(q2)*C.y + + cos(q2)*cos(q3)*C.z) + + assert express(A.x, N) == (cos(q1)*N.x + sin(q1)*N.y) + assert express(A.y, N) == (-sin(q1)*N.x + cos(q1)*N.y) + assert express(A.z, N) == N.z + assert express(A.x, A) == A.x + assert express(A.y, A) == A.y + assert express(A.z, A) == A.z + assert express(A.x, B) == B.x + assert express(A.y, B) == (cos(q2)*B.y - sin(q2)*B.z) + assert express(A.z, B) == (sin(q2)*B.y + cos(q2)*B.z) + assert express(A.x, C) == (cos(q3)*C.x + sin(q3)*C.z) + assert express(A.y, C) == (sin(q2)*sin(q3)*C.x + cos(q2)*C.y - + sin(q2)*cos(q3)*C.z) + assert express(A.z, C) == (-sin(q3)*cos(q2)*C.x + sin(q2)*C.y + + cos(q2)*cos(q3)*C.z) + + assert express(B.x, N) == (cos(q1)*N.x + sin(q1)*N.y) + assert express(B.y, N) == (-sin(q1)*cos(q2)*N.x + + cos(q1)*cos(q2)*N.y + sin(q2)*N.z) + assert express(B.z, N) == (sin(q1)*sin(q2)*N.x - + sin(q2)*cos(q1)*N.y + cos(q2)*N.z) + assert express(B.x, A) == A.x + assert express(B.y, A) == (cos(q2)*A.y + sin(q2)*A.z) + assert express(B.z, A) == (-sin(q2)*A.y + cos(q2)*A.z) + assert express(B.x, B) == B.x + assert express(B.y, B) == B.y + assert express(B.z, B) == B.z + assert express(B.x, C) == (cos(q3)*C.x + sin(q3)*C.z) + assert express(B.y, C) == C.y + assert express(B.z, C) == (-sin(q3)*C.x + cos(q3)*C.z) + + assert express(C.x, N) == ( + (cos(q1)*cos(q3) - sin(q1)*sin(q2)*sin(q3))*N.x + + (sin(q1)*cos(q3) + sin(q2)*sin(q3)*cos(q1))*N.y - + sin(q3)*cos(q2)*N.z) + assert express(C.y, N) == ( + -sin(q1)*cos(q2)*N.x + cos(q1)*cos(q2)*N.y + sin(q2)*N.z) + assert express(C.z, N) == ( + (sin(q3)*cos(q1) + sin(q1)*sin(q2)*cos(q3))*N.x + + (sin(q1)*sin(q3) - sin(q2)*cos(q1)*cos(q3))*N.y + + cos(q2)*cos(q3)*N.z) + assert express(C.x, A) == (cos(q3)*A.x + sin(q2)*sin(q3)*A.y - + sin(q3)*cos(q2)*A.z) + assert express(C.y, A) == (cos(q2)*A.y + sin(q2)*A.z) + assert express(C.z, A) == (sin(q3)*A.x - sin(q2)*cos(q3)*A.y + + cos(q2)*cos(q3)*A.z) + assert express(C.x, B) == (cos(q3)*B.x - sin(q3)*B.z) + assert express(C.y, B) == B.y + assert express(C.z, B) == (sin(q3)*B.x + cos(q3)*B.z) + assert express(C.x, C) == C.x + assert express(C.y, C) == C.y + assert express(C.z, C) == C.z == (C.z) + + # Check to make sure Vectors get converted back to UnitVectors + assert N.x == express((cos(q1)*A.x - sin(q1)*A.y), N).simplify() + assert N.y == express((sin(q1)*A.x + cos(q1)*A.y), N).simplify() + assert N.x == express((cos(q1)*B.x - sin(q1)*cos(q2)*B.y + + sin(q1)*sin(q2)*B.z), N).simplify() + assert N.y == express((sin(q1)*B.x + cos(q1)*cos(q2)*B.y - + sin(q2)*cos(q1)*B.z), N).simplify() + assert N.z == express((sin(q2)*B.y + cos(q2)*B.z), N).simplify() + + """ + These don't really test our code, they instead test the auto simplification + (or lack thereof) of SymPy. + assert N.x == express(( + (cos(q1)*cos(q3)-sin(q1)*sin(q2)*sin(q3))*C.x - + sin(q1)*cos(q2)*C.y + + (sin(q3)*cos(q1)+sin(q1)*sin(q2)*cos(q3))*C.z), N) + assert N.y == express(( + (sin(q1)*cos(q3) + sin(q2)*sin(q3)*cos(q1))*C.x + + cos(q1)*cos(q2)*C.y + + (sin(q1)*sin(q3) - sin(q2)*cos(q1)*cos(q3))*C.z), N) + assert N.z == express((-sin(q3)*cos(q2)*C.x + sin(q2)*C.y + + cos(q2)*cos(q3)*C.z), N) + """ + + assert A.x == express((cos(q1)*N.x + sin(q1)*N.y), A).simplify() + assert A.y == express((-sin(q1)*N.x + cos(q1)*N.y), A).simplify() + + assert A.y == express((cos(q2)*B.y - sin(q2)*B.z), A).simplify() + assert A.z == express((sin(q2)*B.y + cos(q2)*B.z), A).simplify() + + assert A.x == express((cos(q3)*C.x + sin(q3)*C.z), A).simplify() + + # Tripsimp messes up here too. + #print express((sin(q2)*sin(q3)*C.x + cos(q2)*C.y - + # sin(q2)*cos(q3)*C.z), A) + assert A.y == express((sin(q2)*sin(q3)*C.x + cos(q2)*C.y - + sin(q2)*cos(q3)*C.z), A).simplify() + + assert A.z == express((-sin(q3)*cos(q2)*C.x + sin(q2)*C.y + + cos(q2)*cos(q3)*C.z), A).simplify() + assert B.x == express((cos(q1)*N.x + sin(q1)*N.y), B).simplify() + assert B.y == express((-sin(q1)*cos(q2)*N.x + + cos(q1)*cos(q2)*N.y + sin(q2)*N.z), B).simplify() + + assert B.z == express((sin(q1)*sin(q2)*N.x - + sin(q2)*cos(q1)*N.y + cos(q2)*N.z), B).simplify() + + assert B.y == express((cos(q2)*A.y + sin(q2)*A.z), B).simplify() + assert B.z == express((-sin(q2)*A.y + cos(q2)*A.z), B).simplify() + assert B.x == express((cos(q3)*C.x + sin(q3)*C.z), B).simplify() + assert B.z == express((-sin(q3)*C.x + cos(q3)*C.z), B).simplify() + + """ + assert C.x == express(( + (cos(q1)*cos(q3)-sin(q1)*sin(q2)*sin(q3))*N.x + + (sin(q1)*cos(q3)+sin(q2)*sin(q3)*cos(q1))*N.y - + sin(q3)*cos(q2)*N.z), C) + assert C.y == express(( + -sin(q1)*cos(q2)*N.x + cos(q1)*cos(q2)*N.y + sin(q2)*N.z), C) + assert C.z == express(( + (sin(q3)*cos(q1)+sin(q1)*sin(q2)*cos(q3))*N.x + + (sin(q1)*sin(q3)-sin(q2)*cos(q1)*cos(q3))*N.y + + cos(q2)*cos(q3)*N.z), C) + """ + assert C.x == express((cos(q3)*A.x + sin(q2)*sin(q3)*A.y - + sin(q3)*cos(q2)*A.z), C).simplify() + assert C.y == express((cos(q2)*A.y + sin(q2)*A.z), C).simplify() + assert C.z == express((sin(q3)*A.x - sin(q2)*cos(q3)*A.y + + cos(q2)*cos(q3)*A.z), C).simplify() + assert C.x == express((cos(q3)*B.x - sin(q3)*B.z), C).simplify() + assert C.z == express((sin(q3)*B.x + cos(q3)*B.z), C).simplify() + + +def test_time_derivative(): + #The use of time_derivative for calculations pertaining to scalar + #fields has been tested in test_coordinate_vars in test_essential.py + A = ReferenceFrame('A') + q = dynamicsymbols('q') + qd = dynamicsymbols('q', 1) + B = A.orientnew('B', 'Axis', [q, A.z]) + d = A.x | A.x + assert time_derivative(d, B) == (-qd) * (A.y | A.x) + \ + (-qd) * (A.x | A.y) + d1 = A.x | B.y + assert time_derivative(d1, A) == - qd*(A.x|B.x) + assert time_derivative(d1, B) == - qd*(A.y|B.y) + d2 = A.x | B.x + assert time_derivative(d2, A) == qd*(A.x|B.y) + assert time_derivative(d2, B) == - qd*(A.y|B.x) + d3 = A.x | B.z + assert time_derivative(d3, A) == 0 + assert time_derivative(d3, B) == - qd*(A.y|B.z) + q1, q2, q3, q4 = dynamicsymbols('q1 q2 q3 q4') + q1d, q2d, q3d, q4d = dynamicsymbols('q1 q2 q3 q4', 1) + q1dd, q2dd, q3dd, q4dd = dynamicsymbols('q1 q2 q3 q4', 2) + C = B.orientnew('C', 'Axis', [q4, B.x]) + v1 = q1 * A.z + v2 = q2*A.x + q3*B.y + v3 = q1*A.x + q2*A.y + q3*A.z + assert time_derivative(B.x, C) == 0 + assert time_derivative(B.y, C) == - q4d*B.z + assert time_derivative(B.z, C) == q4d*B.y + assert time_derivative(v1, B) == q1d*A.z + assert time_derivative(v1, C) == - q1*sin(q)*q4d*A.x + \ + q1*cos(q)*q4d*A.y + q1d*A.z + assert time_derivative(v2, A) == q2d*A.x - q3*qd*B.x + q3d*B.y + assert time_derivative(v2, C) == q2d*A.x - q2*qd*A.y + \ + q2*sin(q)*q4d*A.z + q3d*B.y - q3*q4d*B.z + assert time_derivative(v3, B) == (q2*qd + q1d)*A.x + \ + (-q1*qd + q2d)*A.y + q3d*A.z + assert time_derivative(d, C) == - qd*(A.y|A.x) + \ + sin(q)*q4d*(A.z|A.x) - qd*(A.x|A.y) + sin(q)*q4d*(A.x|A.z) + raises(ValueError, lambda: time_derivative(B.x, C, order=0.5)) + raises(ValueError, lambda: time_derivative(B.x, C, order=-1)) + + +def test_get_motion_methods(): + #Initialization + t = dynamicsymbols._t + s1, s2, s3 = symbols('s1 s2 s3') + S1, S2, S3 = symbols('S1 S2 S3') + S4, S5, S6 = symbols('S4 S5 S6') + t1, t2 = symbols('t1 t2') + a, b, c = dynamicsymbols('a b c') + ad, bd, cd = dynamicsymbols('a b c', 1) + a2d, b2d, c2d = dynamicsymbols('a b c', 2) + v0 = S1*N.x + S2*N.y + S3*N.z + v01 = S4*N.x + S5*N.y + S6*N.z + v1 = s1*N.x + s2*N.y + s3*N.z + v2 = a*N.x + b*N.y + c*N.z + v2d = ad*N.x + bd*N.y + cd*N.z + v2dd = a2d*N.x + b2d*N.y + c2d*N.z + #Test position parameter + assert get_motion_params(frame = N) == (0, 0, 0) + assert get_motion_params(N, position=v1) == (0, 0, v1) + assert get_motion_params(N, position=v2) == (v2dd, v2d, v2) + #Test velocity parameter + assert get_motion_params(N, velocity=v1) == (0, v1, v1 * t) + assert get_motion_params(N, velocity=v1, position=v0, timevalue1=t1) == \ + (0, v1, v0 + v1*(t - t1)) + answer = get_motion_params(N, velocity=v1, position=v2, timevalue1=t1) + answer_expected = (0, v1, v1*t - v1*t1 + v2.subs(t, t1)) + assert answer == answer_expected + + answer = get_motion_params(N, velocity=v2, position=v0, timevalue1=t1) + integral_vector = Integral(a, (t, t1, t))*N.x + Integral(b, (t, t1, t))*N.y \ + + Integral(c, (t, t1, t))*N.z + answer_expected = (v2d, v2, v0 + integral_vector) + assert answer == answer_expected + + #Test acceleration parameter + assert get_motion_params(N, acceleration=v1) == \ + (v1, v1 * t, v1 * t**2/2) + assert get_motion_params(N, acceleration=v1, velocity=v0, + position=v2, timevalue1=t1, timevalue2=t2) == \ + (v1, (v0 + v1*t - v1*t2), + -v0*t1 + v1*t**2/2 + v1*t2*t1 - \ + v1*t1**2/2 + t*(v0 - v1*t2) + \ + v2.subs(t, t1)) + assert get_motion_params(N, acceleration=v1, velocity=v0, + position=v01, timevalue1=t1, timevalue2=t2) == \ + (v1, v0 + v1*t - v1*t2, + -v0*t1 + v01 + v1*t**2/2 + \ + v1*t2*t1 - v1*t1**2/2 + \ + t*(v0 - v1*t2)) + answer = get_motion_params(N, acceleration=a*N.x, velocity=S1*N.x, + position=S2*N.x, timevalue1=t1, timevalue2=t2) + i1 = Integral(a, (t, t2, t)) + answer_expected = (a*N.x, (S1 + i1)*N.x, \ + (S2 + Integral(S1 + i1, (t, t1, t)))*N.x) + assert answer == answer_expected + + +def test_kin_eqs(): + q0, q1, q2, q3 = dynamicsymbols('q0 q1 q2 q3') + q0d, q1d, q2d, q3d = dynamicsymbols('q0 q1 q2 q3', 1) + u1, u2, u3 = dynamicsymbols('u1 u2 u3') + ke = kinematic_equations([u1,u2,u3], [q1,q2,q3], 'body', 313) + assert ke == kinematic_equations([u1,u2,u3], [q1,q2,q3], 'body', '313') + kds = kinematic_equations([u1, u2, u3], [q0, q1, q2, q3], 'quaternion') + assert kds == [-0.5 * q0 * u1 - 0.5 * q2 * u3 + 0.5 * q3 * u2 + q1d, + -0.5 * q0 * u2 + 0.5 * q1 * u3 - 0.5 * q3 * u1 + q2d, + -0.5 * q0 * u3 - 0.5 * q1 * u2 + 0.5 * q2 * u1 + q3d, + 0.5 * q1 * u1 + 0.5 * q2 * u2 + 0.5 * q3 * u3 + q0d] + raises(ValueError, lambda: kinematic_equations([u1, u2, u3], [q0, q1, q2], 'quaternion')) + raises(ValueError, lambda: kinematic_equations([u1, u2, u3], [q0, q1, q2, q3], 'quaternion', '123')) + raises(ValueError, lambda: kinematic_equations([u1, u2, u3], [q0, q1, q2, q3], 'foo')) + raises(TypeError, lambda: kinematic_equations(u1, [q0, q1, q2, q3], 'quaternion')) + raises(TypeError, lambda: kinematic_equations([u1], [q0, q1, q2, q3], 'quaternion')) + raises(TypeError, lambda: kinematic_equations([u1, u2, u3], q0, 'quaternion')) + raises(ValueError, lambda: kinematic_equations([u1, u2, u3], [q0, q1, q2, q3], 'body')) + raises(ValueError, lambda: kinematic_equations([u1, u2, u3], [q0, q1, q2, q3], 'space')) + raises(ValueError, lambda: kinematic_equations([u1, u2, u3], [q0, q1, q2], 'body', '222')) + assert kinematic_equations([0, 0, 0], [q0, q1, q2], 'space') == [S.Zero, S.Zero, S.Zero] + + +def test_partial_velocity(): + q1, q2, q3, u1, u2, u3 = dynamicsymbols('q1 q2 q3 u1 u2 u3') + u4, u5 = dynamicsymbols('u4, u5') + r = symbols('r') + + N = ReferenceFrame('N') + Y = N.orientnew('Y', 'Axis', [q1, N.z]) + L = Y.orientnew('L', 'Axis', [q2, Y.x]) + R = L.orientnew('R', 'Axis', [q3, L.y]) + R.set_ang_vel(N, u1 * L.x + u2 * L.y + u3 * L.z) + + C = Point('C') + C.set_vel(N, u4 * L.x + u5 * (Y.z ^ L.x)) + Dmc = C.locatenew('Dmc', r * L.z) + Dmc.v2pt_theory(C, N, R) + + vel_list = [Dmc.vel(N), C.vel(N), R.ang_vel_in(N)] + u_list = [u1, u2, u3, u4, u5] + assert (partial_velocity(vel_list, u_list, N) == + [[- r*L.y, r*L.x, 0, L.x, cos(q2)*L.y - sin(q2)*L.z], + [0, 0, 0, L.x, cos(q2)*L.y - sin(q2)*L.z], + [L.x, L.y, L.z, 0, 0]]) + + # Make sure that partial velocities can be computed regardless if the + # orientation between frames is defined or not. + A = ReferenceFrame('A') + B = ReferenceFrame('B') + v = u4 * A.x + u5 * B.y + assert partial_velocity((v, ), (u4, u5), A) == [[A.x, B.y]] + + raises(TypeError, lambda: partial_velocity(Dmc.vel(N), u_list, N)) + raises(TypeError, lambda: partial_velocity(vel_list, u1, N)) + +def test_dynamicsymbols(): + #Tests to check the assumptions applied to dynamicsymbols + f1 = dynamicsymbols('f1') + f2 = dynamicsymbols('f2', real=True) + f3 = dynamicsymbols('f3', positive=True) + f4, f5 = dynamicsymbols('f4,f5', commutative=False) + f6 = dynamicsymbols('f6', integer=True) + assert f1.is_real is None + assert f2.is_real + assert f3.is_positive + assert f4*f5 != f5*f4 + assert f6.is_integer diff --git a/phivenv/Lib/site-packages/sympy/physics/vector/tests/test_output.py b/phivenv/Lib/site-packages/sympy/physics/vector/tests/test_output.py new file mode 100644 index 0000000000000000000000000000000000000000..e02f3e5962bc23bbb62929e343a5afac574a2570 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/vector/tests/test_output.py @@ -0,0 +1,75 @@ +from sympy.core.singleton import S +from sympy.physics.vector import Vector, ReferenceFrame, Dyadic +from sympy.testing.pytest import raises + +A = ReferenceFrame('A') + + +def test_output_type(): + A = ReferenceFrame('A') + v = A.x + A.y + d = v | v + zerov = Vector(0) + zerod = Dyadic(0) + + # dot products + assert isinstance(d & d, Dyadic) + assert isinstance(d & zerod, Dyadic) + assert isinstance(zerod & d, Dyadic) + assert isinstance(d & v, Vector) + assert isinstance(v & d, Vector) + assert isinstance(d & zerov, Vector) + assert isinstance(zerov & d, Vector) + raises(TypeError, lambda: d & S.Zero) + raises(TypeError, lambda: S.Zero & d) + raises(TypeError, lambda: d & 0) + raises(TypeError, lambda: 0 & d) + assert not isinstance(v & v, (Vector, Dyadic)) + assert not isinstance(v & zerov, (Vector, Dyadic)) + assert not isinstance(zerov & v, (Vector, Dyadic)) + raises(TypeError, lambda: v & S.Zero) + raises(TypeError, lambda: S.Zero & v) + raises(TypeError, lambda: v & 0) + raises(TypeError, lambda: 0 & v) + + # cross products + raises(TypeError, lambda: d ^ d) + raises(TypeError, lambda: d ^ zerod) + raises(TypeError, lambda: zerod ^ d) + assert isinstance(d ^ v, Dyadic) + assert isinstance(v ^ d, Dyadic) + assert isinstance(d ^ zerov, Dyadic) + assert isinstance(zerov ^ d, Dyadic) + assert isinstance(zerov ^ d, Dyadic) + raises(TypeError, lambda: d ^ S.Zero) + raises(TypeError, lambda: S.Zero ^ d) + raises(TypeError, lambda: d ^ 0) + raises(TypeError, lambda: 0 ^ d) + assert isinstance(v ^ v, Vector) + assert isinstance(v ^ zerov, Vector) + assert isinstance(zerov ^ v, Vector) + raises(TypeError, lambda: v ^ S.Zero) + raises(TypeError, lambda: S.Zero ^ v) + raises(TypeError, lambda: v ^ 0) + raises(TypeError, lambda: 0 ^ v) + + # outer products + raises(TypeError, lambda: d | d) + raises(TypeError, lambda: d | zerod) + raises(TypeError, lambda: zerod | d) + raises(TypeError, lambda: d | v) + raises(TypeError, lambda: v | d) + raises(TypeError, lambda: d | zerov) + raises(TypeError, lambda: zerov | d) + raises(TypeError, lambda: zerov | d) + raises(TypeError, lambda: d | S.Zero) + raises(TypeError, lambda: S.Zero | d) + raises(TypeError, lambda: d | 0) + raises(TypeError, lambda: 0 | d) + assert isinstance(v | v, Dyadic) + assert isinstance(v | zerov, Dyadic) + assert isinstance(zerov | v, Dyadic) + raises(TypeError, lambda: v | S.Zero) + raises(TypeError, lambda: S.Zero | v) + raises(TypeError, lambda: v | 0) + raises(TypeError, lambda: 0 | v) diff --git a/phivenv/Lib/site-packages/sympy/physics/vector/tests/test_point.py b/phivenv/Lib/site-packages/sympy/physics/vector/tests/test_point.py new file mode 100644 index 0000000000000000000000000000000000000000..0e0c8b092ef61c590d3c713cef25feb3e64051c6 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/vector/tests/test_point.py @@ -0,0 +1,382 @@ +from sympy.physics.vector import dynamicsymbols, Point, ReferenceFrame +from sympy.testing.pytest import raises, ignore_warnings +import warnings + +def test_point_v1pt_theorys(): + q, q2 = dynamicsymbols('q q2') + qd, q2d = dynamicsymbols('q q2', 1) + qdd, q2dd = dynamicsymbols('q q2', 2) + N = ReferenceFrame('N') + B = ReferenceFrame('B') + B.set_ang_vel(N, qd * B.z) + O = Point('O') + P = O.locatenew('P', B.x) + P.set_vel(B, 0) + O.set_vel(N, 0) + assert P.v1pt_theory(O, N, B) == qd * B.y + O.set_vel(N, N.x) + assert P.v1pt_theory(O, N, B) == N.x + qd * B.y + P.set_vel(B, B.z) + assert P.v1pt_theory(O, N, B) == B.z + N.x + qd * B.y + + +def test_point_a1pt_theorys(): + q, q2 = dynamicsymbols('q q2') + qd, q2d = dynamicsymbols('q q2', 1) + qdd, q2dd = dynamicsymbols('q q2', 2) + N = ReferenceFrame('N') + B = ReferenceFrame('B') + B.set_ang_vel(N, qd * B.z) + O = Point('O') + P = O.locatenew('P', B.x) + P.set_vel(B, 0) + O.set_vel(N, 0) + assert P.a1pt_theory(O, N, B) == -(qd**2) * B.x + qdd * B.y + P.set_vel(B, q2d * B.z) + assert P.a1pt_theory(O, N, B) == -(qd**2) * B.x + qdd * B.y + q2dd * B.z + O.set_vel(N, q2d * B.x) + assert P.a1pt_theory(O, N, B) == ((q2dd - qd**2) * B.x + (q2d * qd + qdd) * B.y + + q2dd * B.z) + + +def test_point_v2pt_theorys(): + q = dynamicsymbols('q') + qd = dynamicsymbols('q', 1) + N = ReferenceFrame('N') + B = N.orientnew('B', 'Axis', [q, N.z]) + O = Point('O') + P = O.locatenew('P', 0) + O.set_vel(N, 0) + assert P.v2pt_theory(O, N, B) == 0 + P = O.locatenew('P', B.x) + assert P.v2pt_theory(O, N, B) == (qd * B.z ^ B.x) + O.set_vel(N, N.x) + assert P.v2pt_theory(O, N, B) == N.x + qd * B.y + + +def test_point_a2pt_theorys(): + q = dynamicsymbols('q') + qd = dynamicsymbols('q', 1) + qdd = dynamicsymbols('q', 2) + N = ReferenceFrame('N') + B = N.orientnew('B', 'Axis', [q, N.z]) + O = Point('O') + P = O.locatenew('P', 0) + O.set_vel(N, 0) + assert P.a2pt_theory(O, N, B) == 0 + P.set_pos(O, B.x) + assert P.a2pt_theory(O, N, B) == (-qd**2) * B.x + (qdd) * B.y + + +def test_point_funcs(): + q, q2 = dynamicsymbols('q q2') + qd, q2d = dynamicsymbols('q q2', 1) + qdd, q2dd = dynamicsymbols('q q2', 2) + N = ReferenceFrame('N') + B = ReferenceFrame('B') + B.set_ang_vel(N, 5 * B.y) + O = Point('O') + P = O.locatenew('P', q * B.x + q2 * B.y) + assert P.pos_from(O) == q * B.x + q2 * B.y + P.set_vel(B, qd * B.x + q2d * B.y) + assert P.vel(B) == qd * B.x + q2d * B.y + O.set_vel(N, 0) + assert O.vel(N) == 0 + assert P.a1pt_theory(O, N, B) == ((-25 * q + qdd) * B.x + (q2dd) * B.y + + (-10 * qd) * B.z) + + B = N.orientnew('B', 'Axis', [q, N.z]) + O = Point('O') + P = O.locatenew('P', 10 * B.x) + O.set_vel(N, 5 * N.x) + assert O.vel(N) == 5 * N.x + assert P.a2pt_theory(O, N, B) == (-10 * qd**2) * B.x + (10 * qdd) * B.y + + B.set_ang_vel(N, 5 * B.y) + O = Point('O') + P = O.locatenew('P', q * B.x + q2 * B.y) + P.set_vel(B, qd * B.x + q2d * B.y) + O.set_vel(N, 0) + assert P.v1pt_theory(O, N, B) == qd * B.x + q2d * B.y - 5 * q * B.z + + +def test_point_pos(): + q = dynamicsymbols('q') + N = ReferenceFrame('N') + B = N.orientnew('B', 'Axis', [q, N.z]) + O = Point('O') + P = O.locatenew('P', 10 * N.x + 5 * B.x) + assert P.pos_from(O) == 10 * N.x + 5 * B.x + Q = P.locatenew('Q', 10 * N.y + 5 * B.y) + assert Q.pos_from(P) == 10 * N.y + 5 * B.y + assert Q.pos_from(O) == 10 * N.x + 10 * N.y + 5 * B.x + 5 * B.y + assert O.pos_from(Q) == -10 * N.x - 10 * N.y - 5 * B.x - 5 * B.y + +def test_point_partial_velocity(): + + N = ReferenceFrame('N') + A = ReferenceFrame('A') + + p = Point('p') + + u1, u2 = dynamicsymbols('u1, u2') + + p.set_vel(N, u1 * A.x + u2 * N.y) + + assert p.partial_velocity(N, u1) == A.x + assert p.partial_velocity(N, u1, u2) == (A.x, N.y) + raises(ValueError, lambda: p.partial_velocity(A, u1)) + +def test_point_vel(): #Basic functionality + q1, q2 = dynamicsymbols('q1 q2') + N = ReferenceFrame('N') + B = ReferenceFrame('B') + Q = Point('Q') + O = Point('O') + Q.set_pos(O, q1 * N.x) + raises(ValueError , lambda: Q.vel(N)) # Velocity of O in N is not defined + O.set_vel(N, q2 * N.y) + assert O.vel(N) == q2 * N.y + raises(ValueError , lambda : O.vel(B)) #Velocity of O is not defined in B + +def test_auto_point_vel(): + t = dynamicsymbols._t + q1, q2 = dynamicsymbols('q1 q2') + N = ReferenceFrame('N') + B = ReferenceFrame('B') + O = Point('O') + Q = Point('Q') + Q.set_pos(O, q1 * N.x) + O.set_vel(N, q2 * N.y) + assert Q.vel(N) == q1.diff(t) * N.x + q2 * N.y # Velocity of Q using O + P1 = Point('P1') + P1.set_pos(O, q1 * B.x) + P2 = Point('P2') + P2.set_pos(P1, q2 * B.z) + raises(ValueError, lambda : P2.vel(B)) # O's velocity is defined in different frame, and no + #point in between has its velocity defined + raises(ValueError, lambda: P2.vel(N)) # Velocity of O not defined in N + +def test_auto_point_vel_multiple_point_path(): + t = dynamicsymbols._t + q1, q2 = dynamicsymbols('q1 q2') + B = ReferenceFrame('B') + P = Point('P') + P.set_vel(B, q1 * B.x) + P1 = Point('P1') + P1.set_pos(P, q2 * B.y) + P1.set_vel(B, q1 * B.z) + P2 = Point('P2') + P2.set_pos(P1, q1 * B.z) + P3 = Point('P3') + P3.set_pos(P2, 10 * q1 * B.y) + assert P3.vel(B) == 10 * q1.diff(t) * B.y + (q1 + q1.diff(t)) * B.z + +def test_auto_vel_dont_overwrite(): + t = dynamicsymbols._t + q1, q2, u1 = dynamicsymbols('q1, q2, u1') + N = ReferenceFrame('N') + P = Point('P1') + P.set_vel(N, u1 * N.x) + P1 = Point('P1') + P1.set_pos(P, q2 * N.y) + assert P1.vel(N) == q2.diff(t) * N.y + u1 * N.x + assert P.vel(N) == u1 * N.x + P1.set_vel(N, u1 * N.z) + assert P1.vel(N) == u1 * N.z + +def test_auto_point_vel_if_tree_has_vel_but_inappropriate_pos_vector(): + q1, q2 = dynamicsymbols('q1 q2') + B = ReferenceFrame('B') + S = ReferenceFrame('S') + P = Point('P') + P.set_vel(B, q1 * B.x) + P1 = Point('P1') + P1.set_pos(P, S.y) + raises(ValueError, lambda : P1.vel(B)) # P1.pos_from(P) can't be expressed in B + raises(ValueError, lambda : P1.vel(S)) # P.vel(S) not defined + +def test_auto_point_vel_shortest_path(): + t = dynamicsymbols._t + q1, q2, u1, u2 = dynamicsymbols('q1 q2 u1 u2') + B = ReferenceFrame('B') + P = Point('P') + P.set_vel(B, u1 * B.x) + P1 = Point('P1') + P1.set_pos(P, q2 * B.y) + P1.set_vel(B, q1 * B.z) + P2 = Point('P2') + P2.set_pos(P1, q1 * B.z) + P3 = Point('P3') + P3.set_pos(P2, 10 * q1 * B.y) + P4 = Point('P4') + P4.set_pos(P3, q1 * B.x) + O = Point('O') + O.set_vel(B, u2 * B.y) + O1 = Point('O1') + O1.set_pos(O, q2 * B.z) + P4.set_pos(O1, q1 * B.x + q2 * B.z) + with warnings.catch_warnings(): #There are two possible paths in this point tree, thus a warning is raised + warnings.simplefilter('error') + with ignore_warnings(UserWarning): + assert P4.vel(B) == q1.diff(t) * B.x + u2 * B.y + 2 * q2.diff(t) * B.z + +def test_auto_point_vel_connected_frames(): + t = dynamicsymbols._t + q, q1, q2, u = dynamicsymbols('q q1 q2 u') + N = ReferenceFrame('N') + B = ReferenceFrame('B') + O = Point('O') + O.set_vel(N, u * N.x) + P = Point('P') + P.set_pos(O, q1 * N.x + q2 * B.y) + raises(ValueError, lambda: P.vel(N)) + N.orient(B, 'Axis', (q, B.x)) + assert P.vel(N) == (u + q1.diff(t)) * N.x + q2.diff(t) * B.y - q2 * q.diff(t) * B.z + +def test_auto_point_vel_multiple_paths_warning_arises(): + q, u = dynamicsymbols('q u') + N = ReferenceFrame('N') + O = Point('O') + P = Point('P') + Q = Point('Q') + R = Point('R') + P.set_vel(N, u * N.x) + Q.set_vel(N, u *N.y) + R.set_vel(N, u * N.z) + O.set_pos(P, q * N.z) + O.set_pos(Q, q * N.y) + O.set_pos(R, q * N.x) + with warnings.catch_warnings(): #There are two possible paths in this point tree, thus a warning is raised + warnings.simplefilter("error") + raises(UserWarning ,lambda: O.vel(N)) + +def test_auto_vel_cyclic_warning_arises(): + P = Point('P') + P1 = Point('P1') + P2 = Point('P2') + P3 = Point('P3') + N = ReferenceFrame('N') + P.set_vel(N, N.x) + P1.set_pos(P, N.x) + P2.set_pos(P1, N.y) + P3.set_pos(P2, N.z) + P1.set_pos(P3, N.x + N.y) + with warnings.catch_warnings(): #The path is cyclic at P1, thus a warning is raised + warnings.simplefilter("error") + raises(UserWarning ,lambda: P2.vel(N)) + +def test_auto_vel_cyclic_warning_msg(): + P = Point('P') + P1 = Point('P1') + P2 = Point('P2') + P3 = Point('P3') + N = ReferenceFrame('N') + P.set_vel(N, N.x) + P1.set_pos(P, N.x) + P2.set_pos(P1, N.y) + P3.set_pos(P2, N.z) + P1.set_pos(P3, N.x + N.y) + with warnings.catch_warnings(record = True) as w: #The path is cyclic at P1, thus a warning is raised + warnings.simplefilter("always") + P2.vel(N) + msg = str(w[-1].message).replace("\n", " ") + assert issubclass(w[-1].category, UserWarning) + assert 'Kinematic loops are defined among the positions of points. This is likely not desired and may cause errors in your calculations.' in msg + +def test_auto_vel_multiple_path_warning_msg(): + N = ReferenceFrame('N') + O = Point('O') + P = Point('P') + Q = Point('Q') + P.set_vel(N, N.x) + Q.set_vel(N, N.y) + O.set_pos(P, N.z) + O.set_pos(Q, N.y) + with warnings.catch_warnings(record = True) as w: #There are two possible paths in this point tree, thus a warning is raised + warnings.simplefilter("always") + O.vel(N) + msg = str(w[-1].message).replace("\n", " ") + assert issubclass(w[-1].category, UserWarning) + assert 'Velocity' in msg + assert 'automatically calculated based on point' in msg + assert 'Velocities from these points are not necessarily the same. This may cause errors in your calculations.' in msg + +def test_auto_vel_derivative(): + q1, q2 = dynamicsymbols('q1:3') + u1, u2 = dynamicsymbols('u1:3', 1) + A = ReferenceFrame('A') + B = ReferenceFrame('B') + C = ReferenceFrame('C') + B.orient_axis(A, A.z, q1) + B.set_ang_vel(A, u1 * A.z) + C.orient_axis(B, B.z, q2) + C.set_ang_vel(B, u2 * B.z) + + Am = Point('Am') + Am.set_vel(A, 0) + Bm = Point('Bm') + Bm.set_pos(Am, B.x) + Bm.set_vel(B, 0) + Bm.set_vel(C, 0) + Cm = Point('Cm') + Cm.set_pos(Bm, C.x) + Cm.set_vel(C, 0) + temp = Cm._vel_dict.copy() + assert Cm.vel(A) == (u1 * B.y + (u1 + u2) * C.y) + Cm._vel_dict = temp + Cm.v2pt_theory(Bm, B, C) + assert Cm.vel(A) == (u1 * B.y + (u1 + u2) * C.y) + +def test_auto_point_acc_zero_vel(): + N = ReferenceFrame('N') + O = Point('O') + O.set_vel(N, 0) + assert O.acc(N) == 0 * N.x + +def test_auto_point_acc_compute_vel(): + t = dynamicsymbols._t + q1 = dynamicsymbols('q1') + N = ReferenceFrame('N') + A = ReferenceFrame('A') + A.orient_axis(N, N.z, q1) + + O = Point('O') + O.set_vel(N, 0) + P = Point('P') + P.set_pos(O, A.x) + assert P.acc(N) == -q1.diff(t) ** 2 * A.x + q1.diff(t, 2) * A.y + +def test_auto_acc_derivative(): + # Tests whether the Point.acc method gives the correct acceleration of the + # end point of two linkages in series, while getting minimal information. + q1, q2 = dynamicsymbols('q1:3') + u1, u2 = dynamicsymbols('q1:3', 1) + v1, v2 = dynamicsymbols('q1:3', 2) + A = ReferenceFrame('A') + B = ReferenceFrame('B') + C = ReferenceFrame('C') + B.orient_axis(A, A.z, q1) + C.orient_axis(B, B.z, q2) + + Am = Point('Am') + Am.set_vel(A, 0) + Bm = Point('Bm') + Bm.set_pos(Am, B.x) + Bm.set_vel(B, 0) + Bm.set_vel(C, 0) + Cm = Point('Cm') + Cm.set_pos(Bm, C.x) + Cm.set_vel(C, 0) + + # Copy dictionaries to later check the calculation using the 2pt_theories + Bm_vel_dict, Cm_vel_dict = Bm._vel_dict.copy(), Cm._vel_dict.copy() + Bm_acc_dict, Cm_acc_dict = Bm._acc_dict.copy(), Cm._acc_dict.copy() + check = -u1 ** 2 * B.x + v1 * B.y - (u1 + u2) ** 2 * C.x + (v1 + v2) * C.y + assert Cm.acc(A) == check + Bm._vel_dict, Cm._vel_dict = Bm_vel_dict, Cm_vel_dict + Bm._acc_dict, Cm._acc_dict = Bm_acc_dict, Cm_acc_dict + Bm.v2pt_theory(Am, A, B) + Cm.v2pt_theory(Bm, A, C) + Bm.a2pt_theory(Am, A, B) + assert Cm.a2pt_theory(Bm, A, C) == check diff --git a/phivenv/Lib/site-packages/sympy/physics/vector/tests/test_printing.py b/phivenv/Lib/site-packages/sympy/physics/vector/tests/test_printing.py new file mode 100644 index 0000000000000000000000000000000000000000..0930fe9d0bc6e2fcc60b34f37215fdb19e32fdc4 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/vector/tests/test_printing.py @@ -0,0 +1,353 @@ +# -*- coding: utf-8 -*- + +from sympy.core.function import Function +from sympy.core.symbol import symbols +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (asin, cos, sin) +from sympy.physics.vector import ReferenceFrame, dynamicsymbols, Dyadic +from sympy.physics.vector.printing import (VectorLatexPrinter, vpprint, + vsprint, vsstrrepr, vlatex) + + +a, b, c = symbols('a, b, c') +alpha, omega, beta = dynamicsymbols('alpha, omega, beta') + +A = ReferenceFrame('A') +N = ReferenceFrame('N') + +v = a ** 2 * N.x + b * N.y + c * sin(alpha) * N.z +w = alpha * N.x + sin(omega) * N.y + alpha * beta * N.z +ww = alpha * N.x + asin(omega) * N.y - alpha.diff() * beta * N.z +o = a/b * N.x + (c+b)/a * N.y + c**2/b * N.z + +y = a ** 2 * (N.x | N.y) + b * (N.y | N.y) + c * sin(alpha) * (N.z | N.y) +x = alpha * (N.x | N.x) + sin(omega) * (N.y | N.z) + alpha * beta * (N.z | N.x) +xx = N.x | (-N.y - N.z) +xx2 = N.x | (N.y + N.z) + +def ascii_vpretty(expr): + return vpprint(expr, use_unicode=False, wrap_line=False) + + +def unicode_vpretty(expr): + return vpprint(expr, use_unicode=True, wrap_line=False) + + +def test_latex_printer(): + r = Function('r')('t') + assert VectorLatexPrinter().doprint(r ** 2) == "r^{2}" + r2 = Function('r^2')('t') + assert VectorLatexPrinter().doprint(r2.diff()) == r'\dot{r^{2}}' + ra = Function('r__a')('t') + assert VectorLatexPrinter().doprint(ra.diff().diff()) == r'\ddot{r^{a}}' + + +def test_vector_pretty_print(): + + # TODO : The unit vectors should print with subscripts but they just + # print as `n_x` instead of making `x` a subscript with unicode. + + # TODO : The pretty print division does not print correctly here: + # w = alpha * N.x + sin(omega) * N.y + alpha / beta * N.z + + expected = """\ + 2 \n\ +a n_x + b n_y + c*sin(alpha) n_z\ +""" + uexpected = """\ + 2 \n\ +a n_x + b n_y + c⋅sin(α) n_z\ +""" + + assert ascii_vpretty(v) == expected + assert unicode_vpretty(v) == uexpected + + expected = 'alpha n_x + sin(omega) n_y + alpha*beta n_z' + uexpected = 'α n_x + sin(ω) n_y + α⋅β n_z' + + assert ascii_vpretty(w) == expected + assert unicode_vpretty(w) == uexpected + + expected = """\ + 2 \n\ +a b + c c \n\ +- n_x + ----- n_y + -- n_z\n\ +b a b \ +""" + uexpected = """\ + 2 \n\ +a b + c c \n\ +─ n_x + ───── n_y + ── n_z\n\ +b a b \ +""" + + assert ascii_vpretty(o) == expected + assert unicode_vpretty(o) == uexpected + + # https://github.com/sympy/sympy/issues/26731 + assert ascii_vpretty(-A.x) == '-a_x' + assert unicode_vpretty(-A.x) == '-a_x' + + # https://github.com/sympy/sympy/issues/26799 + assert ascii_vpretty(0*A.x) == '0' + assert unicode_vpretty(0*A.x) == '0' + + +def test_vector_latex(): + + a, b, c, d, omega = symbols('a, b, c, d, omega') + + v = (a ** 2 + b / c) * A.x + sqrt(d) * A.y + cos(omega) * A.z + + assert vlatex(v) == (r'(a^{2} + \frac{b}{c})\mathbf{\hat{a}_x} + ' + r'\sqrt{d}\mathbf{\hat{a}_y} + ' + r'\cos{\left(\omega \right)}' + r'\mathbf{\hat{a}_z}') + + theta, omega, alpha, q = dynamicsymbols('theta, omega, alpha, q') + + v = theta * A.x + omega * omega * A.y + (q * alpha) * A.z + + assert vlatex(v) == (r'\theta\mathbf{\hat{a}_x} + ' + r'\omega^{2}\mathbf{\hat{a}_y} + ' + r'\alpha q\mathbf{\hat{a}_z}') + + phi1, phi2, phi3 = dynamicsymbols('phi1, phi2, phi3') + theta1, theta2, theta3 = symbols('theta1, theta2, theta3') + + v = (sin(theta1) * A.x + + cos(phi1) * cos(phi2) * A.y + + cos(theta1 + phi3) * A.z) + + assert vlatex(v) == (r'\sin{\left(\theta_{1} \right)}' + r'\mathbf{\hat{a}_x} + \cos{' + r'\left(\phi_{1} \right)} \cos{' + r'\left(\phi_{2} \right)}\mathbf{\hat{a}_y} + ' + r'\cos{\left(\theta_{1} + ' + r'\phi_{3} \right)}\mathbf{\hat{a}_z}') + + N = ReferenceFrame('N') + + a, b, c, d, omega = symbols('a, b, c, d, omega') + + v = (a ** 2 + b / c) * N.x + sqrt(d) * N.y + cos(omega) * N.z + + expected = (r'(a^{2} + \frac{b}{c})\mathbf{\hat{n}_x} + ' + r'\sqrt{d}\mathbf{\hat{n}_y} + ' + r'\cos{\left(\omega \right)}' + r'\mathbf{\hat{n}_z}') + + assert vlatex(v) == expected + + # Try custom unit vectors. + + N = ReferenceFrame('N', latexs=(r'\hat{i}', r'\hat{j}', r'\hat{k}')) + + v = (a ** 2 + b / c) * N.x + sqrt(d) * N.y + cos(omega) * N.z + + expected = (r'(a^{2} + \frac{b}{c})\hat{i} + ' + r'\sqrt{d}\hat{j} + ' + r'\cos{\left(\omega \right)}\hat{k}') + assert vlatex(v) == expected + + expected = r'\alpha\mathbf{\hat{n}_x} + \operatorname{asin}{\left(\omega ' \ + r'\right)}\mathbf{\hat{n}_y} - \beta \dot{\alpha}\mathbf{\hat{n}_z}' + assert vlatex(ww) == expected + + expected = r'- \mathbf{\hat{n}_x}\otimes \mathbf{\hat{n}_y} - ' \ + r'\mathbf{\hat{n}_x}\otimes \mathbf{\hat{n}_z}' + assert vlatex(xx) == expected + + expected = r'\mathbf{\hat{n}_x}\otimes \mathbf{\hat{n}_y} + ' \ + r'\mathbf{\hat{n}_x}\otimes \mathbf{\hat{n}_z}' + assert vlatex(xx2) == expected + + +def test_vector_latex_arguments(): + assert vlatex(N.x * 3.0, full_prec=False) == r'3.0\mathbf{\hat{n}_x}' + assert vlatex(N.x * 3.0, full_prec=True) == r'3.00000000000000\mathbf{\hat{n}_x}' + + +def test_vector_latex_with_functions(): + + N = ReferenceFrame('N') + + omega, alpha = dynamicsymbols('omega, alpha') + + v = omega.diff() * N.x + + assert vlatex(v) == r'\dot{\omega}\mathbf{\hat{n}_x}' + + v = omega.diff() ** alpha * N.x + + assert vlatex(v) == (r'\dot{\omega}^{\alpha}' + r'\mathbf{\hat{n}_x}') + + +def test_dyadic_pretty_print(): + + expected = """\ + 2 +a n_x|n_y + b n_y|n_y + c*sin(alpha) n_z|n_y\ +""" + + uexpected = """\ + 2 +a n_x⊗n_y + b n_y⊗n_y + c⋅sin(α) n_z⊗n_y\ +""" + assert ascii_vpretty(y) == expected + assert unicode_vpretty(y) == uexpected + + expected = 'alpha n_x|n_x + sin(omega) n_y|n_z + alpha*beta n_z|n_x' + uexpected = 'α n_x⊗n_x + sin(ω) n_y⊗n_z + α⋅β n_z⊗n_x' + assert ascii_vpretty(x) == expected + assert unicode_vpretty(x) == uexpected + + assert ascii_vpretty(Dyadic([])) == '0' + assert unicode_vpretty(Dyadic([])) == '0' + + assert ascii_vpretty(xx) == '- n_x|n_y - n_x|n_z' + assert unicode_vpretty(xx) == '- n_x⊗n_y - n_x⊗n_z' + + assert ascii_vpretty(xx2) == 'n_x|n_y + n_x|n_z' + assert unicode_vpretty(xx2) == 'n_x⊗n_y + n_x⊗n_z' + + +def test_dyadic_latex(): + + expected = (r'a^{2}\mathbf{\hat{n}_x}\otimes \mathbf{\hat{n}_y} + ' + r'b\mathbf{\hat{n}_y}\otimes \mathbf{\hat{n}_y} + ' + r'c \sin{\left(\alpha \right)}' + r'\mathbf{\hat{n}_z}\otimes \mathbf{\hat{n}_y}') + + assert vlatex(y) == expected + + expected = (r'\alpha\mathbf{\hat{n}_x}\otimes \mathbf{\hat{n}_x} + ' + r'\sin{\left(\omega \right)}\mathbf{\hat{n}_y}' + r'\otimes \mathbf{\hat{n}_z} + ' + r'\alpha \beta\mathbf{\hat{n}_z}\otimes \mathbf{\hat{n}_x}') + + assert vlatex(x) == expected + + assert vlatex(Dyadic([])) == '0' + + +def test_dyadic_str(): + assert vsprint(Dyadic([])) == '0' + assert vsprint(y) == 'a**2*(N.x|N.y) + b*(N.y|N.y) + c*sin(alpha)*(N.z|N.y)' + assert vsprint(x) == 'alpha*(N.x|N.x) + sin(omega)*(N.y|N.z) + alpha*beta*(N.z|N.x)' + assert vsprint(ww) == "alpha*N.x + asin(omega)*N.y - beta*alpha'*N.z" + assert vsprint(xx) == '- (N.x|N.y) - (N.x|N.z)' + assert vsprint(xx2) == '(N.x|N.y) + (N.x|N.z)' + + +def test_vlatex(): # vlatex is broken #12078 + from sympy.physics.vector import vlatex + + x = symbols('x') + J = symbols('J') + + f = Function('f') + g = Function('g') + h = Function('h') + + expected = r'J \left(\frac{d}{d x} g{\left(x \right)} - \frac{d}{d x} h{\left(x \right)}\right)' + + expr = J*f(x).diff(x).subs(f(x), g(x)-h(x)) + + assert vlatex(expr) == expected + + +def test_issue_13354(): + """ + Test for proper pretty printing of physics vectors with ADD + instances in arguments. + + Test is exactly the one suggested in the original bug report by + @moorepants. + """ + + a, b, c = symbols('a, b, c') + A = ReferenceFrame('A') + v = a * A.x + b * A.y + c * A.z + w = b * A.x + c * A.y + a * A.z + z = w + v + + expected = """(a + b) a_x + (b + c) a_y + (a + c) a_z""" + + assert ascii_vpretty(z) == expected + + +def test_vector_derivative_printing(): + # First order + v = omega.diff() * N.x + assert unicode_vpretty(v) == 'ω̇ n_x' + assert ascii_vpretty(v) == "omega'(t) n_x" + + # Second order + v = omega.diff().diff() * N.x + + assert vlatex(v) == r'\ddot{\omega}\mathbf{\hat{n}_x}' + assert unicode_vpretty(v) == 'ω̈ n_x' + assert ascii_vpretty(v) == "omega''(t) n_x" + + # Third order + v = omega.diff().diff().diff() * N.x + + assert vlatex(v) == r'\dddot{\omega}\mathbf{\hat{n}_x}' + assert unicode_vpretty(v) == 'ω⃛ n_x' + assert ascii_vpretty(v) == "omega'''(t) n_x" + + # Fourth order + v = omega.diff().diff().diff().diff() * N.x + + assert vlatex(v) == r'\ddddot{\omega}\mathbf{\hat{n}_x}' + assert unicode_vpretty(v) == 'ω⃜ n_x' + assert ascii_vpretty(v) == "omega''''(t) n_x" + + # Fifth order + v = omega.diff().diff().diff().diff().diff() * N.x + + assert vlatex(v) == r'\frac{d^{5}}{d t^{5}} \omega\mathbf{\hat{n}_x}' + expected = '''\ + 5 \n\ +d \n\ +---(omega) n_x\n\ + 5 \n\ +dt \ +''' + uexpected = '''\ + 5 \n\ +d \n\ +───(ω) n_x\n\ + 5 \n\ +dt \ +''' + assert unicode_vpretty(v) == uexpected + assert ascii_vpretty(v) == expected + + +def test_vector_str_printing(): + assert vsprint(w) == 'alpha*N.x + sin(omega)*N.y + alpha*beta*N.z' + assert vsprint(omega.diff() * N.x) == "omega'*N.x" + assert vsstrrepr(w) == 'alpha*N.x + sin(omega)*N.y + alpha*beta*N.z' + + +def test_vector_str_arguments(): + assert vsprint(N.x * 3.0, full_prec=False) == '3.0*N.x' + assert vsprint(N.x * 3.0, full_prec=True) == '3.00000000000000*N.x' + + +def test_issue_14041(): + import sympy.physics.mechanics as me + + A_frame = me.ReferenceFrame('A') + thetad, phid = me.dynamicsymbols('theta, phi', 1) + L = symbols('L') + + assert vlatex(L*(phid + thetad)**2*A_frame.x) == \ + r"L \left(\dot{\phi} + \dot{\theta}\right)^{2}\mathbf{\hat{a}_x}" + assert vlatex((phid + thetad)**2*A_frame.x) == \ + r"\left(\dot{\phi} + \dot{\theta}\right)^{2}\mathbf{\hat{a}_x}" + assert vlatex((phid*thetad)**a*A_frame.x) == \ + r"\left(\dot{\phi} \dot{\theta}\right)^{a}\mathbf{\hat{a}_x}" diff --git a/phivenv/Lib/site-packages/sympy/physics/vector/tests/test_vector.py b/phivenv/Lib/site-packages/sympy/physics/vector/tests/test_vector.py new file mode 100644 index 0000000000000000000000000000000000000000..2b9c154e60be553228d37eec609dfc23120935ff --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/vector/tests/test_vector.py @@ -0,0 +1,274 @@ +from sympy.core.numbers import (Float, pi) +from sympy.core.symbol import symbols +from sympy.core.sorting import ordered +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.matrices.immutable import ImmutableDenseMatrix as Matrix +from sympy.physics.vector import ReferenceFrame, Vector, dynamicsymbols, dot +from sympy.physics.vector.vector import VectorTypeError +from sympy.abc import x, y, z +from sympy.testing.pytest import raises + +A = ReferenceFrame('A') + + +def test_free_dynamicsymbols(): + A, B, C, D = symbols('A, B, C, D', cls=ReferenceFrame) + a, b, c, d, e, f = dynamicsymbols('a, b, c, d, e, f') + B.orient_axis(A, a, A.x) + C.orient_axis(B, b, B.y) + D.orient_axis(C, c, C.x) + + v = d*D.x + e*D.y + f*D.z + + assert set(ordered(v.free_dynamicsymbols(A))) == {a, b, c, d, e, f} + assert set(ordered(v.free_dynamicsymbols(B))) == {b, c, d, e, f} + assert set(ordered(v.free_dynamicsymbols(C))) == {c, d, e, f} + assert set(ordered(v.free_dynamicsymbols(D))) == {d, e, f} + + +def test_Vector(): + assert A.x != A.y + assert A.y != A.z + assert A.z != A.x + + assert A.x + 0 == A.x + + v1 = x*A.x + y*A.y + z*A.z + v2 = x**2*A.x + y**2*A.y + z**2*A.z + v3 = v1 + v2 + v4 = v1 - v2 + + assert isinstance(v1, Vector) + assert dot(v1, A.x) == x + assert dot(v1, A.y) == y + assert dot(v1, A.z) == z + + assert isinstance(v2, Vector) + assert dot(v2, A.x) == x**2 + assert dot(v2, A.y) == y**2 + assert dot(v2, A.z) == z**2 + + assert isinstance(v3, Vector) + # We probably shouldn't be using simplify in dot... + assert dot(v3, A.x) == x**2 + x + assert dot(v3, A.y) == y**2 + y + assert dot(v3, A.z) == z**2 + z + + assert isinstance(v4, Vector) + # We probably shouldn't be using simplify in dot... + assert dot(v4, A.x) == x - x**2 + assert dot(v4, A.y) == y - y**2 + assert dot(v4, A.z) == z - z**2 + + assert v1.to_matrix(A) == Matrix([[x], [y], [z]]) + q = symbols('q') + B = A.orientnew('B', 'Axis', (q, A.x)) + assert v1.to_matrix(B) == Matrix([[x], + [ y * cos(q) + z * sin(q)], + [-y * sin(q) + z * cos(q)]]) + + #Test the separate method + B = ReferenceFrame('B') + v5 = x*A.x + y*A.y + z*B.z + assert Vector(0).separate() == {} + assert v1.separate() == {A: v1} + assert v5.separate() == {A: x*A.x + y*A.y, B: z*B.z} + + #Test the free_symbols property + v6 = x*A.x + y*A.y + z*A.z + assert v6.free_symbols(A) == {x,y,z} + + raises(TypeError, lambda: v3.applyfunc(v1)) + + +def test_Vector_diffs(): + q1, q2, q3, q4 = dynamicsymbols('q1 q2 q3 q4') + q1d, q2d, q3d, q4d = dynamicsymbols('q1 q2 q3 q4', 1) + q1dd, q2dd, q3dd, q4dd = dynamicsymbols('q1 q2 q3 q4', 2) + N = ReferenceFrame('N') + A = N.orientnew('A', 'Axis', [q3, N.z]) + B = A.orientnew('B', 'Axis', [q2, A.x]) + v1 = q2 * A.x + q3 * N.y + v2 = q3 * B.x + v1 + v3 = v1.dt(B) + v4 = v2.dt(B) + v5 = q1*A.x + q2*A.y + q3*A.z + + assert v1.dt(N) == q2d * A.x + q2 * q3d * A.y + q3d * N.y + assert v1.dt(A) == q2d * A.x + q3 * q3d * N.x + q3d * N.y + assert v1.dt(B) == (q2d * A.x + q3 * q3d * N.x + q3d * + N.y - q3 * cos(q3) * q2d * N.z) + assert v2.dt(N) == (q2d * A.x + (q2 + q3) * q3d * A.y + q3d * B.x + q3d * + N.y) + assert v2.dt(A) == q2d * A.x + q3d * B.x + q3 * q3d * N.x + q3d * N.y + assert v2.dt(B) == (q2d * A.x + q3d * B.x + q3 * q3d * N.x + q3d * N.y - + q3 * cos(q3) * q2d * N.z) + assert v3.dt(N) == (q2dd * A.x + q2d * q3d * A.y + (q3d**2 + q3 * q3dd) * + N.x + q3dd * N.y + (q3 * sin(q3) * q2d * q3d - + cos(q3) * q2d * q3d - q3 * cos(q3) * q2dd) * N.z) + assert v3.dt(A) == (q2dd * A.x + (2 * q3d**2 + q3 * q3dd) * N.x + (q3dd - + q3 * q3d**2) * N.y + (q3 * sin(q3) * q2d * q3d - + cos(q3) * q2d * q3d - q3 * cos(q3) * q2dd) * N.z) + assert (v3.dt(B) - (q2dd*A.x - q3*cos(q3)*q2d**2*A.y + (2*q3d**2 + + q3*q3dd)*N.x + (q3dd - q3*q3d**2)*N.y + (2*q3*sin(q3)*q2d*q3d - + 2*cos(q3)*q2d*q3d - q3*cos(q3)*q2dd)*N.z)).express(B).simplify() == 0 + assert v4.dt(N) == (q2dd * A.x + q3d * (q2d + q3d) * A.y + q3dd * B.x + + (q3d**2 + q3 * q3dd) * N.x + q3dd * N.y + (q3 * + sin(q3) * q2d * q3d - cos(q3) * q2d * q3d - q3 * + cos(q3) * q2dd) * N.z) + assert v4.dt(A) == (q2dd * A.x + q3dd * B.x + (2 * q3d**2 + q3 * q3dd) * + N.x + (q3dd - q3 * q3d**2) * N.y + (q3 * sin(q3) * + q2d * q3d - cos(q3) * q2d * q3d - q3 * cos(q3) * + q2dd) * N.z) + assert (v4.dt(B) - (q2dd*A.x - q3*cos(q3)*q2d**2*A.y + q3dd*B.x + + (2*q3d**2 + q3*q3dd)*N.x + (q3dd - q3*q3d**2)*N.y + + (2*q3*sin(q3)*q2d*q3d - 2*cos(q3)*q2d*q3d - + q3*cos(q3)*q2dd)*N.z)).express(B).simplify() == 0 + assert v5.dt(B) == q1d*A.x + (q3*q2d + q2d)*A.y + (-q2*q2d + q3d)*A.z + assert v5.dt(A) == q1d*A.x + q2d*A.y + q3d*A.z + assert v5.dt(N) == (-q2*q3d + q1d)*A.x + (q1*q3d + q2d)*A.y + q3d*A.z + assert v3.diff(q1d, N) == 0 + assert v3.diff(q2d, N) == A.x - q3 * cos(q3) * N.z + assert v3.diff(q3d, N) == q3 * N.x + N.y + assert v3.diff(q1d, A) == 0 + assert v3.diff(q2d, A) == A.x - q3 * cos(q3) * N.z + assert v3.diff(q3d, A) == q3 * N.x + N.y + assert v3.diff(q1d, B) == 0 + assert v3.diff(q2d, B) == A.x - q3 * cos(q3) * N.z + assert v3.diff(q3d, B) == q3 * N.x + N.y + assert v4.diff(q1d, N) == 0 + assert v4.diff(q2d, N) == A.x - q3 * cos(q3) * N.z + assert v4.diff(q3d, N) == B.x + q3 * N.x + N.y + assert v4.diff(q1d, A) == 0 + assert v4.diff(q2d, A) == A.x - q3 * cos(q3) * N.z + assert v4.diff(q3d, A) == B.x + q3 * N.x + N.y + assert v4.diff(q1d, B) == 0 + assert v4.diff(q2d, B) == A.x - q3 * cos(q3) * N.z + assert v4.diff(q3d, B) == B.x + q3 * N.x + N.y + + # diff() should only express vector components in the derivative frame if + # the orientation of the component's frame depends on the variable + v6 = q2**2*N.y + q2**2*A.y + q2**2*B.y + # already expressed in N + n_measy = 2*q2 + # A_C_N does not depend on q2, so don't express in N + a_measy = 2*q2 + # B_C_N depends on q2, so express in N + b_measx = (q2**2*B.y).dot(N.x).diff(q2) + b_measy = (q2**2*B.y).dot(N.y).diff(q2) + b_measz = (q2**2*B.y).dot(N.z).diff(q2) + n_comp, a_comp = v6.diff(q2, N).args + assert len(v6.diff(q2, N).args) == 2 # only N and A parts + assert n_comp[1] == N + assert a_comp[1] == A + assert n_comp[0] == Matrix([b_measx, b_measy + n_measy, b_measz]) + assert a_comp[0] == Matrix([0, a_measy, 0]) + + +def test_vector_var_in_dcm(): + + N = ReferenceFrame('N') + A = ReferenceFrame('A') + B = ReferenceFrame('B') + u1, u2, u3, u4 = dynamicsymbols('u1 u2 u3 u4') + + v = u1 * u2 * A.x + u3 * N.y + u4**2 * N.z + + assert v.diff(u1, N, var_in_dcm=False) == u2 * A.x + assert v.diff(u1, A, var_in_dcm=False) == u2 * A.x + assert v.diff(u3, N, var_in_dcm=False) == N.y + assert v.diff(u3, A, var_in_dcm=False) == N.y + assert v.diff(u3, B, var_in_dcm=False) == N.y + assert v.diff(u4, N, var_in_dcm=False) == 2 * u4 * N.z + + raises(ValueError, lambda: v.diff(u1, N)) + + +def test_vector_simplify(): + x, y, z, k, n, m, w, f, s, A = symbols('x, y, z, k, n, m, w, f, s, A') + N = ReferenceFrame('N') + + test1 = (1 / x + 1 / y) * N.x + assert (test1 & N.x) != (x + y) / (x * y) + test1 = test1.simplify() + assert (test1 & N.x) == (x + y) / (x * y) + + test2 = (A**2 * s**4 / (4 * pi * k * m**3)) * N.x + test2 = test2.simplify() + assert (test2 & N.x) == (A**2 * s**4 / (4 * pi * k * m**3)) + + test3 = ((4 + 4 * x - 2 * (2 + 2 * x)) / (2 + 2 * x)) * N.x + test3 = test3.simplify() + assert (test3 & N.x) == 0 + + test4 = ((-4 * x * y**2 - 2 * y**3 - 2 * x**2 * y) / (x + y)**2) * N.x + test4 = test4.simplify() + assert (test4 & N.x) == -2 * y + + +def test_vector_evalf(): + a, b = symbols('a b') + v = pi * A.x + assert v.evalf(2) == Float('3.1416', 2) * A.x + v = pi * A.x + 5 * a * A.y - b * A.z + assert v.evalf(3) == Float('3.1416', 3) * A.x + Float('5', 3) * a * A.y - b * A.z + assert v.evalf(5, subs={a: 1.234, b:5.8973}) == Float('3.1415926536', 5) * A.x + Float('6.17', 5) * A.y - Float('5.8973', 5) * A.z + + +def test_vector_angle(): + A = ReferenceFrame('A') + v1 = A.x + A.y + v2 = A.z + assert v1.angle_between(v2) == pi/2 + B = ReferenceFrame('B') + B.orient_axis(A, A.x, pi) + v3 = A.x + v4 = B.x + assert v3.angle_between(v4) == 0 + + +def test_vector_xreplace(): + x, y, z = symbols('x y z') + v = x**2 * A.x + x*y * A.y + x*y*z * A.z + assert v.xreplace({x : cos(x)}) == cos(x)**2 * A.x + y*cos(x) * A.y + y*z*cos(x) * A.z + assert v.xreplace({x*y : pi}) == x**2 * A.x + pi * A.y + x*y*z * A.z + assert v.xreplace({x*y*z : 1}) == x**2*A.x + x*y*A.y + A.z + assert v.xreplace({x:1, z:0}) == A.x + y * A.y + raises(TypeError, lambda: v.xreplace()) + raises(TypeError, lambda: v.xreplace([x, y])) + +def test_issue_23366(): + u1 = dynamicsymbols('u1') + N = ReferenceFrame('N') + N_v_A = u1*N.x + raises(VectorTypeError, lambda: N_v_A.diff(N, u1)) + + +def test_vector_outer(): + a, b, c, d, e, f = symbols('a, b, c, d, e, f') + N = ReferenceFrame('N') + v1 = a*N.x + b*N.y + c*N.z + v2 = d*N.x + e*N.y + f*N.z + v1v2 = Matrix([[a*d, a*e, a*f], + [b*d, b*e, b*f], + [c*d, c*e, c*f]]) + assert v1.outer(v2).to_matrix(N) == v1v2 + assert (v1 | v2).to_matrix(N) == v1v2 + v2v1 = Matrix([[d*a, d*b, d*c], + [e*a, e*b, e*c], + [f*a, f*b, f*c]]) + assert v2.outer(v1).to_matrix(N) == v2v1 + assert (v2 | v1).to_matrix(N) == v2v1 + + +def test_overloaded_operators(): + a, b, c, d, e, f = symbols('a, b, c, d, e, f') + N = ReferenceFrame('N') + v1 = a*N.x + b*N.y + c*N.z + v2 = d*N.x + e*N.y + f*N.z + + assert v1 + v2 == v2 + v1 + assert v1 - v2 == -v2 + v1 + assert v1 & v2 == v2 & v1 + assert v1 ^ v2 == v1.cross(v2) + assert v2 ^ v1 == v2.cross(v1) diff --git a/phivenv/Lib/site-packages/sympy/physics/vector/vector.py b/phivenv/Lib/site-packages/sympy/physics/vector/vector.py new file mode 100644 index 0000000000000000000000000000000000000000..96510c7c55470e0605276a924ce9777f226acd8e --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/vector/vector.py @@ -0,0 +1,806 @@ +from sympy import (S, sympify, expand, sqrt, Add, zeros, acos, + ImmutableMatrix as Matrix, simplify) +from sympy.simplify.trigsimp import trigsimp +from sympy.printing.defaults import Printable +from sympy.utilities.misc import filldedent +from sympy.core.evalf import EvalfMixin + +from mpmath.libmp.libmpf import prec_to_dps + + +__all__ = ['Vector'] + + +class Vector(Printable, EvalfMixin): + """The class used to define vectors. + + It along with ReferenceFrame are the building blocks of describing a + classical mechanics system in PyDy and sympy.physics.vector. + + Attributes + ========== + + simp : Boolean + Let certain methods use trigsimp on their outputs + + """ + + simp = False + is_number = False + + def __init__(self, inlist): + """This is the constructor for the Vector class. You should not be + calling this, it should only be used by other functions. You should be + treating Vectors like you would with if you were doing the math by + hand, and getting the first 3 from the standard basis vectors from a + ReferenceFrame. + + The only exception is to create a zero vector: + zv = Vector(0) + + """ + + self.args = [] + if inlist == 0: + inlist = [] + if isinstance(inlist, dict): + d = inlist + else: + d = {} + for inp in inlist: + if inp[1] in d: + d[inp[1]] += inp[0] + else: + d[inp[1]] = inp[0] + + for k, v in d.items(): + if v != Matrix([0, 0, 0]): + self.args.append((v, k)) + + @property + def func(self): + """Returns the class Vector. """ + return Vector + + def __hash__(self): + return hash(tuple(self.args)) + + def __add__(self, other): + """The add operator for Vector. """ + if other == 0: + return self + other = _check_vector(other) + return Vector(self.args + other.args) + + def dot(self, other): + """Dot product of two vectors. + + Returns a scalar, the dot product of the two Vectors + + Parameters + ========== + + other : Vector + The Vector which we are dotting with + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame, dot + >>> from sympy import symbols + >>> q1 = symbols('q1') + >>> N = ReferenceFrame('N') + >>> dot(N.x, N.x) + 1 + >>> dot(N.x, N.y) + 0 + >>> A = N.orientnew('A', 'Axis', [q1, N.x]) + >>> dot(N.y, A.y) + cos(q1) + + """ + + from sympy.physics.vector.dyadic import Dyadic, _check_dyadic + if isinstance(other, Dyadic): + other = _check_dyadic(other) + ol = Vector(0) + for v in other.args: + ol += v[0] * v[2] * (v[1].dot(self)) + return ol + other = _check_vector(other) + out = S.Zero + for v1 in self.args: + for v2 in other.args: + out += ((v2[0].T) * (v2[1].dcm(v1[1])) * (v1[0]))[0] + if Vector.simp: + return trigsimp(out, recursive=True) + else: + return out + + def __truediv__(self, other): + """This uses mul and inputs self and 1 divided by other. """ + return self.__mul__(S.One / other) + + def __eq__(self, other): + """Tests for equality. + + It is very import to note that this is only as good as the SymPy + equality test; False does not always mean they are not equivalent + Vectors. + If other is 0, and self is empty, returns True. + If other is 0 and self is not empty, returns False. + If none of the above, only accepts other as a Vector. + + """ + + if other == 0: + other = Vector(0) + try: + other = _check_vector(other) + except TypeError: + return False + if (self.args == []) and (other.args == []): + return True + elif (self.args == []) or (other.args == []): + return False + + frame = self.args[0][1] + for v in frame: + if expand((self - other).dot(v)) != 0: + return False + return True + + def __mul__(self, other): + """Multiplies the Vector by a sympifyable expression. + + Parameters + ========== + + other : Sympifyable + The scalar to multiply this Vector with + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame + >>> from sympy import Symbol + >>> N = ReferenceFrame('N') + >>> b = Symbol('b') + >>> V = 10 * b * N.x + >>> print(V) + 10*b*N.x + + """ + + newlist = list(self.args) + other = sympify(other) + for i in range(len(newlist)): + newlist[i] = (other * newlist[i][0], newlist[i][1]) + return Vector(newlist) + + def __neg__(self): + return self * -1 + + def outer(self, other): + """Outer product between two Vectors. + + A rank increasing operation, which returns a Dyadic from two Vectors + + Parameters + ========== + + other : Vector + The Vector to take the outer product with + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame, outer + >>> N = ReferenceFrame('N') + >>> outer(N.x, N.x) + (N.x|N.x) + + """ + + from sympy.physics.vector.dyadic import Dyadic + other = _check_vector(other) + ol = Dyadic(0) + for v in self.args: + for v2 in other.args: + # it looks this way because if we are in the same frame and + # use the enumerate function on the same frame in a nested + # fashion, then bad things happen + ol += Dyadic([(v[0][0] * v2[0][0], v[1].x, v2[1].x)]) + ol += Dyadic([(v[0][0] * v2[0][1], v[1].x, v2[1].y)]) + ol += Dyadic([(v[0][0] * v2[0][2], v[1].x, v2[1].z)]) + ol += Dyadic([(v[0][1] * v2[0][0], v[1].y, v2[1].x)]) + ol += Dyadic([(v[0][1] * v2[0][1], v[1].y, v2[1].y)]) + ol += Dyadic([(v[0][1] * v2[0][2], v[1].y, v2[1].z)]) + ol += Dyadic([(v[0][2] * v2[0][0], v[1].z, v2[1].x)]) + ol += Dyadic([(v[0][2] * v2[0][1], v[1].z, v2[1].y)]) + ol += Dyadic([(v[0][2] * v2[0][2], v[1].z, v2[1].z)]) + return ol + + def _latex(self, printer): + """Latex Printing method. """ + + ar = self.args # just to shorten things + if len(ar) == 0: + return str(0) + ol = [] # output list, to be concatenated to a string + for v in ar: + for j in 0, 1, 2: + # if the coef of the basis vector is 1, we skip the 1 + if v[0][j] == 1: + ol.append(' + ' + v[1].latex_vecs[j]) + # if the coef of the basis vector is -1, we skip the 1 + elif v[0][j] == -1: + ol.append(' - ' + v[1].latex_vecs[j]) + elif v[0][j] != 0: + # If the coefficient of the basis vector is not 1 or -1; + # also, we might wrap it in parentheses, for readability. + arg_str = printer._print(v[0][j]) + if isinstance(v[0][j], Add): + arg_str = "(%s)" % arg_str + if arg_str[0] == '-': + arg_str = arg_str[1:] + str_start = ' - ' + else: + str_start = ' + ' + ol.append(str_start + arg_str + v[1].latex_vecs[j]) + outstr = ''.join(ol) + if outstr.startswith(' + '): + outstr = outstr[3:] + elif outstr.startswith(' '): + outstr = outstr[1:] + return outstr + + def _pretty(self, printer): + """Pretty Printing method. """ + from sympy.printing.pretty.stringpict import prettyForm + + terms = [] + + def juxtapose(a, b): + pa = printer._print(a) + pb = printer._print(b) + if a.is_Add: + pa = prettyForm(*pa.parens()) + return printer._print_seq([pa, pb], delimiter=' ') + + for M, N in self.args: + for i in range(3): + if M[i] == 0: + continue + elif M[i] == 1: + terms.append(prettyForm(N.pretty_vecs[i])) + elif M[i] == -1: + terms.append(prettyForm("-1") * prettyForm(N.pretty_vecs[i])) + else: + terms.append(juxtapose(M[i], N.pretty_vecs[i])) + + if terms: + pretty_result = prettyForm.__add__(*terms) + else: + pretty_result = prettyForm("0") + + return pretty_result + + def __rsub__(self, other): + return (-1 * self) + other + + def _sympystr(self, printer, order=True): + """Printing method. """ + if not order or len(self.args) == 1: + ar = list(self.args) + elif len(self.args) == 0: + return printer._print(0) + else: + d = {v[1]: v[0] for v in self.args} + keys = sorted(d.keys(), key=lambda x: x.index) + ar = [] + for key in keys: + ar.append((d[key], key)) + ol = [] # output list, to be concatenated to a string + for v in ar: + for j in 0, 1, 2: + # if the coef of the basis vector is 1, we skip the 1 + if v[0][j] == 1: + ol.append(' + ' + v[1].str_vecs[j]) + # if the coef of the basis vector is -1, we skip the 1 + elif v[0][j] == -1: + ol.append(' - ' + v[1].str_vecs[j]) + elif v[0][j] != 0: + # If the coefficient of the basis vector is not 1 or -1; + # also, we might wrap it in parentheses, for readability. + arg_str = printer._print(v[0][j]) + if isinstance(v[0][j], Add): + arg_str = "(%s)" % arg_str + if arg_str[0] == '-': + arg_str = arg_str[1:] + str_start = ' - ' + else: + str_start = ' + ' + ol.append(str_start + arg_str + '*' + v[1].str_vecs[j]) + outstr = ''.join(ol) + if outstr.startswith(' + '): + outstr = outstr[3:] + elif outstr.startswith(' '): + outstr = outstr[1:] + return outstr + + def __sub__(self, other): + """The subtraction operator. """ + return self.__add__(other * -1) + + def cross(self, other): + """The cross product operator for two Vectors. + + Returns a Vector, expressed in the same ReferenceFrames as self. + + Parameters + ========== + + other : Vector + The Vector which we are crossing with + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.vector import ReferenceFrame, cross + >>> q1 = symbols('q1') + >>> N = ReferenceFrame('N') + >>> cross(N.x, N.y) + N.z + >>> A = ReferenceFrame('A') + >>> A.orient_axis(N, q1, N.x) + >>> cross(A.x, N.y) + N.z + >>> cross(N.y, A.x) + - sin(q1)*A.y - cos(q1)*A.z + + """ + + from sympy.physics.vector.dyadic import Dyadic, _check_dyadic + if isinstance(other, Dyadic): + other = _check_dyadic(other) + ol = Dyadic(0) + for i, v in enumerate(other.args): + ol += v[0] * ((self.cross(v[1])).outer(v[2])) + return ol + other = _check_vector(other) + if other.args == []: + return Vector(0) + + def _det(mat): + """This is needed as a little method for to find the determinant + of a list in python; needs to work for a 3x3 list. + SymPy's Matrix will not take in Vector, so need a custom function. + You should not be calling this. + + """ + + return (mat[0][0] * (mat[1][1] * mat[2][2] - mat[1][2] * mat[2][1]) + + mat[0][1] * (mat[1][2] * mat[2][0] - mat[1][0] * + mat[2][2]) + mat[0][2] * (mat[1][0] * mat[2][1] - + mat[1][1] * mat[2][0])) + + outlist = [] + ar = other.args # For brevity + for v in ar: + tempx = v[1].x + tempy = v[1].y + tempz = v[1].z + tempm = ([[tempx, tempy, tempz], + [self.dot(tempx), self.dot(tempy), self.dot(tempz)], + [Vector([v]).dot(tempx), Vector([v]).dot(tempy), + Vector([v]).dot(tempz)]]) + outlist += _det(tempm).args + return Vector(outlist) + + __radd__ = __add__ + __rmul__ = __mul__ + + def separate(self): + """ + The constituents of this vector in different reference frames, + as per its definition. + + Returns a dict mapping each ReferenceFrame to the corresponding + constituent Vector. + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame + >>> R1 = ReferenceFrame('R1') + >>> R2 = ReferenceFrame('R2') + >>> v = R1.x + R2.x + >>> v.separate() == {R1: R1.x, R2: R2.x} + True + + """ + + components = {} + for x in self.args: + components[x[1]] = Vector([x]) + return components + + def __and__(self, other): + return self.dot(other) + __and__.__doc__ = dot.__doc__ + __rand__ = __and__ + + def __xor__(self, other): + return self.cross(other) + __xor__.__doc__ = cross.__doc__ + + def __or__(self, other): + return self.outer(other) + __or__.__doc__ = outer.__doc__ + + def diff(self, var, frame, var_in_dcm=True): + """Returns the partial derivative of the vector with respect to a + variable in the provided reference frame. + + Parameters + ========== + var : Symbol + What the partial derivative is taken with respect to. + frame : ReferenceFrame + The reference frame that the partial derivative is taken in. + var_in_dcm : boolean + If true, the differentiation algorithm assumes that the variable + may be present in any of the direction cosine matrices that relate + the frame to the frames of any component of the vector. But if it + is known that the variable is not present in the direction cosine + matrices, false can be set to skip full reexpression in the desired + frame. + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy.physics.vector import dynamicsymbols, ReferenceFrame + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> t = Symbol('t') + >>> q1 = dynamicsymbols('q1') + >>> N = ReferenceFrame('N') + >>> A = N.orientnew('A', 'Axis', [q1, N.y]) + >>> A.x.diff(t, N) + - sin(q1)*q1'*N.x - cos(q1)*q1'*N.z + >>> A.x.diff(t, N).express(A).simplify() + - q1'*A.z + >>> B = ReferenceFrame('B') + >>> u1, u2 = dynamicsymbols('u1, u2') + >>> v = u1 * A.x + u2 * B.y + >>> v.diff(u2, N, var_in_dcm=False) + B.y + + """ + + from sympy.physics.vector.frame import _check_frame + + _check_frame(frame) + var = sympify(var) + + inlist = [] + + for vector_component in self.args: + measure_number = vector_component[0] + component_frame = vector_component[1] + if component_frame == frame: + inlist += [(measure_number.diff(var), frame)] + else: + # If the direction cosine matrix relating the component frame + # with the derivative frame does not contain the variable. + if not var_in_dcm or (frame.dcm(component_frame).diff(var) == + zeros(3, 3)): + inlist += [(measure_number.diff(var), component_frame)] + else: # else express in the frame + reexp_vec_comp = Vector([vector_component]).express(frame) + deriv = reexp_vec_comp.args[0][0].diff(var) + inlist += Vector([(deriv, frame)]).args + + return Vector(inlist) + + def express(self, otherframe, variables=False): + """ + Returns a Vector equivalent to this one, expressed in otherframe. + Uses the global express method. + + Parameters + ========== + + otherframe : ReferenceFrame + The frame for this Vector to be described in + + variables : boolean + If True, the coordinate symbols(if present) in this Vector + are re-expressed in terms otherframe + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame, dynamicsymbols + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> q1 = dynamicsymbols('q1') + >>> N = ReferenceFrame('N') + >>> A = N.orientnew('A', 'Axis', [q1, N.y]) + >>> A.x.express(N) + cos(q1)*N.x - sin(q1)*N.z + + """ + from sympy.physics.vector import express + return express(self, otherframe, variables=variables) + + def to_matrix(self, reference_frame): + """Returns the matrix form of the vector with respect to the given + frame. + + Parameters + ---------- + reference_frame : ReferenceFrame + The reference frame that the rows of the matrix correspond to. + + Returns + ------- + matrix : ImmutableMatrix, shape(3,1) + The matrix that gives the 1D vector. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.vector import ReferenceFrame + >>> a, b, c = symbols('a, b, c') + >>> N = ReferenceFrame('N') + >>> vector = a * N.x + b * N.y + c * N.z + >>> vector.to_matrix(N) + Matrix([ + [a], + [b], + [c]]) + >>> beta = symbols('beta') + >>> A = N.orientnew('A', 'Axis', (beta, N.x)) + >>> vector.to_matrix(A) + Matrix([ + [ a], + [ b*cos(beta) + c*sin(beta)], + [-b*sin(beta) + c*cos(beta)]]) + + """ + + return Matrix([self.dot(unit_vec) for unit_vec in + reference_frame]).reshape(3, 1) + + def doit(self, **hints): + """Calls .doit() on each term in the Vector""" + d = {} + for v in self.args: + d[v[1]] = v[0].applyfunc(lambda x: x.doit(**hints)) + return Vector(d) + + def dt(self, otherframe): + """ + Returns a Vector which is the time derivative of + the self Vector, taken in frame otherframe. + + Calls the global time_derivative method + + Parameters + ========== + + otherframe : ReferenceFrame + The frame to calculate the time derivative in + + """ + from sympy.physics.vector import time_derivative + return time_derivative(self, otherframe) + + def simplify(self): + """Returns a simplified Vector.""" + d = {} + for v in self.args: + d[v[1]] = simplify(v[0]) + return Vector(d) + + def subs(self, *args, **kwargs): + """Substitution on the Vector. + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame + >>> from sympy import Symbol + >>> N = ReferenceFrame('N') + >>> s = Symbol('s') + >>> a = N.x * s + >>> a.subs({s: 2}) + 2*N.x + + """ + + d = {} + for v in self.args: + d[v[1]] = v[0].subs(*args, **kwargs) + return Vector(d) + + def magnitude(self): + """Returns the magnitude (Euclidean norm) of self. + + Warnings + ======== + + Python ignores the leading negative sign so that might + give wrong results. + ``-A.x.magnitude()`` would be treated as ``-(A.x.magnitude())``, + instead of ``(-A.x).magnitude()``. + + """ + return sqrt(self.dot(self)) + + def normalize(self): + """Returns a Vector of magnitude 1, codirectional with self.""" + return Vector(self.args + []) / self.magnitude() + + def applyfunc(self, f): + """Apply a function to each component of a vector.""" + if not callable(f): + raise TypeError("`f` must be callable.") + + d = {} + for v in self.args: + d[v[1]] = v[0].applyfunc(f) + return Vector(d) + + def angle_between(self, vec): + """ + Returns the smallest angle between Vector 'vec' and self. + + Parameter + ========= + + vec : Vector + The Vector between which angle is needed. + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame + >>> A = ReferenceFrame("A") + >>> v1 = A.x + >>> v2 = A.y + >>> v1.angle_between(v2) + pi/2 + + >>> v3 = A.x + A.y + A.z + >>> v1.angle_between(v3) + acos(sqrt(3)/3) + + Warnings + ======== + + Python ignores the leading negative sign so that might give wrong + results. ``-A.x.angle_between()`` would be treated as + ``-(A.x.angle_between())``, instead of ``(-A.x).angle_between()``. + + """ + + vec1 = self.normalize() + vec2 = vec.normalize() + angle = acos(vec1.dot(vec2)) + return angle + + def free_symbols(self, reference_frame): + """Returns the free symbols in the measure numbers of the vector + expressed in the given reference frame. + + Parameters + ========== + reference_frame : ReferenceFrame + The frame with respect to which the free symbols of the given + vector is to be determined. + + Returns + ======= + set of Symbol + set of symbols present in the measure numbers of + ``reference_frame``. + + """ + + return self.to_matrix(reference_frame).free_symbols + + def free_dynamicsymbols(self, reference_frame): + """Returns the free dynamic symbols (functions of time ``t``) in the + measure numbers of the vector expressed in the given reference frame. + + Parameters + ========== + reference_frame : ReferenceFrame + The frame with respect to which the free dynamic symbols of the + given vector is to be determined. + + Returns + ======= + set + Set of functions of time ``t``, e.g. + ``Function('f')(me.dynamicsymbols._t)``. + + """ + # TODO : Circular dependency if imported at top. Should move + # find_dynamicsymbols into physics.vector.functions. + from sympy.physics.mechanics.functions import find_dynamicsymbols + + return find_dynamicsymbols(self, reference_frame=reference_frame) + + def _eval_evalf(self, prec): + if not self.args: + return self + new_args = [] + dps = prec_to_dps(prec) + for mat, frame in self.args: + new_args.append([mat.evalf(n=dps), frame]) + return Vector(new_args) + + def xreplace(self, rule): + """Replace occurrences of objects within the measure numbers of the + vector. + + Parameters + ========== + + rule : dict-like + Expresses a replacement rule. + + Returns + ======= + + Vector + Result of the replacement. + + Examples + ======== + + >>> from sympy import symbols, pi + >>> from sympy.physics.vector import ReferenceFrame + >>> A = ReferenceFrame('A') + >>> x, y, z = symbols('x y z') + >>> ((1 + x*y) * A.x).xreplace({x: pi}) + (pi*y + 1)*A.x + >>> ((1 + x*y) * A.x).xreplace({x: pi, y: 2}) + (1 + 2*pi)*A.x + + Replacements occur only if an entire node in the expression tree is + matched: + + >>> ((x*y + z) * A.x).xreplace({x*y: pi}) + (z + pi)*A.x + >>> ((x*y*z) * A.x).xreplace({x*y: pi}) + x*y*z*A.x + + """ + + new_args = [] + for mat, frame in self.args: + mat = mat.xreplace(rule) + new_args.append([mat, frame]) + return Vector(new_args) + + +class VectorTypeError(TypeError): + + def __init__(self, other, want): + msg = filldedent("Expected an instance of %s, but received object " + "'%s' of %s." % (type(want), other, type(other))) + super().__init__(msg) + + +def _check_vector(other): + if not isinstance(other, Vector): + raise TypeError('A Vector must be supplied') + return other diff --git a/phivenv/Lib/site-packages/sympy/physics/wigner.py b/phivenv/Lib/site-packages/sympy/physics/wigner.py new file mode 100644 index 0000000000000000000000000000000000000000..e08f3fb4a480439fd2bb1f8ff8c305bf69d7abae --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/physics/wigner.py @@ -0,0 +1,1213 @@ +# -*- coding: utf-8 -*- +r""" +Wigner, Clebsch-Gordan, Racah, and Gaunt coefficients + +Collection of functions for calculating Wigner 3j, 6j, 9j, +Clebsch-Gordan, Racah as well as Gaunt coefficients exactly, all +evaluating to a rational number times the square root of a rational +number [Rasch03]_. + +Please see the description of the individual functions for further +details and examples. + +References +========== + +.. [Regge58] 'Symmetry Properties of Clebsch-Gordan Coefficients', + T. Regge, Nuovo Cimento, Volume 10, pp. 544 (1958) +.. [Regge59] 'Symmetry Properties of Racah Coefficients', + T. Regge, Nuovo Cimento, Volume 11, pp. 116 (1959) +.. [Edmonds74] A. R. Edmonds. Angular momentum in quantum mechanics. + Investigations in physics, 4.; Investigations in physics, no. 4. + Princeton, N.J., Princeton University Press, 1957. +.. [Rasch03] J. Rasch and A. C. H. Yu, 'Efficient Storage Scheme for + Pre-calculated Wigner 3j, 6j and Gaunt Coefficients', SIAM + J. Sci. Comput. Volume 25, Issue 4, pp. 1416-1428 (2003) +.. [Liberatodebrito82] 'FORTRAN program for the integral of three + spherical harmonics', A. Liberato de Brito, + Comput. Phys. Commun., Volume 25, pp. 81-85 (1982) +.. [Homeier96] 'Some Properties of the Coupling Coefficients of Real + Spherical Harmonics and Their Relation to Gaunt Coefficients', + H. H. H. Homeier and E. O. Steinborn J. Mol. Struct., Volume 368, + pp. 31-37 (1996) + +Credits and Copyright +===================== + +This code was taken from Sage with the permission of all authors: + +https://groups.google.com/forum/#!topic/sage-devel/M4NZdu-7O38 + +Authors +======= + +- Jens Rasch (2009-03-24): initial version for Sage + +- Jens Rasch (2009-05-31): updated to sage-4.0 + +- Oscar Gerardo Lazo Arjona (2017-06-18): added Wigner D matrices + +- Phil Adam LeMaitre (2022-09-19): added real Gaunt coefficient + +Copyright (C) 2008 Jens Rasch + +""" +from sympy.concrete.summations import Sum +from sympy.core.add import Add +from sympy.core.numbers import int_valued +from sympy.core.function import Function +from sympy.core.numbers import (Float, I, Integer, pi, Rational) +from sympy.core.singleton import S +from sympy.core.symbol import Dummy +from sympy.core.sympify import sympify +from sympy.functions.combinatorial.factorials import (binomial, factorial) +from sympy.functions.elementary.complexes import re +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.functions.special.spherical_harmonics import Ynm +from sympy.matrices.dense import zeros +from sympy.matrices.immutable import ImmutableMatrix +from sympy.utilities.misc import as_int + +# This list of precomputed factorials is needed to massively +# accelerate future calculations of the various coefficients +_Factlist = [1] + + +def _calc_factlist(nn): + r""" + Function calculates a list of precomputed factorials in order to + massively accelerate future calculations of the various + coefficients. + + Parameters + ========== + + nn : integer + Highest factorial to be computed. + + Returns + ======= + + list of integers : + The list of precomputed factorials. + + Examples + ======== + + Calculate list of factorials:: + + sage: from sage.functions.wigner import _calc_factlist + sage: _calc_factlist(10) + [1, 1, 2, 6, 24, 120, 720, 5040, 40320, 362880, 3628800] + """ + if nn >= len(_Factlist): + for ii in range(len(_Factlist), int(nn + 1)): + _Factlist.append(_Factlist[ii - 1] * ii) + return _Factlist[:int(nn) + 1] + + +def _int_or_halfint(value): + """return Python int unless value is half-int (then return float)""" + if isinstance(value, int): + return value + elif type(value) is float: + if value.is_integer(): + return int(value) # an int + if (2*value).is_integer(): + return value # a float + elif isinstance(value, Rational): + if value.q == 2: + return value.p/value.q # a float + elif value.q == 1: + return value.p # an int + elif isinstance(value, Float): + return _int_or_halfint(float(value)) + raise ValueError("expecting integer or half-integer, got %s" % value) + + +def wigner_3j(j_1, j_2, j_3, m_1, m_2, m_3): + r""" + Calculate the Wigner 3j symbol `\operatorname{Wigner3j}(j_1,j_2,j_3,m_1,m_2,m_3)`. + + Parameters + ========== + + j_1, j_2, j_3, m_1, m_2, m_3 : + Integer or half integer. + + Returns + ======= + + Rational number times the square root of a rational number. + + Examples + ======== + + >>> from sympy.physics.wigner import wigner_3j + >>> wigner_3j(2, 6, 4, 0, 0, 0) + sqrt(715)/143 + >>> wigner_3j(2, 6, 4, 0, 0, 1) + 0 + + It is an error to have arguments that are not integer or half + integer values:: + + sage: wigner_3j(2.1, 6, 4, 0, 0, 0) + Traceback (most recent call last): + ... + ValueError: j values must be integer or half integer + sage: wigner_3j(2, 6, 4, 1, 0, -1.1) + Traceback (most recent call last): + ... + ValueError: m values must be integer or half integer + + Notes + ===== + + The Wigner 3j symbol obeys the following symmetry rules: + + - invariant under any permutation of the columns (with the + exception of a sign change where `J:=j_1+j_2+j_3`): + + .. math:: + + \begin{aligned} + \operatorname{Wigner3j}(j_1,j_2,j_3,m_1,m_2,m_3) + &=\operatorname{Wigner3j}(j_3,j_1,j_2,m_3,m_1,m_2) \\ + &=\operatorname{Wigner3j}(j_2,j_3,j_1,m_2,m_3,m_1) \\ + &=(-1)^J \operatorname{Wigner3j}(j_3,j_2,j_1,m_3,m_2,m_1) \\ + &=(-1)^J \operatorname{Wigner3j}(j_1,j_3,j_2,m_1,m_3,m_2) \\ + &=(-1)^J \operatorname{Wigner3j}(j_2,j_1,j_3,m_2,m_1,m_3) + \end{aligned} + + - invariant under space inflection, i.e. + + .. math:: + + \operatorname{Wigner3j}(j_1,j_2,j_3,m_1,m_2,m_3) + =(-1)^J \operatorname{Wigner3j}(j_1,j_2,j_3,-m_1,-m_2,-m_3) + + - symmetric with respect to the 72 additional symmetries based on + the work by [Regge58]_ + + - zero for `j_1`, `j_2`, `j_3` not fulfilling triangle relation + + - zero for `m_1 + m_2 + m_3 \neq 0` + + - zero for violating any one of the conditions + `m_1 \in \{-|j_1|, \ldots, |j_1|\}`, + `m_2 \in \{-|j_2|, \ldots, |j_2|\}`, + `m_3 \in \{-|j_3|, \ldots, |j_3|\}` + + Algorithm + ========= + + This function uses the algorithm of [Edmonds74]_ to calculate the + value of the 3j symbol exactly. Note that the formula contains + alternating sums over large factorials and is therefore unsuitable + for finite precision arithmetic and only useful for a computer + algebra system [Rasch03]_. + + Authors + ======= + + - Jens Rasch (2009-03-24): initial version + """ + + j_1, j_2, j_3, m_1, m_2, m_3 = \ + map(_int_or_halfint, map(sympify, + [j_1, j_2, j_3, m_1, m_2, m_3])) + + if m_1 + m_2 + m_3 != 0: + return S.Zero + a1 = j_1 + j_2 - j_3 + if a1 < 0: + return S.Zero + a2 = j_1 - j_2 + j_3 + if a2 < 0: + return S.Zero + a3 = -j_1 + j_2 + j_3 + if a3 < 0: + return S.Zero + if (abs(m_1) > j_1) or (abs(m_2) > j_2) or (abs(m_3) > j_3): + return S.Zero + if not (int_valued(j_1 - m_1) and \ + int_valued(j_2 - m_2) and \ + int_valued(j_3 - m_3)): + return S.Zero + + maxfact = max(j_1 + j_2 + j_3 + 1, j_1 + abs(m_1), j_2 + abs(m_2), + j_3 + abs(m_3)) + _calc_factlist(int(maxfact)) + + argsqrt = Integer(_Factlist[int(j_1 + j_2 - j_3)] * + _Factlist[int(j_1 - j_2 + j_3)] * + _Factlist[int(-j_1 + j_2 + j_3)] * + _Factlist[int(j_1 - m_1)] * + _Factlist[int(j_1 + m_1)] * + _Factlist[int(j_2 - m_2)] * + _Factlist[int(j_2 + m_2)] * + _Factlist[int(j_3 - m_3)] * + _Factlist[int(j_3 + m_3)]) / \ + _Factlist[int(j_1 + j_2 + j_3 + 1)] + + ressqrt = sqrt(argsqrt) + if ressqrt.is_complex or ressqrt.is_infinite: + ressqrt = ressqrt.as_real_imag()[0] + + imin = max(-j_3 + j_1 + m_2, -j_3 + j_2 - m_1, 0) + imax = min(j_2 + m_2, j_1 - m_1, j_1 + j_2 - j_3) + sumres = 0 + for ii in range(int(imin), int(imax) + 1): + den = _Factlist[ii] * \ + _Factlist[int(ii + j_3 - j_1 - m_2)] * \ + _Factlist[int(j_2 + m_2 - ii)] * \ + _Factlist[int(j_1 - ii - m_1)] * \ + _Factlist[int(ii + j_3 - j_2 + m_1)] * \ + _Factlist[int(j_1 + j_2 - j_3 - ii)] + sumres = sumres + Integer((-1) ** ii) / den + + prefid = Integer((-1) ** int(j_1 - j_2 - m_3)) + res = ressqrt * sumres * prefid + return res + + +def clebsch_gordan(j_1, j_2, j_3, m_1, m_2, m_3): + r""" + Calculates the Clebsch-Gordan coefficient. + `\left\langle j_1 m_1 \; j_2 m_2 | j_3 m_3 \right\rangle`. + + The reference for this function is [Edmonds74]_. + + Parameters + ========== + + j_1, j_2, j_3, m_1, m_2, m_3 : + Integer or half integer. + + Returns + ======= + + Rational number times the square root of a rational number. + + Examples + ======== + + >>> from sympy import S + >>> from sympy.physics.wigner import clebsch_gordan + >>> clebsch_gordan(S(3)/2, S(1)/2, 2, S(3)/2, S(1)/2, 2) + 1 + >>> clebsch_gordan(S(3)/2, S(1)/2, 1, S(3)/2, -S(1)/2, 1) + sqrt(3)/2 + >>> clebsch_gordan(S(3)/2, S(1)/2, 1, -S(1)/2, S(1)/2, 0) + -sqrt(2)/2 + + Notes + ===== + + The Clebsch-Gordan coefficient will be evaluated via its relation + to Wigner 3j symbols: + + .. math:: + + \left\langle j_1 m_1 \; j_2 m_2 | j_3 m_3 \right\rangle + =(-1)^{j_1-j_2+m_3} \sqrt{2j_3+1} + \operatorname{Wigner3j}(j_1,j_2,j_3,m_1,m_2,-m_3) + + See also the documentation on Wigner 3j symbols which exhibit much + higher symmetry relations than the Clebsch-Gordan coefficient. + + Authors + ======= + + - Jens Rasch (2009-03-24): initial version + """ + j_1 = sympify(j_1) + j_2 = sympify(j_2) + j_3 = sympify(j_3) + m_1 = sympify(m_1) + m_2 = sympify(m_2) + m_3 = sympify(m_3) + + w = wigner_3j(j_1, j_2, j_3, m_1, m_2, -m_3) + + return (-1) ** (j_1 - j_2 + m_3) * sqrt(2 * j_3 + 1) * w + + +def _big_delta_coeff(aa, bb, cc, prec=None): + r""" + Calculates the Delta coefficient of the 3 angular momenta for + Racah symbols. Also checks that the differences are of integer + value. + + Parameters + ========== + + aa : + First angular momentum, integer or half integer. + bb : + Second angular momentum, integer or half integer. + cc : + Third angular momentum, integer or half integer. + prec : + Precision of the ``sqrt()`` calculation. + + Returns + ======= + + double : Value of the Delta coefficient. + + Examples + ======== + + sage: from sage.functions.wigner import _big_delta_coeff + sage: _big_delta_coeff(1,1,1) + 1/2*sqrt(1/6) + """ + + # the triangle test will only pass if a) all 3 values are ints or + # b) 1 is an int and the other two are half-ints + if not int_valued(aa + bb - cc): + raise ValueError("j values must be integer or half integer and fulfill the triangle relation") + if not int_valued(aa + cc - bb): + raise ValueError("j values must be integer or half integer and fulfill the triangle relation") + if not int_valued(bb + cc - aa): + raise ValueError("j values must be integer or half integer and fulfill the triangle relation") + if (aa + bb - cc) < 0: + return S.Zero + if (aa + cc - bb) < 0: + return S.Zero + if (bb + cc - aa) < 0: + return S.Zero + + maxfact = max(aa + bb - cc, aa + cc - bb, bb + cc - aa, aa + bb + cc + 1) + _calc_factlist(maxfact) + + argsqrt = Integer(_Factlist[int(aa + bb - cc)] * + _Factlist[int(aa + cc - bb)] * + _Factlist[int(bb + cc - aa)]) / \ + Integer(_Factlist[int(aa + bb + cc + 1)]) + + ressqrt = sqrt(argsqrt) + if prec: + ressqrt = ressqrt.evalf(prec).as_real_imag()[0] + return ressqrt + + +def racah(aa, bb, cc, dd, ee, ff, prec=None): + r""" + Calculate the Racah symbol `W(a,b,c,d;e,f)`. + + Parameters + ========== + + a, ..., f : + Integer or half integer. + prec : + Precision, default: ``None``. Providing a precision can + drastically speed up the calculation. + + Returns + ======= + + Rational number times the square root of a rational number + (if ``prec=None``), or real number if a precision is given. + + Examples + ======== + + >>> from sympy.physics.wigner import racah + >>> racah(3,3,3,3,3,3) + -1/14 + + Notes + ===== + + The Racah symbol is related to the Wigner 6j symbol: + + .. math:: + + \operatorname{Wigner6j}(j_1,j_2,j_3,j_4,j_5,j_6) + =(-1)^{j_1+j_2+j_4+j_5} W(j_1,j_2,j_5,j_4,j_3,j_6) + + Please see the 6j symbol for its much richer symmetries and for + additional properties. + + Algorithm + ========= + + This function uses the algorithm of [Edmonds74]_ to calculate the + value of the 6j symbol exactly. Note that the formula contains + alternating sums over large factorials and is therefore unsuitable + for finite precision arithmetic and only useful for a computer + algebra system [Rasch03]_. + + Authors + ======= + + - Jens Rasch (2009-03-24): initial version + """ + prefac = _big_delta_coeff(aa, bb, ee, prec) * \ + _big_delta_coeff(cc, dd, ee, prec) * \ + _big_delta_coeff(aa, cc, ff, prec) * \ + _big_delta_coeff(bb, dd, ff, prec) + if prefac == 0: + return S.Zero + imin = max(aa + bb + ee, cc + dd + ee, aa + cc + ff, bb + dd + ff) + imax = min(aa + bb + cc + dd, aa + dd + ee + ff, bb + cc + ee + ff) + + maxfact = max(imax + 1, aa + bb + cc + dd, aa + dd + ee + ff, + bb + cc + ee + ff) + _calc_factlist(maxfact) + + sumres = 0 + for kk in range(int(imin), int(imax) + 1): + den = _Factlist[int(kk - aa - bb - ee)] * \ + _Factlist[int(kk - cc - dd - ee)] * \ + _Factlist[int(kk - aa - cc - ff)] * \ + _Factlist[int(kk - bb - dd - ff)] * \ + _Factlist[int(aa + bb + cc + dd - kk)] * \ + _Factlist[int(aa + dd + ee + ff - kk)] * \ + _Factlist[int(bb + cc + ee + ff - kk)] + sumres = sumres + Integer((-1) ** kk * _Factlist[kk + 1]) / den + + res = prefac * sumres * (-1) ** int(aa + bb + cc + dd) + return res + + +def wigner_6j(j_1, j_2, j_3, j_4, j_5, j_6, prec=None): + r""" + Calculate the Wigner 6j symbol `\operatorname{Wigner6j}(j_1,j_2,j_3,j_4,j_5,j_6)`. + + Parameters + ========== + + j_1, ..., j_6 : + Integer or half integer. + prec : + Precision, default: ``None``. Providing a precision can + drastically speed up the calculation. + + Returns + ======= + + Rational number times the square root of a rational number + (if ``prec=None``), or real number if a precision is given. + + Examples + ======== + + >>> from sympy.physics.wigner import wigner_6j + >>> wigner_6j(3,3,3,3,3,3) + -1/14 + >>> wigner_6j(5,5,5,5,5,5) + 1/52 + + It is an error to have arguments that are not integer or half + integer values or do not fulfill the triangle relation:: + + sage: wigner_6j(2.5,2.5,2.5,2.5,2.5,2.5) + Traceback (most recent call last): + ... + ValueError: j values must be integer or half integer and fulfill the triangle relation + sage: wigner_6j(0.5,0.5,1.1,0.5,0.5,1.1) + Traceback (most recent call last): + ... + ValueError: j values must be integer or half integer and fulfill the triangle relation + + Notes + ===== + + The Wigner 6j symbol is related to the Racah symbol but exhibits + more symmetries as detailed below. + + .. math:: + + \operatorname{Wigner6j}(j_1,j_2,j_3,j_4,j_5,j_6) + =(-1)^{j_1+j_2+j_4+j_5} W(j_1,j_2,j_5,j_4,j_3,j_6) + + The Wigner 6j symbol obeys the following symmetry rules: + + - Wigner 6j symbols are left invariant under any permutation of + the columns: + + .. math:: + + \begin{aligned} + \operatorname{Wigner6j}(j_1,j_2,j_3,j_4,j_5,j_6) + &=\operatorname{Wigner6j}(j_3,j_1,j_2,j_6,j_4,j_5) \\ + &=\operatorname{Wigner6j}(j_2,j_3,j_1,j_5,j_6,j_4) \\ + &=\operatorname{Wigner6j}(j_3,j_2,j_1,j_6,j_5,j_4) \\ + &=\operatorname{Wigner6j}(j_1,j_3,j_2,j_4,j_6,j_5) \\ + &=\operatorname{Wigner6j}(j_2,j_1,j_3,j_5,j_4,j_6) + \end{aligned} + + - They are invariant under the exchange of the upper and lower + arguments in each of any two columns, i.e. + + .. math:: + + \begin{aligned} + \operatorname{Wigner6j}(j_1,j_2,j_3,j_4,j_5,j_6) + &=\operatorname{Wigner6j}(j_1,j_5,j_6,j_4,j_2,j_3)\\ + &=\operatorname{Wigner6j}(j_4,j_2,j_6,j_1,j_5,j_3)\\ + &=\operatorname{Wigner6j}(j_4,j_5,j_3,j_1,j_2,j_6) + \end{aligned} + + - additional 6 symmetries [Regge59]_ giving rise to 144 symmetries + in total + + - only non-zero if any triple of `j`'s fulfill a triangle relation + + Algorithm + ========= + + This function uses the algorithm of [Edmonds74]_ to calculate the + value of the 6j symbol exactly. Note that the formula contains + alternating sums over large factorials and is therefore unsuitable + for finite precision arithmetic and only useful for a computer + algebra system [Rasch03]_. + + """ + j_1, j_2, j_3, j_4, j_5, j_6 = map(sympify, \ + [j_1, j_2, j_3, j_4, j_5, j_6]) + res = (-1) ** int(j_1 + j_2 + j_4 + j_5) * \ + racah(j_1, j_2, j_5, j_4, j_3, j_6, prec) + return res + + +def wigner_9j(j_1, j_2, j_3, j_4, j_5, j_6, j_7, j_8, j_9, prec=None): + r""" + Calculate the Wigner 9j symbol + `\operatorname{Wigner9j}(j_1,j_2,j_3,j_4,j_5,j_6,j_7,j_8,j_9)`. + + Parameters + ========== + + j_1, ..., j_9 : + Integer or half integer. + prec : precision, default + ``None``. Providing a precision can + drastically speed up the calculation. + + Returns + ======= + + Rational number times the square root of a rational number + (if ``prec=None``), or real number if a precision is given. + + Examples + ======== + + >>> from sympy.physics.wigner import wigner_9j + >>> wigner_9j(1,1,1, 1,1,1, 1,1,0, prec=64) + 0.05555555555555555555555555555555555555555555555555555555555555555 + + >>> wigner_9j(1/2,1/2,0, 1/2,3/2,1, 0,1,1, prec=64) + 0.1666666666666666666666666666666666666666666666666666666666666667 + + It is an error to have arguments that are not integer or half + integer values or do not fulfill the triangle relation:: + + sage: wigner_9j(0.5,0.5,0.5, 0.5,0.5,0.5, 0.5,0.5,0.5,prec=64) + Traceback (most recent call last): + ... + ValueError: j values must be integer or half integer and fulfill the triangle relation + sage: wigner_9j(1,1,1, 0.5,1,1.5, 0.5,1,2.5,prec=64) + Traceback (most recent call last): + ... + ValueError: j values must be integer or half integer and fulfill the triangle relation + + Algorithm + ========= + + This function uses the algorithm of [Edmonds74]_ to calculate the + value of the 3j symbol exactly. Note that the formula contains + alternating sums over large factorials and is therefore unsuitable + for finite precision arithmetic and only useful for a computer + algebra system [Rasch03]_. + """ + j_1, j_2, j_3, j_4, j_5, j_6, j_7, j_8, j_9 = map(sympify, \ + [j_1, j_2, j_3, j_4, j_5, j_6, j_7, j_8, j_9]) + imax = int(min(j_1 + j_9, j_2 + j_6, j_4 + j_8) * 2) + imin = imax % 2 + sumres = 0 + for kk in range(imin, int(imax) + 1, 2): + sumres = sumres + (kk + 1) * \ + racah(j_1, j_2, j_9, j_6, j_3, kk / 2, prec) * \ + racah(j_4, j_6, j_8, j_2, j_5, kk / 2, prec) * \ + racah(j_1, j_4, j_9, j_8, j_7, kk / 2, prec) + return sumres + + +def gaunt(l_1, l_2, l_3, m_1, m_2, m_3, prec=None): + r""" + Calculate the Gaunt coefficient. + + Explanation + =========== + + The Gaunt coefficient is defined as the integral over three + spherical harmonics: + + .. math:: + + \begin{aligned} + \operatorname{Gaunt}(l_1,l_2,l_3,m_1,m_2,m_3) + &=\int Y_{l_1,m_1}(\Omega) + Y_{l_2,m_2}(\Omega) Y_{l_3,m_3}(\Omega) \,d\Omega \\ + &=\sqrt{\frac{(2l_1+1)(2l_2+1)(2l_3+1)}{4\pi}} + \operatorname{Wigner3j}(l_1,l_2,l_3,0,0,0) + \operatorname{Wigner3j}(l_1,l_2,l_3,m_1,m_2,m_3) + \end{aligned} + + Parameters + ========== + + l_1, l_2, l_3, m_1, m_2, m_3 : + Integer. + prec - precision, default: ``None``. + Providing a precision can + drastically speed up the calculation. + + Returns + ======= + + Rational number times the square root of a rational number + (if ``prec=None``), or real number if a precision is given. + + Examples + ======== + + >>> from sympy.physics.wigner import gaunt + >>> gaunt(1,0,1,1,0,-1) + -1/(2*sqrt(pi)) + >>> gaunt(1000,1000,1200,9,3,-12).n(64) + 0.006895004219221134484332976156744208248842039317638217822322799675 + + It is an error to use non-integer values for `l` and `m`:: + + sage: gaunt(1.2,0,1.2,0,0,0) + Traceback (most recent call last): + ... + ValueError: l values must be integer + sage: gaunt(1,0,1,1.1,0,-1.1) + Traceback (most recent call last): + ... + ValueError: m values must be integer + + Notes + ===== + + The Gaunt coefficient obeys the following symmetry rules: + + - invariant under any permutation of the columns + + .. math:: + \begin{aligned} + Y(l_1,l_2,l_3,m_1,m_2,m_3) + &=Y(l_3,l_1,l_2,m_3,m_1,m_2) \\ + &=Y(l_2,l_3,l_1,m_2,m_3,m_1) \\ + &=Y(l_3,l_2,l_1,m_3,m_2,m_1) \\ + &=Y(l_1,l_3,l_2,m_1,m_3,m_2) \\ + &=Y(l_2,l_1,l_3,m_2,m_1,m_3) + \end{aligned} + + - invariant under space inflection, i.e. + + .. math:: + Y(l_1,l_2,l_3,m_1,m_2,m_3) + =Y(l_1,l_2,l_3,-m_1,-m_2,-m_3) + + - symmetric with respect to the 72 Regge symmetries as inherited + for the `3j` symbols [Regge58]_ + + - zero for `l_1`, `l_2`, `l_3` not fulfilling triangle relation + + - zero for violating any one of the conditions: `l_1 \ge |m_1|`, + `l_2 \ge |m_2|`, `l_3 \ge |m_3|` + + - non-zero only for an even sum of the `l_i`, i.e. + `L = l_1 + l_2 + l_3 = 2n` for `n` in `\mathbb{N}` + + Algorithms + ========== + + This function uses the algorithm of [Liberatodebrito82]_ to + calculate the value of the Gaunt coefficient exactly. Note that + the formula contains alternating sums over large factorials and is + therefore unsuitable for finite precision arithmetic and only + useful for a computer algebra system [Rasch03]_. + + Authors + ======= + + Jens Rasch (2009-03-24): initial version for Sage. + """ + l_1, l_2, l_3, m_1, m_2, m_3 = [ + as_int(i) for i in (l_1, l_2, l_3, m_1, m_2, m_3)] + + if l_1 + l_2 - l_3 < 0: + return S.Zero + if l_1 - l_2 + l_3 < 0: + return S.Zero + if -l_1 + l_2 + l_3 < 0: + return S.Zero + if (m_1 + m_2 + m_3) != 0: + return S.Zero + if (abs(m_1) > l_1) or (abs(m_2) > l_2) or (abs(m_3) > l_3): + return S.Zero + bigL, remL = divmod(l_1 + l_2 + l_3, 2) + if remL % 2: + return S.Zero + + imin = max(-l_3 + l_1 + m_2, -l_3 + l_2 - m_1, 0) + imax = min(l_2 + m_2, l_1 - m_1, l_1 + l_2 - l_3) + + _calc_factlist(max(l_1 + l_2 + l_3 + 1, imax + 1)) + + ressqrt = sqrt((2 * l_1 + 1) * (2 * l_2 + 1) * (2 * l_3 + 1) * \ + _Factlist[l_1 - m_1] * _Factlist[l_1 + m_1] * _Factlist[l_2 - m_2] * \ + _Factlist[l_2 + m_2] * _Factlist[l_3 - m_3] * _Factlist[l_3 + m_3] / \ + (4*pi)) + + prefac = Integer(_Factlist[bigL] * _Factlist[l_2 - l_1 + l_3] * + _Factlist[l_1 - l_2 + l_3] * _Factlist[l_1 + l_2 - l_3])/ \ + _Factlist[2 * bigL + 1]/ \ + (_Factlist[bigL - l_1] * + _Factlist[bigL - l_2] * _Factlist[bigL - l_3]) + + sumres = 0 + for ii in range(int(imin), int(imax) + 1): + den = _Factlist[ii] * _Factlist[ii + l_3 - l_1 - m_2] * \ + _Factlist[l_2 + m_2 - ii] * _Factlist[l_1 - ii - m_1] * \ + _Factlist[ii + l_3 - l_2 + m_1] * _Factlist[l_1 + l_2 - l_3 - ii] + sumres = sumres + Integer((-1) ** ii) / den + + res = ressqrt * prefac * sumres * Integer((-1) ** (bigL + l_3 + m_1 - m_2)) + if prec is not None: + res = res.n(prec) + return res + + +def real_gaunt(l_1, l_2, l_3, mu_1, mu_2, mu_3, prec=None): + r""" + Calculate the real Gaunt coefficient. + + Explanation + =========== + + The real Gaunt coefficient is defined as the integral over three + real spherical harmonics: + + .. math:: + \begin{aligned} + \operatorname{RealGaunt}(l_1,l_2,l_3,\mu_1,\mu_2,\mu_3) + &=\int Z^{\mu_1}_{l_1}(\Omega) + Z^{\mu_2}_{l_2}(\Omega) Z^{\mu_3}_{l_3}(\Omega) \,d\Omega \\ + \end{aligned} + + Alternatively, it can be defined in terms of the standard Gaunt + coefficient by relating the real spherical harmonics to the standard + spherical harmonics via a unitary transformation `U`, i.e. + `Z^{\mu}_{l}(\Omega)=\sum_{m'}U^{\mu}_{m'}Y^{m'}_{l}(\Omega)` [Homeier96]_. + The real Gaunt coefficient is then defined as + + .. math:: + \begin{aligned} + \operatorname{RealGaunt}(l_1,l_2,l_3,\mu_1,\mu_2,\mu_3) + &=\int Z^{\mu_1}_{l_1}(\Omega) + Z^{\mu_2}_{l_2}(\Omega) Z^{\mu_3}_{l_3}(\Omega) \,d\Omega \\ + &=\sum_{m'_1 m'_2 m'_3} U^{\mu_1}_{m'_1}U^{\mu_2}_{m'_2}U^{\mu_3}_{m'_3} + \operatorname{Gaunt}(l_1,l_2,l_3,m'_1,m'_2,m'_3) + \end{aligned} + + The unitary matrix `U` has components + + .. math:: + \begin{aligned} + U^\mu_{m} = \delta_{|\mu||m|}*(\delta_{m0}\delta_{\mu 0} + \frac{1}{\sqrt{2}}\big[\Theta(\mu)\big(\delta_{m\mu}+(-1)^{m}\delta_{m-\mu}\big) + +i \Theta(-\mu)\big((-1)^{m}\delta_{m\mu}-\delta_{m-\mu}\big)\big]) + \end{aligned} + + + where `\delta_{ij}` is the Kronecker delta symbol and `\Theta` is a step + function defined as + + .. math:: + \begin{aligned} + \Theta(x) = \begin{cases} 1 \,\text{for}\, x > 0 \\ 0 \,\text{for}\, x \leq 0 \end{cases} + \end{aligned} + + Parameters + ========== + + l_1, l_2, l_3, mu_1, mu_2, mu_3 : + Integer degree and order + + prec - precision, default: ``None``. + Providing a precision can + drastically speed up the calculation. + + Returns + ======= + + Rational number times the square root of a rational number. + + Examples + ======== + >>> from sympy.physics.wigner import real_gaunt + >>> real_gaunt(1,1,2,-1,1,-2) + sqrt(15)/(10*sqrt(pi)) + >>> real_gaunt(10,10,20,-9,-9,0,prec=64) + -0.00002480019791932209313156167176797577821140084216297395518482071448 + + It is an error to use non-integer values for `l` and `\mu`:: + real_gaunt(2.8,0.5,1.3,0,0,0) + Traceback (most recent call last): + ... + ValueError: l values must be integer + + real_gaunt(2,2,4,0.7,1,-3.4) + Traceback (most recent call last): + ... + ValueError: mu values must be integer + + Notes + ===== + + The real Gaunt coefficient inherits from the standard Gaunt coefficient, + the invariance under any permutation of the pairs `(l_i, \mu_i)` and the + requirement that the sum of the `l_i` be even to yield a non-zero value. + It also obeys the following symmetry rules: + + - zero for `l_1`, `l_2`, `l_3` not fulfilling the condition + `l_1 \in \{l_{\text{max}}, l_{\text{max}}-2, \ldots, l_{\text{min}}\}`, + where `l_{\text{max}} = l_2+l_3`, + + .. math:: + \begin{aligned} + l_{\text{min}} = \begin{cases} \kappa(l_2, l_3, \mu_2, \mu_3) & \text{if}\, + \kappa(l_2, l_3, \mu_2, \mu_3) + l_{\text{max}}\, \text{is even} \\ + \kappa(l_2, l_3, \mu_2, \mu_3)+1 & \text{if}\, \kappa(l_2, l_3, \mu_2, \mu_3) + + l_{\text{max}}\, \text{is odd}\end{cases} + \end{aligned} + + and `\kappa(l_2, l_3, \mu_2, \mu_3) = \max{\big(|l_2-l_3|, \min{\big(|\mu_2+\mu_3|, + |\mu_2-\mu_3|\big)}\big)}` + + - zero for an odd number of negative `\mu_i` + + Algorithms + ========== + + This function uses the algorithms of [Homeier96]_ and [Rasch03]_ to + calculate the value of the real Gaunt coefficient exactly. Note that + the formula used in [Rasch03]_ contains alternating sums over large + factorials and is therefore unsuitable for finite precision arithmetic + and only useful for a computer algebra system [Rasch03]_. However, this + function can in principle use any algorithm that computes the Gaunt + coefficient, so it is suitable for finite precision arithmetic in so far + as the algorithm which computes the Gaunt coefficient is. + """ + l_1, l_2, l_3, mu_1, mu_2, mu_3 = [ + as_int(i) for i in (l_1, l_2, l_3, mu_1, mu_2, mu_3)] + + # check for quick exits + if sum(1 for i in (mu_1, mu_2, mu_3) if i < 0) % 2: + return S.Zero # odd number of negative m + if (l_1 + l_2 + l_3) % 2: + return S.Zero # sum of l is odd + lmax = l_2 + l_3 + lmin = max(abs(l_2 - l_3), min(abs(mu_2 + mu_3), abs(mu_2 - mu_3))) + if (lmin + lmax) % 2: + lmin += 1 + if lmin not in range(lmax, lmin - 2, -2): + return S.Zero + + kron_del = lambda i, j: 1 if i == j else 0 + s = lambda e: -1 if e % 2 else 1 # (-1)**e to give +/-1, avoiding float when e<0 + + t = lambda x: 1 if x > 0 else 0 + A = lambda mu, m: t(-mu) * (s(m) * kron_del(m, mu) - kron_del(m, -mu)) + B = lambda mu, m: t(mu) * (kron_del(m, mu) + s(m) * kron_del(m, -mu)) + U = lambda mu, m: kron_del(abs(mu), abs(m)) * (kron_del(mu, 0) * kron_del(m, 0) + (B(mu, m) + I * A(mu, m))/sqrt(2)) + + ugnt = 0 + for m1 in range(-l_1, l_1+1): + U1 = U(mu_1, m1) + for m2 in range(-l_2, l_2+1): + U2 = U(mu_2, m2) + U3 = U(mu_3,-m1-m2) + ugnt = ugnt + re(U1*U2*U3)*gaunt(l_1, l_2, l_3, m1, m2, -m1 - m2, prec=prec) + + return ugnt + + +class Wigner3j(Function): + + def doit(self, **hints): + if all(obj.is_number for obj in self.args): + return wigner_3j(*self.args) + else: + return self + +def dot_rot_grad_Ynm(j, p, l, m, theta, phi): + r""" + Returns dot product of rotational gradients of spherical harmonics. + + Explanation + =========== + + This function returns the right hand side of the following expression: + + .. math :: + \vec{R}Y{_j^{p}} \cdot \vec{R}Y{_l^{m}} = (-1)^{m+p} + \sum\limits_{k=|l-j|}^{l+j}Y{_k^{m+p}} * \alpha_{l,m,j,p,k} * + \frac{1}{2} (k^2-j^2-l^2+k-j-l) + + + Arguments + ========= + + j, p, l, m .... indices in spherical harmonics (expressions or integers) + theta, phi .... angle arguments in spherical harmonics + + Example + ======= + + >>> from sympy import symbols + >>> from sympy.physics.wigner import dot_rot_grad_Ynm + >>> theta, phi = symbols("theta phi") + >>> dot_rot_grad_Ynm(3, 2, 2, 0, theta, phi).doit() + 3*sqrt(55)*Ynm(5, 2, theta, phi)/(11*sqrt(pi)) + + """ + j = sympify(j) + p = sympify(p) + l = sympify(l) + m = sympify(m) + theta = sympify(theta) + phi = sympify(phi) + k = Dummy("k") + + def alpha(l,m,j,p,k): + return sqrt((2*l+1)*(2*j+1)*(2*k+1)/(4*pi)) * \ + Wigner3j(j, l, k, S.Zero, S.Zero, S.Zero) * \ + Wigner3j(j, l, k, p, m, -m-p) + + return (S.NegativeOne)**(m+p) * Sum(Ynm(k, m+p, theta, phi) * alpha(l,m,j,p,k) / 2 \ + *(k**2-j**2-l**2+k-j-l), (k, abs(l-j), l+j)) + + +def wigner_d_small(J, beta): + """Return the small Wigner d matrix for angular momentum J. + + Explanation + =========== + + J : An integer, half-integer, or SymPy symbol for the total angular + momentum of the angular momentum space being rotated. + beta : A real number representing the Euler angle of rotation about + the so-called line of nodes. See [Edmonds74]_. + + Returns + ======= + + A matrix representing the corresponding Euler angle rotation( in the basis + of eigenvectors of `J_z`). + + .. math :: + \\mathcal{d}_{\\beta} = \\exp\\big( \\frac{i\\beta}{\\hbar} J_y\\big) + + such that + + .. math :: + d^{(J)}_{m',m}(\\beta) = \\mathtt{wigner\\_d\\_small(J,beta)[J-mprime,J-m]} + + The components are calculated using the general form [Edmonds74]_, + equation 4.1.15. + + Examples + ======== + + >>> from sympy import Integer, symbols, pi, pprint + >>> from sympy.physics.wigner import wigner_d_small + >>> half = 1/Integer(2) + >>> beta = symbols("beta", real=True) + >>> pprint(wigner_d_small(half, beta), use_unicode=True) + ⎡ ⎛β⎞ ⎛β⎞⎤ + ⎢cos⎜─⎟ sin⎜─⎟⎥ + ⎢ ⎝2⎠ ⎝2⎠⎥ + ⎢ ⎥ + ⎢ ⎛β⎞ ⎛β⎞⎥ + ⎢-sin⎜─⎟ cos⎜─⎟⎥ + ⎣ ⎝2⎠ ⎝2⎠⎦ + + >>> pprint(wigner_d_small(2*half, beta), use_unicode=True) + ⎡ 2⎛β⎞ ⎛β⎞ ⎛β⎞ 2⎛β⎞ ⎤ + ⎢ cos ⎜─⎟ √2⋅sin⎜─⎟⋅cos⎜─⎟ sin ⎜─⎟ ⎥ + ⎢ ⎝2⎠ ⎝2⎠ ⎝2⎠ ⎝2⎠ ⎥ + ⎢ ⎥ + ⎢ ⎛β⎞ ⎛β⎞ 2⎛β⎞ 2⎛β⎞ ⎛β⎞ ⎛β⎞⎥ + ⎢-√2⋅sin⎜─⎟⋅cos⎜─⎟ - sin ⎜─⎟ + cos ⎜─⎟ √2⋅sin⎜─⎟⋅cos⎜─⎟⎥ + ⎢ ⎝2⎠ ⎝2⎠ ⎝2⎠ ⎝2⎠ ⎝2⎠ ⎝2⎠⎥ + ⎢ ⎥ + ⎢ 2⎛β⎞ ⎛β⎞ ⎛β⎞ 2⎛β⎞ ⎥ + ⎢ sin ⎜─⎟ -√2⋅sin⎜─⎟⋅cos⎜─⎟ cos ⎜─⎟ ⎥ + ⎣ ⎝2⎠ ⎝2⎠ ⎝2⎠ ⎝2⎠ ⎦ + + From table 4 in [Edmonds74]_ + + >>> pprint(wigner_d_small(half, beta).subs({beta:pi/2}), use_unicode=True) + ⎡ √2 √2⎤ + ⎢ ── ──⎥ + ⎢ 2 2 ⎥ + ⎢ ⎥ + ⎢-√2 √2⎥ + ⎢──── ──⎥ + ⎣ 2 2 ⎦ + + >>> pprint(wigner_d_small(2*half, beta).subs({beta:pi/2}), + ... use_unicode=True) + ⎡ √2 ⎤ + ⎢1/2 ── 1/2⎥ + ⎢ 2 ⎥ + ⎢ ⎥ + ⎢-√2 √2 ⎥ + ⎢──── 0 ── ⎥ + ⎢ 2 2 ⎥ + ⎢ ⎥ + ⎢ -√2 ⎥ + ⎢1/2 ──── 1/2⎥ + ⎣ 2 ⎦ + + >>> pprint(wigner_d_small(3*half, beta).subs({beta:pi/2}), + ... use_unicode=True) + ⎡ √2 √6 √6 √2⎤ + ⎢ ── ── ── ──⎥ + ⎢ 4 4 4 4 ⎥ + ⎢ ⎥ + ⎢-√6 -√2 √2 √6⎥ + ⎢──── ──── ── ──⎥ + ⎢ 4 4 4 4 ⎥ + ⎢ ⎥ + ⎢ √6 -√2 -√2 √6⎥ + ⎢ ── ──── ──── ──⎥ + ⎢ 4 4 4 4 ⎥ + ⎢ ⎥ + ⎢-√2 √6 -√6 √2⎥ + ⎢──── ── ──── ──⎥ + ⎣ 4 4 4 4 ⎦ + + >>> pprint(wigner_d_small(4*half, beta).subs({beta:pi/2}), + ... use_unicode=True) + ⎡ √6 ⎤ + ⎢1/4 1/2 ── 1/2 1/4⎥ + ⎢ 4 ⎥ + ⎢ ⎥ + ⎢-1/2 -1/2 0 1/2 1/2⎥ + ⎢ ⎥ + ⎢ √6 √6 ⎥ + ⎢ ── 0 -1/2 0 ── ⎥ + ⎢ 4 4 ⎥ + ⎢ ⎥ + ⎢-1/2 1/2 0 -1/2 1/2⎥ + ⎢ ⎥ + ⎢ √6 ⎥ + ⎢1/4 -1/2 ── -1/2 1/4⎥ + ⎣ 4 ⎦ + + """ + M = [J-i for i in range(2*J+1)] + d = zeros(2*J+1) + + # Mi corresponds to Edmonds' $m'$, and Mj to $m$. + for i, Mi in enumerate(M): + for j, Mj in enumerate(M): + + # We get the maximum and minimum value of sigma. + sigmamax = min([J-Mi, J-Mj]) + sigmamin = max([0, -Mi-Mj]) + + dij = sqrt(factorial(J+Mi)*factorial(J-Mi) / + factorial(J+Mj)/factorial(J-Mj)) + terms = [(-1)**(J-Mi-s) * + binomial(J+Mj, J-Mi-s) * + binomial(J-Mj, s) * + cos(beta/2)**(2*s+Mi+Mj) * + sin(beta/2)**(2*J-2*s-Mj-Mi) + for s in range(sigmamin, sigmamax+1)] + + d[i, j] = dij*Add(*terms) + + return ImmutableMatrix(d) + + +def wigner_d(J, alpha, beta, gamma): + """Return the Wigner D matrix for angular momentum J. + + Explanation + =========== + + J : + An integer, half-integer, or SymPy symbol for the total angular + momentum of the angular momentum space being rotated. + alpha, beta, gamma - Real numbers representing the Euler. + Angles of rotation about the so-called figure axis, line of nodes, + and vertical. See [Edmonds74]_, however note that the symbols alpha + and gamma are swapped in this implementation. + + Returns + ======= + + A matrix representing the corresponding Euler angle rotation (in the basis + of eigenvectors of `J_z`). + + .. math :: + \\mathcal{D}_{\\alpha \\beta \\gamma} = + \\exp\\big( \\frac{i\\alpha}{\\hbar} J_z\\big) + \\exp\\big( \\frac{i\\beta}{\\hbar} J_y\\big) + \\exp\\big( \\frac{i\\gamma}{\\hbar} J_z\\big) + + such that + + .. math :: + \\mathcal{D}^{(J)}_{m',m}(\\alpha, \\beta, \\gamma) = + \\mathtt{wigner_d(J, alpha, beta, gamma)[J-mprime,J-m]} + + The components are calculated using the general form [Edmonds74]_, + equation 4.1.12, however note that the angles alpha and gamma are swapped + in this implementation. + + Examples + ======== + + The simplest possible example: + + >>> from sympy.physics.wigner import wigner_d + >>> from sympy import Integer, symbols, pprint + >>> half = 1/Integer(2) + >>> alpha, beta, gamma = symbols("alpha, beta, gamma", real=True) + >>> pprint(wigner_d(half, alpha, beta, gamma), use_unicode=True) + ⎡ ⅈ⋅α ⅈ⋅γ ⅈ⋅α -ⅈ⋅γ ⎤ + ⎢ ─── ─── ─── ───── ⎥ + ⎢ 2 2 ⎛β⎞ 2 2 ⎛β⎞ ⎥ + ⎢ ℯ ⋅ℯ ⋅cos⎜─⎟ ℯ ⋅ℯ ⋅sin⎜─⎟ ⎥ + ⎢ ⎝2⎠ ⎝2⎠ ⎥ + ⎢ ⎥ + ⎢ -ⅈ⋅α ⅈ⋅γ -ⅈ⋅α -ⅈ⋅γ ⎥ + ⎢ ───── ─── ───── ───── ⎥ + ⎢ 2 2 ⎛β⎞ 2 2 ⎛β⎞⎥ + ⎢-ℯ ⋅ℯ ⋅sin⎜─⎟ ℯ ⋅ℯ ⋅cos⎜─⎟⎥ + ⎣ ⎝2⎠ ⎝2⎠⎦ + + """ + d = wigner_d_small(J, beta) + M = [J-i for i in range(2*J+1)] + # Mi corresponds to Edmonds' $m'$, and Mj to $m$. + D = [[exp(I*Mi*alpha)*d[i, j]*exp(I*Mj*gamma) + for j, Mj in enumerate(M)] for i, Mi in enumerate(M)] + return ImmutableMatrix(D) diff --git a/phivenv/Lib/site-packages/sympy/plotting/__init__.py b/phivenv/Lib/site-packages/sympy/plotting/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..074bcf93b7375eb3dc96d16b5450b539074d8f7d --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/plotting/__init__.py @@ -0,0 +1,22 @@ +from .plot import plot_backends +from .plot_implicit import plot_implicit +from .textplot import textplot +from .pygletplot import PygletPlot +from .plot import PlotGrid +from .plot import (plot, plot_parametric, plot3d, plot3d_parametric_surface, + plot3d_parametric_line, plot_contour) + +__all__ = [ + 'plot_backends', + + 'plot_implicit', + + 'textplot', + + 'PygletPlot', + + 'PlotGrid', + + 'plot', 'plot_parametric', 'plot3d', 'plot3d_parametric_surface', + 'plot3d_parametric_line', 'plot_contour' +] diff --git a/phivenv/Lib/site-packages/sympy/plotting/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/plotting/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0d222d3d5037e0512cb74bf5f7e3b6e9500373d Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/plotting/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/plotting/__pycache__/experimental_lambdify.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/plotting/__pycache__/experimental_lambdify.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4fcc4974d3b47dedf276d5ff940649dde6377030 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/plotting/__pycache__/experimental_lambdify.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/plotting/__pycache__/plot.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/plotting/__pycache__/plot.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a0312b644d90d0e227e509ea5092490e8e6f1c95 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/plotting/__pycache__/plot.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/plotting/__pycache__/plot_implicit.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/plotting/__pycache__/plot_implicit.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..438d65c50fdac7d725426444aa4906653c4e82e4 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/plotting/__pycache__/plot_implicit.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/plotting/__pycache__/plotgrid.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/plotting/__pycache__/plotgrid.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1794c88e0b957da4a2990ac067052646b6a3f638 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/plotting/__pycache__/plotgrid.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/plotting/__pycache__/series.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/plotting/__pycache__/series.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a820b56c6abcc4410eecd76087e3a1d637dc53c7 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/plotting/__pycache__/series.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/plotting/__pycache__/textplot.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/plotting/__pycache__/textplot.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..05f1a2fcc7252fe9e21f84438613c314adb1f731 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/plotting/__pycache__/textplot.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/plotting/__pycache__/utils.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/plotting/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07199b10b138271518fa82690fbc5b1ee555c76d Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/plotting/__pycache__/utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/plotting/backends/__init__.py b/phivenv/Lib/site-packages/sympy/plotting/backends/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/plotting/backends/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/plotting/backends/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..707b7a108a40c5f639a117b7f911164786f18e4f Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/plotting/backends/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/plotting/backends/__pycache__/base_backend.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/plotting/backends/__pycache__/base_backend.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..80570c9cb8510fa2117b5d441f1ffe5880a8ecaf Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/plotting/backends/__pycache__/base_backend.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/plotting/backends/base_backend.py b/phivenv/Lib/site-packages/sympy/plotting/backends/base_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..a43cfa18eb7aff90ddacd6cdb60dfb0dadcb0abf --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/plotting/backends/base_backend.py @@ -0,0 +1,419 @@ +from sympy.plotting.series import BaseSeries, GenericDataSeries +from sympy.utilities.exceptions import sympy_deprecation_warning +from sympy.utilities.iterables import is_sequence + + +__doctest_requires__ = { + ('Plot.append', 'Plot.extend'): ['matplotlib'], +} + + +# Global variable +# Set to False when running tests / doctests so that the plots don't show. +_show = True + +def unset_show(): + """ + Disable show(). For use in the tests. + """ + global _show + _show = False + + +def _deprecation_msg_m_a_r_f(attr): + sympy_deprecation_warning( + f"The `{attr}` property is deprecated. The `{attr}` keyword " + "argument should be passed to a plotting function, which generates " + "the appropriate data series. If needed, index the plot object to " + "retrieve a specific data series.", + deprecated_since_version="1.13", + active_deprecations_target="deprecated-markers-annotations-fill-rectangles", + stacklevel=4) + + +def _create_generic_data_series(**kwargs): + keywords = ["annotations", "markers", "fill", "rectangles"] + series = [] + for kw in keywords: + dictionaries = kwargs.pop(kw, []) + if dictionaries is None: + dictionaries = [] + if isinstance(dictionaries, dict): + dictionaries = [dictionaries] + for d in dictionaries: + args = d.pop("args", []) + series.append(GenericDataSeries(kw, *args, **d)) + return series + + +class Plot: + """Base class for all backends. A backend represents the plotting library, + which implements the necessary functionalities in order to use SymPy + plotting functions. + + For interactive work the function :func:`plot` is better suited. + + This class permits the plotting of SymPy expressions using numerous + backends (:external:mod:`matplotlib`, textplot, the old pyglet module for SymPy, Google + charts api, etc). + + The figure can contain an arbitrary number of plots of SymPy expressions, + lists of coordinates of points, etc. Plot has a private attribute _series that + contains all data series to be plotted (expressions for lines or surfaces, + lists of points, etc (all subclasses of BaseSeries)). Those data series are + instances of classes not imported by ``from sympy import *``. + + The customization of the figure is on two levels. Global options that + concern the figure as a whole (e.g. title, xlabel, scale, etc) and + per-data series options (e.g. name) and aesthetics (e.g. color, point shape, + line type, etc.). + + The difference between options and aesthetics is that an aesthetic can be + a function of the coordinates (or parameters in a parametric plot). The + supported values for an aesthetic are: + + - None (the backend uses default values) + - a constant + - a function of one variable (the first coordinate or parameter) + - a function of two variables (the first and second coordinate or parameters) + - a function of three variables (only in nonparametric 3D plots) + + Their implementation depends on the backend so they may not work in some + backends. + + If the plot is parametric and the arity of the aesthetic function permits + it the aesthetic is calculated over parameters and not over coordinates. + If the arity does not permit calculation over parameters the calculation is + done over coordinates. + + Only cartesian coordinates are supported for the moment, but you can use + the parametric plots to plot in polar, spherical and cylindrical + coordinates. + + The arguments for the constructor Plot must be subclasses of BaseSeries. + + Any global option can be specified as a keyword argument. + + The global options for a figure are: + + - title : str + - xlabel : str or Symbol + - ylabel : str or Symbol + - zlabel : str or Symbol + - legend : bool + - xscale : {'linear', 'log'} + - yscale : {'linear', 'log'} + - axis : bool + - axis_center : tuple of two floats or {'center', 'auto'} + - xlim : tuple of two floats + - ylim : tuple of two floats + - aspect_ratio : tuple of two floats or {'auto'} + - autoscale : bool + - margin : float in [0, 1] + - backend : {'default', 'matplotlib', 'text'} or a subclass of BaseBackend + - size : optional tuple of two floats, (width, height); default: None + + The per data series options and aesthetics are: + There are none in the base series. See below for options for subclasses. + + Some data series support additional aesthetics or options: + + :class:`~.LineOver1DRangeSeries`, :class:`~.Parametric2DLineSeries`, and + :class:`~.Parametric3DLineSeries` support the following: + + Aesthetics: + + - line_color : string, or float, or function, optional + Specifies the color for the plot, which depends on the backend being + used. + + For example, if ``MatplotlibBackend`` is being used, then + Matplotlib string colors are acceptable (``"red"``, ``"r"``, + ``"cyan"``, ``"c"``, ...). + Alternatively, we can use a float number, 0 < color < 1, wrapped in a + string (for example, ``line_color="0.5"``) to specify grayscale colors. + Alternatively, We can specify a function returning a single + float value: this will be used to apply a color-loop (for example, + ``line_color=lambda x: math.cos(x)``). + + Note that by setting line_color, it would be applied simultaneously + to all the series. + + Options: + + - label : str + - steps : bool + - integers_only : bool + + :class:`~.SurfaceOver2DRangeSeries` and :class:`~.ParametricSurfaceSeries` + support the following: + + Aesthetics: + + - surface_color : function which returns a float. + + Notes + ===== + + How the plotting module works: + + 1. Whenever a plotting function is called, the provided expressions are + processed and a list of instances of the + :class:`~sympy.plotting.series.BaseSeries` class is created, containing + the necessary information to plot the expressions + (e.g. the expression, ranges, series name, ...). Eventually, these + objects will generate the numerical data to be plotted. + 2. A subclass of :class:`~.Plot` class is instantiaed (referred to as + backend, from now on), which stores the list of series and the main + attributes of the plot (e.g. axis labels, title, ...). + The backend implements the logic to generate the actual figure with + some plotting library. + 3. When the ``show`` command is executed, series are processed one by one + to generate numerical data and add it to the figure. The backend is also + going to set the axis labels, title, ..., according to the values stored + in the Plot instance. + + The backend should check if it supports the data series that it is given + (e.g. :class:`TextBackend` supports only + :class:`~sympy.plotting.series.LineOver1DRangeSeries`). + + It is the backend responsibility to know how to use the class of data series + that it's given. Note that the current implementation of the ``*Series`` + classes is "matplotlib-centric": the numerical data returned by the + ``get_points`` and ``get_meshes`` methods is meant to be used directly by + Matplotlib. Therefore, the new backend will have to pre-process the + numerical data to make it compatible with the chosen plotting library. + Keep in mind that future SymPy versions may improve the ``*Series`` classes + in order to return numerical data "non-matplotlib-centric", hence if you code + a new backend you have the responsibility to check if its working on each + SymPy release. + + Please explore the :class:`MatplotlibBackend` source code to understand + how a backend should be coded. + + In order to be used by SymPy plotting functions, a backend must implement + the following methods: + + * show(self): used to loop over the data series, generate the numerical + data, plot it and set the axis labels, title, ... + * save(self, path): used to save the current plot to the specified file + path. + * close(self): used to close the current plot backend (note: some plotting + library does not support this functionality. In that case, just raise a + warning). + """ + + def __init__(self, *args, + title=None, xlabel=None, ylabel=None, zlabel=None, aspect_ratio='auto', + xlim=None, ylim=None, axis_center='auto', axis=True, + xscale='linear', yscale='linear', legend=False, autoscale=True, + margin=0, annotations=None, markers=None, rectangles=None, + fill=None, backend='default', size=None, **kwargs): + + # Options for the graph as a whole. + # The possible values for each option are described in the docstring of + # Plot. They are based purely on convention, no checking is done. + self.title = title + self.xlabel = xlabel + self.ylabel = ylabel + self.zlabel = zlabel + self.aspect_ratio = aspect_ratio + self.axis_center = axis_center + self.axis = axis + self.xscale = xscale + self.yscale = yscale + self.legend = legend + self.autoscale = autoscale + self.margin = margin + self._annotations = annotations + self._markers = markers + self._rectangles = rectangles + self._fill = fill + + # Contains the data objects to be plotted. The backend should be smart + # enough to iterate over this list. + self._series = [] + self._series.extend(args) + self._series.extend(_create_generic_data_series( + annotations=annotations, markers=markers, rectangles=rectangles, + fill=fill)) + + is_real = \ + lambda lim: all(getattr(i, 'is_real', True) for i in lim) + is_finite = \ + lambda lim: all(getattr(i, 'is_finite', True) for i in lim) + + # reduce code repetition + def check_and_set(t_name, t): + if t: + if not is_real(t): + raise ValueError( + "All numbers from {}={} must be real".format(t_name, t)) + if not is_finite(t): + raise ValueError( + "All numbers from {}={} must be finite".format(t_name, t)) + setattr(self, t_name, (float(t[0]), float(t[1]))) + + self.xlim = None + check_and_set("xlim", xlim) + self.ylim = None + check_and_set("ylim", ylim) + self.size = None + check_and_set("size", size) + + @property + def _backend(self): + return self + + @property + def backend(self): + return type(self) + + def __str__(self): + series_strs = [('[%d]: ' % i) + str(s) + for i, s in enumerate(self._series)] + return 'Plot object containing:\n' + '\n'.join(series_strs) + + def __getitem__(self, index): + return self._series[index] + + def __setitem__(self, index, *args): + if len(args) == 1 and isinstance(args[0], BaseSeries): + self._series[index] = args + + def __delitem__(self, index): + del self._series[index] + + def append(self, arg): + """Adds an element from a plot's series to an existing plot. + + Examples + ======== + + Consider two ``Plot`` objects, ``p1`` and ``p2``. To add the + second plot's first series object to the first, use the + ``append`` method, like so: + + .. plot:: + :format: doctest + :include-source: True + + >>> from sympy import symbols + >>> from sympy.plotting import plot + >>> x = symbols('x') + >>> p1 = plot(x*x, show=False) + >>> p2 = plot(x, show=False) + >>> p1.append(p2[0]) + >>> p1 + Plot object containing: + [0]: cartesian line: x**2 for x over (-10.0, 10.0) + [1]: cartesian line: x for x over (-10.0, 10.0) + >>> p1.show() + + See Also + ======== + + extend + + """ + if isinstance(arg, BaseSeries): + self._series.append(arg) + else: + raise TypeError('Must specify element of plot to append.') + + def extend(self, arg): + """Adds all series from another plot. + + Examples + ======== + + Consider two ``Plot`` objects, ``p1`` and ``p2``. To add the + second plot to the first, use the ``extend`` method, like so: + + .. plot:: + :format: doctest + :include-source: True + + >>> from sympy import symbols + >>> from sympy.plotting import plot + >>> x = symbols('x') + >>> p1 = plot(x**2, show=False) + >>> p2 = plot(x, -x, show=False) + >>> p1.extend(p2) + >>> p1 + Plot object containing: + [0]: cartesian line: x**2 for x over (-10.0, 10.0) + [1]: cartesian line: x for x over (-10.0, 10.0) + [2]: cartesian line: -x for x over (-10.0, 10.0) + >>> p1.show() + + """ + if isinstance(arg, Plot): + self._series.extend(arg._series) + elif is_sequence(arg): + self._series.extend(arg) + else: + raise TypeError('Expecting Plot or sequence of BaseSeries') + + def show(self): + raise NotImplementedError + + def save(self, path): + raise NotImplementedError + + def close(self): + raise NotImplementedError + + # deprecations + + @property + def markers(self): + """.. deprecated:: 1.13""" + _deprecation_msg_m_a_r_f("markers") + return self._markers + + @markers.setter + def markers(self, v): + """.. deprecated:: 1.13""" + _deprecation_msg_m_a_r_f("markers") + self._series.extend(_create_generic_data_series(markers=v)) + self._markers = v + + @property + def annotations(self): + """.. deprecated:: 1.13""" + _deprecation_msg_m_a_r_f("annotations") + return self._annotations + + @annotations.setter + def annotations(self, v): + """.. deprecated:: 1.13""" + _deprecation_msg_m_a_r_f("annotations") + self._series.extend(_create_generic_data_series(annotations=v)) + self._annotations = v + + @property + def rectangles(self): + """.. deprecated:: 1.13""" + _deprecation_msg_m_a_r_f("rectangles") + return self._rectangles + + @rectangles.setter + def rectangles(self, v): + """.. deprecated:: 1.13""" + _deprecation_msg_m_a_r_f("rectangles") + self._series.extend(_create_generic_data_series(rectangles=v)) + self._rectangles = v + + @property + def fill(self): + """.. deprecated:: 1.13""" + _deprecation_msg_m_a_r_f("fill") + return self._fill + + @fill.setter + def fill(self, v): + """.. deprecated:: 1.13""" + _deprecation_msg_m_a_r_f("fill") + self._series.extend(_create_generic_data_series(fill=v)) + self._fill = v diff --git a/phivenv/Lib/site-packages/sympy/plotting/backends/matplotlibbackend/__init__.py b/phivenv/Lib/site-packages/sympy/plotting/backends/matplotlibbackend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8623940dadb9272730fdeccc1668374781c2e5cf --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/plotting/backends/matplotlibbackend/__init__.py @@ -0,0 +1,5 @@ +from sympy.plotting.backends.matplotlibbackend.matplotlib import ( + MatplotlibBackend, _matplotlib_list +) + +__all__ = ["MatplotlibBackend", "_matplotlib_list"] diff --git a/phivenv/Lib/site-packages/sympy/plotting/backends/matplotlibbackend/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/plotting/backends/matplotlibbackend/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b162f28b0dd8f558fff2df53277f0ad9bd7244e Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/plotting/backends/matplotlibbackend/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/plotting/backends/matplotlibbackend/__pycache__/matplotlib.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/plotting/backends/matplotlibbackend/__pycache__/matplotlib.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02bea903e9b8e4ab6383202f4e488bbd1be7e6d6 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/plotting/backends/matplotlibbackend/__pycache__/matplotlib.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/plotting/backends/matplotlibbackend/matplotlib.py b/phivenv/Lib/site-packages/sympy/plotting/backends/matplotlibbackend/matplotlib.py new file mode 100644 index 0000000000000000000000000000000000000000..f598a10a7cd17d40e18d1438e8c6bb174071d0a6 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/plotting/backends/matplotlibbackend/matplotlib.py @@ -0,0 +1,318 @@ +from collections.abc import Callable +from sympy.core.basic import Basic +from sympy.external import import_module +import sympy.plotting.backends.base_backend as base_backend +from sympy.printing.latex import latex + + +# N.B. +# When changing the minimum module version for matplotlib, please change +# the same in the `SymPyDocTestFinder`` in `sympy/testing/runtests.py` + + +def _str_or_latex(label): + if isinstance(label, Basic): + return latex(label, mode='inline') + return str(label) + + +def _matplotlib_list(interval_list): + """ + Returns lists for matplotlib ``fill`` command from a list of bounding + rectangular intervals + """ + xlist = [] + ylist = [] + if len(interval_list): + for intervals in interval_list: + intervalx = intervals[0] + intervaly = intervals[1] + xlist.extend([intervalx.start, intervalx.start, + intervalx.end, intervalx.end, None]) + ylist.extend([intervaly.start, intervaly.end, + intervaly.end, intervaly.start, None]) + else: + #XXX Ugly hack. Matplotlib does not accept empty lists for ``fill`` + xlist.extend((None, None, None, None)) + ylist.extend((None, None, None, None)) + return xlist, ylist + + +# Don't have to check for the success of importing matplotlib in each case; +# we will only be using this backend if we can successfully import matploblib +class MatplotlibBackend(base_backend.Plot): + """ This class implements the functionalities to use Matplotlib with SymPy + plotting functions. + """ + + def __init__(self, *series, **kwargs): + super().__init__(*series, **kwargs) + self.matplotlib = import_module('matplotlib', + import_kwargs={'fromlist': ['pyplot', 'cm', 'collections']}, + min_module_version='1.1.0', catch=(RuntimeError,)) + self.plt = self.matplotlib.pyplot + self.cm = self.matplotlib.cm + self.LineCollection = self.matplotlib.collections.LineCollection + self.aspect = kwargs.get('aspect_ratio', 'auto') + if self.aspect != 'auto': + self.aspect = float(self.aspect[1]) / self.aspect[0] + # PlotGrid can provide its figure and axes to be populated with + # the data from the series. + self._plotgrid_fig = kwargs.pop("fig", None) + self._plotgrid_ax = kwargs.pop("ax", None) + + def _create_figure(self): + def set_spines(ax): + ax.spines['left'].set_position('zero') + ax.spines['right'].set_color('none') + ax.spines['bottom'].set_position('zero') + ax.spines['top'].set_color('none') + ax.xaxis.set_ticks_position('bottom') + ax.yaxis.set_ticks_position('left') + + if self._plotgrid_fig is not None: + self.fig = self._plotgrid_fig + self.ax = self._plotgrid_ax + if not any(s.is_3D for s in self._series): + set_spines(self.ax) + else: + self.fig = self.plt.figure(figsize=self.size) + if any(s.is_3D for s in self._series): + self.ax = self.fig.add_subplot(1, 1, 1, projection="3d") + else: + self.ax = self.fig.add_subplot(1, 1, 1) + set_spines(self.ax) + + @staticmethod + def get_segments(x, y, z=None): + """ Convert two list of coordinates to a list of segments to be used + with Matplotlib's :external:class:`~matplotlib.collections.LineCollection`. + + Parameters + ========== + x : list + List of x-coordinates + + y : list + List of y-coordinates + + z : list + List of z-coordinates for a 3D line. + """ + np = import_module('numpy') + if z is not None: + dim = 3 + points = (x, y, z) + else: + dim = 2 + points = (x, y) + points = np.ma.array(points).T.reshape(-1, 1, dim) + return np.ma.concatenate([points[:-1], points[1:]], axis=1) + + def _process_series(self, series, ax): + np = import_module('numpy') + mpl_toolkits = import_module( + 'mpl_toolkits', import_kwargs={'fromlist': ['mplot3d']}) + + # XXX Workaround for matplotlib issue + # https://github.com/matplotlib/matplotlib/issues/17130 + xlims, ylims, zlims = [], [], [] + + for s in series: + # Create the collections + if s.is_2Dline: + if s.is_parametric: + x, y, param = s.get_data() + else: + x, y = s.get_data() + if (isinstance(s.line_color, (int, float)) or + callable(s.line_color)): + segments = self.get_segments(x, y) + collection = self.LineCollection(segments) + collection.set_array(s.get_color_array()) + ax.add_collection(collection) + else: + lbl = _str_or_latex(s.label) + line, = ax.plot(x, y, label=lbl, color=s.line_color) + elif s.is_contour: + ax.contour(*s.get_data()) + elif s.is_3Dline: + x, y, z, param = s.get_data() + if (isinstance(s.line_color, (int, float)) or + callable(s.line_color)): + art3d = mpl_toolkits.mplot3d.art3d + segments = self.get_segments(x, y, z) + collection = art3d.Line3DCollection(segments) + collection.set_array(s.get_color_array()) + ax.add_collection(collection) + else: + lbl = _str_or_latex(s.label) + ax.plot(x, y, z, label=lbl, color=s.line_color) + + xlims.append(s._xlim) + ylims.append(s._ylim) + zlims.append(s._zlim) + elif s.is_3Dsurface: + if s.is_parametric: + x, y, z, u, v = s.get_data() + else: + x, y, z = s.get_data() + collection = ax.plot_surface(x, y, z, + cmap=getattr(self.cm, 'viridis', self.cm.jet), + rstride=1, cstride=1, linewidth=0.1) + if isinstance(s.surface_color, (float, int, Callable)): + color_array = s.get_color_array() + color_array = color_array.reshape(color_array.size) + collection.set_array(color_array) + else: + collection.set_color(s.surface_color) + + xlims.append(s._xlim) + ylims.append(s._ylim) + zlims.append(s._zlim) + elif s.is_implicit: + points = s.get_data() + if len(points) == 2: + # interval math plotting + x, y = _matplotlib_list(points[0]) + ax.fill(x, y, facecolor=s.line_color, edgecolor='None') + else: + # use contourf or contour depending on whether it is + # an inequality or equality. + # XXX: ``contour`` plots multiple lines. Should be fixed. + ListedColormap = self.matplotlib.colors.ListedColormap + colormap = ListedColormap(["white", s.line_color]) + xarray, yarray, zarray, plot_type = points + if plot_type == 'contour': + ax.contour(xarray, yarray, zarray, cmap=colormap) + else: + ax.contourf(xarray, yarray, zarray, cmap=colormap) + elif s.is_generic: + if s.type == "markers": + # s.rendering_kw["color"] = s.line_color + ax.plot(*s.args, **s.rendering_kw) + elif s.type == "annotations": + ax.annotate(*s.args, **s.rendering_kw) + elif s.type == "fill": + # s.rendering_kw["color"] = s.line_color + ax.fill_between(*s.args, **s.rendering_kw) + elif s.type == "rectangles": + # s.rendering_kw["color"] = s.line_color + ax.add_patch( + self.matplotlib.patches.Rectangle( + *s.args, **s.rendering_kw)) + else: + raise NotImplementedError( + '{} is not supported in the SymPy plotting module ' + 'with matplotlib backend. Please report this issue.' + .format(ax)) + + Axes3D = mpl_toolkits.mplot3d.Axes3D + if not isinstance(ax, Axes3D): + ax.autoscale_view( + scalex=ax.get_autoscalex_on(), + scaley=ax.get_autoscaley_on()) + else: + # XXX Workaround for matplotlib issue + # https://github.com/matplotlib/matplotlib/issues/17130 + if xlims: + xlims = np.array(xlims) + xlim = (np.amin(xlims[:, 0]), np.amax(xlims[:, 1])) + ax.set_xlim(xlim) + else: + ax.set_xlim([0, 1]) + + if ylims: + ylims = np.array(ylims) + ylim = (np.amin(ylims[:, 0]), np.amax(ylims[:, 1])) + ax.set_ylim(ylim) + else: + ax.set_ylim([0, 1]) + + if zlims: + zlims = np.array(zlims) + zlim = (np.amin(zlims[:, 0]), np.amax(zlims[:, 1])) + ax.set_zlim(zlim) + else: + ax.set_zlim([0, 1]) + + # Set global options. + # TODO The 3D stuff + # XXX The order of those is important. + if self.xscale and not isinstance(ax, Axes3D): + ax.set_xscale(self.xscale) + if self.yscale and not isinstance(ax, Axes3D): + ax.set_yscale(self.yscale) + if not isinstance(ax, Axes3D) or self.matplotlib.__version__ >= '1.2.0': # XXX in the distant future remove this check + ax.set_autoscale_on(self.autoscale) + if self.axis_center: + val = self.axis_center + if isinstance(ax, Axes3D): + pass + elif val == 'center': + ax.spines['left'].set_position('center') + ax.spines['bottom'].set_position('center') + elif val == 'auto': + xl, xh = ax.get_xlim() + yl, yh = ax.get_ylim() + pos_left = ('data', 0) if xl*xh <= 0 else 'center' + pos_bottom = ('data', 0) if yl*yh <= 0 else 'center' + ax.spines['left'].set_position(pos_left) + ax.spines['bottom'].set_position(pos_bottom) + else: + ax.spines['left'].set_position(('data', val[0])) + ax.spines['bottom'].set_position(('data', val[1])) + if not self.axis: + ax.set_axis_off() + if self.legend: + if ax.legend(): + ax.legend_.set_visible(self.legend) + if self.margin: + ax.set_xmargin(self.margin) + ax.set_ymargin(self.margin) + if self.title: + ax.set_title(self.title) + if self.xlabel: + xlbl = _str_or_latex(self.xlabel) + ax.set_xlabel(xlbl, position=(1, 0)) + if self.ylabel: + ylbl = _str_or_latex(self.ylabel) + ax.set_ylabel(ylbl, position=(0, 1)) + if isinstance(ax, Axes3D) and self.zlabel: + zlbl = _str_or_latex(self.zlabel) + ax.set_zlabel(zlbl, position=(0, 1)) + + # xlim and ylim should always be set at last so that plot limits + # doesn't get altered during the process. + if self.xlim: + ax.set_xlim(self.xlim) + if self.ylim: + ax.set_ylim(self.ylim) + self.ax.set_aspect(self.aspect) + + + def process_series(self): + """ + Iterates over every ``Plot`` object and further calls + _process_series() + """ + self._create_figure() + self._process_series(self._series, self.ax) + + def show(self): + self.process_series() + #TODO after fixing https://github.com/ipython/ipython/issues/1255 + # you can uncomment the next line and remove the pyplot.show() call + #self.fig.show() + if base_backend._show: + self.fig.tight_layout() + self.plt.show() + else: + self.close() + + def save(self, path): + self.process_series() + self.fig.savefig(path) + + def close(self): + self.plt.close(self.fig) diff --git a/phivenv/Lib/site-packages/sympy/plotting/backends/textbackend/__init__.py b/phivenv/Lib/site-packages/sympy/plotting/backends/textbackend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ca4685e4b7790653a97b712c27b240ade5bb481a --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/plotting/backends/textbackend/__init__.py @@ -0,0 +1,3 @@ +from sympy.plotting.backends.textbackend.text import TextBackend + +__all__ = ["TextBackend"] diff --git a/phivenv/Lib/site-packages/sympy/plotting/backends/textbackend/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/plotting/backends/textbackend/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fbe3d577d8c2cd999cdce5675254343a40154389 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/plotting/backends/textbackend/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/plotting/backends/textbackend/__pycache__/text.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/plotting/backends/textbackend/__pycache__/text.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a9fa8761965240f274bdbdfe6f2e73613474ff2c Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/plotting/backends/textbackend/__pycache__/text.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/plotting/backends/textbackend/text.py b/phivenv/Lib/site-packages/sympy/plotting/backends/textbackend/text.py new file mode 100644 index 0000000000000000000000000000000000000000..0917ec78b3463a929c373c98fdd279d84ce4c9e5 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/plotting/backends/textbackend/text.py @@ -0,0 +1,24 @@ +import sympy.plotting.backends.base_backend as base_backend +from sympy.plotting.series import LineOver1DRangeSeries +from sympy.plotting.textplot import textplot + + +class TextBackend(base_backend.Plot): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def show(self): + if not base_backend._show: + return + if len(self._series) != 1: + raise ValueError( + 'The TextBackend supports only one graph per Plot.') + elif not isinstance(self._series[0], LineOver1DRangeSeries): + raise ValueError( + 'The TextBackend supports only expressions over a 1D range') + else: + ser = self._series[0] + textplot(ser.expr, ser.start, ser.end) + + def close(self): + pass diff --git a/phivenv/Lib/site-packages/sympy/plotting/experimental_lambdify.py b/phivenv/Lib/site-packages/sympy/plotting/experimental_lambdify.py new file mode 100644 index 0000000000000000000000000000000000000000..ae17e7adf45f2933ccd71514917199c85d14549e --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/plotting/experimental_lambdify.py @@ -0,0 +1,641 @@ +""" rewrite of lambdify - This stuff is not stable at all. + +It is for internal use in the new plotting module. +It may (will! see the Q'n'A in the source) be rewritten. + +It's completely self contained. Especially it does not use lambdarepr. + +It does not aim to replace the current lambdify. Most importantly it will never +ever support anything else than SymPy expressions (no Matrices, dictionaries +and so on). +""" + + +import re +from sympy.core.numbers import (I, NumberSymbol, oo, zoo) +from sympy.core.symbol import Symbol +from sympy.utilities.iterables import numbered_symbols + +# We parse the expression string into a tree that identifies functions. Then +# we translate the names of the functions and we translate also some strings +# that are not names of functions (all this according to translation +# dictionaries). +# If the translation goes to another module (like numpy) the +# module is imported and 'func' is translated to 'module.func'. +# If a function can not be translated, the inner nodes of that part of the +# tree are not translated. So if we have Integral(sqrt(x)), sqrt is not +# translated to np.sqrt and the Integral does not crash. +# A namespace for all this is generated by crawling the (func, args) tree of +# the expression. The creation of this namespace involves many ugly +# workarounds. +# The namespace consists of all the names needed for the SymPy expression and +# all the name of modules used for translation. Those modules are imported only +# as a name (import numpy as np) in order to keep the namespace small and +# manageable. + +# Please, if there is a bug, do not try to fix it here! Rewrite this by using +# the method proposed in the last Q'n'A below. That way the new function will +# work just as well, be just as simple, but it wont need any new workarounds. +# If you insist on fixing it here, look at the workarounds in the function +# sympy_expression_namespace and in lambdify. + +# Q: Why are you not using Python abstract syntax tree? +# A: Because it is more complicated and not much more powerful in this case. + +# Q: What if I have Symbol('sin') or g=Function('f')? +# A: You will break the algorithm. We should use srepr to defend against this? +# The problem with Symbol('sin') is that it will be printed as 'sin'. The +# parser will distinguish it from the function 'sin' because functions are +# detected thanks to the opening parenthesis, but the lambda expression won't +# understand the difference if we have also the sin function. +# The solution (complicated) is to use srepr and maybe ast. +# The problem with the g=Function('f') is that it will be printed as 'f' but in +# the global namespace we have only 'g'. But as the same printer is used in the +# constructor of the namespace there will be no problem. + +# Q: What if some of the printers are not printing as expected? +# A: The algorithm wont work. You must use srepr for those cases. But even +# srepr may not print well. All problems with printers should be considered +# bugs. + +# Q: What about _imp_ functions? +# A: Those are taken care for by evalf. A special case treatment will work +# faster but it's not worth the code complexity. + +# Q: Will ast fix all possible problems? +# A: No. You will always have to use some printer. Even srepr may not work in +# some cases. But if the printer does not work, that should be considered a +# bug. + +# Q: Is there same way to fix all possible problems? +# A: Probably by constructing our strings ourself by traversing the (func, +# args) tree and creating the namespace at the same time. That actually sounds +# good. + +from sympy.external import import_module +import warnings + +#TODO debugging output + + +class vectorized_lambdify: + """ Return a sufficiently smart, vectorized and lambdified function. + + Returns only reals. + + Explanation + =========== + + This function uses experimental_lambdify to created a lambdified + expression ready to be used with numpy. Many of the functions in SymPy + are not implemented in numpy so in some cases we resort to Python cmath or + even to evalf. + + The following translations are tried: + only numpy complex + - on errors raised by SymPy trying to work with ndarray: + only Python cmath and then vectorize complex128 + + When using Python cmath there is no need for evalf or float/complex + because Python cmath calls those. + + This function never tries to mix numpy directly with evalf because numpy + does not understand SymPy Float. If this is needed one can use the + float_wrap_evalf/complex_wrap_evalf options of experimental_lambdify or + better one can be explicit about the dtypes that numpy works with. + Check numpy bug http://projects.scipy.org/numpy/ticket/1013 to know what + types of errors to expect. + """ + def __init__(self, args, expr): + self.args = args + self.expr = expr + self.np = import_module('numpy') + + self.lambda_func_1 = experimental_lambdify( + args, expr, use_np=True) + self.vector_func_1 = self.lambda_func_1 + + self.lambda_func_2 = experimental_lambdify( + args, expr, use_python_cmath=True) + self.vector_func_2 = self.np.vectorize( + self.lambda_func_2, otypes=[complex]) + + self.vector_func = self.vector_func_1 + self.failure = False + + def __call__(self, *args): + np = self.np + + try: + temp_args = (np.array(a, dtype=complex) for a in args) + results = self.vector_func(*temp_args) + results = np.ma.masked_where( + np.abs(results.imag) > 1e-7 * np.abs(results), + results.real, copy=False) + return results + except ValueError: + if self.failure: + raise + + self.failure = True + self.vector_func = self.vector_func_2 + warnings.warn( + 'The evaluation of the expression is problematic. ' + 'We are trying a failback method that may still work. ' + 'Please report this as a bug.') + return self.__call__(*args) + + +class lambdify: + """Returns the lambdified function. + + Explanation + =========== + + This function uses experimental_lambdify to create a lambdified + expression. It uses cmath to lambdify the expression. If the function + is not implemented in Python cmath, Python cmath calls evalf on those + functions. + """ + + def __init__(self, args, expr): + self.args = args + self.expr = expr + self.lambda_func_1 = experimental_lambdify( + args, expr, use_python_cmath=True, use_evalf=True) + self.lambda_func_2 = experimental_lambdify( + args, expr, use_python_math=True, use_evalf=True) + self.lambda_func_3 = experimental_lambdify( + args, expr, use_evalf=True, complex_wrap_evalf=True) + self.lambda_func = self.lambda_func_1 + self.failure = False + + def __call__(self, args): + try: + #The result can be sympy.Float. Hence wrap it with complex type. + result = complex(self.lambda_func(args)) + if abs(result.imag) > 1e-7 * abs(result): + return None + return result.real + except (ZeroDivisionError, OverflowError): + return None + except TypeError as e: + if self.failure: + raise e + + if self.lambda_func == self.lambda_func_1: + self.lambda_func = self.lambda_func_2 + return self.__call__(args) + + self.failure = True + self.lambda_func = self.lambda_func_3 + warnings.warn( + 'The evaluation of the expression is problematic. ' + 'We are trying a failback method that may still work. ' + 'Please report this as a bug.', stacklevel=2) + return self.__call__(args) + + +def experimental_lambdify(*args, **kwargs): + l = Lambdifier(*args, **kwargs) + return l + + +class Lambdifier: + def __init__(self, args, expr, print_lambda=False, use_evalf=False, + float_wrap_evalf=False, complex_wrap_evalf=False, + use_np=False, use_python_math=False, use_python_cmath=False, + use_interval=False): + + self.print_lambda = print_lambda + self.use_evalf = use_evalf + self.float_wrap_evalf = float_wrap_evalf + self.complex_wrap_evalf = complex_wrap_evalf + self.use_np = use_np + self.use_python_math = use_python_math + self.use_python_cmath = use_python_cmath + self.use_interval = use_interval + + # Constructing the argument string + # - check + if not all(isinstance(a, Symbol) for a in args): + raise ValueError('The arguments must be Symbols.') + # - use numbered symbols + syms = numbered_symbols(exclude=expr.free_symbols) + newargs = [next(syms) for _ in args] + expr = expr.xreplace(dict(zip(args, newargs))) + argstr = ', '.join([str(a) for a in newargs]) + del syms, newargs, args + + # Constructing the translation dictionaries and making the translation + self.dict_str = self.get_dict_str() + self.dict_fun = self.get_dict_fun() + exprstr = str(expr) + newexpr = self.tree2str_translate(self.str2tree(exprstr)) + + # Constructing the namespaces + namespace = {} + namespace.update(self.sympy_atoms_namespace(expr)) + namespace.update(self.sympy_expression_namespace(expr)) + # XXX Workaround + # Ugly workaround because Pow(a,Half) prints as sqrt(a) + # and sympy_expression_namespace can not catch it. + from sympy.functions.elementary.miscellaneous import sqrt + namespace.update({'sqrt': sqrt}) + namespace.update({'Eq': lambda x, y: x == y}) + namespace.update({'Ne': lambda x, y: x != y}) + # End workaround. + if use_python_math: + namespace.update({'math': __import__('math')}) + if use_python_cmath: + namespace.update({'cmath': __import__('cmath')}) + if use_np: + try: + namespace.update({'np': __import__('numpy')}) + except ImportError: + raise ImportError( + 'experimental_lambdify failed to import numpy.') + if use_interval: + namespace.update({'imath': __import__( + 'sympy.plotting.intervalmath', fromlist=['intervalmath'])}) + namespace.update({'math': __import__('math')}) + + # Construct the lambda + if self.print_lambda: + print(newexpr) + eval_str = 'lambda %s : ( %s )' % (argstr, newexpr) + self.eval_str = eval_str + exec("MYNEWLAMBDA = %s" % eval_str, namespace) + self.lambda_func = namespace['MYNEWLAMBDA'] + + def __call__(self, *args, **kwargs): + return self.lambda_func(*args, **kwargs) + + + ############################################################################## + # Dicts for translating from SymPy to other modules + ############################################################################## + ### + # builtins + ### + # Functions with different names in builtins + builtin_functions_different = { + 'Min': 'min', + 'Max': 'max', + 'Abs': 'abs', + } + + # Strings that should be translated + builtin_not_functions = { + 'I': '1j', +# 'oo': '1e400', + } + + ### + # numpy + ### + + # Functions that are the same in numpy + numpy_functions_same = [ + 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'exp', 'log', + 'sqrt', 'floor', 'conjugate', 'sign', + ] + + # Functions with different names in numpy + numpy_functions_different = { + "acos": "arccos", + "acosh": "arccosh", + "arg": "angle", + "asin": "arcsin", + "asinh": "arcsinh", + "atan": "arctan", + "atan2": "arctan2", + "atanh": "arctanh", + "ceiling": "ceil", + "im": "imag", + "ln": "log", + "Max": "amax", + "Min": "amin", + "re": "real", + "Abs": "abs", + } + + # Strings that should be translated + numpy_not_functions = { + 'pi': 'np.pi', + 'oo': 'np.inf', + 'E': 'np.e', + } + + ### + # Python math + ### + + # Functions that are the same in math + math_functions_same = [ + 'sin', 'cos', 'tan', 'asin', 'acos', 'atan', 'atan2', + 'sinh', 'cosh', 'tanh', 'asinh', 'acosh', 'atanh', + 'exp', 'log', 'erf', 'sqrt', 'floor', 'factorial', 'gamma', + ] + + # Functions with different names in math + math_functions_different = { + 'ceiling': 'ceil', + 'ln': 'log', + 'loggamma': 'lgamma' + } + + # Strings that should be translated + math_not_functions = { + 'pi': 'math.pi', + 'E': 'math.e', + } + + ### + # Python cmath + ### + + # Functions that are the same in cmath + cmath_functions_same = [ + 'sin', 'cos', 'tan', 'asin', 'acos', 'atan', + 'sinh', 'cosh', 'tanh', 'asinh', 'acosh', 'atanh', + 'exp', 'log', 'sqrt', + ] + + # Functions with different names in cmath + cmath_functions_different = { + 'ln': 'log', + 'arg': 'phase', + } + + # Strings that should be translated + cmath_not_functions = { + 'pi': 'cmath.pi', + 'E': 'cmath.e', + } + + ### + # intervalmath + ### + + interval_not_functions = { + 'pi': 'math.pi', + 'E': 'math.e' + } + + interval_functions_same = [ + 'sin', 'cos', 'exp', 'tan', 'atan', 'log', + 'sqrt', 'cosh', 'sinh', 'tanh', 'floor', + 'acos', 'asin', 'acosh', 'asinh', 'atanh', + 'Abs', 'And', 'Or' + ] + + interval_functions_different = { + 'Min': 'imin', + 'Max': 'imax', + 'ceiling': 'ceil', + + } + + ### + # mpmath, etc + ### + #TODO + + ### + # Create the final ordered tuples of dictionaries + ### + + # For strings + def get_dict_str(self): + dict_str = dict(self.builtin_not_functions) + if self.use_np: + dict_str.update(self.numpy_not_functions) + if self.use_python_math: + dict_str.update(self.math_not_functions) + if self.use_python_cmath: + dict_str.update(self.cmath_not_functions) + if self.use_interval: + dict_str.update(self.interval_not_functions) + return dict_str + + # For functions + def get_dict_fun(self): + dict_fun = dict(self.builtin_functions_different) + if self.use_np: + for s in self.numpy_functions_same: + dict_fun[s] = 'np.' + s + for k, v in self.numpy_functions_different.items(): + dict_fun[k] = 'np.' + v + if self.use_python_math: + for s in self.math_functions_same: + dict_fun[s] = 'math.' + s + for k, v in self.math_functions_different.items(): + dict_fun[k] = 'math.' + v + if self.use_python_cmath: + for s in self.cmath_functions_same: + dict_fun[s] = 'cmath.' + s + for k, v in self.cmath_functions_different.items(): + dict_fun[k] = 'cmath.' + v + if self.use_interval: + for s in self.interval_functions_same: + dict_fun[s] = 'imath.' + s + for k, v in self.interval_functions_different.items(): + dict_fun[k] = 'imath.' + v + return dict_fun + + ############################################################################## + # The translator functions, tree parsers, etc. + ############################################################################## + + def str2tree(self, exprstr): + """Converts an expression string to a tree. + + Explanation + =========== + + Functions are represented by ('func_name(', tree_of_arguments). + Other expressions are (head_string, mid_tree, tail_str). + Expressions that do not contain functions are directly returned. + + Examples + ======== + + >>> from sympy.abc import x, y, z + >>> from sympy import Integral, sin + >>> from sympy.plotting.experimental_lambdify import Lambdifier + >>> str2tree = Lambdifier([x], x).str2tree + + >>> str2tree(str(Integral(x, (x, 1, y)))) + ('', ('Integral(', 'x, (x, 1, y)'), ')') + >>> str2tree(str(x+y)) + 'x + y' + >>> str2tree(str(x+y*sin(z)+1)) + ('x + y*', ('sin(', 'z'), ') + 1') + >>> str2tree('sin(y*(y + 1.1) + (sin(y)))') + ('', ('sin(', ('y*(y + 1.1) + (', ('sin(', 'y'), '))')), ')') + """ + #matches the first 'function_name(' + first_par = re.search(r'(\w+\()', exprstr) + if first_par is None: + return exprstr + else: + start = first_par.start() + end = first_par.end() + head = exprstr[:start] + func = exprstr[start:end] + tail = exprstr[end:] + count = 0 + for i, c in enumerate(tail): + if c == '(': + count += 1 + elif c == ')': + count -= 1 + if count == -1: + break + func_tail = self.str2tree(tail[:i]) + tail = self.str2tree(tail[i:]) + return (head, (func, func_tail), tail) + + @classmethod + def tree2str(cls, tree): + """Converts a tree to string without translations. + + Examples + ======== + + >>> from sympy.abc import x, y, z + >>> from sympy import sin + >>> from sympy.plotting.experimental_lambdify import Lambdifier + >>> str2tree = Lambdifier([x], x).str2tree + >>> tree2str = Lambdifier([x], x).tree2str + + >>> tree2str(str2tree(str(x+y*sin(z)+1))) + 'x + y*sin(z) + 1' + """ + if isinstance(tree, str): + return tree + else: + return ''.join(map(cls.tree2str, tree)) + + def tree2str_translate(self, tree): + """Converts a tree to string with translations. + + Explanation + =========== + + Function names are translated by translate_func. + Other strings are translated by translate_str. + """ + if isinstance(tree, str): + return self.translate_str(tree) + elif isinstance(tree, tuple) and len(tree) == 2: + return self.translate_func(tree[0][:-1], tree[1]) + else: + return ''.join([self.tree2str_translate(t) for t in tree]) + + def translate_str(self, estr): + """Translate substrings of estr using in order the dictionaries in + dict_tuple_str.""" + for pattern, repl in self.dict_str.items(): + estr = re.sub(pattern, repl, estr) + return estr + + def translate_func(self, func_name, argtree): + """Translate function names and the tree of arguments. + + Explanation + =========== + + If the function name is not in the dictionaries of dict_tuple_fun then the + function is surrounded by a float((...).evalf()). + + The use of float is necessary as np.(sympy.Float(..)) raises an + error.""" + if func_name in self.dict_fun: + new_name = self.dict_fun[func_name] + argstr = self.tree2str_translate(argtree) + return new_name + '(' + argstr + elif func_name in ['Eq', 'Ne']: + op = {'Eq': '==', 'Ne': '!='} + return "(lambda x, y: x {} y)({}".format(op[func_name], self.tree2str_translate(argtree)) + else: + template = '(%s(%s)).evalf(' if self.use_evalf else '%s(%s' + if self.float_wrap_evalf: + template = 'float(%s)' % template + elif self.complex_wrap_evalf: + template = 'complex(%s)' % template + + # Wrapping should only happen on the outermost expression, which + # is the only thing we know will be a number. + float_wrap_evalf = self.float_wrap_evalf + complex_wrap_evalf = self.complex_wrap_evalf + self.float_wrap_evalf = False + self.complex_wrap_evalf = False + ret = template % (func_name, self.tree2str_translate(argtree)) + self.float_wrap_evalf = float_wrap_evalf + self.complex_wrap_evalf = complex_wrap_evalf + return ret + + ############################################################################## + # The namespace constructors + ############################################################################## + + @classmethod + def sympy_expression_namespace(cls, expr): + """Traverses the (func, args) tree of an expression and creates a SymPy + namespace. All other modules are imported only as a module name. That way + the namespace is not polluted and rests quite small. It probably causes much + more variable lookups and so it takes more time, but there are no tests on + that for the moment.""" + if expr is None: + return {} + else: + funcname = str(expr.func) + # XXX Workaround + # Here we add an ugly workaround because str(func(x)) + # is not always the same as str(func). Eg + # >>> str(Integral(x)) + # "Integral(x)" + # >>> str(Integral) + # "" + # >>> str(sqrt(x)) + # "sqrt(x)" + # >>> str(sqrt) + # "" + # >>> str(sin(x)) + # "sin(x)" + # >>> str(sin) + # "sin" + # Either one of those can be used but not all at the same time. + # The code considers the sin example as the right one. + regexlist = [ + r'$', + # the example Integral + r'$', # the example sqrt + ] + for r in regexlist: + m = re.match(r, funcname) + if m is not None: + funcname = m.groups()[0] + # End of the workaround + # XXX debug: print funcname + args_dict = {} + for a in expr.args: + if (isinstance(a, (Symbol, NumberSymbol)) or a in [I, zoo, oo]): + continue + else: + args_dict.update(cls.sympy_expression_namespace(a)) + args_dict.update({funcname: expr.func}) + return args_dict + + @staticmethod + def sympy_atoms_namespace(expr): + """For no real reason this function is separated from + sympy_expression_namespace. It can be moved to it.""" + atoms = expr.atoms(Symbol, NumberSymbol, I, zoo, oo) + d = {} + for a in atoms: + # XXX debug: print 'atom:' + str(a) + d[str(a)] = a + return d diff --git a/phivenv/Lib/site-packages/sympy/plotting/intervalmath/__init__.py b/phivenv/Lib/site-packages/sympy/plotting/intervalmath/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fb9a6a57f94e931f0c5f5b3dda7b0b6fd31841f4 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/plotting/intervalmath/__init__.py @@ -0,0 +1,12 @@ +from .interval_arithmetic import interval +from .lib_interval import (Abs, exp, log, log10, sin, cos, tan, sqrt, + imin, imax, sinh, cosh, tanh, acosh, asinh, atanh, + asin, acos, atan, ceil, floor, And, Or) + +__all__ = [ + 'interval', + + 'Abs', 'exp', 'log', 'log10', 'sin', 'cos', 'tan', 'sqrt', 'imin', 'imax', + 'sinh', 'cosh', 'tanh', 'acosh', 'asinh', 'atanh', 'asin', 'acos', 'atan', + 'ceil', 'floor', 'And', 'Or', +] diff --git a/phivenv/Lib/site-packages/sympy/plotting/intervalmath/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/plotting/intervalmath/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fdcfcbafeb630affe50dc87187d63064253f66a8 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/plotting/intervalmath/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/plotting/intervalmath/__pycache__/interval_arithmetic.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/plotting/intervalmath/__pycache__/interval_arithmetic.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef0888e52c86a7059115f0957a211e55e8a4d85d Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/plotting/intervalmath/__pycache__/interval_arithmetic.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/plotting/intervalmath/__pycache__/interval_membership.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/plotting/intervalmath/__pycache__/interval_membership.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b345b5b7b59444b47fd48333185781dc0ee004a2 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/plotting/intervalmath/__pycache__/interval_membership.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/plotting/intervalmath/__pycache__/lib_interval.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/plotting/intervalmath/__pycache__/lib_interval.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5245bdde48351983551094944680db32043f4d21 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/plotting/intervalmath/__pycache__/lib_interval.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/plotting/intervalmath/interval_arithmetic.py b/phivenv/Lib/site-packages/sympy/plotting/intervalmath/interval_arithmetic.py new file mode 100644 index 0000000000000000000000000000000000000000..fc5c0e2ef118c7cf4f80de53a3590de11130410e --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/plotting/intervalmath/interval_arithmetic.py @@ -0,0 +1,413 @@ +""" +Interval Arithmetic for plotting. +This module does not implement interval arithmetic accurately and +hence cannot be used for purposes other than plotting. If you want +to use interval arithmetic, use mpmath's interval arithmetic. + +The module implements interval arithmetic using numpy and +python floating points. The rounding up and down is not handled +and hence this is not an accurate implementation of interval +arithmetic. + +The module uses numpy for speed which cannot be achieved with mpmath. +""" + +# Q: Why use numpy? Why not simply use mpmath's interval arithmetic? +# A: mpmath's interval arithmetic simulates a floating point unit +# and hence is slow, while numpy evaluations are orders of magnitude +# faster. + +# Q: Why create a separate class for intervals? Why not use SymPy's +# Interval Sets? +# A: The functionalities that will be required for plotting is quite +# different from what Interval Sets implement. + +# Q: Why is rounding up and down according to IEEE754 not handled? +# A: It is not possible to do it in both numpy and python. An external +# library has to used, which defeats the whole purpose i.e., speed. Also +# rounding is handled for very few functions in those libraries. + +# Q Will my plots be affected? +# A It will not affect most of the plots. The interval arithmetic +# module based suffers the same problems as that of floating point +# arithmetic. + +from sympy.core.numbers import int_valued +from sympy.core.logic import fuzzy_and +from sympy.simplify.simplify import nsimplify + +from .interval_membership import intervalMembership + + +class interval: + """ Represents an interval containing floating points as start and + end of the interval + The is_valid variable tracks whether the interval obtained as the + result of the function is in the domain and is continuous. + - True: Represents the interval result of a function is continuous and + in the domain of the function. + - False: The interval argument of the function was not in the domain of + the function, hence the is_valid of the result interval is False + - None: The function was not continuous over the interval or + the function's argument interval is partly in the domain of the + function + + A comparison between an interval and a real number, or a + comparison between two intervals may return ``intervalMembership`` + of two 3-valued logic values. + """ + + def __init__(self, *args, is_valid=True, **kwargs): + self.is_valid = is_valid + if len(args) == 1: + if isinstance(args[0], interval): + self.start, self.end = args[0].start, args[0].end + else: + self.start = float(args[0]) + self.end = float(args[0]) + elif len(args) == 2: + if args[0] < args[1]: + self.start = float(args[0]) + self.end = float(args[1]) + else: + self.start = float(args[1]) + self.end = float(args[0]) + + else: + raise ValueError("interval takes a maximum of two float values " + "as arguments") + + @property + def mid(self): + return (self.start + self.end) / 2.0 + + @property + def width(self): + return self.end - self.start + + def __repr__(self): + return "interval(%f, %f)" % (self.start, self.end) + + def __str__(self): + return "[%f, %f]" % (self.start, self.end) + + def __lt__(self, other): + if isinstance(other, (int, float)): + if self.end < other: + return intervalMembership(True, self.is_valid) + elif self.start > other: + return intervalMembership(False, self.is_valid) + else: + return intervalMembership(None, self.is_valid) + + elif isinstance(other, interval): + valid = fuzzy_and([self.is_valid, other.is_valid]) + if self.end < other. start: + return intervalMembership(True, valid) + if self.start > other.end: + return intervalMembership(False, valid) + return intervalMembership(None, valid) + else: + return NotImplemented + + def __gt__(self, other): + if isinstance(other, (int, float)): + if self.start > other: + return intervalMembership(True, self.is_valid) + elif self.end < other: + return intervalMembership(False, self.is_valid) + else: + return intervalMembership(None, self.is_valid) + elif isinstance(other, interval): + return other.__lt__(self) + else: + return NotImplemented + + def __eq__(self, other): + if isinstance(other, (int, float)): + if self.start == other and self.end == other: + return intervalMembership(True, self.is_valid) + if other in self: + return intervalMembership(None, self.is_valid) + else: + return intervalMembership(False, self.is_valid) + + if isinstance(other, interval): + valid = fuzzy_and([self.is_valid, other.is_valid]) + if self.start == other.start and self.end == other.end: + return intervalMembership(True, valid) + elif self.__lt__(other)[0] is not None: + return intervalMembership(False, valid) + else: + return intervalMembership(None, valid) + else: + return NotImplemented + + def __ne__(self, other): + if isinstance(other, (int, float)): + if self.start == other and self.end == other: + return intervalMembership(False, self.is_valid) + if other in self: + return intervalMembership(None, self.is_valid) + else: + return intervalMembership(True, self.is_valid) + + if isinstance(other, interval): + valid = fuzzy_and([self.is_valid, other.is_valid]) + if self.start == other.start and self.end == other.end: + return intervalMembership(False, valid) + if not self.__lt__(other)[0] is None: + return intervalMembership(True, valid) + return intervalMembership(None, valid) + else: + return NotImplemented + + def __le__(self, other): + if isinstance(other, (int, float)): + if self.end <= other: + return intervalMembership(True, self.is_valid) + if self.start > other: + return intervalMembership(False, self.is_valid) + else: + return intervalMembership(None, self.is_valid) + + if isinstance(other, interval): + valid = fuzzy_and([self.is_valid, other.is_valid]) + if self.end <= other.start: + return intervalMembership(True, valid) + if self.start > other.end: + return intervalMembership(False, valid) + return intervalMembership(None, valid) + else: + return NotImplemented + + def __ge__(self, other): + if isinstance(other, (int, float)): + if self.start >= other: + return intervalMembership(True, self.is_valid) + elif self.end < other: + return intervalMembership(False, self.is_valid) + else: + return intervalMembership(None, self.is_valid) + elif isinstance(other, interval): + return other.__le__(self) + + def __add__(self, other): + if isinstance(other, (int, float)): + if self.is_valid: + return interval(self.start + other, self.end + other) + else: + start = self.start + other + end = self.end + other + return interval(start, end, is_valid=self.is_valid) + + elif isinstance(other, interval): + start = self.start + other.start + end = self.end + other.end + valid = fuzzy_and([self.is_valid, other.is_valid]) + return interval(start, end, is_valid=valid) + else: + return NotImplemented + + __radd__ = __add__ + + def __sub__(self, other): + if isinstance(other, (int, float)): + start = self.start - other + end = self.end - other + return interval(start, end, is_valid=self.is_valid) + + elif isinstance(other, interval): + start = self.start - other.end + end = self.end - other.start + valid = fuzzy_and([self.is_valid, other.is_valid]) + return interval(start, end, is_valid=valid) + else: + return NotImplemented + + def __rsub__(self, other): + if isinstance(other, (int, float)): + start = other - self.end + end = other - self.start + return interval(start, end, is_valid=self.is_valid) + elif isinstance(other, interval): + return other.__sub__(self) + else: + return NotImplemented + + def __neg__(self): + if self.is_valid: + return interval(-self.end, -self.start) + else: + return interval(-self.end, -self.start, is_valid=self.is_valid) + + def __mul__(self, other): + if isinstance(other, interval): + if self.is_valid is False or other.is_valid is False: + return interval(-float('inf'), float('inf'), is_valid=False) + elif self.is_valid is None or other.is_valid is None: + return interval(-float('inf'), float('inf'), is_valid=None) + else: + inters = [] + inters.append(self.start * other.start) + inters.append(self.end * other.start) + inters.append(self.start * other.end) + inters.append(self.end * other.end) + start = min(inters) + end = max(inters) + return interval(start, end) + elif isinstance(other, (int, float)): + return interval(self.start*other, self.end*other, is_valid=self.is_valid) + else: + return NotImplemented + + __rmul__ = __mul__ + + def __contains__(self, other): + if isinstance(other, (int, float)): + return self.start <= other and self.end >= other + else: + return self.start <= other.start and other.end <= self.end + + def __rtruediv__(self, other): + if isinstance(other, (int, float)): + other = interval(other) + return other.__truediv__(self) + elif isinstance(other, interval): + return other.__truediv__(self) + else: + return NotImplemented + + def __truediv__(self, other): + # Both None and False are handled + if not self.is_valid: + # Don't divide as the value is not valid + return interval(-float('inf'), float('inf'), is_valid=self.is_valid) + if isinstance(other, (int, float)): + if other == 0: + # Divide by zero encountered. valid nowhere + return interval(-float('inf'), float('inf'), is_valid=False) + else: + return interval(self.start / other, self.end / other) + + elif isinstance(other, interval): + if other.is_valid is False or self.is_valid is False: + return interval(-float('inf'), float('inf'), is_valid=False) + elif other.is_valid is None or self.is_valid is None: + return interval(-float('inf'), float('inf'), is_valid=None) + else: + # denominator contains both signs, i.e. being divided by zero + # return the whole real line with is_valid = None + if 0 in other: + return interval(-float('inf'), float('inf'), is_valid=None) + + # denominator negative + this = self + if other.end < 0: + this = -this + other = -other + + # denominator positive + inters = [] + inters.append(this.start / other.start) + inters.append(this.end / other.start) + inters.append(this.start / other.end) + inters.append(this.end / other.end) + start = max(inters) + end = min(inters) + return interval(start, end) + else: + return NotImplemented + + def __pow__(self, other): + # Implements only power to an integer. + from .lib_interval import exp, log + if not self.is_valid: + return self + if isinstance(other, interval): + return exp(other * log(self)) + elif isinstance(other, (float, int)): + if other < 0: + return 1 / self.__pow__(abs(other)) + else: + if int_valued(other): + return _pow_int(self, other) + else: + return _pow_float(self, other) + else: + return NotImplemented + + def __rpow__(self, other): + if isinstance(other, (float, int)): + if not self.is_valid: + #Don't do anything + return self + elif other < 0: + if self.width > 0: + return interval(-float('inf'), float('inf'), is_valid=False) + else: + power_rational = nsimplify(self.start) + num, denom = power_rational.as_numer_denom() + if denom % 2 == 0: + return interval(-float('inf'), float('inf'), + is_valid=False) + else: + start = -abs(other)**self.start + end = start + return interval(start, end) + else: + return interval(other**self.start, other**self.end) + elif isinstance(other, interval): + return other.__pow__(self) + else: + return NotImplemented + + def __hash__(self): + return hash((self.is_valid, self.start, self.end)) + + +def _pow_float(inter, power): + """Evaluates an interval raised to a floating point.""" + power_rational = nsimplify(power) + num, denom = power_rational.as_numer_denom() + if num % 2 == 0: + start = abs(inter.start)**power + end = abs(inter.end)**power + if start < 0: + ret = interval(0, max(start, end)) + else: + ret = interval(start, end) + return ret + elif denom % 2 == 0: + if inter.end < 0: + return interval(-float('inf'), float('inf'), is_valid=False) + elif inter.start < 0: + return interval(0, inter.end**power, is_valid=None) + else: + return interval(inter.start**power, inter.end**power) + else: + if inter.start < 0: + start = -abs(inter.start)**power + else: + start = inter.start**power + + if inter.end < 0: + end = -abs(inter.end)**power + else: + end = inter.end**power + + return interval(start, end, is_valid=inter.is_valid) + + +def _pow_int(inter, power): + """Evaluates an interval raised to an integer power""" + power = int(power) + if power & 1: + return interval(inter.start**power, inter.end**power) + else: + if inter.start < 0 and inter.end > 0: + start = 0 + end = max(inter.start**power, inter.end**power) + return interval(start, end) + else: + return interval(inter.start**power, inter.end**power) diff --git a/phivenv/Lib/site-packages/sympy/plotting/intervalmath/interval_membership.py b/phivenv/Lib/site-packages/sympy/plotting/intervalmath/interval_membership.py new file mode 100644 index 0000000000000000000000000000000000000000..c4887c2d96f0d006b95a8e207a4f4a75940aec23 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/plotting/intervalmath/interval_membership.py @@ -0,0 +1,78 @@ +from sympy.core.logic import fuzzy_and, fuzzy_or, fuzzy_not, fuzzy_xor + + +class intervalMembership: + """Represents a boolean expression returned by the comparison of + the interval object. + + Parameters + ========== + + (a, b) : (bool, bool) + The first value determines the comparison as follows: + - True: If the comparison is True throughout the intervals. + - False: If the comparison is False throughout the intervals. + - None: If the comparison is True for some part of the intervals. + + The second value is determined as follows: + - True: If both the intervals in comparison are valid. + - False: If at least one of the intervals is False, else + - None + """ + def __init__(self, a, b): + self._wrapped = (a, b) + + def __getitem__(self, i): + try: + return self._wrapped[i] + except IndexError: + raise IndexError( + "{} must be a valid indexing for the 2-tuple." + .format(i)) + + def __len__(self): + return 2 + + def __iter__(self): + return iter(self._wrapped) + + def __str__(self): + return "intervalMembership({}, {})".format(*self) + __repr__ = __str__ + + def __and__(self, other): + if not isinstance(other, intervalMembership): + raise ValueError( + "The comparison is not supported for {}.".format(other)) + + a1, b1 = self + a2, b2 = other + return intervalMembership(fuzzy_and([a1, a2]), fuzzy_and([b1, b2])) + + def __or__(self, other): + if not isinstance(other, intervalMembership): + raise ValueError( + "The comparison is not supported for {}.".format(other)) + + a1, b1 = self + a2, b2 = other + return intervalMembership(fuzzy_or([a1, a2]), fuzzy_and([b1, b2])) + + def __invert__(self): + a, b = self + return intervalMembership(fuzzy_not(a), b) + + def __xor__(self, other): + if not isinstance(other, intervalMembership): + raise ValueError( + "The comparison is not supported for {}.".format(other)) + + a1, b1 = self + a2, b2 = other + return intervalMembership(fuzzy_xor([a1, a2]), fuzzy_and([b1, b2])) + + def __eq__(self, other): + return self._wrapped == other + + def __ne__(self, other): + return self._wrapped != other diff --git a/phivenv/Lib/site-packages/sympy/plotting/intervalmath/lib_interval.py b/phivenv/Lib/site-packages/sympy/plotting/intervalmath/lib_interval.py new file mode 100644 index 0000000000000000000000000000000000000000..7549a05820d747ce057892f8df1fbcbc61cc3f43 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/plotting/intervalmath/lib_interval.py @@ -0,0 +1,452 @@ +""" The module contains implemented functions for interval arithmetic.""" +from functools import reduce + +from sympy.plotting.intervalmath import interval +from sympy.external import import_module + + +def Abs(x): + if isinstance(x, (int, float)): + return interval(abs(x)) + elif isinstance(x, interval): + if x.start < 0 and x.end > 0: + return interval(0, max(abs(x.start), abs(x.end)), is_valid=x.is_valid) + else: + return interval(abs(x.start), abs(x.end)) + else: + raise NotImplementedError + +#Monotonic + + +def exp(x): + """evaluates the exponential of an interval""" + np = import_module('numpy') + if isinstance(x, (int, float)): + return interval(np.exp(x), np.exp(x)) + elif isinstance(x, interval): + return interval(np.exp(x.start), np.exp(x.end), is_valid=x.is_valid) + else: + raise NotImplementedError + + +#Monotonic +def log(x): + """evaluates the natural logarithm of an interval""" + np = import_module('numpy') + if isinstance(x, (int, float)): + if x <= 0: + return interval(-np.inf, np.inf, is_valid=False) + else: + return interval(np.log(x)) + elif isinstance(x, interval): + if not x.is_valid: + return interval(-np.inf, np.inf, is_valid=x.is_valid) + elif x.end <= 0: + return interval(-np.inf, np.inf, is_valid=False) + elif x.start <= 0: + return interval(-np.inf, np.inf, is_valid=None) + + return interval(np.log(x.start), np.log(x.end)) + else: + raise NotImplementedError + + +#Monotonic +def log10(x): + """evaluates the logarithm to the base 10 of an interval""" + np = import_module('numpy') + if isinstance(x, (int, float)): + if x <= 0: + return interval(-np.inf, np.inf, is_valid=False) + else: + return interval(np.log10(x)) + elif isinstance(x, interval): + if not x.is_valid: + return interval(-np.inf, np.inf, is_valid=x.is_valid) + elif x.end <= 0: + return interval(-np.inf, np.inf, is_valid=False) + elif x.start <= 0: + return interval(-np.inf, np.inf, is_valid=None) + return interval(np.log10(x.start), np.log10(x.end)) + else: + raise NotImplementedError + + +#Monotonic +def atan(x): + """evaluates the tan inverse of an interval""" + np = import_module('numpy') + if isinstance(x, (int, float)): + return interval(np.arctan(x)) + elif isinstance(x, interval): + start = np.arctan(x.start) + end = np.arctan(x.end) + return interval(start, end, is_valid=x.is_valid) + else: + raise NotImplementedError + + +#periodic +def sin(x): + """evaluates the sine of an interval""" + np = import_module('numpy') + if isinstance(x, (int, float)): + return interval(np.sin(x)) + elif isinstance(x, interval): + if not x.is_valid: + return interval(-1, 1, is_valid=x.is_valid) + na, __ = divmod(x.start, np.pi / 2.0) + nb, __ = divmod(x.end, np.pi / 2.0) + start = min(np.sin(x.start), np.sin(x.end)) + end = max(np.sin(x.start), np.sin(x.end)) + if nb - na > 4: + return interval(-1, 1, is_valid=x.is_valid) + elif na == nb: + return interval(start, end, is_valid=x.is_valid) + else: + if (na - 1) // 4 != (nb - 1) // 4: + #sin has max + end = 1 + if (na - 3) // 4 != (nb - 3) // 4: + #sin has min + start = -1 + return interval(start, end) + else: + raise NotImplementedError + + +#periodic +def cos(x): + """Evaluates the cos of an interval""" + np = import_module('numpy') + if isinstance(x, (int, float)): + return interval(np.sin(x)) + elif isinstance(x, interval): + if not (np.isfinite(x.start) and np.isfinite(x.end)): + return interval(-1, 1, is_valid=x.is_valid) + na, __ = divmod(x.start, np.pi / 2.0) + nb, __ = divmod(x.end, np.pi / 2.0) + start = min(np.cos(x.start), np.cos(x.end)) + end = max(np.cos(x.start), np.cos(x.end)) + if nb - na > 4: + #differ more than 2*pi + return interval(-1, 1, is_valid=x.is_valid) + elif na == nb: + #in the same quadarant + return interval(start, end, is_valid=x.is_valid) + else: + if (na) // 4 != (nb) // 4: + #cos has max + end = 1 + if (na - 2) // 4 != (nb - 2) // 4: + #cos has min + start = -1 + return interval(start, end, is_valid=x.is_valid) + else: + raise NotImplementedError + + +def tan(x): + """Evaluates the tan of an interval""" + return sin(x) / cos(x) + + +#Monotonic +def sqrt(x): + """Evaluates the square root of an interval""" + np = import_module('numpy') + if isinstance(x, (int, float)): + if x > 0: + return interval(np.sqrt(x)) + else: + return interval(-np.inf, np.inf, is_valid=False) + elif isinstance(x, interval): + #Outside the domain + if x.end < 0: + return interval(-np.inf, np.inf, is_valid=False) + #Partially outside the domain + elif x.start < 0: + return interval(-np.inf, np.inf, is_valid=None) + else: + return interval(np.sqrt(x.start), np.sqrt(x.end), + is_valid=x.is_valid) + else: + raise NotImplementedError + + +def imin(*args): + """Evaluates the minimum of a list of intervals""" + np = import_module('numpy') + if not all(isinstance(arg, (int, float, interval)) for arg in args): + return NotImplementedError + else: + new_args = [a for a in args if isinstance(a, (int, float)) + or a.is_valid] + if len(new_args) == 0: + if all(a.is_valid is False for a in args): + return interval(-np.inf, np.inf, is_valid=False) + else: + return interval(-np.inf, np.inf, is_valid=None) + start_array = [a if isinstance(a, (int, float)) else a.start + for a in new_args] + + end_array = [a if isinstance(a, (int, float)) else a.end + for a in new_args] + return interval(min(start_array), min(end_array)) + + +def imax(*args): + """Evaluates the maximum of a list of intervals""" + np = import_module('numpy') + if not all(isinstance(arg, (int, float, interval)) for arg in args): + return NotImplementedError + else: + new_args = [a for a in args if isinstance(a, (int, float)) + or a.is_valid] + if len(new_args) == 0: + if all(a.is_valid is False for a in args): + return interval(-np.inf, np.inf, is_valid=False) + else: + return interval(-np.inf, np.inf, is_valid=None) + start_array = [a if isinstance(a, (int, float)) else a.start + for a in new_args] + + end_array = [a if isinstance(a, (int, float)) else a.end + for a in new_args] + + return interval(max(start_array), max(end_array)) + + +#Monotonic +def sinh(x): + """Evaluates the hyperbolic sine of an interval""" + np = import_module('numpy') + if isinstance(x, (int, float)): + return interval(np.sinh(x), np.sinh(x)) + elif isinstance(x, interval): + return interval(np.sinh(x.start), np.sinh(x.end), is_valid=x.is_valid) + else: + raise NotImplementedError + + +def cosh(x): + """Evaluates the hyperbolic cos of an interval""" + np = import_module('numpy') + if isinstance(x, (int, float)): + return interval(np.cosh(x), np.cosh(x)) + elif isinstance(x, interval): + #both signs + if x.start < 0 and x.end > 0: + end = max(np.cosh(x.start), np.cosh(x.end)) + return interval(1, end, is_valid=x.is_valid) + else: + #Monotonic + start = np.cosh(x.start) + end = np.cosh(x.end) + return interval(start, end, is_valid=x.is_valid) + else: + raise NotImplementedError + + +#Monotonic +def tanh(x): + """Evaluates the hyperbolic tan of an interval""" + np = import_module('numpy') + if isinstance(x, (int, float)): + return interval(np.tanh(x), np.tanh(x)) + elif isinstance(x, interval): + return interval(np.tanh(x.start), np.tanh(x.end), is_valid=x.is_valid) + else: + raise NotImplementedError + + +def asin(x): + """Evaluates the inverse sine of an interval""" + np = import_module('numpy') + if isinstance(x, (int, float)): + #Outside the domain + if abs(x) > 1: + return interval(-np.inf, np.inf, is_valid=False) + else: + return interval(np.arcsin(x), np.arcsin(x)) + elif isinstance(x, interval): + #Outside the domain + if x.is_valid is False or x.start > 1 or x.end < -1: + return interval(-np.inf, np.inf, is_valid=False) + #Partially outside the domain + elif x.start < -1 or x.end > 1: + return interval(-np.inf, np.inf, is_valid=None) + else: + start = np.arcsin(x.start) + end = np.arcsin(x.end) + return interval(start, end, is_valid=x.is_valid) + + +def acos(x): + """Evaluates the inverse cos of an interval""" + np = import_module('numpy') + if isinstance(x, (int, float)): + if abs(x) > 1: + #Outside the domain + return interval(-np.inf, np.inf, is_valid=False) + else: + return interval(np.arccos(x), np.arccos(x)) + elif isinstance(x, interval): + #Outside the domain + if x.is_valid is False or x.start > 1 or x.end < -1: + return interval(-np.inf, np.inf, is_valid=False) + #Partially outside the domain + elif x.start < -1 or x.end > 1: + return interval(-np.inf, np.inf, is_valid=None) + else: + start = np.arccos(x.start) + end = np.arccos(x.end) + return interval(start, end, is_valid=x.is_valid) + + +def ceil(x): + """Evaluates the ceiling of an interval""" + np = import_module('numpy') + if isinstance(x, (int, float)): + return interval(np.ceil(x)) + elif isinstance(x, interval): + if x.is_valid is False: + return interval(-np.inf, np.inf, is_valid=False) + else: + start = np.ceil(x.start) + end = np.ceil(x.end) + #Continuous over the interval + if start == end: + return interval(start, end, is_valid=x.is_valid) + else: + #Not continuous over the interval + return interval(start, end, is_valid=None) + else: + return NotImplementedError + + +def floor(x): + """Evaluates the floor of an interval""" + np = import_module('numpy') + if isinstance(x, (int, float)): + return interval(np.floor(x)) + elif isinstance(x, interval): + if x.is_valid is False: + return interval(-np.inf, np.inf, is_valid=False) + else: + start = np.floor(x.start) + end = np.floor(x.end) + #continuous over the argument + if start == end: + return interval(start, end, is_valid=x.is_valid) + else: + #not continuous over the interval + return interval(start, end, is_valid=None) + else: + return NotImplementedError + + +def acosh(x): + """Evaluates the inverse hyperbolic cosine of an interval""" + np = import_module('numpy') + if isinstance(x, (int, float)): + #Outside the domain + if x < 1: + return interval(-np.inf, np.inf, is_valid=False) + else: + return interval(np.arccosh(x)) + elif isinstance(x, interval): + #Outside the domain + if x.end < 1: + return interval(-np.inf, np.inf, is_valid=False) + #Partly outside the domain + elif x.start < 1: + return interval(-np.inf, np.inf, is_valid=None) + else: + start = np.arccosh(x.start) + end = np.arccosh(x.end) + return interval(start, end, is_valid=x.is_valid) + else: + return NotImplementedError + + +#Monotonic +def asinh(x): + """Evaluates the inverse hyperbolic sine of an interval""" + np = import_module('numpy') + if isinstance(x, (int, float)): + return interval(np.arcsinh(x)) + elif isinstance(x, interval): + start = np.arcsinh(x.start) + end = np.arcsinh(x.end) + return interval(start, end, is_valid=x.is_valid) + else: + return NotImplementedError + + +def atanh(x): + """Evaluates the inverse hyperbolic tangent of an interval""" + np = import_module('numpy') + if isinstance(x, (int, float)): + #Outside the domain + if abs(x) >= 1: + return interval(-np.inf, np.inf, is_valid=False) + else: + return interval(np.arctanh(x)) + elif isinstance(x, interval): + #outside the domain + if x.is_valid is False or x.start >= 1 or x.end <= -1: + return interval(-np.inf, np.inf, is_valid=False) + #partly outside the domain + elif x.start <= -1 or x.end >= 1: + return interval(-np.inf, np.inf, is_valid=None) + else: + start = np.arctanh(x.start) + end = np.arctanh(x.end) + return interval(start, end, is_valid=x.is_valid) + else: + return NotImplementedError + + +#Three valued logic for interval plotting. + +def And(*args): + """Defines the three valued ``And`` behaviour for a 2-tuple of + three valued logic values""" + def reduce_and(cmp_intervala, cmp_intervalb): + if cmp_intervala[0] is False or cmp_intervalb[0] is False: + first = False + elif cmp_intervala[0] is None or cmp_intervalb[0] is None: + first = None + else: + first = True + if cmp_intervala[1] is False or cmp_intervalb[1] is False: + second = False + elif cmp_intervala[1] is None or cmp_intervalb[1] is None: + second = None + else: + second = True + return (first, second) + return reduce(reduce_and, args) + + +def Or(*args): + """Defines the three valued ``Or`` behaviour for a 2-tuple of + three valued logic values""" + def reduce_or(cmp_intervala, cmp_intervalb): + if cmp_intervala[0] is True or cmp_intervalb[0] is True: + first = True + elif cmp_intervala[0] is None or cmp_intervalb[0] is None: + first = None + else: + first = False + + if cmp_intervala[1] is True or cmp_intervalb[1] is True: + second = True + elif cmp_intervala[1] is None or cmp_intervalb[1] is None: + second = None + else: + second = False + return (first, second) + return reduce(reduce_or, args) diff --git a/phivenv/Lib/site-packages/sympy/plotting/intervalmath/tests/__init__.py b/phivenv/Lib/site-packages/sympy/plotting/intervalmath/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/plotting/intervalmath/tests/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/plotting/intervalmath/tests/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f3d8fe306f56849d8a3f8f18113c5f3804d6a615 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/plotting/intervalmath/tests/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/plotting/intervalmath/tests/__pycache__/test_interval_functions.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/plotting/intervalmath/tests/__pycache__/test_interval_functions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b3e1417d70c337f9224a9fa80582a10a1c1cb29f Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/plotting/intervalmath/tests/__pycache__/test_interval_functions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/plotting/intervalmath/tests/__pycache__/test_interval_membership.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/plotting/intervalmath/tests/__pycache__/test_interval_membership.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49e498618ddbf8ba1f12f643ba07eebe2aecab90 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/plotting/intervalmath/tests/__pycache__/test_interval_membership.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/plotting/intervalmath/tests/__pycache__/test_intervalmath.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/plotting/intervalmath/tests/__pycache__/test_intervalmath.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec939ec5cd394c292337a4c2af55f8babff78511 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/plotting/intervalmath/tests/__pycache__/test_intervalmath.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/plotting/intervalmath/tests/test_interval_functions.py b/phivenv/Lib/site-packages/sympy/plotting/intervalmath/tests/test_interval_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..861c3660df024d3fbec788a027708348e9929655 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/plotting/intervalmath/tests/test_interval_functions.py @@ -0,0 +1,415 @@ +from sympy.external import import_module +from sympy.plotting.intervalmath import ( + Abs, acos, acosh, And, asin, asinh, atan, atanh, ceil, cos, cosh, + exp, floor, imax, imin, interval, log, log10, Or, sin, sinh, sqrt, + tan, tanh, +) + +np = import_module('numpy') +if not np: + disabled = True + + +#requires Numpy. Hence included in interval_functions + + +def test_interval_pow(): + a = 2**interval(1, 2) == interval(2, 4) + assert a == (True, True) + a = interval(1, 2)**interval(1, 2) == interval(1, 4) + assert a == (True, True) + a = interval(-1, 1)**interval(0.5, 2) + assert a.is_valid is None + a = interval(-2, -1) ** interval(1, 2) + assert a.is_valid is False + a = interval(-2, -1) ** (1.0 / 2) + assert a.is_valid is False + a = interval(-1, 1)**(1.0 / 2) + assert a.is_valid is None + a = interval(-1, 1)**(1.0 / 3) == interval(-1, 1) + assert a == (True, True) + a = interval(-1, 1)**2 == interval(0, 1) + assert a == (True, True) + a = interval(-1, 1) ** (1.0 / 29) == interval(-1, 1) + assert a == (True, True) + a = -2**interval(1, 1) == interval(-2, -2) + assert a == (True, True) + + a = interval(1, 2, is_valid=False)**2 + assert a.is_valid is False + + a = (-3)**interval(1, 2) + assert a.is_valid is False + a = (-4)**interval(0.5, 0.5) + assert a.is_valid is False + assert ((-3)**interval(1, 1) == interval(-3, -3)) == (True, True) + + a = interval(8, 64)**(2.0 / 3) + assert abs(a.start - 4) < 1e-10 # eps + assert abs(a.end - 16) < 1e-10 + a = interval(-8, 64)**(2.0 / 3) + assert abs(a.start - 4) < 1e-10 # eps + assert abs(a.end - 16) < 1e-10 + + +def test_exp(): + a = exp(interval(-np.inf, 0)) + assert a.start == np.exp(-np.inf) + assert a.end == np.exp(0) + a = exp(interval(1, 2)) + assert a.start == np.exp(1) + assert a.end == np.exp(2) + a = exp(1) + assert a.start == np.exp(1) + assert a.end == np.exp(1) + + +def test_log(): + a = log(interval(1, 2)) + assert a.start == 0 + assert a.end == np.log(2) + a = log(interval(-1, 1)) + assert a.is_valid is None + a = log(interval(-3, -1)) + assert a.is_valid is False + a = log(-3) + assert a.is_valid is False + a = log(2) + assert a.start == np.log(2) + assert a.end == np.log(2) + + +def test_log10(): + a = log10(interval(1, 2)) + assert a.start == 0 + assert a.end == np.log10(2) + a = log10(interval(-1, 1)) + assert a.is_valid is None + a = log10(interval(-3, -1)) + assert a.is_valid is False + a = log10(-3) + assert a.is_valid is False + a = log10(2) + assert a.start == np.log10(2) + assert a.end == np.log10(2) + + +def test_atan(): + a = atan(interval(0, 1)) + assert a.start == np.arctan(0) + assert a.end == np.arctan(1) + a = atan(1) + assert a.start == np.arctan(1) + assert a.end == np.arctan(1) + + +def test_sin(): + a = sin(interval(0, np.pi / 4)) + assert a.start == np.sin(0) + assert a.end == np.sin(np.pi / 4) + + a = sin(interval(-np.pi / 4, np.pi / 4)) + assert a.start == np.sin(-np.pi / 4) + assert a.end == np.sin(np.pi / 4) + + a = sin(interval(np.pi / 4, 3 * np.pi / 4)) + assert a.start == np.sin(np.pi / 4) + assert a.end == 1 + + a = sin(interval(7 * np.pi / 6, 7 * np.pi / 4)) + assert a.start == -1 + assert a.end == np.sin(7 * np.pi / 6) + + a = sin(interval(0, 3 * np.pi)) + assert a.start == -1 + assert a.end == 1 + + a = sin(interval(np.pi / 3, 7 * np.pi / 4)) + assert a.start == -1 + assert a.end == 1 + + a = sin(np.pi / 4) + assert a.start == np.sin(np.pi / 4) + assert a.end == np.sin(np.pi / 4) + + a = sin(interval(1, 2, is_valid=False)) + assert a.is_valid is False + + +def test_cos(): + a = cos(interval(0, np.pi / 4)) + assert a.start == np.cos(np.pi / 4) + assert a.end == 1 + + a = cos(interval(-np.pi / 4, np.pi / 4)) + assert a.start == np.cos(-np.pi / 4) + assert a.end == 1 + + a = cos(interval(np.pi / 4, 3 * np.pi / 4)) + assert a.start == np.cos(3 * np.pi / 4) + assert a.end == np.cos(np.pi / 4) + + a = cos(interval(3 * np.pi / 4, 5 * np.pi / 4)) + assert a.start == -1 + assert a.end == np.cos(3 * np.pi / 4) + + a = cos(interval(0, 3 * np.pi)) + assert a.start == -1 + assert a.end == 1 + + a = cos(interval(- np.pi / 3, 5 * np.pi / 4)) + assert a.start == -1 + assert a.end == 1 + + a = cos(interval(1, 2, is_valid=False)) + assert a.is_valid is False + + +def test_tan(): + a = tan(interval(0, np.pi / 4)) + assert a.start == 0 + # must match lib_interval definition of tan: + assert a.end == np.sin(np.pi / 4)/np.cos(np.pi / 4) + + a = tan(interval(np.pi / 4, 3 * np.pi / 4)) + #discontinuity + assert a.is_valid is None + + +def test_sqrt(): + a = sqrt(interval(1, 4)) + assert a.start == 1 + assert a.end == 2 + + a = sqrt(interval(0.01, 1)) + assert a.start == np.sqrt(0.01) + assert a.end == 1 + + a = sqrt(interval(-1, 1)) + assert a.is_valid is None + + a = sqrt(interval(-3, -1)) + assert a.is_valid is False + + a = sqrt(4) + assert (a == interval(2, 2)) == (True, True) + + a = sqrt(-3) + assert a.is_valid is False + + +def test_imin(): + a = imin(interval(1, 3), interval(2, 5), interval(-1, 3)) + assert a.start == -1 + assert a.end == 3 + + a = imin(-2, interval(1, 4)) + assert a.start == -2 + assert a.end == -2 + + a = imin(5, interval(3, 4), interval(-2, 2, is_valid=False)) + assert a.start == 3 + assert a.end == 4 + + +def test_imax(): + a = imax(interval(-2, 2), interval(2, 7), interval(-3, 9)) + assert a.start == 2 + assert a.end == 9 + + a = imax(8, interval(1, 4)) + assert a.start == 8 + assert a.end == 8 + + a = imax(interval(1, 2), interval(3, 4), interval(-2, 2, is_valid=False)) + assert a.start == 3 + assert a.end == 4 + + +def test_sinh(): + a = sinh(interval(-1, 1)) + assert a.start == np.sinh(-1) + assert a.end == np.sinh(1) + + a = sinh(1) + assert a.start == np.sinh(1) + assert a.end == np.sinh(1) + + +def test_cosh(): + a = cosh(interval(1, 2)) + assert a.start == np.cosh(1) + assert a.end == np.cosh(2) + a = cosh(interval(-2, -1)) + assert a.start == np.cosh(-1) + assert a.end == np.cosh(-2) + + a = cosh(interval(-2, 1)) + assert a.start == 1 + assert a.end == np.cosh(-2) + + a = cosh(1) + assert a.start == np.cosh(1) + assert a.end == np.cosh(1) + + +def test_tanh(): + a = tanh(interval(-3, 3)) + assert a.start == np.tanh(-3) + assert a.end == np.tanh(3) + + a = tanh(3) + assert a.start == np.tanh(3) + assert a.end == np.tanh(3) + + +def test_asin(): + a = asin(interval(-0.5, 0.5)) + assert a.start == np.arcsin(-0.5) + assert a.end == np.arcsin(0.5) + + a = asin(interval(-1.5, 1.5)) + assert a.is_valid is None + a = asin(interval(-2, -1.5)) + assert a.is_valid is False + + a = asin(interval(0, 2)) + assert a.is_valid is None + + a = asin(interval(2, 5)) + assert a.is_valid is False + + a = asin(0.5) + assert a.start == np.arcsin(0.5) + assert a.end == np.arcsin(0.5) + + a = asin(1.5) + assert a.is_valid is False + + +def test_acos(): + a = acos(interval(-0.5, 0.5)) + assert a.start == np.arccos(0.5) + assert a.end == np.arccos(-0.5) + + a = acos(interval(-1.5, 1.5)) + assert a.is_valid is None + a = acos(interval(-2, -1.5)) + assert a.is_valid is False + + a = acos(interval(0, 2)) + assert a.is_valid is None + + a = acos(interval(2, 5)) + assert a.is_valid is False + + a = acos(0.5) + assert a.start == np.arccos(0.5) + assert a.end == np.arccos(0.5) + + a = acos(1.5) + assert a.is_valid is False + + +def test_ceil(): + a = ceil(interval(0.2, 0.5)) + assert a.start == 1 + assert a.end == 1 + + a = ceil(interval(0.5, 1.5)) + assert a.start == 1 + assert a.end == 2 + assert a.is_valid is None + + a = ceil(interval(-5, 5)) + assert a.is_valid is None + + a = ceil(5.4) + assert a.start == 6 + assert a.end == 6 + + +def test_floor(): + a = floor(interval(0.2, 0.5)) + assert a.start == 0 + assert a.end == 0 + + a = floor(interval(0.5, 1.5)) + assert a.start == 0 + assert a.end == 1 + assert a.is_valid is None + + a = floor(interval(-5, 5)) + assert a.is_valid is None + + a = floor(5.4) + assert a.start == 5 + assert a.end == 5 + + +def test_asinh(): + a = asinh(interval(1, 2)) + assert a.start == np.arcsinh(1) + assert a.end == np.arcsinh(2) + + a = asinh(0.5) + assert a.start == np.arcsinh(0.5) + assert a.end == np.arcsinh(0.5) + + +def test_acosh(): + a = acosh(interval(3, 5)) + assert a.start == np.arccosh(3) + assert a.end == np.arccosh(5) + + a = acosh(interval(0, 3)) + assert a.is_valid is None + a = acosh(interval(-3, 0.5)) + assert a.is_valid is False + + a = acosh(0.5) + assert a.is_valid is False + + a = acosh(2) + assert a.start == np.arccosh(2) + assert a.end == np.arccosh(2) + + +def test_atanh(): + a = atanh(interval(-0.5, 0.5)) + assert a.start == np.arctanh(-0.5) + assert a.end == np.arctanh(0.5) + + a = atanh(interval(0, 3)) + assert a.is_valid is None + + a = atanh(interval(-3, -2)) + assert a.is_valid is False + + a = atanh(0.5) + assert a.start == np.arctanh(0.5) + assert a.end == np.arctanh(0.5) + + a = atanh(1.5) + assert a.is_valid is False + + +def test_Abs(): + assert (Abs(interval(-0.5, 0.5)) == interval(0, 0.5)) == (True, True) + assert (Abs(interval(-3, -2)) == interval(2, 3)) == (True, True) + assert (Abs(-3) == interval(3, 3)) == (True, True) + + +def test_And(): + args = [(True, True), (True, False), (True, None)] + assert And(*args) == (True, False) + + args = [(False, True), (None, None), (True, True)] + assert And(*args) == (False, None) + + +def test_Or(): + args = [(True, True), (True, False), (False, None)] + assert Or(*args) == (True, True) + args = [(None, None), (False, None), (False, False)] + assert Or(*args) == (None, None) diff --git a/phivenv/Lib/site-packages/sympy/plotting/intervalmath/tests/test_interval_membership.py b/phivenv/Lib/site-packages/sympy/plotting/intervalmath/tests/test_interval_membership.py new file mode 100644 index 0000000000000000000000000000000000000000..7b7f23680d60a64a6257a84c2476e31a8b5dfce8 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/plotting/intervalmath/tests/test_interval_membership.py @@ -0,0 +1,150 @@ +from sympy.core.symbol import Symbol +from sympy.plotting.intervalmath import interval +from sympy.plotting.intervalmath.interval_membership import intervalMembership +from sympy.plotting.experimental_lambdify import experimental_lambdify +from sympy.testing.pytest import raises + + +def test_creation(): + assert intervalMembership(True, True) + raises(TypeError, lambda: intervalMembership(True)) + raises(TypeError, lambda: intervalMembership(True, True, True)) + + +def test_getitem(): + a = intervalMembership(True, False) + assert a[0] is True + assert a[1] is False + raises(IndexError, lambda: a[2]) + + +def test_str(): + a = intervalMembership(True, False) + assert str(a) == 'intervalMembership(True, False)' + assert repr(a) == 'intervalMembership(True, False)' + + +def test_equivalence(): + a = intervalMembership(True, True) + b = intervalMembership(True, False) + assert (a == b) is False + assert (a != b) is True + + a = intervalMembership(True, False) + b = intervalMembership(True, False) + assert (a == b) is True + assert (a != b) is False + + +def test_not(): + x = Symbol('x') + + r1 = x > -1 + r2 = x <= -1 + + i = interval + + f1 = experimental_lambdify((x,), r1) + f2 = experimental_lambdify((x,), r2) + + tt = i(-0.1, 0.1, is_valid=True) + tn = i(-0.1, 0.1, is_valid=None) + tf = i(-0.1, 0.1, is_valid=False) + + assert f1(tt) == ~f2(tt) + assert f1(tn) == ~f2(tn) + assert f1(tf) == ~f2(tf) + + nt = i(0.9, 1.1, is_valid=True) + nn = i(0.9, 1.1, is_valid=None) + nf = i(0.9, 1.1, is_valid=False) + + assert f1(nt) == ~f2(nt) + assert f1(nn) == ~f2(nn) + assert f1(nf) == ~f2(nf) + + ft = i(1.9, 2.1, is_valid=True) + fn = i(1.9, 2.1, is_valid=None) + ff = i(1.9, 2.1, is_valid=False) + + assert f1(ft) == ~f2(ft) + assert f1(fn) == ~f2(fn) + assert f1(ff) == ~f2(ff) + + +def test_boolean(): + # There can be 9*9 test cases in full mapping of the cartesian product. + # But we only consider 3*3 cases for simplicity. + s = [ + intervalMembership(False, False), + intervalMembership(None, None), + intervalMembership(True, True) + ] + + # Reduced tests for 'And' + a1 = [ + intervalMembership(False, False), + intervalMembership(False, False), + intervalMembership(False, False), + intervalMembership(False, False), + intervalMembership(None, None), + intervalMembership(None, None), + intervalMembership(False, False), + intervalMembership(None, None), + intervalMembership(True, True) + ] + a1_iter = iter(a1) + for i in range(len(s)): + for j in range(len(s)): + assert s[i] & s[j] == next(a1_iter) + + # Reduced tests for 'Or' + a1 = [ + intervalMembership(False, False), + intervalMembership(None, False), + intervalMembership(True, False), + intervalMembership(None, False), + intervalMembership(None, None), + intervalMembership(True, None), + intervalMembership(True, False), + intervalMembership(True, None), + intervalMembership(True, True) + ] + a1_iter = iter(a1) + for i in range(len(s)): + for j in range(len(s)): + assert s[i] | s[j] == next(a1_iter) + + # Reduced tests for 'Xor' + a1 = [ + intervalMembership(False, False), + intervalMembership(None, False), + intervalMembership(True, False), + intervalMembership(None, False), + intervalMembership(None, None), + intervalMembership(None, None), + intervalMembership(True, False), + intervalMembership(None, None), + intervalMembership(False, True) + ] + a1_iter = iter(a1) + for i in range(len(s)): + for j in range(len(s)): + assert s[i] ^ s[j] == next(a1_iter) + + # Reduced tests for 'Not' + a1 = [ + intervalMembership(True, False), + intervalMembership(None, None), + intervalMembership(False, True) + ] + a1_iter = iter(a1) + for i in range(len(s)): + assert ~s[i] == next(a1_iter) + + +def test_boolean_errors(): + a = intervalMembership(True, True) + raises(ValueError, lambda: a & 1) + raises(ValueError, lambda: a | 1) + raises(ValueError, lambda: a ^ 1) diff --git a/phivenv/Lib/site-packages/sympy/plotting/intervalmath/tests/test_intervalmath.py b/phivenv/Lib/site-packages/sympy/plotting/intervalmath/tests/test_intervalmath.py new file mode 100644 index 0000000000000000000000000000000000000000..e30f217a44b4ea795270c0e2c66b6813b05e63ea --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/plotting/intervalmath/tests/test_intervalmath.py @@ -0,0 +1,213 @@ +from sympy.plotting.intervalmath import interval +from sympy.testing.pytest import raises + + +def test_interval(): + assert (interval(1, 1) == interval(1, 1, is_valid=True)) == (True, True) + assert (interval(1, 1) == interval(1, 1, is_valid=False)) == (True, False) + assert (interval(1, 1) == interval(1, 1, is_valid=None)) == (True, None) + assert (interval(1, 1.5) == interval(1, 2)) == (None, True) + assert (interval(0, 1) == interval(2, 3)) == (False, True) + assert (interval(0, 1) == interval(1, 2)) == (None, True) + assert (interval(1, 2) != interval(1, 2)) == (False, True) + assert (interval(1, 3) != interval(2, 3)) == (None, True) + assert (interval(1, 3) != interval(-5, -3)) == (True, True) + assert ( + interval(1, 3, is_valid=False) != interval(-5, -3)) == (True, False) + assert (interval(1, 3, is_valid=None) != interval(-5, 3)) == (None, None) + assert (interval(4, 4) != 4) == (False, True) + assert (interval(1, 1) == 1) == (True, True) + assert (interval(1, 3, is_valid=False) == interval(1, 3)) == (True, False) + assert (interval(1, 3, is_valid=None) == interval(1, 3)) == (True, None) + inter = interval(-5, 5) + assert (interval(inter) == interval(-5, 5)) == (True, True) + assert inter.width == 10 + assert 0 in inter + assert -5 in inter + assert 5 in inter + assert interval(0, 3) in inter + assert interval(-6, 2) not in inter + assert -5.05 not in inter + assert 5.3 not in inter + interb = interval(-float('inf'), float('inf')) + assert 0 in inter + assert inter in interb + assert interval(0, float('inf')) in interb + assert interval(-float('inf'), 5) in interb + assert interval(-1e50, 1e50) in interb + assert ( + -interval(-1, -2, is_valid=False) == interval(1, 2)) == (True, False) + raises(ValueError, lambda: interval(1, 2, 3)) + + +def test_interval_add(): + assert (interval(1, 2) + interval(2, 3) == interval(3, 5)) == (True, True) + assert (1 + interval(1, 2) == interval(2, 3)) == (True, True) + assert (interval(1, 2) + 1 == interval(2, 3)) == (True, True) + compare = (1 + interval(0, float('inf')) == interval(1, float('inf'))) + assert compare == (True, True) + a = 1 + interval(2, 5, is_valid=False) + assert a.is_valid is False + a = 1 + interval(2, 5, is_valid=None) + assert a.is_valid is None + a = interval(2, 5, is_valid=False) + interval(3, 5, is_valid=None) + assert a.is_valid is False + a = interval(3, 5) + interval(-1, 1, is_valid=None) + assert a.is_valid is None + a = interval(2, 5, is_valid=False) + 1 + assert a.is_valid is False + + +def test_interval_sub(): + assert (interval(1, 2) - interval(1, 5) == interval(-4, 1)) == (True, True) + assert (interval(1, 2) - 1 == interval(0, 1)) == (True, True) + assert (1 - interval(1, 2) == interval(-1, 0)) == (True, True) + a = 1 - interval(1, 2, is_valid=False) + assert a.is_valid is False + a = interval(1, 4, is_valid=None) - 1 + assert a.is_valid is None + a = interval(1, 3, is_valid=False) - interval(1, 3) + assert a.is_valid is False + a = interval(1, 3, is_valid=None) - interval(1, 3) + assert a.is_valid is None + + +def test_interval_inequality(): + assert (interval(1, 2) < interval(3, 4)) == (True, True) + assert (interval(1, 2) < interval(2, 4)) == (None, True) + assert (interval(1, 2) < interval(-2, 0)) == (False, True) + assert (interval(1, 2) <= interval(2, 4)) == (True, True) + assert (interval(1, 2) <= interval(1.5, 6)) == (None, True) + assert (interval(2, 3) <= interval(1, 2)) == (None, True) + assert (interval(2, 3) <= interval(1, 1.5)) == (False, True) + assert ( + interval(1, 2, is_valid=False) <= interval(-2, 0)) == (False, False) + assert (interval(1, 2, is_valid=None) <= interval(-2, 0)) == (False, None) + assert (interval(1, 2) <= 1.5) == (None, True) + assert (interval(1, 2) <= 3) == (True, True) + assert (interval(1, 2) <= 0) == (False, True) + assert (interval(5, 8) > interval(2, 3)) == (True, True) + assert (interval(2, 5) > interval(1, 3)) == (None, True) + assert (interval(2, 3) > interval(3.1, 5)) == (False, True) + + assert (interval(-1, 1) == 0) == (None, True) + assert (interval(-1, 1) == 2) == (False, True) + assert (interval(-1, 1) != 0) == (None, True) + assert (interval(-1, 1) != 2) == (True, True) + + assert (interval(3, 5) > 2) == (True, True) + assert (interval(3, 5) < 2) == (False, True) + assert (interval(1, 5) < 2) == (None, True) + assert (interval(1, 5) > 2) == (None, True) + assert (interval(0, 1) > 2) == (False, True) + assert (interval(1, 2) >= interval(0, 1)) == (True, True) + assert (interval(1, 2) >= interval(0, 1.5)) == (None, True) + assert (interval(1, 2) >= interval(3, 4)) == (False, True) + assert (interval(1, 2) >= 0) == (True, True) + assert (interval(1, 2) >= 1.2) == (None, True) + assert (interval(1, 2) >= 3) == (False, True) + assert (2 > interval(0, 1)) == (True, True) + a = interval(-1, 1, is_valid=False) < interval(2, 5, is_valid=None) + assert a == (True, False) + a = interval(-1, 1, is_valid=None) < interval(2, 5, is_valid=False) + assert a == (True, False) + a = interval(-1, 1, is_valid=None) < interval(2, 5, is_valid=None) + assert a == (True, None) + a = interval(-1, 1, is_valid=False) > interval(-5, -2, is_valid=None) + assert a == (True, False) + a = interval(-1, 1, is_valid=None) > interval(-5, -2, is_valid=False) + assert a == (True, False) + a = interval(-1, 1, is_valid=None) > interval(-5, -2, is_valid=None) + assert a == (True, None) + + +def test_interval_mul(): + assert ( + interval(1, 5) * interval(2, 10) == interval(2, 50)) == (True, True) + a = interval(-1, 1) * interval(2, 10) == interval(-10, 10) + assert a == (True, True) + + a = interval(-1, 1) * interval(-5, 3) == interval(-5, 5) + assert a == (True, True) + + assert (interval(1, 3) * 2 == interval(2, 6)) == (True, True) + assert (3 * interval(-1, 2) == interval(-3, 6)) == (True, True) + + a = 3 * interval(1, 2, is_valid=False) + assert a.is_valid is False + + a = 3 * interval(1, 2, is_valid=None) + assert a.is_valid is None + + a = interval(1, 5, is_valid=False) * interval(1, 2, is_valid=None) + assert a.is_valid is False + + +def test_interval_div(): + div = interval(1, 2, is_valid=False) / 3 + assert div == interval(-float('inf'), float('inf'), is_valid=False) + + div = interval(1, 2, is_valid=None) / 3 + assert div == interval(-float('inf'), float('inf'), is_valid=None) + + div = 3 / interval(1, 2, is_valid=None) + assert div == interval(-float('inf'), float('inf'), is_valid=None) + a = interval(1, 2) / 0 + assert a.is_valid is False + a = interval(0.5, 1) / interval(-1, 0) + assert a.is_valid is None + a = interval(0, 1) / interval(0, 1) + assert a.is_valid is None + + a = interval(-1, 1) / interval(-1, 1) + assert a.is_valid is None + + a = interval(-1, 2) / interval(0.5, 1) == interval(-2.0, 4.0) + assert a == (True, True) + a = interval(0, 1) / interval(0.5, 1) == interval(0.0, 2.0) + assert a == (True, True) + a = interval(-1, 0) / interval(0.5, 1) == interval(-2.0, 0.0) + assert a == (True, True) + a = interval(-0.5, -0.25) / interval(0.5, 1) == interval(-1.0, -0.25) + assert a == (True, True) + a = interval(0.5, 1) / interval(0.5, 1) == interval(0.5, 2.0) + assert a == (True, True) + a = interval(0.5, 4) / interval(0.5, 1) == interval(0.5, 8.0) + assert a == (True, True) + a = interval(-1, -0.5) / interval(0.5, 1) == interval(-2.0, -0.5) + assert a == (True, True) + a = interval(-4, -0.5) / interval(0.5, 1) == interval(-8.0, -0.5) + assert a == (True, True) + a = interval(-1, 2) / interval(-2, -0.5) == interval(-4.0, 2.0) + assert a == (True, True) + a = interval(0, 1) / interval(-2, -0.5) == interval(-2.0, 0.0) + assert a == (True, True) + a = interval(-1, 0) / interval(-2, -0.5) == interval(0.0, 2.0) + assert a == (True, True) + a = interval(-0.5, -0.25) / interval(-2, -0.5) == interval(0.125, 1.0) + assert a == (True, True) + a = interval(0.5, 1) / interval(-2, -0.5) == interval(-2.0, -0.25) + assert a == (True, True) + a = interval(0.5, 4) / interval(-2, -0.5) == interval(-8.0, -0.25) + assert a == (True, True) + a = interval(-1, -0.5) / interval(-2, -0.5) == interval(0.25, 2.0) + assert a == (True, True) + a = interval(-4, -0.5) / interval(-2, -0.5) == interval(0.25, 8.0) + assert a == (True, True) + a = interval(-5, 5, is_valid=False) / 2 + assert a.is_valid is False + +def test_hashable(): + ''' + test that interval objects are hashable. + this is required in order to be able to put them into the cache, which + appears to be necessary for plotting in py3k. For details, see: + + https://github.com/sympy/sympy/pull/2101 + https://github.com/sympy/sympy/issues/6533 + ''' + hash(interval(1, 1)) + hash(interval(1, 1, is_valid=True)) + hash(interval(-4, -0.5)) + hash(interval(-2, -0.5)) + hash(interval(0.25, 8.0)) diff --git a/phivenv/Lib/site-packages/sympy/plotting/plot.py b/phivenv/Lib/site-packages/sympy/plotting/plot.py new file mode 100644 index 0000000000000000000000000000000000000000..50029392a1ac70491f93f28c4d443da15e7fc31e --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/plotting/plot.py @@ -0,0 +1,1234 @@ +"""Plotting module for SymPy. + +A plot is represented by the ``Plot`` class that contains a reference to the +backend and a list of the data series to be plotted. The data series are +instances of classes meant to simplify getting points and meshes from SymPy +expressions. ``plot_backends`` is a dictionary with all the backends. + +This module gives only the essential. For all the fancy stuff use directly +the backend. You can get the backend wrapper for every plot from the +``_backend`` attribute. Moreover the data series classes have various useful +methods like ``get_points``, ``get_meshes``, etc, that may +be useful if you wish to use another plotting library. + +Especially if you need publication ready graphs and this module is not enough +for you - just get the ``_backend`` attribute and add whatever you want +directly to it. In the case of matplotlib (the common way to graph data in +python) just copy ``_backend.fig`` which is the figure and ``_backend.ax`` +which is the axis and work on them as you would on any other matplotlib object. + +Simplicity of code takes much greater importance than performance. Do not use it +if you care at all about performance. A new backend instance is initialized +every time you call ``show()`` and the old one is left to the garbage collector. +""" + +from sympy.concrete.summations import Sum +from sympy.core.containers import Tuple +from sympy.core.expr import Expr +from sympy.core.function import Function, AppliedUndef +from sympy.core.symbol import (Dummy, Symbol, Wild) +from sympy.external import import_module +from sympy.functions import sign +from sympy.plotting.backends.base_backend import Plot +from sympy.plotting.backends.matplotlibbackend import MatplotlibBackend +from sympy.plotting.backends.textbackend import TextBackend +from sympy.plotting.series import ( + LineOver1DRangeSeries, Parametric2DLineSeries, Parametric3DLineSeries, + ParametricSurfaceSeries, SurfaceOver2DRangeSeries, ContourSeries) +from sympy.plotting.utils import _check_arguments, _plot_sympify +from sympy.tensor.indexed import Indexed +# to maintain back-compatibility +from sympy.plotting.plotgrid import PlotGrid # noqa: F401 +from sympy.plotting.series import BaseSeries # noqa: F401 +from sympy.plotting.series import Line2DBaseSeries # noqa: F401 +from sympy.plotting.series import Line3DBaseSeries # noqa: F401 +from sympy.plotting.series import SurfaceBaseSeries # noqa: F401 +from sympy.plotting.series import List2DSeries # noqa: F401 +from sympy.plotting.series import GenericDataSeries # noqa: F401 +from sympy.plotting.series import centers_of_faces # noqa: F401 +from sympy.plotting.series import centers_of_segments # noqa: F401 +from sympy.plotting.series import flat # noqa: F401 +from sympy.plotting.backends.base_backend import unset_show # noqa: F401 +from sympy.plotting.backends.matplotlibbackend import _matplotlib_list # noqa: F401 +from sympy.plotting.textplot import textplot # noqa: F401 + + +__doctest_requires__ = { + ('plot3d', + 'plot3d_parametric_line', + 'plot3d_parametric_surface', + 'plot_parametric'): ['matplotlib'], + # XXX: The plot doctest possibly should not require matplotlib. It fails at + # plot(x**2, (x, -5, 5)) which should be fine for text backend. + ('plot',): ['matplotlib'], +} + + +def _process_summations(sum_bound, *args): + """Substitute oo (infinity) in the lower/upper bounds of a summation with + some integer number. + + Parameters + ========== + + sum_bound : int + oo will be substituted with this integer number. + *args : list/tuple + pre-processed arguments of the form (expr, range, ...) + + Notes + ===== + Let's consider the following summation: ``Sum(1 / x**2, (x, 1, oo))``. + The current implementation of lambdify (SymPy 1.12 at the time of + writing this) will create something of this form: + ``sum(1 / x**2 for x in range(1, INF))`` + The problem is that ``type(INF)`` is float, while ``range`` requires + integers: the evaluation fails. + Instead of modifying ``lambdify`` (which requires a deep knowledge), just + replace it with some integer number. + """ + def new_bound(t, bound): + if (not t.is_number) or t.is_finite: + return t + if sign(t) >= 0: + return bound + return -bound + + args = list(args) + expr = args[0] + + # select summations whose lower/upper bound is infinity + w = Wild("w", properties=[ + lambda t: isinstance(t, Sum), + lambda t: any((not a[1].is_finite) or (not a[2].is_finite) for i, a in enumerate(t.args) if i > 0) + ]) + + for t in list(expr.find(w)): + sums_args = list(t.args) + for i, a in enumerate(sums_args): + if i > 0: + sums_args[i] = (a[0], new_bound(a[1], sum_bound), + new_bound(a[2], sum_bound)) + s = Sum(*sums_args) + expr = expr.subs(t, s) + args[0] = expr + return args + + +def _build_line_series(*args, **kwargs): + """Loop over the provided arguments and create the necessary line series. + """ + series = [] + sum_bound = int(kwargs.get("sum_bound", 1000)) + for arg in args: + expr, r, label, rendering_kw = arg + kw = kwargs.copy() + if rendering_kw is not None: + kw["rendering_kw"] = rendering_kw + # TODO: _process_piecewise check goes here + if not callable(expr): + arg = _process_summations(sum_bound, *arg) + series.append(LineOver1DRangeSeries(*arg[:-1], **kw)) + return series + + +def _create_series(series_type, plot_expr, **kwargs): + """Extract the rendering_kw dictionary from the provided arguments and + create an appropriate data series. + """ + series = [] + for args in plot_expr: + kw = kwargs.copy() + if args[-1] is not None: + kw["rendering_kw"] = args[-1] + series.append(series_type(*args[:-1], **kw)) + return series + + +def _set_labels(series, labels, rendering_kw): + """Apply the `label` and `rendering_kw` keyword arguments to the series. + """ + if not isinstance(labels, (list, tuple)): + labels = [labels] + if len(labels) > 0: + if len(labels) == 1 and len(series) > 1: + # if one label is provided and multiple series are being plotted, + # set the same label to all data series. It maintains + # back-compatibility + labels *= len(series) + if len(series) != len(labels): + raise ValueError("The number of labels must be equal to the " + "number of expressions being plotted.\nReceived " + f"{len(series)} expressions and {len(labels)} labels") + + for s, l in zip(series, labels): + s.label = l + + if rendering_kw: + if isinstance(rendering_kw, dict): + rendering_kw = [rendering_kw] + if len(rendering_kw) == 1: + rendering_kw *= len(series) + elif len(series) != len(rendering_kw): + raise ValueError("The number of rendering dictionaries must be " + "equal to the number of expressions being plotted.\nReceived " + f"{len(series)} expressions and {len(labels)} labels") + for s, r in zip(series, rendering_kw): + s.rendering_kw = r + + +def plot_factory(*args, **kwargs): + backend = kwargs.pop("backend", "default") + if isinstance(backend, str): + if backend == "default": + matplotlib = import_module('matplotlib', + min_module_version='1.1.0', catch=(RuntimeError,)) + if matplotlib: + return MatplotlibBackend(*args, **kwargs) + return TextBackend(*args, **kwargs) + return plot_backends[backend](*args, **kwargs) + elif (type(backend) == type) and issubclass(backend, Plot): + return backend(*args, **kwargs) + else: + raise TypeError("backend must be either a string or a subclass of ``Plot``.") + + +plot_backends = { + 'matplotlib': MatplotlibBackend, + 'text': TextBackend, +} + + +####New API for plotting module #### + +# TODO: Add color arrays for plots. +# TODO: Add more plotting options for 3d plots. +# TODO: Adaptive sampling for 3D plots. + +def plot(*args, show=True, **kwargs): + """Plots a function of a single variable as a curve. + + Parameters + ========== + + args : + The first argument is the expression representing the function + of single variable to be plotted. + + The last argument is a 3-tuple denoting the range of the free + variable. e.g. ``(x, 0, 5)`` + + Typical usage examples are in the following: + + - Plotting a single expression with a single range. + ``plot(expr, range, **kwargs)`` + - Plotting a single expression with the default range (-10, 10). + ``plot(expr, **kwargs)`` + - Plotting multiple expressions with a single range. + ``plot(expr1, expr2, ..., range, **kwargs)`` + - Plotting multiple expressions with multiple ranges. + ``plot((expr1, range1), (expr2, range2), ..., **kwargs)`` + + It is best practice to specify range explicitly because default + range may change in the future if a more advanced default range + detection algorithm is implemented. + + show : bool, optional + The default value is set to ``True``. Set show to ``False`` and + the function will not display the plot. The returned instance of + the ``Plot`` class can then be used to save or display the plot + by calling the ``save()`` and ``show()`` methods respectively. + + line_color : string, or float, or function, optional + Specifies the color for the plot. + See ``Plot`` to see how to set color for the plots. + Note that by setting ``line_color``, it would be applied simultaneously + to all the series. + + title : str, optional + Title of the plot. It is set to the latex representation of + the expression, if the plot has only one expression. + + label : str, optional + The label of the expression in the plot. It will be used when + called with ``legend``. Default is the name of the expression. + e.g. ``sin(x)`` + + xlabel : str or expression, optional + Label for the x-axis. + + ylabel : str or expression, optional + Label for the y-axis. + + xscale : 'linear' or 'log', optional + Sets the scaling of the x-axis. + + yscale : 'linear' or 'log', optional + Sets the scaling of the y-axis. + + axis_center : (float, float), optional + Tuple of two floats denoting the coordinates of the center or + {'center', 'auto'} + + xlim : (float, float), optional + Denotes the x-axis limits, ``(min, max)```. + + ylim : (float, float), optional + Denotes the y-axis limits, ``(min, max)```. + + annotations : list, optional + A list of dictionaries specifying the type of annotation + required. The keys in the dictionary should be equivalent + to the arguments of the :external:mod:`matplotlib`'s + :external:meth:`~matplotlib.axes.Axes.annotate` method. + + markers : list, optional + A list of dictionaries specifying the type the markers required. + The keys in the dictionary should be equivalent to the arguments + of the :external:mod:`matplotlib`'s :external:func:`~matplotlib.pyplot.plot()` function + along with the marker related keyworded arguments. + + rectangles : list, optional + A list of dictionaries specifying the dimensions of the + rectangles to be plotted. The keys in the dictionary should be + equivalent to the arguments of the :external:mod:`matplotlib`'s + :external:class:`~matplotlib.patches.Rectangle` class. + + fill : dict, optional + A dictionary specifying the type of color filling required in + the plot. The keys in the dictionary should be equivalent to the + arguments of the :external:mod:`matplotlib`'s + :external:meth:`~matplotlib.axes.Axes.fill_between` method. + + adaptive : bool, optional + The default value for the ``adaptive`` parameter is now ``False``. + To enable adaptive sampling, set ``adaptive=True`` and specify ``n`` if uniform sampling is required. + + The plotting uses an adaptive algorithm which samples + recursively to accurately plot. The adaptive algorithm uses a + random point near the midpoint of two points that has to be + further sampled. Hence the same plots can appear slightly + different. + + depth : int, optional + Recursion depth of the adaptive algorithm. A depth of value + `n` samples a maximum of `2^{n}` points. + + If the ``adaptive`` flag is set to ``False``, this will be + ignored. + + n : int, optional + Used when the ``adaptive`` is set to ``False``. The function + is uniformly sampled at ``n`` number of points. If the ``adaptive`` + flag is set to ``True``, this will be ignored. + This keyword argument replaces ``nb_of_points``, which should be + considered deprecated. + + size : (float, float), optional + A tuple in the form (width, height) in inches to specify the size of + the overall figure. The default value is set to ``None``, meaning + the size will be set by the default backend. + + Examples + ======== + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy import symbols + >>> from sympy.plotting import plot + >>> x = symbols('x') + + Single Plot + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> plot(x**2, (x, -5, 5)) + Plot object containing: + [0]: cartesian line: x**2 for x over (-5.0, 5.0) + + Multiple plots with single range. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> plot(x, x**2, x**3, (x, -5, 5)) + Plot object containing: + [0]: cartesian line: x for x over (-5.0, 5.0) + [1]: cartesian line: x**2 for x over (-5.0, 5.0) + [2]: cartesian line: x**3 for x over (-5.0, 5.0) + + Multiple plots with different ranges. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> plot((x**2, (x, -6, 6)), (x, (x, -5, 5))) + Plot object containing: + [0]: cartesian line: x**2 for x over (-6.0, 6.0) + [1]: cartesian line: x for x over (-5.0, 5.0) + + No adaptive sampling by default. If adaptive sampling is required, set ``adaptive=True``. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> plot(x**2, adaptive=True, n=400) + Plot object containing: + [0]: cartesian line: x**2 for x over (-10.0, 10.0) + + See Also + ======== + + Plot, LineOver1DRangeSeries + + """ + args = _plot_sympify(args) + plot_expr = _check_arguments(args, 1, 1, **kwargs) + params = kwargs.get("params", None) + free = set() + for p in plot_expr: + if not isinstance(p[1][0], str): + free |= {p[1][0]} + else: + free |= {Symbol(p[1][0])} + if params: + free = free.difference(params.keys()) + x = free.pop() if free else Symbol("x") + kwargs.setdefault('xlabel', x) + kwargs.setdefault('ylabel', Function('f')(x)) + + labels = kwargs.pop("label", []) + rendering_kw = kwargs.pop("rendering_kw", None) + series = _build_line_series(*plot_expr, **kwargs) + _set_labels(series, labels, rendering_kw) + + plots = plot_factory(*series, **kwargs) + if show: + plots.show() + return plots + + +def plot_parametric(*args, show=True, **kwargs): + """ + Plots a 2D parametric curve. + + Parameters + ========== + + args + Common specifications are: + + - Plotting a single parametric curve with a range + ``plot_parametric((expr_x, expr_y), range)`` + - Plotting multiple parametric curves with the same range + ``plot_parametric((expr_x, expr_y), ..., range)`` + - Plotting multiple parametric curves with different ranges + ``plot_parametric((expr_x, expr_y, range), ...)`` + + ``expr_x`` is the expression representing $x$ component of the + parametric function. + + ``expr_y`` is the expression representing $y$ component of the + parametric function. + + ``range`` is a 3-tuple denoting the parameter symbol, start and + stop. For example, ``(u, 0, 5)``. + + If the range is not specified, then a default range of (-10, 10) + is used. + + However, if the arguments are specified as + ``(expr_x, expr_y, range), ...``, you must specify the ranges + for each expressions manually. + + Default range may change in the future if a more advanced + algorithm is implemented. + + adaptive : bool, optional + Specifies whether to use the adaptive sampling or not. + + The default value is set to ``True``. Set adaptive to ``False`` + and specify ``n`` if uniform sampling is required. + + depth : int, optional + The recursion depth of the adaptive algorithm. A depth of + value $n$ samples a maximum of $2^n$ points. + + n : int, optional + Used when the ``adaptive`` flag is set to ``False``. Specifies the + number of the points used for the uniform sampling. + This keyword argument replaces ``nb_of_points``, which should be + considered deprecated. + + line_color : string, or float, or function, optional + Specifies the color for the plot. + See ``Plot`` to see how to set color for the plots. + Note that by setting ``line_color``, it would be applied simultaneously + to all the series. + + label : str, optional + The label of the expression in the plot. It will be used when + called with ``legend``. Default is the name of the expression. + e.g. ``sin(x)`` + + xlabel : str, optional + Label for the x-axis. + + ylabel : str, optional + Label for the y-axis. + + xscale : 'linear' or 'log', optional + Sets the scaling of the x-axis. + + yscale : 'linear' or 'log', optional + Sets the scaling of the y-axis. + + axis_center : (float, float), optional + Tuple of two floats denoting the coordinates of the center or + {'center', 'auto'} + + xlim : (float, float), optional + Denotes the x-axis limits, ``(min, max)```. + + ylim : (float, float), optional + Denotes the y-axis limits, ``(min, max)```. + + size : (float, float), optional + A tuple in the form (width, height) in inches to specify the size of + the overall figure. The default value is set to ``None``, meaning + the size will be set by the default backend. + + Examples + ======== + + .. plot:: + :context: reset + :format: doctest + :include-source: True + + >>> from sympy import plot_parametric, symbols, cos, sin + >>> u = symbols('u') + + A parametric plot with a single expression: + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> plot_parametric((cos(u), sin(u)), (u, -5, 5)) + Plot object containing: + [0]: parametric cartesian line: (cos(u), sin(u)) for u over (-5.0, 5.0) + + A parametric plot with multiple expressions with the same range: + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> plot_parametric((cos(u), sin(u)), (u, cos(u)), (u, -10, 10)) + Plot object containing: + [0]: parametric cartesian line: (cos(u), sin(u)) for u over (-10.0, 10.0) + [1]: parametric cartesian line: (u, cos(u)) for u over (-10.0, 10.0) + + A parametric plot with multiple expressions with different ranges + for each curve: + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> plot_parametric((cos(u), sin(u), (u, -5, 5)), + ... (cos(u), u, (u, -5, 5))) + Plot object containing: + [0]: parametric cartesian line: (cos(u), sin(u)) for u over (-5.0, 5.0) + [1]: parametric cartesian line: (cos(u), u) for u over (-5.0, 5.0) + + Notes + ===== + + The plotting uses an adaptive algorithm which samples recursively to + accurately plot the curve. The adaptive algorithm uses a random point + near the midpoint of two points that has to be further sampled. + Hence, repeating the same plot command can give slightly different + results because of the random sampling. + + If there are multiple plots, then the same optional arguments are + applied to all the plots drawn in the same canvas. If you want to + set these options separately, you can index the returned ``Plot`` + object and set it. + + For example, when you specify ``line_color`` once, it would be + applied simultaneously to both series. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy import pi + >>> expr1 = (u, cos(2*pi*u)/2 + 1/2) + >>> expr2 = (u, sin(2*pi*u)/2 + 1/2) + >>> p = plot_parametric(expr1, expr2, (u, 0, 1), line_color='blue') + + If you want to specify the line color for the specific series, you + should index each item and apply the property manually. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> p[0].line_color = 'red' + >>> p.show() + + See Also + ======== + + Plot, Parametric2DLineSeries + """ + args = _plot_sympify(args) + plot_expr = _check_arguments(args, 2, 1, **kwargs) + + labels = kwargs.pop("label", []) + rendering_kw = kwargs.pop("rendering_kw", None) + series = _create_series(Parametric2DLineSeries, plot_expr, **kwargs) + _set_labels(series, labels, rendering_kw) + + plots = plot_factory(*series, **kwargs) + if show: + plots.show() + return plots + + +def plot3d_parametric_line(*args, show=True, **kwargs): + """ + Plots a 3D parametric line plot. + + Usage + ===== + + Single plot: + + ``plot3d_parametric_line(expr_x, expr_y, expr_z, range, **kwargs)`` + + If the range is not specified, then a default range of (-10, 10) is used. + + Multiple plots. + + ``plot3d_parametric_line((expr_x, expr_y, expr_z, range), ..., **kwargs)`` + + Ranges have to be specified for every expression. + + Default range may change in the future if a more advanced default range + detection algorithm is implemented. + + Arguments + ========= + + expr_x : Expression representing the function along x. + + expr_y : Expression representing the function along y. + + expr_z : Expression representing the function along z. + + range : (:class:`~.Symbol`, float, float) + A 3-tuple denoting the range of the parameter variable, e.g., (u, 0, 5). + + Keyword Arguments + ================= + + Arguments for ``Parametric3DLineSeries`` class. + + n : int + The range is uniformly sampled at ``n`` number of points. + This keyword argument replaces ``nb_of_points``, which should be + considered deprecated. + + Aesthetics: + + line_color : string, or float, or function, optional + Specifies the color for the plot. + See ``Plot`` to see how to set color for the plots. + Note that by setting ``line_color``, it would be applied simultaneously + to all the series. + + label : str + The label to the plot. It will be used when called with ``legend=True`` + to denote the function with the given label in the plot. + + If there are multiple plots, then the same series arguments are applied to + all the plots. If you want to set these options separately, you can index + the returned ``Plot`` object and set it. + + Arguments for ``Plot`` class. + + title : str + Title of the plot. + + size : (float, float), optional + A tuple in the form (width, height) in inches to specify the size of + the overall figure. The default value is set to ``None``, meaning + the size will be set by the default backend. + + Examples + ======== + + .. plot:: + :context: reset + :format: doctest + :include-source: True + + >>> from sympy import symbols, cos, sin + >>> from sympy.plotting import plot3d_parametric_line + >>> u = symbols('u') + + Single plot. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> plot3d_parametric_line(cos(u), sin(u), u, (u, -5, 5)) + Plot object containing: + [0]: 3D parametric cartesian line: (cos(u), sin(u), u) for u over (-5.0, 5.0) + + + Multiple plots. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> plot3d_parametric_line((cos(u), sin(u), u, (u, -5, 5)), + ... (sin(u), u**2, u, (u, -5, 5))) + Plot object containing: + [0]: 3D parametric cartesian line: (cos(u), sin(u), u) for u over (-5.0, 5.0) + [1]: 3D parametric cartesian line: (sin(u), u**2, u) for u over (-5.0, 5.0) + + + See Also + ======== + + Plot, Parametric3DLineSeries + + """ + args = _plot_sympify(args) + plot_expr = _check_arguments(args, 3, 1, **kwargs) + kwargs.setdefault("xlabel", "x") + kwargs.setdefault("ylabel", "y") + kwargs.setdefault("zlabel", "z") + + labels = kwargs.pop("label", []) + rendering_kw = kwargs.pop("rendering_kw", None) + series = _create_series(Parametric3DLineSeries, plot_expr, **kwargs) + _set_labels(series, labels, rendering_kw) + + plots = plot_factory(*series, **kwargs) + if show: + plots.show() + return plots + + +def _plot3d_plot_contour_helper(Series, *args, **kwargs): + """plot3d and plot_contour are structurally identical. Let's reduce + code repetition. + """ + # NOTE: if this import would be at the top-module level, it would trigger + # SymPy's optional-dependencies tests to fail. + from sympy.vector import BaseScalar + + args = _plot_sympify(args) + plot_expr = _check_arguments(args, 1, 2, **kwargs) + + free_x = set() + free_y = set() + _types = (Symbol, BaseScalar, Indexed, AppliedUndef) + for p in plot_expr: + free_x |= {p[1][0]} if isinstance(p[1][0], _types) else {Symbol(p[1][0])} + free_y |= {p[2][0]} if isinstance(p[2][0], _types) else {Symbol(p[2][0])} + x = free_x.pop() if free_x else Symbol("x") + y = free_y.pop() if free_y else Symbol("y") + kwargs.setdefault("xlabel", x) + kwargs.setdefault("ylabel", y) + kwargs.setdefault("zlabel", Function('f')(x, y)) + + # if a polar discretization is requested and automatic labelling has ben + # applied, hide the labels on the x-y axis. + if kwargs.get("is_polar", False): + if callable(kwargs["xlabel"]): + kwargs["xlabel"] = "" + if callable(kwargs["ylabel"]): + kwargs["ylabel"] = "" + + labels = kwargs.pop("label", []) + rendering_kw = kwargs.pop("rendering_kw", None) + series = _create_series(Series, plot_expr, **kwargs) + _set_labels(series, labels, rendering_kw) + plots = plot_factory(*series, **kwargs) + if kwargs.get("show", True): + plots.show() + return plots + + +def plot3d(*args, show=True, **kwargs): + """ + Plots a 3D surface plot. + + Usage + ===== + + Single plot + + ``plot3d(expr, range_x, range_y, **kwargs)`` + + If the ranges are not specified, then a default range of (-10, 10) is used. + + Multiple plot with the same range. + + ``plot3d(expr1, expr2, range_x, range_y, **kwargs)`` + + If the ranges are not specified, then a default range of (-10, 10) is used. + + Multiple plots with different ranges. + + ``plot3d((expr1, range_x, range_y), (expr2, range_x, range_y), ..., **kwargs)`` + + Ranges have to be specified for every expression. + + Default range may change in the future if a more advanced default range + detection algorithm is implemented. + + Arguments + ========= + + expr : Expression representing the function along x. + + range_x : (:class:`~.Symbol`, float, float) + A 3-tuple denoting the range of the x variable, e.g. (x, 0, 5). + + range_y : (:class:`~.Symbol`, float, float) + A 3-tuple denoting the range of the y variable, e.g. (y, 0, 5). + + Keyword Arguments + ================= + + Arguments for ``SurfaceOver2DRangeSeries`` class: + + n1 : int + The x range is sampled uniformly at ``n1`` of points. + This keyword argument replaces ``nb_of_points_x``, which should be + considered deprecated. + + n2 : int + The y range is sampled uniformly at ``n2`` of points. + This keyword argument replaces ``nb_of_points_y``, which should be + considered deprecated. + + Aesthetics: + + surface_color : Function which returns a float + Specifies the color for the surface of the plot. + See :class:`~.Plot` for more details. + + If there are multiple plots, then the same series arguments are applied to + all the plots. If you want to set these options separately, you can index + the returned ``Plot`` object and set it. + + Arguments for ``Plot`` class: + + title : str + Title of the plot. + + size : (float, float), optional + A tuple in the form (width, height) in inches to specify the size of the + overall figure. The default value is set to ``None``, meaning the size will + be set by the default backend. + + Examples + ======== + + .. plot:: + :context: reset + :format: doctest + :include-source: True + + >>> from sympy import symbols + >>> from sympy.plotting import plot3d + >>> x, y = symbols('x y') + + Single plot + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> plot3d(x*y, (x, -5, 5), (y, -5, 5)) + Plot object containing: + [0]: cartesian surface: x*y for x over (-5.0, 5.0) and y over (-5.0, 5.0) + + + Multiple plots with same range + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> plot3d(x*y, -x*y, (x, -5, 5), (y, -5, 5)) + Plot object containing: + [0]: cartesian surface: x*y for x over (-5.0, 5.0) and y over (-5.0, 5.0) + [1]: cartesian surface: -x*y for x over (-5.0, 5.0) and y over (-5.0, 5.0) + + + Multiple plots with different ranges. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> plot3d((x**2 + y**2, (x, -5, 5), (y, -5, 5)), + ... (x*y, (x, -3, 3), (y, -3, 3))) + Plot object containing: + [0]: cartesian surface: x**2 + y**2 for x over (-5.0, 5.0) and y over (-5.0, 5.0) + [1]: cartesian surface: x*y for x over (-3.0, 3.0) and y over (-3.0, 3.0) + + + See Also + ======== + + Plot, SurfaceOver2DRangeSeries + + """ + kwargs.setdefault("show", show) + return _plot3d_plot_contour_helper( + SurfaceOver2DRangeSeries, *args, **kwargs) + + +def plot3d_parametric_surface(*args, show=True, **kwargs): + """ + Plots a 3D parametric surface plot. + + Explanation + =========== + + Single plot. + + ``plot3d_parametric_surface(expr_x, expr_y, expr_z, range_u, range_v, **kwargs)`` + + If the ranges is not specified, then a default range of (-10, 10) is used. + + Multiple plots. + + ``plot3d_parametric_surface((expr_x, expr_y, expr_z, range_u, range_v), ..., **kwargs)`` + + Ranges have to be specified for every expression. + + Default range may change in the future if a more advanced default range + detection algorithm is implemented. + + Arguments + ========= + + expr_x : Expression representing the function along ``x``. + + expr_y : Expression representing the function along ``y``. + + expr_z : Expression representing the function along ``z``. + + range_u : (:class:`~.Symbol`, float, float) + A 3-tuple denoting the range of the u variable, e.g. (u, 0, 5). + + range_v : (:class:`~.Symbol`, float, float) + A 3-tuple denoting the range of the v variable, e.g. (v, 0, 5). + + Keyword Arguments + ================= + + Arguments for ``ParametricSurfaceSeries`` class: + + n1 : int + The ``u`` range is sampled uniformly at ``n1`` of points. + This keyword argument replaces ``nb_of_points_u``, which should be + considered deprecated. + + n2 : int + The ``v`` range is sampled uniformly at ``n2`` of points. + This keyword argument replaces ``nb_of_points_v``, which should be + considered deprecated. + + Aesthetics: + + surface_color : Function which returns a float + Specifies the color for the surface of the plot. See + :class:`~Plot` for more details. + + If there are multiple plots, then the same series arguments are applied for + all the plots. If you want to set these options separately, you can index + the returned ``Plot`` object and set it. + + + Arguments for ``Plot`` class: + + title : str + Title of the plot. + + size : (float, float), optional + A tuple in the form (width, height) in inches to specify the size of the + overall figure. The default value is set to ``None``, meaning the size will + be set by the default backend. + + Examples + ======== + + .. plot:: + :context: reset + :format: doctest + :include-source: True + + >>> from sympy import symbols, cos, sin + >>> from sympy.plotting import plot3d_parametric_surface + >>> u, v = symbols('u v') + + Single plot. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> plot3d_parametric_surface(cos(u + v), sin(u - v), u - v, + ... (u, -5, 5), (v, -5, 5)) + Plot object containing: + [0]: parametric cartesian surface: (cos(u + v), sin(u - v), u - v) for u over (-5.0, 5.0) and v over (-5.0, 5.0) + + + See Also + ======== + + Plot, ParametricSurfaceSeries + + """ + + args = _plot_sympify(args) + plot_expr = _check_arguments(args, 3, 2, **kwargs) + kwargs.setdefault("xlabel", "x") + kwargs.setdefault("ylabel", "y") + kwargs.setdefault("zlabel", "z") + + labels = kwargs.pop("label", []) + rendering_kw = kwargs.pop("rendering_kw", None) + series = _create_series(ParametricSurfaceSeries, plot_expr, **kwargs) + _set_labels(series, labels, rendering_kw) + + plots = plot_factory(*series, **kwargs) + if show: + plots.show() + return plots + +def plot_contour(*args, show=True, **kwargs): + """ + Draws contour plot of a function + + Usage + ===== + + Single plot + + ``plot_contour(expr, range_x, range_y, **kwargs)`` + + If the ranges are not specified, then a default range of (-10, 10) is used. + + Multiple plot with the same range. + + ``plot_contour(expr1, expr2, range_x, range_y, **kwargs)`` + + If the ranges are not specified, then a default range of (-10, 10) is used. + + Multiple plots with different ranges. + + ``plot_contour((expr1, range_x, range_y), (expr2, range_x, range_y), ..., **kwargs)`` + + Ranges have to be specified for every expression. + + Default range may change in the future if a more advanced default range + detection algorithm is implemented. + + Arguments + ========= + + expr : Expression representing the function along x. + + range_x : (:class:`Symbol`, float, float) + A 3-tuple denoting the range of the x variable, e.g. (x, 0, 5). + + range_y : (:class:`Symbol`, float, float) + A 3-tuple denoting the range of the y variable, e.g. (y, 0, 5). + + Keyword Arguments + ================= + + Arguments for ``ContourSeries`` class: + + n1 : int + The x range is sampled uniformly at ``n1`` of points. + This keyword argument replaces ``nb_of_points_x``, which should be + considered deprecated. + + n2 : int + The y range is sampled uniformly at ``n2`` of points. + This keyword argument replaces ``nb_of_points_y``, which should be + considered deprecated. + + Aesthetics: + + surface_color : Function which returns a float + Specifies the color for the surface of the plot. See + :class:`sympy.plotting.Plot` for more details. + + If there are multiple plots, then the same series arguments are applied to + all the plots. If you want to set these options separately, you can index + the returned ``Plot`` object and set it. + + Arguments for ``Plot`` class: + + title : str + Title of the plot. + + size : (float, float), optional + A tuple in the form (width, height) in inches to specify the size of + the overall figure. The default value is set to ``None``, meaning + the size will be set by the default backend. + + See Also + ======== + + Plot, ContourSeries + + """ + kwargs.setdefault("show", show) + return _plot3d_plot_contour_helper(ContourSeries, *args, **kwargs) + + +def check_arguments(args, expr_len, nb_of_free_symbols): + """ + Checks the arguments and converts into tuples of the + form (exprs, ranges). + + Examples + ======== + + .. plot:: + :context: reset + :format: doctest + :include-source: True + + >>> from sympy import cos, sin, symbols + >>> from sympy.plotting.plot import check_arguments + >>> x = symbols('x') + >>> check_arguments([cos(x), sin(x)], 2, 1) + [(cos(x), sin(x), (x, -10, 10))] + + >>> check_arguments([x, x**2], 1, 1) + [(x, (x, -10, 10)), (x**2, (x, -10, 10))] + """ + if not args: + return [] + if expr_len > 1 and isinstance(args[0], Expr): + # Multiple expressions same range. + # The arguments are tuples when the expression length is + # greater than 1. + if len(args) < expr_len: + raise ValueError("len(args) should not be less than expr_len") + for i in range(len(args)): + if isinstance(args[i], Tuple): + break + else: + i = len(args) + 1 + + exprs = Tuple(*args[:i]) + free_symbols = list(set().union(*[e.free_symbols for e in exprs])) + if len(args) == expr_len + nb_of_free_symbols: + #Ranges given + plots = [exprs + Tuple(*args[expr_len:])] + else: + default_range = Tuple(-10, 10) + ranges = [] + for symbol in free_symbols: + ranges.append(Tuple(symbol) + default_range) + + for i in range(len(free_symbols) - nb_of_free_symbols): + ranges.append(Tuple(Dummy()) + default_range) + plots = [exprs + Tuple(*ranges)] + return plots + + if isinstance(args[0], Expr) or (isinstance(args[0], Tuple) and + len(args[0]) == expr_len and + expr_len != 3): + # Cannot handle expressions with number of expression = 3. It is + # not possible to differentiate between expressions and ranges. + #Series of plots with same range + for i in range(len(args)): + if isinstance(args[i], Tuple) and len(args[i]) != expr_len: + break + if not isinstance(args[i], Tuple): + args[i] = Tuple(args[i]) + else: + i = len(args) + 1 + + exprs = args[:i] + assert all(isinstance(e, Expr) for expr in exprs for e in expr) + free_symbols = list(set().union(*[e.free_symbols for expr in exprs + for e in expr])) + + if len(free_symbols) > nb_of_free_symbols: + raise ValueError("The number of free_symbols in the expression " + "is greater than %d" % nb_of_free_symbols) + if len(args) == i + nb_of_free_symbols and isinstance(args[i], Tuple): + ranges = Tuple(*list(args[ + i:i + nb_of_free_symbols])) + plots = [expr + ranges for expr in exprs] + return plots + else: + # Use default ranges. + default_range = Tuple(-10, 10) + ranges = [] + for symbol in free_symbols: + ranges.append(Tuple(symbol) + default_range) + + for i in range(nb_of_free_symbols - len(free_symbols)): + ranges.append(Tuple(Dummy()) + default_range) + ranges = Tuple(*ranges) + plots = [expr + ranges for expr in exprs] + return plots + + elif isinstance(args[0], Tuple) and len(args[0]) == expr_len + nb_of_free_symbols: + # Multiple plots with different ranges. + for arg in args: + for i in range(expr_len): + if not isinstance(arg[i], Expr): + raise ValueError("Expected an expression, given %s" % + str(arg[i])) + for i in range(nb_of_free_symbols): + if not len(arg[i + expr_len]) == 3: + raise ValueError("The ranges should be a tuple of " + "length 3, got %s" % str(arg[i + expr_len])) + return args diff --git a/phivenv/Lib/site-packages/sympy/plotting/plot_implicit.py b/phivenv/Lib/site-packages/sympy/plotting/plot_implicit.py new file mode 100644 index 0000000000000000000000000000000000000000..5dceaf0699a2e6d3ff0bc30f415721918724cad5 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/plotting/plot_implicit.py @@ -0,0 +1,233 @@ +"""Implicit plotting module for SymPy. + +Explanation +=========== + +The module implements a data series called ImplicitSeries which is used by +``Plot`` class to plot implicit plots for different backends. The module, +by default, implements plotting using interval arithmetic. It switches to a +fall back algorithm if the expression cannot be plotted using interval arithmetic. +It is also possible to specify to use the fall back algorithm for all plots. + +Boolean combinations of expressions cannot be plotted by the fall back +algorithm. + +See Also +======== + +sympy.plotting.plot + +References +========== + +.. [1] Jeffrey Allen Tupper. Reliable Two-Dimensional Graphing Methods for +Mathematical Formulae with Two Free Variables. + +.. [2] Jeffrey Allen Tupper. Graphing Equations with Generalized Interval +Arithmetic. Master's thesis. University of Toronto, 1996 + +""" + + +from sympy.core.containers import Tuple +from sympy.core.symbol import (Dummy, Symbol) +from sympy.polys.polyutils import _sort_gens +from sympy.plotting.series import ImplicitSeries, _set_discretization_points +from sympy.plotting.plot import plot_factory +from sympy.utilities.decorator import doctest_depends_on +from sympy.utilities.iterables import flatten + + +__doctest_requires__ = {'plot_implicit': ['matplotlib']} + + +@doctest_depends_on(modules=('matplotlib',)) +def plot_implicit(expr, x_var=None, y_var=None, adaptive=True, depth=0, + n=300, line_color="blue", show=True, **kwargs): + """A plot function to plot implicit equations / inequalities. + + Arguments + ========= + + - expr : The equation / inequality that is to be plotted. + - x_var (optional) : symbol to plot on x-axis or tuple giving symbol + and range as ``(symbol, xmin, xmax)`` + - y_var (optional) : symbol to plot on y-axis or tuple giving symbol + and range as ``(symbol, ymin, ymax)`` + + If neither ``x_var`` nor ``y_var`` are given then the free symbols in the + expression will be assigned in the order they are sorted. + + The following keyword arguments can also be used: + + - ``adaptive`` Boolean. The default value is set to True. It has to be + set to False if you want to use a mesh grid. + + - ``depth`` integer. The depth of recursion for adaptive mesh grid. + Default value is 0. Takes value in the range (0, 4). + + - ``n`` integer. The number of points if adaptive mesh grid is not + used. Default value is 300. This keyword argument replaces ``points``, + which should be considered deprecated. + + - ``show`` Boolean. Default value is True. If set to False, the plot will + not be shown. See ``Plot`` for further information. + + - ``title`` string. The title for the plot. + + - ``xlabel`` string. The label for the x-axis + + - ``ylabel`` string. The label for the y-axis + + Aesthetics options: + + - ``line_color``: float or string. Specifies the color for the plot. + See ``Plot`` to see how to set color for the plots. + Default value is "Blue" + + plot_implicit, by default, uses interval arithmetic to plot functions. If + the expression cannot be plotted using interval arithmetic, it defaults to + a generating a contour using a mesh grid of fixed number of points. By + setting adaptive to False, you can force plot_implicit to use the mesh + grid. The mesh grid method can be effective when adaptive plotting using + interval arithmetic, fails to plot with small line width. + + Examples + ======== + + Plot expressions: + + .. plot:: + :context: reset + :format: doctest + :include-source: True + + >>> from sympy import plot_implicit, symbols, Eq, And + >>> x, y = symbols('x y') + + Without any ranges for the symbols in the expression: + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> p1 = plot_implicit(Eq(x**2 + y**2, 5)) + + With the range for the symbols: + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> p2 = plot_implicit( + ... Eq(x**2 + y**2, 3), (x, -3, 3), (y, -3, 3)) + + With depth of recursion as argument: + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> p3 = plot_implicit( + ... Eq(x**2 + y**2, 5), (x, -4, 4), (y, -4, 4), depth = 2) + + Using mesh grid and not using adaptive meshing: + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> p4 = plot_implicit( + ... Eq(x**2 + y**2, 5), (x, -5, 5), (y, -2, 2), + ... adaptive=False) + + Using mesh grid without using adaptive meshing with number of points + specified: + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> p5 = plot_implicit( + ... Eq(x**2 + y**2, 5), (x, -5, 5), (y, -2, 2), + ... adaptive=False, n=400) + + Plotting regions: + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> p6 = plot_implicit(y > x**2) + + Plotting Using boolean conjunctions: + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> p7 = plot_implicit(And(y > x, y > -x)) + + When plotting an expression with a single variable (y - 1, for example), + specify the x or the y variable explicitly: + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> p8 = plot_implicit(y - 1, y_var=y) + >>> p9 = plot_implicit(x - 1, x_var=x) + """ + + xyvar = [i for i in (x_var, y_var) if i is not None] + free_symbols = expr.free_symbols + range_symbols = Tuple(*flatten(xyvar)).free_symbols + undeclared = free_symbols - range_symbols + if len(free_symbols & range_symbols) > 2: + raise NotImplementedError("Implicit plotting is not implemented for " + "more than 2 variables") + + #Create default ranges if the range is not provided. + default_range = Tuple(-5, 5) + def _range_tuple(s): + if isinstance(s, Symbol): + return Tuple(s) + default_range + if len(s) == 3: + return Tuple(*s) + raise ValueError('symbol or `(symbol, min, max)` expected but got %s' % s) + + if len(xyvar) == 0: + xyvar = list(_sort_gens(free_symbols)) + var_start_end_x = _range_tuple(xyvar[0]) + x = var_start_end_x[0] + if len(xyvar) != 2: + if x in undeclared or not undeclared: + xyvar.append(Dummy('f(%s)' % x.name)) + else: + xyvar.append(undeclared.pop()) + var_start_end_y = _range_tuple(xyvar[1]) + + kwargs = _set_discretization_points(kwargs, ImplicitSeries) + series_argument = ImplicitSeries( + expr, var_start_end_x, var_start_end_y, + adaptive=adaptive, depth=depth, + n=n, line_color=line_color) + + #set the x and y limits + kwargs['xlim'] = tuple(float(x) for x in var_start_end_x[1:]) + kwargs['ylim'] = tuple(float(y) for y in var_start_end_y[1:]) + # set the x and y labels + kwargs.setdefault('xlabel', var_start_end_x[0]) + kwargs.setdefault('ylabel', var_start_end_y[0]) + p = plot_factory(series_argument, **kwargs) + if show: + p.show() + return p diff --git a/phivenv/Lib/site-packages/sympy/plotting/plotgrid.py b/phivenv/Lib/site-packages/sympy/plotting/plotgrid.py new file mode 100644 index 0000000000000000000000000000000000000000..8ff811c591e762275df1a0e3a221d05920d1804e --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/plotting/plotgrid.py @@ -0,0 +1,188 @@ + +from sympy.external import import_module +import sympy.plotting.backends.base_backend as base_backend + + +# N.B. +# When changing the minimum module version for matplotlib, please change +# the same in the `SymPyDocTestFinder`` in `sympy/testing/runtests.py` + + +__doctest_requires__ = { + ("PlotGrid",): ["matplotlib"], +} + + +class PlotGrid: + """This class helps to plot subplots from already created SymPy plots + in a single figure. + + Examples + ======== + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy import symbols + >>> from sympy.plotting import plot, plot3d, PlotGrid + >>> x, y = symbols('x, y') + >>> p1 = plot(x, x**2, x**3, (x, -5, 5)) + >>> p2 = plot((x**2, (x, -6, 6)), (x, (x, -5, 5))) + >>> p3 = plot(x**3, (x, -5, 5)) + >>> p4 = plot3d(x*y, (x, -5, 5), (y, -5, 5)) + + Plotting vertically in a single line: + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> PlotGrid(2, 1, p1, p2) + PlotGrid object containing: + Plot[0]:Plot object containing: + [0]: cartesian line: x for x over (-5.0, 5.0) + [1]: cartesian line: x**2 for x over (-5.0, 5.0) + [2]: cartesian line: x**3 for x over (-5.0, 5.0) + Plot[1]:Plot object containing: + [0]: cartesian line: x**2 for x over (-6.0, 6.0) + [1]: cartesian line: x for x over (-5.0, 5.0) + + Plotting horizontally in a single line: + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> PlotGrid(1, 3, p2, p3, p4) + PlotGrid object containing: + Plot[0]:Plot object containing: + [0]: cartesian line: x**2 for x over (-6.0, 6.0) + [1]: cartesian line: x for x over (-5.0, 5.0) + Plot[1]:Plot object containing: + [0]: cartesian line: x**3 for x over (-5.0, 5.0) + Plot[2]:Plot object containing: + [0]: cartesian surface: x*y for x over (-5.0, 5.0) and y over (-5.0, 5.0) + + Plotting in a grid form: + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> PlotGrid(2, 2, p1, p2, p3, p4) + PlotGrid object containing: + Plot[0]:Plot object containing: + [0]: cartesian line: x for x over (-5.0, 5.0) + [1]: cartesian line: x**2 for x over (-5.0, 5.0) + [2]: cartesian line: x**3 for x over (-5.0, 5.0) + Plot[1]:Plot object containing: + [0]: cartesian line: x**2 for x over (-6.0, 6.0) + [1]: cartesian line: x for x over (-5.0, 5.0) + Plot[2]:Plot object containing: + [0]: cartesian line: x**3 for x over (-5.0, 5.0) + Plot[3]:Plot object containing: + [0]: cartesian surface: x*y for x over (-5.0, 5.0) and y over (-5.0, 5.0) + + """ + def __init__(self, nrows, ncolumns, *args, show=True, size=None, **kwargs): + """ + Parameters + ========== + + nrows : + The number of rows that should be in the grid of the + required subplot. + ncolumns : + The number of columns that should be in the grid + of the required subplot. + + nrows and ncolumns together define the required grid. + + Arguments + ========= + + A list of predefined plot objects entered in a row-wise sequence + i.e. plot objects which are to be in the top row of the required + grid are written first, then the second row objects and so on + + Keyword arguments + ================= + + show : Boolean + The default value is set to ``True``. Set show to ``False`` and + the function will not display the subplot. The returned instance + of the ``PlotGrid`` class can then be used to save or display the + plot by calling the ``save()`` and ``show()`` methods + respectively. + size : (float, float), optional + A tuple in the form (width, height) in inches to specify the size of + the overall figure. The default value is set to ``None``, meaning + the size will be set by the default backend. + """ + self.matplotlib = import_module('matplotlib', + import_kwargs={'fromlist': ['pyplot', 'cm', 'collections']}, + min_module_version='1.1.0', catch=(RuntimeError,)) + self.nrows = nrows + self.ncolumns = ncolumns + self._series = [] + self._fig = None + self.args = args + for arg in args: + self._series.append(arg._series) + self.size = size + if show and self.matplotlib: + self.show() + + def _create_figure(self): + gs = self.matplotlib.gridspec.GridSpec(self.nrows, self.ncolumns) + mapping = {} + c = 0 + for i in range(self.nrows): + for j in range(self.ncolumns): + if c < len(self.args): + mapping[gs[i, j]] = self.args[c] + c += 1 + + kw = {} if not self.size else {"figsize": self.size} + self._fig = self.matplotlib.pyplot.figure(**kw) + for spec, p in mapping.items(): + kw = ({"projection": "3d"} if (len(p._series) > 0 and + p._series[0].is_3D) else {}) + cur_ax = self._fig.add_subplot(spec, **kw) + p._plotgrid_fig = self._fig + p._plotgrid_ax = cur_ax + p.process_series() + + @property + def fig(self): + if not self._fig: + self._create_figure() + return self._fig + + @property + def _backend(self): + return self + + def close(self): + self.matplotlib.pyplot.close(self.fig) + + def show(self): + if base_backend._show: + self.fig.tight_layout() + self.matplotlib.pyplot.show() + else: + self.close() + + def save(self, path): + self.fig.savefig(path) + + def __str__(self): + plot_strs = [('Plot[%d]:' % i) + str(plot) + for i, plot in enumerate(self.args)] + + return 'PlotGrid object containing:\n' + '\n'.join(plot_strs) diff --git a/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__init__.py b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cd86a505d8c4b8026bd91cde27d441e00223a8bc --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__init__.py @@ -0,0 +1,138 @@ +"""Plotting module that can plot 2D and 3D functions +""" + +from sympy.utilities.decorator import doctest_depends_on + +@doctest_depends_on(modules=('pyglet',)) +def PygletPlot(*args, **kwargs): + """ + + Plot Examples + ============= + + See examples/advanced/pyglet_plotting.py for many more examples. + + >>> from sympy.plotting.pygletplot import PygletPlot as Plot + >>> from sympy.abc import x, y, z + + >>> Plot(x*y**3-y*x**3) + [0]: -x**3*y + x*y**3, 'mode=cartesian' + + >>> p = Plot() + >>> p[1] = x*y + >>> p[1].color = z, (0.4,0.4,0.9), (0.9,0.4,0.4) + + >>> p = Plot() + >>> p[1] = x**2+y**2 + >>> p[2] = -x**2-y**2 + + + Variable Intervals + ================== + + The basic format is [var, min, max, steps], but the + syntax is flexible and arguments left out are taken + from the defaults for the current coordinate mode: + + >>> Plot(x**2) # implies [x,-5,5,100] + [0]: x**2, 'mode=cartesian' + + >>> Plot(x**2, [], []) # [x,-1,1,40], [y,-1,1,40] + [0]: x**2, 'mode=cartesian' + >>> Plot(x**2-y**2, [100], [100]) # [x,-1,1,100], [y,-1,1,100] + [0]: x**2 - y**2, 'mode=cartesian' + >>> Plot(x**2, [x,-13,13,100]) + [0]: x**2, 'mode=cartesian' + >>> Plot(x**2, [-13,13]) # [x,-13,13,100] + [0]: x**2, 'mode=cartesian' + >>> Plot(x**2, [x,-13,13]) # [x,-13,13,100] + [0]: x**2, 'mode=cartesian' + >>> Plot(1*x, [], [x], mode='cylindrical') + ... # [unbound_theta,0,2*Pi,40], [x,-1,1,20] + [0]: x, 'mode=cartesian' + + + Coordinate Modes + ================ + + Plot supports several curvilinear coordinate modes, and + they independent for each plotted function. You can specify + a coordinate mode explicitly with the 'mode' named argument, + but it can be automatically determined for Cartesian or + parametric plots, and therefore must only be specified for + polar, cylindrical, and spherical modes. + + Specifically, Plot(function arguments) and Plot[n] = + (function arguments) will interpret your arguments as a + Cartesian plot if you provide one function and a parametric + plot if you provide two or three functions. Similarly, the + arguments will be interpreted as a curve if one variable is + used, and a surface if two are used. + + Supported mode names by number of variables: + + 1: parametric, cartesian, polar + 2: parametric, cartesian, cylindrical = polar, spherical + + >>> Plot(1, mode='spherical') + + + Calculator-like Interface + ========================= + + >>> p = Plot(visible=False) + >>> f = x**2 + >>> p[1] = f + >>> p[2] = f.diff(x) + >>> p[3] = f.diff(x).diff(x) + >>> p + [1]: x**2, 'mode=cartesian' + [2]: 2*x, 'mode=cartesian' + [3]: 2, 'mode=cartesian' + >>> p.show() + >>> p.clear() + >>> p + + >>> p[1] = x**2+y**2 + >>> p[1].style = 'solid' + >>> p[2] = -x**2-y**2 + >>> p[2].style = 'wireframe' + >>> p[1].color = z, (0.4,0.4,0.9), (0.9,0.4,0.4) + >>> p[1].style = 'both' + >>> p[2].style = 'both' + >>> p.close() + + + Plot Window Keyboard Controls + ============================= + + Screen Rotation: + X,Y axis Arrow Keys, A,S,D,W, Numpad 4,6,8,2 + Z axis Q,E, Numpad 7,9 + + Model Rotation: + Z axis Z,C, Numpad 1,3 + + Zoom: R,F, PgUp,PgDn, Numpad +,- + + Reset Camera: X, Numpad 5 + + Camera Presets: + XY F1 + XZ F2 + YZ F3 + Perspective F4 + + Sensitivity Modifier: SHIFT + + Axes Toggle: + Visible F5 + Colors F6 + + Close Window: ESCAPE + + ============================= + """ + + from sympy.plotting.pygletplot.plot import PygletPlot + return PygletPlot(*args, **kwargs) diff --git a/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8c0b7e7b4d4403c4a00c6db7240c5469fa10dea Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__pycache__/color_scheme.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__pycache__/color_scheme.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4f84cb6443667aedce13671bd247de6363a3d9a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__pycache__/color_scheme.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__pycache__/managed_window.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__pycache__/managed_window.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c25b3187ab7da4b0b417290ffbc39eaf2f98897 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__pycache__/managed_window.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__pycache__/plot.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__pycache__/plot.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f69444097437bf25fda7c854f46956e34d3831ee Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__pycache__/plot.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__pycache__/plot_axes.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__pycache__/plot_axes.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a835ad981a9956047931ea30b35c23467235aa8 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__pycache__/plot_axes.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__pycache__/plot_camera.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__pycache__/plot_camera.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1fa74aaca53b9e7b4fbc2b74837021771e397ff3 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__pycache__/plot_camera.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__pycache__/plot_controller.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__pycache__/plot_controller.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bfbdeec50401b5ab6156d9162bf170b996055fb7 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__pycache__/plot_controller.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__pycache__/plot_curve.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__pycache__/plot_curve.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b3e9d18d58abd46ade6a3ab61a2c9aaa8133f78 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__pycache__/plot_curve.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__pycache__/plot_interval.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__pycache__/plot_interval.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d2510d14a70b71834fadb75dab8caf4ad6378da Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__pycache__/plot_interval.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__pycache__/plot_mode.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__pycache__/plot_mode.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c71be9edba106375873863cd318f8edf490550fa Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__pycache__/plot_mode.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__pycache__/plot_mode_base.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__pycache__/plot_mode_base.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ea1cd023190f1ad9ad6e1e24f40a80bb3bdbb70 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__pycache__/plot_mode_base.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__pycache__/plot_modes.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__pycache__/plot_modes.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf109e3d73e7df20e492d4d9cf1e1a80f5997956 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__pycache__/plot_modes.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__pycache__/plot_object.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__pycache__/plot_object.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b5cadd32c73b5ce383af24d25f6b814bd808d0a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__pycache__/plot_object.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__pycache__/plot_rotation.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__pycache__/plot_rotation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5657757c0a792b859acb5dce1e748fc8865eda61 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__pycache__/plot_rotation.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__pycache__/plot_surface.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__pycache__/plot_surface.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e276db39740cb7f1399e9f8528240ce045a047e3 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__pycache__/plot_surface.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__pycache__/plot_window.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__pycache__/plot_window.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0e17826e40d241191b9c693a382eec15aa5c087e Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__pycache__/plot_window.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__pycache__/util.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__pycache__/util.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..516170805aba604c8c25eb28e2f1aa014e81ad00 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/__pycache__/util.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/plotting/pygletplot/color_scheme.py b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/color_scheme.py new file mode 100644 index 0000000000000000000000000000000000000000..613e777a7f45f54349c47d272aa6d1c157bcd117 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/color_scheme.py @@ -0,0 +1,336 @@ +from sympy.core.basic import Basic +from sympy.core.symbol import (Symbol, symbols) +from sympy.utilities.lambdify import lambdify +from .util import interpolate, rinterpolate, create_bounds, update_bounds +from sympy.utilities.iterables import sift + + +class ColorGradient: + colors = [0.4, 0.4, 0.4], [0.9, 0.9, 0.9] + intervals = 0.0, 1.0 + + def __init__(self, *args): + if len(args) == 2: + self.colors = list(args) + self.intervals = [0.0, 1.0] + elif len(args) > 0: + if len(args) % 2 != 0: + raise ValueError("len(args) should be even") + self.colors = [args[i] for i in range(1, len(args), 2)] + self.intervals = [args[i] for i in range(0, len(args), 2)] + assert len(self.colors) == len(self.intervals) + + def copy(self): + c = ColorGradient() + c.colors = [e[::] for e in self.colors] + c.intervals = self.intervals[::] + return c + + def _find_interval(self, v): + m = len(self.intervals) + i = 0 + while i < m - 1 and self.intervals[i] <= v: + i += 1 + return i + + def _interpolate_axis(self, axis, v): + i = self._find_interval(v) + v = rinterpolate(self.intervals[i - 1], self.intervals[i], v) + return interpolate(self.colors[i - 1][axis], self.colors[i][axis], v) + + def __call__(self, r, g, b): + c = self._interpolate_axis + return c(0, r), c(1, g), c(2, b) + +default_color_schemes = {} # defined at the bottom of this file + + +class ColorScheme: + + def __init__(self, *args, **kwargs): + self.args = args + self.f, self.gradient = None, ColorGradient() + + if len(args) == 1 and not isinstance(args[0], Basic) and callable(args[0]): + self.f = args[0] + elif len(args) == 1 and isinstance(args[0], str): + if args[0] in default_color_schemes: + cs = default_color_schemes[args[0]] + self.f, self.gradient = cs.f, cs.gradient.copy() + else: + self.f = lambdify('x,y,z,u,v', args[0]) + else: + self.f, self.gradient = self._interpret_args(args) + self._test_color_function() + if not isinstance(self.gradient, ColorGradient): + raise ValueError("Color gradient not properly initialized. " + "(Not a ColorGradient instance.)") + + def _interpret_args(self, args): + f, gradient = None, self.gradient + atoms, lists = self._sort_args(args) + s = self._pop_symbol_list(lists) + s = self._fill_in_vars(s) + + # prepare the error message for lambdification failure + f_str = ', '.join(str(fa) for fa in atoms) + s_str = (str(sa) for sa in s) + s_str = ', '.join(sa for sa in s_str if sa.find('unbound') < 0) + f_error = ValueError("Could not interpret arguments " + "%s as functions of %s." % (f_str, s_str)) + + # try to lambdify args + if len(atoms) == 1: + fv = atoms[0] + try: + f = lambdify(s, [fv, fv, fv]) + except TypeError: + raise f_error + + elif len(atoms) == 3: + fr, fg, fb = atoms + try: + f = lambdify(s, [fr, fg, fb]) + except TypeError: + raise f_error + + else: + raise ValueError("A ColorScheme must provide 1 or 3 " + "functions in x, y, z, u, and/or v.") + + # try to intrepret any given color information + if len(lists) == 0: + gargs = [] + + elif len(lists) == 1: + gargs = lists[0] + + elif len(lists) == 2: + try: + (r1, g1, b1), (r2, g2, b2) = lists + except TypeError: + raise ValueError("If two color arguments are given, " + "they must be given in the format " + "(r1, g1, b1), (r2, g2, b2).") + gargs = lists + + elif len(lists) == 3: + try: + (r1, r2), (g1, g2), (b1, b2) = lists + except Exception: + raise ValueError("If three color arguments are given, " + "they must be given in the format " + "(r1, r2), (g1, g2), (b1, b2). To create " + "a multi-step gradient, use the syntax " + "[0, colorStart, step1, color1, ..., 1, " + "colorEnd].") + gargs = [[r1, g1, b1], [r2, g2, b2]] + + else: + raise ValueError("Don't know what to do with collection " + "arguments %s." % (', '.join(str(l) for l in lists))) + + if gargs: + try: + gradient = ColorGradient(*gargs) + except Exception as ex: + raise ValueError(("Could not initialize a gradient " + "with arguments %s. Inner " + "exception: %s") % (gargs, str(ex))) + + return f, gradient + + def _pop_symbol_list(self, lists): + symbol_lists = [] + for l in lists: + mark = True + for s in l: + if s is not None and not isinstance(s, Symbol): + mark = False + break + if mark: + lists.remove(l) + symbol_lists.append(l) + if len(symbol_lists) == 1: + return symbol_lists[0] + elif len(symbol_lists) == 0: + return [] + else: + raise ValueError("Only one list of Symbols " + "can be given for a color scheme.") + + def _fill_in_vars(self, args): + defaults = symbols('x,y,z,u,v') + v_error = ValueError("Could not find what to plot.") + if len(args) == 0: + return defaults + if not isinstance(args, (tuple, list)): + raise v_error + if len(args) == 0: + return defaults + for s in args: + if s is not None and not isinstance(s, Symbol): + raise v_error + # when vars are given explicitly, any vars + # not given are marked 'unbound' as to not + # be accidentally used in an expression + vars = [Symbol('unbound%i' % (i)) for i in range(1, 6)] + # interpret as t + if len(args) == 1: + vars[3] = args[0] + # interpret as u,v + elif len(args) == 2: + if args[0] is not None: + vars[3] = args[0] + if args[1] is not None: + vars[4] = args[1] + # interpret as x,y,z + elif len(args) >= 3: + # allow some of x,y,z to be + # left unbound if not given + if args[0] is not None: + vars[0] = args[0] + if args[1] is not None: + vars[1] = args[1] + if args[2] is not None: + vars[2] = args[2] + # interpret the rest as t + if len(args) >= 4: + vars[3] = args[3] + # ...or u,v + if len(args) >= 5: + vars[4] = args[4] + return vars + + def _sort_args(self, args): + lists, atoms = sift(args, + lambda a: isinstance(a, (tuple, list)), binary=True) + return atoms, lists + + def _test_color_function(self): + if not callable(self.f): + raise ValueError("Color function is not callable.") + try: + result = self.f(0, 0, 0, 0, 0) + if len(result) != 3: + raise ValueError("length should be equal to 3") + except TypeError: + raise ValueError("Color function needs to accept x,y,z,u,v, " + "as arguments even if it doesn't use all of them.") + except AssertionError: + raise ValueError("Color function needs to return 3-tuple r,g,b.") + except Exception: + pass # color function probably not valid at 0,0,0,0,0 + + def __call__(self, x, y, z, u, v): + try: + return self.f(x, y, z, u, v) + except Exception: + return None + + def apply_to_curve(self, verts, u_set, set_len=None, inc_pos=None): + """ + Apply this color scheme to a + set of vertices over a single + independent variable u. + """ + bounds = create_bounds() + cverts = [] + if callable(set_len): + set_len(len(u_set)*2) + # calculate f() = r,g,b for each vert + # and find the min and max for r,g,b + for _u in range(len(u_set)): + if verts[_u] is None: + cverts.append(None) + else: + x, y, z = verts[_u] + u, v = u_set[_u], None + c = self(x, y, z, u, v) + if c is not None: + c = list(c) + update_bounds(bounds, c) + cverts.append(c) + if callable(inc_pos): + inc_pos() + # scale and apply gradient + for _u in range(len(u_set)): + if cverts[_u] is not None: + for _c in range(3): + # scale from [f_min, f_max] to [0,1] + cverts[_u][_c] = rinterpolate(bounds[_c][0], bounds[_c][1], + cverts[_u][_c]) + # apply gradient + cverts[_u] = self.gradient(*cverts[_u]) + if callable(inc_pos): + inc_pos() + return cverts + + def apply_to_surface(self, verts, u_set, v_set, set_len=None, inc_pos=None): + """ + Apply this color scheme to a + set of vertices over two + independent variables u and v. + """ + bounds = create_bounds() + cverts = [] + if callable(set_len): + set_len(len(u_set)*len(v_set)*2) + # calculate f() = r,g,b for each vert + # and find the min and max for r,g,b + for _u in range(len(u_set)): + column = [] + for _v in range(len(v_set)): + if verts[_u][_v] is None: + column.append(None) + else: + x, y, z = verts[_u][_v] + u, v = u_set[_u], v_set[_v] + c = self(x, y, z, u, v) + if c is not None: + c = list(c) + update_bounds(bounds, c) + column.append(c) + if callable(inc_pos): + inc_pos() + cverts.append(column) + # scale and apply gradient + for _u in range(len(u_set)): + for _v in range(len(v_set)): + if cverts[_u][_v] is not None: + # scale from [f_min, f_max] to [0,1] + for _c in range(3): + cverts[_u][_v][_c] = rinterpolate(bounds[_c][0], + bounds[_c][1], cverts[_u][_v][_c]) + # apply gradient + cverts[_u][_v] = self.gradient(*cverts[_u][_v]) + if callable(inc_pos): + inc_pos() + return cverts + + def str_base(self): + return ", ".join(str(a) for a in self.args) + + def __repr__(self): + return "%s" % (self.str_base()) + + +x, y, z, t, u, v = symbols('x,y,z,t,u,v') + +default_color_schemes['rainbow'] = ColorScheme(z, y, x) +default_color_schemes['zfade'] = ColorScheme(z, (0.4, 0.4, 0.97), + (0.97, 0.4, 0.4), (None, None, z)) +default_color_schemes['zfade3'] = ColorScheme(z, (None, None, z), + [0.00, (0.2, 0.2, 1.0), + 0.35, (0.2, 0.8, 0.4), + 0.50, (0.3, 0.9, 0.3), + 0.65, (0.4, 0.8, 0.2), + 1.00, (1.0, 0.2, 0.2)]) + +default_color_schemes['zfade4'] = ColorScheme(z, (None, None, z), + [0.0, (0.3, 0.3, 1.0), + 0.30, (0.3, 1.0, 0.3), + 0.55, (0.95, 1.0, 0.2), + 0.65, (1.0, 0.95, 0.2), + 0.85, (1.0, 0.7, 0.2), + 1.0, (1.0, 0.3, 0.2)]) diff --git a/phivenv/Lib/site-packages/sympy/plotting/pygletplot/managed_window.py b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/managed_window.py new file mode 100644 index 0000000000000000000000000000000000000000..81fa2541b4dd9e13534aabfd2a11bf88c479daf8 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/managed_window.py @@ -0,0 +1,106 @@ +from pyglet.window import Window +from pyglet.clock import Clock + +from threading import Thread, Lock + +gl_lock = Lock() + + +class ManagedWindow(Window): + """ + A pyglet window with an event loop which executes automatically + in a separate thread. Behavior is added by creating a subclass + which overrides setup, update, and/or draw. + """ + fps_limit = 30 + default_win_args = {"width": 600, + "height": 500, + "vsync": False, + "resizable": True} + + def __init__(self, **win_args): + """ + It is best not to override this function in the child + class, unless you need to take additional arguments. + Do any OpenGL initialization calls in setup(). + """ + + # check if this is run from the doctester + if win_args.get('runfromdoctester', False): + return + + self.win_args = dict(self.default_win_args, **win_args) + self.Thread = Thread(target=self.__event_loop__) + self.Thread.start() + + def __event_loop__(self, **win_args): + """ + The event loop thread function. Do not override or call + directly (it is called by __init__). + """ + gl_lock.acquire() + try: + try: + super().__init__(**self.win_args) + self.switch_to() + self.setup() + except Exception as e: + print("Window initialization failed: %s" % (str(e))) + self.has_exit = True + finally: + gl_lock.release() + + clock = Clock() + clock.fps_limit = self.fps_limit + while not self.has_exit: + dt = clock.tick() + gl_lock.acquire() + try: + try: + self.switch_to() + self.dispatch_events() + self.clear() + self.update(dt) + self.draw() + self.flip() + except Exception as e: + print("Uncaught exception in event loop: %s" % str(e)) + self.has_exit = True + finally: + gl_lock.release() + super().close() + + def close(self): + """ + Closes the window. + """ + self.has_exit = True + + def setup(self): + """ + Called once before the event loop begins. + Override this method in a child class. This + is the best place to put things like OpenGL + initialization calls. + """ + pass + + def update(self, dt): + """ + Called before draw during each iteration of + the event loop. dt is the elapsed time in + seconds since the last update. OpenGL rendering + calls are best put in draw() rather than here. + """ + pass + + def draw(self): + """ + Called after update during each iteration of + the event loop. Put OpenGL rendering calls + here. + """ + pass + +if __name__ == '__main__': + ManagedWindow() diff --git a/phivenv/Lib/site-packages/sympy/plotting/pygletplot/plot.py b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/plot.py new file mode 100644 index 0000000000000000000000000000000000000000..8c3dd3c8d4ce6c660cc07f93a55029eef98e55a2 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/plot.py @@ -0,0 +1,464 @@ +from threading import RLock + +# it is sufficient to import "pyglet" here once +try: + import pyglet.gl as pgl +except ImportError: + raise ImportError("pyglet is required for plotting.\n " + "visit https://pyglet.org/") + +from sympy.core.numbers import Integer +from sympy.external.gmpy import SYMPY_INTS +from sympy.geometry.entity import GeometryEntity +from sympy.plotting.pygletplot.plot_axes import PlotAxes +from sympy.plotting.pygletplot.plot_mode import PlotMode +from sympy.plotting.pygletplot.plot_object import PlotObject +from sympy.plotting.pygletplot.plot_window import PlotWindow +from sympy.plotting.pygletplot.util import parse_option_string +from sympy.utilities.decorator import doctest_depends_on +from sympy.utilities.iterables import is_sequence + +from time import sleep +from os import getcwd, listdir + +import ctypes + +@doctest_depends_on(modules=('pyglet',)) +class PygletPlot: + """ + Plot Examples + ============= + + See examples/advanced/pyglet_plotting.py for many more examples. + + >>> from sympy.plotting.pygletplot import PygletPlot as Plot + >>> from sympy.abc import x, y, z + + >>> Plot(x*y**3-y*x**3) + [0]: -x**3*y + x*y**3, 'mode=cartesian' + + >>> p = Plot() + >>> p[1] = x*y + >>> p[1].color = z, (0.4,0.4,0.9), (0.9,0.4,0.4) + + >>> p = Plot() + >>> p[1] = x**2+y**2 + >>> p[2] = -x**2-y**2 + + + Variable Intervals + ================== + + The basic format is [var, min, max, steps], but the + syntax is flexible and arguments left out are taken + from the defaults for the current coordinate mode: + + >>> Plot(x**2) # implies [x,-5,5,100] + [0]: x**2, 'mode=cartesian' + >>> Plot(x**2, [], []) # [x,-1,1,40], [y,-1,1,40] + [0]: x**2, 'mode=cartesian' + >>> Plot(x**2-y**2, [100], [100]) # [x,-1,1,100], [y,-1,1,100] + [0]: x**2 - y**2, 'mode=cartesian' + >>> Plot(x**2, [x,-13,13,100]) + [0]: x**2, 'mode=cartesian' + >>> Plot(x**2, [-13,13]) # [x,-13,13,100] + [0]: x**2, 'mode=cartesian' + >>> Plot(x**2, [x,-13,13]) # [x,-13,13,10] + [0]: x**2, 'mode=cartesian' + >>> Plot(1*x, [], [x], mode='cylindrical') + ... # [unbound_theta,0,2*Pi,40], [x,-1,1,20] + [0]: x, 'mode=cartesian' + + + Coordinate Modes + ================ + + Plot supports several curvilinear coordinate modes, and + they independent for each plotted function. You can specify + a coordinate mode explicitly with the 'mode' named argument, + but it can be automatically determined for Cartesian or + parametric plots, and therefore must only be specified for + polar, cylindrical, and spherical modes. + + Specifically, Plot(function arguments) and Plot[n] = + (function arguments) will interpret your arguments as a + Cartesian plot if you provide one function and a parametric + plot if you provide two or three functions. Similarly, the + arguments will be interpreted as a curve if one variable is + used, and a surface if two are used. + + Supported mode names by number of variables: + + 1: parametric, cartesian, polar + 2: parametric, cartesian, cylindrical = polar, spherical + + >>> Plot(1, mode='spherical') + + + Calculator-like Interface + ========================= + + >>> p = Plot(visible=False) + >>> f = x**2 + >>> p[1] = f + >>> p[2] = f.diff(x) + >>> p[3] = f.diff(x).diff(x) + >>> p + [1]: x**2, 'mode=cartesian' + [2]: 2*x, 'mode=cartesian' + [3]: 2, 'mode=cartesian' + >>> p.show() + >>> p.clear() + >>> p + + >>> p[1] = x**2+y**2 + >>> p[1].style = 'solid' + >>> p[2] = -x**2-y**2 + >>> p[2].style = 'wireframe' + >>> p[1].color = z, (0.4,0.4,0.9), (0.9,0.4,0.4) + >>> p[1].style = 'both' + >>> p[2].style = 'both' + >>> p.close() + + + Plot Window Keyboard Controls + ============================= + + Screen Rotation: + X,Y axis Arrow Keys, A,S,D,W, Numpad 4,6,8,2 + Z axis Q,E, Numpad 7,9 + + Model Rotation: + Z axis Z,C, Numpad 1,3 + + Zoom: R,F, PgUp,PgDn, Numpad +,- + + Reset Camera: X, Numpad 5 + + Camera Presets: + XY F1 + XZ F2 + YZ F3 + Perspective F4 + + Sensitivity Modifier: SHIFT + + Axes Toggle: + Visible F5 + Colors F6 + + Close Window: ESCAPE + + ============================= + + """ + + @doctest_depends_on(modules=('pyglet',)) + def __init__(self, *fargs, **win_args): + """ + Positional Arguments + ==================== + + Any given positional arguments are used to + initialize a plot function at index 1. In + other words... + + >>> from sympy.plotting.pygletplot import PygletPlot as Plot + >>> from sympy.abc import x + >>> p = Plot(x**2, visible=False) + + ...is equivalent to... + + >>> p = Plot(visible=False) + >>> p[1] = x**2 + + Note that in earlier versions of the plotting + module, you were able to specify multiple + functions in the initializer. This functionality + has been dropped in favor of better automatic + plot plot_mode detection. + + + Named Arguments + =============== + + axes + An option string of the form + "key1=value1; key2 = value2" which + can use the following options: + + style = ordinate + none OR frame OR box OR ordinate + + stride = 0.25 + val OR (val_x, val_y, val_z) + + overlay = True (draw on top of plot) + True OR False + + colored = False (False uses Black, + True uses colors + R,G,B = X,Y,Z) + True OR False + + label_axes = False (display axis names + at endpoints) + True OR False + + visible = True (show immediately + True OR False + + + The following named arguments are passed as + arguments to window initialization: + + antialiasing = True + True OR False + + ortho = False + True OR False + + invert_mouse_zoom = False + True OR False + + """ + # Register the plot modes + from . import plot_modes # noqa + + self._win_args = win_args + self._window = None + + self._render_lock = RLock() + + self._functions = {} + self._pobjects = [] + self._screenshot = ScreenShot(self) + + axe_options = parse_option_string(win_args.pop('axes', '')) + self.axes = PlotAxes(**axe_options) + self._pobjects.append(self.axes) + + self[0] = fargs + if win_args.get('visible', True): + self.show() + + ## Window Interfaces + + def show(self): + """ + Creates and displays a plot window, or activates it + (gives it focus) if it has already been created. + """ + if self._window and not self._window.has_exit: + self._window.activate() + else: + self._win_args['visible'] = True + self.axes.reset_resources() + + #if hasattr(self, '_doctest_depends_on'): + # self._win_args['runfromdoctester'] = True + + self._window = PlotWindow(self, **self._win_args) + + def close(self): + """ + Closes the plot window. + """ + if self._window: + self._window.close() + + def saveimage(self, outfile=None, format='', size=(600, 500)): + """ + Saves a screen capture of the plot window to an + image file. + + If outfile is given, it can either be a path + or a file object. Otherwise a png image will + be saved to the current working directory. + If the format is omitted, it is determined from + the filename extension. + """ + self._screenshot.save(outfile, format, size) + + ## Function List Interfaces + + def clear(self): + """ + Clears the function list of this plot. + """ + self._render_lock.acquire() + self._functions = {} + self.adjust_all_bounds() + self._render_lock.release() + + def __getitem__(self, i): + """ + Returns the function at position i in the + function list. + """ + return self._functions[i] + + def __setitem__(self, i, args): + """ + Parses and adds a PlotMode to the function + list. + """ + if not (isinstance(i, (SYMPY_INTS, Integer)) and i >= 0): + raise ValueError("Function index must " + "be an integer >= 0.") + + if isinstance(args, PlotObject): + f = args + else: + if (not is_sequence(args)) or isinstance(args, GeometryEntity): + args = [args] + if len(args) == 0: + return # no arguments given + kwargs = {"bounds_callback": self.adjust_all_bounds} + f = PlotMode(*args, **kwargs) + + if f: + self._render_lock.acquire() + self._functions[i] = f + self._render_lock.release() + else: + raise ValueError("Failed to parse '%s'." + % ', '.join(str(a) for a in args)) + + def __delitem__(self, i): + """ + Removes the function in the function list at + position i. + """ + self._render_lock.acquire() + del self._functions[i] + self.adjust_all_bounds() + self._render_lock.release() + + def firstavailableindex(self): + """ + Returns the first unused index in the function list. + """ + i = 0 + self._render_lock.acquire() + while i in self._functions: + i += 1 + self._render_lock.release() + return i + + def append(self, *args): + """ + Parses and adds a PlotMode to the function + list at the first available index. + """ + self.__setitem__(self.firstavailableindex(), args) + + def __len__(self): + """ + Returns the number of functions in the function list. + """ + return len(self._functions) + + def __iter__(self): + """ + Allows iteration of the function list. + """ + return self._functions.itervalues() + + def __repr__(self): + return str(self) + + def __str__(self): + """ + Returns a string containing a new-line separated + list of the functions in the function list. + """ + s = "" + if len(self._functions) == 0: + s += "" + else: + self._render_lock.acquire() + s += "\n".join(["%s[%i]: %s" % ("", i, str(self._functions[i])) + for i in self._functions]) + self._render_lock.release() + return s + + def adjust_all_bounds(self): + self._render_lock.acquire() + self.axes.reset_bounding_box() + for f in self._functions: + self.axes.adjust_bounds(self._functions[f].bounds) + self._render_lock.release() + + def wait_for_calculations(self): + sleep(0) + self._render_lock.acquire() + for f in self._functions: + a = self._functions[f]._get_calculating_verts + b = self._functions[f]._get_calculating_cverts + while a() or b(): + sleep(0) + self._render_lock.release() + +class ScreenShot: + def __init__(self, plot): + self._plot = plot + self.screenshot_requested = False + self.outfile = None + self.format = '' + self.invisibleMode = False + self.flag = 0 + + def __bool__(self): + return self.screenshot_requested + + def _execute_saving(self): + if self.flag < 3: + self.flag += 1 + return + + size_x, size_y = self._plot._window.get_size() + size = size_x*size_y*4*ctypes.sizeof(ctypes.c_ubyte) + image = ctypes.create_string_buffer(size) + pgl.glReadPixels(0, 0, size_x, size_y, pgl.GL_RGBA, pgl.GL_UNSIGNED_BYTE, image) + from PIL import Image + im = Image.frombuffer('RGBA', (size_x, size_y), + image.raw, 'raw', 'RGBA', 0, 1) + im.transpose(Image.FLIP_TOP_BOTTOM).save(self.outfile, self.format) + + self.flag = 0 + self.screenshot_requested = False + if self.invisibleMode: + self._plot._window.close() + + def save(self, outfile=None, format='', size=(600, 500)): + self.outfile = outfile + self.format = format + self.size = size + self.screenshot_requested = True + + if not self._plot._window or self._plot._window.has_exit: + self._plot._win_args['visible'] = False + + self._plot._win_args['width'] = size[0] + self._plot._win_args['height'] = size[1] + + self._plot.axes.reset_resources() + self._plot._window = PlotWindow(self._plot, **self._plot._win_args) + self.invisibleMode = True + + if self.outfile is None: + self.outfile = self._create_unique_path() + print(self.outfile) + + def _create_unique_path(self): + cwd = getcwd() + l = listdir(cwd) + path = '' + i = 0 + while True: + if not 'plot_%s.png' % i in l: + path = cwd + '/plot_%s.png' % i + break + i += 1 + return path diff --git a/phivenv/Lib/site-packages/sympy/plotting/pygletplot/plot_axes.py b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/plot_axes.py new file mode 100644 index 0000000000000000000000000000000000000000..ae26fb0b2fa64e7f7318c51ce3fe5afaa276b48e --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/plot_axes.py @@ -0,0 +1,251 @@ +import pyglet.gl as pgl +from pyglet import font + +from sympy.core import S +from sympy.plotting.pygletplot.plot_object import PlotObject +from sympy.plotting.pygletplot.util import billboard_matrix, dot_product, \ + get_direction_vectors, strided_range, vec_mag, vec_sub +from sympy.utilities.iterables import is_sequence + + +class PlotAxes(PlotObject): + + def __init__(self, *args, + style='', none=None, frame=None, box=None, ordinate=None, + stride=0.25, + visible='', overlay='', colored='', label_axes='', label_ticks='', + tick_length=0.1, + font_face='Arial', font_size=28, + **kwargs): + # initialize style parameter + style = style.lower() + + # allow alias kwargs to override style kwarg + if none is not None: + style = 'none' + if frame is not None: + style = 'frame' + if box is not None: + style = 'box' + if ordinate is not None: + style = 'ordinate' + + if style in ['', 'ordinate']: + self._render_object = PlotAxesOrdinate(self) + elif style in ['frame', 'box']: + self._render_object = PlotAxesFrame(self) + elif style in ['none']: + self._render_object = None + else: + raise ValueError(("Unrecognized axes style %s.") % (style)) + + # initialize stride parameter + try: + stride = eval(stride) + except TypeError: + pass + if is_sequence(stride): + if len(stride) != 3: + raise ValueError("length should be equal to 3") + self._stride = stride + else: + self._stride = [stride, stride, stride] + self._tick_length = float(tick_length) + + # setup bounding box and ticks + self._origin = [0, 0, 0] + self.reset_bounding_box() + + def flexible_boolean(input, default): + if input in [True, False]: + return input + if input in ('f', 'F', 'false', 'False'): + return False + if input in ('t', 'T', 'true', 'True'): + return True + return default + + # initialize remaining parameters + self.visible = flexible_boolean(kwargs, True) + self._overlay = flexible_boolean(overlay, True) + self._colored = flexible_boolean(colored, False) + self._label_axes = flexible_boolean(label_axes, False) + self._label_ticks = flexible_boolean(label_ticks, True) + + # setup label font + self.font_face = font_face + self.font_size = font_size + + # this is also used to reinit the + # font on window close/reopen + self.reset_resources() + + def reset_resources(self): + self.label_font = None + + def reset_bounding_box(self): + self._bounding_box = [[None, None], [None, None], [None, None]] + self._axis_ticks = [[], [], []] + + def draw(self): + if self._render_object: + pgl.glPushAttrib(pgl.GL_ENABLE_BIT | pgl.GL_POLYGON_BIT | pgl.GL_DEPTH_BUFFER_BIT) + if self._overlay: + pgl.glDisable(pgl.GL_DEPTH_TEST) + self._render_object.draw() + pgl.glPopAttrib() + + def adjust_bounds(self, child_bounds): + b = self._bounding_box + c = child_bounds + for i in range(3): + if abs(c[i][0]) is S.Infinity or abs(c[i][1]) is S.Infinity: + continue + b[i][0] = c[i][0] if b[i][0] is None else min([b[i][0], c[i][0]]) + b[i][1] = c[i][1] if b[i][1] is None else max([b[i][1], c[i][1]]) + self._bounding_box = b + self._recalculate_axis_ticks(i) + + def _recalculate_axis_ticks(self, axis): + b = self._bounding_box + if b[axis][0] is None or b[axis][1] is None: + self._axis_ticks[axis] = [] + else: + self._axis_ticks[axis] = strided_range(b[axis][0], b[axis][1], + self._stride[axis]) + + def toggle_visible(self): + self.visible = not self.visible + + def toggle_colors(self): + self._colored = not self._colored + + +class PlotAxesBase(PlotObject): + + def __init__(self, parent_axes): + self._p = parent_axes + + def draw(self): + color = [([0.2, 0.1, 0.3], [0.2, 0.1, 0.3], [0.2, 0.1, 0.3]), + ([0.9, 0.3, 0.5], [0.5, 1.0, 0.5], [0.3, 0.3, 0.9])][self._p._colored] + self.draw_background(color) + self.draw_axis(2, color[2]) + self.draw_axis(1, color[1]) + self.draw_axis(0, color[0]) + + def draw_background(self, color): + pass # optional + + def draw_axis(self, axis, color): + raise NotImplementedError() + + def draw_text(self, text, position, color, scale=1.0): + if len(color) == 3: + color = (color[0], color[1], color[2], 1.0) + + if self._p.label_font is None: + self._p.label_font = font.load(self._p.font_face, + self._p.font_size, + bold=True, italic=False) + + label = font.Text(self._p.label_font, text, + color=color, + valign=font.Text.BASELINE, + halign=font.Text.CENTER) + + pgl.glPushMatrix() + pgl.glTranslatef(*position) + billboard_matrix() + scale_factor = 0.005 * scale + pgl.glScalef(scale_factor, scale_factor, scale_factor) + pgl.glColor4f(0, 0, 0, 0) + label.draw() + pgl.glPopMatrix() + + def draw_line(self, v, color): + o = self._p._origin + pgl.glBegin(pgl.GL_LINES) + pgl.glColor3f(*color) + pgl.glVertex3f(v[0][0] + o[0], v[0][1] + o[1], v[0][2] + o[2]) + pgl.glVertex3f(v[1][0] + o[0], v[1][1] + o[1], v[1][2] + o[2]) + pgl.glEnd() + + +class PlotAxesOrdinate(PlotAxesBase): + + def __init__(self, parent_axes): + super().__init__(parent_axes) + + def draw_axis(self, axis, color): + ticks = self._p._axis_ticks[axis] + radius = self._p._tick_length / 2.0 + if len(ticks) < 2: + return + + # calculate the vector for this axis + axis_lines = [[0, 0, 0], [0, 0, 0]] + axis_lines[0][axis], axis_lines[1][axis] = ticks[0], ticks[-1] + axis_vector = vec_sub(axis_lines[1], axis_lines[0]) + + # calculate angle to the z direction vector + pos_z = get_direction_vectors()[2] + d = abs(dot_product(axis_vector, pos_z)) + d = d / vec_mag(axis_vector) + + # don't draw labels if we're looking down the axis + labels_visible = abs(d - 1.0) > 0.02 + + # draw the ticks and labels + for tick in ticks: + self.draw_tick_line(axis, color, radius, tick, labels_visible) + + # draw the axis line and labels + self.draw_axis_line(axis, color, ticks[0], ticks[-1], labels_visible) + + def draw_axis_line(self, axis, color, a_min, a_max, labels_visible): + axis_line = [[0, 0, 0], [0, 0, 0]] + axis_line[0][axis], axis_line[1][axis] = a_min, a_max + self.draw_line(axis_line, color) + if labels_visible: + self.draw_axis_line_labels(axis, color, axis_line) + + def draw_axis_line_labels(self, axis, color, axis_line): + if not self._p._label_axes: + return + axis_labels = [axis_line[0][::], axis_line[1][::]] + axis_labels[0][axis] -= 0.3 + axis_labels[1][axis] += 0.3 + a_str = ['X', 'Y', 'Z'][axis] + self.draw_text("-" + a_str, axis_labels[0], color) + self.draw_text("+" + a_str, axis_labels[1], color) + + def draw_tick_line(self, axis, color, radius, tick, labels_visible): + tick_axis = {0: 1, 1: 0, 2: 1}[axis] + tick_line = [[0, 0, 0], [0, 0, 0]] + tick_line[0][axis] = tick_line[1][axis] = tick + tick_line[0][tick_axis], tick_line[1][tick_axis] = -radius, radius + self.draw_line(tick_line, color) + if labels_visible: + self.draw_tick_line_label(axis, color, radius, tick) + + def draw_tick_line_label(self, axis, color, radius, tick): + if not self._p._label_axes: + return + tick_label_vector = [0, 0, 0] + tick_label_vector[axis] = tick + tick_label_vector[{0: 1, 1: 0, 2: 1}[axis]] = [-1, 1, 1][ + axis] * radius * 3.5 + self.draw_text(str(tick), tick_label_vector, color, scale=0.5) + + +class PlotAxesFrame(PlotAxesBase): + + def __init__(self, parent_axes): + super().__init__(parent_axes) + + def draw_background(self, color): + pass + + def draw_axis(self, axis, color): + raise NotImplementedError() diff --git a/phivenv/Lib/site-packages/sympy/plotting/pygletplot/plot_camera.py b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/plot_camera.py new file mode 100644 index 0000000000000000000000000000000000000000..43598debac252ffd22beb8690fef30745259c634 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/plot_camera.py @@ -0,0 +1,124 @@ +import pyglet.gl as pgl +from sympy.plotting.pygletplot.plot_rotation import get_spherical_rotatation +from sympy.plotting.pygletplot.util import get_model_matrix, model_to_screen, \ + screen_to_model, vec_subs + + +class PlotCamera: + + min_dist = 0.05 + max_dist = 500.0 + + min_ortho_dist = 100.0 + max_ortho_dist = 10000.0 + + _default_dist = 6.0 + _default_ortho_dist = 600.0 + + rot_presets = { + 'xy': (0, 0, 0), + 'xz': (-90, 0, 0), + 'yz': (0, 90, 0), + 'perspective': (-45, 0, -45) + } + + def __init__(self, window, ortho=False): + self.window = window + self.axes = self.window.plot.axes + self.ortho = ortho + self.reset() + + def init_rot_matrix(self): + pgl.glPushMatrix() + pgl.glLoadIdentity() + self._rot = get_model_matrix() + pgl.glPopMatrix() + + def set_rot_preset(self, preset_name): + self.init_rot_matrix() + if preset_name not in self.rot_presets: + raise ValueError( + "%s is not a valid rotation preset." % preset_name) + r = self.rot_presets[preset_name] + self.euler_rotate(r[0], 1, 0, 0) + self.euler_rotate(r[1], 0, 1, 0) + self.euler_rotate(r[2], 0, 0, 1) + + def reset(self): + self._dist = 0.0 + self._x, self._y = 0.0, 0.0 + self._rot = None + if self.ortho: + self._dist = self._default_ortho_dist + else: + self._dist = self._default_dist + self.init_rot_matrix() + + def mult_rot_matrix(self, rot): + pgl.glPushMatrix() + pgl.glLoadMatrixf(rot) + pgl.glMultMatrixf(self._rot) + self._rot = get_model_matrix() + pgl.glPopMatrix() + + def setup_projection(self): + pgl.glMatrixMode(pgl.GL_PROJECTION) + pgl.glLoadIdentity() + if self.ortho: + # yep, this is pseudo ortho (don't tell anyone) + pgl.gluPerspective( + 0.3, float(self.window.width)/float(self.window.height), + self.min_ortho_dist - 0.01, self.max_ortho_dist + 0.01) + else: + pgl.gluPerspective( + 30.0, float(self.window.width)/float(self.window.height), + self.min_dist - 0.01, self.max_dist + 0.01) + pgl.glMatrixMode(pgl.GL_MODELVIEW) + + def _get_scale(self): + return 1.0, 1.0, 1.0 + + def apply_transformation(self): + pgl.glLoadIdentity() + pgl.glTranslatef(self._x, self._y, -self._dist) + if self._rot is not None: + pgl.glMultMatrixf(self._rot) + pgl.glScalef(*self._get_scale()) + + def spherical_rotate(self, p1, p2, sensitivity=1.0): + mat = get_spherical_rotatation(p1, p2, self.window.width, + self.window.height, sensitivity) + if mat is not None: + self.mult_rot_matrix(mat) + + def euler_rotate(self, angle, x, y, z): + pgl.glPushMatrix() + pgl.glLoadMatrixf(self._rot) + pgl.glRotatef(angle, x, y, z) + self._rot = get_model_matrix() + pgl.glPopMatrix() + + def zoom_relative(self, clicks, sensitivity): + + if self.ortho: + dist_d = clicks * sensitivity * 50.0 + min_dist = self.min_ortho_dist + max_dist = self.max_ortho_dist + else: + dist_d = clicks * sensitivity + min_dist = self.min_dist + max_dist = self.max_dist + + new_dist = (self._dist - dist_d) + if (clicks < 0 and new_dist < max_dist) or new_dist > min_dist: + self._dist = new_dist + + def mouse_translate(self, x, y, dx, dy): + pgl.glPushMatrix() + pgl.glLoadIdentity() + pgl.glTranslatef(0, 0, -self._dist) + z = model_to_screen(0, 0, 0)[2] + d = vec_subs(screen_to_model(x, y, z), screen_to_model(x - dx, y - dy, z)) + pgl.glPopMatrix() + self._x += d[0] + self._y += d[1] diff --git a/phivenv/Lib/site-packages/sympy/plotting/pygletplot/plot_controller.py b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/plot_controller.py new file mode 100644 index 0000000000000000000000000000000000000000..aa7e01e6fd17fddf07b733442208a0a4c9d87d5b --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/plot_controller.py @@ -0,0 +1,218 @@ +from pyglet.window import key +from pyglet.window.mouse import LEFT, RIGHT, MIDDLE +from sympy.plotting.pygletplot.util import get_direction_vectors, get_basis_vectors + + +class PlotController: + + normal_mouse_sensitivity = 4.0 + modified_mouse_sensitivity = 1.0 + + normal_key_sensitivity = 160.0 + modified_key_sensitivity = 40.0 + + keymap = { + key.LEFT: 'left', + key.A: 'left', + key.NUM_4: 'left', + + key.RIGHT: 'right', + key.D: 'right', + key.NUM_6: 'right', + + key.UP: 'up', + key.W: 'up', + key.NUM_8: 'up', + + key.DOWN: 'down', + key.S: 'down', + key.NUM_2: 'down', + + key.Z: 'rotate_z_neg', + key.NUM_1: 'rotate_z_neg', + + key.C: 'rotate_z_pos', + key.NUM_3: 'rotate_z_pos', + + key.Q: 'spin_left', + key.NUM_7: 'spin_left', + key.E: 'spin_right', + key.NUM_9: 'spin_right', + + key.X: 'reset_camera', + key.NUM_5: 'reset_camera', + + key.NUM_ADD: 'zoom_in', + key.PAGEUP: 'zoom_in', + key.R: 'zoom_in', + + key.NUM_SUBTRACT: 'zoom_out', + key.PAGEDOWN: 'zoom_out', + key.F: 'zoom_out', + + key.RSHIFT: 'modify_sensitivity', + key.LSHIFT: 'modify_sensitivity', + + key.F1: 'rot_preset_xy', + key.F2: 'rot_preset_xz', + key.F3: 'rot_preset_yz', + key.F4: 'rot_preset_perspective', + + key.F5: 'toggle_axes', + key.F6: 'toggle_axe_colors', + + key.F8: 'save_image' + } + + def __init__(self, window, *, invert_mouse_zoom=False, **kwargs): + self.invert_mouse_zoom = invert_mouse_zoom + self.window = window + self.camera = window.camera + self.action = { + # Rotation around the view Y (up) vector + 'left': False, + 'right': False, + # Rotation around the view X vector + 'up': False, + 'down': False, + # Rotation around the view Z vector + 'spin_left': False, + 'spin_right': False, + # Rotation around the model Z vector + 'rotate_z_neg': False, + 'rotate_z_pos': False, + # Reset to the default rotation + 'reset_camera': False, + # Performs camera z-translation + 'zoom_in': False, + 'zoom_out': False, + # Use alternative sensitivity (speed) + 'modify_sensitivity': False, + # Rotation presets + 'rot_preset_xy': False, + 'rot_preset_xz': False, + 'rot_preset_yz': False, + 'rot_preset_perspective': False, + # axes + 'toggle_axes': False, + 'toggle_axe_colors': False, + # screenshot + 'save_image': False + } + + def update(self, dt): + z = 0 + if self.action['zoom_out']: + z -= 1 + if self.action['zoom_in']: + z += 1 + if z != 0: + self.camera.zoom_relative(z/10.0, self.get_key_sensitivity()/10.0) + + dx, dy, dz = 0, 0, 0 + if self.action['left']: + dx -= 1 + if self.action['right']: + dx += 1 + if self.action['up']: + dy -= 1 + if self.action['down']: + dy += 1 + if self.action['spin_left']: + dz += 1 + if self.action['spin_right']: + dz -= 1 + + if not self.is_2D(): + if dx != 0: + self.camera.euler_rotate(dx*dt*self.get_key_sensitivity(), + *(get_direction_vectors()[1])) + if dy != 0: + self.camera.euler_rotate(dy*dt*self.get_key_sensitivity(), + *(get_direction_vectors()[0])) + if dz != 0: + self.camera.euler_rotate(dz*dt*self.get_key_sensitivity(), + *(get_direction_vectors()[2])) + else: + self.camera.mouse_translate(0, 0, dx*dt*self.get_key_sensitivity(), + -dy*dt*self.get_key_sensitivity()) + + rz = 0 + if self.action['rotate_z_neg'] and not self.is_2D(): + rz -= 1 + if self.action['rotate_z_pos'] and not self.is_2D(): + rz += 1 + + if rz != 0: + self.camera.euler_rotate(rz*dt*self.get_key_sensitivity(), + *(get_basis_vectors()[2])) + + if self.action['reset_camera']: + self.camera.reset() + + if self.action['rot_preset_xy']: + self.camera.set_rot_preset('xy') + if self.action['rot_preset_xz']: + self.camera.set_rot_preset('xz') + if self.action['rot_preset_yz']: + self.camera.set_rot_preset('yz') + if self.action['rot_preset_perspective']: + self.camera.set_rot_preset('perspective') + + if self.action['toggle_axes']: + self.action['toggle_axes'] = False + self.camera.axes.toggle_visible() + + if self.action['toggle_axe_colors']: + self.action['toggle_axe_colors'] = False + self.camera.axes.toggle_colors() + + if self.action['save_image']: + self.action['save_image'] = False + self.window.plot.saveimage() + + return True + + def get_mouse_sensitivity(self): + if self.action['modify_sensitivity']: + return self.modified_mouse_sensitivity + else: + return self.normal_mouse_sensitivity + + def get_key_sensitivity(self): + if self.action['modify_sensitivity']: + return self.modified_key_sensitivity + else: + return self.normal_key_sensitivity + + def on_key_press(self, symbol, modifiers): + if symbol in self.keymap: + self.action[self.keymap[symbol]] = True + + def on_key_release(self, symbol, modifiers): + if symbol in self.keymap: + self.action[self.keymap[symbol]] = False + + def on_mouse_drag(self, x, y, dx, dy, buttons, modifiers): + if buttons & LEFT: + if self.is_2D(): + self.camera.mouse_translate(x, y, dx, dy) + else: + self.camera.spherical_rotate((x - dx, y - dy), (x, y), + self.get_mouse_sensitivity()) + if buttons & MIDDLE: + self.camera.zoom_relative([1, -1][self.invert_mouse_zoom]*dy, + self.get_mouse_sensitivity()/20.0) + if buttons & RIGHT: + self.camera.mouse_translate(x, y, dx, dy) + + def on_mouse_scroll(self, x, y, dx, dy): + self.camera.zoom_relative([1, -1][self.invert_mouse_zoom]*dy, + self.get_mouse_sensitivity()) + + def is_2D(self): + functions = self.window.plot._functions + for i in functions: + if len(functions[i].i_vars) > 1 or len(functions[i].d_vars) > 2: + return False + return True diff --git a/phivenv/Lib/site-packages/sympy/plotting/pygletplot/plot_curve.py b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/plot_curve.py new file mode 100644 index 0000000000000000000000000000000000000000..6b97dac843f58c76694d424f0b0b7e3499ba5202 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/plot_curve.py @@ -0,0 +1,82 @@ +import pyglet.gl as pgl +from sympy.core import S +from sympy.plotting.pygletplot.plot_mode_base import PlotModeBase + + +class PlotCurve(PlotModeBase): + + style_override = 'wireframe' + + def _on_calculate_verts(self): + self.t_interval = self.intervals[0] + self.t_set = list(self.t_interval.frange()) + self.bounds = [[S.Infinity, S.NegativeInfinity, 0], + [S.Infinity, S.NegativeInfinity, 0], + [S.Infinity, S.NegativeInfinity, 0]] + evaluate = self._get_evaluator() + + self._calculating_verts_pos = 0.0 + self._calculating_verts_len = float(self.t_interval.v_len) + + self.verts = [] + b = self.bounds + for t in self.t_set: + try: + _e = evaluate(t) # calculate vertex + except (NameError, ZeroDivisionError): + _e = None + if _e is not None: # update bounding box + for axis in range(3): + b[axis][0] = min([b[axis][0], _e[axis]]) + b[axis][1] = max([b[axis][1], _e[axis]]) + self.verts.append(_e) + self._calculating_verts_pos += 1.0 + + for axis in range(3): + b[axis][2] = b[axis][1] - b[axis][0] + if b[axis][2] == 0.0: + b[axis][2] = 1.0 + + self.push_wireframe(self.draw_verts(False)) + + def _on_calculate_cverts(self): + if not self.verts or not self.color: + return + + def set_work_len(n): + self._calculating_cverts_len = float(n) + + def inc_work_pos(): + self._calculating_cverts_pos += 1.0 + set_work_len(1) + self._calculating_cverts_pos = 0 + self.cverts = self.color.apply_to_curve(self.verts, + self.t_set, + set_len=set_work_len, + inc_pos=inc_work_pos) + self.push_wireframe(self.draw_verts(True)) + + def calculate_one_cvert(self, t): + vert = self.verts[t] + return self.color(vert[0], vert[1], vert[2], + self.t_set[t], None) + + def draw_verts(self, use_cverts): + def f(): + pgl.glBegin(pgl.GL_LINE_STRIP) + for t in range(len(self.t_set)): + p = self.verts[t] + if p is None: + pgl.glEnd() + pgl.glBegin(pgl.GL_LINE_STRIP) + continue + if use_cverts: + c = self.cverts[t] + if c is None: + c = (0, 0, 0) + pgl.glColor3f(*c) + else: + pgl.glColor3f(*self.default_wireframe_color) + pgl.glVertex3f(*p) + pgl.glEnd() + return f diff --git a/phivenv/Lib/site-packages/sympy/plotting/pygletplot/plot_interval.py b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/plot_interval.py new file mode 100644 index 0000000000000000000000000000000000000000..085ab096915bbc4a3761b71736b4dd14f1ff779f --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/plot_interval.py @@ -0,0 +1,181 @@ +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.core.sympify import sympify +from sympy.core.numbers import Integer + + +class PlotInterval: + """ + """ + _v, _v_min, _v_max, _v_steps = None, None, None, None + + def require_all_args(f): + def check(self, *args, **kwargs): + for g in [self._v, self._v_min, self._v_max, self._v_steps]: + if g is None: + raise ValueError("PlotInterval is incomplete.") + return f(self, *args, **kwargs) + return check + + def __init__(self, *args): + if len(args) == 1: + if isinstance(args[0], PlotInterval): + self.fill_from(args[0]) + return + elif isinstance(args[0], str): + try: + args = eval(args[0]) + except TypeError: + s_eval_error = "Could not interpret string %s." + raise ValueError(s_eval_error % (args[0])) + elif isinstance(args[0], (tuple, list)): + args = args[0] + else: + raise ValueError("Not an interval.") + if not isinstance(args, (tuple, list)) or len(args) > 4: + f_error = "PlotInterval must be a tuple or list of length 4 or less." + raise ValueError(f_error) + + args = list(args) + if len(args) > 0 and (args[0] is None or isinstance(args[0], Symbol)): + self.v = args.pop(0) + if len(args) in [2, 3]: + self.v_min = args.pop(0) + self.v_max = args.pop(0) + if len(args) == 1: + self.v_steps = args.pop(0) + elif len(args) == 1: + self.v_steps = args.pop(0) + + def get_v(self): + return self._v + + def set_v(self, v): + if v is None: + self._v = None + return + if not isinstance(v, Symbol): + raise ValueError("v must be a SymPy Symbol.") + self._v = v + + def get_v_min(self): + return self._v_min + + def set_v_min(self, v_min): + if v_min is None: + self._v_min = None + return + try: + self._v_min = sympify(v_min) + float(self._v_min.evalf()) + except TypeError: + raise ValueError("v_min could not be interpreted as a number.") + + def get_v_max(self): + return self._v_max + + def set_v_max(self, v_max): + if v_max is None: + self._v_max = None + return + try: + self._v_max = sympify(v_max) + float(self._v_max.evalf()) + except TypeError: + raise ValueError("v_max could not be interpreted as a number.") + + def get_v_steps(self): + return self._v_steps + + def set_v_steps(self, v_steps): + if v_steps is None: + self._v_steps = None + return + if isinstance(v_steps, int): + v_steps = Integer(v_steps) + elif not isinstance(v_steps, Integer): + raise ValueError("v_steps must be an int or SymPy Integer.") + if v_steps <= S.Zero: + raise ValueError("v_steps must be positive.") + self._v_steps = v_steps + + @require_all_args + def get_v_len(self): + return self.v_steps + 1 + + v = property(get_v, set_v) + v_min = property(get_v_min, set_v_min) + v_max = property(get_v_max, set_v_max) + v_steps = property(get_v_steps, set_v_steps) + v_len = property(get_v_len) + + def fill_from(self, b): + if b.v is not None: + self.v = b.v + if b.v_min is not None: + self.v_min = b.v_min + if b.v_max is not None: + self.v_max = b.v_max + if b.v_steps is not None: + self.v_steps = b.v_steps + + @staticmethod + def try_parse(*args): + """ + Returns a PlotInterval if args can be interpreted + as such, otherwise None. + """ + if len(args) == 1 and isinstance(args[0], PlotInterval): + return args[0] + try: + return PlotInterval(*args) + except ValueError: + return None + + def _str_base(self): + return ",".join([str(self.v), str(self.v_min), + str(self.v_max), str(self.v_steps)]) + + def __repr__(self): + """ + A string representing the interval in class constructor form. + """ + return "PlotInterval(%s)" % (self._str_base()) + + def __str__(self): + """ + A string representing the interval in list form. + """ + return "[%s]" % (self._str_base()) + + @require_all_args + def assert_complete(self): + pass + + @require_all_args + def vrange(self): + """ + Yields v_steps+1 SymPy numbers ranging from + v_min to v_max. + """ + d = (self.v_max - self.v_min) / self.v_steps + for i in range(self.v_steps + 1): + a = self.v_min + (d * Integer(i)) + yield a + + @require_all_args + def vrange2(self): + """ + Yields v_steps pairs of SymPy numbers ranging from + (v_min, v_min + step) to (v_max - step, v_max). + """ + d = (self.v_max - self.v_min) / self.v_steps + a = self.v_min + (d * S.Zero) + for i in range(self.v_steps): + b = self.v_min + (d * Integer(i + 1)) + yield a, b + a = b + + def frange(self): + for i in self.vrange(): + yield float(i.evalf()) diff --git a/phivenv/Lib/site-packages/sympy/plotting/pygletplot/plot_mode.py b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/plot_mode.py new file mode 100644 index 0000000000000000000000000000000000000000..f4ee00db9177b98b3259438949836fe5b69416c2 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/plot_mode.py @@ -0,0 +1,400 @@ +from .plot_interval import PlotInterval +from .plot_object import PlotObject +from .util import parse_option_string +from sympy.core.symbol import Symbol +from sympy.core.sympify import sympify +from sympy.geometry.entity import GeometryEntity +from sympy.utilities.iterables import is_sequence + + +class PlotMode(PlotObject): + """ + Grandparent class for plotting + modes. Serves as interface for + registration, lookup, and init + of modes. + + To create a new plot mode, + inherit from PlotModeBase + or one of its children, such + as PlotSurface or PlotCurve. + """ + + ## Class-level attributes + ## used to register and lookup + ## plot modes. See PlotModeBase + ## for descriptions and usage. + + i_vars, d_vars = '', '' + intervals = [] + aliases = [] + is_default = False + + ## Draw is the only method here which + ## is meant to be overridden in child + ## classes, and PlotModeBase provides + ## a base implementation. + def draw(self): + raise NotImplementedError() + + ## Everything else in this file has to + ## do with registration and retrieval + ## of plot modes. This is where I've + ## hidden much of the ugliness of automatic + ## plot mode divination... + + ## Plot mode registry data structures + _mode_alias_list = [] + _mode_map = { + 1: {1: {}, 2: {}}, + 2: {1: {}, 2: {}}, + 3: {1: {}, 2: {}}, + } # [d][i][alias_str]: class + _mode_default_map = { + 1: {}, + 2: {}, + 3: {}, + } # [d][i]: class + _i_var_max, _d_var_max = 2, 3 + + def __new__(cls, *args, **kwargs): + """ + This is the function which interprets + arguments given to Plot.__init__ and + Plot.__setattr__. Returns an initialized + instance of the appropriate child class. + """ + + newargs, newkwargs = PlotMode._extract_options(args, kwargs) + mode_arg = newkwargs.get('mode', '') + + # Interpret the arguments + d_vars, intervals = PlotMode._interpret_args(newargs) + i_vars = PlotMode._find_i_vars(d_vars, intervals) + i, d = max([len(i_vars), len(intervals)]), len(d_vars) + + # Find the appropriate mode + subcls = PlotMode._get_mode(mode_arg, i, d) + + # Create the object + o = object.__new__(subcls) + + # Do some setup for the mode instance + o.d_vars = d_vars + o._fill_i_vars(i_vars) + o._fill_intervals(intervals) + o.options = newkwargs + + return o + + @staticmethod + def _get_mode(mode_arg, i_var_count, d_var_count): + """ + Tries to return an appropriate mode class. + Intended to be called only by __new__. + + mode_arg + Can be a string or a class. If it is a + PlotMode subclass, it is simply returned. + If it is a string, it can an alias for + a mode or an empty string. In the latter + case, we try to find a default mode for + the i_var_count and d_var_count. + + i_var_count + The number of independent variables + needed to evaluate the d_vars. + + d_var_count + The number of dependent variables; + usually the number of functions to + be evaluated in plotting. + + For example, a Cartesian function y = f(x) has + one i_var (x) and one d_var (y). A parametric + form x,y,z = f(u,v), f(u,v), f(u,v) has two + two i_vars (u,v) and three d_vars (x,y,z). + """ + # if the mode_arg is simply a PlotMode class, + # check that the mode supports the numbers + # of independent and dependent vars, then + # return it + try: + m = None + if issubclass(mode_arg, PlotMode): + m = mode_arg + except TypeError: + pass + if m: + if not m._was_initialized: + raise ValueError(("To use unregistered plot mode %s " + "you must first call %s._init_mode().") + % (m.__name__, m.__name__)) + if d_var_count != m.d_var_count: + raise ValueError(("%s can only plot functions " + "with %i dependent variables.") + % (m.__name__, + m.d_var_count)) + if i_var_count > m.i_var_count: + raise ValueError(("%s cannot plot functions " + "with more than %i independent " + "variables.") + % (m.__name__, + m.i_var_count)) + return m + # If it is a string, there are two possibilities. + if isinstance(mode_arg, str): + i, d = i_var_count, d_var_count + if i > PlotMode._i_var_max: + raise ValueError(var_count_error(True, True)) + if d > PlotMode._d_var_max: + raise ValueError(var_count_error(False, True)) + # If the string is '', try to find a suitable + # default mode + if not mode_arg: + return PlotMode._get_default_mode(i, d) + # Otherwise, interpret the string as a mode + # alias (e.g. 'cartesian', 'parametric', etc) + else: + return PlotMode._get_aliased_mode(mode_arg, i, d) + else: + raise ValueError("PlotMode argument must be " + "a class or a string") + + @staticmethod + def _get_default_mode(i, d, i_vars=-1): + if i_vars == -1: + i_vars = i + try: + return PlotMode._mode_default_map[d][i] + except KeyError: + # Keep looking for modes in higher i var counts + # which support the given d var count until we + # reach the max i_var count. + if i < PlotMode._i_var_max: + return PlotMode._get_default_mode(i + 1, d, i_vars) + else: + raise ValueError(("Couldn't find a default mode " + "for %i independent and %i " + "dependent variables.") % (i_vars, d)) + + @staticmethod + def _get_aliased_mode(alias, i, d, i_vars=-1): + if i_vars == -1: + i_vars = i + if alias not in PlotMode._mode_alias_list: + raise ValueError(("Couldn't find a mode called" + " %s. Known modes: %s.") + % (alias, ", ".join(PlotMode._mode_alias_list))) + try: + return PlotMode._mode_map[d][i][alias] + except TypeError: + # Keep looking for modes in higher i var counts + # which support the given d var count and alias + # until we reach the max i_var count. + if i < PlotMode._i_var_max: + return PlotMode._get_aliased_mode(alias, i + 1, d, i_vars) + else: + raise ValueError(("Couldn't find a %s mode " + "for %i independent and %i " + "dependent variables.") + % (alias, i_vars, d)) + + @classmethod + def _register(cls): + """ + Called once for each user-usable plot mode. + For Cartesian2D, it is invoked after the + class definition: Cartesian2D._register() + """ + name = cls.__name__ + cls._init_mode() + + try: + i, d = cls.i_var_count, cls.d_var_count + # Add the mode to _mode_map under all + # given aliases + for a in cls.aliases: + if a not in PlotMode._mode_alias_list: + # Also track valid aliases, so + # we can quickly know when given + # an invalid one in _get_mode. + PlotMode._mode_alias_list.append(a) + PlotMode._mode_map[d][i][a] = cls + if cls.is_default: + # If this mode was marked as the + # default for this d,i combination, + # also set that. + PlotMode._mode_default_map[d][i] = cls + + except Exception as e: + raise RuntimeError(("Failed to register " + "plot mode %s. Reason: %s") + % (name, (str(e)))) + + @classmethod + def _init_mode(cls): + """ + Initializes the plot mode based on + the 'mode-specific parameters' above. + Only intended to be called by + PlotMode._register(). To use a mode without + registering it, you can directly call + ModeSubclass._init_mode(). + """ + def symbols_list(symbol_str): + return [Symbol(s) for s in symbol_str] + + # Convert the vars strs into + # lists of symbols. + cls.i_vars = symbols_list(cls.i_vars) + cls.d_vars = symbols_list(cls.d_vars) + + # Var count is used often, calculate + # it once here + cls.i_var_count = len(cls.i_vars) + cls.d_var_count = len(cls.d_vars) + + if cls.i_var_count > PlotMode._i_var_max: + raise ValueError(var_count_error(True, False)) + if cls.d_var_count > PlotMode._d_var_max: + raise ValueError(var_count_error(False, False)) + + # Try to use first alias as primary_alias + if len(cls.aliases) > 0: + cls.primary_alias = cls.aliases[0] + else: + cls.primary_alias = cls.__name__ + + di = cls.intervals + if len(di) != cls.i_var_count: + raise ValueError("Plot mode must provide a " + "default interval for each i_var.") + for i in range(cls.i_var_count): + # default intervals must be given [min,max,steps] + # (no var, but they must be in the same order as i_vars) + if len(di[i]) != 3: + raise ValueError("length should be equal to 3") + + # Initialize an incomplete interval, + # to later be filled with a var when + # the mode is instantiated. + di[i] = PlotInterval(None, *di[i]) + + # To prevent people from using modes + # without these required fields set up. + cls._was_initialized = True + + _was_initialized = False + + ## Initializer Helper Methods + + @staticmethod + def _find_i_vars(functions, intervals): + i_vars = [] + + # First, collect i_vars in the + # order they are given in any + # intervals. + for i in intervals: + if i.v is None: + continue + elif i.v in i_vars: + raise ValueError(("Multiple intervals given " + "for %s.") % (str(i.v))) + i_vars.append(i.v) + + # Then, find any remaining + # i_vars in given functions + # (aka d_vars) + for f in functions: + for a in f.free_symbols: + if a not in i_vars: + i_vars.append(a) + + return i_vars + + def _fill_i_vars(self, i_vars): + # copy default i_vars + self.i_vars = [Symbol(str(i)) for i in self.i_vars] + # replace with given i_vars + for i in range(len(i_vars)): + self.i_vars[i] = i_vars[i] + + def _fill_intervals(self, intervals): + # copy default intervals + self.intervals = [PlotInterval(i) for i in self.intervals] + # track i_vars used so far + v_used = [] + # fill copy of default + # intervals with given info + for i in range(len(intervals)): + self.intervals[i].fill_from(intervals[i]) + if self.intervals[i].v is not None: + v_used.append(self.intervals[i].v) + # Find any orphan intervals and + # assign them i_vars + for i in range(len(self.intervals)): + if self.intervals[i].v is None: + u = [v for v in self.i_vars if v not in v_used] + if len(u) == 0: + raise ValueError("length should not be equal to 0") + self.intervals[i].v = u[0] + v_used.append(u[0]) + + @staticmethod + def _interpret_args(args): + interval_wrong_order = "PlotInterval %s was given before any function(s)." + interpret_error = "Could not interpret %s as a function or interval." + + functions, intervals = [], [] + if isinstance(args[0], GeometryEntity): + for coords in list(args[0].arbitrary_point()): + functions.append(coords) + intervals.append(PlotInterval.try_parse(args[0].plot_interval())) + else: + for a in args: + i = PlotInterval.try_parse(a) + if i is not None: + if len(functions) == 0: + raise ValueError(interval_wrong_order % (str(i))) + else: + intervals.append(i) + else: + if is_sequence(a, include=str): + raise ValueError(interpret_error % (str(a))) + try: + f = sympify(a) + functions.append(f) + except TypeError: + raise ValueError(interpret_error % str(a)) + + return functions, intervals + + @staticmethod + def _extract_options(args, kwargs): + newkwargs, newargs = {}, [] + for a in args: + if isinstance(a, str): + newkwargs = dict(newkwargs, **parse_option_string(a)) + else: + newargs.append(a) + newkwargs = dict(newkwargs, **kwargs) + return newargs, newkwargs + + +def var_count_error(is_independent, is_plotting): + """ + Used to format an error message which differs + slightly in 4 places. + """ + if is_plotting: + v = "Plotting" + else: + v = "Registering plot modes" + if is_independent: + n, s = PlotMode._i_var_max, "independent" + else: + n, s = PlotMode._d_var_max, "dependent" + return ("%s with more than %i %s variables " + "is not supported.") % (v, n, s) diff --git a/phivenv/Lib/site-packages/sympy/plotting/pygletplot/plot_mode_base.py b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/plot_mode_base.py new file mode 100644 index 0000000000000000000000000000000000000000..2c6503650afda122e271bdecb2365c8fa20f2376 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/plot_mode_base.py @@ -0,0 +1,378 @@ +import pyglet.gl as pgl +from sympy.core import S +from sympy.plotting.pygletplot.color_scheme import ColorScheme +from sympy.plotting.pygletplot.plot_mode import PlotMode +from sympy.utilities.iterables import is_sequence +from time import sleep +from threading import Thread, Event, RLock +import warnings + + +class PlotModeBase(PlotMode): + """ + Intended parent class for plotting + modes. Provides base functionality + in conjunction with its parent, + PlotMode. + """ + + ## + ## Class-Level Attributes + ## + + """ + The following attributes are meant + to be set at the class level, and serve + as parameters to the plot mode registry + (in PlotMode). See plot_modes.py for + concrete examples. + """ + + """ + i_vars + 'x' for Cartesian2D + 'xy' for Cartesian3D + etc. + + d_vars + 'y' for Cartesian2D + 'r' for Polar + etc. + """ + i_vars, d_vars = '', '' + + """ + intervals + Default intervals for each i_var, and in the + same order. Specified [min, max, steps]. + No variable can be given (it is bound later). + """ + intervals = [] + + """ + aliases + A list of strings which can be used to + access this mode. + 'cartesian' for Cartesian2D and Cartesian3D + 'polar' for Polar + 'cylindrical', 'polar' for Cylindrical + + Note that _init_mode chooses the first alias + in the list as the mode's primary_alias, which + will be displayed to the end user in certain + contexts. + """ + aliases = [] + + """ + is_default + Whether to set this mode as the default + for arguments passed to PlotMode() containing + the same number of d_vars as this mode and + at most the same number of i_vars. + """ + is_default = False + + """ + All of the above attributes are defined in PlotMode. + The following ones are specific to PlotModeBase. + """ + + """ + A list of the render styles. Do not modify. + """ + styles = {'wireframe': 1, 'solid': 2, 'both': 3} + + """ + style_override + Always use this style if not blank. + """ + style_override = '' + + """ + default_wireframe_color + default_solid_color + Can be used when color is None or being calculated. + Used by PlotCurve and PlotSurface, but not anywhere + in PlotModeBase. + """ + + default_wireframe_color = (0.85, 0.85, 0.85) + default_solid_color = (0.6, 0.6, 0.9) + default_rot_preset = 'xy' + + ## + ## Instance-Level Attributes + ## + + ## 'Abstract' member functions + def _get_evaluator(self): + if self.use_lambda_eval: + try: + e = self._get_lambda_evaluator() + return e + except Exception: + warnings.warn("\nWarning: creating lambda evaluator failed. " + "Falling back on SymPy subs evaluator.") + return self._get_sympy_evaluator() + + def _get_sympy_evaluator(self): + raise NotImplementedError() + + def _get_lambda_evaluator(self): + raise NotImplementedError() + + def _on_calculate_verts(self): + raise NotImplementedError() + + def _on_calculate_cverts(self): + raise NotImplementedError() + + ## Base member functions + def __init__(self, *args, bounds_callback=None, **kwargs): + self.verts = [] + self.cverts = [] + self.bounds = [[S.Infinity, S.NegativeInfinity, 0], + [S.Infinity, S.NegativeInfinity, 0], + [S.Infinity, S.NegativeInfinity, 0]] + self.cbounds = [[S.Infinity, S.NegativeInfinity, 0], + [S.Infinity, S.NegativeInfinity, 0], + [S.Infinity, S.NegativeInfinity, 0]] + + self._draw_lock = RLock() + + self._calculating_verts = Event() + self._calculating_cverts = Event() + self._calculating_verts_pos = 0.0 + self._calculating_verts_len = 0.0 + self._calculating_cverts_pos = 0.0 + self._calculating_cverts_len = 0.0 + + self._max_render_stack_size = 3 + self._draw_wireframe = [-1] + self._draw_solid = [-1] + + self._style = None + self._color = None + + self.predraw = [] + self.postdraw = [] + + self.use_lambda_eval = self.options.pop('use_sympy_eval', None) is None + self.style = self.options.pop('style', '') + self.color = self.options.pop('color', 'rainbow') + self.bounds_callback = bounds_callback + + self._on_calculate() + + def synchronized(f): + def w(self, *args, **kwargs): + self._draw_lock.acquire() + try: + r = f(self, *args, **kwargs) + return r + finally: + self._draw_lock.release() + return w + + @synchronized + def push_wireframe(self, function): + """ + Push a function which performs gl commands + used to build a display list. (The list is + built outside of the function) + """ + assert callable(function) + self._draw_wireframe.append(function) + if len(self._draw_wireframe) > self._max_render_stack_size: + del self._draw_wireframe[1] # leave marker element + + @synchronized + def push_solid(self, function): + """ + Push a function which performs gl commands + used to build a display list. (The list is + built outside of the function) + """ + assert callable(function) + self._draw_solid.append(function) + if len(self._draw_solid) > self._max_render_stack_size: + del self._draw_solid[1] # leave marker element + + def _create_display_list(self, function): + dl = pgl.glGenLists(1) + pgl.glNewList(dl, pgl.GL_COMPILE) + function() + pgl.glEndList() + return dl + + def _render_stack_top(self, render_stack): + top = render_stack[-1] + if top == -1: + return -1 # nothing to display + elif callable(top): + dl = self._create_display_list(top) + render_stack[-1] = (dl, top) + return dl # display newly added list + elif len(top) == 2: + if pgl.GL_TRUE == pgl.glIsList(top[0]): + return top[0] # display stored list + dl = self._create_display_list(top[1]) + render_stack[-1] = (dl, top[1]) + return dl # display regenerated list + + def _draw_solid_display_list(self, dl): + pgl.glPushAttrib(pgl.GL_ENABLE_BIT | pgl.GL_POLYGON_BIT) + pgl.glPolygonMode(pgl.GL_FRONT_AND_BACK, pgl.GL_FILL) + pgl.glCallList(dl) + pgl.glPopAttrib() + + def _draw_wireframe_display_list(self, dl): + pgl.glPushAttrib(pgl.GL_ENABLE_BIT | pgl.GL_POLYGON_BIT) + pgl.glPolygonMode(pgl.GL_FRONT_AND_BACK, pgl.GL_LINE) + pgl.glEnable(pgl.GL_POLYGON_OFFSET_LINE) + pgl.glPolygonOffset(-0.005, -50.0) + pgl.glCallList(dl) + pgl.glPopAttrib() + + @synchronized + def draw(self): + for f in self.predraw: + if callable(f): + f() + if self.style_override: + style = self.styles[self.style_override] + else: + style = self.styles[self._style] + # Draw solid component if style includes solid + if style & 2: + dl = self._render_stack_top(self._draw_solid) + if dl > 0 and pgl.GL_TRUE == pgl.glIsList(dl): + self._draw_solid_display_list(dl) + # Draw wireframe component if style includes wireframe + if style & 1: + dl = self._render_stack_top(self._draw_wireframe) + if dl > 0 and pgl.GL_TRUE == pgl.glIsList(dl): + self._draw_wireframe_display_list(dl) + for f in self.postdraw: + if callable(f): + f() + + def _on_change_color(self, color): + Thread(target=self._calculate_cverts).start() + + def _on_calculate(self): + Thread(target=self._calculate_all).start() + + def _calculate_all(self): + self._calculate_verts() + self._calculate_cverts() + + def _calculate_verts(self): + if self._calculating_verts.is_set(): + return + self._calculating_verts.set() + try: + self._on_calculate_verts() + finally: + self._calculating_verts.clear() + if callable(self.bounds_callback): + self.bounds_callback() + + def _calculate_cverts(self): + if self._calculating_verts.is_set(): + return + while self._calculating_cverts.is_set(): + sleep(0) # wait for previous calculation + self._calculating_cverts.set() + try: + self._on_calculate_cverts() + finally: + self._calculating_cverts.clear() + + def _get_calculating_verts(self): + return self._calculating_verts.is_set() + + def _get_calculating_verts_pos(self): + return self._calculating_verts_pos + + def _get_calculating_verts_len(self): + return self._calculating_verts_len + + def _get_calculating_cverts(self): + return self._calculating_cverts.is_set() + + def _get_calculating_cverts_pos(self): + return self._calculating_cverts_pos + + def _get_calculating_cverts_len(self): + return self._calculating_cverts_len + + ## Property handlers + def _get_style(self): + return self._style + + @synchronized + def _set_style(self, v): + if v is None: + return + if v == '': + step_max = 0 + for i in self.intervals: + if i.v_steps is None: + continue + step_max = max([step_max, int(i.v_steps)]) + v = ['both', 'solid'][step_max > 40] + if v not in self.styles: + raise ValueError("v should be there in self.styles") + if v == self._style: + return + self._style = v + + def _get_color(self): + return self._color + + @synchronized + def _set_color(self, v): + try: + if v is not None: + if is_sequence(v): + v = ColorScheme(*v) + else: + v = ColorScheme(v) + if repr(v) == repr(self._color): + return + self._on_change_color(v) + self._color = v + except Exception as e: + raise RuntimeError("Color change failed. " + "Reason: %s" % (str(e))) + + style = property(_get_style, _set_style) + color = property(_get_color, _set_color) + + calculating_verts = property(_get_calculating_verts) + calculating_verts_pos = property(_get_calculating_verts_pos) + calculating_verts_len = property(_get_calculating_verts_len) + + calculating_cverts = property(_get_calculating_cverts) + calculating_cverts_pos = property(_get_calculating_cverts_pos) + calculating_cverts_len = property(_get_calculating_cverts_len) + + ## String representations + + def __str__(self): + f = ", ".join(str(d) for d in self.d_vars) + o = "'mode=%s'" % (self.primary_alias) + return ", ".join([f, o]) + + def __repr__(self): + f = ", ".join(str(d) for d in self.d_vars) + i = ", ".join(str(i) for i in self.intervals) + d = [('mode', self.primary_alias), + ('color', str(self.color)), + ('style', str(self.style))] + + o = "'%s'" % ("; ".join("%s=%s" % (k, v) + for k, v in d if v != 'None')) + return ", ".join([f, i, o]) diff --git a/phivenv/Lib/site-packages/sympy/plotting/pygletplot/plot_modes.py b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/plot_modes.py new file mode 100644 index 0000000000000000000000000000000000000000..e78e0b4ce291b071f684fa3ffc02f456dffe0023 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/plot_modes.py @@ -0,0 +1,209 @@ +from sympy.utilities.lambdify import lambdify +from sympy.core.numbers import pi +from sympy.functions import sin, cos +from sympy.plotting.pygletplot.plot_curve import PlotCurve +from sympy.plotting.pygletplot.plot_surface import PlotSurface + +from math import sin as p_sin +from math import cos as p_cos + + +def float_vec3(f): + def inner(*args): + v = f(*args) + return float(v[0]), float(v[1]), float(v[2]) + return inner + + +class Cartesian2D(PlotCurve): + i_vars, d_vars = 'x', 'y' + intervals = [[-5, 5, 100]] + aliases = ['cartesian'] + is_default = True + + def _get_sympy_evaluator(self): + fy = self.d_vars[0] + x = self.t_interval.v + + @float_vec3 + def e(_x): + return (_x, fy.subs(x, _x), 0.0) + return e + + def _get_lambda_evaluator(self): + fy = self.d_vars[0] + x = self.t_interval.v + return lambdify([x], [x, fy, 0.0]) + + +class Cartesian3D(PlotSurface): + i_vars, d_vars = 'xy', 'z' + intervals = [[-1, 1, 40], [-1, 1, 40]] + aliases = ['cartesian', 'monge'] + is_default = True + + def _get_sympy_evaluator(self): + fz = self.d_vars[0] + x = self.u_interval.v + y = self.v_interval.v + + @float_vec3 + def e(_x, _y): + return (_x, _y, fz.subs(x, _x).subs(y, _y)) + return e + + def _get_lambda_evaluator(self): + fz = self.d_vars[0] + x = self.u_interval.v + y = self.v_interval.v + return lambdify([x, y], [x, y, fz]) + + +class ParametricCurve2D(PlotCurve): + i_vars, d_vars = 't', 'xy' + intervals = [[0, 2*pi, 100]] + aliases = ['parametric'] + is_default = True + + def _get_sympy_evaluator(self): + fx, fy = self.d_vars + t = self.t_interval.v + + @float_vec3 + def e(_t): + return (fx.subs(t, _t), fy.subs(t, _t), 0.0) + return e + + def _get_lambda_evaluator(self): + fx, fy = self.d_vars + t = self.t_interval.v + return lambdify([t], [fx, fy, 0.0]) + + +class ParametricCurve3D(PlotCurve): + i_vars, d_vars = 't', 'xyz' + intervals = [[0, 2*pi, 100]] + aliases = ['parametric'] + is_default = True + + def _get_sympy_evaluator(self): + fx, fy, fz = self.d_vars + t = self.t_interval.v + + @float_vec3 + def e(_t): + return (fx.subs(t, _t), fy.subs(t, _t), fz.subs(t, _t)) + return e + + def _get_lambda_evaluator(self): + fx, fy, fz = self.d_vars + t = self.t_interval.v + return lambdify([t], [fx, fy, fz]) + + +class ParametricSurface(PlotSurface): + i_vars, d_vars = 'uv', 'xyz' + intervals = [[-1, 1, 40], [-1, 1, 40]] + aliases = ['parametric'] + is_default = True + + def _get_sympy_evaluator(self): + fx, fy, fz = self.d_vars + u = self.u_interval.v + v = self.v_interval.v + + @float_vec3 + def e(_u, _v): + return (fx.subs(u, _u).subs(v, _v), + fy.subs(u, _u).subs(v, _v), + fz.subs(u, _u).subs(v, _v)) + return e + + def _get_lambda_evaluator(self): + fx, fy, fz = self.d_vars + u = self.u_interval.v + v = self.v_interval.v + return lambdify([u, v], [fx, fy, fz]) + + +class Polar(PlotCurve): + i_vars, d_vars = 't', 'r' + intervals = [[0, 2*pi, 100]] + aliases = ['polar'] + is_default = False + + def _get_sympy_evaluator(self): + fr = self.d_vars[0] + t = self.t_interval.v + + def e(_t): + _r = float(fr.subs(t, _t)) + return (_r*p_cos(_t), _r*p_sin(_t), 0.0) + return e + + def _get_lambda_evaluator(self): + fr = self.d_vars[0] + t = self.t_interval.v + fx, fy = fr*cos(t), fr*sin(t) + return lambdify([t], [fx, fy, 0.0]) + + +class Cylindrical(PlotSurface): + i_vars, d_vars = 'th', 'r' + intervals = [[0, 2*pi, 40], [-1, 1, 20]] + aliases = ['cylindrical', 'polar'] + is_default = False + + def _get_sympy_evaluator(self): + fr = self.d_vars[0] + t = self.u_interval.v + h = self.v_interval.v + + def e(_t, _h): + _r = float(fr.subs(t, _t).subs(h, _h)) + return (_r*p_cos(_t), _r*p_sin(_t), _h) + return e + + def _get_lambda_evaluator(self): + fr = self.d_vars[0] + t = self.u_interval.v + h = self.v_interval.v + fx, fy = fr*cos(t), fr*sin(t) + return lambdify([t, h], [fx, fy, h]) + + +class Spherical(PlotSurface): + i_vars, d_vars = 'tp', 'r' + intervals = [[0, 2*pi, 40], [0, pi, 20]] + aliases = ['spherical'] + is_default = False + + def _get_sympy_evaluator(self): + fr = self.d_vars[0] + t = self.u_interval.v + p = self.v_interval.v + + def e(_t, _p): + _r = float(fr.subs(t, _t).subs(p, _p)) + return (_r*p_cos(_t)*p_sin(_p), + _r*p_sin(_t)*p_sin(_p), + _r*p_cos(_p)) + return e + + def _get_lambda_evaluator(self): + fr = self.d_vars[0] + t = self.u_interval.v + p = self.v_interval.v + fx = fr * cos(t) * sin(p) + fy = fr * sin(t) * sin(p) + fz = fr * cos(p) + return lambdify([t, p], [fx, fy, fz]) + +Cartesian2D._register() +Cartesian3D._register() +ParametricCurve2D._register() +ParametricCurve3D._register() +ParametricSurface._register() +Polar._register() +Cylindrical._register() +Spherical._register() diff --git a/phivenv/Lib/site-packages/sympy/plotting/pygletplot/plot_object.py b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/plot_object.py new file mode 100644 index 0000000000000000000000000000000000000000..e51040fb8b1a52c49d849b96692f6c0dba329d75 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/plot_object.py @@ -0,0 +1,17 @@ +class PlotObject: + """ + Base class for objects which can be displayed in + a Plot. + """ + visible = True + + def _draw(self): + if self.visible: + self.draw() + + def draw(self): + """ + OpenGL rendering code for the plot object. + Override in base class. + """ + pass diff --git a/phivenv/Lib/site-packages/sympy/plotting/pygletplot/plot_rotation.py b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/plot_rotation.py new file mode 100644 index 0000000000000000000000000000000000000000..11ede2d1c3e74e5470cf601348e494c35720b9a8 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/plot_rotation.py @@ -0,0 +1,68 @@ +try: + from ctypes import c_float +except ImportError: + pass + +import pyglet.gl as pgl +from math import sqrt as _sqrt, acos as _acos, pi + + +def cross(a, b): + return (a[1] * b[2] - a[2] * b[1], + a[2] * b[0] - a[0] * b[2], + a[0] * b[1] - a[1] * b[0]) + + +def dot(a, b): + return a[0] * b[0] + a[1] * b[1] + a[2] * b[2] + + +def mag(a): + return _sqrt(a[0]**2 + a[1]**2 + a[2]**2) + + +def norm(a): + m = mag(a) + return (a[0] / m, a[1] / m, a[2] / m) + + +def get_sphere_mapping(x, y, width, height): + x = min([max([x, 0]), width]) + y = min([max([y, 0]), height]) + + sr = _sqrt((width/2)**2 + (height/2)**2) + sx = ((x - width / 2) / sr) + sy = ((y - height / 2) / sr) + + sz = 1.0 - sx**2 - sy**2 + + if sz > 0.0: + sz = _sqrt(sz) + return (sx, sy, sz) + else: + sz = 0 + return norm((sx, sy, sz)) + +rad2deg = 180.0 / pi + + +def get_spherical_rotatation(p1, p2, width, height, theta_multiplier): + v1 = get_sphere_mapping(p1[0], p1[1], width, height) + v2 = get_sphere_mapping(p2[0], p2[1], width, height) + + d = min(max([dot(v1, v2), -1]), 1) + + if abs(d - 1.0) < 0.000001: + return None + + raxis = norm( cross(v1, v2) ) + rtheta = theta_multiplier * rad2deg * _acos(d) + + pgl.glPushMatrix() + pgl.glLoadIdentity() + pgl.glRotatef(rtheta, *raxis) + mat = (c_float*16)() + pgl.glGetFloatv(pgl.GL_MODELVIEW_MATRIX, mat) + pgl.glPopMatrix() + + return mat diff --git a/phivenv/Lib/site-packages/sympy/plotting/pygletplot/plot_surface.py b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/plot_surface.py new file mode 100644 index 0000000000000000000000000000000000000000..ed421eebb441d193f4d9b763f56e146c11e5a42c --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/plot_surface.py @@ -0,0 +1,102 @@ +import pyglet.gl as pgl + +from sympy.core import S +from sympy.plotting.pygletplot.plot_mode_base import PlotModeBase + + +class PlotSurface(PlotModeBase): + + default_rot_preset = 'perspective' + + def _on_calculate_verts(self): + self.u_interval = self.intervals[0] + self.u_set = list(self.u_interval.frange()) + self.v_interval = self.intervals[1] + self.v_set = list(self.v_interval.frange()) + self.bounds = [[S.Infinity, S.NegativeInfinity, 0], + [S.Infinity, S.NegativeInfinity, 0], + [S.Infinity, S.NegativeInfinity, 0]] + evaluate = self._get_evaluator() + + self._calculating_verts_pos = 0.0 + self._calculating_verts_len = float( + self.u_interval.v_len*self.v_interval.v_len) + + verts = [] + b = self.bounds + for u in self.u_set: + column = [] + for v in self.v_set: + try: + _e = evaluate(u, v) # calculate vertex + except ZeroDivisionError: + _e = None + if _e is not None: # update bounding box + for axis in range(3): + b[axis][0] = min([b[axis][0], _e[axis]]) + b[axis][1] = max([b[axis][1], _e[axis]]) + column.append(_e) + self._calculating_verts_pos += 1.0 + + verts.append(column) + for axis in range(3): + b[axis][2] = b[axis][1] - b[axis][0] + if b[axis][2] == 0.0: + b[axis][2] = 1.0 + + self.verts = verts + self.push_wireframe(self.draw_verts(False, False)) + self.push_solid(self.draw_verts(False, True)) + + def _on_calculate_cverts(self): + if not self.verts or not self.color: + return + + def set_work_len(n): + self._calculating_cverts_len = float(n) + + def inc_work_pos(): + self._calculating_cverts_pos += 1.0 + set_work_len(1) + self._calculating_cverts_pos = 0 + self.cverts = self.color.apply_to_surface(self.verts, + self.u_set, + self.v_set, + set_len=set_work_len, + inc_pos=inc_work_pos) + self.push_solid(self.draw_verts(True, True)) + + def calculate_one_cvert(self, u, v): + vert = self.verts[u][v] + return self.color(vert[0], vert[1], vert[2], + self.u_set[u], self.v_set[v]) + + def draw_verts(self, use_cverts, use_solid_color): + def f(): + for u in range(1, len(self.u_set)): + pgl.glBegin(pgl.GL_QUAD_STRIP) + for v in range(len(self.v_set)): + pa = self.verts[u - 1][v] + pb = self.verts[u][v] + if pa is None or pb is None: + pgl.glEnd() + pgl.glBegin(pgl.GL_QUAD_STRIP) + continue + if use_cverts: + ca = self.cverts[u - 1][v] + cb = self.cverts[u][v] + if ca is None: + ca = (0, 0, 0) + if cb is None: + cb = (0, 0, 0) + else: + if use_solid_color: + ca = cb = self.default_solid_color + else: + ca = cb = self.default_wireframe_color + pgl.glColor3f(*ca) + pgl.glVertex3f(*pa) + pgl.glColor3f(*cb) + pgl.glVertex3f(*pb) + pgl.glEnd() + return f diff --git a/phivenv/Lib/site-packages/sympy/plotting/pygletplot/plot_window.py b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/plot_window.py new file mode 100644 index 0000000000000000000000000000000000000000..d9df4cc453acb05d7c2d871e9e8efeb36905de5d --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/plot_window.py @@ -0,0 +1,144 @@ +from time import perf_counter + + +import pyglet.gl as pgl + +from sympy.plotting.pygletplot.managed_window import ManagedWindow +from sympy.plotting.pygletplot.plot_camera import PlotCamera +from sympy.plotting.pygletplot.plot_controller import PlotController + + +class PlotWindow(ManagedWindow): + + def __init__(self, plot, antialiasing=True, ortho=False, + invert_mouse_zoom=False, linewidth=1.5, caption="SymPy Plot", + **kwargs): + """ + Named Arguments + =============== + + antialiasing = True + True OR False + ortho = False + True OR False + invert_mouse_zoom = False + True OR False + """ + self.plot = plot + + self.camera = None + self._calculating = False + + self.antialiasing = antialiasing + self.ortho = ortho + self.invert_mouse_zoom = invert_mouse_zoom + self.linewidth = linewidth + self.title = caption + self.last_caption_update = 0 + self.caption_update_interval = 0.2 + self.drawing_first_object = True + + super().__init__(**kwargs) + + def setup(self): + self.camera = PlotCamera(self, ortho=self.ortho) + self.controller = PlotController(self, + invert_mouse_zoom=self.invert_mouse_zoom) + self.push_handlers(self.controller) + + pgl.glClearColor(1.0, 1.0, 1.0, 0.0) + pgl.glClearDepth(1.0) + + pgl.glDepthFunc(pgl.GL_LESS) + pgl.glEnable(pgl.GL_DEPTH_TEST) + + pgl.glEnable(pgl.GL_LINE_SMOOTH) + pgl.glShadeModel(pgl.GL_SMOOTH) + pgl.glLineWidth(self.linewidth) + + pgl.glEnable(pgl.GL_BLEND) + pgl.glBlendFunc(pgl.GL_SRC_ALPHA, pgl.GL_ONE_MINUS_SRC_ALPHA) + + if self.antialiasing: + pgl.glHint(pgl.GL_LINE_SMOOTH_HINT, pgl.GL_NICEST) + pgl.glHint(pgl.GL_POLYGON_SMOOTH_HINT, pgl.GL_NICEST) + + self.camera.setup_projection() + + def on_resize(self, w, h): + super().on_resize(w, h) + if self.camera is not None: + self.camera.setup_projection() + + def update(self, dt): + self.controller.update(dt) + + def draw(self): + self.plot._render_lock.acquire() + self.camera.apply_transformation() + + calc_verts_pos, calc_verts_len = 0, 0 + calc_cverts_pos, calc_cverts_len = 0, 0 + + should_update_caption = (perf_counter() - self.last_caption_update > + self.caption_update_interval) + + if len(self.plot._functions.values()) == 0: + self.drawing_first_object = True + + iterfunctions = iter(self.plot._functions.values()) + + for r in iterfunctions: + if self.drawing_first_object: + self.camera.set_rot_preset(r.default_rot_preset) + self.drawing_first_object = False + + pgl.glPushMatrix() + r._draw() + pgl.glPopMatrix() + + # might as well do this while we are + # iterating and have the lock rather + # than locking and iterating twice + # per frame: + + if should_update_caption: + try: + if r.calculating_verts: + calc_verts_pos += r.calculating_verts_pos + calc_verts_len += r.calculating_verts_len + if r.calculating_cverts: + calc_cverts_pos += r.calculating_cverts_pos + calc_cverts_len += r.calculating_cverts_len + except ValueError: + pass + + for r in self.plot._pobjects: + pgl.glPushMatrix() + r._draw() + pgl.glPopMatrix() + + if should_update_caption: + self.update_caption(calc_verts_pos, calc_verts_len, + calc_cverts_pos, calc_cverts_len) + self.last_caption_update = perf_counter() + + if self.plot._screenshot: + self.plot._screenshot._execute_saving() + + self.plot._render_lock.release() + + def update_caption(self, calc_verts_pos, calc_verts_len, + calc_cverts_pos, calc_cverts_len): + caption = self.title + if calc_verts_len or calc_cverts_len: + caption += " (calculating" + if calc_verts_len > 0: + p = (calc_verts_pos / calc_verts_len) * 100 + caption += " vertices %i%%" % (p) + if calc_cverts_len > 0: + p = (calc_cverts_pos / calc_cverts_len) * 100 + caption += " colors %i%%" % (p) + caption += ")" + if self.caption != caption: + self.set_caption(caption) diff --git a/phivenv/Lib/site-packages/sympy/plotting/pygletplot/tests/__init__.py b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/plotting/pygletplot/tests/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/tests/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f1a28c8ca0d38654db931af7e238997d4729425 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/tests/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/plotting/pygletplot/tests/__pycache__/test_plotting.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/tests/__pycache__/test_plotting.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d2abeb149bdb0f7785189ec40afc351a8f7c948 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/tests/__pycache__/test_plotting.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/plotting/pygletplot/tests/test_plotting.py b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/tests/test_plotting.py new file mode 100644 index 0000000000000000000000000000000000000000..ddc4aaf3621a8c9056ce0d81c89ca6a0a681bbdb --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/tests/test_plotting.py @@ -0,0 +1,88 @@ +from sympy.external.importtools import import_module + +disabled = False + +# if pyglet.gl fails to import, e.g. opengl is missing, we disable the tests +pyglet_gl = import_module("pyglet.gl", catch=(OSError,)) +pyglet_window = import_module("pyglet.window", catch=(OSError,)) +if not pyglet_gl or not pyglet_window: + disabled = True + + +from sympy.core.symbol import symbols +from sympy.functions.elementary.exponential import log +from sympy.functions.elementary.trigonometric import (cos, sin) +x, y, z = symbols('x, y, z') + + +def test_plot_2d(): + from sympy.plotting.pygletplot import PygletPlot + p = PygletPlot(x, [x, -5, 5, 4], visible=False) + p.wait_for_calculations() + + +def test_plot_2d_discontinuous(): + from sympy.plotting.pygletplot import PygletPlot + p = PygletPlot(1/x, [x, -1, 1, 2], visible=False) + p.wait_for_calculations() + + +def test_plot_3d(): + from sympy.plotting.pygletplot import PygletPlot + p = PygletPlot(x*y, [x, -5, 5, 5], [y, -5, 5, 5], visible=False) + p.wait_for_calculations() + + +def test_plot_3d_discontinuous(): + from sympy.plotting.pygletplot import PygletPlot + p = PygletPlot(1/x, [x, -3, 3, 6], [y, -1, 1, 1], visible=False) + p.wait_for_calculations() + + +def test_plot_2d_polar(): + from sympy.plotting.pygletplot import PygletPlot + p = PygletPlot(1/x, [x, -1, 1, 4], 'mode=polar', visible=False) + p.wait_for_calculations() + + +def test_plot_3d_cylinder(): + from sympy.plotting.pygletplot import PygletPlot + p = PygletPlot( + 1/y, [x, 0, 6.282, 4], [y, -1, 1, 4], 'mode=polar;style=solid', + visible=False) + p.wait_for_calculations() + + +def test_plot_3d_spherical(): + from sympy.plotting.pygletplot import PygletPlot + p = PygletPlot( + 1, [x, 0, 6.282, 4], [y, 0, 3.141, + 4], 'mode=spherical;style=wireframe', + visible=False) + p.wait_for_calculations() + + +def test_plot_2d_parametric(): + from sympy.plotting.pygletplot import PygletPlot + p = PygletPlot(sin(x), cos(x), [x, 0, 6.282, 4], visible=False) + p.wait_for_calculations() + + +def test_plot_3d_parametric(): + from sympy.plotting.pygletplot import PygletPlot + p = PygletPlot(sin(x), cos(x), x/5.0, [x, 0, 6.282, 4], visible=False) + p.wait_for_calculations() + + +def _test_plot_log(): + from sympy.plotting.pygletplot import PygletPlot + p = PygletPlot(log(x), [x, 0, 6.282, 4], 'mode=polar', visible=False) + p.wait_for_calculations() + + +def test_plot_integral(): + # Make sure it doesn't treat x as an independent variable + from sympy.plotting.pygletplot import PygletPlot + from sympy.integrals.integrals import Integral + p = PygletPlot(Integral(z*x, (x, 1, z), (z, 1, y)), visible=False) + p.wait_for_calculations() diff --git a/phivenv/Lib/site-packages/sympy/plotting/pygletplot/util.py b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/util.py new file mode 100644 index 0000000000000000000000000000000000000000..43b882ca18274dcdb273cf35680016453db3c698 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/plotting/pygletplot/util.py @@ -0,0 +1,188 @@ +try: + from ctypes import c_float, c_int, c_double +except ImportError: + pass + +import pyglet.gl as pgl +from sympy.core import S + + +def get_model_matrix(array_type=c_float, glGetMethod=pgl.glGetFloatv): + """ + Returns the current modelview matrix. + """ + m = (array_type*16)() + glGetMethod(pgl.GL_MODELVIEW_MATRIX, m) + return m + + +def get_projection_matrix(array_type=c_float, glGetMethod=pgl.glGetFloatv): + """ + Returns the current modelview matrix. + """ + m = (array_type*16)() + glGetMethod(pgl.GL_PROJECTION_MATRIX, m) + return m + + +def get_viewport(): + """ + Returns the current viewport. + """ + m = (c_int*4)() + pgl.glGetIntegerv(pgl.GL_VIEWPORT, m) + return m + + +def get_direction_vectors(): + m = get_model_matrix() + return ((m[0], m[4], m[8]), + (m[1], m[5], m[9]), + (m[2], m[6], m[10])) + + +def get_view_direction_vectors(): + m = get_model_matrix() + return ((m[0], m[1], m[2]), + (m[4], m[5], m[6]), + (m[8], m[9], m[10])) + + +def get_basis_vectors(): + return ((1, 0, 0), (0, 1, 0), (0, 0, 1)) + + +def screen_to_model(x, y, z): + m = get_model_matrix(c_double, pgl.glGetDoublev) + p = get_projection_matrix(c_double, pgl.glGetDoublev) + w = get_viewport() + mx, my, mz = c_double(), c_double(), c_double() + pgl.gluUnProject(x, y, z, m, p, w, mx, my, mz) + return float(mx.value), float(my.value), float(mz.value) + + +def model_to_screen(x, y, z): + m = get_model_matrix(c_double, pgl.glGetDoublev) + p = get_projection_matrix(c_double, pgl.glGetDoublev) + w = get_viewport() + mx, my, mz = c_double(), c_double(), c_double() + pgl.gluProject(x, y, z, m, p, w, mx, my, mz) + return float(mx.value), float(my.value), float(mz.value) + + +def vec_subs(a, b): + return tuple(a[i] - b[i] for i in range(len(a))) + + +def billboard_matrix(): + """ + Removes rotational components of + current matrix so that primitives + are always drawn facing the viewer. + + |1|0|0|x| + |0|1|0|x| + |0|0|1|x| (x means left unchanged) + |x|x|x|x| + """ + m = get_model_matrix() + # XXX: for i in range(11): m[i] = i ? + m[0] = 1 + m[1] = 0 + m[2] = 0 + m[4] = 0 + m[5] = 1 + m[6] = 0 + m[8] = 0 + m[9] = 0 + m[10] = 1 + pgl.glLoadMatrixf(m) + + +def create_bounds(): + return [[S.Infinity, S.NegativeInfinity, 0], + [S.Infinity, S.NegativeInfinity, 0], + [S.Infinity, S.NegativeInfinity, 0]] + + +def update_bounds(b, v): + if v is None: + return + for axis in range(3): + b[axis][0] = min([b[axis][0], v[axis]]) + b[axis][1] = max([b[axis][1], v[axis]]) + + +def interpolate(a_min, a_max, a_ratio): + return a_min + a_ratio * (a_max - a_min) + + +def rinterpolate(a_min, a_max, a_value): + a_range = a_max - a_min + if a_max == a_min: + a_range = 1.0 + return (a_value - a_min) / float(a_range) + + +def interpolate_color(color1, color2, ratio): + return tuple(interpolate(color1[i], color2[i], ratio) for i in range(3)) + + +def scale_value(v, v_min, v_len): + return (v - v_min) / v_len + + +def scale_value_list(flist): + v_min, v_max = min(flist), max(flist) + v_len = v_max - v_min + return [scale_value(f, v_min, v_len) for f in flist] + + +def strided_range(r_min, r_max, stride, max_steps=50): + o_min, o_max = r_min, r_max + if abs(r_min - r_max) < 0.001: + return [] + try: + range(int(r_min - r_max)) + except (TypeError, OverflowError): + return [] + if r_min > r_max: + raise ValueError("r_min cannot be greater than r_max") + r_min_s = (r_min % stride) + r_max_s = stride - (r_max % stride) + if abs(r_max_s - stride) < 0.001: + r_max_s = 0.0 + r_min -= r_min_s + r_max += r_max_s + r_steps = int((r_max - r_min)/stride) + if max_steps and r_steps > max_steps: + return strided_range(o_min, o_max, stride*2) + return [r_min] + [r_min + e*stride for e in range(1, r_steps + 1)] + [r_max] + + +def parse_option_string(s): + if not isinstance(s, str): + return None + options = {} + for token in s.split(';'): + pieces = token.split('=') + if len(pieces) == 1: + option, value = pieces[0], "" + elif len(pieces) == 2: + option, value = pieces + else: + raise ValueError("Plot option string '%s' is malformed." % (s)) + options[option.strip()] = value.strip() + return options + + +def dot_product(v1, v2): + return sum(v1[i]*v2[i] for i in range(3)) + + +def vec_sub(v1, v2): + return tuple(v1[i] - v2[i] for i in range(3)) + + +def vec_mag(v): + return sum(v[i]**2 for i in range(3))**(0.5) diff --git a/phivenv/Lib/site-packages/sympy/plotting/series.py b/phivenv/Lib/site-packages/sympy/plotting/series.py new file mode 100644 index 0000000000000000000000000000000000000000..ddd64116277668389fb8defc8289543667d2c9e8 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/plotting/series.py @@ -0,0 +1,2591 @@ +### The base class for all series +from collections.abc import Callable +from sympy.calculus.util import continuous_domain +from sympy.concrete import Sum, Product +from sympy.core.containers import Tuple +from sympy.core.expr import Expr +from sympy.core.function import arity +from sympy.core.sorting import default_sort_key +from sympy.core.symbol import Symbol +from sympy.functions import atan2, zeta, frac, ceiling, floor, im +from sympy.core.relational import (Equality, GreaterThan, + LessThan, Relational, Ne) +from sympy.core.sympify import sympify +from sympy.external import import_module +from sympy.logic.boolalg import BooleanFunction +from sympy.plotting.utils import _get_free_symbols, extract_solution +from sympy.printing.latex import latex +from sympy.printing.pycode import PythonCodePrinter +from sympy.printing.precedence import precedence +from sympy.sets.sets import Set, Interval, Union +from sympy.simplify.simplify import nsimplify +from sympy.utilities.exceptions import sympy_deprecation_warning +from sympy.utilities.lambdify import lambdify +from .intervalmath import interval +import warnings + + +class IntervalMathPrinter(PythonCodePrinter): + """A printer to be used inside `plot_implicit` when `adaptive=True`, + in which case the interval arithmetic module is going to be used, which + requires the following edits. + """ + def _print_And(self, expr): + PREC = precedence(expr) + return " & ".join(self.parenthesize(a, PREC) + for a in sorted(expr.args, key=default_sort_key)) + + def _print_Or(self, expr): + PREC = precedence(expr) + return " | ".join(self.parenthesize(a, PREC) + for a in sorted(expr.args, key=default_sort_key)) + + +def _uniform_eval(f1, f2, *args, modules=None, + force_real_eval=False, has_sum=False): + """ + Note: this is an experimental function, as such it is prone to changes. + Please, do not use it in your code. + """ + np = import_module('numpy') + + def wrapper_func(func, *args): + try: + return complex(func(*args)) + except (ZeroDivisionError, OverflowError): + return complex(np.nan, np.nan) + + # NOTE: np.vectorize is much slower than numpy vectorized operations. + # However, this modules must be able to evaluate functions also with + # mpmath or sympy. + wrapper_func = np.vectorize(wrapper_func, otypes=[complex]) + + def _eval_with_sympy(err=None): + if f2 is None: + msg = "Impossible to evaluate the provided numerical function" + if err is None: + msg += "." + else: + msg += "because the following exception was raised:\n" + "{}: {}".format(type(err).__name__, err) + raise RuntimeError(msg) + if err: + warnings.warn( + "The evaluation with %s failed.\n" % ( + "NumPy/SciPy" if not modules else modules) + + "{}: {}\n".format(type(err).__name__, err) + + "Trying to evaluate the expression with Sympy, but it might " + "be a slow operation." + ) + return wrapper_func(f2, *args) + + if modules == "sympy": + return _eval_with_sympy() + + try: + return wrapper_func(f1, *args) + except Exception as err: + return _eval_with_sympy(err) + + +def _adaptive_eval(f, x): + """Evaluate f(x) with an adaptive algorithm. Post-process the result. + If a symbolic expression is evaluated with SymPy, it might returns + another symbolic expression, containing additions, ... + Force evaluation to a float. + + Parameters + ========== + f : callable + x : float + """ + np = import_module('numpy') + + y = f(x) + if isinstance(y, Expr) and (not y.is_Number): + y = y.evalf() + y = complex(y) + if y.imag > 1e-08: + return np.nan + return y.real + + +def _get_wrapper_for_expr(ret): + wrapper = "%s" + if ret == "real": + wrapper = "re(%s)" + elif ret == "imag": + wrapper = "im(%s)" + elif ret == "abs": + wrapper = "abs(%s)" + elif ret == "arg": + wrapper = "arg(%s)" + return wrapper + + +class BaseSeries: + """Base class for the data objects containing stuff to be plotted. + + Notes + ===== + + The backend should check if it supports the data series that is given. + (e.g. TextBackend supports only LineOver1DRangeSeries). + It is the backend responsibility to know how to use the class of + data series that is given. + + Some data series classes are grouped (using a class attribute like is_2Dline) + according to the api they present (based only on convention). The backend is + not obliged to use that api (e.g. LineOver1DRangeSeries belongs to the + is_2Dline group and presents the get_points method, but the + TextBackend does not use the get_points method). + + BaseSeries + """ + + # Some flags follow. The rationale for using flags instead of checking base + # classes is that setting multiple flags is simpler than multiple + # inheritance. + + is_2Dline = False + # Some of the backends expect: + # - get_points returning 1D np.arrays list_x, list_y + # - get_color_array returning 1D np.array (done in Line2DBaseSeries) + # with the colors calculated at the points from get_points + + is_3Dline = False + # Some of the backends expect: + # - get_points returning 1D np.arrays list_x, list_y, list_y + # - get_color_array returning 1D np.array (done in Line2DBaseSeries) + # with the colors calculated at the points from get_points + + is_3Dsurface = False + # Some of the backends expect: + # - get_meshes returning mesh_x, mesh_y, mesh_z (2D np.arrays) + # - get_points an alias for get_meshes + + is_contour = False + # Some of the backends expect: + # - get_meshes returning mesh_x, mesh_y, mesh_z (2D np.arrays) + # - get_points an alias for get_meshes + + is_implicit = False + # Some of the backends expect: + # - get_meshes returning mesh_x (1D array), mesh_y(1D array, + # mesh_z (2D np.arrays) + # - get_points an alias for get_meshes + # Different from is_contour as the colormap in backend will be + # different + + is_interactive = False + # An interactive series can update its data. + + is_parametric = False + # The calculation of aesthetics expects: + # - get_parameter_points returning one or two np.arrays (1D or 2D) + # used for calculation aesthetics + + is_generic = False + # Represent generic user-provided numerical data + + is_vector = False + is_2Dvector = False + is_3Dvector = False + # Represents a 2D or 3D vector data series + + _N = 100 + # default number of discretization points for uniform sampling. Each + # subclass can set its number. + + def __init__(self, *args, **kwargs): + kwargs = _set_discretization_points(kwargs.copy(), type(self)) + # discretize the domain using only integer numbers + self.only_integers = kwargs.get("only_integers", False) + # represents the evaluation modules to be used by lambdify + self.modules = kwargs.get("modules", None) + # plot functions might create data series that might not be useful to + # be shown on the legend, for example wireframe lines on 3D plots. + self.show_in_legend = kwargs.get("show_in_legend", True) + # line and surface series can show data with a colormap, hence a + # colorbar is essential to understand the data. However, sometime it + # is useful to hide it on series-by-series base. The following keyword + # controls whether the series should show a colorbar or not. + self.colorbar = kwargs.get("colorbar", True) + # Some series might use a colormap as default coloring. Setting this + # attribute to False will inform the backends to use solid color. + self.use_cm = kwargs.get("use_cm", False) + # If True, the backend will attempt to render it on a polar-projection + # axis, or using a polar discretization if a 3D plot is requested + self.is_polar = kwargs.get("is_polar", kwargs.get("polar", False)) + # If True, the rendering will use points, not lines. + self.is_point = kwargs.get("is_point", kwargs.get("point", False)) + # some backend is able to render latex, other needs standard text + self._label = self._latex_label = "" + + self._ranges = [] + self._n = [ + int(kwargs.get("n1", self._N)), + int(kwargs.get("n2", self._N)), + int(kwargs.get("n3", self._N)) + ] + self._scales = [ + kwargs.get("xscale", "linear"), + kwargs.get("yscale", "linear"), + kwargs.get("zscale", "linear") + ] + + # enable interactive widget plots + self._params = kwargs.get("params", {}) + if not isinstance(self._params, dict): + raise TypeError("`params` must be a dictionary mapping symbols " + "to numeric values.") + if len(self._params) > 0: + self.is_interactive = True + + # contains keyword arguments that will be passed to the rendering + # function of the chosen plotting library + self.rendering_kw = kwargs.get("rendering_kw", {}) + + # numerical transformation functions to be applied to the output data: + # x, y, z (coordinates), p (parameter on parametric plots) + self._tx = kwargs.get("tx", None) + self._ty = kwargs.get("ty", None) + self._tz = kwargs.get("tz", None) + self._tp = kwargs.get("tp", None) + if not all(callable(t) or (t is None) for t in + [self._tx, self._ty, self._tz, self._tp]): + raise TypeError("`tx`, `ty`, `tz`, `tp` must be functions.") + + # list of numerical functions representing the expressions to evaluate + self._functions = [] + # signature for the numerical functions + self._signature = [] + # some expressions don't like to be evaluated over complex data. + # if that's the case, set this to True + self._force_real_eval = kwargs.get("force_real_eval", None) + # this attribute will eventually contain a dictionary with the + # discretized ranges + self._discretized_domain = None + # whether the series contains any interactive range, which is a range + # where the minimum and maximum values can be changed with an + # interactive widget + self._interactive_ranges = False + # NOTE: consider a generic summation, for example: + # s = Sum(cos(pi * x), (x, 1, y)) + # This gets lambdified to something: + # sum(cos(pi*x) for x in range(1, y+1)) + # Hence, y needs to be an integer, otherwise it raises: + # TypeError: 'complex' object cannot be interpreted as an integer + # This list will contains symbols that are upper bound to summations + # or products + self._needs_to_be_int = [] + # a color function will be responsible to set the line/surface color + # according to some logic. Each data series will et an appropriate + # default value. + self.color_func = None + # NOTE: color_func usually receives numerical functions that are going + # to be evaluated over the coordinates of the computed points (or the + # discretized meshes). + # However, if an expression is given to color_func, then it will be + # lambdified with symbols in self._signature, and it will be evaluated + # with the same data used to evaluate the plotted expression. + self._eval_color_func_with_signature = False + + def _block_lambda_functions(self, *exprs): + """Some data series can be used to plot numerical functions, others + cannot. Execute this method inside the `__init__` to prevent the + processing of numerical functions. + """ + if any(callable(e) for e in exprs): + raise TypeError(type(self).__name__ + " requires a symbolic " + "expression.") + + def _check_fs(self): + """ Checks if there are enough parameters and free symbols. + """ + exprs, ranges = self.expr, self.ranges + params, label = self.params, self.label + exprs = exprs if hasattr(exprs, "__iter__") else [exprs] + if any(callable(e) for e in exprs): + return + + # from the expression's free symbols, remove the ones used in + # the parameters and the ranges + fs = _get_free_symbols(exprs) + fs = fs.difference(params.keys()) + if ranges is not None: + fs = fs.difference([r[0] for r in ranges]) + + if len(fs) > 0: + raise ValueError( + "Incompatible expression and parameters.\n" + + "Expression: {}\n".format( + (exprs, ranges, label) if ranges is not None else (exprs, label)) + + "params: {}\n".format(params) + + "Specify what these symbols represent: {}\n".format(fs) + + "Are they ranges or parameters?" + ) + + # verify that all symbols are known (they either represent plotting + # ranges or parameters) + range_symbols = [r[0] for r in ranges] + for r in ranges: + fs = set().union(*[e.free_symbols for e in r[1:]]) + if any(t in fs for t in range_symbols): + # ranges can't depend on each other, for example this are + # not allowed: + # (x, 0, y), (y, 0, 3) + # (x, 0, y), (y, x + 2, 3) + raise ValueError("Range symbols can't be included into " + "minimum and maximum of a range. " + "Received range: %s" % str(r)) + if len(fs) > 0: + self._interactive_ranges = True + remaining_fs = fs.difference(params.keys()) + if len(remaining_fs) > 0: + raise ValueError( + "Unknown symbols found in plotting range: %s. " % (r,) + + "Are the following parameters? %s" % remaining_fs) + + def _create_lambda_func(self): + """Create the lambda functions to be used by the uniform meshing + strategy. + + Notes + ===== + The old sympy.plotting used experimental_lambdify. It created one + lambda function each time an evaluation was requested. If that failed, + it went on to create a different lambda function and evaluated it, + and so on. + + This new module changes strategy: it creates right away the default + lambda function as well as the backup one. The reason is that the + series could be interactive, hence the numerical function will be + evaluated multiple times. So, let's create the functions just once. + + This approach works fine for the majority of cases, in which the + symbolic expression is relatively short, hence the lambdification + is fast. If the expression is very long, this approach takes twice + the time to create the lambda functions. Be aware of that! + """ + exprs = self.expr if hasattr(self.expr, "__iter__") else [self.expr] + if not any(callable(e) for e in exprs): + fs = _get_free_symbols(exprs) + self._signature = sorted(fs, key=lambda t: t.name) + + # Generate a list of lambda functions, two for each expression: + # 1. the default one. + # 2. the backup one, in case of failures with the default one. + self._functions = [] + for e in exprs: + # TODO: set cse=True once this issue is solved: + # https://github.com/sympy/sympy/issues/24246 + self._functions.append([ + lambdify(self._signature, e, modules=self.modules), + lambdify(self._signature, e, modules="sympy", dummify=True), + ]) + else: + self._signature = sorted([r[0] for r in self.ranges], key=lambda t: t.name) + self._functions = [(e, None) for e in exprs] + + # deal with symbolic color_func + if isinstance(self.color_func, Expr): + self.color_func = lambdify(self._signature, self.color_func) + self._eval_color_func_with_signature = True + + def _update_range_value(self, t): + """If the value of a plotting range is a symbolic expression, + substitute the parameters in order to get a numerical value. + """ + if not self._interactive_ranges: + return complex(t) + return complex(t.subs(self.params)) + + def _create_discretized_domain(self): + """Discretize the ranges for uniform meshing strategy. + """ + # NOTE: the goal is to create a dictionary stored in + # self._discretized_domain, mapping symbols to a numpy array + # representing the discretization + discr_symbols = [] + discretizations = [] + + # create a 1D discretization + for i, r in enumerate(self.ranges): + discr_symbols.append(r[0]) + c_start = self._update_range_value(r[1]) + c_end = self._update_range_value(r[2]) + start = c_start.real if c_start.imag == c_end.imag == 0 else c_start + end = c_end.real if c_start.imag == c_end.imag == 0 else c_end + needs_integer_discr = self.only_integers or (r[0] in self._needs_to_be_int) + d = BaseSeries._discretize(start, end, self.n[i], + scale=self.scales[i], + only_integers=needs_integer_discr) + + if ((not self._force_real_eval) and (not needs_integer_discr) and + (d.dtype != "complex")): + d = d + 1j * c_start.imag + + if needs_integer_discr: + d = d.astype(int) + + discretizations.append(d) + + # create 2D or 3D + self._create_discretized_domain_helper(discr_symbols, discretizations) + + def _create_discretized_domain_helper(self, discr_symbols, discretizations): + """Create 2D or 3D discretized grids. + + Subclasses should override this method in order to implement a + different behaviour. + """ + np = import_module('numpy') + + # discretization suitable for 2D line plots, 3D surface plots, + # contours plots, vector plots + # NOTE: why indexing='ij'? Because it produces consistent results with + # np.mgrid. This is important as Mayavi requires this indexing + # to correctly compute 3D streamlines. While VTK is able to compute + # streamlines regardless of the indexing, with indexing='xy' it + # produces "strange" results with "voids" into the + # discretization volume. indexing='ij' solves the problem. + # Also note that matplotlib 2D streamlines requires indexing='xy'. + indexing = "xy" + if self.is_3Dvector or (self.is_3Dsurface and self.is_implicit): + indexing = "ij" + meshes = np.meshgrid(*discretizations, indexing=indexing) + self._discretized_domain = dict(zip(discr_symbols, meshes)) + + def _evaluate(self, cast_to_real=True): + """Evaluation of the symbolic expression (or expressions) with the + uniform meshing strategy, based on current values of the parameters. + """ + np = import_module('numpy') + + # create lambda functions + if not self._functions: + self._create_lambda_func() + # create (or update) the discretized domain + if (not self._discretized_domain) or self._interactive_ranges: + self._create_discretized_domain() + # ensure that discretized domains are returned with the proper order + discr = [self._discretized_domain[s[0]] for s in self.ranges] + + args = self._aggregate_args() + + results = [] + for f in self._functions: + r = _uniform_eval(*f, *args) + # the evaluation might produce an int/float. Need this correction. + r = self._correct_shape(np.array(r), discr[0]) + # sometime the evaluation is performed over arrays of type object. + # hence, `result` might be of type object, which don't work well + # with numpy real and imag functions. + r = r.astype(complex) + results.append(r) + + if cast_to_real: + discr = [np.real(d.astype(complex)) for d in discr] + return [*discr, *results] + + def _aggregate_args(self): + """Create a list of arguments to be passed to the lambda function, + sorted according to self._signature. + """ + args = [] + for s in self._signature: + if s in self._params.keys(): + args.append( + int(self._params[s]) if s in self._needs_to_be_int else + self._params[s] if self._force_real_eval + else complex(self._params[s])) + else: + args.append(self._discretized_domain[s]) + return args + + @property + def expr(self): + """Return the expression (or expressions) of the series.""" + return self._expr + + @expr.setter + def expr(self, e): + """Set the expression (or expressions) of the series.""" + is_iter = hasattr(e, "__iter__") + is_callable = callable(e) if not is_iter else any(callable(t) for t in e) + if is_callable: + self._expr = e + else: + self._expr = sympify(e) if not is_iter else Tuple(*e) + + # look for the upper bound of summations and products + s = set() + for e in self._expr.atoms(Sum, Product): + for a in e.args[1:]: + if isinstance(a[-1], Symbol): + s.add(a[-1]) + self._needs_to_be_int = list(s) + + # list of sympy functions that when lambdified, the corresponding + # numpy functions don't like complex-type arguments + pf = [ceiling, floor, atan2, frac, zeta] + if self._force_real_eval is not True: + check_res = [self._expr.has(f) for f in pf] + self._force_real_eval = any(check_res) + if self._force_real_eval and ((self.modules is None) or + (isinstance(self.modules, str) and "numpy" in self.modules)): + funcs = [f for f, c in zip(pf, check_res) if c] + warnings.warn("NumPy is unable to evaluate with complex " + "numbers some of the functions included in this " + "symbolic expression: %s. " % funcs + + "Hence, the evaluation will use real numbers. " + "If you believe the resulting plot is incorrect, " + "change the evaluation module by setting the " + "`modules` keyword argument.") + if self._functions: + # update lambda functions + self._create_lambda_func() + + @property + def is_3D(self): + flags3D = [self.is_3Dline, self.is_3Dsurface, self.is_3Dvector] + return any(flags3D) + + @property + def is_line(self): + flagslines = [self.is_2Dline, self.is_3Dline] + return any(flagslines) + + def _line_surface_color(self, prop, val): + """This method enables back-compatibility with old sympy.plotting""" + # NOTE: color_func is set inside the init method of the series. + # If line_color/surface_color is not a callable, then color_func will + # be set to None. + setattr(self, prop, val) + if callable(val) or isinstance(val, Expr): + self.color_func = val + setattr(self, prop, None) + elif val is not None: + self.color_func = None + + @property + def line_color(self): + return self._line_color + + @line_color.setter + def line_color(self, val): + self._line_surface_color("_line_color", val) + + @property + def n(self): + """Returns a list [n1, n2, n3] of numbers of discratization points. + """ + return self._n + + @n.setter + def n(self, v): + """Set the numbers of discretization points. ``v`` must be an int or + a list. + + Let ``s`` be a series. Then: + + * to set the number of discretization points along the x direction (or + first parameter): ``s.n = 10`` + * to set the number of discretization points along the x and y + directions (or first and second parameters): ``s.n = [10, 15]`` + * to set the number of discretization points along the x, y and z + directions: ``s.n = [10, 15, 20]`` + + The following is highly unreccomended, because it prevents + the execution of necessary code in order to keep updated data: + ``s.n[1] = 15`` + """ + if not hasattr(v, "__iter__"): + self._n[0] = v + else: + self._n[:len(v)] = v + if self._discretized_domain: + # update the discretized domain + self._create_discretized_domain() + + @property + def params(self): + """Get or set the current parameters dictionary. + + Parameters + ========== + + p : dict + + * key: symbol associated to the parameter + * val: the numeric value + """ + return self._params + + @params.setter + def params(self, p): + self._params = p + + def _post_init(self): + exprs = self.expr if hasattr(self.expr, "__iter__") else [self.expr] + if any(callable(e) for e in exprs) and self.params: + raise TypeError("`params` was provided, hence an interactive plot " + "is expected. However, interactive plots do not support " + "user-provided numerical functions.") + + # if the expressions is a lambda function and no label has been + # provided, then its better to do the following in order to avoid + # surprises on the backend + if any(callable(e) for e in exprs): + if self._label == str(self.expr): + self.label = "" + + self._check_fs() + + if hasattr(self, "adaptive") and self.adaptive and self.params: + warnings.warn("`params` was provided, hence an interactive plot " + "is expected. However, interactive plots do not support " + "adaptive evaluation. Automatically switched to " + "adaptive=False.") + self.adaptive = False + + @property + def scales(self): + return self._scales + + @scales.setter + def scales(self, v): + if isinstance(v, str): + self._scales[0] = v + else: + self._scales[:len(v)] = v + + @property + def surface_color(self): + return self._surface_color + + @surface_color.setter + def surface_color(self, val): + self._line_surface_color("_surface_color", val) + + @property + def rendering_kw(self): + return self._rendering_kw + + @rendering_kw.setter + def rendering_kw(self, kwargs): + if isinstance(kwargs, dict): + self._rendering_kw = kwargs + else: + self._rendering_kw = {} + if kwargs is not None: + warnings.warn( + "`rendering_kw` must be a dictionary, instead an " + "object of type %s was received. " % type(kwargs) + + "Automatically setting `rendering_kw` to an empty " + "dictionary") + + @staticmethod + def _discretize(start, end, N, scale="linear", only_integers=False): + """Discretize a 1D domain. + + Returns + ======= + + domain : np.ndarray with dtype=float or complex + The domain's dtype will be float or complex (depending on the + type of start/end) even if only_integers=True. It is left for + the downstream code to perform further casting, if necessary. + """ + np = import_module('numpy') + + if only_integers is True: + start, end = int(start), int(end) + N = end - start + 1 + + if scale == "linear": + return np.linspace(start, end, N) + return np.geomspace(start, end, N) + + @staticmethod + def _correct_shape(a, b): + """Convert ``a`` to a np.ndarray of the same shape of ``b``. + + Parameters + ========== + + a : int, float, complex, np.ndarray + Usually, this is the result of a numerical evaluation of a + symbolic expression. Even if a discretized domain was used to + evaluate the function, the result can be a scalar (int, float, + complex). Think for example to ``expr = Float(2)`` and + ``f = lambdify(x, expr)``. No matter the shape of the numerical + array representing x, the result of the evaluation will be + a single value. + + b : np.ndarray + It represents the correct shape that ``a`` should have. + + Returns + ======= + new_a : np.ndarray + An array with the correct shape. + """ + np = import_module('numpy') + + if not isinstance(a, np.ndarray): + a = np.array(a) + if a.shape != b.shape: + if a.shape == (): + a = a * np.ones_like(b) + else: + a = a.reshape(b.shape) + return a + + def eval_color_func(self, *args): + """Evaluate the color function. + + Parameters + ========== + + args : tuple + Arguments to be passed to the coloring function. Can be coordinates + or parameters or both. + + Notes + ===== + + The backend will request the data series to generate the numerical + data. Depending on the data series, either the data series itself or + the backend will eventually execute this function to generate the + appropriate coloring value. + """ + np = import_module('numpy') + if self.color_func is None: + # NOTE: with the line_color and surface_color attributes + # (back-compatibility with the old sympy.plotting module) it is + # possible to create a plot with a callable line_color (or + # surface_color). For example: + # p = plot(sin(x), line_color=lambda x, y: -y) + # This creates a ColoredLineOver1DRangeSeries with line_color=None + # and color_func=lambda x, y: -y, which effectively is a + # parametric series. Later we could change it to a string value: + # p[0].line_color = "red" + # However, this sets ine_color="red" and color_func=None, but the + # series is still ColoredLineOver1DRangeSeries (a parametric + # series), which will render using a color_func... + warnings.warn("This is likely not the result you were " + "looking for. Please, re-execute the plot command, this time " + "with the appropriate an appropriate value to line_color " + "or surface_color.") + return np.ones_like(args[0]) + + if self._eval_color_func_with_signature: + args = self._aggregate_args() + color = self.color_func(*args) + _re, _im = np.real(color), np.imag(color) + _re[np.invert(np.isclose(_im, np.zeros_like(_im)))] = np.nan + return _re + + nargs = arity(self.color_func) + if nargs == 1: + if self.is_2Dline and self.is_parametric: + if len(args) == 2: + # ColoredLineOver1DRangeSeries + return self._correct_shape(self.color_func(args[0]), args[0]) + # Parametric2DLineSeries + return self._correct_shape(self.color_func(args[2]), args[2]) + elif self.is_3Dline and self.is_parametric: + return self._correct_shape(self.color_func(args[3]), args[3]) + elif self.is_3Dsurface and self.is_parametric: + return self._correct_shape(self.color_func(args[3]), args[3]) + return self._correct_shape(self.color_func(args[0]), args[0]) + elif nargs == 2: + if self.is_3Dsurface and self.is_parametric: + return self._correct_shape(self.color_func(*args[3:]), args[3]) + return self._correct_shape(self.color_func(*args[:2]), args[0]) + return self._correct_shape(self.color_func(*args[:nargs]), args[0]) + + def get_data(self): + """Compute and returns the numerical data. + + The number of parameters returned by this method depends on the + specific instance. If ``s`` is the series, make sure to read + ``help(s.get_data)`` to understand what it returns. + """ + raise NotImplementedError + + def _get_wrapped_label(self, label, wrapper): + """Given a latex representation of an expression, wrap it inside + some characters. Matplotlib needs "$%s%$", K3D-Jupyter needs "%s". + """ + return wrapper % label + + def get_label(self, use_latex=False, wrapper="$%s$"): + """Return the label to be used to display the expression. + + Parameters + ========== + use_latex : bool + If False, the string representation of the expression is returned. + If True, the latex representation is returned. + wrapper : str + The backend might need the latex representation to be wrapped by + some characters. Default to ``"$%s$"``. + + Returns + ======= + label : str + """ + if use_latex is False: + return self._label + if self._label == str(self.expr): + # when the backend requests a latex label and user didn't provide + # any label + return self._get_wrapped_label(self._latex_label, wrapper) + return self._latex_label + + @property + def label(self): + return self.get_label() + + @label.setter + def label(self, val): + """Set the labels associated to this series.""" + # NOTE: the init method of any series requires a label. If the user do + # not provide it, the preprocessing function will set label=None, which + # informs the series to initialize two attributes: + # _label contains the string representation of the expression. + # _latex_label contains the latex representation of the expression. + self._label = self._latex_label = val + + @property + def ranges(self): + return self._ranges + + @ranges.setter + def ranges(self, val): + new_vals = [] + for v in val: + if v is not None: + new_vals.append(tuple([sympify(t) for t in v])) + self._ranges = new_vals + + def _apply_transform(self, *args): + """Apply transformations to the results of numerical evaluation. + + Parameters + ========== + args : tuple + Results of numerical evaluation. + + Returns + ======= + transformed_args : tuple + Tuple containing the transformed results. + """ + t = lambda x, transform: x if transform is None else transform(x) + x, y, z = None, None, None + if len(args) == 2: + x, y = args + return t(x, self._tx), t(y, self._ty) + elif (len(args) == 3) and isinstance(self, Parametric2DLineSeries): + x, y, u = args + return (t(x, self._tx), t(y, self._ty), t(u, self._tp)) + elif len(args) == 3: + x, y, z = args + return t(x, self._tx), t(y, self._ty), t(z, self._tz) + elif (len(args) == 4) and isinstance(self, Parametric3DLineSeries): + x, y, z, u = args + return (t(x, self._tx), t(y, self._ty), t(z, self._tz), t(u, self._tp)) + elif len(args) == 4: # 2D vector plot + x, y, u, v = args + return ( + t(x, self._tx), t(y, self._ty), + t(u, self._tx), t(v, self._ty) + ) + elif (len(args) == 5) and isinstance(self, ParametricSurfaceSeries): + x, y, z, u, v = args + return (t(x, self._tx), t(y, self._ty), t(z, self._tz), u, v) + elif (len(args) == 6) and self.is_3Dvector: # 3D vector plot + x, y, z, u, v, w = args + return ( + t(x, self._tx), t(y, self._ty), t(z, self._tz), + t(u, self._tx), t(v, self._ty), t(w, self._tz) + ) + elif len(args) == 6: # complex plot + x, y, _abs, _arg, img, colors = args + return ( + x, y, t(_abs, self._tz), _arg, img, colors) + return args + + def _str_helper(self, s): + pre, post = "", "" + if self.is_interactive: + pre = "interactive " + post = " and parameters " + str(tuple(self.params.keys())) + return pre + s + post + + +def _detect_poles_numerical_helper(x, y, eps=0.01, expr=None, symb=None, symbolic=False): + """Compute the steepness of each segment. If it's greater than a + threshold, set the right-point y-value non NaN and record the + corresponding x-location for further processing. + + Returns + ======= + x : np.ndarray + Unchanged x-data. + yy : np.ndarray + Modified y-data with NaN values. + """ + np = import_module('numpy') + + yy = y.copy() + threshold = np.pi / 2 - eps + for i in range(len(x) - 1): + dx = x[i + 1] - x[i] + dy = abs(y[i + 1] - y[i]) + angle = np.arctan(dy / dx) + if abs(angle) >= threshold: + yy[i + 1] = np.nan + + return x, yy + +def _detect_poles_symbolic_helper(expr, symb, start, end): + """Attempts to compute symbolic discontinuities. + + Returns + ======= + pole : list + List of symbolic poles, possibly empty. + """ + poles = [] + interval = Interval(nsimplify(start), nsimplify(end)) + res = continuous_domain(expr, symb, interval) + res = res.simplify() + if res == interval: + pass + elif (isinstance(res, Union) and + all(isinstance(t, Interval) for t in res.args)): + poles = [] + for s in res.args: + if s.left_open: + poles.append(s.left) + if s.right_open: + poles.append(s.right) + poles = list(set(poles)) + else: + raise ValueError( + f"Could not parse the following object: {res} .\n" + "Please, submit this as a bug. Consider also to set " + "`detect_poles=True`." + ) + return poles + + +### 2D lines +class Line2DBaseSeries(BaseSeries): + """A base class for 2D lines. + + - adding the label, steps and only_integers options + - making is_2Dline true + - defining get_segments and get_color_array + """ + + is_2Dline = True + _dim = 2 + _N = 1000 + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.steps = kwargs.get("steps", False) + self.is_point = kwargs.get("is_point", kwargs.get("point", False)) + self.is_filled = kwargs.get("is_filled", kwargs.get("fill", True)) + self.adaptive = kwargs.get("adaptive", False) + self.depth = kwargs.get('depth', 12) + self.use_cm = kwargs.get("use_cm", False) + self.color_func = kwargs.get("color_func", None) + self.line_color = kwargs.get("line_color", None) + self.detect_poles = kwargs.get("detect_poles", False) + self.eps = kwargs.get("eps", 0.01) + self.is_polar = kwargs.get("is_polar", kwargs.get("polar", False)) + self.unwrap = kwargs.get("unwrap", False) + # when detect_poles="symbolic", stores the location of poles so that + # they can be appropriately rendered + self.poles_locations = [] + exclude = kwargs.get("exclude", []) + if isinstance(exclude, Set): + exclude = list(extract_solution(exclude, n=100)) + if not hasattr(exclude, "__iter__"): + exclude = [exclude] + exclude = [float(e) for e in exclude] + self.exclude = sorted(exclude) + + def get_data(self): + """Return coordinates for plotting the line. + + Returns + ======= + + x: np.ndarray + x-coordinates + + y: np.ndarray + y-coordinates + + z: np.ndarray (optional) + z-coordinates in case of Parametric3DLineSeries, + Parametric3DLineInteractiveSeries + + param : np.ndarray (optional) + The parameter in case of Parametric2DLineSeries, + Parametric3DLineSeries or AbsArgLineSeries (and their + corresponding interactive series). + """ + np = import_module('numpy') + points = self._get_data_helper() + + if (isinstance(self, LineOver1DRangeSeries) and + (self.detect_poles == "symbolic")): + poles = _detect_poles_symbolic_helper( + self.expr.subs(self.params), *self.ranges[0]) + poles = np.array([float(t) for t in poles]) + t = lambda x, transform: x if transform is None else transform(x) + self.poles_locations = t(np.array(poles), self._tx) + + # postprocessing + points = self._apply_transform(*points) + + if self.is_2Dline and self.detect_poles: + if len(points) == 2: + x, y = points + x, y = _detect_poles_numerical_helper( + x, y, self.eps) + points = (x, y) + else: + x, y, p = points + x, y = _detect_poles_numerical_helper(x, y, self.eps) + points = (x, y, p) + + if self.unwrap: + kw = {} + if self.unwrap is not True: + kw = self.unwrap + if self.is_2Dline: + if len(points) == 2: + x, y = points + y = np.unwrap(y, **kw) + points = (x, y) + else: + x, y, p = points + y = np.unwrap(y, **kw) + points = (x, y, p) + + if self.steps is True: + if self.is_2Dline: + x, y = points[0], points[1] + x = np.array((x, x)).T.flatten()[1:] + y = np.array((y, y)).T.flatten()[:-1] + if self.is_parametric: + points = (x, y, points[2]) + else: + points = (x, y) + elif self.is_3Dline: + x = np.repeat(points[0], 3)[2:] + y = np.repeat(points[1], 3)[:-2] + z = np.repeat(points[2], 3)[1:-1] + if len(points) > 3: + points = (x, y, z, points[3]) + else: + points = (x, y, z) + + if len(self.exclude) > 0: + points = self._insert_exclusions(points) + return points + + def get_segments(self): + sympy_deprecation_warning( + """ + The Line2DBaseSeries.get_segments() method is deprecated. + + Instead, use the MatplotlibBackend.get_segments() method, or use + The get_points() or get_data() methods. + """, + deprecated_since_version="1.9", + active_deprecations_target="deprecated-get-segments") + + np = import_module('numpy') + points = type(self).get_data(self) + points = np.ma.array(points).T.reshape(-1, 1, self._dim) + return np.ma.concatenate([points[:-1], points[1:]], axis=1) + + def _insert_exclusions(self, points): + """Add NaN to each of the exclusion point. Practically, this adds a + NaN to the exclusion point, plus two other nearby points evaluated with + the numerical functions associated to this data series. + These nearby points are important when the number of discretization + points is low, or the scale is logarithm. + + NOTE: it would be easier to just add exclusion points to the + discretized domain before evaluation, then after evaluation add NaN + to the exclusion points. But that's only work with adaptive=False. + The following approach work even with adaptive=True. + """ + np = import_module("numpy") + points = list(points) + n = len(points) + # index of the x-coordinate (for 2d plots) or parameter (for 2d/3d + # parametric plots) + k = n - 1 + if n == 2: + k = 0 + # indices of the other coordinates + j_indeces = sorted(set(range(n)).difference([k])) + # TODO: for now, I assume that numpy functions are going to succeed + funcs = [f[0] for f in self._functions] + + for e in self.exclude: + res = points[k] - e >= 0 + # if res contains both True and False, ie, if e is found + if any(res) and any(~res): + idx = np.nanargmax(res) + # select the previous point with respect to e + idx -= 1 + # TODO: what if points[k][idx]==e or points[k][idx+1]==e? + + if idx > 0 and idx < len(points[k]) - 1: + delta_prev = abs(e - points[k][idx]) + delta_post = abs(e - points[k][idx + 1]) + delta = min(delta_prev, delta_post) / 100 + prev = e - delta + post = e + delta + + # add points to the x-coord or the parameter + points[k] = np.concatenate( + (points[k][:idx], [prev, e, post], points[k][idx+1:])) + + # add points to the other coordinates + c = 0 + for j in j_indeces: + values = funcs[c](np.array([prev, post])) + c += 1 + points[j] = np.concatenate( + (points[j][:idx], [values[0], np.nan, values[1]], points[j][idx+1:])) + return points + + @property + def var(self): + return None if not self.ranges else self.ranges[0][0] + + @property + def start(self): + if not self.ranges: + return None + try: + return self._cast(self.ranges[0][1]) + except TypeError: + return self.ranges[0][1] + + @property + def end(self): + if not self.ranges: + return None + try: + return self._cast(self.ranges[0][2]) + except TypeError: + return self.ranges[0][2] + + @property + def xscale(self): + return self._scales[0] + + @xscale.setter + def xscale(self, v): + self.scales = v + + def get_color_array(self): + np = import_module('numpy') + c = self.line_color + if hasattr(c, '__call__'): + f = np.vectorize(c) + nargs = arity(c) + if nargs == 1 and self.is_parametric: + x = self.get_parameter_points() + return f(centers_of_segments(x)) + else: + variables = list(map(centers_of_segments, self.get_points())) + if nargs == 1: + return f(variables[0]) + elif nargs == 2: + return f(*variables[:2]) + else: # only if the line is 3D (otherwise raises an error) + return f(*variables) + else: + return c*np.ones(self.nb_of_points) + + +class List2DSeries(Line2DBaseSeries): + """Representation for a line consisting of list of points.""" + + def __init__(self, list_x, list_y, label="", **kwargs): + super().__init__(**kwargs) + np = import_module('numpy') + if len(list_x) != len(list_y): + raise ValueError( + "The two lists of coordinates must have the same " + "number of elements.\n" + "Received: len(list_x) = {} ".format(len(list_x)) + + "and len(list_y) = {}".format(len(list_y)) + ) + self._block_lambda_functions(list_x, list_y) + check = lambda l: [isinstance(t, Expr) and (not t.is_number) for t in l] + if any(check(list_x) + check(list_y)) or self.params: + if not self.params: + raise ValueError("Some or all elements of the provided lists " + "are symbolic expressions, but the ``params`` dictionary " + "was not provided: those elements can't be evaluated.") + self.list_x = Tuple(*list_x) + self.list_y = Tuple(*list_y) + else: + self.list_x = np.array(list_x, dtype=np.float64) + self.list_y = np.array(list_y, dtype=np.float64) + + self._expr = (self.list_x, self.list_y) + if not any(isinstance(t, np.ndarray) for t in [self.list_x, self.list_y]): + self._check_fs() + self.is_polar = kwargs.get("is_polar", kwargs.get("polar", False)) + self.label = label + self.rendering_kw = kwargs.get("rendering_kw", {}) + if self.use_cm and self.color_func: + self.is_parametric = True + if isinstance(self.color_func, Expr): + raise TypeError( + "%s don't support symbolic " % self.__class__.__name__ + + "expression for `color_func`.") + + def __str__(self): + return "2D list plot" + + def _get_data_helper(self): + """Returns coordinates that needs to be postprocessed.""" + lx, ly = self.list_x, self.list_y + + if not self.is_interactive: + return self._eval_color_func_and_return(lx, ly) + + np = import_module('numpy') + lx = np.array([t.evalf(subs=self.params) for t in lx], dtype=float) + ly = np.array([t.evalf(subs=self.params) for t in ly], dtype=float) + return self._eval_color_func_and_return(lx, ly) + + def _eval_color_func_and_return(self, *data): + if self.use_cm and callable(self.color_func): + return [*data, self.eval_color_func(*data)] + return data + + +class LineOver1DRangeSeries(Line2DBaseSeries): + """Representation for a line consisting of a SymPy expression over a range.""" + + def __init__(self, expr, var_start_end, label="", **kwargs): + super().__init__(**kwargs) + self.expr = expr if callable(expr) else sympify(expr) + self._label = str(self.expr) if label is None else label + self._latex_label = latex(self.expr) if label is None else label + self.ranges = [var_start_end] + self._cast = complex + # for complex-related data series, this determines what data to return + # on the y-axis + self._return = kwargs.get("return", None) + self._post_init() + + if not self._interactive_ranges: + # NOTE: the following check is only possible when the minimum and + # maximum values of a plotting range are numeric + start, end = [complex(t) for t in self.ranges[0][1:]] + if im(start) != im(end): + raise ValueError( + "%s requires the imaginary " % self.__class__.__name__ + + "part of the start and end values of the range " + "to be the same.") + + if self.adaptive and self._return: + warnings.warn("The adaptive algorithm is unable to deal with " + "complex numbers. Automatically switching to uniform meshing.") + self.adaptive = False + + @property + def nb_of_points(self): + return self.n[0] + + @nb_of_points.setter + def nb_of_points(self, v): + self.n = v + + def __str__(self): + def f(t): + if isinstance(t, complex): + if t.imag != 0: + return t + return t.real + return t + pre = "interactive " if self.is_interactive else "" + post = "" + if self.is_interactive: + post = " and parameters " + str(tuple(self.params.keys())) + wrapper = _get_wrapper_for_expr(self._return) + return pre + "cartesian line: %s for %s over %s" % ( + wrapper % self.expr, + str(self.var), + str((f(self.start), f(self.end))), + ) + post + + def get_points(self): + """Return lists of coordinates for plotting. Depending on the + ``adaptive`` option, this function will either use an adaptive algorithm + or it will uniformly sample the expression over the provided range. + + This function is available for back-compatibility purposes. Consider + using ``get_data()`` instead. + + Returns + ======= + x : list + List of x-coordinates + + y : list + List of y-coordinates + """ + return self._get_data_helper() + + def _adaptive_sampling(self): + try: + if callable(self.expr): + f = self.expr + else: + f = lambdify([self.var], self.expr, self.modules) + x, y = self._adaptive_sampling_helper(f) + except Exception as err: + warnings.warn( + "The evaluation with %s failed.\n" % ( + "NumPy/SciPy" if not self.modules else self.modules) + + "{}: {}\n".format(type(err).__name__, err) + + "Trying to evaluate the expression with Sympy, but it might " + "be a slow operation." + ) + f = lambdify([self.var], self.expr, "sympy") + x, y = self._adaptive_sampling_helper(f) + return x, y + + def _adaptive_sampling_helper(self, f): + """The adaptive sampling is done by recursively checking if three + points are almost collinear. If they are not collinear, then more + points are added between those points. + + References + ========== + + .. [1] Adaptive polygonal approximation of parametric curves, + Luiz Henrique de Figueiredo. + """ + np = import_module('numpy') + + x_coords = [] + y_coords = [] + def sample(p, q, depth): + """ Samples recursively if three points are almost collinear. + For depth < 6, points are added irrespective of whether they + satisfy the collinearity condition or not. The maximum depth + allowed is 12. + """ + # Randomly sample to avoid aliasing. + random = 0.45 + np.random.rand() * 0.1 + if self.xscale == 'log': + xnew = 10**(np.log10(p[0]) + random * (np.log10(q[0]) - + np.log10(p[0]))) + else: + xnew = p[0] + random * (q[0] - p[0]) + ynew = _adaptive_eval(f, xnew) + new_point = np.array([xnew, ynew]) + + # Maximum depth + if depth > self.depth: + x_coords.append(q[0]) + y_coords.append(q[1]) + + # Sample to depth of 6 (whether the line is flat or not) + # without using linspace (to avoid aliasing). + elif depth < 6: + sample(p, new_point, depth + 1) + sample(new_point, q, depth + 1) + + # Sample ten points if complex values are encountered + # at both ends. If there is a real value in between, then + # sample those points further. + elif p[1] is None and q[1] is None: + if self.xscale == 'log': + xarray = np.logspace(p[0], q[0], 10) + else: + xarray = np.linspace(p[0], q[0], 10) + yarray = list(map(f, xarray)) + if not all(y is None for y in yarray): + for i in range(len(yarray) - 1): + if not (yarray[i] is None and yarray[i + 1] is None): + sample([xarray[i], yarray[i]], + [xarray[i + 1], yarray[i + 1]], depth + 1) + + # Sample further if one of the end points in None (i.e. a + # complex value) or the three points are not almost collinear. + elif (p[1] is None or q[1] is None or new_point[1] is None + or not flat(p, new_point, q)): + sample(p, new_point, depth + 1) + sample(new_point, q, depth + 1) + else: + x_coords.append(q[0]) + y_coords.append(q[1]) + + f_start = _adaptive_eval(f, self.start.real) + f_end = _adaptive_eval(f, self.end.real) + x_coords.append(self.start.real) + y_coords.append(f_start) + sample(np.array([self.start.real, f_start]), + np.array([self.end.real, f_end]), 0) + + return (x_coords, y_coords) + + def _uniform_sampling(self): + np = import_module('numpy') + + x, result = self._evaluate() + _re, _im = np.real(result), np.imag(result) + _re = self._correct_shape(_re, x) + _im = self._correct_shape(_im, x) + return x, _re, _im + + def _get_data_helper(self): + """Returns coordinates that needs to be postprocessed. + """ + np = import_module('numpy') + if self.adaptive and (not self.only_integers): + x, y = self._adaptive_sampling() + return [np.array(t) for t in [x, y]] + + x, _re, _im = self._uniform_sampling() + + if self._return is None: + # The evaluation could produce complex numbers. Set real elements + # to NaN where there are non-zero imaginary elements + _re[np.invert(np.isclose(_im, np.zeros_like(_im)))] = np.nan + elif self._return == "real": + pass + elif self._return == "imag": + _re = _im + elif self._return == "abs": + _re = np.sqrt(_re**2 + _im**2) + elif self._return == "arg": + _re = np.arctan2(_im, _re) + else: + raise ValueError("`_return` not recognized. " + "Received: %s" % self._return) + + return x, _re + + +class ParametricLineBaseSeries(Line2DBaseSeries): + is_parametric = True + + def _set_parametric_line_label(self, label): + """Logic to set the correct label to be shown on the plot. + If `use_cm=True` there will be a colorbar, so we show the parameter. + If `use_cm=False`, there might be a legend, so we show the expressions. + + Parameters + ========== + label : str + label passed in by the pre-processor or the user + """ + self._label = str(self.var) if label is None else label + self._latex_label = latex(self.var) if label is None else label + if (self.use_cm is False) and (self._label == str(self.var)): + self._label = str(self.expr) + self._latex_label = latex(self.expr) + # if the expressions is a lambda function and use_cm=False and no label + # has been provided, then its better to do the following in order to + # avoid surprises on the backend + if any(callable(e) for e in self.expr) and (not self.use_cm): + if self._label == str(self.expr): + self._label = "" + + def get_label(self, use_latex=False, wrapper="$%s$"): + # parametric lines returns the representation of the parameter to be + # shown on the colorbar if `use_cm=True`, otherwise it returns the + # representation of the expression to be placed on the legend. + if self.use_cm: + if str(self.var) == self._label: + if use_latex: + return self._get_wrapped_label(latex(self.var), wrapper) + return str(self.var) + # here the user has provided a custom label + return self._label + if use_latex: + if self._label != str(self.expr): + return self._latex_label + return self._get_wrapped_label(self._latex_label, wrapper) + return self._label + + def _get_data_helper(self): + """Returns coordinates that needs to be postprocessed. + Depending on the `adaptive` option, this function will either use an + adaptive algorithm or it will uniformly sample the expression over the + provided range. + """ + if self.adaptive: + np = import_module("numpy") + coords = self._adaptive_sampling() + coords = [np.array(t) for t in coords] + else: + coords = self._uniform_sampling() + + if self.is_2Dline and self.is_polar: + # when plot_polar is executed with polar_axis=True + np = import_module('numpy') + x, y, _ = coords + r = np.sqrt(x**2 + y**2) + t = np.arctan2(y, x) + coords = [t, r, coords[-1]] + + if callable(self.color_func): + coords = list(coords) + coords[-1] = self.eval_color_func(*coords) + + return coords + + def _uniform_sampling(self): + """Returns coordinates that needs to be postprocessed.""" + np = import_module('numpy') + + results = self._evaluate() + for i, r in enumerate(results): + _re, _im = np.real(r), np.imag(r) + _re[np.invert(np.isclose(_im, np.zeros_like(_im)))] = np.nan + results[i] = _re + + return [*results[1:], results[0]] + + def get_parameter_points(self): + return self.get_data()[-1] + + def get_points(self): + """ Return lists of coordinates for plotting. Depending on the + ``adaptive`` option, this function will either use an adaptive algorithm + or it will uniformly sample the expression over the provided range. + + This function is available for back-compatibility purposes. Consider + using ``get_data()`` instead. + + Returns + ======= + x : list + List of x-coordinates + y : list + List of y-coordinates + z : list + List of z-coordinates, only for 3D parametric line plot. + """ + return self._get_data_helper()[:-1] + + @property + def nb_of_points(self): + return self.n[0] + + @nb_of_points.setter + def nb_of_points(self, v): + self.n = v + + +class Parametric2DLineSeries(ParametricLineBaseSeries): + """Representation for a line consisting of two parametric SymPy expressions + over a range.""" + + is_2Dline = True + + def __init__(self, expr_x, expr_y, var_start_end, label="", **kwargs): + super().__init__(**kwargs) + self.expr_x = expr_x if callable(expr_x) else sympify(expr_x) + self.expr_y = expr_y if callable(expr_y) else sympify(expr_y) + self.expr = (self.expr_x, self.expr_y) + self.ranges = [var_start_end] + self._cast = float + self.use_cm = kwargs.get("use_cm", True) + self._set_parametric_line_label(label) + self._post_init() + + def __str__(self): + return self._str_helper( + "parametric cartesian line: (%s, %s) for %s over %s" % ( + str(self.expr_x), + str(self.expr_y), + str(self.var), + str((self.start, self.end)) + )) + + def _adaptive_sampling(self): + try: + if callable(self.expr_x) and callable(self.expr_y): + f_x = self.expr_x + f_y = self.expr_y + else: + f_x = lambdify([self.var], self.expr_x) + f_y = lambdify([self.var], self.expr_y) + x, y, p = self._adaptive_sampling_helper(f_x, f_y) + except Exception as err: + warnings.warn( + "The evaluation with %s failed.\n" % ( + "NumPy/SciPy" if not self.modules else self.modules) + + "{}: {}\n".format(type(err).__name__, err) + + "Trying to evaluate the expression with Sympy, but it might " + "be a slow operation." + ) + f_x = lambdify([self.var], self.expr_x, "sympy") + f_y = lambdify([self.var], self.expr_y, "sympy") + x, y, p = self._adaptive_sampling_helper(f_x, f_y) + return x, y, p + + def _adaptive_sampling_helper(self, f_x, f_y): + """The adaptive sampling is done by recursively checking if three + points are almost collinear. If they are not collinear, then more + points are added between those points. + + References + ========== + + .. [1] Adaptive polygonal approximation of parametric curves, + Luiz Henrique de Figueiredo. + """ + x_coords = [] + y_coords = [] + param = [] + + def sample(param_p, param_q, p, q, depth): + """ Samples recursively if three points are almost collinear. + For depth < 6, points are added irrespective of whether they + satisfy the collinearity condition or not. The maximum depth + allowed is 12. + """ + # Randomly sample to avoid aliasing. + np = import_module('numpy') + random = 0.45 + np.random.rand() * 0.1 + param_new = param_p + random * (param_q - param_p) + xnew = _adaptive_eval(f_x, param_new) + ynew = _adaptive_eval(f_y, param_new) + new_point = np.array([xnew, ynew]) + + # Maximum depth + if depth > self.depth: + x_coords.append(q[0]) + y_coords.append(q[1]) + param.append(param_p) + + # Sample irrespective of whether the line is flat till the + # depth of 6. We are not using linspace to avoid aliasing. + elif depth < 6: + sample(param_p, param_new, p, new_point, depth + 1) + sample(param_new, param_q, new_point, q, depth + 1) + + # Sample ten points if complex values are encountered + # at both ends. If there is a real value in between, then + # sample those points further. + elif ((p[0] is None and q[1] is None) or + (p[1] is None and q[1] is None)): + param_array = np.linspace(param_p, param_q, 10) + x_array = [_adaptive_eval(f_x, t) for t in param_array] + y_array = [_adaptive_eval(f_y, t) for t in param_array] + if not all(x is None and y is None + for x, y in zip(x_array, y_array)): + for i in range(len(y_array) - 1): + if ((x_array[i] is not None and y_array[i] is not None) or + (x_array[i + 1] is not None and y_array[i + 1] is not None)): + point_a = [x_array[i], y_array[i]] + point_b = [x_array[i + 1], y_array[i + 1]] + sample(param_array[i], param_array[i], point_a, + point_b, depth + 1) + + # Sample further if one of the end points in None (i.e. a complex + # value) or the three points are not almost collinear. + elif (p[0] is None or p[1] is None + or q[1] is None or q[0] is None + or not flat(p, new_point, q)): + sample(param_p, param_new, p, new_point, depth + 1) + sample(param_new, param_q, new_point, q, depth + 1) + else: + x_coords.append(q[0]) + y_coords.append(q[1]) + param.append(param_p) + + f_start_x = _adaptive_eval(f_x, self.start) + f_start_y = _adaptive_eval(f_y, self.start) + start = [f_start_x, f_start_y] + f_end_x = _adaptive_eval(f_x, self.end) + f_end_y = _adaptive_eval(f_y, self.end) + end = [f_end_x, f_end_y] + x_coords.append(f_start_x) + y_coords.append(f_start_y) + param.append(self.start) + sample(self.start, self.end, start, end, 0) + + return x_coords, y_coords, param + + +### 3D lines +class Line3DBaseSeries(Line2DBaseSeries): + """A base class for 3D lines. + + Most of the stuff is derived from Line2DBaseSeries.""" + + is_2Dline = False + is_3Dline = True + _dim = 3 + + def __init__(self): + super().__init__() + + +class Parametric3DLineSeries(ParametricLineBaseSeries): + """Representation for a 3D line consisting of three parametric SymPy + expressions and a range.""" + + is_2Dline = False + is_3Dline = True + + def __init__(self, expr_x, expr_y, expr_z, var_start_end, label="", **kwargs): + super().__init__(**kwargs) + self.expr_x = expr_x if callable(expr_x) else sympify(expr_x) + self.expr_y = expr_y if callable(expr_y) else sympify(expr_y) + self.expr_z = expr_z if callable(expr_z) else sympify(expr_z) + self.expr = (self.expr_x, self.expr_y, self.expr_z) + self.ranges = [var_start_end] + self._cast = float + self.adaptive = False + self.use_cm = kwargs.get("use_cm", True) + self._set_parametric_line_label(label) + self._post_init() + # TODO: remove this + self._xlim = None + self._ylim = None + self._zlim = None + + def __str__(self): + return self._str_helper( + "3D parametric cartesian line: (%s, %s, %s) for %s over %s" % ( + str(self.expr_x), + str(self.expr_y), + str(self.expr_z), + str(self.var), + str((self.start, self.end)) + )) + + def get_data(self): + # TODO: remove this + np = import_module("numpy") + x, y, z, p = super().get_data() + self._xlim = (np.amin(x), np.amax(x)) + self._ylim = (np.amin(y), np.amax(y)) + self._zlim = (np.amin(z), np.amax(z)) + return x, y, z, p + + +### Surfaces +class SurfaceBaseSeries(BaseSeries): + """A base class for 3D surfaces.""" + + is_3Dsurface = True + + def __init__(self, *args, **kwargs): + super().__init__(**kwargs) + self.use_cm = kwargs.get("use_cm", False) + # NOTE: why should SurfaceOver2DRangeSeries support is polar? + # After all, the same result can be achieve with + # ParametricSurfaceSeries. For example: + # sin(r) for (r, 0, 2 * pi) and (theta, 0, pi/2) can be parameterized + # as (r * cos(theta), r * sin(theta), sin(t)) for (r, 0, 2 * pi) and + # (theta, 0, pi/2). + # Because it is faster to evaluate (important for interactive plots). + self.is_polar = kwargs.get("is_polar", kwargs.get("polar", False)) + self.surface_color = kwargs.get("surface_color", None) + self.color_func = kwargs.get("color_func", lambda x, y, z: z) + if callable(self.surface_color): + self.color_func = self.surface_color + self.surface_color = None + + def _set_surface_label(self, label): + exprs = self.expr + self._label = str(exprs) if label is None else label + self._latex_label = latex(exprs) if label is None else label + # if the expressions is a lambda function and no label + # has been provided, then its better to do the following to avoid + # surprises on the backend + is_lambda = (callable(exprs) if not hasattr(exprs, "__iter__") + else any(callable(e) for e in exprs)) + if is_lambda and (self._label == str(exprs)): + self._label = "" + self._latex_label = "" + + def get_color_array(self): + np = import_module('numpy') + c = self.surface_color + if isinstance(c, Callable): + f = np.vectorize(c) + nargs = arity(c) + if self.is_parametric: + variables = list(map(centers_of_faces, self.get_parameter_meshes())) + if nargs == 1: + return f(variables[0]) + elif nargs == 2: + return f(*variables) + variables = list(map(centers_of_faces, self.get_meshes())) + if nargs == 1: + return f(variables[0]) + elif nargs == 2: + return f(*variables[:2]) + else: + return f(*variables) + else: + if isinstance(self, SurfaceOver2DRangeSeries): + return c*np.ones(min(self.nb_of_points_x, self.nb_of_points_y)) + else: + return c*np.ones(min(self.nb_of_points_u, self.nb_of_points_v)) + + +class SurfaceOver2DRangeSeries(SurfaceBaseSeries): + """Representation for a 3D surface consisting of a SymPy expression and 2D + range.""" + + def __init__(self, expr, var_start_end_x, var_start_end_y, label="", **kwargs): + super().__init__(**kwargs) + self.expr = expr if callable(expr) else sympify(expr) + self.ranges = [var_start_end_x, var_start_end_y] + self._set_surface_label(label) + self._post_init() + # TODO: remove this + self._xlim = (self.start_x, self.end_x) + self._ylim = (self.start_y, self.end_y) + + @property + def var_x(self): + return self.ranges[0][0] + + @property + def var_y(self): + return self.ranges[1][0] + + @property + def start_x(self): + try: + return float(self.ranges[0][1]) + except TypeError: + return self.ranges[0][1] + + @property + def end_x(self): + try: + return float(self.ranges[0][2]) + except TypeError: + return self.ranges[0][2] + + @property + def start_y(self): + try: + return float(self.ranges[1][1]) + except TypeError: + return self.ranges[1][1] + + @property + def end_y(self): + try: + return float(self.ranges[1][2]) + except TypeError: + return self.ranges[1][2] + + @property + def nb_of_points_x(self): + return self.n[0] + + @nb_of_points_x.setter + def nb_of_points_x(self, v): + n = self.n + self.n = [v, n[1:]] + + @property + def nb_of_points_y(self): + return self.n[1] + + @nb_of_points_y.setter + def nb_of_points_y(self, v): + n = self.n + self.n = [n[0], v, n[2]] + + def __str__(self): + series_type = "cartesian surface" if self.is_3Dsurface else "contour" + return self._str_helper( + series_type + ": %s for" " %s over %s and %s over %s" % ( + str(self.expr), + str(self.var_x), str((self.start_x, self.end_x)), + str(self.var_y), str((self.start_y, self.end_y)), + )) + + def get_meshes(self): + """Return the x,y,z coordinates for plotting the surface. + This function is available for back-compatibility purposes. Consider + using ``get_data()`` instead. + """ + return self.get_data() + + def get_data(self): + """Return arrays of coordinates for plotting. + + Returns + ======= + mesh_x : np.ndarray + Discretized x-domain. + mesh_y : np.ndarray + Discretized y-domain. + mesh_z : np.ndarray + Results of the evaluation. + """ + np = import_module('numpy') + + results = self._evaluate() + # mask out complex values + for i, r in enumerate(results): + _re, _im = np.real(r), np.imag(r) + _re[np.invert(np.isclose(_im, np.zeros_like(_im)))] = np.nan + results[i] = _re + + x, y, z = results + if self.is_polar and self.is_3Dsurface: + r = x.copy() + x = r * np.cos(y) + y = r * np.sin(y) + + # TODO: remove this + self._zlim = (np.amin(z), np.amax(z)) + + return self._apply_transform(x, y, z) + + +class ParametricSurfaceSeries(SurfaceBaseSeries): + """Representation for a 3D surface consisting of three parametric SymPy + expressions and a range.""" + + is_parametric = True + + def __init__(self, expr_x, expr_y, expr_z, + var_start_end_u, var_start_end_v, label="", **kwargs): + super().__init__(**kwargs) + self.expr_x = expr_x if callable(expr_x) else sympify(expr_x) + self.expr_y = expr_y if callable(expr_y) else sympify(expr_y) + self.expr_z = expr_z if callable(expr_z) else sympify(expr_z) + self.expr = (self.expr_x, self.expr_y, self.expr_z) + self.ranges = [var_start_end_u, var_start_end_v] + self.color_func = kwargs.get("color_func", lambda x, y, z, u, v: z) + self._set_surface_label(label) + self._post_init() + + @property + def var_u(self): + return self.ranges[0][0] + + @property + def var_v(self): + return self.ranges[1][0] + + @property + def start_u(self): + try: + return float(self.ranges[0][1]) + except TypeError: + return self.ranges[0][1] + + @property + def end_u(self): + try: + return float(self.ranges[0][2]) + except TypeError: + return self.ranges[0][2] + + @property + def start_v(self): + try: + return float(self.ranges[1][1]) + except TypeError: + return self.ranges[1][1] + + @property + def end_v(self): + try: + return float(self.ranges[1][2]) + except TypeError: + return self.ranges[1][2] + + @property + def nb_of_points_u(self): + return self.n[0] + + @nb_of_points_u.setter + def nb_of_points_u(self, v): + n = self.n + self.n = [v, n[1:]] + + @property + def nb_of_points_v(self): + return self.n[1] + + @nb_of_points_v.setter + def nb_of_points_v(self, v): + n = self.n + self.n = [n[0], v, n[2]] + + def __str__(self): + return self._str_helper( + "parametric cartesian surface: (%s, %s, %s) for" + " %s over %s and %s over %s" % ( + str(self.expr_x), str(self.expr_y), str(self.expr_z), + str(self.var_u), str((self.start_u, self.end_u)), + str(self.var_v), str((self.start_v, self.end_v)), + )) + + def get_parameter_meshes(self): + return self.get_data()[3:] + + def get_meshes(self): + """Return the x,y,z coordinates for plotting the surface. + This function is available for back-compatibility purposes. Consider + using ``get_data()`` instead. + """ + return self.get_data()[:3] + + def get_data(self): + """Return arrays of coordinates for plotting. + + Returns + ======= + x : np.ndarray [n2 x n1] + x-coordinates. + y : np.ndarray [n2 x n1] + y-coordinates. + z : np.ndarray [n2 x n1] + z-coordinates. + mesh_u : np.ndarray [n2 x n1] + Discretized u range. + mesh_v : np.ndarray [n2 x n1] + Discretized v range. + """ + np = import_module('numpy') + + results = self._evaluate() + # mask out complex values + for i, r in enumerate(results): + _re, _im = np.real(r), np.imag(r) + _re[np.invert(np.isclose(_im, np.zeros_like(_im)))] = np.nan + results[i] = _re + + # TODO: remove this + x, y, z = results[2:] + self._xlim = (np.amin(x), np.amax(x)) + self._ylim = (np.amin(y), np.amax(y)) + self._zlim = (np.amin(z), np.amax(z)) + + return self._apply_transform(*results[2:], *results[:2]) + + +### Contours +class ContourSeries(SurfaceOver2DRangeSeries): + """Representation for a contour plot.""" + + is_3Dsurface = False + is_contour = True + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_filled = kwargs.get("is_filled", kwargs.get("fill", True)) + self.show_clabels = kwargs.get("clabels", True) + + # NOTE: contour plots are used by plot_contour, plot_vector and + # plot_complex_vector. By implementing contour_kw we are able to + # quickly target the contour plot. + self.rendering_kw = kwargs.get("contour_kw", + kwargs.get("rendering_kw", {})) + + +class GenericDataSeries(BaseSeries): + """Represents generic numerical data. + + Notes + ===== + This class serves the purpose of back-compatibility with the "markers, + annotations, fill, rectangles" keyword arguments that represent + user-provided numerical data. In particular, it solves the problem of + combining together two or more plot-objects with the ``extend`` or + ``append`` methods: user-provided numerical data is also taken into + consideration because it is stored in this series class. + + Also note that the current implementation is far from optimal, as each + keyword argument is stored into an attribute in the ``Plot`` class, which + requires a hard-coded if-statement in the ``MatplotlibBackend`` class. + The implementation suggests that it is ok to add attributes and + if-statements to provide more and more functionalities for user-provided + numerical data (e.g. adding horizontal lines, or vertical lines, or bar + plots, etc). However, in doing so one would reinvent the wheel: plotting + libraries (like Matplotlib) already implements the necessary API. + + Instead of adding more keyword arguments and attributes, users interested + in adding custom numerical data to a plot should retrieve the figure + created by this plotting module. For example, this code: + + .. plot:: + :context: close-figs + :include-source: True + + from sympy import Symbol, plot, cos + x = Symbol("x") + p = plot(cos(x), markers=[{"args": [[0, 1, 2], [0, 1, -1], "*"]}]) + + Becomes: + + .. plot:: + :context: close-figs + :include-source: True + + p = plot(cos(x), backend="matplotlib") + fig, ax = p._backend.fig, p._backend.ax + ax.plot([0, 1, 2], [0, 1, -1], "*") + fig + + Which is far better in terms of readability. Also, it gives access to the + full plotting library capabilities, without the need to reinvent the wheel. + """ + is_generic = True + + def __init__(self, tp, *args, **kwargs): + self.type = tp + self.args = args + self.rendering_kw = kwargs + + def get_data(self): + return self.args + + +class ImplicitSeries(BaseSeries): + """Representation for 2D Implicit plot.""" + + is_implicit = True + use_cm = False + _N = 100 + + def __init__(self, expr, var_start_end_x, var_start_end_y, label="", **kwargs): + super().__init__(**kwargs) + self.adaptive = kwargs.get("adaptive", False) + self.expr = expr + self._label = str(expr) if label is None else label + self._latex_label = latex(expr) if label is None else label + self.ranges = [var_start_end_x, var_start_end_y] + self.var_x, self.start_x, self.end_x = self.ranges[0] + self.var_y, self.start_y, self.end_y = self.ranges[1] + self._color = kwargs.get("color", kwargs.get("line_color", None)) + + if self.is_interactive and self.adaptive: + raise NotImplementedError("Interactive plot with `adaptive=True` " + "is not supported.") + + # Check whether the depth is greater than 4 or less than 0. + depth = kwargs.get("depth", 0) + if depth > 4: + depth = 4 + elif depth < 0: + depth = 0 + self.depth = 4 + depth + self._post_init() + + @property + def expr(self): + if self.adaptive: + return self._adaptive_expr + return self._non_adaptive_expr + + @expr.setter + def expr(self, expr): + self._block_lambda_functions(expr) + # these are needed for adaptive evaluation + expr, has_equality = self._has_equality(sympify(expr)) + self._adaptive_expr = expr + self.has_equality = has_equality + self._label = str(expr) + self._latex_label = latex(expr) + + if isinstance(expr, (BooleanFunction, Ne)) and (not self.adaptive): + self.adaptive = True + msg = "contains Boolean functions. " + if isinstance(expr, Ne): + msg = "is an unequality. " + warnings.warn( + "The provided expression " + msg + + "In order to plot the expression, the algorithm " + + "automatically switched to an adaptive sampling." + ) + + if isinstance(expr, BooleanFunction): + self._non_adaptive_expr = None + self._is_equality = False + else: + # these are needed for uniform meshing evaluation + expr, is_equality = self._preprocess_meshgrid_expression(expr, self.adaptive) + self._non_adaptive_expr = expr + self._is_equality = is_equality + + @property + def line_color(self): + return self._color + + @line_color.setter + def line_color(self, v): + self._color = v + + color = line_color + + def _has_equality(self, expr): + # Represents whether the expression contains an Equality, GreaterThan + # or LessThan + has_equality = False + + def arg_expand(bool_expr): + """Recursively expands the arguments of an Boolean Function""" + for arg in bool_expr.args: + if isinstance(arg, BooleanFunction): + arg_expand(arg) + elif isinstance(arg, Relational): + arg_list.append(arg) + + arg_list = [] + if isinstance(expr, BooleanFunction): + arg_expand(expr) + # Check whether there is an equality in the expression provided. + if any(isinstance(e, (Equality, GreaterThan, LessThan)) for e in arg_list): + has_equality = True + elif not isinstance(expr, Relational): + expr = Equality(expr, 0) + has_equality = True + elif isinstance(expr, (Equality, GreaterThan, LessThan)): + has_equality = True + + return expr, has_equality + + def __str__(self): + f = lambda t: float(t) if len(t.free_symbols) == 0 else t + + return self._str_helper( + "Implicit expression: %s for %s over %s and %s over %s") % ( + str(self._adaptive_expr), + str(self.var_x), + str((f(self.start_x), f(self.end_x))), + str(self.var_y), + str((f(self.start_y), f(self.end_y))), + ) + + def get_data(self): + """Returns numerical data. + + Returns + ======= + + If the series is evaluated with the `adaptive=True` it returns: + + interval_list : list + List of bounding rectangular intervals to be postprocessed and + eventually used with Matplotlib's ``fill`` command. + dummy : str + A string containing ``"fill"``. + + Otherwise, it returns 2D numpy arrays to be used with Matplotlib's + ``contour`` or ``contourf`` commands: + + x_array : np.ndarray + y_array : np.ndarray + z_array : np.ndarray + plot_type : str + A string specifying which plot command to use, ``"contour"`` + or ``"contourf"``. + """ + if self.adaptive: + data = self._adaptive_eval() + if data is not None: + return data + + return self._get_meshes_grid() + + def _adaptive_eval(self): + """ + References + ========== + + .. [1] Jeffrey Allen Tupper. Reliable Two-Dimensional Graphing Methods for + Mathematical Formulae with Two Free Variables. + + .. [2] Jeffrey Allen Tupper. Graphing Equations with Generalized Interval + Arithmetic. Master's thesis. University of Toronto, 1996 + """ + import sympy.plotting.intervalmath.lib_interval as li + + user_functions = {} + printer = IntervalMathPrinter({ + 'fully_qualified_modules': False, 'inline': True, + 'allow_unknown_functions': True, + 'user_functions': user_functions}) + + keys = [t for t in dir(li) if ("__" not in t) and (t not in ["import_module", "interval"])] + vals = [getattr(li, k) for k in keys] + d = dict(zip(keys, vals)) + func = lambdify((self.var_x, self.var_y), self.expr, modules=[d], printer=printer) + data = None + + try: + data = self._get_raster_interval(func) + except NameError as err: + warnings.warn( + "Adaptive meshing could not be applied to the" + " expression, as some functions are not yet implemented" + " in the interval math module:\n\n" + "NameError: %s\n\n" % err + + "Proceeding with uniform meshing." + ) + self.adaptive = False + except TypeError: + warnings.warn( + "Adaptive meshing could not be applied to the" + " expression. Using uniform meshing.") + self.adaptive = False + + return data + + def _get_raster_interval(self, func): + """Uses interval math to adaptively mesh and obtain the plot""" + np = import_module('numpy') + + k = self.depth + interval_list = [] + sx, sy = [float(t) for t in [self.start_x, self.start_y]] + ex, ey = [float(t) for t in [self.end_x, self.end_y]] + # Create initial 32 divisions + xsample = np.linspace(sx, ex, 33) + ysample = np.linspace(sy, ey, 33) + + # Add a small jitter so that there are no false positives for equality. + # Ex: y==x becomes True for x interval(1, 2) and y interval(1, 2) + # which will draw a rectangle. + jitterx = ( + (np.random.rand(len(xsample)) * 2 - 1) + * (ex - sx) + / 2 ** 20 + ) + jittery = ( + (np.random.rand(len(ysample)) * 2 - 1) + * (ey - sy) + / 2 ** 20 + ) + xsample += jitterx + ysample += jittery + + xinter = [interval(x1, x2) for x1, x2 in zip(xsample[:-1], xsample[1:])] + yinter = [interval(y1, y2) for y1, y2 in zip(ysample[:-1], ysample[1:])] + interval_list = [[x, y] for x in xinter for y in yinter] + plot_list = [] + + # recursive call refinepixels which subdivides the intervals which are + # neither True nor False according to the expression. + def refine_pixels(interval_list): + """Evaluates the intervals and subdivides the interval if the + expression is partially satisfied.""" + temp_interval_list = [] + plot_list = [] + for intervals in interval_list: + + # Convert the array indices to x and y values + intervalx = intervals[0] + intervaly = intervals[1] + func_eval = func(intervalx, intervaly) + # The expression is valid in the interval. Change the contour + # array values to 1. + if func_eval[1] is False or func_eval[0] is False: + pass + elif func_eval == (True, True): + plot_list.append([intervalx, intervaly]) + elif func_eval[1] is None or func_eval[0] is None: + # Subdivide + avgx = intervalx.mid + avgy = intervaly.mid + a = interval(intervalx.start, avgx) + b = interval(avgx, intervalx.end) + c = interval(intervaly.start, avgy) + d = interval(avgy, intervaly.end) + temp_interval_list.append([a, c]) + temp_interval_list.append([a, d]) + temp_interval_list.append([b, c]) + temp_interval_list.append([b, d]) + return temp_interval_list, plot_list + + while k >= 0 and len(interval_list): + interval_list, plot_list_temp = refine_pixels(interval_list) + plot_list.extend(plot_list_temp) + k = k - 1 + # Check whether the expression represents an equality + # If it represents an equality, then none of the intervals + # would have satisfied the expression due to floating point + # differences. Add all the undecided values to the plot. + if self.has_equality: + for intervals in interval_list: + intervalx = intervals[0] + intervaly = intervals[1] + func_eval = func(intervalx, intervaly) + if func_eval[1] and func_eval[0] is not False: + plot_list.append([intervalx, intervaly]) + return plot_list, "fill" + + def _get_meshes_grid(self): + """Generates the mesh for generating a contour. + + In the case of equality, ``contour`` function of matplotlib can + be used. In other cases, matplotlib's ``contourf`` is used. + """ + np = import_module('numpy') + + xarray, yarray, z_grid = self._evaluate() + _re, _im = np.real(z_grid), np.imag(z_grid) + _re[np.invert(np.isclose(_im, np.zeros_like(_im)))] = np.nan + if self._is_equality: + return xarray, yarray, _re, 'contour' + return xarray, yarray, _re, 'contourf' + + @staticmethod + def _preprocess_meshgrid_expression(expr, adaptive): + """If the expression is a Relational, rewrite it as a single + expression. + + Returns + ======= + + expr : Expr + The rewritten expression + + equality : Boolean + Whether the original expression was an Equality or not. + """ + equality = False + if isinstance(expr, Equality): + expr = expr.lhs - expr.rhs + equality = True + elif isinstance(expr, Relational): + expr = expr.gts - expr.lts + elif not adaptive: + raise NotImplementedError( + "The expression is not supported for " + "plotting in uniform meshed plot." + ) + return expr, equality + + def get_label(self, use_latex=False, wrapper="$%s$"): + """Return the label to be used to display the expression. + + Parameters + ========== + use_latex : bool + If False, the string representation of the expression is returned. + If True, the latex representation is returned. + wrapper : str + The backend might need the latex representation to be wrapped by + some characters. Default to ``"$%s$"``. + + Returns + ======= + label : str + """ + if use_latex is False: + return self._label + if self._label == str(self._adaptive_expr): + return self._get_wrapped_label(self._latex_label, wrapper) + return self._latex_label + + +############################################################################## +# Finding the centers of line segments or mesh faces +############################################################################## + +def centers_of_segments(array): + np = import_module('numpy') + return np.mean(np.vstack((array[:-1], array[1:])), 0) + + +def centers_of_faces(array): + np = import_module('numpy') + return np.mean(np.dstack((array[:-1, :-1], + array[1:, :-1], + array[:-1, 1:], + array[:-1, :-1], + )), 2) + + +def flat(x, y, z, eps=1e-3): + """Checks whether three points are almost collinear""" + np = import_module('numpy') + # Workaround plotting piecewise (#8577) + vector_a = (x - y).astype(float) + vector_b = (z - y).astype(float) + dot_product = np.dot(vector_a, vector_b) + vector_a_norm = np.linalg.norm(vector_a) + vector_b_norm = np.linalg.norm(vector_b) + cos_theta = dot_product / (vector_a_norm * vector_b_norm) + return abs(cos_theta + 1) < eps + + +def _set_discretization_points(kwargs, pt): + """Allow the use of the keyword arguments ``n, n1, n2`` to + specify the number of discretization points in one and two + directions, while keeping back-compatibility with older keyword arguments + like, ``nb_of_points, nb_of_points_*, points``. + + Parameters + ========== + + kwargs : dict + Dictionary of keyword arguments passed into a plotting function. + pt : type + The type of the series, which indicates the kind of plot we are + trying to create. + """ + replace_old_keywords = { + "nb_of_points": "n", + "nb_of_points_x": "n1", + "nb_of_points_y": "n2", + "nb_of_points_u": "n1", + "nb_of_points_v": "n2", + "points": "n" + } + for k, v in replace_old_keywords.items(): + if k in kwargs.keys(): + kwargs[v] = kwargs.pop(k) + + if pt in [LineOver1DRangeSeries, Parametric2DLineSeries, + Parametric3DLineSeries]: + if "n" in kwargs.keys(): + kwargs["n1"] = kwargs["n"] + if hasattr(kwargs["n"], "__iter__") and (len(kwargs["n"]) > 0): + kwargs["n1"] = kwargs["n"][0] + elif pt in [SurfaceOver2DRangeSeries, ContourSeries, + ParametricSurfaceSeries, ImplicitSeries]: + if "n" in kwargs.keys(): + if hasattr(kwargs["n"], "__iter__") and (len(kwargs["n"]) > 1): + kwargs["n1"] = kwargs["n"][0] + kwargs["n2"] = kwargs["n"][1] + else: + kwargs["n1"] = kwargs["n2"] = kwargs["n"] + return kwargs diff --git a/phivenv/Lib/site-packages/sympy/plotting/tests/__init__.py b/phivenv/Lib/site-packages/sympy/plotting/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/sympy/plotting/tests/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/plotting/tests/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ddcfd1595ad64a30d205ae4a812a84e3bbe2d032 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/plotting/tests/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/plotting/tests/__pycache__/test_experimental_lambdify.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/plotting/tests/__pycache__/test_experimental_lambdify.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..61c318ecd6e45bb0562d415f38296407aabb71bb Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/plotting/tests/__pycache__/test_experimental_lambdify.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/plotting/tests/__pycache__/test_plot.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/plotting/tests/__pycache__/test_plot.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f764ceab251cefd2217602c64bb9719008a7361 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/plotting/tests/__pycache__/test_plot.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/plotting/tests/__pycache__/test_plot_implicit.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/plotting/tests/__pycache__/test_plot_implicit.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ea420752c369c2a807222b563fce020b598df90 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/plotting/tests/__pycache__/test_plot_implicit.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/plotting/tests/__pycache__/test_series.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/plotting/tests/__pycache__/test_series.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b10fa4efbb76d365c1f01e15c64aa2e8b817f08 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/plotting/tests/__pycache__/test_series.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/plotting/tests/__pycache__/test_textplot.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/plotting/tests/__pycache__/test_textplot.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fed69af7a639beb829b4a9155e7b388541da7027 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/plotting/tests/__pycache__/test_textplot.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/plotting/tests/__pycache__/test_utils.cpython-39.pyc b/phivenv/Lib/site-packages/sympy/plotting/tests/__pycache__/test_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b675e5216cb08748466d60b8f41986a40fe7853 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/plotting/tests/__pycache__/test_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/sympy/plotting/tests/test_experimental_lambdify.py b/phivenv/Lib/site-packages/sympy/plotting/tests/test_experimental_lambdify.py new file mode 100644 index 0000000000000000000000000000000000000000..95839d668762be7be94d0de5092594306ceeadbd --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/plotting/tests/test_experimental_lambdify.py @@ -0,0 +1,77 @@ +from sympy.core.symbol import symbols, Symbol +from sympy.functions import Max +from sympy.plotting.experimental_lambdify import experimental_lambdify +from sympy.plotting.intervalmath.interval_arithmetic import \ + interval, intervalMembership + + +# Tests for exception handling in experimental_lambdify +def test_experimental_lambify(): + x = Symbol('x') + f = experimental_lambdify([x], Max(x, 5)) + # XXX should f be tested? If f(2) is attempted, an + # error is raised because a complex produced during wrapping of the arg + # is being compared with an int. + assert Max(2, 5) == 5 + assert Max(5, 7) == 7 + + x = Symbol('x-3') + f = experimental_lambdify([x], x + 1) + assert f(1) == 2 + + +def test_composite_boolean_region(): + x, y = symbols('x y') + + r1 = (x - 1)**2 + y**2 < 2 + r2 = (x + 1)**2 + y**2 < 2 + + f = experimental_lambdify((x, y), r1 & r2) + a = (interval(-0.1, 0.1), interval(-0.1, 0.1)) + assert f(*a) == intervalMembership(True, True) + a = (interval(-1.1, -0.9), interval(-0.1, 0.1)) + assert f(*a) == intervalMembership(False, True) + a = (interval(0.9, 1.1), interval(-0.1, 0.1)) + assert f(*a) == intervalMembership(False, True) + a = (interval(-0.1, 0.1), interval(1.9, 2.1)) + assert f(*a) == intervalMembership(False, True) + + f = experimental_lambdify((x, y), r1 | r2) + a = (interval(-0.1, 0.1), interval(-0.1, 0.1)) + assert f(*a) == intervalMembership(True, True) + a = (interval(-1.1, -0.9), interval(-0.1, 0.1)) + assert f(*a) == intervalMembership(True, True) + a = (interval(0.9, 1.1), interval(-0.1, 0.1)) + assert f(*a) == intervalMembership(True, True) + a = (interval(-0.1, 0.1), interval(1.9, 2.1)) + assert f(*a) == intervalMembership(False, True) + + f = experimental_lambdify((x, y), r1 & ~r2) + a = (interval(-0.1, 0.1), interval(-0.1, 0.1)) + assert f(*a) == intervalMembership(False, True) + a = (interval(-1.1, -0.9), interval(-0.1, 0.1)) + assert f(*a) == intervalMembership(False, True) + a = (interval(0.9, 1.1), interval(-0.1, 0.1)) + assert f(*a) == intervalMembership(True, True) + a = (interval(-0.1, 0.1), interval(1.9, 2.1)) + assert f(*a) == intervalMembership(False, True) + + f = experimental_lambdify((x, y), ~r1 & r2) + a = (interval(-0.1, 0.1), interval(-0.1, 0.1)) + assert f(*a) == intervalMembership(False, True) + a = (interval(-1.1, -0.9), interval(-0.1, 0.1)) + assert f(*a) == intervalMembership(True, True) + a = (interval(0.9, 1.1), interval(-0.1, 0.1)) + assert f(*a) == intervalMembership(False, True) + a = (interval(-0.1, 0.1), interval(1.9, 2.1)) + assert f(*a) == intervalMembership(False, True) + + f = experimental_lambdify((x, y), ~r1 & ~r2) + a = (interval(-0.1, 0.1), interval(-0.1, 0.1)) + assert f(*a) == intervalMembership(False, True) + a = (interval(-1.1, -0.9), interval(-0.1, 0.1)) + assert f(*a) == intervalMembership(False, True) + a = (interval(0.9, 1.1), interval(-0.1, 0.1)) + assert f(*a) == intervalMembership(False, True) + a = (interval(-0.1, 0.1), interval(1.9, 2.1)) + assert f(*a) == intervalMembership(True, True) diff --git a/phivenv/Lib/site-packages/sympy/plotting/tests/test_plot.py b/phivenv/Lib/site-packages/sympy/plotting/tests/test_plot.py new file mode 100644 index 0000000000000000000000000000000000000000..e5246c38a19552222aa62720d3f5e9e320344662 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/plotting/tests/test_plot.py @@ -0,0 +1,1344 @@ +import os +from tempfile import TemporaryDirectory +import pytest +from sympy.concrete.summations import Sum +from sympy.core.numbers import (I, oo, pi) +from sympy.core.relational import Ne +from sympy.core.symbol import Symbol, symbols +from sympy.functions.elementary.exponential import (LambertW, exp, exp_polar, log) +from sympy.functions.elementary.miscellaneous import (real_root, sqrt) +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.functions.elementary.miscellaneous import Min +from sympy.functions.special.hyper import meijerg +from sympy.integrals.integrals import Integral +from sympy.logic.boolalg import And +from sympy.core.singleton import S +from sympy.core.sympify import sympify +from sympy.external import import_module +from sympy.plotting.plot import ( + Plot, plot, plot_parametric, plot3d_parametric_line, plot3d, + plot3d_parametric_surface) +from sympy.plotting.plot import ( + unset_show, plot_contour, PlotGrid, MatplotlibBackend, TextBackend) +from sympy.plotting.series import ( + LineOver1DRangeSeries, Parametric2DLineSeries, Parametric3DLineSeries, + ParametricSurfaceSeries, SurfaceOver2DRangeSeries) +from sympy.testing.pytest import skip, skip_under_pyodide, warns, raises, warns_deprecated_sympy +from sympy.utilities import lambdify as lambdify_ +from sympy.utilities.exceptions import ignore_warnings + +unset_show() + + +matplotlib = import_module( + 'matplotlib', min_module_version='1.1.0', catch=(RuntimeError,)) + + +class DummyBackendNotOk(Plot): + """ Used to verify if users can create their own backends. + This backend is meant to raise NotImplementedError for methods `show`, + `save`, `close`. + """ + def __new__(cls, *args, **kwargs): + return object.__new__(cls) + + +class DummyBackendOk(Plot): + """ Used to verify if users can create their own backends. + This backend is meant to pass all tests. + """ + def __new__(cls, *args, **kwargs): + return object.__new__(cls) + + def show(self): + pass + + def save(self): + pass + + def close(self): + pass + +def test_basic_plotting_backend(): + x = Symbol('x') + plot(x, (x, 0, 3), backend='text') + plot(x**2 + 1, (x, 0, 3), backend='text') + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_plot_and_save_1(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + y = Symbol('y') + + with TemporaryDirectory(prefix='sympy_') as tmpdir: + ### + # Examples from the 'introduction' notebook + ### + p = plot(x, legend=True, label='f1', adaptive=adaptive, n=10) + p = plot(x*sin(x), x*cos(x), label='f2', adaptive=adaptive, n=10) + p.extend(p) + p[0].line_color = lambda a: a + p[1].line_color = 'b' + p.title = 'Big title' + p.xlabel = 'the x axis' + p[1].label = 'straight line' + p.legend = True + p.aspect_ratio = (1, 1) + p.xlim = (-15, 20) + filename = 'test_basic_options_and_colors.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + p.extend(plot(x + 1, adaptive=adaptive, n=10)) + p.append(plot(x + 3, x**2, adaptive=adaptive, n=10)[1]) + filename = 'test_plot_extend_append.png' + p.save(os.path.join(tmpdir, filename)) + + p[2] = plot(x**2, (x, -2, 3), adaptive=adaptive, n=10) + filename = 'test_plot_setitem.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + p = plot(sin(x), (x, -2*pi, 4*pi), adaptive=adaptive, n=10) + filename = 'test_line_explicit.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + p = plot(sin(x), adaptive=adaptive, n=10) + filename = 'test_line_default_range.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + p = plot((x**2, (x, -5, 5)), (x**3, (x, -3, 3)), adaptive=adaptive, n=10) + filename = 'test_line_multiple_range.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + raises(ValueError, lambda: plot(x, y)) + + #Piecewise plots + p = plot(Piecewise((1, x > 0), (0, True)), (x, -1, 1), adaptive=adaptive, n=10) + filename = 'test_plot_piecewise.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + p = plot(Piecewise((x, x < 1), (x**2, True)), (x, -3, 3), adaptive=adaptive, n=10) + filename = 'test_plot_piecewise_2.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + # test issue 7471 + p1 = plot(x, adaptive=adaptive, n=10) + p2 = plot(3, adaptive=adaptive, n=10) + p1.extend(p2) + filename = 'test_horizontal_line.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + # test issue 10925 + f = Piecewise((-1, x < -1), (x, And(-1 <= x, x < 0)), \ + (x**2, And(0 <= x, x < 1)), (x**3, x >= 1)) + p = plot(f, (x, -3, 3), adaptive=adaptive, n=10) + filename = 'test_plot_piecewise_3.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_plot_and_save_2(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + y = Symbol('y') + z = Symbol('z') + + with TemporaryDirectory(prefix='sympy_') as tmpdir: + #parametric 2d plots. + #Single plot with default range. + p = plot_parametric(sin(x), cos(x), adaptive=adaptive, n=10) + filename = 'test_parametric.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + #Single plot with range. + p = plot_parametric( + sin(x), cos(x), (x, -5, 5), legend=True, label='parametric_plot', + adaptive=adaptive, n=10) + filename = 'test_parametric_range.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + #Multiple plots with same range. + p = plot_parametric((sin(x), cos(x)), (x, sin(x)), + adaptive=adaptive, n=10) + filename = 'test_parametric_multiple.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + #Multiple plots with different ranges. + p = plot_parametric( + (sin(x), cos(x), (x, -3, 3)), (x, sin(x), (x, -5, 5)), + adaptive=adaptive, n=10) + filename = 'test_parametric_multiple_ranges.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + #depth of recursion specified. + p = plot_parametric(x, sin(x), depth=13, + adaptive=adaptive, n=10) + filename = 'test_recursion_depth.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + #No adaptive sampling. + p = plot_parametric(cos(x), sin(x), adaptive=False, n=500) + filename = 'test_adaptive.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + #3d parametric plots + p = plot3d_parametric_line( + sin(x), cos(x), x, legend=True, label='3d_parametric_plot', + adaptive=adaptive, n=10) + filename = 'test_3d_line.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + p = plot3d_parametric_line( + (sin(x), cos(x), x, (x, -5, 5)), (cos(x), sin(x), x, (x, -3, 3)), + adaptive=adaptive, n=10) + filename = 'test_3d_line_multiple.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + p = plot3d_parametric_line(sin(x), cos(x), x, n=30, + adaptive=adaptive) + filename = 'test_3d_line_points.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + # 3d surface single plot. + p = plot3d(x * y, adaptive=adaptive, n=10) + filename = 'test_surface.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + # Multiple 3D plots with same range. + p = plot3d(-x * y, x * y, (x, -5, 5), adaptive=adaptive, n=10) + filename = 'test_surface_multiple.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + # Multiple 3D plots with different ranges. + p = plot3d( + (x * y, (x, -3, 3), (y, -3, 3)), (-x * y, (x, -3, 3), (y, -3, 3)), + adaptive=adaptive, n=10) + filename = 'test_surface_multiple_ranges.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + # Single Parametric 3D plot + p = plot3d_parametric_surface(sin(x + y), cos(x - y), x - y, + adaptive=adaptive, n=10) + filename = 'test_parametric_surface.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + # Multiple Parametric 3D plots. + p = plot3d_parametric_surface( + (x*sin(z), x*cos(z), z, (x, -5, 5), (z, -5, 5)), + (sin(x + y), cos(x - y), x - y, (x, -5, 5), (y, -5, 5)), + adaptive=adaptive, n=10) + filename = 'test_parametric_surface.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + # Single Contour plot. + p = plot_contour(sin(x)*sin(y), (x, -5, 5), (y, -5, 5), + adaptive=adaptive, n=10) + filename = 'test_contour_plot.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + # Multiple Contour plots with same range. + p = plot_contour(x**2 + y**2, x**3 + y**3, (x, -5, 5), (y, -5, 5), + adaptive=adaptive, n=10) + filename = 'test_contour_plot.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + # Multiple Contour plots with different range. + p = plot_contour( + (x**2 + y**2, (x, -5, 5), (y, -5, 5)), + (x**3 + y**3, (x, -3, 3), (y, -3, 3)), + adaptive=adaptive, n=10) + filename = 'test_contour_plot.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_plot_and_save_3(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + y = Symbol('y') + z = Symbol('z') + + with TemporaryDirectory(prefix='sympy_') as tmpdir: + ### + # Examples from the 'colors' notebook + ### + + p = plot(sin(x), adaptive=adaptive, n=10) + p[0].line_color = lambda a: a + filename = 'test_colors_line_arity1.png' + p.save(os.path.join(tmpdir, filename)) + + p[0].line_color = lambda a, b: b + filename = 'test_colors_line_arity2.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + p = plot(x*sin(x), x*cos(x), (x, 0, 10), adaptive=adaptive, n=10) + p[0].line_color = lambda a: a + filename = 'test_colors_param_line_arity1.png' + p.save(os.path.join(tmpdir, filename)) + + p[0].line_color = lambda a, b: a + filename = 'test_colors_param_line_arity1.png' + p.save(os.path.join(tmpdir, filename)) + + p[0].line_color = lambda a, b: b + filename = 'test_colors_param_line_arity2b.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + p = plot3d_parametric_line( + sin(x) + 0.1*sin(x)*cos(7*x), + cos(x) + 0.1*cos(x)*cos(7*x), + 0.1*sin(7*x), + (x, 0, 2*pi), adaptive=adaptive, n=10) + p[0].line_color = lambdify_(x, sin(4*x)) + filename = 'test_colors_3d_line_arity1.png' + p.save(os.path.join(tmpdir, filename)) + p[0].line_color = lambda a, b: b + filename = 'test_colors_3d_line_arity2.png' + p.save(os.path.join(tmpdir, filename)) + p[0].line_color = lambda a, b, c: c + filename = 'test_colors_3d_line_arity3.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + p = plot3d(sin(x)*y, (x, 0, 6*pi), (y, -5, 5), adaptive=adaptive, n=10) + p[0].surface_color = lambda a: a + filename = 'test_colors_surface_arity1.png' + p.save(os.path.join(tmpdir, filename)) + p[0].surface_color = lambda a, b: b + filename = 'test_colors_surface_arity2.png' + p.save(os.path.join(tmpdir, filename)) + p[0].surface_color = lambda a, b, c: c + filename = 'test_colors_surface_arity3a.png' + p.save(os.path.join(tmpdir, filename)) + p[0].surface_color = lambdify_((x, y, z), sqrt((x - 3*pi)**2 + y**2)) + filename = 'test_colors_surface_arity3b.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + p = plot3d_parametric_surface(x * cos(4 * y), x * sin(4 * y), y, + (x, -1, 1), (y, -1, 1), adaptive=adaptive, n=10) + p[0].surface_color = lambda a: a + filename = 'test_colors_param_surf_arity1.png' + p.save(os.path.join(tmpdir, filename)) + p[0].surface_color = lambda a, b: a*b + filename = 'test_colors_param_surf_arity2.png' + p.save(os.path.join(tmpdir, filename)) + p[0].surface_color = lambdify_((x, y, z), sqrt(x**2 + y**2 + z**2)) + filename = 'test_colors_param_surf_arity3.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + +@pytest.mark.parametrize("adaptive", [True]) +def test_plot_and_save_4(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + y = Symbol('y') + + ### + # Examples from the 'advanced' notebook + ### + + with TemporaryDirectory(prefix='sympy_') as tmpdir: + i = Integral(log((sin(x)**2 + 1)*sqrt(x**2 + 1)), (x, 0, y)) + p = plot(i, (y, 1, 5), adaptive=adaptive, n=10, force_real_eval=True) + filename = 'test_advanced_integral.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_plot_and_save_5(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + y = Symbol('y') + + with TemporaryDirectory(prefix='sympy_') as tmpdir: + s = Sum(1/x**y, (x, 1, oo)) + p = plot(s, (y, 2, 10), adaptive=adaptive, n=10) + filename = 'test_advanced_inf_sum.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + p = plot(Sum(1/x, (x, 1, y)), (y, 2, 10), show=False, + adaptive=adaptive, n=10) + p[0].only_integers = True + p[0].steps = True + filename = 'test_advanced_fin_sum.png' + + # XXX: This should be fixed in experimental_lambdify or by using + # ordinary lambdify so that it doesn't warn. The error results from + # passing an array of values as the integration limit. + # + # UserWarning: The evaluation of the expression is problematic. We are + # trying a failback method that may still work. Please report this as a + # bug. + with ignore_warnings(UserWarning): + p.save(os.path.join(tmpdir, filename)) + + p._backend.close() + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_plot_and_save_6(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + + with TemporaryDirectory(prefix='sympy_') as tmpdir: + filename = 'test.png' + ### + # Test expressions that can not be translated to np and generate complex + # results. + ### + p = plot(sin(x) + I*cos(x)) + p.save(os.path.join(tmpdir, filename)) + + with ignore_warnings(RuntimeWarning): + p = plot(sqrt(sqrt(-x))) + p.save(os.path.join(tmpdir, filename)) + + p = plot(LambertW(x)) + p.save(os.path.join(tmpdir, filename)) + p = plot(sqrt(LambertW(x))) + p.save(os.path.join(tmpdir, filename)) + + #Characteristic function of a StudentT distribution with nu=10 + x1 = 5 * x**2 * exp_polar(-I*pi)/2 + m1 = meijerg(((1 / 2,), ()), ((5, 0, 1 / 2), ()), x1) + x2 = 5*x**2 * exp_polar(I*pi)/2 + m2 = meijerg(((1/2,), ()), ((5, 0, 1/2), ()), x2) + expr = (m1 + m2) / (48 * pi) + with warns( + UserWarning, + match="The evaluation with NumPy/SciPy failed", + test_stacklevel=False, + ): + p = plot(expr, (x, 1e-6, 1e-2), adaptive=adaptive, n=10) + p.save(os.path.join(tmpdir, filename)) + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_plotgrid_and_save(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + y = Symbol('y') + + with TemporaryDirectory(prefix='sympy_') as tmpdir: + p1 = plot(x, adaptive=adaptive, n=10) + p2 = plot_parametric((sin(x), cos(x)), (x, sin(x)), show=False, + adaptive=adaptive, n=10) + p3 = plot_parametric( + cos(x), sin(x), adaptive=adaptive, n=10, show=False) + p4 = plot3d_parametric_line(sin(x), cos(x), x, show=False, + adaptive=adaptive, n=10) + # symmetric grid + p = PlotGrid(2, 2, p1, p2, p3, p4) + filename = 'test_grid1.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + # grid size greater than the number of subplots + p = PlotGrid(3, 4, p1, p2, p3, p4) + filename = 'test_grid2.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + p5 = plot(cos(x),(x, -pi, pi), show=False, adaptive=adaptive, n=10) + p5[0].line_color = lambda a: a + p6 = plot(Piecewise((1, x > 0), (0, True)), (x, -1, 1), show=False, + adaptive=adaptive, n=10) + p7 = plot_contour( + (x**2 + y**2, (x, -5, 5), (y, -5, 5)), + (x**3 + y**3, (x, -3, 3), (y, -3, 3)), show=False, + adaptive=adaptive, n=10) + # unsymmetric grid (subplots in one line) + p = PlotGrid(1, 3, p5, p6, p7) + filename = 'test_grid3.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_append_issue_7140(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + p1 = plot(x, adaptive=adaptive, n=10) + p2 = plot(x**2, adaptive=adaptive, n=10) + plot(x + 2, adaptive=adaptive, n=10) + + # append a series + p2.append(p1[0]) + assert len(p2._series) == 2 + + with raises(TypeError): + p1.append(p2) + + with raises(TypeError): + p1.append(p2._series) + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_issue_15265(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + eqn = sin(x) + + p = plot(eqn, xlim=(-S.Pi, S.Pi), ylim=(-1, 1), adaptive=adaptive, n=10) + p._backend.close() + + p = plot(eqn, xlim=(-1, 1), ylim=(-S.Pi, S.Pi), adaptive=adaptive, n=10) + p._backend.close() + + p = plot(eqn, xlim=(-1, 1), adaptive=adaptive, n=10, + ylim=(sympify('-3.14'), sympify('3.14'))) + p._backend.close() + + p = plot(eqn, adaptive=adaptive, n=10, + xlim=(sympify('-3.14'), sympify('3.14')), ylim=(-1, 1)) + p._backend.close() + + raises(ValueError, + lambda: plot(eqn, adaptive=adaptive, n=10, + xlim=(-S.ImaginaryUnit, 1), ylim=(-1, 1))) + + raises(ValueError, + lambda: plot(eqn, adaptive=adaptive, n=10, + xlim=(-1, 1), ylim=(-1, S.ImaginaryUnit))) + + raises(ValueError, + lambda: plot(eqn, adaptive=adaptive, n=10, + xlim=(S.NegativeInfinity, 1), ylim=(-1, 1))) + + raises(ValueError, + lambda: plot(eqn, adaptive=adaptive, n=10, + xlim=(-1, 1), ylim=(-1, S.Infinity))) + + +def test_empty_Plot(): + if not matplotlib: + skip("Matplotlib not the default backend") + + # No exception showing an empty plot + plot() + # Plot is only a base class: doesn't implement any logic for showing + # images + p = Plot() + raises(NotImplementedError, lambda: p.show()) + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_issue_17405(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + f = x**0.3 - 10*x**3 + x**2 + p = plot(f, (x, -10, 10), adaptive=adaptive, n=30, show=False) + # Random number of segments, probably more than 100, but we want to see + # that there are segments generated, as opposed to when the bug was present + + # RuntimeWarning: invalid value encountered in double_scalars + with ignore_warnings(RuntimeWarning): + assert len(p[0].get_data()[0]) >= 30 + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_logplot_PR_16796(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + p = plot(x, (x, .001, 100), adaptive=adaptive, n=30, + xscale='log', show=False) + # Random number of segments, probably more than 100, but we want to see + # that there are segments generated, as opposed to when the bug was present + assert len(p[0].get_data()[0]) >= 30 + assert p[0].end == 100.0 + assert p[0].start == .001 + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_issue_16572(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + p = plot(LambertW(x), show=False, adaptive=adaptive, n=30) + # Random number of segments, probably more than 50, but we want to see + # that there are segments generated, as opposed to when the bug was present + assert len(p[0].get_data()[0]) >= 30 + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_issue_11865(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + k = Symbol('k', integer=True) + f = Piecewise((-I*exp(I*pi*k)/k + I*exp(-I*pi*k)/k, Ne(k, 0)), (2*pi, True)) + p = plot(f, show=False, adaptive=adaptive, n=30) + # Random number of segments, probably more than 100, but we want to see + # that there are segments generated, as opposed to when the bug was present + # and that there are no exceptions. + assert len(p[0].get_data()[0]) >= 30 + + +@skip_under_pyodide("Warnings not emitted in Pyodide because of lack of WASM fp exception support") +def test_issue_11461(): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + p = plot(real_root((log(x/(x-2))), 3), show=False, adaptive=True) + with warns( + RuntimeWarning, + match="invalid value encountered in", + test_stacklevel=False, + ): + # Random number of segments, probably more than 100, but we want to see + # that there are segments generated, as opposed to when the bug was present + # and that there are no exceptions. + assert len(p[0].get_data()[0]) >= 30 + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_issue_11764(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + p = plot_parametric(cos(x), sin(x), (x, 0, 2 * pi), + aspect_ratio=(1,1), show=False, adaptive=adaptive, n=30) + assert p.aspect_ratio == (1, 1) + # Random number of segments, probably more than 100, but we want to see + # that there are segments generated, as opposed to when the bug was present + assert len(p[0].get_data()[0]) >= 30 + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_issue_13516(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + + pm = plot(sin(x), backend="matplotlib", show=False, adaptive=adaptive, n=30) + assert pm.backend == MatplotlibBackend + assert len(pm[0].get_data()[0]) >= 30 + + pt = plot(sin(x), backend="text", show=False, adaptive=adaptive, n=30) + assert pt.backend == TextBackend + assert len(pt[0].get_data()[0]) >= 30 + + pd = plot(sin(x), backend="default", show=False, adaptive=adaptive, n=30) + assert pd.backend == MatplotlibBackend + assert len(pd[0].get_data()[0]) >= 30 + + p = plot(sin(x), show=False, adaptive=adaptive, n=30) + assert p.backend == MatplotlibBackend + assert len(p[0].get_data()[0]) >= 30 + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_plot_limits(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + p = plot(x, x**2, (x, -10, 10), adaptive=adaptive, n=10) + backend = p._backend + + xmin, xmax = backend.ax.get_xlim() + assert abs(xmin + 10) < 2 + assert abs(xmax - 10) < 2 + ymin, ymax = backend.ax.get_ylim() + assert abs(ymin + 10) < 10 + assert abs(ymax - 100) < 10 + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_plot3d_parametric_line_limits(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + + v1 = (2*cos(x), 2*sin(x), 2*x, (x, -5, 5)) + v2 = (sin(x), cos(x), x, (x, -5, 5)) + p = plot3d_parametric_line(v1, v2, adaptive=adaptive, n=60) + backend = p._backend + + xmin, xmax = backend.ax.get_xlim() + assert abs(xmin + 2) < 1e-2 + assert abs(xmax - 2) < 1e-2 + ymin, ymax = backend.ax.get_ylim() + assert abs(ymin + 2) < 1e-2 + assert abs(ymax - 2) < 1e-2 + zmin, zmax = backend.ax.get_zlim() + assert abs(zmin + 10) < 1e-2 + assert abs(zmax - 10) < 1e-2 + + p = plot3d_parametric_line(v2, v1, adaptive=adaptive, n=60) + backend = p._backend + + xmin, xmax = backend.ax.get_xlim() + assert abs(xmin + 2) < 1e-2 + assert abs(xmax - 2) < 1e-2 + ymin, ymax = backend.ax.get_ylim() + assert abs(ymin + 2) < 1e-2 + assert abs(ymax - 2) < 1e-2 + zmin, zmax = backend.ax.get_zlim() + assert abs(zmin + 10) < 1e-2 + assert abs(zmax - 10) < 1e-2 + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_plot_size(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + + p1 = plot(sin(x), backend="matplotlib", size=(8, 4), + adaptive=adaptive, n=10) + s1 = p1._backend.fig.get_size_inches() + assert (s1[0] == 8) and (s1[1] == 4) + p2 = plot(sin(x), backend="matplotlib", size=(5, 10), + adaptive=adaptive, n=10) + s2 = p2._backend.fig.get_size_inches() + assert (s2[0] == 5) and (s2[1] == 10) + p3 = PlotGrid(2, 1, p1, p2, size=(6, 2), + adaptive=adaptive, n=10) + s3 = p3._backend.fig.get_size_inches() + assert (s3[0] == 6) and (s3[1] == 2) + + with raises(ValueError): + plot(sin(x), backend="matplotlib", size=(-1, 3)) + + +def test_issue_20113(): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + + # verify the capability to use custom backends + plot(sin(x), backend=Plot, show=False) + p2 = plot(sin(x), backend=MatplotlibBackend, show=False) + assert p2.backend == MatplotlibBackend + assert len(p2[0].get_data()[0]) >= 30 + p3 = plot(sin(x), backend=DummyBackendOk, show=False) + assert p3.backend == DummyBackendOk + assert len(p3[0].get_data()[0]) >= 30 + + # test for an improper coded backend + p4 = plot(sin(x), backend=DummyBackendNotOk, show=False) + assert p4.backend == DummyBackendNotOk + assert len(p4[0].get_data()[0]) >= 30 + with raises(NotImplementedError): + p4.show() + with raises(NotImplementedError): + p4.save("test/path") + with raises(NotImplementedError): + p4._backend.close() + + +def test_custom_coloring(): + x = Symbol('x') + y = Symbol('y') + plot(cos(x), line_color=lambda a: a) + plot(cos(x), line_color=1) + plot(cos(x), line_color="r") + plot_parametric(cos(x), sin(x), line_color=lambda a: a) + plot_parametric(cos(x), sin(x), line_color=1) + plot_parametric(cos(x), sin(x), line_color="r") + plot3d_parametric_line(cos(x), sin(x), x, line_color=lambda a: a) + plot3d_parametric_line(cos(x), sin(x), x, line_color=1) + plot3d_parametric_line(cos(x), sin(x), x, line_color="r") + plot3d_parametric_surface(cos(x + y), sin(x - y), x - y, + (x, -5, 5), (y, -5, 5), + surface_color=lambda a, b: a**2 + b**2) + plot3d_parametric_surface(cos(x + y), sin(x - y), x - y, + (x, -5, 5), (y, -5, 5), + surface_color=1) + plot3d_parametric_surface(cos(x + y), sin(x - y), x - y, + (x, -5, 5), (y, -5, 5), + surface_color="r") + plot3d(x*y, (x, -5, 5), (y, -5, 5), + surface_color=lambda a, b: a**2 + b**2) + plot3d(x*y, (x, -5, 5), (y, -5, 5), surface_color=1) + plot3d(x*y, (x, -5, 5), (y, -5, 5), surface_color="r") + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_deprecated_get_segments(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + f = sin(x) + p = plot(f, (x, -10, 10), show=False, adaptive=adaptive, n=10) + with warns_deprecated_sympy(): + p[0].get_segments() + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_generic_data_series(adaptive): + # verify that no errors are raised when generic data series are used + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol("x") + p = plot(x, + markers=[{"args":[[0, 1], [0, 1]], "marker": "*", "linestyle": "none"}], + annotations=[{"text": "test", "xy": (0, 0)}], + fill={"x": [0, 1, 2, 3], "y1": [0, 1, 2, 3]}, + rectangles=[{"xy": (0, 0), "width": 5, "height": 1}], + adaptive=adaptive, n=10) + assert len(p._backend.ax.collections) == 1 + assert len(p._backend.ax.patches) == 1 + assert len(p._backend.ax.lines) == 2 + assert len(p._backend.ax.texts) == 1 + + +def test_deprecated_markers_annotations_rectangles_fill(): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + p = plot(sin(x), (x, -10, 10), show=False) + with warns_deprecated_sympy(): + p.markers = [{"args":[[0, 1], [0, 1]], "marker": "*", "linestyle": "none"}] + assert len(p._series) == 2 + with warns_deprecated_sympy(): + p.annotations = [{"text": "test", "xy": (0, 0)}] + assert len(p._series) == 3 + with warns_deprecated_sympy(): + p.fill = {"x": [0, 1, 2, 3], "y1": [0, 1, 2, 3]} + assert len(p._series) == 4 + with warns_deprecated_sympy(): + p.rectangles = [{"xy": (0, 0), "width": 5, "height": 1}] + assert len(p._series) == 5 + + +def test_back_compatibility(): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + y = Symbol('y') + p = plot(sin(x), adaptive=False, n=5) + assert len(p[0].get_points()) == 2 + assert len(p[0].get_data()) == 2 + p = plot_parametric(cos(x), sin(x), (x, 0, 2), adaptive=False, n=5) + assert len(p[0].get_points()) == 2 + assert len(p[0].get_data()) == 3 + p = plot3d_parametric_line(cos(x), sin(x), x, (x, 0, 2), + adaptive=False, n=5) + assert len(p[0].get_points()) == 3 + assert len(p[0].get_data()) == 4 + p = plot3d(cos(x**2 + y**2), (x, -pi, pi), (y, -pi, pi), n=5) + assert len(p[0].get_meshes()) == 3 + assert len(p[0].get_data()) == 3 + p = plot_contour(cos(x**2 + y**2), (x, -pi, pi), (y, -pi, pi), n=5) + assert len(p[0].get_meshes()) == 3 + assert len(p[0].get_data()) == 3 + p = plot3d_parametric_surface(x * cos(y), x * sin(y), x * cos(4 * y) / 2, + (x, 0, pi), (y, 0, 2*pi), n=5) + assert len(p[0].get_meshes()) == 3 + assert len(p[0].get_data()) == 5 + + +def test_plot_arguments(): + ### Test arguments for plot() + if not matplotlib: + skip("Matplotlib not the default backend") + + x, y = symbols("x, y") + + # single expressions + p = plot(x + 1) + assert isinstance(p[0], LineOver1DRangeSeries) + assert p[0].expr == x + 1 + assert p[0].ranges == [(x, -10, 10)] + assert p[0].get_label(False) == "x + 1" + assert p[0].rendering_kw == {} + + # single expressions custom label + p = plot(x + 1, "label") + assert isinstance(p[0], LineOver1DRangeSeries) + assert p[0].expr == x + 1 + assert p[0].ranges == [(x, -10, 10)] + assert p[0].get_label(False) == "label" + assert p[0].rendering_kw == {} + + # single expressions with range + p = plot(x + 1, (x, -2, 2)) + assert p[0].ranges == [(x, -2, 2)] + + # single expressions with range, label and rendering-kw dictionary + p = plot(x + 1, (x, -2, 2), "test", {"color": "r"}) + assert p[0].get_label(False) == "test" + assert p[0].rendering_kw == {"color": "r"} + + # multiple expressions + p = plot(x + 1, x**2) + assert isinstance(p[0], LineOver1DRangeSeries) + assert p[0].expr == x + 1 + assert p[0].ranges == [(x, -10, 10)] + assert p[0].get_label(False) == "x + 1" + assert p[0].rendering_kw == {} + assert isinstance(p[1], LineOver1DRangeSeries) + assert p[1].expr == x**2 + assert p[1].ranges == [(x, -10, 10)] + assert p[1].get_label(False) == "x**2" + assert p[1].rendering_kw == {} + + # multiple expressions over the same range + p = plot(x + 1, x**2, (x, 0, 5)) + assert p[0].ranges == [(x, 0, 5)] + assert p[1].ranges == [(x, 0, 5)] + + # multiple expressions over the same range with the same rendering kws + p = plot(x + 1, x**2, (x, 0, 5), {"color": "r"}) + assert p[0].ranges == [(x, 0, 5)] + assert p[1].ranges == [(x, 0, 5)] + assert p[0].rendering_kw == {"color": "r"} + assert p[1].rendering_kw == {"color": "r"} + + # multiple expressions with different ranges, labels and rendering kws + p = plot( + (x + 1, (x, 0, 5)), + (x**2, (x, -2, 2), "test", {"color": "r"})) + assert isinstance(p[0], LineOver1DRangeSeries) + assert p[0].expr == x + 1 + assert p[0].ranges == [(x, 0, 5)] + assert p[0].get_label(False) == "x + 1" + assert p[0].rendering_kw == {} + assert isinstance(p[1], LineOver1DRangeSeries) + assert p[1].expr == x**2 + assert p[1].ranges == [(x, -2, 2)] + assert p[1].get_label(False) == "test" + assert p[1].rendering_kw == {"color": "r"} + + # single argument: lambda function + f = lambda t: t + p = plot(lambda t: t) + assert isinstance(p[0], LineOver1DRangeSeries) + assert callable(p[0].expr) + assert p[0].ranges[0][1:] == (-10, 10) + assert p[0].get_label(False) == "" + assert p[0].rendering_kw == {} + + # single argument: lambda function + custom range and label + p = plot(f, ("t", -5, 6), "test") + assert p[0].ranges[0][1:] == (-5, 6) + assert p[0].get_label(False) == "test" + + +def test_plot_parametric_arguments(): + ### Test arguments for plot_parametric() + if not matplotlib: + skip("Matplotlib not the default backend") + + x, y = symbols("x, y") + + # single parametric expression + p = plot_parametric(x + 1, x) + assert isinstance(p[0], Parametric2DLineSeries) + assert p[0].expr == (x + 1, x) + assert p[0].ranges == [(x, -10, 10)] + assert p[0].get_label(False) == "x" + assert p[0].rendering_kw == {} + + # single parametric expression with custom range, label and rendering kws + p = plot_parametric(x + 1, x, (x, -2, 2), "test", + {"cmap": "Reds"}) + assert p[0].expr == (x + 1, x) + assert p[0].ranges == [(x, -2, 2)] + assert p[0].get_label(False) == "test" + assert p[0].rendering_kw == {"cmap": "Reds"} + + p = plot_parametric((x + 1, x), (x, -2, 2), "test") + assert p[0].expr == (x + 1, x) + assert p[0].ranges == [(x, -2, 2)] + assert p[0].get_label(False) == "test" + assert p[0].rendering_kw == {} + + # multiple parametric expressions same symbol + p = plot_parametric((x + 1, x), (x ** 2, x + 1)) + assert p[0].expr == (x + 1, x) + assert p[0].ranges == [(x, -10, 10)] + assert p[0].get_label(False) == "x" + assert p[0].rendering_kw == {} + assert p[1].expr == (x ** 2, x + 1) + assert p[1].ranges == [(x, -10, 10)] + assert p[1].get_label(False) == "x" + assert p[1].rendering_kw == {} + + # multiple parametric expressions different symbols + p = plot_parametric((x + 1, x), (y ** 2, y + 1, "test")) + assert p[0].expr == (x + 1, x) + assert p[0].ranges == [(x, -10, 10)] + assert p[0].get_label(False) == "x" + assert p[0].rendering_kw == {} + assert p[1].expr == (y ** 2, y + 1) + assert p[1].ranges == [(y, -10, 10)] + assert p[1].get_label(False) == "test" + assert p[1].rendering_kw == {} + + # multiple parametric expressions same range + p = plot_parametric((x + 1, x), (x ** 2, x + 1), (x, -2, 2)) + assert p[0].expr == (x + 1, x) + assert p[0].ranges == [(x, -2, 2)] + assert p[0].get_label(False) == "x" + assert p[0].rendering_kw == {} + assert p[1].expr == (x ** 2, x + 1) + assert p[1].ranges == [(x, -2, 2)] + assert p[1].get_label(False) == "x" + assert p[1].rendering_kw == {} + + # multiple parametric expressions, custom ranges and labels + p = plot_parametric( + (x + 1, x, (x, -2, 2), "test1"), + (x ** 2, x + 1, (x, -3, 3), "test2", {"cmap": "Reds"})) + assert p[0].expr == (x + 1, x) + assert p[0].ranges == [(x, -2, 2)] + assert p[0].get_label(False) == "test1" + assert p[0].rendering_kw == {} + assert p[1].expr == (x ** 2, x + 1) + assert p[1].ranges == [(x, -3, 3)] + assert p[1].get_label(False) == "test2" + assert p[1].rendering_kw == {"cmap": "Reds"} + + # single argument: lambda function + fx = lambda t: t + fy = lambda t: 2 * t + p = plot_parametric(fx, fy) + assert all(callable(t) for t in p[0].expr) + assert p[0].ranges[0][1:] == (-10, 10) + assert "Dummy" in p[0].get_label(False) + assert p[0].rendering_kw == {} + + # single argument: lambda function + custom range + label + p = plot_parametric(fx, fy, ("t", 0, 2), "test") + assert all(callable(t) for t in p[0].expr) + assert p[0].ranges[0][1:] == (0, 2) + assert p[0].get_label(False) == "test" + assert p[0].rendering_kw == {} + + +def test_plot3d_parametric_line_arguments(): + ### Test arguments for plot3d_parametric_line() + if not matplotlib: + skip("Matplotlib not the default backend") + + x, y = symbols("x, y") + + # single parametric expression + p = plot3d_parametric_line(x + 1, x, sin(x)) + assert isinstance(p[0], Parametric3DLineSeries) + assert p[0].expr == (x + 1, x, sin(x)) + assert p[0].ranges == [(x, -10, 10)] + assert p[0].get_label(False) == "x" + assert p[0].rendering_kw == {} + + # single parametric expression with custom range, label and rendering kws + p = plot3d_parametric_line(x + 1, x, sin(x), (x, -2, 2), + "test", {"cmap": "Reds"}) + assert isinstance(p[0], Parametric3DLineSeries) + assert p[0].expr == (x + 1, x, sin(x)) + assert p[0].ranges == [(x, -2, 2)] + assert p[0].get_label(False) == "test" + assert p[0].rendering_kw == {"cmap": "Reds"} + + p = plot3d_parametric_line((x + 1, x, sin(x)), (x, -2, 2), "test") + assert p[0].expr == (x + 1, x, sin(x)) + assert p[0].ranges == [(x, -2, 2)] + assert p[0].get_label(False) == "test" + assert p[0].rendering_kw == {} + + # multiple parametric expression same symbol + p = plot3d_parametric_line( + (x + 1, x, sin(x)), (x ** 2, 1, cos(x), {"cmap": "Reds"})) + assert p[0].expr == (x + 1, x, sin(x)) + assert p[0].ranges == [(x, -10, 10)] + assert p[0].get_label(False) == "x" + assert p[0].rendering_kw == {} + assert p[1].expr == (x ** 2, 1, cos(x)) + assert p[1].ranges == [(x, -10, 10)] + assert p[1].get_label(False) == "x" + assert p[1].rendering_kw == {"cmap": "Reds"} + + # multiple parametric expression different symbols + p = plot3d_parametric_line((x + 1, x, sin(x)), (y ** 2, 1, cos(y))) + assert p[0].expr == (x + 1, x, sin(x)) + assert p[0].ranges == [(x, -10, 10)] + assert p[0].get_label(False) == "x" + assert p[0].rendering_kw == {} + assert p[1].expr == (y ** 2, 1, cos(y)) + assert p[1].ranges == [(y, -10, 10)] + assert p[1].get_label(False) == "y" + assert p[1].rendering_kw == {} + + # multiple parametric expression, custom ranges and labels + p = plot3d_parametric_line( + (x + 1, x, sin(x)), + (x ** 2, 1, cos(x), (x, -2, 2), "test", {"cmap": "Reds"})) + assert p[0].expr == (x + 1, x, sin(x)) + assert p[0].ranges == [(x, -10, 10)] + assert p[0].get_label(False) == "x" + assert p[0].rendering_kw == {} + assert p[1].expr == (x ** 2, 1, cos(x)) + assert p[1].ranges == [(x, -2, 2)] + assert p[1].get_label(False) == "test" + assert p[1].rendering_kw == {"cmap": "Reds"} + + # single argument: lambda function + fx = lambda t: t + fy = lambda t: 2 * t + fz = lambda t: 3 * t + p = plot3d_parametric_line(fx, fy, fz) + assert all(callable(t) for t in p[0].expr) + assert p[0].ranges[0][1:] == (-10, 10) + assert "Dummy" in p[0].get_label(False) + assert p[0].rendering_kw == {} + + # single argument: lambda function + custom range + label + p = plot3d_parametric_line(fx, fy, fz, ("t", 0, 2), "test") + assert all(callable(t) for t in p[0].expr) + assert p[0].ranges[0][1:] == (0, 2) + assert p[0].get_label(False) == "test" + assert p[0].rendering_kw == {} + + +def test_plot3d_plot_contour_arguments(): + ### Test arguments for plot3d() and plot_contour() + if not matplotlib: + skip("Matplotlib not the default backend") + + x, y = symbols("x, y") + + # single expression + p = plot3d(x + y) + assert isinstance(p[0], SurfaceOver2DRangeSeries) + assert p[0].expr == x + y + assert p[0].ranges[0] == (x, -10, 10) or (y, -10, 10) + assert p[0].ranges[1] == (x, -10, 10) or (y, -10, 10) + assert p[0].get_label(False) == "x + y" + assert p[0].rendering_kw == {} + + # single expression, custom range, label and rendering kws + p = plot3d(x + y, (x, -2, 2), "test", {"cmap": "Reds"}) + assert isinstance(p[0], SurfaceOver2DRangeSeries) + assert p[0].expr == x + y + assert p[0].ranges[0] == (x, -2, 2) + assert p[0].ranges[1] == (y, -10, 10) + assert p[0].get_label(False) == "test" + assert p[0].rendering_kw == {"cmap": "Reds"} + + p = plot3d(x + y, (x, -2, 2), (y, -4, 4), "test") + assert p[0].ranges[0] == (x, -2, 2) + assert p[0].ranges[1] == (y, -4, 4) + + # multiple expressions + p = plot3d(x + y, x * y) + assert p[0].expr == x + y + assert p[0].ranges[0] == (x, -10, 10) or (y, -10, 10) + assert p[0].ranges[1] == (x, -10, 10) or (y, -10, 10) + assert p[0].get_label(False) == "x + y" + assert p[0].rendering_kw == {} + assert p[1].expr == x * y + assert p[1].ranges[0] == (x, -10, 10) or (y, -10, 10) + assert p[1].ranges[1] == (x, -10, 10) or (y, -10, 10) + assert p[1].get_label(False) == "x*y" + assert p[1].rendering_kw == {} + + # multiple expressions, same custom ranges + p = plot3d(x + y, x * y, (x, -2, 2), (y, -4, 4)) + assert p[0].expr == x + y + assert p[0].ranges[0] == (x, -2, 2) + assert p[0].ranges[1] == (y, -4, 4) + assert p[0].get_label(False) == "x + y" + assert p[0].rendering_kw == {} + assert p[1].expr == x * y + assert p[1].ranges[0] == (x, -2, 2) + assert p[1].ranges[1] == (y, -4, 4) + assert p[1].get_label(False) == "x*y" + assert p[1].rendering_kw == {} + + # multiple expressions, custom ranges, labels and rendering kws + p = plot3d( + (x + y, (x, -2, 2), (y, -4, 4)), + (x * y, (x, -3, 3), (y, -6, 6), "test", {"cmap": "Reds"})) + assert p[0].expr == x + y + assert p[0].ranges[0] == (x, -2, 2) + assert p[0].ranges[1] == (y, -4, 4) + assert p[0].get_label(False) == "x + y" + assert p[0].rendering_kw == {} + assert p[1].expr == x * y + assert p[1].ranges[0] == (x, -3, 3) + assert p[1].ranges[1] == (y, -6, 6) + assert p[1].get_label(False) == "test" + assert p[1].rendering_kw == {"cmap": "Reds"} + + # single expression: lambda function + f = lambda x, y: x + y + p = plot3d(f) + assert callable(p[0].expr) + assert p[0].ranges[0][1:] == (-10, 10) + assert p[0].ranges[1][1:] == (-10, 10) + assert p[0].get_label(False) == "" + assert p[0].rendering_kw == {} + + # single expression: lambda function + custom ranges + label + p = plot3d(f, ("a", -5, 3), ("b", -2, 1), "test") + assert callable(p[0].expr) + assert p[0].ranges[0][1:] == (-5, 3) + assert p[0].ranges[1][1:] == (-2, 1) + assert p[0].get_label(False) == "test" + assert p[0].rendering_kw == {} + + # test issue 25818 + # single expression, custom range, min/max functions + p = plot3d(Min(x, y), (x, 0, 10), (y, 0, 10)) + assert isinstance(p[0], SurfaceOver2DRangeSeries) + assert p[0].expr == Min(x, y) + assert p[0].ranges[0] == (x, 0, 10) + assert p[0].ranges[1] == (y, 0, 10) + assert p[0].get_label(False) == "Min(x, y)" + assert p[0].rendering_kw == {} + + +def test_plot3d_parametric_surface_arguments(): + ### Test arguments for plot3d_parametric_surface() + if not matplotlib: + skip("Matplotlib not the default backend") + + x, y = symbols("x, y") + + # single parametric expression + p = plot3d_parametric_surface(x + y, cos(x + y), sin(x + y)) + assert isinstance(p[0], ParametricSurfaceSeries) + assert p[0].expr == (x + y, cos(x + y), sin(x + y)) + assert p[0].ranges[0] == (x, -10, 10) or (y, -10, 10) + assert p[0].ranges[1] == (x, -10, 10) or (y, -10, 10) + assert p[0].get_label(False) == "(x + y, cos(x + y), sin(x + y))" + assert p[0].rendering_kw == {} + + # single parametric expression, custom ranges, labels and rendering kws + p = plot3d_parametric_surface(x + y, cos(x + y), sin(x + y), + (x, -2, 2), (y, -4, 4), "test", {"cmap": "Reds"}) + assert isinstance(p[0], ParametricSurfaceSeries) + assert p[0].expr == (x + y, cos(x + y), sin(x + y)) + assert p[0].ranges[0] == (x, -2, 2) + assert p[0].ranges[1] == (y, -4, 4) + assert p[0].get_label(False) == "test" + assert p[0].rendering_kw == {"cmap": "Reds"} + + # multiple parametric expressions + p = plot3d_parametric_surface( + (x + y, cos(x + y), sin(x + y)), + (x - y, cos(x - y), sin(x - y), "test")) + assert p[0].expr == (x + y, cos(x + y), sin(x + y)) + assert p[0].ranges[0] == (x, -10, 10) or (y, -10, 10) + assert p[0].ranges[1] == (x, -10, 10) or (y, -10, 10) + assert p[0].get_label(False) == "(x + y, cos(x + y), sin(x + y))" + assert p[0].rendering_kw == {} + assert p[1].expr == (x - y, cos(x - y), sin(x - y)) + assert p[1].ranges[0] == (x, -10, 10) or (y, -10, 10) + assert p[1].ranges[1] == (x, -10, 10) or (y, -10, 10) + assert p[1].get_label(False) == "test" + assert p[1].rendering_kw == {} + + # multiple parametric expressions, custom ranges and labels + p = plot3d_parametric_surface( + (x + y, cos(x + y), sin(x + y), (x, -2, 2), "test"), + (x - y, cos(x - y), sin(x - y), (x, -3, 3), (y, -4, 4), + "test2", {"cmap": "Reds"})) + assert p[0].expr == (x + y, cos(x + y), sin(x + y)) + assert p[0].ranges[0] == (x, -2, 2) + assert p[0].ranges[1] == (y, -10, 10) + assert p[0].get_label(False) == "test" + assert p[0].rendering_kw == {} + assert p[1].expr == (x - y, cos(x - y), sin(x - y)) + assert p[1].ranges[0] == (x, -3, 3) + assert p[1].ranges[1] == (y, -4, 4) + assert p[1].get_label(False) == "test2" + assert p[1].rendering_kw == {"cmap": "Reds"} + + # lambda functions instead of symbolic expressions for a single 3D + # parametric surface + p = plot3d_parametric_surface( + lambda u, v: u, lambda u, v: v, lambda u, v: u + v, + ("u", 0, 2), ("v", -3, 4)) + assert all(callable(t) for t in p[0].expr) + assert p[0].ranges[0][1:] == (-0, 2) + assert p[0].ranges[1][1:] == (-3, 4) + assert p[0].get_label(False) == "" + assert p[0].rendering_kw == {} + + # lambda functions instead of symbolic expressions for multiple 3D + # parametric surfaces + p = plot3d_parametric_surface( + (lambda u, v: u, lambda u, v: v, lambda u, v: u + v, + ("u", 0, 2), ("v", -3, 4)), + (lambda u, v: v, lambda u, v: u, lambda u, v: u - v, + ("u", -2, 3), ("v", -4, 5), "test")) + assert all(callable(t) for t in p[0].expr) + assert p[0].ranges[0][1:] == (0, 2) + assert p[0].ranges[1][1:] == (-3, 4) + assert p[0].get_label(False) == "" + assert p[0].rendering_kw == {} + assert all(callable(t) for t in p[1].expr) + assert p[1].ranges[0][1:] == (-2, 3) + assert p[1].ranges[1][1:] == (-4, 5) + assert p[1].get_label(False) == "test" + assert p[1].rendering_kw == {} diff --git a/phivenv/Lib/site-packages/sympy/plotting/tests/test_plot_implicit.py b/phivenv/Lib/site-packages/sympy/plotting/tests/test_plot_implicit.py new file mode 100644 index 0000000000000000000000000000000000000000..73c7b186c83f0b64d5f6f4cc5cd9f6a08efef43a --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/plotting/tests/test_plot_implicit.py @@ -0,0 +1,146 @@ +from sympy.core.numbers import (I, pi) +from sympy.core.relational import Eq +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.complexes import re +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.trigonometric import (cos, sin, tan) +from sympy.logic.boolalg import (And, Or) +from sympy.plotting.plot_implicit import plot_implicit +from sympy.plotting.plot import unset_show +from tempfile import NamedTemporaryFile, mkdtemp +from sympy.testing.pytest import skip, warns, XFAIL +from sympy.external import import_module +from sympy.testing.tmpfiles import TmpFileManager + +import os + +#Set plots not to show +unset_show() + +def tmp_file(dir=None, name=''): + return NamedTemporaryFile( + suffix='.png', dir=dir, delete=False).name + +def plot_and_save(expr, *args, name='', dir=None, **kwargs): + p = plot_implicit(expr, *args, **kwargs) + p.save(tmp_file(dir=dir, name=name)) + # Close the plot to avoid a warning from matplotlib + p._backend.close() + +def plot_implicit_tests(name): + temp_dir = mkdtemp() + TmpFileManager.tmp_folder(temp_dir) + x = Symbol('x') + y = Symbol('y') + #implicit plot tests + plot_and_save(Eq(y, cos(x)), (x, -5, 5), (y, -2, 2), name=name, dir=temp_dir) + plot_and_save(Eq(y**2, x**3 - x), (x, -5, 5), + (y, -4, 4), name=name, dir=temp_dir) + plot_and_save(y > 1 / x, (x, -5, 5), + (y, -2, 2), name=name, dir=temp_dir) + plot_and_save(y < 1 / tan(x), (x, -5, 5), + (y, -2, 2), name=name, dir=temp_dir) + plot_and_save(y >= 2 * sin(x) * cos(x), (x, -5, 5), + (y, -2, 2), name=name, dir=temp_dir) + plot_and_save(y <= x**2, (x, -3, 3), + (y, -1, 5), name=name, dir=temp_dir) + + #Test all input args for plot_implicit + plot_and_save(Eq(y**2, x**3 - x), dir=temp_dir) + plot_and_save(Eq(y**2, x**3 - x), adaptive=False, dir=temp_dir) + plot_and_save(Eq(y**2, x**3 - x), adaptive=False, n=500, dir=temp_dir) + plot_and_save(y > x, (x, -5, 5), dir=temp_dir) + plot_and_save(And(y > exp(x), y > x + 2), dir=temp_dir) + plot_and_save(Or(y > x, y > -x), dir=temp_dir) + plot_and_save(x**2 - 1, (x, -5, 5), dir=temp_dir) + plot_and_save(x**2 - 1, dir=temp_dir) + plot_and_save(y > x, depth=-5, dir=temp_dir) + plot_and_save(y > x, depth=5, dir=temp_dir) + plot_and_save(y > cos(x), adaptive=False, dir=temp_dir) + plot_and_save(y < cos(x), adaptive=False, dir=temp_dir) + plot_and_save(And(y > cos(x), Or(y > x, Eq(y, x))), dir=temp_dir) + plot_and_save(y - cos(pi / x), dir=temp_dir) + + plot_and_save(x**2 - 1, title='An implicit plot', dir=temp_dir) + +@XFAIL +def test_no_adaptive_meshing(): + matplotlib = import_module('matplotlib', min_module_version='1.1.0', catch=(RuntimeError,)) + if matplotlib: + try: + temp_dir = mkdtemp() + TmpFileManager.tmp_folder(temp_dir) + x = Symbol('x') + y = Symbol('y') + # Test plots which cannot be rendered using the adaptive algorithm + + # This works, but it triggers a deprecation warning from sympify(). The + # code needs to be updated to detect if interval math is supported without + # relying on random AttributeErrors. + with warns(UserWarning, match="Adaptive meshing could not be applied"): + plot_and_save(Eq(y, re(cos(x) + I*sin(x))), name='test', dir=temp_dir) + finally: + TmpFileManager.cleanup() + else: + skip("Matplotlib not the default backend") +def test_line_color(): + x, y = symbols('x, y') + p = plot_implicit(x**2 + y**2 - 1, line_color="green", show=False) + assert p._series[0].line_color == "green" + p = plot_implicit(x**2 + y**2 - 1, line_color='r', show=False) + assert p._series[0].line_color == "r" + +def test_matplotlib(): + matplotlib = import_module('matplotlib', min_module_version='1.1.0', catch=(RuntimeError,)) + if matplotlib: + try: + plot_implicit_tests('test') + test_line_color() + finally: + TmpFileManager.cleanup() + else: + skip("Matplotlib not the default backend") + + +def test_region_and(): + matplotlib = import_module('matplotlib', min_module_version='1.1.0', catch=(RuntimeError,)) + if not matplotlib: + skip("Matplotlib not the default backend") + + from matplotlib.testing.compare import compare_images + test_directory = os.path.dirname(os.path.abspath(__file__)) + + try: + temp_dir = mkdtemp() + TmpFileManager.tmp_folder(temp_dir) + + x, y = symbols('x y') + + r1 = (x - 1)**2 + y**2 < 2 + r2 = (x + 1)**2 + y**2 < 2 + + test_filename = tmp_file(dir=temp_dir, name="test_region_and") + cmp_filename = os.path.join(test_directory, "test_region_and.png") + p = plot_implicit(r1 & r2, x, y) + p.save(test_filename) + compare_images(cmp_filename, test_filename, 0.005) + + test_filename = tmp_file(dir=temp_dir, name="test_region_or") + cmp_filename = os.path.join(test_directory, "test_region_or.png") + p = plot_implicit(r1 | r2, x, y) + p.save(test_filename) + compare_images(cmp_filename, test_filename, 0.005) + + test_filename = tmp_file(dir=temp_dir, name="test_region_not") + cmp_filename = os.path.join(test_directory, "test_region_not.png") + p = plot_implicit(~r1, x, y) + p.save(test_filename) + compare_images(cmp_filename, test_filename, 0.005) + + test_filename = tmp_file(dir=temp_dir, name="test_region_xor") + cmp_filename = os.path.join(test_directory, "test_region_xor.png") + p = plot_implicit(r1 ^ r2, x, y) + p.save(test_filename) + compare_images(cmp_filename, test_filename, 0.005) + finally: + TmpFileManager.cleanup() diff --git a/phivenv/Lib/site-packages/sympy/plotting/tests/test_region_and.png b/phivenv/Lib/site-packages/sympy/plotting/tests/test_region_and.png new file mode 100644 index 0000000000000000000000000000000000000000..61dda4c2054e5e4bd5018cb84af86a832e81886a Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/plotting/tests/test_region_and.png differ diff --git a/phivenv/Lib/site-packages/sympy/plotting/tests/test_region_not.png b/phivenv/Lib/site-packages/sympy/plotting/tests/test_region_not.png new file mode 100644 index 0000000000000000000000000000000000000000..29d3d47b5a95346cb7c44655c12a2a63e6c7a857 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/plotting/tests/test_region_not.png differ diff --git a/phivenv/Lib/site-packages/sympy/plotting/tests/test_region_or.png b/phivenv/Lib/site-packages/sympy/plotting/tests/test_region_or.png new file mode 100644 index 0000000000000000000000000000000000000000..8a6329dd8dd368c37e431a7741e0869ec84f8f68 Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/plotting/tests/test_region_or.png differ diff --git a/phivenv/Lib/site-packages/sympy/plotting/tests/test_region_xor.png b/phivenv/Lib/site-packages/sympy/plotting/tests/test_region_xor.png new file mode 100644 index 0000000000000000000000000000000000000000..1a48862909d3ad09a5f4d306bf6c8f96117d080c Binary files /dev/null and b/phivenv/Lib/site-packages/sympy/plotting/tests/test_region_xor.png differ diff --git a/phivenv/Lib/site-packages/sympy/plotting/tests/test_series.py b/phivenv/Lib/site-packages/sympy/plotting/tests/test_series.py new file mode 100644 index 0000000000000000000000000000000000000000..9fdacbd73aef18b07d2e14ce444b709654ee6f23 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/plotting/tests/test_series.py @@ -0,0 +1,1771 @@ +from sympy import ( + latex, exp, symbols, I, pi, sin, cos, tan, log, sqrt, + re, im, arg, frac, Sum, S, Abs, lambdify, + Function, dsolve, Eq, floor, Tuple +) +from sympy.external import import_module +from sympy.plotting.series import ( + LineOver1DRangeSeries, Parametric2DLineSeries, Parametric3DLineSeries, + SurfaceOver2DRangeSeries, ContourSeries, ParametricSurfaceSeries, + ImplicitSeries, _set_discretization_points, List2DSeries +) +from sympy.testing.pytest import raises, warns, XFAIL, skip, ignore_warnings + +np = import_module('numpy') + + +def test_adaptive(): + # verify that adaptive-related keywords produces the expected results + if not np: + skip("numpy not installed.") + + x, y = symbols("x, y") + + s1 = LineOver1DRangeSeries(sin(x), (x, -10, 10), "", adaptive=True, + depth=2) + x1, _ = s1.get_data() + s2 = LineOver1DRangeSeries(sin(x), (x, -10, 10), "", adaptive=True, + depth=5) + x2, _ = s2.get_data() + s3 = LineOver1DRangeSeries(sin(x), (x, -10, 10), "", adaptive=True) + x3, _ = s3.get_data() + assert len(x1) < len(x2) < len(x3) + + s1 = Parametric2DLineSeries(cos(x), sin(x), (x, 0, 2*pi), + adaptive=True, depth=2) + x1, _, _, = s1.get_data() + s2 = Parametric2DLineSeries(cos(x), sin(x), (x, 0, 2*pi), + adaptive=True, depth=5) + x2, _, _ = s2.get_data() + s3 = Parametric2DLineSeries(cos(x), sin(x), (x, 0, 2*pi), + adaptive=True) + x3, _, _ = s3.get_data() + assert len(x1) < len(x2) < len(x3) + + +def test_detect_poles(): + if not np: + skip("numpy not installed.") + + x, u = symbols("x, u") + + s1 = LineOver1DRangeSeries(tan(x), (x, -pi, pi), + adaptive=False, n=1000, detect_poles=False) + xx1, yy1 = s1.get_data() + s2 = LineOver1DRangeSeries(tan(x), (x, -pi, pi), + adaptive=False, n=1000, detect_poles=True, eps=0.01) + xx2, yy2 = s2.get_data() + # eps is too small: doesn't detect any poles + s3 = LineOver1DRangeSeries(tan(x), (x, -pi, pi), + adaptive=False, n=1000, detect_poles=True, eps=1e-06) + xx3, yy3 = s3.get_data() + s4 = LineOver1DRangeSeries(tan(x), (x, -pi, pi), + adaptive=False, n=1000, detect_poles="symbolic") + xx4, yy4 = s4.get_data() + + assert np.allclose(xx1, xx2) and np.allclose(xx1, xx3) and np.allclose(xx1, xx4) + assert not np.any(np.isnan(yy1)) + assert not np.any(np.isnan(yy3)) + assert np.any(np.isnan(yy2)) + assert np.any(np.isnan(yy4)) + assert len(s2.poles_locations) == len(s3.poles_locations) == 0 + assert len(s4.poles_locations) == 2 + assert np.allclose(np.abs(s4.poles_locations), np.pi / 2) + + with warns( + UserWarning, + match="NumPy is unable to evaluate with complex numbers some of", + test_stacklevel=False, + ): + s1 = LineOver1DRangeSeries(frac(x), (x, -10, 10), + adaptive=False, n=1000, detect_poles=False) + s2 = LineOver1DRangeSeries(frac(x), (x, -10, 10), + adaptive=False, n=1000, detect_poles=True, eps=0.05) + s3 = LineOver1DRangeSeries(frac(x), (x, -10, 10), + adaptive=False, n=1000, detect_poles="symbolic") + xx1, yy1 = s1.get_data() + xx2, yy2 = s2.get_data() + xx3, yy3 = s3.get_data() + assert np.allclose(xx1, xx2) and np.allclose(xx1, xx3) + assert not np.any(np.isnan(yy1)) + assert np.any(np.isnan(yy2)) and np.any(np.isnan(yy2)) + assert not np.allclose(yy1, yy2, equal_nan=True) + # The poles below are actually step discontinuities. + assert len(s3.poles_locations) == 21 + + s1 = LineOver1DRangeSeries(tan(u * x), (x, -pi, pi), params={u: 1}, + adaptive=False, n=1000, detect_poles=False) + xx1, yy1 = s1.get_data() + s2 = LineOver1DRangeSeries(tan(u * x), (x, -pi, pi), params={u: 1}, + adaptive=False, n=1000, detect_poles=True, eps=0.01) + xx2, yy2 = s2.get_data() + # eps is too small: doesn't detect any poles + s3 = LineOver1DRangeSeries(tan(u * x), (x, -pi, pi), params={u: 1}, + adaptive=False, n=1000, detect_poles=True, eps=1e-06) + xx3, yy3 = s3.get_data() + s4 = LineOver1DRangeSeries(tan(u * x), (x, -pi, pi), params={u: 1}, + adaptive=False, n=1000, detect_poles="symbolic") + xx4, yy4 = s4.get_data() + + assert np.allclose(xx1, xx2) and np.allclose(xx1, xx3) and np.allclose(xx1, xx4) + assert not np.any(np.isnan(yy1)) + assert not np.any(np.isnan(yy3)) + assert np.any(np.isnan(yy2)) + assert np.any(np.isnan(yy4)) + assert len(s2.poles_locations) == len(s3.poles_locations) == 0 + assert len(s4.poles_locations) == 2 + assert np.allclose(np.abs(s4.poles_locations), np.pi / 2) + + with warns( + UserWarning, + match="NumPy is unable to evaluate with complex numbers some of", + test_stacklevel=False, + ): + u, v = symbols("u, v", real=True) + n = S(1) / 3 + f = (u + I * v)**n + r, i = re(f), im(f) + s1 = Parametric2DLineSeries(r.subs(u, -2), i.subs(u, -2), (v, -2, 2), + adaptive=False, n=1000, detect_poles=False) + s2 = Parametric2DLineSeries(r.subs(u, -2), i.subs(u, -2), (v, -2, 2), + adaptive=False, n=1000, detect_poles=True) + with ignore_warnings(RuntimeWarning): + xx1, yy1, pp1 = s1.get_data() + assert not np.isnan(yy1).any() + xx2, yy2, pp2 = s2.get_data() + assert np.isnan(yy2).any() + + with warns( + UserWarning, + match="NumPy is unable to evaluate with complex numbers some of", + test_stacklevel=False, + ): + f = (x * u + x * I * v)**n + r, i = re(f), im(f) + s1 = Parametric2DLineSeries(r.subs(u, -2), i.subs(u, -2), + (v, -2, 2), params={x: 1}, + adaptive=False, n1=1000, detect_poles=False) + s2 = Parametric2DLineSeries(r.subs(u, -2), i.subs(u, -2), + (v, -2, 2), params={x: 1}, + adaptive=False, n1=1000, detect_poles=True) + with ignore_warnings(RuntimeWarning): + xx1, yy1, pp1 = s1.get_data() + assert not np.isnan(yy1).any() + xx2, yy2, pp2 = s2.get_data() + assert np.isnan(yy2).any() + + +def test_number_discretization_points(): + # verify that the different ways to set the number of discretization + # points are consistent with each other. + if not np: + skip("numpy not installed.") + + x, y, z = symbols("x:z") + + for pt in [LineOver1DRangeSeries, Parametric2DLineSeries, + Parametric3DLineSeries]: + kw1 = _set_discretization_points({"n": 10}, pt) + kw2 = _set_discretization_points({"n": [10, 20, 30]}, pt) + kw3 = _set_discretization_points({"n1": 10}, pt) + assert all(("n1" in kw) and kw["n1"] == 10 for kw in [kw1, kw2, kw3]) + + for pt in [SurfaceOver2DRangeSeries, ContourSeries, ParametricSurfaceSeries, + ImplicitSeries]: + kw1 = _set_discretization_points({"n": 10}, pt) + kw2 = _set_discretization_points({"n": [10, 20, 30]}, pt) + kw3 = _set_discretization_points({"n1": 10, "n2": 20}, pt) + assert kw1["n1"] == kw1["n2"] == 10 + assert all((kw["n1"] == 10) and (kw["n2"] == 20) for kw in [kw2, kw3]) + + # verify that line-related series can deal with large float number of + # discretization points + LineOver1DRangeSeries(cos(x), (x, -5, 5), adaptive=False, n=1e04).get_data() + + +def test_list2dseries(): + if not np: + skip("numpy not installed.") + + xx = np.linspace(-3, 3, 10) + yy1 = np.cos(xx) + yy2 = np.linspace(-3, 3, 20) + + # same number of elements: everything is fine + s = List2DSeries(xx, yy1) + assert not s.is_parametric + # different number of elements: error + raises(ValueError, lambda: List2DSeries(xx, yy2)) + + # no color func: returns only x, y components and s in not parametric + s = List2DSeries(xx, yy1) + xxs, yys = s.get_data() + assert np.allclose(xx, xxs) + assert np.allclose(yy1, yys) + assert not s.is_parametric + + +def test_interactive_vs_noninteractive(): + # verify that if a *Series class receives a `params` dictionary, it sets + # is_interactive=True + x, y, z, u, v = symbols("x, y, z, u, v") + + s = LineOver1DRangeSeries(cos(x), (x, -5, 5)) + assert not s.is_interactive + s = LineOver1DRangeSeries(u * cos(x), (x, -5, 5), params={u: 1}) + assert s.is_interactive + + s = Parametric2DLineSeries(cos(x), sin(x), (x, -5, 5)) + assert not s.is_interactive + s = Parametric2DLineSeries(u * cos(x), u * sin(x), (x, -5, 5), + params={u: 1}) + assert s.is_interactive + + s = Parametric3DLineSeries(cos(x), sin(x), x, (x, -5, 5)) + assert not s.is_interactive + s = Parametric3DLineSeries(u * cos(x), u * sin(x), x, (x, -5, 5), + params={u: 1}) + assert s.is_interactive + + s = SurfaceOver2DRangeSeries(cos(x * y), (x, -5, 5), (y, -5, 5)) + assert not s.is_interactive + s = SurfaceOver2DRangeSeries(u * cos(x * y), (x, -5, 5), (y, -5, 5), + params={u: 1}) + assert s.is_interactive + + s = ContourSeries(cos(x * y), (x, -5, 5), (y, -5, 5)) + assert not s.is_interactive + s = ContourSeries(u * cos(x * y), (x, -5, 5), (y, -5, 5), + params={u: 1}) + assert s.is_interactive + + s = ParametricSurfaceSeries(u * cos(v), v * sin(u), u + v, + (u, -5, 5), (v, -5, 5)) + assert not s.is_interactive + s = ParametricSurfaceSeries(u * cos(v * x), v * sin(u), u + v, + (u, -5, 5), (v, -5, 5), params={x: 1}) + assert s.is_interactive + + +def test_lin_log_scale(): + # Verify that data series create the correct spacing in the data. + if not np: + skip("numpy not installed.") + + x, y, z = symbols("x, y, z") + + s = LineOver1DRangeSeries(x, (x, 1, 10), adaptive=False, n=50, + xscale="linear") + xx, _ = s.get_data() + assert np.isclose(xx[1] - xx[0], xx[-1] - xx[-2]) + + s = LineOver1DRangeSeries(x, (x, 1, 10), adaptive=False, n=50, + xscale="log") + xx, _ = s.get_data() + assert not np.isclose(xx[1] - xx[0], xx[-1] - xx[-2]) + + s = Parametric2DLineSeries( + cos(x), sin(x), (x, pi / 2, 1.5 * pi), adaptive=False, n=50, + xscale="linear") + _, _, param = s.get_data() + assert np.isclose(param[1] - param[0], param[-1] - param[-2]) + + s = Parametric2DLineSeries( + cos(x), sin(x), (x, pi / 2, 1.5 * pi), adaptive=False, n=50, + xscale="log") + _, _, param = s.get_data() + assert not np.isclose(param[1] - param[0], param[-1] - param[-2]) + + s = Parametric3DLineSeries( + cos(x), sin(x), x, (x, pi / 2, 1.5 * pi), adaptive=False, n=50, + xscale="linear") + _, _, _, param = s.get_data() + assert np.isclose(param[1] - param[0], param[-1] - param[-2]) + + s = Parametric3DLineSeries( + cos(x), sin(x), x, (x, pi / 2, 1.5 * pi), adaptive=False, n=50, + xscale="log") + _, _, _, param = s.get_data() + assert not np.isclose(param[1] - param[0], param[-1] - param[-2]) + + s = SurfaceOver2DRangeSeries( + cos(x ** 2 + y ** 2), (x, 1, 5), (y, 1, 5), n=10, + xscale="linear", yscale="linear") + xx, yy, _ = s.get_data() + assert np.isclose(xx[0, 1] - xx[0, 0], xx[0, -1] - xx[0, -2]) + assert np.isclose(yy[1, 0] - yy[0, 0], yy[-1, 0] - yy[-2, 0]) + + s = SurfaceOver2DRangeSeries( + cos(x ** 2 + y ** 2), (x, 1, 5), (y, 1, 5), n=10, + xscale="log", yscale="log") + xx, yy, _ = s.get_data() + assert not np.isclose(xx[0, 1] - xx[0, 0], xx[0, -1] - xx[0, -2]) + assert not np.isclose(yy[1, 0] - yy[0, 0], yy[-1, 0] - yy[-2, 0]) + + s = ImplicitSeries( + cos(x ** 2 + y ** 2) > 0, (x, 1, 5), (y, 1, 5), + n1=10, n2=10, xscale="linear", yscale="linear", adaptive=False) + xx, yy, _, _ = s.get_data() + assert np.isclose(xx[0, 1] - xx[0, 0], xx[0, -1] - xx[0, -2]) + assert np.isclose(yy[1, 0] - yy[0, 0], yy[-1, 0] - yy[-2, 0]) + + s = ImplicitSeries( + cos(x ** 2 + y ** 2) > 0, (x, 1, 5), (y, 1, 5), + n=10, xscale="log", yscale="log", adaptive=False) + xx, yy, _, _ = s.get_data() + assert not np.isclose(xx[0, 1] - xx[0, 0], xx[0, -1] - xx[0, -2]) + assert not np.isclose(yy[1, 0] - yy[0, 0], yy[-1, 0] - yy[-2, 0]) + + +def test_rendering_kw(): + # verify that each series exposes the `rendering_kw` attribute + if not np: + skip("numpy not installed.") + + u, v, x, y, z = symbols("u, v, x:z") + + s = List2DSeries([1, 2, 3], [4, 5, 6]) + assert isinstance(s.rendering_kw, dict) + + s = LineOver1DRangeSeries(1, (x, -5, 5)) + assert isinstance(s.rendering_kw, dict) + + s = Parametric2DLineSeries(sin(x), cos(x), (x, 0, pi)) + assert isinstance(s.rendering_kw, dict) + + s = Parametric3DLineSeries(cos(x), sin(x), x, (x, 0, 2 * pi)) + assert isinstance(s.rendering_kw, dict) + + s = SurfaceOver2DRangeSeries(x + y, (x, -2, 2), (y, -3, 3)) + assert isinstance(s.rendering_kw, dict) + + s = ContourSeries(x + y, (x, -2, 2), (y, -3, 3)) + assert isinstance(s.rendering_kw, dict) + + s = ParametricSurfaceSeries(1, x, y, (x, 0, 1), (y, 0, 1)) + assert isinstance(s.rendering_kw, dict) + + +def test_data_shape(): + # Verify that the series produces the correct data shape when the input + # expression is a number. + if not np: + skip("numpy not installed.") + + u, x, y, z = symbols("u, x:z") + + # scalar expression: it should return a numpy ones array + s = LineOver1DRangeSeries(1, (x, -5, 5)) + xx, yy = s.get_data() + assert len(xx) == len(yy) + assert np.all(yy == 1) + + s = LineOver1DRangeSeries(1, (x, -5, 5), adaptive=False, n=10) + xx, yy = s.get_data() + assert len(xx) == len(yy) == 10 + assert np.all(yy == 1) + + s = Parametric2DLineSeries(sin(x), 1, (x, 0, pi)) + xx, yy, param = s.get_data() + assert (len(xx) == len(yy)) and (len(xx) == len(param)) + assert np.all(yy == 1) + + s = Parametric2DLineSeries(1, sin(x), (x, 0, pi)) + xx, yy, param = s.get_data() + assert (len(xx) == len(yy)) and (len(xx) == len(param)) + assert np.all(xx == 1) + + s = Parametric2DLineSeries(sin(x), 1, (x, 0, pi), adaptive=False) + xx, yy, param = s.get_data() + assert (len(xx) == len(yy)) and (len(xx) == len(param)) + assert np.all(yy == 1) + + s = Parametric2DLineSeries(1, sin(x), (x, 0, pi), adaptive=False) + xx, yy, param = s.get_data() + assert (len(xx) == len(yy)) and (len(xx) == len(param)) + assert np.all(xx == 1) + + s = Parametric3DLineSeries(cos(x), sin(x), 1, (x, 0, 2 * pi)) + xx, yy, zz, param = s.get_data() + assert (len(xx) == len(yy)) and (len(xx) == len(zz)) and (len(xx) == len(param)) + assert np.all(zz == 1) + + s = Parametric3DLineSeries(cos(x), 1, x, (x, 0, 2 * pi)) + xx, yy, zz, param = s.get_data() + assert (len(xx) == len(yy)) and (len(xx) == len(zz)) and (len(xx) == len(param)) + assert np.all(yy == 1) + + s = Parametric3DLineSeries(1, sin(x), x, (x, 0, 2 * pi)) + xx, yy, zz, param = s.get_data() + assert (len(xx) == len(yy)) and (len(xx) == len(zz)) and (len(xx) == len(param)) + assert np.all(xx == 1) + + s = SurfaceOver2DRangeSeries(1, (x, -2, 2), (y, -3, 3)) + xx, yy, zz = s.get_data() + assert (xx.shape == yy.shape) and (xx.shape == zz.shape) + assert np.all(zz == 1) + + s = ParametricSurfaceSeries(1, x, y, (x, 0, 1), (y, 0, 1)) + xx, yy, zz, uu, vv = s.get_data() + assert xx.shape == yy.shape == zz.shape == uu.shape == vv.shape + assert np.all(xx == 1) + + s = ParametricSurfaceSeries(1, 1, y, (x, 0, 1), (y, 0, 1)) + xx, yy, zz, uu, vv = s.get_data() + assert xx.shape == yy.shape == zz.shape == uu.shape == vv.shape + assert np.all(yy == 1) + + s = ParametricSurfaceSeries(x, 1, 1, (x, 0, 1), (y, 0, 1)) + xx, yy, zz, uu, vv = s.get_data() + assert xx.shape == yy.shape == zz.shape == uu.shape == vv.shape + assert np.all(zz == 1) + + +def test_only_integers(): + if not np: + skip("numpy not installed.") + + x, y, u, v = symbols("x, y, u, v") + + s = LineOver1DRangeSeries(sin(x), (x, -5.5, 4.5), "", + adaptive=False, only_integers=True) + xx, _ = s.get_data() + assert len(xx) == 10 + assert xx[0] == -5 and xx[-1] == 4 + + s = Parametric2DLineSeries(cos(x), sin(x), (x, 0, 2 * pi), "", + adaptive=False, only_integers=True) + _, _, p = s.get_data() + assert len(p) == 7 + assert p[0] == 0 and p[-1] == 6 + + s = Parametric3DLineSeries(cos(x), sin(x), x, (x, 0, 2 * pi), "", + adaptive=False, only_integers=True) + _, _, _, p = s.get_data() + assert len(p) == 7 + assert p[0] == 0 and p[-1] == 6 + + s = SurfaceOver2DRangeSeries(cos(x**2 + y**2), (x, -5.5, 5.5), + (y, -3.5, 3.5), "", + adaptive=False, only_integers=True) + xx, yy, _ = s.get_data() + assert xx.shape == yy.shape == (7, 11) + assert np.allclose(xx[:, 0] - (-5) * np.ones(7), 0) + assert np.allclose(xx[0, :] - np.linspace(-5, 5, 11), 0) + assert np.allclose(yy[:, 0] - np.linspace(-3, 3, 7), 0) + assert np.allclose(yy[0, :] - (-3) * np.ones(11), 0) + + r = 2 + sin(7 * u + 5 * v) + expr = ( + r * cos(u) * sin(v), + r * sin(u) * sin(v), + r * cos(v) + ) + s = ParametricSurfaceSeries(*expr, (u, 0, 2 * pi), (v, 0, pi), "", + adaptive=False, only_integers=True) + xx, yy, zz, uu, vv = s.get_data() + assert xx.shape == yy.shape == zz.shape == uu.shape == vv.shape == (4, 7) + + # only_integers also works with scalar expressions + s = LineOver1DRangeSeries(1, (x, -5.5, 4.5), "", + adaptive=False, only_integers=True) + xx, _ = s.get_data() + assert len(xx) == 10 + assert xx[0] == -5 and xx[-1] == 4 + + s = Parametric2DLineSeries(cos(x), 1, (x, 0, 2 * pi), "", + adaptive=False, only_integers=True) + _, _, p = s.get_data() + assert len(p) == 7 + assert p[0] == 0 and p[-1] == 6 + + s = SurfaceOver2DRangeSeries(1, (x, -5.5, 5.5), (y, -3.5, 3.5), "", + adaptive=False, only_integers=True) + xx, yy, _ = s.get_data() + assert xx.shape == yy.shape == (7, 11) + assert np.allclose(xx[:, 0] - (-5) * np.ones(7), 0) + assert np.allclose(xx[0, :] - np.linspace(-5, 5, 11), 0) + assert np.allclose(yy[:, 0] - np.linspace(-3, 3, 7), 0) + assert np.allclose(yy[0, :] - (-3) * np.ones(11), 0) + + r = 2 + sin(7 * u + 5 * v) + expr = ( + r * cos(u) * sin(v), + 1, + r * cos(v) + ) + s = ParametricSurfaceSeries(*expr, (u, 0, 2 * pi), (v, 0, pi), "", + adaptive=False, only_integers=True) + xx, yy, zz, uu, vv = s.get_data() + assert xx.shape == yy.shape == zz.shape == uu.shape == vv.shape == (4, 7) + + +def test_is_point_is_filled(): + # verify that `is_point` and `is_filled` are attributes and that they + # they receive the correct values + if not np: + skip("numpy not installed.") + + x, u = symbols("x, u") + + s = LineOver1DRangeSeries(cos(x), (x, -5, 5), "", + is_point=False, is_filled=True) + assert (not s.is_point) and s.is_filled + s = LineOver1DRangeSeries(cos(x), (x, -5, 5), "", + is_point=True, is_filled=False) + assert s.is_point and (not s.is_filled) + + s = List2DSeries([0, 1, 2], [3, 4, 5], + is_point=False, is_filled=True) + assert (not s.is_point) and s.is_filled + s = List2DSeries([0, 1, 2], [3, 4, 5], + is_point=True, is_filled=False) + assert s.is_point and (not s.is_filled) + + s = Parametric2DLineSeries(cos(x), sin(x), (x, -5, 5), + is_point=False, is_filled=True) + assert (not s.is_point) and s.is_filled + s = Parametric2DLineSeries(cos(x), sin(x), (x, -5, 5), + is_point=True, is_filled=False) + assert s.is_point and (not s.is_filled) + + s = Parametric3DLineSeries(cos(x), sin(x), x, (x, -5, 5), + is_point=False, is_filled=True) + assert (not s.is_point) and s.is_filled + s = Parametric3DLineSeries(cos(x), sin(x), x, (x, -5, 5), + is_point=True, is_filled=False) + assert s.is_point and (not s.is_filled) + + +def test_is_filled_2d(): + # verify that the is_filled attribute is exposed by the following series + x, y = symbols("x, y") + + expr = cos(x**2 + y**2) + ranges = (x, -2, 2), (y, -2, 2) + + s = ContourSeries(expr, *ranges) + assert s.is_filled + s = ContourSeries(expr, *ranges, is_filled=True) + assert s.is_filled + s = ContourSeries(expr, *ranges, is_filled=False) + assert not s.is_filled + + +def test_steps(): + if not np: + skip("numpy not installed.") + + x, u = symbols("x, u") + + def do_test(s1, s2): + if (not s1.is_parametric) and s1.is_2Dline: + xx1, _ = s1.get_data() + xx2, _ = s2.get_data() + elif s1.is_parametric and s1.is_2Dline: + xx1, _, _ = s1.get_data() + xx2, _, _ = s2.get_data() + elif (not s1.is_parametric) and s1.is_3Dline: + xx1, _, _ = s1.get_data() + xx2, _, _ = s2.get_data() + else: + xx1, _, _, _ = s1.get_data() + xx2, _, _, _ = s2.get_data() + assert len(xx1) != len(xx2) + + s1 = LineOver1DRangeSeries(cos(x), (x, -5, 5), "", + adaptive=False, n=40, steps=False) + s2 = LineOver1DRangeSeries(cos(x), (x, -5, 5), "", + adaptive=False, n=40, steps=True) + do_test(s1, s2) + + s1 = List2DSeries([0, 1, 2], [3, 4, 5], steps=False) + s2 = List2DSeries([0, 1, 2], [3, 4, 5], steps=True) + do_test(s1, s2) + + s1 = Parametric2DLineSeries(cos(x), sin(x), (x, -5, 5), + adaptive=False, n=40, steps=False) + s2 = Parametric2DLineSeries(cos(x), sin(x), (x, -5, 5), + adaptive=False, n=40, steps=True) + do_test(s1, s2) + + s1 = Parametric3DLineSeries(cos(x), sin(x), x, (x, -5, 5), + adaptive=False, n=40, steps=False) + s2 = Parametric3DLineSeries(cos(x), sin(x), x, (x, -5, 5), + adaptive=False, n=40, steps=True) + do_test(s1, s2) + + +def test_interactive_data(): + # verify that InteractiveSeries produces the same numerical data as their + # corresponding non-interactive series. + if not np: + skip("numpy not installed.") + + u, x, y, z = symbols("u, x:z") + + def do_test(data1, data2): + assert len(data1) == len(data2) + for d1, d2 in zip(data1, data2): + assert np.allclose(d1, d2) + + s1 = LineOver1DRangeSeries(u * cos(x), (x, -5, 5), params={u: 1}, n=50) + s2 = LineOver1DRangeSeries(cos(x), (x, -5, 5), adaptive=False, n=50) + do_test(s1.get_data(), s2.get_data()) + + s1 = Parametric2DLineSeries( + u * cos(x), u * sin(x), (x, -5, 5), params={u: 1}, n=50) + s2 = Parametric2DLineSeries(cos(x), sin(x), (x, -5, 5), + adaptive=False, n=50) + do_test(s1.get_data(), s2.get_data()) + + s1 = Parametric3DLineSeries( + u * cos(x), u * sin(x), u * x, (x, -5, 5), + params={u: 1}, n=50) + s2 = Parametric3DLineSeries(cos(x), sin(x), x, (x, -5, 5), + adaptive=False, n=50) + do_test(s1.get_data(), s2.get_data()) + + s1 = SurfaceOver2DRangeSeries( + u * cos(x ** 2 + y ** 2), (x, -3, 3), (y, -3, 3), + params={u: 1}, n1=50, n2=50,) + s2 = SurfaceOver2DRangeSeries( + cos(x ** 2 + y ** 2), (x, -3, 3), (y, -3, 3), + adaptive=False, n1=50, n2=50) + do_test(s1.get_data(), s2.get_data()) + + s1 = ParametricSurfaceSeries( + u * cos(x + y), sin(x + y), x - y, (x, -3, 3), (y, -3, 3), + params={u: 1}, n1=50, n2=50,) + s2 = ParametricSurfaceSeries( + cos(x + y), sin(x + y), x - y, (x, -3, 3), (y, -3, 3), + adaptive=False, n1=50, n2=50,) + do_test(s1.get_data(), s2.get_data()) + + # real part of a complex function evaluated over a real line with numpy + expr = re((z ** 2 + 1) / (z ** 2 - 1)) + s1 = LineOver1DRangeSeries(u * expr, (z, -3, 3), adaptive=False, n=50, + modules=None, params={u: 1}) + s2 = LineOver1DRangeSeries(expr, (z, -3, 3), adaptive=False, n=50, + modules=None) + do_test(s1.get_data(), s2.get_data()) + + # real part of a complex function evaluated over a real line with mpmath + expr = re((z ** 2 + 1) / (z ** 2 - 1)) + s1 = LineOver1DRangeSeries(u * expr, (z, -3, 3), n=50, modules="mpmath", + params={u: 1}) + s2 = LineOver1DRangeSeries(expr, (z, -3, 3), + adaptive=False, n=50, modules="mpmath") + do_test(s1.get_data(), s2.get_data()) + + +def test_list2dseries_interactive(): + if not np: + skip("numpy not installed.") + + x, y, u = symbols("x, y, u") + + s = List2DSeries([1, 2, 3], [1, 2, 3]) + assert not s.is_interactive + + # symbolic expressions as coordinates, but no ``params`` + raises(ValueError, lambda: List2DSeries([cos(x)], [sin(x)])) + + # too few parameters + raises(ValueError, + lambda: List2DSeries([cos(x), y], [sin(x), 2], params={u: 1})) + + s = List2DSeries([cos(x)], [sin(x)], params={x: 1}) + assert s.is_interactive + + s = List2DSeries([x, 2, 3, 4], [4, 3, 2, x], params={x: 3}) + xx, yy = s.get_data() + assert np.allclose(xx, [3, 2, 3, 4]) + assert np.allclose(yy, [4, 3, 2, 3]) + assert not s.is_parametric + + # numeric lists + params is present -> interactive series and + # lists are converted to Tuple. + s = List2DSeries([1, 2, 3], [1, 2, 3], params={x: 1}) + assert s.is_interactive + assert isinstance(s.list_x, Tuple) + assert isinstance(s.list_y, Tuple) + + +def test_mpmath(): + # test that the argument of complex functions evaluated with mpmath + # might be different than the one computed with Numpy (different + # behaviour at branch cuts) + if not np: + skip("numpy not installed.") + + z, u = symbols("z, u") + + s1 = LineOver1DRangeSeries(im(sqrt(-z)), (z, 1e-03, 5), + adaptive=True, modules=None, force_real_eval=True) + s2 = LineOver1DRangeSeries(im(sqrt(-z)), (z, 1e-03, 5), + adaptive=True, modules="mpmath", force_real_eval=True) + xx1, yy1 = s1.get_data() + xx2, yy2 = s2.get_data() + assert np.all(yy1 < 0) + assert np.all(yy2 > 0) + + s1 = LineOver1DRangeSeries(im(sqrt(-z)), (z, -5, 5), + adaptive=False, n=20, modules=None, force_real_eval=True) + s2 = LineOver1DRangeSeries(im(sqrt(-z)), (z, -5, 5), + adaptive=False, n=20, modules="mpmath", force_real_eval=True) + xx1, yy1 = s1.get_data() + xx2, yy2 = s2.get_data() + assert np.allclose(xx1, xx2) + assert not np.allclose(yy1, yy2) + + +def test_str(): + u, x, y, z = symbols("u, x:z") + + s = LineOver1DRangeSeries(cos(x), (x, -4, 3)) + assert str(s) == "cartesian line: cos(x) for x over (-4.0, 3.0)" + + d = {"return": "real"} + s = LineOver1DRangeSeries(cos(x), (x, -4, 3), **d) + assert str(s) == "cartesian line: re(cos(x)) for x over (-4.0, 3.0)" + + d = {"return": "imag"} + s = LineOver1DRangeSeries(cos(x), (x, -4, 3), **d) + assert str(s) == "cartesian line: im(cos(x)) for x over (-4.0, 3.0)" + + d = {"return": "abs"} + s = LineOver1DRangeSeries(cos(x), (x, -4, 3), **d) + assert str(s) == "cartesian line: abs(cos(x)) for x over (-4.0, 3.0)" + + d = {"return": "arg"} + s = LineOver1DRangeSeries(cos(x), (x, -4, 3), **d) + assert str(s) == "cartesian line: arg(cos(x)) for x over (-4.0, 3.0)" + + s = LineOver1DRangeSeries(cos(u * x), (x, -4, 3), params={u: 1}) + assert str(s) == "interactive cartesian line: cos(u*x) for x over (-4.0, 3.0) and parameters (u,)" + + s = LineOver1DRangeSeries(cos(u * x), (x, -u, 3*y), params={u: 1, y: 1}) + assert str(s) == "interactive cartesian line: cos(u*x) for x over (-u, 3*y) and parameters (u, y)" + + s = Parametric2DLineSeries(cos(x), sin(x), (x, -4, 3)) + assert str(s) == "parametric cartesian line: (cos(x), sin(x)) for x over (-4.0, 3.0)" + + s = Parametric2DLineSeries(cos(u * x), sin(x), (x, -4, 3), params={u: 1}) + assert str(s) == "interactive parametric cartesian line: (cos(u*x), sin(x)) for x over (-4.0, 3.0) and parameters (u,)" + + s = Parametric2DLineSeries(cos(u * x), sin(x), (x, -u, 3*y), params={u: 1, y:1}) + assert str(s) == "interactive parametric cartesian line: (cos(u*x), sin(x)) for x over (-u, 3*y) and parameters (u, y)" + + s = Parametric3DLineSeries(cos(x), sin(x), x, (x, -4, 3)) + assert str(s) == "3D parametric cartesian line: (cos(x), sin(x), x) for x over (-4.0, 3.0)" + + s = Parametric3DLineSeries(cos(u*x), sin(x), x, (x, -4, 3), params={u: 1}) + assert str(s) == "interactive 3D parametric cartesian line: (cos(u*x), sin(x), x) for x over (-4.0, 3.0) and parameters (u,)" + + s = Parametric3DLineSeries(cos(u*x), sin(x), x, (x, -u, 3*y), params={u: 1, y: 1}) + assert str(s) == "interactive 3D parametric cartesian line: (cos(u*x), sin(x), x) for x over (-u, 3*y) and parameters (u, y)" + + s = SurfaceOver2DRangeSeries(cos(x * y), (x, -4, 3), (y, -2, 5)) + assert str(s) == "cartesian surface: cos(x*y) for x over (-4.0, 3.0) and y over (-2.0, 5.0)" + + s = SurfaceOver2DRangeSeries(cos(u * x * y), (x, -4, 3), (y, -2, 5), params={u: 1}) + assert str(s) == "interactive cartesian surface: cos(u*x*y) for x over (-4.0, 3.0) and y over (-2.0, 5.0) and parameters (u,)" + + s = SurfaceOver2DRangeSeries(cos(u * x * y), (x, -4*u, 3), (y, -2, 5*u), params={u: 1}) + assert str(s) == "interactive cartesian surface: cos(u*x*y) for x over (-4*u, 3.0) and y over (-2.0, 5*u) and parameters (u,)" + + s = ContourSeries(cos(x * y), (x, -4, 3), (y, -2, 5)) + assert str(s) == "contour: cos(x*y) for x over (-4.0, 3.0) and y over (-2.0, 5.0)" + + s = ContourSeries(cos(u * x * y), (x, -4, 3), (y, -2, 5), params={u: 1}) + assert str(s) == "interactive contour: cos(u*x*y) for x over (-4.0, 3.0) and y over (-2.0, 5.0) and parameters (u,)" + + s = ParametricSurfaceSeries(cos(x * y), sin(x * y), x * y, + (x, -4, 3), (y, -2, 5)) + assert str(s) == "parametric cartesian surface: (cos(x*y), sin(x*y), x*y) for x over (-4.0, 3.0) and y over (-2.0, 5.0)" + + s = ParametricSurfaceSeries(cos(u * x * y), sin(x * y), x * y, + (x, -4, 3), (y, -2, 5), params={u: 1}) + assert str(s) == "interactive parametric cartesian surface: (cos(u*x*y), sin(x*y), x*y) for x over (-4.0, 3.0) and y over (-2.0, 5.0) and parameters (u,)" + + s = ImplicitSeries(x < y, (x, -5, 4), (y, -3, 2)) + assert str(s) == "Implicit expression: x < y for x over (-5.0, 4.0) and y over (-3.0, 2.0)" + + +def test_use_cm(): + # verify that the `use_cm` attribute is implemented. + if not np: + skip("numpy not installed.") + + u, x, y, z = symbols("u, x:z") + + s = List2DSeries([1, 2, 3, 4], [5, 6, 7, 8], use_cm=True) + assert s.use_cm + s = List2DSeries([1, 2, 3, 4], [5, 6, 7, 8], use_cm=False) + assert not s.use_cm + + s = Parametric2DLineSeries(cos(x), sin(x), (x, -4, 3), use_cm=True) + assert s.use_cm + s = Parametric2DLineSeries(cos(x), sin(x), (x, -4, 3), use_cm=False) + assert not s.use_cm + + s = Parametric3DLineSeries(cos(x), sin(x), x, (x, -4, 3), + use_cm=True) + assert s.use_cm + s = Parametric3DLineSeries(cos(x), sin(x), x, (x, -4, 3), + use_cm=False) + assert not s.use_cm + + s = SurfaceOver2DRangeSeries(cos(x * y), (x, -4, 3), (y, -2, 5), + use_cm=True) + assert s.use_cm + s = SurfaceOver2DRangeSeries(cos(x * y), (x, -4, 3), (y, -2, 5), + use_cm=False) + assert not s.use_cm + + s = ParametricSurfaceSeries(cos(x * y), sin(x * y), x * y, + (x, -4, 3), (y, -2, 5), use_cm=True) + assert s.use_cm + s = ParametricSurfaceSeries(cos(x * y), sin(x * y), x * y, + (x, -4, 3), (y, -2, 5), use_cm=False) + assert not s.use_cm + + +def test_surface_use_cm(): + # verify that SurfaceOver2DRangeSeries and ParametricSurfaceSeries get + # the same value for use_cm + + x, y, u, v = symbols("x, y, u, v") + + # they read the same value from default settings + s1 = SurfaceOver2DRangeSeries(cos(x**2 + y**2), (x, -2, 2), (y, -2, 2)) + s2 = ParametricSurfaceSeries(u * cos(v), u * sin(v), u, + (u, 0, 1), (v, 0 , 2*pi)) + assert s1.use_cm == s2.use_cm + + # they get the same value + s1 = SurfaceOver2DRangeSeries(cos(x**2 + y**2), (x, -2, 2), (y, -2, 2), + use_cm=False) + s2 = ParametricSurfaceSeries(u * cos(v), u * sin(v), u, + (u, 0, 1), (v, 0 , 2*pi), use_cm=False) + assert s1.use_cm == s2.use_cm + + # they get the same value + s1 = SurfaceOver2DRangeSeries(cos(x**2 + y**2), (x, -2, 2), (y, -2, 2), + use_cm=True) + s2 = ParametricSurfaceSeries(u * cos(v), u * sin(v), u, + (u, 0, 1), (v, 0 , 2*pi), use_cm=True) + assert s1.use_cm == s2.use_cm + + +def test_sums(): + # test that data series are able to deal with sums + if not np: + skip("numpy not installed.") + + x, y, u = symbols("x, y, u") + + def do_test(data1, data2): + assert len(data1) == len(data2) + for d1, d2 in zip(data1, data2): + assert np.allclose(d1, d2) + + s = LineOver1DRangeSeries(Sum(1 / x ** y, (x, 1, 1000)), (y, 2, 10), + adaptive=False, only_integers=True) + xx, yy = s.get_data() + + s1 = LineOver1DRangeSeries(Sum(1 / x, (x, 1, y)), (y, 2, 10), + adaptive=False, only_integers=True) + xx1, yy1 = s1.get_data() + + s2 = LineOver1DRangeSeries(Sum(u / x, (x, 1, y)), (y, 2, 10), + params={u: 1}, only_integers=True) + xx2, yy2 = s2.get_data() + xx1 = xx1.astype(float) + xx2 = xx2.astype(float) + do_test([xx1, yy1], [xx2, yy2]) + + s = LineOver1DRangeSeries(Sum(1 / x, (x, 1, y)), (y, 2, 10), + adaptive=True) + with warns( + UserWarning, + match="The evaluation with NumPy/SciPy failed", + test_stacklevel=False, + ): + raises(TypeError, lambda: s.get_data()) + + +def test_apply_transforms(): + # verify that transformation functions get applied to the output + # of data series + if not np: + skip("numpy not installed.") + + x, y, z, u, v = symbols("x:z, u, v") + + s1 = LineOver1DRangeSeries(cos(x), (x, -2*pi, 2*pi), adaptive=False, n=10) + s2 = LineOver1DRangeSeries(cos(x), (x, -2*pi, 2*pi), adaptive=False, n=10, + tx=np.rad2deg) + s3 = LineOver1DRangeSeries(cos(x), (x, -2*pi, 2*pi), adaptive=False, n=10, + ty=np.rad2deg) + s4 = LineOver1DRangeSeries(cos(x), (x, -2*pi, 2*pi), adaptive=False, n=10, + tx=np.rad2deg, ty=np.rad2deg) + + x1, y1 = s1.get_data() + x2, y2 = s2.get_data() + x3, y3 = s3.get_data() + x4, y4 = s4.get_data() + assert np.isclose(x1[0], -2*np.pi) and np.isclose(x1[-1], 2*np.pi) + assert (y1.min() < -0.9) and (y1.max() > 0.9) + assert np.isclose(x2[0], -360) and np.isclose(x2[-1], 360) + assert (y2.min() < -0.9) and (y2.max() > 0.9) + assert np.isclose(x3[0], -2*np.pi) and np.isclose(x3[-1], 2*np.pi) + assert (y3.min() < -52) and (y3.max() > 52) + assert np.isclose(x4[0], -360) and np.isclose(x4[-1], 360) + assert (y4.min() < -52) and (y4.max() > 52) + + xx = np.linspace(-2*np.pi, 2*np.pi, 10) + yy = np.cos(xx) + s1 = List2DSeries(xx, yy) + s2 = List2DSeries(xx, yy, tx=np.rad2deg, ty=np.rad2deg) + x1, y1 = s1.get_data() + x2, y2 = s2.get_data() + assert np.isclose(x1[0], -2*np.pi) and np.isclose(x1[-1], 2*np.pi) + assert (y1.min() < -0.9) and (y1.max() > 0.9) + assert np.isclose(x2[0], -360) and np.isclose(x2[-1], 360) + assert (y2.min() < -52) and (y2.max() > 52) + + s1 = Parametric2DLineSeries( + sin(x), cos(x), (x, -pi, pi), adaptive=False, n=10) + s2 = Parametric2DLineSeries( + sin(x), cos(x), (x, -pi, pi), adaptive=False, n=10, + tx=np.rad2deg, ty=np.rad2deg, tp=np.rad2deg) + x1, y1, a1 = s1.get_data() + x2, y2, a2 = s2.get_data() + assert np.allclose(x1, np.deg2rad(x2)) + assert np.allclose(y1, np.deg2rad(y2)) + assert np.allclose(a1, np.deg2rad(a2)) + + s1 = Parametric3DLineSeries( + sin(x), cos(x), x, (x, -pi, pi), adaptive=False, n=10) + s2 = Parametric3DLineSeries( + sin(x), cos(x), x, (x, -pi, pi), adaptive=False, n=10, tp=np.rad2deg) + x1, y1, z1, a1 = s1.get_data() + x2, y2, z2, a2 = s2.get_data() + assert np.allclose(x1, x2) + assert np.allclose(y1, y2) + assert np.allclose(z1, z2) + assert np.allclose(a1, np.deg2rad(a2)) + + s1 = SurfaceOver2DRangeSeries( + cos(x**2 + y**2), (x, -2*pi, 2*pi), (y, -2*pi, 2*pi), + adaptive=False, n1=10, n2=10) + s2 = SurfaceOver2DRangeSeries( + cos(x**2 + y**2), (x, -2*pi, 2*pi), (y, -2*pi, 2*pi), + adaptive=False, n1=10, n2=10, + tx=np.rad2deg, ty=lambda x: 2*x, tz=lambda x: 3*x) + x1, y1, z1 = s1.get_data() + x2, y2, z2 = s2.get_data() + assert np.allclose(x1, np.deg2rad(x2)) + assert np.allclose(y1, y2 / 2) + assert np.allclose(z1, z2 / 3) + + s1 = ParametricSurfaceSeries( + u + v, u - v, u * v, (u, 0, 2*pi), (v, 0, pi), + adaptive=False, n1=10, n2=10) + s2 = ParametricSurfaceSeries( + u + v, u - v, u * v, (u, 0, 2*pi), (v, 0, pi), + adaptive=False, n1=10, n2=10, + tx=np.rad2deg, ty=lambda x: 2*x, tz=lambda x: 3*x) + x1, y1, z1, u1, v1 = s1.get_data() + x2, y2, z2, u2, v2 = s2.get_data() + assert np.allclose(x1, np.deg2rad(x2)) + assert np.allclose(y1, y2 / 2) + assert np.allclose(z1, z2 / 3) + assert np.allclose(u1, u2) + assert np.allclose(v1, v2) + + +def test_series_labels(): + # verify that series return the correct label, depending on the plot + # type and input arguments. If the user set custom label on a data series, + # it should returned un-modified. + if not np: + skip("numpy not installed.") + + x, y, z, u, v = symbols("x, y, z, u, v") + wrapper = "$%s$" + + expr = cos(x) + s1 = LineOver1DRangeSeries(expr, (x, -2, 2), None) + s2 = LineOver1DRangeSeries(expr, (x, -2, 2), "test") + assert s1.get_label(False) == str(expr) + assert s1.get_label(True) == wrapper % latex(expr) + assert s2.get_label(False) == "test" + assert s2.get_label(True) == "test" + + s1 = List2DSeries([0, 1, 2, 3], [0, 1, 2, 3], "test") + assert s1.get_label(False) == "test" + assert s1.get_label(True) == "test" + + expr = (cos(x), sin(x)) + s1 = Parametric2DLineSeries(*expr, (x, -2, 2), None, use_cm=True) + s2 = Parametric2DLineSeries(*expr, (x, -2, 2), "test", use_cm=True) + s3 = Parametric2DLineSeries(*expr, (x, -2, 2), None, use_cm=False) + s4 = Parametric2DLineSeries(*expr, (x, -2, 2), "test", use_cm=False) + assert s1.get_label(False) == "x" + assert s1.get_label(True) == wrapper % "x" + assert s2.get_label(False) == "test" + assert s2.get_label(True) == "test" + assert s3.get_label(False) == str(expr) + assert s3.get_label(True) == wrapper % latex(expr) + assert s4.get_label(False) == "test" + assert s4.get_label(True) == "test" + + expr = (cos(x), sin(x), x) + s1 = Parametric3DLineSeries(*expr, (x, -2, 2), None, use_cm=True) + s2 = Parametric3DLineSeries(*expr, (x, -2, 2), "test", use_cm=True) + s3 = Parametric3DLineSeries(*expr, (x, -2, 2), None, use_cm=False) + s4 = Parametric3DLineSeries(*expr, (x, -2, 2), "test", use_cm=False) + assert s1.get_label(False) == "x" + assert s1.get_label(True) == wrapper % "x" + assert s2.get_label(False) == "test" + assert s2.get_label(True) == "test" + assert s3.get_label(False) == str(expr) + assert s3.get_label(True) == wrapper % latex(expr) + assert s4.get_label(False) == "test" + assert s4.get_label(True) == "test" + + expr = cos(x**2 + y**2) + s1 = SurfaceOver2DRangeSeries(expr, (x, -2, 2), (y, -2, 2), None) + s2 = SurfaceOver2DRangeSeries(expr, (x, -2, 2), (y, -2, 2), "test") + assert s1.get_label(False) == str(expr) + assert s1.get_label(True) == wrapper % latex(expr) + assert s2.get_label(False) == "test" + assert s2.get_label(True) == "test" + + expr = (cos(x - y), sin(x + y), x - y) + s1 = ParametricSurfaceSeries(*expr, (x, -2, 2), (y, -2, 2), None) + s2 = ParametricSurfaceSeries(*expr, (x, -2, 2), (y, -2, 2), "test") + assert s1.get_label(False) == str(expr) + assert s1.get_label(True) == wrapper % latex(expr) + assert s2.get_label(False) == "test" + assert s2.get_label(True) == "test" + + expr = Eq(cos(x - y), 0) + s1 = ImplicitSeries(expr, (x, -10, 10), (y, -10, 10), None) + s2 = ImplicitSeries(expr, (x, -10, 10), (y, -10, 10), "test") + assert s1.get_label(False) == str(expr) + assert s1.get_label(True) == wrapper % latex(expr) + assert s2.get_label(False) == "test" + assert s2.get_label(True) == "test" + + +def test_is_polar_2d_parametric(): + # verify that Parametric2DLineSeries isable to apply polar discretization, + # which is used when polar_plot is executed with polar_axis=True + if not np: + skip("numpy not installed.") + + t, u = symbols("t u") + + # NOTE: a sufficiently big n must be provided, or else tests + # are going to fail + # No colormap + f = sin(4 * t) + s1 = Parametric2DLineSeries(f * cos(t), f * sin(t), (t, 0, 2*pi), + adaptive=False, n=10, is_polar=False, use_cm=False) + x1, y1, p1 = s1.get_data() + s2 = Parametric2DLineSeries(f * cos(t), f * sin(t), (t, 0, 2*pi), + adaptive=False, n=10, is_polar=True, use_cm=False) + th, r, p2 = s2.get_data() + assert (not np.allclose(x1, th)) and (not np.allclose(y1, r)) + assert np.allclose(p1, p2) + + # With colormap + s3 = Parametric2DLineSeries(f * cos(t), f * sin(t), (t, 0, 2*pi), + adaptive=False, n=10, is_polar=False, color_func=lambda t: 2*t) + x3, y3, p3 = s3.get_data() + s4 = Parametric2DLineSeries(f * cos(t), f * sin(t), (t, 0, 2*pi), + adaptive=False, n=10, is_polar=True, color_func=lambda t: 2*t) + th4, r4, p4 = s4.get_data() + assert np.allclose(p3, p4) and (not np.allclose(p1, p3)) + assert np.allclose(x3, x1) and np.allclose(y3, y1) + assert np.allclose(th4, th) and np.allclose(r4, r) + + +def test_is_polar_3d(): + # verify that SurfaceOver2DRangeSeries is able to apply + # polar discretization + if not np: + skip("numpy not installed.") + + x, y, t = symbols("x, y, t") + expr = (x**2 - 1)**2 + s1 = SurfaceOver2DRangeSeries(expr, (x, 0, 1.5), (y, 0, 2 * pi), + n=10, adaptive=False, is_polar=False) + s2 = SurfaceOver2DRangeSeries(expr, (x, 0, 1.5), (y, 0, 2 * pi), + n=10, adaptive=False, is_polar=True) + x1, y1, z1 = s1.get_data() + x2, y2, z2 = s2.get_data() + x22, y22 = x1 * np.cos(y1), x1 * np.sin(y1) + assert np.allclose(x2, x22) + assert np.allclose(y2, y22) + + +def test_color_func(): + # verify that eval_color_func produces the expected results in order to + # maintain back compatibility with the old sympy.plotting module + if not np: + skip("numpy not installed.") + + x, y, z, u, v = symbols("x, y, z, u, v") + + # color func: returns x, y, color and s is parametric + xx = np.linspace(-3, 3, 10) + yy1 = np.cos(xx) + s = List2DSeries(xx, yy1, color_func=lambda x, y: 2 * x, use_cm=True) + xxs, yys, col = s.get_data() + assert np.allclose(xx, xxs) + assert np.allclose(yy1, yys) + assert np.allclose(2 * xx, col) + assert s.is_parametric + + s = List2DSeries(xx, yy1, color_func=lambda x, y: 2 * x, use_cm=False) + assert len(s.get_data()) == 2 + assert not s.is_parametric + + s = Parametric2DLineSeries(cos(x), sin(x), (x, 0, 2*pi), + adaptive=False, n=10, color_func=lambda t: t) + xx, yy, col = s.get_data() + assert (not np.allclose(xx, col)) and (not np.allclose(yy, col)) + s = Parametric2DLineSeries(cos(x), sin(x), (x, 0, 2*pi), + adaptive=False, n=10, color_func=lambda x, y: x * y) + xx, yy, col = s.get_data() + assert np.allclose(col, xx * yy) + s = Parametric2DLineSeries(cos(x), sin(x), (x, 0, 2*pi), + adaptive=False, n=10, color_func=lambda x, y, t: x * y * t) + xx, yy, col = s.get_data() + assert np.allclose(col, xx * yy * np.linspace(0, 2*np.pi, 10)) + + s = Parametric3DLineSeries(cos(x), sin(x), x, (x, 0, 2*pi), + adaptive=False, n=10, color_func=lambda t: t) + xx, yy, zz, col = s.get_data() + assert (not np.allclose(xx, col)) and (not np.allclose(yy, col)) + s = Parametric3DLineSeries(cos(x), sin(x), x, (x, 0, 2*pi), + adaptive=False, n=10, color_func=lambda x, y, z: x * y * z) + xx, yy, zz, col = s.get_data() + assert np.allclose(col, xx * yy * zz) + s = Parametric3DLineSeries(cos(x), sin(x), x, (x, 0, 2*pi), + adaptive=False, n=10, color_func=lambda x, y, z, t: x * y * z * t) + xx, yy, zz, col = s.get_data() + assert np.allclose(col, xx * yy * zz * np.linspace(0, 2*np.pi, 10)) + + s = SurfaceOver2DRangeSeries(cos(x**2 + y**2), (x, -2, 2), (y, -2, 2), + adaptive=False, n1=10, n2=10, color_func=lambda x: x) + xx, yy, zz = s.get_data() + col = s.eval_color_func(xx, yy, zz) + assert np.allclose(xx, col) + s = SurfaceOver2DRangeSeries(cos(x**2 + y**2), (x, -2, 2), (y, -2, 2), + adaptive=False, n1=10, n2=10, color_func=lambda x, y: x * y) + xx, yy, zz = s.get_data() + col = s.eval_color_func(xx, yy, zz) + assert np.allclose(xx * yy, col) + s = SurfaceOver2DRangeSeries(cos(x**2 + y**2), (x, -2, 2), (y, -2, 2), + adaptive=False, n1=10, n2=10, color_func=lambda x, y, z: x * y * z) + xx, yy, zz = s.get_data() + col = s.eval_color_func(xx, yy, zz) + assert np.allclose(xx * yy * zz, col) + + s = ParametricSurfaceSeries(1, x, y, (x, 0, 1), (y, 0, 1), adaptive=False, + n1=10, n2=10, color_func=lambda u:u) + xx, yy, zz, uu, vv = s.get_data() + col = s.eval_color_func(xx, yy, zz, uu, vv) + assert np.allclose(uu, col) + s = ParametricSurfaceSeries(1, x, y, (x, 0, 1), (y, 0, 1), adaptive=False, + n1=10, n2=10, color_func=lambda u, v: u * v) + xx, yy, zz, uu, vv = s.get_data() + col = s.eval_color_func(xx, yy, zz, uu, vv) + assert np.allclose(uu * vv, col) + s = ParametricSurfaceSeries(1, x, y, (x, 0, 1), (y, 0, 1), adaptive=False, + n1=10, n2=10, color_func=lambda x, y, z: x * y * z) + xx, yy, zz, uu, vv = s.get_data() + col = s.eval_color_func(xx, yy, zz, uu, vv) + assert np.allclose(xx * yy * zz, col) + s = ParametricSurfaceSeries(1, x, y, (x, 0, 1), (y, 0, 1), adaptive=False, + n1=10, n2=10, color_func=lambda x, y, z, u, v: x * y * z * u * v) + xx, yy, zz, uu, vv = s.get_data() + col = s.eval_color_func(xx, yy, zz, uu, vv) + assert np.allclose(xx * yy * zz * uu * vv, col) + + # Interactive Series + s = List2DSeries([0, 1, 2, x], [x, 2, 3, 4], + color_func=lambda x, y: 2 * x, params={x: 1}, use_cm=True) + xx, yy, col = s.get_data() + assert np.allclose(xx, [0, 1, 2, 1]) + assert np.allclose(yy, [1, 2, 3, 4]) + assert np.allclose(2 * xx, col) + assert s.is_parametric and s.use_cm + + s = List2DSeries([0, 1, 2, x], [x, 2, 3, 4], + color_func=lambda x, y: 2 * x, params={x: 1}, use_cm=False) + assert len(s.get_data()) == 2 + assert not s.is_parametric + + +def test_color_func_scalar_val(): + # verify that eval_color_func returns a numpy array even when color_func + # evaluates to a scalar value + if not np: + skip("numpy not installed.") + + x, y = symbols("x, y") + + s = Parametric2DLineSeries(cos(x), sin(x), (x, 0, 2*pi), + adaptive=False, n=10, color_func=lambda t: 1) + xx, yy, col = s.get_data() + assert np.allclose(col, np.ones(xx.shape)) + + s = Parametric3DLineSeries(cos(x), sin(x), x, (x, 0, 2*pi), + adaptive=False, n=10, color_func=lambda t: 1) + xx, yy, zz, col = s.get_data() + assert np.allclose(col, np.ones(xx.shape)) + + s = SurfaceOver2DRangeSeries(cos(x**2 + y**2), (x, -2, 2), (y, -2, 2), + adaptive=False, n1=10, n2=10, color_func=lambda x: 1) + xx, yy, zz = s.get_data() + assert np.allclose(s.eval_color_func(xx), np.ones(xx.shape)) + + s = ParametricSurfaceSeries(1, x, y, (x, 0, 1), (y, 0, 1), adaptive=False, + n1=10, n2=10, color_func=lambda u: 1) + xx, yy, zz, uu, vv = s.get_data() + col = s.eval_color_func(xx, yy, zz, uu, vv) + assert np.allclose(col, np.ones(xx.shape)) + + +def test_color_func_expression(): + # verify that color_func is able to deal with instances of Expr: they will + # be lambdified with the same signature used for the main expression. + if not np: + skip("numpy not installed.") + + x, y = symbols("x, y") + + s1 = Parametric2DLineSeries(cos(x), sin(x), (x, 0, 2*pi), + color_func=sin(x), adaptive=False, n=10, use_cm=True) + s2 = Parametric2DLineSeries(cos(x), sin(x), (x, 0, 2*pi), + color_func=lambda x: np.cos(x), adaptive=False, n=10, use_cm=True) + # the following statement should not raise errors + d1 = s1.get_data() + assert callable(s1.color_func) + d2 = s2.get_data() + assert not np.allclose(d1[-1], d2[-1]) + + s = SurfaceOver2DRangeSeries(cos(x**2 + y**2), (x, -pi, pi), (y, -pi, pi), + color_func=sin(x**2 + y**2), adaptive=False, n1=5, n2=5) + # the following statement should not raise errors + s.get_data() + assert callable(s.color_func) + + xx = [1, 2, 3, 4, 5] + yy = [1, 2, 3, 4, 5] + raises(TypeError, + lambda : List2DSeries(xx, yy, use_cm=True, color_func=sin(x))) + + +def test_line_surface_color(): + # verify the back-compatibility with the old sympy.plotting module. + # By setting line_color or surface_color to be a callable, it will set + # the color_func attribute. + + x, y, z = symbols("x, y, z") + + s = LineOver1DRangeSeries(sin(x), (x, -5, 5), adaptive=False, n=10, + line_color=lambda x: x) + assert (s.line_color is None) and callable(s.color_func) + + s = Parametric2DLineSeries(cos(x), sin(x), (x, 0, 2*pi), + adaptive=False, n=10, line_color=lambda t: t) + assert (s.line_color is None) and callable(s.color_func) + + s = SurfaceOver2DRangeSeries(cos(x**2 + y**2), (x, -2, 2), (y, -2, 2), + n1=10, n2=10, surface_color=lambda x: x) + assert (s.surface_color is None) and callable(s.color_func) + + +def test_complex_adaptive_false(): + # verify that series with adaptive=False is evaluated with discretized + # ranges of type complex. + if not np: + skip("numpy not installed.") + + x, y, u = symbols("x y u") + + def do_test(data1, data2): + assert len(data1) == len(data2) + for d1, d2 in zip(data1, data2): + assert np.allclose(d1, d2) + + expr1 = sqrt(x) * exp(-x**2) + expr2 = sqrt(u * x) * exp(-x**2) + s1 = LineOver1DRangeSeries(im(expr1), (x, -5, 5), adaptive=False, n=10) + s2 = LineOver1DRangeSeries(im(expr2), (x, -5, 5), + adaptive=False, n=10, params={u: 1}) + data1 = s1.get_data() + data2 = s2.get_data() + + do_test(data1, data2) + assert (not np.allclose(data1[1], 0)) and (not np.allclose(data2[1], 0)) + + s1 = Parametric2DLineSeries(re(expr1), im(expr1), (x, -pi, pi), + adaptive=False, n=10) + s2 = Parametric2DLineSeries(re(expr2), im(expr2), (x, -pi, pi), + adaptive=False, n=10, params={u: 1}) + data1 = s1.get_data() + data2 = s2.get_data() + do_test(data1, data2) + assert (not np.allclose(data1[1], 0)) and (not np.allclose(data2[1], 0)) + + s1 = SurfaceOver2DRangeSeries(im(expr1), (x, -5, 5), (y, -10, 10), + adaptive=False, n1=30, n2=3) + s2 = SurfaceOver2DRangeSeries(im(expr2), (x, -5, 5), (y, -10, 10), + adaptive=False, n1=30, n2=3, params={u: 1}) + data1 = s1.get_data() + data2 = s2.get_data() + do_test(data1, data2) + assert (not np.allclose(data1[1], 0)) and (not np.allclose(data2[1], 0)) + + +def test_expr_is_lambda_function(): + # verify that when a numpy function is provided, the series will be able + # to evaluate it. Also, label should be empty in order to prevent some + # backend from crashing. + if not np: + skip("numpy not installed.") + + f = lambda x: np.cos(x) + s1 = LineOver1DRangeSeries(f, ("x", -5, 5), adaptive=True, depth=3) + s1.get_data() + s2 = LineOver1DRangeSeries(f, ("x", -5, 5), adaptive=False, n=10) + s2.get_data() + assert s1.label == s2.label == "" + + fx = lambda x: np.cos(x) + fy = lambda x: np.sin(x) + s1 = Parametric2DLineSeries(fx, fy, ("x", 0, 2*pi), + adaptive=True, adaptive_goal=0.1) + s1.get_data() + s2 = Parametric2DLineSeries(fx, fy, ("x", 0, 2*pi), + adaptive=False, n=10) + s2.get_data() + assert s1.label == s2.label == "" + + fz = lambda x: x + s1 = Parametric3DLineSeries(fx, fy, fz, ("x", 0, 2*pi), + adaptive=True, adaptive_goal=0.1) + s1.get_data() + s2 = Parametric3DLineSeries(fx, fy, fz, ("x", 0, 2*pi), + adaptive=False, n=10) + s2.get_data() + assert s1.label == s2.label == "" + + f = lambda x, y: np.cos(x**2 + y**2) + s1 = SurfaceOver2DRangeSeries(f, ("a", -2, 2), ("b", -3, 3), + adaptive=False, n1=10, n2=10) + s1.get_data() + s2 = ContourSeries(f, ("a", -2, 2), ("b", -3, 3), + adaptive=False, n1=10, n2=10) + s2.get_data() + assert s1.label == s2.label == "" + + fx = lambda u, v: np.cos(u + v) + fy = lambda u, v: np.sin(u - v) + fz = lambda u, v: u * v + s1 = ParametricSurfaceSeries(fx, fy, fz, ("u", 0, pi), ("v", 0, 2*pi), + adaptive=False, n1=10, n2=10) + s1.get_data() + assert s1.label == "" + + raises(TypeError, lambda: List2DSeries(lambda t: t, lambda t: t)) + raises(TypeError, lambda : ImplicitSeries(lambda t: np.sin(t), + ("x", -5, 5), ("y", -6, 6))) + + +def test_show_in_legend_lines(): + # verify that lines series correctly set the show_in_legend attribute + x, u = symbols("x, u") + + s = LineOver1DRangeSeries(cos(x), (x, -2, 2), "test", show_in_legend=True) + assert s.show_in_legend + s = LineOver1DRangeSeries(cos(x), (x, -2, 2), "test", show_in_legend=False) + assert not s.show_in_legend + + s = Parametric2DLineSeries(cos(x), sin(x), (x, 0, 1), "test", + show_in_legend=True) + assert s.show_in_legend + s = Parametric2DLineSeries(cos(x), sin(x), (x, 0, 1), "test", + show_in_legend=False) + assert not s.show_in_legend + + s = Parametric3DLineSeries(cos(x), sin(x), x, (x, 0, 1), "test", + show_in_legend=True) + assert s.show_in_legend + s = Parametric3DLineSeries(cos(x), sin(x), x, (x, 0, 1), "test", + show_in_legend=False) + assert not s.show_in_legend + + +@XFAIL +def test_particular_case_1_with_adaptive_true(): + # Verify that symbolic expressions and numerical lambda functions are + # evaluated with the same algorithm. + if not np: + skip("numpy not installed.") + + # NOTE: xfail because sympy's adaptive algorithm is not deterministic + + def do_test(a, b): + with warns( + RuntimeWarning, + match="invalid value encountered in scalar power", + test_stacklevel=False, + ): + d1 = a.get_data() + d2 = b.get_data() + for t, v in zip(d1, d2): + assert np.allclose(t, v) + + n = symbols("n") + a = S(2) / 3 + epsilon = 0.01 + xn = (n**3 + n**2)**(S(1)/3) - (n**3 - n**2)**(S(1)/3) + expr = Abs(xn - a) - epsilon + math_func = lambdify([n], expr) + s1 = LineOver1DRangeSeries(expr, (n, -10, 10), "", + adaptive=True, depth=3) + s2 = LineOver1DRangeSeries(math_func, ("n", -10, 10), "", + adaptive=True, depth=3) + do_test(s1, s2) + + +def test_particular_case_1_with_adaptive_false(): + # Verify that symbolic expressions and numerical lambda functions are + # evaluated with the same algorithm. In particular, uniform evaluation + # is going to use np.vectorize, which correctly evaluates the following + # mathematical function. + if not np: + skip("numpy not installed.") + + def do_test(a, b): + d1 = a.get_data() + d2 = b.get_data() + for t, v in zip(d1, d2): + assert np.allclose(t, v) + + n = symbols("n") + a = S(2) / 3 + epsilon = 0.01 + xn = (n**3 + n**2)**(S(1)/3) - (n**3 - n**2)**(S(1)/3) + expr = Abs(xn - a) - epsilon + math_func = lambdify([n], expr) + + s3 = LineOver1DRangeSeries(expr, (n, -10, 10), "", + adaptive=False, n=10) + s4 = LineOver1DRangeSeries(math_func, ("n", -10, 10), "", + adaptive=False, n=10) + do_test(s3, s4) + + +def test_complex_params_number_eval(): + # The main expression contains terms like sqrt(xi - 1), with + # parameter (0 <= xi <= 1). + # There shouldn't be any NaN values on the output. + if not np: + skip("numpy not installed.") + + xi, wn, x0, v0, t = symbols("xi, omega_n, x0, v0, t") + x = Function("x")(t) + eq = x.diff(t, 2) + 2 * xi * wn * x.diff(t) + wn**2 * x + sol = dsolve(eq, x, ics={x.subs(t, 0): x0, x.diff(t).subs(t, 0): v0}) + params = { + wn: 0.5, + xi: 0.25, + x0: 0.45, + v0: 0.0 + } + s = LineOver1DRangeSeries(sol.rhs, (t, 0, 100), adaptive=False, n=5, + params=params) + x, y = s.get_data() + assert not np.isnan(x).any() + assert not np.isnan(y).any() + + + # Fourier Series of a sawtooth wave + # The main expression contains a Sum with a symbolic upper range. + # The lambdified code looks like: + # sum(blablabla for for n in range(1, m+1)) + # But range requires integer numbers, whereas per above example, the series + # casts parameters to complex. Verify that the series is able to detect + # upper bounds in summations and cast it to int in order to get successful + # evaluation + x, T, n, m = symbols("x, T, n, m") + fs = S(1) / 2 - (1 / pi) * Sum(sin(2 * n * pi * x / T) / n, (n, 1, m)) + params = { + T: 4.5, + m: 5 + } + s = LineOver1DRangeSeries(fs, (x, 0, 10), adaptive=False, n=5, + params=params) + x, y = s.get_data() + assert not np.isnan(x).any() + assert not np.isnan(y).any() + + +def test_complex_range_line_plot_1(): + # verify that univariate functions are evaluated with a complex + # data range (with zero imaginary part). There shouldn't be any + # NaN value in the output. + if not np: + skip("numpy not installed.") + + x, u = symbols("x, u") + expr1 = im(sqrt(x) * exp(-x**2)) + expr2 = im(sqrt(u * x) * exp(-x**2)) + s1 = LineOver1DRangeSeries(expr1, (x, -10, 10), adaptive=True, + adaptive_goal=0.1) + s2 = LineOver1DRangeSeries(expr1, (x, -10, 10), adaptive=False, n=30) + s3 = LineOver1DRangeSeries(expr2, (x, -10, 10), adaptive=False, n=30, + params={u: 1}) + + with ignore_warnings(RuntimeWarning): + data1 = s1.get_data() + data2 = s2.get_data() + data3 = s3.get_data() + + assert not np.isnan(data1[1]).any() + assert not np.isnan(data2[1]).any() + assert not np.isnan(data3[1]).any() + assert np.allclose(data2[0], data3[0]) and np.allclose(data2[1], data3[1]) + + +@XFAIL +def test_complex_range_line_plot_2(): + # verify that univariate functions are evaluated with a complex + # data range (with non-zero imaginary part). There shouldn't be any + # NaN value in the output. + if not np: + skip("numpy not installed.") + + # NOTE: xfail because sympy's adaptive algorithm is unable to deal with + # complex number. + + x, u = symbols("x, u") + + # adaptive and uniform meshing should produce the same data. + # because of the adaptive nature, just compare the first and last points + # of both series. + s1 = LineOver1DRangeSeries(abs(sqrt(x)), (x, -5-2j, 5-2j), adaptive=True) + s2 = LineOver1DRangeSeries(abs(sqrt(x)), (x, -5-2j, 5-2j), adaptive=False, + n=10) + with warns( + RuntimeWarning, + match="invalid value encountered in sqrt", + test_stacklevel=False, + ): + d1 = s1.get_data() + d2 = s2.get_data() + xx1 = [d1[0][0], d1[0][-1]] + xx2 = [d2[0][0], d2[0][-1]] + yy1 = [d1[1][0], d1[1][-1]] + yy2 = [d2[1][0], d2[1][-1]] + assert np.allclose(xx1, xx2) + assert np.allclose(yy1, yy2) + + +def test_force_real_eval(): + # verify that force_real_eval=True produces inconsistent results when + # compared with evaluation of complex domain. + if not np: + skip("numpy not installed.") + + x = symbols("x") + + expr = im(sqrt(x) * exp(-x**2)) + s1 = LineOver1DRangeSeries(expr, (x, -10, 10), adaptive=False, n=10, + force_real_eval=False) + s2 = LineOver1DRangeSeries(expr, (x, -10, 10), adaptive=False, n=10, + force_real_eval=True) + d1 = s1.get_data() + with ignore_warnings(RuntimeWarning): + d2 = s2.get_data() + assert not np.allclose(d1[1], 0) + assert np.allclose(d2[1], 0) + + +def test_contour_series_show_clabels(): + # verify that a contour series has the abiliy to set the visibility of + # labels to contour lines + + x, y = symbols("x, y") + s = ContourSeries(cos(x*y), (x, -2, 2), (y, -2, 2)) + assert s.show_clabels + + s = ContourSeries(cos(x*y), (x, -2, 2), (y, -2, 2), clabels=True) + assert s.show_clabels + + s = ContourSeries(cos(x*y), (x, -2, 2), (y, -2, 2), clabels=False) + assert not s.show_clabels + + +def test_LineOver1DRangeSeries_complex_range(): + # verify that LineOver1DRangeSeries can accept a complex range + # if the imaginary part of the start and end values are the same + + x = symbols("x") + + LineOver1DRangeSeries(sqrt(x), (x, -10, 10)) + LineOver1DRangeSeries(sqrt(x), (x, -10-2j, 10-2j)) + raises(ValueError, + lambda : LineOver1DRangeSeries(sqrt(x), (x, -10-2j, 10+2j))) + + +def test_symbolic_plotting_ranges(): + # verify that data series can use symbolic plotting ranges + if not np: + skip("numpy not installed.") + + x, y, z, a, b = symbols("x, y, z, a, b") + + def do_test(s1, s2, new_params): + d1 = s1.get_data() + d2 = s2.get_data() + for u, v in zip(d1, d2): + assert np.allclose(u, v) + s2.params = new_params + d2 = s2.get_data() + for u, v in zip(d1, d2): + assert not np.allclose(u, v) + + s1 = LineOver1DRangeSeries(sin(x), (x, 0, 1), adaptive=False, n=10) + s2 = LineOver1DRangeSeries(sin(x), (x, a, b), params={a: 0, b: 1}, + adaptive=False, n=10) + do_test(s1, s2, {a: 0.5, b: 1.5}) + + # missing a parameter + raises(ValueError, + lambda : LineOver1DRangeSeries(sin(x), (x, a, b), params={a: 1}, n=10)) + + s1 = Parametric2DLineSeries(cos(x), sin(x), (x, 0, 1), adaptive=False, n=10) + s2 = Parametric2DLineSeries(cos(x), sin(x), (x, a, b), params={a: 0, b: 1}, + adaptive=False, n=10) + do_test(s1, s2, {a: 0.5, b: 1.5}) + + # missing a parameter + raises(ValueError, + lambda : Parametric2DLineSeries(cos(x), sin(x), (x, a, b), + params={a: 0}, adaptive=False, n=10)) + + s1 = Parametric3DLineSeries(cos(x), sin(x), x, (x, 0, 1), + adaptive=False, n=10) + s2 = Parametric3DLineSeries(cos(x), sin(x), x, (x, a, b), + params={a: 0, b: 1}, adaptive=False, n=10) + do_test(s1, s2, {a: 0.5, b: 1.5}) + + # missing a parameter + raises(ValueError, + lambda : Parametric3DLineSeries(cos(x), sin(x), x, (x, a, b), + params={a: 0}, adaptive=False, n=10)) + + s1 = SurfaceOver2DRangeSeries(cos(x**2 + y**2), (x, -pi, pi), (y, -pi, pi), + adaptive=False, n1=5, n2=5) + s2 = SurfaceOver2DRangeSeries(cos(x**2 + y**2), (x, -pi * a, pi * a), + (y, -pi * b, pi * b), params={a: 1, b: 1}, + adaptive=False, n1=5, n2=5) + do_test(s1, s2, {a: 0.5, b: 1.5}) + + # missing a parameter + raises(ValueError, + lambda : SurfaceOver2DRangeSeries(cos(x**2 + y**2), + (x, -pi * a, pi * a), (y, -pi * b, pi * b), params={a: 1}, + adaptive=False, n1=5, n2=5)) + # one range symbol is included into another range's minimum or maximum val + raises(ValueError, + lambda : SurfaceOver2DRangeSeries(cos(x**2 + y**2), + (x, -pi * a + y, pi * a), (y, -pi * b, pi * b), params={a: 1}, + adaptive=False, n1=5, n2=5)) + + s1 = ParametricSurfaceSeries( + cos(x - y), sin(x + y), x - y, (x, -2, 2), (y, -2, 2), n1=5, n2=5) + s2 = ParametricSurfaceSeries( + cos(x - y), sin(x + y), x - y, (x, -2 * a, 2), (y, -2, 2 * b), + params={a: 1, b: 1}, n1=5, n2=5) + do_test(s1, s2, {a: 0.5, b: 1.5}) + + # missing a parameter + raises(ValueError, + lambda : ParametricSurfaceSeries( + cos(x - y), sin(x + y), x - y, (x, -2 * a, 2), (y, -2, 2 * b), + params={a: 1}, n1=5, n2=5)) + + +def test_exclude_points(): + # verify that exclude works as expected + if not np: + skip("numpy not installed.") + + x = symbols("x") + + expr = (floor(x) + S.Half) / (1 - (x - S.Half)**2) + + with warns( + UserWarning, + match="NumPy is unable to evaluate with complex numbers some", + test_stacklevel=False, + ): + s = LineOver1DRangeSeries(expr, (x, -3.5, 3.5), adaptive=False, n=100, + exclude=list(range(-3, 4))) + xx, yy = s.get_data() + assert not np.isnan(xx).any() + assert np.count_nonzero(np.isnan(yy)) == 7 + assert len(xx) > 100 + + e1 = log(floor(x)) * cos(x) + e2 = log(floor(x)) * sin(x) + with warns( + UserWarning, + match="NumPy is unable to evaluate with complex numbers some", + test_stacklevel=False, + ): + s = Parametric2DLineSeries(e1, e2, (x, 1, 12), adaptive=False, n=100, + exclude=list(range(1, 13))) + xx, yy, pp = s.get_data() + assert not np.isnan(pp).any() + assert np.count_nonzero(np.isnan(xx)) == 11 + assert np.count_nonzero(np.isnan(yy)) == 11 + assert len(xx) > 100 + + +def test_unwrap(): + # verify that unwrap works as expected + if not np: + skip("numpy not installed.") + + x, y = symbols("x, y") + expr = 1 / (x**3 + 2*x**2 + x) + expr = arg(expr.subs(x, I*y*2*pi)) + s1 = LineOver1DRangeSeries(expr, (y, 1e-05, 1e05), xscale="log", + adaptive=False, n=10, unwrap=False) + s2 = LineOver1DRangeSeries(expr, (y, 1e-05, 1e05), xscale="log", + adaptive=False, n=10, unwrap=True) + s3 = LineOver1DRangeSeries(expr, (y, 1e-05, 1e05), xscale="log", + adaptive=False, n=10, unwrap={"period": 4}) + x1, y1 = s1.get_data() + x2, y2 = s2.get_data() + x3, y3 = s3.get_data() + assert np.allclose(x1, x2) + # there must not be nan values in the results of these evaluations + assert all(not np.isnan(t).any() for t in [y1, y2, y3]) + assert not np.allclose(y1, y2) + assert not np.allclose(y1, y3) + assert not np.allclose(y2, y3) diff --git a/phivenv/Lib/site-packages/sympy/plotting/tests/test_textplot.py b/phivenv/Lib/site-packages/sympy/plotting/tests/test_textplot.py new file mode 100644 index 0000000000000000000000000000000000000000..928085c627e5230f2ac4a8ce0bbac5354ab35d51 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/plotting/tests/test_textplot.py @@ -0,0 +1,203 @@ +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.functions.elementary.exponential import log +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import sin +from sympy.plotting.textplot import textplot_str + +from sympy.utilities.exceptions import ignore_warnings + + +def test_axes_alignment(): + x = Symbol('x') + lines = [ + ' 1 | ..', + ' | ... ', + ' | .. ', + ' | ... ', + ' | ... ', + ' | .. ', + ' | ... ', + ' | ... ', + ' | .. ', + ' | ... ', + ' 0 |--------------------------...--------------------------', + ' | ... ', + ' | .. ', + ' | ... ', + ' | ... ', + ' | .. ', + ' | ... ', + ' | ... ', + ' | .. ', + ' | ... ', + ' -1 |_______________________________________________________', + ' -1 0 1' + ] + assert lines == list(textplot_str(x, -1, 1)) + + lines = [ + ' 1 | ..', + ' | .... ', + ' | ... ', + ' | ... ', + ' | .... ', + ' | ... ', + ' | ... ', + ' | .... ', + ' 0 |--------------------------...--------------------------', + ' | .... ', + ' | ... ', + ' | ... ', + ' | .... ', + ' | ... ', + ' | ... ', + ' | .... ', + ' -1 |_______________________________________________________', + ' -1 0 1' + ] + assert lines == list(textplot_str(x, -1, 1, H=17)) + + +def test_singularity(): + x = Symbol('x') + lines = [ + ' 54 | . ', + ' | ', + ' | ', + ' | ', + ' | ',' | ', + ' | ', + ' | ', + ' | ', + ' | ', + ' 27.5 |--.----------------------------------------------------', + ' | ', + ' | ', + ' | ', + ' | . ', + ' | \\ ', + ' | \\ ', + ' | .. ', + ' | ... ', + ' | ............. ', + ' 1 |_______________________________________________________', + ' 0 0.5 1' + ] + assert lines == list(textplot_str(1/x, 0, 1)) + + lines = [ + ' 0 | ......', + ' | ........ ', + ' | ........ ', + ' | ...... ', + ' | ..... ', + ' | .... ', + ' | ... ', + ' | .. ', + ' | ... ', + ' | / ', + ' -2 |-------..----------------------------------------------', + ' | / ', + ' | / ', + ' | / ', + ' | . ', + ' | ', + ' | . ', + ' | ', + ' | ', + ' | ', + ' -4 |_______________________________________________________', + ' 0 0.5 1' + ] + # RuntimeWarning: divide by zero encountered in log + with ignore_warnings(RuntimeWarning): + assert lines == list(textplot_str(log(x), 0, 1)) + + +def test_sinc(): + x = Symbol('x') + lines = [ + ' 1 | . . ', + ' | . . ', + ' | ', + ' | . . ', + ' | ', + ' | . . ', + ' | ', + ' | ', + ' | . . ', + ' | ', + ' 0.4 |-------------------------------------------------------', + ' | . . ', + ' | ', + ' | . . ', + ' | ', + ' | ..... ..... ', + ' | .. \\ . . / .. ', + ' | / \\ / \\ ', + ' |/ \\ . . / \\', + ' | \\ / \\ / ', + ' -0.2 |_______________________________________________________', + ' -10 0 10' + ] + # RuntimeWarning: invalid value encountered in double_scalars + with ignore_warnings(RuntimeWarning): + assert lines == list(textplot_str(sin(x)/x, -10, 10)) + + +def test_imaginary(): + x = Symbol('x') + lines = [ + ' 1 | ..', + ' | .. ', + ' | ... ', + ' | .. ', + ' | .. ', + ' | .. ', + ' | .. ', + ' | .. ', + ' | .. ', + ' | / ', + ' 0.5 |----------------------------------/--------------------', + ' | .. ', + ' | / ', + ' | . ', + ' | ', + ' | . ', + ' | . ', + ' | ', + ' | ', + ' | ', + ' 0 |_______________________________________________________', + ' -1 0 1' + ] + # RuntimeWarning: invalid value encountered in sqrt + with ignore_warnings(RuntimeWarning): + assert list(textplot_str(sqrt(x), -1, 1)) == lines + + lines = [ + ' 1 | ', + ' | ', + ' | ', + ' | ', + ' | ', + ' | ', + ' | ', + ' | ', + ' | ', + ' | ', + ' 0 |-------------------------------------------------------', + ' | ', + ' | ', + ' | ', + ' | ', + ' | ', + ' | ', + ' | ', + ' | ', + ' | ', + ' -1 |_______________________________________________________', + ' -1 0 1' + ] + assert list(textplot_str(S.ImaginaryUnit, -1, 1)) == lines diff --git a/phivenv/Lib/site-packages/sympy/plotting/tests/test_utils.py b/phivenv/Lib/site-packages/sympy/plotting/tests/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4206a8b001319552c2e2be1aeb46057e6f708912 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/plotting/tests/test_utils.py @@ -0,0 +1,110 @@ +from pytest import raises +from sympy import ( + symbols, Expr, Tuple, Integer, cos, solveset, FiniteSet, ImageSet) +from sympy.plotting.utils import ( + _create_ranges, _plot_sympify, extract_solution) +from sympy.physics.mechanics import ReferenceFrame, Vector as MechVector +from sympy.vector import CoordSys3D, Vector + + +def test_plot_sympify(): + x, y = symbols("x, y") + + # argument is already sympified + args = x + y + r = _plot_sympify(args) + assert r == args + + # one argument needs to be sympified + args = (x + y, 1) + r = _plot_sympify(args) + assert isinstance(r, (list, tuple, Tuple)) and len(r) == 2 + assert isinstance(r[0], Expr) + assert isinstance(r[1], Integer) + + # string and dict should not be sympified + args = (x + y, (x, 0, 1), "str", 1, {1: 1, 2: 2.0}) + r = _plot_sympify(args) + assert isinstance(r, (list, tuple, Tuple)) and len(r) == 5 + assert isinstance(r[0], Expr) + assert isinstance(r[1], Tuple) + assert isinstance(r[2], str) + assert isinstance(r[3], Integer) + assert isinstance(r[4], dict) and isinstance(r[4][1], int) and isinstance(r[4][2], float) + + # nested arguments containing strings + args = ((x + y, (y, 0, 1), "a"), (x + 1, (x, 0, 1), "$f_{1}$")) + r = _plot_sympify(args) + assert isinstance(r, (list, tuple, Tuple)) and len(r) == 2 + assert isinstance(r[0], Tuple) + assert isinstance(r[0][1], Tuple) + assert isinstance(r[0][1][1], Integer) + assert isinstance(r[0][2], str) + assert isinstance(r[1], Tuple) + assert isinstance(r[1][1], Tuple) + assert isinstance(r[1][1][1], Integer) + assert isinstance(r[1][2], str) + + # vectors from sympy.physics.vectors module are not sympified + # vectors from sympy.vectors are sympified + # in both cases, no error should be raised + R = ReferenceFrame("R") + v1 = 2 * R.x + R.y + C = CoordSys3D("C") + v2 = 2 * C.i + C.j + args = (v1, v2) + r = _plot_sympify(args) + assert isinstance(r, (list, tuple, Tuple)) and len(r) == 2 + assert isinstance(v1, MechVector) + assert isinstance(v2, Vector) + + +def test_create_ranges(): + x, y = symbols("x, y") + + # user don't provide any range -> return a default range + r = _create_ranges({x}, [], 1) + assert isinstance(r, (list, tuple, Tuple)) and len(r) == 1 + assert isinstance(r[0], (Tuple, tuple)) + assert r[0] == (x, -10, 10) + + r = _create_ranges({x, y}, [], 2) + assert isinstance(r, (list, tuple, Tuple)) and len(r) == 2 + assert isinstance(r[0], (Tuple, tuple)) + assert isinstance(r[1], (Tuple, tuple)) + assert r[0] == (x, -10, 10) or (y, -10, 10) + assert r[1] == (y, -10, 10) or (x, -10, 10) + assert r[0] != r[1] + + # not enough ranges provided by the user -> create default ranges + r = _create_ranges( + {x, y}, + [ + (x, 0, 1), + ], + 2, + ) + assert isinstance(r, (list, tuple, Tuple)) and len(r) == 2 + assert isinstance(r[0], (Tuple, tuple)) + assert isinstance(r[1], (Tuple, tuple)) + assert r[0] == (x, 0, 1) or (y, -10, 10) + assert r[1] == (y, -10, 10) or (x, 0, 1) + assert r[0] != r[1] + + # too many free symbols + raises(ValueError, lambda: _create_ranges({x, y}, [], 1)) + raises(ValueError, lambda: _create_ranges({x, y}, [(x, 0, 5), (y, 0, 1)], 1)) + + +def test_extract_solution(): + x = symbols("x") + + sol = solveset(cos(10 * x)) + assert sol.has(ImageSet) + res = extract_solution(sol) + assert len(res) == 20 + assert isinstance(res, FiniteSet) + + res = extract_solution(sol, 20) + assert len(res) == 40 + assert isinstance(res, FiniteSet) diff --git a/phivenv/Lib/site-packages/sympy/plotting/textplot.py b/phivenv/Lib/site-packages/sympy/plotting/textplot.py new file mode 100644 index 0000000000000000000000000000000000000000..5f1f2b639d6c387a6a36cf89fe36bc7717c92b2b --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/plotting/textplot.py @@ -0,0 +1,168 @@ +from sympy.core.numbers import Float +from sympy.core.symbol import Dummy +from sympy.utilities.lambdify import lambdify + +import math + + +def is_valid(x): + """Check if a floating point number is valid""" + if x is None: + return False + if isinstance(x, complex): + return False + return not math.isinf(x) and not math.isnan(x) + + +def rescale(y, W, H, mi, ma): + """Rescale the given array `y` to fit into the integer values + between `0` and `H-1` for the values between ``mi`` and ``ma``. + """ + y_new = [] + + norm = ma - mi + offset = (ma + mi) / 2 + + for x in range(W): + if is_valid(y[x]): + normalized = (y[x] - offset) / norm + if not is_valid(normalized): + y_new.append(None) + else: + rescaled = Float((normalized*H + H/2) * (H-1)/H).round() + rescaled = int(rescaled) + y_new.append(rescaled) + else: + y_new.append(None) + return y_new + + +def linspace(start, stop, num): + return [start + (stop - start) * x / (num-1) for x in range(num)] + + +def textplot_str(expr, a, b, W=55, H=21): + """Generator for the lines of the plot""" + free = expr.free_symbols + if len(free) > 1: + raise ValueError( + "The expression must have a single variable. (Got {})" + .format(free)) + x = free.pop() if free else Dummy() + f = lambdify([x], expr) + if isinstance(a, complex): + if a.imag == 0: + a = a.real + if isinstance(b, complex): + if b.imag == 0: + b = b.real + a = float(a) + b = float(b) + + # Calculate function values + x = linspace(a, b, W) + y = [] + for val in x: + try: + y.append(f(val)) + # Not sure what exceptions to catch here or why... + except (ValueError, TypeError, ZeroDivisionError): + y.append(None) + + # Normalize height to screen space + y_valid = list(filter(is_valid, y)) + if y_valid: + ma = max(y_valid) + mi = min(y_valid) + if ma == mi: + if ma: + mi, ma = sorted([0, 2*ma]) + else: + mi, ma = -1, 1 + else: + mi, ma = -1, 1 + y_range = ma - mi + precision = math.floor(math.log10(y_range)) - 1 + precision *= -1 + mi = round(mi, precision) + ma = round(ma, precision) + y = rescale(y, W, H, mi, ma) + + y_bins = linspace(mi, ma, H) + + # Draw plot + margin = 7 + for h in range(H - 1, -1, -1): + s = [' '] * W + for i in range(W): + if y[i] == h: + if (i == 0 or y[i - 1] == h - 1) and (i == W - 1 or y[i + 1] == h + 1): + s[i] = '/' + elif (i == 0 or y[i - 1] == h + 1) and (i == W - 1 or y[i + 1] == h - 1): + s[i] = '\\' + else: + s[i] = '.' + + if h == 0: + for i in range(W): + s[i] = '_' + + # Print y values + if h in (0, H//2, H - 1): + prefix = ("%g" % y_bins[h]).rjust(margin)[:margin] + else: + prefix = " "*margin + s = "".join(s) + if h == H//2: + s = s.replace(" ", "-") + yield prefix + " |" + s + + # Print x values + bottom = " " * (margin + 2) + bottom += ("%g" % x[0]).ljust(W//2) + if W % 2 == 1: + bottom += ("%g" % x[W//2]).ljust(W//2) + else: + bottom += ("%g" % x[W//2]).ljust(W//2-1) + bottom += "%g" % x[-1] + yield bottom + + +def textplot(expr, a, b, W=55, H=21): + r""" + Print a crude ASCII art plot of the SymPy expression 'expr' (which + should contain a single symbol, e.g. x or something else) over the + interval [a, b]. + + Examples + ======== + + >>> from sympy import Symbol, sin + >>> from sympy.plotting import textplot + >>> t = Symbol('t') + >>> textplot(sin(t)*t, 0, 15) + 14 | ... + | . + | . + | . + | . + | ... + | / . . + | / + | / . + | . . . + 1.5 |----.......-------------------------------------------- + |.... \ . . + | \ / . + | .. / . + | \ / . + | .... + | . + | . . + | + | . . + -11 |_______________________________________________________ + 0 7.5 15 + """ + for line in textplot_str(expr, a, b, W, H): + print(line) diff --git a/phivenv/Lib/site-packages/sympy/plotting/utils.py b/phivenv/Lib/site-packages/sympy/plotting/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3213dea09b5a98e96094e7dffbd9b992c7d2b87e --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/plotting/utils.py @@ -0,0 +1,323 @@ +from sympy.core.containers import Tuple +from sympy.core.basic import Basic +from sympy.core.expr import Expr +from sympy.core.function import AppliedUndef +from sympy.core.relational import Relational +from sympy.core.symbol import Dummy +from sympy.core.sympify import sympify +from sympy.logic.boolalg import BooleanFunction +from sympy.sets.fancysets import ImageSet +from sympy.sets.sets import FiniteSet +from sympy.tensor.indexed import Indexed + + +def _get_free_symbols(exprs): + """Returns the free symbols of a symbolic expression. + + If the expression contains any of these elements, assume that they are + the "free symbols" of the expression: + + * indexed objects + * applied undefined function (useful for sympy.physics.mechanics module) + """ + if not isinstance(exprs, (list, tuple, set)): + exprs = [exprs] + if all(callable(e) for e in exprs): + return set() + + free = set().union(*[e.atoms(Indexed) for e in exprs]) + free = free.union(*[e.atoms(AppliedUndef) for e in exprs]) + return free or set().union(*[e.free_symbols for e in exprs]) + + +def extract_solution(set_sol, n=10): + """Extract numerical solutions from a set solution (computed by solveset, + linsolve, nonlinsolve). Often, it is not trivial do get something useful + out of them. + + Parameters + ========== + + n : int, optional + In order to replace ImageSet with FiniteSet, an iterator is created + for each ImageSet contained in `set_sol`, starting from 0 up to `n`. + Default value: 10. + """ + images = set_sol.find(ImageSet) + for im in images: + it = iter(im) + s = FiniteSet(*[next(it) for n in range(0, n)]) + set_sol = set_sol.subs(im, s) + return set_sol + + +def _plot_sympify(args): + """This function recursively loop over the arguments passed to the plot + functions: the sympify function will be applied to all arguments except + those of type string/dict. + + Generally, users can provide the following arguments to a plot function: + + expr, range1 [tuple, opt], ..., label [str, opt], rendering_kw [dict, opt] + + `expr, range1, ...` can be sympified, whereas `label, rendering_kw` can't. + In particular, whenever a special character like $, {, }, ... is used in + the `label`, sympify will raise an error. + """ + if isinstance(args, Expr): + return args + + args = list(args) + for i, a in enumerate(args): + if isinstance(a, (list, tuple)): + args[i] = Tuple(*_plot_sympify(a), sympify=False) + elif not (isinstance(a, (str, dict)) or callable(a) + # NOTE: check if it is a vector from sympy.physics.vector module + # without importing the module (because it slows down SymPy's + # import process and triggers SymPy's optional-dependencies + # tests to fail). + or ((a.__class__.__name__ == "Vector") and not isinstance(a, Basic)) + ): + args[i] = sympify(a) + return args + + +def _create_ranges(exprs, ranges, npar, label="", params=None): + """This function does two things: + + 1. Check if the number of free symbols is in agreement with the type of + plot chosen. For example, plot() requires 1 free symbol; + plot3d() requires 2 free symbols. + 2. Sometime users create plots without providing ranges for the variables. + Here we create the necessary ranges. + + Parameters + ========== + + exprs : iterable + The expressions from which to extract the free symbols + ranges : iterable + The limiting ranges provided by the user + npar : int + The number of free symbols required by the plot functions. + For example, + npar=1 for plot, npar=2 for plot3d, ... + params : dict + A dictionary mapping symbols to parameters for interactive plot. + """ + get_default_range = lambda symbol: Tuple(symbol, -10, 10) + + free_symbols = _get_free_symbols(exprs) + if params is not None: + free_symbols = free_symbols.difference(params.keys()) + + if len(free_symbols) > npar: + raise ValueError( + "Too many free symbols.\n" + + "Expected {} free symbols.\n".format(npar) + + "Received {}: {}".format(len(free_symbols), free_symbols) + ) + + if len(ranges) > npar: + raise ValueError( + "Too many ranges. Received %s, expected %s" % (len(ranges), npar)) + + # free symbols in the ranges provided by the user + rfs = set().union([r[0] for r in ranges]) + if len(rfs) != len(ranges): + raise ValueError("Multiple ranges with the same symbol") + + if len(ranges) < npar: + symbols = free_symbols.difference(rfs) + if symbols != set(): + # add a range for each missing free symbols + for s in symbols: + ranges.append(get_default_range(s)) + # if there is still room, fill them with dummys + for i in range(npar - len(ranges)): + ranges.append(get_default_range(Dummy())) + + if len(free_symbols) == npar: + # there could be times when this condition is not met, for example + # plotting the function f(x, y) = x (which is a plane); in this case, + # free_symbols = {x} whereas rfs = {x, y} (or x and Dummy) + rfs = set().union([r[0] for r in ranges]) + if len(free_symbols.difference(rfs)) > 0: + raise ValueError( + "Incompatible free symbols of the expressions with " + "the ranges.\n" + + "Free symbols in the expressions: {}\n".format(free_symbols) + + "Free symbols in the ranges: {}".format(rfs) + ) + return ranges + + +def _is_range(r): + """A range is defined as (symbol, start, end). start and end should + be numbers. + """ + # TODO: prange check goes here + return ( + isinstance(r, Tuple) + and (len(r) == 3) + and (not isinstance(r.args[1], str)) and r.args[1].is_number + and (not isinstance(r.args[2], str)) and r.args[2].is_number + ) + + +def _unpack_args(*args): + """Given a list/tuple of arguments previously processed by _plot_sympify() + and/or _check_arguments(), separates and returns its components: + expressions, ranges, label and rendering keywords. + + Examples + ======== + + >>> from sympy import cos, sin, symbols + >>> from sympy.plotting.utils import _plot_sympify, _unpack_args + >>> x, y = symbols('x, y') + >>> args = (sin(x), (x, -10, 10), "f1") + >>> args = _plot_sympify(args) + >>> _unpack_args(*args) + ([sin(x)], [(x, -10, 10)], 'f1', None) + + >>> args = (sin(x**2 + y**2), (x, -2, 2), (y, -3, 3), "f2") + >>> args = _plot_sympify(args) + >>> _unpack_args(*args) + ([sin(x**2 + y**2)], [(x, -2, 2), (y, -3, 3)], 'f2', None) + + >>> args = (sin(x + y), cos(x - y), x + y, (x, -2, 2), (y, -3, 3), "f3") + >>> args = _plot_sympify(args) + >>> _unpack_args(*args) + ([sin(x + y), cos(x - y), x + y], [(x, -2, 2), (y, -3, 3)], 'f3', None) + """ + ranges = [t for t in args if _is_range(t)] + labels = [t for t in args if isinstance(t, str)] + label = None if not labels else labels[0] + rendering_kw = [t for t in args if isinstance(t, dict)] + rendering_kw = None if not rendering_kw else rendering_kw[0] + # NOTE: why None? because args might have been preprocessed by + # _check_arguments, so None might represent the rendering_kw + results = [not (_is_range(a) or isinstance(a, (str, dict)) or (a is None)) for a in args] + exprs = [a for a, b in zip(args, results) if b] + return exprs, ranges, label, rendering_kw + + +def _check_arguments(args, nexpr, npar, **kwargs): + """Checks the arguments and converts into tuples of the + form (exprs, ranges, label, rendering_kw). + + Parameters + ========== + + args + The arguments provided to the plot functions + nexpr + The number of sub-expression forming an expression to be plotted. + For example: + nexpr=1 for plot. + nexpr=2 for plot_parametric: a curve is represented by a tuple of two + elements. + nexpr=1 for plot3d. + nexpr=3 for plot3d_parametric_line: a curve is represented by a tuple + of three elements. + npar + The number of free symbols required by the plot functions. For example, + npar=1 for plot, npar=2 for plot3d, ... + **kwargs : + keyword arguments passed to the plotting function. It will be used to + verify if ``params`` has ben provided. + + Examples + ======== + + .. plot:: + :context: reset + :format: doctest + :include-source: True + + >>> from sympy import cos, sin, symbols + >>> from sympy.plotting.plot import _check_arguments + >>> x = symbols('x') + >>> _check_arguments([cos(x), sin(x)], 2, 1) + [(cos(x), sin(x), (x, -10, 10), None, None)] + + >>> _check_arguments([cos(x), sin(x), "test"], 2, 1) + [(cos(x), sin(x), (x, -10, 10), 'test', None)] + + >>> _check_arguments([cos(x), sin(x), "test", {"a": 0, "b": 1}], 2, 1) + [(cos(x), sin(x), (x, -10, 10), 'test', {'a': 0, 'b': 1})] + + >>> _check_arguments([x, x**2], 1, 1) + [(x, (x, -10, 10), None, None), (x**2, (x, -10, 10), None, None)] + """ + if not args: + return [] + output = [] + params = kwargs.get("params", None) + + if all(isinstance(a, (Expr, Relational, BooleanFunction)) for a in args[:nexpr]): + # In this case, with a single plot command, we are plotting either: + # 1. one expression + # 2. multiple expressions over the same range + + exprs, ranges, label, rendering_kw = _unpack_args(*args) + free_symbols = set().union(*[e.free_symbols for e in exprs]) + ranges = _create_ranges(exprs, ranges, npar, label, params) + + if nexpr > 1: + # in case of plot_parametric or plot3d_parametric_line, there will + # be 2 or 3 expressions defining a curve. Group them together. + if len(exprs) == nexpr: + exprs = (tuple(exprs),) + for expr in exprs: + # need this if-else to deal with both plot/plot3d and + # plot_parametric/plot3d_parametric_line + is_expr = isinstance(expr, (Expr, Relational, BooleanFunction)) + e = (expr,) if is_expr else expr + output.append((*e, *ranges, label, rendering_kw)) + + else: + # In this case, we are plotting multiple expressions, each one with its + # range. Each "expression" to be plotted has the following form: + # (expr, range, label) where label is optional + + _, ranges, labels, rendering_kw = _unpack_args(*args) + labels = [labels] if labels else [] + + # number of expressions + n = (len(ranges) + len(labels) + + (len(rendering_kw) if rendering_kw is not None else 0)) + new_args = args[:-n] if n > 0 else args + + # at this point, new_args might just be [expr]. But I need it to be + # [[expr]] in order to be able to loop over + # [expr, range [opt], label [opt]] + if not isinstance(new_args[0], (list, tuple, Tuple)): + new_args = [new_args] + + # Each arg has the form (expr1, expr2, ..., range1 [optional], ..., + # label [optional], rendering_kw [optional]) + for arg in new_args: + # look for "local" range and label. If there is not, use "global". + l = [a for a in arg if isinstance(a, str)] + if not l: + l = labels + r = [a for a in arg if _is_range(a)] + if not r: + r = ranges.copy() + rend_kw = [a for a in arg if isinstance(a, dict)] + rend_kw = rendering_kw if len(rend_kw) == 0 else rend_kw[0] + + # NOTE: arg = arg[:nexpr] may raise an exception if lambda + # functions are used. Execute the following instead: + arg = [arg[i] for i in range(nexpr)] + free_symbols = set() + if all(not callable(a) for a in arg): + free_symbols = free_symbols.union(*[a.free_symbols for a in arg]) + if len(r) != npar: + r = _create_ranges(arg, r, npar, "", params) + + label = None if not l else l[0] + output.append((*arg, *r, label, rend_kw)) + return output diff --git a/phivenv/Lib/site-packages/sympy/polys/appellseqs.py b/phivenv/Lib/site-packages/sympy/polys/appellseqs.py new file mode 100644 index 0000000000000000000000000000000000000000..ac10fe3d1f1e60ccdf46cdae4eb5b8a969500a3e --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/appellseqs.py @@ -0,0 +1,269 @@ +r""" +Efficient functions for generating Appell sequences. + +An Appell sequence is a zero-indexed sequence of polynomials `p_i(x)` +satisfying `p_{i+1}'(x)=(i+1)p_i(x)` for all `i`. This definition leads +to the following iterative algorithm: + +.. math :: p_0(x) = c_0,\ p_i(x) = i \int_0^x p_{i-1}(t)\,dt + c_i + +The constant coefficients `c_i` are usually determined from the +just-evaluated integral and `i`. + +Appell sequences satisfy the following identity from umbral calculus: + +.. math :: p_n(x+y) = \sum_{k=0}^n \binom{n}{k} p_k(x) y^{n-k} + +References +========== + +.. [1] https://en.wikipedia.org/wiki/Appell_sequence +.. [2] Peter Luschny, "An introduction to the Bernoulli function", + https://arxiv.org/abs/2009.06743 +""" +from sympy.polys.densearith import dup_mul_ground, dup_sub_ground, dup_quo_ground +from sympy.polys.densetools import dup_eval, dup_integrate +from sympy.polys.domains import ZZ, QQ +from sympy.polys.polytools import named_poly +from sympy.utilities import public + +def dup_bernoulli(n, K): + """Low-level implementation of Bernoulli polynomials.""" + if n < 1: + return [K.one] + p = [K.one, K(-1,2)] + for i in range(2, n+1): + p = dup_integrate(dup_mul_ground(p, K(i), K), 1, K) + if i % 2 == 0: + p = dup_sub_ground(p, dup_eval(p, K(1,2), K) * K(1<<(i-1), (1<>> from sympy import summation + >>> from sympy.abc import x + >>> from sympy.polys import bernoulli_poly + >>> bernoulli_poly(5, x) + x**5 - 5*x**4/2 + 5*x**3/3 - x/6 + + >>> def psum(p, a, b): + ... return (bernoulli_poly(p+1,b+1) - bernoulli_poly(p+1,a)) / (p+1) + >>> psum(4, -6, 27) + 3144337 + >>> summation(x**4, (x, -6, 27)) + 3144337 + + >>> psum(1, 1, x).factor() + x*(x + 1)/2 + >>> psum(2, 1, x).factor() + x*(x + 1)*(2*x + 1)/6 + >>> psum(3, 1, x).factor() + x**2*(x + 1)**2/4 + + Parameters + ========== + + n : int + Degree of the polynomial. + x : optional + polys : bool, optional + If True, return a Poly, otherwise (default) return an expression. + + See Also + ======== + + sympy.functions.combinatorial.numbers.bernoulli + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Bernoulli_polynomials + """ + return named_poly(n, dup_bernoulli, QQ, "Bernoulli polynomial", (x,), polys) + +def dup_bernoulli_c(n, K): + """Low-level implementation of central Bernoulli polynomials.""" + p = [K.one] + for i in range(1, n+1): + p = dup_integrate(dup_mul_ground(p, K(i), K), 1, K) + if i % 2 == 0: + p = dup_sub_ground(p, dup_eval(p, K.one, K) * K((1<<(i-1))-1, (1<>> from sympy import bernoulli, euler, genocchi + >>> from sympy.abc import x + >>> from sympy.polys import andre_poly + >>> andre_poly(9, x) + x**9 - 36*x**7 + 630*x**5 - 5124*x**3 + 12465*x + + >>> [andre_poly(n, 0) for n in range(11)] + [1, 0, -1, 0, 5, 0, -61, 0, 1385, 0, -50521] + >>> [euler(n) for n in range(11)] + [1, 0, -1, 0, 5, 0, -61, 0, 1385, 0, -50521] + >>> [andre_poly(n-1, 1) * n / (4**n - 2**n) for n in range(1, 11)] + [1/2, 1/6, 0, -1/30, 0, 1/42, 0, -1/30, 0, 5/66] + >>> [bernoulli(n) for n in range(1, 11)] + [1/2, 1/6, 0, -1/30, 0, 1/42, 0, -1/30, 0, 5/66] + >>> [-andre_poly(n-1, -1) * n / (-2)**(n-1) for n in range(1, 11)] + [-1, -1, 0, 1, 0, -3, 0, 17, 0, -155] + >>> [genocchi(n) for n in range(1, 11)] + [-1, -1, 0, 1, 0, -3, 0, 17, 0, -155] + + >>> [abs(andre_poly(n, n%2)) for n in range(11)] + [1, 1, 1, 2, 5, 16, 61, 272, 1385, 7936, 50521] + + Parameters + ========== + + n : int + Degree of the polynomial. + x : optional + polys : bool, optional + If True, return a Poly, otherwise (default) return an expression. + + See Also + ======== + + sympy.functions.combinatorial.numbers.andre + + References + ========== + + .. [1] Peter Luschny, "An introduction to the Bernoulli function", + https://arxiv.org/abs/2009.06743 + """ + return named_poly(n, dup_andre, ZZ, "Andre polynomial", (x,), polys) diff --git a/phivenv/Lib/site-packages/sympy/polys/compatibility.py b/phivenv/Lib/site-packages/sympy/polys/compatibility.py new file mode 100644 index 0000000000000000000000000000000000000000..eb239d282a738d1e5611a2249d313ff1d3b7671c --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/compatibility.py @@ -0,0 +1,1152 @@ +"""Compatibility interface between dense and sparse polys. """ + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from sympy.core.expr import Expr + from sympy.polys.domains.domain import Domain + from sympy.polys.orderings import MonomialOrder + from sympy.polys.rings import PolyElement + +from sympy.polys.densearith import dup_add_term +from sympy.polys.densearith import dmp_add_term +from sympy.polys.densearith import dup_sub_term +from sympy.polys.densearith import dmp_sub_term +from sympy.polys.densearith import dup_mul_term +from sympy.polys.densearith import dmp_mul_term +from sympy.polys.densearith import dup_add_ground +from sympy.polys.densearith import dmp_add_ground +from sympy.polys.densearith import dup_sub_ground +from sympy.polys.densearith import dmp_sub_ground +from sympy.polys.densearith import dup_mul_ground +from sympy.polys.densearith import dmp_mul_ground +from sympy.polys.densearith import dup_quo_ground +from sympy.polys.densearith import dmp_quo_ground +from sympy.polys.densearith import dup_exquo_ground +from sympy.polys.densearith import dmp_exquo_ground +from sympy.polys.densearith import dup_lshift +from sympy.polys.densearith import dup_rshift +from sympy.polys.densearith import dup_abs +from sympy.polys.densearith import dmp_abs +from sympy.polys.densearith import dup_neg +from sympy.polys.densearith import dmp_neg +from sympy.polys.densearith import dup_add +from sympy.polys.densearith import dmp_add +from sympy.polys.densearith import dup_sub +from sympy.polys.densearith import dmp_sub +from sympy.polys.densearith import dup_add_mul +from sympy.polys.densearith import dmp_add_mul +from sympy.polys.densearith import dup_sub_mul +from sympy.polys.densearith import dmp_sub_mul +from sympy.polys.densearith import dup_mul +from sympy.polys.densearith import dmp_mul +from sympy.polys.densearith import dup_sqr +from sympy.polys.densearith import dmp_sqr +from sympy.polys.densearith import dup_pow +from sympy.polys.densearith import dmp_pow +from sympy.polys.densearith import dup_pdiv +from sympy.polys.densearith import dup_prem +from sympy.polys.densearith import dup_pquo +from sympy.polys.densearith import dup_pexquo +from sympy.polys.densearith import dmp_pdiv +from sympy.polys.densearith import dmp_prem +from sympy.polys.densearith import dmp_pquo +from sympy.polys.densearith import dmp_pexquo +from sympy.polys.densearith import dup_rr_div +from sympy.polys.densearith import dmp_rr_div +from sympy.polys.densearith import dup_ff_div +from sympy.polys.densearith import dmp_ff_div +from sympy.polys.densearith import dup_div +from sympy.polys.densearith import dup_rem +from sympy.polys.densearith import dup_quo +from sympy.polys.densearith import dup_exquo +from sympy.polys.densearith import dmp_div +from sympy.polys.densearith import dmp_rem +from sympy.polys.densearith import dmp_quo +from sympy.polys.densearith import dmp_exquo +from sympy.polys.densearith import dup_max_norm +from sympy.polys.densearith import dmp_max_norm +from sympy.polys.densearith import dup_l1_norm +from sympy.polys.densearith import dmp_l1_norm +from sympy.polys.densearith import dup_l2_norm_squared +from sympy.polys.densearith import dmp_l2_norm_squared +from sympy.polys.densearith import dup_expand +from sympy.polys.densearith import dmp_expand +from sympy.polys.densebasic import dup_LC +from sympy.polys.densebasic import dmp_LC +from sympy.polys.densebasic import dup_TC +from sympy.polys.densebasic import dmp_TC +from sympy.polys.densebasic import dmp_ground_LC +from sympy.polys.densebasic import dmp_ground_TC +from sympy.polys.densebasic import dup_degree +from sympy.polys.densebasic import dmp_degree +from sympy.polys.densebasic import dmp_degree_in +from sympy.polys.densebasic import dmp_to_dict +from sympy.polys.densetools import dup_integrate +from sympy.polys.densetools import dmp_integrate +from sympy.polys.densetools import dmp_integrate_in +from sympy.polys.densetools import dup_diff +from sympy.polys.densetools import dmp_diff +from sympy.polys.densetools import dmp_diff_in +from sympy.polys.densetools import dup_eval +from sympy.polys.densetools import dmp_eval +from sympy.polys.densetools import dmp_eval_in +from sympy.polys.densetools import dmp_eval_tail +from sympy.polys.densetools import dmp_diff_eval_in +from sympy.polys.densetools import dup_trunc +from sympy.polys.densetools import dmp_trunc +from sympy.polys.densetools import dmp_ground_trunc +from sympy.polys.densetools import dup_monic +from sympy.polys.densetools import dmp_ground_monic +from sympy.polys.densetools import dup_content +from sympy.polys.densetools import dmp_ground_content +from sympy.polys.densetools import dup_primitive +from sympy.polys.densetools import dmp_ground_primitive +from sympy.polys.densetools import dup_extract +from sympy.polys.densetools import dmp_ground_extract +from sympy.polys.densetools import dup_real_imag +from sympy.polys.densetools import dup_mirror +from sympy.polys.densetools import dup_scale +from sympy.polys.densetools import dup_shift +from sympy.polys.densetools import dmp_shift +from sympy.polys.densetools import dup_transform +from sympy.polys.densetools import dup_compose +from sympy.polys.densetools import dmp_compose +from sympy.polys.densetools import dup_decompose +from sympy.polys.densetools import dmp_lift +from sympy.polys.densetools import dup_sign_variations +from sympy.polys.densetools import dup_clear_denoms +from sympy.polys.densetools import dmp_clear_denoms +from sympy.polys.densetools import dup_revert +from sympy.polys.euclidtools import dup_half_gcdex +from sympy.polys.euclidtools import dmp_half_gcdex +from sympy.polys.euclidtools import dup_gcdex +from sympy.polys.euclidtools import dmp_gcdex +from sympy.polys.euclidtools import dup_invert +from sympy.polys.euclidtools import dmp_invert +from sympy.polys.euclidtools import dup_euclidean_prs +from sympy.polys.euclidtools import dmp_euclidean_prs +from sympy.polys.euclidtools import dup_primitive_prs +from sympy.polys.euclidtools import dmp_primitive_prs +from sympy.polys.euclidtools import dup_inner_subresultants +from sympy.polys.euclidtools import dup_subresultants +from sympy.polys.euclidtools import dup_prs_resultant +from sympy.polys.euclidtools import dup_resultant +from sympy.polys.euclidtools import dmp_inner_subresultants +from sympy.polys.euclidtools import dmp_subresultants +from sympy.polys.euclidtools import dmp_prs_resultant +from sympy.polys.euclidtools import dmp_zz_modular_resultant +from sympy.polys.euclidtools import dmp_zz_collins_resultant +from sympy.polys.euclidtools import dmp_qq_collins_resultant +from sympy.polys.euclidtools import dmp_resultant +from sympy.polys.euclidtools import dup_discriminant +from sympy.polys.euclidtools import dmp_discriminant +from sympy.polys.euclidtools import dup_rr_prs_gcd +from sympy.polys.euclidtools import dup_ff_prs_gcd +from sympy.polys.euclidtools import dmp_rr_prs_gcd +from sympy.polys.euclidtools import dmp_ff_prs_gcd +from sympy.polys.euclidtools import dup_zz_heu_gcd +from sympy.polys.euclidtools import dmp_zz_heu_gcd +from sympy.polys.euclidtools import dup_qq_heu_gcd +from sympy.polys.euclidtools import dmp_qq_heu_gcd +from sympy.polys.euclidtools import dup_inner_gcd +from sympy.polys.euclidtools import dmp_inner_gcd +from sympy.polys.euclidtools import dup_gcd +from sympy.polys.euclidtools import dmp_gcd +from sympy.polys.euclidtools import dup_rr_lcm +from sympy.polys.euclidtools import dup_ff_lcm +from sympy.polys.euclidtools import dup_lcm +from sympy.polys.euclidtools import dmp_rr_lcm +from sympy.polys.euclidtools import dmp_ff_lcm +from sympy.polys.euclidtools import dmp_lcm +from sympy.polys.euclidtools import dmp_content +from sympy.polys.euclidtools import dmp_primitive +from sympy.polys.euclidtools import dup_cancel +from sympy.polys.euclidtools import dmp_cancel +from sympy.polys.factortools import dup_trial_division +from sympy.polys.factortools import dmp_trial_division +from sympy.polys.factortools import dup_zz_mignotte_bound +from sympy.polys.factortools import dmp_zz_mignotte_bound +from sympy.polys.factortools import dup_zz_hensel_step +from sympy.polys.factortools import dup_zz_hensel_lift +from sympy.polys.factortools import dup_zz_zassenhaus +from sympy.polys.factortools import dup_zz_irreducible_p +from sympy.polys.factortools import dup_cyclotomic_p +from sympy.polys.factortools import dup_zz_cyclotomic_poly +from sympy.polys.factortools import dup_zz_cyclotomic_factor +from sympy.polys.factortools import dup_zz_factor_sqf +from sympy.polys.factortools import dup_zz_factor +from sympy.polys.factortools import dmp_zz_wang_non_divisors +from sympy.polys.factortools import dmp_zz_wang_lead_coeffs +from sympy.polys.factortools import dup_zz_diophantine +from sympy.polys.factortools import dmp_zz_diophantine +from sympy.polys.factortools import dmp_zz_wang_hensel_lifting +from sympy.polys.factortools import dmp_zz_wang +from sympy.polys.factortools import dmp_zz_factor +from sympy.polys.factortools import dup_qq_i_factor +from sympy.polys.factortools import dup_zz_i_factor +from sympy.polys.factortools import dmp_qq_i_factor +from sympy.polys.factortools import dmp_zz_i_factor +from sympy.polys.factortools import dup_ext_factor +from sympy.polys.factortools import dmp_ext_factor +from sympy.polys.factortools import dup_gf_factor +from sympy.polys.factortools import dmp_gf_factor +from sympy.polys.factortools import dup_factor_list +from sympy.polys.factortools import dup_factor_list_include +from sympy.polys.factortools import dmp_factor_list +from sympy.polys.factortools import dmp_factor_list_include +from sympy.polys.factortools import dup_irreducible_p +from sympy.polys.factortools import dmp_irreducible_p +from sympy.polys.rootisolation import dup_sturm +from sympy.polys.rootisolation import dup_root_upper_bound +from sympy.polys.rootisolation import dup_root_lower_bound +from sympy.polys.rootisolation import dup_step_refine_real_root +from sympy.polys.rootisolation import dup_inner_refine_real_root +from sympy.polys.rootisolation import dup_outer_refine_real_root +from sympy.polys.rootisolation import dup_refine_real_root +from sympy.polys.rootisolation import dup_inner_isolate_real_roots +from sympy.polys.rootisolation import dup_inner_isolate_positive_roots +from sympy.polys.rootisolation import dup_inner_isolate_negative_roots +from sympy.polys.rootisolation import dup_isolate_real_roots_sqf +from sympy.polys.rootisolation import dup_isolate_real_roots +from sympy.polys.rootisolation import dup_isolate_real_roots_list +from sympy.polys.rootisolation import dup_count_real_roots +from sympy.polys.rootisolation import dup_count_complex_roots +from sympy.polys.rootisolation import dup_isolate_complex_roots_sqf +from sympy.polys.rootisolation import dup_isolate_all_roots_sqf +from sympy.polys.rootisolation import dup_isolate_all_roots + +from sympy.polys.sqfreetools import ( + dup_sqf_p, dmp_sqf_p, dmp_norm, dup_sqf_norm, dmp_sqf_norm, + dup_gf_sqf_part, dmp_gf_sqf_part, dup_sqf_part, dmp_sqf_part, + dup_gf_sqf_list, dmp_gf_sqf_list, dup_sqf_list, dup_sqf_list_include, + dmp_sqf_list, dmp_sqf_list_include, dup_gff_list, dmp_gff_list) + +from sympy.polys.galoistools import ( + gf_degree, gf_LC, gf_TC, gf_strip, gf_from_dict, + gf_to_dict, gf_from_int_poly, gf_to_int_poly, gf_neg, gf_add_ground, gf_sub_ground, + gf_mul_ground, gf_quo_ground, gf_add, gf_sub, gf_mul, gf_sqr, gf_add_mul, gf_sub_mul, + gf_expand, gf_div, gf_rem, gf_quo, gf_exquo, gf_lshift, gf_rshift, gf_pow, gf_pow_mod, + gf_gcd, gf_lcm, gf_cofactors, gf_gcdex, gf_monic, gf_diff, gf_eval, gf_multi_eval, + gf_compose, gf_compose_mod, gf_trace_map, gf_random, gf_irreducible, gf_irred_p_ben_or, + gf_irred_p_rabin, gf_irreducible_p, gf_sqf_p, gf_sqf_part, gf_Qmatrix, + gf_berlekamp, gf_ddf_zassenhaus, gf_edf_zassenhaus, gf_ddf_shoup, gf_edf_shoup, + gf_zassenhaus, gf_shoup, gf_factor_sqf, gf_factor) + +from sympy.utilities import public + +@public +class IPolys: + + gens: tuple[PolyElement, ...] + symbols: tuple[Expr, ...] + ngens: int + domain: Domain + order: MonomialOrder + + def drop(self, gen): + pass + + def clone(self, symbols=None, domain=None, order=None): + pass + + def to_ground(self): + pass + + def ground_new(self, element): + pass + + def domain_new(self, element): + pass + + def from_dict(self, d): + pass + + def wrap(self, element): + from sympy.polys.rings import PolyElement + if isinstance(element, PolyElement): + if element.ring == self: + return element + else: + raise NotImplementedError("domain conversions") + else: + return self.ground_new(element) + + def to_dense(self, element): + return self.wrap(element).to_dense() + + def from_dense(self, element): + return self.from_dict(dmp_to_dict(element, self.ngens-1, self.domain)) + + def dup_add_term(self, f, c, i): + return self.from_dense(dup_add_term(self.to_dense(f), c, i, self.domain)) + def dmp_add_term(self, f, c, i): + return self.from_dense(dmp_add_term(self.to_dense(f), self.wrap(c).drop(0).to_dense(), i, self.ngens-1, self.domain)) + def dup_sub_term(self, f, c, i): + return self.from_dense(dup_sub_term(self.to_dense(f), c, i, self.domain)) + def dmp_sub_term(self, f, c, i): + return self.from_dense(dmp_sub_term(self.to_dense(f), self.wrap(c).drop(0).to_dense(), i, self.ngens-1, self.domain)) + def dup_mul_term(self, f, c, i): + return self.from_dense(dup_mul_term(self.to_dense(f), c, i, self.domain)) + def dmp_mul_term(self, f, c, i): + return self.from_dense(dmp_mul_term(self.to_dense(f), self.wrap(c).drop(0).to_dense(), i, self.ngens-1, self.domain)) + + def dup_add_ground(self, f, c): + return self.from_dense(dup_add_ground(self.to_dense(f), c, self.domain)) + def dmp_add_ground(self, f, c): + return self.from_dense(dmp_add_ground(self.to_dense(f), c, self.ngens-1, self.domain)) + def dup_sub_ground(self, f, c): + return self.from_dense(dup_sub_ground(self.to_dense(f), c, self.domain)) + def dmp_sub_ground(self, f, c): + return self.from_dense(dmp_sub_ground(self.to_dense(f), c, self.ngens-1, self.domain)) + def dup_mul_ground(self, f, c): + return self.from_dense(dup_mul_ground(self.to_dense(f), c, self.domain)) + def dmp_mul_ground(self, f, c): + return self.from_dense(dmp_mul_ground(self.to_dense(f), c, self.ngens-1, self.domain)) + def dup_quo_ground(self, f, c): + return self.from_dense(dup_quo_ground(self.to_dense(f), c, self.domain)) + def dmp_quo_ground(self, f, c): + return self.from_dense(dmp_quo_ground(self.to_dense(f), c, self.ngens-1, self.domain)) + def dup_exquo_ground(self, f, c): + return self.from_dense(dup_exquo_ground(self.to_dense(f), c, self.domain)) + def dmp_exquo_ground(self, f, c): + return self.from_dense(dmp_exquo_ground(self.to_dense(f), c, self.ngens-1, self.domain)) + + def dup_lshift(self, f, n): + return self.from_dense(dup_lshift(self.to_dense(f), n, self.domain)) + def dup_rshift(self, f, n): + return self.from_dense(dup_rshift(self.to_dense(f), n, self.domain)) + + def dup_abs(self, f): + return self.from_dense(dup_abs(self.to_dense(f), self.domain)) + def dmp_abs(self, f): + return self.from_dense(dmp_abs(self.to_dense(f), self.ngens-1, self.domain)) + + def dup_neg(self, f): + return self.from_dense(dup_neg(self.to_dense(f), self.domain)) + def dmp_neg(self, f): + return self.from_dense(dmp_neg(self.to_dense(f), self.ngens-1, self.domain)) + + def dup_add(self, f, g): + return self.from_dense(dup_add(self.to_dense(f), self.to_dense(g), self.domain)) + def dmp_add(self, f, g): + return self.from_dense(dmp_add(self.to_dense(f), self.to_dense(g), self.ngens-1, self.domain)) + + def dup_sub(self, f, g): + return self.from_dense(dup_sub(self.to_dense(f), self.to_dense(g), self.domain)) + def dmp_sub(self, f, g): + return self.from_dense(dmp_sub(self.to_dense(f), self.to_dense(g), self.ngens-1, self.domain)) + + def dup_add_mul(self, f, g, h): + return self.from_dense(dup_add_mul(self.to_dense(f), self.to_dense(g), self.to_dense(h), self.domain)) + def dmp_add_mul(self, f, g, h): + return self.from_dense(dmp_add_mul(self.to_dense(f), self.to_dense(g), self.to_dense(h), self.ngens-1, self.domain)) + def dup_sub_mul(self, f, g, h): + return self.from_dense(dup_sub_mul(self.to_dense(f), self.to_dense(g), self.to_dense(h), self.domain)) + def dmp_sub_mul(self, f, g, h): + return self.from_dense(dmp_sub_mul(self.to_dense(f), self.to_dense(g), self.to_dense(h), self.ngens-1, self.domain)) + + def dup_mul(self, f, g): + return self.from_dense(dup_mul(self.to_dense(f), self.to_dense(g), self.domain)) + def dmp_mul(self, f, g): + return self.from_dense(dmp_mul(self.to_dense(f), self.to_dense(g), self.ngens-1, self.domain)) + + def dup_sqr(self, f): + return self.from_dense(dup_sqr(self.to_dense(f), self.domain)) + def dmp_sqr(self, f): + return self.from_dense(dmp_sqr(self.to_dense(f), self.ngens-1, self.domain)) + def dup_pow(self, f, n): + return self.from_dense(dup_pow(self.to_dense(f), n, self.domain)) + def dmp_pow(self, f, n): + return self.from_dense(dmp_pow(self.to_dense(f), n, self.ngens-1, self.domain)) + + def dup_pdiv(self, f, g): + q, r = dup_pdiv(self.to_dense(f), self.to_dense(g), self.domain) + return (self.from_dense(q), self.from_dense(r)) + def dup_prem(self, f, g): + return self.from_dense(dup_prem(self.to_dense(f), self.to_dense(g), self.domain)) + def dup_pquo(self, f, g): + return self.from_dense(dup_pquo(self.to_dense(f), self.to_dense(g), self.domain)) + def dup_pexquo(self, f, g): + return self.from_dense(dup_pexquo(self.to_dense(f), self.to_dense(g), self.domain)) + + def dmp_pdiv(self, f, g): + q, r = dmp_pdiv(self.to_dense(f), self.to_dense(g), self.ngens-1, self.domain) + return (self.from_dense(q), self.from_dense(r)) + def dmp_prem(self, f, g): + return self.from_dense(dmp_prem(self.to_dense(f), self.to_dense(g), self.ngens-1, self.domain)) + def dmp_pquo(self, f, g): + return self.from_dense(dmp_pquo(self.to_dense(f), self.to_dense(g), self.ngens-1, self.domain)) + def dmp_pexquo(self, f, g): + return self.from_dense(dmp_pexquo(self.to_dense(f), self.to_dense(g), self.ngens-1, self.domain)) + + def dup_rr_div(self, f, g): + q, r = dup_rr_div(self.to_dense(f), self.to_dense(g), self.domain) + return (self.from_dense(q), self.from_dense(r)) + def dmp_rr_div(self, f, g): + q, r = dmp_rr_div(self.to_dense(f), self.to_dense(g), self.ngens-1, self.domain) + return (self.from_dense(q), self.from_dense(r)) + def dup_ff_div(self, f, g): + q, r = dup_ff_div(self.to_dense(f), self.to_dense(g), self.domain) + return (self.from_dense(q), self.from_dense(r)) + def dmp_ff_div(self, f, g): + q, r = dmp_ff_div(self.to_dense(f), self.to_dense(g), self.ngens-1, self.domain) + return (self.from_dense(q), self.from_dense(r)) + + def dup_div(self, f, g): + q, r = dup_div(self.to_dense(f), self.to_dense(g), self.domain) + return (self.from_dense(q), self.from_dense(r)) + def dup_rem(self, f, g): + return self.from_dense(dup_rem(self.to_dense(f), self.to_dense(g), self.domain)) + def dup_quo(self, f, g): + return self.from_dense(dup_quo(self.to_dense(f), self.to_dense(g), self.domain)) + def dup_exquo(self, f, g): + return self.from_dense(dup_exquo(self.to_dense(f), self.to_dense(g), self.domain)) + + def dmp_div(self, f, g): + q, r = dmp_div(self.to_dense(f), self.to_dense(g), self.ngens-1, self.domain) + return (self.from_dense(q), self.from_dense(r)) + def dmp_rem(self, f, g): + return self.from_dense(dmp_rem(self.to_dense(f), self.to_dense(g), self.ngens-1, self.domain)) + def dmp_quo(self, f, g): + return self.from_dense(dmp_quo(self.to_dense(f), self.to_dense(g), self.ngens-1, self.domain)) + def dmp_exquo(self, f, g): + return self.from_dense(dmp_exquo(self.to_dense(f), self.to_dense(g), self.ngens-1, self.domain)) + + def dup_max_norm(self, f): + return dup_max_norm(self.to_dense(f), self.domain) + def dmp_max_norm(self, f): + return dmp_max_norm(self.to_dense(f), self.ngens-1, self.domain) + + def dup_l1_norm(self, f): + return dup_l1_norm(self.to_dense(f), self.domain) + def dmp_l1_norm(self, f): + return dmp_l1_norm(self.to_dense(f), self.ngens-1, self.domain) + + def dup_l2_norm_squared(self, f): + return dup_l2_norm_squared(self.to_dense(f), self.domain) + def dmp_l2_norm_squared(self, f): + return dmp_l2_norm_squared(self.to_dense(f), self.ngens-1, self.domain) + + def dup_expand(self, polys): + return self.from_dense(dup_expand(list(map(self.to_dense, polys)), self.domain)) + def dmp_expand(self, polys): + return self.from_dense(dmp_expand(list(map(self.to_dense, polys)), self.ngens-1, self.domain)) + + def dup_LC(self, f): + return dup_LC(self.to_dense(f), self.domain) + def dmp_LC(self, f): + LC = dmp_LC(self.to_dense(f), self.domain) + if isinstance(LC, list): + return self[1:].from_dense(LC) + else: + return LC + def dup_TC(self, f): + return dup_TC(self.to_dense(f), self.domain) + def dmp_TC(self, f): + TC = dmp_TC(self.to_dense(f), self.domain) + if isinstance(TC, list): + return self[1:].from_dense(TC) + else: + return TC + + def dmp_ground_LC(self, f): + return dmp_ground_LC(self.to_dense(f), self.ngens-1, self.domain) + def dmp_ground_TC(self, f): + return dmp_ground_TC(self.to_dense(f), self.ngens-1, self.domain) + + def dup_degree(self, f): + return dup_degree(self.to_dense(f)) + def dmp_degree(self, f): + return dmp_degree(self.to_dense(f), self.ngens-1) + def dmp_degree_in(self, f, j): + return dmp_degree_in(self.to_dense(f), j, self.ngens-1) + def dup_integrate(self, f, m): + return self.from_dense(dup_integrate(self.to_dense(f), m, self.domain)) + def dmp_integrate(self, f, m): + return self.from_dense(dmp_integrate(self.to_dense(f), m, self.ngens-1, self.domain)) + + def dup_diff(self, f, m): + return self.from_dense(dup_diff(self.to_dense(f), m, self.domain)) + def dmp_diff(self, f, m): + return self.from_dense(dmp_diff(self.to_dense(f), m, self.ngens-1, self.domain)) + + def dmp_diff_in(self, f, m, j): + return self.from_dense(dmp_diff_in(self.to_dense(f), m, j, self.ngens-1, self.domain)) + def dmp_integrate_in(self, f, m, j): + return self.from_dense(dmp_integrate_in(self.to_dense(f), m, j, self.ngens-1, self.domain)) + + def dup_eval(self, f, a): + return dup_eval(self.to_dense(f), a, self.domain) + def dmp_eval(self, f, a): + result = dmp_eval(self.to_dense(f), a, self.ngens-1, self.domain) + return self[1:].from_dense(result) + + def dmp_eval_in(self, f, a, j): + result = dmp_eval_in(self.to_dense(f), a, j, self.ngens-1, self.domain) + return self.drop(j).from_dense(result) + def dmp_diff_eval_in(self, f, m, a, j): + result = dmp_diff_eval_in(self.to_dense(f), m, a, j, self.ngens-1, self.domain) + return self.drop(j).from_dense(result) + + def dmp_eval_tail(self, f, A): + result = dmp_eval_tail(self.to_dense(f), A, self.ngens-1, self.domain) + if isinstance(result, list): + return self[:-len(A)].from_dense(result) + else: + return result + + def dup_trunc(self, f, p): + return self.from_dense(dup_trunc(self.to_dense(f), p, self.domain)) + def dmp_trunc(self, f, g): + return self.from_dense(dmp_trunc(self.to_dense(f), self[1:].to_dense(g), self.ngens-1, self.domain)) + def dmp_ground_trunc(self, f, p): + return self.from_dense(dmp_ground_trunc(self.to_dense(f), p, self.ngens-1, self.domain)) + + def dup_monic(self, f): + return self.from_dense(dup_monic(self.to_dense(f), self.domain)) + def dmp_ground_monic(self, f): + return self.from_dense(dmp_ground_monic(self.to_dense(f), self.ngens-1, self.domain)) + + def dup_extract(self, f, g): + c, F, G = dup_extract(self.to_dense(f), self.to_dense(g), self.domain) + return (c, self.from_dense(F), self.from_dense(G)) + def dmp_ground_extract(self, f, g): + c, F, G = dmp_ground_extract(self.to_dense(f), self.to_dense(g), self.ngens-1, self.domain) + return (c, self.from_dense(F), self.from_dense(G)) + + def dup_real_imag(self, f): + p, q = dup_real_imag(self.wrap(f).drop(1).to_dense(), self.domain) + return (self.from_dense(p), self.from_dense(q)) + + def dup_mirror(self, f): + return self.from_dense(dup_mirror(self.to_dense(f), self.domain)) + def dup_scale(self, f, a): + return self.from_dense(dup_scale(self.to_dense(f), a, self.domain)) + def dup_shift(self, f, a): + return self.from_dense(dup_shift(self.to_dense(f), a, self.domain)) + def dmp_shift(self, f, a): + return self.from_dense(dmp_shift(self.to_dense(f), a, self.ngens-1, self.domain)) + def dup_transform(self, f, p, q): + return self.from_dense(dup_transform(self.to_dense(f), self.to_dense(p), self.to_dense(q), self.domain)) + + def dup_compose(self, f, g): + return self.from_dense(dup_compose(self.to_dense(f), self.to_dense(g), self.domain)) + def dmp_compose(self, f, g): + return self.from_dense(dmp_compose(self.to_dense(f), self.to_dense(g), self.ngens-1, self.domain)) + + def dup_decompose(self, f): + components = dup_decompose(self.to_dense(f), self.domain) + return list(map(self.from_dense, components)) + + def dmp_lift(self, f): + result = dmp_lift(self.to_dense(f), self.ngens-1, self.domain) + return self.to_ground().from_dense(result) + + def dup_sign_variations(self, f): + return dup_sign_variations(self.to_dense(f), self.domain) + + def dup_clear_denoms(self, f, convert=False): + c, F = dup_clear_denoms(self.to_dense(f), self.domain, convert=convert) + if convert: + ring = self.clone(domain=self.domain.get_ring()) + else: + ring = self + return (c, ring.from_dense(F)) + def dmp_clear_denoms(self, f, convert=False): + c, F = dmp_clear_denoms(self.to_dense(f), self.ngens-1, self.domain, convert=convert) + if convert: + ring = self.clone(domain=self.domain.get_ring()) + else: + ring = self + return (c, ring.from_dense(F)) + + def dup_revert(self, f, n): + return self.from_dense(dup_revert(self.to_dense(f), n, self.domain)) + + def dup_half_gcdex(self, f, g): + s, h = dup_half_gcdex(self.to_dense(f), self.to_dense(g), self.domain) + return (self.from_dense(s), self.from_dense(h)) + def dmp_half_gcdex(self, f, g): + s, h = dmp_half_gcdex(self.to_dense(f), self.to_dense(g), self.ngens-1, self.domain) + return (self.from_dense(s), self.from_dense(h)) + def dup_gcdex(self, f, g): + s, t, h = dup_gcdex(self.to_dense(f), self.to_dense(g), self.domain) + return (self.from_dense(s), self.from_dense(t), self.from_dense(h)) + def dmp_gcdex(self, f, g): + s, t, h = dmp_gcdex(self.to_dense(f), self.to_dense(g), self.ngens-1, self.domain) + return (self.from_dense(s), self.from_dense(t), self.from_dense(h)) + + def dup_invert(self, f, g): + return self.from_dense(dup_invert(self.to_dense(f), self.to_dense(g), self.domain)) + def dmp_invert(self, f, g): + return self.from_dense(dmp_invert(self.to_dense(f), self.to_dense(g), self.ngens-1, self.domain)) + + def dup_euclidean_prs(self, f, g): + prs = dup_euclidean_prs(self.to_dense(f), self.to_dense(g), self.domain) + return list(map(self.from_dense, prs)) + def dmp_euclidean_prs(self, f, g): + prs = dmp_euclidean_prs(self.to_dense(f), self.to_dense(g), self.ngens-1, self.domain) + return list(map(self.from_dense, prs)) + def dup_primitive_prs(self, f, g): + prs = dup_primitive_prs(self.to_dense(f), self.to_dense(g), self.domain) + return list(map(self.from_dense, prs)) + def dmp_primitive_prs(self, f, g): + prs = dmp_primitive_prs(self.to_dense(f), self.to_dense(g), self.ngens-1, self.domain) + return list(map(self.from_dense, prs)) + + def dup_inner_subresultants(self, f, g): + prs, sres = dup_inner_subresultants(self.to_dense(f), self.to_dense(g), self.domain) + return (list(map(self.from_dense, prs)), sres) + def dmp_inner_subresultants(self, f, g): + prs, sres = dmp_inner_subresultants(self.to_dense(f), self.to_dense(g), self.ngens-1, self.domain) + return (list(map(self.from_dense, prs)), sres) + + def dup_subresultants(self, f, g): + prs = dup_subresultants(self.to_dense(f), self.to_dense(g), self.domain) + return list(map(self.from_dense, prs)) + def dmp_subresultants(self, f, g): + prs = dmp_subresultants(self.to_dense(f), self.to_dense(g), self.ngens-1, self.domain) + return list(map(self.from_dense, prs)) + + def dup_prs_resultant(self, f, g): + res, prs = dup_prs_resultant(self.to_dense(f), self.to_dense(g), self.domain) + return (res, list(map(self.from_dense, prs))) + def dmp_prs_resultant(self, f, g): + res, prs = dmp_prs_resultant(self.to_dense(f), self.to_dense(g), self.ngens-1, self.domain) + return (self[1:].from_dense(res), list(map(self.from_dense, prs))) + + def dmp_zz_modular_resultant(self, f, g, p): + res = dmp_zz_modular_resultant(self.to_dense(f), self.to_dense(g), self.domain_new(p), self.ngens-1, self.domain) + return self[1:].from_dense(res) + def dmp_zz_collins_resultant(self, f, g): + res = dmp_zz_collins_resultant(self.to_dense(f), self.to_dense(g), self.ngens-1, self.domain) + return self[1:].from_dense(res) + def dmp_qq_collins_resultant(self, f, g): + res = dmp_qq_collins_resultant(self.to_dense(f), self.to_dense(g), self.ngens-1, self.domain) + return self[1:].from_dense(res) + + def dup_resultant(self, f, g): #, includePRS=False): + return dup_resultant(self.to_dense(f), self.to_dense(g), self.domain) #, includePRS=includePRS) + def dmp_resultant(self, f, g): #, includePRS=False): + res = dmp_resultant(self.to_dense(f), self.to_dense(g), self.ngens-1, self.domain) #, includePRS=includePRS) + if isinstance(res, list): + return self[1:].from_dense(res) + else: + return res + + def dup_discriminant(self, f): + return dup_discriminant(self.to_dense(f), self.domain) + def dmp_discriminant(self, f): + disc = dmp_discriminant(self.to_dense(f), self.ngens-1, self.domain) + if isinstance(disc, list): + return self[1:].from_dense(disc) + else: + return disc + + def dup_rr_prs_gcd(self, f, g): + H, F, G = dup_rr_prs_gcd(self.to_dense(f), self.to_dense(g), self.domain) + return (self.from_dense(H), self.from_dense(F), self.from_dense(G)) + def dup_ff_prs_gcd(self, f, g): + H, F, G = dup_ff_prs_gcd(self.to_dense(f), self.to_dense(g), self.domain) + return (self.from_dense(H), self.from_dense(F), self.from_dense(G)) + def dmp_rr_prs_gcd(self, f, g): + H, F, G = dmp_rr_prs_gcd(self.to_dense(f), self.to_dense(g), self.ngens-1, self.domain) + return (self.from_dense(H), self.from_dense(F), self.from_dense(G)) + def dmp_ff_prs_gcd(self, f, g): + H, F, G = dmp_ff_prs_gcd(self.to_dense(f), self.to_dense(g), self.ngens-1, self.domain) + return (self.from_dense(H), self.from_dense(F), self.from_dense(G)) + def dup_zz_heu_gcd(self, f, g): + H, F, G = dup_zz_heu_gcd(self.to_dense(f), self.to_dense(g), self.domain) + return (self.from_dense(H), self.from_dense(F), self.from_dense(G)) + def dmp_zz_heu_gcd(self, f, g): + H, F, G = dmp_zz_heu_gcd(self.to_dense(f), self.to_dense(g), self.ngens-1, self.domain) + return (self.from_dense(H), self.from_dense(F), self.from_dense(G)) + def dup_qq_heu_gcd(self, f, g): + H, F, G = dup_qq_heu_gcd(self.to_dense(f), self.to_dense(g), self.domain) + return (self.from_dense(H), self.from_dense(F), self.from_dense(G)) + def dmp_qq_heu_gcd(self, f, g): + H, F, G = dmp_qq_heu_gcd(self.to_dense(f), self.to_dense(g), self.ngens-1, self.domain) + return (self.from_dense(H), self.from_dense(F), self.from_dense(G)) + def dup_inner_gcd(self, f, g): + H, F, G = dup_inner_gcd(self.to_dense(f), self.to_dense(g), self.domain) + return (self.from_dense(H), self.from_dense(F), self.from_dense(G)) + def dmp_inner_gcd(self, f, g): + H, F, G = dmp_inner_gcd(self.to_dense(f), self.to_dense(g), self.ngens-1, self.domain) + return (self.from_dense(H), self.from_dense(F), self.from_dense(G)) + def dup_gcd(self, f, g): + H = dup_gcd(self.to_dense(f), self.to_dense(g), self.domain) + return self.from_dense(H) + def dmp_gcd(self, f, g): + H = dmp_gcd(self.to_dense(f), self.to_dense(g), self.ngens-1, self.domain) + return self.from_dense(H) + def dup_rr_lcm(self, f, g): + H = dup_rr_lcm(self.to_dense(f), self.to_dense(g), self.domain) + return self.from_dense(H) + def dup_ff_lcm(self, f, g): + H = dup_ff_lcm(self.to_dense(f), self.to_dense(g), self.domain) + return self.from_dense(H) + def dup_lcm(self, f, g): + H = dup_lcm(self.to_dense(f), self.to_dense(g), self.domain) + return self.from_dense(H) + def dmp_rr_lcm(self, f, g): + H = dmp_rr_lcm(self.to_dense(f), self.to_dense(g), self.ngens-1, self.domain) + return self.from_dense(H) + def dmp_ff_lcm(self, f, g): + H = dmp_ff_lcm(self.to_dense(f), self.to_dense(g), self.ngens-1, self.domain) + return self.from_dense(H) + def dmp_lcm(self, f, g): + H = dmp_lcm(self.to_dense(f), self.to_dense(g), self.ngens-1, self.domain) + return self.from_dense(H) + + def dup_content(self, f): + cont = dup_content(self.to_dense(f), self.domain) + return cont + def dup_primitive(self, f): + cont, prim = dup_primitive(self.to_dense(f), self.domain) + return cont, self.from_dense(prim) + + def dmp_content(self, f): + cont = dmp_content(self.to_dense(f), self.ngens-1, self.domain) + if isinstance(cont, list): + return self[1:].from_dense(cont) + else: + return cont + def dmp_primitive(self, f): + cont, prim = dmp_primitive(self.to_dense(f), self.ngens-1, self.domain) + if isinstance(cont, list): + return (self[1:].from_dense(cont), self.from_dense(prim)) + else: + return (cont, self.from_dense(prim)) + + def dmp_ground_content(self, f): + cont = dmp_ground_content(self.to_dense(f), self.ngens-1, self.domain) + return cont + def dmp_ground_primitive(self, f): + cont, prim = dmp_ground_primitive(self.to_dense(f), self.ngens-1, self.domain) + return (cont, self.from_dense(prim)) + + def dup_cancel(self, f, g, include=True): + result = dup_cancel(self.to_dense(f), self.to_dense(g), self.domain, include=include) + if not include: + cf, cg, F, G = result + return (cf, cg, self.from_dense(F), self.from_dense(G)) + else: + F, G = result + return (self.from_dense(F), self.from_dense(G)) + def dmp_cancel(self, f, g, include=True): + result = dmp_cancel(self.to_dense(f), self.to_dense(g), self.ngens-1, self.domain, include=include) + if not include: + cf, cg, F, G = result + return (cf, cg, self.from_dense(F), self.from_dense(G)) + else: + F, G = result + return (self.from_dense(F), self.from_dense(G)) + + def dup_trial_division(self, f, factors): + factors = dup_trial_division(self.to_dense(f), list(map(self.to_dense, factors)), self.domain) + return [ (self.from_dense(g), k) for g, k in factors ] + def dmp_trial_division(self, f, factors): + factors = dmp_trial_division(self.to_dense(f), list(map(self.to_dense, factors)), self.ngens-1, self.domain) + return [ (self.from_dense(g), k) for g, k in factors ] + + def dup_zz_mignotte_bound(self, f): + return dup_zz_mignotte_bound(self.to_dense(f), self.domain) + def dmp_zz_mignotte_bound(self, f): + return dmp_zz_mignotte_bound(self.to_dense(f), self.ngens-1, self.domain) + + def dup_zz_hensel_step(self, m, f, g, h, s, t): + D = self.to_dense + G, H, S, T = dup_zz_hensel_step(m, D(f), D(g), D(h), D(s), D(t), self.domain) + return (self.from_dense(G), self.from_dense(H), self.from_dense(S), self.from_dense(T)) + def dup_zz_hensel_lift(self, p, f, f_list, l): + D = self.to_dense + polys = dup_zz_hensel_lift(p, D(f), list(map(D, f_list)), l, self.domain) + return list(map(self.from_dense, polys)) + + def dup_zz_zassenhaus(self, f): + factors = dup_zz_zassenhaus(self.to_dense(f), self.domain) + return [ (self.from_dense(g), k) for g, k in factors ] + + def dup_zz_irreducible_p(self, f): + return dup_zz_irreducible_p(self.to_dense(f), self.domain) + def dup_cyclotomic_p(self, f, irreducible=False): + return dup_cyclotomic_p(self.to_dense(f), self.domain, irreducible=irreducible) + def dup_zz_cyclotomic_poly(self, n): + F = dup_zz_cyclotomic_poly(n, self.domain) + return self.from_dense(F) + def dup_zz_cyclotomic_factor(self, f): + result = dup_zz_cyclotomic_factor(self.to_dense(f), self.domain) + if result is None: + return result + else: + return list(map(self.from_dense, result)) + + # E: List[ZZ], cs: ZZ, ct: ZZ + def dmp_zz_wang_non_divisors(self, E, cs, ct): + return dmp_zz_wang_non_divisors(E, cs, ct, self.domain) + + # f: Poly, T: List[(Poly, int)], ct: ZZ, A: List[ZZ] + #def dmp_zz_wang_test_points(f, T, ct, A): + # dmp_zz_wang_test_points(self.to_dense(f), T, ct, A, self.ngens-1, self.domain) + + # f: Poly, T: List[(Poly, int)], cs: ZZ, E: List[ZZ], H: List[Poly], A: List[ZZ] + def dmp_zz_wang_lead_coeffs(self, f, T, cs, E, H, A): + mv = self[1:] + T = [ (mv.to_dense(t), k) for t, k in T ] + uv = self[:1] + H = list(map(uv.to_dense, H)) + f, HH, CC = dmp_zz_wang_lead_coeffs(self.to_dense(f), T, cs, E, H, A, self.ngens-1, self.domain) + return self.from_dense(f), list(map(uv.from_dense, HH)), list(map(mv.from_dense, CC)) + + # f: List[Poly], m: int, p: ZZ + def dup_zz_diophantine(self, F, m, p): + result = dup_zz_diophantine(list(map(self.to_dense, F)), m, p, self.domain) + return list(map(self.from_dense, result)) + + # f: List[Poly], c: List[Poly], A: List[ZZ], d: int, p: ZZ + def dmp_zz_diophantine(self, F, c, A, d, p): + result = dmp_zz_diophantine(list(map(self.to_dense, F)), self.to_dense(c), A, d, p, self.ngens-1, self.domain) + return list(map(self.from_dense, result)) + + # f: Poly, H: List[Poly], LC: List[Poly], A: List[ZZ], p: ZZ + def dmp_zz_wang_hensel_lifting(self, f, H, LC, A, p): + uv = self[:1] + mv = self[1:] + H = list(map(uv.to_dense, H)) + LC = list(map(mv.to_dense, LC)) + result = dmp_zz_wang_hensel_lifting(self.to_dense(f), H, LC, A, p, self.ngens-1, self.domain) + return list(map(self.from_dense, result)) + + def dmp_zz_wang(self, f, mod=None, seed=None): + factors = dmp_zz_wang(self.to_dense(f), self.ngens-1, self.domain, mod=mod, seed=seed) + return [ self.from_dense(g) for g in factors ] + + def dup_zz_factor_sqf(self, f): + coeff, factors = dup_zz_factor_sqf(self.to_dense(f), self.domain) + return (coeff, [ self.from_dense(g) for g in factors ]) + + def dup_zz_factor(self, f): + coeff, factors = dup_zz_factor(self.to_dense(f), self.domain) + return (coeff, [ (self.from_dense(g), k) for g, k in factors ]) + def dmp_zz_factor(self, f): + coeff, factors = dmp_zz_factor(self.to_dense(f), self.ngens-1, self.domain) + return (coeff, [ (self.from_dense(g), k) for g, k in factors ]) + + def dup_qq_i_factor(self, f): + coeff, factors = dup_qq_i_factor(self.to_dense(f), self.domain) + return (coeff, [ (self.from_dense(g), k) for g, k in factors ]) + def dmp_qq_i_factor(self, f): + coeff, factors = dmp_qq_i_factor(self.to_dense(f), self.ngens-1, self.domain) + return (coeff, [ (self.from_dense(g), k) for g, k in factors ]) + + def dup_zz_i_factor(self, f): + coeff, factors = dup_zz_i_factor(self.to_dense(f), self.domain) + return (coeff, [ (self.from_dense(g), k) for g, k in factors ]) + def dmp_zz_i_factor(self, f): + coeff, factors = dmp_zz_i_factor(self.to_dense(f), self.ngens-1, self.domain) + return (coeff, [ (self.from_dense(g), k) for g, k in factors ]) + + def dup_ext_factor(self, f): + coeff, factors = dup_ext_factor(self.to_dense(f), self.domain) + return (coeff, [ (self.from_dense(g), k) for g, k in factors ]) + def dmp_ext_factor(self, f): + coeff, factors = dmp_ext_factor(self.to_dense(f), self.ngens-1, self.domain) + return (coeff, [ (self.from_dense(g), k) for g, k in factors ]) + + def dup_gf_factor(self, f): + coeff, factors = dup_gf_factor(self.to_dense(f), self.domain) + return (coeff, [ (self.from_dense(g), k) for g, k in factors ]) + def dmp_gf_factor(self, f): + coeff, factors = dmp_gf_factor(self.to_dense(f), self.ngens-1, self.domain) + return (coeff, [ (self.from_dense(g), k) for g, k in factors ]) + + def dup_factor_list(self, f): + coeff, factors = dup_factor_list(self.to_dense(f), self.domain) + return (coeff, [ (self.from_dense(g), k) for g, k in factors ]) + def dup_factor_list_include(self, f): + factors = dup_factor_list_include(self.to_dense(f), self.domain) + return [ (self.from_dense(g), k) for g, k in factors ] + + def dmp_factor_list(self, f): + coeff, factors = dmp_factor_list(self.to_dense(f), self.ngens-1, self.domain) + return (coeff, [ (self.from_dense(g), k) for g, k in factors ]) + def dmp_factor_list_include(self, f): + factors = dmp_factor_list_include(self.to_dense(f), self.ngens-1, self.domain) + return [ (self.from_dense(g), k) for g, k in factors ] + + def dup_irreducible_p(self, f): + return dup_irreducible_p(self.to_dense(f), self.domain) + def dmp_irreducible_p(self, f): + return dmp_irreducible_p(self.to_dense(f), self.ngens-1, self.domain) + + def dup_sturm(self, f): + seq = dup_sturm(self.to_dense(f), self.domain) + return list(map(self.from_dense, seq)) + + def dup_sqf_p(self, f): + return dup_sqf_p(self.to_dense(f), self.domain) + def dmp_sqf_p(self, f): + return dmp_sqf_p(self.to_dense(f), self.ngens-1, self.domain) + + def dmp_norm(self, f): + n = dmp_norm(self.to_dense(f), self.ngens-1, self.domain) + return self.to_ground().from_dense(n) + + def dup_sqf_norm(self, f): + s, F, R = dup_sqf_norm(self.to_dense(f), self.domain) + return (s, self.from_dense(F), self.to_ground().from_dense(R)) + def dmp_sqf_norm(self, f): + s, F, R = dmp_sqf_norm(self.to_dense(f), self.ngens-1, self.domain) + return (s, self.from_dense(F), self.to_ground().from_dense(R)) + + def dup_gf_sqf_part(self, f): + return self.from_dense(dup_gf_sqf_part(self.to_dense(f), self.domain)) + def dmp_gf_sqf_part(self, f): + return self.from_dense(dmp_gf_sqf_part(self.to_dense(f), self.domain)) + def dup_sqf_part(self, f): + return self.from_dense(dup_sqf_part(self.to_dense(f), self.domain)) + def dmp_sqf_part(self, f): + return self.from_dense(dmp_sqf_part(self.to_dense(f), self.ngens-1, self.domain)) + + def dup_gf_sqf_list(self, f, all=False): + coeff, factors = dup_gf_sqf_list(self.to_dense(f), self.domain, all=all) + return (coeff, [ (self.from_dense(g), k) for g, k in factors ]) + def dmp_gf_sqf_list(self, f, all=False): + coeff, factors = dmp_gf_sqf_list(self.to_dense(f), self.ngens-1, self.domain, all=all) + return (coeff, [ (self.from_dense(g), k) for g, k in factors ]) + + def dup_sqf_list(self, f, all=False): + coeff, factors = dup_sqf_list(self.to_dense(f), self.domain, all=all) + return (coeff, [ (self.from_dense(g), k) for g, k in factors ]) + def dup_sqf_list_include(self, f, all=False): + factors = dup_sqf_list_include(self.to_dense(f), self.domain, all=all) + return [ (self.from_dense(g), k) for g, k in factors ] + def dmp_sqf_list(self, f, all=False): + coeff, factors = dmp_sqf_list(self.to_dense(f), self.ngens-1, self.domain, all=all) + return (coeff, [ (self.from_dense(g), k) for g, k in factors ]) + def dmp_sqf_list_include(self, f, all=False): + factors = dmp_sqf_list_include(self.to_dense(f), self.ngens-1, self.domain, all=all) + return [ (self.from_dense(g), k) for g, k in factors ] + + def dup_gff_list(self, f): + factors = dup_gff_list(self.to_dense(f), self.domain) + return [ (self.from_dense(g), k) for g, k in factors ] + def dmp_gff_list(self, f): + factors = dmp_gff_list(self.to_dense(f), self.ngens-1, self.domain) + return [ (self.from_dense(g), k) for g, k in factors ] + + def dup_root_upper_bound(self, f): + return dup_root_upper_bound(self.to_dense(f), self.domain) + def dup_root_lower_bound(self, f): + return dup_root_lower_bound(self.to_dense(f), self.domain) + + def dup_step_refine_real_root(self, f, M, fast=False): + return dup_step_refine_real_root(self.to_dense(f), M, self.domain, fast=fast) + def dup_inner_refine_real_root(self, f, M, eps=None, steps=None, disjoint=None, fast=False, mobius=False): + return dup_inner_refine_real_root(self.to_dense(f), M, self.domain, eps=eps, steps=steps, disjoint=disjoint, fast=fast, mobius=mobius) + def dup_outer_refine_real_root(self, f, s, t, eps=None, steps=None, disjoint=None, fast=False): + return dup_outer_refine_real_root(self.to_dense(f), s, t, self.domain, eps=eps, steps=steps, disjoint=disjoint, fast=fast) + def dup_refine_real_root(self, f, s, t, eps=None, steps=None, disjoint=None, fast=False): + return dup_refine_real_root(self.to_dense(f), s, t, self.domain, eps=eps, steps=steps, disjoint=disjoint, fast=fast) + def dup_inner_isolate_real_roots(self, f, eps=None, fast=False): + return dup_inner_isolate_real_roots(self.to_dense(f), self.domain, eps=eps, fast=fast) + def dup_inner_isolate_positive_roots(self, f, eps=None, inf=None, sup=None, fast=False, mobius=False): + return dup_inner_isolate_positive_roots(self.to_dense(f), self.domain, eps=eps, inf=inf, sup=sup, fast=fast, mobius=mobius) + def dup_inner_isolate_negative_roots(self, f, inf=None, sup=None, eps=None, fast=False, mobius=False): + return dup_inner_isolate_negative_roots(self.to_dense(f), self.domain, inf=inf, sup=sup, eps=eps, fast=fast, mobius=mobius) + def dup_isolate_real_roots_sqf(self, f, eps=None, inf=None, sup=None, fast=False, blackbox=False): + return dup_isolate_real_roots_sqf(self.to_dense(f), self.domain, eps=eps, inf=inf, sup=sup, fast=fast, blackbox=blackbox) + def dup_isolate_real_roots(self, f, eps=None, inf=None, sup=None, basis=False, fast=False): + return dup_isolate_real_roots(self.to_dense(f), self.domain, eps=eps, inf=inf, sup=sup, basis=basis, fast=fast) + def dup_isolate_real_roots_list(self, polys, eps=None, inf=None, sup=None, strict=False, basis=False, fast=False): + return dup_isolate_real_roots_list(list(map(self.to_dense, polys)), self.domain, eps=eps, inf=inf, sup=sup, strict=strict, basis=basis, fast=fast) + def dup_count_real_roots(self, f, inf=None, sup=None): + return dup_count_real_roots(self.to_dense(f), self.domain, inf=inf, sup=sup) + def dup_count_complex_roots(self, f, inf=None, sup=None, exclude=None): + return dup_count_complex_roots(self.to_dense(f), self.domain, inf=inf, sup=sup, exclude=exclude) + def dup_isolate_complex_roots_sqf(self, f, eps=None, inf=None, sup=None, blackbox=False): + return dup_isolate_complex_roots_sqf(self.to_dense(f), self.domain, eps=eps, inf=inf, sup=sup, blackbox=blackbox) + def dup_isolate_all_roots_sqf(self, f, eps=None, inf=None, sup=None, fast=False, blackbox=False): + return dup_isolate_all_roots_sqf(self.to_dense(f), self.domain, eps=eps, inf=inf, sup=sup, fast=fast, blackbox=blackbox) + def dup_isolate_all_roots(self, f, eps=None, inf=None, sup=None, fast=False): + return dup_isolate_all_roots(self.to_dense(f), self.domain, eps=eps, inf=inf, sup=sup, fast=fast) + + def fateman_poly_F_1(self): + from sympy.polys.specialpolys import dmp_fateman_poly_F_1 + return tuple(map(self.from_dense, dmp_fateman_poly_F_1(self.ngens-1, self.domain))) + def fateman_poly_F_2(self): + from sympy.polys.specialpolys import dmp_fateman_poly_F_2 + return tuple(map(self.from_dense, dmp_fateman_poly_F_2(self.ngens-1, self.domain))) + def fateman_poly_F_3(self): + from sympy.polys.specialpolys import dmp_fateman_poly_F_3 + return tuple(map(self.from_dense, dmp_fateman_poly_F_3(self.ngens-1, self.domain))) + + def to_gf_dense(self, element): + return gf_strip([ self.domain.dom.convert(c, self.domain) for c in self.wrap(element).to_dense() ]) + + def from_gf_dense(self, element): + return self.from_dict(dmp_to_dict(element, self.ngens-1, self.domain.dom)) + + def gf_degree(self, f): + return gf_degree(self.to_gf_dense(f)) + + def gf_LC(self, f): + return gf_LC(self.to_gf_dense(f), self.domain.dom) + def gf_TC(self, f): + return gf_TC(self.to_gf_dense(f), self.domain.dom) + + def gf_strip(self, f): + return self.from_gf_dense(gf_strip(self.to_gf_dense(f))) + def gf_trunc(self, f): + return self.from_gf_dense(gf_strip(self.to_gf_dense(f), self.domain.mod)) + def gf_normal(self, f): + return self.from_gf_dense(gf_strip(self.to_gf_dense(f), self.domain.mod, self.domain.dom)) + + def gf_from_dict(self, f): + return self.from_gf_dense(gf_from_dict(f, self.domain.mod, self.domain.dom)) + def gf_to_dict(self, f, symmetric=True): + return gf_to_dict(self.to_gf_dense(f), self.domain.mod, symmetric=symmetric) + + def gf_from_int_poly(self, f): + return self.from_gf_dense(gf_from_int_poly(f, self.domain.mod)) + def gf_to_int_poly(self, f, symmetric=True): + return gf_to_int_poly(self.to_gf_dense(f), self.domain.mod, symmetric=symmetric) + + def gf_neg(self, f): + return self.from_gf_dense(gf_neg(self.to_gf_dense(f), self.domain.mod, self.domain.dom)) + + def gf_add_ground(self, f, a): + return self.from_gf_dense(gf_add_ground(self.to_gf_dense(f), a, self.domain.mod, self.domain.dom)) + def gf_sub_ground(self, f, a): + return self.from_gf_dense(gf_sub_ground(self.to_gf_dense(f), a, self.domain.mod, self.domain.dom)) + def gf_mul_ground(self, f, a): + return self.from_gf_dense(gf_mul_ground(self.to_gf_dense(f), a, self.domain.mod, self.domain.dom)) + def gf_quo_ground(self, f, a): + return self.from_gf_dense(gf_quo_ground(self.to_gf_dense(f), a, self.domain.mod, self.domain.dom)) + + def gf_add(self, f, g): + return self.from_gf_dense(gf_add(self.to_gf_dense(f), self.to_gf_dense(g), self.domain.mod, self.domain.dom)) + def gf_sub(self, f, g): + return self.from_gf_dense(gf_sub(self.to_gf_dense(f), self.to_gf_dense(g), self.domain.mod, self.domain.dom)) + def gf_mul(self, f, g): + return self.from_gf_dense(gf_mul(self.to_gf_dense(f), self.to_gf_dense(g), self.domain.mod, self.domain.dom)) + def gf_sqr(self, f): + return self.from_gf_dense(gf_sqr(self.to_gf_dense(f), self.domain.mod, self.domain.dom)) + + def gf_add_mul(self, f, g, h): + return self.from_gf_dense(gf_add_mul(self.to_gf_dense(f), self.to_gf_dense(g), self.to_gf_dense(h), self.domain.mod, self.domain.dom)) + def gf_sub_mul(self, f, g, h): + return self.from_gf_dense(gf_sub_mul(self.to_gf_dense(f), self.to_gf_dense(g), self.to_gf_dense(h), self.domain.mod, self.domain.dom)) + + def gf_expand(self, F): + return self.from_gf_dense(gf_expand(list(map(self.to_gf_dense, F)), self.domain.mod, self.domain.dom)) + + def gf_div(self, f, g): + q, r = gf_div(self.to_gf_dense(f), self.to_gf_dense(g), self.domain.mod, self.domain.dom) + return self.from_gf_dense(q), self.from_gf_dense(r) + def gf_rem(self, f, g): + return self.from_gf_dense(gf_rem(self.to_gf_dense(f), self.to_gf_dense(g), self.domain.mod, self.domain.dom)) + def gf_quo(self, f, g): + return self.from_gf_dense(gf_quo(self.to_gf_dense(f), self.to_gf_dense(g), self.domain.mod, self.domain.dom)) + def gf_exquo(self, f, g): + return self.from_gf_dense(gf_exquo(self.to_gf_dense(f), self.to_gf_dense(g), self.domain.mod, self.domain.dom)) + + def gf_lshift(self, f, n): + return self.from_gf_dense(gf_lshift(self.to_gf_dense(f), n, self.domain.dom)) + def gf_rshift(self, f, n): + return self.from_gf_dense(gf_rshift(self.to_gf_dense(f), n, self.domain.dom)) + + def gf_pow(self, f, n): + return self.from_gf_dense(gf_pow(self.to_gf_dense(f), n, self.domain.mod, self.domain.dom)) + def gf_pow_mod(self, f, n, g): + return self.from_gf_dense(gf_pow_mod(self.to_gf_dense(f), n, self.to_gf_dense(g), self.domain.mod, self.domain.dom)) + + def gf_cofactors(self, f, g): + h, cff, cfg = gf_cofactors(self.to_gf_dense(f), self.to_gf_dense(g), self.domain.mod, self.domain.dom) + return self.from_gf_dense(h), self.from_gf_dense(cff), self.from_gf_dense(cfg) + def gf_gcd(self, f, g): + return self.from_gf_dense(gf_gcd(self.to_gf_dense(f), self.to_gf_dense(g), self.domain.mod, self.domain.dom)) + def gf_lcm(self, f, g): + return self.from_gf_dense(gf_lcm(self.to_gf_dense(f), self.to_gf_dense(g), self.domain.mod, self.domain.dom)) + def gf_gcdex(self, f, g): + return self.from_gf_dense(gf_gcdex(self.to_gf_dense(f), self.to_gf_dense(g), self.domain.mod, self.domain.dom)) + + def gf_monic(self, f): + return self.from_gf_dense(gf_monic(self.to_gf_dense(f), self.domain.mod, self.domain.dom)) + def gf_diff(self, f): + return self.from_gf_dense(gf_diff(self.to_gf_dense(f), self.domain.mod, self.domain.dom)) + + def gf_eval(self, f, a): + return gf_eval(self.to_gf_dense(f), a, self.domain.mod, self.domain.dom) + def gf_multi_eval(self, f, A): + return gf_multi_eval(self.to_gf_dense(f), A, self.domain.mod, self.domain.dom) + + def gf_compose(self, f, g): + return self.from_gf_dense(gf_compose(self.to_gf_dense(f), self.to_gf_dense(g), self.domain.mod, self.domain.dom)) + def gf_compose_mod(self, g, h, f): + return self.from_gf_dense(gf_compose_mod(self.to_gf_dense(g), self.to_gf_dense(h), self.to_gf_dense(f), self.domain.mod, self.domain.dom)) + + def gf_trace_map(self, a, b, c, n, f): + a = self.to_gf_dense(a) + b = self.to_gf_dense(b) + c = self.to_gf_dense(c) + f = self.to_gf_dense(f) + U, V = gf_trace_map(a, b, c, n, f, self.domain.mod, self.domain.dom) + return self.from_gf_dense(U), self.from_gf_dense(V) + + def gf_random(self, n): + return self.from_gf_dense(gf_random(n, self.domain.mod, self.domain.dom)) + def gf_irreducible(self, n): + return self.from_gf_dense(gf_irreducible(n, self.domain.mod, self.domain.dom)) + + def gf_irred_p_ben_or(self, f): + return gf_irred_p_ben_or(self.to_gf_dense(f), self.domain.mod, self.domain.dom) + def gf_irred_p_rabin(self, f): + return gf_irred_p_rabin(self.to_gf_dense(f), self.domain.mod, self.domain.dom) + def gf_irreducible_p(self, f): + return gf_irreducible_p(self.to_gf_dense(f), self.domain.mod, self.domain.dom) + def gf_sqf_p(self, f): + return gf_sqf_p(self.to_gf_dense(f), self.domain.mod, self.domain.dom) + + def gf_sqf_part(self, f): + return self.from_gf_dense(gf_sqf_part(self.to_gf_dense(f), self.domain.mod, self.domain.dom)) + def gf_sqf_list(self, f, all=False): + coeff, factors = gf_sqf_part(self.to_gf_dense(f), self.domain.mod, self.domain.dom) + return coeff, [ (self.from_gf_dense(g), k) for g, k in factors ] + + def gf_Qmatrix(self, f): + return gf_Qmatrix(self.to_gf_dense(f), self.domain.mod, self.domain.dom) + def gf_berlekamp(self, f): + factors = gf_berlekamp(self.to_gf_dense(f), self.domain.mod, self.domain.dom) + return [ self.from_gf_dense(g) for g in factors ] + + def gf_ddf_zassenhaus(self, f): + factors = gf_ddf_zassenhaus(self.to_gf_dense(f), self.domain.mod, self.domain.dom) + return [ (self.from_gf_dense(g), k) for g, k in factors ] + def gf_edf_zassenhaus(self, f, n): + factors = gf_edf_zassenhaus(self.to_gf_dense(f), self.domain.mod, self.domain.dom) + return [ self.from_gf_dense(g) for g in factors ] + + def gf_ddf_shoup(self, f): + factors = gf_ddf_shoup(self.to_gf_dense(f), self.domain.mod, self.domain.dom) + return [ (self.from_gf_dense(g), k) for g, k in factors ] + def gf_edf_shoup(self, f, n): + factors = gf_edf_shoup(self.to_gf_dense(f), self.domain.mod, self.domain.dom) + return [ self.from_gf_dense(g) for g in factors ] + + def gf_zassenhaus(self, f): + factors = gf_zassenhaus(self.to_gf_dense(f), self.domain.mod, self.domain.dom) + return [ self.from_gf_dense(g) for g in factors ] + def gf_shoup(self, f): + factors = gf_shoup(self.to_gf_dense(f), self.domain.mod, self.domain.dom) + return [ self.from_gf_dense(g) for g in factors ] + + def gf_factor_sqf(self, f, method=None): + coeff, factors = gf_factor_sqf(self.to_gf_dense(f), self.domain.mod, self.domain.dom, method=method) + return coeff, [ self.from_gf_dense(g) for g in factors ] + def gf_factor(self, f): + coeff, factors = gf_factor(self.to_gf_dense(f), self.domain.mod, self.domain.dom) + return coeff, [ (self.from_gf_dense(g), k) for g, k in factors ] diff --git a/phivenv/Lib/site-packages/sympy/polys/constructor.py b/phivenv/Lib/site-packages/sympy/polys/constructor.py new file mode 100644 index 0000000000000000000000000000000000000000..49ce4782b987419ee8b736974f8755301380bdda --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/constructor.py @@ -0,0 +1,387 @@ +"""Tools for constructing domains for expressions. """ +from math import prod + +from sympy.core import sympify +from sympy.core.evalf import pure_complex +from sympy.core.sorting import ordered +from sympy.polys.domains import ZZ, QQ, ZZ_I, QQ_I, EX +from sympy.polys.domains.complexfield import ComplexField +from sympy.polys.domains.realfield import RealField +from sympy.polys.polyoptions import build_options +from sympy.polys.polyutils import parallel_dict_from_basic +from sympy.utilities import public + + +def _construct_simple(coeffs, opt): + """Handle simple domains, e.g.: ZZ, QQ, RR and algebraic domains. """ + rationals = floats = complexes = algebraics = False + float_numbers = [] + + if opt.extension is True: + is_algebraic = lambda coeff: coeff.is_number and coeff.is_algebraic + else: + is_algebraic = lambda coeff: False + + for coeff in coeffs: + if coeff.is_Rational: + if not coeff.is_Integer: + rationals = True + elif coeff.is_Float: + if algebraics: + # there are both reals and algebraics -> EX + return False + else: + floats = True + float_numbers.append(coeff) + else: + is_complex = pure_complex(coeff) + if is_complex: + complexes = True + x, y = is_complex + if x.is_Rational and y.is_Rational: + if not (x.is_Integer and y.is_Integer): + rationals = True + continue + else: + floats = True + if x.is_Float: + float_numbers.append(x) + if y.is_Float: + float_numbers.append(y) + elif is_algebraic(coeff): + if floats: + # there are both algebraics and reals -> EX + return False + algebraics = True + else: + # this is a composite domain, e.g. ZZ[X], EX + return None + + # Use the maximum precision of all coefficients for the RR or CC + # precision + max_prec = max(c._prec for c in float_numbers) if float_numbers else 53 + + if algebraics: + domain, result = _construct_algebraic(coeffs, opt) + else: + if floats and complexes: + domain = ComplexField(prec=max_prec) + elif floats: + domain = RealField(prec=max_prec) + elif rationals or opt.field: + domain = QQ_I if complexes else QQ + else: + domain = ZZ_I if complexes else ZZ + + result = [domain.from_sympy(coeff) for coeff in coeffs] + + return domain, result + + +def _construct_algebraic(coeffs, opt): + """We know that coefficients are algebraic so construct the extension. """ + from sympy.polys.numberfields import primitive_element + + exts = set() + + def build_trees(args): + trees = [] + for a in args: + if a.is_Rational: + tree = ('Q', QQ.from_sympy(a)) + elif a.is_Add: + tree = ('+', build_trees(a.args)) + elif a.is_Mul: + tree = ('*', build_trees(a.args)) + else: + tree = ('e', a) + exts.add(a) + trees.append(tree) + return trees + + trees = build_trees(coeffs) + exts = list(ordered(exts)) + + g, span, H = primitive_element(exts, ex=True, polys=True) + root = sum(s*ext for s, ext in zip(span, exts)) + + domain, g = QQ.algebraic_field((g, root)), g.rep.to_list() + + exts_dom = [domain.dtype.from_list(h, g, QQ) for h in H] + exts_map = dict(zip(exts, exts_dom)) + + def convert_tree(tree): + op, args = tree + if op == 'Q': + return domain.dtype.from_list([args], g, QQ) + elif op == '+': + return sum((convert_tree(a) for a in args), domain.zero) + elif op == '*': + return prod(convert_tree(a) for a in args) + elif op == 'e': + return exts_map[args] + else: + raise RuntimeError + + result = [convert_tree(tree) for tree in trees] + + return domain, result + + +def _construct_composite(coeffs, opt): + """Handle composite domains, e.g.: ZZ[X], QQ[X], ZZ(X), QQ(X). """ + numers, denoms = [], [] + + for coeff in coeffs: + numer, denom = coeff.as_numer_denom() + + numers.append(numer) + denoms.append(denom) + + polys, gens = parallel_dict_from_basic(numers + denoms) # XXX: sorting + if not gens: + return None + + if opt.composite is None: + if any(gen.is_number and gen.is_algebraic for gen in gens): + return None # generators are number-like so lets better use EX + + all_symbols = set() + + for gen in gens: + symbols = gen.free_symbols + + if all_symbols & symbols: + return None # there could be algebraic relations between generators + else: + all_symbols |= symbols + + n = len(gens) + k = len(polys)//2 + + numers = polys[:k] + denoms = polys[k:] + + if opt.field: + fractions = True + else: + fractions, zeros = False, (0,)*n + + for denom in denoms: + if len(denom) > 1 or zeros not in denom: + fractions = True + break + + coeffs = set() + + if not fractions: + for numer, denom in zip(numers, denoms): + denom = denom[zeros] + + for monom, coeff in numer.items(): + coeff /= denom + coeffs.add(coeff) + numer[monom] = coeff + else: + for numer, denom in zip(numers, denoms): + coeffs.update(list(numer.values())) + coeffs.update(list(denom.values())) + + rationals = floats = complexes = False + float_numbers = [] + + for coeff in coeffs: + if coeff.is_Rational: + if not coeff.is_Integer: + rationals = True + elif coeff.is_Float: + floats = True + float_numbers.append(coeff) + else: + is_complex = pure_complex(coeff) + if is_complex is not None: + complexes = True + x, y = is_complex + if x.is_Rational and y.is_Rational: + if not (x.is_Integer and y.is_Integer): + rationals = True + else: + floats = True + if x.is_Float: + float_numbers.append(x) + if y.is_Float: + float_numbers.append(y) + + max_prec = max(c._prec for c in float_numbers) if float_numbers else 53 + + if floats and complexes: + ground = ComplexField(prec=max_prec) + elif floats: + ground = RealField(prec=max_prec) + elif complexes: + if rationals: + ground = QQ_I + else: + ground = ZZ_I + elif rationals: + ground = QQ + else: + ground = ZZ + + result = [] + + if not fractions: + domain = ground.poly_ring(*gens) + + for numer in numers: + for monom, coeff in numer.items(): + numer[monom] = ground.from_sympy(coeff) + + result.append(domain(numer)) + else: + domain = ground.frac_field(*gens) + + for numer, denom in zip(numers, denoms): + for monom, coeff in numer.items(): + numer[monom] = ground.from_sympy(coeff) + + for monom, coeff in denom.items(): + denom[monom] = ground.from_sympy(coeff) + + result.append(domain((numer, denom))) + + return domain, result + + +def _construct_expression(coeffs, opt): + """The last resort case, i.e. use the expression domain. """ + domain, result = EX, [] + + for coeff in coeffs: + result.append(domain.from_sympy(coeff)) + + return domain, result + + +@public +def construct_domain(obj, **args): + """Construct a minimal domain for a list of expressions. + + Explanation + =========== + + Given a list of normal SymPy expressions (of type :py:class:`~.Expr`) + ``construct_domain`` will find a minimal :py:class:`~.Domain` that can + represent those expressions. The expressions will be converted to elements + of the domain and both the domain and the domain elements are returned. + + Parameters + ========== + + obj: list or dict + The expressions to build a domain for. + + **args: keyword arguments + Options that affect the choice of domain. + + Returns + ======= + + (K, elements): Domain and list of domain elements + The domain K that can represent the expressions and the list or dict + of domain elements representing the same expressions as elements of K. + + Examples + ======== + + Given a list of :py:class:`~.Integer` ``construct_domain`` will return the + domain :ref:`ZZ` and a list of integers as elements of :ref:`ZZ`. + + >>> from sympy import construct_domain, S + >>> expressions = [S(2), S(3), S(4)] + >>> K, elements = construct_domain(expressions) + >>> K + ZZ + >>> elements + [2, 3, 4] + >>> type(elements[0]) # doctest: +SKIP + + >>> type(expressions[0]) + + + If there are any :py:class:`~.Rational` then :ref:`QQ` is returned + instead. + + >>> construct_domain([S(1)/2, S(3)/4]) + (QQ, [1/2, 3/4]) + + If there are symbols then a polynomial ring :ref:`K[x]` is returned. + + >>> from sympy import symbols + >>> x, y = symbols('x, y') + >>> construct_domain([2*x + 1, S(3)/4]) + (QQ[x], [2*x + 1, 3/4]) + >>> construct_domain([2*x + 1, y]) + (ZZ[x,y], [2*x + 1, y]) + + If any symbols appear with negative powers then a rational function field + :ref:`K(x)` will be returned. + + >>> construct_domain([y/x, x/(1 - y)]) + (ZZ(x,y), [y/x, -x/(y - 1)]) + + Irrational algebraic numbers will result in the :ref:`EX` domain by + default. The keyword argument ``extension=True`` leads to the construction + of an algebraic number field :ref:`QQ(a)`. + + >>> from sympy import sqrt + >>> construct_domain([sqrt(2)]) + (EX, [EX(sqrt(2))]) + >>> construct_domain([sqrt(2)], extension=True) # doctest: +SKIP + (QQ, [ANP([1, 0], [1, 0, -2], QQ)]) + + See also + ======== + + Domain + Expr + """ + opt = build_options(args) + + if hasattr(obj, '__iter__'): + if isinstance(obj, dict): + if not obj: + monoms, coeffs = [], [] + else: + monoms, coeffs = list(zip(*list(obj.items()))) + else: + coeffs = obj + else: + coeffs = [obj] + + coeffs = list(map(sympify, coeffs)) + result = _construct_simple(coeffs, opt) + + if result is not None: + if result is not False: + domain, coeffs = result + else: + domain, coeffs = _construct_expression(coeffs, opt) + else: + if opt.composite is False: + result = None + else: + result = _construct_composite(coeffs, opt) + + if result is not None: + domain, coeffs = result + else: + domain, coeffs = _construct_expression(coeffs, opt) + + if hasattr(obj, '__iter__'): + if isinstance(obj, dict): + return domain, dict(list(zip(monoms, coeffs))) + else: + return domain, coeffs + else: + return domain, coeffs[0] diff --git a/phivenv/Lib/site-packages/sympy/polys/densearith.py b/phivenv/Lib/site-packages/sympy/polys/densearith.py new file mode 100644 index 0000000000000000000000000000000000000000..1088691ca3fb020e9074c1c7c017c1baaba637c8 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/densearith.py @@ -0,0 +1,1875 @@ +"""Arithmetics for dense recursive polynomials in ``K[x]`` or ``K[X]``. """ + + +from sympy.polys.densebasic import ( + dup_slice, + dup_LC, dmp_LC, + dup_degree, dmp_degree, + dup_strip, dmp_strip, + dmp_zero_p, dmp_zero, + dmp_one_p, dmp_one, + dmp_ground, dmp_zeros) +from sympy.polys.polyerrors import (ExactQuotientFailed, PolynomialDivisionFailed) + +def dup_add_term(f, c, i, K): + """ + Add ``c*x**i`` to ``f`` in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_add_term(x**2 - 1, ZZ(2), 4) + 2*x**4 + x**2 - 1 + + """ + if not c: + return f + + n = len(f) + m = n - i - 1 + + if i == n - 1: + return dup_strip([f[0] + c] + f[1:]) + else: + if i >= n: + return [c] + [K.zero]*(i - n) + f + else: + return f[:m] + [f[m] + c] + f[m + 1:] + + +def dmp_add_term(f, c, i, u, K): + """ + Add ``c(x_2..x_u)*x_0**i`` to ``f`` in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + >>> R.dmp_add_term(x*y + 1, 2, 2) + 2*x**2 + x*y + 1 + + """ + if not u: + return dup_add_term(f, c, i, K) + + v = u - 1 + + if dmp_zero_p(c, v): + return f + + n = len(f) + m = n - i - 1 + + if i == n - 1: + return dmp_strip([dmp_add(f[0], c, v, K)] + f[1:], u) + else: + if i >= n: + return [c] + dmp_zeros(i - n, v, K) + f + else: + return f[:m] + [dmp_add(f[m], c, v, K)] + f[m + 1:] + + +def dup_sub_term(f, c, i, K): + """ + Subtract ``c*x**i`` from ``f`` in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_sub_term(2*x**4 + x**2 - 1, ZZ(2), 4) + x**2 - 1 + + """ + if not c: + return f + + n = len(f) + m = n - i - 1 + + if i == n - 1: + return dup_strip([f[0] - c] + f[1:]) + else: + if i >= n: + return [-c] + [K.zero]*(i - n) + f + else: + return f[:m] + [f[m] - c] + f[m + 1:] + + +def dmp_sub_term(f, c, i, u, K): + """ + Subtract ``c(x_2..x_u)*x_0**i`` from ``f`` in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + >>> R.dmp_sub_term(2*x**2 + x*y + 1, 2, 2) + x*y + 1 + + """ + if not u: + return dup_add_term(f, -c, i, K) + + v = u - 1 + + if dmp_zero_p(c, v): + return f + + n = len(f) + m = n - i - 1 + + if i == n - 1: + return dmp_strip([dmp_sub(f[0], c, v, K)] + f[1:], u) + else: + if i >= n: + return [dmp_neg(c, v, K)] + dmp_zeros(i - n, v, K) + f + else: + return f[:m] + [dmp_sub(f[m], c, v, K)] + f[m + 1:] + + +def dup_mul_term(f, c, i, K): + """ + Multiply ``f`` by ``c*x**i`` in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_mul_term(x**2 - 1, ZZ(3), 2) + 3*x**4 - 3*x**2 + + """ + if not c or not f: + return [] + else: + return [ cf * c for cf in f ] + [K.zero]*i + + +def dmp_mul_term(f, c, i, u, K): + """ + Multiply ``f`` by ``c(x_2..x_u)*x_0**i`` in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + >>> R.dmp_mul_term(x**2*y + x, 3*y, 2) + 3*x**4*y**2 + 3*x**3*y + + """ + if not u: + return dup_mul_term(f, c, i, K) + + v = u - 1 + + if dmp_zero_p(f, u): + return f + if dmp_zero_p(c, v): + return dmp_zero(u) + else: + return [ dmp_mul(cf, c, v, K) for cf in f ] + dmp_zeros(i, v, K) + + +def dup_add_ground(f, c, K): + """ + Add an element of the ground domain to ``f``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_add_ground(x**3 + 2*x**2 + 3*x + 4, ZZ(4)) + x**3 + 2*x**2 + 3*x + 8 + + """ + return dup_add_term(f, c, 0, K) + + +def dmp_add_ground(f, c, u, K): + """ + Add an element of the ground domain to ``f``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + >>> R.dmp_add_ground(x**3 + 2*x**2 + 3*x + 4, ZZ(4)) + x**3 + 2*x**2 + 3*x + 8 + + """ + return dmp_add_term(f, dmp_ground(c, u - 1), 0, u, K) + + +def dup_sub_ground(f, c, K): + """ + Subtract an element of the ground domain from ``f``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_sub_ground(x**3 + 2*x**2 + 3*x + 4, ZZ(4)) + x**3 + 2*x**2 + 3*x + + """ + return dup_sub_term(f, c, 0, K) + + +def dmp_sub_ground(f, c, u, K): + """ + Subtract an element of the ground domain from ``f``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + >>> R.dmp_sub_ground(x**3 + 2*x**2 + 3*x + 4, ZZ(4)) + x**3 + 2*x**2 + 3*x + + """ + return dmp_sub_term(f, dmp_ground(c, u - 1), 0, u, K) + + +def dup_mul_ground(f, c, K): + """ + Multiply ``f`` by a constant value in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_mul_ground(x**2 + 2*x - 1, ZZ(3)) + 3*x**2 + 6*x - 3 + + """ + if not c or not f: + return [] + else: + return [ cf * c for cf in f ] + + +def dmp_mul_ground(f, c, u, K): + """ + Multiply ``f`` by a constant value in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + >>> R.dmp_mul_ground(2*x + 2*y, ZZ(3)) + 6*x + 6*y + + """ + if not u: + return dup_mul_ground(f, c, K) + + v = u - 1 + + return [ dmp_mul_ground(cf, c, v, K) for cf in f ] + + +def dup_quo_ground(f, c, K): + """ + Quotient by a constant in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ, QQ + + >>> R, x = ring("x", ZZ) + >>> R.dup_quo_ground(3*x**2 + 2, ZZ(2)) + x**2 + 1 + + >>> R, x = ring("x", QQ) + >>> R.dup_quo_ground(3*x**2 + 2, QQ(2)) + 3/2*x**2 + 1 + + """ + if not c: + raise ZeroDivisionError('polynomial division') + if not f: + return f + + if K.is_Field: + return [ K.quo(cf, c) for cf in f ] + else: + return [ cf // c for cf in f ] + + +def dmp_quo_ground(f, c, u, K): + """ + Quotient by a constant in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ, QQ + + >>> R, x,y = ring("x,y", ZZ) + >>> R.dmp_quo_ground(2*x**2*y + 3*x, ZZ(2)) + x**2*y + x + + >>> R, x,y = ring("x,y", QQ) + >>> R.dmp_quo_ground(2*x**2*y + 3*x, QQ(2)) + x**2*y + 3/2*x + + """ + if not u: + return dup_quo_ground(f, c, K) + + v = u - 1 + + return [ dmp_quo_ground(cf, c, v, K) for cf in f ] + + +def dup_exquo_ground(f, c, K): + """ + Exact quotient by a constant in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys import ring, QQ + >>> R, x = ring("x", QQ) + + >>> R.dup_exquo_ground(x**2 + 2, QQ(2)) + 1/2*x**2 + 1 + + """ + if not c: + raise ZeroDivisionError('polynomial division') + if not f: + return f + + return [ K.exquo(cf, c) for cf in f ] + + +def dmp_exquo_ground(f, c, u, K): + """ + Exact quotient by a constant in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys import ring, QQ + >>> R, x,y = ring("x,y", QQ) + + >>> R.dmp_exquo_ground(x**2*y + 2*x, QQ(2)) + 1/2*x**2*y + x + + """ + if not u: + return dup_exquo_ground(f, c, K) + + v = u - 1 + + return [ dmp_exquo_ground(cf, c, v, K) for cf in f ] + + +def dup_lshift(f, n, K): + """ + Efficiently multiply ``f`` by ``x**n`` in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_lshift(x**2 + 1, 2) + x**4 + x**2 + + """ + if not f: + return f + else: + return f + [K.zero]*n + + +def dup_rshift(f, n, K): + """ + Efficiently divide ``f`` by ``x**n`` in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_rshift(x**4 + x**2, 2) + x**2 + 1 + >>> R.dup_rshift(x**4 + x**2 + 2, 2) + x**2 + 1 + + """ + return f[:-n] + + +def dup_abs(f, K): + """ + Make all coefficients positive in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_abs(x**2 - 1) + x**2 + 1 + + """ + return [ K.abs(coeff) for coeff in f ] + + +def dmp_abs(f, u, K): + """ + Make all coefficients positive in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + >>> R.dmp_abs(x**2*y - x) + x**2*y + x + + """ + if not u: + return dup_abs(f, K) + + v = u - 1 + + return [ dmp_abs(cf, v, K) for cf in f ] + + +def dup_neg(f, K): + """ + Negate a polynomial in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_neg(x**2 - 1) + -x**2 + 1 + + """ + return [ -coeff for coeff in f ] + + +def dmp_neg(f, u, K): + """ + Negate a polynomial in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + >>> R.dmp_neg(x**2*y - x) + -x**2*y + x + + """ + if not u: + return dup_neg(f, K) + + v = u - 1 + + return [ dmp_neg(cf, v, K) for cf in f ] + + +def dup_add(f, g, K): + """ + Add dense polynomials in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_add(x**2 - 1, x - 2) + x**2 + x - 3 + + """ + if not f: + return g + if not g: + return f + + df = dup_degree(f) + dg = dup_degree(g) + + if df == dg: + return dup_strip([ a + b for a, b in zip(f, g) ]) + else: + k = abs(df - dg) + + if df > dg: + h, f = f[:k], f[k:] + else: + h, g = g[:k], g[k:] + + return h + [ a + b for a, b in zip(f, g) ] + + +def dmp_add(f, g, u, K): + """ + Add dense polynomials in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + >>> R.dmp_add(x**2 + y, x**2*y + x) + x**2*y + x**2 + x + y + + """ + if not u: + return dup_add(f, g, K) + + df = dmp_degree(f, u) + + if df < 0: + return g + + dg = dmp_degree(g, u) + + if dg < 0: + return f + + v = u - 1 + + if df == dg: + return dmp_strip([ dmp_add(a, b, v, K) for a, b in zip(f, g) ], u) + else: + k = abs(df - dg) + + if df > dg: + h, f = f[:k], f[k:] + else: + h, g = g[:k], g[k:] + + return h + [ dmp_add(a, b, v, K) for a, b in zip(f, g) ] + + +def dup_sub(f, g, K): + """ + Subtract dense polynomials in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_sub(x**2 - 1, x - 2) + x**2 - x + 1 + + """ + if not f: + return dup_neg(g, K) + if not g: + return f + + df = dup_degree(f) + dg = dup_degree(g) + + if df == dg: + return dup_strip([ a - b for a, b in zip(f, g) ]) + else: + k = abs(df - dg) + + if df > dg: + h, f = f[:k], f[k:] + else: + h, g = dup_neg(g[:k], K), g[k:] + + return h + [ a - b for a, b in zip(f, g) ] + + +def dmp_sub(f, g, u, K): + """ + Subtract dense polynomials in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + >>> R.dmp_sub(x**2 + y, x**2*y + x) + -x**2*y + x**2 - x + y + + """ + if not u: + return dup_sub(f, g, K) + + df = dmp_degree(f, u) + + if df < 0: + return dmp_neg(g, u, K) + + dg = dmp_degree(g, u) + + if dg < 0: + return f + + v = u - 1 + + if df == dg: + return dmp_strip([ dmp_sub(a, b, v, K) for a, b in zip(f, g) ], u) + else: + k = abs(df - dg) + + if df > dg: + h, f = f[:k], f[k:] + else: + h, g = dmp_neg(g[:k], u, K), g[k:] + + return h + [ dmp_sub(a, b, v, K) for a, b in zip(f, g) ] + + +def dup_add_mul(f, g, h, K): + """ + Returns ``f + g*h`` where ``f, g, h`` are in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_add_mul(x**2 - 1, x - 2, x + 2) + 2*x**2 - 5 + + """ + return dup_add(f, dup_mul(g, h, K), K) + + +def dmp_add_mul(f, g, h, u, K): + """ + Returns ``f + g*h`` where ``f, g, h`` are in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + >>> R.dmp_add_mul(x**2 + y, x, x + 2) + 2*x**2 + 2*x + y + + """ + return dmp_add(f, dmp_mul(g, h, u, K), u, K) + + +def dup_sub_mul(f, g, h, K): + """ + Returns ``f - g*h`` where ``f, g, h`` are in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_sub_mul(x**2 - 1, x - 2, x + 2) + 3 + + """ + return dup_sub(f, dup_mul(g, h, K), K) + + +def dmp_sub_mul(f, g, h, u, K): + """ + Returns ``f - g*h`` where ``f, g, h`` are in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + >>> R.dmp_sub_mul(x**2 + y, x, x + 2) + -2*x + y + + """ + return dmp_sub(f, dmp_mul(g, h, u, K), u, K) + + +def dup_mul(f, g, K): + """ + Multiply dense polynomials in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_mul(x - 2, x + 2) + x**2 - 4 + + """ + if f == g: + return dup_sqr(f, K) + + if not (f and g): + return [] + + df = dup_degree(f) + dg = dup_degree(g) + + n = max(df, dg) + 1 + + if n < 100 or not K.is_Exact: + h = [] + + for i in range(0, df + dg + 1): + coeff = K.zero + + for j in range(max(0, i - dg), min(df, i) + 1): + coeff += f[j]*g[i - j] + + h.append(coeff) + + return dup_strip(h) + else: + # Use Karatsuba's algorithm (divide and conquer), see e.g.: + # Joris van der Hoeven, Relax But Don't Be Too Lazy, + # J. Symbolic Computation, 11 (2002), section 3.1.1. + n2 = n//2 + + fl, gl = dup_slice(f, 0, n2, K), dup_slice(g, 0, n2, K) + + fh = dup_rshift(dup_slice(f, n2, n, K), n2, K) + gh = dup_rshift(dup_slice(g, n2, n, K), n2, K) + + lo, hi = dup_mul(fl, gl, K), dup_mul(fh, gh, K) + + mid = dup_mul(dup_add(fl, fh, K), dup_add(gl, gh, K), K) + mid = dup_sub(mid, dup_add(lo, hi, K), K) + + return dup_add(dup_add(lo, dup_lshift(mid, n2, K), K), + dup_lshift(hi, 2*n2, K), K) + + +def dmp_mul(f, g, u, K): + """ + Multiply dense polynomials in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + >>> R.dmp_mul(x*y + 1, x) + x**2*y + x + + """ + if not u: + return dup_mul(f, g, K) + + if f == g: + return dmp_sqr(f, u, K) + + df = dmp_degree(f, u) + + if df < 0: + return f + + dg = dmp_degree(g, u) + + if dg < 0: + return g + + h, v = [], u - 1 + + for i in range(0, df + dg + 1): + coeff = dmp_zero(v) + + for j in range(max(0, i - dg), min(df, i) + 1): + coeff = dmp_add(coeff, dmp_mul(f[j], g[i - j], v, K), v, K) + + h.append(coeff) + + return dmp_strip(h, u) + + +def dup_sqr(f, K): + """ + Square dense polynomials in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_sqr(x**2 + 1) + x**4 + 2*x**2 + 1 + + """ + df, h = len(f) - 1, [] + + for i in range(0, 2*df + 1): + c = K.zero + + jmin = max(0, i - df) + jmax = min(i, df) + + n = jmax - jmin + 1 + + jmax = jmin + n // 2 - 1 + + for j in range(jmin, jmax + 1): + c += f[j]*f[i - j] + + c += c + + if n & 1: + elem = f[jmax + 1] + c += elem**2 + + h.append(c) + + return dup_strip(h) + + +def dmp_sqr(f, u, K): + """ + Square dense polynomials in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + >>> R.dmp_sqr(x**2 + x*y + y**2) + x**4 + 2*x**3*y + 3*x**2*y**2 + 2*x*y**3 + y**4 + + """ + if not u: + return dup_sqr(f, K) + + df = dmp_degree(f, u) + + if df < 0: + return f + + h, v = [], u - 1 + + for i in range(0, 2*df + 1): + c = dmp_zero(v) + + jmin = max(0, i - df) + jmax = min(i, df) + + n = jmax - jmin + 1 + + jmax = jmin + n // 2 - 1 + + for j in range(jmin, jmax + 1): + c = dmp_add(c, dmp_mul(f[j], f[i - j], v, K), v, K) + + c = dmp_mul_ground(c, K(2), v, K) + + if n & 1: + elem = dmp_sqr(f[jmax + 1], v, K) + c = dmp_add(c, elem, v, K) + + h.append(c) + + return dmp_strip(h, u) + + +def dup_pow(f, n, K): + """ + Raise ``f`` to the ``n``-th power in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_pow(x - 2, 3) + x**3 - 6*x**2 + 12*x - 8 + + """ + if not n: + return [K.one] + if n < 0: + raise ValueError("Cannot raise polynomial to a negative power") + if n == 1 or not f or f == [K.one]: + return f + + g = [K.one] + + while True: + n, m = n//2, n + + if m % 2: + g = dup_mul(g, f, K) + + if not n: + break + + f = dup_sqr(f, K) + + return g + + +def dmp_pow(f, n, u, K): + """ + Raise ``f`` to the ``n``-th power in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + >>> R.dmp_pow(x*y + 1, 3) + x**3*y**3 + 3*x**2*y**2 + 3*x*y + 1 + + """ + if not u: + return dup_pow(f, n, K) + + if not n: + return dmp_one(u, K) + if n < 0: + raise ValueError("Cannot raise polynomial to a negative power") + if n == 1 or dmp_zero_p(f, u) or dmp_one_p(f, u, K): + return f + + g = dmp_one(u, K) + + while True: + n, m = n//2, n + + if m & 1: + g = dmp_mul(g, f, u, K) + + if not n: + break + + f = dmp_sqr(f, u, K) + + return g + + +def dup_pdiv(f, g, K): + """ + Polynomial pseudo-division in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_pdiv(x**2 + 1, 2*x - 4) + (2*x + 4, 20) + + """ + df = dup_degree(f) + dg = dup_degree(g) + + q, r, dr = [], f, df + + if not g: + raise ZeroDivisionError("polynomial division") + elif df < dg: + return q, r + + N = df - dg + 1 + lc_g = dup_LC(g, K) + + while True: + lc_r = dup_LC(r, K) + j, N = dr - dg, N - 1 + + Q = dup_mul_ground(q, lc_g, K) + q = dup_add_term(Q, lc_r, j, K) + + R = dup_mul_ground(r, lc_g, K) + G = dup_mul_term(g, lc_r, j, K) + r = dup_sub(R, G, K) + + _dr, dr = dr, dup_degree(r) + + if dr < dg: + break + elif not (dr < _dr): + raise PolynomialDivisionFailed(f, g, K) + + c = lc_g**N + + q = dup_mul_ground(q, c, K) + r = dup_mul_ground(r, c, K) + + return q, r + + +def dup_prem(f, g, K): + """ + Polynomial pseudo-remainder in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_prem(x**2 + 1, 2*x - 4) + 20 + + """ + df = dup_degree(f) + dg = dup_degree(g) + + r, dr = f, df + + if not g: + raise ZeroDivisionError("polynomial division") + elif df < dg: + return r + + N = df - dg + 1 + lc_g = dup_LC(g, K) + + while True: + lc_r = dup_LC(r, K) + j, N = dr - dg, N - 1 + + R = dup_mul_ground(r, lc_g, K) + G = dup_mul_term(g, lc_r, j, K) + r = dup_sub(R, G, K) + + _dr, dr = dr, dup_degree(r) + + if dr < dg: + break + elif not (dr < _dr): + raise PolynomialDivisionFailed(f, g, K) + + return dup_mul_ground(r, lc_g**N, K) + + +def dup_pquo(f, g, K): + """ + Polynomial exact pseudo-quotient in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_pquo(x**2 - 1, 2*x - 2) + 2*x + 2 + + >>> R.dup_pquo(x**2 + 1, 2*x - 4) + 2*x + 4 + + """ + return dup_pdiv(f, g, K)[0] + + +def dup_pexquo(f, g, K): + """ + Polynomial pseudo-quotient in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_pexquo(x**2 - 1, 2*x - 2) + 2*x + 2 + + >>> R.dup_pexquo(x**2 + 1, 2*x - 4) + Traceback (most recent call last): + ... + ExactQuotientFailed: [2, -4] does not divide [1, 0, 1] + + """ + q, r = dup_pdiv(f, g, K) + + if not r: + return q + else: + raise ExactQuotientFailed(f, g) + + +def dmp_pdiv(f, g, u, K): + """ + Polynomial pseudo-division in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + >>> R.dmp_pdiv(x**2 + x*y, 2*x + 2) + (2*x + 2*y - 2, -4*y + 4) + + """ + if not u: + return dup_pdiv(f, g, K) + + df = dmp_degree(f, u) + dg = dmp_degree(g, u) + + if dg < 0: + raise ZeroDivisionError("polynomial division") + + q, r, dr = dmp_zero(u), f, df + + if df < dg: + return q, r + + N = df - dg + 1 + lc_g = dmp_LC(g, K) + + while True: + lc_r = dmp_LC(r, K) + j, N = dr - dg, N - 1 + + Q = dmp_mul_term(q, lc_g, 0, u, K) + q = dmp_add_term(Q, lc_r, j, u, K) + + R = dmp_mul_term(r, lc_g, 0, u, K) + G = dmp_mul_term(g, lc_r, j, u, K) + r = dmp_sub(R, G, u, K) + + _dr, dr = dr, dmp_degree(r, u) + + if dr < dg: + break + elif not (dr < _dr): + raise PolynomialDivisionFailed(f, g, K) + + c = dmp_pow(lc_g, N, u - 1, K) + + q = dmp_mul_term(q, c, 0, u, K) + r = dmp_mul_term(r, c, 0, u, K) + + return q, r + + +def dmp_prem(f, g, u, K): + """ + Polynomial pseudo-remainder in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + >>> R.dmp_prem(x**2 + x*y, 2*x + 2) + -4*y + 4 + + """ + if not u: + return dup_prem(f, g, K) + + df = dmp_degree(f, u) + dg = dmp_degree(g, u) + + if dg < 0: + raise ZeroDivisionError("polynomial division") + + r, dr = f, df + + if df < dg: + return r + + N = df - dg + 1 + lc_g = dmp_LC(g, K) + + while True: + lc_r = dmp_LC(r, K) + j, N = dr - dg, N - 1 + + R = dmp_mul_term(r, lc_g, 0, u, K) + G = dmp_mul_term(g, lc_r, j, u, K) + r = dmp_sub(R, G, u, K) + + _dr, dr = dr, dmp_degree(r, u) + + if dr < dg: + break + elif not (dr < _dr): + raise PolynomialDivisionFailed(f, g, K) + + c = dmp_pow(lc_g, N, u - 1, K) + + return dmp_mul_term(r, c, 0, u, K) + + +def dmp_pquo(f, g, u, K): + """ + Polynomial exact pseudo-quotient in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + >>> f = x**2 + x*y + >>> g = 2*x + 2*y + >>> h = 2*x + 2 + + >>> R.dmp_pquo(f, g) + 2*x + + >>> R.dmp_pquo(f, h) + 2*x + 2*y - 2 + + """ + return dmp_pdiv(f, g, u, K)[0] + + +def dmp_pexquo(f, g, u, K): + """ + Polynomial pseudo-quotient in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + >>> f = x**2 + x*y + >>> g = 2*x + 2*y + >>> h = 2*x + 2 + + >>> R.dmp_pexquo(f, g) + 2*x + + >>> R.dmp_pexquo(f, h) + Traceback (most recent call last): + ... + ExactQuotientFailed: [[2], [2]] does not divide [[1], [1, 0], []] + + """ + q, r = dmp_pdiv(f, g, u, K) + + if dmp_zero_p(r, u): + return q + else: + raise ExactQuotientFailed(f, g) + + +def dup_rr_div(f, g, K): + """ + Univariate division with remainder over a ring. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_rr_div(x**2 + 1, 2*x - 4) + (0, x**2 + 1) + + """ + df = dup_degree(f) + dg = dup_degree(g) + + q, r, dr = [], f, df + + if not g: + raise ZeroDivisionError("polynomial division") + elif df < dg: + return q, r + + lc_g = dup_LC(g, K) + + while True: + lc_r = dup_LC(r, K) + + if lc_r % lc_g: + break + + c = K.exquo(lc_r, lc_g) + j = dr - dg + + q = dup_add_term(q, c, j, K) + h = dup_mul_term(g, c, j, K) + r = dup_sub(r, h, K) + + _dr, dr = dr, dup_degree(r) + + if dr < dg: + break + elif not (dr < _dr): + raise PolynomialDivisionFailed(f, g, K) + + return q, r + + +def dmp_rr_div(f, g, u, K): + """ + Multivariate division with remainder over a ring. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + >>> R.dmp_rr_div(x**2 + x*y, 2*x + 2) + (0, x**2 + x*y) + + """ + if not u: + return dup_rr_div(f, g, K) + + df = dmp_degree(f, u) + dg = dmp_degree(g, u) + + if dg < 0: + raise ZeroDivisionError("polynomial division") + + q, r, dr = dmp_zero(u), f, df + + if df < dg: + return q, r + + lc_g, v = dmp_LC(g, K), u - 1 + + while True: + lc_r = dmp_LC(r, K) + c, R = dmp_rr_div(lc_r, lc_g, v, K) + + if not dmp_zero_p(R, v): + break + + j = dr - dg + + q = dmp_add_term(q, c, j, u, K) + h = dmp_mul_term(g, c, j, u, K) + r = dmp_sub(r, h, u, K) + + _dr, dr = dr, dmp_degree(r, u) + + if dr < dg: + break + elif not (dr < _dr): + raise PolynomialDivisionFailed(f, g, K) + + return q, r + + +def dup_ff_div(f, g, K): + """ + Polynomial division with remainder over a field. + + Examples + ======== + + >>> from sympy.polys import ring, QQ + >>> R, x = ring("x", QQ) + + >>> R.dup_ff_div(x**2 + 1, 2*x - 4) + (1/2*x + 1, 5) + + """ + df = dup_degree(f) + dg = dup_degree(g) + + q, r, dr = [], f, df + + if not g: + raise ZeroDivisionError("polynomial division") + elif df < dg: + return q, r + + lc_g = dup_LC(g, K) + + while True: + lc_r = dup_LC(r, K) + + c = K.exquo(lc_r, lc_g) + j = dr - dg + + q = dup_add_term(q, c, j, K) + h = dup_mul_term(g, c, j, K) + r = dup_sub(r, h, K) + + _dr, dr = dr, dup_degree(r) + + if dr < dg: + break + elif dr == _dr and not K.is_Exact: + # remove leading term created by rounding error + r = dup_strip(r[1:]) + dr = dup_degree(r) + if dr < dg: + break + elif not (dr < _dr): + raise PolynomialDivisionFailed(f, g, K) + + return q, r + + +def dmp_ff_div(f, g, u, K): + """ + Polynomial division with remainder over a field. + + Examples + ======== + + >>> from sympy.polys import ring, QQ + >>> R, x,y = ring("x,y", QQ) + + >>> R.dmp_ff_div(x**2 + x*y, 2*x + 2) + (1/2*x + 1/2*y - 1/2, -y + 1) + + """ + if not u: + return dup_ff_div(f, g, K) + + df = dmp_degree(f, u) + dg = dmp_degree(g, u) + + if dg < 0: + raise ZeroDivisionError("polynomial division") + + q, r, dr = dmp_zero(u), f, df + + if df < dg: + return q, r + + lc_g, v = dmp_LC(g, K), u - 1 + + while True: + lc_r = dmp_LC(r, K) + c, R = dmp_ff_div(lc_r, lc_g, v, K) + + if not dmp_zero_p(R, v): + break + + j = dr - dg + + q = dmp_add_term(q, c, j, u, K) + h = dmp_mul_term(g, c, j, u, K) + r = dmp_sub(r, h, u, K) + + _dr, dr = dr, dmp_degree(r, u) + + if dr < dg: + break + elif not (dr < _dr): + raise PolynomialDivisionFailed(f, g, K) + + return q, r + + +def dup_div(f, g, K): + """ + Polynomial division with remainder in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ, QQ + + >>> R, x = ring("x", ZZ) + >>> R.dup_div(x**2 + 1, 2*x - 4) + (0, x**2 + 1) + + >>> R, x = ring("x", QQ) + >>> R.dup_div(x**2 + 1, 2*x - 4) + (1/2*x + 1, 5) + + """ + if K.is_Field: + return dup_ff_div(f, g, K) + else: + return dup_rr_div(f, g, K) + + +def dup_rem(f, g, K): + """ + Returns polynomial remainder in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ, QQ + + >>> R, x = ring("x", ZZ) + >>> R.dup_rem(x**2 + 1, 2*x - 4) + x**2 + 1 + + >>> R, x = ring("x", QQ) + >>> R.dup_rem(x**2 + 1, 2*x - 4) + 5 + + """ + return dup_div(f, g, K)[1] + + +def dup_quo(f, g, K): + """ + Returns exact polynomial quotient in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ, QQ + + >>> R, x = ring("x", ZZ) + >>> R.dup_quo(x**2 + 1, 2*x - 4) + 0 + + >>> R, x = ring("x", QQ) + >>> R.dup_quo(x**2 + 1, 2*x - 4) + 1/2*x + 1 + + """ + return dup_div(f, g, K)[0] + + +def dup_exquo(f, g, K): + """ + Returns polynomial quotient in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_exquo(x**2 - 1, x - 1) + x + 1 + + >>> R.dup_exquo(x**2 + 1, 2*x - 4) + Traceback (most recent call last): + ... + ExactQuotientFailed: [2, -4] does not divide [1, 0, 1] + + """ + q, r = dup_div(f, g, K) + + if not r: + return q + else: + raise ExactQuotientFailed(f, g) + + +def dmp_div(f, g, u, K): + """ + Polynomial division with remainder in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ, QQ + + >>> R, x,y = ring("x,y", ZZ) + >>> R.dmp_div(x**2 + x*y, 2*x + 2) + (0, x**2 + x*y) + + >>> R, x,y = ring("x,y", QQ) + >>> R.dmp_div(x**2 + x*y, 2*x + 2) + (1/2*x + 1/2*y - 1/2, -y + 1) + + """ + if K.is_Field: + return dmp_ff_div(f, g, u, K) + else: + return dmp_rr_div(f, g, u, K) + + +def dmp_rem(f, g, u, K): + """ + Returns polynomial remainder in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ, QQ + + >>> R, x,y = ring("x,y", ZZ) + >>> R.dmp_rem(x**2 + x*y, 2*x + 2) + x**2 + x*y + + >>> R, x,y = ring("x,y", QQ) + >>> R.dmp_rem(x**2 + x*y, 2*x + 2) + -y + 1 + + """ + return dmp_div(f, g, u, K)[1] + + +def dmp_quo(f, g, u, K): + """ + Returns exact polynomial quotient in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ, QQ + + >>> R, x,y = ring("x,y", ZZ) + >>> R.dmp_quo(x**2 + x*y, 2*x + 2) + 0 + + >>> R, x,y = ring("x,y", QQ) + >>> R.dmp_quo(x**2 + x*y, 2*x + 2) + 1/2*x + 1/2*y - 1/2 + + """ + return dmp_div(f, g, u, K)[0] + + +def dmp_exquo(f, g, u, K): + """ + Returns polynomial quotient in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + >>> f = x**2 + x*y + >>> g = x + y + >>> h = 2*x + 2 + + >>> R.dmp_exquo(f, g) + x + + >>> R.dmp_exquo(f, h) + Traceback (most recent call last): + ... + ExactQuotientFailed: [[2], [2]] does not divide [[1], [1, 0], []] + + """ + q, r = dmp_div(f, g, u, K) + + if dmp_zero_p(r, u): + return q + else: + raise ExactQuotientFailed(f, g) + + +def dup_max_norm(f, K): + """ + Returns maximum norm of a polynomial in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_max_norm(-x**2 + 2*x - 3) + 3 + + """ + if not f: + return K.zero + else: + return max(dup_abs(f, K)) + + +def dmp_max_norm(f, u, K): + """ + Returns maximum norm of a polynomial in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + >>> R.dmp_max_norm(2*x*y - x - 3) + 3 + + """ + if not u: + return dup_max_norm(f, K) + + v = u - 1 + + return max(dmp_max_norm(c, v, K) for c in f) + + +def dup_l1_norm(f, K): + """ + Returns l1 norm of a polynomial in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_l1_norm(2*x**3 - 3*x**2 + 1) + 6 + + """ + if not f: + return K.zero + else: + return sum(dup_abs(f, K)) + + +def dmp_l1_norm(f, u, K): + """ + Returns l1 norm of a polynomial in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + >>> R.dmp_l1_norm(2*x*y - x - 3) + 6 + + """ + if not u: + return dup_l1_norm(f, K) + + v = u - 1 + + return sum(dmp_l1_norm(c, v, K) for c in f) + + +def dup_l2_norm_squared(f, K): + """ + Returns squared l2 norm of a polynomial in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_l2_norm_squared(2*x**3 - 3*x**2 + 1) + 14 + + """ + return sum([coeff**2 for coeff in f], K.zero) + + +def dmp_l2_norm_squared(f, u, K): + """ + Returns squared l2 norm of a polynomial in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + >>> R.dmp_l2_norm_squared(2*x*y - x - 3) + 14 + + """ + if not u: + return dup_l2_norm_squared(f, K) + + v = u - 1 + + return sum(dmp_l2_norm_squared(c, v, K) for c in f) + + +def dup_expand(polys, K): + """ + Multiply together several polynomials in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_expand([x**2 - 1, x, 2]) + 2*x**3 - 2*x + + """ + if not polys: + return [K.one] + + f = polys[0] + + for g in polys[1:]: + f = dup_mul(f, g, K) + + return f + + +def dmp_expand(polys, u, K): + """ + Multiply together several polynomials in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + >>> R.dmp_expand([x**2 + y**2, x + 1]) + x**3 + x**2 + x*y**2 + y**2 + + """ + if not polys: + return dmp_one(u, K) + + f = polys[0] + + for g in polys[1:]: + f = dmp_mul(f, g, u, K) + + return f diff --git a/phivenv/Lib/site-packages/sympy/polys/densebasic.py b/phivenv/Lib/site-packages/sympy/polys/densebasic.py new file mode 100644 index 0000000000000000000000000000000000000000..b3a8a9497302b1af5bca20de100b7ae41e96b439 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/densebasic.py @@ -0,0 +1,1887 @@ +"""Basic tools for dense recursive polynomials in ``K[x]`` or ``K[X]``. """ + + +from sympy.core import igcd +from sympy.polys.monomials import monomial_min, monomial_div +from sympy.polys.orderings import monomial_key + +import random + + +ninf = float('-inf') + + +def poly_LC(f, K): + """ + Return leading coefficient of ``f``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.densebasic import poly_LC + + >>> poly_LC([], ZZ) + 0 + >>> poly_LC([ZZ(1), ZZ(2), ZZ(3)], ZZ) + 1 + + """ + if not f: + return K.zero + else: + return f[0] + + +def poly_TC(f, K): + """ + Return trailing coefficient of ``f``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.densebasic import poly_TC + + >>> poly_TC([], ZZ) + 0 + >>> poly_TC([ZZ(1), ZZ(2), ZZ(3)], ZZ) + 3 + + """ + if not f: + return K.zero + else: + return f[-1] + +dup_LC = dmp_LC = poly_LC +dup_TC = dmp_TC = poly_TC + + +def dmp_ground_LC(f, u, K): + """ + Return the ground leading coefficient. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.densebasic import dmp_ground_LC + + >>> f = ZZ.map([[[1], [2, 3]]]) + + >>> dmp_ground_LC(f, 2, ZZ) + 1 + + """ + while u: + f = dmp_LC(f, K) + u -= 1 + + return dup_LC(f, K) + + +def dmp_ground_TC(f, u, K): + """ + Return the ground trailing coefficient. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.densebasic import dmp_ground_TC + + >>> f = ZZ.map([[[1], [2, 3]]]) + + >>> dmp_ground_TC(f, 2, ZZ) + 3 + + """ + while u: + f = dmp_TC(f, K) + u -= 1 + + return dup_TC(f, K) + + +def dmp_true_LT(f, u, K): + """ + Return the leading term ``c * x_1**n_1 ... x_k**n_k``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.densebasic import dmp_true_LT + + >>> f = ZZ.map([[4], [2, 0], [3, 0, 0]]) + + >>> dmp_true_LT(f, 1, ZZ) + ((2, 0), 4) + + """ + monom = [] + + while u: + monom.append(len(f) - 1) + f, u = f[0], u - 1 + + if not f: + monom.append(0) + else: + monom.append(len(f) - 1) + + return tuple(monom), dup_LC(f, K) + + +def dup_degree(f): + """ + Return the leading degree of ``f`` in ``K[x]``. + + Note that the degree of 0 is negative infinity (``float('-inf')``). + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.densebasic import dup_degree + + >>> f = ZZ.map([1, 2, 0, 3]) + + >>> dup_degree(f) + 3 + + """ + if not f: + return ninf + return len(f) - 1 + + +def dmp_degree(f, u): + """ + Return the leading degree of ``f`` in ``x_0`` in ``K[X]``. + + Note that the degree of 0 is negative infinity (``float('-inf')``). + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.densebasic import dmp_degree + + >>> dmp_degree([[[]]], 2) + -inf + + >>> f = ZZ.map([[2], [1, 2, 3]]) + + >>> dmp_degree(f, 1) + 1 + + """ + if dmp_zero_p(f, u): + return ninf + else: + return len(f) - 1 + + +def _rec_degree_in(g, v, i, j): + """Recursive helper function for :func:`dmp_degree_in`.""" + if i == j: + return dmp_degree(g, v) + + v, i = v - 1, i + 1 + + return max(_rec_degree_in(c, v, i, j) for c in g) + + +def dmp_degree_in(f, j, u): + """ + Return the leading degree of ``f`` in ``x_j`` in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.densebasic import dmp_degree_in + + >>> f = ZZ.map([[2], [1, 2, 3]]) + + >>> dmp_degree_in(f, 0, 1) + 1 + >>> dmp_degree_in(f, 1, 1) + 2 + + """ + if not j: + return dmp_degree(f, u) + if j < 0 or j > u: + raise IndexError("0 <= j <= %s expected, got %s" % (u, j)) + + return _rec_degree_in(f, u, 0, j) + + +def _rec_degree_list(g, v, i, degs): + """Recursive helper for :func:`dmp_degree_list`.""" + degs[i] = max(degs[i], dmp_degree(g, v)) + + if v > 0: + v, i = v - 1, i + 1 + + for c in g: + _rec_degree_list(c, v, i, degs) + + +def dmp_degree_list(f, u): + """ + Return a list of degrees of ``f`` in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.densebasic import dmp_degree_list + + >>> f = ZZ.map([[1], [1, 2, 3]]) + + >>> dmp_degree_list(f, 1) + (1, 2) + + """ + degs = [ninf]*(u + 1) + _rec_degree_list(f, u, 0, degs) + return tuple(degs) + + +def dup_strip(f): + """ + Remove leading zeros from ``f`` in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys.densebasic import dup_strip + + >>> dup_strip([0, 0, 1, 2, 3, 0]) + [1, 2, 3, 0] + + """ + if not f or f[0]: + return f + + i = 0 + + for cf in f: + if cf: + break + else: + i += 1 + + return f[i:] + + +def dmp_strip(f, u): + """ + Remove leading zeros from ``f`` in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys.densebasic import dmp_strip + + >>> dmp_strip([[], [0, 1, 2], [1]], 1) + [[0, 1, 2], [1]] + + """ + if not u: + return dup_strip(f) + + if dmp_zero_p(f, u): + return f + + i, v = 0, u - 1 + + for c in f: + if not dmp_zero_p(c, v): + break + else: + i += 1 + + if i == len(f): + return dmp_zero(u) + else: + return f[i:] + + +def _rec_validate(f, g, i, K): + """Recursive helper for :func:`dmp_validate`.""" + if not isinstance(g, list): + if K is not None and not K.of_type(g): + raise TypeError("%s in %s in not of type %s" % (g, f, K.dtype)) + + return {i - 1} + elif not g: + return {i} + else: + levels = set() + + for c in g: + levels |= _rec_validate(f, c, i + 1, K) + + return levels + + +def _rec_strip(g, v): + """Recursive helper for :func:`_rec_strip`.""" + if not v: + return dup_strip(g) + + w = v - 1 + + return dmp_strip([ _rec_strip(c, w) for c in g ], v) + + +def dmp_validate(f, K=None): + """ + Return the number of levels in ``f`` and recursively strip it. + + Examples + ======== + + >>> from sympy.polys.densebasic import dmp_validate + + >>> dmp_validate([[], [0, 1, 2], [1]]) + ([[1, 2], [1]], 1) + + >>> dmp_validate([[1], 1]) + Traceback (most recent call last): + ... + ValueError: invalid data structure for a multivariate polynomial + + """ + levels = _rec_validate(f, f, 0, K) + + u = levels.pop() + + if not levels: + return _rec_strip(f, u), u + else: + raise ValueError( + "invalid data structure for a multivariate polynomial") + + +def dup_reverse(f): + """ + Compute ``x**n * f(1/x)``, i.e.: reverse ``f`` in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.densebasic import dup_reverse + + >>> f = ZZ.map([1, 2, 3, 0]) + + >>> dup_reverse(f) + [3, 2, 1] + + """ + return dup_strip(list(reversed(f))) + + +def dup_copy(f): + """ + Create a new copy of a polynomial ``f`` in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.densebasic import dup_copy + + >>> f = ZZ.map([1, 2, 3, 0]) + + >>> dup_copy([1, 2, 3, 0]) + [1, 2, 3, 0] + + """ + return list(f) + + +def dmp_copy(f, u): + """ + Create a new copy of a polynomial ``f`` in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.densebasic import dmp_copy + + >>> f = ZZ.map([[1], [1, 2]]) + + >>> dmp_copy(f, 1) + [[1], [1, 2]] + + """ + if not u: + return list(f) + + v = u - 1 + + return [ dmp_copy(c, v) for c in f ] + + +def dup_to_tuple(f): + """ + Convert `f` into a tuple. + + This is needed for hashing. This is similar to dup_copy(). + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.densebasic import dup_copy + + >>> f = ZZ.map([1, 2, 3, 0]) + + >>> dup_copy([1, 2, 3, 0]) + [1, 2, 3, 0] + + """ + return tuple(f) + + +def dmp_to_tuple(f, u): + """ + Convert `f` into a nested tuple of tuples. + + This is needed for hashing. This is similar to dmp_copy(). + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.densebasic import dmp_to_tuple + + >>> f = ZZ.map([[1], [1, 2]]) + + >>> dmp_to_tuple(f, 1) + ((1,), (1, 2)) + + """ + if not u: + return tuple(f) + v = u - 1 + + return tuple(dmp_to_tuple(c, v) for c in f) + + +def dup_normal(f, K): + """ + Normalize univariate polynomial in the given domain. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.densebasic import dup_normal + + >>> dup_normal([0, 1, 2, 3], ZZ) + [1, 2, 3] + + """ + return dup_strip([ K.normal(c) for c in f ]) + + +def dmp_normal(f, u, K): + """ + Normalize a multivariate polynomial in the given domain. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.densebasic import dmp_normal + + >>> dmp_normal([[], [0, 1, 2]], 1, ZZ) + [[1, 2]] + + """ + if not u: + return dup_normal(f, K) + + v = u - 1 + + return dmp_strip([ dmp_normal(c, v, K) for c in f ], u) + + +def dup_convert(f, K0, K1): + """ + Convert the ground domain of ``f`` from ``K0`` to ``K1``. + + Examples + ======== + + >>> from sympy.polys.rings import ring + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.densebasic import dup_convert + + >>> R, x = ring("x", ZZ) + + >>> dup_convert([R(1), R(2)], R.to_domain(), ZZ) + [1, 2] + >>> dup_convert([ZZ(1), ZZ(2)], ZZ, R.to_domain()) + [1, 2] + + """ + if K0 is not None and K0 == K1: + return f + else: + return dup_strip([ K1.convert(c, K0) for c in f ]) + + +def dmp_convert(f, u, K0, K1): + """ + Convert the ground domain of ``f`` from ``K0`` to ``K1``. + + Examples + ======== + + >>> from sympy.polys.rings import ring + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.densebasic import dmp_convert + + >>> R, x = ring("x", ZZ) + + >>> dmp_convert([[R(1)], [R(2)]], 1, R.to_domain(), ZZ) + [[1], [2]] + >>> dmp_convert([[ZZ(1)], [ZZ(2)]], 1, ZZ, R.to_domain()) + [[1], [2]] + + """ + if not u: + return dup_convert(f, K0, K1) + if K0 is not None and K0 == K1: + return f + + v = u - 1 + + return dmp_strip([ dmp_convert(c, v, K0, K1) for c in f ], u) + + +def dup_from_sympy(f, K): + """ + Convert the ground domain of ``f`` from SymPy to ``K``. + + Examples + ======== + + >>> from sympy import S + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.densebasic import dup_from_sympy + + >>> dup_from_sympy([S(1), S(2)], ZZ) == [ZZ(1), ZZ(2)] + True + + """ + return dup_strip([ K.from_sympy(c) for c in f ]) + + +def dmp_from_sympy(f, u, K): + """ + Convert the ground domain of ``f`` from SymPy to ``K``. + + Examples + ======== + + >>> from sympy import S + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.densebasic import dmp_from_sympy + + >>> dmp_from_sympy([[S(1)], [S(2)]], 1, ZZ) == [[ZZ(1)], [ZZ(2)]] + True + + """ + if not u: + return dup_from_sympy(f, K) + + v = u - 1 + + return dmp_strip([ dmp_from_sympy(c, v, K) for c in f ], u) + + +def dup_nth(f, n, K): + """ + Return the ``n``-th coefficient of ``f`` in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.densebasic import dup_nth + + >>> f = ZZ.map([1, 2, 3]) + + >>> dup_nth(f, 0, ZZ) + 3 + >>> dup_nth(f, 4, ZZ) + 0 + + """ + if n < 0: + raise IndexError("'n' must be non-negative, got %i" % n) + elif n >= len(f): + return K.zero + else: + return f[dup_degree(f) - n] + + +def dmp_nth(f, n, u, K): + """ + Return the ``n``-th coefficient of ``f`` in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.densebasic import dmp_nth + + >>> f = ZZ.map([[1], [2], [3]]) + + >>> dmp_nth(f, 0, 1, ZZ) + [3] + >>> dmp_nth(f, 4, 1, ZZ) + [] + + """ + if n < 0: + raise IndexError("'n' must be non-negative, got %i" % n) + elif n >= len(f): + return dmp_zero(u - 1) + else: + return f[dmp_degree(f, u) - n] + + +def dmp_ground_nth(f, N, u, K): + """ + Return the ground ``n``-th coefficient of ``f`` in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.densebasic import dmp_ground_nth + + >>> f = ZZ.map([[1], [2, 3]]) + + >>> dmp_ground_nth(f, (0, 1), 1, ZZ) + 2 + + """ + v = u + + for n in N: + if n < 0: + raise IndexError("`n` must be non-negative, got %i" % n) + elif n >= len(f): + return K.zero + else: + d = dmp_degree(f, v) + if d == ninf: + d = -1 + f, v = f[d - n], v - 1 + + return f + + +def dmp_zero_p(f, u): + """ + Return ``True`` if ``f`` is zero in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys.densebasic import dmp_zero_p + + >>> dmp_zero_p([[[[[]]]]], 4) + True + >>> dmp_zero_p([[[[[1]]]]], 4) + False + + """ + while u: + if len(f) != 1: + return False + + f = f[0] + u -= 1 + + return not f + + +def dmp_zero(u): + """ + Return a multivariate zero. + + Examples + ======== + + >>> from sympy.polys.densebasic import dmp_zero + + >>> dmp_zero(4) + [[[[[]]]]] + + """ + r = [] + + for i in range(u): + r = [r] + + return r + + +def dmp_one_p(f, u, K): + """ + Return ``True`` if ``f`` is one in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.densebasic import dmp_one_p + + >>> dmp_one_p([[[ZZ(1)]]], 2, ZZ) + True + + """ + return dmp_ground_p(f, K.one, u) + + +def dmp_one(u, K): + """ + Return a multivariate one over ``K``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.densebasic import dmp_one + + >>> dmp_one(2, ZZ) + [[[1]]] + + """ + return dmp_ground(K.one, u) + + +def dmp_ground_p(f, c, u): + """ + Return True if ``f`` is constant in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys.densebasic import dmp_ground_p + + >>> dmp_ground_p([[[3]]], 3, 2) + True + >>> dmp_ground_p([[[4]]], None, 2) + True + + """ + if c is not None and not c: + return dmp_zero_p(f, u) + + while u: + if len(f) != 1: + return False + f = f[0] + u -= 1 + + if c is None: + return len(f) <= 1 + else: + return f == [c] + + +def dmp_ground(c, u): + """ + Return a multivariate constant. + + Examples + ======== + + >>> from sympy.polys.densebasic import dmp_ground + + >>> dmp_ground(3, 5) + [[[[[[3]]]]]] + >>> dmp_ground(1, -1) + 1 + + """ + if not c: + return dmp_zero(u) + + for i in range(u + 1): + c = [c] + + return c + + +def dmp_zeros(n, u, K): + """ + Return a list of multivariate zeros. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.densebasic import dmp_zeros + + >>> dmp_zeros(3, 2, ZZ) + [[[[]]], [[[]]], [[[]]]] + >>> dmp_zeros(3, -1, ZZ) + [0, 0, 0] + + """ + if not n: + return [] + + if u < 0: + return [K.zero]*n + else: + return [ dmp_zero(u) for i in range(n) ] + + +def dmp_grounds(c, n, u): + """ + Return a list of multivariate constants. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.densebasic import dmp_grounds + + >>> dmp_grounds(ZZ(4), 3, 2) + [[[[4]]], [[[4]]], [[[4]]]] + >>> dmp_grounds(ZZ(4), 3, -1) + [4, 4, 4] + + """ + if not n: + return [] + + if u < 0: + return [c]*n + else: + return [ dmp_ground(c, u) for i in range(n) ] + + +def dmp_negative_p(f, u, K): + """ + Return ``True`` if ``LC(f)`` is negative. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.densebasic import dmp_negative_p + + >>> dmp_negative_p([[ZZ(1)], [-ZZ(1)]], 1, ZZ) + False + >>> dmp_negative_p([[-ZZ(1)], [ZZ(1)]], 1, ZZ) + True + + """ + return K.is_negative(dmp_ground_LC(f, u, K)) + + +def dmp_positive_p(f, u, K): + """ + Return ``True`` if ``LC(f)`` is positive. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.densebasic import dmp_positive_p + + >>> dmp_positive_p([[ZZ(1)], [-ZZ(1)]], 1, ZZ) + True + >>> dmp_positive_p([[-ZZ(1)], [ZZ(1)]], 1, ZZ) + False + + """ + return K.is_positive(dmp_ground_LC(f, u, K)) + + +def dup_from_dict(f, K): + """ + Create a ``K[x]`` polynomial from a ``dict``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.densebasic import dup_from_dict + + >>> dup_from_dict({(0,): ZZ(7), (2,): ZZ(5), (4,): ZZ(1)}, ZZ) + [1, 0, 5, 0, 7] + >>> dup_from_dict({}, ZZ) + [] + + """ + if not f: + return [] + + n, h = max(f.keys()), [] + + if isinstance(n, int): + for k in range(n, -1, -1): + h.append(f.get(k, K.zero)) + else: + (n,) = n + + for k in range(n, -1, -1): + h.append(f.get((k,), K.zero)) + + return dup_strip(h) + + +def dup_from_raw_dict(f, K): + """ + Create a ``K[x]`` polynomial from a raw ``dict``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.densebasic import dup_from_raw_dict + + >>> dup_from_raw_dict({0: ZZ(7), 2: ZZ(5), 4: ZZ(1)}, ZZ) + [1, 0, 5, 0, 7] + + """ + if not f: + return [] + + n, h = max(f.keys()), [] + + for k in range(n, -1, -1): + h.append(f.get(k, K.zero)) + + return dup_strip(h) + + +def dmp_from_dict(f, u, K): + """ + Create a ``K[X]`` polynomial from a ``dict``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.densebasic import dmp_from_dict + + >>> dmp_from_dict({(0, 0): ZZ(3), (0, 1): ZZ(2), (2, 1): ZZ(1)}, 1, ZZ) + [[1, 0], [], [2, 3]] + >>> dmp_from_dict({}, 0, ZZ) + [] + + """ + if not u: + return dup_from_dict(f, K) + if not f: + return dmp_zero(u) + + coeffs = {} + + for monom, coeff in f.items(): + head, tail = monom[0], monom[1:] + + if head in coeffs: + coeffs[head][tail] = coeff + else: + coeffs[head] = { tail: coeff } + + n, v, h = max(coeffs.keys()), u - 1, [] + + for k in range(n, -1, -1): + coeff = coeffs.get(k) + + if coeff is not None: + h.append(dmp_from_dict(coeff, v, K)) + else: + h.append(dmp_zero(v)) + + return dmp_strip(h, u) + + +def dup_to_dict(f, K=None, zero=False): + """ + Convert ``K[x]`` polynomial to a ``dict``. + + Examples + ======== + + >>> from sympy.polys.densebasic import dup_to_dict + + >>> dup_to_dict([1, 0, 5, 0, 7]) + {(0,): 7, (2,): 5, (4,): 1} + >>> dup_to_dict([]) + {} + + """ + if not f and zero: + return {(0,): K.zero} + + n, result = len(f) - 1, {} + + for k in range(0, n + 1): + if f[n - k]: + result[(k,)] = f[n - k] + + return result + + +def dup_to_raw_dict(f, K=None, zero=False): + """ + Convert a ``K[x]`` polynomial to a raw ``dict``. + + Examples + ======== + + >>> from sympy.polys.densebasic import dup_to_raw_dict + + >>> dup_to_raw_dict([1, 0, 5, 0, 7]) + {0: 7, 2: 5, 4: 1} + + """ + if not f and zero: + return {0: K.zero} + + n, result = len(f) - 1, {} + + for k in range(0, n + 1): + if f[n - k]: + result[k] = f[n - k] + + return result + + +def dmp_to_dict(f, u, K=None, zero=False): + """ + Convert a ``K[X]`` polynomial to a ``dict````. + + Examples + ======== + + >>> from sympy.polys.densebasic import dmp_to_dict + + >>> dmp_to_dict([[1, 0], [], [2, 3]], 1) + {(0, 0): 3, (0, 1): 2, (2, 1): 1} + >>> dmp_to_dict([], 0) + {} + + """ + if not u: + return dup_to_dict(f, K, zero=zero) + + if dmp_zero_p(f, u) and zero: + return {(0,)*(u + 1): K.zero} + + n, v, result = dmp_degree(f, u), u - 1, {} + + if n == ninf: + n = -1 + + for k in range(0, n + 1): + h = dmp_to_dict(f[n - k], v) + + for exp, coeff in h.items(): + result[(k,) + exp] = coeff + + return result + + +def dmp_swap(f, i, j, u, K): + """ + Transform ``K[..x_i..x_j..]`` to ``K[..x_j..x_i..]``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.densebasic import dmp_swap + + >>> f = ZZ.map([[[2], [1, 0]], []]) + + >>> dmp_swap(f, 0, 1, 2, ZZ) + [[[2], []], [[1, 0], []]] + >>> dmp_swap(f, 1, 2, 2, ZZ) + [[[1], [2, 0]], [[]]] + >>> dmp_swap(f, 0, 2, 2, ZZ) + [[[1, 0]], [[2, 0], []]] + + """ + if i < 0 or j < 0 or i > u or j > u: + raise IndexError("0 <= i < j <= %s expected" % u) + elif i == j: + return f + + F, H = dmp_to_dict(f, u), {} + + for exp, coeff in F.items(): + H[exp[:i] + (exp[j],) + + exp[i + 1:j] + + (exp[i],) + exp[j + 1:]] = coeff + + return dmp_from_dict(H, u, K) + + +def dmp_permute(f, P, u, K): + """ + Return a polynomial in ``K[x_{P(1)},..,x_{P(n)}]``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.densebasic import dmp_permute + + >>> f = ZZ.map([[[2], [1, 0]], []]) + + >>> dmp_permute(f, [1, 0, 2], 2, ZZ) + [[[2], []], [[1, 0], []]] + >>> dmp_permute(f, [1, 2, 0], 2, ZZ) + [[[1], []], [[2, 0], []]] + + """ + F, H = dmp_to_dict(f, u), {} + + for exp, coeff in F.items(): + new_exp = [0]*len(exp) + + for e, p in zip(exp, P): + new_exp[p] = e + + H[tuple(new_exp)] = coeff + + return dmp_from_dict(H, u, K) + + +def dmp_nest(f, l, K): + """ + Return a multivariate value nested ``l``-levels. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.densebasic import dmp_nest + + >>> dmp_nest([[ZZ(1)]], 2, ZZ) + [[[[1]]]] + + """ + if not isinstance(f, list): + return dmp_ground(f, l) + + for i in range(l): + f = [f] + + return f + + +def dmp_raise(f, l, u, K): + """ + Return a multivariate polynomial raised ``l``-levels. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.densebasic import dmp_raise + + >>> f = ZZ.map([[], [1, 2]]) + + >>> dmp_raise(f, 2, 1, ZZ) + [[[[]]], [[[1]], [[2]]]] + + """ + if not l: + return f + + if not u: + if not f: + return dmp_zero(l) + + k = l - 1 + + return [ dmp_ground(c, k) for c in f ] + + v = u - 1 + + return [ dmp_raise(c, l, v, K) for c in f ] + + +def dup_deflate(f, K): + """ + Map ``x**m`` to ``y`` in a polynomial in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.densebasic import dup_deflate + + >>> f = ZZ.map([1, 0, 0, 1, 0, 0, 1]) + + >>> dup_deflate(f, ZZ) + (3, [1, 1, 1]) + + """ + if dup_degree(f) <= 0: + return 1, f + + g = 0 + + for i in range(len(f)): + if not f[-i - 1]: + continue + + g = igcd(g, i) + + if g == 1: + return 1, f + + return g, f[::g] + + +def dmp_deflate(f, u, K): + """ + Map ``x_i**m_i`` to ``y_i`` in a polynomial in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.densebasic import dmp_deflate + + >>> f = ZZ.map([[1, 0, 0, 2], [], [3, 0, 0, 4]]) + + >>> dmp_deflate(f, 1, ZZ) + ((2, 3), [[1, 2], [3, 4]]) + + """ + if dmp_zero_p(f, u): + return (1,)*(u + 1), f + + F = dmp_to_dict(f, u) + B = [0]*(u + 1) + + for M in F.keys(): + for i, m in enumerate(M): + B[i] = igcd(B[i], m) + + for i, b in enumerate(B): + if not b: + B[i] = 1 + + B = tuple(B) + + if all(b == 1 for b in B): + return B, f + + H = {} + + for A, coeff in F.items(): + N = [ a // b for a, b in zip(A, B) ] + H[tuple(N)] = coeff + + return B, dmp_from_dict(H, u, K) + + +def dup_multi_deflate(polys, K): + """ + Map ``x**m`` to ``y`` in a set of polynomials in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.densebasic import dup_multi_deflate + + >>> f = ZZ.map([1, 0, 2, 0, 3]) + >>> g = ZZ.map([4, 0, 0]) + + >>> dup_multi_deflate((f, g), ZZ) + (2, ([1, 2, 3], [4, 0])) + + """ + G = 0 + + for p in polys: + if dup_degree(p) <= 0: + return 1, polys + + g = 0 + + for i in range(len(p)): + if not p[-i - 1]: + continue + + g = igcd(g, i) + + if g == 1: + return 1, polys + + G = igcd(G, g) + + return G, tuple([ p[::G] for p in polys ]) + + +def dmp_multi_deflate(polys, u, K): + """ + Map ``x_i**m_i`` to ``y_i`` in a set of polynomials in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.densebasic import dmp_multi_deflate + + >>> f = ZZ.map([[1, 0, 0, 2], [], [3, 0, 0, 4]]) + >>> g = ZZ.map([[1, 0, 2], [], [3, 0, 4]]) + + >>> dmp_multi_deflate((f, g), 1, ZZ) + ((2, 1), ([[1, 0, 0, 2], [3, 0, 0, 4]], [[1, 0, 2], [3, 0, 4]])) + + """ + if not u: + M, H = dup_multi_deflate(polys, K) + return (M,), H + + F, B = [], [0]*(u + 1) + + for p in polys: + f = dmp_to_dict(p, u) + + if not dmp_zero_p(p, u): + for M in f.keys(): + for i, m in enumerate(M): + B[i] = igcd(B[i], m) + + F.append(f) + + for i, b in enumerate(B): + if not b: + B[i] = 1 + + B = tuple(B) + + if all(b == 1 for b in B): + return B, polys + + H = [] + + for f in F: + h = {} + + for A, coeff in f.items(): + N = [ a // b for a, b in zip(A, B) ] + h[tuple(N)] = coeff + + H.append(dmp_from_dict(h, u, K)) + + return B, tuple(H) + + +def dup_inflate(f, m, K): + """ + Map ``y`` to ``x**m`` in a polynomial in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.densebasic import dup_inflate + + >>> f = ZZ.map([1, 1, 1]) + + >>> dup_inflate(f, 3, ZZ) + [1, 0, 0, 1, 0, 0, 1] + + """ + if m <= 0: + raise IndexError("'m' must be positive, got %s" % m) + if m == 1 or not f: + return f + + result = [f[0]] + + for coeff in f[1:]: + result.extend([K.zero]*(m - 1)) + result.append(coeff) + + return result + + +def _rec_inflate(g, M, v, i, K): + """Recursive helper for :func:`dmp_inflate`.""" + if not v: + return dup_inflate(g, M[i], K) + if M[i] <= 0: + raise IndexError("all M[i] must be positive, got %s" % M[i]) + + w, j = v - 1, i + 1 + + g = [ _rec_inflate(c, M, w, j, K) for c in g ] + + result = [g[0]] + + for coeff in g[1:]: + for _ in range(1, M[i]): + result.append(dmp_zero(w)) + + result.append(coeff) + + return result + + +def dmp_inflate(f, M, u, K): + """ + Map ``y_i`` to ``x_i**k_i`` in a polynomial in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.densebasic import dmp_inflate + + >>> f = ZZ.map([[1, 2], [3, 4]]) + + >>> dmp_inflate(f, (2, 3), 1, ZZ) + [[1, 0, 0, 2], [], [3, 0, 0, 4]] + + """ + if not u: + return dup_inflate(f, M[0], K) + + if all(m == 1 for m in M): + return f + else: + return _rec_inflate(f, M, u, 0, K) + + +def dmp_exclude(f, u, K): + """ + Exclude useless levels from ``f``. + + Return the levels excluded, the new excluded ``f``, and the new ``u``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.densebasic import dmp_exclude + + >>> f = ZZ.map([[[1]], [[1], [2]]]) + + >>> dmp_exclude(f, 2, ZZ) + ([2], [[1], [1, 2]], 1) + + """ + if not u or dmp_ground_p(f, None, u): + return [], f, u + + J, F = [], dmp_to_dict(f, u) + + for j in range(0, u + 1): + for monom in F.keys(): + if monom[j]: + break + else: + J.append(j) + + if not J: + return [], f, u + + f = {} + + for monom, coeff in F.items(): + monom = list(monom) + + for j in reversed(J): + del monom[j] + + f[tuple(monom)] = coeff + + u -= len(J) + + return J, dmp_from_dict(f, u, K), u + + +def dmp_include(f, J, u, K): + """ + Include useless levels in ``f``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.densebasic import dmp_include + + >>> f = ZZ.map([[1], [1, 2]]) + + >>> dmp_include(f, [2], 1, ZZ) + [[[1]], [[1], [2]]] + + """ + if not J: + return f + + F, f = dmp_to_dict(f, u), {} + + for monom, coeff in F.items(): + monom = list(monom) + + for j in J: + monom.insert(j, 0) + + f[tuple(monom)] = coeff + + u += len(J) + + return dmp_from_dict(f, u, K) + + +def dmp_inject(f, u, K, front=False): + """ + Convert ``f`` from ``K[X][Y]`` to ``K[X,Y]``. + + Examples + ======== + + >>> from sympy.polys.rings import ring + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.densebasic import dmp_inject + + >>> R, x,y = ring("x,y", ZZ) + + >>> dmp_inject([R(1), x + 2], 0, R.to_domain()) + ([[[1]], [[1], [2]]], 2) + >>> dmp_inject([R(1), x + 2], 0, R.to_domain(), front=True) + ([[[1]], [[1, 2]]], 2) + + """ + f, h = dmp_to_dict(f, u), {} + + v = K.ngens - 1 + + for f_monom, g in f.items(): + g = g.to_dict() + + for g_monom, c in g.items(): + if front: + h[g_monom + f_monom] = c + else: + h[f_monom + g_monom] = c + + w = u + v + 1 + + return dmp_from_dict(h, w, K.dom), w + + +def dmp_eject(f, u, K, front=False): + """ + Convert ``f`` from ``K[X,Y]`` to ``K[X][Y]``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.densebasic import dmp_eject + + >>> dmp_eject([[[1]], [[1], [2]]], 2, ZZ['x', 'y']) + [1, x + 2] + + """ + f, h = dmp_to_dict(f, u), {} + + n = K.ngens + v = u - K.ngens + 1 + + for monom, c in f.items(): + if front: + g_monom, f_monom = monom[:n], monom[n:] + else: + g_monom, f_monom = monom[-n:], monom[:-n] + + if f_monom in h: + h[f_monom][g_monom] = c + else: + h[f_monom] = {g_monom: c} + + for monom, c in h.items(): + h[monom] = K(c) + + return dmp_from_dict(h, v - 1, K) + + +def dup_terms_gcd(f, K): + """ + Remove GCD of terms from ``f`` in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.densebasic import dup_terms_gcd + + >>> f = ZZ.map([1, 0, 1, 0, 0]) + + >>> dup_terms_gcd(f, ZZ) + (2, [1, 0, 1]) + + """ + if dup_TC(f, K) or not f: + return 0, f + + i = 0 + + for c in reversed(f): + if not c: + i += 1 + else: + break + + return i, f[:-i] + + +def dmp_terms_gcd(f, u, K): + """ + Remove GCD of terms from ``f`` in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.densebasic import dmp_terms_gcd + + >>> f = ZZ.map([[1, 0], [1, 0, 0], [], []]) + + >>> dmp_terms_gcd(f, 1, ZZ) + ((2, 1), [[1], [1, 0]]) + + """ + if dmp_ground_TC(f, u, K) or dmp_zero_p(f, u): + return (0,)*(u + 1), f + + F = dmp_to_dict(f, u) + G = monomial_min(*list(F.keys())) + + if all(g == 0 for g in G): + return G, f + + f = {} + + for monom, coeff in F.items(): + f[monomial_div(monom, G)] = coeff + + return G, dmp_from_dict(f, u, K) + + +def _rec_list_terms(g, v, monom): + """Recursive helper for :func:`dmp_list_terms`.""" + d, terms = dmp_degree(g, v), [] + + if not v: + for i, c in enumerate(g): + if not c: + continue + + terms.append((monom + (d - i,), c)) + else: + w = v - 1 + + for i, c in enumerate(g): + terms.extend(_rec_list_terms(c, w, monom + (d - i,))) + + return terms + + +def dmp_list_terms(f, u, K, order=None): + """ + List all non-zero terms from ``f`` in the given order ``order``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.densebasic import dmp_list_terms + + >>> f = ZZ.map([[1, 1], [2, 3]]) + + >>> dmp_list_terms(f, 1, ZZ) + [((1, 1), 1), ((1, 0), 1), ((0, 1), 2), ((0, 0), 3)] + >>> dmp_list_terms(f, 1, ZZ, order='grevlex') + [((1, 1), 1), ((1, 0), 1), ((0, 1), 2), ((0, 0), 3)] + + """ + def sort(terms, O): + return sorted(terms, key=lambda term: O(term[0]), reverse=True) + + terms = _rec_list_terms(f, u, ()) + + if not terms: + return [((0,)*(u + 1), K.zero)] + + if order is None: + return terms + else: + return sort(terms, monomial_key(order)) + + +def dup_apply_pairs(f, g, h, args, K): + """ + Apply ``h`` to pairs of coefficients of ``f`` and ``g``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.densebasic import dup_apply_pairs + + >>> h = lambda x, y, z: 2*x + y - z + + >>> dup_apply_pairs([1, 2, 3], [3, 2, 1], h, (1,), ZZ) + [4, 5, 6] + + """ + n, m = len(f), len(g) + + if n != m: + if n > m: + g = [K.zero]*(n - m) + g + else: + f = [K.zero]*(m - n) + f + + result = [] + + for a, b in zip(f, g): + result.append(h(a, b, *args)) + + return dup_strip(result) + + +def dmp_apply_pairs(f, g, h, args, u, K): + """ + Apply ``h`` to pairs of coefficients of ``f`` and ``g``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.densebasic import dmp_apply_pairs + + >>> h = lambda x, y, z: 2*x + y - z + + >>> dmp_apply_pairs([[1], [2, 3]], [[3], [2, 1]], h, (1,), 1, ZZ) + [[4], [5, 6]] + + """ + if not u: + return dup_apply_pairs(f, g, h, args, K) + + n, m, v = len(f), len(g), u - 1 + + if n != m: + if n > m: + g = dmp_zeros(n - m, v, K) + g + else: + f = dmp_zeros(m - n, v, K) + f + + result = [] + + for a, b in zip(f, g): + result.append(dmp_apply_pairs(a, b, h, args, v, K)) + + return dmp_strip(result, u) + + +def dup_slice(f, m, n, K): + """Take a continuous subsequence of terms of ``f`` in ``K[x]``. """ + k = len(f) + + if k >= m: + M = k - m + else: + M = 0 + if k >= n: + N = k - n + else: + N = 0 + + f = f[N:M] + + while f and f[0] == K.zero: + f.pop(0) + + if not f: + return [] + else: + return f + [K.zero]*m + + +def dmp_slice(f, m, n, u, K): + """Take a continuous subsequence of terms of ``f`` in ``K[X]``. """ + return dmp_slice_in(f, m, n, 0, u, K) + + +def dmp_slice_in(f, m, n, j, u, K): + """Take a continuous subsequence of terms of ``f`` in ``x_j`` in ``K[X]``. """ + if j < 0 or j > u: + raise IndexError("-%s <= j < %s expected, got %s" % (u, u, j)) + + if not u: + return dup_slice(f, m, n, K) + + f, g = dmp_to_dict(f, u), {} + + for monom, coeff in f.items(): + k = monom[j] + + if k < m or k >= n: + monom = monom[:j] + (0,) + monom[j + 1:] + + if monom in g: + g[monom] += coeff + else: + g[monom] = coeff + + return dmp_from_dict(g, u, K) + + +def dup_random(n, a, b, K): + """ + Return a polynomial of degree ``n`` with coefficients in ``[a, b]``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.densebasic import dup_random + + >>> dup_random(3, -10, 10, ZZ) #doctest: +SKIP + [-2, -8, 9, -4] + + """ + f = [ K.convert(random.randint(a, b)) for _ in range(0, n + 1) ] + + while not f[0]: + f[0] = K.convert(random.randint(a, b)) + + return f diff --git a/phivenv/Lib/site-packages/sympy/polys/densetools.py b/phivenv/Lib/site-packages/sympy/polys/densetools.py new file mode 100644 index 0000000000000000000000000000000000000000..122bf778a4843847be6db17708887416ba458f49 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/densetools.py @@ -0,0 +1,1438 @@ +"""Advanced tools for dense recursive polynomials in ``K[x]`` or ``K[X]``. """ + + +from sympy.polys.densearith import ( + dup_add_term, dmp_add_term, + dup_lshift, + dup_add, dmp_add, + dup_sub, dmp_sub, + dup_mul, dmp_mul, + dup_sqr, + dup_div, + dup_rem, dmp_rem, + dup_mul_ground, dmp_mul_ground, + dup_quo_ground, dmp_quo_ground, + dup_exquo_ground, dmp_exquo_ground, +) +from sympy.polys.densebasic import ( + dup_strip, dmp_strip, + dup_convert, dmp_convert, + dup_degree, dmp_degree, + dmp_to_dict, + dmp_from_dict, + dup_LC, dmp_LC, dmp_ground_LC, + dup_TC, dmp_TC, + dmp_zero, dmp_ground, + dmp_zero_p, + dup_to_raw_dict, dup_from_raw_dict, + dmp_zeros, + dmp_include, +) +from sympy.polys.polyerrors import ( + MultivariatePolynomialError, + DomainError +) + +from math import ceil as _ceil, log2 as _log2 + + +def dup_integrate(f, m, K): + """ + Computes the indefinite integral of ``f`` in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys import ring, QQ + >>> R, x = ring("x", QQ) + + >>> R.dup_integrate(x**2 + 2*x, 1) + 1/3*x**3 + x**2 + >>> R.dup_integrate(x**2 + 2*x, 2) + 1/12*x**4 + 1/3*x**3 + + """ + if m <= 0 or not f: + return f + + g = [K.zero]*m + + for i, c in enumerate(reversed(f)): + n = i + 1 + + for j in range(1, m): + n *= i + j + 1 + + g.insert(0, K.exquo(c, K(n))) + + return g + + +def dmp_integrate(f, m, u, K): + """ + Computes the indefinite integral of ``f`` in ``x_0`` in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys import ring, QQ + >>> R, x,y = ring("x,y", QQ) + + >>> R.dmp_integrate(x + 2*y, 1) + 1/2*x**2 + 2*x*y + >>> R.dmp_integrate(x + 2*y, 2) + 1/6*x**3 + x**2*y + + """ + if not u: + return dup_integrate(f, m, K) + + if m <= 0 or dmp_zero_p(f, u): + return f + + g, v = dmp_zeros(m, u - 1, K), u - 1 + + for i, c in enumerate(reversed(f)): + n = i + 1 + + for j in range(1, m): + n *= i + j + 1 + + g.insert(0, dmp_quo_ground(c, K(n), v, K)) + + return g + + +def _rec_integrate_in(g, m, v, i, j, K): + """Recursive helper for :func:`dmp_integrate_in`.""" + if i == j: + return dmp_integrate(g, m, v, K) + + w, i = v - 1, i + 1 + + return dmp_strip([ _rec_integrate_in(c, m, w, i, j, K) for c in g ], v) + + +def dmp_integrate_in(f, m, j, u, K): + """ + Computes the indefinite integral of ``f`` in ``x_j`` in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys import ring, QQ + >>> R, x,y = ring("x,y", QQ) + + >>> R.dmp_integrate_in(x + 2*y, 1, 0) + 1/2*x**2 + 2*x*y + >>> R.dmp_integrate_in(x + 2*y, 1, 1) + x*y + y**2 + + """ + if j < 0 or j > u: + raise IndexError("0 <= j <= u expected, got u = %d, j = %d" % (u, j)) + + return _rec_integrate_in(f, m, u, 0, j, K) + + +def dup_diff(f, m, K): + """ + ``m``-th order derivative of a polynomial in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_diff(x**3 + 2*x**2 + 3*x + 4, 1) + 3*x**2 + 4*x + 3 + >>> R.dup_diff(x**3 + 2*x**2 + 3*x + 4, 2) + 6*x + 4 + + """ + if m <= 0: + return f + + n = dup_degree(f) + + if n < m: + return [] + + deriv = [] + + if m == 1: + for coeff in f[:-m]: + deriv.append(K(n)*coeff) + n -= 1 + else: + for coeff in f[:-m]: + k = n + + for i in range(n - 1, n - m, -1): + k *= i + + deriv.append(K(k)*coeff) + n -= 1 + + return dup_strip(deriv) + + +def dmp_diff(f, m, u, K): + """ + ``m``-th order derivative in ``x_0`` of a polynomial in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + >>> f = x*y**2 + 2*x*y + 3*x + 2*y**2 + 3*y + 1 + + >>> R.dmp_diff(f, 1) + y**2 + 2*y + 3 + >>> R.dmp_diff(f, 2) + 0 + + """ + if not u: + return dup_diff(f, m, K) + if m <= 0: + return f + + n = dmp_degree(f, u) + + if n < m: + return dmp_zero(u) + + deriv, v = [], u - 1 + + if m == 1: + for coeff in f[:-m]: + deriv.append(dmp_mul_ground(coeff, K(n), v, K)) + n -= 1 + else: + for coeff in f[:-m]: + k = n + + for i in range(n - 1, n - m, -1): + k *= i + + deriv.append(dmp_mul_ground(coeff, K(k), v, K)) + n -= 1 + + return dmp_strip(deriv, u) + + +def _rec_diff_in(g, m, v, i, j, K): + """Recursive helper for :func:`dmp_diff_in`.""" + if i == j: + return dmp_diff(g, m, v, K) + + w, i = v - 1, i + 1 + + return dmp_strip([ _rec_diff_in(c, m, w, i, j, K) for c in g ], v) + + +def dmp_diff_in(f, m, j, u, K): + """ + ``m``-th order derivative in ``x_j`` of a polynomial in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + >>> f = x*y**2 + 2*x*y + 3*x + 2*y**2 + 3*y + 1 + + >>> R.dmp_diff_in(f, 1, 0) + y**2 + 2*y + 3 + >>> R.dmp_diff_in(f, 1, 1) + 2*x*y + 2*x + 4*y + 3 + + """ + if j < 0 or j > u: + raise IndexError("0 <= j <= %s expected, got %s" % (u, j)) + + return _rec_diff_in(f, m, u, 0, j, K) + + +def dup_eval(f, a, K): + """ + Evaluate a polynomial at ``x = a`` in ``K[x]`` using Horner scheme. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_eval(x**2 + 2*x + 3, 2) + 11 + + """ + if not a: + return K.convert(dup_TC(f, K)) + + result = K.zero + + for c in f: + result *= a + result += c + + return result + + +def dmp_eval(f, a, u, K): + """ + Evaluate a polynomial at ``x_0 = a`` in ``K[X]`` using the Horner scheme. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + >>> R.dmp_eval(2*x*y + 3*x + y + 2, 2) + 5*y + 8 + + """ + if not u: + return dup_eval(f, a, K) + + if not a: + return dmp_TC(f, K) + + result, v = dmp_LC(f, K), u - 1 + + for coeff in f[1:]: + result = dmp_mul_ground(result, a, v, K) + result = dmp_add(result, coeff, v, K) + + return result + + +def _rec_eval_in(g, a, v, i, j, K): + """Recursive helper for :func:`dmp_eval_in`.""" + if i == j: + return dmp_eval(g, a, v, K) + + v, i = v - 1, i + 1 + + return dmp_strip([ _rec_eval_in(c, a, v, i, j, K) for c in g ], v) + + +def dmp_eval_in(f, a, j, u, K): + """ + Evaluate a polynomial at ``x_j = a`` in ``K[X]`` using the Horner scheme. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + >>> f = 2*x*y + 3*x + y + 2 + + >>> R.dmp_eval_in(f, 2, 0) + 5*y + 8 + >>> R.dmp_eval_in(f, 2, 1) + 7*x + 4 + + """ + if j < 0 or j > u: + raise IndexError("0 <= j <= %s expected, got %s" % (u, j)) + + return _rec_eval_in(f, a, u, 0, j, K) + + +def _rec_eval_tail(g, i, A, u, K): + """Recursive helper for :func:`dmp_eval_tail`.""" + if i == u: + return dup_eval(g, A[-1], K) + else: + h = [ _rec_eval_tail(c, i + 1, A, u, K) for c in g ] + + if i < u - len(A) + 1: + return h + else: + return dup_eval(h, A[-u + i - 1], K) + + +def dmp_eval_tail(f, A, u, K): + """ + Evaluate a polynomial at ``x_j = a_j, ...`` in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + >>> f = 2*x*y + 3*x + y + 2 + + >>> R.dmp_eval_tail(f, [2]) + 7*x + 4 + >>> R.dmp_eval_tail(f, [2, 2]) + 18 + + """ + if not A: + return f + + if dmp_zero_p(f, u): + return dmp_zero(u - len(A)) + + e = _rec_eval_tail(f, 0, A, u, K) + + if u == len(A) - 1: + return e + else: + return dmp_strip(e, u - len(A)) + + +def _rec_diff_eval(g, m, a, v, i, j, K): + """Recursive helper for :func:`dmp_diff_eval`.""" + if i == j: + return dmp_eval(dmp_diff(g, m, v, K), a, v, K) + + v, i = v - 1, i + 1 + + return dmp_strip([ _rec_diff_eval(c, m, a, v, i, j, K) for c in g ], v) + + +def dmp_diff_eval_in(f, m, a, j, u, K): + """ + Differentiate and evaluate a polynomial in ``x_j`` at ``a`` in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + >>> f = x*y**2 + 2*x*y + 3*x + 2*y**2 + 3*y + 1 + + >>> R.dmp_diff_eval_in(f, 1, 2, 0) + y**2 + 2*y + 3 + >>> R.dmp_diff_eval_in(f, 1, 2, 1) + 6*x + 11 + + """ + if j > u: + raise IndexError("-%s <= j < %s expected, got %s" % (u, u, j)) + if not j: + return dmp_eval(dmp_diff(f, m, u, K), a, u, K) + + return _rec_diff_eval(f, m, a, u, 0, j, K) + + +def dup_trunc(f, p, K): + """ + Reduce a ``K[x]`` polynomial modulo a constant ``p`` in ``K``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_trunc(2*x**3 + 3*x**2 + 5*x + 7, ZZ(3)) + -x**3 - x + 1 + + """ + if K.is_ZZ: + g = [] + + for c in f: + c = c % p + + if c > p // 2: + g.append(c - p) + else: + g.append(c) + elif K.is_FiniteField: + # XXX: python-flint's nmod does not support % + pi = int(p) + g = [ K(int(c) % pi) for c in f ] + else: + g = [ c % p for c in f ] + + return dup_strip(g) + + +def dmp_trunc(f, p, u, K): + """ + Reduce a ``K[X]`` polynomial modulo a polynomial ``p`` in ``K[Y]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + >>> f = 3*x**2*y + 8*x**2 + 5*x*y + 6*x + 2*y + 3 + >>> g = (y - 1).drop(x) + + >>> R.dmp_trunc(f, g) + 11*x**2 + 11*x + 5 + + """ + return dmp_strip([ dmp_rem(c, p, u - 1, K) for c in f ], u) + + +def dmp_ground_trunc(f, p, u, K): + """ + Reduce a ``K[X]`` polynomial modulo a constant ``p`` in ``K``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + >>> f = 3*x**2*y + 8*x**2 + 5*x*y + 6*x + 2*y + 3 + + >>> R.dmp_ground_trunc(f, ZZ(3)) + -x**2 - x*y - y + + """ + if not u: + return dup_trunc(f, p, K) + + v = u - 1 + + return dmp_strip([ dmp_ground_trunc(c, p, v, K) for c in f ], u) + + +def dup_monic(f, K): + """ + Divide all coefficients by ``LC(f)`` in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ, QQ + + >>> R, x = ring("x", ZZ) + >>> R.dup_monic(3*x**2 + 6*x + 9) + x**2 + 2*x + 3 + + >>> R, x = ring("x", QQ) + >>> R.dup_monic(3*x**2 + 4*x + 2) + x**2 + 4/3*x + 2/3 + + """ + if not f: + return f + + lc = dup_LC(f, K) + + if K.is_one(lc): + return f + else: + return dup_exquo_ground(f, lc, K) + + +def dmp_ground_monic(f, u, K): + """ + Divide all coefficients by ``LC(f)`` in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ, QQ + + >>> R, x,y = ring("x,y", ZZ) + >>> f = 3*x**2*y + 6*x**2 + 3*x*y + 9*y + 3 + + >>> R.dmp_ground_monic(f) + x**2*y + 2*x**2 + x*y + 3*y + 1 + + >>> R, x,y = ring("x,y", QQ) + >>> f = 3*x**2*y + 8*x**2 + 5*x*y + 6*x + 2*y + 3 + + >>> R.dmp_ground_monic(f) + x**2*y + 8/3*x**2 + 5/3*x*y + 2*x + 2/3*y + 1 + + """ + if not u: + return dup_monic(f, K) + + if dmp_zero_p(f, u): + return f + + lc = dmp_ground_LC(f, u, K) + + if K.is_one(lc): + return f + else: + return dmp_exquo_ground(f, lc, u, K) + + +def dup_content(f, K): + """ + Compute the GCD of coefficients of ``f`` in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ, QQ + + >>> R, x = ring("x", ZZ) + >>> f = 6*x**2 + 8*x + 12 + + >>> R.dup_content(f) + 2 + + >>> R, x = ring("x", QQ) + >>> f = 6*x**2 + 8*x + 12 + + >>> R.dup_content(f) + 2 + + """ + from sympy.polys.domains import QQ + + if not f: + return K.zero + + cont = K.zero + + if K == QQ: + for c in f: + cont = K.gcd(cont, c) + else: + for c in f: + cont = K.gcd(cont, c) + + if K.is_one(cont): + break + + return cont + + +def dmp_ground_content(f, u, K): + """ + Compute the GCD of coefficients of ``f`` in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ, QQ + + >>> R, x,y = ring("x,y", ZZ) + >>> f = 2*x*y + 6*x + 4*y + 12 + + >>> R.dmp_ground_content(f) + 2 + + >>> R, x,y = ring("x,y", QQ) + >>> f = 2*x*y + 6*x + 4*y + 12 + + >>> R.dmp_ground_content(f) + 2 + + """ + from sympy.polys.domains import QQ + + if not u: + return dup_content(f, K) + + if dmp_zero_p(f, u): + return K.zero + + cont, v = K.zero, u - 1 + + if K == QQ: + for c in f: + cont = K.gcd(cont, dmp_ground_content(c, v, K)) + else: + for c in f: + cont = K.gcd(cont, dmp_ground_content(c, v, K)) + + if K.is_one(cont): + break + + return cont + + +def dup_primitive(f, K): + """ + Compute content and the primitive form of ``f`` in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ, QQ + + >>> R, x = ring("x", ZZ) + >>> f = 6*x**2 + 8*x + 12 + + >>> R.dup_primitive(f) + (2, 3*x**2 + 4*x + 6) + + >>> R, x = ring("x", QQ) + >>> f = 6*x**2 + 8*x + 12 + + >>> R.dup_primitive(f) + (2, 3*x**2 + 4*x + 6) + + """ + if not f: + return K.zero, f + + cont = dup_content(f, K) + + if K.is_one(cont): + return cont, f + else: + return cont, dup_quo_ground(f, cont, K) + + +def dmp_ground_primitive(f, u, K): + """ + Compute content and the primitive form of ``f`` in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ, QQ + + >>> R, x,y = ring("x,y", ZZ) + >>> f = 2*x*y + 6*x + 4*y + 12 + + >>> R.dmp_ground_primitive(f) + (2, x*y + 3*x + 2*y + 6) + + >>> R, x,y = ring("x,y", QQ) + >>> f = 2*x*y + 6*x + 4*y + 12 + + >>> R.dmp_ground_primitive(f) + (2, x*y + 3*x + 2*y + 6) + + """ + if not u: + return dup_primitive(f, K) + + if dmp_zero_p(f, u): + return K.zero, f + + cont = dmp_ground_content(f, u, K) + + if K.is_one(cont): + return cont, f + else: + return cont, dmp_quo_ground(f, cont, u, K) + + +def dup_extract(f, g, K): + """ + Extract common content from a pair of polynomials in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_extract(6*x**2 + 12*x + 18, 4*x**2 + 8*x + 12) + (2, 3*x**2 + 6*x + 9, 2*x**2 + 4*x + 6) + + """ + fc = dup_content(f, K) + gc = dup_content(g, K) + + gcd = K.gcd(fc, gc) + + if not K.is_one(gcd): + f = dup_quo_ground(f, gcd, K) + g = dup_quo_ground(g, gcd, K) + + return gcd, f, g + + +def dmp_ground_extract(f, g, u, K): + """ + Extract common content from a pair of polynomials in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + >>> R.dmp_ground_extract(6*x*y + 12*x + 18, 4*x*y + 8*x + 12) + (2, 3*x*y + 6*x + 9, 2*x*y + 4*x + 6) + + """ + fc = dmp_ground_content(f, u, K) + gc = dmp_ground_content(g, u, K) + + gcd = K.gcd(fc, gc) + + if not K.is_one(gcd): + f = dmp_quo_ground(f, gcd, u, K) + g = dmp_quo_ground(g, gcd, u, K) + + return gcd, f, g + + +def dup_real_imag(f, K): + """ + Find ``f1`` and ``f2``, such that ``f(x+I*y) = f1(x,y) + f2(x,y)*I``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + >>> R.dup_real_imag(x**3 + x**2 + x + 1) + (x**3 + x**2 - 3*x*y**2 + x - y**2 + 1, 3*x**2*y + 2*x*y - y**3 + y) + + >>> from sympy.abc import x, y, z + >>> from sympy import I + >>> (z**3 + z**2 + z + 1).subs(z, x+I*y).expand().collect(I) + x**3 + x**2 - 3*x*y**2 + x - y**2 + I*(3*x**2*y + 2*x*y - y**3 + y) + 1 + + """ + if not K.is_ZZ and not K.is_QQ: + raise DomainError("computing real and imaginary parts is not supported over %s" % K) + + f1 = dmp_zero(1) + f2 = dmp_zero(1) + + if not f: + return f1, f2 + + g = [[[K.one, K.zero]], [[K.one], []]] + h = dmp_ground(f[0], 2) + + for c in f[1:]: + h = dmp_mul(h, g, 2, K) + h = dmp_add_term(h, dmp_ground(c, 1), 0, 2, K) + + H = dup_to_raw_dict(h) + + for k, h in H.items(): + m = k % 4 + + if not m: + f1 = dmp_add(f1, h, 1, K) + elif m == 1: + f2 = dmp_add(f2, h, 1, K) + elif m == 2: + f1 = dmp_sub(f1, h, 1, K) + else: + f2 = dmp_sub(f2, h, 1, K) + + return f1, f2 + + +def dup_mirror(f, K): + """ + Evaluate efficiently the composition ``f(-x)`` in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_mirror(x**3 + 2*x**2 - 4*x + 2) + -x**3 + 2*x**2 + 4*x + 2 + + """ + f = list(f) + + for i in range(len(f) - 2, -1, -2): + f[i] = -f[i] + + return f + + +def dup_scale(f, a, K): + """ + Evaluate efficiently composition ``f(a*x)`` in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_scale(x**2 - 2*x + 1, ZZ(2)) + 4*x**2 - 4*x + 1 + + """ + f, n, b = list(f), len(f) - 1, a + + for i in range(n - 1, -1, -1): + f[i], b = b*f[i], b*a + + return f + + +def dup_shift(f, a, K): + """ + Evaluate efficiently Taylor shift ``f(x + a)`` in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_shift(x**2 - 2*x + 1, ZZ(2)) + x**2 + 2*x + 1 + + """ + f, n = list(f), len(f) - 1 + + for i in range(n, 0, -1): + for j in range(0, i): + f[j + 1] += a*f[j] + + return f + + +def dmp_shift(f, a, u, K): + """ + Evaluate efficiently Taylor shift ``f(X + A)`` in ``K[X]``. + + Examples + ======== + + >>> from sympy import symbols, ring, ZZ + >>> x, y = symbols('x y') + >>> R, _, _ = ring([x, y], ZZ) + + >>> p = x**2*y + 2*x*y + 3*x + 4*y + 5 + + >>> R.dmp_shift(R(p), [ZZ(1), ZZ(2)]) + x**2*y + 2*x**2 + 4*x*y + 11*x + 7*y + 22 + + >>> p.subs({x: x + 1, y: y + 2}).expand() + x**2*y + 2*x**2 + 4*x*y + 11*x + 7*y + 22 + """ + if not u: + return dup_shift(f, a[0], K) + + if dmp_zero_p(f, u): + return f + + a0, a1 = a[0], a[1:] + + if any(a1): + f = [ dmp_shift(c, a1, u-1, K) for c in f ] + else: + f = list(f) + + if a0: + n = len(f) - 1 + + for i in range(n, 0, -1): + for j in range(0, i): + afj = dmp_mul_ground(f[j], a0, u-1, K) + f[j + 1] = dmp_add(f[j + 1], afj, u-1, K) + + return dmp_strip(f, u) + + +def dup_transform(f, p, q, K): + """ + Evaluate functional transformation ``q**n * f(p/q)`` in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_transform(x**2 - 2*x + 1, x**2 + 1, x - 1) + x**4 - 2*x**3 + 5*x**2 - 4*x + 4 + + """ + if not f: + return [] + + n = len(f) - 1 + h, Q = [f[0]], [[K.one]] + + for i in range(0, n): + Q.append(dup_mul(Q[-1], q, K)) + + for c, q in zip(f[1:], Q[1:]): + h = dup_mul(h, p, K) + q = dup_mul_ground(q, c, K) + h = dup_add(h, q, K) + + return h + + +def dup_compose(f, g, K): + """ + Evaluate functional composition ``f(g)`` in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_compose(x**2 + x, x - 1) + x**2 - x + + """ + if len(g) <= 1: + return dup_strip([dup_eval(f, dup_LC(g, K), K)]) + + if not f: + return [] + + h = [f[0]] + + for c in f[1:]: + h = dup_mul(h, g, K) + h = dup_add_term(h, c, 0, K) + + return h + + +def dmp_compose(f, g, u, K): + """ + Evaluate functional composition ``f(g)`` in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + >>> R.dmp_compose(x*y + 2*x + y, y) + y**2 + 3*y + + """ + if not u: + return dup_compose(f, g, K) + + if dmp_zero_p(f, u): + return f + + h = [f[0]] + + for c in f[1:]: + h = dmp_mul(h, g, u, K) + h = dmp_add_term(h, c, 0, u, K) + + return h + + +def _dup_right_decompose(f, s, K): + """Helper function for :func:`_dup_decompose`.""" + n = len(f) - 1 + lc = dup_LC(f, K) + + f = dup_to_raw_dict(f) + g = { s: K.one } + + r = n // s + + for i in range(1, s): + coeff = K.zero + + for j in range(0, i): + if not n + j - i in f: + continue + + if not s - j in g: + continue + + fc, gc = f[n + j - i], g[s - j] + coeff += (i - r*j)*fc*gc + + g[s - i] = K.quo(coeff, i*r*lc) + + return dup_from_raw_dict(g, K) + + +def _dup_left_decompose(f, h, K): + """Helper function for :func:`_dup_decompose`.""" + g, i = {}, 0 + + while f: + q, r = dup_div(f, h, K) + + if dup_degree(r) > 0: + return None + else: + g[i] = dup_LC(r, K) + f, i = q, i + 1 + + return dup_from_raw_dict(g, K) + + +def _dup_decompose(f, K): + """Helper function for :func:`dup_decompose`.""" + df = len(f) - 1 + + for s in range(2, df): + if df % s != 0: + continue + + h = _dup_right_decompose(f, s, K) + + if h is not None: + g = _dup_left_decompose(f, h, K) + + if g is not None: + return g, h + + return None + + +def dup_decompose(f, K): + """ + Computes functional decomposition of ``f`` in ``K[x]``. + + Given a univariate polynomial ``f`` with coefficients in a field of + characteristic zero, returns list ``[f_1, f_2, ..., f_n]``, where:: + + f = f_1 o f_2 o ... f_n = f_1(f_2(... f_n)) + + and ``f_2, ..., f_n`` are monic and homogeneous polynomials of at + least second degree. + + Unlike factorization, complete functional decompositions of + polynomials are not unique, consider examples: + + 1. ``f o g = f(x + b) o (g - b)`` + 2. ``x**n o x**m = x**m o x**n`` + 3. ``T_n o T_m = T_m o T_n`` + + where ``T_n`` and ``T_m`` are Chebyshev polynomials. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_decompose(x**4 - 2*x**3 + x**2) + [x**2, x**2 - x] + + References + ========== + + .. [1] [Kozen89]_ + + """ + F = [] + + while True: + result = _dup_decompose(f, K) + + if result is not None: + f, h = result + F = [h] + F + else: + break + + return [f] + F + + +def dmp_alg_inject(f, u, K): + """ + Convert polynomial from ``K(a)[X]`` to ``K[a,X]``. + + Examples + ======== + + >>> from sympy.polys.densetools import dmp_alg_inject + >>> from sympy import QQ, sqrt + + >>> K = QQ.algebraic_field(sqrt(2)) + + >>> p = [K.from_sympy(sqrt(2)), K.zero, K.one] + >>> P, lev, dom = dmp_alg_inject(p, 0, K) + >>> P + [[1, 0, 0], [1]] + >>> lev + 1 + >>> dom + QQ + + """ + if K.is_GaussianRing or K.is_GaussianField: + return _dmp_alg_inject_gaussian(f, u, K) + elif K.is_Algebraic: + return _dmp_alg_inject_alg(f, u, K) + else: + raise DomainError('computation can be done only in an algebraic domain') + + +def _dmp_alg_inject_gaussian(f, u, K): + """Helper function for :func:`dmp_alg_inject`.""" + f, h = dmp_to_dict(f, u), {} + + for f_monom, g in f.items(): + x, y = g.x, g.y + if x: + h[(0,) + f_monom] = x + if y: + h[(1,) + f_monom] = y + + F = dmp_from_dict(h, u + 1, K.dom) + + return F, u + 1, K.dom + + +def _dmp_alg_inject_alg(f, u, K): + """Helper function for :func:`dmp_alg_inject`.""" + f, h = dmp_to_dict(f, u), {} + + for f_monom, g in f.items(): + for g_monom, c in g.to_dict().items(): + h[g_monom + f_monom] = c + + F = dmp_from_dict(h, u + 1, K.dom) + + return F, u + 1, K.dom + + +def dmp_lift(f, u, K): + """ + Convert algebraic coefficients to integers in ``K[X]``. + + Examples + ======== + + >>> from sympy.polys import ring, QQ + >>> from sympy import I + + >>> K = QQ.algebraic_field(I) + >>> R, x = ring("x", K) + + >>> f = x**2 + K([QQ(1), QQ(0)])*x + K([QQ(2), QQ(0)]) + + >>> R.dmp_lift(f) + x**4 + x**2 + 4*x + 4 + + """ + # Circular import. Probably dmp_lift should be moved to euclidtools + from .euclidtools import dmp_resultant + + F, v, K2 = dmp_alg_inject(f, u, K) + + p_a = K.mod.to_list() + P_A = dmp_include(p_a, list(range(1, v + 1)), 0, K2) + + return dmp_resultant(F, P_A, v, K2) + + +def dup_sign_variations(f, K): + """ + Compute the number of sign variations of ``f`` in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_sign_variations(x**4 - x**2 - x + 1) + 2 + + """ + def is_negative_sympy(a): + if not a: + # XXX: requires zero equivalence testing in the domain + return False + else: + # XXX: This is inefficient. It should not be necessary to use a + # symbolic expression here at least for algebraic fields. If the + # domain elements can be numerically evaluated to real values with + # precision then this should work. We first need to rule out zero + # elements though. + return bool(K.to_sympy(a) < 0) + + # XXX: There should be a way to check for real numeric domains and + # Domain.is_negative should be fixed to handle all real numeric domains. + # It should not be necessary to special case all these different domains + # in this otherwise generic function. + if K.is_ZZ or K.is_QQ or K.is_RR: + is_negative = K.is_negative + elif K.is_AlgebraicField and K.ext.is_comparable: + is_negative = is_negative_sympy + elif ((K.is_PolynomialRing or K.is_FractionField) and len(K.symbols) == 1 and + (K.dom.is_ZZ or K.dom.is_QQ or K.is_AlgebraicField) and + K.symbols[0].is_transcendental and K.symbols[0].is_comparable): + # We can handle a polynomial ring like QQ[E] if there is a single + # transcendental generator because then zero equivalence is assured. + is_negative = is_negative_sympy + else: + raise DomainError("sign variation counting not supported over %s" % K) + + prev, k = K.zero, 0 + + for coeff in f: + if is_negative(coeff*prev): + k += 1 + + if coeff: + prev = coeff + + return k + + +def dup_clear_denoms(f, K0, K1=None, convert=False): + """ + Clear denominators, i.e. transform ``K_0`` to ``K_1``. + + Examples + ======== + + >>> from sympy.polys import ring, QQ + >>> R, x = ring("x", QQ) + + >>> f = QQ(1,2)*x + QQ(1,3) + + >>> R.dup_clear_denoms(f, convert=False) + (6, 3*x + 2) + >>> R.dup_clear_denoms(f, convert=True) + (6, 3*x + 2) + + """ + if K1 is None: + if K0.has_assoc_Ring: + K1 = K0.get_ring() + else: + K1 = K0 + + common = K1.one + + for c in f: + common = K1.lcm(common, K0.denom(c)) + + if K1.is_one(common): + if not convert: + return common, f + else: + return common, dup_convert(f, K0, K1) + + # Use quo rather than exquo to handle inexact domains by discarding the + # remainder. + f = [K0.numer(c)*K1.quo(common, K0.denom(c)) for c in f] + + if not convert: + return common, dup_convert(f, K1, K0) + else: + return common, f + + +def _rec_clear_denoms(g, v, K0, K1): + """Recursive helper for :func:`dmp_clear_denoms`.""" + common = K1.one + + if not v: + for c in g: + common = K1.lcm(common, K0.denom(c)) + else: + w = v - 1 + + for c in g: + common = K1.lcm(common, _rec_clear_denoms(c, w, K0, K1)) + + return common + + +def dmp_clear_denoms(f, u, K0, K1=None, convert=False): + """ + Clear denominators, i.e. transform ``K_0`` to ``K_1``. + + Examples + ======== + + >>> from sympy.polys import ring, QQ + >>> R, x,y = ring("x,y", QQ) + + >>> f = QQ(1,2)*x + QQ(1,3)*y + 1 + + >>> R.dmp_clear_denoms(f, convert=False) + (6, 3*x + 2*y + 6) + >>> R.dmp_clear_denoms(f, convert=True) + (6, 3*x + 2*y + 6) + + """ + if not u: + return dup_clear_denoms(f, K0, K1, convert=convert) + + if K1 is None: + if K0.has_assoc_Ring: + K1 = K0.get_ring() + else: + K1 = K0 + + common = _rec_clear_denoms(f, u, K0, K1) + + if not K1.is_one(common): + f = dmp_mul_ground(f, common, u, K0) + + if not convert: + return common, f + else: + return common, dmp_convert(f, u, K0, K1) + + +def dup_revert(f, n, K): + """ + Compute ``f**(-1)`` mod ``x**n`` using Newton iteration. + + This function computes first ``2**n`` terms of a polynomial that + is a result of inversion of a polynomial modulo ``x**n``. This is + useful to efficiently compute series expansion of ``1/f``. + + Examples + ======== + + >>> from sympy.polys import ring, QQ + >>> R, x = ring("x", QQ) + + >>> f = -QQ(1,720)*x**6 + QQ(1,24)*x**4 - QQ(1,2)*x**2 + 1 + + >>> R.dup_revert(f, 8) + 61/720*x**6 + 5/24*x**4 + 1/2*x**2 + 1 + + """ + g = [K.revert(dup_TC(f, K))] + h = [K.one, K.zero, K.zero] + + N = int(_ceil(_log2(n))) + + for i in range(1, N + 1): + a = dup_mul_ground(g, K(2), K) + b = dup_mul(f, dup_sqr(g, K), K) + g = dup_rem(dup_sub(a, b, K), h, K) + h = dup_lshift(h, dup_degree(h), K) + + return g + + +def dmp_revert(f, g, u, K): + """ + Compute ``f**(-1)`` mod ``x**n`` using Newton iteration. + + Examples + ======== + + >>> from sympy.polys import ring, QQ + >>> R, x,y = ring("x,y", QQ) + + """ + if not u: + return dup_revert(f, g, K) + else: + raise MultivariatePolynomialError(f, g) diff --git a/phivenv/Lib/site-packages/sympy/polys/dispersion.py b/phivenv/Lib/site-packages/sympy/polys/dispersion.py new file mode 100644 index 0000000000000000000000000000000000000000..699277d221f24b9bff42c55c3bb34fe5783ae7a1 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/dispersion.py @@ -0,0 +1,212 @@ +from sympy.core import S +from sympy.polys import Poly + + +def dispersionset(p, q=None, *gens, **args): + r"""Compute the *dispersion set* of two polynomials. + + For two polynomials `f(x)` and `g(x)` with `\deg f > 0` + and `\deg g > 0` the dispersion set `\operatorname{J}(f, g)` is defined as: + + .. math:: + \operatorname{J}(f, g) + & := \{a \in \mathbb{N}_0 | \gcd(f(x), g(x+a)) \neq 1\} \\ + & = \{a \in \mathbb{N}_0 | \deg \gcd(f(x), g(x+a)) \geq 1\} + + For a single polynomial one defines `\operatorname{J}(f) := \operatorname{J}(f, f)`. + + Examples + ======== + + >>> from sympy import poly + >>> from sympy.polys.dispersion import dispersion, dispersionset + >>> from sympy.abc import x + + Dispersion set and dispersion of a simple polynomial: + + >>> fp = poly((x - 3)*(x + 3), x) + >>> sorted(dispersionset(fp)) + [0, 6] + >>> dispersion(fp) + 6 + + Note that the definition of the dispersion is not symmetric: + + >>> fp = poly(x**4 - 3*x**2 + 1, x) + >>> gp = fp.shift(-3) + >>> sorted(dispersionset(fp, gp)) + [2, 3, 4] + >>> dispersion(fp, gp) + 4 + >>> sorted(dispersionset(gp, fp)) + [] + >>> dispersion(gp, fp) + -oo + + Computing the dispersion also works over field extensions: + + >>> from sympy import sqrt + >>> fp = poly(x**2 + sqrt(5)*x - 1, x, domain='QQ') + >>> gp = poly(x**2 + (2 + sqrt(5))*x + sqrt(5), x, domain='QQ') + >>> sorted(dispersionset(fp, gp)) + [2] + >>> sorted(dispersionset(gp, fp)) + [1, 4] + + We can even perform the computations for polynomials + having symbolic coefficients: + + >>> from sympy.abc import a + >>> fp = poly(4*x**4 + (4*a + 8)*x**3 + (a**2 + 6*a + 4)*x**2 + (a**2 + 2*a)*x, x) + >>> sorted(dispersionset(fp)) + [0, 1] + + See Also + ======== + + dispersion + + References + ========== + + .. [1] [ManWright94]_ + .. [2] [Koepf98]_ + .. [3] [Abramov71]_ + .. [4] [Man93]_ + """ + # Check for valid input + same = False if q is not None else True + if same: + q = p + + p = Poly(p, *gens, **args) + q = Poly(q, *gens, **args) + + if not p.is_univariate or not q.is_univariate: + raise ValueError("Polynomials need to be univariate") + + # The generator + if not p.gen == q.gen: + raise ValueError("Polynomials must have the same generator") + gen = p.gen + + # We define the dispersion of constant polynomials to be zero + if p.degree() < 1 or q.degree() < 1: + return {0} + + # Factor p and q over the rationals + fp = p.factor_list() + fq = q.factor_list() if not same else fp + + # Iterate over all pairs of factors + J = set() + for s, unused in fp[1]: + for t, unused in fq[1]: + m = s.degree() + n = t.degree() + if n != m: + continue + an = s.LC() + bn = t.LC() + if not (an - bn).is_zero: + continue + # Note that the roles of `s` and `t` below are switched + # w.r.t. the original paper. This is for consistency + # with the description in the book of W. Koepf. + anm1 = s.coeff_monomial(gen**(m-1)) + bnm1 = t.coeff_monomial(gen**(n-1)) + alpha = (anm1 - bnm1) / S(n*bn) + if not alpha.is_integer: + continue + if alpha < 0 or alpha in J: + continue + if n > 1 and not (s - t.shift(alpha)).is_zero: + continue + J.add(alpha) + + return J + + +def dispersion(p, q=None, *gens, **args): + r"""Compute the *dispersion* of polynomials. + + For two polynomials `f(x)` and `g(x)` with `\deg f > 0` + and `\deg g > 0` the dispersion `\operatorname{dis}(f, g)` is defined as: + + .. math:: + \operatorname{dis}(f, g) + & := \max\{ J(f,g) \cup \{0\} \} \\ + & = \max\{ \{a \in \mathbb{N} | \gcd(f(x), g(x+a)) \neq 1\} \cup \{0\} \} + + and for a single polynomial `\operatorname{dis}(f) := \operatorname{dis}(f, f)`. + Note that we make the definition `\max\{\} := -\infty`. + + Examples + ======== + + >>> from sympy import poly + >>> from sympy.polys.dispersion import dispersion, dispersionset + >>> from sympy.abc import x + + Dispersion set and dispersion of a simple polynomial: + + >>> fp = poly((x - 3)*(x + 3), x) + >>> sorted(dispersionset(fp)) + [0, 6] + >>> dispersion(fp) + 6 + + Note that the definition of the dispersion is not symmetric: + + >>> fp = poly(x**4 - 3*x**2 + 1, x) + >>> gp = fp.shift(-3) + >>> sorted(dispersionset(fp, gp)) + [2, 3, 4] + >>> dispersion(fp, gp) + 4 + >>> sorted(dispersionset(gp, fp)) + [] + >>> dispersion(gp, fp) + -oo + + The maximum of an empty set is defined to be `-\infty` + as seen in this example. + + Computing the dispersion also works over field extensions: + + >>> from sympy import sqrt + >>> fp = poly(x**2 + sqrt(5)*x - 1, x, domain='QQ') + >>> gp = poly(x**2 + (2 + sqrt(5))*x + sqrt(5), x, domain='QQ') + >>> sorted(dispersionset(fp, gp)) + [2] + >>> sorted(dispersionset(gp, fp)) + [1, 4] + + We can even perform the computations for polynomials + having symbolic coefficients: + + >>> from sympy.abc import a + >>> fp = poly(4*x**4 + (4*a + 8)*x**3 + (a**2 + 6*a + 4)*x**2 + (a**2 + 2*a)*x, x) + >>> sorted(dispersionset(fp)) + [0, 1] + + See Also + ======== + + dispersionset + + References + ========== + + .. [1] [ManWright94]_ + .. [2] [Koepf98]_ + .. [3] [Abramov71]_ + .. [4] [Man93]_ + """ + J = dispersionset(p, q, *gens, **args) + if not J: + # Definition for maximum of empty set + j = S.NegativeInfinity + else: + j = max(J) + return j diff --git a/phivenv/Lib/site-packages/sympy/polys/distributedmodules.py b/phivenv/Lib/site-packages/sympy/polys/distributedmodules.py new file mode 100644 index 0000000000000000000000000000000000000000..df4581e58951a9c29b9e5b085311f5e6cb00f381 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/distributedmodules.py @@ -0,0 +1,739 @@ +r""" +Sparse distributed elements of free modules over multivariate (generalized) +polynomial rings. + +This code and its data structures are very much like the distributed +polynomials, except that the first "exponent" of the monomial is +a module generator index. That is, the multi-exponent ``(i, e_1, ..., e_n)`` +represents the "monomial" `x_1^{e_1} \cdots x_n^{e_n} f_i` of the free module +`F` generated by `f_1, \ldots, f_r` over (a localization of) the ring +`K[x_1, \ldots, x_n]`. A module element is simply stored as a list of terms +ordered by the monomial order. Here a term is a pair of a multi-exponent and a +coefficient. In general, this coefficient should never be zero (since it can +then be omitted). The zero module element is stored as an empty list. + +The main routines are ``sdm_nf_mora`` and ``sdm_groebner`` which can be used +to compute, respectively, weak normal forms and standard bases. They work with +arbitrary (not necessarily global) monomial orders. + +In general, product orders have to be used to construct valid monomial orders +for modules. However, ``lex`` can be used as-is. + +Note that the "level" (number of variables, i.e. parameter u+1 in +distributedpolys.py) is never needed in this code. + +The main reference for this file is [SCA], +"A Singular Introduction to Commutative Algebra". +""" + + +from itertools import permutations + +from sympy.polys.monomials import ( + monomial_mul, monomial_lcm, monomial_div, monomial_deg +) + +from sympy.polys.polytools import Poly +from sympy.polys.polyutils import parallel_dict_from_expr +from sympy.core.singleton import S +from sympy.core.sympify import sympify + +# Additional monomial tools. + + +def sdm_monomial_mul(M, X): + """ + Multiply tuple ``X`` representing a monomial of `K[X]` into the tuple + ``M`` representing a monomial of `F`. + + Examples + ======== + + Multiplying `xy^3` into `x f_1` yields `x^2 y^3 f_1`: + + >>> from sympy.polys.distributedmodules import sdm_monomial_mul + >>> sdm_monomial_mul((1, 1, 0), (1, 3)) + (1, 2, 3) + """ + return (M[0],) + monomial_mul(X, M[1:]) + + +def sdm_monomial_deg(M): + """ + Return the total degree of ``M``. + + Examples + ======== + + For example, the total degree of `x^2 y f_5` is 3: + + >>> from sympy.polys.distributedmodules import sdm_monomial_deg + >>> sdm_monomial_deg((5, 2, 1)) + 3 + """ + return monomial_deg(M[1:]) + + +def sdm_monomial_lcm(A, B): + r""" + Return the "least common multiple" of ``A`` and ``B``. + + IF `A = M e_j` and `B = N e_j`, where `M` and `N` are polynomial monomials, + this returns `\lcm(M, N) e_j`. Note that ``A`` and ``B`` involve distinct + monomials. + + Otherwise the result is undefined. + + Examples + ======== + + >>> from sympy.polys.distributedmodules import sdm_monomial_lcm + >>> sdm_monomial_lcm((1, 2, 3), (1, 0, 5)) + (1, 2, 5) + """ + return (A[0],) + monomial_lcm(A[1:], B[1:]) + + +def sdm_monomial_divides(A, B): + """ + Does there exist a (polynomial) monomial X such that XA = B? + + Examples + ======== + + Positive examples: + + In the following examples, the monomial is given in terms of x, y and the + generator(s), f_1, f_2 etc. The tuple form of that monomial is used in + the call to sdm_monomial_divides. + Note: the generator appears last in the expression but first in the tuple + and other factors appear in the same order that they appear in the monomial + expression. + + `A = f_1` divides `B = f_1` + + >>> from sympy.polys.distributedmodules import sdm_monomial_divides + >>> sdm_monomial_divides((1, 0, 0), (1, 0, 0)) + True + + `A = f_1` divides `B = x^2 y f_1` + + >>> sdm_monomial_divides((1, 0, 0), (1, 2, 1)) + True + + `A = xy f_5` divides `B = x^2 y f_5` + + >>> sdm_monomial_divides((5, 1, 1), (5, 2, 1)) + True + + Negative examples: + + `A = f_1` does not divide `B = f_2` + + >>> sdm_monomial_divides((1, 0, 0), (2, 0, 0)) + False + + `A = x f_1` does not divide `B = f_1` + + >>> sdm_monomial_divides((1, 1, 0), (1, 0, 0)) + False + + `A = xy^2 f_5` does not divide `B = y f_5` + + >>> sdm_monomial_divides((5, 1, 2), (5, 0, 1)) + False + """ + return A[0] == B[0] and all(a <= b for a, b in zip(A[1:], B[1:])) + + +# The actual distributed modules code. + +def sdm_LC(f, K): + """Returns the leading coefficient of ``f``. """ + if not f: + return K.zero + else: + return f[0][1] + + +def sdm_to_dict(f): + """Make a dictionary from a distributed polynomial. """ + return dict(f) + + +def sdm_from_dict(d, O): + """ + Create an sdm from a dictionary. + + Here ``O`` is the monomial order to use. + + Examples + ======== + + >>> from sympy.polys.distributedmodules import sdm_from_dict + >>> from sympy.polys import QQ, lex + >>> dic = {(1, 1, 0): QQ(1), (1, 0, 0): QQ(2), (0, 1, 0): QQ(0)} + >>> sdm_from_dict(dic, lex) + [((1, 1, 0), 1), ((1, 0, 0), 2)] + """ + return sdm_strip(sdm_sort(list(d.items()), O)) + + +def sdm_sort(f, O): + """Sort terms in ``f`` using the given monomial order ``O``. """ + return sorted(f, key=lambda term: O(term[0]), reverse=True) + + +def sdm_strip(f): + """Remove terms with zero coefficients from ``f`` in ``K[X]``. """ + return [ (monom, coeff) for monom, coeff in f if coeff ] + + +def sdm_add(f, g, O, K): + """ + Add two module elements ``f``, ``g``. + + Addition is done over the ground field ``K``, monomials are ordered + according to ``O``. + + Examples + ======== + + All examples use lexicographic order. + + `(xy f_1) + (f_2) = f_2 + xy f_1` + + >>> from sympy.polys.distributedmodules import sdm_add + >>> from sympy.polys import lex, QQ + >>> sdm_add([((1, 1, 1), QQ(1))], [((2, 0, 0), QQ(1))], lex, QQ) + [((2, 0, 0), 1), ((1, 1, 1), 1)] + + `(xy f_1) + (-xy f_1)` = 0` + + >>> sdm_add([((1, 1, 1), QQ(1))], [((1, 1, 1), QQ(-1))], lex, QQ) + [] + + `(f_1) + (2f_1) = 3f_1` + + >>> sdm_add([((1, 0, 0), QQ(1))], [((1, 0, 0), QQ(2))], lex, QQ) + [((1, 0, 0), 3)] + + `(yf_1) + (xf_1) = xf_1 + yf_1` + + >>> sdm_add([((1, 0, 1), QQ(1))], [((1, 1, 0), QQ(1))], lex, QQ) + [((1, 1, 0), 1), ((1, 0, 1), 1)] + """ + h = dict(f) + + for monom, c in g: + if monom in h: + coeff = h[monom] + c + + if not coeff: + del h[monom] + else: + h[monom] = coeff + else: + h[monom] = c + + return sdm_from_dict(h, O) + + +def sdm_LM(f): + r""" + Returns the leading monomial of ``f``. + + Only valid if `f \ne 0`. + + Examples + ======== + + >>> from sympy.polys.distributedmodules import sdm_LM, sdm_from_dict + >>> from sympy.polys import QQ, lex + >>> dic = {(1, 2, 3): QQ(1), (4, 0, 0): QQ(1), (4, 0, 1): QQ(1)} + >>> sdm_LM(sdm_from_dict(dic, lex)) + (4, 0, 1) + """ + return f[0][0] + + +def sdm_LT(f): + r""" + Returns the leading term of ``f``. + + Only valid if `f \ne 0`. + + Examples + ======== + + >>> from sympy.polys.distributedmodules import sdm_LT, sdm_from_dict + >>> from sympy.polys import QQ, lex + >>> dic = {(1, 2, 3): QQ(1), (4, 0, 0): QQ(2), (4, 0, 1): QQ(3)} + >>> sdm_LT(sdm_from_dict(dic, lex)) + ((4, 0, 1), 3) + """ + return f[0] + + +def sdm_mul_term(f, term, O, K): + """ + Multiply a distributed module element ``f`` by a (polynomial) term ``term``. + + Multiplication of coefficients is done over the ground field ``K``, and + monomials are ordered according to ``O``. + + Examples + ======== + + `0 f_1 = 0` + + >>> from sympy.polys.distributedmodules import sdm_mul_term + >>> from sympy.polys import lex, QQ + >>> sdm_mul_term([((1, 0, 0), QQ(1))], ((0, 0), QQ(0)), lex, QQ) + [] + + `x 0 = 0` + + >>> sdm_mul_term([], ((1, 0), QQ(1)), lex, QQ) + [] + + `(x) (f_1) = xf_1` + + >>> sdm_mul_term([((1, 0, 0), QQ(1))], ((1, 0), QQ(1)), lex, QQ) + [((1, 1, 0), 1)] + + `(2xy) (3x f_1 + 4y f_2) = 8xy^2 f_2 + 6x^2y f_1` + + >>> f = [((2, 0, 1), QQ(4)), ((1, 1, 0), QQ(3))] + >>> sdm_mul_term(f, ((1, 1), QQ(2)), lex, QQ) + [((2, 1, 2), 8), ((1, 2, 1), 6)] + """ + X, c = term + + if not f or not c: + return [] + else: + if K.is_one(c): + return [ (sdm_monomial_mul(f_M, X), f_c) for f_M, f_c in f ] + else: + return [ (sdm_monomial_mul(f_M, X), f_c * c) for f_M, f_c in f ] + + +def sdm_zero(): + """Return the zero module element.""" + return [] + + +def sdm_deg(f): + """ + Degree of ``f``. + + This is the maximum of the degrees of all its monomials. + Invalid if ``f`` is zero. + + Examples + ======== + + >>> from sympy.polys.distributedmodules import sdm_deg + >>> sdm_deg([((1, 2, 3), 1), ((10, 0, 1), 1), ((2, 3, 4), 4)]) + 7 + """ + return max(sdm_monomial_deg(M[0]) for M in f) + + +# Conversion + +def sdm_from_vector(vec, O, K, **opts): + """ + Create an sdm from an iterable of expressions. + + Coefficients are created in the ground field ``K``, and terms are ordered + according to monomial order ``O``. Named arguments are passed on to the + polys conversion code and can be used to specify for example generators. + + Examples + ======== + + >>> from sympy.polys.distributedmodules import sdm_from_vector + >>> from sympy.abc import x, y, z + >>> from sympy.polys import QQ, lex + >>> sdm_from_vector([x**2+y**2, 2*z], lex, QQ) + [((1, 0, 0, 1), 2), ((0, 2, 0, 0), 1), ((0, 0, 2, 0), 1)] + """ + dics, gens = parallel_dict_from_expr(sympify(vec), **opts) + dic = {} + for i, d in enumerate(dics): + for k, v in d.items(): + dic[(i,) + k] = K.convert(v) + return sdm_from_dict(dic, O) + + +def sdm_to_vector(f, gens, K, n=None): + """ + Convert sdm ``f`` into a list of polynomial expressions. + + The generators for the polynomial ring are specified via ``gens``. The rank + of the module is guessed, or passed via ``n``. The ground field is assumed + to be ``K``. + + Examples + ======== + + >>> from sympy.polys.distributedmodules import sdm_to_vector + >>> from sympy.abc import x, y, z + >>> from sympy.polys import QQ + >>> f = [((1, 0, 0, 1), QQ(2)), ((0, 2, 0, 0), QQ(1)), ((0, 0, 2, 0), QQ(1))] + >>> sdm_to_vector(f, [x, y, z], QQ) + [x**2 + y**2, 2*z] + """ + dic = sdm_to_dict(f) + dics = {} + for k, v in dic.items(): + dics.setdefault(k[0], []).append((k[1:], v)) + n = n or len(dics) + res = [] + for k in range(n): + if k in dics: + res.append(Poly(dict(dics[k]), gens=gens, domain=K).as_expr()) + else: + res.append(S.Zero) + return res + +# Algorithms. + + +def sdm_spoly(f, g, O, K, phantom=None): + """ + Compute the generalized s-polynomial of ``f`` and ``g``. + + The ground field is assumed to be ``K``, and monomials ordered according to + ``O``. + + This is invalid if either of ``f`` or ``g`` is zero. + + If the leading terms of `f` and `g` involve different basis elements of + `F`, their s-poly is defined to be zero. Otherwise it is a certain linear + combination of `f` and `g` in which the leading terms cancel. + See [SCA, defn 2.3.6] for details. + + If ``phantom`` is not ``None``, it should be a pair of module elements on + which to perform the same operation(s) as on ``f`` and ``g``. The in this + case both results are returned. + + Examples + ======== + + >>> from sympy.polys.distributedmodules import sdm_spoly + >>> from sympy.polys import QQ, lex + >>> f = [((2, 1, 1), QQ(1)), ((1, 0, 1), QQ(1))] + >>> g = [((2, 3, 0), QQ(1))] + >>> h = [((1, 2, 3), QQ(1))] + >>> sdm_spoly(f, h, lex, QQ) + [] + >>> sdm_spoly(f, g, lex, QQ) + [((1, 2, 1), 1)] + """ + if not f or not g: + return sdm_zero() + LM1 = sdm_LM(f) + LM2 = sdm_LM(g) + if LM1[0] != LM2[0]: + return sdm_zero() + LM1 = LM1[1:] + LM2 = LM2[1:] + lcm = monomial_lcm(LM1, LM2) + m1 = monomial_div(lcm, LM1) + m2 = monomial_div(lcm, LM2) + c = K.quo(-sdm_LC(f, K), sdm_LC(g, K)) + r1 = sdm_add(sdm_mul_term(f, (m1, K.one), O, K), + sdm_mul_term(g, (m2, c), O, K), O, K) + if phantom is None: + return r1 + r2 = sdm_add(sdm_mul_term(phantom[0], (m1, K.one), O, K), + sdm_mul_term(phantom[1], (m2, c), O, K), O, K) + return r1, r2 + + +def sdm_ecart(f): + """ + Compute the ecart of ``f``. + + This is defined to be the difference of the total degree of `f` and the + total degree of the leading monomial of `f` [SCA, defn 2.3.7]. + + Invalid if f is zero. + + Examples + ======== + + >>> from sympy.polys.distributedmodules import sdm_ecart + >>> sdm_ecart([((1, 2, 3), 1), ((1, 0, 1), 1)]) + 0 + >>> sdm_ecart([((2, 2, 1), 1), ((1, 5, 1), 1)]) + 3 + """ + return sdm_deg(f) - sdm_monomial_deg(sdm_LM(f)) + + +def sdm_nf_mora(f, G, O, K, phantom=None): + r""" + Compute a weak normal form of ``f`` with respect to ``G`` and order ``O``. + + The ground field is assumed to be ``K``, and monomials ordered according to + ``O``. + + Weak normal forms are defined in [SCA, defn 2.3.3]. They are not unique. + This function deterministically computes a weak normal form, depending on + the order of `G`. + + The most important property of a weak normal form is the following: if + `R` is the ring associated with the monomial ordering (if the ordering is + global, we just have `R = K[x_1, \ldots, x_n]`, otherwise it is a certain + localization thereof), `I` any ideal of `R` and `G` a standard basis for + `I`, then for any `f \in R`, we have `f \in I` if and only if + `NF(f | G) = 0`. + + This is the generalized Mora algorithm for computing weak normal forms with + respect to arbitrary monomial orders [SCA, algorithm 2.3.9]. + + If ``phantom`` is not ``None``, it should be a pair of "phantom" arguments + on which to perform the same computations as on ``f``, ``G``, both results + are then returned. + """ + from itertools import repeat + h = f + T = list(G) + if phantom is not None: + # "phantom" variables with suffix p + hp = phantom[0] + Tp = list(phantom[1]) + phantom = True + else: + Tp = repeat([]) + phantom = False + while h: + # TODO better data structure!!! + Th = [(g, sdm_ecart(g), gp) for g, gp in zip(T, Tp) + if sdm_monomial_divides(sdm_LM(g), sdm_LM(h))] + if not Th: + break + g, _, gp = min(Th, key=lambda x: x[1]) + if sdm_ecart(g) > sdm_ecart(h): + T.append(h) + if phantom: + Tp.append(hp) + if phantom: + h, hp = sdm_spoly(h, g, O, K, phantom=(hp, gp)) + else: + h = sdm_spoly(h, g, O, K) + if phantom: + return h, hp + return h + + +def sdm_nf_buchberger(f, G, O, K, phantom=None): + r""" + Compute a weak normal form of ``f`` with respect to ``G`` and order ``O``. + + The ground field is assumed to be ``K``, and monomials ordered according to + ``O``. + + This is the standard Buchberger algorithm for computing weak normal forms with + respect to *global* monomial orders [SCA, algorithm 1.6.10]. + + If ``phantom`` is not ``None``, it should be a pair of "phantom" arguments + on which to perform the same computations as on ``f``, ``G``, both results + are then returned. + """ + from itertools import repeat + h = f + T = list(G) + if phantom is not None: + # "phantom" variables with suffix p + hp = phantom[0] + Tp = list(phantom[1]) + phantom = True + else: + Tp = repeat([]) + phantom = False + while h: + try: + g, gp = next((g, gp) for g, gp in zip(T, Tp) + if sdm_monomial_divides(sdm_LM(g), sdm_LM(h))) + except StopIteration: + break + if phantom: + h, hp = sdm_spoly(h, g, O, K, phantom=(hp, gp)) + else: + h = sdm_spoly(h, g, O, K) + if phantom: + return h, hp + return h + + +def sdm_nf_buchberger_reduced(f, G, O, K): + r""" + Compute a reduced normal form of ``f`` with respect to ``G`` and order ``O``. + + The ground field is assumed to be ``K``, and monomials ordered according to + ``O``. + + In contrast to weak normal forms, reduced normal forms *are* unique, but + their computation is more expensive. + + This is the standard Buchberger algorithm for computing reduced normal forms + with respect to *global* monomial orders [SCA, algorithm 1.6.11]. + + The ``pantom`` option is not supported, so this normal form cannot be used + as a normal form for the "extended" groebner algorithm. + """ + h = sdm_zero() + g = f + while g: + g = sdm_nf_buchberger(g, G, O, K) + if g: + h = sdm_add(h, [sdm_LT(g)], O, K) + g = g[1:] + return h + + +def sdm_groebner(G, NF, O, K, extended=False): + """ + Compute a minimal standard basis of ``G`` with respect to order ``O``. + + The algorithm uses a normal form ``NF``, for example ``sdm_nf_mora``. + The ground field is assumed to be ``K``, and monomials ordered according + to ``O``. + + Let `N` denote the submodule generated by elements of `G`. A standard + basis for `N` is a subset `S` of `N`, such that `in(S) = in(N)`, where for + any subset `X` of `F`, `in(X)` denotes the submodule generated by the + initial forms of elements of `X`. [SCA, defn 2.3.2] + + A standard basis is called minimal if no subset of it is a standard basis. + + One may show that standard bases are always generating sets. + + Minimal standard bases are not unique. This algorithm computes a + deterministic result, depending on the particular order of `G`. + + If ``extended=True``, also compute the transition matrix from the initial + generators to the groebner basis. That is, return a list of coefficient + vectors, expressing the elements of the groebner basis in terms of the + elements of ``G``. + + This functions implements the "sugar" strategy, see + + Giovini et al: "One sugar cube, please" OR Selection strategies in + Buchberger algorithm. + """ + + # The critical pair set. + # A critical pair is stored as (i, j, s, t) where (i, j) defines the pair + # (by indexing S), s is the sugar of the pair, and t is the lcm of their + # leading monomials. + P = [] + + # The eventual standard basis. + S = [] + Sugars = [] + + def Ssugar(i, j): + """Compute the sugar of the S-poly corresponding to (i, j).""" + LMi = sdm_LM(S[i]) + LMj = sdm_LM(S[j]) + return max(Sugars[i] - sdm_monomial_deg(LMi), + Sugars[j] - sdm_monomial_deg(LMj)) \ + + sdm_monomial_deg(sdm_monomial_lcm(LMi, LMj)) + + ourkey = lambda p: (p[2], O(p[3]), p[1]) + + def update(f, sugar, P): + """Add f with sugar ``sugar`` to S, update P.""" + if not f: + return P + k = len(S) + S.append(f) + Sugars.append(sugar) + + LMf = sdm_LM(f) + + def removethis(pair): + i, j, s, t = pair + if LMf[0] != t[0]: + return False + tik = sdm_monomial_lcm(LMf, sdm_LM(S[i])) + tjk = sdm_monomial_lcm(LMf, sdm_LM(S[j])) + return tik != t and tjk != t and sdm_monomial_divides(tik, t) and \ + sdm_monomial_divides(tjk, t) + # apply the chain criterion + P = [p for p in P if not removethis(p)] + + # new-pair set + N = [(i, k, Ssugar(i, k), sdm_monomial_lcm(LMf, sdm_LM(S[i]))) + for i in range(k) if LMf[0] == sdm_LM(S[i])[0]] + # TODO apply the product criterion? + N.sort(key=ourkey) + remove = set() + for i, p in enumerate(N): + for j in range(i + 1, len(N)): + if sdm_monomial_divides(p[3], N[j][3]): + remove.add(j) + + # TODO mergesort? + P.extend(reversed([p for i, p in enumerate(N) if i not in remove])) + P.sort(key=ourkey, reverse=True) + # NOTE reverse-sort, because we want to pop from the end + return P + + # Figure out the number of generators in the ground ring. + try: + # NOTE: we look for the first non-zero vector, take its first monomial + # the number of generators in the ring is one less than the length + # (since the zeroth entry is for the module generators) + numgens = len(next(x[0] for x in G if x)[0]) - 1 + except StopIteration: + # No non-zero elements in G ... + if extended: + return [], [] + return [] + + # This list will store expressions of the elements of S in terms of the + # initial generators + coefficients = [] + + # First add all the elements of G to S + for i, f in enumerate(G): + P = update(f, sdm_deg(f), P) + if extended and f: + coefficients.append(sdm_from_dict({(i,) + (0,)*numgens: K(1)}, O)) + + # Now carry out the buchberger algorithm. + while P: + i, j, s, t = P.pop() + f, g = S[i], S[j] + if extended: + sp, coeff = sdm_spoly(f, g, O, K, + phantom=(coefficients[i], coefficients[j])) + h, hcoeff = NF(sp, S, O, K, phantom=(coeff, coefficients)) + if h: + coefficients.append(hcoeff) + else: + h = NF(sdm_spoly(f, g, O, K), S, O, K) + P = update(h, Ssugar(i, j), P) + + # Finally interreduce the standard basis. + # (TODO again, better data structures) + S = {(tuple(f), i) for i, f in enumerate(S)} + for (a, ai), (b, bi) in permutations(S, 2): + A = sdm_LM(a) + B = sdm_LM(b) + if sdm_monomial_divides(A, B) and (b, bi) in S and (a, ai) in S: + S.remove((b, bi)) + + L = sorted(((list(f), i) for f, i in S), key=lambda p: O(sdm_LM(p[0])), + reverse=True) + res = [x[0] for x in L] + if extended: + return res, [coefficients[i] for _, i in L] + return res diff --git a/phivenv/Lib/site-packages/sympy/polys/domainmatrix.py b/phivenv/Lib/site-packages/sympy/polys/domainmatrix.py new file mode 100644 index 0000000000000000000000000000000000000000..c0ccaaa4cb96e0c49da58d8e9128c1b6fa551ade --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/domainmatrix.py @@ -0,0 +1,12 @@ +""" +Stub module to expose DomainMatrix which has now moved to +sympy.polys.matrices package. It should now be imported as: + + >>> from sympy.polys.matrices import DomainMatrix + +This module might be removed in future. +""" + +from sympy.polys.matrices.domainmatrix import DomainMatrix + +__all__ = ['DomainMatrix'] diff --git a/phivenv/Lib/site-packages/sympy/polys/euclidtools.py b/phivenv/Lib/site-packages/sympy/polys/euclidtools.py new file mode 100644 index 0000000000000000000000000000000000000000..768a44a94930f05e701e9f27a8b0f570a3312314 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/euclidtools.py @@ -0,0 +1,1912 @@ +"""Euclidean algorithms, GCDs, LCMs and polynomial remainder sequences. """ + + +from sympy.polys.densearith import ( + dup_sub_mul, + dup_neg, dmp_neg, + dmp_add, + dmp_sub, + dup_mul, dmp_mul, + dmp_pow, + dup_div, dmp_div, + dup_rem, + dup_quo, dmp_quo, + dup_prem, dmp_prem, + dup_mul_ground, dmp_mul_ground, + dmp_mul_term, + dup_quo_ground, dmp_quo_ground, + dup_max_norm, dmp_max_norm) +from sympy.polys.densebasic import ( + dup_strip, dmp_raise, + dmp_zero, dmp_one, dmp_ground, + dmp_one_p, dmp_zero_p, + dmp_zeros, + dup_degree, dmp_degree, dmp_degree_in, + dup_LC, dmp_LC, dmp_ground_LC, + dmp_multi_deflate, dmp_inflate, + dup_convert, dmp_convert, + dmp_apply_pairs) +from sympy.polys.densetools import ( + dup_clear_denoms, dmp_clear_denoms, + dup_diff, dmp_diff, + dup_eval, dmp_eval, dmp_eval_in, + dup_trunc, dmp_ground_trunc, + dup_monic, dmp_ground_monic, + dup_primitive, dmp_ground_primitive, + dup_extract, dmp_ground_extract) +from sympy.polys.galoistools import ( + gf_int, gf_crt) +from sympy.polys.polyconfig import query +from sympy.polys.polyerrors import ( + MultivariatePolynomialError, + HeuristicGCDFailed, + HomomorphismFailed, + NotInvertible, + DomainError) + + + + +def dup_half_gcdex(f, g, K): + """ + Half extended Euclidean algorithm in `F[x]`. + + Returns ``(s, h)`` such that ``h = gcd(f, g)`` and ``s*f = h (mod g)``. + + Examples + ======== + + >>> from sympy.polys import ring, QQ + >>> R, x = ring("x", QQ) + + >>> f = x**4 - 2*x**3 - 6*x**2 + 12*x + 15 + >>> g = x**3 + x**2 - 4*x - 4 + + >>> R.dup_half_gcdex(f, g) + (-1/5*x + 3/5, x + 1) + + """ + if not K.is_Field: + raise DomainError("Cannot compute half extended GCD over %s" % K) + + a, b = [K.one], [] + + while g: + q, r = dup_div(f, g, K) + f, g = g, r + a, b = b, dup_sub_mul(a, q, b, K) + + a = dup_quo_ground(a, dup_LC(f, K), K) + f = dup_monic(f, K) + + return a, f + + +def dmp_half_gcdex(f, g, u, K): + """ + Half extended Euclidean algorithm in `F[X]`. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + """ + if not u: + return dup_half_gcdex(f, g, K) + else: + raise MultivariatePolynomialError(f, g) + + +def dup_gcdex(f, g, K): + """ + Extended Euclidean algorithm in `F[x]`. + + Returns ``(s, t, h)`` such that ``h = gcd(f, g)`` and ``s*f + t*g = h``. + + Examples + ======== + + >>> from sympy.polys import ring, QQ + >>> R, x = ring("x", QQ) + + >>> f = x**4 - 2*x**3 - 6*x**2 + 12*x + 15 + >>> g = x**3 + x**2 - 4*x - 4 + + >>> R.dup_gcdex(f, g) + (-1/5*x + 3/5, 1/5*x**2 - 6/5*x + 2, x + 1) + + """ + s, h = dup_half_gcdex(f, g, K) + + F = dup_sub_mul(h, s, f, K) + t = dup_quo(F, g, K) + + return s, t, h + + +def dmp_gcdex(f, g, u, K): + """ + Extended Euclidean algorithm in `F[X]`. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + """ + if not u: + return dup_gcdex(f, g, K) + else: + raise MultivariatePolynomialError(f, g) + + +def dup_invert(f, g, K): + """ + Compute multiplicative inverse of `f` modulo `g` in `F[x]`. + + Examples + ======== + + >>> from sympy.polys import ring, QQ + >>> R, x = ring("x", QQ) + + >>> f = x**2 - 1 + >>> g = 2*x - 1 + >>> h = x - 1 + + >>> R.dup_invert(f, g) + -4/3 + + >>> R.dup_invert(f, h) + Traceback (most recent call last): + ... + NotInvertible: zero divisor + + """ + s, h = dup_half_gcdex(f, g, K) + + if h == [K.one]: + return dup_rem(s, g, K) + else: + raise NotInvertible("zero divisor") + + +def dmp_invert(f, g, u, K): + """ + Compute multiplicative inverse of `f` modulo `g` in `F[X]`. + + Examples + ======== + + >>> from sympy.polys import ring, QQ + >>> R, x = ring("x", QQ) + + """ + if not u: + return dup_invert(f, g, K) + else: + raise MultivariatePolynomialError(f, g) + + +def dup_euclidean_prs(f, g, K): + """ + Euclidean polynomial remainder sequence (PRS) in `K[x]`. + + Examples + ======== + + >>> from sympy.polys import ring, QQ + >>> R, x = ring("x", QQ) + + >>> f = x**8 + x**6 - 3*x**4 - 3*x**3 + 8*x**2 + 2*x - 5 + >>> g = 3*x**6 + 5*x**4 - 4*x**2 - 9*x + 21 + + >>> prs = R.dup_euclidean_prs(f, g) + + >>> prs[0] + x**8 + x**6 - 3*x**4 - 3*x**3 + 8*x**2 + 2*x - 5 + >>> prs[1] + 3*x**6 + 5*x**4 - 4*x**2 - 9*x + 21 + >>> prs[2] + -5/9*x**4 + 1/9*x**2 - 1/3 + >>> prs[3] + -117/25*x**2 - 9*x + 441/25 + >>> prs[4] + 233150/19773*x - 102500/6591 + >>> prs[5] + -1288744821/543589225 + + """ + prs = [f, g] + h = dup_rem(f, g, K) + + while h: + prs.append(h) + f, g = g, h + h = dup_rem(f, g, K) + + return prs + + +def dmp_euclidean_prs(f, g, u, K): + """ + Euclidean polynomial remainder sequence (PRS) in `K[X]`. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + """ + if not u: + return dup_euclidean_prs(f, g, K) + else: + raise MultivariatePolynomialError(f, g) + + +def dup_primitive_prs(f, g, K): + """ + Primitive polynomial remainder sequence (PRS) in `K[x]`. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> f = x**8 + x**6 - 3*x**4 - 3*x**3 + 8*x**2 + 2*x - 5 + >>> g = 3*x**6 + 5*x**4 - 4*x**2 - 9*x + 21 + + >>> prs = R.dup_primitive_prs(f, g) + + >>> prs[0] + x**8 + x**6 - 3*x**4 - 3*x**3 + 8*x**2 + 2*x - 5 + >>> prs[1] + 3*x**6 + 5*x**4 - 4*x**2 - 9*x + 21 + >>> prs[2] + -5*x**4 + x**2 - 3 + >>> prs[3] + 13*x**2 + 25*x - 49 + >>> prs[4] + 4663*x - 6150 + >>> prs[5] + 1 + + """ + prs = [f, g] + _, h = dup_primitive(dup_prem(f, g, K), K) + + while h: + prs.append(h) + f, g = g, h + _, h = dup_primitive(dup_prem(f, g, K), K) + + return prs + + +def dmp_primitive_prs(f, g, u, K): + """ + Primitive polynomial remainder sequence (PRS) in `K[X]`. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + """ + if not u: + return dup_primitive_prs(f, g, K) + else: + raise MultivariatePolynomialError(f, g) + + +def dup_inner_subresultants(f, g, K): + """ + Subresultant PRS algorithm in `K[x]`. + + Computes the subresultant polynomial remainder sequence (PRS) + and the non-zero scalar subresultants of `f` and `g`. + By [1] Thm. 3, these are the constants '-c' (- to optimize + computation of sign). + The first subdeterminant is set to 1 by convention to match + the polynomial and the scalar subdeterminants. + If 'deg(f) < deg(g)', the subresultants of '(g,f)' are computed. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_inner_subresultants(x**2 + 1, x**2 - 1) + ([x**2 + 1, x**2 - 1, -2], [1, 1, 4]) + + References + ========== + + .. [1] W.S. Brown, The Subresultant PRS Algorithm. + ACM Transaction of Mathematical Software 4 (1978) 237-249 + + """ + n = dup_degree(f) + m = dup_degree(g) + + if n < m: + f, g = g, f + n, m = m, n + + if not f: + return [], [] + + if not g: + return [f], [K.one] + + R = [f, g] + d = n - m + + b = (-K.one)**(d + 1) + + h = dup_prem(f, g, K) + h = dup_mul_ground(h, b, K) + + lc = dup_LC(g, K) + c = lc**d + + # Conventional first scalar subdeterminant is 1 + S = [K.one, c] + c = -c + + while h: + k = dup_degree(h) + R.append(h) + + f, g, m, d = g, h, k, m - k + + b = -lc * c**d + + h = dup_prem(f, g, K) + h = dup_quo_ground(h, b, K) + + lc = dup_LC(g, K) + + if d > 1: # abnormal case + q = c**(d - 1) + c = K.quo((-lc)**d, q) + else: + c = -lc + + S.append(-c) + + return R, S + + +def dup_subresultants(f, g, K): + """ + Computes subresultant PRS of two polynomials in `K[x]`. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_subresultants(x**2 + 1, x**2 - 1) + [x**2 + 1, x**2 - 1, -2] + + """ + return dup_inner_subresultants(f, g, K)[0] + + +def dup_prs_resultant(f, g, K): + """ + Resultant algorithm in `K[x]` using subresultant PRS. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_prs_resultant(x**2 + 1, x**2 - 1) + (4, [x**2 + 1, x**2 - 1, -2]) + + """ + if not f or not g: + return (K.zero, []) + + R, S = dup_inner_subresultants(f, g, K) + + if dup_degree(R[-1]) > 0: + return (K.zero, R) + + return S[-1], R + + +def dup_resultant(f, g, K, includePRS=False): + """ + Computes resultant of two polynomials in `K[x]`. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_resultant(x**2 + 1, x**2 - 1) + 4 + + """ + if includePRS: + return dup_prs_resultant(f, g, K) + return dup_prs_resultant(f, g, K)[0] + + +def dmp_inner_subresultants(f, g, u, K): + """ + Subresultant PRS algorithm in `K[X]`. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + >>> f = 3*x**2*y - y**3 - 4 + >>> g = x**2 + x*y**3 - 9 + + >>> a = 3*x*y**4 + y**3 - 27*y + 4 + >>> b = -3*y**10 - 12*y**7 + y**6 - 54*y**4 + 8*y**3 + 729*y**2 - 216*y + 16 + + >>> prs = [f, g, a, b] + >>> sres = [[1], [1], [3, 0, 0, 0, 0], [-3, 0, 0, -12, 1, 0, -54, 8, 729, -216, 16]] + + >>> R.dmp_inner_subresultants(f, g) == (prs, sres) + True + + """ + if not u: + return dup_inner_subresultants(f, g, K) + + n = dmp_degree(f, u) + m = dmp_degree(g, u) + + if n < m: + f, g = g, f + n, m = m, n + + if dmp_zero_p(f, u): + return [], [] + + v = u - 1 + if dmp_zero_p(g, u): + return [f], [dmp_ground(K.one, v)] + + R = [f, g] + d = n - m + + b = dmp_pow(dmp_ground(-K.one, v), d + 1, v, K) + + h = dmp_prem(f, g, u, K) + h = dmp_mul_term(h, b, 0, u, K) + + lc = dmp_LC(g, K) + c = dmp_pow(lc, d, v, K) + + S = [dmp_ground(K.one, v), c] + c = dmp_neg(c, v, K) + + while not dmp_zero_p(h, u): + k = dmp_degree(h, u) + R.append(h) + + f, g, m, d = g, h, k, m - k + + b = dmp_mul(dmp_neg(lc, v, K), + dmp_pow(c, d, v, K), v, K) + + h = dmp_prem(f, g, u, K) + h = [ dmp_quo(ch, b, v, K) for ch in h ] + + lc = dmp_LC(g, K) + + if d > 1: + p = dmp_pow(dmp_neg(lc, v, K), d, v, K) + q = dmp_pow(c, d - 1, v, K) + c = dmp_quo(p, q, v, K) + else: + c = dmp_neg(lc, v, K) + + S.append(dmp_neg(c, v, K)) + + return R, S + + +def dmp_subresultants(f, g, u, K): + """ + Computes subresultant PRS of two polynomials in `K[X]`. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + >>> f = 3*x**2*y - y**3 - 4 + >>> g = x**2 + x*y**3 - 9 + + >>> a = 3*x*y**4 + y**3 - 27*y + 4 + >>> b = -3*y**10 - 12*y**7 + y**6 - 54*y**4 + 8*y**3 + 729*y**2 - 216*y + 16 + + >>> R.dmp_subresultants(f, g) == [f, g, a, b] + True + + """ + return dmp_inner_subresultants(f, g, u, K)[0] + + +def dmp_prs_resultant(f, g, u, K): + """ + Resultant algorithm in `K[X]` using subresultant PRS. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + >>> f = 3*x**2*y - y**3 - 4 + >>> g = x**2 + x*y**3 - 9 + + >>> a = 3*x*y**4 + y**3 - 27*y + 4 + >>> b = -3*y**10 - 12*y**7 + y**6 - 54*y**4 + 8*y**3 + 729*y**2 - 216*y + 16 + + >>> res, prs = R.dmp_prs_resultant(f, g) + + >>> res == b # resultant has n-1 variables + False + >>> res == b.drop(x) + True + >>> prs == [f, g, a, b] + True + + """ + if not u: + return dup_prs_resultant(f, g, K) + + if dmp_zero_p(f, u) or dmp_zero_p(g, u): + return (dmp_zero(u - 1), []) + + R, S = dmp_inner_subresultants(f, g, u, K) + + if dmp_degree(R[-1], u) > 0: + return (dmp_zero(u - 1), R) + + return S[-1], R + + +def dmp_zz_modular_resultant(f, g, p, u, K): + """ + Compute resultant of `f` and `g` modulo a prime `p`. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + >>> f = x + y + 2 + >>> g = 2*x*y + x + 3 + + >>> R.dmp_zz_modular_resultant(f, g, 5) + -2*y**2 + 1 + + """ + if not u: + return gf_int(dup_prs_resultant(f, g, K)[0] % p, p) + + v = u - 1 + + n = dmp_degree(f, u) + m = dmp_degree(g, u) + + N = dmp_degree_in(f, 1, u) + M = dmp_degree_in(g, 1, u) + + B = n*M + m*N + + D, a = [K.one], -K.one + r = dmp_zero(v) + + while dup_degree(D) <= B: + while True: + a += K.one + + if a == p: + raise HomomorphismFailed('no luck') + + F = dmp_eval_in(f, gf_int(a, p), 1, u, K) + + if dmp_degree(F, v) == n: + G = dmp_eval_in(g, gf_int(a, p), 1, u, K) + + if dmp_degree(G, v) == m: + break + + R = dmp_zz_modular_resultant(F, G, p, v, K) + e = dmp_eval(r, a, v, K) + + if not v: + R = dup_strip([R]) + e = dup_strip([e]) + else: + R = [R] + e = [e] + + d = K.invert(dup_eval(D, a, K), p) + d = dup_mul_ground(D, d, K) + d = dmp_raise(d, v, 0, K) + + c = dmp_mul(d, dmp_sub(R, e, v, K), v, K) + r = dmp_add(r, c, v, K) + + r = dmp_ground_trunc(r, p, v, K) + + D = dup_mul(D, [K.one, -a], K) + D = dup_trunc(D, p, K) + + return r + + +def _collins_crt(r, R, P, p, K): + """Wrapper of CRT for Collins's resultant algorithm. """ + return gf_int(gf_crt([r, R], [P, p], K), P*p) + + +def dmp_zz_collins_resultant(f, g, u, K): + """ + Collins's modular resultant algorithm in `Z[X]`. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + >>> f = x + y + 2 + >>> g = 2*x*y + x + 3 + + >>> R.dmp_zz_collins_resultant(f, g) + -2*y**2 - 5*y + 1 + + """ + + n = dmp_degree(f, u) + m = dmp_degree(g, u) + + if n < 0 or m < 0: + return dmp_zero(u - 1) + + A = dmp_max_norm(f, u, K) + B = dmp_max_norm(g, u, K) + + a = dmp_ground_LC(f, u, K) + b = dmp_ground_LC(g, u, K) + + v = u - 1 + + B = K(2)*K.factorial(K(n + m))*A**m*B**n + r, p, P = dmp_zero(v), K.one, K.one + + from sympy.ntheory import nextprime + + while P <= B: + p = K(nextprime(p)) + + while not (a % p) or not (b % p): + p = K(nextprime(p)) + + F = dmp_ground_trunc(f, p, u, K) + G = dmp_ground_trunc(g, p, u, K) + + try: + R = dmp_zz_modular_resultant(F, G, p, u, K) + except HomomorphismFailed: + continue + + if K.is_one(P): + r = R + else: + r = dmp_apply_pairs(r, R, _collins_crt, (P, p, K), v, K) + + P *= p + + return r + + +def dmp_qq_collins_resultant(f, g, u, K0): + """ + Collins's modular resultant algorithm in `Q[X]`. + + Examples + ======== + + >>> from sympy.polys import ring, QQ + >>> R, x,y = ring("x,y", QQ) + + >>> f = QQ(1,2)*x + y + QQ(2,3) + >>> g = 2*x*y + x + 3 + + >>> R.dmp_qq_collins_resultant(f, g) + -2*y**2 - 7/3*y + 5/6 + + """ + n = dmp_degree(f, u) + m = dmp_degree(g, u) + + if n < 0 or m < 0: + return dmp_zero(u - 1) + + K1 = K0.get_ring() + + cf, f = dmp_clear_denoms(f, u, K0, K1) + cg, g = dmp_clear_denoms(g, u, K0, K1) + + f = dmp_convert(f, u, K0, K1) + g = dmp_convert(g, u, K0, K1) + + r = dmp_zz_collins_resultant(f, g, u, K1) + r = dmp_convert(r, u - 1, K1, K0) + + c = K0.convert(cf**m * cg**n, K1) + + return dmp_quo_ground(r, c, u - 1, K0) + + +def dmp_resultant(f, g, u, K, includePRS=False): + """ + Computes resultant of two polynomials in `K[X]`. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + >>> f = 3*x**2*y - y**3 - 4 + >>> g = x**2 + x*y**3 - 9 + + >>> R.dmp_resultant(f, g) + -3*y**10 - 12*y**7 + y**6 - 54*y**4 + 8*y**3 + 729*y**2 - 216*y + 16 + + """ + if not u: + return dup_resultant(f, g, K, includePRS=includePRS) + + if includePRS: + return dmp_prs_resultant(f, g, u, K) + + if K.is_Field: + if K.is_QQ and query('USE_COLLINS_RESULTANT'): + return dmp_qq_collins_resultant(f, g, u, K) + else: + if K.is_ZZ and query('USE_COLLINS_RESULTANT'): + return dmp_zz_collins_resultant(f, g, u, K) + + return dmp_prs_resultant(f, g, u, K)[0] + + +def dup_discriminant(f, K): + """ + Computes discriminant of a polynomial in `K[x]`. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_discriminant(x**2 + 2*x + 3) + -8 + + """ + d = dup_degree(f) + + if d <= 0: + return K.zero + else: + s = (-1)**((d*(d - 1)) // 2) + c = dup_LC(f, K) + + r = dup_resultant(f, dup_diff(f, 1, K), K) + + return K.quo(r, c*K(s)) + + +def dmp_discriminant(f, u, K): + """ + Computes discriminant of a polynomial in `K[X]`. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y,z,t = ring("x,y,z,t", ZZ) + + >>> R.dmp_discriminant(x**2*y + x*z + t) + -4*y*t + z**2 + + """ + if not u: + return dup_discriminant(f, K) + + d, v = dmp_degree(f, u), u - 1 + + if d <= 0: + return dmp_zero(v) + else: + s = (-1)**((d*(d - 1)) // 2) + c = dmp_LC(f, K) + + r = dmp_resultant(f, dmp_diff(f, 1, u, K), u, K) + c = dmp_mul_ground(c, K(s), v, K) + + return dmp_quo(r, c, v, K) + + +def _dup_rr_trivial_gcd(f, g, K): + """Handle trivial cases in GCD algorithm over a ring. """ + if not (f or g): + return [], [], [] + elif not f: + if K.is_nonnegative(dup_LC(g, K)): + return g, [], [K.one] + else: + return dup_neg(g, K), [], [-K.one] + elif not g: + if K.is_nonnegative(dup_LC(f, K)): + return f, [K.one], [] + else: + return dup_neg(f, K), [-K.one], [] + + return None + + +def _dup_ff_trivial_gcd(f, g, K): + """Handle trivial cases in GCD algorithm over a field. """ + if not (f or g): + return [], [], [] + elif not f: + return dup_monic(g, K), [], [dup_LC(g, K)] + elif not g: + return dup_monic(f, K), [dup_LC(f, K)], [] + else: + return None + + +def _dmp_rr_trivial_gcd(f, g, u, K): + """Handle trivial cases in GCD algorithm over a ring. """ + zero_f = dmp_zero_p(f, u) + zero_g = dmp_zero_p(g, u) + if_contain_one = dmp_one_p(f, u, K) or dmp_one_p(g, u, K) + + if zero_f and zero_g: + return tuple(dmp_zeros(3, u, K)) + elif zero_f: + if K.is_nonnegative(dmp_ground_LC(g, u, K)): + return g, dmp_zero(u), dmp_one(u, K) + else: + return dmp_neg(g, u, K), dmp_zero(u), dmp_ground(-K.one, u) + elif zero_g: + if K.is_nonnegative(dmp_ground_LC(f, u, K)): + return f, dmp_one(u, K), dmp_zero(u) + else: + return dmp_neg(f, u, K), dmp_ground(-K.one, u), dmp_zero(u) + elif if_contain_one: + return dmp_one(u, K), f, g + elif query('USE_SIMPLIFY_GCD'): + return _dmp_simplify_gcd(f, g, u, K) + else: + return None + + +def _dmp_ff_trivial_gcd(f, g, u, K): + """Handle trivial cases in GCD algorithm over a field. """ + zero_f = dmp_zero_p(f, u) + zero_g = dmp_zero_p(g, u) + + if zero_f and zero_g: + return tuple(dmp_zeros(3, u, K)) + elif zero_f: + return (dmp_ground_monic(g, u, K), + dmp_zero(u), + dmp_ground(dmp_ground_LC(g, u, K), u)) + elif zero_g: + return (dmp_ground_monic(f, u, K), + dmp_ground(dmp_ground_LC(f, u, K), u), + dmp_zero(u)) + elif query('USE_SIMPLIFY_GCD'): + return _dmp_simplify_gcd(f, g, u, K) + else: + return None + + +def _dmp_simplify_gcd(f, g, u, K): + """Try to eliminate `x_0` from GCD computation in `K[X]`. """ + df = dmp_degree(f, u) + dg = dmp_degree(g, u) + + if df > 0 and dg > 0: + return None + + if not (df or dg): + F = dmp_LC(f, K) + G = dmp_LC(g, K) + else: + if not df: + F = dmp_LC(f, K) + G = dmp_content(g, u, K) + else: + F = dmp_content(f, u, K) + G = dmp_LC(g, K) + + v = u - 1 + h = dmp_gcd(F, G, v, K) + + cff = [ dmp_quo(cf, h, v, K) for cf in f ] + cfg = [ dmp_quo(cg, h, v, K) for cg in g ] + + return [h], cff, cfg + + +def dup_rr_prs_gcd(f, g, K): + """ + Computes polynomial GCD using subresultants over a ring. + + Returns ``(h, cff, cfg)`` such that ``a = gcd(f, g)``, ``cff = quo(f, h)``, + and ``cfg = quo(g, h)``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_rr_prs_gcd(x**2 - 1, x**2 - 3*x + 2) + (x - 1, x + 1, x - 2) + + """ + result = _dup_rr_trivial_gcd(f, g, K) + + if result is not None: + return result + + fc, F = dup_primitive(f, K) + gc, G = dup_primitive(g, K) + + c = K.gcd(fc, gc) + + h = dup_subresultants(F, G, K)[-1] + _, h = dup_primitive(h, K) + + c *= K.canonical_unit(dup_LC(h, K)) + + h = dup_mul_ground(h, c, K) + + cff = dup_quo(f, h, K) + cfg = dup_quo(g, h, K) + + return h, cff, cfg + + +def dup_ff_prs_gcd(f, g, K): + """ + Computes polynomial GCD using subresultants over a field. + + Returns ``(h, cff, cfg)`` such that ``a = gcd(f, g)``, ``cff = quo(f, h)``, + and ``cfg = quo(g, h)``. + + Examples + ======== + + >>> from sympy.polys import ring, QQ + >>> R, x = ring("x", QQ) + + >>> R.dup_ff_prs_gcd(x**2 - 1, x**2 - 3*x + 2) + (x - 1, x + 1, x - 2) + + """ + result = _dup_ff_trivial_gcd(f, g, K) + + if result is not None: + return result + + h = dup_subresultants(f, g, K)[-1] + h = dup_monic(h, K) + + cff = dup_quo(f, h, K) + cfg = dup_quo(g, h, K) + + return h, cff, cfg + + +def dmp_rr_prs_gcd(f, g, u, K): + """ + Computes polynomial GCD using subresultants over a ring. + + Returns ``(h, cff, cfg)`` such that ``a = gcd(f, g)``, ``cff = quo(f, h)``, + and ``cfg = quo(g, h)``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y, = ring("x,y", ZZ) + + >>> f = x**2 + 2*x*y + y**2 + >>> g = x**2 + x*y + + >>> R.dmp_rr_prs_gcd(f, g) + (x + y, x + y, x) + + """ + if not u: + return dup_rr_prs_gcd(f, g, K) + + result = _dmp_rr_trivial_gcd(f, g, u, K) + + if result is not None: + return result + + fc, F = dmp_primitive(f, u, K) + gc, G = dmp_primitive(g, u, K) + + h = dmp_subresultants(F, G, u, K)[-1] + c, _, _ = dmp_rr_prs_gcd(fc, gc, u - 1, K) + + _, h = dmp_primitive(h, u, K) + h = dmp_mul_term(h, c, 0, u, K) + + unit = K.canonical_unit(dmp_ground_LC(h, u, K)) + + if unit != K.one: + h = dmp_mul_ground(h, unit, u, K) + + cff = dmp_quo(f, h, u, K) + cfg = dmp_quo(g, h, u, K) + + return h, cff, cfg + + +def dmp_ff_prs_gcd(f, g, u, K): + """ + Computes polynomial GCD using subresultants over a field. + + Returns ``(h, cff, cfg)`` such that ``a = gcd(f, g)``, ``cff = quo(f, h)``, + and ``cfg = quo(g, h)``. + + Examples + ======== + + >>> from sympy.polys import ring, QQ + >>> R, x,y, = ring("x,y", QQ) + + >>> f = QQ(1,2)*x**2 + x*y + QQ(1,2)*y**2 + >>> g = x**2 + x*y + + >>> R.dmp_ff_prs_gcd(f, g) + (x + y, 1/2*x + 1/2*y, x) + + """ + if not u: + return dup_ff_prs_gcd(f, g, K) + + result = _dmp_ff_trivial_gcd(f, g, u, K) + + if result is not None: + return result + + fc, F = dmp_primitive(f, u, K) + gc, G = dmp_primitive(g, u, K) + + h = dmp_subresultants(F, G, u, K)[-1] + c, _, _ = dmp_ff_prs_gcd(fc, gc, u - 1, K) + + _, h = dmp_primitive(h, u, K) + h = dmp_mul_term(h, c, 0, u, K) + h = dmp_ground_monic(h, u, K) + + cff = dmp_quo(f, h, u, K) + cfg = dmp_quo(g, h, u, K) + + return h, cff, cfg + +HEU_GCD_MAX = 6 + + +def _dup_zz_gcd_interpolate(h, x, K): + """Interpolate polynomial GCD from integer GCD. """ + f = [] + + while h: + g = h % x + + if g > x // 2: + g -= x + + f.insert(0, g) + h = (h - g) // x + + return f + + +def dup_zz_heu_gcd(f, g, K): + """ + Heuristic polynomial GCD in `Z[x]`. + + Given univariate polynomials `f` and `g` in `Z[x]`, returns + their GCD and cofactors, i.e. polynomials ``h``, ``cff`` and ``cfg`` + such that:: + + h = gcd(f, g), cff = quo(f, h) and cfg = quo(g, h) + + The algorithm is purely heuristic which means it may fail to compute + the GCD. This will be signaled by raising an exception. In this case + you will need to switch to another GCD method. + + The algorithm computes the polynomial GCD by evaluating polynomials + f and g at certain points and computing (fast) integer GCD of those + evaluations. The polynomial GCD is recovered from the integer image + by interpolation. The final step is to verify if the result is the + correct GCD. This gives cofactors as a side effect. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_zz_heu_gcd(x**2 - 1, x**2 - 3*x + 2) + (x - 1, x + 1, x - 2) + + References + ========== + + .. [1] [Liao95]_ + + """ + result = _dup_rr_trivial_gcd(f, g, K) + + if result is not None: + return result + + df = dup_degree(f) + dg = dup_degree(g) + + gcd, f, g = dup_extract(f, g, K) + + if df == 0 or dg == 0: + return [gcd], f, g + + f_norm = dup_max_norm(f, K) + g_norm = dup_max_norm(g, K) + + B = K(2*min(f_norm, g_norm) + 29) + + x = max(min(B, 99*K.sqrt(B)), + 2*min(f_norm // abs(dup_LC(f, K)), + g_norm // abs(dup_LC(g, K))) + 4) + + for i in range(0, HEU_GCD_MAX): + ff = dup_eval(f, x, K) + gg = dup_eval(g, x, K) + + if ff and gg: + h = K.gcd(ff, gg) + + cff = ff // h + cfg = gg // h + + h = _dup_zz_gcd_interpolate(h, x, K) + h = dup_primitive(h, K)[1] + + cff_, r = dup_div(f, h, K) + + if not r: + cfg_, r = dup_div(g, h, K) + + if not r: + h = dup_mul_ground(h, gcd, K) + return h, cff_, cfg_ + + cff = _dup_zz_gcd_interpolate(cff, x, K) + + h, r = dup_div(f, cff, K) + + if not r: + cfg_, r = dup_div(g, h, K) + + if not r: + h = dup_mul_ground(h, gcd, K) + return h, cff, cfg_ + + cfg = _dup_zz_gcd_interpolate(cfg, x, K) + + h, r = dup_div(g, cfg, K) + + if not r: + cff_, r = dup_div(f, h, K) + + if not r: + h = dup_mul_ground(h, gcd, K) + return h, cff_, cfg + + x = 73794*x * K.sqrt(K.sqrt(x)) // 27011 + + raise HeuristicGCDFailed('no luck') + + +def _dmp_zz_gcd_interpolate(h, x, v, K): + """Interpolate polynomial GCD from integer GCD. """ + f = [] + + while not dmp_zero_p(h, v): + g = dmp_ground_trunc(h, x, v, K) + f.insert(0, g) + + h = dmp_sub(h, g, v, K) + h = dmp_quo_ground(h, x, v, K) + + if K.is_negative(dmp_ground_LC(f, v + 1, K)): + return dmp_neg(f, v + 1, K) + else: + return f + + +def dmp_zz_heu_gcd(f, g, u, K): + """ + Heuristic polynomial GCD in `Z[X]`. + + Given univariate polynomials `f` and `g` in `Z[X]`, returns + their GCD and cofactors, i.e. polynomials ``h``, ``cff`` and ``cfg`` + such that:: + + h = gcd(f, g), cff = quo(f, h) and cfg = quo(g, h) + + The algorithm is purely heuristic which means it may fail to compute + the GCD. This will be signaled by raising an exception. In this case + you will need to switch to another GCD method. + + The algorithm computes the polynomial GCD by evaluating polynomials + f and g at certain points and computing (fast) integer GCD of those + evaluations. The polynomial GCD is recovered from the integer image + by interpolation. The evaluation process reduces f and g variable by + variable into a large integer. The final step is to verify if the + interpolated polynomial is the correct GCD. This gives cofactors of + the input polynomials as a side effect. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y, = ring("x,y", ZZ) + + >>> f = x**2 + 2*x*y + y**2 + >>> g = x**2 + x*y + + >>> R.dmp_zz_heu_gcd(f, g) + (x + y, x + y, x) + + References + ========== + + .. [1] [Liao95]_ + + """ + if not u: + return dup_zz_heu_gcd(f, g, K) + + result = _dmp_rr_trivial_gcd(f, g, u, K) + + if result is not None: + return result + + gcd, f, g = dmp_ground_extract(f, g, u, K) + + f_norm = dmp_max_norm(f, u, K) + g_norm = dmp_max_norm(g, u, K) + + B = K(2*min(f_norm, g_norm) + 29) + + x = max(min(B, 99*K.sqrt(B)), + 2*min(f_norm // abs(dmp_ground_LC(f, u, K)), + g_norm // abs(dmp_ground_LC(g, u, K))) + 4) + + for i in range(0, HEU_GCD_MAX): + ff = dmp_eval(f, x, u, K) + gg = dmp_eval(g, x, u, K) + + v = u - 1 + + if not (dmp_zero_p(ff, v) or dmp_zero_p(gg, v)): + h, cff, cfg = dmp_zz_heu_gcd(ff, gg, v, K) + + h = _dmp_zz_gcd_interpolate(h, x, v, K) + h = dmp_ground_primitive(h, u, K)[1] + + cff_, r = dmp_div(f, h, u, K) + + if dmp_zero_p(r, u): + cfg_, r = dmp_div(g, h, u, K) + + if dmp_zero_p(r, u): + h = dmp_mul_ground(h, gcd, u, K) + return h, cff_, cfg_ + + cff = _dmp_zz_gcd_interpolate(cff, x, v, K) + + h, r = dmp_div(f, cff, u, K) + + if dmp_zero_p(r, u): + cfg_, r = dmp_div(g, h, u, K) + + if dmp_zero_p(r, u): + h = dmp_mul_ground(h, gcd, u, K) + return h, cff, cfg_ + + cfg = _dmp_zz_gcd_interpolate(cfg, x, v, K) + + h, r = dmp_div(g, cfg, u, K) + + if dmp_zero_p(r, u): + cff_, r = dmp_div(f, h, u, K) + + if dmp_zero_p(r, u): + h = dmp_mul_ground(h, gcd, u, K) + return h, cff_, cfg + + x = 73794*x * K.sqrt(K.sqrt(x)) // 27011 + + raise HeuristicGCDFailed('no luck') + + +def dup_qq_heu_gcd(f, g, K0): + """ + Heuristic polynomial GCD in `Q[x]`. + + Returns ``(h, cff, cfg)`` such that ``a = gcd(f, g)``, + ``cff = quo(f, h)``, and ``cfg = quo(g, h)``. + + Examples + ======== + + >>> from sympy.polys import ring, QQ + >>> R, x = ring("x", QQ) + + >>> f = QQ(1,2)*x**2 + QQ(7,4)*x + QQ(3,2) + >>> g = QQ(1,2)*x**2 + x + + >>> R.dup_qq_heu_gcd(f, g) + (x + 2, 1/2*x + 3/4, 1/2*x) + + """ + result = _dup_ff_trivial_gcd(f, g, K0) + + if result is not None: + return result + + K1 = K0.get_ring() + + cf, f = dup_clear_denoms(f, K0, K1) + cg, g = dup_clear_denoms(g, K0, K1) + + f = dup_convert(f, K0, K1) + g = dup_convert(g, K0, K1) + + h, cff, cfg = dup_zz_heu_gcd(f, g, K1) + + h = dup_convert(h, K1, K0) + + c = dup_LC(h, K0) + h = dup_monic(h, K0) + + cff = dup_convert(cff, K1, K0) + cfg = dup_convert(cfg, K1, K0) + + cff = dup_mul_ground(cff, K0.quo(c, cf), K0) + cfg = dup_mul_ground(cfg, K0.quo(c, cg), K0) + + return h, cff, cfg + + +def dmp_qq_heu_gcd(f, g, u, K0): + """ + Heuristic polynomial GCD in `Q[X]`. + + Returns ``(h, cff, cfg)`` such that ``a = gcd(f, g)``, + ``cff = quo(f, h)``, and ``cfg = quo(g, h)``. + + Examples + ======== + + >>> from sympy.polys import ring, QQ + >>> R, x,y, = ring("x,y", QQ) + + >>> f = QQ(1,4)*x**2 + x*y + y**2 + >>> g = QQ(1,2)*x**2 + x*y + + >>> R.dmp_qq_heu_gcd(f, g) + (x + 2*y, 1/4*x + 1/2*y, 1/2*x) + + """ + result = _dmp_ff_trivial_gcd(f, g, u, K0) + + if result is not None: + return result + + K1 = K0.get_ring() + + cf, f = dmp_clear_denoms(f, u, K0, K1) + cg, g = dmp_clear_denoms(g, u, K0, K1) + + f = dmp_convert(f, u, K0, K1) + g = dmp_convert(g, u, K0, K1) + + h, cff, cfg = dmp_zz_heu_gcd(f, g, u, K1) + + h = dmp_convert(h, u, K1, K0) + + c = dmp_ground_LC(h, u, K0) + h = dmp_ground_monic(h, u, K0) + + cff = dmp_convert(cff, u, K1, K0) + cfg = dmp_convert(cfg, u, K1, K0) + + cff = dmp_mul_ground(cff, K0.quo(c, cf), u, K0) + cfg = dmp_mul_ground(cfg, K0.quo(c, cg), u, K0) + + return h, cff, cfg + + +def dup_inner_gcd(f, g, K): + """ + Computes polynomial GCD and cofactors of `f` and `g` in `K[x]`. + + Returns ``(h, cff, cfg)`` such that ``a = gcd(f, g)``, + ``cff = quo(f, h)``, and ``cfg = quo(g, h)``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_inner_gcd(x**2 - 1, x**2 - 3*x + 2) + (x - 1, x + 1, x - 2) + + """ + # XXX: This used to check for K.is_Exact but leads to awkward results when + # the domain is something like RR[z] e.g.: + # + # >>> g, p, q = Poly(1, x).cancel(Poly(51.05*x*y - 1.0, x)) + # >>> g + # 1.0 + # >>> p + # Poly(17592186044421.0, x, domain='RR[y]') + # >>> q + # Poly(898081097567692.0*y*x - 17592186044421.0, x, domain='RR[y]')) + # + # Maybe it would be better to flatten into multivariate polynomials first. + if K.is_RR or K.is_CC: + try: + exact = K.get_exact() + except DomainError: + return [K.one], f, g + + f = dup_convert(f, K, exact) + g = dup_convert(g, K, exact) + + h, cff, cfg = dup_inner_gcd(f, g, exact) + + h = dup_convert(h, exact, K) + cff = dup_convert(cff, exact, K) + cfg = dup_convert(cfg, exact, K) + + return h, cff, cfg + elif K.is_Field: + if K.is_QQ and query('USE_HEU_GCD'): + try: + return dup_qq_heu_gcd(f, g, K) + except HeuristicGCDFailed: + pass + + return dup_ff_prs_gcd(f, g, K) + else: + if K.is_ZZ and query('USE_HEU_GCD'): + try: + return dup_zz_heu_gcd(f, g, K) + except HeuristicGCDFailed: + pass + + return dup_rr_prs_gcd(f, g, K) + + +def _dmp_inner_gcd(f, g, u, K): + """Helper function for `dmp_inner_gcd()`. """ + if not K.is_Exact: + try: + exact = K.get_exact() + except DomainError: + return dmp_one(u, K), f, g + + f = dmp_convert(f, u, K, exact) + g = dmp_convert(g, u, K, exact) + + h, cff, cfg = _dmp_inner_gcd(f, g, u, exact) + + h = dmp_convert(h, u, exact, K) + cff = dmp_convert(cff, u, exact, K) + cfg = dmp_convert(cfg, u, exact, K) + + return h, cff, cfg + elif K.is_Field: + if K.is_QQ and query('USE_HEU_GCD'): + try: + return dmp_qq_heu_gcd(f, g, u, K) + except HeuristicGCDFailed: + pass + + return dmp_ff_prs_gcd(f, g, u, K) + else: + if K.is_ZZ and query('USE_HEU_GCD'): + try: + return dmp_zz_heu_gcd(f, g, u, K) + except HeuristicGCDFailed: + pass + + return dmp_rr_prs_gcd(f, g, u, K) + + +def dmp_inner_gcd(f, g, u, K): + """ + Computes polynomial GCD and cofactors of `f` and `g` in `K[X]`. + + Returns ``(h, cff, cfg)`` such that ``a = gcd(f, g)``, + ``cff = quo(f, h)``, and ``cfg = quo(g, h)``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y, = ring("x,y", ZZ) + + >>> f = x**2 + 2*x*y + y**2 + >>> g = x**2 + x*y + + >>> R.dmp_inner_gcd(f, g) + (x + y, x + y, x) + + """ + if not u: + return dup_inner_gcd(f, g, K) + + J, (f, g) = dmp_multi_deflate((f, g), u, K) + h, cff, cfg = _dmp_inner_gcd(f, g, u, K) + + return (dmp_inflate(h, J, u, K), + dmp_inflate(cff, J, u, K), + dmp_inflate(cfg, J, u, K)) + + +def dup_gcd(f, g, K): + """ + Computes polynomial GCD of `f` and `g` in `K[x]`. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_gcd(x**2 - 1, x**2 - 3*x + 2) + x - 1 + + """ + return dup_inner_gcd(f, g, K)[0] + + +def dmp_gcd(f, g, u, K): + """ + Computes polynomial GCD of `f` and `g` in `K[X]`. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y, = ring("x,y", ZZ) + + >>> f = x**2 + 2*x*y + y**2 + >>> g = x**2 + x*y + + >>> R.dmp_gcd(f, g) + x + y + + """ + return dmp_inner_gcd(f, g, u, K)[0] + + +def dup_rr_lcm(f, g, K): + """ + Computes polynomial LCM over a ring in `K[x]`. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_rr_lcm(x**2 - 1, x**2 - 3*x + 2) + x**3 - 2*x**2 - x + 2 + + """ + if not f or not g: + return dmp_zero(0) + + fc, f = dup_primitive(f, K) + gc, g = dup_primitive(g, K) + + c = K.lcm(fc, gc) + + h = dup_quo(dup_mul(f, g, K), + dup_gcd(f, g, K), K) + + u = K.canonical_unit(dup_LC(h, K)) + + return dup_mul_ground(h, c*u, K) + + +def dup_ff_lcm(f, g, K): + """ + Computes polynomial LCM over a field in `K[x]`. + + Examples + ======== + + >>> from sympy.polys import ring, QQ + >>> R, x = ring("x", QQ) + + >>> f = QQ(1,2)*x**2 + QQ(7,4)*x + QQ(3,2) + >>> g = QQ(1,2)*x**2 + x + + >>> R.dup_ff_lcm(f, g) + x**3 + 7/2*x**2 + 3*x + + """ + h = dup_quo(dup_mul(f, g, K), + dup_gcd(f, g, K), K) + + return dup_monic(h, K) + + +def dup_lcm(f, g, K): + """ + Computes polynomial LCM of `f` and `g` in `K[x]`. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_lcm(x**2 - 1, x**2 - 3*x + 2) + x**3 - 2*x**2 - x + 2 + + """ + if K.is_Field: + return dup_ff_lcm(f, g, K) + else: + return dup_rr_lcm(f, g, K) + + +def dmp_rr_lcm(f, g, u, K): + """ + Computes polynomial LCM over a ring in `K[X]`. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y, = ring("x,y", ZZ) + + >>> f = x**2 + 2*x*y + y**2 + >>> g = x**2 + x*y + + >>> R.dmp_rr_lcm(f, g) + x**3 + 2*x**2*y + x*y**2 + + """ + fc, f = dmp_ground_primitive(f, u, K) + gc, g = dmp_ground_primitive(g, u, K) + + c = K.lcm(fc, gc) + + h = dmp_quo(dmp_mul(f, g, u, K), + dmp_gcd(f, g, u, K), u, K) + + return dmp_mul_ground(h, c, u, K) + + +def dmp_ff_lcm(f, g, u, K): + """ + Computes polynomial LCM over a field in `K[X]`. + + Examples + ======== + + >>> from sympy.polys import ring, QQ + >>> R, x,y, = ring("x,y", QQ) + + >>> f = QQ(1,4)*x**2 + x*y + y**2 + >>> g = QQ(1,2)*x**2 + x*y + + >>> R.dmp_ff_lcm(f, g) + x**3 + 4*x**2*y + 4*x*y**2 + + """ + h = dmp_quo(dmp_mul(f, g, u, K), + dmp_gcd(f, g, u, K), u, K) + + return dmp_ground_monic(h, u, K) + + +def dmp_lcm(f, g, u, K): + """ + Computes polynomial LCM of `f` and `g` in `K[X]`. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y, = ring("x,y", ZZ) + + >>> f = x**2 + 2*x*y + y**2 + >>> g = x**2 + x*y + + >>> R.dmp_lcm(f, g) + x**3 + 2*x**2*y + x*y**2 + + """ + if not u: + return dup_lcm(f, g, K) + + if K.is_Field: + return dmp_ff_lcm(f, g, u, K) + else: + return dmp_rr_lcm(f, g, u, K) + + +def dmp_content(f, u, K): + """ + Returns GCD of multivariate coefficients. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y, = ring("x,y", ZZ) + + >>> R.dmp_content(2*x*y + 6*x + 4*y + 12) + 2*y + 6 + + """ + cont, v = dmp_LC(f, K), u - 1 + + if dmp_zero_p(f, u): + return cont + + for c in f[1:]: + cont = dmp_gcd(cont, c, v, K) + + if dmp_one_p(cont, v, K): + break + + if K.is_negative(dmp_ground_LC(cont, v, K)): + return dmp_neg(cont, v, K) + else: + return cont + + +def dmp_primitive(f, u, K): + """ + Returns multivariate content and a primitive polynomial. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y, = ring("x,y", ZZ) + + >>> R.dmp_primitive(2*x*y + 6*x + 4*y + 12) + (2*y + 6, x + 2) + + """ + cont, v = dmp_content(f, u, K), u - 1 + + if dmp_zero_p(f, u) or dmp_one_p(cont, v, K): + return cont, f + else: + return cont, [ dmp_quo(c, cont, v, K) for c in f ] + + +def dup_cancel(f, g, K, include=True): + """ + Cancel common factors in a rational function `f/g`. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_cancel(2*x**2 - 2, x**2 - 2*x + 1) + (2*x + 2, x - 1) + + """ + return dmp_cancel(f, g, 0, K, include=include) + + +def dmp_cancel(f, g, u, K, include=True): + """ + Cancel common factors in a rational function `f/g`. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + >>> R.dmp_cancel(2*x**2 - 2, x**2 - 2*x + 1) + (2*x + 2, x - 1) + + """ + K0 = None + + if K.is_Field and K.has_assoc_Ring: + K0, K = K, K.get_ring() + + cq, f = dmp_clear_denoms(f, u, K0, K, convert=True) + cp, g = dmp_clear_denoms(g, u, K0, K, convert=True) + else: + cp, cq = K.one, K.one + + _, p, q = dmp_inner_gcd(f, g, u, K) + + if K0 is not None: + _, cp, cq = K.cofactors(cp, cq) + + p = dmp_convert(p, u, K, K0) + q = dmp_convert(q, u, K, K0) + + K = K0 + + p_neg = K.is_negative(dmp_ground_LC(p, u, K)) + q_neg = K.is_negative(dmp_ground_LC(q, u, K)) + + if p_neg and q_neg: + p, q = dmp_neg(p, u, K), dmp_neg(q, u, K) + elif p_neg: + cp, p = -cp, dmp_neg(p, u, K) + elif q_neg: + cp, q = -cp, dmp_neg(q, u, K) + + if not include: + return cp, cq, p, q + + p = dmp_mul_ground(p, cp, u, K) + q = dmp_mul_ground(q, cq, u, K) + + return p, q diff --git a/phivenv/Lib/site-packages/sympy/polys/factortools.py b/phivenv/Lib/site-packages/sympy/polys/factortools.py new file mode 100644 index 0000000000000000000000000000000000000000..021a6b06cb8802748deef6c69448ebc50503269b --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/factortools.py @@ -0,0 +1,1648 @@ +"""Polynomial factorization routines in characteristic zero. """ + +from sympy.external.gmpy import GROUND_TYPES + +from sympy.core.random import _randint + +from sympy.polys.galoistools import ( + gf_from_int_poly, gf_to_int_poly, + gf_lshift, gf_add_mul, gf_mul, + gf_div, gf_rem, + gf_gcdex, + gf_sqf_p, + gf_factor_sqf, gf_factor) + +from sympy.polys.densebasic import ( + dup_LC, dmp_LC, dmp_ground_LC, + dup_TC, + dup_convert, dmp_convert, + dup_degree, dmp_degree, + dmp_degree_in, dmp_degree_list, + dmp_from_dict, + dmp_zero_p, + dmp_one, + dmp_nest, dmp_raise, + dup_strip, + dmp_ground, + dup_inflate, + dmp_exclude, dmp_include, + dmp_inject, dmp_eject, + dup_terms_gcd, dmp_terms_gcd) + +from sympy.polys.densearith import ( + dup_neg, dmp_neg, + dup_add, dmp_add, + dup_sub, dmp_sub, + dup_mul, dmp_mul, + dup_sqr, + dmp_pow, + dup_div, dmp_div, + dup_quo, dmp_quo, + dmp_expand, + dmp_add_mul, + dup_sub_mul, dmp_sub_mul, + dup_lshift, + dup_max_norm, dmp_max_norm, + dup_l1_norm, + dup_mul_ground, dmp_mul_ground, + dup_quo_ground, dmp_quo_ground) + +from sympy.polys.densetools import ( + dup_clear_denoms, dmp_clear_denoms, + dup_trunc, dmp_ground_trunc, + dup_content, + dup_monic, dmp_ground_monic, + dup_primitive, dmp_ground_primitive, + dmp_eval_tail, + dmp_eval_in, dmp_diff_eval_in, + dup_shift, dmp_shift, dup_mirror) + +from sympy.polys.euclidtools import ( + dmp_primitive, + dup_inner_gcd, dmp_inner_gcd) + +from sympy.polys.sqfreetools import ( + dup_sqf_p, + dup_sqf_norm, dmp_sqf_norm, + dup_sqf_part, dmp_sqf_part, + _dup_check_degrees, _dmp_check_degrees, + ) + +from sympy.polys.polyutils import _sort_factors +from sympy.polys.polyconfig import query + +from sympy.polys.polyerrors import ( + ExtraneousFactors, DomainError, CoercionFailed, EvaluationFailed) + +from sympy.utilities import subsets + +from math import ceil as _ceil, log as _log, log2 as _log2 + + +if GROUND_TYPES == 'flint': + from flint import fmpz_poly +else: + fmpz_poly = None + + +def dup_trial_division(f, factors, K): + """ + Determine multiplicities of factors for a univariate polynomial + using trial division. + + An error will be raised if any factor does not divide ``f``. + """ + result = [] + + for factor in factors: + k = 0 + + while True: + q, r = dup_div(f, factor, K) + + if not r: + f, k = q, k + 1 + else: + break + + if k == 0: + raise RuntimeError("trial division failed") + + result.append((factor, k)) + + return _sort_factors(result) + + +def dmp_trial_division(f, factors, u, K): + """ + Determine multiplicities of factors for a multivariate polynomial + using trial division. + + An error will be raised if any factor does not divide ``f``. + """ + result = [] + + for factor in factors: + k = 0 + + while True: + q, r = dmp_div(f, factor, u, K) + + if dmp_zero_p(r, u): + f, k = q, k + 1 + else: + break + + if k == 0: + raise RuntimeError("trial division failed") + + result.append((factor, k)) + + return _sort_factors(result) + + +def dup_zz_mignotte_bound(f, K): + """ + The Knuth-Cohen variant of Mignotte bound for + univariate polynomials in ``K[x]``. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> f = x**3 + 14*x**2 + 56*x + 64 + >>> R.dup_zz_mignotte_bound(f) + 152 + + By checking ``factor(f)`` we can see that max coeff is 8 + + Also consider a case that ``f`` is irreducible for example + ``f = 2*x**2 + 3*x + 4``. To avoid a bug for these cases, we return the + bound plus the max coefficient of ``f`` + + >>> f = 2*x**2 + 3*x + 4 + >>> R.dup_zz_mignotte_bound(f) + 6 + + Lastly, to see the difference between the new and the old Mignotte bound + consider the irreducible polynomial: + + >>> f = 87*x**7 + 4*x**6 + 80*x**5 + 17*x**4 + 9*x**3 + 12*x**2 + 49*x + 26 + >>> R.dup_zz_mignotte_bound(f) + 744 + + The new Mignotte bound is 744 whereas the old one (SymPy 1.5.1) is 1937664. + + + References + ========== + + ..[1] [Abbott13]_ + + """ + from sympy.functions.combinatorial.factorials import binomial + d = dup_degree(f) + delta = _ceil(d / 2) + delta2 = _ceil(delta / 2) + + # euclidean-norm + eucl_norm = K.sqrt( sum( cf**2 for cf in f ) ) + + # biggest values of binomial coefficients (p. 538 of reference) + t1 = binomial(delta - 1, delta2) + t2 = binomial(delta - 1, delta2 - 1) + + lc = K.abs(dup_LC(f, K)) # leading coefficient + bound = t1 * eucl_norm + t2 * lc # (p. 538 of reference) + bound += dup_max_norm(f, K) # add max coeff for irreducible polys + bound = _ceil(bound / 2) * 2 # round up to even integer + + return bound + +def dmp_zz_mignotte_bound(f, u, K): + """Mignotte bound for multivariate polynomials in `K[X]`. """ + a = dmp_max_norm(f, u, K) + b = abs(dmp_ground_LC(f, u, K)) + n = sum(dmp_degree_list(f, u)) + + return K.sqrt(K(n + 1))*2**n*a*b + + +def dup_zz_hensel_step(m, f, g, h, s, t, K): + """ + One step in Hensel lifting in `Z[x]`. + + Given positive integer `m` and `Z[x]` polynomials `f`, `g`, `h`, `s` + and `t` such that:: + + f = g*h (mod m) + s*g + t*h = 1 (mod m) + + lc(f) is not a zero divisor (mod m) + lc(h) = 1 + + deg(f) = deg(g) + deg(h) + deg(s) < deg(h) + deg(t) < deg(g) + + returns polynomials `G`, `H`, `S` and `T`, such that:: + + f = G*H (mod m**2) + S*G + T*H = 1 (mod m**2) + + References + ========== + + .. [1] [Gathen99]_ + + """ + M = m**2 + + e = dup_sub_mul(f, g, h, K) + e = dup_trunc(e, M, K) + + q, r = dup_div(dup_mul(s, e, K), h, K) + + q = dup_trunc(q, M, K) + r = dup_trunc(r, M, K) + + u = dup_add(dup_mul(t, e, K), dup_mul(q, g, K), K) + G = dup_trunc(dup_add(g, u, K), M, K) + H = dup_trunc(dup_add(h, r, K), M, K) + + u = dup_add(dup_mul(s, G, K), dup_mul(t, H, K), K) + b = dup_trunc(dup_sub(u, [K.one], K), M, K) + + c, d = dup_div(dup_mul(s, b, K), H, K) + + c = dup_trunc(c, M, K) + d = dup_trunc(d, M, K) + + u = dup_add(dup_mul(t, b, K), dup_mul(c, G, K), K) + S = dup_trunc(dup_sub(s, d, K), M, K) + T = dup_trunc(dup_sub(t, u, K), M, K) + + return G, H, S, T + + +def dup_zz_hensel_lift(p, f, f_list, l, K): + r""" + Multifactor Hensel lifting in `Z[x]`. + + Given a prime `p`, polynomial `f` over `Z[x]` such that `lc(f)` + is a unit modulo `p`, monic pair-wise coprime polynomials `f_i` + over `Z[x]` satisfying:: + + f = lc(f) f_1 ... f_r (mod p) + + and a positive integer `l`, returns a list of monic polynomials + `F_1,\ F_2,\ \dots,\ F_r` satisfying:: + + f = lc(f) F_1 ... F_r (mod p**l) + + F_i = f_i (mod p), i = 1..r + + References + ========== + + .. [1] [Gathen99]_ + + """ + r = len(f_list) + lc = dup_LC(f, K) + + if r == 1: + F = dup_mul_ground(f, K.gcdex(lc, p**l)[0], K) + return [ dup_trunc(F, p**l, K) ] + + m = p + k = r // 2 + d = int(_ceil(_log2(l))) + + g = gf_from_int_poly([lc], p) + + for f_i in f_list[:k]: + g = gf_mul(g, gf_from_int_poly(f_i, p), p, K) + + h = gf_from_int_poly(f_list[k], p) + + for f_i in f_list[k + 1:]: + h = gf_mul(h, gf_from_int_poly(f_i, p), p, K) + + s, t, _ = gf_gcdex(g, h, p, K) + + g = gf_to_int_poly(g, p) + h = gf_to_int_poly(h, p) + s = gf_to_int_poly(s, p) + t = gf_to_int_poly(t, p) + + for _ in range(1, d + 1): + (g, h, s, t), m = dup_zz_hensel_step(m, f, g, h, s, t, K), m**2 + + return dup_zz_hensel_lift(p, g, f_list[:k], l, K) \ + + dup_zz_hensel_lift(p, h, f_list[k:], l, K) + +def _test_pl(fc, q, pl): + if q > pl // 2: + q = q - pl + if not q: + return True + return fc % q == 0 + +def dup_zz_zassenhaus(f, K): + """Factor primitive square-free polynomials in `Z[x]`. """ + n = dup_degree(f) + + if n == 1: + return [f] + + from sympy.ntheory import isprime + + fc = f[-1] + A = dup_max_norm(f, K) + b = dup_LC(f, K) + B = int(abs(K.sqrt(K(n + 1))*2**n*A*b)) + C = int((n + 1)**(2*n)*A**(2*n - 1)) + gamma = int(_ceil(2*_log2(C))) + bound = int(2*gamma*_log(gamma)) + a = [] + # choose a prime number `p` such that `f` be square free in Z_p + # if there are many factors in Z_p, choose among a few different `p` + # the one with fewer factors + for px in range(3, bound + 1): + if not isprime(px) or b % px == 0: + continue + + px = K.convert(px) + + F = gf_from_int_poly(f, px) + + if not gf_sqf_p(F, px, K): + continue + fsqfx = gf_factor_sqf(F, px, K)[1] + a.append((px, fsqfx)) + if len(fsqfx) < 15 or len(a) > 4: + break + p, fsqf = min(a, key=lambda x: len(x[1])) + + l = int(_ceil(_log(2*B + 1, p))) + + modular = [gf_to_int_poly(ff, p) for ff in fsqf] + + g = dup_zz_hensel_lift(p, f, modular, l, K) + + sorted_T = range(len(g)) + T = set(sorted_T) + factors, s = [], 1 + pl = p**l + + while 2*s <= len(T): + for S in subsets(sorted_T, s): + # lift the constant coefficient of the product `G` of the factors + # in the subset `S`; if it is does not divide `fc`, `G` does + # not divide the input polynomial + + if b == 1: + q = 1 + for i in S: + q = q*g[i][-1] + q = q % pl + if not _test_pl(fc, q, pl): + continue + else: + G = [b] + for i in S: + G = dup_mul(G, g[i], K) + G = dup_trunc(G, pl, K) + G = dup_primitive(G, K)[1] + q = G[-1] + if q and fc % q != 0: + continue + + H = [b] + S = set(S) + T_S = T - S + + if b == 1: + G = [b] + for i in S: + G = dup_mul(G, g[i], K) + G = dup_trunc(G, pl, K) + + for i in T_S: + H = dup_mul(H, g[i], K) + + H = dup_trunc(H, pl, K) + + G_norm = dup_l1_norm(G, K) + H_norm = dup_l1_norm(H, K) + + if G_norm*H_norm <= B: + T = T_S + sorted_T = [i for i in sorted_T if i not in S] + + G = dup_primitive(G, K)[1] + f = dup_primitive(H, K)[1] + + factors.append(G) + b = dup_LC(f, K) + + break + else: + s += 1 + + return factors + [f] + + +def dup_zz_irreducible_p(f, K): + """Test irreducibility using Eisenstein's criterion. """ + lc = dup_LC(f, K) + tc = dup_TC(f, K) + + e_fc = dup_content(f[1:], K) + + if e_fc: + from sympy.ntheory import factorint + e_ff = factorint(int(e_fc)) + + for p in e_ff.keys(): + if (lc % p) and (tc % p**2): + return True + + +def dup_cyclotomic_p(f, K, irreducible=False): + """ + Efficiently test if ``f`` is a cyclotomic polynomial. + + Examples + ======== + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> f = x**16 + x**14 - x**10 + x**8 - x**6 + x**2 + 1 + >>> R.dup_cyclotomic_p(f) + False + + >>> g = x**16 + x**14 - x**10 - x**8 - x**6 + x**2 + 1 + >>> R.dup_cyclotomic_p(g) + True + + References + ========== + + Bradford, Russell J., and James H. Davenport. "Effective tests for + cyclotomic polynomials." In International Symposium on Symbolic and + Algebraic Computation, pp. 244-251. Springer, Berlin, Heidelberg, 1988. + + """ + if K.is_QQ: + try: + K0, K = K, K.get_ring() + f = dup_convert(f, K0, K) + except CoercionFailed: + return False + elif not K.is_ZZ: + return False + + lc = dup_LC(f, K) + tc = dup_TC(f, K) + + if lc != 1 or (tc != -1 and tc != 1): + return False + + if not irreducible: + coeff, factors = dup_factor_list(f, K) + + if coeff != K.one or factors != [(f, 1)]: + return False + + n = dup_degree(f) + g, h = [], [] + + for i in range(n, -1, -2): + g.insert(0, f[i]) + + for i in range(n - 1, -1, -2): + h.insert(0, f[i]) + + g = dup_sqr(dup_strip(g), K) + h = dup_sqr(dup_strip(h), K) + + F = dup_sub(g, dup_lshift(h, 1, K), K) + + if K.is_negative(dup_LC(F, K)): + F = dup_neg(F, K) + + if F == f: + return True + + g = dup_mirror(f, K) + + if K.is_negative(dup_LC(g, K)): + g = dup_neg(g, K) + + if F == g and dup_cyclotomic_p(g, K): + return True + + G = dup_sqf_part(F, K) + + if dup_sqr(G, K) == F and dup_cyclotomic_p(G, K): + return True + + return False + + +def dup_zz_cyclotomic_poly(n, K): + """Efficiently generate n-th cyclotomic polynomial. """ + from sympy.ntheory import factorint + h = [K.one, -K.one] + + for p, k in factorint(n).items(): + h = dup_quo(dup_inflate(h, p, K), h, K) + h = dup_inflate(h, p**(k - 1), K) + + return h + + +def _dup_cyclotomic_decompose(n, K): + from sympy.ntheory import factorint + + H = [[K.one, -K.one]] + + for p, k in factorint(n).items(): + Q = [ dup_quo(dup_inflate(h, p, K), h, K) for h in H ] + H.extend(Q) + + for i in range(1, k): + Q = [ dup_inflate(q, p, K) for q in Q ] + H.extend(Q) + + return H + + +def dup_zz_cyclotomic_factor(f, K): + """ + Efficiently factor polynomials `x**n - 1` and `x**n + 1` in `Z[x]`. + + Given a univariate polynomial `f` in `Z[x]` returns a list of factors + of `f`, provided that `f` is in the form `x**n - 1` or `x**n + 1` for + `n >= 1`. Otherwise returns None. + + Factorization is performed using cyclotomic decomposition of `f`, + which makes this method much faster that any other direct factorization + approach (e.g. Zassenhaus's). + + References + ========== + + .. [1] [Weisstein09]_ + + """ + lc_f, tc_f = dup_LC(f, K), dup_TC(f, K) + + if dup_degree(f) <= 0: + return None + + if lc_f != 1 or tc_f not in [-1, 1]: + return None + + if any(bool(cf) for cf in f[1:-1]): + return None + + n = dup_degree(f) + F = _dup_cyclotomic_decompose(n, K) + + if not K.is_one(tc_f): + return F + else: + H = [] + + for h in _dup_cyclotomic_decompose(2*n, K): + if h not in F: + H.append(h) + + return H + + +def dup_zz_factor_sqf(f, K): + """Factor square-free (non-primitive) polynomials in `Z[x]`. """ + cont, g = dup_primitive(f, K) + + n = dup_degree(g) + + if dup_LC(g, K) < 0: + cont, g = -cont, dup_neg(g, K) + + if n <= 0: + return cont, [] + elif n == 1: + return cont, [g] + + if query('USE_IRREDUCIBLE_IN_FACTOR'): + if dup_zz_irreducible_p(g, K): + return cont, [g] + + factors = None + + if query('USE_CYCLOTOMIC_FACTOR'): + factors = dup_zz_cyclotomic_factor(g, K) + + if factors is None: + factors = dup_zz_zassenhaus(g, K) + + return cont, _sort_factors(factors, multiple=False) + + +def dup_zz_factor(f, K): + """ + Factor (non square-free) polynomials in `Z[x]`. + + Given a univariate polynomial `f` in `Z[x]` computes its complete + factorization `f_1, ..., f_n` into irreducibles over integers:: + + f = content(f) f_1**k_1 ... f_n**k_n + + The factorization is computed by reducing the input polynomial + into a primitive square-free polynomial and factoring it using + Zassenhaus algorithm. Trial division is used to recover the + multiplicities of factors. + + The result is returned as a tuple consisting of:: + + (content(f), [(f_1, k_1), ..., (f_n, k_n)) + + Examples + ======== + + Consider the polynomial `f = 2*x**4 - 2`:: + + >>> from sympy.polys import ring, ZZ + >>> R, x = ring("x", ZZ) + + >>> R.dup_zz_factor(2*x**4 - 2) + (2, [(x - 1, 1), (x + 1, 1), (x**2 + 1, 1)]) + + In result we got the following factorization:: + + f = 2 (x - 1) (x + 1) (x**2 + 1) + + Note that this is a complete factorization over integers, + however over Gaussian integers we can factor the last term. + + By default, polynomials `x**n - 1` and `x**n + 1` are factored + using cyclotomic decomposition to speedup computations. To + disable this behaviour set cyclotomic=False. + + References + ========== + + .. [1] [Gathen99]_ + + """ + if GROUND_TYPES == 'flint': + f_flint = fmpz_poly(f[::-1]) + cont, factors = f_flint.factor() + factors = [(fac.coeffs()[::-1], exp) for fac, exp in factors] + return cont, _sort_factors(factors) + + cont, g = dup_primitive(f, K) + + n = dup_degree(g) + + if dup_LC(g, K) < 0: + cont, g = -cont, dup_neg(g, K) + + if n <= 0: + return cont, [] + elif n == 1: + return cont, [(g, 1)] + + if query('USE_IRREDUCIBLE_IN_FACTOR'): + if dup_zz_irreducible_p(g, K): + return cont, [(g, 1)] + + g = dup_sqf_part(g, K) + H = None + + if query('USE_CYCLOTOMIC_FACTOR'): + H = dup_zz_cyclotomic_factor(g, K) + + if H is None: + H = dup_zz_zassenhaus(g, K) + + factors = dup_trial_division(f, H, K) + + _dup_check_degrees(f, factors) + + return cont, factors + + +def dmp_zz_wang_non_divisors(E, cs, ct, K): + """Wang/EEZ: Compute a set of valid divisors. """ + result = [ cs*ct ] + + for q in E: + q = abs(q) + + for r in reversed(result): + while r != 1: + r = K.gcd(r, q) + q = q // r + + if K.is_one(q): + return None + + result.append(q) + + return result[1:] + + +def dmp_zz_wang_test_points(f, T, ct, A, u, K): + """Wang/EEZ: Test evaluation points for suitability. """ + if not dmp_eval_tail(dmp_LC(f, K), A, u - 1, K): + raise EvaluationFailed('no luck') + + g = dmp_eval_tail(f, A, u, K) + + if not dup_sqf_p(g, K): + raise EvaluationFailed('no luck') + + c, h = dup_primitive(g, K) + + if K.is_negative(dup_LC(h, K)): + c, h = -c, dup_neg(h, K) + + v = u - 1 + + E = [ dmp_eval_tail(t, A, v, K) for t, _ in T ] + D = dmp_zz_wang_non_divisors(E, c, ct, K) + + if D is not None: + return c, h, E + else: + raise EvaluationFailed('no luck') + + +def dmp_zz_wang_lead_coeffs(f, T, cs, E, H, A, u, K): + """Wang/EEZ: Compute correct leading coefficients. """ + C, J, v = [], [0]*len(E), u - 1 + + for h in H: + c = dmp_one(v, K) + d = dup_LC(h, K)*cs + + for i in reversed(range(len(E))): + k, e, (t, _) = 0, E[i], T[i] + + while not (d % e): + d, k = d//e, k + 1 + + if k != 0: + c, J[i] = dmp_mul(c, dmp_pow(t, k, v, K), v, K), 1 + + C.append(c) + + if not all(J): + raise ExtraneousFactors # pragma: no cover + + CC, HH = [], [] + + for c, h in zip(C, H): + d = dmp_eval_tail(c, A, v, K) + lc = dup_LC(h, K) + + if K.is_one(cs): + cc = lc//d + else: + g = K.gcd(lc, d) + d, cc = d//g, lc//g + h, cs = dup_mul_ground(h, d, K), cs//d + + c = dmp_mul_ground(c, cc, v, K) + + CC.append(c) + HH.append(h) + + if K.is_one(cs): + return f, HH, CC + + CCC, HHH = [], [] + + for c, h in zip(CC, HH): + CCC.append(dmp_mul_ground(c, cs, v, K)) + HHH.append(dmp_mul_ground(h, cs, 0, K)) + + f = dmp_mul_ground(f, cs**(len(H) - 1), u, K) + + return f, HHH, CCC + + +def dup_zz_diophantine(F, m, p, K): + """Wang/EEZ: Solve univariate Diophantine equations. """ + if len(F) == 2: + a, b = F + + f = gf_from_int_poly(a, p) + g = gf_from_int_poly(b, p) + + s, t, G = gf_gcdex(g, f, p, K) + + s = gf_lshift(s, m, K) + t = gf_lshift(t, m, K) + + q, s = gf_div(s, f, p, K) + + t = gf_add_mul(t, q, g, p, K) + + s = gf_to_int_poly(s, p) + t = gf_to_int_poly(t, p) + + result = [s, t] + else: + G = [F[-1]] + + for f in reversed(F[1:-1]): + G.insert(0, dup_mul(f, G[0], K)) + + S, T = [], [[1]] + + for f, g in zip(F, G): + t, s = dmp_zz_diophantine([g, f], T[-1], [], 0, p, 1, K) + T.append(t) + S.append(s) + + result, S = [], S + [T[-1]] + + for s, f in zip(S, F): + s = gf_from_int_poly(s, p) + f = gf_from_int_poly(f, p) + + r = gf_rem(gf_lshift(s, m, K), f, p, K) + s = gf_to_int_poly(r, p) + + result.append(s) + + return result + + +def dmp_zz_diophantine(F, c, A, d, p, u, K): + """Wang/EEZ: Solve multivariate Diophantine equations. """ + if not A: + S = [ [] for _ in F ] + n = dup_degree(c) + + for i, coeff in enumerate(c): + if not coeff: + continue + + T = dup_zz_diophantine(F, n - i, p, K) + + for j, (s, t) in enumerate(zip(S, T)): + t = dup_mul_ground(t, coeff, K) + S[j] = dup_trunc(dup_add(s, t, K), p, K) + else: + n = len(A) + e = dmp_expand(F, u, K) + + a, A = A[-1], A[:-1] + B, G = [], [] + + for f in F: + B.append(dmp_quo(e, f, u, K)) + G.append(dmp_eval_in(f, a, n, u, K)) + + C = dmp_eval_in(c, a, n, u, K) + + v = u - 1 + + S = dmp_zz_diophantine(G, C, A, d, p, v, K) + S = [ dmp_raise(s, 1, v, K) for s in S ] + + for s, b in zip(S, B): + c = dmp_sub_mul(c, s, b, u, K) + + c = dmp_ground_trunc(c, p, u, K) + + m = dmp_nest([K.one, -a], n, K) + M = dmp_one(n, K) + + for k in range(0, d): + if dmp_zero_p(c, u): + break + + M = dmp_mul(M, m, u, K) + C = dmp_diff_eval_in(c, k + 1, a, n, u, K) + + if not dmp_zero_p(C, v): + C = dmp_quo_ground(C, K.factorial(K(k) + 1), v, K) + T = dmp_zz_diophantine(G, C, A, d, p, v, K) + + for i, t in enumerate(T): + T[i] = dmp_mul(dmp_raise(t, 1, v, K), M, u, K) + + for i, (s, t) in enumerate(zip(S, T)): + S[i] = dmp_add(s, t, u, K) + + for t, b in zip(T, B): + c = dmp_sub_mul(c, t, b, u, K) + + c = dmp_ground_trunc(c, p, u, K) + + S = [ dmp_ground_trunc(s, p, u, K) for s in S ] + + return S + + +def dmp_zz_wang_hensel_lifting(f, H, LC, A, p, u, K): + """Wang/EEZ: Parallel Hensel lifting algorithm. """ + S, n, v = [f], len(A), u - 1 + + H = list(H) + + for i, a in enumerate(reversed(A[1:])): + s = dmp_eval_in(S[0], a, n - i, u - i, K) + S.insert(0, dmp_ground_trunc(s, p, v - i, K)) + + d = max(dmp_degree_list(f, u)[1:]) + + for j, s, a in zip(range(2, n + 2), S, A): + G, w = list(H), j - 1 + + I, J = A[:j - 2], A[j - 1:] + + for i, (h, lc) in enumerate(zip(H, LC)): + lc = dmp_ground_trunc(dmp_eval_tail(lc, J, v, K), p, w - 1, K) + H[i] = [lc] + dmp_raise(h[1:], 1, w - 1, K) + + m = dmp_nest([K.one, -a], w, K) + M = dmp_one(w, K) + + c = dmp_sub(s, dmp_expand(H, w, K), w, K) + + dj = dmp_degree_in(s, w, w) + + for k in range(0, dj): + if dmp_zero_p(c, w): + break + + M = dmp_mul(M, m, w, K) + C = dmp_diff_eval_in(c, k + 1, a, w, w, K) + + if not dmp_zero_p(C, w - 1): + C = dmp_quo_ground(C, K.factorial(K(k) + 1), w - 1, K) + T = dmp_zz_diophantine(G, C, I, d, p, w - 1, K) + + for i, (h, t) in enumerate(zip(H, T)): + h = dmp_add_mul(h, dmp_raise(t, 1, w - 1, K), M, w, K) + H[i] = dmp_ground_trunc(h, p, w, K) + + h = dmp_sub(s, dmp_expand(H, w, K), w, K) + c = dmp_ground_trunc(h, p, w, K) + + if dmp_expand(H, u, K) != f: + raise ExtraneousFactors # pragma: no cover + else: + return H + + +def dmp_zz_wang(f, u, K, mod=None, seed=None): + r""" + Factor primitive square-free polynomials in `Z[X]`. + + Given a multivariate polynomial `f` in `Z[x_1,...,x_n]`, which is + primitive and square-free in `x_1`, computes factorization of `f` into + irreducibles over integers. + + The procedure is based on Wang's Enhanced Extended Zassenhaus + algorithm. The algorithm works by viewing `f` as a univariate polynomial + in `Z[x_2,...,x_n][x_1]`, for which an evaluation mapping is computed:: + + x_2 -> a_2, ..., x_n -> a_n + + where `a_i`, for `i = 2, \dots, n`, are carefully chosen integers. The + mapping is used to transform `f` into a univariate polynomial in `Z[x_1]`, + which can be factored efficiently using Zassenhaus algorithm. The last + step is to lift univariate factors to obtain true multivariate + factors. For this purpose a parallel Hensel lifting procedure is used. + + The parameter ``seed`` is passed to _randint and can be used to seed randint + (when an integer) or (for testing purposes) can be a sequence of numbers. + + References + ========== + + .. [1] [Wang78]_ + .. [2] [Geddes92]_ + + """ + from sympy.ntheory import nextprime + + randint = _randint(seed) + + ct, T = dmp_zz_factor(dmp_LC(f, K), u - 1, K) + + b = dmp_zz_mignotte_bound(f, u, K) + p = K(nextprime(b)) + + if mod is None: + if u == 1: + mod = 2 + else: + mod = 1 + + history, configs, A, r = set(), [], [K.zero]*u, None + + try: + cs, s, E = dmp_zz_wang_test_points(f, T, ct, A, u, K) + + _, H = dup_zz_factor_sqf(s, K) + + r = len(H) + + if r == 1: + return [f] + + configs = [(s, cs, E, H, A)] + except EvaluationFailed: + pass + + eez_num_configs = query('EEZ_NUMBER_OF_CONFIGS') + eez_num_tries = query('EEZ_NUMBER_OF_TRIES') + eez_mod_step = query('EEZ_MODULUS_STEP') + + while len(configs) < eez_num_configs: + for _ in range(eez_num_tries): + A = [ K(randint(-mod, mod)) for _ in range(u) ] + + if tuple(A) not in history: + history.add(tuple(A)) + else: + continue + + try: + cs, s, E = dmp_zz_wang_test_points(f, T, ct, A, u, K) + except EvaluationFailed: + continue + + _, H = dup_zz_factor_sqf(s, K) + + rr = len(H) + + if r is not None: + if rr != r: # pragma: no cover + if rr < r: + configs, r = [], rr + else: + continue + else: + r = rr + + if r == 1: + return [f] + + configs.append((s, cs, E, H, A)) + + if len(configs) == eez_num_configs: + break + else: + mod += eez_mod_step + + s_norm, s_arg, i = None, 0, 0 + + for s, _, _, _, _ in configs: + _s_norm = dup_max_norm(s, K) + + if s_norm is not None: + if _s_norm < s_norm: + s_norm = _s_norm + s_arg = i + else: + s_norm = _s_norm + + i += 1 + + _, cs, E, H, A = configs[s_arg] + orig_f = f + + try: + f, H, LC = dmp_zz_wang_lead_coeffs(f, T, cs, E, H, A, u, K) + factors = dmp_zz_wang_hensel_lifting(f, H, LC, A, p, u, K) + except ExtraneousFactors: # pragma: no cover + if query('EEZ_RESTART_IF_NEEDED'): + return dmp_zz_wang(orig_f, u, K, mod + 1) + else: + raise ExtraneousFactors( + "we need to restart algorithm with better parameters") + + result = [] + + for f in factors: + _, f = dmp_ground_primitive(f, u, K) + + if K.is_negative(dmp_ground_LC(f, u, K)): + f = dmp_neg(f, u, K) + + result.append(f) + + return result + + +def dmp_zz_factor(f, u, K): + r""" + Factor (non square-free) polynomials in `Z[X]`. + + Given a multivariate polynomial `f` in `Z[x]` computes its complete + factorization `f_1, \dots, f_n` into irreducibles over integers:: + + f = content(f) f_1**k_1 ... f_n**k_n + + The factorization is computed by reducing the input polynomial + into a primitive square-free polynomial and factoring it using + Enhanced Extended Zassenhaus (EEZ) algorithm. Trial division + is used to recover the multiplicities of factors. + + The result is returned as a tuple consisting of:: + + (content(f), [(f_1, k_1), ..., (f_n, k_n)) + + Consider polynomial `f = 2*(x**2 - y**2)`:: + + >>> from sympy.polys import ring, ZZ + >>> R, x,y = ring("x,y", ZZ) + + >>> R.dmp_zz_factor(2*x**2 - 2*y**2) + (2, [(x - y, 1), (x + y, 1)]) + + In result we got the following factorization:: + + f = 2 (x - y) (x + y) + + References + ========== + + .. [1] [Gathen99]_ + + """ + if not u: + return dup_zz_factor(f, K) + + if dmp_zero_p(f, u): + return K.zero, [] + + cont, g = dmp_ground_primitive(f, u, K) + + if dmp_ground_LC(g, u, K) < 0: + cont, g = -cont, dmp_neg(g, u, K) + + if all(d <= 0 for d in dmp_degree_list(g, u)): + return cont, [] + + G, g = dmp_primitive(g, u, K) + + factors = [] + + if dmp_degree(g, u) > 0: + g = dmp_sqf_part(g, u, K) + H = dmp_zz_wang(g, u, K) + factors = dmp_trial_division(f, H, u, K) + + for g, k in dmp_zz_factor(G, u - 1, K)[1]: + factors.insert(0, ([g], k)) + + _dmp_check_degrees(f, u, factors) + + return cont, _sort_factors(factors) + + +def dup_qq_i_factor(f, K0): + """Factor univariate polynomials into irreducibles in `QQ_I[x]`. """ + # Factor in QQ + K1 = K0.as_AlgebraicField() + f = dup_convert(f, K0, K1) + coeff, factors = dup_factor_list(f, K1) + factors = [(dup_convert(fac, K1, K0), i) for fac, i in factors] + coeff = K0.convert(coeff, K1) + return coeff, factors + + +def dup_zz_i_factor(f, K0): + """Factor univariate polynomials into irreducibles in `ZZ_I[x]`. """ + # First factor in QQ_I + K1 = K0.get_field() + f = dup_convert(f, K0, K1) + coeff, factors = dup_qq_i_factor(f, K1) + + new_factors = [] + for fac, i in factors: + # Extract content + fac_denom, fac_num = dup_clear_denoms(fac, K1) + fac_num_ZZ_I = dup_convert(fac_num, K1, K0) + content, fac_prim = dmp_ground_primitive(fac_num_ZZ_I, 0, K0) + + coeff = (coeff * content ** i) // fac_denom ** i + new_factors.append((fac_prim, i)) + + factors = new_factors + coeff = K0.convert(coeff, K1) + return coeff, factors + + +def dmp_qq_i_factor(f, u, K0): + """Factor multivariate polynomials into irreducibles in `QQ_I[X]`. """ + # Factor in QQ + K1 = K0.as_AlgebraicField() + f = dmp_convert(f, u, K0, K1) + coeff, factors = dmp_factor_list(f, u, K1) + factors = [(dmp_convert(fac, u, K1, K0), i) for fac, i in factors] + coeff = K0.convert(coeff, K1) + return coeff, factors + + +def dmp_zz_i_factor(f, u, K0): + """Factor multivariate polynomials into irreducibles in `ZZ_I[X]`. """ + # First factor in QQ_I + K1 = K0.get_field() + f = dmp_convert(f, u, K0, K1) + coeff, factors = dmp_qq_i_factor(f, u, K1) + + new_factors = [] + for fac, i in factors: + # Extract content + fac_denom, fac_num = dmp_clear_denoms(fac, u, K1) + fac_num_ZZ_I = dmp_convert(fac_num, u, K1, K0) + content, fac_prim = dmp_ground_primitive(fac_num_ZZ_I, u, K0) + + coeff = (coeff * content ** i) // fac_denom ** i + new_factors.append((fac_prim, i)) + + factors = new_factors + coeff = K0.convert(coeff, K1) + return coeff, factors + + +def dup_ext_factor(f, K): + r"""Factor univariate polynomials over algebraic number fields. + + The domain `K` must be an algebraic number field `k(a)` (see :ref:`QQ(a)`). + + Examples + ======== + + First define the algebraic number field `K = \mathbb{Q}(\sqrt{2})`: + + >>> from sympy import QQ, sqrt + >>> from sympy.polys.factortools import dup_ext_factor + >>> K = QQ.algebraic_field(sqrt(2)) + + We can now factorise the polynomial `x^2 - 2` over `K`: + + >>> p = [K(1), K(0), K(-2)] # x^2 - 2 + >>> p1 = [K(1), -K.unit] # x - sqrt(2) + >>> p2 = [K(1), +K.unit] # x + sqrt(2) + >>> dup_ext_factor(p, K) == (K.one, [(p1, 1), (p2, 1)]) + True + + Usually this would be done at a higher level: + + >>> from sympy import factor + >>> from sympy.abc import x + >>> factor(x**2 - 2, extension=sqrt(2)) + (x - sqrt(2))*(x + sqrt(2)) + + Explanation + =========== + + Uses Trager's algorithm. In particular this function is algorithm + ``alg_factor`` from [Trager76]_. + + If `f` is a polynomial in `k(a)[x]` then its norm `g(x)` is a polynomial in + `k[x]`. If `g(x)` is square-free and has irreducible factors `g_1(x)`, + `g_2(x)`, `\cdots` then the irreducible factors of `f` in `k(a)[x]` are + given by `f_i(x) = \gcd(f(x), g_i(x))` where the GCD is computed in + `k(a)[x]`. + + The first step in Trager's algorithm is to find an integer shift `s` so + that `f(x-sa)` has square-free norm. Then the norm is factorized in `k[x]` + and the GCD of (shifted) `f` with each factor gives the shifted factors of + `f`. At the end the shift is undone to recover the unshifted factors of `f` + in `k(a)[x]`. + + The algorithm reduces the problem of factorization in `k(a)[x]` to + factorization in `k[x]` with the main additional steps being to compute the + norm (a resultant calculation in `k[x,y]`) and some polynomial GCDs in + `k(a)[x]`. + + In practice in SymPy the base field `k` will be the rationals :ref:`QQ` and + this function factorizes a polynomial with coefficients in an algebraic + number field like `\mathbb{Q}(\sqrt{2})`. + + See Also + ======== + + dmp_ext_factor: + Analogous function for multivariate polynomials over ``k(a)``. + dup_sqf_norm: + Subroutine ``sqfr_norm`` also from [Trager76]_. + sympy.polys.polytools.factor: + The high-level function that ultimately uses this function as needed. + """ + n, lc = dup_degree(f), dup_LC(f, K) + + f = dup_monic(f, K) + + if n <= 0: + return lc, [] + if n == 1: + return lc, [(f, 1)] + + f, F = dup_sqf_part(f, K), f + s, g, r = dup_sqf_norm(f, K) + + factors = dup_factor_list_include(r, K.dom) + + if len(factors) == 1: + return lc, [(f, n//dup_degree(f))] + + H = s*K.unit + + for i, (factor, _) in enumerate(factors): + h = dup_convert(factor, K.dom, K) + h, _, g = dup_inner_gcd(h, g, K) + h = dup_shift(h, H, K) + factors[i] = h + + factors = dup_trial_division(F, factors, K) + + _dup_check_degrees(F, factors) + + return lc, factors + + +def dmp_ext_factor(f, u, K): + r"""Factor multivariate polynomials over algebraic number fields. + + The domain `K` must be an algebraic number field `k(a)` (see :ref:`QQ(a)`). + + Examples + ======== + + First define the algebraic number field `K = \mathbb{Q}(\sqrt{2})`: + + >>> from sympy import QQ, sqrt + >>> from sympy.polys.factortools import dmp_ext_factor + >>> K = QQ.algebraic_field(sqrt(2)) + + We can now factorise the polynomial `x^2 y^2 - 2` over `K`: + + >>> p = [[K(1),K(0),K(0)], [], [K(-2)]] # x**2*y**2 - 2 + >>> p1 = [[K(1),K(0)], [-K.unit]] # x*y - sqrt(2) + >>> p2 = [[K(1),K(0)], [+K.unit]] # x*y + sqrt(2) + >>> dmp_ext_factor(p, 1, K) == (K.one, [(p1, 1), (p2, 1)]) + True + + Usually this would be done at a higher level: + + >>> from sympy import factor + >>> from sympy.abc import x, y + >>> factor(x**2*y**2 - 2, extension=sqrt(2)) + (x*y - sqrt(2))*(x*y + sqrt(2)) + + Explanation + =========== + + This is Trager's algorithm for multivariate polynomials. In particular this + function is algorithm ``alg_factor`` from [Trager76]_. + + See :func:`dup_ext_factor` for explanation. + + See Also + ======== + + dup_ext_factor: + Analogous function for univariate polynomials over ``k(a)``. + dmp_sqf_norm: + Multivariate version of subroutine ``sqfr_norm`` also from [Trager76]_. + sympy.polys.polytools.factor: + The high-level function that ultimately uses this function as needed. + """ + if not u: + return dup_ext_factor(f, K) + + lc = dmp_ground_LC(f, u, K) + f = dmp_ground_monic(f, u, K) + + if all(d <= 0 for d in dmp_degree_list(f, u)): + return lc, [] + + f, F = dmp_sqf_part(f, u, K), f + s, g, r = dmp_sqf_norm(f, u, K) + + factors = dmp_factor_list_include(r, u, K.dom) + + if len(factors) == 1: + factors = [f] + else: + for i, (factor, _) in enumerate(factors): + h = dmp_convert(factor, u, K.dom, K) + h, _, g = dmp_inner_gcd(h, g, u, K) + a = [si*K.unit for si in s] + h = dmp_shift(h, a, u, K) + factors[i] = h + + result = dmp_trial_division(F, factors, u, K) + + _dmp_check_degrees(F, u, result) + + return lc, result + + +def dup_gf_factor(f, K): + """Factor univariate polynomials over finite fields. """ + f = dup_convert(f, K, K.dom) + + coeff, factors = gf_factor(f, K.mod, K.dom) + + for i, (f, k) in enumerate(factors): + factors[i] = (dup_convert(f, K.dom, K), k) + + return K.convert(coeff, K.dom), factors + + +def dmp_gf_factor(f, u, K): + """Factor multivariate polynomials over finite fields. """ + raise NotImplementedError('multivariate polynomials over finite fields') + + +def dup_factor_list(f, K0): + """Factor univariate polynomials into irreducibles in `K[x]`. """ + j, f = dup_terms_gcd(f, K0) + cont, f = dup_primitive(f, K0) + + if K0.is_FiniteField: + coeff, factors = dup_gf_factor(f, K0) + elif K0.is_Algebraic: + coeff, factors = dup_ext_factor(f, K0) + elif K0.is_GaussianRing: + coeff, factors = dup_zz_i_factor(f, K0) + elif K0.is_GaussianField: + coeff, factors = dup_qq_i_factor(f, K0) + else: + if not K0.is_Exact: + K0_inexact, K0 = K0, K0.get_exact() + f = dup_convert(f, K0_inexact, K0) + else: + K0_inexact = None + + if K0.is_Field: + K = K0.get_ring() + + denom, f = dup_clear_denoms(f, K0, K) + f = dup_convert(f, K0, K) + else: + K = K0 + + if K.is_ZZ: + coeff, factors = dup_zz_factor(f, K) + elif K.is_Poly: + f, u = dmp_inject(f, 0, K) + + coeff, factors = dmp_factor_list(f, u, K.dom) + + for i, (f, k) in enumerate(factors): + factors[i] = (dmp_eject(f, u, K), k) + + coeff = K.convert(coeff, K.dom) + else: # pragma: no cover + raise DomainError('factorization not supported over %s' % K0) + + if K0.is_Field: + for i, (f, k) in enumerate(factors): + factors[i] = (dup_convert(f, K, K0), k) + + coeff = K0.convert(coeff, K) + coeff = K0.quo(coeff, denom) + + if K0_inexact: + for i, (f, k) in enumerate(factors): + max_norm = dup_max_norm(f, K0) + f = dup_quo_ground(f, max_norm, K0) + f = dup_convert(f, K0, K0_inexact) + factors[i] = (f, k) + coeff = K0.mul(coeff, K0.pow(max_norm, k)) + + coeff = K0_inexact.convert(coeff, K0) + K0 = K0_inexact + + if j: + factors.insert(0, ([K0.one, K0.zero], j)) + + return coeff*cont, _sort_factors(factors) + + +def dup_factor_list_include(f, K): + """Factor univariate polynomials into irreducibles in `K[x]`. """ + coeff, factors = dup_factor_list(f, K) + + if not factors: + return [(dup_strip([coeff]), 1)] + else: + g = dup_mul_ground(factors[0][0], coeff, K) + return [(g, factors[0][1])] + factors[1:] + + +def dmp_factor_list(f, u, K0): + """Factor multivariate polynomials into irreducibles in `K[X]`. """ + if not u: + return dup_factor_list(f, K0) + + J, f = dmp_terms_gcd(f, u, K0) + cont, f = dmp_ground_primitive(f, u, K0) + + if K0.is_FiniteField: # pragma: no cover + coeff, factors = dmp_gf_factor(f, u, K0) + elif K0.is_Algebraic: + coeff, factors = dmp_ext_factor(f, u, K0) + elif K0.is_GaussianRing: + coeff, factors = dmp_zz_i_factor(f, u, K0) + elif K0.is_GaussianField: + coeff, factors = dmp_qq_i_factor(f, u, K0) + else: + if not K0.is_Exact: + K0_inexact, K0 = K0, K0.get_exact() + f = dmp_convert(f, u, K0_inexact, K0) + else: + K0_inexact = None + + if K0.is_Field: + K = K0.get_ring() + + denom, f = dmp_clear_denoms(f, u, K0, K) + f = dmp_convert(f, u, K0, K) + else: + K = K0 + + if K.is_ZZ: + levels, f, v = dmp_exclude(f, u, K) + coeff, factors = dmp_zz_factor(f, v, K) + + for i, (f, k) in enumerate(factors): + factors[i] = (dmp_include(f, levels, v, K), k) + elif K.is_Poly: + f, v = dmp_inject(f, u, K) + + coeff, factors = dmp_factor_list(f, v, K.dom) + + for i, (f, k) in enumerate(factors): + factors[i] = (dmp_eject(f, v, K), k) + + coeff = K.convert(coeff, K.dom) + else: # pragma: no cover + raise DomainError('factorization not supported over %s' % K0) + + if K0.is_Field: + for i, (f, k) in enumerate(factors): + factors[i] = (dmp_convert(f, u, K, K0), k) + + coeff = K0.convert(coeff, K) + coeff = K0.quo(coeff, denom) + + if K0_inexact: + for i, (f, k) in enumerate(factors): + max_norm = dmp_max_norm(f, u, K0) + f = dmp_quo_ground(f, max_norm, u, K0) + f = dmp_convert(f, u, K0, K0_inexact) + factors[i] = (f, k) + coeff = K0.mul(coeff, K0.pow(max_norm, k)) + + coeff = K0_inexact.convert(coeff, K0) + K0 = K0_inexact + + for i, j in enumerate(reversed(J)): + if not j: + continue + + term = {(0,)*(u - i) + (1,) + (0,)*i: K0.one} + factors.insert(0, (dmp_from_dict(term, u, K0), j)) + + return coeff*cont, _sort_factors(factors) + + +def dmp_factor_list_include(f, u, K): + """Factor multivariate polynomials into irreducibles in `K[X]`. """ + if not u: + return dup_factor_list_include(f, K) + + coeff, factors = dmp_factor_list(f, u, K) + + if not factors: + return [(dmp_ground(coeff, u), 1)] + else: + g = dmp_mul_ground(factors[0][0], coeff, u, K) + return [(g, factors[0][1])] + factors[1:] + + +def dup_irreducible_p(f, K): + """ + Returns ``True`` if a univariate polynomial ``f`` has no factors + over its domain. + """ + return dmp_irreducible_p(f, 0, K) + + +def dmp_irreducible_p(f, u, K): + """ + Returns ``True`` if a multivariate polynomial ``f`` has no factors + over its domain. + """ + _, factors = dmp_factor_list(f, u, K) + + if not factors: + return True + elif len(factors) > 1: + return False + else: + _, k = factors[0] + return k == 1 diff --git a/phivenv/Lib/site-packages/sympy/polys/fglmtools.py b/phivenv/Lib/site-packages/sympy/polys/fglmtools.py new file mode 100644 index 0000000000000000000000000000000000000000..d68fe5bc2a40741e39b89163d393f7b57e6b1c49 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/fglmtools.py @@ -0,0 +1,153 @@ +"""Implementation of matrix FGLM Groebner basis conversion algorithm. """ + + +from sympy.polys.monomials import monomial_mul, monomial_div + +def matrix_fglm(F, ring, O_to): + """ + Converts the reduced Groebner basis ``F`` of a zero-dimensional + ideal w.r.t. ``O_from`` to a reduced Groebner basis + w.r.t. ``O_to``. + + References + ========== + + .. [1] J.C. Faugere, P. Gianni, D. Lazard, T. Mora (1994). Efficient + Computation of Zero-dimensional Groebner Bases by Change of + Ordering + """ + domain = ring.domain + ngens = ring.ngens + + ring_to = ring.clone(order=O_to) + + old_basis = _basis(F, ring) + M = _representing_matrices(old_basis, F, ring) + + # V contains the normalforms (wrt O_from) of S + S = [ring.zero_monom] + V = [[domain.one] + [domain.zero] * (len(old_basis) - 1)] + G = [] + + L = [(i, 0) for i in range(ngens)] # (i, j) corresponds to x_i * S[j] + L.sort(key=lambda k_l: O_to(_incr_k(S[k_l[1]], k_l[0])), reverse=True) + t = L.pop() + + P = _identity_matrix(len(old_basis), domain) + + while True: + s = len(S) + v = _matrix_mul(M[t[0]], V[t[1]]) + _lambda = _matrix_mul(P, v) + + if all(_lambda[i] == domain.zero for i in range(s, len(old_basis))): + # there is a linear combination of v by V + lt = ring.term_new(_incr_k(S[t[1]], t[0]), domain.one) + rest = ring.from_dict({S[i]: _lambda[i] for i in range(s)}) + + g = (lt - rest).set_ring(ring_to) + if g: + G.append(g) + else: + # v is linearly independent from V + P = _update(s, _lambda, P) + S.append(_incr_k(S[t[1]], t[0])) + V.append(v) + + L.extend([(i, s) for i in range(ngens)]) + L = list(set(L)) + L.sort(key=lambda k_l: O_to(_incr_k(S[k_l[1]], k_l[0])), reverse=True) + + L = [(k, l) for (k, l) in L if all(monomial_div(_incr_k(S[l], k), g.LM) is None for g in G)] + + if not L: + G = [ g.monic() for g in G ] + return sorted(G, key=lambda g: O_to(g.LM), reverse=True) + + t = L.pop() + + +def _incr_k(m, k): + return tuple(list(m[:k]) + [m[k] + 1] + list(m[k + 1:])) + + +def _identity_matrix(n, domain): + M = [[domain.zero]*n for _ in range(n)] + + for i in range(n): + M[i][i] = domain.one + + return M + + +def _matrix_mul(M, v): + return [sum(row[i] * v[i] for i in range(len(v))) for row in M] + + +def _update(s, _lambda, P): + """ + Update ``P`` such that for the updated `P'` `P' v = e_{s}`. + """ + k = min(j for j in range(s, len(_lambda)) if _lambda[j] != 0) + + for r in range(len(_lambda)): + if r != k: + P[r] = [P[r][j] - (P[k][j] * _lambda[r]) / _lambda[k] for j in range(len(P[r]))] + + P[k] = [P[k][j] / _lambda[k] for j in range(len(P[k]))] + P[k], P[s] = P[s], P[k] + + return P + + +def _representing_matrices(basis, G, ring): + r""" + Compute the matrices corresponding to the linear maps `m \mapsto + x_i m` for all variables `x_i`. + """ + domain = ring.domain + u = ring.ngens-1 + + def var(i): + return tuple([0] * i + [1] + [0] * (u - i)) + + def representing_matrix(m): + M = [[domain.zero] * len(basis) for _ in range(len(basis))] + + for i, v in enumerate(basis): + r = ring.term_new(monomial_mul(m, v), domain.one).rem(G) + + for monom, coeff in r.terms(): + j = basis.index(monom) + M[j][i] = coeff + + return M + + return [representing_matrix(var(i)) for i in range(u + 1)] + + +def _basis(G, ring): + r""" + Computes a list of monomials which are not divisible by the leading + monomials wrt to ``O`` of ``G``. These monomials are a basis of + `K[X_1, \ldots, X_n]/(G)`. + """ + order = ring.order + + leading_monomials = [g.LM for g in G] + candidates = [ring.zero_monom] + basis = [] + + while candidates: + t = candidates.pop() + basis.append(t) + + new_candidates = [_incr_k(t, k) for k in range(ring.ngens) + if all(monomial_div(_incr_k(t, k), lmg) is None + for lmg in leading_monomials)] + candidates.extend(new_candidates) + candidates.sort(key=order, reverse=True) + + basis = list(set(basis)) + + return sorted(basis, key=order) diff --git a/phivenv/Lib/site-packages/sympy/polys/fields.py b/phivenv/Lib/site-packages/sympy/polys/fields.py new file mode 100644 index 0000000000000000000000000000000000000000..ee844df55690af0b140132249990b335d926b6d4 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/fields.py @@ -0,0 +1,639 @@ +"""Sparse rational function fields. """ + +from __future__ import annotations +from functools import reduce + +from operator import add, mul, lt, le, gt, ge + +from sympy.core.expr import Expr +from sympy.core.mod import Mod +from sympy.core.numbers import Exp1 +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.core.sympify import CantSympify, sympify +from sympy.functions.elementary.exponential import ExpBase +from sympy.polys.domains.domain import Domain +from sympy.polys.domains.domainelement import DomainElement +from sympy.polys.domains.fractionfield import FractionField +from sympy.polys.domains.polynomialring import PolynomialRing +from sympy.polys.constructor import construct_domain +from sympy.polys.orderings import lex, MonomialOrder +from sympy.polys.polyerrors import CoercionFailed +from sympy.polys.polyoptions import build_options +from sympy.polys.polyutils import _parallel_dict_from_expr +from sympy.polys.rings import PolyRing, PolyElement +from sympy.printing.defaults import DefaultPrinting +from sympy.utilities import public +from sympy.utilities.iterables import is_sequence +from sympy.utilities.magic import pollute + +@public +def field(symbols, domain, order=lex): + """Construct new rational function field returning (field, x1, ..., xn). """ + _field = FracField(symbols, domain, order) + return (_field,) + _field.gens + +@public +def xfield(symbols, domain, order=lex): + """Construct new rational function field returning (field, (x1, ..., xn)). """ + _field = FracField(symbols, domain, order) + return (_field, _field.gens) + +@public +def vfield(symbols, domain, order=lex): + """Construct new rational function field and inject generators into global namespace. """ + _field = FracField(symbols, domain, order) + pollute([ sym.name for sym in _field.symbols ], _field.gens) + return _field + +@public +def sfield(exprs, *symbols, **options): + """Construct a field deriving generators and domain + from options and input expressions. + + Parameters + ========== + + exprs : py:class:`~.Expr` or sequence of :py:class:`~.Expr` (sympifiable) + + symbols : sequence of :py:class:`~.Symbol`/:py:class:`~.Expr` + + options : keyword arguments understood by :py:class:`~.Options` + + Examples + ======== + + >>> from sympy import exp, log, symbols, sfield + + >>> x = symbols("x") + >>> K, f = sfield((x*log(x) + 4*x**2)*exp(1/x + log(x)/3)/x**2) + >>> K + Rational function field in x, exp(1/x), log(x), x**(1/3) over ZZ with lex order + >>> f + (4*x**2*(exp(1/x)) + x*(exp(1/x))*(log(x)))/((x**(1/3))**5) + """ + single = False + if not is_sequence(exprs): + exprs, single = [exprs], True + + exprs = list(map(sympify, exprs)) + opt = build_options(symbols, options) + numdens = [] + for expr in exprs: + numdens.extend(expr.as_numer_denom()) + reps, opt = _parallel_dict_from_expr(numdens, opt) + + if opt.domain is None: + # NOTE: this is inefficient because construct_domain() automatically + # performs conversion to the target domain. It shouldn't do this. + coeffs = sum([list(rep.values()) for rep in reps], []) + opt.domain, _ = construct_domain(coeffs, opt=opt) + + _field = FracField(opt.gens, opt.domain, opt.order) + fracs = [] + for i in range(0, len(reps), 2): + fracs.append(_field(tuple(reps[i:i+2]))) + + if single: + return (_field, fracs[0]) + else: + return (_field, fracs) + + +class FracField(DefaultPrinting): + """Multivariate distributed rational function field. """ + + ring: PolyRing + gens: tuple[FracElement, ...] + symbols: tuple[Expr, ...] + ngens: int + domain: Domain + order: MonomialOrder + + def __new__(cls, symbols, domain, order=lex): + ring = PolyRing(symbols, domain, order) + symbols = ring.symbols + ngens = ring.ngens + domain = ring.domain + order = ring.order + + _hash_tuple = (cls.__name__, symbols, ngens, domain, order) + + obj = object.__new__(cls) + obj._hash_tuple = _hash_tuple + obj._hash = hash(_hash_tuple) + obj.ring = ring + obj.symbols = symbols + obj.ngens = ngens + obj.domain = domain + obj.order = order + + obj.dtype = FracElement(obj, ring.zero).raw_new + + obj.zero = obj.dtype(ring.zero) + obj.one = obj.dtype(ring.one) + + obj.gens = obj._gens() + + for symbol, generator in zip(obj.symbols, obj.gens): + if isinstance(symbol, Symbol): + name = symbol.name + + if not hasattr(obj, name): + setattr(obj, name, generator) + + return obj + + def _gens(self): + """Return a list of polynomial generators. """ + return tuple([ self.dtype(gen) for gen in self.ring.gens ]) + + def __getnewargs__(self): + return (self.symbols, self.domain, self.order) + + def __hash__(self): + return self._hash + + def index(self, gen): + if self.is_element(gen): + return self.ring.index(gen.to_poly()) + else: + raise ValueError("expected a %s, got %s instead" % (self.dtype,gen)) + + def __eq__(self, other): + return isinstance(other, FracField) and \ + (self.symbols, self.ngens, self.domain, self.order) == \ + (other.symbols, other.ngens, other.domain, other.order) + + def __ne__(self, other): + return not self == other + + def is_element(self, element): + """True if ``element`` is an element of this field. False otherwise. """ + return isinstance(element, FracElement) and element.field == self + + def raw_new(self, numer, denom=None): + return self.dtype(numer, denom) + + def new(self, numer, denom=None): + if denom is None: denom = self.ring.one + numer, denom = numer.cancel(denom) + return self.raw_new(numer, denom) + + def domain_new(self, element): + return self.domain.convert(element) + + def ground_new(self, element): + try: + return self.new(self.ring.ground_new(element)) + except CoercionFailed: + domain = self.domain + + if not domain.is_Field and domain.has_assoc_Field: + ring = self.ring + ground_field = domain.get_field() + element = ground_field.convert(element) + numer = ring.ground_new(ground_field.numer(element)) + denom = ring.ground_new(ground_field.denom(element)) + return self.raw_new(numer, denom) + else: + raise + + def field_new(self, element): + if isinstance(element, FracElement): + if self == element.field: + return element + + if isinstance(self.domain, FractionField) and \ + self.domain.field == element.field: + return self.ground_new(element) + elif isinstance(self.domain, PolynomialRing) and \ + self.domain.ring.to_field() == element.field: + return self.ground_new(element) + else: + raise NotImplementedError("conversion") + elif isinstance(element, PolyElement): + denom, numer = element.clear_denoms() + + if isinstance(self.domain, PolynomialRing) and \ + numer.ring == self.domain.ring: + numer = self.ring.ground_new(numer) + elif isinstance(self.domain, FractionField) and \ + numer.ring == self.domain.field.to_ring(): + numer = self.ring.ground_new(numer) + else: + numer = numer.set_ring(self.ring) + + denom = self.ring.ground_new(denom) + return self.raw_new(numer, denom) + elif isinstance(element, tuple) and len(element) == 2: + numer, denom = list(map(self.ring.ring_new, element)) + return self.new(numer, denom) + elif isinstance(element, str): + raise NotImplementedError("parsing") + elif isinstance(element, Expr): + return self.from_expr(element) + else: + return self.ground_new(element) + + __call__ = field_new + + def _rebuild_expr(self, expr, mapping): + domain = self.domain + powers = tuple((gen, gen.as_base_exp()) for gen in mapping.keys() + if gen.is_Pow or isinstance(gen, ExpBase)) + + def _rebuild(expr): + generator = mapping.get(expr) + + if generator is not None: + return generator + elif expr.is_Add: + return reduce(add, list(map(_rebuild, expr.args))) + elif expr.is_Mul: + return reduce(mul, list(map(_rebuild, expr.args))) + elif expr.is_Pow or isinstance(expr, (ExpBase, Exp1)): + b, e = expr.as_base_exp() + # look for bg**eg whose integer power may be b**e + for gen, (bg, eg) in powers: + if bg == b and Mod(e, eg) == 0: + return mapping.get(gen)**int(e/eg) + if e.is_Integer and e is not S.One: + return _rebuild(b)**int(e) + elif mapping.get(1/expr) is not None: + return 1/mapping.get(1/expr) + + try: + return domain.convert(expr) + except CoercionFailed: + if not domain.is_Field and domain.has_assoc_Field: + return domain.get_field().convert(expr) + else: + raise + + return _rebuild(expr) + + def from_expr(self, expr): + mapping = dict(list(zip(self.symbols, self.gens))) + + try: + frac = self._rebuild_expr(sympify(expr), mapping) + except CoercionFailed: + raise ValueError("expected an expression convertible to a rational function in %s, got %s" % (self, expr)) + else: + return self.field_new(frac) + + def to_domain(self): + return FractionField(self) + + def to_ring(self): + return PolyRing(self.symbols, self.domain, self.order) + +class FracElement(DomainElement, DefaultPrinting, CantSympify): + """Element of multivariate distributed rational function field. """ + + def __init__(self, field, numer, denom=None): + if denom is None: + denom = field.ring.one + elif not denom: + raise ZeroDivisionError("zero denominator") + + self.field = field + self.numer = numer + self.denom = denom + + def raw_new(f, numer, denom=None): + return f.__class__(f.field, numer, denom) + + def new(f, numer, denom): + return f.raw_new(*numer.cancel(denom)) + + def to_poly(f): + if f.denom != 1: + raise ValueError("f.denom should be 1") + return f.numer + + def parent(self): + return self.field.to_domain() + + def __getnewargs__(self): + return (self.field, self.numer, self.denom) + + _hash = None + + def __hash__(self): + _hash = self._hash + if _hash is None: + self._hash = _hash = hash((self.field, self.numer, self.denom)) + return _hash + + def copy(self): + return self.raw_new(self.numer.copy(), self.denom.copy()) + + def set_field(self, new_field): + if self.field == new_field: + return self + else: + new_ring = new_field.ring + numer = self.numer.set_ring(new_ring) + denom = self.denom.set_ring(new_ring) + return new_field.new(numer, denom) + + def as_expr(self, *symbols): + return self.numer.as_expr(*symbols)/self.denom.as_expr(*symbols) + + def __eq__(f, g): + if isinstance(g, FracElement) and f.field == g.field: + return f.numer == g.numer and f.denom == g.denom + else: + return f.numer == g and f.denom == f.field.ring.one + + def __ne__(f, g): + return not f == g + + def __bool__(f): + return bool(f.numer) + + def sort_key(self): + return (self.denom.sort_key(), self.numer.sort_key()) + + def _cmp(f1, f2, op): + if f1.field.is_element(f2): + return op(f1.sort_key(), f2.sort_key()) + else: + return NotImplemented + + def __lt__(f1, f2): + return f1._cmp(f2, lt) + def __le__(f1, f2): + return f1._cmp(f2, le) + def __gt__(f1, f2): + return f1._cmp(f2, gt) + def __ge__(f1, f2): + return f1._cmp(f2, ge) + + def __pos__(f): + """Negate all coefficients in ``f``. """ + return f.raw_new(f.numer, f.denom) + + def __neg__(f): + """Negate all coefficients in ``f``. """ + return f.raw_new(-f.numer, f.denom) + + def _extract_ground(self, element): + domain = self.field.domain + + try: + element = domain.convert(element) + except CoercionFailed: + if not domain.is_Field and domain.has_assoc_Field: + ground_field = domain.get_field() + + try: + element = ground_field.convert(element) + except CoercionFailed: + pass + else: + return -1, ground_field.numer(element), ground_field.denom(element) + + return 0, None, None + else: + return 1, element, None + + def __add__(f, g): + """Add rational functions ``f`` and ``g``. """ + field = f.field + + if not g: + return f + elif not f: + return g + elif field.is_element(g): + if f.denom == g.denom: + return f.new(f.numer + g.numer, f.denom) + else: + return f.new(f.numer*g.denom + f.denom*g.numer, f.denom*g.denom) + elif field.ring.is_element(g): + return f.new(f.numer + f.denom*g, f.denom) + else: + if isinstance(g, FracElement): + if isinstance(field.domain, FractionField) and field.domain.field == g.field: + pass + elif isinstance(g.field.domain, FractionField) and g.field.domain.field == field: + return g.__radd__(f) + else: + return NotImplemented + elif isinstance(g, PolyElement): + if isinstance(field.domain, PolynomialRing) and field.domain.ring == g.ring: + pass + else: + return g.__radd__(f) + + return f.__radd__(g) + + def __radd__(f, c): + if f.field.ring.is_element(c): + return f.new(f.numer + f.denom*c, f.denom) + + op, g_numer, g_denom = f._extract_ground(c) + + if op == 1: + return f.new(f.numer + f.denom*g_numer, f.denom) + elif not op: + return NotImplemented + else: + return f.new(f.numer*g_denom + f.denom*g_numer, f.denom*g_denom) + + def __sub__(f, g): + """Subtract rational functions ``f`` and ``g``. """ + field = f.field + + if not g: + return f + elif not f: + return -g + elif field.is_element(g): + if f.denom == g.denom: + return f.new(f.numer - g.numer, f.denom) + else: + return f.new(f.numer*g.denom - f.denom*g.numer, f.denom*g.denom) + elif field.ring.is_element(g): + return f.new(f.numer - f.denom*g, f.denom) + else: + if isinstance(g, FracElement): + if isinstance(field.domain, FractionField) and field.domain.field == g.field: + pass + elif isinstance(g.field.domain, FractionField) and g.field.domain.field == field: + return g.__rsub__(f) + else: + return NotImplemented + elif isinstance(g, PolyElement): + if isinstance(field.domain, PolynomialRing) and field.domain.ring == g.ring: + pass + else: + return g.__rsub__(f) + + op, g_numer, g_denom = f._extract_ground(g) + + if op == 1: + return f.new(f.numer - f.denom*g_numer, f.denom) + elif not op: + return NotImplemented + else: + return f.new(f.numer*g_denom - f.denom*g_numer, f.denom*g_denom) + + def __rsub__(f, c): + if f.field.ring.is_element(c): + return f.new(-f.numer + f.denom*c, f.denom) + + op, g_numer, g_denom = f._extract_ground(c) + + if op == 1: + return f.new(-f.numer + f.denom*g_numer, f.denom) + elif not op: + return NotImplemented + else: + return f.new(-f.numer*g_denom + f.denom*g_numer, f.denom*g_denom) + + def __mul__(f, g): + """Multiply rational functions ``f`` and ``g``. """ + field = f.field + + if not f or not g: + return field.zero + elif field.is_element(g): + return f.new(f.numer*g.numer, f.denom*g.denom) + elif field.ring.is_element(g): + return f.new(f.numer*g, f.denom) + else: + if isinstance(g, FracElement): + if isinstance(field.domain, FractionField) and field.domain.field == g.field: + pass + elif isinstance(g.field.domain, FractionField) and g.field.domain.field == field: + return g.__rmul__(f) + else: + return NotImplemented + elif isinstance(g, PolyElement): + if isinstance(field.domain, PolynomialRing) and field.domain.ring == g.ring: + pass + else: + return g.__rmul__(f) + + return f.__rmul__(g) + + def __rmul__(f, c): + if f.field.ring.is_element(c): + return f.new(f.numer*c, f.denom) + + op, g_numer, g_denom = f._extract_ground(c) + + if op == 1: + return f.new(f.numer*g_numer, f.denom) + elif not op: + return NotImplemented + else: + return f.new(f.numer*g_numer, f.denom*g_denom) + + def __truediv__(f, g): + """Computes quotient of fractions ``f`` and ``g``. """ + field = f.field + + if not g: + raise ZeroDivisionError + elif field.is_element(g): + return f.new(f.numer*g.denom, f.denom*g.numer) + elif field.ring.is_element(g): + return f.new(f.numer, f.denom*g) + else: + if isinstance(g, FracElement): + if isinstance(field.domain, FractionField) and field.domain.field == g.field: + pass + elif isinstance(g.field.domain, FractionField) and g.field.domain.field == field: + return g.__rtruediv__(f) + else: + return NotImplemented + elif isinstance(g, PolyElement): + if isinstance(field.domain, PolynomialRing) and field.domain.ring == g.ring: + pass + else: + return g.__rtruediv__(f) + + op, g_numer, g_denom = f._extract_ground(g) + + if op == 1: + return f.new(f.numer, f.denom*g_numer) + elif not op: + return NotImplemented + else: + return f.new(f.numer*g_denom, f.denom*g_numer) + + def __rtruediv__(f, c): + if not f: + raise ZeroDivisionError + elif f.field.ring.is_element(c): + return f.new(f.denom*c, f.numer) + + op, g_numer, g_denom = f._extract_ground(c) + + if op == 1: + return f.new(f.denom*g_numer, f.numer) + elif not op: + return NotImplemented + else: + return f.new(f.denom*g_numer, f.numer*g_denom) + + def __pow__(f, n): + """Raise ``f`` to a non-negative power ``n``. """ + if n >= 0: + return f.raw_new(f.numer**n, f.denom**n) + elif not f: + raise ZeroDivisionError + else: + return f.raw_new(f.denom**-n, f.numer**-n) + + def diff(f, x): + """Computes partial derivative in ``x``. + + Examples + ======== + + >>> from sympy.polys.fields import field + >>> from sympy.polys.domains import ZZ + + >>> _, x, y, z = field("x,y,z", ZZ) + >>> ((x**2 + y)/(z + 1)).diff(x) + 2*x/(z + 1) + + """ + x = x.to_poly() + return f.new(f.numer.diff(x)*f.denom - f.numer*f.denom.diff(x), f.denom**2) + + def __call__(f, *values): + if 0 < len(values) <= f.field.ngens: + return f.evaluate(list(zip(f.field.gens, values))) + else: + raise ValueError("expected at least 1 and at most %s values, got %s" % (f.field.ngens, len(values))) + + def evaluate(f, x, a=None): + if isinstance(x, list) and a is None: + x = [ (X.to_poly(), a) for X, a in x ] + numer, denom = f.numer.evaluate(x), f.denom.evaluate(x) + else: + x = x.to_poly() + numer, denom = f.numer.evaluate(x, a), f.denom.evaluate(x, a) + + field = numer.ring.to_field() + return field.new(numer, denom) + + def subs(f, x, a=None): + if isinstance(x, list) and a is None: + x = [ (X.to_poly(), a) for X, a in x ] + numer, denom = f.numer.subs(x), f.denom.subs(x) + else: + x = x.to_poly() + numer, denom = f.numer.subs(x, a), f.denom.subs(x, a) + + return f.new(numer, denom) + + def compose(f, x, a=None): + raise NotImplementedError diff --git a/phivenv/Lib/site-packages/sympy/polys/galoistools.py b/phivenv/Lib/site-packages/sympy/polys/galoistools.py new file mode 100644 index 0000000000000000000000000000000000000000..b09f85057eced59b8054c6007f2b291a35a2fafb --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/galoistools.py @@ -0,0 +1,2532 @@ +"""Dense univariate polynomials with coefficients in Galois fields. """ + +from math import ceil as _ceil, sqrt as _sqrt, prod + +from sympy.core.random import uniform, _randint +from sympy.external.gmpy import SYMPY_INTS, MPZ, invert +from sympy.polys.polyconfig import query +from sympy.polys.polyerrors import ExactQuotientFailed +from sympy.polys.polyutils import _sort_factors + + +def gf_crt(U, M, K=None): + """ + Chinese Remainder Theorem. + + Given a set of integer residues ``u_0,...,u_n`` and a set of + co-prime integer moduli ``m_0,...,m_n``, returns an integer + ``u``, such that ``u = u_i mod m_i`` for ``i = ``0,...,n``. + + Examples + ======== + + Consider a set of residues ``U = [49, 76, 65]`` + and a set of moduli ``M = [99, 97, 95]``. Then we have:: + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_crt + + >>> gf_crt([49, 76, 65], [99, 97, 95], ZZ) + 639985 + + This is the correct result because:: + + >>> [639985 % m for m in [99, 97, 95]] + [49, 76, 65] + + Note: this is a low-level routine with no error checking. + + See Also + ======== + + sympy.ntheory.modular.crt : a higher level crt routine + sympy.ntheory.modular.solve_congruence + + """ + p = prod(M, start=K.one) + v = K.zero + + for u, m in zip(U, M): + e = p // m + s, _, _ = K.gcdex(e, m) + v += e*(u*s % m) + + return v % p + + +def gf_crt1(M, K): + """ + First part of the Chinese Remainder Theorem. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_crt, gf_crt1, gf_crt2 + >>> U = [49, 76, 65] + >>> M = [99, 97, 95] + + The following two codes have the same result. + + >>> gf_crt(U, M, ZZ) + 639985 + + >>> p, E, S = gf_crt1(M, ZZ) + >>> gf_crt2(U, M, p, E, S, ZZ) + 639985 + + However, it is faster when we want to fix ``M`` and + compute for multiple U, i.e. the following cases: + + >>> p, E, S = gf_crt1(M, ZZ) + >>> Us = [[49, 76, 65], [23, 42, 67]] + >>> for U in Us: + ... print(gf_crt2(U, M, p, E, S, ZZ)) + 639985 + 236237 + + See Also + ======== + + sympy.ntheory.modular.crt1 : a higher level crt routine + sympy.polys.galoistools.gf_crt + sympy.polys.galoistools.gf_crt2 + + """ + E, S = [], [] + p = prod(M, start=K.one) + + for m in M: + E.append(p // m) + S.append(K.gcdex(E[-1], m)[0] % m) + + return p, E, S + + +def gf_crt2(U, M, p, E, S, K): + """ + Second part of the Chinese Remainder Theorem. + + See ``gf_crt1`` for usage. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_crt2 + + >>> U = [49, 76, 65] + >>> M = [99, 97, 95] + >>> p = 912285 + >>> E = [9215, 9405, 9603] + >>> S = [62, 24, 12] + + >>> gf_crt2(U, M, p, E, S, ZZ) + 639985 + + See Also + ======== + + sympy.ntheory.modular.crt2 : a higher level crt routine + sympy.polys.galoistools.gf_crt + sympy.polys.galoistools.gf_crt1 + + """ + v = K.zero + + for u, m, e, s in zip(U, M, E, S): + v += e*(u*s % m) + + return v % p + + +def gf_int(a, p): + """ + Coerce ``a mod p`` to an integer in the range ``[-p/2, p/2]``. + + Examples + ======== + + >>> from sympy.polys.galoistools import gf_int + + >>> gf_int(2, 7) + 2 + >>> gf_int(5, 7) + -2 + + """ + if a <= p // 2: + return a + else: + return a - p + + +def gf_degree(f): + """ + Return the leading degree of ``f``. + + Examples + ======== + + >>> from sympy.polys.galoistools import gf_degree + + >>> gf_degree([1, 1, 2, 0]) + 3 + >>> gf_degree([]) + -1 + + """ + return len(f) - 1 + + +def gf_LC(f, K): + """ + Return the leading coefficient of ``f``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_LC + + >>> gf_LC([3, 0, 1], ZZ) + 3 + + """ + if not f: + return K.zero + else: + return f[0] + + +def gf_TC(f, K): + """ + Return the trailing coefficient of ``f``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_TC + + >>> gf_TC([3, 0, 1], ZZ) + 1 + + """ + if not f: + return K.zero + else: + return f[-1] + + +def gf_strip(f): + """ + Remove leading zeros from ``f``. + + + Examples + ======== + + >>> from sympy.polys.galoistools import gf_strip + + >>> gf_strip([0, 0, 0, 3, 0, 1]) + [3, 0, 1] + + """ + if not f or f[0]: + return f + + k = 0 + + for coeff in f: + if coeff: + break + else: + k += 1 + + return f[k:] + + +def gf_trunc(f, p): + """ + Reduce all coefficients modulo ``p``. + + Examples + ======== + + >>> from sympy.polys.galoistools import gf_trunc + + >>> gf_trunc([7, -2, 3], 5) + [2, 3, 3] + + """ + return gf_strip([ a % p for a in f ]) + + +def gf_normal(f, p, K): + """ + Normalize all coefficients in ``K``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_normal + + >>> gf_normal([5, 10, 21, -3], 5, ZZ) + [1, 2] + + """ + return gf_trunc(list(map(K, f)), p) + + +def gf_from_dict(f, p, K): + """ + Create a ``GF(p)[x]`` polynomial from a dict. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_from_dict + + >>> gf_from_dict({10: ZZ(4), 4: ZZ(33), 0: ZZ(-1)}, 5, ZZ) + [4, 0, 0, 0, 0, 0, 3, 0, 0, 0, 4] + + """ + n, h = max(f.keys()), [] + + if isinstance(n, SYMPY_INTS): + for k in range(n, -1, -1): + h.append(f.get(k, K.zero) % p) + else: + (n,) = n + + for k in range(n, -1, -1): + h.append(f.get((k,), K.zero) % p) + + return gf_trunc(h, p) + + +def gf_to_dict(f, p, symmetric=True): + """ + Convert a ``GF(p)[x]`` polynomial to a dict. + + Examples + ======== + + >>> from sympy.polys.galoistools import gf_to_dict + + >>> gf_to_dict([4, 0, 0, 0, 0, 0, 3, 0, 0, 0, 4], 5) + {0: -1, 4: -2, 10: -1} + >>> gf_to_dict([4, 0, 0, 0, 0, 0, 3, 0, 0, 0, 4], 5, symmetric=False) + {0: 4, 4: 3, 10: 4} + + """ + n, result = gf_degree(f), {} + + for k in range(0, n + 1): + if symmetric: + a = gf_int(f[n - k], p) + else: + a = f[n - k] + + if a: + result[k] = a + + return result + + +def gf_from_int_poly(f, p): + """ + Create a ``GF(p)[x]`` polynomial from ``Z[x]``. + + Examples + ======== + + >>> from sympy.polys.galoistools import gf_from_int_poly + + >>> gf_from_int_poly([7, -2, 3], 5) + [2, 3, 3] + + """ + return gf_trunc(f, p) + + +def gf_to_int_poly(f, p, symmetric=True): + """ + Convert a ``GF(p)[x]`` polynomial to ``Z[x]``. + + + Examples + ======== + + >>> from sympy.polys.galoistools import gf_to_int_poly + + >>> gf_to_int_poly([2, 3, 3], 5) + [2, -2, -2] + >>> gf_to_int_poly([2, 3, 3], 5, symmetric=False) + [2, 3, 3] + + """ + if symmetric: + return [ gf_int(c, p) for c in f ] + else: + return f + + +def gf_neg(f, p, K): + """ + Negate a polynomial in ``GF(p)[x]``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_neg + + >>> gf_neg([3, 2, 1, 0], 5, ZZ) + [2, 3, 4, 0] + + """ + return [ -coeff % p for coeff in f ] + + +def gf_add_ground(f, a, p, K): + """ + Compute ``f + a`` where ``f`` in ``GF(p)[x]`` and ``a`` in ``GF(p)``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_add_ground + + >>> gf_add_ground([3, 2, 4], 2, 5, ZZ) + [3, 2, 1] + + """ + if not f: + a = a % p + else: + a = (f[-1] + a) % p + + if len(f) > 1: + return f[:-1] + [a] + + if not a: + return [] + else: + return [a] + + +def gf_sub_ground(f, a, p, K): + """ + Compute ``f - a`` where ``f`` in ``GF(p)[x]`` and ``a`` in ``GF(p)``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_sub_ground + + >>> gf_sub_ground([3, 2, 4], 2, 5, ZZ) + [3, 2, 2] + + """ + if not f: + a = -a % p + else: + a = (f[-1] - a) % p + + if len(f) > 1: + return f[:-1] + [a] + + if not a: + return [] + else: + return [a] + + +def gf_mul_ground(f, a, p, K): + """ + Compute ``f * a`` where ``f`` in ``GF(p)[x]`` and ``a`` in ``GF(p)``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_mul_ground + + >>> gf_mul_ground([3, 2, 4], 2, 5, ZZ) + [1, 4, 3] + + """ + if not a: + return [] + else: + return [ (a*b) % p for b in f ] + + +def gf_quo_ground(f, a, p, K): + """ + Compute ``f/a`` where ``f`` in ``GF(p)[x]`` and ``a`` in ``GF(p)``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_quo_ground + + >>> gf_quo_ground(ZZ.map([3, 2, 4]), ZZ(2), 5, ZZ) + [4, 1, 2] + + """ + return gf_mul_ground(f, K.invert(a, p), p, K) + + +def gf_add(f, g, p, K): + """ + Add polynomials in ``GF(p)[x]``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_add + + >>> gf_add([3, 2, 4], [2, 2, 2], 5, ZZ) + [4, 1] + + """ + if not f: + return g + if not g: + return f + + df = gf_degree(f) + dg = gf_degree(g) + + if df == dg: + return gf_strip([ (a + b) % p for a, b in zip(f, g) ]) + else: + k = abs(df - dg) + + if df > dg: + h, f = f[:k], f[k:] + else: + h, g = g[:k], g[k:] + + return h + [ (a + b) % p for a, b in zip(f, g) ] + + +def gf_sub(f, g, p, K): + """ + Subtract polynomials in ``GF(p)[x]``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_sub + + >>> gf_sub([3, 2, 4], [2, 2, 2], 5, ZZ) + [1, 0, 2] + + """ + if not g: + return f + if not f: + return gf_neg(g, p, K) + + df = gf_degree(f) + dg = gf_degree(g) + + if df == dg: + return gf_strip([ (a - b) % p for a, b in zip(f, g) ]) + else: + k = abs(df - dg) + + if df > dg: + h, f = f[:k], f[k:] + else: + h, g = gf_neg(g[:k], p, K), g[k:] + + return h + [ (a - b) % p for a, b in zip(f, g) ] + + +def gf_mul(f, g, p, K): + """ + Multiply polynomials in ``GF(p)[x]``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_mul + + >>> gf_mul([3, 2, 4], [2, 2, 2], 5, ZZ) + [1, 0, 3, 2, 3] + + """ + df = gf_degree(f) + dg = gf_degree(g) + + dh = df + dg + h = [0]*(dh + 1) + + for i in range(0, dh + 1): + coeff = K.zero + + for j in range(max(0, i - dg), min(i, df) + 1): + coeff += f[j]*g[i - j] + + h[i] = coeff % p + + return gf_strip(h) + + +def gf_sqr(f, p, K): + """ + Square polynomials in ``GF(p)[x]``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_sqr + + >>> gf_sqr([3, 2, 4], 5, ZZ) + [4, 2, 3, 1, 1] + + """ + df = gf_degree(f) + + dh = 2*df + h = [0]*(dh + 1) + + for i in range(0, dh + 1): + coeff = K.zero + + jmin = max(0, i - df) + jmax = min(i, df) + + n = jmax - jmin + 1 + + jmax = jmin + n // 2 - 1 + + for j in range(jmin, jmax + 1): + coeff += f[j]*f[i - j] + + coeff += coeff + + if n & 1: + elem = f[jmax + 1] + coeff += elem**2 + + h[i] = coeff % p + + return gf_strip(h) + + +def gf_add_mul(f, g, h, p, K): + """ + Returns ``f + g*h`` where ``f``, ``g``, ``h`` in ``GF(p)[x]``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_add_mul + >>> gf_add_mul([3, 2, 4], [2, 2, 2], [1, 4], 5, ZZ) + [2, 3, 2, 2] + """ + return gf_add(f, gf_mul(g, h, p, K), p, K) + + +def gf_sub_mul(f, g, h, p, K): + """ + Compute ``f - g*h`` where ``f``, ``g``, ``h`` in ``GF(p)[x]``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_sub_mul + + >>> gf_sub_mul([3, 2, 4], [2, 2, 2], [1, 4], 5, ZZ) + [3, 3, 2, 1] + + """ + return gf_sub(f, gf_mul(g, h, p, K), p, K) + + +def gf_expand(F, p, K): + """ + Expand results of :func:`~.factor` in ``GF(p)[x]``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_expand + + >>> gf_expand([([3, 2, 4], 1), ([2, 2], 2), ([3, 1], 3)], 5, ZZ) + [4, 3, 0, 3, 0, 1, 4, 1] + + """ + if isinstance(F, tuple): + lc, F = F + else: + lc = K.one + + g = [lc] + + for f, k in F: + f = gf_pow(f, k, p, K) + g = gf_mul(g, f, p, K) + + return g + + +def gf_div(f, g, p, K): + """ + Division with remainder in ``GF(p)[x]``. + + Given univariate polynomials ``f`` and ``g`` with coefficients in a + finite field with ``p`` elements, returns polynomials ``q`` and ``r`` + (quotient and remainder) such that ``f = q*g + r``. + + Consider polynomials ``x**3 + x + 1`` and ``x**2 + x`` in GF(2):: + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_div, gf_add_mul + + >>> gf_div(ZZ.map([1, 0, 1, 1]), ZZ.map([1, 1, 0]), 2, ZZ) + ([1, 1], [1]) + + As result we obtained quotient ``x + 1`` and remainder ``1``, thus:: + + >>> gf_add_mul(ZZ.map([1]), ZZ.map([1, 1]), ZZ.map([1, 1, 0]), 2, ZZ) + [1, 0, 1, 1] + + References + ========== + + .. [1] [Monagan93]_ + .. [2] [Gathen99]_ + + """ + df = gf_degree(f) + dg = gf_degree(g) + + if not g: + raise ZeroDivisionError("polynomial division") + elif df < dg: + return [], f + + inv = K.invert(g[0], p) + + h, dq, dr = list(f), df - dg, dg - 1 + + for i in range(0, df + 1): + coeff = h[i] + + for j in range(max(0, dg - i), min(df - i, dr) + 1): + coeff -= h[i + j - dg] * g[dg - j] + + if i <= dq: + coeff *= inv + + h[i] = coeff % p + + return h[:dq + 1], gf_strip(h[dq + 1:]) + + +def gf_rem(f, g, p, K): + """ + Compute polynomial remainder in ``GF(p)[x]``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_rem + + >>> gf_rem(ZZ.map([1, 0, 1, 1]), ZZ.map([1, 1, 0]), 2, ZZ) + [1] + + """ + return gf_div(f, g, p, K)[1] + + +def gf_quo(f, g, p, K): + """ + Compute exact quotient in ``GF(p)[x]``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_quo + + >>> gf_quo(ZZ.map([1, 0, 1, 1]), ZZ.map([1, 1, 0]), 2, ZZ) + [1, 1] + >>> gf_quo(ZZ.map([1, 0, 3, 2, 3]), ZZ.map([2, 2, 2]), 5, ZZ) + [3, 2, 4] + + """ + df = gf_degree(f) + dg = gf_degree(g) + + if not g: + raise ZeroDivisionError("polynomial division") + elif df < dg: + return [] + + inv = K.invert(g[0], p) + + h, dq, dr = f[:], df - dg, dg - 1 + + for i in range(0, dq + 1): + coeff = h[i] + + for j in range(max(0, dg - i), min(df - i, dr) + 1): + coeff -= h[i + j - dg] * g[dg - j] + + h[i] = (coeff * inv) % p + + return h[:dq + 1] + + +def gf_exquo(f, g, p, K): + """ + Compute polynomial quotient in ``GF(p)[x]``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_exquo + + >>> gf_exquo(ZZ.map([1, 0, 3, 2, 3]), ZZ.map([2, 2, 2]), 5, ZZ) + [3, 2, 4] + + >>> gf_exquo(ZZ.map([1, 0, 1, 1]), ZZ.map([1, 1, 0]), 2, ZZ) + Traceback (most recent call last): + ... + ExactQuotientFailed: [1, 1, 0] does not divide [1, 0, 1, 1] + + """ + q, r = gf_div(f, g, p, K) + + if not r: + return q + else: + raise ExactQuotientFailed(f, g) + + +def gf_lshift(f, n, K): + """ + Efficiently multiply ``f`` by ``x**n``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_lshift + + >>> gf_lshift([3, 2, 4], 4, ZZ) + [3, 2, 4, 0, 0, 0, 0] + + """ + if not f: + return f + else: + return f + [K.zero]*n + + +def gf_rshift(f, n, K): + """ + Efficiently divide ``f`` by ``x**n``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_rshift + + >>> gf_rshift([1, 2, 3, 4, 0], 3, ZZ) + ([1, 2], [3, 4, 0]) + + """ + if not n: + return f, [] + else: + return f[:-n], f[-n:] + + +def gf_pow(f, n, p, K): + """ + Compute ``f**n`` in ``GF(p)[x]`` using repeated squaring. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_pow + + >>> gf_pow([3, 2, 4], 3, 5, ZZ) + [2, 4, 4, 2, 2, 1, 4] + + """ + if not n: + return [K.one] + elif n == 1: + return f + elif n == 2: + return gf_sqr(f, p, K) + + h = [K.one] + + while True: + if n & 1: + h = gf_mul(h, f, p, K) + n -= 1 + + n >>= 1 + + if not n: + break + + f = gf_sqr(f, p, K) + + return h + +def gf_frobenius_monomial_base(g, p, K): + """ + return the list of ``x**(i*p) mod g in Z_p`` for ``i = 0, .., n - 1`` + where ``n = gf_degree(g)`` + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_frobenius_monomial_base + >>> g = ZZ.map([1, 0, 2, 1]) + >>> gf_frobenius_monomial_base(g, 5, ZZ) + [[1], [4, 4, 2], [1, 2]] + + """ + n = gf_degree(g) + if n == 0: + return [] + b = [0]*n + b[0] = [1] + if p < n: + for i in range(1, n): + mon = gf_lshift(b[i - 1], p, K) + b[i] = gf_rem(mon, g, p, K) + elif n > 1: + b[1] = gf_pow_mod([K.one, K.zero], p, g, p, K) + for i in range(2, n): + b[i] = gf_mul(b[i - 1], b[1], p, K) + b[i] = gf_rem(b[i], g, p, K) + + return b + +def gf_frobenius_map(f, g, b, p, K): + """ + compute gf_pow_mod(f, p, g, p, K) using the Frobenius map + + Parameters + ========== + + f, g : polynomials in ``GF(p)[x]`` + b : frobenius monomial base + p : prime number + K : domain + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_frobenius_monomial_base, gf_frobenius_map + >>> f = ZZ.map([2, 1, 0, 1]) + >>> g = ZZ.map([1, 0, 2, 1]) + >>> p = 5 + >>> b = gf_frobenius_monomial_base(g, p, ZZ) + >>> r = gf_frobenius_map(f, g, b, p, ZZ) + >>> gf_frobenius_map(f, g, b, p, ZZ) + [4, 0, 3] + """ + m = gf_degree(g) + if gf_degree(f) >= m: + f = gf_rem(f, g, p, K) + if not f: + return [] + n = gf_degree(f) + sf = [f[-1]] + for i in range(1, n + 1): + v = gf_mul_ground(b[i], f[n - i], p, K) + sf = gf_add(sf, v, p, K) + return sf + +def _gf_pow_pnm1d2(f, n, g, b, p, K): + """ + utility function for ``gf_edf_zassenhaus`` + Compute ``f**((p**n - 1) // 2)`` in ``GF(p)[x]/(g)`` + ``f**((p**n - 1) // 2) = (f*f**p*...*f**(p**n - 1))**((p - 1) // 2)`` + """ + f = gf_rem(f, g, p, K) + h = f + r = f + for i in range(1, n): + h = gf_frobenius_map(h, g, b, p, K) + r = gf_mul(r, h, p, K) + r = gf_rem(r, g, p, K) + + res = gf_pow_mod(r, (p - 1)//2, g, p, K) + return res + +def gf_pow_mod(f, n, g, p, K): + """ + Compute ``f**n`` in ``GF(p)[x]/(g)`` using repeated squaring. + + Given polynomials ``f`` and ``g`` in ``GF(p)[x]`` and a non-negative + integer ``n``, efficiently computes ``f**n (mod g)`` i.e. the remainder + of ``f**n`` from division by ``g``, using the repeated squaring algorithm. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_pow_mod + + >>> gf_pow_mod(ZZ.map([3, 2, 4]), 3, ZZ.map([1, 1]), 5, ZZ) + [] + + References + ========== + + .. [1] [Gathen99]_ + + """ + if not n: + return [K.one] + elif n == 1: + return gf_rem(f, g, p, K) + elif n == 2: + return gf_rem(gf_sqr(f, p, K), g, p, K) + + h = [K.one] + + while True: + if n & 1: + h = gf_mul(h, f, p, K) + h = gf_rem(h, g, p, K) + n -= 1 + + n >>= 1 + + if not n: + break + + f = gf_sqr(f, p, K) + f = gf_rem(f, g, p, K) + + return h + + +def gf_gcd(f, g, p, K): + """ + Euclidean Algorithm in ``GF(p)[x]``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_gcd + + >>> gf_gcd(ZZ.map([3, 2, 4]), ZZ.map([2, 2, 3]), 5, ZZ) + [1, 3] + + """ + while g: + f, g = g, gf_rem(f, g, p, K) + + return gf_monic(f, p, K)[1] + + +def gf_lcm(f, g, p, K): + """ + Compute polynomial LCM in ``GF(p)[x]``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_lcm + + >>> gf_lcm(ZZ.map([3, 2, 4]), ZZ.map([2, 2, 3]), 5, ZZ) + [1, 2, 0, 4] + + """ + if not f or not g: + return [] + + h = gf_quo(gf_mul(f, g, p, K), + gf_gcd(f, g, p, K), p, K) + + return gf_monic(h, p, K)[1] + + +def gf_cofactors(f, g, p, K): + """ + Compute polynomial GCD and cofactors in ``GF(p)[x]``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_cofactors + + >>> gf_cofactors(ZZ.map([3, 2, 4]), ZZ.map([2, 2, 3]), 5, ZZ) + ([1, 3], [3, 3], [2, 1]) + + """ + if not f and not g: + return ([], [], []) + + h = gf_gcd(f, g, p, K) + + return (h, gf_quo(f, h, p, K), + gf_quo(g, h, p, K)) + + +def gf_gcdex(f, g, p, K): + """ + Extended Euclidean Algorithm in ``GF(p)[x]``. + + Given polynomials ``f`` and ``g`` in ``GF(p)[x]``, computes polynomials + ``s``, ``t`` and ``h``, such that ``h = gcd(f, g)`` and ``s*f + t*g = h``. + The typical application of EEA is solving polynomial diophantine equations. + + Consider polynomials ``f = (x + 7) (x + 1)``, ``g = (x + 7) (x**2 + 1)`` + in ``GF(11)[x]``. Application of Extended Euclidean Algorithm gives:: + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_gcdex, gf_mul, gf_add + + >>> s, t, g = gf_gcdex(ZZ.map([1, 8, 7]), ZZ.map([1, 7, 1, 7]), 11, ZZ) + >>> s, t, g + ([5, 6], [6], [1, 7]) + + As result we obtained polynomials ``s = 5*x + 6`` and ``t = 6``, and + additionally ``gcd(f, g) = x + 7``. This is correct because:: + + >>> S = gf_mul(s, ZZ.map([1, 8, 7]), 11, ZZ) + >>> T = gf_mul(t, ZZ.map([1, 7, 1, 7]), 11, ZZ) + + >>> gf_add(S, T, 11, ZZ) == [1, 7] + True + + References + ========== + + .. [1] [Gathen99]_ + + """ + if not (f or g): + return [K.one], [], [] + + p0, r0 = gf_monic(f, p, K) + p1, r1 = gf_monic(g, p, K) + + if not f: + return [], [K.invert(p1, p)], r1 + if not g: + return [K.invert(p0, p)], [], r0 + + s0, s1 = [K.invert(p0, p)], [] + t0, t1 = [], [K.invert(p1, p)] + + while True: + Q, R = gf_div(r0, r1, p, K) + + if not R: + break + + (lc, r1), r0 = gf_monic(R, p, K), r1 + + inv = K.invert(lc, p) + + s = gf_sub_mul(s0, s1, Q, p, K) + t = gf_sub_mul(t0, t1, Q, p, K) + + s1, s0 = gf_mul_ground(s, inv, p, K), s1 + t1, t0 = gf_mul_ground(t, inv, p, K), t1 + + return s1, t1, r1 + + +def gf_monic(f, p, K): + """ + Compute LC and a monic polynomial in ``GF(p)[x]``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_monic + + >>> gf_monic(ZZ.map([3, 2, 4]), 5, ZZ) + (3, [1, 4, 3]) + + """ + if not f: + return K.zero, [] + else: + lc = f[0] + + if K.is_one(lc): + return lc, list(f) + else: + return lc, gf_quo_ground(f, lc, p, K) + + +def gf_diff(f, p, K): + """ + Differentiate polynomial in ``GF(p)[x]``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_diff + + >>> gf_diff([3, 2, 4], 5, ZZ) + [1, 2] + + """ + df = gf_degree(f) + + h, n = [K.zero]*df, df + + for coeff in f[:-1]: + coeff *= K(n) + coeff %= p + + if coeff: + h[df - n] = coeff + + n -= 1 + + return gf_strip(h) + + +def gf_eval(f, a, p, K): + """ + Evaluate ``f(a)`` in ``GF(p)`` using Horner scheme. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_eval + + >>> gf_eval([3, 2, 4], 2, 5, ZZ) + 0 + + """ + result = K.zero + + for c in f: + result *= a + result += c + result %= p + + return result + + +def gf_multi_eval(f, A, p, K): + """ + Evaluate ``f(a)`` for ``a`` in ``[a_1, ..., a_n]``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_multi_eval + + >>> gf_multi_eval([3, 2, 4], [0, 1, 2, 3, 4], 5, ZZ) + [4, 4, 0, 2, 0] + + """ + return [ gf_eval(f, a, p, K) for a in A ] + + +def gf_compose(f, g, p, K): + """ + Compute polynomial composition ``f(g)`` in ``GF(p)[x]``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_compose + + >>> gf_compose([3, 2, 4], [2, 2, 2], 5, ZZ) + [2, 4, 0, 3, 0] + + """ + if len(g) <= 1: + return gf_strip([gf_eval(f, gf_LC(g, K), p, K)]) + + if not f: + return [] + + h = [f[0]] + + for c in f[1:]: + h = gf_mul(h, g, p, K) + h = gf_add_ground(h, c, p, K) + + return h + + +def gf_compose_mod(g, h, f, p, K): + """ + Compute polynomial composition ``g(h)`` in ``GF(p)[x]/(f)``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_compose_mod + + >>> gf_compose_mod(ZZ.map([3, 2, 4]), ZZ.map([2, 2, 2]), ZZ.map([4, 3]), 5, ZZ) + [4] + + """ + if not g: + return [] + + comp = [g[0]] + + for a in g[1:]: + comp = gf_mul(comp, h, p, K) + comp = gf_add_ground(comp, a, p, K) + comp = gf_rem(comp, f, p, K) + + return comp + + +def gf_trace_map(a, b, c, n, f, p, K): + """ + Compute polynomial trace map in ``GF(p)[x]/(f)``. + + Given a polynomial ``f`` in ``GF(p)[x]``, polynomials ``a``, ``b``, + ``c`` in the quotient ring ``GF(p)[x]/(f)`` such that ``b = c**t + (mod f)`` for some positive power ``t`` of ``p``, and a positive + integer ``n``, returns a mapping:: + + a -> a**t**n, a + a**t + a**t**2 + ... + a**t**n (mod f) + + In factorization context, ``b = x**p mod f`` and ``c = x mod f``. + This way we can efficiently compute trace polynomials in equal + degree factorization routine, much faster than with other methods, + like iterated Frobenius algorithm, for large degrees. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_trace_map + + >>> gf_trace_map([1, 2], [4, 4], [1, 1], 4, [3, 2, 4], 5, ZZ) + ([1, 3], [1, 3]) + + References + ========== + + .. [1] [Gathen92]_ + + """ + u = gf_compose_mod(a, b, f, p, K) + v = b + + if n & 1: + U = gf_add(a, u, p, K) + V = b + else: + U = a + V = c + + n >>= 1 + + while n: + u = gf_add(u, gf_compose_mod(u, v, f, p, K), p, K) + v = gf_compose_mod(v, v, f, p, K) + + if n & 1: + U = gf_add(U, gf_compose_mod(u, V, f, p, K), p, K) + V = gf_compose_mod(v, V, f, p, K) + + n >>= 1 + + return gf_compose_mod(a, V, f, p, K), U + +def _gf_trace_map(f, n, g, b, p, K): + """ + utility for ``gf_edf_shoup`` + """ + f = gf_rem(f, g, p, K) + h = f + r = f + for i in range(1, n): + h = gf_frobenius_map(h, g, b, p, K) + r = gf_add(r, h, p, K) + r = gf_rem(r, g, p, K) + return r + + +def gf_random(n, p, K): + """ + Generate a random polynomial in ``GF(p)[x]`` of degree ``n``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_random + >>> gf_random(10, 5, ZZ) #doctest: +SKIP + [1, 2, 3, 2, 1, 1, 1, 2, 0, 4, 2] + + """ + pi = int(p) + return [K.one] + [ K(int(uniform(0, pi))) for i in range(0, n) ] + + +def gf_irreducible(n, p, K): + """ + Generate random irreducible polynomial of degree ``n`` in ``GF(p)[x]``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_irreducible + >>> gf_irreducible(10, 5, ZZ) #doctest: +SKIP + [1, 4, 2, 2, 3, 2, 4, 1, 4, 0, 4] + + """ + while True: + f = gf_random(n, p, K) + if gf_irreducible_p(f, p, K): + return f + + +def gf_irred_p_ben_or(f, p, K): + """ + Ben-Or's polynomial irreducibility test over finite fields. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_irred_p_ben_or + + >>> gf_irred_p_ben_or(ZZ.map([1, 4, 2, 2, 3, 2, 4, 1, 4, 0, 4]), 5, ZZ) + True + >>> gf_irred_p_ben_or(ZZ.map([3, 2, 4]), 5, ZZ) + False + + """ + n = gf_degree(f) + + if n <= 1: + return True + + _, f = gf_monic(f, p, K) + if n < 5: + H = h = gf_pow_mod([K.one, K.zero], p, f, p, K) + + for i in range(0, n//2): + g = gf_sub(h, [K.one, K.zero], p, K) + + if gf_gcd(f, g, p, K) == [K.one]: + h = gf_compose_mod(h, H, f, p, K) + else: + return False + else: + b = gf_frobenius_monomial_base(f, p, K) + H = h = gf_frobenius_map([K.one, K.zero], f, b, p, K) + for i in range(0, n//2): + g = gf_sub(h, [K.one, K.zero], p, K) + if gf_gcd(f, g, p, K) == [K.one]: + h = gf_frobenius_map(h, f, b, p, K) + else: + return False + + return True + + +def gf_irred_p_rabin(f, p, K): + """ + Rabin's polynomial irreducibility test over finite fields. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_irred_p_rabin + + >>> gf_irred_p_rabin(ZZ.map([1, 4, 2, 2, 3, 2, 4, 1, 4, 0, 4]), 5, ZZ) + True + >>> gf_irred_p_rabin(ZZ.map([3, 2, 4]), 5, ZZ) + False + + """ + n = gf_degree(f) + + if n <= 1: + return True + + _, f = gf_monic(f, p, K) + + x = [K.one, K.zero] + + from sympy.ntheory import factorint + + indices = { n//d for d in factorint(n) } + + b = gf_frobenius_monomial_base(f, p, K) + h = b[1] + + for i in range(1, n): + if i in indices: + g = gf_sub(h, x, p, K) + + if gf_gcd(f, g, p, K) != [K.one]: + return False + + h = gf_frobenius_map(h, f, b, p, K) + + return h == x + +_irred_methods = { + 'ben-or': gf_irred_p_ben_or, + 'rabin': gf_irred_p_rabin, +} + + +def gf_irreducible_p(f, p, K): + """ + Test irreducibility of a polynomial ``f`` in ``GF(p)[x]``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_irreducible_p + + >>> gf_irreducible_p(ZZ.map([1, 4, 2, 2, 3, 2, 4, 1, 4, 0, 4]), 5, ZZ) + True + >>> gf_irreducible_p(ZZ.map([3, 2, 4]), 5, ZZ) + False + + """ + method = query('GF_IRRED_METHOD') + + if method is not None: + irred = _irred_methods[method](f, p, K) + else: + irred = gf_irred_p_rabin(f, p, K) + + return irred + + +def gf_sqf_p(f, p, K): + """ + Return ``True`` if ``f`` is square-free in ``GF(p)[x]``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_sqf_p + + >>> gf_sqf_p(ZZ.map([3, 2, 4]), 5, ZZ) + True + >>> gf_sqf_p(ZZ.map([2, 4, 4, 2, 2, 1, 4]), 5, ZZ) + False + + """ + _, f = gf_monic(f, p, K) + + if not f: + return True + else: + return gf_gcd(f, gf_diff(f, p, K), p, K) == [K.one] + + +def gf_sqf_part(f, p, K): + """ + Return square-free part of a ``GF(p)[x]`` polynomial. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_sqf_part + + >>> gf_sqf_part(ZZ.map([1, 1, 3, 0, 1, 0, 2, 2, 1]), 5, ZZ) + [1, 4, 3] + + """ + _, sqf = gf_sqf_list(f, p, K) + + g = [K.one] + + for f, _ in sqf: + g = gf_mul(g, f, p, K) + + return g + + +def gf_sqf_list(f, p, K, all=False): + """ + Return the square-free decomposition of a ``GF(p)[x]`` polynomial. + + Given a polynomial ``f`` in ``GF(p)[x]``, returns the leading coefficient + of ``f`` and a square-free decomposition ``f_1**e_1 f_2**e_2 ... f_k**e_k`` + such that all ``f_i`` are monic polynomials and ``(f_i, f_j)`` for ``i != j`` + are co-prime and ``e_1 ... e_k`` are given in increasing order. All trivial + terms (i.e. ``f_i = 1``) are not included in the output. + + Consider polynomial ``f = x**11 + 1`` over ``GF(11)[x]``:: + + >>> from sympy.polys.domains import ZZ + + >>> from sympy.polys.galoistools import ( + ... gf_from_dict, gf_diff, gf_sqf_list, gf_pow, + ... ) + ... # doctest: +NORMALIZE_WHITESPACE + + >>> f = gf_from_dict({11: ZZ(1), 0: ZZ(1)}, 11, ZZ) + + Note that ``f'(x) = 0``:: + + >>> gf_diff(f, 11, ZZ) + [] + + This phenomenon does not happen in characteristic zero. However we can + still compute square-free decomposition of ``f`` using ``gf_sqf()``:: + + >>> gf_sqf_list(f, 11, ZZ) + (1, [([1, 1], 11)]) + + We obtained factorization ``f = (x + 1)**11``. This is correct because:: + + >>> gf_pow([1, 1], 11, 11, ZZ) == f + True + + References + ========== + + .. [1] [Geddes92]_ + + """ + n, sqf, factors, r = 1, False, [], int(p) + + lc, f = gf_monic(f, p, K) + + if gf_degree(f) < 1: + return lc, [] + + while True: + F = gf_diff(f, p, K) + + if F != []: + g = gf_gcd(f, F, p, K) + h = gf_quo(f, g, p, K) + + i = 1 + + while h != [K.one]: + G = gf_gcd(g, h, p, K) + H = gf_quo(h, G, p, K) + + if gf_degree(H) > 0: + factors.append((H, i*n)) + + g, h, i = gf_quo(g, G, p, K), G, i + 1 + + if g == [K.one]: + sqf = True + else: + f = g + + if not sqf: + d = gf_degree(f) // r + + for i in range(0, d + 1): + f[i] = f[i*r] + + f, n = f[:d + 1], n*r + else: + break + + if all: + raise ValueError("'all=True' is not supported yet") + + return lc, factors + + +def gf_Qmatrix(f, p, K): + """ + Calculate Berlekamp's ``Q`` matrix. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_Qmatrix + + >>> gf_Qmatrix([3, 2, 4], 5, ZZ) + [[1, 0], + [3, 4]] + + >>> gf_Qmatrix([1, 0, 0, 0, 1], 5, ZZ) + [[1, 0, 0, 0], + [0, 4, 0, 0], + [0, 0, 1, 0], + [0, 0, 0, 4]] + + """ + n, r = gf_degree(f), int(p) + + q = [K.one] + [K.zero]*(n - 1) + Q = [list(q)] + [[]]*(n - 1) + + for i in range(1, (n - 1)*r + 1): + qq, c = [(-q[-1]*f[-1]) % p], q[-1] + + for j in range(1, n): + qq.append((q[j - 1] - c*f[-j - 1]) % p) + + if not (i % r): + Q[i//r] = list(qq) + + q = qq + + return Q + + +def gf_Qbasis(Q, p, K): + """ + Compute a basis of the kernel of ``Q``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_Qmatrix, gf_Qbasis + + >>> gf_Qbasis(gf_Qmatrix([1, 0, 0, 0, 1], 5, ZZ), 5, ZZ) + [[1, 0, 0, 0], [0, 0, 1, 0]] + + >>> gf_Qbasis(gf_Qmatrix([3, 2, 4], 5, ZZ), 5, ZZ) + [[1, 0]] + + """ + Q, n = [ list(q) for q in Q ], len(Q) + + for k in range(0, n): + Q[k][k] = (Q[k][k] - K.one) % p + + for k in range(0, n): + for i in range(k, n): + if Q[k][i]: + break + else: + continue + + inv = K.invert(Q[k][i], p) + + for j in range(0, n): + Q[j][i] = (Q[j][i]*inv) % p + + for j in range(0, n): + t = Q[j][k] + Q[j][k] = Q[j][i] + Q[j][i] = t + + for i in range(0, n): + if i != k: + q = Q[k][i] + + for j in range(0, n): + Q[j][i] = (Q[j][i] - Q[j][k]*q) % p + + for i in range(0, n): + for j in range(0, n): + if i == j: + Q[i][j] = (K.one - Q[i][j]) % p + else: + Q[i][j] = (-Q[i][j]) % p + + basis = [] + + for q in Q: + if any(q): + basis.append(q) + + return basis + + +def gf_berlekamp(f, p, K): + """ + Factor a square-free ``f`` in ``GF(p)[x]`` for small ``p``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_berlekamp + + >>> gf_berlekamp([1, 0, 0, 0, 1], 5, ZZ) + [[1, 0, 2], [1, 0, 3]] + + """ + Q = gf_Qmatrix(f, p, K) + V = gf_Qbasis(Q, p, K) + + for i, v in enumerate(V): + V[i] = gf_strip(list(reversed(v))) + + factors = [f] + + for k in range(1, len(V)): + for f in list(factors): + s = K.zero + + while s < p: + g = gf_sub_ground(V[k], s, p, K) + h = gf_gcd(f, g, p, K) + + if h != [K.one] and h != f: + factors.remove(f) + + f = gf_quo(f, h, p, K) + factors.extend([f, h]) + + if len(factors) == len(V): + return _sort_factors(factors, multiple=False) + + s += K.one + + return _sort_factors(factors, multiple=False) + + +def gf_ddf_zassenhaus(f, p, K): + """ + Cantor-Zassenhaus: Deterministic Distinct Degree Factorization + + Given a monic square-free polynomial ``f`` in ``GF(p)[x]``, computes + partial distinct degree factorization ``f_1 ... f_d`` of ``f`` where + ``deg(f_i) != deg(f_j)`` for ``i != j``. The result is returned as a + list of pairs ``(f_i, e_i)`` where ``deg(f_i) > 0`` and ``e_i > 0`` + is an argument to the equal degree factorization routine. + + Consider the polynomial ``x**15 - 1`` in ``GF(11)[x]``:: + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_from_dict + + >>> f = gf_from_dict({15: ZZ(1), 0: ZZ(-1)}, 11, ZZ) + + Distinct degree factorization gives:: + + >>> from sympy.polys.galoistools import gf_ddf_zassenhaus + + >>> gf_ddf_zassenhaus(f, 11, ZZ) + [([1, 0, 0, 0, 0, 10], 1), ([1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1], 2)] + + which means ``x**15 - 1 = (x**5 - 1) (x**10 + x**5 + 1)``. To obtain + factorization into irreducibles, use equal degree factorization + procedure (EDF) with each of the factors. + + References + ========== + + .. [1] [Gathen99]_ + .. [2] [Geddes92]_ + + """ + i, g, factors = 1, [K.one, K.zero], [] + + b = gf_frobenius_monomial_base(f, p, K) + while 2*i <= gf_degree(f): + g = gf_frobenius_map(g, f, b, p, K) + h = gf_gcd(f, gf_sub(g, [K.one, K.zero], p, K), p, K) + + if h != [K.one]: + factors.append((h, i)) + + f = gf_quo(f, h, p, K) + g = gf_rem(g, f, p, K) + b = gf_frobenius_monomial_base(f, p, K) + + i += 1 + + if f != [K.one]: + return factors + [(f, gf_degree(f))] + else: + return factors + + +def gf_edf_zassenhaus(f, n, p, K): + """ + Cantor-Zassenhaus: Probabilistic Equal Degree Factorization + + Given a monic square-free polynomial ``f`` in ``GF(p)[x]`` and + an integer ``n``, such that ``n`` divides ``deg(f)``, returns all + irreducible factors ``f_1,...,f_d`` of ``f``, each of degree ``n``. + EDF procedure gives complete factorization over Galois fields. + + Consider the square-free polynomial ``f = x**3 + x**2 + x + 1`` in + ``GF(5)[x]``. Let's compute its irreducible factors of degree one:: + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_edf_zassenhaus + + >>> gf_edf_zassenhaus([1,1,1,1], 1, 5, ZZ) + [[1, 1], [1, 2], [1, 3]] + + Notes + ===== + + The case p == 2 is handled by Cohen's Algorithm 3.4.8. The case p odd is + as in Geddes Algorithm 8.9 (or Cohen's Algorithm 3.4.6). + + References + ========== + + .. [1] [Gathen99]_ + .. [2] [Geddes92]_ Algorithm 8.9 + .. [3] [Cohen93]_ Algorithm 3.4.8 + + """ + factors = [f] + + if gf_degree(f) <= n: + return factors + + N = gf_degree(f) // n + if p != 2: + b = gf_frobenius_monomial_base(f, p, K) + + t = [K.one, K.zero] + while len(factors) < N: + if p == 2: + h = r = t + + for i in range(n - 1): + r = gf_pow_mod(r, 2, f, p, K) + h = gf_add(h, r, p, K) + + g = gf_gcd(f, h, p, K) + t += [K.zero, K.zero] + else: + r = gf_random(2 * n - 1, p, K) + h = _gf_pow_pnm1d2(r, n, f, b, p, K) + g = gf_gcd(f, gf_sub_ground(h, K.one, p, K), p, K) + + if g != [K.one] and g != f: + factors = gf_edf_zassenhaus(g, n, p, K) \ + + gf_edf_zassenhaus(gf_quo(f, g, p, K), n, p, K) + + return _sort_factors(factors, multiple=False) + + +def gf_ddf_shoup(f, p, K): + """ + Kaltofen-Shoup: Deterministic Distinct Degree Factorization + + Given a monic square-free polynomial ``f`` in ``GF(p)[x]``, computes + partial distinct degree factorization ``f_1,...,f_d`` of ``f`` where + ``deg(f_i) != deg(f_j)`` for ``i != j``. The result is returned as a + list of pairs ``(f_i, e_i)`` where ``deg(f_i) > 0`` and ``e_i > 0`` + is an argument to the equal degree factorization routine. + + This algorithm is an improved version of Zassenhaus algorithm for + large ``deg(f)`` and modulus ``p`` (especially for ``deg(f) ~ lg(p)``). + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_ddf_shoup, gf_from_dict + + >>> f = gf_from_dict({6: ZZ(1), 5: ZZ(-1), 4: ZZ(1), 3: ZZ(1), 1: ZZ(-1)}, 3, ZZ) + + >>> gf_ddf_shoup(f, 3, ZZ) + [([1, 1, 0], 1), ([1, 1, 0, 1, 2], 2)] + + References + ========== + + .. [1] [Kaltofen98]_ + .. [2] [Shoup95]_ + .. [3] [Gathen92]_ + + """ + n = gf_degree(f) + k = int(_ceil(_sqrt(n//2))) + b = gf_frobenius_monomial_base(f, p, K) + h = gf_frobenius_map([K.one, K.zero], f, b, p, K) + # U[i] = x**(p**i) + U = [[K.one, K.zero], h] + [K.zero]*(k - 1) + + for i in range(2, k + 1): + U[i] = gf_frobenius_map(U[i-1], f, b, p, K) + + h, U = U[k], U[:k] + # V[i] = x**(p**(k*(i+1))) + V = [h] + [K.zero]*(k - 1) + + for i in range(1, k): + V[i] = gf_compose_mod(V[i - 1], h, f, p, K) + + factors = [] + + for i, v in enumerate(V): + h, j = [K.one], k - 1 + + for u in U: + g = gf_sub(v, u, p, K) + h = gf_mul(h, g, p, K) + h = gf_rem(h, f, p, K) + + g = gf_gcd(f, h, p, K) + f = gf_quo(f, g, p, K) + + for u in reversed(U): + h = gf_sub(v, u, p, K) + F = gf_gcd(g, h, p, K) + + if F != [K.one]: + factors.append((F, k*(i + 1) - j)) + + g, j = gf_quo(g, F, p, K), j - 1 + + if f != [K.one]: + factors.append((f, gf_degree(f))) + + return factors + +def gf_edf_shoup(f, n, p, K): + """ + Gathen-Shoup: Probabilistic Equal Degree Factorization + + Given a monic square-free polynomial ``f`` in ``GF(p)[x]`` and integer + ``n`` such that ``n`` divides ``deg(f)``, returns all irreducible factors + ``f_1,...,f_d`` of ``f``, each of degree ``n``. This is a complete + factorization over Galois fields. + + This algorithm is an improved version of Zassenhaus algorithm for + large ``deg(f)`` and modulus ``p`` (especially for ``deg(f) ~ lg(p)``). + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_edf_shoup + + >>> gf_edf_shoup(ZZ.map([1, 2837, 2277]), 1, 2917, ZZ) + [[1, 852], [1, 1985]] + + References + ========== + + .. [1] [Shoup91]_ + .. [2] [Gathen92]_ + + """ + N, q = gf_degree(f), int(p) + + if not N: + return [] + if N <= n: + return [f] + + factors, x = [f], [K.one, K.zero] + + r = gf_random(N - 1, p, K) + + if p == 2: + h = gf_pow_mod(x, q, f, p, K) + H = gf_trace_map(r, h, x, n - 1, f, p, K)[1] + h1 = gf_gcd(f, H, p, K) + h2 = gf_quo(f, h1, p, K) + + factors = gf_edf_shoup(h1, n, p, K) \ + + gf_edf_shoup(h2, n, p, K) + else: + b = gf_frobenius_monomial_base(f, p, K) + H = _gf_trace_map(r, n, f, b, p, K) + h = gf_pow_mod(H, (q - 1)//2, f, p, K) + + h1 = gf_gcd(f, h, p, K) + h2 = gf_gcd(f, gf_sub_ground(h, K.one, p, K), p, K) + h3 = gf_quo(f, gf_mul(h1, h2, p, K), p, K) + + factors = gf_edf_shoup(h1, n, p, K) \ + + gf_edf_shoup(h2, n, p, K) \ + + gf_edf_shoup(h3, n, p, K) + + return _sort_factors(factors, multiple=False) + + +def gf_zassenhaus(f, p, K): + """ + Factor a square-free ``f`` in ``GF(p)[x]`` for medium ``p``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_zassenhaus + + >>> gf_zassenhaus(ZZ.map([1, 4, 3]), 5, ZZ) + [[1, 1], [1, 3]] + + """ + factors = [] + + for factor, n in gf_ddf_zassenhaus(f, p, K): + factors += gf_edf_zassenhaus(factor, n, p, K) + + return _sort_factors(factors, multiple=False) + + +def gf_shoup(f, p, K): + """ + Factor a square-free ``f`` in ``GF(p)[x]`` for large ``p``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_shoup + + >>> gf_shoup(ZZ.map([1, 4, 3]), 5, ZZ) + [[1, 1], [1, 3]] + + """ + factors = [] + + for factor, n in gf_ddf_shoup(f, p, K): + factors += gf_edf_shoup(factor, n, p, K) + + return _sort_factors(factors, multiple=False) + +_factor_methods = { + 'berlekamp': gf_berlekamp, # ``p`` : small + 'zassenhaus': gf_zassenhaus, # ``p`` : medium + 'shoup': gf_shoup, # ``p`` : large +} + + +def gf_factor_sqf(f, p, K, method=None): + """ + Factor a square-free polynomial ``f`` in ``GF(p)[x]``. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_factor_sqf + + >>> gf_factor_sqf(ZZ.map([3, 2, 4]), 5, ZZ) + (3, [[1, 1], [1, 3]]) + + """ + lc, f = gf_monic(f, p, K) + + if gf_degree(f) < 1: + return lc, [] + + method = method or query('GF_FACTOR_METHOD') + + if method is not None: + factors = _factor_methods[method](f, p, K) + else: + factors = gf_zassenhaus(f, p, K) + + return lc, factors + + +def gf_factor(f, p, K): + """ + Factor (non square-free) polynomials in ``GF(p)[x]``. + + Given a possibly non square-free polynomial ``f`` in ``GF(p)[x]``, + returns its complete factorization into irreducibles:: + + f_1(x)**e_1 f_2(x)**e_2 ... f_d(x)**e_d + + where each ``f_i`` is a monic polynomial and ``gcd(f_i, f_j) == 1``, + for ``i != j``. The result is given as a tuple consisting of the + leading coefficient of ``f`` and a list of factors of ``f`` with + their multiplicities. + + The algorithm proceeds by first computing square-free decomposition + of ``f`` and then iteratively factoring each of square-free factors. + + Consider a non square-free polynomial ``f = (7*x + 1) (x + 2)**2`` in + ``GF(11)[x]``. We obtain its factorization into irreducibles as follows:: + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.galoistools import gf_factor + + >>> gf_factor(ZZ.map([5, 2, 7, 2]), 11, ZZ) + (5, [([1, 2], 1), ([1, 8], 2)]) + + We arrived with factorization ``f = 5 (x + 2) (x + 8)**2``. We did not + recover the exact form of the input polynomial because we requested to + get monic factors of ``f`` and its leading coefficient separately. + + Square-free factors of ``f`` can be factored into irreducibles over + ``GF(p)`` using three very different methods: + + Berlekamp + efficient for very small values of ``p`` (usually ``p < 25``) + Cantor-Zassenhaus + efficient on average input and with "typical" ``p`` + Shoup-Kaltofen-Gathen + efficient with very large inputs and modulus + + If you want to use a specific factorization method, instead of the default + one, set ``GF_FACTOR_METHOD`` with one of ``berlekamp``, ``zassenhaus`` or + ``shoup`` values. + + References + ========== + + .. [1] [Gathen99]_ + + """ + lc, f = gf_monic(f, p, K) + + if gf_degree(f) < 1: + return lc, [] + + factors = [] + + for g, n in gf_sqf_list(f, p, K)[1]: + for h in gf_factor_sqf(g, p, K)[1]: + factors.append((h, n)) + + return lc, _sort_factors(factors) + + +def gf_value(f, a): + """ + Value of polynomial 'f' at 'a' in field R. + + Examples + ======== + + >>> from sympy.polys.galoistools import gf_value + + >>> gf_value([1, 7, 2, 4], 11) + 2204 + + """ + result = 0 + for c in f: + result *= a + result += c + return result + + +def linear_congruence(a, b, m): + """ + Returns the values of x satisfying a*x congruent b mod(m) + + Here m is positive integer and a, b are natural numbers. + This function returns only those values of x which are distinct mod(m). + + Examples + ======== + + >>> from sympy.polys.galoistools import linear_congruence + + >>> linear_congruence(3, 12, 15) + [4, 9, 14] + + There are 3 solutions distinct mod(15) since gcd(a, m) = gcd(3, 15) = 3. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Linear_congruence_theorem + + """ + from sympy.polys.polytools import gcdex + if a % m == 0: + if b % m == 0: + return list(range(m)) + else: + return [] + r, _, g = gcdex(a, m) + if b % g != 0: + return [] + return [(r * b // g + t * m // g) % m for t in range(g)] + + +def _raise_mod_power(x, s, p, f): + """ + Used in gf_csolve to generate solutions of f(x) cong 0 mod(p**(s + 1)) + from the solutions of f(x) cong 0 mod(p**s). + + Examples + ======== + + >>> from sympy.polys.galoistools import _raise_mod_power + >>> from sympy.polys.galoistools import csolve_prime + + These is the solutions of f(x) = x**2 + x + 7 cong 0 mod(3) + + >>> f = [1, 1, 7] + >>> csolve_prime(f, 3) + [1] + >>> [ i for i in range(3) if not (i**2 + i + 7) % 3] + [1] + + The solutions of f(x) cong 0 mod(9) are constructed from the + values returned from _raise_mod_power: + + >>> x, s, p = 1, 1, 3 + >>> V = _raise_mod_power(x, s, p, f) + >>> [x + v * p**s for v in V] + [1, 4, 7] + + And these are confirmed with the following: + + >>> [ i for i in range(3**2) if not (i**2 + i + 7) % 3**2] + [1, 4, 7] + + """ + from sympy.polys.domains import ZZ + f_f = gf_diff(f, p, ZZ) + alpha = gf_value(f_f, x) + beta = - gf_value(f, x) // p**s + return linear_congruence(alpha, beta, p) + + +def _csolve_prime_las_vegas(f, p, seed=None): + r""" Solutions of `f(x) \equiv 0 \pmod{p}`, `f(0) \not\equiv 0 \pmod{p}`. + + Explanation + =========== + + This algorithm is classified as the Las Vegas method. + That is, it always returns the correct answer and solves the problem + fast in many cases, but if it is unlucky, it does not answer forever. + + Suppose the polynomial f is not a zero polynomial. Assume further + that it is of degree at most p-1 and `f(0)\not\equiv 0 \pmod{p}`. + These assumptions are not an essential part of the algorithm, + only that it is more convenient for the function calling this + function to resolve them. + + Note that `x^{p-1} - 1 \equiv \prod_{a=1}^{p-1}(x - a) \pmod{p}`. + Thus, the greatest common divisor with f is `\prod_{s \in S}(x - s)`, + with S being the set of solutions to f. Furthermore, + when a is randomly determined, `(x+a)^{(p-1)/2}-1` is + a polynomial with (p-1)/2 randomly chosen solutions. + The greatest common divisor of f may be a nontrivial factor of f. + + When p is large and the degree of f is small, + it is faster than naive solution methods. + + Parameters + ========== + + f : polynomial + p : prime number + + Returns + ======= + + list[int] + a list of solutions, sorted in ascending order + by integers in the range [1, p). The same value + does not exist in the list even if there is + a multiple solution. If no solution exists, returns []. + + Examples + ======== + + >>> from sympy.polys.galoistools import _csolve_prime_las_vegas + >>> _csolve_prime_las_vegas([1, 4, 3], 7) # x^2 + 4x + 3 = 0 (mod 7) + [4, 6] + >>> _csolve_prime_las_vegas([5, 7, 1, 9], 11) # 5x^3 + 7x^2 + x + 9 = 0 (mod 11) + [1, 5, 8] + + References + ========== + + .. [1] R. Crandall and C. Pomerance "Prime Numbers", 2nd Ed., Algorithm 2.3.10 + + """ + from sympy.polys.domains import ZZ + from sympy.ntheory import sqrt_mod + randint = _randint(seed) + root = set() + g = gf_pow_mod([1, 0], p - 1, f, p, ZZ) + g = gf_sub_ground(g, 1, p, ZZ) + # We want to calculate gcd(x**(p-1) - 1, f(x)) + factors = [gf_gcd(f, g, p, ZZ)] + while factors: + f = factors.pop() + # If the degree is small, solve directly + if len(f) <= 1: + continue + if len(f) == 2: + root.add(-invert(f[0], p) * f[1] % p) + continue + if len(f) == 3: + inv = invert(f[0], p) + b = f[1] * inv % p + b = (b + p * (b % 2)) // 2 + root.update((r - b) % p for r in + sqrt_mod(b**2 - f[2] * inv, p, all_roots=True)) + continue + while True: + # Determine `a` randomly and + # compute gcd((x+a)**((p-1)//2)-1, f(x)) + a = randint(0, p - 1) + g = gf_pow_mod([1, a], (p - 1) // 2, f, p, ZZ) + g = gf_sub_ground(g, 1, p, ZZ) + g = gf_gcd(f, g, p, ZZ) + if 1 < len(g) < len(f): + factors.append(g) + factors.append(gf_div(f, g, p, ZZ)[0]) + break + return sorted(root) + + +def csolve_prime(f, p, e=1): + r""" Solutions of `f(x) \equiv 0 \pmod{p^e}`. + + Parameters + ========== + + f : polynomial + p : prime number + e : positive integer + + Returns + ======= + + list[int] + a list of solutions, sorted in ascending order + by integers in the range [1, p**e). The same value + does not exist in the list even if there is + a multiple solution. If no solution exists, returns []. + + Examples + ======== + + >>> from sympy.polys.galoistools import csolve_prime + >>> csolve_prime([1, 1, 7], 3, 1) + [1] + >>> csolve_prime([1, 1, 7], 3, 2) + [1, 4, 7] + + Solutions [7, 4, 1] (mod 3**2) are generated by ``_raise_mod_power()`` + from solution [1] (mod 3). + """ + from sympy.polys.domains import ZZ + g = [MPZ(int(c)) for c in f] + # Convert to polynomial of degree at most p-1 + for i in range(len(g) - p): + g[i + p - 1] += g[i] + g[i] = 0 + g = gf_trunc(g, p) + # Checks whether g(x) is divisible by x + k = 0 + while k < len(g) and g[len(g) - k - 1] == 0: + k += 1 + if k: + g = g[:-k] + root_zero = [0] + else: + root_zero = [] + if g == []: + X1 = list(range(p)) + elif len(g)**2 < p: + # The conditions under which `_csolve_prime_las_vegas` is faster than + # a naive solution are worth considering. + X1 = root_zero + _csolve_prime_las_vegas(g, p) + else: + X1 = root_zero + [i for i in range(p) if gf_eval(g, i, p, ZZ) == 0] + if e == 1: + return X1 + X = [] + S = list(zip(X1, [1]*len(X1))) + while S: + x, s = S.pop() + if s == e: + X.append(x) + else: + s1 = s + 1 + ps = p**s + S.extend([(x + v*ps, s1) for v in _raise_mod_power(x, s, p, f)]) + return sorted(X) + + +def gf_csolve(f, n): + """ + To solve f(x) congruent 0 mod(n). + + n is divided into canonical factors and f(x) cong 0 mod(p**e) will be + solved for each factor. Applying the Chinese Remainder Theorem to the + results returns the final answers. + + Examples + ======== + + Solve [1, 1, 7] congruent 0 mod(189): + + >>> from sympy.polys.galoistools import gf_csolve + >>> gf_csolve([1, 1, 7], 189) + [13, 49, 76, 112, 139, 175] + + See Also + ======== + + sympy.ntheory.residue_ntheory.polynomial_congruence : a higher level solving routine + + References + ========== + + .. [1] 'An introduction to the Theory of Numbers' 5th Edition by Ivan Niven, + Zuckerman and Montgomery. + + """ + from sympy.polys.domains import ZZ + from sympy.ntheory import factorint + P = factorint(n) + X = [csolve_prime(f, p, e) for p, e in P.items()] + pools = list(map(tuple, X)) + perms = [[]] + for pool in pools: + perms = [x + [y] for x in perms for y in pool] + dist_factors = [pow(p, e) for p, e in P.items()] + return sorted([gf_crt(per, dist_factors, ZZ) for per in perms]) diff --git a/phivenv/Lib/site-packages/sympy/polys/groebnertools.py b/phivenv/Lib/site-packages/sympy/polys/groebnertools.py new file mode 100644 index 0000000000000000000000000000000000000000..fc5c2f228ab4f4182e4c8fff68d974aa25c9d531 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/groebnertools.py @@ -0,0 +1,862 @@ +"""Groebner bases algorithms. """ + + +from sympy.core.symbol import Dummy +from sympy.polys.monomials import monomial_mul, monomial_lcm, monomial_divides, term_div +from sympy.polys.orderings import lex +from sympy.polys.polyerrors import DomainError +from sympy.polys.polyconfig import query + +def groebner(seq, ring, method=None): + """ + Computes Groebner basis for a set of polynomials in `K[X]`. + + Wrapper around the (default) improved Buchberger and the other algorithms + for computing Groebner bases. The choice of algorithm can be changed via + ``method`` argument or :func:`sympy.polys.polyconfig.setup`, where + ``method`` can be either ``buchberger`` or ``f5b``. + + """ + if method is None: + method = query('groebner') + + _groebner_methods = { + 'buchberger': _buchberger, + 'f5b': _f5b, + } + + try: + _groebner = _groebner_methods[method] + except KeyError: + raise ValueError("'%s' is not a valid Groebner bases algorithm (valid are 'buchberger' and 'f5b')" % method) + + domain, orig = ring.domain, None + + if not domain.is_Field or not domain.has_assoc_Field: + try: + orig, ring = ring, ring.clone(domain=domain.get_field()) + except DomainError: + raise DomainError("Cannot compute a Groebner basis over %s" % domain) + else: + seq = [ s.set_ring(ring) for s in seq ] + + G = _groebner(seq, ring) + + if orig is not None: + G = [ g.clear_denoms()[1].set_ring(orig) for g in G ] + + return G + +def _buchberger(f, ring): + """ + Computes Groebner basis for a set of polynomials in `K[X]`. + + Given a set of multivariate polynomials `F`, finds another + set `G`, such that Ideal `F = Ideal G` and `G` is a reduced + Groebner basis. + + The resulting basis is unique and has monic generators if the + ground domains is a field. Otherwise the result is non-unique + but Groebner bases over e.g. integers can be computed (if the + input polynomials are monic). + + Groebner bases can be used to choose specific generators for a + polynomial ideal. Because these bases are unique you can check + for ideal equality by comparing the Groebner bases. To see if + one polynomial lies in an ideal, divide by the elements in the + base and see if the remainder vanishes. + + They can also be used to solve systems of polynomial equations + as, by choosing lexicographic ordering, you can eliminate one + variable at a time, provided that the ideal is zero-dimensional + (finite number of solutions). + + Notes + ===== + + Algorithm used: an improved version of Buchberger's algorithm + as presented in T. Becker, V. Weispfenning, Groebner Bases: A + Computational Approach to Commutative Algebra, Springer, 1993, + page 232. + + References + ========== + + .. [1] [Bose03]_ + .. [2] [Giovini91]_ + .. [3] [Ajwa95]_ + .. [4] [Cox97]_ + + """ + order = ring.order + + monomial_mul = ring.monomial_mul + monomial_div = ring.monomial_div + monomial_lcm = ring.monomial_lcm + + def select(P): + # normal selection strategy + # select the pair with minimum LCM(LM(f), LM(g)) + pr = min(P, key=lambda pair: order(monomial_lcm(f[pair[0]].LM, f[pair[1]].LM))) + return pr + + def normal(g, J): + h = g.rem([ f[j] for j in J ]) + + if not h: + return None + else: + h = h.monic() + + if h not in I: + I[h] = len(f) + f.append(h) + + return h.LM, I[h] + + def update(G, B, ih): + # update G using the set of critical pairs B and h + # [BW] page 230 + h = f[ih] + mh = h.LM + + # filter new pairs (h, g), g in G + C = G.copy() + D = set() + + while C: + # select a pair (h, g) by popping an element from C + ig = C.pop() + g = f[ig] + mg = g.LM + LCMhg = monomial_lcm(mh, mg) + + def lcm_divides(ip): + # LCM(LM(h), LM(p)) divides LCM(LM(h), LM(g)) + m = monomial_lcm(mh, f[ip].LM) + return monomial_div(LCMhg, m) + + # HT(h) and HT(g) disjoint: mh*mg == LCMhg + if monomial_mul(mh, mg) == LCMhg or ( + not any(lcm_divides(ipx) for ipx in C) and + not any(lcm_divides(pr[1]) for pr in D)): + D.add((ih, ig)) + + E = set() + + while D: + # select h, g from D (h the same as above) + ih, ig = D.pop() + mg = f[ig].LM + LCMhg = monomial_lcm(mh, mg) + + if not monomial_mul(mh, mg) == LCMhg: + E.add((ih, ig)) + + # filter old pairs + B_new = set() + + while B: + # select g1, g2 from B (-> CP) + ig1, ig2 = B.pop() + mg1 = f[ig1].LM + mg2 = f[ig2].LM + LCM12 = monomial_lcm(mg1, mg2) + + # if HT(h) does not divide lcm(HT(g1), HT(g2)) + if not monomial_div(LCM12, mh) or \ + monomial_lcm(mg1, mh) == LCM12 or \ + monomial_lcm(mg2, mh) == LCM12: + B_new.add((ig1, ig2)) + + B_new |= E + + # filter polynomials + G_new = set() + + while G: + ig = G.pop() + mg = f[ig].LM + + if not monomial_div(mg, mh): + G_new.add(ig) + + G_new.add(ih) + + return G_new, B_new + # end of update ################################ + + if not f: + return [] + + # replace f with a reduced list of initial polynomials; see [BW] page 203 + f1 = f[:] + + while True: + f = f1[:] + f1 = [] + + for i in range(len(f)): + p = f[i] + r = p.rem(f[:i]) + + if r: + f1.append(r.monic()) + + if f == f1: + break + + I = {} # ip = I[p]; p = f[ip] + F = set() # set of indices of polynomials + G = set() # set of indices of intermediate would-be Groebner basis + CP = set() # set of pairs of indices of critical pairs + + for i, h in enumerate(f): + I[h] = i + F.add(i) + + ##################################### + # algorithm GROEBNERNEWS2 in [BW] page 232 + + while F: + # select p with minimum monomial according to the monomial ordering + h = min([f[x] for x in F], key=lambda f: order(f.LM)) + ih = I[h] + F.remove(ih) + G, CP = update(G, CP, ih) + + # count the number of critical pairs which reduce to zero + reductions_to_zero = 0 + + while CP: + ig1, ig2 = select(CP) + CP.remove((ig1, ig2)) + + h = spoly(f[ig1], f[ig2], ring) + # ordering divisors is on average more efficient [Cox] page 111 + G1 = sorted(G, key=lambda g: order(f[g].LM)) + ht = normal(h, G1) + + if ht: + G, CP = update(G, CP, ht[1]) + else: + reductions_to_zero += 1 + + ###################################### + # now G is a Groebner basis; reduce it + Gr = set() + + for ig in G: + ht = normal(f[ig], G - {ig}) + + if ht: + Gr.add(ht[1]) + + Gr = [f[ig] for ig in Gr] + + # order according to the monomial ordering + Gr = sorted(Gr, key=lambda f: order(f.LM), reverse=True) + + return Gr + +def spoly(p1, p2, ring): + """ + Compute LCM(LM(p1), LM(p2))/LM(p1)*p1 - LCM(LM(p1), LM(p2))/LM(p2)*p2 + This is the S-poly provided p1 and p2 are monic + """ + LM1 = p1.LM + LM2 = p2.LM + LCM12 = ring.monomial_lcm(LM1, LM2) + m1 = ring.monomial_div(LCM12, LM1) + m2 = ring.monomial_div(LCM12, LM2) + s1 = p1.mul_monom(m1) + s2 = p2.mul_monom(m2) + s = s1 - s2 + return s + +# F5B + +# convenience functions + + +def Sign(f): + return f[0] + + +def Polyn(f): + return f[1] + + +def Num(f): + return f[2] + + +def sig(monomial, index): + return (monomial, index) + + +def lbp(signature, polynomial, number): + return (signature, polynomial, number) + +# signature functions + + +def sig_cmp(u, v, order): + """ + Compare two signatures by extending the term order to K[X]^n. + + u < v iff + - the index of v is greater than the index of u + or + - the index of v is equal to the index of u and u[0] < v[0] w.r.t. order + + u > v otherwise + """ + if u[1] > v[1]: + return -1 + if u[1] == v[1]: + #if u[0] == v[0]: + # return 0 + if order(u[0]) < order(v[0]): + return -1 + return 1 + + +def sig_key(s, order): + """ + Key for comparing two signatures. + + s = (m, k), t = (n, l) + + s < t iff [k > l] or [k == l and m < n] + s > t otherwise + """ + return (-s[1], order(s[0])) + + +def sig_mult(s, m): + """ + Multiply a signature by a monomial. + + The product of a signature (m, i) and a monomial n is defined as + (m * t, i). + """ + return sig(monomial_mul(s[0], m), s[1]) + +# labeled polynomial functions + + +def lbp_sub(f, g): + """ + Subtract labeled polynomial g from f. + + The signature and number of the difference of f and g are signature + and number of the maximum of f and g, w.r.t. lbp_cmp. + """ + if sig_cmp(Sign(f), Sign(g), Polyn(f).ring.order) < 0: + max_poly = g + else: + max_poly = f + + ret = Polyn(f) - Polyn(g) + + return lbp(Sign(max_poly), ret, Num(max_poly)) + + +def lbp_mul_term(f, cx): + """ + Multiply a labeled polynomial with a term. + + The product of a labeled polynomial (s, p, k) by a monomial is + defined as (m * s, m * p, k). + """ + return lbp(sig_mult(Sign(f), cx[0]), Polyn(f).mul_term(cx), Num(f)) + + +def lbp_cmp(f, g): + """ + Compare two labeled polynomials. + + f < g iff + - Sign(f) < Sign(g) + or + - Sign(f) == Sign(g) and Num(f) > Num(g) + + f > g otherwise + """ + if sig_cmp(Sign(f), Sign(g), Polyn(f).ring.order) == -1: + return -1 + if Sign(f) == Sign(g): + if Num(f) > Num(g): + return -1 + #if Num(f) == Num(g): + # return 0 + return 1 + + +def lbp_key(f): + """ + Key for comparing two labeled polynomials. + """ + return (sig_key(Sign(f), Polyn(f).ring.order), -Num(f)) + +# algorithm and helper functions + + +def critical_pair(f, g, ring): + """ + Compute the critical pair corresponding to two labeled polynomials. + + A critical pair is a tuple (um, f, vm, g), where um and vm are + terms such that um * f - vm * g is the S-polynomial of f and g (so, + wlog assume um * f > vm * g). + For performance sake, a critical pair is represented as a tuple + (Sign(um * f), um, f, Sign(vm * g), vm, g), since um * f creates + a new, relatively expensive object in memory, whereas Sign(um * + f) and um are lightweight and f (in the tuple) is a reference to + an already existing object in memory. + """ + domain = ring.domain + + ltf = Polyn(f).LT + ltg = Polyn(g).LT + lt = (monomial_lcm(ltf[0], ltg[0]), domain.one) + + um = term_div(lt, ltf, domain) + vm = term_div(lt, ltg, domain) + + # The full information is not needed (now), so only the product + # with the leading term is considered: + fr = lbp_mul_term(lbp(Sign(f), Polyn(f).leading_term(), Num(f)), um) + gr = lbp_mul_term(lbp(Sign(g), Polyn(g).leading_term(), Num(g)), vm) + + # return in proper order, such that the S-polynomial is just + # u_first * f_first - u_second * f_second: + if lbp_cmp(fr, gr) == -1: + return (Sign(gr), vm, g, Sign(fr), um, f) + else: + return (Sign(fr), um, f, Sign(gr), vm, g) + + +def cp_cmp(c, d): + """ + Compare two critical pairs c and d. + + c < d iff + - lbp(c[0], _, Num(c[2]) < lbp(d[0], _, Num(d[2])) (this + corresponds to um_c * f_c and um_d * f_d) + or + - lbp(c[0], _, Num(c[2]) >< lbp(d[0], _, Num(d[2])) and + lbp(c[3], _, Num(c[5])) < lbp(d[3], _, Num(d[5])) (this + corresponds to vm_c * g_c and vm_d * g_d) + + c > d otherwise + """ + zero = Polyn(c[2]).ring.zero + + c0 = lbp(c[0], zero, Num(c[2])) + d0 = lbp(d[0], zero, Num(d[2])) + + r = lbp_cmp(c0, d0) + + if r == -1: + return -1 + if r == 0: + c1 = lbp(c[3], zero, Num(c[5])) + d1 = lbp(d[3], zero, Num(d[5])) + + r = lbp_cmp(c1, d1) + + if r == -1: + return -1 + #if r == 0: + # return 0 + return 1 + + +def cp_key(c, ring): + """ + Key for comparing critical pairs. + """ + return (lbp_key(lbp(c[0], ring.zero, Num(c[2]))), lbp_key(lbp(c[3], ring.zero, Num(c[5])))) + + +def s_poly(cp): + """ + Compute the S-polynomial of a critical pair. + + The S-polynomial of a critical pair cp is cp[1] * cp[2] - cp[4] * cp[5]. + """ + return lbp_sub(lbp_mul_term(cp[2], cp[1]), lbp_mul_term(cp[5], cp[4])) + + +def is_rewritable_or_comparable(sign, num, B): + """ + Check if a labeled polynomial is redundant by checking if its + signature and number imply rewritability or comparability. + + (sign, num) is comparable if there exists a labeled polynomial + h in B, such that sign[1] (the index) is less than Sign(h)[1] + and sign[0] is divisible by the leading monomial of h. + + (sign, num) is rewritable if there exists a labeled polynomial + h in B, such thatsign[1] is equal to Sign(h)[1], num < Num(h) + and sign[0] is divisible by Sign(h)[0]. + """ + for h in B: + # comparable + if sign[1] < Sign(h)[1]: + if monomial_divides(Polyn(h).LM, sign[0]): + return True + + # rewritable + if sign[1] == Sign(h)[1]: + if num < Num(h): + if monomial_divides(Sign(h)[0], sign[0]): + return True + return False + + +def f5_reduce(f, B): + """ + F5-reduce a labeled polynomial f by B. + + Continuously searches for non-zero labeled polynomial h in B, such + that the leading term lt_h of h divides the leading term lt_f of + f and Sign(lt_h * h) < Sign(f). If such a labeled polynomial h is + found, f gets replaced by f - lt_f / lt_h * h. If no such h can be + found or f is 0, f is no further F5-reducible and f gets returned. + + A polynomial that is reducible in the usual sense need not be + F5-reducible, e.g.: + + >>> from sympy.polys.groebnertools import lbp, sig, f5_reduce, Polyn + >>> from sympy.polys import ring, QQ, lex + + >>> R, x,y,z = ring("x,y,z", QQ, lex) + + >>> f = lbp(sig((1, 1, 1), 4), x, 3) + >>> g = lbp(sig((0, 0, 0), 2), x, 2) + + >>> Polyn(f).rem([Polyn(g)]) + 0 + >>> f5_reduce(f, [g]) + (((1, 1, 1), 4), x, 3) + + """ + order = Polyn(f).ring.order + domain = Polyn(f).ring.domain + + if not Polyn(f): + return f + + while True: + g = f + + for h in B: + if Polyn(h): + if monomial_divides(Polyn(h).LM, Polyn(f).LM): + t = term_div(Polyn(f).LT, Polyn(h).LT, domain) + if sig_cmp(sig_mult(Sign(h), t[0]), Sign(f), order) < 0: + # The following check need not be done and is in general slower than without. + #if not is_rewritable_or_comparable(Sign(gp), Num(gp), B): + hp = lbp_mul_term(h, t) + f = lbp_sub(f, hp) + break + + if g == f or not Polyn(f): + return f + + +def _f5b(F, ring): + """ + Computes a reduced Groebner basis for the ideal generated by F. + + f5b is an implementation of the F5B algorithm by Yao Sun and + Dingkang Wang. Similarly to Buchberger's algorithm, the algorithm + proceeds by computing critical pairs, computing the S-polynomial, + reducing it and adjoining the reduced S-polynomial if it is not 0. + + Unlike Buchberger's algorithm, each polynomial contains additional + information, namely a signature and a number. The signature + specifies the path of computation (i.e. from which polynomial in + the original basis was it derived and how), the number says when + the polynomial was added to the basis. With this information it + is (often) possible to decide if an S-polynomial will reduce to + 0 and can be discarded. + + Optimizations include: Reducing the generators before computing + a Groebner basis, removing redundant critical pairs when a new + polynomial enters the basis and sorting the critical pairs and + the current basis. + + Once a Groebner basis has been found, it gets reduced. + + References + ========== + + .. [1] Yao Sun, Dingkang Wang: "A New Proof for the Correctness of F5 + (F5-Like) Algorithm", https://arxiv.org/abs/1004.0084 (specifically + v4) + + .. [2] Thomas Becker, Volker Weispfenning, Groebner bases: A computational + approach to commutative algebra, 1993, p. 203, 216 + """ + order = ring.order + + # reduce polynomials (like in Mario Pernici's implementation) (Becker, Weispfenning, p. 203) + B = F + while True: + F = B + B = [] + + for i in range(len(F)): + p = F[i] + r = p.rem(F[:i]) + + if r: + B.append(r) + + if F == B: + break + + # basis + B = [lbp(sig(ring.zero_monom, i + 1), F[i], i + 1) for i in range(len(F))] + B.sort(key=lambda f: order(Polyn(f).LM), reverse=True) + + # critical pairs + CP = [critical_pair(B[i], B[j], ring) for i in range(len(B)) for j in range(i + 1, len(B))] + CP.sort(key=lambda cp: cp_key(cp, ring), reverse=True) + + k = len(B) + + reductions_to_zero = 0 + + while len(CP): + cp = CP.pop() + + # discard redundant critical pairs: + if is_rewritable_or_comparable(cp[0], Num(cp[2]), B): + continue + if is_rewritable_or_comparable(cp[3], Num(cp[5]), B): + continue + + s = s_poly(cp) + + p = f5_reduce(s, B) + + p = lbp(Sign(p), Polyn(p).monic(), k + 1) + + if Polyn(p): + # remove old critical pairs, that become redundant when adding p: + indices = [] + for i, cp in enumerate(CP): + if is_rewritable_or_comparable(cp[0], Num(cp[2]), [p]): + indices.append(i) + elif is_rewritable_or_comparable(cp[3], Num(cp[5]), [p]): + indices.append(i) + + for i in reversed(indices): + del CP[i] + + # only add new critical pairs that are not made redundant by p: + for g in B: + if Polyn(g): + cp = critical_pair(p, g, ring) + if is_rewritable_or_comparable(cp[0], Num(cp[2]), [p]): + continue + elif is_rewritable_or_comparable(cp[3], Num(cp[5]), [p]): + continue + + CP.append(cp) + + # sort (other sorting methods/selection strategies were not as successful) + CP.sort(key=lambda cp: cp_key(cp, ring), reverse=True) + + # insert p into B: + m = Polyn(p).LM + if order(m) <= order(Polyn(B[-1]).LM): + B.append(p) + else: + for i, q in enumerate(B): + if order(m) > order(Polyn(q).LM): + B.insert(i, p) + break + + k += 1 + + #print(len(B), len(CP), "%d critical pairs removed" % len(indices)) + else: + reductions_to_zero += 1 + + # reduce Groebner basis: + H = [Polyn(g).monic() for g in B] + H = red_groebner(H, ring) + + return sorted(H, key=lambda f: order(f.LM), reverse=True) + + +def red_groebner(G, ring): + """ + Compute reduced Groebner basis, from BeckerWeispfenning93, p. 216 + + Selects a subset of generators, that already generate the ideal + and computes a reduced Groebner basis for them. + """ + def reduction(P): + """ + The actual reduction algorithm. + """ + Q = [] + for i, p in enumerate(P): + h = p.rem(P[:i] + P[i + 1:]) + if h: + Q.append(h) + + return [p.monic() for p in Q] + + F = G + H = [] + + while F: + f0 = F.pop() + + if not any(monomial_divides(f.LM, f0.LM) for f in F + H): + H.append(f0) + + # Becker, Weispfenning, p. 217: H is Groebner basis of the ideal generated by G. + return reduction(H) + + +def is_groebner(G, ring): + """ + Check if G is a Groebner basis. + """ + for i in range(len(G)): + for j in range(i + 1, len(G)): + s = spoly(G[i], G[j], ring) + s = s.rem(G) + if s: + return False + + return True + + +def is_minimal(G, ring): + """ + Checks if G is a minimal Groebner basis. + """ + order = ring.order + domain = ring.domain + + G.sort(key=lambda g: order(g.LM)) + + for i, g in enumerate(G): + if g.LC != domain.one: + return False + + for h in G[:i] + G[i + 1:]: + if monomial_divides(h.LM, g.LM): + return False + + return True + + +def is_reduced(G, ring): + """ + Checks if G is a reduced Groebner basis. + """ + order = ring.order + domain = ring.domain + + G.sort(key=lambda g: order(g.LM)) + + for i, g in enumerate(G): + if g.LC != domain.one: + return False + + for term in g.terms(): + for h in G[:i] + G[i + 1:]: + if monomial_divides(h.LM, term[0]): + return False + + return True + +def groebner_lcm(f, g): + """ + Computes LCM of two polynomials using Groebner bases. + + The LCM is computed as the unique generator of the intersection + of the two ideals generated by `f` and `g`. The approach is to + compute a Groebner basis with respect to lexicographic ordering + of `t*f` and `(1 - t)*g`, where `t` is an unrelated variable and + then filtering out the solution that does not contain `t`. + + References + ========== + + .. [1] [Cox97]_ + + """ + if f.ring != g.ring: + raise ValueError("Values should be equal") + + ring = f.ring + domain = ring.domain + + if not f or not g: + return ring.zero + + if len(f) <= 1 and len(g) <= 1: + monom = monomial_lcm(f.LM, g.LM) + coeff = domain.lcm(f.LC, g.LC) + return ring.term_new(monom, coeff) + + fc, f = f.primitive() + gc, g = g.primitive() + + lcm = domain.lcm(fc, gc) + + f_terms = [ ((1,) + monom, coeff) for monom, coeff in f.terms() ] + g_terms = [ ((0,) + monom, coeff) for monom, coeff in g.terms() ] \ + + [ ((1,) + monom,-coeff) for monom, coeff in g.terms() ] + + t = Dummy("t") + t_ring = ring.clone(symbols=(t,) + ring.symbols, order=lex) + + F = t_ring.from_terms(f_terms) + G = t_ring.from_terms(g_terms) + + basis = groebner([F, G], t_ring) + + def is_independent(h, j): + return not any(monom[j] for monom in h.monoms()) + + H = [ h for h in basis if is_independent(h, 0) ] + + h_terms = [ (monom[1:], coeff*lcm) for monom, coeff in H[0].terms() ] + h = ring.from_terms(h_terms) + + return h + +def groebner_gcd(f, g): + """Computes GCD of two polynomials using Groebner bases. """ + if f.ring != g.ring: + raise ValueError("Values should be equal") + domain = f.ring.domain + + if not domain.is_Field: + fc, f = f.primitive() + gc, g = g.primitive() + gcd = domain.gcd(fc, gc) + + H = (f*g).quo([groebner_lcm(f, g)]) + + if len(H) != 1: + raise ValueError("Length should be 1") + h = H[0] + + if not domain.is_Field: + return gcd*h + else: + return h.monic() diff --git a/phivenv/Lib/site-packages/sympy/polys/heuristicgcd.py b/phivenv/Lib/site-packages/sympy/polys/heuristicgcd.py new file mode 100644 index 0000000000000000000000000000000000000000..ea9eeac952e88552d729f0bd3073dee21b6ab68b --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/heuristicgcd.py @@ -0,0 +1,149 @@ +"""Heuristic polynomial GCD algorithm (HEUGCD). """ + +from .polyerrors import HeuristicGCDFailed + +HEU_GCD_MAX = 6 + +def heugcd(f, g): + """ + Heuristic polynomial GCD in ``Z[X]``. + + Given univariate polynomials ``f`` and ``g`` in ``Z[X]``, returns + their GCD and cofactors, i.e. polynomials ``h``, ``cff`` and ``cfg`` + such that:: + + h = gcd(f, g), cff = quo(f, h) and cfg = quo(g, h) + + The algorithm is purely heuristic which means it may fail to compute + the GCD. This will be signaled by raising an exception. In this case + you will need to switch to another GCD method. + + The algorithm computes the polynomial GCD by evaluating polynomials + ``f`` and ``g`` at certain points and computing (fast) integer GCD + of those evaluations. The polynomial GCD is recovered from the integer + image by interpolation. The evaluation process reduces f and g variable + by variable into a large integer. The final step is to verify if the + interpolated polynomial is the correct GCD. This gives cofactors of + the input polynomials as a side effect. + + Examples + ======== + + >>> from sympy.polys.heuristicgcd import heugcd + >>> from sympy.polys import ring, ZZ + + >>> R, x,y, = ring("x,y", ZZ) + + >>> f = x**2 + 2*x*y + y**2 + >>> g = x**2 + x*y + + >>> h, cff, cfg = heugcd(f, g) + >>> h, cff, cfg + (x + y, x + y, x) + + >>> cff*h == f + True + >>> cfg*h == g + True + + References + ========== + + .. [1] [Liao95]_ + + """ + assert f.ring == g.ring and f.ring.domain.is_ZZ + + ring = f.ring + x0 = ring.gens[0] + domain = ring.domain + + gcd, f, g = f.extract_ground(g) + + f_norm = f.max_norm() + g_norm = g.max_norm() + + B = domain(2*min(f_norm, g_norm) + 29) + + x = max(min(B, 99*domain.sqrt(B)), + 2*min(f_norm // abs(f.LC), + g_norm // abs(g.LC)) + 4) + + for i in range(0, HEU_GCD_MAX): + ff = f.evaluate(x0, x) + gg = g.evaluate(x0, x) + + if ff and gg: + if ring.ngens == 1: + h, cff, cfg = domain.cofactors(ff, gg) + else: + h, cff, cfg = heugcd(ff, gg) + + h = _gcd_interpolate(h, x, ring) + h = h.primitive()[1] + + cff_, r = f.div(h) + + if not r: + cfg_, r = g.div(h) + + if not r: + h = h.mul_ground(gcd) + return h, cff_, cfg_ + + cff = _gcd_interpolate(cff, x, ring) + + h, r = f.div(cff) + + if not r: + cfg_, r = g.div(h) + + if not r: + h = h.mul_ground(gcd) + return h, cff, cfg_ + + cfg = _gcd_interpolate(cfg, x, ring) + + h, r = g.div(cfg) + + if not r: + cff_, r = f.div(h) + + if not r: + h = h.mul_ground(gcd) + return h, cff_, cfg + + x = 73794*x * domain.sqrt(domain.sqrt(x)) // 27011 + + raise HeuristicGCDFailed('no luck') + +def _gcd_interpolate(h, x, ring): + """Interpolate polynomial GCD from integer GCD. """ + f, i = ring.zero, 0 + + # TODO: don't expose poly repr implementation details + if ring.ngens == 1: + while h: + g = h % x + if g > x // 2: g -= x + h = (h - g) // x + + # f += X**i*g + if g: + f[(i,)] = g + i += 1 + else: + while h: + g = h.trunc_ground(x) + h = (h - g).quo_ground(x) + + # f += X**i*g + if g: + for monom, coeff in g.iterterms(): + f[(i,) + monom] = coeff + i += 1 + + if f.LC < 0: + return -f + else: + return f diff --git a/phivenv/Lib/site-packages/sympy/polys/modulargcd.py b/phivenv/Lib/site-packages/sympy/polys/modulargcd.py new file mode 100644 index 0000000000000000000000000000000000000000..6f0012316c499cfde85f56c5c37a3475f4175a4e --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/modulargcd.py @@ -0,0 +1,2278 @@ +from sympy.core.symbol import Dummy +from sympy.ntheory import nextprime +from sympy.ntheory.modular import crt +from sympy.polys.domains import PolynomialRing +from sympy.polys.galoistools import ( + gf_gcd, gf_from_dict, gf_gcdex, gf_div, gf_lcm) +from sympy.polys.polyerrors import ModularGCDFailed + +from mpmath import sqrt +import random + + +def _trivial_gcd(f, g): + """ + Compute the GCD of two polynomials in trivial cases, i.e. when one + or both polynomials are zero. + """ + ring = f.ring + + if not (f or g): + return ring.zero, ring.zero, ring.zero + elif not f: + if g.LC < ring.domain.zero: + return -g, ring.zero, -ring.one + else: + return g, ring.zero, ring.one + elif not g: + if f.LC < ring.domain.zero: + return -f, -ring.one, ring.zero + else: + return f, ring.one, ring.zero + return None + + +def _gf_gcd(fp, gp, p): + r""" + Compute the GCD of two univariate polynomials in `\mathbb{Z}_p[x]`. + """ + dom = fp.ring.domain + + while gp: + rem = fp + deg = gp.degree() + lcinv = dom.invert(gp.LC, p) + + while True: + degrem = rem.degree() + if degrem < deg: + break + rem = (rem - gp.mul_monom((degrem - deg,)).mul_ground(lcinv * rem.LC)).trunc_ground(p) + + fp = gp + gp = rem + + return fp.mul_ground(dom.invert(fp.LC, p)).trunc_ground(p) + + +def _degree_bound_univariate(f, g): + r""" + Compute an upper bound for the degree of the GCD of two univariate + integer polynomials `f` and `g`. + + The function chooses a suitable prime `p` and computes the GCD of + `f` and `g` in `\mathbb{Z}_p[x]`. The choice of `p` guarantees that + the degree in `\mathbb{Z}_p[x]` is greater than or equal to the degree + in `\mathbb{Z}[x]`. + + Parameters + ========== + + f : PolyElement + univariate integer polynomial + g : PolyElement + univariate integer polynomial + + """ + gamma = f.ring.domain.gcd(f.LC, g.LC) + p = 1 + + p = nextprime(p) + while gamma % p == 0: + p = nextprime(p) + + fp = f.trunc_ground(p) + gp = g.trunc_ground(p) + hp = _gf_gcd(fp, gp, p) + deghp = hp.degree() + return deghp + + +def _chinese_remainder_reconstruction_univariate(hp, hq, p, q): + r""" + Construct a polynomial `h_{pq}` in `\mathbb{Z}_{p q}[x]` such that + + .. math :: + + h_{pq} = h_p \; \mathrm{mod} \, p + + h_{pq} = h_q \; \mathrm{mod} \, q + + for relatively prime integers `p` and `q` and polynomials + `h_p` and `h_q` in `\mathbb{Z}_p[x]` and `\mathbb{Z}_q[x]` + respectively. + + The coefficients of the polynomial `h_{pq}` are computed with the + Chinese Remainder Theorem. The symmetric representation in + `\mathbb{Z}_p[x]`, `\mathbb{Z}_q[x]` and `\mathbb{Z}_{p q}[x]` is used. + It is assumed that `h_p` and `h_q` have the same degree. + + Parameters + ========== + + hp : PolyElement + univariate integer polynomial with coefficients in `\mathbb{Z}_p` + hq : PolyElement + univariate integer polynomial with coefficients in `\mathbb{Z}_q` + p : Integer + modulus of `h_p`, relatively prime to `q` + q : Integer + modulus of `h_q`, relatively prime to `p` + + Examples + ======== + + >>> from sympy.polys.modulargcd import _chinese_remainder_reconstruction_univariate + >>> from sympy.polys import ring, ZZ + + >>> R, x = ring("x", ZZ) + >>> p = 3 + >>> q = 5 + + >>> hp = -x**3 - 1 + >>> hq = 2*x**3 - 2*x**2 + x + + >>> hpq = _chinese_remainder_reconstruction_univariate(hp, hq, p, q) + >>> hpq + 2*x**3 + 3*x**2 + 6*x + 5 + + >>> hpq.trunc_ground(p) == hp + True + >>> hpq.trunc_ground(q) == hq + True + + """ + n = hp.degree() + x = hp.ring.gens[0] + hpq = hp.ring.zero + + for i in range(n+1): + hpq[(i,)] = crt([p, q], [hp.coeff(x**i), hq.coeff(x**i)], symmetric=True)[0] + + hpq.strip_zero() + return hpq + + +def modgcd_univariate(f, g): + r""" + Computes the GCD of two polynomials in `\mathbb{Z}[x]` using a modular + algorithm. + + The algorithm computes the GCD of two univariate integer polynomials + `f` and `g` by computing the GCD in `\mathbb{Z}_p[x]` for suitable + primes `p` and then reconstructing the coefficients with the Chinese + Remainder Theorem. Trial division is only made for candidates which + are very likely the desired GCD. + + Parameters + ========== + + f : PolyElement + univariate integer polynomial + g : PolyElement + univariate integer polynomial + + Returns + ======= + + h : PolyElement + GCD of the polynomials `f` and `g` + cff : PolyElement + cofactor of `f`, i.e. `\frac{f}{h}` + cfg : PolyElement + cofactor of `g`, i.e. `\frac{g}{h}` + + Examples + ======== + + >>> from sympy.polys.modulargcd import modgcd_univariate + >>> from sympy.polys import ring, ZZ + + >>> R, x = ring("x", ZZ) + + >>> f = x**5 - 1 + >>> g = x - 1 + + >>> h, cff, cfg = modgcd_univariate(f, g) + >>> h, cff, cfg + (x - 1, x**4 + x**3 + x**2 + x + 1, 1) + + >>> cff * h == f + True + >>> cfg * h == g + True + + >>> f = 6*x**2 - 6 + >>> g = 2*x**2 + 4*x + 2 + + >>> h, cff, cfg = modgcd_univariate(f, g) + >>> h, cff, cfg + (2*x + 2, 3*x - 3, x + 1) + + >>> cff * h == f + True + >>> cfg * h == g + True + + References + ========== + + 1. [Monagan00]_ + + """ + assert f.ring == g.ring and f.ring.domain.is_ZZ + + result = _trivial_gcd(f, g) + if result is not None: + return result + + ring = f.ring + + cf, f = f.primitive() + cg, g = g.primitive() + ch = ring.domain.gcd(cf, cg) + + bound = _degree_bound_univariate(f, g) + if bound == 0: + return ring(ch), f.mul_ground(cf // ch), g.mul_ground(cg // ch) + + gamma = ring.domain.gcd(f.LC, g.LC) + m = 1 + p = 1 + + while True: + p = nextprime(p) + while gamma % p == 0: + p = nextprime(p) + + fp = f.trunc_ground(p) + gp = g.trunc_ground(p) + hp = _gf_gcd(fp, gp, p) + deghp = hp.degree() + + if deghp > bound: + continue + elif deghp < bound: + m = 1 + bound = deghp + continue + + hp = hp.mul_ground(gamma).trunc_ground(p) + if m == 1: + m = p + hlastm = hp + continue + + hm = _chinese_remainder_reconstruction_univariate(hp, hlastm, p, m) + m *= p + + if not hm == hlastm: + hlastm = hm + continue + + h = hm.quo_ground(hm.content()) + fquo, frem = f.div(h) + gquo, grem = g.div(h) + if not frem and not grem: + if h.LC < 0: + ch = -ch + h = h.mul_ground(ch) + cff = fquo.mul_ground(cf // ch) + cfg = gquo.mul_ground(cg // ch) + return h, cff, cfg + + +def _primitive(f, p): + r""" + Compute the content and the primitive part of a polynomial in + `\mathbb{Z}_p[x_0, \ldots, x_{k-2}, y] \cong \mathbb{Z}_p[y][x_0, \ldots, x_{k-2}]`. + + Parameters + ========== + + f : PolyElement + integer polynomial in `\mathbb{Z}_p[x0, \ldots, x{k-2}, y]` + p : Integer + modulus of `f` + + Returns + ======= + + contf : PolyElement + integer polynomial in `\mathbb{Z}_p[y]`, content of `f` + ppf : PolyElement + primitive part of `f`, i.e. `\frac{f}{contf}` + + Examples + ======== + + >>> from sympy.polys.modulargcd import _primitive + >>> from sympy.polys import ring, ZZ + + >>> R, x, y = ring("x, y", ZZ) + >>> p = 3 + + >>> f = x**2*y**2 + x**2*y - y**2 - y + >>> _primitive(f, p) + (y**2 + y, x**2 - 1) + + >>> R, x, y, z = ring("x, y, z", ZZ) + + >>> f = x*y*z - y**2*z**2 + >>> _primitive(f, p) + (z, x*y - y**2*z) + + """ + ring = f.ring + dom = ring.domain + k = ring.ngens + + coeffs = {} + for monom, coeff in f.iterterms(): + if monom[:-1] not in coeffs: + coeffs[monom[:-1]] = {} + coeffs[monom[:-1]][monom[-1]] = coeff + + cont = [] + for coeff in iter(coeffs.values()): + cont = gf_gcd(cont, gf_from_dict(coeff, p, dom), p, dom) + + yring = ring.clone(symbols=ring.symbols[k-1]) + contf = yring.from_dense(cont).trunc_ground(p) + + return contf, f.quo(contf.set_ring(ring)) + + +def _deg(f): + r""" + Compute the degree of a multivariate polynomial + `f \in K[x_0, \ldots, x_{k-2}, y] \cong K[y][x_0, \ldots, x_{k-2}]`. + + Parameters + ========== + + f : PolyElement + polynomial in `K[x_0, \ldots, x_{k-2}, y]` + + Returns + ======= + + degf : Integer tuple + degree of `f` in `x_0, \ldots, x_{k-2}` + + Examples + ======== + + >>> from sympy.polys.modulargcd import _deg + >>> from sympy.polys import ring, ZZ + + >>> R, x, y = ring("x, y", ZZ) + + >>> f = x**2*y**2 + x**2*y - 1 + >>> _deg(f) + (2,) + + >>> R, x, y, z = ring("x, y, z", ZZ) + + >>> f = x**2*y**2 + x**2*y - 1 + >>> _deg(f) + (2, 2) + + >>> f = x*y*z - y**2*z**2 + >>> _deg(f) + (1, 1) + + """ + k = f.ring.ngens + degf = (0,) * (k-1) + for monom in f.itermonoms(): + if monom[:-1] > degf: + degf = monom[:-1] + return degf + + +def _LC(f): + r""" + Compute the leading coefficient of a multivariate polynomial + `f \in K[x_0, \ldots, x_{k-2}, y] \cong K[y][x_0, \ldots, x_{k-2}]`. + + Parameters + ========== + + f : PolyElement + polynomial in `K[x_0, \ldots, x_{k-2}, y]` + + Returns + ======= + + lcf : PolyElement + polynomial in `K[y]`, leading coefficient of `f` + + Examples + ======== + + >>> from sympy.polys.modulargcd import _LC + >>> from sympy.polys import ring, ZZ + + >>> R, x, y = ring("x, y", ZZ) + + >>> f = x**2*y**2 + x**2*y - 1 + >>> _LC(f) + y**2 + y + + >>> R, x, y, z = ring("x, y, z", ZZ) + + >>> f = x**2*y**2 + x**2*y - 1 + >>> _LC(f) + 1 + + >>> f = x*y*z - y**2*z**2 + >>> _LC(f) + z + + """ + ring = f.ring + k = ring.ngens + yring = ring.clone(symbols=ring.symbols[k-1]) + y = yring.gens[0] + degf = _deg(f) + + lcf = yring.zero + for monom, coeff in f.iterterms(): + if monom[:-1] == degf: + lcf += coeff*y**monom[-1] + return lcf + + +def _swap(f, i): + """ + Make the variable `x_i` the leading one in a multivariate polynomial `f`. + """ + ring = f.ring + fswap = ring.zero + for monom, coeff in f.iterterms(): + monomswap = (monom[i],) + monom[:i] + monom[i+1:] + fswap[monomswap] = coeff + return fswap + + +def _degree_bound_bivariate(f, g): + r""" + Compute upper degree bounds for the GCD of two bivariate + integer polynomials `f` and `g`. + + The GCD is viewed as a polynomial in `\mathbb{Z}[y][x]` and the + function returns an upper bound for its degree and one for the degree + of its content. This is done by choosing a suitable prime `p` and + computing the GCD of the contents of `f \; \mathrm{mod} \, p` and + `g \; \mathrm{mod} \, p`. The choice of `p` guarantees that the degree + of the content in `\mathbb{Z}_p[y]` is greater than or equal to the + degree in `\mathbb{Z}[y]`. To obtain the degree bound in the variable + `x`, the polynomials are evaluated at `y = a` for a suitable + `a \in \mathbb{Z}_p` and then their GCD in `\mathbb{Z}_p[x]` is + computed. If no such `a` exists, i.e. the degree in `\mathbb{Z}_p[x]` + is always smaller than the one in `\mathbb{Z}[y][x]`, then the bound is + set to the minimum of the degrees of `f` and `g` in `x`. + + Parameters + ========== + + f : PolyElement + bivariate integer polynomial + g : PolyElement + bivariate integer polynomial + + Returns + ======= + + xbound : Integer + upper bound for the degree of the GCD of the polynomials `f` and + `g` in the variable `x` + ycontbound : Integer + upper bound for the degree of the content of the GCD of the + polynomials `f` and `g` in the variable `y` + + References + ========== + + 1. [Monagan00]_ + + """ + ring = f.ring + + gamma1 = ring.domain.gcd(f.LC, g.LC) + gamma2 = ring.domain.gcd(_swap(f, 1).LC, _swap(g, 1).LC) + badprimes = gamma1 * gamma2 + p = 1 + + p = nextprime(p) + while badprimes % p == 0: + p = nextprime(p) + + fp = f.trunc_ground(p) + gp = g.trunc_ground(p) + contfp, fp = _primitive(fp, p) + contgp, gp = _primitive(gp, p) + conthp = _gf_gcd(contfp, contgp, p) # polynomial in Z_p[y] + ycontbound = conthp.degree() + + # polynomial in Z_p[y] + delta = _gf_gcd(_LC(fp), _LC(gp), p) + + for a in range(p): + if not delta.evaluate(0, a) % p: + continue + fpa = fp.evaluate(1, a).trunc_ground(p) + gpa = gp.evaluate(1, a).trunc_ground(p) + hpa = _gf_gcd(fpa, gpa, p) + xbound = hpa.degree() + return xbound, ycontbound + + return min(fp.degree(), gp.degree()), ycontbound + + +def _chinese_remainder_reconstruction_multivariate(hp, hq, p, q): + r""" + Construct a polynomial `h_{pq}` in + `\mathbb{Z}_{p q}[x_0, \ldots, x_{k-1}]` such that + + .. math :: + + h_{pq} = h_p \; \mathrm{mod} \, p + + h_{pq} = h_q \; \mathrm{mod} \, q + + for relatively prime integers `p` and `q` and polynomials + `h_p` and `h_q` in `\mathbb{Z}_p[x_0, \ldots, x_{k-1}]` and + `\mathbb{Z}_q[x_0, \ldots, x_{k-1}]` respectively. + + The coefficients of the polynomial `h_{pq}` are computed with the + Chinese Remainder Theorem. The symmetric representation in + `\mathbb{Z}_p[x_0, \ldots, x_{k-1}]`, + `\mathbb{Z}_q[x_0, \ldots, x_{k-1}]` and + `\mathbb{Z}_{p q}[x_0, \ldots, x_{k-1}]` is used. + + Parameters + ========== + + hp : PolyElement + multivariate integer polynomial with coefficients in `\mathbb{Z}_p` + hq : PolyElement + multivariate integer polynomial with coefficients in `\mathbb{Z}_q` + p : Integer + modulus of `h_p`, relatively prime to `q` + q : Integer + modulus of `h_q`, relatively prime to `p` + + Examples + ======== + + >>> from sympy.polys.modulargcd import _chinese_remainder_reconstruction_multivariate + >>> from sympy.polys import ring, ZZ + + >>> R, x, y = ring("x, y", ZZ) + >>> p = 3 + >>> q = 5 + + >>> hp = x**3*y - x**2 - 1 + >>> hq = -x**3*y - 2*x*y**2 + 2 + + >>> hpq = _chinese_remainder_reconstruction_multivariate(hp, hq, p, q) + >>> hpq + 4*x**3*y + 5*x**2 + 3*x*y**2 + 2 + + >>> hpq.trunc_ground(p) == hp + True + >>> hpq.trunc_ground(q) == hq + True + + >>> R, x, y, z = ring("x, y, z", ZZ) + >>> p = 6 + >>> q = 5 + + >>> hp = 3*x**4 - y**3*z + z + >>> hq = -2*x**4 + z + + >>> hpq = _chinese_remainder_reconstruction_multivariate(hp, hq, p, q) + >>> hpq + 3*x**4 + 5*y**3*z + z + + >>> hpq.trunc_ground(p) == hp + True + >>> hpq.trunc_ground(q) == hq + True + + """ + hpmonoms = set(hp.monoms()) + hqmonoms = set(hq.monoms()) + monoms = hpmonoms.intersection(hqmonoms) + hpmonoms.difference_update(monoms) + hqmonoms.difference_update(monoms) + + domain = hp.ring.domain + zero = domain.zero + + hpq = hp.ring.zero + + if isinstance(hp.ring.domain, PolynomialRing): + crt_ = _chinese_remainder_reconstruction_multivariate + else: + def crt_(cp, cq, p, q): + return domain(crt([p, q], [cp, cq], symmetric=True)[0]) + + for monom in monoms: + hpq[monom] = crt_(hp[monom], hq[monom], p, q) + for monom in hpmonoms: + hpq[monom] = crt_(hp[monom], zero, p, q) + for monom in hqmonoms: + hpq[monom] = crt_(zero, hq[monom], p, q) + + return hpq + + +def _interpolate_multivariate(evalpoints, hpeval, ring, i, p, ground=False): + r""" + Reconstruct a polynomial `h_p` in `\mathbb{Z}_p[x_0, \ldots, x_{k-1}]` + from a list of evaluation points in `\mathbb{Z}_p` and a list of + polynomials in + `\mathbb{Z}_p[x_0, \ldots, x_{i-1}, x_{i+1}, \ldots, x_{k-1}]`, which + are the images of `h_p` evaluated in the variable `x_i`. + + It is also possible to reconstruct a parameter of the ground domain, + i.e. if `h_p` is a polynomial over `\mathbb{Z}_p[x_0, \ldots, x_{k-1}]`. + In this case, one has to set ``ground=True``. + + Parameters + ========== + + evalpoints : list of Integer objects + list of evaluation points in `\mathbb{Z}_p` + hpeval : list of PolyElement objects + list of polynomials in (resp. over) + `\mathbb{Z}_p[x_0, \ldots, x_{i-1}, x_{i+1}, \ldots, x_{k-1}]`, + images of `h_p` evaluated in the variable `x_i` + ring : PolyRing + `h_p` will be an element of this ring + i : Integer + index of the variable which has to be reconstructed + p : Integer + prime number, modulus of `h_p` + ground : Boolean + indicates whether `x_i` is in the ground domain, default is + ``False`` + + Returns + ======= + + hp : PolyElement + interpolated polynomial in (resp. over) + `\mathbb{Z}_p[x_0, \ldots, x_{k-1}]` + + """ + hp = ring.zero + + if ground: + domain = ring.domain.domain + y = ring.domain.gens[i] + else: + domain = ring.domain + y = ring.gens[i] + + for a, hpa in zip(evalpoints, hpeval): + numer = ring.one + denom = domain.one + for b in evalpoints: + if b == a: + continue + + numer *= y - b + denom *= a - b + + denom = domain.invert(denom, p) + coeff = numer.mul_ground(denom) + hp += hpa.set_ring(ring) * coeff + + return hp.trunc_ground(p) + + +def modgcd_bivariate(f, g): + r""" + Computes the GCD of two polynomials in `\mathbb{Z}[x, y]` using a + modular algorithm. + + The algorithm computes the GCD of two bivariate integer polynomials + `f` and `g` by calculating the GCD in `\mathbb{Z}_p[x, y]` for + suitable primes `p` and then reconstructing the coefficients with the + Chinese Remainder Theorem. To compute the bivariate GCD over + `\mathbb{Z}_p`, the polynomials `f \; \mathrm{mod} \, p` and + `g \; \mathrm{mod} \, p` are evaluated at `y = a` for certain + `a \in \mathbb{Z}_p` and then their univariate GCD in `\mathbb{Z}_p[x]` + is computed. Interpolating those yields the bivariate GCD in + `\mathbb{Z}_p[x, y]`. To verify the result in `\mathbb{Z}[x, y]`, trial + division is done, but only for candidates which are very likely the + desired GCD. + + Parameters + ========== + + f : PolyElement + bivariate integer polynomial + g : PolyElement + bivariate integer polynomial + + Returns + ======= + + h : PolyElement + GCD of the polynomials `f` and `g` + cff : PolyElement + cofactor of `f`, i.e. `\frac{f}{h}` + cfg : PolyElement + cofactor of `g`, i.e. `\frac{g}{h}` + + Examples + ======== + + >>> from sympy.polys.modulargcd import modgcd_bivariate + >>> from sympy.polys import ring, ZZ + + >>> R, x, y = ring("x, y", ZZ) + + >>> f = x**2 - y**2 + >>> g = x**2 + 2*x*y + y**2 + + >>> h, cff, cfg = modgcd_bivariate(f, g) + >>> h, cff, cfg + (x + y, x - y, x + y) + + >>> cff * h == f + True + >>> cfg * h == g + True + + >>> f = x**2*y - x**2 - 4*y + 4 + >>> g = x + 2 + + >>> h, cff, cfg = modgcd_bivariate(f, g) + >>> h, cff, cfg + (x + 2, x*y - x - 2*y + 2, 1) + + >>> cff * h == f + True + >>> cfg * h == g + True + + References + ========== + + 1. [Monagan00]_ + + """ + assert f.ring == g.ring and f.ring.domain.is_ZZ + + result = _trivial_gcd(f, g) + if result is not None: + return result + + ring = f.ring + + cf, f = f.primitive() + cg, g = g.primitive() + ch = ring.domain.gcd(cf, cg) + + xbound, ycontbound = _degree_bound_bivariate(f, g) + if xbound == ycontbound == 0: + return ring(ch), f.mul_ground(cf // ch), g.mul_ground(cg // ch) + + fswap = _swap(f, 1) + gswap = _swap(g, 1) + degyf = fswap.degree() + degyg = gswap.degree() + + ybound, xcontbound = _degree_bound_bivariate(fswap, gswap) + if ybound == xcontbound == 0: + return ring(ch), f.mul_ground(cf // ch), g.mul_ground(cg // ch) + + # TODO: to improve performance, choose the main variable here + + gamma1 = ring.domain.gcd(f.LC, g.LC) + gamma2 = ring.domain.gcd(fswap.LC, gswap.LC) + badprimes = gamma1 * gamma2 + m = 1 + p = 1 + + while True: + p = nextprime(p) + while badprimes % p == 0: + p = nextprime(p) + + fp = f.trunc_ground(p) + gp = g.trunc_ground(p) + contfp, fp = _primitive(fp, p) + contgp, gp = _primitive(gp, p) + conthp = _gf_gcd(contfp, contgp, p) # monic polynomial in Z_p[y] + degconthp = conthp.degree() + + if degconthp > ycontbound: + continue + elif degconthp < ycontbound: + m = 1 + ycontbound = degconthp + continue + + # polynomial in Z_p[y] + delta = _gf_gcd(_LC(fp), _LC(gp), p) + + degcontfp = contfp.degree() + degcontgp = contgp.degree() + degdelta = delta.degree() + + N = min(degyf - degcontfp, degyg - degcontgp, + ybound - ycontbound + degdelta) + 1 + + if p < N: + continue + + n = 0 + evalpoints = [] + hpeval = [] + unlucky = False + + for a in range(p): + deltaa = delta.evaluate(0, a) + if not deltaa % p: + continue + + fpa = fp.evaluate(1, a).trunc_ground(p) + gpa = gp.evaluate(1, a).trunc_ground(p) + hpa = _gf_gcd(fpa, gpa, p) # monic polynomial in Z_p[x] + deghpa = hpa.degree() + + if deghpa > xbound: + continue + elif deghpa < xbound: + m = 1 + xbound = deghpa + unlucky = True + break + + hpa = hpa.mul_ground(deltaa).trunc_ground(p) + evalpoints.append(a) + hpeval.append(hpa) + n += 1 + + if n == N: + break + + if unlucky: + continue + if n < N: + continue + + hp = _interpolate_multivariate(evalpoints, hpeval, ring, 1, p) + + hp = _primitive(hp, p)[1] + hp = hp * conthp.set_ring(ring) + degyhp = hp.degree(1) + + if degyhp > ybound: + continue + if degyhp < ybound: + m = 1 + ybound = degyhp + continue + + hp = hp.mul_ground(gamma1).trunc_ground(p) + if m == 1: + m = p + hlastm = hp + continue + + hm = _chinese_remainder_reconstruction_multivariate(hp, hlastm, p, m) + m *= p + + if not hm == hlastm: + hlastm = hm + continue + + h = hm.quo_ground(hm.content()) + fquo, frem = f.div(h) + gquo, grem = g.div(h) + if not frem and not grem: + if h.LC < 0: + ch = -ch + h = h.mul_ground(ch) + cff = fquo.mul_ground(cf // ch) + cfg = gquo.mul_ground(cg // ch) + return h, cff, cfg + + +def _modgcd_multivariate_p(f, g, p, degbound, contbound): + r""" + Compute the GCD of two polynomials in + `\mathbb{Z}_p[x_0, \ldots, x_{k-1}]`. + + The algorithm reduces the problem step by step by evaluating the + polynomials `f` and `g` at `x_{k-1} = a` for suitable + `a \in \mathbb{Z}_p` and then calls itself recursively to compute the GCD + in `\mathbb{Z}_p[x_0, \ldots, x_{k-2}]`. If these recursive calls are + successful for enough evaluation points, the GCD in `k` variables is + interpolated, otherwise the algorithm returns ``None``. Every time a GCD + or a content is computed, their degrees are compared with the bounds. If + a degree greater then the bound is encountered, then the current call + returns ``None`` and a new evaluation point has to be chosen. If at some + point the degree is smaller, the correspondent bound is updated and the + algorithm fails. + + Parameters + ========== + + f : PolyElement + multivariate integer polynomial with coefficients in `\mathbb{Z}_p` + g : PolyElement + multivariate integer polynomial with coefficients in `\mathbb{Z}_p` + p : Integer + prime number, modulus of `f` and `g` + degbound : list of Integer objects + ``degbound[i]`` is an upper bound for the degree of the GCD of `f` + and `g` in the variable `x_i` + contbound : list of Integer objects + ``contbound[i]`` is an upper bound for the degree of the content of + the GCD in `\mathbb{Z}_p[x_i][x_0, \ldots, x_{i-1}]`, + ``contbound[0]`` is not used can therefore be chosen + arbitrarily. + + Returns + ======= + + h : PolyElement + GCD of the polynomials `f` and `g` or ``None`` + + References + ========== + + 1. [Monagan00]_ + 2. [Brown71]_ + + """ + ring = f.ring + k = ring.ngens + + if k == 1: + h = _gf_gcd(f, g, p).trunc_ground(p) + degh = h.degree() + + if degh > degbound[0]: + return None + if degh < degbound[0]: + degbound[0] = degh + raise ModularGCDFailed + + return h + + degyf = f.degree(k-1) + degyg = g.degree(k-1) + + contf, f = _primitive(f, p) + contg, g = _primitive(g, p) + + conth = _gf_gcd(contf, contg, p) # polynomial in Z_p[y] + + degcontf = contf.degree() + degcontg = contg.degree() + degconth = conth.degree() + + if degconth > contbound[k-1]: + return None + if degconth < contbound[k-1]: + contbound[k-1] = degconth + raise ModularGCDFailed + + lcf = _LC(f) + lcg = _LC(g) + + delta = _gf_gcd(lcf, lcg, p) # polynomial in Z_p[y] + + evaltest = delta + + for i in range(k-1): + evaltest *= _gf_gcd(_LC(_swap(f, i)), _LC(_swap(g, i)), p) + + degdelta = delta.degree() + + N = min(degyf - degcontf, degyg - degcontg, + degbound[k-1] - contbound[k-1] + degdelta) + 1 + + if p < N: + return None + + n = 0 + d = 0 + evalpoints = [] + heval = [] + points = list(range(p)) + + while points: + a = random.sample(points, 1)[0] + points.remove(a) + + if not evaltest.evaluate(0, a) % p: + continue + + deltaa = delta.evaluate(0, a) % p + + fa = f.evaluate(k-1, a).trunc_ground(p) + ga = g.evaluate(k-1, a).trunc_ground(p) + + # polynomials in Z_p[x_0, ..., x_{k-2}] + ha = _modgcd_multivariate_p(fa, ga, p, degbound, contbound) + + if ha is None: + d += 1 + if d > n: + return None + continue + + if ha.is_ground: + h = conth.set_ring(ring).trunc_ground(p) + return h + + ha = ha.mul_ground(deltaa).trunc_ground(p) + + evalpoints.append(a) + heval.append(ha) + n += 1 + + if n == N: + h = _interpolate_multivariate(evalpoints, heval, ring, k-1, p) + + h = _primitive(h, p)[1] * conth.set_ring(ring) + degyh = h.degree(k-1) + + if degyh > degbound[k-1]: + return None + if degyh < degbound[k-1]: + degbound[k-1] = degyh + raise ModularGCDFailed + + return h + + return None + + +def modgcd_multivariate(f, g): + r""" + Compute the GCD of two polynomials in `\mathbb{Z}[x_0, \ldots, x_{k-1}]` + using a modular algorithm. + + The algorithm computes the GCD of two multivariate integer polynomials + `f` and `g` by calculating the GCD in + `\mathbb{Z}_p[x_0, \ldots, x_{k-1}]` for suitable primes `p` and then + reconstructing the coefficients with the Chinese Remainder Theorem. To + compute the multivariate GCD over `\mathbb{Z}_p` the recursive + subroutine :func:`_modgcd_multivariate_p` is used. To verify the result in + `\mathbb{Z}[x_0, \ldots, x_{k-1}]`, trial division is done, but only for + candidates which are very likely the desired GCD. + + Parameters + ========== + + f : PolyElement + multivariate integer polynomial + g : PolyElement + multivariate integer polynomial + + Returns + ======= + + h : PolyElement + GCD of the polynomials `f` and `g` + cff : PolyElement + cofactor of `f`, i.e. `\frac{f}{h}` + cfg : PolyElement + cofactor of `g`, i.e. `\frac{g}{h}` + + Examples + ======== + + >>> from sympy.polys.modulargcd import modgcd_multivariate + >>> from sympy.polys import ring, ZZ + + >>> R, x, y = ring("x, y", ZZ) + + >>> f = x**2 - y**2 + >>> g = x**2 + 2*x*y + y**2 + + >>> h, cff, cfg = modgcd_multivariate(f, g) + >>> h, cff, cfg + (x + y, x - y, x + y) + + >>> cff * h == f + True + >>> cfg * h == g + True + + >>> R, x, y, z = ring("x, y, z", ZZ) + + >>> f = x*z**2 - y*z**2 + >>> g = x**2*z + z + + >>> h, cff, cfg = modgcd_multivariate(f, g) + >>> h, cff, cfg + (z, x*z - y*z, x**2 + 1) + + >>> cff * h == f + True + >>> cfg * h == g + True + + References + ========== + + 1. [Monagan00]_ + 2. [Brown71]_ + + See also + ======== + + _modgcd_multivariate_p + + """ + assert f.ring == g.ring and f.ring.domain.is_ZZ + + result = _trivial_gcd(f, g) + if result is not None: + return result + + ring = f.ring + k = ring.ngens + + # divide out integer content + cf, f = f.primitive() + cg, g = g.primitive() + ch = ring.domain.gcd(cf, cg) + + gamma = ring.domain.gcd(f.LC, g.LC) + + badprimes = ring.domain.one + for i in range(k): + badprimes *= ring.domain.gcd(_swap(f, i).LC, _swap(g, i).LC) + + degbound = [min(fdeg, gdeg) for fdeg, gdeg in zip(f.degrees(), g.degrees())] + contbound = list(degbound) + + m = 1 + p = 1 + + while True: + p = nextprime(p) + while badprimes % p == 0: + p = nextprime(p) + + fp = f.trunc_ground(p) + gp = g.trunc_ground(p) + + try: + # monic GCD of fp, gp in Z_p[x_0, ..., x_{k-2}, y] + hp = _modgcd_multivariate_p(fp, gp, p, degbound, contbound) + except ModularGCDFailed: + m = 1 + continue + + if hp is None: + continue + + hp = hp.mul_ground(gamma).trunc_ground(p) + if m == 1: + m = p + hlastm = hp + continue + + hm = _chinese_remainder_reconstruction_multivariate(hp, hlastm, p, m) + m *= p + + if not hm == hlastm: + hlastm = hm + continue + + h = hm.primitive()[1] + fquo, frem = f.div(h) + gquo, grem = g.div(h) + if not frem and not grem: + if h.LC < 0: + ch = -ch + h = h.mul_ground(ch) + cff = fquo.mul_ground(cf // ch) + cfg = gquo.mul_ground(cg // ch) + return h, cff, cfg + + +def _gf_div(f, g, p): + r""" + Compute `\frac f g` modulo `p` for two univariate polynomials over + `\mathbb Z_p`. + """ + ring = f.ring + densequo, denserem = gf_div(f.to_dense(), g.to_dense(), p, ring.domain) + return ring.from_dense(densequo), ring.from_dense(denserem) + + +def _rational_function_reconstruction(c, p, m): + r""" + Reconstruct a rational function `\frac a b` in `\mathbb Z_p(t)` from + + .. math:: + + c = \frac a b \; \mathrm{mod} \, m, + + where `c` and `m` are polynomials in `\mathbb Z_p[t]` and `m` has + positive degree. + + The algorithm is based on the Euclidean Algorithm. In general, `m` is + not irreducible, so it is possible that `b` is not invertible modulo + `m`. In that case ``None`` is returned. + + Parameters + ========== + + c : PolyElement + univariate polynomial in `\mathbb Z[t]` + p : Integer + prime number + m : PolyElement + modulus, not necessarily irreducible + + Returns + ======= + + frac : FracElement + either `\frac a b` in `\mathbb Z(t)` or ``None`` + + References + ========== + + 1. [Hoeij04]_ + + """ + ring = c.ring + domain = ring.domain + M = m.degree() + N = M // 2 + D = M - N - 1 + + r0, s0 = m, ring.zero + r1, s1 = c, ring.one + + while r1.degree() > N: + quo = _gf_div(r0, r1, p)[0] + r0, r1 = r1, (r0 - quo*r1).trunc_ground(p) + s0, s1 = s1, (s0 - quo*s1).trunc_ground(p) + + a, b = r1, s1 + if b.degree() > D or _gf_gcd(b, m, p) != 1: + return None + + lc = b.LC + if lc != 1: + lcinv = domain.invert(lc, p) + a = a.mul_ground(lcinv).trunc_ground(p) + b = b.mul_ground(lcinv).trunc_ground(p) + + field = ring.to_field() + + return field(a) / field(b) + + +def _rational_reconstruction_func_coeffs(hm, p, m, ring, k): + r""" + Reconstruct every coefficient `c_h` of a polynomial `h` in + `\mathbb Z_p(t_k)[t_1, \ldots, t_{k-1}][x, z]` from the corresponding + coefficient `c_{h_m}` of a polynomial `h_m` in + `\mathbb Z_p[t_1, \ldots, t_k][x, z] \cong \mathbb Z_p[t_k][t_1, \ldots, t_{k-1}][x, z]` + such that + + .. math:: + + c_{h_m} = c_h \; \mathrm{mod} \, m, + + where `m \in \mathbb Z_p[t]`. + + The reconstruction is based on the Euclidean Algorithm. In general, `m` + is not irreducible, so it is possible that this fails for some + coefficient. In that case ``None`` is returned. + + Parameters + ========== + + hm : PolyElement + polynomial in `\mathbb Z[t_1, \ldots, t_k][x, z]` + p : Integer + prime number, modulus of `\mathbb Z_p` + m : PolyElement + modulus, polynomial in `\mathbb Z[t]`, not necessarily irreducible + ring : PolyRing + `\mathbb Z(t_k)[t_1, \ldots, t_{k-1}][x, z]`, `h` will be an + element of this ring + k : Integer + index of the parameter `t_k` which will be reconstructed + + Returns + ======= + + h : PolyElement + reconstructed polynomial in + `\mathbb Z(t_k)[t_1, \ldots, t_{k-1}][x, z]` or ``None`` + + See also + ======== + + _rational_function_reconstruction + + """ + h = ring.zero + + for monom, coeff in hm.iterterms(): + if k == 0: + coeffh = _rational_function_reconstruction(coeff, p, m) + + if not coeffh: + return None + + else: + coeffh = ring.domain.zero + for mon, c in coeff.drop_to_ground(k).iterterms(): + ch = _rational_function_reconstruction(c, p, m) + + if not ch: + return None + + coeffh[mon] = ch + + h[monom] = coeffh + + return h + + +def _gf_gcdex(f, g, p): + r""" + Extended Euclidean Algorithm for two univariate polynomials over + `\mathbb Z_p`. + + Returns polynomials `s, t` and `h`, such that `h` is the GCD of `f` and + `g` and `sf + tg = h \; \mathrm{mod} \, p`. + + """ + ring = f.ring + s, t, h = gf_gcdex(f.to_dense(), g.to_dense(), p, ring.domain) + return ring.from_dense(s), ring.from_dense(t), ring.from_dense(h) + + +def _trunc(f, minpoly, p): + r""" + Compute the reduced representation of a polynomial `f` in + `\mathbb Z_p[z] / (\check m_{\alpha}(z))[x]` + + Parameters + ========== + + f : PolyElement + polynomial in `\mathbb Z[x, z]` + minpoly : PolyElement + polynomial `\check m_{\alpha} \in \mathbb Z[z]`, not necessarily + irreducible + p : Integer + prime number, modulus of `\mathbb Z_p` + + Returns + ======= + + ftrunc : PolyElement + polynomial in `\mathbb Z[x, z]`, reduced modulo + `\check m_{\alpha}(z)` and `p` + + """ + ring = f.ring + minpoly = minpoly.set_ring(ring) + p_ = ring.ground_new(p) + + return f.trunc_ground(p).rem([minpoly, p_]).trunc_ground(p) + + +def _euclidean_algorithm(f, g, minpoly, p): + r""" + Compute the monic GCD of two univariate polynomials in + `\mathbb{Z}_p[z]/(\check m_{\alpha}(z))[x]` with the Euclidean + Algorithm. + + In general, `\check m_{\alpha}(z)` is not irreducible, so it is possible + that some leading coefficient is not invertible modulo + `\check m_{\alpha}(z)`. In that case ``None`` is returned. + + Parameters + ========== + + f, g : PolyElement + polynomials in `\mathbb Z[x, z]` + minpoly : PolyElement + polynomial in `\mathbb Z[z]`, not necessarily irreducible + p : Integer + prime number, modulus of `\mathbb Z_p` + + Returns + ======= + + h : PolyElement + GCD of `f` and `g` in `\mathbb Z[z, x]` or ``None``, coefficients + are in `\left[ -\frac{p-1} 2, \frac{p-1} 2 \right]` + + """ + ring = f.ring + + f = _trunc(f, minpoly, p) + g = _trunc(g, minpoly, p) + + while g: + rem = f + deg = g.degree(0) # degree in x + lcinv, _, gcd = _gf_gcdex(ring.dmp_LC(g), minpoly, p) + + if not gcd == 1: + return None + + while True: + degrem = rem.degree(0) # degree in x + if degrem < deg: + break + quo = (lcinv * ring.dmp_LC(rem)).set_ring(ring) + rem = _trunc(rem - g.mul_monom((degrem - deg, 0))*quo, minpoly, p) + + f = g + g = rem + + lcfinv = _gf_gcdex(ring.dmp_LC(f), minpoly, p)[0].set_ring(ring) + + return _trunc(f * lcfinv, minpoly, p) + + +def _trial_division(f, h, minpoly, p=None): + r""" + Check if `h` divides `f` in + `\mathbb K[t_1, \ldots, t_k][z]/(m_{\alpha}(z))`, where `\mathbb K` is + either `\mathbb Q` or `\mathbb Z_p`. + + This algorithm is based on pseudo division and does not use any + fractions. By default `\mathbb K` is `\mathbb Q`, if a prime number `p` + is given, `\mathbb Z_p` is chosen instead. + + Parameters + ========== + + f, h : PolyElement + polynomials in `\mathbb Z[t_1, \ldots, t_k][x, z]` + minpoly : PolyElement + polynomial `m_{\alpha}(z)` in `\mathbb Z[t_1, \ldots, t_k][z]` + p : Integer or None + if `p` is given, `\mathbb K` is set to `\mathbb Z_p` instead of + `\mathbb Q`, default is ``None`` + + Returns + ======= + + rem : PolyElement + remainder of `\frac f h` + + References + ========== + + .. [1] [Hoeij02]_ + + """ + ring = f.ring + + zxring = ring.clone(symbols=(ring.symbols[1], ring.symbols[0])) + + minpoly = minpoly.set_ring(ring) + + rem = f + + degrem = rem.degree() + degh = h.degree() + degm = minpoly.degree(1) + + lch = _LC(h).set_ring(ring) + lcm = minpoly.LC + + while rem and degrem >= degh: + # polynomial in Z[t_1, ..., t_k][z] + lcrem = _LC(rem).set_ring(ring) + rem = rem*lch - h.mul_monom((degrem - degh, 0))*lcrem + if p: + rem = rem.trunc_ground(p) + degrem = rem.degree(1) + + while rem and degrem >= degm: + # polynomial in Z[t_1, ..., t_k][x] + lcrem = _LC(rem.set_ring(zxring)).set_ring(ring) + rem = rem.mul_ground(lcm) - minpoly.mul_monom((0, degrem - degm))*lcrem + if p: + rem = rem.trunc_ground(p) + degrem = rem.degree(1) + + degrem = rem.degree() + + return rem + + +def _evaluate_ground(f, i, a): + r""" + Evaluate a polynomial `f` at `a` in the `i`-th variable of the ground + domain. + """ + ring = f.ring.clone(domain=f.ring.domain.ring.drop(i)) + fa = ring.zero + + for monom, coeff in f.iterterms(): + fa[monom] = coeff.evaluate(i, a) + + return fa + + +def _func_field_modgcd_p(f, g, minpoly, p): + r""" + Compute the GCD of two polynomials `f` and `g` in + `\mathbb Z_p(t_1, \ldots, t_k)[z]/(\check m_\alpha(z))[x]`. + + The algorithm reduces the problem step by step by evaluating the + polynomials `f` and `g` at `t_k = a` for suitable `a \in \mathbb Z_p` + and then calls itself recursively to compute the GCD in + `\mathbb Z_p(t_1, \ldots, t_{k-1})[z]/(\check m_\alpha(z))[x]`. If these + recursive calls are successful, the GCD over `k` variables is + interpolated, otherwise the algorithm returns ``None``. After + interpolation, Rational Function Reconstruction is used to obtain the + correct coefficients. If this fails, a new evaluation point has to be + chosen, otherwise the desired polynomial is obtained by clearing + denominators. The result is verified with a fraction free trial + division. + + Parameters + ========== + + f, g : PolyElement + polynomials in `\mathbb Z[t_1, \ldots, t_k][x, z]` + minpoly : PolyElement + polynomial in `\mathbb Z[t_1, \ldots, t_k][z]`, not necessarily + irreducible + p : Integer + prime number, modulus of `\mathbb Z_p` + + Returns + ======= + + h : PolyElement + primitive associate in `\mathbb Z[t_1, \ldots, t_k][x, z]` of the + GCD of the polynomials `f` and `g` or ``None``, coefficients are + in `\left[ -\frac{p-1} 2, \frac{p-1} 2 \right]` + + References + ========== + + 1. [Hoeij04]_ + + """ + ring = f.ring + domain = ring.domain # Z[t_1, ..., t_k] + + if isinstance(domain, PolynomialRing): + k = domain.ngens + else: + return _euclidean_algorithm(f, g, minpoly, p) + + if k == 1: + qdomain = domain.ring.to_field() + else: + qdomain = domain.ring.drop_to_ground(k - 1) + qdomain = qdomain.clone(domain=qdomain.domain.ring.to_field()) + + qring = ring.clone(domain=qdomain) # = Z(t_k)[t_1, ..., t_{k-1}][x, z] + + n = 1 + d = 1 + + # polynomial in Z_p[t_1, ..., t_k][z] + gamma = ring.dmp_LC(f) * ring.dmp_LC(g) + # polynomial in Z_p[t_1, ..., t_k] + delta = minpoly.LC + + evalpoints = [] + heval = [] + LMlist = [] + points = list(range(p)) + + while points: + a = random.sample(points, 1)[0] + points.remove(a) + + if k == 1: + test = delta.evaluate(k-1, a) % p == 0 + else: + test = delta.evaluate(k-1, a).trunc_ground(p) == 0 + + if test: + continue + + gammaa = _evaluate_ground(gamma, k-1, a) + minpolya = _evaluate_ground(minpoly, k-1, a) + + if gammaa.rem([minpolya, gammaa.ring(p)]) == 0: + continue + + fa = _evaluate_ground(f, k-1, a) + ga = _evaluate_ground(g, k-1, a) + + # polynomial in Z_p[x, t_1, ..., t_{k-1}, z]/(minpoly) + ha = _func_field_modgcd_p(fa, ga, minpolya, p) + + if ha is None: + d += 1 + if d > n: + return None + continue + + if ha == 1: + return ha + + LM = [ha.degree()] + [0]*(k-1) + if k > 1: + for monom, coeff in ha.iterterms(): + if monom[0] == LM[0] and coeff.LM > tuple(LM[1:]): + LM[1:] = coeff.LM + + evalpoints_a = [a] + heval_a = [ha] + if k == 1: + m = qring.domain.get_ring().one + else: + m = qring.domain.domain.get_ring().one + + t = m.ring.gens[0] + + for b, hb, LMhb in zip(evalpoints, heval, LMlist): + if LMhb == LM: + evalpoints_a.append(b) + heval_a.append(hb) + m *= (t - b) + + m = m.trunc_ground(p) + evalpoints.append(a) + heval.append(ha) + LMlist.append(LM) + n += 1 + + # polynomial in Z_p[t_1, ..., t_k][x, z] + h = _interpolate_multivariate(evalpoints_a, heval_a, ring, k-1, p, ground=True) + + # polynomial in Z_p(t_k)[t_1, ..., t_{k-1}][x, z] + h = _rational_reconstruction_func_coeffs(h, p, m, qring, k-1) + + if h is None: + continue + + if k == 1: + dom = qring.domain.field + den = dom.ring.one + + for coeff in h.itercoeffs(): + den = dom.ring.from_dense(gf_lcm(den.to_dense(), coeff.denom.to_dense(), + p, dom.domain)) + + else: + dom = qring.domain.domain.field + den = dom.ring.one + + for coeff in h.itercoeffs(): + for c in coeff.itercoeffs(): + den = dom.ring.from_dense(gf_lcm(den.to_dense(), c.denom.to_dense(), + p, dom.domain)) + + den = qring.domain_new(den.trunc_ground(p)) + h = ring(h.mul_ground(den).as_expr()).trunc_ground(p) + + if not _trial_division(f, h, minpoly, p) and not _trial_division(g, h, minpoly, p): + return h + + return None + + +def _integer_rational_reconstruction(c, m, domain): + r""" + Reconstruct a rational number `\frac a b` from + + .. math:: + + c = \frac a b \; \mathrm{mod} \, m, + + where `c` and `m` are integers. + + The algorithm is based on the Euclidean Algorithm. In general, `m` is + not a prime number, so it is possible that `b` is not invertible modulo + `m`. In that case ``None`` is returned. + + Parameters + ========== + + c : Integer + `c = \frac a b \; \mathrm{mod} \, m` + m : Integer + modulus, not necessarily prime + domain : IntegerRing + `a, b, c` are elements of ``domain`` + + Returns + ======= + + frac : Rational + either `\frac a b` in `\mathbb Q` or ``None`` + + References + ========== + + 1. [Wang81]_ + + """ + if c < 0: + c += m + + r0, s0 = m, domain.zero + r1, s1 = c, domain.one + + bound = sqrt(m / 2) # still correct if replaced by ZZ.sqrt(m // 2) ? + + while int(r1) >= bound: + quo = r0 // r1 + r0, r1 = r1, r0 - quo*r1 + s0, s1 = s1, s0 - quo*s1 + + if abs(int(s1)) >= bound: + return None + + if s1 < 0: + a, b = -r1, -s1 + elif s1 > 0: + a, b = r1, s1 + else: + return None + + field = domain.get_field() + + return field(a) / field(b) + + +def _rational_reconstruction_int_coeffs(hm, m, ring): + r""" + Reconstruct every rational coefficient `c_h` of a polynomial `h` in + `\mathbb Q[t_1, \ldots, t_k][x, z]` from the corresponding integer + coefficient `c_{h_m}` of a polynomial `h_m` in + `\mathbb Z[t_1, \ldots, t_k][x, z]` such that + + .. math:: + + c_{h_m} = c_h \; \mathrm{mod} \, m, + + where `m \in \mathbb Z`. + + The reconstruction is based on the Euclidean Algorithm. In general, + `m` is not a prime number, so it is possible that this fails for some + coefficient. In that case ``None`` is returned. + + Parameters + ========== + + hm : PolyElement + polynomial in `\mathbb Z[t_1, \ldots, t_k][x, z]` + m : Integer + modulus, not necessarily prime + ring : PolyRing + `\mathbb Q[t_1, \ldots, t_k][x, z]`, `h` will be an element of this + ring + + Returns + ======= + + h : PolyElement + reconstructed polynomial in `\mathbb Q[t_1, \ldots, t_k][x, z]` or + ``None`` + + See also + ======== + + _integer_rational_reconstruction + + """ + h = ring.zero + + if isinstance(ring.domain, PolynomialRing): + reconstruction = _rational_reconstruction_int_coeffs + domain = ring.domain.ring + else: + reconstruction = _integer_rational_reconstruction + domain = hm.ring.domain + + for monom, coeff in hm.iterterms(): + coeffh = reconstruction(coeff, m, domain) + + if not coeffh: + return None + + h[monom] = coeffh + + return h + + +def _func_field_modgcd_m(f, g, minpoly): + r""" + Compute the GCD of two polynomials in + `\mathbb Q(t_1, \ldots, t_k)[z]/(m_{\alpha}(z))[x]` using a modular + algorithm. + + The algorithm computes the GCD of two polynomials `f` and `g` by + calculating the GCD in + `\mathbb Z_p(t_1, \ldots, t_k)[z] / (\check m_{\alpha}(z))[x]` for + suitable primes `p` and the primitive associate `\check m_{\alpha}(z)` + of `m_{\alpha}(z)`. Then the coefficients are reconstructed with the + Chinese Remainder Theorem and Rational Reconstruction. To compute the + GCD over `\mathbb Z_p(t_1, \ldots, t_k)[z] / (\check m_{\alpha})[x]`, + the recursive subroutine ``_func_field_modgcd_p`` is used. To verify the + result in `\mathbb Q(t_1, \ldots, t_k)[z] / (m_{\alpha}(z))[x]`, a + fraction free trial division is used. + + Parameters + ========== + + f, g : PolyElement + polynomials in `\mathbb Z[t_1, \ldots, t_k][x, z]` + minpoly : PolyElement + irreducible polynomial in `\mathbb Z[t_1, \ldots, t_k][z]` + + Returns + ======= + + h : PolyElement + the primitive associate in `\mathbb Z[t_1, \ldots, t_k][x, z]` of + the GCD of `f` and `g` + + Examples + ======== + + >>> from sympy.polys.modulargcd import _func_field_modgcd_m + >>> from sympy.polys import ring, ZZ + + >>> R, x, z = ring('x, z', ZZ) + >>> minpoly = (z**2 - 2).drop(0) + + >>> f = x**2 + 2*x*z + 2 + >>> g = x + z + >>> _func_field_modgcd_m(f, g, minpoly) + x + z + + >>> D, t = ring('t', ZZ) + >>> R, x, z = ring('x, z', D) + >>> minpoly = (z**2-3).drop(0) + + >>> f = x**2 + (t + 1)*x*z + 3*t + >>> g = x*z + 3*t + >>> _func_field_modgcd_m(f, g, minpoly) + x + t*z + + References + ========== + + 1. [Hoeij04]_ + + See also + ======== + + _func_field_modgcd_p + + """ + ring = f.ring + domain = ring.domain + + if isinstance(domain, PolynomialRing): + k = domain.ngens + QQdomain = domain.ring.clone(domain=domain.domain.get_field()) + QQring = ring.clone(domain=QQdomain) + else: + k = 0 + QQring = ring.clone(domain=ring.domain.get_field()) + + cf, f = f.primitive() + cg, g = g.primitive() + + # polynomial in Z[t_1, ..., t_k][z] + gamma = ring.dmp_LC(f) * ring.dmp_LC(g) + # polynomial in Z[t_1, ..., t_k] + delta = minpoly.LC + + p = 1 + primes = [] + hplist = [] + LMlist = [] + + while True: + p = nextprime(p) + + if gamma.trunc_ground(p) == 0: + continue + + if k == 0: + test = (delta % p == 0) + else: + test = (delta.trunc_ground(p) == 0) + + if test: + continue + + fp = f.trunc_ground(p) + gp = g.trunc_ground(p) + minpolyp = minpoly.trunc_ground(p) + + hp = _func_field_modgcd_p(fp, gp, minpolyp, p) + + if hp is None: + continue + + if hp == 1: + return ring.one + + LM = [hp.degree()] + [0]*k + if k > 0: + for monom, coeff in hp.iterterms(): + if monom[0] == LM[0] and coeff.LM > tuple(LM[1:]): + LM[1:] = coeff.LM + + hm = hp + m = p + + for q, hq, LMhq in zip(primes, hplist, LMlist): + if LMhq == LM: + hm = _chinese_remainder_reconstruction_multivariate(hq, hm, q, m) + m *= q + + primes.append(p) + hplist.append(hp) + LMlist.append(LM) + + hm = _rational_reconstruction_int_coeffs(hm, m, QQring) + + if hm is None: + continue + + if k == 0: + h = hm.clear_denoms()[1] + else: + den = domain.domain.one + for coeff in hm.itercoeffs(): + den = domain.domain.lcm(den, coeff.clear_denoms()[0]) + h = hm.mul_ground(den) + + # convert back to Z[t_1, ..., t_k][x, z] from Q[t_1, ..., t_k][x, z] + h = h.set_ring(ring) + h = h.primitive()[1] + + if not (_trial_division(f.mul_ground(cf), h, minpoly) or + _trial_division(g.mul_ground(cg), h, minpoly)): + return h + + +def _to_ZZ_poly(f, ring): + r""" + Compute an associate of a polynomial + `f \in \mathbb Q(\alpha)[x_0, \ldots, x_{n-1}]` in + `\mathbb Z[x_1, \ldots, x_{n-1}][z] / (\check m_{\alpha}(z))[x_0]`, + where `\check m_{\alpha}(z) \in \mathbb Z[z]` is the primitive associate + of the minimal polynomial `m_{\alpha}(z)` of `\alpha` over + `\mathbb Q`. + + Parameters + ========== + + f : PolyElement + polynomial in `\mathbb Q(\alpha)[x_0, \ldots, x_{n-1}]` + ring : PolyRing + `\mathbb Z[x_1, \ldots, x_{n-1}][x_0, z]` + + Returns + ======= + + f_ : PolyElement + associate of `f` in + `\mathbb Z[x_1, \ldots, x_{n-1}][x_0, z]` + + """ + f_ = ring.zero + + if isinstance(ring.domain, PolynomialRing): + domain = ring.domain.domain + else: + domain = ring.domain + + den = domain.one + + for coeff in f.itercoeffs(): + for c in coeff.to_list(): + if c: + den = domain.lcm(den, c.denominator) + + for monom, coeff in f.iterterms(): + coeff = coeff.to_list() + m = ring.domain.one + if isinstance(ring.domain, PolynomialRing): + m = m.mul_monom(monom[1:]) + n = len(coeff) + + for i in range(n): + if coeff[i]: + c = domain.convert(coeff[i] * den) * m + + if (monom[0], n-i-1) not in f_: + f_[(monom[0], n-i-1)] = c + else: + f_[(monom[0], n-i-1)] += c + + return f_ + + +def _to_ANP_poly(f, ring): + r""" + Convert a polynomial + `f \in \mathbb Z[x_1, \ldots, x_{n-1}][z]/(\check m_{\alpha}(z))[x_0]` + to a polynomial in `\mathbb Q(\alpha)[x_0, \ldots, x_{n-1}]`, + where `\check m_{\alpha}(z) \in \mathbb Z[z]` is the primitive associate + of the minimal polynomial `m_{\alpha}(z)` of `\alpha` over + `\mathbb Q`. + + Parameters + ========== + + f : PolyElement + polynomial in `\mathbb Z[x_1, \ldots, x_{n-1}][x_0, z]` + ring : PolyRing + `\mathbb Q(\alpha)[x_0, \ldots, x_{n-1}]` + + Returns + ======= + + f_ : PolyElement + polynomial in `\mathbb Q(\alpha)[x_0, \ldots, x_{n-1}]` + + """ + domain = ring.domain + f_ = ring.zero + + if isinstance(f.ring.domain, PolynomialRing): + for monom, coeff in f.iterterms(): + for mon, coef in coeff.iterterms(): + m = (monom[0],) + mon + c = domain([domain.domain(coef)] + [0]*monom[1]) + + if m not in f_: + f_[m] = c + else: + f_[m] += c + + else: + for monom, coeff in f.iterterms(): + m = (monom[0],) + c = domain([domain.domain(coeff)] + [0]*monom[1]) + + if m not in f_: + f_[m] = c + else: + f_[m] += c + + return f_ + + +def _minpoly_from_dense(minpoly, ring): + r""" + Change representation of the minimal polynomial from ``DMP`` to + ``PolyElement`` for a given ring. + """ + minpoly_ = ring.zero + + for monom, coeff in minpoly.terms(): + minpoly_[monom] = ring.domain(coeff) + + return minpoly_ + + +def _primitive_in_x0(f): + r""" + Compute the content in `x_0` and the primitive part of a polynomial `f` + in + `\mathbb Q(\alpha)[x_0, x_1, \ldots, x_{n-1}] \cong \mathbb Q(\alpha)[x_1, \ldots, x_{n-1}][x_0]`. + """ + fring = f.ring + ring = fring.drop_to_ground(*range(1, fring.ngens)) + dom = ring.domain.ring + f_ = ring(f.as_expr()) + cont = dom.zero + + for coeff in f_.itercoeffs(): + cont = func_field_modgcd(cont, coeff)[0] + if cont == dom.one: + return cont, f + + return cont, f.quo(cont.set_ring(fring)) + + +# TODO: add support for algebraic function fields +def func_field_modgcd(f, g): + r""" + Compute the GCD of two polynomials `f` and `g` in + `\mathbb Q(\alpha)[x_0, \ldots, x_{n-1}]` using a modular algorithm. + + The algorithm first computes the primitive associate + `\check m_{\alpha}(z)` of the minimal polynomial `m_{\alpha}` in + `\mathbb{Z}[z]` and the primitive associates of `f` and `g` in + `\mathbb{Z}[x_1, \ldots, x_{n-1}][z]/(\check m_{\alpha})[x_0]`. Then it + computes the GCD in + `\mathbb Q(x_1, \ldots, x_{n-1})[z]/(m_{\alpha}(z))[x_0]`. + This is done by calculating the GCD in + `\mathbb{Z}_p(x_1, \ldots, x_{n-1})[z]/(\check m_{\alpha}(z))[x_0]` for + suitable primes `p` and then reconstructing the coefficients with the + Chinese Remainder Theorem and Rational Reconstruction. The GCD over + `\mathbb{Z}_p(x_1, \ldots, x_{n-1})[z]/(\check m_{\alpha}(z))[x_0]` is + computed with a recursive subroutine, which evaluates the polynomials at + `x_{n-1} = a` for suitable evaluation points `a \in \mathbb Z_p` and + then calls itself recursively until the ground domain does no longer + contain any parameters. For + `\mathbb{Z}_p[z]/(\check m_{\alpha}(z))[x_0]` the Euclidean Algorithm is + used. The results of those recursive calls are then interpolated and + Rational Function Reconstruction is used to obtain the correct + coefficients. The results, both in + `\mathbb Q(x_1, \ldots, x_{n-1})[z]/(m_{\alpha}(z))[x_0]` and + `\mathbb{Z}_p(x_1, \ldots, x_{n-1})[z]/(\check m_{\alpha}(z))[x_0]`, are + verified by a fraction free trial division. + + Apart from the above GCD computation some GCDs in + `\mathbb Q(\alpha)[x_1, \ldots, x_{n-1}]` have to be calculated, + because treating the polynomials as univariate ones can result in + a spurious content of the GCD. For this ``func_field_modgcd`` is + called recursively. + + Parameters + ========== + + f, g : PolyElement + polynomials in `\mathbb Q(\alpha)[x_0, \ldots, x_{n-1}]` + + Returns + ======= + + h : PolyElement + monic GCD of the polynomials `f` and `g` + cff : PolyElement + cofactor of `f`, i.e. `\frac f h` + cfg : PolyElement + cofactor of `g`, i.e. `\frac g h` + + Examples + ======== + + >>> from sympy.polys.modulargcd import func_field_modgcd + >>> from sympy.polys import AlgebraicField, QQ, ring + >>> from sympy import sqrt + + >>> A = AlgebraicField(QQ, sqrt(2)) + >>> R, x = ring('x', A) + + >>> f = x**2 - 2 + >>> g = x + sqrt(2) + + >>> h, cff, cfg = func_field_modgcd(f, g) + + >>> h == x + sqrt(2) + True + >>> cff * h == f + True + >>> cfg * h == g + True + + >>> R, x, y = ring('x, y', A) + + >>> f = x**2 + 2*sqrt(2)*x*y + 2*y**2 + >>> g = x + sqrt(2)*y + + >>> h, cff, cfg = func_field_modgcd(f, g) + + >>> h == x + sqrt(2)*y + True + >>> cff * h == f + True + >>> cfg * h == g + True + + >>> f = x + sqrt(2)*y + >>> g = x + y + + >>> h, cff, cfg = func_field_modgcd(f, g) + + >>> h == R.one + True + >>> cff * h == f + True + >>> cfg * h == g + True + + References + ========== + + 1. [Hoeij04]_ + + """ + ring = f.ring + domain = ring.domain + n = ring.ngens + + assert ring == g.ring and domain.is_Algebraic + + result = _trivial_gcd(f, g) + if result is not None: + return result + + z = Dummy('z') + + ZZring = ring.clone(symbols=ring.symbols + (z,), domain=domain.domain.get_ring()) + + if n == 1: + f_ = _to_ZZ_poly(f, ZZring) + g_ = _to_ZZ_poly(g, ZZring) + minpoly = ZZring.drop(0).from_dense(domain.mod.to_list()) + + h = _func_field_modgcd_m(f_, g_, minpoly) + h = _to_ANP_poly(h, ring) + + else: + # contx0f in Q(a)[x_1, ..., x_{n-1}], f in Q(a)[x_0, ..., x_{n-1}] + contx0f, f = _primitive_in_x0(f) + contx0g, g = _primitive_in_x0(g) + contx0h = func_field_modgcd(contx0f, contx0g)[0] + + ZZring_ = ZZring.drop_to_ground(*range(1, n)) + + f_ = _to_ZZ_poly(f, ZZring_) + g_ = _to_ZZ_poly(g, ZZring_) + minpoly = _minpoly_from_dense(domain.mod, ZZring_.drop(0)) + + h = _func_field_modgcd_m(f_, g_, minpoly) + h = _to_ANP_poly(h, ring) + + contx0h_, h = _primitive_in_x0(h) + h *= contx0h.set_ring(ring) + f *= contx0f.set_ring(ring) + g *= contx0g.set_ring(ring) + + h = h.quo_ground(h.LC) + + return h, f.quo(h), g.quo(h) diff --git a/phivenv/Lib/site-packages/sympy/polys/monomials.py b/phivenv/Lib/site-packages/sympy/polys/monomials.py new file mode 100644 index 0000000000000000000000000000000000000000..43a17223861f656b8a4a51fb4c0c934635ed4623 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/monomials.py @@ -0,0 +1,628 @@ +"""Tools and arithmetics for monomials of distributed polynomials. """ + + +from itertools import combinations_with_replacement, product +from textwrap import dedent + +from sympy.core.cache import cacheit +from sympy.core import Mul, S, Tuple, sympify +from sympy.polys.polyerrors import ExactQuotientFailed +from sympy.polys.polyutils import PicklableWithSlots, dict_from_expr +from sympy.utilities import public +from sympy.utilities.iterables import is_sequence, iterable + +@public +def itermonomials(variables, max_degrees, min_degrees=None): + r""" + ``max_degrees`` and ``min_degrees`` are either both integers or both lists. + Unless otherwise specified, ``min_degrees`` is either ``0`` or + ``[0, ..., 0]``. + + A generator of all monomials ``monom`` is returned, such that + either + ``min_degree <= total_degree(monom) <= max_degree``, + or + ``min_degrees[i] <= degree_list(monom)[i] <= max_degrees[i]``, + for all ``i``. + + Case I. ``max_degrees`` and ``min_degrees`` are both integers + ============================================================= + + Given a set of variables $V$ and a min_degree $N$ and a max_degree $M$ + generate a set of monomials of degree less than or equal to $N$ and greater + than or equal to $M$. The total number of monomials in commutative + variables is huge and is given by the following formula if $M = 0$: + + .. math:: + \frac{(\#V + N)!}{\#V! N!} + + For example if we would like to generate a dense polynomial of + a total degree $N = 50$ and $M = 0$, which is the worst case, in 5 + variables, assuming that exponents and all of coefficients are 32-bit long + and stored in an array we would need almost 80 GiB of memory! Fortunately + most polynomials, that we will encounter, are sparse. + + Consider monomials in commutative variables $x$ and $y$ + and non-commutative variables $a$ and $b$:: + + >>> from sympy import symbols + >>> from sympy.polys.monomials import itermonomials + >>> from sympy.polys.orderings import monomial_key + >>> from sympy.abc import x, y + + >>> sorted(itermonomials([x, y], 2), key=monomial_key('grlex', [y, x])) + [1, x, y, x**2, x*y, y**2] + + >>> sorted(itermonomials([x, y], 3), key=monomial_key('grlex', [y, x])) + [1, x, y, x**2, x*y, y**2, x**3, x**2*y, x*y**2, y**3] + + >>> a, b = symbols('a, b', commutative=False) + >>> set(itermonomials([a, b, x], 2)) + {1, a, a**2, b, b**2, x, x**2, a*b, b*a, x*a, x*b} + + >>> sorted(itermonomials([x, y], 2, 1), key=monomial_key('grlex', [y, x])) + [x, y, x**2, x*y, y**2] + + Case II. ``max_degrees`` and ``min_degrees`` are both lists + =========================================================== + + If ``max_degrees = [d_1, ..., d_n]`` and + ``min_degrees = [e_1, ..., e_n]``, the number of monomials generated + is: + + .. math:: + (d_1 - e_1 + 1) (d_2 - e_2 + 1) \cdots (d_n - e_n + 1) + + Let us generate all monomials ``monom`` in variables $x$ and $y$ + such that ``[1, 2][i] <= degree_list(monom)[i] <= [2, 4][i]``, + ``i = 0, 1`` :: + + >>> from sympy import symbols + >>> from sympy.polys.monomials import itermonomials + >>> from sympy.polys.orderings import monomial_key + >>> from sympy.abc import x, y + + >>> sorted(itermonomials([x, y], [2, 4], [1, 2]), reverse=True, key=monomial_key('lex', [x, y])) + [x**2*y**4, x**2*y**3, x**2*y**2, x*y**4, x*y**3, x*y**2] + """ + if is_sequence(max_degrees): + n = len(variables) + if len(max_degrees) != n: + raise ValueError('Argument sizes do not match') + if min_degrees is None: + min_degrees = [0]*n + elif not is_sequence(min_degrees): + raise ValueError('min_degrees is not a list') + else: + if len(min_degrees) != n: + raise ValueError('Argument sizes do not match') + if any(i < 0 for i in min_degrees): + raise ValueError("min_degrees cannot contain negative numbers") + if any(min_degrees[i] > max_degrees[i] for i in range(n)): + raise ValueError('min_degrees[i] must be <= max_degrees[i] for all i') + power_lists = [] + for var, min_d, max_d in zip(variables, min_degrees, max_degrees): + power_lists.append([var**i for i in range(min_d, max_d + 1)]) + for powers in product(*power_lists): + yield Mul(*powers) + else: + max_degree = max_degrees + if max_degree < 0: + raise ValueError("max_degrees cannot be negative") + if min_degrees is None: + min_degree = 0 + else: + if min_degrees < 0: + raise ValueError("min_degrees cannot be negative") + min_degree = min_degrees + if min_degree > max_degree: + return + if not variables or max_degree == 0: + yield S.One + return + # Force to list in case of passed tuple or other incompatible collection + variables = list(variables) + [S.One] + if all(variable.is_commutative for variable in variables): + it = combinations_with_replacement(variables, max_degree) + else: + it = product(variables, repeat=max_degree) + monomials_set = set() + d = max_degree - min_degree + for item in it: + count = 0 + for variable in item: + if variable == 1: + count += 1 + if d < count: + break + else: + monomials_set.add(Mul(*item)) + yield from monomials_set + +def monomial_count(V, N): + r""" + Computes the number of monomials. + + The number of monomials is given by the following formula: + + .. math:: + + \frac{(\#V + N)!}{\#V! N!} + + where `N` is a total degree and `V` is a set of variables. + + Examples + ======== + + >>> from sympy.polys.monomials import itermonomials, monomial_count + >>> from sympy.polys.orderings import monomial_key + >>> from sympy.abc import x, y + + >>> monomial_count(2, 2) + 6 + + >>> M = list(itermonomials([x, y], 2)) + + >>> sorted(M, key=monomial_key('grlex', [y, x])) + [1, x, y, x**2, x*y, y**2] + >>> len(M) + 6 + + """ + from sympy.functions.combinatorial.factorials import factorial + return factorial(V + N) / factorial(V) / factorial(N) + +def monomial_mul(A, B): + """ + Multiplication of tuples representing monomials. + + Examples + ======== + + Lets multiply `x**3*y**4*z` with `x*y**2`:: + + >>> from sympy.polys.monomials import monomial_mul + + >>> monomial_mul((3, 4, 1), (1, 2, 0)) + (4, 6, 1) + + which gives `x**4*y**5*z`. + + """ + return tuple([ a + b for a, b in zip(A, B) ]) + +def monomial_div(A, B): + """ + Division of tuples representing monomials. + + Examples + ======== + + Lets divide `x**3*y**4*z` by `x*y**2`:: + + >>> from sympy.polys.monomials import monomial_div + + >>> monomial_div((3, 4, 1), (1, 2, 0)) + (2, 2, 1) + + which gives `x**2*y**2*z`. However:: + + >>> monomial_div((3, 4, 1), (1, 2, 2)) is None + True + + `x*y**2*z**2` does not divide `x**3*y**4*z`. + + """ + C = monomial_ldiv(A, B) + + if all(c >= 0 for c in C): + return tuple(C) + else: + return None + +def monomial_ldiv(A, B): + """ + Division of tuples representing monomials. + + Examples + ======== + + Lets divide `x**3*y**4*z` by `x*y**2`:: + + >>> from sympy.polys.monomials import monomial_ldiv + + >>> monomial_ldiv((3, 4, 1), (1, 2, 0)) + (2, 2, 1) + + which gives `x**2*y**2*z`. + + >>> monomial_ldiv((3, 4, 1), (1, 2, 2)) + (2, 2, -1) + + which gives `x**2*y**2*z**-1`. + + """ + return tuple([ a - b for a, b in zip(A, B) ]) + +def monomial_pow(A, n): + """Return the n-th pow of the monomial. """ + return tuple([ a*n for a in A ]) + +def monomial_gcd(A, B): + """ + Greatest common divisor of tuples representing monomials. + + Examples + ======== + + Lets compute GCD of `x*y**4*z` and `x**3*y**2`:: + + >>> from sympy.polys.monomials import monomial_gcd + + >>> monomial_gcd((1, 4, 1), (3, 2, 0)) + (1, 2, 0) + + which gives `x*y**2`. + + """ + return tuple([ min(a, b) for a, b in zip(A, B) ]) + +def monomial_lcm(A, B): + """ + Least common multiple of tuples representing monomials. + + Examples + ======== + + Lets compute LCM of `x*y**4*z` and `x**3*y**2`:: + + >>> from sympy.polys.monomials import monomial_lcm + + >>> monomial_lcm((1, 4, 1), (3, 2, 0)) + (3, 4, 1) + + which gives `x**3*y**4*z`. + + """ + return tuple([ max(a, b) for a, b in zip(A, B) ]) + +def monomial_divides(A, B): + """ + Does there exist a monomial X such that XA == B? + + Examples + ======== + + >>> from sympy.polys.monomials import monomial_divides + >>> monomial_divides((1, 2), (3, 4)) + True + >>> monomial_divides((1, 2), (0, 2)) + False + """ + return all(a <= b for a, b in zip(A, B)) + +def monomial_max(*monoms): + """ + Returns maximal degree for each variable in a set of monomials. + + Examples + ======== + + Consider monomials `x**3*y**4*z**5`, `y**5*z` and `x**6*y**3*z**9`. + We wish to find out what is the maximal degree for each of `x`, `y` + and `z` variables:: + + >>> from sympy.polys.monomials import monomial_max + + >>> monomial_max((3,4,5), (0,5,1), (6,3,9)) + (6, 5, 9) + + """ + M = list(monoms[0]) + + for N in monoms[1:]: + for i, n in enumerate(N): + M[i] = max(M[i], n) + + return tuple(M) + +def monomial_min(*monoms): + """ + Returns minimal degree for each variable in a set of monomials. + + Examples + ======== + + Consider monomials `x**3*y**4*z**5`, `y**5*z` and `x**6*y**3*z**9`. + We wish to find out what is the minimal degree for each of `x`, `y` + and `z` variables:: + + >>> from sympy.polys.monomials import monomial_min + + >>> monomial_min((3,4,5), (0,5,1), (6,3,9)) + (0, 3, 1) + + """ + M = list(monoms[0]) + + for N in monoms[1:]: + for i, n in enumerate(N): + M[i] = min(M[i], n) + + return tuple(M) + +def monomial_deg(M): + """ + Returns the total degree of a monomial. + + Examples + ======== + + The total degree of `xy^2` is 3: + + >>> from sympy.polys.monomials import monomial_deg + >>> monomial_deg((1, 2)) + 3 + """ + return sum(M) + +def term_div(a, b, domain): + """Division of two terms in over a ring/field. """ + a_lm, a_lc = a + b_lm, b_lc = b + + monom = monomial_div(a_lm, b_lm) + + if domain.is_Field: + if monom is not None: + return monom, domain.quo(a_lc, b_lc) + else: + return None + else: + if not (monom is None or a_lc % b_lc): + return monom, domain.quo(a_lc, b_lc) + else: + return None + +class MonomialOps: + """Code generator of fast monomial arithmetic functions. """ + + @cacheit + def __new__(cls, ngens): + obj = super().__new__(cls) + obj.ngens = ngens + return obj + + def __getnewargs__(self): + return (self.ngens,) + + def _build(self, code, name): + ns = {} + exec(code, ns) + return ns[name] + + def _vars(self, name): + return [ "%s%s" % (name, i) for i in range(self.ngens) ] + + @cacheit + def mul(self): + name = "monomial_mul" + template = dedent("""\ + def %(name)s(A, B): + (%(A)s,) = A + (%(B)s,) = B + return (%(AB)s,) + """) + A = self._vars("a") + B = self._vars("b") + AB = [ "%s + %s" % (a, b) for a, b in zip(A, B) ] + code = template % {"name": name, "A": ", ".join(A), "B": ", ".join(B), "AB": ", ".join(AB)} + return self._build(code, name) + + @cacheit + def pow(self): + name = "monomial_pow" + template = dedent("""\ + def %(name)s(A, k): + (%(A)s,) = A + return (%(Ak)s,) + """) + A = self._vars("a") + Ak = [ "%s*k" % a for a in A ] + code = template % {"name": name, "A": ", ".join(A), "Ak": ", ".join(Ak)} + return self._build(code, name) + + @cacheit + def mulpow(self): + name = "monomial_mulpow" + template = dedent("""\ + def %(name)s(A, B, k): + (%(A)s,) = A + (%(B)s,) = B + return (%(ABk)s,) + """) + A = self._vars("a") + B = self._vars("b") + ABk = [ "%s + %s*k" % (a, b) for a, b in zip(A, B) ] + code = template % {"name": name, "A": ", ".join(A), "B": ", ".join(B), "ABk": ", ".join(ABk)} + return self._build(code, name) + + @cacheit + def ldiv(self): + name = "monomial_ldiv" + template = dedent("""\ + def %(name)s(A, B): + (%(A)s,) = A + (%(B)s,) = B + return (%(AB)s,) + """) + A = self._vars("a") + B = self._vars("b") + AB = [ "%s - %s" % (a, b) for a, b in zip(A, B) ] + code = template % {"name": name, "A": ", ".join(A), "B": ", ".join(B), "AB": ", ".join(AB)} + return self._build(code, name) + + @cacheit + def div(self): + name = "monomial_div" + template = dedent("""\ + def %(name)s(A, B): + (%(A)s,) = A + (%(B)s,) = B + %(RAB)s + return (%(R)s,) + """) + A = self._vars("a") + B = self._vars("b") + RAB = [ "r%(i)s = a%(i)s - b%(i)s\n if r%(i)s < 0: return None" % {"i": i} for i in range(self.ngens) ] + R = self._vars("r") + code = template % {"name": name, "A": ", ".join(A), "B": ", ".join(B), "RAB": "\n ".join(RAB), "R": ", ".join(R)} + return self._build(code, name) + + @cacheit + def lcm(self): + name = "monomial_lcm" + template = dedent("""\ + def %(name)s(A, B): + (%(A)s,) = A + (%(B)s,) = B + return (%(AB)s,) + """) + A = self._vars("a") + B = self._vars("b") + AB = [ "%s if %s >= %s else %s" % (a, a, b, b) for a, b in zip(A, B) ] + code = template % {"name": name, "A": ", ".join(A), "B": ", ".join(B), "AB": ", ".join(AB)} + return self._build(code, name) + + @cacheit + def gcd(self): + name = "monomial_gcd" + template = dedent("""\ + def %(name)s(A, B): + (%(A)s,) = A + (%(B)s,) = B + return (%(AB)s,) + """) + A = self._vars("a") + B = self._vars("b") + AB = [ "%s if %s <= %s else %s" % (a, a, b, b) for a, b in zip(A, B) ] + code = template % {"name": name, "A": ", ".join(A), "B": ", ".join(B), "AB": ", ".join(AB)} + return self._build(code, name) + +@public +class Monomial(PicklableWithSlots): + """Class representing a monomial, i.e. a product of powers. """ + + __slots__ = ('exponents', 'gens') + + def __init__(self, monom, gens=None): + if not iterable(monom): + rep, gens = dict_from_expr(sympify(monom), gens=gens) + if len(rep) == 1 and list(rep.values())[0] == 1: + monom = list(rep.keys())[0] + else: + raise ValueError("Expected a monomial got {}".format(monom)) + + self.exponents = tuple(map(int, monom)) + self.gens = gens + + def rebuild(self, exponents, gens=None): + return self.__class__(exponents, gens or self.gens) + + def __len__(self): + return len(self.exponents) + + def __iter__(self): + return iter(self.exponents) + + def __getitem__(self, item): + return self.exponents[item] + + def __hash__(self): + return hash((self.__class__.__name__, self.exponents, self.gens)) + + def __str__(self): + if self.gens: + return "*".join([ "%s**%s" % (gen, exp) for gen, exp in zip(self.gens, self.exponents) ]) + else: + return "%s(%s)" % (self.__class__.__name__, self.exponents) + + def as_expr(self, *gens): + """Convert a monomial instance to a SymPy expression. """ + gens = gens or self.gens + + if not gens: + raise ValueError( + "Cannot convert %s to an expression without generators" % self) + + return Mul(*[ gen**exp for gen, exp in zip(gens, self.exponents) ]) + + def __eq__(self, other): + if isinstance(other, Monomial): + exponents = other.exponents + elif isinstance(other, (tuple, Tuple)): + exponents = other + else: + return False + + return self.exponents == exponents + + def __ne__(self, other): + return not self == other + + def __mul__(self, other): + if isinstance(other, Monomial): + exponents = other.exponents + elif isinstance(other, (tuple, Tuple)): + exponents = other + else: + raise NotImplementedError + + return self.rebuild(monomial_mul(self.exponents, exponents)) + + def __truediv__(self, other): + if isinstance(other, Monomial): + exponents = other.exponents + elif isinstance(other, (tuple, Tuple)): + exponents = other + else: + raise NotImplementedError + + result = monomial_div(self.exponents, exponents) + + if result is not None: + return self.rebuild(result) + else: + raise ExactQuotientFailed(self, Monomial(other)) + + __floordiv__ = __truediv__ + + def __pow__(self, other): + n = int(other) + if n < 0: + raise ValueError("a non-negative integer expected, got %s" % other) + return self.rebuild(monomial_pow(self.exponents, n)) + + def gcd(self, other): + """Greatest common divisor of monomials. """ + if isinstance(other, Monomial): + exponents = other.exponents + elif isinstance(other, (tuple, Tuple)): + exponents = other + else: + raise TypeError( + "an instance of Monomial class expected, got %s" % other) + + return self.rebuild(monomial_gcd(self.exponents, exponents)) + + def lcm(self, other): + """Least common multiple of monomials. """ + if isinstance(other, Monomial): + exponents = other.exponents + elif isinstance(other, (tuple, Tuple)): + exponents = other + else: + raise TypeError( + "an instance of Monomial class expected, got %s" % other) + + return self.rebuild(monomial_lcm(self.exponents, exponents)) diff --git a/phivenv/Lib/site-packages/sympy/polys/multivariate_resultants.py b/phivenv/Lib/site-packages/sympy/polys/multivariate_resultants.py new file mode 100644 index 0000000000000000000000000000000000000000..b6c967a8b981e25e8e26745804a658ff7b90e9af --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/multivariate_resultants.py @@ -0,0 +1,473 @@ +""" +This module contains functions for two multivariate resultants. These +are: + +- Dixon's resultant. +- Macaulay's resultant. + +Multivariate resultants are used to identify whether a multivariate +system has common roots. That is when the resultant is equal to zero. +""" +from math import prod + +from sympy.core.mul import Mul +from sympy.matrices.dense import (Matrix, diag) +from sympy.polys.polytools import (Poly, degree_list, rem) +from sympy.simplify.simplify import simplify +from sympy.tensor.indexed import IndexedBase +from sympy.polys.monomials import itermonomials, monomial_deg +from sympy.polys.orderings import monomial_key +from sympy.polys.polytools import poly_from_expr, total_degree +from sympy.functions.combinatorial.factorials import binomial +from itertools import combinations_with_replacement +from sympy.utilities.exceptions import sympy_deprecation_warning + +class DixonResultant(): + """ + A class for retrieving the Dixon's resultant of a multivariate + system. + + Examples + ======== + + >>> from sympy import symbols + + >>> from sympy.polys.multivariate_resultants import DixonResultant + >>> x, y = symbols('x, y') + + >>> p = x + y + >>> q = x ** 2 + y ** 3 + >>> h = x ** 2 + y + + >>> dixon = DixonResultant(variables=[x, y], polynomials=[p, q, h]) + >>> poly = dixon.get_dixon_polynomial() + >>> matrix = dixon.get_dixon_matrix(polynomial=poly) + >>> matrix + Matrix([ + [ 0, 0, -1, 0, -1], + [ 0, -1, 0, -1, 0], + [-1, 0, 1, 0, 0], + [ 0, -1, 0, 0, 1], + [-1, 0, 0, 1, 0]]) + >>> matrix.det() + 0 + + See Also + ======== + + Notebook in examples: sympy/example/notebooks. + + References + ========== + + .. [1] [Kapur1994]_ + .. [2] [Palancz08]_ + + """ + + def __init__(self, polynomials, variables): + """ + A class that takes two lists, a list of polynomials and list of + variables. Returns the Dixon matrix of the multivariate system. + + Parameters + ---------- + polynomials : list of polynomials + A list of m n-degree polynomials + variables: list + A list of all n variables + """ + self.polynomials = polynomials + self.variables = variables + + self.n = len(self.variables) + self.m = len(self.polynomials) + + a = IndexedBase("alpha") + # A list of n alpha variables (the replacing variables) + self.dummy_variables = [a[i] for i in range(self.n)] + + # A list of the d_max of each variable. + self._max_degrees = [max(degree_list(poly)[i] for poly in self.polynomials) + for i in range(self.n)] + + @property + def max_degrees(self): + sympy_deprecation_warning( + """ + The max_degrees property of DixonResultant is deprecated. + """, + deprecated_since_version="1.5", + active_deprecations_target="deprecated-dixonresultant-properties", + ) + return self._max_degrees + + def get_dixon_polynomial(self): + r""" + Returns + ======= + + dixon_polynomial: polynomial + Dixon's polynomial is calculated as: + + delta = Delta(A) / ((x_1 - a_1) ... (x_n - a_n)) where, + + A = |p_1(x_1,... x_n), ..., p_n(x_1,... x_n)| + |p_1(a_1,... x_n), ..., p_n(a_1,... x_n)| + |... , ..., ...| + |p_1(a_1,... a_n), ..., p_n(a_1,... a_n)| + """ + if self.m != (self.n + 1): + raise ValueError('Method invalid for given combination.') + + # First row + rows = [self.polynomials] + + temp = list(self.variables) + + for idx in range(self.n): + temp[idx] = self.dummy_variables[idx] + substitution = dict(zip(self.variables, temp)) + rows.append([f.subs(substitution) for f in self.polynomials]) + + A = Matrix(rows) + + terms = zip(self.variables, self.dummy_variables) + product_of_differences = Mul(*[a - b for a, b in terms]) + dixon_polynomial = (A.det() / product_of_differences).factor() + + return poly_from_expr(dixon_polynomial, self.dummy_variables)[0] + + def get_upper_degree(self): + sympy_deprecation_warning( + """ + The get_upper_degree() method of DixonResultant is deprecated. Use + get_max_degrees() instead. + """, + deprecated_since_version="1.5", + active_deprecations_target="deprecated-dixonresultant-properties" + ) + list_of_products = [self.variables[i] ** self._max_degrees[i] + for i in range(self.n)] + product = prod(list_of_products) + product = Poly(product).monoms() + + return monomial_deg(*product) + + def get_max_degrees(self, polynomial): + r""" + Returns a list of the maximum degree of each variable appearing + in the coefficients of the Dixon polynomial. The coefficients are + viewed as polys in $x_1, x_2, \dots, x_n$. + """ + deg_lists = [degree_list(Poly(poly, self.variables)) + for poly in polynomial.coeffs()] + + max_degrees = [max(degs) for degs in zip(*deg_lists)] + + return max_degrees + + def get_dixon_matrix(self, polynomial): + r""" + Construct the Dixon matrix from the coefficients of polynomial + \alpha. Each coefficient is viewed as a polynomial of x_1, ..., + x_n. + """ + + max_degrees = self.get_max_degrees(polynomial) + + # list of column headers of the Dixon matrix. + monomials = itermonomials(self.variables, max_degrees) + monomials = sorted(monomials, reverse=True, + key=monomial_key('lex', self.variables)) + + dixon_matrix = Matrix([[Poly(c, *self.variables).coeff_monomial(m) + for m in monomials] + for c in polynomial.coeffs()]) + + # remove columns if needed + if dixon_matrix.shape[0] != dixon_matrix.shape[1]: + keep = [column for column in range(dixon_matrix.shape[-1]) + if any(element != 0 for element + in dixon_matrix[:, column])] + + dixon_matrix = dixon_matrix[:, keep] + + return dixon_matrix + + def KSY_precondition(self, matrix): + """ + Test for the validity of the Kapur-Saxena-Yang precondition. + + The precondition requires that the column corresponding to the + monomial 1 = x_1 ^ 0 * x_2 ^ 0 * ... * x_n ^ 0 is not a linear + combination of the remaining ones. In SymPy notation this is + the last column. For the precondition to hold the last non-zero + row of the rref matrix should be of the form [0, 0, ..., 1]. + """ + if matrix.is_zero_matrix: + return False + + m, n = matrix.shape + + # simplify the matrix and keep only its non-zero rows + matrix = simplify(matrix.rref()[0]) + rows = [i for i in range(m) if any(matrix[i, j] != 0 for j in range(n))] + matrix = matrix[rows,:] + + condition = Matrix([[0]*(n-1) + [1]]) + + if matrix[-1,:] == condition: + return True + else: + return False + + def delete_zero_rows_and_columns(self, matrix): + """Remove the zero rows and columns of the matrix.""" + rows = [ + i for i in range(matrix.rows) if not matrix.row(i).is_zero_matrix] + cols = [ + j for j in range(matrix.cols) if not matrix.col(j).is_zero_matrix] + + return matrix[rows, cols] + + def product_leading_entries(self, matrix): + """Calculate the product of the leading entries of the matrix.""" + res = 1 + for row in range(matrix.rows): + for el in matrix.row(row): + if el != 0: + res = res * el + break + return res + + def get_KSY_Dixon_resultant(self, matrix): + """Calculate the Kapur-Saxena-Yang approach to the Dixon Resultant.""" + matrix = self.delete_zero_rows_and_columns(matrix) + _, U, _ = matrix.LUdecomposition() + matrix = self.delete_zero_rows_and_columns(simplify(U)) + + return self.product_leading_entries(matrix) + +class MacaulayResultant(): + """ + A class for calculating the Macaulay resultant. Note that the + polynomials must be homogenized and their coefficients must be + given as symbols. + + Examples + ======== + + >>> from sympy import symbols + + >>> from sympy.polys.multivariate_resultants import MacaulayResultant + >>> x, y, z = symbols('x, y, z') + + >>> a_0, a_1, a_2 = symbols('a_0, a_1, a_2') + >>> b_0, b_1, b_2 = symbols('b_0, b_1, b_2') + >>> c_0, c_1, c_2,c_3, c_4 = symbols('c_0, c_1, c_2, c_3, c_4') + + >>> f = a_0 * y - a_1 * x + a_2 * z + >>> g = b_1 * x ** 2 + b_0 * y ** 2 - b_2 * z ** 2 + >>> h = c_0 * y * z ** 2 - c_1 * x ** 3 + c_2 * x ** 2 * z - c_3 * x * z ** 2 + c_4 * z ** 3 + + >>> mac = MacaulayResultant(polynomials=[f, g, h], variables=[x, y, z]) + >>> mac.monomial_set + [x**4, x**3*y, x**3*z, x**2*y**2, x**2*y*z, x**2*z**2, x*y**3, + x*y**2*z, x*y*z**2, x*z**3, y**4, y**3*z, y**2*z**2, y*z**3, z**4] + >>> matrix = mac.get_matrix() + >>> submatrix = mac.get_submatrix(matrix) + >>> submatrix + Matrix([ + [-a_1, a_0, a_2, 0], + [ 0, -a_1, 0, 0], + [ 0, 0, -a_1, 0], + [ 0, 0, 0, -a_1]]) + + See Also + ======== + + Notebook in examples: sympy/example/notebooks. + + References + ========== + + .. [1] [Bruce97]_ + .. [2] [Stiller96]_ + + """ + def __init__(self, polynomials, variables): + """ + Parameters + ========== + + variables: list + A list of all n variables + polynomials : list of SymPy polynomials + A list of m n-degree polynomials + """ + self.polynomials = polynomials + self.variables = variables + self.n = len(variables) + + # A list of the d_max of each polynomial. + self.degrees = [total_degree(poly, *self.variables) for poly + in self.polynomials] + + self.degree_m = self._get_degree_m() + self.monomials_size = self.get_size() + + # The set T of all possible monomials of degree degree_m + self.monomial_set = self.get_monomials_of_certain_degree(self.degree_m) + + def _get_degree_m(self): + r""" + Returns + ======= + + degree_m: int + The degree_m is calculated as 1 + \sum_1 ^ n (d_i - 1), + where d_i is the degree of the i polynomial + """ + return 1 + sum(d - 1 for d in self.degrees) + + def get_size(self): + r""" + Returns + ======= + + size: int + The size of set T. Set T is the set of all possible + monomials of the n variables for degree equal to the + degree_m + """ + return binomial(self.degree_m + self.n - 1, self.n - 1) + + def get_monomials_of_certain_degree(self, degree): + """ + Returns + ======= + + monomials: list + A list of monomials of a certain degree. + """ + monomials = [Mul(*monomial) for monomial + in combinations_with_replacement(self.variables, + degree)] + + return sorted(monomials, reverse=True, + key=monomial_key('lex', self.variables)) + + def get_row_coefficients(self): + """ + Returns + ======= + + row_coefficients: list + The row coefficients of Macaulay's matrix + """ + row_coefficients = [] + divisible = [] + for i in range(self.n): + if i == 0: + degree = self.degree_m - self.degrees[i] + monomial = self.get_monomials_of_certain_degree(degree) + row_coefficients.append(monomial) + else: + divisible.append(self.variables[i - 1] ** + self.degrees[i - 1]) + degree = self.degree_m - self.degrees[i] + poss_rows = self.get_monomials_of_certain_degree(degree) + for div in divisible: + for p in poss_rows: + if rem(p, div) == 0: + poss_rows = [item for item in poss_rows + if item != p] + row_coefficients.append(poss_rows) + return row_coefficients + + def get_matrix(self): + """ + Returns + ======= + + macaulay_matrix: Matrix + The Macaulay numerator matrix + """ + rows = [] + row_coefficients = self.get_row_coefficients() + for i in range(self.n): + for multiplier in row_coefficients[i]: + coefficients = [] + poly = Poly(self.polynomials[i] * multiplier, + *self.variables) + + for mono in self.monomial_set: + coefficients.append(poly.coeff_monomial(mono)) + rows.append(coefficients) + + macaulay_matrix = Matrix(rows) + return macaulay_matrix + + def get_reduced_nonreduced(self): + r""" + Returns + ======= + + reduced: list + A list of the reduced monomials + non_reduced: list + A list of the monomials that are not reduced + + Definition + ========== + + A polynomial is said to be reduced in x_i, if its degree (the + maximum degree of its monomials) in x_i is less than d_i. A + polynomial that is reduced in all variables but one is said + simply to be reduced. + """ + divisible = [] + for m in self.monomial_set: + temp = [] + for i, v in enumerate(self.variables): + temp.append(bool(total_degree(m, v) >= self.degrees[i])) + divisible.append(temp) + reduced = [i for i, r in enumerate(divisible) + if sum(r) < self.n - 1] + non_reduced = [i for i, r in enumerate(divisible) + if sum(r) >= self.n -1] + + return reduced, non_reduced + + def get_submatrix(self, matrix): + r""" + Returns + ======= + + macaulay_submatrix: Matrix + The Macaulay denominator matrix. Columns that are non reduced are kept. + The row which contains one of the a_{i}s is dropped. a_{i}s + are the coefficients of x_i ^ {d_i}. + """ + reduced, non_reduced = self.get_reduced_nonreduced() + + # if reduced == [], then det(matrix) should be 1 + if reduced == []: + return diag([1]) + + # reduced != [] + reduction_set = [v ** self.degrees[i] for i, v + in enumerate(self.variables)] + + ais = [self.polynomials[i].coeff(reduction_set[i]) + for i in range(self.n)] + + reduced_matrix = matrix[:, reduced] + keep = [] + for row in range(reduced_matrix.rows): + check = [ai in reduced_matrix[row, :] for ai in ais] + if True not in check: + keep.append(row) + + return matrix[keep, non_reduced] diff --git a/phivenv/Lib/site-packages/sympy/polys/orderings.py b/phivenv/Lib/site-packages/sympy/polys/orderings.py new file mode 100644 index 0000000000000000000000000000000000000000..b6ed575d5103440e1e8ebda4c53c4149d3badf11 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/orderings.py @@ -0,0 +1,286 @@ +"""Definitions of monomial orderings. """ + +from __future__ import annotations + +__all__ = ["lex", "grlex", "grevlex", "ilex", "igrlex", "igrevlex"] + +from sympy.core import Symbol +from sympy.utilities.iterables import iterable + +class MonomialOrder: + """Base class for monomial orderings. """ + + alias: str | None = None + is_global: bool | None = None + is_default = False + + def __repr__(self): + return self.__class__.__name__ + "()" + + def __str__(self): + return self.alias + + def __call__(self, monomial): + raise NotImplementedError + + def __eq__(self, other): + return self.__class__ == other.__class__ + + def __hash__(self): + return hash(self.__class__) + + def __ne__(self, other): + return not (self == other) + +class LexOrder(MonomialOrder): + """Lexicographic order of monomials. """ + + alias = 'lex' + is_global = True + is_default = True + + def __call__(self, monomial): + return monomial + +class GradedLexOrder(MonomialOrder): + """Graded lexicographic order of monomials. """ + + alias = 'grlex' + is_global = True + + def __call__(self, monomial): + return (sum(monomial), monomial) + +class ReversedGradedLexOrder(MonomialOrder): + """Reversed graded lexicographic order of monomials. """ + + alias = 'grevlex' + is_global = True + + def __call__(self, monomial): + return (sum(monomial), tuple(reversed([-m for m in monomial]))) + +class ProductOrder(MonomialOrder): + """ + A product order built from other monomial orders. + + Given (not necessarily total) orders O1, O2, ..., On, their product order + P is defined as M1 > M2 iff there exists i such that O1(M1) = O2(M2), + ..., Oi(M1) = Oi(M2), O{i+1}(M1) > O{i+1}(M2). + + Product orders are typically built from monomial orders on different sets + of variables. + + ProductOrder is constructed by passing a list of pairs + [(O1, L1), (O2, L2), ...] where Oi are MonomialOrders and Li are callables. + Upon comparison, the Li are passed the total monomial, and should filter + out the part of the monomial to pass to Oi. + + Examples + ======== + + We can use a lexicographic order on x_1, x_2 and also on + y_1, y_2, y_3, and their product on {x_i, y_i} as follows: + + >>> from sympy.polys.orderings import lex, grlex, ProductOrder + >>> P = ProductOrder( + ... (lex, lambda m: m[:2]), # lex order on x_1 and x_2 of monomial + ... (grlex, lambda m: m[2:]) # grlex on y_1, y_2, y_3 + ... ) + >>> P((2, 1, 1, 0, 0)) > P((1, 10, 0, 2, 0)) + True + + Here the exponent `2` of `x_1` in the first monomial + (`x_1^2 x_2 y_1`) is bigger than the exponent `1` of `x_1` in the + second monomial (`x_1 x_2^10 y_2^2`), so the first monomial is greater + in the product ordering. + + >>> P((2, 1, 1, 0, 0)) < P((2, 1, 0, 2, 0)) + True + + Here the exponents of `x_1` and `x_2` agree, so the grlex order on + `y_1, y_2, y_3` is used to decide the ordering. In this case the monomial + `y_2^2` is ordered larger than `y_1`, since for the grlex order the degree + of the monomial is most important. + """ + + def __init__(self, *args): + self.args = args + + def __call__(self, monomial): + return tuple(O(lamda(monomial)) for (O, lamda) in self.args) + + def __repr__(self): + contents = [repr(x[0]) for x in self.args] + return self.__class__.__name__ + '(' + ", ".join(contents) + ')' + + def __str__(self): + contents = [str(x[0]) for x in self.args] + return self.__class__.__name__ + '(' + ", ".join(contents) + ')' + + def __eq__(self, other): + if not isinstance(other, ProductOrder): + return False + return self.args == other.args + + def __hash__(self): + return hash((self.__class__, self.args)) + + @property + def is_global(self): + if all(o.is_global is True for o, _ in self.args): + return True + if all(o.is_global is False for o, _ in self.args): + return False + return None + +class InverseOrder(MonomialOrder): + """ + The "inverse" of another monomial order. + + If O is any monomial order, we can construct another monomial order iO + such that `A >_{iO} B` if and only if `B >_O A`. This is useful for + constructing local orders. + + Note that many algorithms only work with *global* orders. + + For example, in the inverse lexicographic order on a single variable `x`, + high powers of `x` count as small: + + >>> from sympy.polys.orderings import lex, InverseOrder + >>> ilex = InverseOrder(lex) + >>> ilex((5,)) < ilex((0,)) + True + """ + + def __init__(self, O): + self.O = O + + def __str__(self): + return "i" + str(self.O) + + def __call__(self, monomial): + def inv(l): + if iterable(l): + return tuple(inv(x) for x in l) + return -l + return inv(self.O(monomial)) + + @property + def is_global(self): + if self.O.is_global is True: + return False + if self.O.is_global is False: + return True + return None + + def __eq__(self, other): + return isinstance(other, InverseOrder) and other.O == self.O + + def __hash__(self): + return hash((self.__class__, self.O)) + +lex = LexOrder() +grlex = GradedLexOrder() +grevlex = ReversedGradedLexOrder() +ilex = InverseOrder(lex) +igrlex = InverseOrder(grlex) +igrevlex = InverseOrder(grevlex) + +_monomial_key = { + 'lex': lex, + 'grlex': grlex, + 'grevlex': grevlex, + 'ilex': ilex, + 'igrlex': igrlex, + 'igrevlex': igrevlex +} + +def monomial_key(order=None, gens=None): + """ + Return a function defining admissible order on monomials. + + The result of a call to :func:`monomial_key` is a function which should + be used as a key to :func:`sorted` built-in function, to provide order + in a set of monomials of the same length. + + Currently supported monomial orderings are: + + 1. lex - lexicographic order (default) + 2. grlex - graded lexicographic order + 3. grevlex - reversed graded lexicographic order + 4. ilex, igrlex, igrevlex - the corresponding inverse orders + + If the ``order`` input argument is not a string but has ``__call__`` + attribute, then it will pass through with an assumption that the + callable object defines an admissible order on monomials. + + If the ``gens`` input argument contains a list of generators, the + resulting key function can be used to sort SymPy ``Expr`` objects. + + """ + if order is None: + order = lex + + if isinstance(order, Symbol): + order = str(order) + + if isinstance(order, str): + try: + order = _monomial_key[order] + except KeyError: + raise ValueError("supported monomial orderings are 'lex', 'grlex' and 'grevlex', got %r" % order) + if hasattr(order, '__call__'): + if gens is not None: + def _order(expr): + return order(expr.as_poly(*gens).degree_list()) + return _order + return order + else: + raise ValueError("monomial ordering specification must be a string or a callable, got %s" % order) + +class _ItemGetter: + """Helper class to return a subsequence of values.""" + + def __init__(self, seq): + self.seq = tuple(seq) + + def __call__(self, m): + return tuple(m[idx] for idx in self.seq) + + def __eq__(self, other): + if not isinstance(other, _ItemGetter): + return False + return self.seq == other.seq + +def build_product_order(arg, gens): + """ + Build a monomial order on ``gens``. + + ``arg`` should be a tuple of iterables. The first element of each iterable + should be a string or monomial order (will be passed to monomial_key), + the others should be subsets of the generators. This function will build + the corresponding product order. + + For example, build a product of two grlex orders: + + >>> from sympy.polys.orderings import build_product_order + >>> from sympy.abc import x, y, z, t + + >>> O = build_product_order((("grlex", x, y), ("grlex", z, t)), [x, y, z, t]) + >>> O((1, 2, 3, 4)) + ((3, (1, 2)), (7, (3, 4))) + + """ + gens2idx = {} + for i, g in enumerate(gens): + gens2idx[g] = i + order = [] + for expr in arg: + name = expr[0] + var = expr[1:] + + def makelambda(var): + return _ItemGetter(gens2idx[g] for g in var) + order.append((monomial_key(name), makelambda(var))) + return ProductOrder(*order) diff --git a/phivenv/Lib/site-packages/sympy/polys/orthopolys.py b/phivenv/Lib/site-packages/sympy/polys/orthopolys.py new file mode 100644 index 0000000000000000000000000000000000000000..ee82457703a2be172951ee38e3cd67f221a438a0 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/orthopolys.py @@ -0,0 +1,343 @@ +"""Efficient functions for generating orthogonal polynomials.""" +from sympy.core.symbol import Dummy +from sympy.polys.densearith import (dup_mul, dup_mul_ground, + dup_lshift, dup_sub, dup_add, dup_sub_term, dup_sub_ground, dup_sqr) +from sympy.polys.domains import ZZ, QQ +from sympy.polys.polytools import named_poly +from sympy.utilities import public + +def dup_jacobi(n, a, b, K): + """Low-level implementation of Jacobi polynomials.""" + if n < 1: + return [K.one] + m2, m1 = [K.one], [(a+b)/K(2) + K.one, (a-b)/K(2)] + for i in range(2, n+1): + den = K(i)*(a + b + i)*(a + b + K(2)*i - K(2)) + f0 = (a + b + K(2)*i - K.one) * (a*a - b*b) / (K(2)*den) + f1 = (a + b + K(2)*i - K.one) * (a + b + K(2)*i - K(2)) * (a + b + K(2)*i) / (K(2)*den) + f2 = (a + i - K.one)*(b + i - K.one)*(a + b + K(2)*i) / den + p0 = dup_mul_ground(m1, f0, K) + p1 = dup_mul_ground(dup_lshift(m1, 1, K), f1, K) + p2 = dup_mul_ground(m2, f2, K) + m2, m1 = m1, dup_sub(dup_add(p0, p1, K), p2, K) + return m1 + +@public +def jacobi_poly(n, a, b, x=None, polys=False): + r"""Generates the Jacobi polynomial `P_n^{(a,b)}(x)`. + + Parameters + ========== + + n : int + Degree of the polynomial. + a + Lower limit of minimal domain for the list of coefficients. + b + Upper limit of minimal domain for the list of coefficients. + x : optional + polys : bool, optional + If True, return a Poly, otherwise (default) return an expression. + """ + return named_poly(n, dup_jacobi, None, "Jacobi polynomial", (x, a, b), polys) + +def dup_gegenbauer(n, a, K): + """Low-level implementation of Gegenbauer polynomials.""" + if n < 1: + return [K.one] + m2, m1 = [K.one], [K(2)*a, K.zero] + for i in range(2, n+1): + p1 = dup_mul_ground(dup_lshift(m1, 1, K), K(2)*(a-K.one)/K(i) + K(2), K) + p2 = dup_mul_ground(m2, K(2)*(a-K.one)/K(i) + K.one, K) + m2, m1 = m1, dup_sub(p1, p2, K) + return m1 + +def gegenbauer_poly(n, a, x=None, polys=False): + r"""Generates the Gegenbauer polynomial `C_n^{(a)}(x)`. + + Parameters + ========== + + n : int + Degree of the polynomial. + x : optional + a + Decides minimal domain for the list of coefficients. + polys : bool, optional + If True, return a Poly, otherwise (default) return an expression. + """ + return named_poly(n, dup_gegenbauer, None, "Gegenbauer polynomial", (x, a), polys) + +def dup_chebyshevt(n, K): + """Low-level implementation of Chebyshev polynomials of the first kind.""" + if n < 1: + return [K.one] + # When n is small, it is faster to directly calculate the recurrence relation. + if n < 64: # The threshold serves as a heuristic + return _dup_chebyshevt_rec(n, K) + return _dup_chebyshevt_prod(n, K) + +def _dup_chebyshevt_rec(n, K): + r""" Chebyshev polynomials of the first kind using recurrence. + + Explanation + =========== + + Chebyshev polynomials of the first kind are defined by the recurrence + relation: + + .. math:: + T_0(x) &= 1\\ + T_1(x) &= x\\ + T_n(x) &= 2xT_{n-1}(x) - T_{n-2}(x) + + This function calculates the Chebyshev polynomial of the first kind using + the above recurrence relation. + + Parameters + ========== + + n : int + n is a nonnegative integer. + K : domain + + """ + m2, m1 = [K.one], [K.one, K.zero] + for _ in range(n - 1): + m2, m1 = m1, dup_sub(dup_mul_ground(dup_lshift(m1, 1, K), K(2), K), m2, K) + return m1 + +def _dup_chebyshevt_prod(n, K): + r""" Chebyshev polynomials of the first kind using recursive products. + + Explanation + =========== + + Computes Chebyshev polynomials of the first kind using + + .. math:: + T_{2n}(x) &= 2T_n^2(x) - 1\\ + T_{2n+1}(x) &= 2T_{n+1}(x)T_n(x) - x + + This is faster than ``_dup_chebyshevt_rec`` for large ``n``. + + Parameters + ========== + + n : int + n is a nonnegative integer. + K : domain + + """ + m2, m1 = [K.one, K.zero], [K(2), K.zero, -K.one] + for i in bin(n)[3:]: + c = dup_sub_term(dup_mul_ground(dup_mul(m1, m2, K), K(2), K), K.one, 1, K) + if i == '1': + m2, m1 = c, dup_sub_ground(dup_mul_ground(dup_sqr(m1, K), K(2), K), K.one, K) + else: + m2, m1 = dup_sub_ground(dup_mul_ground(dup_sqr(m2, K), K(2), K), K.one, K), c + return m2 + +def dup_chebyshevu(n, K): + """Low-level implementation of Chebyshev polynomials of the second kind.""" + if n < 1: + return [K.one] + m2, m1 = [K.one], [K(2), K.zero] + for i in range(2, n+1): + m2, m1 = m1, dup_sub(dup_mul_ground(dup_lshift(m1, 1, K), K(2), K), m2, K) + return m1 + +@public +def chebyshevt_poly(n, x=None, polys=False): + r"""Generates the Chebyshev polynomial of the first kind `T_n(x)`. + + Parameters + ========== + + n : int + Degree of the polynomial. + x : optional + polys : bool, optional + If True, return a Poly, otherwise (default) return an expression. + """ + return named_poly(n, dup_chebyshevt, ZZ, + "Chebyshev polynomial of the first kind", (x,), polys) + +@public +def chebyshevu_poly(n, x=None, polys=False): + r"""Generates the Chebyshev polynomial of the second kind `U_n(x)`. + + Parameters + ========== + + n : int + Degree of the polynomial. + x : optional + polys : bool, optional + If True, return a Poly, otherwise (default) return an expression. + """ + return named_poly(n, dup_chebyshevu, ZZ, + "Chebyshev polynomial of the second kind", (x,), polys) + +def dup_hermite(n, K): + """Low-level implementation of Hermite polynomials.""" + if n < 1: + return [K.one] + m2, m1 = [K.one], [K(2), K.zero] + for i in range(2, n+1): + a = dup_lshift(m1, 1, K) + b = dup_mul_ground(m2, K(i-1), K) + m2, m1 = m1, dup_mul_ground(dup_sub(a, b, K), K(2), K) + return m1 + +def dup_hermite_prob(n, K): + """Low-level implementation of probabilist's Hermite polynomials.""" + if n < 1: + return [K.one] + m2, m1 = [K.one], [K.one, K.zero] + for i in range(2, n+1): + a = dup_lshift(m1, 1, K) + b = dup_mul_ground(m2, K(i-1), K) + m2, m1 = m1, dup_sub(a, b, K) + return m1 + +@public +def hermite_poly(n, x=None, polys=False): + r"""Generates the Hermite polynomial `H_n(x)`. + + Parameters + ========== + + n : int + Degree of the polynomial. + x : optional + polys : bool, optional + If True, return a Poly, otherwise (default) return an expression. + """ + return named_poly(n, dup_hermite, ZZ, "Hermite polynomial", (x,), polys) + +@public +def hermite_prob_poly(n, x=None, polys=False): + r"""Generates the probabilist's Hermite polynomial `He_n(x)`. + + Parameters + ========== + + n : int + Degree of the polynomial. + x : optional + polys : bool, optional + If True, return a Poly, otherwise (default) return an expression. + """ + return named_poly(n, dup_hermite_prob, ZZ, + "probabilist's Hermite polynomial", (x,), polys) + +def dup_legendre(n, K): + """Low-level implementation of Legendre polynomials.""" + if n < 1: + return [K.one] + m2, m1 = [K.one], [K.one, K.zero] + for i in range(2, n+1): + a = dup_mul_ground(dup_lshift(m1, 1, K), K(2*i-1, i), K) + b = dup_mul_ground(m2, K(i-1, i), K) + m2, m1 = m1, dup_sub(a, b, K) + return m1 + +@public +def legendre_poly(n, x=None, polys=False): + r"""Generates the Legendre polynomial `P_n(x)`. + + Parameters + ========== + + n : int + Degree of the polynomial. + x : optional + polys : bool, optional + If True, return a Poly, otherwise (default) return an expression. + """ + return named_poly(n, dup_legendre, QQ, "Legendre polynomial", (x,), polys) + +def dup_laguerre(n, alpha, K): + """Low-level implementation of Laguerre polynomials.""" + m2, m1 = [K.zero], [K.one] + for i in range(1, n+1): + a = dup_mul(m1, [-K.one/K(i), (alpha-K.one)/K(i) + K(2)], K) + b = dup_mul_ground(m2, (alpha-K.one)/K(i) + K.one, K) + m2, m1 = m1, dup_sub(a, b, K) + return m1 + +@public +def laguerre_poly(n, x=None, alpha=0, polys=False): + r"""Generates the Laguerre polynomial `L_n^{(\alpha)}(x)`. + + Parameters + ========== + + n : int + Degree of the polynomial. + x : optional + alpha : optional + Decides minimal domain for the list of coefficients. + polys : bool, optional + If True, return a Poly, otherwise (default) return an expression. + """ + return named_poly(n, dup_laguerre, None, "Laguerre polynomial", (x, alpha), polys) + +def dup_spherical_bessel_fn(n, K): + """Low-level implementation of fn(n, x).""" + if n < 1: + return [K.one, K.zero] + m2, m1 = [K.one], [K.one, K.zero] + for i in range(2, n+1): + m2, m1 = m1, dup_sub(dup_mul_ground(dup_lshift(m1, 1, K), K(2*i-1), K), m2, K) + return dup_lshift(m1, 1, K) + +def dup_spherical_bessel_fn_minus(n, K): + """Low-level implementation of fn(-n, x).""" + m2, m1 = [K.one, K.zero], [K.zero] + for i in range(2, n+1): + m2, m1 = m1, dup_sub(dup_mul_ground(dup_lshift(m1, 1, K), K(3-2*i), K), m2, K) + return m1 + +def spherical_bessel_fn(n, x=None, polys=False): + """ + Coefficients for the spherical Bessel functions. + + These are only needed in the jn() function. + + The coefficients are calculated from: + + fn(0, z) = 1/z + fn(1, z) = 1/z**2 + fn(n-1, z) + fn(n+1, z) == (2*n+1)/z * fn(n, z) + + Parameters + ========== + + n : int + Degree of the polynomial. + x : optional + polys : bool, optional + If True, return a Poly, otherwise (default) return an expression. + + Examples + ======== + + >>> from sympy.polys.orthopolys import spherical_bessel_fn as fn + >>> from sympy import Symbol + >>> z = Symbol("z") + >>> fn(1, z) + z**(-2) + >>> fn(2, z) + -1/z + 3/z**3 + >>> fn(3, z) + -6/z**2 + 15/z**4 + >>> fn(4, z) + 1/z - 45/z**3 + 105/z**5 + + """ + if x is None: + x = Dummy("x") + f = dup_spherical_bessel_fn_minus if n < 0 else dup_spherical_bessel_fn + return named_poly(abs(n), f, ZZ, "", (QQ(1)/x,), polys) diff --git a/phivenv/Lib/site-packages/sympy/polys/partfrac.py b/phivenv/Lib/site-packages/sympy/polys/partfrac.py new file mode 100644 index 0000000000000000000000000000000000000000..dedc1bf0fba42128e869303ed9b12c598640a36c --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/partfrac.py @@ -0,0 +1,496 @@ +"""Algorithms for partial fraction decomposition of rational functions. """ + + +from sympy.core import S, Add, sympify, Function, Lambda, Dummy +from sympy.core.traversal import preorder_traversal +from sympy.polys import Poly, RootSum, cancel, factor +from sympy.polys.polyerrors import PolynomialError +from sympy.polys.polyoptions import allowed_flags, set_defaults +from sympy.polys.polytools import parallel_poly_from_expr +from sympy.utilities import numbered_symbols, take, xthreaded, public + + +@xthreaded +@public +def apart(f, x=None, full=False, **options): + """ + Compute partial fraction decomposition of a rational function. + + Given a rational function ``f``, computes the partial fraction + decomposition of ``f``. Two algorithms are available: One is based on the + undetermined coefficients method, the other is Bronstein's full partial + fraction decomposition algorithm. + + The undetermined coefficients method (selected by ``full=False``) uses + polynomial factorization (and therefore accepts the same options as + factor) for the denominator. Per default it works over the rational + numbers, therefore decomposition of denominators with non-rational roots + (e.g. irrational, complex roots) is not supported by default (see options + of factor). + + Bronstein's algorithm can be selected by using ``full=True`` and allows a + decomposition of denominators with non-rational roots. A human-readable + result can be obtained via ``doit()`` (see examples below). + + Examples + ======== + + >>> from sympy.polys.partfrac import apart + >>> from sympy.abc import x, y + + By default, using the undetermined coefficients method: + + >>> apart(y/(x + 2)/(x + 1), x) + -y/(x + 2) + y/(x + 1) + + The undetermined coefficients method does not provide a result when the + denominators roots are not rational: + + >>> apart(y/(x**2 + x + 1), x) + y/(x**2 + x + 1) + + You can choose Bronstein's algorithm by setting ``full=True``: + + >>> apart(y/(x**2 + x + 1), x, full=True) + RootSum(_w**2 + _w + 1, Lambda(_a, (-2*_a*y/3 - y/3)/(-_a + x))) + + Calling ``doit()`` yields a human-readable result: + + >>> apart(y/(x**2 + x + 1), x, full=True).doit() + (-y/3 - 2*y*(-1/2 - sqrt(3)*I/2)/3)/(x + 1/2 + sqrt(3)*I/2) + (-y/3 - + 2*y*(-1/2 + sqrt(3)*I/2)/3)/(x + 1/2 - sqrt(3)*I/2) + + + See Also + ======== + + apart_list, assemble_partfrac_list + """ + allowed_flags(options, []) + + f = sympify(f) + + if f.is_Atom: + return f + else: + P, Q = f.as_numer_denom() + + _options = options.copy() + options = set_defaults(options, extension=True) + try: + (P, Q), opt = parallel_poly_from_expr((P, Q), x, **options) + except PolynomialError as msg: + if f.is_commutative: + raise PolynomialError(msg) + # non-commutative + if f.is_Mul: + c, nc = f.args_cnc(split_1=False) + nc = f.func(*nc) + if c: + c = apart(f.func._from_args(c), x=x, full=full, **_options) + return c*nc + else: + return nc + elif f.is_Add: + c = [] + nc = [] + for i in f.args: + if i.is_commutative: + c.append(i) + else: + try: + nc.append(apart(i, x=x, full=full, **_options)) + except NotImplementedError: + nc.append(i) + return apart(f.func(*c), x=x, full=full, **_options) + f.func(*nc) + else: + reps = [] + pot = preorder_traversal(f) + next(pot) + for e in pot: + try: + reps.append((e, apart(e, x=x, full=full, **_options))) + pot.skip() # this was handled successfully + except NotImplementedError: + pass + return f.xreplace(dict(reps)) + + if P.is_multivariate: + fc = f.cancel() + if fc != f: + return apart(fc, x=x, full=full, **_options) + + raise NotImplementedError( + "multivariate partial fraction decomposition") + + common, P, Q = P.cancel(Q) + + poly, P = P.div(Q, auto=True) + P, Q = P.rat_clear_denoms(Q) + + if Q.degree() <= 1: + partial = P/Q + else: + if not full: + partial = apart_undetermined_coeffs(P, Q) + else: + partial = apart_full_decomposition(P, Q) + + terms = S.Zero + + for term in Add.make_args(partial): + if term.has(RootSum): + terms += term + else: + terms += factor(term) + + return common*(poly.as_expr() + terms) + + +def apart_undetermined_coeffs(P, Q): + """Partial fractions via method of undetermined coefficients. """ + X = numbered_symbols(cls=Dummy) + partial, symbols = [], [] + + _, factors = Q.factor_list() + + for f, k in factors: + n, q = f.degree(), Q + + for i in range(1, k + 1): + coeffs, q = take(X, n), q.quo(f) + partial.append((coeffs, q, f, i)) + symbols.extend(coeffs) + + dom = Q.get_domain().inject(*symbols) + F = Poly(0, Q.gen, domain=dom) + + for i, (coeffs, q, f, k) in enumerate(partial): + h = Poly(coeffs, Q.gen, domain=dom) + partial[i] = (h, f, k) + q = q.set_domain(dom) + F += h*q + + system, result = [], S.Zero + + for (k,), coeff in F.terms(): + system.append(coeff - P.nth(k)) + + from sympy.solvers import solve + solution = solve(system, symbols) + + for h, f, k in partial: + h = h.as_expr().subs(solution) + result += h/f.as_expr()**k + + return result + + +def apart_full_decomposition(P, Q): + """ + Bronstein's full partial fraction decomposition algorithm. + + Given a univariate rational function ``f``, performing only GCD + operations over the algebraic closure of the initial ground domain + of definition, compute full partial fraction decomposition with + fractions having linear denominators. + + Note that no factorization of the initial denominator of ``f`` is + performed. The final decomposition is formed in terms of a sum of + :class:`RootSum` instances. + + References + ========== + + .. [1] [Bronstein93]_ + + """ + return assemble_partfrac_list(apart_list(P/Q, P.gens[0])) + + +@public +def apart_list(f, x=None, dummies=None, **options): + """ + Compute partial fraction decomposition of a rational function + and return the result in structured form. + + Given a rational function ``f`` compute the partial fraction decomposition + of ``f``. Only Bronstein's full partial fraction decomposition algorithm + is supported by this method. The return value is highly structured and + perfectly suited for further algorithmic treatment rather than being + human-readable. The function returns a tuple holding three elements: + + * The first item is the common coefficient, free of the variable `x` used + for decomposition. (It is an element of the base field `K`.) + + * The second item is the polynomial part of the decomposition. This can be + the zero polynomial. (It is an element of `K[x]`.) + + * The third part itself is a list of quadruples. Each quadruple + has the following elements in this order: + + - The (not necessarily irreducible) polynomial `D` whose roots `w_i` appear + in the linear denominator of a bunch of related fraction terms. (This item + can also be a list of explicit roots. However, at the moment ``apart_list`` + never returns a result this way, but the related ``assemble_partfrac_list`` + function accepts this format as input.) + + - The numerator of the fraction, written as a function of the root `w` + + - The linear denominator of the fraction *excluding its power exponent*, + written as a function of the root `w`. + + - The power to which the denominator has to be raised. + + On can always rebuild a plain expression by using the function ``assemble_partfrac_list``. + + Examples + ======== + + A first example: + + >>> from sympy.polys.partfrac import apart_list, assemble_partfrac_list + >>> from sympy.abc import x, t + + >>> f = (2*x**3 - 2*x) / (x**2 - 2*x + 1) + >>> pfd = apart_list(f) + >>> pfd + (1, + Poly(2*x + 4, x, domain='ZZ'), + [(Poly(_w - 1, _w, domain='ZZ'), Lambda(_a, 4), Lambda(_a, -_a + x), 1)]) + + >>> assemble_partfrac_list(pfd) + 2*x + 4 + 4/(x - 1) + + Second example: + + >>> f = (-2*x - 2*x**2) / (3*x**2 - 6*x) + >>> pfd = apart_list(f) + >>> pfd + (-1, + Poly(2/3, x, domain='QQ'), + [(Poly(_w - 2, _w, domain='ZZ'), Lambda(_a, 2), Lambda(_a, -_a + x), 1)]) + + >>> assemble_partfrac_list(pfd) + -2/3 - 2/(x - 2) + + Another example, showing symbolic parameters: + + >>> pfd = apart_list(t/(x**2 + x + t), x) + >>> pfd + (1, + Poly(0, x, domain='ZZ[t]'), + [(Poly(_w**2 + _w + t, _w, domain='ZZ[t]'), + Lambda(_a, -2*_a*t/(4*t - 1) - t/(4*t - 1)), + Lambda(_a, -_a + x), + 1)]) + + >>> assemble_partfrac_list(pfd) + RootSum(_w**2 + _w + t, Lambda(_a, (-2*_a*t/(4*t - 1) - t/(4*t - 1))/(-_a + x))) + + This example is taken from Bronstein's original paper: + + >>> f = 36 / (x**5 - 2*x**4 - 2*x**3 + 4*x**2 + x - 2) + >>> pfd = apart_list(f) + >>> pfd + (1, + Poly(0, x, domain='ZZ'), + [(Poly(_w - 2, _w, domain='ZZ'), Lambda(_a, 4), Lambda(_a, -_a + x), 1), + (Poly(_w**2 - 1, _w, domain='ZZ'), Lambda(_a, -3*_a - 6), Lambda(_a, -_a + x), 2), + (Poly(_w + 1, _w, domain='ZZ'), Lambda(_a, -4), Lambda(_a, -_a + x), 1)]) + + >>> assemble_partfrac_list(pfd) + -4/(x + 1) - 3/(x + 1)**2 - 9/(x - 1)**2 + 4/(x - 2) + + See also + ======== + + apart, assemble_partfrac_list + + References + ========== + + .. [1] [Bronstein93]_ + + """ + allowed_flags(options, []) + + f = sympify(f) + + if f.is_Atom: + return f + else: + P, Q = f.as_numer_denom() + + options = set_defaults(options, extension=True) + (P, Q), opt = parallel_poly_from_expr((P, Q), x, **options) + + if P.is_multivariate: + raise NotImplementedError( + "multivariate partial fraction decomposition") + + common, P, Q = P.cancel(Q) + + poly, P = P.div(Q, auto=True) + P, Q = P.rat_clear_denoms(Q) + + polypart = poly + + if dummies is None: + def dummies(name): + d = Dummy(name) + while True: + yield d + + dummies = dummies("w") + + rationalpart = apart_list_full_decomposition(P, Q, dummies) + + return (common, polypart, rationalpart) + + +def apart_list_full_decomposition(P, Q, dummygen): + """ + Bronstein's full partial fraction decomposition algorithm. + + Given a univariate rational function ``f``, performing only GCD + operations over the algebraic closure of the initial ground domain + of definition, compute full partial fraction decomposition with + fractions having linear denominators. + + Note that no factorization of the initial denominator of ``f`` is + performed. The final decomposition is formed in terms of a sum of + :class:`RootSum` instances. + + References + ========== + + .. [1] [Bronstein93]_ + + """ + P_orig, Q_orig, x, U = P, Q, P.gen, [] + + u = Function('u')(x) + a = Dummy('a') + + partial = [] + + for d, n in Q.sqf_list_include(all=True): + b = d.as_expr() + U += [ u.diff(x, n - 1) ] + + h = cancel(P_orig/Q_orig.quo(d**n)) / u**n + + H, subs = [h], [] + + for j in range(1, n): + H += [ H[-1].diff(x) / j ] + + for j in range(1, n + 1): + subs += [ (U[j - 1], b.diff(x, j) / j) ] + + for j in range(0, n): + P, Q = cancel(H[j]).as_numer_denom() + + for i in range(0, j + 1): + P = P.subs(*subs[j - i]) + + Q = Q.subs(*subs[0]) + + P = Poly(P, x) + Q = Poly(Q, x) + + G = P.gcd(d) + D = d.quo(G) + + B, g = Q.half_gcdex(D) + b = (P * B.quo(g)).rem(D) + + Dw = D.subs(x, next(dummygen)) + numer = Lambda(a, b.as_expr().subs(x, a)) + denom = Lambda(a, (x - a)) + exponent = n-j + + partial.append((Dw, numer, denom, exponent)) + + return partial + + +@public +def assemble_partfrac_list(partial_list): + r"""Reassemble a full partial fraction decomposition + from a structured result obtained by the function ``apart_list``. + + Examples + ======== + + This example is taken from Bronstein's original paper: + + >>> from sympy.polys.partfrac import apart_list, assemble_partfrac_list + >>> from sympy.abc import x + + >>> f = 36 / (x**5 - 2*x**4 - 2*x**3 + 4*x**2 + x - 2) + >>> pfd = apart_list(f) + >>> pfd + (1, + Poly(0, x, domain='ZZ'), + [(Poly(_w - 2, _w, domain='ZZ'), Lambda(_a, 4), Lambda(_a, -_a + x), 1), + (Poly(_w**2 - 1, _w, domain='ZZ'), Lambda(_a, -3*_a - 6), Lambda(_a, -_a + x), 2), + (Poly(_w + 1, _w, domain='ZZ'), Lambda(_a, -4), Lambda(_a, -_a + x), 1)]) + + >>> assemble_partfrac_list(pfd) + -4/(x + 1) - 3/(x + 1)**2 - 9/(x - 1)**2 + 4/(x - 2) + + If we happen to know some roots we can provide them easily inside the structure: + + >>> pfd = apart_list(2/(x**2-2)) + >>> pfd + (1, + Poly(0, x, domain='ZZ'), + [(Poly(_w**2 - 2, _w, domain='ZZ'), + Lambda(_a, _a/2), + Lambda(_a, -_a + x), + 1)]) + + >>> pfda = assemble_partfrac_list(pfd) + >>> pfda + RootSum(_w**2 - 2, Lambda(_a, _a/(-_a + x)))/2 + + >>> pfda.doit() + -sqrt(2)/(2*(x + sqrt(2))) + sqrt(2)/(2*(x - sqrt(2))) + + >>> from sympy import Dummy, Poly, Lambda, sqrt + >>> a = Dummy("a") + >>> pfd = (1, Poly(0, x, domain='ZZ'), [([sqrt(2),-sqrt(2)], Lambda(a, a/2), Lambda(a, -a + x), 1)]) + + >>> assemble_partfrac_list(pfd) + -sqrt(2)/(2*(x + sqrt(2))) + sqrt(2)/(2*(x - sqrt(2))) + + See Also + ======== + + apart, apart_list + """ + # Common factor + common = partial_list[0] + + # Polynomial part + polypart = partial_list[1] + pfd = polypart.as_expr() + + # Rational parts + for r, nf, df, ex in partial_list[2]: + if isinstance(r, Poly): + # Assemble in case the roots are given implicitly by a polynomials + an, nu = nf.variables, nf.expr + ad, de = df.variables, df.expr + # Hack to make dummies equal because Lambda created new Dummies + de = de.subs(ad[0], an[0]) + func = Lambda(tuple(an), nu/de**ex) + pfd += RootSum(r, func, auto=False, quadratic=False) + else: + # Assemble in case the roots are given explicitly by a list of algebraic numbers + for root in r: + pfd += nf(root)/df(root)**ex + + return common*pfd diff --git a/phivenv/Lib/site-packages/sympy/polys/polyclasses.py b/phivenv/Lib/site-packages/sympy/polys/polyclasses.py new file mode 100644 index 0000000000000000000000000000000000000000..1cc0e0f368ab07d64837e057b841ea991d9de223 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/polyclasses.py @@ -0,0 +1,3186 @@ +"""OO layer for several polynomial representations. """ + +from __future__ import annotations + +from sympy.external.gmpy import GROUND_TYPES + +from sympy.utilities.exceptions import sympy_deprecation_warning + +from sympy.core.numbers import oo +from sympy.core.sympify import CantSympify +from sympy.polys.polyutils import PicklableWithSlots, _sort_factors +from sympy.polys.domains import Domain, ZZ, QQ + +from sympy.polys.polyerrors import ( + CoercionFailed, + ExactQuotientFailed, + DomainError, + NotInvertible, +) + +from sympy.polys.densebasic import ( + ninf, + dmp_validate, + dup_normal, dmp_normal, + dup_convert, dmp_convert, + dmp_from_sympy, + dup_strip, + dmp_degree_in, + dmp_degree_list, + dmp_negative_p, + dmp_ground_LC, + dmp_ground_TC, + dmp_ground_nth, + dmp_one, dmp_ground, + dmp_zero, dmp_zero_p, dmp_one_p, dmp_ground_p, + dup_from_dict, dmp_from_dict, + dmp_to_dict, + dmp_deflate, + dmp_inject, dmp_eject, + dmp_terms_gcd, + dmp_list_terms, dmp_exclude, + dup_slice, dmp_slice_in, dmp_permute, + dmp_to_tuple,) + +from sympy.polys.densearith import ( + dmp_add_ground, + dmp_sub_ground, + dmp_mul_ground, + dmp_quo_ground, + dmp_exquo_ground, + dmp_abs, + dmp_neg, + dmp_add, + dmp_sub, + dmp_mul, + dmp_sqr, + dmp_pow, + dmp_pdiv, + dmp_prem, + dmp_pquo, + dmp_pexquo, + dmp_div, + dmp_rem, + dmp_quo, + dmp_exquo, + dmp_add_mul, dmp_sub_mul, + dmp_max_norm, + dmp_l1_norm, + dmp_l2_norm_squared) + +from sympy.polys.densetools import ( + dmp_clear_denoms, + dmp_integrate_in, + dmp_diff_in, + dmp_eval_in, + dup_revert, + dmp_ground_trunc, + dmp_ground_content, + dmp_ground_primitive, + dmp_ground_monic, + dmp_compose, + dup_decompose, + dup_shift, + dmp_shift, + dup_transform, + dmp_lift) + +from sympy.polys.euclidtools import ( + dup_half_gcdex, dup_gcdex, dup_invert, + dmp_subresultants, + dmp_resultant, + dmp_discriminant, + dmp_inner_gcd, + dmp_gcd, + dmp_lcm, + dmp_cancel) + +from sympy.polys.sqfreetools import ( + dup_gff_list, + dmp_norm, + dmp_sqf_p, + dmp_sqf_norm, + dmp_sqf_part, + dmp_sqf_list, dmp_sqf_list_include) + +from sympy.polys.factortools import ( + dup_cyclotomic_p, dmp_irreducible_p, + dmp_factor_list, dmp_factor_list_include) + +from sympy.polys.rootisolation import ( + dup_isolate_real_roots_sqf, + dup_isolate_real_roots, + dup_isolate_all_roots_sqf, + dup_isolate_all_roots, + dup_refine_real_root, + dup_count_real_roots, + dup_count_complex_roots, + dup_sturm, + dup_cauchy_upper_bound, + dup_cauchy_lower_bound, + dup_mignotte_sep_bound_squared) + +from sympy.polys.polyerrors import ( + UnificationFailed, + PolynomialError) + + +if GROUND_TYPES == 'flint': + import flint + def _supported_flint_domain(D): + return D.is_ZZ or D.is_QQ or D.is_FF and D._is_flint +else: + flint = None + def _supported_flint_domain(D): + return False + + +class DMP(CantSympify): + """Dense Multivariate Polynomials over `K`. """ + + __slots__ = () + + lev: int + dom: Domain + + def __new__(cls, rep, dom, lev=None): + + if lev is None: + rep, lev = dmp_validate(rep) + elif not isinstance(rep, list): + raise CoercionFailed("expected list, got %s" % type(rep)) + + return cls.new(rep, dom, lev) + + @classmethod + def new(cls, rep, dom, lev): + # It would be too slow to call _validate_args always at runtime. + # Ideally this checking would be handled by a static type checker. + # + #cls._validate_args(rep, dom, lev) + if flint is not None: + if lev == 0 and _supported_flint_domain(dom): + return DUP_Flint._new(rep, dom, lev) + + return DMP_Python._new(rep, dom, lev) + + @property + def rep(f): + """Get the representation of ``f``. """ + + sympy_deprecation_warning(""" + Accessing the ``DMP.rep`` attribute is deprecated. The internal + representation of ``DMP`` instances can now be ``DUP_Flint`` when the + ground types are ``flint``. In this case the ``DMP`` instance does not + have a ``rep`` attribute. Use ``DMP.to_list()`` instead. Using + ``DMP.to_list()`` also works in previous versions of SymPy. + """, + deprecated_since_version="1.13", + active_deprecations_target="dmp-rep", + ) + + return f.to_list() + + def to_best(f): + """Convert to DUP_Flint if possible. + + This method should be used when the domain or level is changed and it + potentially becomes possible to convert from DMP_Python to DUP_Flint. + """ + if flint is not None: + if isinstance(f, DMP_Python) and f.lev == 0 and _supported_flint_domain(f.dom): + return DUP_Flint.new(f._rep, f.dom, f.lev) + + return f + + @classmethod + def _validate_args(cls, rep, dom, lev): + assert isinstance(dom, Domain) + assert isinstance(lev, int) and lev >= 0 + + def validate_rep(rep, lev): + assert isinstance(rep, list) + if lev == 0: + assert all(dom.of_type(c) for c in rep) + else: + for r in rep: + validate_rep(r, lev - 1) + + validate_rep(rep, lev) + + @classmethod + def from_dict(cls, rep, lev, dom): + rep = dmp_from_dict(rep, lev, dom) + return cls.new(rep, dom, lev) + + @classmethod + def from_list(cls, rep, lev, dom): + """Create an instance of ``cls`` given a list of native coefficients. """ + return cls.new(dmp_convert(rep, lev, None, dom), dom, lev) + + @classmethod + def from_sympy_list(cls, rep, lev, dom): + """Create an instance of ``cls`` given a list of SymPy coefficients. """ + return cls.new(dmp_from_sympy(rep, lev, dom), dom, lev) + + @classmethod + def from_monoms_coeffs(cls, monoms, coeffs, lev, dom): + return cls(dict(list(zip(monoms, coeffs))), dom, lev) + + def convert(f, dom): + """Convert ``f`` to a ``DMP`` over the new domain. """ + if f.dom == dom: + return f + elif f.lev or flint is None: + return f._convert(dom) + elif isinstance(f, DUP_Flint): + if _supported_flint_domain(dom): + return f._convert(dom) + else: + return f.to_DMP_Python()._convert(dom) + elif isinstance(f, DMP_Python): + if _supported_flint_domain(dom): + return f._convert(dom).to_DUP_Flint() + else: + return f._convert(dom) + else: + raise RuntimeError("unreachable code") + + def _convert(f, dom): + raise NotImplementedError + + @classmethod + def zero(cls, lev, dom): + return DMP(dmp_zero(lev), dom, lev) + + @classmethod + def one(cls, lev, dom): + return DMP(dmp_one(lev, dom), dom, lev) + + def _one(f): + raise NotImplementedError + + def __repr__(f): + return "%s(%s, %s)" % (f.__class__.__name__, f.to_list(), f.dom) + + def __hash__(f): + return hash((f.__class__.__name__, f.to_tuple(), f.lev, f.dom)) + + def __getnewargs__(self): + return self.to_list(), self.dom, self.lev + + def ground_new(f, coeff): + """Construct a new ground instance of ``f``. """ + raise NotImplementedError + + def unify_DMP(f, g): + """Unify and return ``DMP`` instances of ``f`` and ``g``. """ + if not isinstance(g, DMP) or f.lev != g.lev: + raise UnificationFailed("Cannot unify %s with %s" % (f, g)) + + if f.dom != g.dom: + dom = f.dom.unify(g.dom) + f = f.convert(dom) + g = g.convert(dom) + + return f, g + + def to_dict(f, zero=False): + """Convert ``f`` to a dict representation with native coefficients. """ + return dmp_to_dict(f.to_list(), f.lev, f.dom, zero=zero) + + def to_sympy_dict(f, zero=False): + """Convert ``f`` to a dict representation with SymPy coefficients. """ + rep = f.to_dict(zero=zero) + + for k, v in rep.items(): + rep[k] = f.dom.to_sympy(v) + + return rep + + def to_sympy_list(f): + """Convert ``f`` to a list representation with SymPy coefficients. """ + def sympify_nested_list(rep): + out = [] + for val in rep: + if isinstance(val, list): + out.append(sympify_nested_list(val)) + else: + out.append(f.dom.to_sympy(val)) + return out + + return sympify_nested_list(f.to_list()) + + def to_list(f): + """Convert ``f`` to a list representation with native coefficients. """ + raise NotImplementedError + + def to_tuple(f): + """ + Convert ``f`` to a tuple representation with native coefficients. + + This is needed for hashing. + """ + raise NotImplementedError + + def to_ring(f): + """Make the ground domain a ring. """ + return f.convert(f.dom.get_ring()) + + def to_field(f): + """Make the ground domain a field. """ + return f.convert(f.dom.get_field()) + + def to_exact(f): + """Make the ground domain exact. """ + return f.convert(f.dom.get_exact()) + + def slice(f, m, n, j=0): + """Take a continuous subsequence of terms of ``f``. """ + if not f.lev and not j: + return f._slice(m, n) + else: + return f._slice_lev(m, n, j) + + def _slice(f, m, n): + raise NotImplementedError + + def _slice_lev(f, m, n, j): + raise NotImplementedError + + def coeffs(f, order=None): + """Returns all non-zero coefficients from ``f`` in lex order. """ + return [ c for _, c in f.terms(order=order) ] + + def monoms(f, order=None): + """Returns all non-zero monomials from ``f`` in lex order. """ + return [ m for m, _ in f.terms(order=order) ] + + def terms(f, order=None): + """Returns all non-zero terms from ``f`` in lex order. """ + if f.is_zero: + zero_monom = (0,)*(f.lev + 1) + return [(zero_monom, f.dom.zero)] + else: + return f._terms(order=order) + + def _terms(f, order=None): + raise NotImplementedError + + def all_coeffs(f): + """Returns all coefficients from ``f``. """ + if f.lev: + raise PolynomialError('multivariate polynomials not supported') + + if not f: + return [f.dom.zero] + else: + return list(f.to_list()) + + def all_monoms(f): + """Returns all monomials from ``f``. """ + if f.lev: + raise PolynomialError('multivariate polynomials not supported') + + n = f.degree() + + if n < 0: + return [(0,)] + else: + return [ (n - i,) for i, c in enumerate(f.to_list()) ] + + def all_terms(f): + """Returns all terms from a ``f``. """ + if f.lev: + raise PolynomialError('multivariate polynomials not supported') + + n = f.degree() + + if n < 0: + return [((0,), f.dom.zero)] + else: + return [ ((n - i,), c) for i, c in enumerate(f.to_list()) ] + + def lift(f): + """Convert algebraic coefficients to rationals. """ + return f._lift().to_best() + + def _lift(f): + raise NotImplementedError + + def deflate(f): + """Reduce degree of `f` by mapping `x_i^m` to `y_i`. """ + raise NotImplementedError + + def inject(f, front=False): + """Inject ground domain generators into ``f``. """ + raise NotImplementedError + + def eject(f, dom, front=False): + """Eject selected generators into the ground domain. """ + raise NotImplementedError + + def exclude(f): + r""" + Remove useless generators from ``f``. + + Returns the removed generators and the new excluded ``f``. + + Examples + ======== + + >>> from sympy.polys.polyclasses import DMP + >>> from sympy.polys.domains import ZZ + + >>> DMP([[[ZZ(1)]], [[ZZ(1)], [ZZ(2)]]], ZZ).exclude() + ([2], DMP_Python([[1], [1, 2]], ZZ)) + + """ + J, F = f._exclude() + return J, F.to_best() + + def _exclude(f): + raise NotImplementedError + + def permute(f, P): + r""" + Returns a polynomial in `K[x_{P(1)}, ..., x_{P(n)}]`. + + Examples + ======== + + >>> from sympy.polys.polyclasses import DMP + >>> from sympy.polys.domains import ZZ + + >>> DMP([[[ZZ(2)], [ZZ(1), ZZ(0)]], [[]]], ZZ).permute([1, 0, 2]) + DMP_Python([[[2], []], [[1, 0], []]], ZZ) + + >>> DMP([[[ZZ(2)], [ZZ(1), ZZ(0)]], [[]]], ZZ).permute([1, 2, 0]) + DMP_Python([[[1], []], [[2, 0], []]], ZZ) + + """ + return f._permute(P) + + def _permute(f, P): + raise NotImplementedError + + def terms_gcd(f): + """Remove GCD of terms from the polynomial ``f``. """ + raise NotImplementedError + + def abs(f): + """Make all coefficients in ``f`` positive. """ + raise NotImplementedError + + def neg(f): + """Negate all coefficients in ``f``. """ + raise NotImplementedError + + def add_ground(f, c): + """Add an element of the ground domain to ``f``. """ + return f._add_ground(f.dom.convert(c)) + + def sub_ground(f, c): + """Subtract an element of the ground domain from ``f``. """ + return f._sub_ground(f.dom.convert(c)) + + def mul_ground(f, c): + """Multiply ``f`` by a an element of the ground domain. """ + return f._mul_ground(f.dom.convert(c)) + + def quo_ground(f, c): + """Quotient of ``f`` by a an element of the ground domain. """ + return f._quo_ground(f.dom.convert(c)) + + def exquo_ground(f, c): + """Exact quotient of ``f`` by a an element of the ground domain. """ + return f._exquo_ground(f.dom.convert(c)) + + def add(f, g): + """Add two multivariate polynomials ``f`` and ``g``. """ + F, G = f.unify_DMP(g) + return F._add(G) + + def sub(f, g): + """Subtract two multivariate polynomials ``f`` and ``g``. """ + F, G = f.unify_DMP(g) + return F._sub(G) + + def mul(f, g): + """Multiply two multivariate polynomials ``f`` and ``g``. """ + F, G = f.unify_DMP(g) + return F._mul(G) + + def sqr(f): + """Square a multivariate polynomial ``f``. """ + return f._sqr() + + def pow(f, n): + """Raise ``f`` to a non-negative power ``n``. """ + if not isinstance(n, int): + raise TypeError("``int`` expected, got %s" % type(n)) + return f._pow(n) + + def pdiv(f, g): + """Polynomial pseudo-division of ``f`` and ``g``. """ + F, G = f.unify_DMP(g) + return F._pdiv(G) + + def prem(f, g): + """Polynomial pseudo-remainder of ``f`` and ``g``. """ + F, G = f.unify_DMP(g) + return F._prem(G) + + def pquo(f, g): + """Polynomial pseudo-quotient of ``f`` and ``g``. """ + F, G = f.unify_DMP(g) + return F._pquo(G) + + def pexquo(f, g): + """Polynomial exact pseudo-quotient of ``f`` and ``g``. """ + F, G = f.unify_DMP(g) + return F._pexquo(G) + + def div(f, g): + """Polynomial division with remainder of ``f`` and ``g``. """ + F, G = f.unify_DMP(g) + return F._div(G) + + def rem(f, g): + """Computes polynomial remainder of ``f`` and ``g``. """ + F, G = f.unify_DMP(g) + return F._rem(G) + + def quo(f, g): + """Computes polynomial quotient of ``f`` and ``g``. """ + F, G = f.unify_DMP(g) + return F._quo(G) + + def exquo(f, g): + """Computes polynomial exact quotient of ``f`` and ``g``. """ + F, G = f.unify_DMP(g) + return F._exquo(G) + + def _add_ground(f, c): + raise NotImplementedError + + def _sub_ground(f, c): + raise NotImplementedError + + def _mul_ground(f, c): + raise NotImplementedError + + def _quo_ground(f, c): + raise NotImplementedError + + def _exquo_ground(f, c): + raise NotImplementedError + + def _add(f, g): + raise NotImplementedError + + def _sub(f, g): + raise NotImplementedError + + def _mul(f, g): + raise NotImplementedError + + def _sqr(f): + raise NotImplementedError + + def _pow(f, n): + raise NotImplementedError + + def _pdiv(f, g): + raise NotImplementedError + + def _prem(f, g): + raise NotImplementedError + + def _pquo(f, g): + raise NotImplementedError + + def _pexquo(f, g): + raise NotImplementedError + + def _div(f, g): + raise NotImplementedError + + def _rem(f, g): + raise NotImplementedError + + def _quo(f, g): + raise NotImplementedError + + def _exquo(f, g): + raise NotImplementedError + + def degree(f, j=0): + """Returns the leading degree of ``f`` in ``x_j``. """ + if not isinstance(j, int): + raise TypeError("``int`` expected, got %s" % type(j)) + + return f._degree(j) + + def _degree(f, j): + raise NotImplementedError + + def degree_list(f): + """Returns a list of degrees of ``f``. """ + raise NotImplementedError + + def total_degree(f): + """Returns the total degree of ``f``. """ + raise NotImplementedError + + def homogenize(f, s): + """Return homogeneous polynomial of ``f``""" + td = f.total_degree() + result = {} + new_symbol = (s == len(f.terms()[0][0])) + for term in f.terms(): + d = sum(term[0]) + if d < td: + i = td - d + else: + i = 0 + if new_symbol: + result[term[0] + (i,)] = term[1] + else: + l = list(term[0]) + l[s] += i + result[tuple(l)] = term[1] + return DMP.from_dict(result, f.lev + int(new_symbol), f.dom) + + def homogeneous_order(f): + """Returns the homogeneous order of ``f``. """ + if f.is_zero: + return -oo + + monoms = f.monoms() + tdeg = sum(monoms[0]) + + for monom in monoms: + _tdeg = sum(monom) + + if _tdeg != tdeg: + return None + + return tdeg + + def LC(f): + """Returns the leading coefficient of ``f``. """ + raise NotImplementedError + + def TC(f): + """Returns the trailing coefficient of ``f``. """ + raise NotImplementedError + + def nth(f, *N): + """Returns the ``n``-th coefficient of ``f``. """ + if all(isinstance(n, int) for n in N): + return f._nth(N) + else: + raise TypeError("a sequence of integers expected") + + def _nth(f, N): + raise NotImplementedError + + def max_norm(f): + """Returns maximum norm of ``f``. """ + raise NotImplementedError + + def l1_norm(f): + """Returns l1 norm of ``f``. """ + raise NotImplementedError + + def l2_norm_squared(f): + """Return squared l2 norm of ``f``. """ + raise NotImplementedError + + def clear_denoms(f): + """Clear denominators, but keep the ground domain. """ + raise NotImplementedError + + def integrate(f, m=1, j=0): + """Computes the ``m``-th order indefinite integral of ``f`` in ``x_j``. """ + if not isinstance(m, int): + raise TypeError("``int`` expected, got %s" % type(m)) + + if not isinstance(j, int): + raise TypeError("``int`` expected, got %s" % type(j)) + + return f._integrate(m, j) + + def _integrate(f, m, j): + raise NotImplementedError + + def diff(f, m=1, j=0): + """Computes the ``m``-th order derivative of ``f`` in ``x_j``. """ + if not isinstance(m, int): + raise TypeError("``int`` expected, got %s" % type(m)) + + if not isinstance(j, int): + raise TypeError("``int`` expected, got %s" % type(j)) + + return f._diff(m, j) + + def _diff(f, m, j): + raise NotImplementedError + + def eval(f, a, j=0): + """Evaluates ``f`` at the given point ``a`` in ``x_j``. """ + if not isinstance(j, int): + raise TypeError("``int`` expected, got %s" % type(j)) + elif not (0 <= j <= f.lev): + raise ValueError("invalid variable index %s" % j) + + if f.lev: + return f._eval_lev(a, j) + else: + return f._eval(a) + + def _eval(f, a): + raise NotImplementedError + + def _eval_lev(f, a, j): + raise NotImplementedError + + def half_gcdex(f, g): + """Half extended Euclidean algorithm, if univariate. """ + F, G = f.unify_DMP(g) + + if F.lev: + raise ValueError('univariate polynomial expected') + + return F._half_gcdex(G) + + def _half_gcdex(f, g): + raise NotImplementedError + + def gcdex(f, g): + """Extended Euclidean algorithm, if univariate. """ + F, G = f.unify_DMP(g) + + if F.lev: + raise ValueError('univariate polynomial expected') + + if not F.dom.is_Field: + raise DomainError('ground domain must be a field') + + return F._gcdex(G) + + def _gcdex(f, g): + raise NotImplementedError + + def invert(f, g): + """Invert ``f`` modulo ``g``, if possible. """ + F, G = f.unify_DMP(g) + + if F.lev: + raise ValueError('univariate polynomial expected') + + return F._invert(G) + + def _invert(f, g): + raise NotImplementedError + + def revert(f, n): + """Compute ``f**(-1)`` mod ``x**n``. """ + if f.lev: + raise ValueError('univariate polynomial expected') + + return f._revert(n) + + def _revert(f, n): + raise NotImplementedError + + def subresultants(f, g): + """Computes subresultant PRS sequence of ``f`` and ``g``. """ + F, G = f.unify_DMP(g) + return F._subresultants(G) + + def _subresultants(f, g): + raise NotImplementedError + + def resultant(f, g, includePRS=False): + """Computes resultant of ``f`` and ``g`` via PRS. """ + F, G = f.unify_DMP(g) + if includePRS: + return F._resultant_includePRS(G) + else: + return F._resultant(G) + + def _resultant(f, g, includePRS=False): + raise NotImplementedError + + def discriminant(f): + """Computes discriminant of ``f``. """ + raise NotImplementedError + + def cofactors(f, g): + """Returns GCD of ``f`` and ``g`` and their cofactors. """ + F, G = f.unify_DMP(g) + return F._cofactors(G) + + def _cofactors(f, g): + raise NotImplementedError + + def gcd(f, g): + """Returns polynomial GCD of ``f`` and ``g``. """ + F, G = f.unify_DMP(g) + return F._gcd(G) + + def _gcd(f, g): + raise NotImplementedError + + def lcm(f, g): + """Returns polynomial LCM of ``f`` and ``g``. """ + F, G = f.unify_DMP(g) + return F._lcm(G) + + def _lcm(f, g): + raise NotImplementedError + + def cancel(f, g, include=True): + """Cancel common factors in a rational function ``f/g``. """ + F, G = f.unify_DMP(g) + + if include: + return F._cancel_include(G) + else: + return F._cancel(G) + + def _cancel(f, g): + raise NotImplementedError + + def _cancel_include(f, g): + raise NotImplementedError + + def trunc(f, p): + """Reduce ``f`` modulo a constant ``p``. """ + return f._trunc(f.dom.convert(p)) + + def _trunc(f, p): + raise NotImplementedError + + def monic(f): + """Divides all coefficients by ``LC(f)``. """ + raise NotImplementedError + + def content(f): + """Returns GCD of polynomial coefficients. """ + raise NotImplementedError + + def primitive(f): + """Returns content and a primitive form of ``f``. """ + raise NotImplementedError + + def compose(f, g): + """Computes functional composition of ``f`` and ``g``. """ + F, G = f.unify_DMP(g) + return F._compose(G) + + def _compose(f, g): + raise NotImplementedError + + def decompose(f): + """Computes functional decomposition of ``f``. """ + if f.lev: + raise ValueError('univariate polynomial expected') + + return f._decompose() + + def _decompose(f): + raise NotImplementedError + + def shift(f, a): + """Efficiently compute Taylor shift ``f(x + a)``. """ + if f.lev: + raise ValueError('univariate polynomial expected') + + return f._shift(f.dom.convert(a)) + + def shift_list(f, a): + """Efficiently compute Taylor shift ``f(X + A)``. """ + a = [f.dom.convert(ai) for ai in a] + return f._shift_list(a) + + def _shift(f, a): + raise NotImplementedError + + def transform(f, p, q): + """Evaluate functional transformation ``q**n * f(p/q)``.""" + if f.lev: + raise ValueError('univariate polynomial expected') + + P, Q = p.unify_DMP(q) + F, P = f.unify_DMP(P) + F, Q = F.unify_DMP(Q) + + return F._transform(P, Q) + + def _transform(f, p, q): + raise NotImplementedError + + def sturm(f): + """Computes the Sturm sequence of ``f``. """ + if f.lev: + raise ValueError('univariate polynomial expected') + + return f._sturm() + + def _sturm(f): + raise NotImplementedError + + def cauchy_upper_bound(f): + """Computes the Cauchy upper bound on the roots of ``f``. """ + if f.lev: + raise ValueError('univariate polynomial expected') + + return f._cauchy_upper_bound() + + def _cauchy_upper_bound(f): + raise NotImplementedError + + def cauchy_lower_bound(f): + """Computes the Cauchy lower bound on the nonzero roots of ``f``. """ + if f.lev: + raise ValueError('univariate polynomial expected') + + return f._cauchy_lower_bound() + + def _cauchy_lower_bound(f): + raise NotImplementedError + + def mignotte_sep_bound_squared(f): + """Computes the squared Mignotte bound on root separations of ``f``. """ + if f.lev: + raise ValueError('univariate polynomial expected') + + return f._mignotte_sep_bound_squared() + + def _mignotte_sep_bound_squared(f): + raise NotImplementedError + + def gff_list(f): + """Computes greatest factorial factorization of ``f``. """ + if f.lev: + raise ValueError('univariate polynomial expected') + + return f._gff_list() + + def _gff_list(f): + raise NotImplementedError + + def norm(f): + """Computes ``Norm(f)``.""" + raise NotImplementedError + + def sqf_norm(f): + """Computes square-free norm of ``f``. """ + raise NotImplementedError + + def sqf_part(f): + """Computes square-free part of ``f``. """ + raise NotImplementedError + + def sqf_list(f, all=False): + """Returns a list of square-free factors of ``f``. """ + raise NotImplementedError + + def sqf_list_include(f, all=False): + """Returns a list of square-free factors of ``f``. """ + raise NotImplementedError + + def factor_list(f): + """Returns a list of irreducible factors of ``f``. """ + raise NotImplementedError + + def factor_list_include(f): + """Returns a list of irreducible factors of ``f``. """ + raise NotImplementedError + + def intervals(f, all=False, eps=None, inf=None, sup=None, fast=False, sqf=False): + """Compute isolating intervals for roots of ``f``. """ + if f.lev: + raise PolynomialError("Cannot isolate roots of a multivariate polynomial") + + if all and sqf: + return f._isolate_all_roots_sqf(eps=eps, inf=inf, sup=sup, fast=fast) + elif all and not sqf: + return f._isolate_all_roots(eps=eps, inf=inf, sup=sup, fast=fast) + elif not all and sqf: + return f._isolate_real_roots_sqf(eps=eps, inf=inf, sup=sup, fast=fast) + else: + return f._isolate_real_roots(eps=eps, inf=inf, sup=sup, fast=fast) + + def _isolate_all_roots(f, eps, inf, sup, fast): + raise NotImplementedError + + def _isolate_all_roots_sqf(f, eps, inf, sup, fast): + raise NotImplementedError + + def _isolate_real_roots(f, eps, inf, sup, fast): + raise NotImplementedError + + def _isolate_real_roots_sqf(f, eps, inf, sup, fast): + raise NotImplementedError + + def refine_root(f, s, t, eps=None, steps=None, fast=False): + """ + Refine an isolating interval to the given precision. + + ``eps`` should be a rational number. + + """ + if f.lev: + raise PolynomialError( + "Cannot refine a root of a multivariate polynomial") + + return f._refine_real_root(s, t, eps=eps, steps=steps, fast=fast) + + def _refine_real_root(f, s, t, eps, steps, fast): + raise NotImplementedError + + def count_real_roots(f, inf=None, sup=None): + """Return the number of real roots of ``f`` in ``[inf, sup]``. """ + raise NotImplementedError + + def count_complex_roots(f, inf=None, sup=None): + """Return the number of complex roots of ``f`` in ``[inf, sup]``. """ + raise NotImplementedError + + @property + def is_zero(f): + """Returns ``True`` if ``f`` is a zero polynomial. """ + raise NotImplementedError + + @property + def is_one(f): + """Returns ``True`` if ``f`` is a unit polynomial. """ + raise NotImplementedError + + @property + def is_ground(f): + """Returns ``True`` if ``f`` is an element of the ground domain. """ + raise NotImplementedError + + @property + def is_sqf(f): + """Returns ``True`` if ``f`` is a square-free polynomial. """ + raise NotImplementedError + + @property + def is_monic(f): + """Returns ``True`` if the leading coefficient of ``f`` is one. """ + raise NotImplementedError + + @property + def is_primitive(f): + """Returns ``True`` if the GCD of the coefficients of ``f`` is one. """ + raise NotImplementedError + + @property + def is_linear(f): + """Returns ``True`` if ``f`` is linear in all its variables. """ + raise NotImplementedError + + @property + def is_quadratic(f): + """Returns ``True`` if ``f`` is quadratic in all its variables. """ + raise NotImplementedError + + @property + def is_monomial(f): + """Returns ``True`` if ``f`` is zero or has only one term. """ + raise NotImplementedError + + @property + def is_homogeneous(f): + """Returns ``True`` if ``f`` is a homogeneous polynomial. """ + raise NotImplementedError + + @property + def is_irreducible(f): + """Returns ``True`` if ``f`` has no factors over its domain. """ + raise NotImplementedError + + @property + def is_cyclotomic(f): + """Returns ``True`` if ``f`` is a cyclotomic polynomial. """ + raise NotImplementedError + + def __abs__(f): + return f.abs() + + def __neg__(f): + return f.neg() + + def __add__(f, g): + if isinstance(g, DMP): + return f.add(g) + else: + try: + return f.add_ground(g) + except CoercionFailed: + return NotImplemented + + def __radd__(f, g): + return f.__add__(g) + + def __sub__(f, g): + if isinstance(g, DMP): + return f.sub(g) + else: + try: + return f.sub_ground(g) + except CoercionFailed: + return NotImplemented + + def __rsub__(f, g): + return (-f).__add__(g) + + def __mul__(f, g): + if isinstance(g, DMP): + return f.mul(g) + else: + try: + return f.mul_ground(g) + except CoercionFailed: + return NotImplemented + + def __rmul__(f, g): + return f.__mul__(g) + + def __truediv__(f, g): + if isinstance(g, DMP): + return f.exquo(g) + else: + try: + return f.mul_ground(g) + except CoercionFailed: + return NotImplemented + + def __rtruediv__(f, g): + if isinstance(g, DMP): + return g.exquo(f) + else: + try: + return f._one().mul_ground(g).exquo(f) + except CoercionFailed: + return NotImplemented + + def __pow__(f, n): + return f.pow(n) + + def __divmod__(f, g): + return f.div(g) + + def __mod__(f, g): + return f.rem(g) + + def __floordiv__(f, g): + if isinstance(g, DMP): + return f.quo(g) + else: + try: + return f.quo_ground(g) + except TypeError: + return NotImplemented + + def __eq__(f, g): + if f is g: + return True + if not isinstance(g, DMP): + return NotImplemented + try: + F, G = f.unify_DMP(g) + except UnificationFailed: + return False + else: + return F._strict_eq(G) + + def _strict_eq(f, g): + raise NotImplementedError + + def eq(f, g, strict=False): + if not strict: + return f == g + else: + return f._strict_eq(g) + + def ne(f, g, strict=False): + return not f.eq(g, strict=strict) + + def __lt__(f, g): + F, G = f.unify_DMP(g) + return F.to_list() < G.to_list() + + def __le__(f, g): + F, G = f.unify_DMP(g) + return F.to_list() <= G.to_list() + + def __gt__(f, g): + F, G = f.unify_DMP(g) + return F.to_list() > G.to_list() + + def __ge__(f, g): + F, G = f.unify_DMP(g) + return F.to_list() >= G.to_list() + + def __bool__(f): + return not f.is_zero + + +class DMP_Python(DMP): + """Dense Multivariate Polynomials over `K`. """ + + __slots__ = ('_rep', 'dom', 'lev') + + @classmethod + def _new(cls, rep, dom, lev): + obj = object.__new__(cls) + obj._rep = rep + obj.lev = lev + obj.dom = dom + return obj + + def _strict_eq(f, g): + if type(f) != type(g): + return False + return f.lev == g.lev and f.dom == g.dom and f._rep == g._rep + + def per(f, rep): + """Create a DMP out of the given representation. """ + return f._new(rep, f.dom, f.lev) + + def ground_new(f, coeff): + """Construct a new ground instance of ``f``. """ + return f._new(dmp_ground(coeff, f.lev), f.dom, f.lev) + + def _one(f): + return f.one(f.lev, f.dom) + + def unify(f, g): + """Unify representations of two multivariate polynomials. """ + # XXX: This function is not really used any more since there is + # unify_DMP now. + if not isinstance(g, DMP) or f.lev != g.lev: + raise UnificationFailed("Cannot unify %s with %s" % (f, g)) + + if f.dom == g.dom: + return f.lev, f.dom, f.per, f._rep, g._rep + else: + lev, dom = f.lev, f.dom.unify(g.dom) + + F = dmp_convert(f._rep, lev, f.dom, dom) + G = dmp_convert(g._rep, lev, g.dom, dom) + + def per(rep): + return f._new(rep, dom, lev) + + return lev, dom, per, F, G + + def to_DUP_Flint(f): + """Convert ``f`` to a Flint representation. """ + return DUP_Flint._new(f._rep, f.dom, f.lev) + + def to_list(f): + """Convert ``f`` to a list representation with native coefficients. """ + return list(f._rep) + + def to_tuple(f): + """Convert ``f`` to a tuple representation with native coefficients. """ + return dmp_to_tuple(f._rep, f.lev) + + def _convert(f, dom): + """Convert the ground domain of ``f``. """ + return f._new(dmp_convert(f._rep, f.lev, f.dom, dom), dom, f.lev) + + def _slice(f, m, n): + """Take a continuous subsequence of terms of ``f``. """ + rep = dup_slice(f._rep, m, n, f.dom) + return f._new(rep, f.dom, f.lev) + + def _slice_lev(f, m, n, j): + """Take a continuous subsequence of terms of ``f``. """ + rep = dmp_slice_in(f._rep, m, n, j, f.lev, f.dom) + return f._new(rep, f.dom, f.lev) + + def _terms(f, order=None): + """Returns all non-zero terms from ``f`` in lex order. """ + return dmp_list_terms(f._rep, f.lev, f.dom, order=order) + + def _lift(f): + """Convert algebraic coefficients to rationals. """ + r = dmp_lift(f._rep, f.lev, f.dom) + return f._new(r, f.dom.dom, f.lev) + + def deflate(f): + """Reduce degree of `f` by mapping `x_i^m` to `y_i`. """ + J, F = dmp_deflate(f._rep, f.lev, f.dom) + return J, f.per(F) + + def inject(f, front=False): + """Inject ground domain generators into ``f``. """ + F, lev = dmp_inject(f._rep, f.lev, f.dom, front=front) + # XXX: domain and level changed here + return f._new(F, f.dom.dom, lev) + + def eject(f, dom, front=False): + """Eject selected generators into the ground domain. """ + F = dmp_eject(f._rep, f.lev, dom, front=front) + # XXX: domain and level changed here + return f._new(F, dom, f.lev - len(dom.symbols)) + + def _exclude(f): + """Remove useless generators from ``f``. """ + J, F, u = dmp_exclude(f._rep, f.lev, f.dom) + # XXX: level changed here + return J, f._new(F, f.dom, u) + + def _permute(f, P): + """Returns a polynomial in `K[x_{P(1)}, ..., x_{P(n)}]`. """ + return f.per(dmp_permute(f._rep, P, f.lev, f.dom)) + + def terms_gcd(f): + """Remove GCD of terms from the polynomial ``f``. """ + J, F = dmp_terms_gcd(f._rep, f.lev, f.dom) + return J, f.per(F) + + def _add_ground(f, c): + """Add an element of the ground domain to ``f``. """ + return f.per(dmp_add_ground(f._rep, c, f.lev, f.dom)) + + def _sub_ground(f, c): + """Subtract an element of the ground domain from ``f``. """ + return f.per(dmp_sub_ground(f._rep, c, f.lev, f.dom)) + + def _mul_ground(f, c): + """Multiply ``f`` by a an element of the ground domain. """ + return f.per(dmp_mul_ground(f._rep, c, f.lev, f.dom)) + + def _quo_ground(f, c): + """Quotient of ``f`` by a an element of the ground domain. """ + return f.per(dmp_quo_ground(f._rep, c, f.lev, f.dom)) + + def _exquo_ground(f, c): + """Exact quotient of ``f`` by a an element of the ground domain. """ + return f.per(dmp_exquo_ground(f._rep, c, f.lev, f.dom)) + + def abs(f): + """Make all coefficients in ``f`` positive. """ + return f.per(dmp_abs(f._rep, f.lev, f.dom)) + + def neg(f): + """Negate all coefficients in ``f``. """ + return f.per(dmp_neg(f._rep, f.lev, f.dom)) + + def _add(f, g): + """Add two multivariate polynomials ``f`` and ``g``. """ + return f.per(dmp_add(f._rep, g._rep, f.lev, f.dom)) + + def _sub(f, g): + """Subtract two multivariate polynomials ``f`` and ``g``. """ + return f.per(dmp_sub(f._rep, g._rep, f.lev, f.dom)) + + def _mul(f, g): + """Multiply two multivariate polynomials ``f`` and ``g``. """ + return f.per(dmp_mul(f._rep, g._rep, f.lev, f.dom)) + + def sqr(f): + """Square a multivariate polynomial ``f``. """ + return f.per(dmp_sqr(f._rep, f.lev, f.dom)) + + def _pow(f, n): + """Raise ``f`` to a non-negative power ``n``. """ + return f.per(dmp_pow(f._rep, n, f.lev, f.dom)) + + def _pdiv(f, g): + """Polynomial pseudo-division of ``f`` and ``g``. """ + q, r = dmp_pdiv(f._rep, g._rep, f.lev, f.dom) + return f.per(q), f.per(r) + + def _prem(f, g): + """Polynomial pseudo-remainder of ``f`` and ``g``. """ + return f.per(dmp_prem(f._rep, g._rep, f.lev, f.dom)) + + def _pquo(f, g): + """Polynomial pseudo-quotient of ``f`` and ``g``. """ + return f.per(dmp_pquo(f._rep, g._rep, f.lev, f.dom)) + + def _pexquo(f, g): + """Polynomial exact pseudo-quotient of ``f`` and ``g``. """ + return f.per(dmp_pexquo(f._rep, g._rep, f.lev, f.dom)) + + def _div(f, g): + """Polynomial division with remainder of ``f`` and ``g``. """ + q, r = dmp_div(f._rep, g._rep, f.lev, f.dom) + return f.per(q), f.per(r) + + def _rem(f, g): + """Computes polynomial remainder of ``f`` and ``g``. """ + return f.per(dmp_rem(f._rep, g._rep, f.lev, f.dom)) + + def _quo(f, g): + """Computes polynomial quotient of ``f`` and ``g``. """ + return f.per(dmp_quo(f._rep, g._rep, f.lev, f.dom)) + + def _exquo(f, g): + """Computes polynomial exact quotient of ``f`` and ``g``. """ + return f.per(dmp_exquo(f._rep, g._rep, f.lev, f.dom)) + + def _degree(f, j=0): + """Returns the leading degree of ``f`` in ``x_j``. """ + return dmp_degree_in(f._rep, j, f.lev) + + def degree_list(f): + """Returns a list of degrees of ``f``. """ + return dmp_degree_list(f._rep, f.lev) + + def total_degree(f): + """Returns the total degree of ``f``. """ + return max(sum(m) for m in f.monoms()) + + def LC(f): + """Returns the leading coefficient of ``f``. """ + return dmp_ground_LC(f._rep, f.lev, f.dom) + + def TC(f): + """Returns the trailing coefficient of ``f``. """ + return dmp_ground_TC(f._rep, f.lev, f.dom) + + def _nth(f, N): + """Returns the ``n``-th coefficient of ``f``. """ + return dmp_ground_nth(f._rep, N, f.lev, f.dom) + + def max_norm(f): + """Returns maximum norm of ``f``. """ + return dmp_max_norm(f._rep, f.lev, f.dom) + + def l1_norm(f): + """Returns l1 norm of ``f``. """ + return dmp_l1_norm(f._rep, f.lev, f.dom) + + def l2_norm_squared(f): + """Return squared l2 norm of ``f``. """ + return dmp_l2_norm_squared(f._rep, f.lev, f.dom) + + def clear_denoms(f): + """Clear denominators, but keep the ground domain. """ + coeff, F = dmp_clear_denoms(f._rep, f.lev, f.dom) + return coeff, f.per(F) + + def _integrate(f, m=1, j=0): + """Computes the ``m``-th order indefinite integral of ``f`` in ``x_j``. """ + return f.per(dmp_integrate_in(f._rep, m, j, f.lev, f.dom)) + + def _diff(f, m=1, j=0): + """Computes the ``m``-th order derivative of ``f`` in ``x_j``. """ + return f.per(dmp_diff_in(f._rep, m, j, f.lev, f.dom)) + + def _eval(f, a): + return dmp_eval_in(f._rep, f.dom.convert(a), 0, f.lev, f.dom) + + def _eval_lev(f, a, j): + rep = dmp_eval_in(f._rep, f.dom.convert(a), j, f.lev, f.dom) + return f.new(rep, f.dom, f.lev - 1) + + def _half_gcdex(f, g): + """Half extended Euclidean algorithm, if univariate. """ + s, h = dup_half_gcdex(f._rep, g._rep, f.dom) + return f.per(s), f.per(h) + + def _gcdex(f, g): + """Extended Euclidean algorithm, if univariate. """ + s, t, h = dup_gcdex(f._rep, g._rep, f.dom) + return f.per(s), f.per(t), f.per(h) + + def _invert(f, g): + """Invert ``f`` modulo ``g``, if possible. """ + s = dup_invert(f._rep, g._rep, f.dom) + return f.per(s) + + def _revert(f, n): + """Compute ``f**(-1)`` mod ``x**n``. """ + return f.per(dup_revert(f._rep, n, f.dom)) + + def _subresultants(f, g): + """Computes subresultant PRS sequence of ``f`` and ``g``. """ + R = dmp_subresultants(f._rep, g._rep, f.lev, f.dom) + return list(map(f.per, R)) + + def _resultant_includePRS(f, g): + """Computes resultant of ``f`` and ``g`` via PRS. """ + res, R = dmp_resultant(f._rep, g._rep, f.lev, f.dom, includePRS=True) + if f.lev: + res = f.new(res, f.dom, f.lev - 1) + return res, list(map(f.per, R)) + + def _resultant(f, g): + res = dmp_resultant(f._rep, g._rep, f.lev, f.dom) + if f.lev: + res = f.new(res, f.dom, f.lev - 1) + return res + + def discriminant(f): + """Computes discriminant of ``f``. """ + res = dmp_discriminant(f._rep, f.lev, f.dom) + if f.lev: + res = f.new(res, f.dom, f.lev - 1) + return res + + def _cofactors(f, g): + """Returns GCD of ``f`` and ``g`` and their cofactors. """ + h, cff, cfg = dmp_inner_gcd(f._rep, g._rep, f.lev, f.dom) + return f.per(h), f.per(cff), f.per(cfg) + + def _gcd(f, g): + """Returns polynomial GCD of ``f`` and ``g``. """ + return f.per(dmp_gcd(f._rep, g._rep, f.lev, f.dom)) + + def _lcm(f, g): + """Returns polynomial LCM of ``f`` and ``g``. """ + return f.per(dmp_lcm(f._rep, g._rep, f.lev, f.dom)) + + def _cancel(f, g): + """Cancel common factors in a rational function ``f/g``. """ + cF, cG, F, G = dmp_cancel(f._rep, g._rep, f.lev, f.dom, include=False) + return cF, cG, f.per(F), f.per(G) + + def _cancel_include(f, g): + """Cancel common factors in a rational function ``f/g``. """ + F, G = dmp_cancel(f._rep, g._rep, f.lev, f.dom, include=True) + return f.per(F), f.per(G) + + def _trunc(f, p): + """Reduce ``f`` modulo a constant ``p``. """ + return f.per(dmp_ground_trunc(f._rep, p, f.lev, f.dom)) + + def monic(f): + """Divides all coefficients by ``LC(f)``. """ + return f.per(dmp_ground_monic(f._rep, f.lev, f.dom)) + + def content(f): + """Returns GCD of polynomial coefficients. """ + return dmp_ground_content(f._rep, f.lev, f.dom) + + def primitive(f): + """Returns content and a primitive form of ``f``. """ + cont, F = dmp_ground_primitive(f._rep, f.lev, f.dom) + return cont, f.per(F) + + def _compose(f, g): + """Computes functional composition of ``f`` and ``g``. """ + return f.per(dmp_compose(f._rep, g._rep, f.lev, f.dom)) + + def _decompose(f): + """Computes functional decomposition of ``f``. """ + return list(map(f.per, dup_decompose(f._rep, f.dom))) + + def _shift(f, a): + """Efficiently compute Taylor shift ``f(x + a)``. """ + return f.per(dup_shift(f._rep, a, f.dom)) + + def _shift_list(f, a): + """Efficiently compute Taylor shift ``f(X + A)``. """ + return f.per(dmp_shift(f._rep, a, f.lev, f.dom)) + + def _transform(f, p, q): + """Evaluate functional transformation ``q**n * f(p/q)``.""" + return f.per(dup_transform(f._rep, p._rep, q._rep, f.dom)) + + def _sturm(f): + """Computes the Sturm sequence of ``f``. """ + return list(map(f.per, dup_sturm(f._rep, f.dom))) + + def _cauchy_upper_bound(f): + """Computes the Cauchy upper bound on the roots of ``f``. """ + return dup_cauchy_upper_bound(f._rep, f.dom) + + def _cauchy_lower_bound(f): + """Computes the Cauchy lower bound on the nonzero roots of ``f``. """ + return dup_cauchy_lower_bound(f._rep, f.dom) + + def _mignotte_sep_bound_squared(f): + """Computes the squared Mignotte bound on root separations of ``f``. """ + return dup_mignotte_sep_bound_squared(f._rep, f.dom) + + def _gff_list(f): + """Computes greatest factorial factorization of ``f``. """ + return [ (f.per(g), k) for g, k in dup_gff_list(f._rep, f.dom) ] + + def norm(f): + """Computes ``Norm(f)``.""" + r = dmp_norm(f._rep, f.lev, f.dom) + return f.new(r, f.dom.dom, f.lev) + + def sqf_norm(f): + """Computes square-free norm of ``f``. """ + s, g, r = dmp_sqf_norm(f._rep, f.lev, f.dom) + return s, f.per(g), f.new(r, f.dom.dom, f.lev) + + def sqf_part(f): + """Computes square-free part of ``f``. """ + return f.per(dmp_sqf_part(f._rep, f.lev, f.dom)) + + def sqf_list(f, all=False): + """Returns a list of square-free factors of ``f``. """ + coeff, factors = dmp_sqf_list(f._rep, f.lev, f.dom, all) + return coeff, [ (f.per(g), k) for g, k in factors ] + + def sqf_list_include(f, all=False): + """Returns a list of square-free factors of ``f``. """ + factors = dmp_sqf_list_include(f._rep, f.lev, f.dom, all) + return [ (f.per(g), k) for g, k in factors ] + + def factor_list(f): + """Returns a list of irreducible factors of ``f``. """ + coeff, factors = dmp_factor_list(f._rep, f.lev, f.dom) + return coeff, [ (f.per(g), k) for g, k in factors ] + + def factor_list_include(f): + """Returns a list of irreducible factors of ``f``. """ + factors = dmp_factor_list_include(f._rep, f.lev, f.dom) + return [ (f.per(g), k) for g, k in factors ] + + def _isolate_real_roots(f, eps, inf, sup, fast): + return dup_isolate_real_roots(f._rep, f.dom, eps=eps, inf=inf, sup=sup, fast=fast) + + def _isolate_real_roots_sqf(f, eps, inf, sup, fast): + return dup_isolate_real_roots_sqf(f._rep, f.dom, eps=eps, inf=inf, sup=sup, fast=fast) + + def _isolate_all_roots(f, eps, inf, sup, fast): + return dup_isolate_all_roots(f._rep, f.dom, eps=eps, inf=inf, sup=sup, fast=fast) + + def _isolate_all_roots_sqf(f, eps, inf, sup, fast): + return dup_isolate_all_roots_sqf(f._rep, f.dom, eps=eps, inf=inf, sup=sup, fast=fast) + + def _refine_real_root(f, s, t, eps, steps, fast): + return dup_refine_real_root(f._rep, s, t, f.dom, eps=eps, steps=steps, fast=fast) + + def count_real_roots(f, inf=None, sup=None): + """Return the number of real roots of ``f`` in ``[inf, sup]``. """ + return dup_count_real_roots(f._rep, f.dom, inf=inf, sup=sup) + + def count_complex_roots(f, inf=None, sup=None): + """Return the number of complex roots of ``f`` in ``[inf, sup]``. """ + return dup_count_complex_roots(f._rep, f.dom, inf=inf, sup=sup) + + @property + def is_zero(f): + """Returns ``True`` if ``f`` is a zero polynomial. """ + return dmp_zero_p(f._rep, f.lev) + + @property + def is_one(f): + """Returns ``True`` if ``f`` is a unit polynomial. """ + return dmp_one_p(f._rep, f.lev, f.dom) + + @property + def is_ground(f): + """Returns ``True`` if ``f`` is an element of the ground domain. """ + return dmp_ground_p(f._rep, None, f.lev) + + @property + def is_sqf(f): + """Returns ``True`` if ``f`` is a square-free polynomial. """ + return dmp_sqf_p(f._rep, f.lev, f.dom) + + @property + def is_monic(f): + """Returns ``True`` if the leading coefficient of ``f`` is one. """ + return f.dom.is_one(dmp_ground_LC(f._rep, f.lev, f.dom)) + + @property + def is_primitive(f): + """Returns ``True`` if the GCD of the coefficients of ``f`` is one. """ + return f.dom.is_one(dmp_ground_content(f._rep, f.lev, f.dom)) + + @property + def is_linear(f): + """Returns ``True`` if ``f`` is linear in all its variables. """ + return all(sum(monom) <= 1 for monom in dmp_to_dict(f._rep, f.lev, f.dom).keys()) + + @property + def is_quadratic(f): + """Returns ``True`` if ``f`` is quadratic in all its variables. """ + return all(sum(monom) <= 2 for monom in dmp_to_dict(f._rep, f.lev, f.dom).keys()) + + @property + def is_monomial(f): + """Returns ``True`` if ``f`` is zero or has only one term. """ + return len(f.to_dict()) <= 1 + + @property + def is_homogeneous(f): + """Returns ``True`` if ``f`` is a homogeneous polynomial. """ + return f.homogeneous_order() is not None + + @property + def is_irreducible(f): + """Returns ``True`` if ``f`` has no factors over its domain. """ + return dmp_irreducible_p(f._rep, f.lev, f.dom) + + @property + def is_cyclotomic(f): + """Returns ``True`` if ``f`` is a cyclotomic polynomial. """ + if not f.lev: + return dup_cyclotomic_p(f._rep, f.dom) + else: + return False + + +class DUP_Flint(DMP): + """Dense Multivariate Polynomials over `K`. """ + + lev = 0 + + __slots__ = ('_rep', 'dom', '_cls') + + def __reduce__(self): + return self.__class__, (self.to_list(), self.dom, self.lev) + + @classmethod + def _new(cls, rep, dom, lev): + rep = cls._flint_poly(rep[::-1], dom, lev) + return cls.from_rep(rep, dom) + + def to_list(f): + """Convert ``f`` to a list representation with native coefficients. """ + return f._rep.coeffs()[::-1] + + @classmethod + def _flint_poly(cls, rep, dom, lev): + assert _supported_flint_domain(dom) + assert lev == 0 + flint_cls = cls._get_flint_poly_cls(dom) + return flint_cls(rep) + + @classmethod + def _get_flint_poly_cls(cls, dom): + if dom.is_ZZ: + return flint.fmpz_poly + elif dom.is_QQ: + return flint.fmpq_poly + elif dom.is_FF: + return dom._poly_ctx + else: + raise RuntimeError("Domain %s is not supported with flint" % dom) + + @classmethod + def from_rep(cls, rep, dom): + """Create a DMP from the given representation. """ + + if dom.is_ZZ: + assert isinstance(rep, flint.fmpz_poly) + _cls = flint.fmpz_poly + elif dom.is_QQ: + assert isinstance(rep, flint.fmpq_poly) + _cls = flint.fmpq_poly + elif dom.is_FF: + assert isinstance(rep, (flint.nmod_poly, flint.fmpz_mod_poly)) + c = dom.characteristic() + __cls = type(rep) + _cls = lambda e: __cls(e, c) + else: + raise RuntimeError("Domain %s is not supported with flint" % dom) + + obj = object.__new__(cls) + obj.dom = dom + obj._rep = rep + obj._cls = _cls + + return obj + + def _strict_eq(f, g): + if type(f) != type(g): + return False + return f.dom == g.dom and f._rep == g._rep + + def ground_new(f, coeff): + """Construct a new ground instance of ``f``. """ + return f.from_rep(f._cls([coeff]), f.dom) + + def _one(f): + return f.ground_new(f.dom.one) + + def unify(f, g): + """Unify representations of two polynomials. """ + raise RuntimeError + + def to_DMP_Python(f): + """Convert ``f`` to a Python native representation. """ + return DMP_Python._new(f.to_list(), f.dom, f.lev) + + def to_tuple(f): + """Convert ``f`` to a tuple representation with native coefficients. """ + return tuple(f.to_list()) + + def _convert(f, dom): + """Convert the ground domain of ``f``. """ + if dom == QQ and f.dom == ZZ: + return f.from_rep(flint.fmpq_poly(f._rep), dom) + elif _supported_flint_domain(dom) and _supported_flint_domain(f.dom): + # XXX: python-flint should provide a faster way to do this. + return f.to_DMP_Python()._convert(dom).to_DUP_Flint() + else: + raise RuntimeError(f"DUP_Flint: Cannot convert {f.dom} to {dom}") + + def _slice(f, m, n): + """Take a continuous subsequence of terms of ``f``. """ + coeffs = f._rep.coeffs()[m:n] + return f.from_rep(f._cls(coeffs), f.dom) + + def _slice_lev(f, m, n, j): + """Take a continuous subsequence of terms of ``f``. """ + # Only makes sense for multivariate polynomials + raise NotImplementedError + + def _terms(f, order=None): + """Returns all non-zero terms from ``f`` in lex order. """ + if order is None or order.alias == 'lex': + terms = [ ((n,), c) for n, c in enumerate(f._rep.coeffs()) if c ] + return terms[::-1] + else: + # XXX: InverseOrder (ilex) comes here. We could handle that case + # efficiently by reversing the coefficients but it is not clear + # how to test if the order is InverseOrder. + # + # Otherwise why would the order ever be different for univariate + # polynomials? + return f.to_DMP_Python()._terms(order=order) + + def _lift(f): + """Convert algebraic coefficients to rationals. """ + # This is for algebraic number fields which DUP_Flint does not support + raise NotImplementedError + + def deflate(f): + """Reduce degree of `f` by mapping `x_i^m` to `y_i`. """ + # XXX: Check because otherwise this segfaults with python-flint: + # + # >>> flint.fmpz_poly([]).deflation() + # Exception (fmpz_poly_deflate). Division by zero. + # Aborted (core dumped + # + if f.is_zero: + return (1,), f + g, n = f._rep.deflation() + return (n,), f.from_rep(g, f.dom) + + def inject(f, front=False): + """Inject ground domain generators into ``f``. """ + # Ground domain would need to be a poly ring + raise NotImplementedError + + def eject(f, dom, front=False): + """Eject selected generators into the ground domain. """ + # Only makes sense for multivariate polynomials + raise NotImplementedError + + def _exclude(f): + """Remove useless generators from ``f``. """ + # Only makes sense for multivariate polynomials + raise NotImplementedError + + def _permute(f, P): + """Returns a polynomial in `K[x_{P(1)}, ..., x_{P(n)}]`. """ + # Only makes sense for multivariate polynomials + raise NotImplementedError + + def terms_gcd(f): + """Remove GCD of terms from the polynomial ``f``. """ + # XXX: python-flint should have primitive, content, etc methods. + J, F = f.to_DMP_Python().terms_gcd() + return J, F.to_DUP_Flint() + + def _add_ground(f, c): + """Add an element of the ground domain to ``f``. """ + return f.from_rep(f._rep + c, f.dom) + + def _sub_ground(f, c): + """Subtract an element of the ground domain from ``f``. """ + return f.from_rep(f._rep - c, f.dom) + + def _mul_ground(f, c): + """Multiply ``f`` by a an element of the ground domain. """ + return f.from_rep(f._rep * c, f.dom) + + def _quo_ground(f, c): + """Quotient of ``f`` by a an element of the ground domain. """ + return f.from_rep(f._rep // c, f.dom) + + def _exquo_ground(f, c): + """Exact quotient of ``f`` by an element of the ground domain. """ + q, r = divmod(f._rep, c) + if r: + raise ExactQuotientFailed(f, c) + return f.from_rep(q, f.dom) + + def abs(f): + """Make all coefficients in ``f`` positive. """ + return f.to_DMP_Python().abs().to_DUP_Flint() + + def neg(f): + """Negate all coefficients in ``f``. """ + return f.from_rep(-f._rep, f.dom) + + def _add(f, g): + """Add two multivariate polynomials ``f`` and ``g``. """ + return f.from_rep(f._rep + g._rep, f.dom) + + def _sub(f, g): + """Subtract two multivariate polynomials ``f`` and ``g``. """ + return f.from_rep(f._rep - g._rep, f.dom) + + def _mul(f, g): + """Multiply two multivariate polynomials ``f`` and ``g``. """ + return f.from_rep(f._rep * g._rep, f.dom) + + def sqr(f): + """Square a multivariate polynomial ``f``. """ + return f.from_rep(f._rep ** 2, f.dom) + + def _pow(f, n): + """Raise ``f`` to a non-negative power ``n``. """ + return f.from_rep(f._rep ** n, f.dom) + + def _pdiv(f, g): + """Polynomial pseudo-division of ``f`` and ``g``. """ + d = f.degree() - g.degree() + 1 + q, r = divmod(g.LC()**d * f._rep, g._rep) + return f.from_rep(q, f.dom), f.from_rep(r, f.dom) + + def _prem(f, g): + """Polynomial pseudo-remainder of ``f`` and ``g``. """ + d = f.degree() - g.degree() + 1 + q = (g.LC()**d * f._rep) % g._rep + return f.from_rep(q, f.dom) + + def _pquo(f, g): + """Polynomial pseudo-quotient of ``f`` and ``g``. """ + d = f.degree() - g.degree() + 1 + r = (g.LC()**d * f._rep) // g._rep + return f.from_rep(r, f.dom) + + def _pexquo(f, g): + """Polynomial exact pseudo-quotient of ``f`` and ``g``. """ + d = f.degree() - g.degree() + 1 + q, r = divmod(g.LC()**d * f._rep, g._rep) + if r: + raise ExactQuotientFailed(f, g) + return f.from_rep(q, f.dom) + + def _div(f, g): + """Polynomial division with remainder of ``f`` and ``g``. """ + if f.dom.is_Field: + q, r = divmod(f._rep, g._rep) + return f.from_rep(q, f.dom), f.from_rep(r, f.dom) + else: + # XXX: python-flint defines division in ZZ[x] differently + q, r = f.to_DMP_Python()._div(g.to_DMP_Python()) + return q.to_DUP_Flint(), r.to_DUP_Flint() + + def _rem(f, g): + """Computes polynomial remainder of ``f`` and ``g``. """ + return f.from_rep(f._rep % g._rep, f.dom) + + def _quo(f, g): + """Computes polynomial quotient of ``f`` and ``g``. """ + return f.from_rep(f._rep // g._rep, f.dom) + + def _exquo(f, g): + """Computes polynomial exact quotient of ``f`` and ``g``. """ + q, r = f._div(g) + if r: + raise ExactQuotientFailed(f, g) + return q + + def _degree(f, j=0): + """Returns the leading degree of ``f`` in ``x_j``. """ + d = f._rep.degree() + if d == -1: + d = ninf + return d + + def degree_list(f): + """Returns a list of degrees of ``f``. """ + return ( f._degree() ,) + + def total_degree(f): + """Returns the total degree of ``f``. """ + return f._degree() + + def LC(f): + """Returns the leading coefficient of ``f``. """ + return f._rep[f._rep.degree()] + + def TC(f): + """Returns the trailing coefficient of ``f``. """ + return f._rep[0] + + def _nth(f, N): + """Returns the ``n``-th coefficient of ``f``. """ + [n] = N + return f._rep[n] + + def max_norm(f): + """Returns maximum norm of ``f``. """ + return f.to_DMP_Python().max_norm() + + def l1_norm(f): + """Returns l1 norm of ``f``. """ + return f.to_DMP_Python().l1_norm() + + def l2_norm_squared(f): + """Return squared l2 norm of ``f``. """ + return f.to_DMP_Python().l2_norm_squared() + + def clear_denoms(f): + """Clear denominators, but keep the ground domain. """ + R = f.dom + if R.is_QQ: + denom = f._rep.denom() + numer = f.from_rep(f._cls(f._rep.numer()), f.dom) + return denom, numer + elif R.is_ZZ or R.is_FiniteField: + return R.one, f + else: + raise NotImplementedError + + def _integrate(f, m=1, j=0): + """Computes the ``m``-th order indefinite integral of ``f`` in ``x_j``. """ + assert j == 0 + if f.dom.is_Field: + rep = f._rep + for i in range(m): + rep = rep.integral() + return f.from_rep(rep, f.dom) + else: + return f.to_DMP_Python()._integrate(m=m, j=j).to_DUP_Flint() + + def _diff(f, m=1, j=0): + """Computes the ``m``-th order derivative of ``f``. """ + assert j == 0 + rep = f._rep + for i in range(m): + rep = rep.derivative() + return f.from_rep(rep, f.dom) + + def _eval(f, a): + # XXX: This method is called with many different input types. Ideally + # we could use e.g. fmpz_poly.__call__ here but more thought needs to + # go into which types this is supposed to be called with and what types + # it should return. + return f.to_DMP_Python()._eval(a) + + def _eval_lev(f, a, j): + # Only makes sense for multivariate polynomials + raise NotImplementedError + + def _half_gcdex(f, g): + """Half extended Euclidean algorithm. """ + s, h = f.to_DMP_Python()._half_gcdex(g.to_DMP_Python()) + return s.to_DUP_Flint(), h.to_DUP_Flint() + + def _gcdex(f, g): + """Extended Euclidean algorithm. """ + h, s, t = f._rep.xgcd(g._rep) + return f.from_rep(s, f.dom), f.from_rep(t, f.dom), f.from_rep(h, f.dom) + + def _invert(f, g): + """Invert ``f`` modulo ``g``, if possible. """ + R = f.dom + if R.is_Field: + gcd, F_inv, _ = f._rep.xgcd(g._rep) + # XXX: Should be gcd != 1 but nmod_poly does not compare equal to + # other types. + if gcd != 0*gcd + 1: + raise NotInvertible("zero divisor") + return f.from_rep(F_inv, R) + else: + # fmpz_poly does not have xgcd or invert and this is not well + # defined in general. + return f.to_DMP_Python()._invert(g.to_DMP_Python()).to_DUP_Flint() + + def _revert(f, n): + """Compute ``f**(-1)`` mod ``x**n``. """ + # XXX: Use fmpz_series etc for reversion? + # Maybe python-flint should provide revert for fmpz_poly... + return f.to_DMP_Python()._revert(n).to_DUP_Flint() + + def _subresultants(f, g): + """Computes subresultant PRS sequence of ``f`` and ``g``. """ + # XXX: Maybe _fmpz_poly_pseudo_rem_cohen could be used... + R = f.to_DMP_Python()._subresultants(g.to_DMP_Python()) + return [ g.to_DUP_Flint() for g in R ] + + def _resultant_includePRS(f, g): + """Computes resultant of ``f`` and ``g`` via PRS. """ + # XXX: Maybe _fmpz_poly_pseudo_rem_cohen could be used... + res, R = f.to_DMP_Python()._resultant_includePRS(g.to_DMP_Python()) + return res, [ g.to_DUP_Flint() for g in R ] + + def _resultant(f, g): + """Computes resultant of ``f`` and ``g``. """ + # XXX: Use fmpz_mpoly etc when possible... + return f.to_DMP_Python()._resultant(g.to_DMP_Python()) + + def discriminant(f): + """Computes discriminant of ``f``. """ + # XXX: Use fmpz_mpoly etc when possible... + return f.to_DMP_Python().discriminant() + + def _cofactors(f, g): + """Returns GCD of ``f`` and ``g`` and their cofactors. """ + h = f.gcd(g) + return h, f.exquo(h), g.exquo(h) + + def _gcd(f, g): + """Returns polynomial GCD of ``f`` and ``g``. """ + return f.from_rep(f._rep.gcd(g._rep), f.dom) + + def _lcm(f, g): + """Returns polynomial LCM of ``f`` and ``g``. """ + # XXX: python-flint should have a lcm method + if not (f and g): + return f.ground_new(f.dom.zero) + + l = f._mul(g)._exquo(f._gcd(g)) + + if l.dom.is_Field: + l = l.monic() + elif l.LC() < 0: + l = l.neg() + + return l + + def _cancel(f, g): + """Cancel common factors in a rational function ``f/g``. """ + assert f.dom == g.dom + R = f.dom + + # Think carefully about how to handle denominators and coefficient + # canonicalisation if more domains are permitted... + assert R.is_ZZ or R.is_QQ or R.is_FiniteField + + if R.is_FiniteField: + h = f._gcd(g) + F, G = f.exquo(h), g.exquo(h) + return R.one, R.one, F, G + + if R.is_QQ: + cG, F = f.clear_denoms() + cF, G = g.clear_denoms() + else: + cG, F = R.one, f + cF, G = R.one, g + + cH = cF.gcd(cG) + cF, cG = cF // cH, cG // cH + + H = F._gcd(G) + F, G = F.exquo(H), G.exquo(H) + + f_neg = F.LC() < 0 + g_neg = G.LC() < 0 + + if f_neg and g_neg: + F, G = F.neg(), G.neg() + elif f_neg: + cF, F = -cF, F.neg() + elif g_neg: + cF, G = -cF, G.neg() + + return cF, cG, F, G + + def _cancel_include(f, g): + """Cancel common factors in a rational function ``f/g``. """ + cF, cG, F, G = f._cancel(g) + return F._mul_ground(cF), G._mul_ground(cG) + + def _trunc(f, p): + """Reduce ``f`` modulo a constant ``p``. """ + return f.to_DMP_Python()._trunc(p).to_DUP_Flint() + + def monic(f): + """Divides all coefficients by ``LC(f)``. """ + # XXX: python-flint should add monic + return f._exquo_ground(f.LC()) + + def content(f): + """Returns GCD of polynomial coefficients. """ + # XXX: python-flint should have a content method + return f.to_DMP_Python().content() + + def primitive(f): + """Returns content and a primitive form of ``f``. """ + cont = f.content() + if f.is_zero: + return f.dom.zero, f + prim = f._exquo_ground(cont) + return cont, prim + + def _compose(f, g): + """Computes functional composition of ``f`` and ``g``. """ + return f.from_rep(f._rep(g._rep), f.dom) + + def _decompose(f): + """Computes functional decomposition of ``f``. """ + return [ g.to_DUP_Flint() for g in f.to_DMP_Python()._decompose() ] + + def _shift(f, a): + """Efficiently compute Taylor shift ``f(x + a)``. """ + x_plus_a = f._cls([a, f.dom.one]) + return f.from_rep(f._rep(x_plus_a), f.dom) + + def _transform(f, p, q): + """Evaluate functional transformation ``q**n * f(p/q)``.""" + F, P, Q = f.to_DMP_Python(), p.to_DMP_Python(), q.to_DMP_Python() + return F.transform(P, Q).to_DUP_Flint() + + def _sturm(f): + """Computes the Sturm sequence of ``f``. """ + return [ g.to_DUP_Flint() for g in f.to_DMP_Python()._sturm() ] + + def _cauchy_upper_bound(f): + """Computes the Cauchy upper bound on the roots of ``f``. """ + return f.to_DMP_Python()._cauchy_upper_bound() + + def _cauchy_lower_bound(f): + """Computes the Cauchy lower bound on the nonzero roots of ``f``. """ + return f.to_DMP_Python()._cauchy_lower_bound() + + def _mignotte_sep_bound_squared(f): + """Computes the squared Mignotte bound on root separations of ``f``. """ + return f.to_DMP_Python()._mignotte_sep_bound_squared() + + def _gff_list(f): + """Computes greatest factorial factorization of ``f``. """ + F = f.to_DMP_Python() + return [ (g.to_DUP_Flint(), k) for g, k in F.gff_list() ] + + def norm(f): + """Computes ``Norm(f)``.""" + # This is for algebraic number fields which DUP_Flint does not support + raise NotImplementedError + + def sqf_norm(f): + """Computes square-free norm of ``f``. """ + # This is for algebraic number fields which DUP_Flint does not support + raise NotImplementedError + + def sqf_part(f): + """Computes square-free part of ``f``. """ + return f._exquo(f._gcd(f._diff())) + + def sqf_list(f, all=False): + """Returns a list of square-free factors of ``f``. """ + # XXX: python-flint should provide square free factorisation. + coeff, factors = f.to_DMP_Python().sqf_list(all=all) + return coeff, [ (g.to_DUP_Flint(), k) for g, k in factors ] + + def sqf_list_include(f, all=False): + """Returns a list of square-free factors of ``f``. """ + factors = f.to_DMP_Python().sqf_list_include(all=all) + return [ (g.to_DUP_Flint(), k) for g, k in factors ] + + def factor_list(f): + """Returns a list of irreducible factors of ``f``. """ + + if f.dom.is_ZZ or f.dom.is_FF: + # python-flint matches polys here + coeff, factors = f._rep.factor() + factors = [ (f.from_rep(g, f.dom), k) for g, k in factors ] + + elif f.dom.is_QQ: + # python-flint returns monic factors over QQ whereas polys returns + # denominator free factors. + coeff, factors = f._rep.factor() + factors_monic = [ (f.from_rep(g, f.dom), k) for g, k in factors ] + + # Absorb the denominators into coeff + factors = [] + for g, k in factors_monic: + d, g = g.clear_denoms() + coeff /= d**k + factors.append((g, k)) + + else: + # Check carefully when adding more domains here... + raise RuntimeError("Domain %s is not supported with flint" % f.dom) + + # We need to match the way that polys orders the factors + factors = f._sort_factors(factors) + + return coeff, factors + + def factor_list_include(f): + """Returns a list of irreducible factors of ``f``. """ + # XXX: factor_list_include seems to be broken in general: + # + # >>> Poly(2*(x - 1)**3, x).factor_list_include() + # [(Poly(2*x - 2, x, domain='ZZ'), 3)] + # + # Let's not try to implement it here. + factors = f.to_DMP_Python().factor_list_include() + return [ (g.to_DUP_Flint(), k) for g, k in factors ] + + def _sort_factors(f, factors): + """Sort a list of factors to canonical order. """ + # Convert the factors to lists and use _sort_factors from polys + factors = [ (g.to_list(), k) for g, k in factors ] + factors = _sort_factors(factors, multiple=True) + to_dup_flint = lambda g: f.from_rep(f._cls(g[::-1]), f.dom) + return [ (to_dup_flint(g), k) for g, k in factors ] + + def _isolate_real_roots(f, eps, inf, sup, fast): + return f.to_DMP_Python()._isolate_real_roots(eps, inf, sup, fast) + + def _isolate_real_roots_sqf(f, eps, inf, sup, fast): + return f.to_DMP_Python()._isolate_real_roots_sqf(eps, inf, sup, fast) + + def _isolate_all_roots(f, eps, inf, sup, fast): + # fmpz_poly and fmpq_poly have a complex_roots method that could be + # used here. It probably makes more sense to add analogous methods in + # python-flint though. + return f.to_DMP_Python()._isolate_all_roots(eps, inf, sup, fast) + + def _isolate_all_roots_sqf(f, eps, inf, sup, fast): + return f.to_DMP_Python()._isolate_all_roots_sqf(eps, inf, sup, fast) + + def _refine_real_root(f, s, t, eps, steps, fast): + return f.to_DMP_Python()._refine_real_root(s, t, eps, steps, fast) + + def count_real_roots(f, inf=None, sup=None): + """Return the number of real roots of ``f`` in ``[inf, sup]``. """ + return f.to_DMP_Python().count_real_roots(inf=inf, sup=sup) + + def count_complex_roots(f, inf=None, sup=None): + """Return the number of complex roots of ``f`` in ``[inf, sup]``. """ + return f.to_DMP_Python().count_complex_roots(inf=inf, sup=sup) + + @property + def is_zero(f): + """Returns ``True`` if ``f`` is a zero polynomial. """ + return not f._rep + + @property + def is_one(f): + """Returns ``True`` if ``f`` is a unit polynomial. """ + return f._rep == f.dom.one + + @property + def is_ground(f): + """Returns ``True`` if ``f`` is an element of the ground domain. """ + return f._rep.degree() <= 0 + + @property + def is_linear(f): + """Returns ``True`` if ``f`` is linear in all its variables. """ + return f._rep.degree() <= 1 + + @property + def is_quadratic(f): + """Returns ``True`` if ``f`` is quadratic in all its variables. """ + return f._rep.degree() <= 2 + + @property + def is_monomial(f): + """Returns ``True`` if ``f`` is zero or has only one term. """ + fr = f._rep + return fr.degree() < 0 or not any(fr[n] for n in range(fr.degree())) + + @property + def is_monic(f): + """Returns ``True`` if the leading coefficient of ``f`` is one. """ + return f.LC() == f.dom.one + + @property + def is_primitive(f): + """Returns ``True`` if the GCD of the coefficients of ``f`` is one. """ + return f.to_DMP_Python().is_primitive + + @property + def is_homogeneous(f): + """Returns ``True`` if ``f`` is a homogeneous polynomial. """ + return f.to_DMP_Python().is_homogeneous + + @property + def is_sqf(f): + """Returns ``True`` if ``f`` is a square-free polynomial. """ + g = f._rep.gcd(f._rep.derivative()) + return g.degree() <= 0 + + @property + def is_irreducible(f): + """Returns ``True`` if ``f`` has no factors over its domain. """ + _, factors = f._rep.factor() + if len(factors) == 0: + return True + elif len(factors) == 1: + return factors[0][1] == 1 + else: + return False + + @property + def is_cyclotomic(f): + """Returns ``True`` if ``f`` is a cyclotomic polynomial. """ + if f.dom.is_QQ: + try: + f = f.convert(ZZ) + except CoercionFailed: + return False + if f.dom.is_ZZ: + return bool(f._rep.is_cyclotomic()) + else: + # This is what dup_cyclotomic_p does... + return False + + +def init_normal_DMF(num, den, lev, dom): + return DMF(dmp_normal(num, lev, dom), + dmp_normal(den, lev, dom), dom, lev) + + +class DMF(PicklableWithSlots, CantSympify): + """Dense Multivariate Fractions over `K`. """ + + __slots__ = ('num', 'den', 'lev', 'dom') + + def __init__(self, rep, dom, lev=None): + num, den, lev = self._parse(rep, dom, lev) + num, den = dmp_cancel(num, den, lev, dom) + + self.num = num + self.den = den + self.lev = lev + self.dom = dom + + @classmethod + def new(cls, rep, dom, lev=None): + num, den, lev = cls._parse(rep, dom, lev) + + obj = object.__new__(cls) + + obj.num = num + obj.den = den + obj.lev = lev + obj.dom = dom + + return obj + + def ground_new(self, rep): + return self.new(rep, self.dom, self.lev) + + @classmethod + def _parse(cls, rep, dom, lev=None): + if isinstance(rep, tuple): + num, den = rep + + if lev is not None: + if isinstance(num, dict): + num = dmp_from_dict(num, lev, dom) + + if isinstance(den, dict): + den = dmp_from_dict(den, lev, dom) + else: + num, num_lev = dmp_validate(num) + den, den_lev = dmp_validate(den) + + if num_lev == den_lev: + lev = num_lev + else: + raise ValueError('inconsistent number of levels') + + if dmp_zero_p(den, lev): + raise ZeroDivisionError('fraction denominator') + + if dmp_zero_p(num, lev): + den = dmp_one(lev, dom) + else: + if dmp_negative_p(den, lev, dom): + num = dmp_neg(num, lev, dom) + den = dmp_neg(den, lev, dom) + else: + num = rep + + if lev is not None: + if isinstance(num, dict): + num = dmp_from_dict(num, lev, dom) + elif not isinstance(num, list): + num = dmp_ground(dom.convert(num), lev) + else: + num, lev = dmp_validate(num) + + den = dmp_one(lev, dom) + + return num, den, lev + + def __repr__(f): + return "%s((%s, %s), %s)" % (f.__class__.__name__, f.num, f.den, f.dom) + + def __hash__(f): + return hash((f.__class__.__name__, dmp_to_tuple(f.num, f.lev), + dmp_to_tuple(f.den, f.lev), f.lev, f.dom)) + + def poly_unify(f, g): + """Unify a multivariate fraction and a polynomial. """ + if not isinstance(g, DMP) or f.lev != g.lev: + raise UnificationFailed("Cannot unify %s with %s" % (f, g)) + + if f.dom == g.dom: + return (f.lev, f.dom, f.per, (f.num, f.den), g._rep) + else: + lev, dom = f.lev, f.dom.unify(g.dom) + + F = (dmp_convert(f.num, lev, f.dom, dom), + dmp_convert(f.den, lev, f.dom, dom)) + + G = dmp_convert(g._rep, lev, g.dom, dom) + + def per(num, den, cancel=True, kill=False, lev=lev): + if kill: + if not lev: + return num/den + else: + lev = lev - 1 + + if cancel: + num, den = dmp_cancel(num, den, lev, dom) + + return f.__class__.new((num, den), dom, lev) + + return lev, dom, per, F, G + + def frac_unify(f, g): + """Unify representations of two multivariate fractions. """ + if not isinstance(g, DMF) or f.lev != g.lev: + raise UnificationFailed("Cannot unify %s with %s" % (f, g)) + + if f.dom == g.dom: + return (f.lev, f.dom, f.per, (f.num, f.den), + (g.num, g.den)) + else: + lev, dom = f.lev, f.dom.unify(g.dom) + + F = (dmp_convert(f.num, lev, f.dom, dom), + dmp_convert(f.den, lev, f.dom, dom)) + + G = (dmp_convert(g.num, lev, g.dom, dom), + dmp_convert(g.den, lev, g.dom, dom)) + + def per(num, den, cancel=True, kill=False, lev=lev): + if kill: + if not lev: + return num/den + else: + lev = lev - 1 + + if cancel: + num, den = dmp_cancel(num, den, lev, dom) + + return f.__class__.new((num, den), dom, lev) + + return lev, dom, per, F, G + + def per(f, num, den, cancel=True, kill=False): + """Create a DMF out of the given representation. """ + lev, dom = f.lev, f.dom + + if kill: + if not lev: + return num/den + else: + lev -= 1 + + if cancel: + num, den = dmp_cancel(num, den, lev, dom) + + return f.__class__.new((num, den), dom, lev) + + def half_per(f, rep, kill=False): + """Create a DMP out of the given representation. """ + lev = f.lev + + if kill: + if not lev: + return rep + else: + lev -= 1 + + return DMP(rep, f.dom, lev) + + @classmethod + def zero(cls, lev, dom): + return cls.new(0, dom, lev) + + @classmethod + def one(cls, lev, dom): + return cls.new(1, dom, lev) + + def numer(f): + """Returns the numerator of ``f``. """ + return f.half_per(f.num) + + def denom(f): + """Returns the denominator of ``f``. """ + return f.half_per(f.den) + + def cancel(f): + """Remove common factors from ``f.num`` and ``f.den``. """ + return f.per(f.num, f.den) + + def neg(f): + """Negate all coefficients in ``f``. """ + return f.per(dmp_neg(f.num, f.lev, f.dom), f.den, cancel=False) + + def add_ground(f, c): + """Add an element of the ground domain to ``f``. """ + return f + f.ground_new(c) + + def add(f, g): + """Add two multivariate fractions ``f`` and ``g``. """ + if isinstance(g, DMP): + lev, dom, per, (F_num, F_den), G = f.poly_unify(g) + num, den = dmp_add_mul(F_num, F_den, G, lev, dom), F_den + else: + lev, dom, per, F, G = f.frac_unify(g) + (F_num, F_den), (G_num, G_den) = F, G + + num = dmp_add(dmp_mul(F_num, G_den, lev, dom), + dmp_mul(F_den, G_num, lev, dom), lev, dom) + den = dmp_mul(F_den, G_den, lev, dom) + + return per(num, den) + + def sub(f, g): + """Subtract two multivariate fractions ``f`` and ``g``. """ + if isinstance(g, DMP): + lev, dom, per, (F_num, F_den), G = f.poly_unify(g) + num, den = dmp_sub_mul(F_num, F_den, G, lev, dom), F_den + else: + lev, dom, per, F, G = f.frac_unify(g) + (F_num, F_den), (G_num, G_den) = F, G + + num = dmp_sub(dmp_mul(F_num, G_den, lev, dom), + dmp_mul(F_den, G_num, lev, dom), lev, dom) + den = dmp_mul(F_den, G_den, lev, dom) + + return per(num, den) + + def mul(f, g): + """Multiply two multivariate fractions ``f`` and ``g``. """ + if isinstance(g, DMP): + lev, dom, per, (F_num, F_den), G = f.poly_unify(g) + num, den = dmp_mul(F_num, G, lev, dom), F_den + else: + lev, dom, per, F, G = f.frac_unify(g) + (F_num, F_den), (G_num, G_den) = F, G + + num = dmp_mul(F_num, G_num, lev, dom) + den = dmp_mul(F_den, G_den, lev, dom) + + return per(num, den) + + def pow(f, n): + """Raise ``f`` to a non-negative power ``n``. """ + if isinstance(n, int): + num, den = f.num, f.den + if n < 0: + num, den, n = den, num, -n + return f.per(dmp_pow(num, n, f.lev, f.dom), + dmp_pow(den, n, f.lev, f.dom), cancel=False) + else: + raise TypeError("``int`` expected, got %s" % type(n)) + + def quo(f, g): + """Computes quotient of fractions ``f`` and ``g``. """ + if isinstance(g, DMP): + lev, dom, per, (F_num, F_den), G = f.poly_unify(g) + num, den = F_num, dmp_mul(F_den, G, lev, dom) + else: + lev, dom, per, F, G = f.frac_unify(g) + (F_num, F_den), (G_num, G_den) = F, G + + num = dmp_mul(F_num, G_den, lev, dom) + den = dmp_mul(F_den, G_num, lev, dom) + + return per(num, den) + + exquo = quo + + def invert(f, check=True): + """Computes inverse of a fraction ``f``. """ + return f.per(f.den, f.num, cancel=False) + + @property + def is_zero(f): + """Returns ``True`` if ``f`` is a zero fraction. """ + return dmp_zero_p(f.num, f.lev) + + @property + def is_one(f): + """Returns ``True`` if ``f`` is a unit fraction. """ + return dmp_one_p(f.num, f.lev, f.dom) and \ + dmp_one_p(f.den, f.lev, f.dom) + + def __neg__(f): + return f.neg() + + def __add__(f, g): + if isinstance(g, (DMP, DMF)): + return f.add(g) + elif g in f.dom: + return f.add_ground(f.dom.convert(g)) + + try: + return f.add(f.half_per(g)) + except (TypeError, CoercionFailed, NotImplementedError): + return NotImplemented + + def __radd__(f, g): + return f.__add__(g) + + def __sub__(f, g): + if isinstance(g, (DMP, DMF)): + return f.sub(g) + + try: + return f.sub(f.half_per(g)) + except (TypeError, CoercionFailed, NotImplementedError): + return NotImplemented + + def __rsub__(f, g): + return (-f).__add__(g) + + def __mul__(f, g): + if isinstance(g, (DMP, DMF)): + return f.mul(g) + + try: + return f.mul(f.half_per(g)) + except (TypeError, CoercionFailed, NotImplementedError): + return NotImplemented + + def __rmul__(f, g): + return f.__mul__(g) + + def __pow__(f, n): + return f.pow(n) + + def __truediv__(f, g): + if isinstance(g, (DMP, DMF)): + return f.quo(g) + + try: + return f.quo(f.half_per(g)) + except (TypeError, CoercionFailed, NotImplementedError): + return NotImplemented + + def __rtruediv__(self, g): + return self.invert(check=False)*g + + def __eq__(f, g): + try: + if isinstance(g, DMP): + _, _, _, (F_num, F_den), G = f.poly_unify(g) + + if f.lev == g.lev: + return dmp_one_p(F_den, f.lev, f.dom) and F_num == G + else: + _, _, _, F, G = f.frac_unify(g) + + if f.lev == g.lev: + return F == G + except UnificationFailed: + pass + + return False + + def __ne__(f, g): + try: + if isinstance(g, DMP): + _, _, _, (F_num, F_den), G = f.poly_unify(g) + + if f.lev == g.lev: + return not (dmp_one_p(F_den, f.lev, f.dom) and F_num == G) + else: + _, _, _, F, G = f.frac_unify(g) + + if f.lev == g.lev: + return F != G + except UnificationFailed: + pass + + return True + + def __lt__(f, g): + _, _, _, F, G = f.frac_unify(g) + return F < G + + def __le__(f, g): + _, _, _, F, G = f.frac_unify(g) + return F <= G + + def __gt__(f, g): + _, _, _, F, G = f.frac_unify(g) + return F > G + + def __ge__(f, g): + _, _, _, F, G = f.frac_unify(g) + return F >= G + + def __bool__(f): + return not dmp_zero_p(f.num, f.lev) + + +def init_normal_ANP(rep, mod, dom): + return ANP(dup_normal(rep, dom), + dup_normal(mod, dom), dom) + + +class ANP(CantSympify): + """Dense Algebraic Number Polynomials over a field. """ + + __slots__ = ('_rep', '_mod', 'dom') + + def __new__(cls, rep, mod, dom): + if isinstance(rep, DMP): + pass + elif type(rep) is dict: # don't use isinstance + rep = DMP(dup_from_dict(rep, dom), dom, 0) + else: + if isinstance(rep, list): + rep = [dom.convert(a) for a in rep] + else: + rep = [dom.convert(rep)] + rep = DMP(dup_strip(rep), dom, 0) + + if isinstance(mod, DMP): + pass + elif isinstance(mod, dict): + mod = DMP(dup_from_dict(mod, dom), dom, 0) + else: + mod = DMP(dup_strip(mod), dom, 0) + + return cls.new(rep, mod, dom) + + @classmethod + def new(cls, rep, mod, dom): + if not (rep.dom == mod.dom == dom): + raise RuntimeError("Inconsistent domain") + obj = super().__new__(cls) + obj._rep = rep + obj._mod = mod + obj.dom = dom + return obj + + # XXX: It should be possible to use __getnewargs__ rather than __reduce__ + # but it doesn't work for some reason. Probably this would be easier if + # python-flint supported pickling for polynomial types. + def __reduce__(self): + return ANP, (self.rep, self.mod, self.dom) + + @property + def rep(self): + return self._rep.to_list() + + @property + def mod(self): + return self.mod_to_list() + + def to_DMP(self): + return self._rep + + def mod_to_DMP(self): + return self._mod + + def per(f, rep): + return f.new(rep, f._mod, f.dom) + + def __repr__(f): + return "%s(%s, %s, %s)" % (f.__class__.__name__, f._rep.to_list(), f._mod.to_list(), f.dom) + + def __hash__(f): + return hash((f.__class__.__name__, f.to_tuple(), f._mod.to_tuple(), f.dom)) + + def convert(f, dom): + """Convert ``f`` to a ``ANP`` over a new domain. """ + if f.dom == dom: + return f + else: + return f.new(f._rep.convert(dom), f._mod.convert(dom), dom) + + def unify(f, g): + """Unify representations of two algebraic numbers. """ + + # XXX: This unify method is not used any more because unify_ANP is used + # instead. + + if not isinstance(g, ANP) or f.mod != g.mod: + raise UnificationFailed("Cannot unify %s with %s" % (f, g)) + + if f.dom == g.dom: + return f.dom, f.per, f.rep, g.rep, f.mod + else: + dom = f.dom.unify(g.dom) + + F = dup_convert(f.rep, f.dom, dom) + G = dup_convert(g.rep, g.dom, dom) + + if dom != f.dom and dom != g.dom: + mod = dup_convert(f.mod, f.dom, dom) + else: + if dom == f.dom: + mod = f.mod + else: + mod = g.mod + + per = lambda rep: ANP(rep, mod, dom) + + return dom, per, F, G, mod + + def unify_ANP(f, g): + """Unify and return ``DMP`` instances of ``f`` and ``g``. """ + if not isinstance(g, ANP) or f._mod != g._mod: + raise UnificationFailed("Cannot unify %s with %s" % (f, g)) + + # The domain is almost always QQ but there are some tests involving ZZ + if f.dom != g.dom: + dom = f.dom.unify(g.dom) + f = f.convert(dom) + g = g.convert(dom) + + return f._rep, g._rep, f._mod, f.dom + + @classmethod + def zero(cls, mod, dom): + return ANP(0, mod, dom) + + @classmethod + def one(cls, mod, dom): + return ANP(1, mod, dom) + + def to_dict(f): + """Convert ``f`` to a dict representation with native coefficients. """ + return f._rep.to_dict() + + def to_sympy_dict(f): + """Convert ``f`` to a dict representation with SymPy coefficients. """ + rep = dmp_to_dict(f.rep, 0, f.dom) + + for k, v in rep.items(): + rep[k] = f.dom.to_sympy(v) + + return rep + + def to_list(f): + """Convert ``f`` to a list representation with native coefficients. """ + return f._rep.to_list() + + def mod_to_list(f): + """Return ``f.mod`` as a list with native coefficients. """ + return f._mod.to_list() + + def to_sympy_list(f): + """Convert ``f`` to a list representation with SymPy coefficients. """ + return [ f.dom.to_sympy(c) for c in f.to_list() ] + + def to_tuple(f): + """ + Convert ``f`` to a tuple representation with native coefficients. + + This is needed for hashing. + """ + return f._rep.to_tuple() + + @classmethod + def from_list(cls, rep, mod, dom): + return ANP(dup_strip(list(map(dom.convert, rep))), mod, dom) + + def add_ground(f, c): + """Add an element of the ground domain to ``f``. """ + return f.per(f._rep.add_ground(c)) + + def sub_ground(f, c): + """Subtract an element of the ground domain from ``f``. """ + return f.per(f._rep.sub_ground(c)) + + def mul_ground(f, c): + """Multiply ``f`` by an element of the ground domain. """ + return f.per(f._rep.mul_ground(c)) + + def quo_ground(f, c): + """Quotient of ``f`` by an element of the ground domain. """ + return f.per(f._rep.quo_ground(c)) + + def neg(f): + return f.per(f._rep.neg()) + + def add(f, g): + F, G, mod, dom = f.unify_ANP(g) + return f.new(F.add(G), mod, dom) + + def sub(f, g): + F, G, mod, dom = f.unify_ANP(g) + return f.new(F.sub(G), mod, dom) + + def mul(f, g): + F, G, mod, dom = f.unify_ANP(g) + return f.new(F.mul(G).rem(mod), mod, dom) + + def pow(f, n): + """Raise ``f`` to a non-negative power ``n``. """ + if not isinstance(n, int): + raise TypeError("``int`` expected, got %s" % type(n)) + + mod = f._mod + F = f._rep + + if n < 0: + F, n = F.invert(mod), -n + + # XXX: Need a pow_mod method for DMP + return f.new(F.pow(n).rem(f._mod), mod, f.dom) + + def exquo(f, g): + F, G, mod, dom = f.unify_ANP(g) + return f.new(F.mul(G.invert(mod)).rem(mod), mod, dom) + + def div(f, g): + return f.exquo(g), f.zero(f._mod, f.dom) + + def quo(f, g): + return f.exquo(g) + + def rem(f, g): + F, G, mod, dom = f.unify_ANP(g) + s, h = F.half_gcdex(G) + + if h.is_one: + return f.zero(mod, dom) + else: + raise NotInvertible("zero divisor") + + def LC(f): + """Returns the leading coefficient of ``f``. """ + return f._rep.LC() + + def TC(f): + """Returns the trailing coefficient of ``f``. """ + return f._rep.TC() + + @property + def is_zero(f): + """Returns ``True`` if ``f`` is a zero algebraic number. """ + return f._rep.is_zero + + @property + def is_one(f): + """Returns ``True`` if ``f`` is a unit algebraic number. """ + return f._rep.is_one + + @property + def is_ground(f): + """Returns ``True`` if ``f`` is an element of the ground domain. """ + return f._rep.is_ground + + def __pos__(f): + return f + + def __neg__(f): + return f.neg() + + def __add__(f, g): + if isinstance(g, ANP): + return f.add(g) + try: + g = f.dom.convert(g) + except CoercionFailed: + return NotImplemented + else: + return f.add_ground(g) + + def __radd__(f, g): + return f.__add__(g) + + def __sub__(f, g): + if isinstance(g, ANP): + return f.sub(g) + try: + g = f.dom.convert(g) + except CoercionFailed: + return NotImplemented + else: + return f.sub_ground(g) + + def __rsub__(f, g): + return (-f).__add__(g) + + def __mul__(f, g): + if isinstance(g, ANP): + return f.mul(g) + try: + g = f.dom.convert(g) + except CoercionFailed: + return NotImplemented + else: + return f.mul_ground(g) + + def __rmul__(f, g): + return f.__mul__(g) + + def __pow__(f, n): + return f.pow(n) + + def __divmod__(f, g): + return f.div(g) + + def __mod__(f, g): + return f.rem(g) + + def __truediv__(f, g): + if isinstance(g, ANP): + return f.quo(g) + try: + g = f.dom.convert(g) + except CoercionFailed: + return NotImplemented + else: + return f.quo_ground(g) + + def __eq__(f, g): + try: + F, G, _, _ = f.unify_ANP(g) + except UnificationFailed: + return NotImplemented + return F == G + + def __ne__(f, g): + try: + F, G, _, _ = f.unify_ANP(g) + except UnificationFailed: + return NotImplemented + return F != G + + def __lt__(f, g): + F, G, _, _ = f.unify_ANP(g) + return F < G + + def __le__(f, g): + F, G, _, _ = f.unify_ANP(g) + return F <= G + + def __gt__(f, g): + F, G, _, _ = f.unify_ANP(g) + return F > G + + def __ge__(f, g): + F, G, _, _ = f.unify_ANP(g) + return F >= G + + def __bool__(f): + return bool(f._rep) diff --git a/phivenv/Lib/site-packages/sympy/polys/polyconfig.py b/phivenv/Lib/site-packages/sympy/polys/polyconfig.py new file mode 100644 index 0000000000000000000000000000000000000000..75731f7ac4e4f8784ff8f999cc3537bfa3c6659a --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/polyconfig.py @@ -0,0 +1,67 @@ +"""Configuration utilities for polynomial manipulation algorithms. """ + + +from contextlib import contextmanager + +_default_config = { + 'USE_COLLINS_RESULTANT': False, + 'USE_SIMPLIFY_GCD': True, + 'USE_HEU_GCD': True, + + 'USE_IRREDUCIBLE_IN_FACTOR': False, + 'USE_CYCLOTOMIC_FACTOR': True, + + 'EEZ_RESTART_IF_NEEDED': True, + 'EEZ_NUMBER_OF_CONFIGS': 3, + 'EEZ_NUMBER_OF_TRIES': 5, + 'EEZ_MODULUS_STEP': 2, + + 'GF_IRRED_METHOD': 'rabin', + 'GF_FACTOR_METHOD': 'zassenhaus', + + 'GROEBNER': 'buchberger', +} + +_current_config = {} + +@contextmanager +def using(**kwargs): + for k, v in kwargs.items(): + setup(k, v) + + yield + + for k in kwargs.keys(): + setup(k) + +def setup(key, value=None): + """Assign a value to (or reset) a configuration item. """ + key = key.upper() + + if value is not None: + _current_config[key] = value + else: + _current_config[key] = _default_config[key] + + +def query(key): + """Ask for a value of the given configuration item. """ + return _current_config.get(key.upper(), None) + + +def configure(): + """Initialized configuration of polys module. """ + from os import getenv + + for key, default in _default_config.items(): + value = getenv('SYMPY_' + key) + + if value is not None: + try: + _current_config[key] = eval(value) + except NameError: + _current_config[key] = value + else: + _current_config[key] = default + +configure() diff --git a/phivenv/Lib/site-packages/sympy/polys/polyerrors.py b/phivenv/Lib/site-packages/sympy/polys/polyerrors.py new file mode 100644 index 0000000000000000000000000000000000000000..79385ffaf6746386f8f108c3e02992dcaf4a4f55 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/polyerrors.py @@ -0,0 +1,183 @@ +"""Definitions of common exceptions for `polys` module. """ + + +from sympy.utilities import public + +@public +class BasePolynomialError(Exception): + """Base class for polynomial related exceptions. """ + + def new(self, *args): + raise NotImplementedError("abstract base class") + +@public +class ExactQuotientFailed(BasePolynomialError): + + def __init__(self, f, g, dom=None): + self.f, self.g, self.dom = f, g, dom + + def __str__(self): # pragma: no cover + from sympy.printing.str import sstr + + if self.dom is None: + return "%s does not divide %s" % (sstr(self.g), sstr(self.f)) + else: + return "%s does not divide %s in %s" % (sstr(self.g), sstr(self.f), sstr(self.dom)) + + def new(self, f, g): + return self.__class__(f, g, self.dom) + +@public +class PolynomialDivisionFailed(BasePolynomialError): + + def __init__(self, f, g, domain): + self.f = f + self.g = g + self.domain = domain + + def __str__(self): + if self.domain.is_EX: + msg = "You may want to use a different simplification algorithm. Note " \ + "that in general it's not possible to guarantee to detect zero " \ + "in this domain." + elif not self.domain.is_Exact: + msg = "Your working precision or tolerance of computations may be set " \ + "improperly. Adjust those parameters of the coefficient domain " \ + "and try again." + else: + msg = "Zero detection is guaranteed in this coefficient domain. This " \ + "may indicate a bug in SymPy or the domain is user defined and " \ + "doesn't implement zero detection properly." + + return "couldn't reduce degree in a polynomial division algorithm when " \ + "dividing %s by %s. This can happen when it's not possible to " \ + "detect zero in the coefficient domain. The domain of computation " \ + "is %s. %s" % (self.f, self.g, self.domain, msg) + +@public +class OperationNotSupported(BasePolynomialError): + + def __init__(self, poly, func): + self.poly = poly + self.func = func + + def __str__(self): # pragma: no cover + return "`%s` operation not supported by %s representation" % (self.func, self.poly.rep.__class__.__name__) + +@public +class HeuristicGCDFailed(BasePolynomialError): + pass + +class ModularGCDFailed(BasePolynomialError): + pass + +@public +class HomomorphismFailed(BasePolynomialError): + pass + +@public +class IsomorphismFailed(BasePolynomialError): + pass + +@public +class ExtraneousFactors(BasePolynomialError): + pass + +@public +class EvaluationFailed(BasePolynomialError): + pass + +@public +class RefinementFailed(BasePolynomialError): + pass + +@public +class CoercionFailed(BasePolynomialError): + pass + +@public +class NotInvertible(BasePolynomialError): + pass + +@public +class NotReversible(BasePolynomialError): + pass + +@public +class NotAlgebraic(BasePolynomialError): + pass + +@public +class DomainError(BasePolynomialError): + pass + +@public +class PolynomialError(BasePolynomialError): + pass + +@public +class UnificationFailed(BasePolynomialError): + pass + +@public +class UnsolvableFactorError(BasePolynomialError): + """Raised if ``roots`` is called with strict=True and a polynomial + having a factor whose solutions are not expressible in radicals + is encountered.""" + +@public +class GeneratorsError(BasePolynomialError): + pass + +@public +class GeneratorsNeeded(GeneratorsError): + pass + +@public +class ComputationFailed(BasePolynomialError): + + def __init__(self, func, nargs, exc): + self.func = func + self.nargs = nargs + self.exc = exc + + def __str__(self): + return "%s(%s) failed without generators" % (self.func, ', '.join(map(str, self.exc.exprs[:self.nargs]))) + +@public +class UnivariatePolynomialError(PolynomialError): + pass + +@public +class MultivariatePolynomialError(PolynomialError): + pass + +@public +class PolificationFailed(PolynomialError): + + def __init__(self, opt, origs, exprs, seq=False): + if not seq: + self.orig = origs + self.expr = exprs + self.origs = [origs] + self.exprs = [exprs] + else: + self.origs = origs + self.exprs = exprs + + self.opt = opt + self.seq = seq + + def __str__(self): # pragma: no cover + if not self.seq: + return "Cannot construct a polynomial from %s" % str(self.orig) + else: + return "Cannot construct polynomials from %s" % ', '.join(map(str, self.origs)) + +@public +class OptionError(BasePolynomialError): + pass + +@public +class FlagError(OptionError): + pass diff --git a/phivenv/Lib/site-packages/sympy/polys/polyfuncs.py b/phivenv/Lib/site-packages/sympy/polys/polyfuncs.py new file mode 100644 index 0000000000000000000000000000000000000000..b412123f7383c68177a88df8817e921d96f6d5af --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/polyfuncs.py @@ -0,0 +1,321 @@ +"""High-level polynomials manipulation functions. """ + + +from sympy.core import S, Basic, symbols, Dummy +from sympy.polys.polyerrors import ( + PolificationFailed, ComputationFailed, + MultivariatePolynomialError, OptionError) +from sympy.polys.polyoptions import allowed_flags, build_options +from sympy.polys.polytools import poly_from_expr, Poly +from sympy.polys.specialpolys import ( + symmetric_poly, interpolating_poly) +from sympy.polys.rings import sring +from sympy.utilities import numbered_symbols, take, public + +@public +def symmetrize(F, *gens, **args): + r""" + Rewrite a polynomial in terms of elementary symmetric polynomials. + + A symmetric polynomial is a multivariate polynomial that remains invariant + under any variable permutation, i.e., if `f = f(x_1, x_2, \dots, x_n)`, + then `f = f(x_{i_1}, x_{i_2}, \dots, x_{i_n})`, where + `(i_1, i_2, \dots, i_n)` is a permutation of `(1, 2, \dots, n)` (an + element of the group `S_n`). + + Returns a tuple of symmetric polynomials ``(f1, f2, ..., fn)`` such that + ``f = f1 + f2 + ... + fn``. + + Examples + ======== + + >>> from sympy.polys.polyfuncs import symmetrize + >>> from sympy.abc import x, y + + >>> symmetrize(x**2 + y**2) + (-2*x*y + (x + y)**2, 0) + + >>> symmetrize(x**2 + y**2, formal=True) + (s1**2 - 2*s2, 0, [(s1, x + y), (s2, x*y)]) + + >>> symmetrize(x**2 - y**2) + (-2*x*y + (x + y)**2, -2*y**2) + + >>> symmetrize(x**2 - y**2, formal=True) + (s1**2 - 2*s2, -2*y**2, [(s1, x + y), (s2, x*y)]) + + """ + allowed_flags(args, ['formal', 'symbols']) + + iterable = True + + if not hasattr(F, '__iter__'): + iterable = False + F = [F] + + R, F = sring(F, *gens, **args) + gens = R.symbols + + opt = build_options(gens, args) + symbols = opt.symbols + symbols = [next(symbols) for i in range(len(gens))] + + result = [] + + for f in F: + p, r, m = f.symmetrize() + result.append((p.as_expr(*symbols), r.as_expr(*gens))) + + polys = [(s, g.as_expr()) for s, (_, g) in zip(symbols, m)] + + if not opt.formal: + for i, (sym, non_sym) in enumerate(result): + result[i] = (sym.subs(polys), non_sym) + + if not iterable: + result, = result + + if not opt.formal: + return result + else: + if iterable: + return result, polys + else: + return result + (polys,) + + +@public +def horner(f, *gens, **args): + """ + Rewrite a polynomial in Horner form. + + Among other applications, evaluation of a polynomial at a point is optimal + when it is applied using the Horner scheme ([1]). + + Examples + ======== + + >>> from sympy.polys.polyfuncs import horner + >>> from sympy.abc import x, y, a, b, c, d, e + + >>> horner(9*x**4 + 8*x**3 + 7*x**2 + 6*x + 5) + x*(x*(x*(9*x + 8) + 7) + 6) + 5 + + >>> horner(a*x**4 + b*x**3 + c*x**2 + d*x + e) + e + x*(d + x*(c + x*(a*x + b))) + + >>> f = 4*x**2*y**2 + 2*x**2*y + 2*x*y**2 + x*y + + >>> horner(f, wrt=x) + x*(x*y*(4*y + 2) + y*(2*y + 1)) + + >>> horner(f, wrt=y) + y*(x*y*(4*x + 2) + x*(2*x + 1)) + + References + ========== + [1] - https://en.wikipedia.org/wiki/Horner_scheme + + """ + allowed_flags(args, []) + + try: + F, opt = poly_from_expr(f, *gens, **args) + except PolificationFailed as exc: + return exc.expr + + form, gen = S.Zero, F.gen + + if F.is_univariate: + for coeff in F.all_coeffs(): + form = form*gen + coeff + else: + F, gens = Poly(F, gen), gens[1:] + + for coeff in F.all_coeffs(): + form = form*gen + horner(coeff, *gens, **args) + + return form + + +@public +def interpolate(data, x): + """ + Construct an interpolating polynomial for the data points + evaluated at point x (which can be symbolic or numeric). + + Examples + ======== + + >>> from sympy.polys.polyfuncs import interpolate + >>> from sympy.abc import a, b, x + + A list is interpreted as though it were paired with a range starting + from 1: + + >>> interpolate([1, 4, 9, 16], x) + x**2 + + This can be made explicit by giving a list of coordinates: + + >>> interpolate([(1, 1), (2, 4), (3, 9)], x) + x**2 + + The (x, y) coordinates can also be given as keys and values of a + dictionary (and the points need not be equispaced): + + >>> interpolate([(-1, 2), (1, 2), (2, 5)], x) + x**2 + 1 + >>> interpolate({-1: 2, 1: 2, 2: 5}, x) + x**2 + 1 + + If the interpolation is going to be used only once then the + value of interest can be passed instead of passing a symbol: + + >>> interpolate([1, 4, 9], 5) + 25 + + Symbolic coordinates are also supported: + + >>> [(i,interpolate((a, b), i)) for i in range(1, 4)] + [(1, a), (2, b), (3, -a + 2*b)] + """ + n = len(data) + + if isinstance(data, dict): + if x in data: + return S(data[x]) + X, Y = list(zip(*data.items())) + else: + if isinstance(data[0], tuple): + X, Y = list(zip(*data)) + if x in X: + return S(Y[X.index(x)]) + else: + if x in range(1, n + 1): + return S(data[x - 1]) + Y = list(data) + X = list(range(1, n + 1)) + + try: + return interpolating_poly(n, x, X, Y).expand() + except ValueError: + d = Dummy() + return interpolating_poly(n, d, X, Y).expand().subs(d, x) + + +@public +def rational_interpolate(data, degnum, X=symbols('x')): + """ + Returns a rational interpolation, where the data points are element of + any integral domain. + + The first argument contains the data (as a list of coordinates). The + ``degnum`` argument is the degree in the numerator of the rational + function. Setting it too high will decrease the maximal degree in the + denominator for the same amount of data. + + Examples + ======== + + >>> from sympy.polys.polyfuncs import rational_interpolate + + >>> data = [(1, -210), (2, -35), (3, 105), (4, 231), (5, 350), (6, 465)] + >>> rational_interpolate(data, 2) + (105*x**2 - 525)/(x + 1) + + Values do not need to be integers: + + >>> from sympy import sympify + >>> x = [1, 2, 3, 4, 5, 6] + >>> y = sympify("[-1, 0, 2, 22/5, 7, 68/7]") + >>> rational_interpolate(zip(x, y), 2) + (3*x**2 - 7*x + 2)/(x + 1) + + The symbol for the variable can be changed if needed: + >>> from sympy import symbols + >>> z = symbols('z') + >>> rational_interpolate(data, 2, X=z) + (105*z**2 - 525)/(z + 1) + + References + ========== + + .. [1] Algorithm is adapted from: + http://axiom-wiki.newsynthesis.org/RationalInterpolation + + """ + from sympy.matrices.dense import ones + + xdata, ydata = list(zip(*data)) + + k = len(xdata) - degnum - 1 + if k < 0: + raise OptionError("Too few values for the required degree.") + c = ones(degnum + k + 1, degnum + k + 2) + for j in range(max(degnum, k)): + for i in range(degnum + k + 1): + c[i, j + 1] = c[i, j]*xdata[i] + for j in range(k + 1): + for i in range(degnum + k + 1): + c[i, degnum + k + 1 - j] = -c[i, k - j]*ydata[i] + r = c.nullspace()[0] + return (sum(r[i] * X**i for i in range(degnum + 1)) + / sum(r[i + degnum + 1] * X**i for i in range(k + 1))) + + +@public +def viete(f, roots=None, *gens, **args): + """ + Generate Viete's formulas for ``f``. + + Examples + ======== + + >>> from sympy.polys.polyfuncs import viete + >>> from sympy import symbols + + >>> x, a, b, c, r1, r2 = symbols('x,a:c,r1:3') + + >>> viete(a*x**2 + b*x + c, [r1, r2], x) + [(r1 + r2, -b/a), (r1*r2, c/a)] + + """ + allowed_flags(args, []) + + if isinstance(roots, Basic): + gens, roots = (roots,) + gens, None + + try: + f, opt = poly_from_expr(f, *gens, **args) + except PolificationFailed as exc: + raise ComputationFailed('viete', 1, exc) + + if f.is_multivariate: + raise MultivariatePolynomialError( + "multivariate polynomials are not allowed") + + n = f.degree() + + if n < 1: + raise ValueError( + "Cannot derive Viete's formulas for a constant polynomial") + + if roots is None: + roots = numbered_symbols('r', start=1) + + roots = take(roots, n) + + if n != len(roots): + raise ValueError("required %s roots, got %s" % (n, len(roots))) + + lc, coeffs = f.LC(), f.all_coeffs() + result, sign = [], -1 + + for i, coeff in enumerate(coeffs[1:]): + poly = symmetric_poly(i + 1, roots) + coeff = sign*(coeff/lc) + result.append((poly, coeff)) + sign = -sign + + return result diff --git a/phivenv/Lib/site-packages/sympy/polys/polymatrix.py b/phivenv/Lib/site-packages/sympy/polys/polymatrix.py new file mode 100644 index 0000000000000000000000000000000000000000..fb2a58efc3ebfd85507ac2b0cfd31230e55ded66 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/polymatrix.py @@ -0,0 +1,292 @@ +from sympy.core.expr import Expr +from sympy.core.symbol import Dummy +from sympy.core.sympify import _sympify + +from sympy.polys.polyerrors import CoercionFailed +from sympy.polys.polytools import Poly, parallel_poly_from_expr +from sympy.polys.domains import QQ + +from sympy.polys.matrices import DomainMatrix +from sympy.polys.matrices.domainscalar import DomainScalar + + +class MutablePolyDenseMatrix: + """ + A mutable matrix of objects from poly module or to operate with them. + + Examples + ======== + + >>> from sympy.polys.polymatrix import PolyMatrix + >>> from sympy import Symbol, Poly + >>> x = Symbol('x') + >>> pm1 = PolyMatrix([[Poly(x**2, x), Poly(-x, x)], [Poly(x**3, x), Poly(-1 + x, x)]]) + >>> v1 = PolyMatrix([[1, 0], [-1, 0]], x) + >>> pm1*v1 + PolyMatrix([ + [ x**2 + x, 0], + [x**3 - x + 1, 0]], ring=QQ[x]) + + >>> pm1.ring + ZZ[x] + + >>> v1*pm1 + PolyMatrix([ + [ x**2, -x], + [-x**2, x]], ring=QQ[x]) + + >>> pm2 = PolyMatrix([[Poly(x**2, x, domain='QQ'), Poly(0, x, domain='QQ'), Poly(1, x, domain='QQ'), \ + Poly(x**3, x, domain='QQ'), Poly(0, x, domain='QQ'), Poly(-x**3, x, domain='QQ')]]) + >>> v2 = PolyMatrix([1, 0, 0, 0, 0, 0], x) + >>> v2.ring + QQ[x] + >>> pm2*v2 + PolyMatrix([[x**2]], ring=QQ[x]) + + """ + + def __new__(cls, *args, ring=None): + + if not args: + # PolyMatrix(ring=QQ[x]) + if ring is None: + raise TypeError("The ring needs to be specified for an empty PolyMatrix") + rows, cols, items, gens = 0, 0, [], () + elif isinstance(args[0], list): + elements, gens = args[0], args[1:] + if not elements: + # PolyMatrix([]) + rows, cols, items = 0, 0, [] + elif isinstance(elements[0], (list, tuple)): + # PolyMatrix([[1, 2]], x) + rows, cols = len(elements), len(elements[0]) + items = [e for row in elements for e in row] + else: + # PolyMatrix([1, 2], x) + rows, cols = len(elements), 1 + items = elements + elif [type(a) for a in args[:3]] == [int, int, list]: + # PolyMatrix(2, 2, [1, 2, 3, 4], x) + rows, cols, items, gens = args[0], args[1], args[2], args[3:] + elif [type(a) for a in args[:3]] == [int, int, type(lambda: 0)]: + # PolyMatrix(2, 2, lambda i, j: i+j, x) + rows, cols, func, gens = args[0], args[1], args[2], args[3:] + items = [func(i, j) for i in range(rows) for j in range(cols)] + else: + raise TypeError("Invalid arguments") + + # PolyMatrix([[1]], x, y) vs PolyMatrix([[1]], (x, y)) + if len(gens) == 1 and isinstance(gens[0], tuple): + gens = gens[0] + # gens is now a tuple (x, y) + + return cls.from_list(rows, cols, items, gens, ring) + + @classmethod + def from_list(cls, rows, cols, items, gens, ring): + + # items can be Expr, Poly, or a mix of Expr and Poly + items = [_sympify(item) for item in items] + if items and all(isinstance(item, Poly) for item in items): + polys = True + else: + polys = False + + # Identify the ring for the polys + if ring is not None: + # Parse a domain string like 'QQ[x]' + if isinstance(ring, str): + ring = Poly(0, Dummy(), domain=ring).domain + elif polys: + p = items[0] + for p2 in items[1:]: + p, _ = p.unify(p2) + ring = p.domain[p.gens] + else: + items, info = parallel_poly_from_expr(items, gens, field=True) + ring = info['domain'][info['gens']] + polys = True + + # Efficiently convert when all elements are Poly + if polys: + p_ring = Poly(0, ring.symbols, domain=ring.domain) + to_ring = ring.ring.from_list + convert_poly = lambda p: to_ring(p.unify(p_ring)[0].rep.to_list()) + elements = [convert_poly(p) for p in items] + else: + convert_expr = ring.from_sympy + elements = [convert_expr(e.as_expr()) for e in items] + + # Convert to domain elements and construct DomainMatrix + elements_lol = [[elements[i*cols + j] for j in range(cols)] for i in range(rows)] + dm = DomainMatrix(elements_lol, (rows, cols), ring) + return cls.from_dm(dm) + + @classmethod + def from_dm(cls, dm): + obj = super().__new__(cls) + dm = dm.to_sparse() + R = dm.domain + obj._dm = dm + obj.ring = R + obj.domain = R.domain + obj.gens = R.symbols + return obj + + def to_Matrix(self): + return self._dm.to_Matrix() + + @classmethod + def from_Matrix(cls, other, *gens, ring=None): + return cls(*other.shape, other.flat(), *gens, ring=ring) + + def set_gens(self, gens): + return self.from_Matrix(self.to_Matrix(), gens) + + def __repr__(self): + if self.rows * self.cols: + return 'Poly' + repr(self.to_Matrix())[:-1] + f', ring={self.ring})' + else: + return f'PolyMatrix({self.rows}, {self.cols}, [], ring={self.ring})' + + @property + def shape(self): + return self._dm.shape + + @property + def rows(self): + return self.shape[0] + + @property + def cols(self): + return self.shape[1] + + def __len__(self): + return self.rows * self.cols + + def __getitem__(self, key): + + def to_poly(v): + ground = self._dm.domain.domain + gens = self._dm.domain.symbols + return Poly(v.to_dict(), gens, domain=ground) + + dm = self._dm + + if isinstance(key, slice): + items = dm.flat()[key] + return [to_poly(item) for item in items] + elif isinstance(key, int): + i, j = divmod(key, self.cols) + e = dm[i,j] + return to_poly(e.element) + + i, j = key + if isinstance(i, int) and isinstance(j, int): + return to_poly(dm[i, j].element) + else: + return self.from_dm(dm[i, j]) + + def __eq__(self, other): + if not isinstance(self, type(other)): + return NotImplemented + return self._dm == other._dm + + def __add__(self, other): + if isinstance(other, type(self)): + return self.from_dm(self._dm + other._dm) + return NotImplemented + + def __sub__(self, other): + if isinstance(other, type(self)): + return self.from_dm(self._dm - other._dm) + return NotImplemented + + def __mul__(self, other): + if isinstance(other, type(self)): + return self.from_dm(self._dm * other._dm) + elif isinstance(other, int): + other = _sympify(other) + if isinstance(other, Expr): + Kx = self.ring + try: + other_ds = DomainScalar(Kx.from_sympy(other), Kx) + except (CoercionFailed, ValueError): + other_ds = DomainScalar.from_sympy(other) + return self.from_dm(self._dm * other_ds) + return NotImplemented + + def __rmul__(self, other): + if isinstance(other, int): + other = _sympify(other) + if isinstance(other, Expr): + other_ds = DomainScalar.from_sympy(other) + return self.from_dm(other_ds * self._dm) + return NotImplemented + + def __truediv__(self, other): + + if isinstance(other, Poly): + other = other.as_expr() + elif isinstance(other, int): + other = _sympify(other) + if not isinstance(other, Expr): + return NotImplemented + + other = self.domain.from_sympy(other) + inverse = self.ring.convert_from(1/other, self.domain) + inverse = DomainScalar(inverse, self.ring) + dm = self._dm * inverse + return self.from_dm(dm) + + def __neg__(self): + return self.from_dm(-self._dm) + + def transpose(self): + return self.from_dm(self._dm.transpose()) + + def row_join(self, other): + dm = DomainMatrix.hstack(self._dm, other._dm) + return self.from_dm(dm) + + def col_join(self, other): + dm = DomainMatrix.vstack(self._dm, other._dm) + return self.from_dm(dm) + + def applyfunc(self, func): + M = self.to_Matrix().applyfunc(func) + return self.from_Matrix(M, self.gens) + + @classmethod + def eye(cls, n, gens): + return cls.from_dm(DomainMatrix.eye(n, QQ[gens])) + + @classmethod + def zeros(cls, m, n, gens): + return cls.from_dm(DomainMatrix.zeros((m, n), QQ[gens])) + + def rref(self, simplify='ignore', normalize_last='ignore'): + # If this is K[x] then computes RREF in ground field K. + if not (self.domain.is_Field and all(p.is_ground for p in self)): + raise ValueError("PolyMatrix rref is only for ground field elements") + dm = self._dm + dm_ground = dm.convert_to(dm.domain.domain) + dm_rref, pivots = dm_ground.rref() + dm_rref = dm_rref.convert_to(dm.domain) + return self.from_dm(dm_rref), pivots + + def nullspace(self): + # If this is K[x] then computes nullspace in ground field K. + if not (self.domain.is_Field and all(p.is_ground for p in self)): + raise ValueError("PolyMatrix nullspace is only for ground field elements") + dm = self._dm + K, Kx = self.domain, self.ring + dm_null_rows = dm.convert_to(K).nullspace(divide_last=True).convert_to(Kx) + dm_null = dm_null_rows.transpose() + dm_basis = [dm_null[:,i] for i in range(dm_null.shape[1])] + return [self.from_dm(dmvec) for dmvec in dm_basis] + + def rank(self): + return self.cols - len(self.nullspace()) + +MutablePolyMatrix = PolyMatrix = MutablePolyDenseMatrix diff --git a/phivenv/Lib/site-packages/sympy/polys/polyoptions.py b/phivenv/Lib/site-packages/sympy/polys/polyoptions.py new file mode 100644 index 0000000000000000000000000000000000000000..7b9bd989c4d5676aab32e65c62996137b3e0b73e --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/polyoptions.py @@ -0,0 +1,791 @@ +"""Options manager for :class:`~.Poly` and public API functions. """ + +from __future__ import annotations + +__all__ = ["Options"] + +from sympy.core.basic import Basic +from sympy.core.expr import Expr +from sympy.core.sympify import sympify +from sympy.polys.polyerrors import GeneratorsError, OptionError, FlagError +from sympy.utilities import numbered_symbols, topological_sort, public +from sympy.utilities.iterables import has_dups, is_sequence + +import sympy.polys + +import re + +class Option: + """Base class for all kinds of options. """ + + option: str | None = None + + is_Flag = False + + requires: list[str] = [] + excludes: list[str] = [] + + after: list[str] = [] + before: list[str] = [] + + @classmethod + def default(cls): + return None + + @classmethod + def preprocess(cls, option): + return None + + @classmethod + def postprocess(cls, options): + pass + + +class Flag(Option): + """Base class for all kinds of flags. """ + + is_Flag = True + + +class BooleanOption(Option): + """An option that must have a boolean value or equivalent assigned. """ + + @classmethod + def preprocess(cls, value): + if value in [True, False]: + return bool(value) + else: + raise OptionError("'%s' must have a boolean value assigned, got %s" % (cls.option, value)) + + +class OptionType(type): + """Base type for all options that does registers options. """ + + def __init__(cls, *args, **kwargs): + @property + def getter(self): + try: + return self[cls.option] + except KeyError: + return cls.default() + + setattr(Options, cls.option, getter) + Options.__options__[cls.option] = cls + + +@public +class Options(dict): + """ + Options manager for polynomial manipulation module. + + Examples + ======== + + >>> from sympy.polys.polyoptions import Options + >>> from sympy.polys.polyoptions import build_options + + >>> from sympy.abc import x, y, z + + >>> Options((x, y, z), {'domain': 'ZZ'}) + {'auto': False, 'domain': ZZ, 'gens': (x, y, z)} + + >>> build_options((x, y, z), {'domain': 'ZZ'}) + {'auto': False, 'domain': ZZ, 'gens': (x, y, z)} + + **Options** + + * Expand --- boolean option + * Gens --- option + * Wrt --- option + * Sort --- option + * Order --- option + * Field --- boolean option + * Greedy --- boolean option + * Domain --- option + * Split --- boolean option + * Gaussian --- boolean option + * Extension --- option + * Modulus --- option + * Symmetric --- boolean option + * Strict --- boolean option + + **Flags** + + * Auto --- boolean flag + * Frac --- boolean flag + * Formal --- boolean flag + * Polys --- boolean flag + * Include --- boolean flag + * All --- boolean flag + * Gen --- flag + * Series --- boolean flag + + """ + + __order__ = None + __options__: dict[str, type[Option]] = {} + + gens: tuple[Expr, ...] + domain: sympy.polys.domains.Domain + + def __init__(self, gens, args, flags=None, strict=False): + dict.__init__(self) + + if gens and args.get('gens', ()): + raise OptionError( + "both '*gens' and keyword argument 'gens' supplied") + elif gens: + args = dict(args) + args['gens'] = gens + + defaults = args.pop('defaults', {}) + + def preprocess_options(args): + for option, value in args.items(): + try: + cls = self.__options__[option] + except KeyError: + raise OptionError("'%s' is not a valid option" % option) + + if issubclass(cls, Flag): + if flags is None or option not in flags: + if strict: + raise OptionError("'%s' flag is not allowed in this context" % option) + + if value is not None: + self[option] = cls.preprocess(value) + + preprocess_options(args) + + for key in dict(defaults): + if key in self: + del defaults[key] + else: + for option in self.keys(): + cls = self.__options__[option] + + if key in cls.excludes: + del defaults[key] + break + + preprocess_options(defaults) + + for option in self.keys(): + cls = self.__options__[option] + + for require_option in cls.requires: + if self.get(require_option) is None: + raise OptionError("'%s' option is only allowed together with '%s'" % (option, require_option)) + + for exclude_option in cls.excludes: + if self.get(exclude_option) is not None: + raise OptionError("'%s' option is not allowed together with '%s'" % (option, exclude_option)) + + for option in self.__order__: + self.__options__[option].postprocess(self) + + @classmethod + def _init_dependencies_order(cls): + """Resolve the order of options' processing. """ + if cls.__order__ is None: + vertices, edges = [], set() + + for name, option in cls.__options__.items(): + vertices.append(name) + + edges.update((_name, name) for _name in option.after) + + edges.update((name, _name) for _name in option.before) + + try: + cls.__order__ = topological_sort((vertices, list(edges))) + except ValueError: + raise RuntimeError( + "cycle detected in sympy.polys options framework") + + def clone(self, updates={}): + """Clone ``self`` and update specified options. """ + obj = dict.__new__(self.__class__) + + for option, value in self.items(): + obj[option] = value + + for option, value in updates.items(): + obj[option] = value + + return obj + + def __setattr__(self, attr, value): + if attr in self.__options__: + self[attr] = value + else: + super().__setattr__(attr, value) + + @property + def args(self): + args = {} + + for option, value in self.items(): + if value is not None and option != 'gens': + cls = self.__options__[option] + + if not issubclass(cls, Flag): + args[option] = value + + return args + + @property + def options(self): + options = {} + + for option, cls in self.__options__.items(): + if not issubclass(cls, Flag): + options[option] = getattr(self, option) + + return options + + @property + def flags(self): + flags = {} + + for option, cls in self.__options__.items(): + if issubclass(cls, Flag): + flags[option] = getattr(self, option) + + return flags + + +class Expand(BooleanOption, metaclass=OptionType): + """``expand`` option to polynomial manipulation functions. """ + + option = 'expand' + + requires: list[str] = [] + excludes: list[str] = [] + + @classmethod + def default(cls): + return True + + +class Gens(Option, metaclass=OptionType): + """``gens`` option to polynomial manipulation functions. """ + + option = 'gens' + + requires: list[str] = [] + excludes: list[str] = [] + + @classmethod + def default(cls): + return () + + @classmethod + def preprocess(cls, gens): + if isinstance(gens, Basic): + gens = (gens,) + elif len(gens) == 1 and is_sequence(gens[0]): + gens = gens[0] + + if gens == (None,): + gens = () + elif has_dups(gens): + raise GeneratorsError("duplicated generators: %s" % str(gens)) + elif any(gen.is_commutative is False for gen in gens): + raise GeneratorsError("non-commutative generators: %s" % str(gens)) + + return tuple(gens) + + +class Wrt(Option, metaclass=OptionType): + """``wrt`` option to polynomial manipulation functions. """ + + option = 'wrt' + + requires: list[str] = [] + excludes: list[str] = [] + + _re_split = re.compile(r"\s*,\s*|\s+") + + @classmethod + def preprocess(cls, wrt): + if isinstance(wrt, Basic): + return [str(wrt)] + elif isinstance(wrt, str): + wrt = wrt.strip() + if wrt.endswith(','): + raise OptionError('Bad input: missing parameter.') + if not wrt: + return [] + return list(cls._re_split.split(wrt)) + elif hasattr(wrt, '__getitem__'): + return list(map(str, wrt)) + else: + raise OptionError("invalid argument for 'wrt' option") + + +class Sort(Option, metaclass=OptionType): + """``sort`` option to polynomial manipulation functions. """ + + option = 'sort' + + requires: list[str] = [] + excludes: list[str] = [] + + @classmethod + def default(cls): + return [] + + @classmethod + def preprocess(cls, sort): + if isinstance(sort, str): + return [ gen.strip() for gen in sort.split('>') ] + elif hasattr(sort, '__getitem__'): + return list(map(str, sort)) + else: + raise OptionError("invalid argument for 'sort' option") + + +class Order(Option, metaclass=OptionType): + """``order`` option to polynomial manipulation functions. """ + + option = 'order' + + requires: list[str] = [] + excludes: list[str] = [] + + @classmethod + def default(cls): + return sympy.polys.orderings.lex + + @classmethod + def preprocess(cls, order): + return sympy.polys.orderings.monomial_key(order) + + +class Field(BooleanOption, metaclass=OptionType): + """``field`` option to polynomial manipulation functions. """ + + option = 'field' + + requires: list[str] = [] + excludes = ['domain', 'split', 'gaussian'] + + +class Greedy(BooleanOption, metaclass=OptionType): + """``greedy`` option to polynomial manipulation functions. """ + + option = 'greedy' + + requires: list[str] = [] + excludes = ['domain', 'split', 'gaussian', 'extension', 'modulus', 'symmetric'] + + +class Composite(BooleanOption, metaclass=OptionType): + """``composite`` option to polynomial manipulation functions. """ + + option = 'composite' + + @classmethod + def default(cls): + return None + + requires: list[str] = [] + excludes = ['domain', 'split', 'gaussian', 'extension', 'modulus', 'symmetric'] + + +class Domain(Option, metaclass=OptionType): + """``domain`` option to polynomial manipulation functions. """ + + option = 'domain' + + requires: list[str] = [] + excludes = ['field', 'greedy', 'split', 'gaussian', 'extension'] + + after = ['gens'] + + _re_realfield = re.compile(r"^(R|RR)(_(\d+))?$") + _re_complexfield = re.compile(r"^(C|CC)(_(\d+))?$") + _re_finitefield = re.compile(r"^(FF|GF)\((\d+)\)$") + _re_polynomial = re.compile(r"^(Z|ZZ|Q|QQ|ZZ_I|QQ_I|R|RR|C|CC)\[(.+)\]$") + _re_fraction = re.compile(r"^(Z|ZZ|Q|QQ)\((.+)\)$") + _re_algebraic = re.compile(r"^(Q|QQ)\<(.+)\>$") + + @classmethod + def preprocess(cls, domain): + if isinstance(domain, sympy.polys.domains.Domain): + return domain + elif hasattr(domain, 'to_domain'): + return domain.to_domain() + elif isinstance(domain, str): + if domain in ['Z', 'ZZ']: + return sympy.polys.domains.ZZ + + if domain in ['Q', 'QQ']: + return sympy.polys.domains.QQ + + if domain == 'ZZ_I': + return sympy.polys.domains.ZZ_I + + if domain == 'QQ_I': + return sympy.polys.domains.QQ_I + + if domain == 'EX': + return sympy.polys.domains.EX + + r = cls._re_realfield.match(domain) + + if r is not None: + _, _, prec = r.groups() + + if prec is None: + return sympy.polys.domains.RR + else: + return sympy.polys.domains.RealField(int(prec)) + + r = cls._re_complexfield.match(domain) + + if r is not None: + _, _, prec = r.groups() + + if prec is None: + return sympy.polys.domains.CC + else: + return sympy.polys.domains.ComplexField(int(prec)) + + r = cls._re_finitefield.match(domain) + + if r is not None: + return sympy.polys.domains.FF(int(r.groups()[1])) + + r = cls._re_polynomial.match(domain) + + if r is not None: + ground, gens = r.groups() + + gens = list(map(sympify, gens.split(','))) + + if ground in ['Z', 'ZZ']: + return sympy.polys.domains.ZZ.poly_ring(*gens) + elif ground in ['Q', 'QQ']: + return sympy.polys.domains.QQ.poly_ring(*gens) + elif ground in ['R', 'RR']: + return sympy.polys.domains.RR.poly_ring(*gens) + elif ground == 'ZZ_I': + return sympy.polys.domains.ZZ_I.poly_ring(*gens) + elif ground == 'QQ_I': + return sympy.polys.domains.QQ_I.poly_ring(*gens) + else: + return sympy.polys.domains.CC.poly_ring(*gens) + + r = cls._re_fraction.match(domain) + + if r is not None: + ground, gens = r.groups() + + gens = list(map(sympify, gens.split(','))) + + if ground in ['Z', 'ZZ']: + return sympy.polys.domains.ZZ.frac_field(*gens) + else: + return sympy.polys.domains.QQ.frac_field(*gens) + + r = cls._re_algebraic.match(domain) + + if r is not None: + gens = list(map(sympify, r.groups()[1].split(','))) + return sympy.polys.domains.QQ.algebraic_field(*gens) + + raise OptionError('expected a valid domain specification, got %s' % domain) + + @classmethod + def postprocess(cls, options): + if 'gens' in options and 'domain' in options and options['domain'].is_Composite and \ + (set(options['domain'].symbols) & set(options['gens'])): + raise GeneratorsError( + "ground domain and generators interfere together") + elif ('gens' not in options or not options['gens']) and \ + 'domain' in options and options['domain'] == sympy.polys.domains.EX: + raise GeneratorsError("you have to provide generators because EX domain was requested") + + +class Split(BooleanOption, metaclass=OptionType): + """``split`` option to polynomial manipulation functions. """ + + option = 'split' + + requires: list[str] = [] + excludes = ['field', 'greedy', 'domain', 'gaussian', 'extension', + 'modulus', 'symmetric'] + + @classmethod + def postprocess(cls, options): + if 'split' in options: + raise NotImplementedError("'split' option is not implemented yet") + + +class Gaussian(BooleanOption, metaclass=OptionType): + """``gaussian`` option to polynomial manipulation functions. """ + + option = 'gaussian' + + requires: list[str] = [] + excludes = ['field', 'greedy', 'domain', 'split', 'extension', + 'modulus', 'symmetric'] + + @classmethod + def postprocess(cls, options): + if 'gaussian' in options and options['gaussian'] is True: + options['domain'] = sympy.polys.domains.QQ_I + Extension.postprocess(options) + + +class Extension(Option, metaclass=OptionType): + """``extension`` option to polynomial manipulation functions. """ + + option = 'extension' + + requires: list[str] = [] + excludes = ['greedy', 'domain', 'split', 'gaussian', 'modulus', + 'symmetric'] + + @classmethod + def preprocess(cls, extension): + if extension == 1: + return bool(extension) + elif extension == 0: + raise OptionError("'False' is an invalid argument for 'extension'") + else: + if not hasattr(extension, '__iter__'): + extension = {extension} + else: + if not extension: + extension = None + else: + extension = set(extension) + + return extension + + @classmethod + def postprocess(cls, options): + if 'extension' in options and options['extension'] is not True: + options['domain'] = sympy.polys.domains.QQ.algebraic_field( + *options['extension']) + + +class Modulus(Option, metaclass=OptionType): + """``modulus`` option to polynomial manipulation functions. """ + + option = 'modulus' + + requires: list[str] = [] + excludes = ['greedy', 'split', 'domain', 'gaussian', 'extension'] + + @classmethod + def preprocess(cls, modulus): + modulus = sympify(modulus) + + if modulus.is_Integer and modulus > 0: + return int(modulus) + else: + raise OptionError( + "'modulus' must a positive integer, got %s" % modulus) + + @classmethod + def postprocess(cls, options): + if 'modulus' in options: + modulus = options['modulus'] + symmetric = options.get('symmetric', True) + options['domain'] = sympy.polys.domains.FF(modulus, symmetric) + + +class Symmetric(BooleanOption, metaclass=OptionType): + """``symmetric`` option to polynomial manipulation functions. """ + + option = 'symmetric' + + requires = ['modulus'] + excludes = ['greedy', 'domain', 'split', 'gaussian', 'extension'] + + +class Strict(BooleanOption, metaclass=OptionType): + """``strict`` option to polynomial manipulation functions. """ + + option = 'strict' + + @classmethod + def default(cls): + return True + + +class Auto(BooleanOption, Flag, metaclass=OptionType): + """``auto`` flag to polynomial manipulation functions. """ + + option = 'auto' + + after = ['field', 'domain', 'extension', 'gaussian'] + + @classmethod + def default(cls): + return True + + @classmethod + def postprocess(cls, options): + if ('domain' in options or 'field' in options) and 'auto' not in options: + options['auto'] = False + + +class Frac(BooleanOption, Flag, metaclass=OptionType): + """``auto`` option to polynomial manipulation functions. """ + + option = 'frac' + + @classmethod + def default(cls): + return False + + +class Formal(BooleanOption, Flag, metaclass=OptionType): + """``formal`` flag to polynomial manipulation functions. """ + + option = 'formal' + + @classmethod + def default(cls): + return False + + +class Polys(BooleanOption, Flag, metaclass=OptionType): + """``polys`` flag to polynomial manipulation functions. """ + + option = 'polys' + + +class Include(BooleanOption, Flag, metaclass=OptionType): + """``include`` flag to polynomial manipulation functions. """ + + option = 'include' + + @classmethod + def default(cls): + return False + + +class All(BooleanOption, Flag, metaclass=OptionType): + """``all`` flag to polynomial manipulation functions. """ + + option = 'all' + + @classmethod + def default(cls): + return False + + +class Gen(Flag, metaclass=OptionType): + """``gen`` flag to polynomial manipulation functions. """ + + option = 'gen' + + @classmethod + def default(cls): + return 0 + + @classmethod + def preprocess(cls, gen): + if isinstance(gen, (Basic, int)): + return gen + else: + raise OptionError("invalid argument for 'gen' option") + + +class Series(BooleanOption, Flag, metaclass=OptionType): + """``series`` flag to polynomial manipulation functions. """ + + option = 'series' + + @classmethod + def default(cls): + return False + + +class Symbols(Flag, metaclass=OptionType): + """``symbols`` flag to polynomial manipulation functions. """ + + option = 'symbols' + + @classmethod + def default(cls): + return numbered_symbols('s', start=1) + + @classmethod + def preprocess(cls, symbols): + if hasattr(symbols, '__iter__'): + return iter(symbols) + else: + raise OptionError("expected an iterator or iterable container, got %s" % symbols) + + +class Method(Flag, metaclass=OptionType): + """``method`` flag to polynomial manipulation functions. """ + + option = 'method' + + @classmethod + def preprocess(cls, method): + if isinstance(method, str): + return method.lower() + else: + raise OptionError("expected a string, got %s" % method) + + +def build_options(gens, args=None): + """Construct options from keyword arguments or ... options. """ + if args is None: + gens, args = (), gens + + if len(args) != 1 or 'opt' not in args or gens: + return Options(gens, args) + else: + return args['opt'] + + +def allowed_flags(args, flags): + """ + Allow specified flags to be used in the given context. + + Examples + ======== + + >>> from sympy.polys.polyoptions import allowed_flags + >>> from sympy.polys.domains import ZZ + + >>> allowed_flags({'domain': ZZ}, []) + + >>> allowed_flags({'domain': ZZ, 'frac': True}, []) + Traceback (most recent call last): + ... + FlagError: 'frac' flag is not allowed in this context + + >>> allowed_flags({'domain': ZZ, 'frac': True}, ['frac']) + + """ + flags = set(flags) + + for arg in args.keys(): + try: + if Options.__options__[arg].is_Flag and arg not in flags: + raise FlagError( + "'%s' flag is not allowed in this context" % arg) + except KeyError: + raise OptionError("'%s' is not a valid option" % arg) + + +def set_defaults(options, **defaults): + """Update options with default values. """ + if 'defaults' not in options: + options = dict(options) + options['defaults'] = defaults + + return options + +Options._init_dependencies_order() diff --git a/phivenv/Lib/site-packages/sympy/polys/polyquinticconst.py b/phivenv/Lib/site-packages/sympy/polys/polyquinticconst.py new file mode 100644 index 0000000000000000000000000000000000000000..3b17096fd2cf3b205c3b819eb11ffc2012ea125b --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/polyquinticconst.py @@ -0,0 +1,187 @@ +""" +Solving solvable quintics - An implementation of DS Dummit's paper + +Paper : +https://www.ams.org/journals/mcom/1991-57-195/S0025-5718-1991-1079014-X/S0025-5718-1991-1079014-X.pdf + +Mathematica notebook: +http://www.emba.uvm.edu/~ddummit/quintics/quintics.nb + +""" + + +from sympy.core import Symbol +from sympy.core.evalf import N +from sympy.core.numbers import I, Rational +from sympy.functions import sqrt +from sympy.polys.polytools import Poly +from sympy.utilities import public + +x = Symbol('x') + +@public +class PolyQuintic: + """Special functions for solvable quintics""" + def __init__(self, poly): + _, _, self.p, self.q, self.r, self.s = poly.all_coeffs() + self.zeta1 = Rational(-1, 4) + (sqrt(5)/4) + I*sqrt((sqrt(5)/8) + Rational(5, 8)) + self.zeta2 = (-sqrt(5)/4) - Rational(1, 4) + I*sqrt((-sqrt(5)/8) + Rational(5, 8)) + self.zeta3 = (-sqrt(5)/4) - Rational(1, 4) - I*sqrt((-sqrt(5)/8) + Rational(5, 8)) + self.zeta4 = Rational(-1, 4) + (sqrt(5)/4) - I*sqrt((sqrt(5)/8) + Rational(5, 8)) + + @property + def f20(self): + p, q, r, s = self.p, self.q, self.r, self.s + f20 = q**8 - 13*p*q**6*r + p**5*q**2*r**2 + 65*p**2*q**4*r**2 - 4*p**6*r**3 - 128*p**3*q**2*r**3 + 17*q**4*r**3 + 48*p**4*r**4 - 16*p*q**2*r**4 - 192*p**2*r**5 + 256*r**6 - 4*p**5*q**3*s - 12*p**2*q**5*s + 18*p**6*q*r*s + 12*p**3*q**3*r*s - 124*q**5*r*s + 196*p**4*q*r**2*s + 590*p*q**3*r**2*s - 160*p**2*q*r**3*s - 1600*q*r**4*s - 27*p**7*s**2 - 150*p**4*q**2*s**2 - 125*p*q**4*s**2 - 99*p**5*r*s**2 - 725*p**2*q**2*r*s**2 + 1200*p**3*r**2*s**2 + 3250*q**2*r**2*s**2 - 2000*p*r**3*s**2 - 1250*p*q*r*s**3 + 3125*p**2*s**4 - 9375*r*s**4-(2*p*q**6 - 19*p**2*q**4*r + 51*p**3*q**2*r**2 - 3*q**4*r**2 - 32*p**4*r**3 - 76*p*q**2*r**3 + 256*p**2*r**4 - 512*r**5 + 31*p**3*q**3*s + 58*q**5*s - 117*p**4*q*r*s - 105*p*q**3*r*s - 260*p**2*q*r**2*s + 2400*q*r**3*s + 108*p**5*s**2 + 325*p**2*q**2*s**2 - 525*p**3*r*s**2 - 2750*q**2*r*s**2 + 500*p*r**2*s**2 - 625*p*q*s**3 + 3125*s**4)*x+(p**2*q**4 - 6*p**3*q**2*r - 8*q**4*r + 9*p**4*r**2 + 76*p*q**2*r**2 - 136*p**2*r**3 + 400*r**4 - 50*p*q**3*s + 90*p**2*q*r*s - 1400*q*r**2*s + 625*q**2*s**2 + 500*p*r*s**2)*x**2-(2*q**4 - 21*p*q**2*r + 40*p**2*r**2 - 160*r**3 + 15*p**2*q*s + 400*q*r*s - 125*p*s**2)*x**3+(2*p*q**2 - 6*p**2*r + 40*r**2 - 50*q*s)*x**4 + 8*r*x**5 + x**6 + return Poly(f20, x) + + @property + def b(self): + p, q, r, s = self.p, self.q, self.r, self.s + b = ( [], [0,0,0,0,0,0], [0,0,0,0,0,0], [0,0,0,0,0,0], [0,0,0,0,0,0],) + + b[1][5] = 100*p**7*q**7 + 2175*p**4*q**9 + 10500*p*q**11 - 1100*p**8*q**5*r - 27975*p**5*q**7*r - 152950*p**2*q**9*r + 4125*p**9*q**3*r**2 + 128875*p**6*q**5*r**2 + 830525*p**3*q**7*r**2 - 59450*q**9*r**2 - 5400*p**10*q*r**3 - 243800*p**7*q**3*r**3 - 2082650*p**4*q**5*r**3 + 333925*p*q**7*r**3 + 139200*p**8*q*r**4 + 2406000*p**5*q**3*r**4 + 122600*p**2*q**5*r**4 - 1254400*p**6*q*r**5 - 3776000*p**3*q**3*r**5 - 1832000*q**5*r**5 + 4736000*p**4*q*r**6 + 6720000*p*q**3*r**6 - 6400000*p**2*q*r**7 + 900*p**9*q**4*s + 37400*p**6*q**6*s + 281625*p**3*q**8*s + 435000*q**10*s - 6750*p**10*q**2*r*s - 322300*p**7*q**4*r*s - 2718575*p**4*q**6*r*s - 4214250*p*q**8*r*s + 16200*p**11*r**2*s + 859275*p**8*q**2*r**2*s + 8925475*p**5*q**4*r**2*s + 14427875*p**2*q**6*r**2*s - 453600*p**9*r**3*s - 10038400*p**6*q**2*r**3*s - 17397500*p**3*q**4*r**3*s + 11333125*q**6*r**3*s + 4451200*p**7*r**4*s + 15850000*p**4*q**2*r**4*s - 34000000*p*q**4*r**4*s - 17984000*p**5*r**5*s + 10000000*p**2*q**2*r**5*s + 25600000*p**3*r**6*s + 8000000*q**2*r**6*s - 6075*p**11*q*s**2 + 83250*p**8*q**3*s**2 + 1282500*p**5*q**5*s**2 + 2862500*p**2*q**7*s**2 - 724275*p**9*q*r*s**2 - 9807250*p**6*q**3*r*s**2 - 28374375*p**3*q**5*r*s**2 - 22212500*q**7*r*s**2 + 8982000*p**7*q*r**2*s**2 + 39600000*p**4*q**3*r**2*s**2 + 61746875*p*q**5*r**2*s**2 + 1010000*p**5*q*r**3*s**2 + 1000000*p**2*q**3*r**3*s**2 - 78000000*p**3*q*r**4*s**2 - 30000000*q**3*r**4*s**2 - 80000000*p*q*r**5*s**2 + 759375*p**10*s**3 + 9787500*p**7*q**2*s**3 + 39062500*p**4*q**4*s**3 + 52343750*p*q**6*s**3 - 12301875*p**8*r*s**3 - 98175000*p**5*q**2*r*s**3 - 225078125*p**2*q**4*r*s**3 + 54900000*p**6*r**2*s**3 + 310000000*p**3*q**2*r**2*s**3 + 7890625*q**4*r**2*s**3 - 51250000*p**4*r**3*s**3 + 420000000*p*q**2*r**3*s**3 - 110000000*p**2*r**4*s**3 + 200000000*r**5*s**3 - 2109375*p**6*q*s**4 + 21093750*p**3*q**3*s**4 + 89843750*q**5*s**4 - 182343750*p**4*q*r*s**4 - 733203125*p*q**3*r*s**4 + 196875000*p**2*q*r**2*s**4 - 1125000000*q*r**3*s**4 + 158203125*p**5*s**5 + 566406250*p**2*q**2*s**5 - 101562500*p**3*r*s**5 + 1669921875*q**2*r*s**5 - 1250000000*p*r**2*s**5 + 1220703125*p*q*s**6 - 6103515625*s**7 + + b[1][4] = -1000*p**5*q**7 - 7250*p**2*q**9 + 10800*p**6*q**5*r + 96900*p**3*q**7*r + 52500*q**9*r - 37400*p**7*q**3*r**2 - 470850*p**4*q**5*r**2 - 640600*p*q**7*r**2 + 39600*p**8*q*r**3 + 983600*p**5*q**3*r**3 + 2848100*p**2*q**5*r**3 - 814400*p**6*q*r**4 - 6076000*p**3*q**3*r**4 - 2308000*q**5*r**4 + 5024000*p**4*q*r**5 + 9680000*p*q**3*r**5 - 9600000*p**2*q*r**6 - 13800*p**7*q**4*s - 94650*p**4*q**6*s + 26500*p*q**8*s + 86400*p**8*q**2*r*s + 816500*p**5*q**4*r*s + 257500*p**2*q**6*r*s - 91800*p**9*r**2*s - 1853700*p**6*q**2*r**2*s - 630000*p**3*q**4*r**2*s + 8971250*q**6*r**2*s + 2071200*p**7*r**3*s + 7240000*p**4*q**2*r**3*s - 29375000*p*q**4*r**3*s - 14416000*p**5*r**4*s + 5200000*p**2*q**2*r**4*s + 30400000*p**3*r**5*s + 12000000*q**2*r**5*s - 64800*p**9*q*s**2 - 567000*p**6*q**3*s**2 - 1655000*p**3*q**5*s**2 - 6987500*q**7*s**2 - 337500*p**7*q*r*s**2 - 8462500*p**4*q**3*r*s**2 + 5812500*p*q**5*r*s**2 + 24930000*p**5*q*r**2*s**2 + 69125000*p**2*q**3*r**2*s**2 - 103500000*p**3*q*r**3*s**2 - 30000000*q**3*r**3*s**2 - 90000000*p*q*r**4*s**2 + 708750*p**8*s**3 + 5400000*p**5*q**2*s**3 - 8906250*p**2*q**4*s**3 - 18562500*p**6*r*s**3 + 625000*p**3*q**2*r*s**3 - 29687500*q**4*r*s**3 + 75000000*p**4*r**2*s**3 + 416250000*p*q**2*r**2*s**3 - 60000000*p**2*r**3*s**3 + 300000000*r**4*s**3 - 71718750*p**4*q*s**4 - 189062500*p*q**3*s**4 - 210937500*p**2*q*r*s**4 - 1187500000*q*r**2*s**4 + 187500000*p**3*s**5 + 800781250*q**2*s**5 + 390625000*p*r*s**5 + + b[1][3] = 500*p**6*q**5 + 6350*p**3*q**7 + 19800*q**9 - 3750*p**7*q**3*r - 65100*p**4*q**5*r - 264950*p*q**7*r + 6750*p**8*q*r**2 + 209050*p**5*q**3*r**2 + 1217250*p**2*q**5*r**2 - 219000*p**6*q*r**3 - 2510000*p**3*q**3*r**3 - 1098500*q**5*r**3 + 2068000*p**4*q*r**4 + 5060000*p*q**3*r**4 - 5200000*p**2*q*r**5 + 6750*p**8*q**2*s + 96350*p**5*q**4*s + 346000*p**2*q**6*s - 20250*p**9*r*s - 459900*p**6*q**2*r*s - 1828750*p**3*q**4*r*s + 2930000*q**6*r*s + 594000*p**7*r**2*s + 4301250*p**4*q**2*r**2*s - 10906250*p*q**4*r**2*s - 5252000*p**5*r**3*s + 1450000*p**2*q**2*r**3*s + 12800000*p**3*r**4*s + 6500000*q**2*r**4*s - 74250*p**7*q*s**2 - 1418750*p**4*q**3*s**2 - 5956250*p*q**5*s**2 + 4297500*p**5*q*r*s**2 + 29906250*p**2*q**3*r*s**2 - 31500000*p**3*q*r**2*s**2 - 12500000*q**3*r**2*s**2 - 35000000*p*q*r**3*s**2 - 1350000*p**6*s**3 - 6093750*p**3*q**2*s**3 - 17500000*q**4*s**3 + 7031250*p**4*r*s**3 + 127812500*p*q**2*r*s**3 - 18750000*p**2*r**2*s**3 + 162500000*r**3*s**3 - 107812500*p**2*q*s**4 - 460937500*q*r*s**4 + 214843750*p*s**5 + + b[1][2] = -1950*p**4*q**5 - 14100*p*q**7 + 14350*p**5*q**3*r + 125600*p**2*q**5*r - 27900*p**6*q*r**2 - 402250*p**3*q**3*r**2 - 288250*q**5*r**2 + 436000*p**4*q*r**3 + 1345000*p*q**3*r**3 - 1400000*p**2*q*r**4 - 9450*p**6*q**2*s + 1250*p**3*q**4*s + 465000*q**6*s + 49950*p**7*r*s + 302500*p**4*q**2*r*s - 1718750*p*q**4*r*s - 834000*p**5*r**2*s - 437500*p**2*q**2*r**2*s + 3100000*p**3*r**3*s + 1750000*q**2*r**3*s + 292500*p**5*q*s**2 + 1937500*p**2*q**3*s**2 - 3343750*p**3*q*r*s**2 - 1875000*q**3*r*s**2 - 8125000*p*q*r**2*s**2 + 1406250*p**4*s**3 + 12343750*p*q**2*s**3 - 5312500*p**2*r*s**3 + 43750000*r**2*s**3 - 74218750*q*s**4 + + b[1][1] = 300*p**5*q**3 + 2150*p**2*q**5 - 1350*p**6*q*r - 21500*p**3*q**3*r - 61500*q**5*r + 42000*p**4*q*r**2 + 290000*p*q**3*r**2 - 300000*p**2*q*r**3 + 4050*p**7*s + 45000*p**4*q**2*s + 125000*p*q**4*s - 108000*p**5*r*s - 643750*p**2*q**2*r*s + 700000*p**3*r**2*s + 375000*q**2*r**2*s + 93750*p**3*q*s**2 + 312500*q**3*s**2 - 1875000*p*q*r*s**2 + 1406250*p**2*s**3 + 9375000*r*s**3 + + b[1][0] = -1250*p**3*q**3 - 9000*q**5 + 4500*p**4*q*r + 46250*p*q**3*r - 50000*p**2*q*r**2 - 6750*p**5*s - 43750*p**2*q**2*s + 75000*p**3*r*s + 62500*q**2*r*s - 156250*p*q*s**2 + 1562500*s**3 + + b[2][5] = 200*p**6*q**11 - 250*p**3*q**13 - 10800*q**15 - 3900*p**7*q**9*r - 3325*p**4*q**11*r + 181800*p*q**13*r + 26950*p**8*q**7*r**2 + 69625*p**5*q**9*r**2 - 1214450*p**2*q**11*r**2 - 78725*p**9*q**5*r**3 - 368675*p**6*q**7*r**3 + 4166325*p**3*q**9*r**3 + 1131100*q**11*r**3 + 73400*p**10*q**3*r**4 + 661950*p**7*q**5*r**4 - 9151950*p**4*q**7*r**4 - 16633075*p*q**9*r**4 + 36000*p**11*q*r**5 + 135600*p**8*q**3*r**5 + 17321400*p**5*q**5*r**5 + 85338300*p**2*q**7*r**5 - 832000*p**9*q*r**6 - 21379200*p**6*q**3*r**6 - 176044000*p**3*q**5*r**6 - 1410000*q**7*r**6 + 6528000*p**7*q*r**7 + 129664000*p**4*q**3*r**7 + 47344000*p*q**5*r**7 - 21504000*p**5*q*r**8 - 115200000*p**2*q**3*r**8 + 25600000*p**3*q*r**9 + 64000000*q**3*r**9 + 15700*p**8*q**8*s + 120525*p**5*q**10*s + 113250*p**2*q**12*s - 196900*p**9*q**6*r*s - 1776925*p**6*q**8*r*s - 3062475*p**3*q**10*r*s - 4153500*q**12*r*s + 857925*p**10*q**4*r**2*s + 10562775*p**7*q**6*r**2*s + 34866250*p**4*q**8*r**2*s + 73486750*p*q**10*r**2*s - 1333800*p**11*q**2*r**3*s - 29212625*p**8*q**4*r**3*s - 168729675*p**5*q**6*r**3*s - 427230750*p**2*q**8*r**3*s + 108000*p**12*r**4*s + 30384200*p**9*q**2*r**4*s + 324535100*p**6*q**4*r**4*s + 952666750*p**3*q**6*r**4*s - 38076875*q**8*r**4*s - 4296000*p**10*r**5*s - 213606400*p**7*q**2*r**5*s - 842060000*p**4*q**4*r**5*s - 95285000*p*q**6*r**5*s + 61184000*p**8*r**6*s + 567520000*p**5*q**2*r**6*s + 547000000*p**2*q**4*r**6*s - 390912000*p**6*r**7*s - 812800000*p**3*q**2*r**7*s - 924000000*q**4*r**7*s + 1152000000*p**4*r**8*s + 800000000*p*q**2*r**8*s - 1280000000*p**2*r**9*s + 141750*p**10*q**5*s**2 - 31500*p**7*q**7*s**2 - 11325000*p**4*q**9*s**2 - 31687500*p*q**11*s**2 - 1293975*p**11*q**3*r*s**2 - 4803800*p**8*q**5*r*s**2 + 71398250*p**5*q**7*r*s**2 + 227625000*p**2*q**9*r*s**2 + 3256200*p**12*q*r**2*s**2 + 43870125*p**9*q**3*r**2*s**2 + 64581500*p**6*q**5*r**2*s**2 + 56090625*p**3*q**7*r**2*s**2 + 260218750*q**9*r**2*s**2 - 74610000*p**10*q*r**3*s**2 - 662186500*p**7*q**3*r**3*s**2 - 1987747500*p**4*q**5*r**3*s**2 - 811928125*p*q**7*r**3*s**2 + 471286000*p**8*q*r**4*s**2 + 2106040000*p**5*q**3*r**4*s**2 + 792687500*p**2*q**5*r**4*s**2 - 135120000*p**6*q*r**5*s**2 + 2479000000*p**3*q**3*r**5*s**2 + 5242250000*q**5*r**5*s**2 - 6400000000*p**4*q*r**6*s**2 - 8620000000*p*q**3*r**6*s**2 + 13280000000*p**2*q*r**7*s**2 + 1600000000*q*r**8*s**2 + 273375*p**12*q**2*s**3 - 13612500*p**9*q**4*s**3 - 177250000*p**6*q**6*s**3 - 511015625*p**3*q**8*s**3 - 320937500*q**10*s**3 - 2770200*p**13*r*s**3 + 12595500*p**10*q**2*r*s**3 + 543950000*p**7*q**4*r*s**3 + 1612281250*p**4*q**6*r*s**3 + 968125000*p*q**8*r*s**3 + 77031000*p**11*r**2*s**3 + 373218750*p**8*q**2*r**2*s**3 + 1839765625*p**5*q**4*r**2*s**3 + 1818515625*p**2*q**6*r**2*s**3 - 776745000*p**9*r**3*s**3 - 6861075000*p**6*q**2*r**3*s**3 - 20014531250*p**3*q**4*r**3*s**3 - 13747812500*q**6*r**3*s**3 + 3768000000*p**7*r**4*s**3 + 35365000000*p**4*q**2*r**4*s**3 + 34441875000*p*q**4*r**4*s**3 - 9628000000*p**5*r**5*s**3 - 63230000000*p**2*q**2*r**5*s**3 + 13600000000*p**3*r**6*s**3 - 15000000000*q**2*r**6*s**3 - 10400000000*p*r**7*s**3 - 45562500*p**11*q*s**4 - 525937500*p**8*q**3*s**4 - 1364218750*p**5*q**5*s**4 - 1382812500*p**2*q**7*s**4 + 572062500*p**9*q*r*s**4 + 2473515625*p**6*q**3*r*s**4 + 13192187500*p**3*q**5*r*s**4 + 12703125000*q**7*r*s**4 - 451406250*p**7*q*r**2*s**4 - 18153906250*p**4*q**3*r**2*s**4 - 36908203125*p*q**5*r**2*s**4 - 9069375000*p**5*q*r**3*s**4 + 79957812500*p**2*q**3*r**3*s**4 + 5512500000*p**3*q*r**4*s**4 + 50656250000*q**3*r**4*s**4 + 74750000000*p*q*r**5*s**4 + 56953125*p**10*s**5 + 1381640625*p**7*q**2*s**5 - 781250000*p**4*q**4*s**5 + 878906250*p*q**6*s**5 - 2655703125*p**8*r*s**5 - 3223046875*p**5*q**2*r*s**5 - 35117187500*p**2*q**4*r*s**5 + 26573437500*p**6*r**2*s**5 + 14785156250*p**3*q**2*r**2*s**5 - 52050781250*q**4*r**2*s**5 - 103062500000*p**4*r**3*s**5 - 281796875000*p*q**2*r**3*s**5 + 146875000000*p**2*r**4*s**5 - 37500000000*r**5*s**5 - 8789062500*p**6*q*s**6 - 3906250000*p**3*q**3*s**6 + 1464843750*q**5*s**6 + 102929687500*p**4*q*r*s**6 + 297119140625*p*q**3*r*s**6 - 217773437500*p**2*q*r**2*s**6 + 167968750000*q*r**3*s**6 + 10986328125*p**5*s**7 + 98876953125*p**2*q**2*s**7 - 188964843750*p**3*r*s**7 - 278320312500*q**2*r*s**7 + 517578125000*p*r**2*s**7 - 610351562500*p*q*s**8 + 762939453125*s**9 + + b[2][4] = -200*p**7*q**9 + 1850*p**4*q**11 + 21600*p*q**13 + 3200*p**8*q**7*r - 19200*p**5*q**9*r - 316350*p**2*q**11*r - 19050*p**9*q**5*r**2 + 37400*p**6*q**7*r**2 + 1759250*p**3*q**9*r**2 + 440100*q**11*r**2 + 48750*p**10*q**3*r**3 + 190200*p**7*q**5*r**3 - 4604200*p**4*q**7*r**3 - 6072800*p*q**9*r**3 - 43200*p**11*q*r**4 - 834500*p**8*q**3*r**4 + 4916000*p**5*q**5*r**4 + 27926850*p**2*q**7*r**4 + 969600*p**9*q*r**5 + 2467200*p**6*q**3*r**5 - 45393200*p**3*q**5*r**5 - 5399500*q**7*r**5 - 7283200*p**7*q*r**6 + 10536000*p**4*q**3*r**6 + 41656000*p*q**5*r**6 + 22784000*p**5*q*r**7 - 35200000*p**2*q**3*r**7 - 25600000*p**3*q*r**8 + 96000000*q**3*r**8 - 3000*p**9*q**6*s + 40400*p**6*q**8*s + 136550*p**3*q**10*s - 1647000*q**12*s + 40500*p**10*q**4*r*s - 173600*p**7*q**6*r*s - 126500*p**4*q**8*r*s + 23969250*p*q**10*r*s - 153900*p**11*q**2*r**2*s - 486150*p**8*q**4*r**2*s - 4115800*p**5*q**6*r**2*s - 112653250*p**2*q**8*r**2*s + 129600*p**12*r**3*s + 2683350*p**9*q**2*r**3*s + 10906650*p**6*q**4*r**3*s + 187289500*p**3*q**6*r**3*s + 44098750*q**8*r**3*s - 4384800*p**10*r**4*s - 35660800*p**7*q**2*r**4*s - 175420000*p**4*q**4*r**4*s - 426538750*p*q**6*r**4*s + 60857600*p**8*r**5*s + 349436000*p**5*q**2*r**5*s + 900600000*p**2*q**4*r**5*s - 429568000*p**6*r**6*s - 1511200000*p**3*q**2*r**6*s - 1286000000*q**4*r**6*s + 1472000000*p**4*r**7*s + 1440000000*p*q**2*r**7*s - 1920000000*p**2*r**8*s - 36450*p**11*q**3*s**2 - 188100*p**8*q**5*s**2 - 5504750*p**5*q**7*s**2 - 37968750*p**2*q**9*s**2 + 255150*p**12*q*r*s**2 + 2754000*p**9*q**3*r*s**2 + 49196500*p**6*q**5*r*s**2 + 323587500*p**3*q**7*r*s**2 - 83250000*q**9*r*s**2 - 465750*p**10*q*r**2*s**2 - 31881500*p**7*q**3*r**2*s**2 - 415585000*p**4*q**5*r**2*s**2 + 1054775000*p*q**7*r**2*s**2 - 96823500*p**8*q*r**3*s**2 - 701490000*p**5*q**3*r**3*s**2 - 2953531250*p**2*q**5*r**3*s**2 + 1454560000*p**6*q*r**4*s**2 + 7670500000*p**3*q**3*r**4*s**2 + 5661062500*q**5*r**4*s**2 - 7785000000*p**4*q*r**5*s**2 - 9450000000*p*q**3*r**5*s**2 + 14000000000*p**2*q*r**6*s**2 + 2400000000*q*r**7*s**2 - 437400*p**13*s**3 - 10145250*p**10*q**2*s**3 - 121912500*p**7*q**4*s**3 - 576531250*p**4*q**6*s**3 - 528593750*p*q**8*s**3 + 12939750*p**11*r*s**3 + 313368750*p**8*q**2*r*s**3 + 2171812500*p**5*q**4*r*s**3 + 2381718750*p**2*q**6*r*s**3 - 124638750*p**9*r**2*s**3 - 3001575000*p**6*q**2*r**2*s**3 - 12259375000*p**3*q**4*r**2*s**3 - 9985312500*q**6*r**2*s**3 + 384000000*p**7*r**3*s**3 + 13997500000*p**4*q**2*r**3*s**3 + 20749531250*p*q**4*r**3*s**3 - 553500000*p**5*r**4*s**3 - 41835000000*p**2*q**2*r**4*s**3 + 5420000000*p**3*r**5*s**3 - 16300000000*q**2*r**5*s**3 - 17600000000*p*r**6*s**3 - 7593750*p**9*q*s**4 + 289218750*p**6*q**3*s**4 + 3591406250*p**3*q**5*s**4 + 5992187500*q**7*s**4 + 658125000*p**7*q*r*s**4 - 269531250*p**4*q**3*r*s**4 - 15882812500*p*q**5*r*s**4 - 4785000000*p**5*q*r**2*s**4 + 54375781250*p**2*q**3*r**2*s**4 - 5668750000*p**3*q*r**3*s**4 + 35867187500*q**3*r**3*s**4 + 113875000000*p*q*r**4*s**4 - 544218750*p**8*s**5 - 5407031250*p**5*q**2*s**5 - 14277343750*p**2*q**4*s**5 + 5421093750*p**6*r*s**5 - 24941406250*p**3*q**2*r*s**5 - 25488281250*q**4*r*s**5 - 11500000000*p**4*r**2*s**5 - 231894531250*p*q**2*r**2*s**5 - 6250000000*p**2*r**3*s**5 - 43750000000*r**4*s**5 + 35449218750*p**4*q*s**6 + 137695312500*p*q**3*s**6 + 34667968750*p**2*q*r*s**6 + 202148437500*q*r**2*s**6 - 33691406250*p**3*s**7 - 214843750000*q**2*s**7 - 31738281250*p*r*s**7 + + b[2][3] = -800*p**5*q**9 - 5400*p**2*q**11 + 5800*p**6*q**7*r + 48750*p**3*q**9*r + 16200*q**11*r - 3000*p**7*q**5*r**2 - 108350*p**4*q**7*r**2 - 263250*p*q**9*r**2 - 60700*p**8*q**3*r**3 - 386250*p**5*q**5*r**3 + 253100*p**2*q**7*r**3 + 127800*p**9*q*r**4 + 2326700*p**6*q**3*r**4 + 6565550*p**3*q**5*r**4 - 705750*q**7*r**4 - 2903200*p**7*q*r**5 - 21218000*p**4*q**3*r**5 + 1057000*p*q**5*r**5 + 20368000*p**5*q*r**6 + 33000000*p**2*q**3*r**6 - 43200000*p**3*q*r**7 + 52000000*q**3*r**7 + 6200*p**7*q**6*s + 188250*p**4*q**8*s + 931500*p*q**10*s - 73800*p**8*q**4*r*s - 1466850*p**5*q**6*r*s - 6894000*p**2*q**8*r*s + 315900*p**9*q**2*r**2*s + 4547000*p**6*q**4*r**2*s + 20362500*p**3*q**6*r**2*s + 15018750*q**8*r**2*s - 653400*p**10*r**3*s - 13897550*p**7*q**2*r**3*s - 76757500*p**4*q**4*r**3*s - 124207500*p*q**6*r**3*s + 18567600*p**8*r**4*s + 175911000*p**5*q**2*r**4*s + 253787500*p**2*q**4*r**4*s - 183816000*p**6*r**5*s - 706900000*p**3*q**2*r**5*s - 665750000*q**4*r**5*s + 740000000*p**4*r**6*s + 890000000*p*q**2*r**6*s - 1040000000*p**2*r**7*s - 763000*p**6*q**5*s**2 - 12375000*p**3*q**7*s**2 - 40500000*q**9*s**2 + 364500*p**10*q*r*s**2 + 15537000*p**7*q**3*r*s**2 + 154392500*p**4*q**5*r*s**2 + 372206250*p*q**7*r*s**2 - 25481250*p**8*q*r**2*s**2 - 386300000*p**5*q**3*r**2*s**2 - 996343750*p**2*q**5*r**2*s**2 + 459872500*p**6*q*r**3*s**2 + 2943937500*p**3*q**3*r**3*s**2 + 2437781250*q**5*r**3*s**2 - 2883750000*p**4*q*r**4*s**2 - 4343750000*p*q**3*r**4*s**2 + 5495000000*p**2*q*r**5*s**2 + 1300000000*q*r**6*s**2 - 364500*p**11*s**3 - 13668750*p**8*q**2*s**3 - 113406250*p**5*q**4*s**3 - 159062500*p**2*q**6*s**3 + 13972500*p**9*r*s**3 + 61537500*p**6*q**2*r*s**3 - 1622656250*p**3*q**4*r*s**3 - 2720625000*q**6*r*s**3 - 201656250*p**7*r**2*s**3 + 1949687500*p**4*q**2*r**2*s**3 + 4979687500*p*q**4*r**2*s**3 + 497125000*p**5*r**3*s**3 - 11150625000*p**2*q**2*r**3*s**3 + 2982500000*p**3*r**4*s**3 - 6612500000*q**2*r**4*s**3 - 10450000000*p*r**5*s**3 + 126562500*p**7*q*s**4 + 1443750000*p**4*q**3*s**4 + 281250000*p*q**5*s**4 - 1648125000*p**5*q*r*s**4 + 11271093750*p**2*q**3*r*s**4 - 4785156250*p**3*q*r**2*s**4 + 8808593750*q**3*r**2*s**4 + 52390625000*p*q*r**3*s**4 - 611718750*p**6*s**5 - 13027343750*p**3*q**2*s**5 - 1464843750*q**4*s**5 + 6492187500*p**4*r*s**5 - 65351562500*p*q**2*r*s**5 - 13476562500*p**2*r**2*s**5 - 24218750000*r**3*s**5 + 41992187500*p**2*q*s**6 + 69824218750*q*r*s**6 - 34179687500*p*s**7 + + b[2][2] = -1000*p**6*q**7 - 5150*p**3*q**9 + 10800*q**11 + 11000*p**7*q**5*r + 66450*p**4*q**7*r - 127800*p*q**9*r - 41250*p**8*q**3*r**2 - 368400*p**5*q**5*r**2 + 204200*p**2*q**7*r**2 + 54000*p**9*q*r**3 + 1040950*p**6*q**3*r**3 + 2096500*p**3*q**5*r**3 + 200000*q**7*r**3 - 1140000*p**7*q*r**4 - 7691000*p**4*q**3*r**4 - 2281000*p*q**5*r**4 + 7296000*p**5*q*r**5 + 13300000*p**2*q**3*r**5 - 14400000*p**3*q*r**6 + 14000000*q**3*r**6 - 9000*p**8*q**4*s + 52100*p**5*q**6*s + 710250*p**2*q**8*s + 67500*p**9*q**2*r*s - 256100*p**6*q**4*r*s - 5753000*p**3*q**6*r*s + 292500*q**8*r*s - 162000*p**10*r**2*s - 1432350*p**7*q**2*r**2*s + 5410000*p**4*q**4*r**2*s - 7408750*p*q**6*r**2*s + 4401000*p**8*r**3*s + 24185000*p**5*q**2*r**3*s + 20781250*p**2*q**4*r**3*s - 43012000*p**6*r**4*s - 146300000*p**3*q**2*r**4*s - 165875000*q**4*r**4*s + 182000000*p**4*r**5*s + 250000000*p*q**2*r**5*s - 280000000*p**2*r**6*s + 60750*p**10*q*s**2 + 2414250*p**7*q**3*s**2 + 15770000*p**4*q**5*s**2 + 15825000*p*q**7*s**2 - 6021000*p**8*q*r*s**2 - 62252500*p**5*q**3*r*s**2 - 74718750*p**2*q**5*r*s**2 + 90888750*p**6*q*r**2*s**2 + 471312500*p**3*q**3*r**2*s**2 + 525875000*q**5*r**2*s**2 - 539375000*p**4*q*r**3*s**2 - 1030000000*p*q**3*r**3*s**2 + 1142500000*p**2*q*r**4*s**2 + 350000000*q*r**5*s**2 - 303750*p**9*s**3 - 35943750*p**6*q**2*s**3 - 331875000*p**3*q**4*s**3 - 505937500*q**6*s**3 + 8437500*p**7*r*s**3 + 530781250*p**4*q**2*r*s**3 + 1150312500*p*q**4*r*s**3 - 154500000*p**5*r**2*s**3 - 2059062500*p**2*q**2*r**2*s**3 + 1150000000*p**3*r**3*s**3 - 1343750000*q**2*r**3*s**3 - 2900000000*p*r**4*s**3 + 30937500*p**5*q*s**4 + 1166406250*p**2*q**3*s**4 - 1496875000*p**3*q*r*s**4 + 1296875000*q**3*r*s**4 + 10640625000*p*q*r**2*s**4 - 281250000*p**4*s**5 - 9746093750*p*q**2*s**5 + 1269531250*p**2*r*s**5 - 7421875000*r**2*s**5 + 15625000000*q*s**6 + + b[2][1] = -1600*p**4*q**7 - 10800*p*q**9 + 9800*p**5*q**5*r + 80550*p**2*q**7*r - 4600*p**6*q**3*r**2 - 112700*p**3*q**5*r**2 + 40500*q**7*r**2 - 34200*p**7*q*r**3 - 279500*p**4*q**3*r**3 - 665750*p*q**5*r**3 + 632000*p**5*q*r**4 + 3200000*p**2*q**3*r**4 - 2800000*p**3*q*r**5 + 3000000*q**3*r**5 - 18600*p**6*q**4*s - 51750*p**3*q**6*s + 405000*q**8*s + 21600*p**7*q**2*r*s - 122500*p**4*q**4*r*s - 2891250*p*q**6*r*s + 156600*p**8*r**2*s + 1569750*p**5*q**2*r**2*s + 6943750*p**2*q**4*r**2*s - 3774000*p**6*r**3*s - 27100000*p**3*q**2*r**3*s - 30187500*q**4*r**3*s + 28000000*p**4*r**4*s + 52500000*p*q**2*r**4*s - 60000000*p**2*r**5*s - 81000*p**8*q*s**2 - 240000*p**5*q**3*s**2 + 937500*p**2*q**5*s**2 + 3273750*p**6*q*r*s**2 + 30406250*p**3*q**3*r*s**2 + 55687500*q**5*r*s**2 - 42187500*p**4*q*r**2*s**2 - 112812500*p*q**3*r**2*s**2 + 152500000*p**2*q*r**3*s**2 + 75000000*q*r**4*s**2 - 4218750*p**4*q**2*s**3 + 15156250*p*q**4*s**3 + 5906250*p**5*r*s**3 - 206562500*p**2*q**2*r*s**3 + 107500000*p**3*r**2*s**3 - 159375000*q**2*r**2*s**3 - 612500000*p*r**3*s**3 + 135937500*p**3*q*s**4 + 46875000*q**3*s**4 + 1175781250*p*q*r*s**4 - 292968750*p**2*s**5 - 1367187500*r*s**5 + + b[2][0] = -800*p**5*q**5 - 5400*p**2*q**7 + 6000*p**6*q**3*r + 51700*p**3*q**5*r + 27000*q**7*r - 10800*p**7*q*r**2 - 163250*p**4*q**3*r**2 - 285750*p*q**5*r**2 + 192000*p**5*q*r**3 + 1000000*p**2*q**3*r**3 - 800000*p**3*q*r**4 + 500000*q**3*r**4 - 10800*p**7*q**2*s - 57500*p**4*q**4*s + 67500*p*q**6*s + 32400*p**8*r*s + 279000*p**5*q**2*r*s - 131250*p**2*q**4*r*s - 729000*p**6*r**2*s - 4100000*p**3*q**2*r**2*s - 5343750*q**4*r**2*s + 5000000*p**4*r**3*s + 10000000*p*q**2*r**3*s - 10000000*p**2*r**4*s + 641250*p**6*q*s**2 + 5812500*p**3*q**3*s**2 + 10125000*q**5*s**2 - 7031250*p**4*q*r*s**2 - 20625000*p*q**3*r*s**2 + 17500000*p**2*q*r**2*s**2 + 12500000*q*r**3*s**2 - 843750*p**5*s**3 - 19375000*p**2*q**2*s**3 + 30000000*p**3*r*s**3 - 20312500*q**2*r*s**3 - 112500000*p*r**2*s**3 + 183593750*p*q*s**4 - 292968750*s**5 + + b[3][5] = 500*p**11*q**6 + 9875*p**8*q**8 + 42625*p**5*q**10 - 35000*p**2*q**12 - 4500*p**12*q**4*r - 108375*p**9*q**6*r - 516750*p**6*q**8*r + 1110500*p**3*q**10*r + 2730000*q**12*r + 10125*p**13*q**2*r**2 + 358250*p**10*q**4*r**2 + 1908625*p**7*q**6*r**2 - 11744250*p**4*q**8*r**2 - 43383250*p*q**10*r**2 - 313875*p**11*q**2*r**3 - 2074875*p**8*q**4*r**3 + 52094750*p**5*q**6*r**3 + 264567500*p**2*q**8*r**3 + 796125*p**9*q**2*r**4 - 92486250*p**6*q**4*r**4 - 757957500*p**3*q**6*r**4 - 29354375*q**8*r**4 + 60970000*p**7*q**2*r**5 + 1112462500*p**4*q**4*r**5 + 571094375*p*q**6*r**5 - 685290000*p**5*q**2*r**6 - 2037800000*p**2*q**4*r**6 + 2279600000*p**3*q**2*r**7 + 849000000*q**4*r**7 - 1480000000*p*q**2*r**8 + 13500*p**13*q**3*s + 363000*p**10*q**5*s + 2861250*p**7*q**7*s + 8493750*p**4*q**9*s + 17031250*p*q**11*s - 60750*p**14*q*r*s - 2319750*p**11*q**3*r*s - 22674250*p**8*q**5*r*s - 74368750*p**5*q**7*r*s - 170578125*p**2*q**9*r*s + 2760750*p**12*q*r**2*s + 46719000*p**9*q**3*r**2*s + 163356375*p**6*q**5*r**2*s + 360295625*p**3*q**7*r**2*s - 195990625*q**9*r**2*s - 37341750*p**10*q*r**3*s - 194739375*p**7*q**3*r**3*s - 105463125*p**4*q**5*r**3*s - 415825000*p*q**7*r**3*s + 90180000*p**8*q*r**4*s - 990552500*p**5*q**3*r**4*s + 3519212500*p**2*q**5*r**4*s + 1112220000*p**6*q*r**5*s - 4508750000*p**3*q**3*r**5*s - 8159500000*q**5*r**5*s - 4356000000*p**4*q*r**6*s + 14615000000*p*q**3*r**6*s - 2160000000*p**2*q*r**7*s + 91125*p**15*s**2 + 3290625*p**12*q**2*s**2 + 35100000*p**9*q**4*s**2 + 175406250*p**6*q**6*s**2 + 629062500*p**3*q**8*s**2 + 910937500*q**10*s**2 - 5710500*p**13*r*s**2 - 100423125*p**10*q**2*r*s**2 - 604743750*p**7*q**4*r*s**2 - 2954843750*p**4*q**6*r*s**2 - 4587578125*p*q**8*r*s**2 + 116194500*p**11*r**2*s**2 + 1280716250*p**8*q**2*r**2*s**2 + 7401190625*p**5*q**4*r**2*s**2 + 11619937500*p**2*q**6*r**2*s**2 - 952173125*p**9*r**3*s**2 - 6519712500*p**6*q**2*r**3*s**2 - 10238593750*p**3*q**4*r**3*s**2 + 29984609375*q**6*r**3*s**2 + 2558300000*p**7*r**4*s**2 + 16225000000*p**4*q**2*r**4*s**2 - 64994140625*p*q**4*r**4*s**2 + 4202250000*p**5*r**5*s**2 + 46925000000*p**2*q**2*r**5*s**2 - 28950000000*p**3*r**6*s**2 - 1000000000*q**2*r**6*s**2 + 37000000000*p*r**7*s**2 - 48093750*p**11*q*s**3 - 673359375*p**8*q**3*s**3 - 2170312500*p**5*q**5*s**3 - 2466796875*p**2*q**7*s**3 + 647578125*p**9*q*r*s**3 + 597031250*p**6*q**3*r*s**3 - 7542578125*p**3*q**5*r*s**3 - 41125000000*q**7*r*s**3 - 2175828125*p**7*q*r**2*s**3 - 7101562500*p**4*q**3*r**2*s**3 + 100596875000*p*q**5*r**2*s**3 - 8984687500*p**5*q*r**3*s**3 - 120070312500*p**2*q**3*r**3*s**3 + 57343750000*p**3*q*r**4*s**3 + 9500000000*q**3*r**4*s**3 - 342875000000*p*q*r**5*s**3 + 400781250*p**10*s**4 + 8531250000*p**7*q**2*s**4 + 34033203125*p**4*q**4*s**4 + 42724609375*p*q**6*s**4 - 6289453125*p**8*r*s**4 - 24037109375*p**5*q**2*r*s**4 - 62626953125*p**2*q**4*r*s**4 + 17299218750*p**6*r**2*s**4 + 108357421875*p**3*q**2*r**2*s**4 - 55380859375*q**4*r**2*s**4 + 105648437500*p**4*r**3*s**4 + 1204228515625*p*q**2*r**3*s**4 - 365000000000*p**2*r**4*s**4 + 184375000000*r**5*s**4 - 32080078125*p**6*q*s**5 - 98144531250*p**3*q**3*s**5 + 93994140625*q**5*s**5 - 178955078125*p**4*q*r*s**5 - 1299804687500*p*q**3*r*s**5 + 332421875000*p**2*q*r**2*s**5 - 1195312500000*q*r**3*s**5 + 72021484375*p**5*s**6 + 323486328125*p**2*q**2*s**6 + 682373046875*p**3*r*s**6 + 2447509765625*q**2*r*s**6 - 3011474609375*p*r**2*s**6 + 3051757812500*p*q*s**7 - 7629394531250*s**8 + + b[3][4] = 1500*p**9*q**6 + 69625*p**6*q**8 + 590375*p**3*q**10 + 1035000*q**12 - 13500*p**10*q**4*r - 760625*p**7*q**6*r - 7904500*p**4*q**8*r - 18169250*p*q**10*r + 30375*p**11*q**2*r**2 + 2628625*p**8*q**4*r**2 + 37879000*p**5*q**6*r**2 + 121367500*p**2*q**8*r**2 - 2699250*p**9*q**2*r**3 - 76776875*p**6*q**4*r**3 - 403583125*p**3*q**6*r**3 - 78865625*q**8*r**3 + 60907500*p**7*q**2*r**4 + 735291250*p**4*q**4*r**4 + 781142500*p*q**6*r**4 - 558270000*p**5*q**2*r**5 - 2150725000*p**2*q**4*r**5 + 2015400000*p**3*q**2*r**6 + 1181000000*q**4*r**6 - 2220000000*p*q**2*r**7 + 40500*p**11*q**3*s + 1376500*p**8*q**5*s + 9953125*p**5*q**7*s + 9765625*p**2*q**9*s - 182250*p**12*q*r*s - 8859000*p**9*q**3*r*s - 82854500*p**6*q**5*r*s - 71511250*p**3*q**7*r*s + 273631250*q**9*r*s + 10233000*p**10*q*r**2*s + 179627500*p**7*q**3*r**2*s + 25164375*p**4*q**5*r**2*s - 2927290625*p*q**7*r**2*s - 171305000*p**8*q*r**3*s - 544768750*p**5*q**3*r**3*s + 7583437500*p**2*q**5*r**3*s + 1139860000*p**6*q*r**4*s - 6489375000*p**3*q**3*r**4*s - 9625375000*q**5*r**4*s - 1838000000*p**4*q*r**5*s + 19835000000*p*q**3*r**5*s - 3240000000*p**2*q*r**6*s + 273375*p**13*s**2 + 9753750*p**10*q**2*s**2 + 82575000*p**7*q**4*s**2 + 202265625*p**4*q**6*s**2 + 556093750*p*q**8*s**2 - 11552625*p**11*r*s**2 - 115813125*p**8*q**2*r*s**2 + 630590625*p**5*q**4*r*s**2 + 1347015625*p**2*q**6*r*s**2 + 157578750*p**9*r**2*s**2 - 689206250*p**6*q**2*r**2*s**2 - 4299609375*p**3*q**4*r**2*s**2 + 23896171875*q**6*r**2*s**2 - 1022437500*p**7*r**3*s**2 + 6648125000*p**4*q**2*r**3*s**2 - 52895312500*p*q**4*r**3*s**2 + 4401750000*p**5*r**4*s**2 + 26500000000*p**2*q**2*r**4*s**2 - 22125000000*p**3*r**5*s**2 - 1500000000*q**2*r**5*s**2 + 55500000000*p*r**6*s**2 - 137109375*p**9*q*s**3 - 1955937500*p**6*q**3*s**3 - 6790234375*p**3*q**5*s**3 - 16996093750*q**7*s**3 + 2146218750*p**7*q*r*s**3 + 6570312500*p**4*q**3*r*s**3 + 39918750000*p*q**5*r*s**3 - 7673281250*p**5*q*r**2*s**3 - 52000000000*p**2*q**3*r**2*s**3 + 50796875000*p**3*q*r**3*s**3 + 18750000000*q**3*r**3*s**3 - 399875000000*p*q*r**4*s**3 + 780468750*p**8*s**4 + 14455078125*p**5*q**2*s**4 + 10048828125*p**2*q**4*s**4 - 15113671875*p**6*r*s**4 + 39298828125*p**3*q**2*r*s**4 - 52138671875*q**4*r*s**4 + 45964843750*p**4*r**2*s**4 + 914414062500*p*q**2*r**2*s**4 + 1953125000*p**2*r**3*s**4 + 334375000000*r**4*s**4 - 149169921875*p**4*q*s**5 - 459716796875*p*q**3*s**5 - 325585937500*p**2*q*r*s**5 - 1462890625000*q*r**2*s**5 + 296630859375*p**3*s**6 + 1324462890625*q**2*s**6 + 307617187500*p*r*s**6 + + b[3][3] = -20750*p**7*q**6 - 290125*p**4*q**8 - 993000*p*q**10 + 146125*p**8*q**4*r + 2721500*p**5*q**6*r + 11833750*p**2*q**8*r - 237375*p**9*q**2*r**2 - 8167500*p**6*q**4*r**2 - 54605625*p**3*q**6*r**2 - 23802500*q**8*r**2 + 8927500*p**7*q**2*r**3 + 131184375*p**4*q**4*r**3 + 254695000*p*q**6*r**3 - 121561250*p**5*q**2*r**4 - 728003125*p**2*q**4*r**4 + 702550000*p**3*q**2*r**5 + 597312500*q**4*r**5 - 1202500000*p*q**2*r**6 - 194625*p**9*q**3*s - 1568875*p**6*q**5*s + 9685625*p**3*q**7*s + 74662500*q**9*s + 327375*p**10*q*r*s + 1280000*p**7*q**3*r*s - 123703750*p**4*q**5*r*s - 850121875*p*q**7*r*s - 7436250*p**8*q*r**2*s + 164820000*p**5*q**3*r**2*s + 2336659375*p**2*q**5*r**2*s + 32202500*p**6*q*r**3*s - 2429765625*p**3*q**3*r**3*s - 4318609375*q**5*r**3*s + 148000000*p**4*q*r**4*s + 9902812500*p*q**3*r**4*s - 1755000000*p**2*q*r**5*s + 1154250*p**11*s**2 + 36821250*p**8*q**2*s**2 + 372825000*p**5*q**4*s**2 + 1170921875*p**2*q**6*s**2 - 38913750*p**9*r*s**2 - 797071875*p**6*q**2*r*s**2 - 2848984375*p**3*q**4*r*s**2 + 7651406250*q**6*r*s**2 + 415068750*p**7*r**2*s**2 + 3151328125*p**4*q**2*r**2*s**2 - 17696875000*p*q**4*r**2*s**2 - 725968750*p**5*r**3*s**2 + 5295312500*p**2*q**2*r**3*s**2 - 8581250000*p**3*r**4*s**2 - 812500000*q**2*r**4*s**2 + 30062500000*p*r**5*s**2 - 110109375*p**7*q*s**3 - 1976562500*p**4*q**3*s**3 - 6329296875*p*q**5*s**3 + 2256328125*p**5*q*r*s**3 + 8554687500*p**2*q**3*r*s**3 + 12947265625*p**3*q*r**2*s**3 + 7984375000*q**3*r**2*s**3 - 167039062500*p*q*r**3*s**3 + 1181250000*p**6*s**4 + 17873046875*p**3*q**2*s**4 - 20449218750*q**4*s**4 - 16265625000*p**4*r*s**4 + 260869140625*p*q**2*r*s**4 + 21025390625*p**2*r**2*s**4 + 207617187500*r**3*s**4 - 207177734375*p**2*q*s**5 - 615478515625*q*r*s**5 + 301513671875*p*s**6 + + b[3][2] = 53125*p**5*q**6 + 425000*p**2*q**8 - 394375*p**6*q**4*r - 4301875*p**3*q**6*r - 3225000*q**8*r + 851250*p**7*q**2*r**2 + 16910625*p**4*q**4*r**2 + 44210000*p*q**6*r**2 - 20474375*p**5*q**2*r**3 - 147190625*p**2*q**4*r**3 + 163975000*p**3*q**2*r**4 + 156812500*q**4*r**4 - 323750000*p*q**2*r**5 - 99375*p**7*q**3*s - 6395000*p**4*q**5*s - 49243750*p*q**7*s - 1164375*p**8*q*r*s + 4465625*p**5*q**3*r*s + 205546875*p**2*q**5*r*s + 12163750*p**6*q*r**2*s - 315546875*p**3*q**3*r**2*s - 946453125*q**5*r**2*s - 23500000*p**4*q*r**3*s + 2313437500*p*q**3*r**3*s - 472500000*p**2*q*r**4*s + 1316250*p**9*s**2 + 22715625*p**6*q**2*s**2 + 206953125*p**3*q**4*s**2 + 1220000000*q**6*s**2 - 20953125*p**7*r*s**2 - 277656250*p**4*q**2*r*s**2 - 3317187500*p*q**4*r*s**2 + 293734375*p**5*r**2*s**2 + 1351562500*p**2*q**2*r**2*s**2 - 2278125000*p**3*r**3*s**2 - 218750000*q**2*r**3*s**2 + 8093750000*p*r**4*s**2 - 9609375*p**5*q*s**3 + 240234375*p**2*q**3*s**3 + 2310546875*p**3*q*r*s**3 + 1171875000*q**3*r*s**3 - 33460937500*p*q*r**2*s**3 + 2185546875*p**4*s**4 + 32578125000*p*q**2*s**4 - 8544921875*p**2*r*s**4 + 58398437500*r**2*s**4 - 114013671875*q*s**5 + + b[3][1] = -16250*p**6*q**4 - 191875*p**3*q**6 - 495000*q**8 + 73125*p**7*q**2*r + 1437500*p**4*q**4*r + 5866250*p*q**6*r - 2043125*p**5*q**2*r**2 - 17218750*p**2*q**4*r**2 + 19106250*p**3*q**2*r**3 + 34015625*q**4*r**3 - 69375000*p*q**2*r**4 - 219375*p**8*q*s - 2846250*p**5*q**3*s - 8021875*p**2*q**5*s + 3420000*p**6*q*r*s - 1640625*p**3*q**3*r*s - 152468750*q**5*r*s + 3062500*p**4*q*r**2*s + 381171875*p*q**3*r**2*s - 101250000*p**2*q*r**3*s + 2784375*p**7*s**2 + 43515625*p**4*q**2*s**2 + 115625000*p*q**4*s**2 - 48140625*p**5*r*s**2 - 307421875*p**2*q**2*r*s**2 - 25781250*p**3*r**2*s**2 - 46875000*q**2*r**2*s**2 + 1734375000*p*r**3*s**2 - 128906250*p**3*q*s**3 + 339843750*q**3*s**3 - 4583984375*p*q*r*s**3 + 2236328125*p**2*s**4 + 12255859375*r*s**4 + + b[3][0] = 31875*p**4*q**4 + 255000*p*q**6 - 82500*p**5*q**2*r - 1106250*p**2*q**4*r + 1653125*p**3*q**2*r**2 + 5187500*q**4*r**2 - 11562500*p*q**2*r**3 - 118125*p**6*q*s - 3593750*p**3*q**3*s - 23812500*q**5*s + 4656250*p**4*q*r*s + 67109375*p*q**3*r*s - 16875000*p**2*q*r**2*s - 984375*p**5*s**2 - 19531250*p**2*q**2*s**2 - 37890625*p**3*r*s**2 - 7812500*q**2*r*s**2 + 289062500*p*r**2*s**2 - 529296875*p*q*s**3 + 2343750000*s**4 + + b[4][5] = 600*p**10*q**10 + 13850*p**7*q**12 + 106150*p**4*q**14 + 270000*p*q**16 - 9300*p**11*q**8*r - 234075*p**8*q**10*r - 1942825*p**5*q**12*r - 5319900*p**2*q**14*r + 52050*p**12*q**6*r**2 + 1481025*p**9*q**8*r**2 + 13594450*p**6*q**10*r**2 + 40062750*p**3*q**12*r**2 - 3569400*q**14*r**2 - 122175*p**13*q**4*r**3 - 4260350*p**10*q**6*r**3 - 45052375*p**7*q**8*r**3 - 142634900*p**4*q**10*r**3 + 54186350*p*q**12*r**3 + 97200*p**14*q**2*r**4 + 5284225*p**11*q**4*r**4 + 70389525*p**8*q**6*r**4 + 232732850*p**5*q**8*r**4 - 318849400*p**2*q**10*r**4 - 2046000*p**12*q**2*r**5 - 43874125*p**9*q**4*r**5 - 107411850*p**6*q**6*r**5 + 948310700*p**3*q**8*r**5 - 34763575*q**10*r**5 + 5915600*p**10*q**2*r**6 - 115887800*p**7*q**4*r**6 - 1649542400*p**4*q**6*r**6 + 224468875*p*q**8*r**6 + 120252800*p**8*q**2*r**7 + 1779902000*p**5*q**4*r**7 - 288250000*p**2*q**6*r**7 - 915200000*p**6*q**2*r**8 - 1164000000*p**3*q**4*r**8 - 444200000*q**6*r**8 + 2502400000*p**4*q**2*r**9 + 1984000000*p*q**4*r**9 - 2880000000*p**2*q**2*r**10 + 20700*p**12*q**7*s + 551475*p**9*q**9*s + 5194875*p**6*q**11*s + 18985000*p**3*q**13*s + 16875000*q**15*s - 218700*p**13*q**5*r*s - 6606475*p**10*q**7*r*s - 69770850*p**7*q**9*r*s - 285325500*p**4*q**11*r*s - 292005000*p*q**13*r*s + 694575*p**14*q**3*r**2*s + 26187750*p**11*q**5*r**2*s + 328992825*p**8*q**7*r**2*s + 1573292400*p**5*q**9*r**2*s + 1930043875*p**2*q**11*r**2*s - 583200*p**15*q*r**3*s - 37263225*p**12*q**3*r**3*s - 638579425*p**9*q**5*r**3*s - 3920212225*p**6*q**7*r**3*s - 6327336875*p**3*q**9*r**3*s + 440969375*q**11*r**3*s + 13446000*p**13*q*r**4*s + 462330325*p**10*q**3*r**4*s + 4509088275*p**7*q**5*r**4*s + 11709795625*p**4*q**7*r**4*s - 3579565625*p*q**9*r**4*s - 85033600*p**11*q*r**5*s - 2136801600*p**8*q**3*r**5*s - 12221575800*p**5*q**5*r**5*s + 9431044375*p**2*q**7*r**5*s + 10643200*p**9*q*r**6*s + 4565594000*p**6*q**3*r**6*s - 1778590000*p**3*q**5*r**6*s + 4842175000*q**7*r**6*s + 712320000*p**7*q*r**7*s - 16182000000*p**4*q**3*r**7*s - 21918000000*p*q**5*r**7*s - 742400000*p**5*q*r**8*s + 31040000000*p**2*q**3*r**8*s + 1280000000*p**3*q*r**9*s + 4800000000*q**3*r**9*s + 230850*p**14*q**4*s**2 + 7373250*p**11*q**6*s**2 + 85045625*p**8*q**8*s**2 + 399140625*p**5*q**10*s**2 + 565031250*p**2*q**12*s**2 - 1257525*p**15*q**2*r*s**2 - 52728975*p**12*q**4*r*s**2 - 743466375*p**9*q**6*r*s**2 - 4144915000*p**6*q**8*r*s**2 - 7102690625*p**3*q**10*r*s**2 - 1389937500*q**12*r*s**2 + 874800*p**16*r**2*s**2 + 89851275*p**13*q**2*r**2*s**2 + 1897236775*p**10*q**4*r**2*s**2 + 14144163000*p**7*q**6*r**2*s**2 + 31942921875*p**4*q**8*r**2*s**2 + 13305118750*p*q**10*r**2*s**2 - 23004000*p**14*r**3*s**2 - 1450715475*p**11*q**2*r**3*s**2 - 19427105000*p**8*q**4*r**3*s**2 - 70634028750*p**5*q**6*r**3*s**2 - 47854218750*p**2*q**8*r**3*s**2 + 204710400*p**12*r**4*s**2 + 10875135000*p**9*q**2*r**4*s**2 + 83618806250*p**6*q**4*r**4*s**2 + 62744500000*p**3*q**6*r**4*s**2 - 19806718750*q**8*r**4*s**2 - 757094800*p**10*r**5*s**2 - 37718030000*p**7*q**2*r**5*s**2 - 22479500000*p**4*q**4*r**5*s**2 + 91556093750*p*q**6*r**5*s**2 + 2306320000*p**8*r**6*s**2 + 55539600000*p**5*q**2*r**6*s**2 - 112851250000*p**2*q**4*r**6*s**2 - 10720000000*p**6*r**7*s**2 - 64720000000*p**3*q**2*r**7*s**2 - 59925000000*q**4*r**7*s**2 + 28000000000*p**4*r**8*s**2 + 28000000000*p*q**2*r**8*s**2 - 24000000000*p**2*r**9*s**2 + 820125*p**16*q*s**3 + 36804375*p**13*q**3*s**3 + 552225000*p**10*q**5*s**3 + 3357593750*p**7*q**7*s**3 + 7146562500*p**4*q**9*s**3 + 3851562500*p*q**11*s**3 - 92400750*p**14*q*r*s**3 - 2350175625*p**11*q**3*r*s**3 - 19470640625*p**8*q**5*r*s**3 - 52820593750*p**5*q**7*r*s**3 - 45447734375*p**2*q**9*r*s**3 + 1824363000*p**12*q*r**2*s**3 + 31435234375*p**9*q**3*r**2*s**3 + 141717537500*p**6*q**5*r**2*s**3 + 228370781250*p**3*q**7*r**2*s**3 + 34610078125*q**9*r**2*s**3 - 17591825625*p**10*q*r**3*s**3 - 188927187500*p**7*q**3*r**3*s**3 - 502088984375*p**4*q**5*r**3*s**3 - 187849296875*p*q**7*r**3*s**3 + 75577750000*p**8*q*r**4*s**3 + 342800000000*p**5*q**3*r**4*s**3 + 295384296875*p**2*q**5*r**4*s**3 - 107681250000*p**6*q*r**5*s**3 + 53330000000*p**3*q**3*r**5*s**3 + 271586875000*q**5*r**5*s**3 - 26410000000*p**4*q*r**6*s**3 - 188200000000*p*q**3*r**6*s**3 + 92000000000*p**2*q*r**7*s**3 + 120000000000*q*r**8*s**3 + 47840625*p**15*s**4 + 1150453125*p**12*q**2*s**4 + 9229453125*p**9*q**4*s**4 + 24954687500*p**6*q**6*s**4 + 22978515625*p**3*q**8*s**4 + 1367187500*q**10*s**4 - 1193737500*p**13*r*s**4 - 20817843750*p**10*q**2*r*s**4 - 98640000000*p**7*q**4*r*s**4 - 225767187500*p**4*q**6*r*s**4 - 74707031250*p*q**8*r*s**4 + 13431318750*p**11*r**2*s**4 + 188709843750*p**8*q**2*r**2*s**4 + 875157656250*p**5*q**4*r**2*s**4 + 593812890625*p**2*q**6*r**2*s**4 - 69869296875*p**9*r**3*s**4 - 854811093750*p**6*q**2*r**3*s**4 - 1730658203125*p**3*q**4*r**3*s**4 - 570867187500*q**6*r**3*s**4 + 162075625000*p**7*r**4*s**4 + 1536375000000*p**4*q**2*r**4*s**4 + 765156250000*p*q**4*r**4*s**4 - 165988750000*p**5*r**5*s**4 - 728968750000*p**2*q**2*r**5*s**4 + 121500000000*p**3*r**6*s**4 - 1039375000000*q**2*r**6*s**4 - 100000000000*p*r**7*s**4 - 379687500*p**11*q*s**5 - 11607421875*p**8*q**3*s**5 - 20830078125*p**5*q**5*s**5 - 33691406250*p**2*q**7*s**5 - 41491406250*p**9*q*r*s**5 - 419054687500*p**6*q**3*r*s**5 - 129511718750*p**3*q**5*r*s**5 + 311767578125*q**7*r*s**5 + 620116015625*p**7*q*r**2*s**5 + 1154687500000*p**4*q**3*r**2*s**5 + 36455078125*p*q**5*r**2*s**5 - 2265953125000*p**5*q*r**3*s**5 - 1509521484375*p**2*q**3*r**3*s**5 + 2530468750000*p**3*q*r**4*s**5 + 3259765625000*q**3*r**4*s**5 + 93750000000*p*q*r**5*s**5 + 23730468750*p**10*s**6 + 243603515625*p**7*q**2*s**6 + 341552734375*p**4*q**4*s**6 - 12207031250*p*q**6*s**6 - 357099609375*p**8*r*s**6 - 298193359375*p**5*q**2*r*s**6 + 406738281250*p**2*q**4*r*s**6 + 1615683593750*p**6*r**2*s**6 + 558593750000*p**3*q**2*r**2*s**6 - 2811035156250*q**4*r**2*s**6 - 2960937500000*p**4*r**3*s**6 - 3802246093750*p*q**2*r**3*s**6 + 2347656250000*p**2*r**4*s**6 - 671875000000*r**5*s**6 - 651855468750*p**6*q*s**7 - 1458740234375*p**3*q**3*s**7 - 152587890625*q**5*s**7 + 1628417968750*p**4*q*r*s**7 + 3948974609375*p*q**3*r*s**7 - 916748046875*p**2*q*r**2*s**7 + 1611328125000*q*r**3*s**7 + 640869140625*p**5*s**8 + 1068115234375*p**2*q**2*s**8 - 2044677734375*p**3*r*s**8 - 3204345703125*q**2*r*s**8 + 1739501953125*p*r**2*s**8 + + b[4][4] = -600*p**11*q**8 - 14050*p**8*q**10 - 109100*p**5*q**12 - 280800*p**2*q**14 + 7200*p**12*q**6*r + 188700*p**9*q**8*r + 1621725*p**6*q**10*r + 4577075*p**3*q**12*r + 5400*q**14*r - 28350*p**13*q**4*r**2 - 910600*p**10*q**6*r**2 - 9237975*p**7*q**8*r**2 - 30718900*p**4*q**10*r**2 - 5575950*p*q**12*r**2 + 36450*p**14*q**2*r**3 + 1848125*p**11*q**4*r**3 + 25137775*p**8*q**6*r**3 + 109591450*p**5*q**8*r**3 + 70627650*p**2*q**10*r**3 - 1317150*p**12*q**2*r**4 - 32857100*p**9*q**4*r**4 - 219125575*p**6*q**6*r**4 - 327565875*p**3*q**8*r**4 - 13011875*q**10*r**4 + 16484150*p**10*q**2*r**5 + 222242250*p**7*q**4*r**5 + 642173750*p**4*q**6*r**5 + 101263750*p*q**8*r**5 - 79345000*p**8*q**2*r**6 - 433180000*p**5*q**4*r**6 - 93731250*p**2*q**6*r**6 - 74300000*p**6*q**2*r**7 - 1057900000*p**3*q**4*r**7 - 591175000*q**6*r**7 + 1891600000*p**4*q**2*r**8 + 2796000000*p*q**4*r**8 - 4320000000*p**2*q**2*r**9 - 16200*p**13*q**5*s - 359500*p**10*q**7*s - 2603825*p**7*q**9*s - 4590375*p**4*q**11*s + 12352500*p*q**13*s + 121500*p**14*q**3*r*s + 3227400*p**11*q**5*r*s + 27301725*p**8*q**7*r*s + 59480975*p**5*q**9*r*s - 137308875*p**2*q**11*r*s - 218700*p**15*q*r**2*s - 8903925*p**12*q**3*r**2*s - 100918225*p**9*q**5*r**2*s - 325291300*p**6*q**7*r**2*s + 365705000*p**3*q**9*r**2*s + 94342500*q**11*r**2*s + 7632900*p**13*q*r**3*s + 162995400*p**10*q**3*r**3*s + 974558975*p**7*q**5*r**3*s + 930991250*p**4*q**7*r**3*s - 495368750*p*q**9*r**3*s - 97344900*p**11*q*r**4*s - 1406739250*p**8*q**3*r**4*s - 5572526250*p**5*q**5*r**4*s - 1903987500*p**2*q**7*r**4*s + 678550000*p**9*q*r**5*s + 8176215000*p**6*q**3*r**5*s + 18082050000*p**3*q**5*r**5*s + 5435843750*q**7*r**5*s - 2979800000*p**7*q*r**6*s - 29163500000*p**4*q**3*r**6*s - 27417500000*p*q**5*r**6*s + 6282400000*p**5*q*r**7*s + 48690000000*p**2*q**3*r**7*s - 2880000000*p**3*q*r**8*s + 7200000000*q**3*r**8*s - 109350*p**15*q**2*s**2 - 2405700*p**12*q**4*s**2 - 16125250*p**9*q**6*s**2 - 4930000*p**6*q**8*s**2 + 201150000*p**3*q**10*s**2 - 243000000*q**12*s**2 + 328050*p**16*r*s**2 + 10552275*p**13*q**2*r*s**2 + 88019100*p**10*q**4*r*s**2 - 4208625*p**7*q**6*r*s**2 - 1920390625*p**4*q**8*r*s**2 + 1759537500*p*q**10*r*s**2 - 11955600*p**14*r**2*s**2 - 196375050*p**11*q**2*r**2*s**2 - 555196250*p**8*q**4*r**2*s**2 + 4213270000*p**5*q**6*r**2*s**2 - 157468750*p**2*q**8*r**2*s**2 + 162656100*p**12*r**3*s**2 + 1880870000*p**9*q**2*r**3*s**2 + 753684375*p**6*q**4*r**3*s**2 - 25423062500*p**3*q**6*r**3*s**2 - 14142031250*q**8*r**3*s**2 - 1251948750*p**10*r**4*s**2 - 12524475000*p**7*q**2*r**4*s**2 + 18067656250*p**4*q**4*r**4*s**2 + 60531875000*p*q**6*r**4*s**2 + 6827725000*p**8*r**5*s**2 + 57157000000*p**5*q**2*r**5*s**2 - 75844531250*p**2*q**4*r**5*s**2 - 24452500000*p**6*r**6*s**2 - 144950000000*p**3*q**2*r**6*s**2 - 82109375000*q**4*r**6*s**2 + 46950000000*p**4*r**7*s**2 + 60000000000*p*q**2*r**7*s**2 - 36000000000*p**2*r**8*s**2 + 1549125*p**14*q*s**3 + 51873750*p**11*q**3*s**3 + 599781250*p**8*q**5*s**3 + 2421156250*p**5*q**7*s**3 - 1693515625*p**2*q**9*s**3 - 104884875*p**12*q*r*s**3 - 1937437500*p**9*q**3*r*s**3 - 11461053125*p**6*q**5*r*s**3 + 10299375000*p**3*q**7*r*s**3 + 10551250000*q**9*r*s**3 + 1336263750*p**10*q*r**2*s**3 + 23737250000*p**7*q**3*r**2*s**3 + 57136718750*p**4*q**5*r**2*s**3 - 8288906250*p*q**7*r**2*s**3 - 10907218750*p**8*q*r**3*s**3 - 160615000000*p**5*q**3*r**3*s**3 - 111134687500*p**2*q**5*r**3*s**3 + 46743125000*p**6*q*r**4*s**3 + 570509375000*p**3*q**3*r**4*s**3 + 274839843750*q**5*r**4*s**3 - 73312500000*p**4*q*r**5*s**3 - 145437500000*p*q**3*r**5*s**3 + 8750000000*p**2*q*r**6*s**3 + 180000000000*q*r**7*s**3 + 15946875*p**13*s**4 + 1265625*p**10*q**2*s**4 - 3282343750*p**7*q**4*s**4 - 38241406250*p**4*q**6*s**4 - 40136718750*p*q**8*s**4 - 113146875*p**11*r*s**4 - 2302734375*p**8*q**2*r*s**4 + 68450156250*p**5*q**4*r*s**4 + 177376562500*p**2*q**6*r*s**4 + 3164062500*p**9*r**2*s**4 + 14392890625*p**6*q**2*r**2*s**4 - 543781250000*p**3*q**4*r**2*s**4 - 319769531250*q**6*r**2*s**4 - 21048281250*p**7*r**3*s**4 - 240687500000*p**4*q**2*r**3*s**4 - 228164062500*p*q**4*r**3*s**4 + 23062500000*p**5*r**4*s**4 + 300410156250*p**2*q**2*r**4*s**4 + 93437500000*p**3*r**5*s**4 - 1141015625000*q**2*r**5*s**4 - 187500000000*p*r**6*s**4 + 1761328125*p**9*q*s**5 - 3177734375*p**6*q**3*s**5 + 60019531250*p**3*q**5*s**5 + 108398437500*q**7*s**5 + 24106640625*p**7*q*r*s**5 + 429589843750*p**4*q**3*r*s**5 + 410371093750*p*q**5*r*s**5 - 23582031250*p**5*q*r**2*s**5 + 202441406250*p**2*q**3*r**2*s**5 - 383203125000*p**3*q*r**3*s**5 + 2232910156250*q**3*r**3*s**5 + 1500000000000*p*q*r**4*s**5 - 13710937500*p**8*s**6 - 202832031250*p**5*q**2*s**6 - 531738281250*p**2*q**4*s**6 + 73330078125*p**6*r*s**6 - 3906250000*p**3*q**2*r*s**6 - 1275878906250*q**4*r*s**6 - 121093750000*p**4*r**2*s**6 - 3308593750000*p*q**2*r**2*s**6 + 18066406250*p**2*r**3*s**6 - 244140625000*r**4*s**6 + 327148437500*p**4*q*s**7 + 1672363281250*p*q**3*s**7 + 446777343750*p**2*q*r*s**7 + 1232910156250*q*r**2*s**7 - 274658203125*p**3*s**8 - 1068115234375*q**2*s**8 - 61035156250*p*r*s**8 + + b[4][3] = 200*p**9*q**8 + 7550*p**6*q**10 + 78650*p**3*q**12 + 248400*q**14 - 4800*p**10*q**6*r - 164300*p**7*q**8*r - 1709575*p**4*q**10*r - 5566500*p*q**12*r + 31050*p**11*q**4*r**2 + 1116175*p**8*q**6*r**2 + 12674650*p**5*q**8*r**2 + 45333850*p**2*q**10*r**2 - 60750*p**12*q**2*r**3 - 2872725*p**9*q**4*r**3 - 40403050*p**6*q**6*r**3 - 173564375*p**3*q**8*r**3 - 11242250*q**10*r**3 + 2174100*p**10*q**2*r**4 + 54010000*p**7*q**4*r**4 + 331074875*p**4*q**6*r**4 + 114173750*p*q**8*r**4 - 24858500*p**8*q**2*r**5 - 300875000*p**5*q**4*r**5 - 319430625*p**2*q**6*r**5 + 69810000*p**6*q**2*r**6 - 23900000*p**3*q**4*r**6 - 294662500*q**6*r**6 + 524200000*p**4*q**2*r**7 + 1432000000*p*q**4*r**7 - 2340000000*p**2*q**2*r**8 + 5400*p**11*q**5*s + 310400*p**8*q**7*s + 3591725*p**5*q**9*s + 11556750*p**2*q**11*s - 105300*p**12*q**3*r*s - 4234650*p**9*q**5*r*s - 49928875*p**6*q**7*r*s - 174078125*p**3*q**9*r*s + 18000000*q**11*r*s + 364500*p**13*q*r**2*s + 15763050*p**10*q**3*r**2*s + 220187400*p**7*q**5*r**2*s + 929609375*p**4*q**7*r**2*s - 43653125*p*q**9*r**2*s - 13427100*p**11*q*r**3*s - 346066250*p**8*q**3*r**3*s - 2287673375*p**5*q**5*r**3*s - 1403903125*p**2*q**7*r**3*s + 184586000*p**9*q*r**4*s + 2983460000*p**6*q**3*r**4*s + 8725818750*p**3*q**5*r**4*s + 2527734375*q**7*r**4*s - 1284480000*p**7*q*r**5*s - 13138250000*p**4*q**3*r**5*s - 14001625000*p*q**5*r**5*s + 4224800000*p**5*q*r**6*s + 27460000000*p**2*q**3*r**6*s - 3760000000*p**3*q*r**7*s + 3900000000*q**3*r**7*s + 36450*p**13*q**2*s**2 + 2765475*p**10*q**4*s**2 + 34027625*p**7*q**6*s**2 + 97375000*p**4*q**8*s**2 - 88275000*p*q**10*s**2 - 546750*p**14*r*s**2 - 21961125*p**11*q**2*r*s**2 - 273059375*p**8*q**4*r*s**2 - 761562500*p**5*q**6*r*s**2 + 1869656250*p**2*q**8*r*s**2 + 20545650*p**12*r**2*s**2 + 473934375*p**9*q**2*r**2*s**2 + 1758053125*p**6*q**4*r**2*s**2 - 8743359375*p**3*q**6*r**2*s**2 - 4154375000*q**8*r**2*s**2 - 296559000*p**10*r**3*s**2 - 4065056250*p**7*q**2*r**3*s**2 - 186328125*p**4*q**4*r**3*s**2 + 19419453125*p*q**6*r**3*s**2 + 2326262500*p**8*r**4*s**2 + 21189375000*p**5*q**2*r**4*s**2 - 26301953125*p**2*q**4*r**4*s**2 - 10513250000*p**6*r**5*s**2 - 69937500000*p**3*q**2*r**5*s**2 - 42257812500*q**4*r**5*s**2 + 23375000000*p**4*r**6*s**2 + 40750000000*p*q**2*r**6*s**2 - 19500000000*p**2*r**7*s**2 + 4009500*p**12*q*s**3 + 36140625*p**9*q**3*s**3 - 335459375*p**6*q**5*s**3 - 2695312500*p**3*q**7*s**3 - 1486250000*q**9*s**3 + 102515625*p**10*q*r*s**3 + 4006812500*p**7*q**3*r*s**3 + 27589609375*p**4*q**5*r*s**3 + 20195312500*p*q**7*r*s**3 - 2792812500*p**8*q*r**2*s**3 - 44115156250*p**5*q**3*r**2*s**3 - 72609453125*p**2*q**5*r**2*s**3 + 18752500000*p**6*q*r**3*s**3 + 218140625000*p**3*q**3*r**3*s**3 + 109940234375*q**5*r**3*s**3 - 21893750000*p**4*q*r**4*s**3 - 65187500000*p*q**3*r**4*s**3 - 31000000000*p**2*q*r**5*s**3 + 97500000000*q*r**6*s**3 - 86568750*p**11*s**4 - 1955390625*p**8*q**2*s**4 - 8960781250*p**5*q**4*s**4 - 1357812500*p**2*q**6*s**4 + 1657968750*p**9*r*s**4 + 10467187500*p**6*q**2*r*s**4 - 55292968750*p**3*q**4*r*s**4 - 60683593750*q**6*r*s**4 - 11473593750*p**7*r**2*s**4 - 123281250000*p**4*q**2*r**2*s**4 - 164912109375*p*q**4*r**2*s**4 + 13150000000*p**5*r**3*s**4 + 190751953125*p**2*q**2*r**3*s**4 + 61875000000*p**3*r**4*s**4 - 467773437500*q**2*r**4*s**4 - 118750000000*p*r**5*s**4 + 7583203125*p**7*q*s**5 + 54638671875*p**4*q**3*s**5 + 39423828125*p*q**5*s**5 + 32392578125*p**5*q*r*s**5 + 278515625000*p**2*q**3*r*s**5 - 298339843750*p**3*q*r**2*s**5 + 560791015625*q**3*r**2*s**5 + 720703125000*p*q*r**3*s**5 - 19687500000*p**6*s**6 - 159667968750*p**3*q**2*s**6 - 72265625000*q**4*s**6 + 116699218750*p**4*r*s**6 - 924072265625*p*q**2*r*s**6 - 156005859375*p**2*r**2*s**6 - 112304687500*r**3*s**6 + 349121093750*p**2*q*s**7 + 396728515625*q*r*s**7 - 213623046875*p*s**8 + + b[4][2] = -600*p**10*q**6 - 18450*p**7*q**8 - 174000*p**4*q**10 - 518400*p*q**12 + 5400*p**11*q**4*r + 197550*p**8*q**6*r + 2147775*p**5*q**8*r + 7219800*p**2*q**10*r - 12150*p**12*q**2*r**2 - 662200*p**9*q**4*r**2 - 9274775*p**6*q**6*r**2 - 38330625*p**3*q**8*r**2 - 5508000*q**10*r**2 + 656550*p**10*q**2*r**3 + 16233750*p**7*q**4*r**3 + 97335875*p**4*q**6*r**3 + 58271250*p*q**8*r**3 - 9845500*p**8*q**2*r**4 - 119464375*p**5*q**4*r**4 - 194431875*p**2*q**6*r**4 + 49465000*p**6*q**2*r**5 + 166000000*p**3*q**4*r**5 - 80793750*q**6*r**5 + 54400000*p**4*q**2*r**6 + 377750000*p*q**4*r**6 - 630000000*p**2*q**2*r**7 - 16200*p**12*q**3*s - 459300*p**9*q**5*s - 4207225*p**6*q**7*s - 10827500*p**3*q**9*s + 13635000*q**11*s + 72900*p**13*q*r*s + 2877300*p**10*q**3*r*s + 33239700*p**7*q**5*r*s + 107080625*p**4*q**7*r*s - 114975000*p*q**9*r*s - 3601800*p**11*q*r**2*s - 75214375*p**8*q**3*r**2*s - 387073250*p**5*q**5*r**2*s + 55540625*p**2*q**7*r**2*s + 53793000*p**9*q*r**3*s + 687176875*p**6*q**3*r**3*s + 1670018750*p**3*q**5*r**3*s + 665234375*q**7*r**3*s - 391570000*p**7*q*r**4*s - 3420125000*p**4*q**3*r**4*s - 3609625000*p*q**5*r**4*s + 1365600000*p**5*q*r**5*s + 7236250000*p**2*q**3*r**5*s - 1220000000*p**3*q*r**6*s + 1050000000*q**3*r**6*s - 109350*p**14*s**2 - 3065850*p**11*q**2*s**2 - 26908125*p**8*q**4*s**2 - 44606875*p**5*q**6*s**2 + 269812500*p**2*q**8*s**2 + 5200200*p**12*r*s**2 + 81826875*p**9*q**2*r*s**2 + 155378125*p**6*q**4*r*s**2 - 1936203125*p**3*q**6*r*s**2 - 998437500*q**8*r*s**2 - 77145750*p**10*r**2*s**2 - 745528125*p**7*q**2*r**2*s**2 + 683437500*p**4*q**4*r**2*s**2 + 4083359375*p*q**6*r**2*s**2 + 593287500*p**8*r**3*s**2 + 4799375000*p**5*q**2*r**3*s**2 - 4167578125*p**2*q**4*r**3*s**2 - 2731125000*p**6*r**4*s**2 - 18668750000*p**3*q**2*r**4*s**2 - 10480468750*q**4*r**4*s**2 + 6200000000*p**4*r**5*s**2 + 11750000000*p*q**2*r**5*s**2 - 5250000000*p**2*r**6*s**2 + 26527500*p**10*q*s**3 + 526031250*p**7*q**3*s**3 + 3160703125*p**4*q**5*s**3 + 2650312500*p*q**7*s**3 - 448031250*p**8*q*r*s**3 - 6682968750*p**5*q**3*r*s**3 - 11642812500*p**2*q**5*r*s**3 + 2553203125*p**6*q*r**2*s**3 + 37234375000*p**3*q**3*r**2*s**3 + 21871484375*q**5*r**2*s**3 + 2803125000*p**4*q*r**3*s**3 - 10796875000*p*q**3*r**3*s**3 - 16656250000*p**2*q*r**4*s**3 + 26250000000*q*r**5*s**3 - 75937500*p**9*s**4 - 704062500*p**6*q**2*s**4 - 8363281250*p**3*q**4*s**4 - 10398437500*q**6*s**4 + 197578125*p**7*r*s**4 - 16441406250*p**4*q**2*r*s**4 - 24277343750*p*q**4*r*s**4 - 5716015625*p**5*r**2*s**4 + 31728515625*p**2*q**2*r**2*s**4 + 27031250000*p**3*r**3*s**4 - 92285156250*q**2*r**3*s**4 - 33593750000*p*r**4*s**4 + 10394531250*p**5*q*s**5 + 38037109375*p**2*q**3*s**5 - 48144531250*p**3*q*r*s**5 + 74462890625*q**3*r*s**5 + 121093750000*p*q*r**2*s**5 - 2197265625*p**4*s**6 - 92529296875*p*q**2*s**6 + 15380859375*p**2*r*s**6 - 31738281250*r**2*s**6 + 54931640625*q*s**7 + + b[4][1] = 200*p**8*q**6 + 2950*p**5*q**8 + 10800*p**2*q**10 - 1800*p**9*q**4*r - 49650*p**6*q**6*r - 403375*p**3*q**8*r - 999000*q**10*r + 4050*p**10*q**2*r**2 + 236625*p**7*q**4*r**2 + 3109500*p**4*q**6*r**2 + 11463750*p*q**8*r**2 - 331500*p**8*q**2*r**3 - 7818125*p**5*q**4*r**3 - 41411250*p**2*q**6*r**3 + 4782500*p**6*q**2*r**4 + 47475000*p**3*q**4*r**4 - 16728125*q**6*r**4 - 8700000*p**4*q**2*r**5 + 81750000*p*q**4*r**5 - 135000000*p**2*q**2*r**6 + 5400*p**10*q**3*s + 144200*p**7*q**5*s + 939375*p**4*q**7*s + 1012500*p*q**9*s - 24300*p**11*q*r*s - 1169250*p**8*q**3*r*s - 14027250*p**5*q**5*r*s - 44446875*p**2*q**7*r*s + 2011500*p**9*q*r**2*s + 49330625*p**6*q**3*r**2*s + 272009375*p**3*q**5*r**2*s + 104062500*q**7*r**2*s - 34660000*p**7*q*r**3*s - 455062500*p**4*q**3*r**3*s - 625906250*p*q**5*r**3*s + 210200000*p**5*q*r**4*s + 1298750000*p**2*q**3*r**4*s - 240000000*p**3*q*r**5*s + 225000000*q**3*r**5*s + 36450*p**12*s**2 + 1231875*p**9*q**2*s**2 + 10712500*p**6*q**4*s**2 + 21718750*p**3*q**6*s**2 + 16875000*q**8*s**2 - 2814750*p**10*r*s**2 - 67612500*p**7*q**2*r*s**2 - 345156250*p**4*q**4*r*s**2 - 283125000*p*q**6*r*s**2 + 51300000*p**8*r**2*s**2 + 734531250*p**5*q**2*r**2*s**2 + 1267187500*p**2*q**4*r**2*s**2 - 384312500*p**6*r**3*s**2 - 3912500000*p**3*q**2*r**3*s**2 - 1822265625*q**4*r**3*s**2 + 1112500000*p**4*r**4*s**2 + 2437500000*p*q**2*r**4*s**2 - 1125000000*p**2*r**5*s**2 - 72578125*p**5*q**3*s**3 - 189296875*p**2*q**5*s**3 + 127265625*p**6*q*r*s**3 + 1415625000*p**3*q**3*r*s**3 + 1229687500*q**5*r*s**3 + 1448437500*p**4*q*r**2*s**3 + 2218750000*p*q**3*r**2*s**3 - 4031250000*p**2*q*r**3*s**3 + 5625000000*q*r**4*s**3 - 132890625*p**7*s**4 - 529296875*p**4*q**2*s**4 - 175781250*p*q**4*s**4 - 401953125*p**5*r*s**4 - 4482421875*p**2*q**2*r*s**4 + 4140625000*p**3*r**2*s**4 - 10498046875*q**2*r**2*s**4 - 7031250000*p*r**3*s**4 + 1220703125*p**3*q*s**5 + 1953125000*q**3*s**5 + 14160156250*p*q*r*s**5 - 1708984375*p**2*s**6 - 3662109375*r*s**6 + + b[4][0] = -4600*p**6*q**6 - 67850*p**3*q**8 - 248400*q**10 + 38900*p**7*q**4*r + 679575*p**4*q**6*r + 2866500*p*q**8*r - 81900*p**8*q**2*r**2 - 2009750*p**5*q**4*r**2 - 10783750*p**2*q**6*r**2 + 1478750*p**6*q**2*r**3 + 14165625*p**3*q**4*r**3 - 2743750*q**6*r**3 - 5450000*p**4*q**2*r**4 + 12687500*p*q**4*r**4 - 22500000*p**2*q**2*r**5 - 101700*p**8*q**3*s - 1700975*p**5*q**5*s - 7061250*p**2*q**7*s + 423900*p**9*q*r*s + 9292375*p**6*q**3*r*s + 50438750*p**3*q**5*r*s + 20475000*q**7*r*s - 7852500*p**7*q*r**2*s - 87765625*p**4*q**3*r**2*s - 121609375*p*q**5*r**2*s + 47700000*p**5*q*r**3*s + 264687500*p**2*q**3*r**3*s - 65000000*p**3*q*r**4*s + 37500000*q**3*r**4*s - 534600*p**10*s**2 - 10344375*p**7*q**2*s**2 - 54859375*p**4*q**4*s**2 - 40312500*p*q**6*s**2 + 10158750*p**8*r*s**2 + 117778125*p**5*q**2*r*s**2 + 192421875*p**2*q**4*r*s**2 - 70593750*p**6*r**2*s**2 - 685312500*p**3*q**2*r**2*s**2 - 334375000*q**4*r**2*s**2 + 193750000*p**4*r**3*s**2 + 500000000*p*q**2*r**3*s**2 - 187500000*p**2*r**4*s**2 + 8437500*p**6*q*s**3 + 159218750*p**3*q**3*s**3 + 220625000*q**5*s**3 + 353828125*p**4*q*r*s**3 + 412500000*p*q**3*r*s**3 - 1023437500*p**2*q*r**2*s**3 + 937500000*q*r**3*s**3 - 206015625*p**5*s**4 - 701171875*p**2*q**2*s**4 + 998046875*p**3*r*s**4 - 1308593750*q**2*r*s**4 - 1367187500*p*r**2*s**4 + 1708984375*p*q*s**5 - 976562500*s**6 + + return b + + @property + def o(self): + p, q, r, s = self.p, self.q, self.r, self.s + o = [0]*6 + + o[5] = -1600*p**10*q**10 - 23600*p**7*q**12 - 86400*p**4*q**14 + 24800*p**11*q**8*r + 419200*p**8*q**10*r + 1850450*p**5*q**12*r + 896400*p**2*q**14*r - 138800*p**12*q**6*r**2 - 2921900*p**9*q**8*r**2 - 17295200*p**6*q**10*r**2 - 27127750*p**3*q**12*r**2 - 26076600*q**14*r**2 + 325800*p**13*q**4*r**3 + 9993850*p**10*q**6*r**3 + 88010500*p**7*q**8*r**3 + 274047650*p**4*q**10*r**3 + 410171400*p*q**12*r**3 - 259200*p**14*q**2*r**4 - 17147100*p**11*q**4*r**4 - 254289150*p**8*q**6*r**4 - 1318548225*p**5*q**8*r**4 - 2633598475*p**2*q**10*r**4 + 12636000*p**12*q**2*r**5 + 388911000*p**9*q**4*r**5 + 3269704725*p**6*q**6*r**5 + 8791192300*p**3*q**8*r**5 + 93560575*q**10*r**5 - 228361600*p**10*q**2*r**6 - 3951199200*p**7*q**4*r**6 - 16276981100*p**4*q**6*r**6 - 1597227000*p*q**8*r**6 + 1947899200*p**8*q**2*r**7 + 17037648000*p**5*q**4*r**7 + 8919740000*p**2*q**6*r**7 - 7672160000*p**6*q**2*r**8 - 15496000000*p**3*q**4*r**8 + 4224000000*q**6*r**8 + 9968000000*p**4*q**2*r**9 - 8640000000*p*q**4*r**9 + 4800000000*p**2*q**2*r**10 - 55200*p**12*q**7*s - 685600*p**9*q**9*s + 1028250*p**6*q**11*s + 37650000*p**3*q**13*s + 111375000*q**15*s + 583200*p**13*q**5*r*s + 9075600*p**10*q**7*r*s - 883150*p**7*q**9*r*s - 506830750*p**4*q**11*r*s - 1793137500*p*q**13*r*s - 1852200*p**14*q**3*r**2*s - 41435250*p**11*q**5*r**2*s - 80566700*p**8*q**7*r**2*s + 2485673600*p**5*q**9*r**2*s + 11442286125*p**2*q**11*r**2*s + 1555200*p**15*q*r**3*s + 80846100*p**12*q**3*r**3*s + 564906800*p**9*q**5*r**3*s - 4493012400*p**6*q**7*r**3*s - 35492391250*p**3*q**9*r**3*s - 789931875*q**11*r**3*s - 71766000*p**13*q*r**4*s - 1551149200*p**10*q**3*r**4*s - 1773437900*p**7*q**5*r**4*s + 51957593125*p**4*q**7*r**4*s + 14964765625*p*q**9*r**4*s + 1231569600*p**11*q*r**5*s + 12042977600*p**8*q**3*r**5*s - 27151011200*p**5*q**5*r**5*s - 88080610000*p**2*q**7*r**5*s - 9912995200*p**9*q*r**6*s - 29448104000*p**6*q**3*r**6*s + 144954840000*p**3*q**5*r**6*s - 44601300000*q**7*r**6*s + 35453760000*p**7*q*r**7*s - 63264000000*p**4*q**3*r**7*s + 60544000000*p*q**5*r**7*s - 30048000000*p**5*q*r**8*s + 37040000000*p**2*q**3*r**8*s - 60800000000*p**3*q*r**9*s - 48000000000*q**3*r**9*s - 615600*p**14*q**4*s**2 - 10524500*p**11*q**6*s**2 - 33831250*p**8*q**8*s**2 + 222806250*p**5*q**10*s**2 + 1099687500*p**2*q**12*s**2 + 3353400*p**15*q**2*r*s**2 + 74269350*p**12*q**4*r*s**2 + 276445750*p**9*q**6*r*s**2 - 2618600000*p**6*q**8*r*s**2 - 14473243750*p**3*q**10*r*s**2 + 1383750000*q**12*r*s**2 - 2332800*p**16*r**2*s**2 - 132750900*p**13*q**2*r**2*s**2 - 900775150*p**10*q**4*r**2*s**2 + 8249244500*p**7*q**6*r**2*s**2 + 59525796875*p**4*q**8*r**2*s**2 - 40292868750*p*q**10*r**2*s**2 + 128304000*p**14*r**3*s**2 + 3160232100*p**11*q**2*r**3*s**2 + 8329580000*p**8*q**4*r**3*s**2 - 45558458750*p**5*q**6*r**3*s**2 + 297252890625*p**2*q**8*r**3*s**2 - 2769854400*p**12*r**4*s**2 - 37065970000*p**9*q**2*r**4*s**2 - 90812546875*p**6*q**4*r**4*s**2 - 627902000000*p**3*q**6*r**4*s**2 + 181347421875*q**8*r**4*s**2 + 30946932800*p**10*r**5*s**2 + 249954680000*p**7*q**2*r**5*s**2 + 802954812500*p**4*q**4*r**5*s**2 - 80900000000*p*q**6*r**5*s**2 - 192137320000*p**8*r**6*s**2 - 932641600000*p**5*q**2*r**6*s**2 - 943242500000*p**2*q**4*r**6*s**2 + 658412000000*p**6*r**7*s**2 + 1930720000000*p**3*q**2*r**7*s**2 + 593800000000*q**4*r**7*s**2 - 1162800000000*p**4*r**8*s**2 - 280000000000*p*q**2*r**8*s**2 + 840000000000*p**2*r**9*s**2 - 2187000*p**16*q*s**3 - 47418750*p**13*q**3*s**3 - 180618750*p**10*q**5*s**3 + 2231250000*p**7*q**7*s**3 + 17857734375*p**4*q**9*s**3 + 29882812500*p*q**11*s**3 + 24664500*p**14*q*r*s**3 - 853368750*p**11*q**3*r*s**3 - 25939693750*p**8*q**5*r*s**3 - 177541562500*p**5*q**7*r*s**3 - 297978828125*p**2*q**9*r*s**3 - 153468000*p**12*q*r**2*s**3 + 30188125000*p**9*q**3*r**2*s**3 + 344049821875*p**6*q**5*r**2*s**3 + 534026875000*p**3*q**7*r**2*s**3 - 340726484375*q**9*r**2*s**3 - 9056190000*p**10*q*r**3*s**3 - 322314687500*p**7*q**3*r**3*s**3 - 769632109375*p**4*q**5*r**3*s**3 - 83276875000*p*q**7*r**3*s**3 + 164061000000*p**8*q*r**4*s**3 + 1381358750000*p**5*q**3*r**4*s**3 + 3088020000000*p**2*q**5*r**4*s**3 - 1267655000000*p**6*q*r**5*s**3 - 7642630000000*p**3*q**3*r**5*s**3 - 2759877500000*q**5*r**5*s**3 + 4597760000000*p**4*q*r**6*s**3 + 1846200000000*p*q**3*r**6*s**3 - 7006000000000*p**2*q*r**7*s**3 - 1200000000000*q*r**8*s**3 + 18225000*p**15*s**4 + 1328906250*p**12*q**2*s**4 + 24729140625*p**9*q**4*s**4 + 169467187500*p**6*q**6*s**4 + 413281250000*p**3*q**8*s**4 + 223828125000*q**10*s**4 + 710775000*p**13*r*s**4 - 18611015625*p**10*q**2*r*s**4 - 314344375000*p**7*q**4*r*s**4 - 828439843750*p**4*q**6*r*s**4 + 460937500000*p*q**8*r*s**4 - 25674975000*p**11*r**2*s**4 - 52223515625*p**8*q**2*r**2*s**4 - 387160000000*p**5*q**4*r**2*s**4 - 4733680078125*p**2*q**6*r**2*s**4 + 343911875000*p**9*r**3*s**4 + 3328658359375*p**6*q**2*r**3*s**4 + 16532406250000*p**3*q**4*r**3*s**4 + 5980613281250*q**6*r**3*s**4 - 2295497500000*p**7*r**4*s**4 - 14809820312500*p**4*q**2*r**4*s**4 - 6491406250000*p*q**4*r**4*s**4 + 7768470000000*p**5*r**5*s**4 + 34192562500000*p**2*q**2*r**5*s**4 - 11859000000000*p**3*r**6*s**4 + 10530000000000*q**2*r**6*s**4 + 6000000000000*p*r**7*s**4 + 11453906250*p**11*q*s**5 + 149765625000*p**8*q**3*s**5 + 545537109375*p**5*q**5*s**5 + 527343750000*p**2*q**7*s**5 - 371313281250*p**9*q*r*s**5 - 3461455078125*p**6*q**3*r*s**5 - 7920878906250*p**3*q**5*r*s**5 - 4747314453125*q**7*r*s**5 + 2417815625000*p**7*q*r**2*s**5 + 5465576171875*p**4*q**3*r**2*s**5 + 5937128906250*p*q**5*r**2*s**5 - 10661156250000*p**5*q*r**3*s**5 - 63574218750000*p**2*q**3*r**3*s**5 + 24059375000000*p**3*q*r**4*s**5 - 33023437500000*q**3*r**4*s**5 - 43125000000000*p*q*r**5*s**5 + 94394531250*p**10*s**6 + 1097167968750*p**7*q**2*s**6 + 2829833984375*p**4*q**4*s**6 - 1525878906250*p*q**6*s**6 + 2724609375*p**8*r*s**6 + 13998535156250*p**5*q**2*r*s**6 + 57094482421875*p**2*q**4*r*s**6 - 8512509765625*p**6*r**2*s**6 - 37941406250000*p**3*q**2*r**2*s**6 + 33191894531250*q**4*r**2*s**6 + 50534179687500*p**4*r**3*s**6 + 156656250000000*p*q**2*r**3*s**6 - 85023437500000*p**2*r**4*s**6 + 10125000000000*r**5*s**6 - 2717285156250*p**6*q*s**7 - 11352539062500*p**3*q**3*s**7 - 2593994140625*q**5*s**7 - 47154541015625*p**4*q*r*s**7 - 160644531250000*p*q**3*r*s**7 + 142500000000000*p**2*q*r**2*s**7 - 26757812500000*q*r**3*s**7 - 4364013671875*p**5*s**8 - 94604492187500*p**2*q**2*s**8 + 114379882812500*p**3*r*s**8 + 51116943359375*q**2*r*s**8 - 346435546875000*p*r**2*s**8 + 476837158203125*p*q*s**9 - 476837158203125*s**10 + + o[4] = 1600*p**11*q**8 + 20800*p**8*q**10 + 45100*p**5*q**12 - 151200*p**2*q**14 - 19200*p**12*q**6*r - 293200*p**9*q**8*r - 794600*p**6*q**10*r + 2634675*p**3*q**12*r + 2640600*q**14*r + 75600*p**13*q**4*r**2 + 1529100*p**10*q**6*r**2 + 6233350*p**7*q**8*r**2 - 12013350*p**4*q**10*r**2 - 29069550*p*q**12*r**2 - 97200*p**14*q**2*r**3 - 3562500*p**11*q**4*r**3 - 26984900*p**8*q**6*r**3 - 15900325*p**5*q**8*r**3 + 76267100*p**2*q**10*r**3 + 3272400*p**12*q**2*r**4 + 59486850*p**9*q**4*r**4 + 221270075*p**6*q**6*r**4 + 74065250*p**3*q**8*r**4 - 300564375*q**10*r**4 - 45569400*p**10*q**2*r**5 - 438666000*p**7*q**4*r**5 - 444821250*p**4*q**6*r**5 + 2448256250*p*q**8*r**5 + 290640000*p**8*q**2*r**6 + 855850000*p**5*q**4*r**6 - 5741875000*p**2*q**6*r**6 - 644000000*p**6*q**2*r**7 + 5574000000*p**3*q**4*r**7 + 4643000000*q**6*r**7 - 1696000000*p**4*q**2*r**8 - 12660000000*p*q**4*r**8 + 7200000000*p**2*q**2*r**9 + 43200*p**13*q**5*s + 572000*p**10*q**7*s - 59800*p**7*q**9*s - 24174625*p**4*q**11*s - 74587500*p*q**13*s - 324000*p**14*q**3*r*s - 5531400*p**11*q**5*r*s - 3712100*p**8*q**7*r*s + 293009275*p**5*q**9*r*s + 1115548875*p**2*q**11*r*s + 583200*p**15*q*r**2*s + 18343800*p**12*q**3*r**2*s + 77911100*p**9*q**5*r**2*s - 957488825*p**6*q**7*r**2*s - 5449661250*p**3*q**9*r**2*s + 960120000*q**11*r**2*s - 23684400*p**13*q*r**3*s - 373761900*p**10*q**3*r**3*s - 27944975*p**7*q**5*r**3*s + 10375740625*p**4*q**7*r**3*s - 4649093750*p*q**9*r**3*s + 395816400*p**11*q*r**4*s + 2910968000*p**8*q**3*r**4*s - 9126162500*p**5*q**5*r**4*s - 11696118750*p**2*q**7*r**4*s - 3028640000*p**9*q*r**5*s - 3251550000*p**6*q**3*r**5*s + 47914250000*p**3*q**5*r**5*s - 30255625000*q**7*r**5*s + 9304000000*p**7*q*r**6*s - 42970000000*p**4*q**3*r**6*s + 31475000000*p*q**5*r**6*s + 2176000000*p**5*q*r**7*s + 62100000000*p**2*q**3*r**7*s - 43200000000*p**3*q*r**8*s - 72000000000*q**3*r**8*s + 291600*p**15*q**2*s**2 + 2702700*p**12*q**4*s**2 - 38692250*p**9*q**6*s**2 - 538903125*p**6*q**8*s**2 - 1613112500*p**3*q**10*s**2 + 320625000*q**12*s**2 - 874800*p**16*r*s**2 - 14166900*p**13*q**2*r*s**2 + 193284900*p**10*q**4*r*s**2 + 3688520500*p**7*q**6*r*s**2 + 11613390625*p**4*q**8*r*s**2 - 15609881250*p*q**10*r*s**2 + 44031600*p**14*r**2*s**2 + 482345550*p**11*q**2*r**2*s**2 - 2020881875*p**8*q**4*r**2*s**2 - 7407026250*p**5*q**6*r**2*s**2 + 136175750000*p**2*q**8*r**2*s**2 - 1000884600*p**12*r**3*s**2 - 8888950000*p**9*q**2*r**3*s**2 - 30101703125*p**6*q**4*r**3*s**2 - 319761000000*p**3*q**6*r**3*s**2 + 51519218750*q**8*r**3*s**2 + 12622395000*p**10*r**4*s**2 + 97032450000*p**7*q**2*r**4*s**2 + 469929218750*p**4*q**4*r**4*s**2 + 291342187500*p*q**6*r**4*s**2 - 96382000000*p**8*r**5*s**2 - 598070000000*p**5*q**2*r**5*s**2 - 1165021875000*p**2*q**4*r**5*s**2 + 446500000000*p**6*r**6*s**2 + 1651500000000*p**3*q**2*r**6*s**2 + 789375000000*q**4*r**6*s**2 - 1152000000000*p**4*r**7*s**2 - 600000000000*p*q**2*r**7*s**2 + 1260000000000*p**2*r**8*s**2 - 24786000*p**14*q*s**3 - 660487500*p**11*q**3*s**3 - 5886356250*p**8*q**5*s**3 - 18137187500*p**5*q**7*s**3 - 5120546875*p**2*q**9*s**3 + 827658000*p**12*q*r*s**3 + 13343062500*p**9*q**3*r*s**3 + 39782068750*p**6*q**5*r*s**3 - 111288437500*p**3*q**7*r*s**3 - 15438750000*q**9*r*s**3 - 14540782500*p**10*q*r**2*s**3 - 135889750000*p**7*q**3*r**2*s**3 - 176892578125*p**4*q**5*r**2*s**3 - 934462656250*p*q**7*r**2*s**3 + 171669250000*p**8*q*r**3*s**3 + 1164538125000*p**5*q**3*r**3*s**3 + 3192346406250*p**2*q**5*r**3*s**3 - 1295476250000*p**6*q*r**4*s**3 - 6540712500000*p**3*q**3*r**4*s**3 - 2957828125000*q**5*r**4*s**3 + 5366750000000*p**4*q*r**5*s**3 + 3165000000000*p*q**3*r**5*s**3 - 8862500000000*p**2*q*r**6*s**3 - 1800000000000*q*r**7*s**3 + 236925000*p**13*s**4 + 8895234375*p**10*q**2*s**4 + 106180781250*p**7*q**4*s**4 + 474221875000*p**4*q**6*s**4 + 616210937500*p*q**8*s**4 - 6995868750*p**11*r*s**4 - 184190625000*p**8*q**2*r*s**4 - 1299254453125*p**5*q**4*r*s**4 - 2475458593750*p**2*q**6*r*s**4 + 63049218750*p**9*r**2*s**4 + 1646791484375*p**6*q**2*r**2*s**4 + 9086886718750*p**3*q**4*r**2*s**4 + 4673421875000*q**6*r**2*s**4 - 215665000000*p**7*r**3*s**4 - 7864589843750*p**4*q**2*r**3*s**4 - 5987890625000*p*q**4*r**3*s**4 + 594843750000*p**5*r**4*s**4 + 27791171875000*p**2*q**2*r**4*s**4 - 3881250000000*p**3*r**5*s**4 + 12203125000000*q**2*r**5*s**4 + 10312500000000*p*r**6*s**4 - 34720312500*p**9*q*s**5 - 545126953125*p**6*q**3*s**5 - 2176425781250*p**3*q**5*s**5 - 2792968750000*q**7*s**5 - 1395703125*p**7*q*r*s**5 - 1957568359375*p**4*q**3*r*s**5 + 5122636718750*p*q**5*r*s**5 + 858210937500*p**5*q*r**2*s**5 - 42050097656250*p**2*q**3*r**2*s**5 + 7088281250000*p**3*q*r**3*s**5 - 25974609375000*q**3*r**3*s**5 - 69296875000000*p*q*r**4*s**5 + 384697265625*p**8*s**6 + 6403320312500*p**5*q**2*s**6 + 16742675781250*p**2*q**4*s**6 - 3467080078125*p**6*r*s**6 + 11009765625000*p**3*q**2*r*s**6 + 16451660156250*q**4*r*s**6 + 6979003906250*p**4*r**2*s**6 + 145403320312500*p*q**2*r**2*s**6 + 4076171875000*p**2*r**3*s**6 + 22265625000000*r**4*s**6 - 21915283203125*p**4*q*s**7 - 86608886718750*p*q**3*s**7 - 22785644531250*p**2*q*r*s**7 - 103466796875000*q*r**2*s**7 + 18798828125000*p**3*s**8 + 106048583984375*q**2*s**8 + 17761230468750*p*r*s**8 + + o[3] = 2800*p**9*q**8 + 55700*p**6*q**10 + 363600*p**3*q**12 + 777600*q**14 - 27200*p**10*q**6*r - 700200*p**7*q**8*r - 5726550*p**4*q**10*r - 15066000*p*q**12*r + 74700*p**11*q**4*r**2 + 2859575*p**8*q**6*r**2 + 31175725*p**5*q**8*r**2 + 103147650*p**2*q**10*r**2 - 40500*p**12*q**2*r**3 - 4274400*p**9*q**4*r**3 - 76065825*p**6*q**6*r**3 - 365623750*p**3*q**8*r**3 - 132264000*q**10*r**3 + 2192400*p**10*q**2*r**4 + 92562500*p**7*q**4*r**4 + 799193875*p**4*q**6*r**4 + 1188193125*p*q**8*r**4 - 41231500*p**8*q**2*r**5 - 914210000*p**5*q**4*r**5 - 3318853125*p**2*q**6*r**5 + 398850000*p**6*q**2*r**6 + 3944000000*p**3*q**4*r**6 + 2211312500*q**6*r**6 - 1817000000*p**4*q**2*r**7 - 6720000000*p*q**4*r**7 + 3900000000*p**2*q**2*r**8 + 75600*p**11*q**5*s + 1823100*p**8*q**7*s + 14534150*p**5*q**9*s + 38265750*p**2*q**11*s - 394200*p**12*q**3*r*s - 11453850*p**9*q**5*r*s - 101213000*p**6*q**7*r*s - 223565625*p**3*q**9*r*s + 415125000*q**11*r*s + 243000*p**13*q*r**2*s + 13654575*p**10*q**3*r**2*s + 163811725*p**7*q**5*r**2*s + 173461250*p**4*q**7*r**2*s - 3008671875*p*q**9*r**2*s - 2016900*p**11*q*r**3*s - 86576250*p**8*q**3*r**3*s - 324146625*p**5*q**5*r**3*s + 3378506250*p**2*q**7*r**3*s - 89211000*p**9*q*r**4*s - 55207500*p**6*q**3*r**4*s + 1493950000*p**3*q**5*r**4*s - 12573609375*q**7*r**4*s + 1140100000*p**7*q*r**5*s + 42500000*p**4*q**3*r**5*s + 21511250000*p*q**5*r**5*s - 4058000000*p**5*q*r**6*s + 6725000000*p**2*q**3*r**6*s - 1400000000*p**3*q*r**7*s - 39000000000*q**3*r**7*s + 510300*p**13*q**2*s**2 + 4814775*p**10*q**4*s**2 - 70265125*p**7*q**6*s**2 - 1016484375*p**4*q**8*s**2 - 3221100000*p*q**10*s**2 - 364500*p**14*r*s**2 + 30314250*p**11*q**2*r*s**2 + 1106765625*p**8*q**4*r*s**2 + 10984203125*p**5*q**6*r*s**2 + 33905812500*p**2*q**8*r*s**2 - 37980900*p**12*r**2*s**2 - 2142905625*p**9*q**2*r**2*s**2 - 26896125000*p**6*q**4*r**2*s**2 - 95551328125*p**3*q**6*r**2*s**2 + 11320312500*q**8*r**2*s**2 + 1743781500*p**10*r**3*s**2 + 35432262500*p**7*q**2*r**3*s**2 + 177855859375*p**4*q**4*r**3*s**2 + 121260546875*p*q**6*r**3*s**2 - 25943162500*p**8*r**4*s**2 - 249165500000*p**5*q**2*r**4*s**2 - 461739453125*p**2*q**4*r**4*s**2 + 177823750000*p**6*r**5*s**2 + 726225000000*p**3*q**2*r**5*s**2 + 404195312500*q**4*r**5*s**2 - 565875000000*p**4*r**6*s**2 - 407500000000*p*q**2*r**6*s**2 + 682500000000*p**2*r**7*s**2 - 59140125*p**12*q*s**3 - 1290515625*p**9*q**3*s**3 - 8785071875*p**6*q**5*s**3 - 15588281250*p**3*q**7*s**3 + 17505000000*q**9*s**3 + 896062500*p**10*q*r*s**3 + 2589750000*p**7*q**3*r*s**3 - 82700156250*p**4*q**5*r*s**3 - 347683593750*p*q**7*r*s**3 + 17022656250*p**8*q*r**2*s**3 + 320923593750*p**5*q**3*r**2*s**3 + 1042116875000*p**2*q**5*r**2*s**3 - 353262812500*p**6*q*r**3*s**3 - 2212664062500*p**3*q**3*r**3*s**3 - 1252408984375*q**5*r**3*s**3 + 1967362500000*p**4*q*r**4*s**3 + 1583343750000*p*q**3*r**4*s**3 - 3560625000000*p**2*q*r**5*s**3 - 975000000000*q*r**6*s**3 + 462459375*p**11*s**4 + 14210859375*p**8*q**2*s**4 + 99521718750*p**5*q**4*s**4 + 114955468750*p**2*q**6*s**4 - 17720859375*p**9*r*s**4 - 100320703125*p**6*q**2*r*s**4 + 1021943359375*p**3*q**4*r*s**4 + 1193203125000*q**6*r*s**4 + 171371250000*p**7*r**2*s**4 - 1113390625000*p**4*q**2*r**2*s**4 - 1211474609375*p*q**4*r**2*s**4 - 274056250000*p**5*r**3*s**4 + 8285166015625*p**2*q**2*r**3*s**4 - 2079375000000*p**3*r**4*s**4 + 5137304687500*q**2*r**4*s**4 + 6187500000000*p*r**5*s**4 - 135675000000*p**7*q*s**5 - 1275244140625*p**4*q**3*s**5 - 28388671875*p*q**5*s**5 + 1015166015625*p**5*q*r*s**5 - 10584423828125*p**2*q**3*r*s**5 + 3559570312500*p**3*q*r**2*s**5 - 6929931640625*q**3*r**2*s**5 - 32304687500000*p*q*r**3*s**5 + 430576171875*p**6*s**6 + 9397949218750*p**3*q**2*s**6 + 575195312500*q**4*s**6 - 4086425781250*p**4*r*s**6 + 42183837890625*p*q**2*r*s**6 + 8156494140625*p**2*r**2*s**6 + 12612304687500*r**3*s**6 - 25513916015625*p**2*q*s**7 - 37017822265625*q*r*s**7 + 18981933593750*p*s**8 + + o[2] = 1600*p**10*q**6 + 9200*p**7*q**8 - 126000*p**4*q**10 - 777600*p*q**12 - 14400*p**11*q**4*r - 119300*p**8*q**6*r + 1203225*p**5*q**8*r + 9412200*p**2*q**10*r + 32400*p**12*q**2*r**2 + 417950*p**9*q**4*r**2 - 4543725*p**6*q**6*r**2 - 49008125*p**3*q**8*r**2 - 24192000*q**10*r**2 - 292050*p**10*q**2*r**3 + 8760000*p**7*q**4*r**3 + 137506625*p**4*q**6*r**3 + 225438750*p*q**8*r**3 - 4213250*p**8*q**2*r**4 - 173595625*p**5*q**4*r**4 - 653003125*p**2*q**6*r**4 + 82575000*p**6*q**2*r**5 + 838125000*p**3*q**4*r**5 + 578562500*q**6*r**5 - 421500000*p**4*q**2*r**6 - 1796250000*p*q**4*r**6 + 1050000000*p**2*q**2*r**7 + 43200*p**12*q**3*s + 807300*p**9*q**5*s + 5328225*p**6*q**7*s + 16946250*p**3*q**9*s + 29565000*q**11*s - 194400*p**13*q*r*s - 5505300*p**10*q**3*r*s - 49886700*p**7*q**5*r*s - 178821875*p**4*q**7*r*s - 222750000*p*q**9*r*s + 6814800*p**11*q*r**2*s + 120525625*p**8*q**3*r**2*s + 526694500*p**5*q**5*r**2*s + 84065625*p**2*q**7*r**2*s - 123670500*p**9*q*r**3*s - 1106731875*p**6*q**3*r**3*s - 669556250*p**3*q**5*r**3*s - 2869265625*q**7*r**3*s + 1004350000*p**7*q*r**4*s + 3384375000*p**4*q**3*r**4*s + 5665625000*p*q**5*r**4*s - 3411000000*p**5*q*r**5*s - 418750000*p**2*q**3*r**5*s + 1700000000*p**3*q*r**6*s - 10500000000*q**3*r**6*s + 291600*p**14*s**2 + 9829350*p**11*q**2*s**2 + 114151875*p**8*q**4*s**2 + 522169375*p**5*q**6*s**2 + 716906250*p**2*q**8*s**2 - 18625950*p**12*r*s**2 - 387703125*p**9*q**2*r*s**2 - 2056109375*p**6*q**4*r*s**2 - 760203125*p**3*q**6*r*s**2 + 3071250000*q**8*r*s**2 + 512419500*p**10*r**2*s**2 + 5859053125*p**7*q**2*r**2*s**2 + 12154062500*p**4*q**4*r**2*s**2 + 15931640625*p*q**6*r**2*s**2 - 6598393750*p**8*r**3*s**2 - 43549625000*p**5*q**2*r**3*s**2 - 82011328125*p**2*q**4*r**3*s**2 + 43538125000*p**6*r**4*s**2 + 160831250000*p**3*q**2*r**4*s**2 + 99070312500*q**4*r**4*s**2 - 141812500000*p**4*r**5*s**2 - 117500000000*p*q**2*r**5*s**2 + 183750000000*p**2*r**6*s**2 - 154608750*p**10*q*s**3 - 3309468750*p**7*q**3*s**3 - 20834140625*p**4*q**5*s**3 - 34731562500*p*q**7*s**3 + 5970375000*p**8*q*r*s**3 + 68533281250*p**5*q**3*r*s**3 + 142698281250*p**2*q**5*r*s**3 - 74509140625*p**6*q*r**2*s**3 - 389148437500*p**3*q**3*r**2*s**3 - 270937890625*q**5*r**2*s**3 + 366696875000*p**4*q*r**3*s**3 + 400031250000*p*q**3*r**3*s**3 - 735156250000*p**2*q*r**4*s**3 - 262500000000*q*r**5*s**3 + 371250000*p**9*s**4 + 21315000000*p**6*q**2*s**4 + 179515625000*p**3*q**4*s**4 + 238406250000*q**6*s**4 - 9071015625*p**7*r*s**4 - 268945312500*p**4*q**2*r*s**4 - 379785156250*p*q**4*r*s**4 + 140262890625*p**5*r**2*s**4 + 1486259765625*p**2*q**2*r**2*s**4 - 806484375000*p**3*r**3*s**4 + 1066210937500*q**2*r**3*s**4 + 1722656250000*p*r**4*s**4 - 125648437500*p**5*q*s**5 - 1236279296875*p**2*q**3*s**5 + 1267871093750*p**3*q*r*s**5 - 1044677734375*q**3*r*s**5 - 6630859375000*p*q*r**2*s**5 + 160888671875*p**4*s**6 + 6352294921875*p*q**2*s**6 - 708740234375*p**2*r*s**6 + 3901367187500*r**2*s**6 - 8050537109375*q*s**7 + + o[1] = 2800*p**8*q**6 + 41300*p**5*q**8 + 151200*p**2*q**10 - 25200*p**9*q**4*r - 542600*p**6*q**6*r - 3397875*p**3*q**8*r - 5751000*q**10*r + 56700*p**10*q**2*r**2 + 1972125*p**7*q**4*r**2 + 18624250*p**4*q**6*r**2 + 50253750*p*q**8*r**2 - 1701000*p**8*q**2*r**3 - 32630625*p**5*q**4*r**3 - 139868750*p**2*q**6*r**3 + 18162500*p**6*q**2*r**4 + 177125000*p**3*q**4*r**4 + 121734375*q**6*r**4 - 100500000*p**4*q**2*r**5 - 386250000*p*q**4*r**5 + 225000000*p**2*q**2*r**6 + 75600*p**10*q**3*s + 1708800*p**7*q**5*s + 12836875*p**4*q**7*s + 32062500*p*q**9*s - 340200*p**11*q*r*s - 10185750*p**8*q**3*r*s - 97502750*p**5*q**5*r*s - 301640625*p**2*q**7*r*s + 7168500*p**9*q*r**2*s + 135960625*p**6*q**3*r**2*s + 587471875*p**3*q**5*r**2*s - 384750000*q**7*r**2*s - 29325000*p**7*q*r**3*s - 320625000*p**4*q**3*r**3*s + 523437500*p*q**5*r**3*s - 42000000*p**5*q*r**4*s + 343750000*p**2*q**3*r**4*s + 150000000*p**3*q*r**5*s - 2250000000*q**3*r**5*s + 510300*p**12*s**2 + 12808125*p**9*q**2*s**2 + 107062500*p**6*q**4*s**2 + 270312500*p**3*q**6*s**2 - 168750000*q**8*s**2 - 2551500*p**10*r*s**2 - 5062500*p**7*q**2*r*s**2 + 712343750*p**4*q**4*r*s**2 + 4788281250*p*q**6*r*s**2 - 256837500*p**8*r**2*s**2 - 3574812500*p**5*q**2*r**2*s**2 - 14967968750*p**2*q**4*r**2*s**2 + 4040937500*p**6*r**3*s**2 + 26400000000*p**3*q**2*r**3*s**2 + 17083984375*q**4*r**3*s**2 - 21812500000*p**4*r**4*s**2 - 24375000000*p*q**2*r**4*s**2 + 39375000000*p**2*r**5*s**2 - 127265625*p**5*q**3*s**3 - 680234375*p**2*q**5*s**3 - 2048203125*p**6*q*r*s**3 - 18794531250*p**3*q**3*r*s**3 - 25050000000*q**5*r*s**3 + 26621875000*p**4*q*r**2*s**3 + 37007812500*p*q**3*r**2*s**3 - 105468750000*p**2*q*r**3*s**3 - 56250000000*q*r**4*s**3 + 1124296875*p**7*s**4 + 9251953125*p**4*q**2*s**4 - 8007812500*p*q**4*s**4 - 4004296875*p**5*r*s**4 + 179931640625*p**2*q**2*r*s**4 - 75703125000*p**3*r**2*s**4 + 133447265625*q**2*r**2*s**4 + 363281250000*p*r**3*s**4 - 91552734375*p**3*q*s**5 - 19531250000*q**3*s**5 - 751953125000*p*q*r*s**5 + 157958984375*p**2*s**6 + 748291015625*r*s**6 + + o[0] = -14400*p**6*q**6 - 212400*p**3*q**8 - 777600*q**10 + 92100*p**7*q**4*r + 1689675*p**4*q**6*r + 7371000*p*q**8*r - 122850*p**8*q**2*r**2 - 3735250*p**5*q**4*r**2 - 22432500*p**2*q**6*r**2 + 2298750*p**6*q**2*r**3 + 29390625*p**3*q**4*r**3 + 18000000*q**6*r**3 - 17750000*p**4*q**2*r**4 - 62812500*p*q**4*r**4 + 37500000*p**2*q**2*r**5 - 51300*p**8*q**3*s - 768025*p**5*q**5*s - 2801250*p**2*q**7*s - 275400*p**9*q*r*s - 5479875*p**6*q**3*r*s - 35538750*p**3*q**5*r*s - 68850000*q**7*r*s + 12757500*p**7*q*r**2*s + 133640625*p**4*q**3*r**2*s + 222609375*p*q**5*r**2*s - 108500000*p**5*q*r**3*s - 290312500*p**2*q**3*r**3*s + 275000000*p**3*q*r**4*s - 375000000*q**3*r**4*s + 1931850*p**10*s**2 + 40213125*p**7*q**2*s**2 + 253921875*p**4*q**4*s**2 + 464062500*p*q**6*s**2 - 71077500*p**8*r*s**2 - 818746875*p**5*q**2*r*s**2 - 1882265625*p**2*q**4*r*s**2 + 826031250*p**6*r**2*s**2 + 4369687500*p**3*q**2*r**2*s**2 + 3107812500*q**4*r**2*s**2 - 3943750000*p**4*r**3*s**2 - 5000000000*p*q**2*r**3*s**2 + 6562500000*p**2*r**4*s**2 - 295312500*p**6*q*s**3 - 2938906250*p**3*q**3*s**3 - 4848750000*q**5*s**3 + 3791484375*p**4*q*r*s**3 + 7556250000*p*q**3*r*s**3 - 11960937500*p**2*q*r**2*s**3 - 9375000000*q*r**3*s**3 + 1668515625*p**5*s**4 + 20447265625*p**2*q**2*s**4 - 21955078125*p**3*r*s**4 + 18984375000*q**2*r*s**4 + 67382812500*p*r**2*s**4 - 120849609375*p*q*s**5 + 157226562500*s**6 + + return o + + @property + def a(self): + p, q, r, s = self.p, self.q, self.r, self.s + a = [0]*6 + + a[5] = -100*p**7*q**7 - 2175*p**4*q**9 - 10500*p*q**11 + 1100*p**8*q**5*r + 27975*p**5*q**7*r + 152950*p**2*q**9*r - 4125*p**9*q**3*r**2 - 128875*p**6*q**5*r**2 - 830525*p**3*q**7*r**2 + 59450*q**9*r**2 + 5400*p**10*q*r**3 + 243800*p**7*q**3*r**3 + 2082650*p**4*q**5*r**3 - 333925*p*q**7*r**3 - 139200*p**8*q*r**4 - 2406000*p**5*q**3*r**4 - 122600*p**2*q**5*r**4 + 1254400*p**6*q*r**5 + 3776000*p**3*q**3*r**5 + 1832000*q**5*r**5 - 4736000*p**4*q*r**6 - 6720000*p*q**3*r**6 + 6400000*p**2*q*r**7 - 900*p**9*q**4*s - 37400*p**6*q**6*s - 281625*p**3*q**8*s - 435000*q**10*s + 6750*p**10*q**2*r*s + 322300*p**7*q**4*r*s + 2718575*p**4*q**6*r*s + 4214250*p*q**8*r*s - 16200*p**11*r**2*s - 859275*p**8*q**2*r**2*s - 8925475*p**5*q**4*r**2*s - 14427875*p**2*q**6*r**2*s + 453600*p**9*r**3*s + 10038400*p**6*q**2*r**3*s + 17397500*p**3*q**4*r**3*s - 11333125*q**6*r**3*s - 4451200*p**7*r**4*s - 15850000*p**4*q**2*r**4*s + 34000000*p*q**4*r**4*s + 17984000*p**5*r**5*s - 10000000*p**2*q**2*r**5*s - 25600000*p**3*r**6*s - 8000000*q**2*r**6*s + 6075*p**11*q*s**2 - 83250*p**8*q**3*s**2 - 1282500*p**5*q**5*s**2 - 2862500*p**2*q**7*s**2 + 724275*p**9*q*r*s**2 + 9807250*p**6*q**3*r*s**2 + 28374375*p**3*q**5*r*s**2 + 22212500*q**7*r*s**2 - 8982000*p**7*q*r**2*s**2 - 39600000*p**4*q**3*r**2*s**2 - 61746875*p*q**5*r**2*s**2 - 1010000*p**5*q*r**3*s**2 - 1000000*p**2*q**3*r**3*s**2 + 78000000*p**3*q*r**4*s**2 + 30000000*q**3*r**4*s**2 + 80000000*p*q*r**5*s**2 - 759375*p**10*s**3 - 9787500*p**7*q**2*s**3 - 39062500*p**4*q**4*s**3 - 52343750*p*q**6*s**3 + 12301875*p**8*r*s**3 + 98175000*p**5*q**2*r*s**3 + 225078125*p**2*q**4*r*s**3 - 54900000*p**6*r**2*s**3 - 310000000*p**3*q**2*r**2*s**3 - 7890625*q**4*r**2*s**3 + 51250000*p**4*r**3*s**3 - 420000000*p*q**2*r**3*s**3 + 110000000*p**2*r**4*s**3 - 200000000*r**5*s**3 + 2109375*p**6*q*s**4 - 21093750*p**3*q**3*s**4 - 89843750*q**5*s**4 + 182343750*p**4*q*r*s**4 + 733203125*p*q**3*r*s**4 - 196875000*p**2*q*r**2*s**4 + 1125000000*q*r**3*s**4 - 158203125*p**5*s**5 - 566406250*p**2*q**2*s**5 + 101562500*p**3*r*s**5 - 1669921875*q**2*r*s**5 + 1250000000*p*r**2*s**5 - 1220703125*p*q*s**6 + 6103515625*s**7 + + a[4] = 1000*p**5*q**7 + 7250*p**2*q**9 - 10800*p**6*q**5*r - 96900*p**3*q**7*r - 52500*q**9*r + 37400*p**7*q**3*r**2 + 470850*p**4*q**5*r**2 + 640600*p*q**7*r**2 - 39600*p**8*q*r**3 - 983600*p**5*q**3*r**3 - 2848100*p**2*q**5*r**3 + 814400*p**6*q*r**4 + 6076000*p**3*q**3*r**4 + 2308000*q**5*r**4 - 5024000*p**4*q*r**5 - 9680000*p*q**3*r**5 + 9600000*p**2*q*r**6 + 13800*p**7*q**4*s + 94650*p**4*q**6*s - 26500*p*q**8*s - 86400*p**8*q**2*r*s - 816500*p**5*q**4*r*s - 257500*p**2*q**6*r*s + 91800*p**9*r**2*s + 1853700*p**6*q**2*r**2*s + 630000*p**3*q**4*r**2*s - 8971250*q**6*r**2*s - 2071200*p**7*r**3*s - 7240000*p**4*q**2*r**3*s + 29375000*p*q**4*r**3*s + 14416000*p**5*r**4*s - 5200000*p**2*q**2*r**4*s - 30400000*p**3*r**5*s - 12000000*q**2*r**5*s + 64800*p**9*q*s**2 + 567000*p**6*q**3*s**2 + 1655000*p**3*q**5*s**2 + 6987500*q**7*s**2 + 337500*p**7*q*r*s**2 + 8462500*p**4*q**3*r*s**2 - 5812500*p*q**5*r*s**2 - 24930000*p**5*q*r**2*s**2 - 69125000*p**2*q**3*r**2*s**2 + 103500000*p**3*q*r**3*s**2 + 30000000*q**3*r**3*s**2 + 90000000*p*q*r**4*s**2 - 708750*p**8*s**3 - 5400000*p**5*q**2*s**3 + 8906250*p**2*q**4*s**3 + 18562500*p**6*r*s**3 - 625000*p**3*q**2*r*s**3 + 29687500*q**4*r*s**3 - 75000000*p**4*r**2*s**3 - 416250000*p*q**2*r**2*s**3 + 60000000*p**2*r**3*s**3 - 300000000*r**4*s**3 + 71718750*p**4*q*s**4 + 189062500*p*q**3*s**4 + 210937500*p**2*q*r*s**4 + 1187500000*q*r**2*s**4 - 187500000*p**3*s**5 - 800781250*q**2*s**5 - 390625000*p*r*s**5 + + a[3] = -500*p**6*q**5 - 6350*p**3*q**7 - 19800*q**9 + 3750*p**7*q**3*r + 65100*p**4*q**5*r + 264950*p*q**7*r - 6750*p**8*q*r**2 - 209050*p**5*q**3*r**2 - 1217250*p**2*q**5*r**2 + 219000*p**6*q*r**3 + 2510000*p**3*q**3*r**3 + 1098500*q**5*r**3 - 2068000*p**4*q*r**4 - 5060000*p*q**3*r**4 + 5200000*p**2*q*r**5 - 6750*p**8*q**2*s - 96350*p**5*q**4*s - 346000*p**2*q**6*s + 20250*p**9*r*s + 459900*p**6*q**2*r*s + 1828750*p**3*q**4*r*s - 2930000*q**6*r*s - 594000*p**7*r**2*s - 4301250*p**4*q**2*r**2*s + 10906250*p*q**4*r**2*s + 5252000*p**5*r**3*s - 1450000*p**2*q**2*r**3*s - 12800000*p**3*r**4*s - 6500000*q**2*r**4*s + 74250*p**7*q*s**2 + 1418750*p**4*q**3*s**2 + 5956250*p*q**5*s**2 - 4297500*p**5*q*r*s**2 - 29906250*p**2*q**3*r*s**2 + 31500000*p**3*q*r**2*s**2 + 12500000*q**3*r**2*s**2 + 35000000*p*q*r**3*s**2 + 1350000*p**6*s**3 + 6093750*p**3*q**2*s**3 + 17500000*q**4*s**3 - 7031250*p**4*r*s**3 - 127812500*p*q**2*r*s**3 + 18750000*p**2*r**2*s**3 - 162500000*r**3*s**3 + 107812500*p**2*q*s**4 + 460937500*q*r*s**4 - 214843750*p*s**5 + + a[2] = 1950*p**4*q**5 + 14100*p*q**7 - 14350*p**5*q**3*r - 125600*p**2*q**5*r + 27900*p**6*q*r**2 + 402250*p**3*q**3*r**2 + 288250*q**5*r**2 - 436000*p**4*q*r**3 - 1345000*p*q**3*r**3 + 1400000*p**2*q*r**4 + 9450*p**6*q**2*s - 1250*p**3*q**4*s - 465000*q**6*s - 49950*p**7*r*s - 302500*p**4*q**2*r*s + 1718750*p*q**4*r*s + 834000*p**5*r**2*s + 437500*p**2*q**2*r**2*s - 3100000*p**3*r**3*s - 1750000*q**2*r**3*s - 292500*p**5*q*s**2 - 1937500*p**2*q**3*s**2 + 3343750*p**3*q*r*s**2 + 1875000*q**3*r*s**2 + 8125000*p*q*r**2*s**2 - 1406250*p**4*s**3 - 12343750*p*q**2*s**3 + 5312500*p**2*r*s**3 - 43750000*r**2*s**3 + 74218750*q*s**4 + + a[1] = -300*p**5*q**3 - 2150*p**2*q**5 + 1350*p**6*q*r + 21500*p**3*q**3*r + 61500*q**5*r - 42000*p**4*q*r**2 - 290000*p*q**3*r**2 + 300000*p**2*q*r**3 - 4050*p**7*s - 45000*p**4*q**2*s - 125000*p*q**4*s + 108000*p**5*r*s + 643750*p**2*q**2*r*s - 700000*p**3*r**2*s - 375000*q**2*r**2*s - 93750*p**3*q*s**2 - 312500*q**3*s**2 + 1875000*p*q*r*s**2 - 1406250*p**2*s**3 - 9375000*r*s**3 + + a[0] = 1250*p**3*q**3 + 9000*q**5 - 4500*p**4*q*r - 46250*p*q**3*r + 50000*p**2*q*r**2 + 6750*p**5*s + 43750*p**2*q**2*s - 75000*p**3*r*s - 62500*q**2*r*s + 156250*p*q*s**2 - 1562500*s**3 + + return a + + @property + def c(self): + p, q, r, s = self.p, self.q, self.r, self.s + c = [0]*6 + + c[5] = -40*p**5*q**11 - 270*p**2*q**13 + 700*p**6*q**9*r + 5165*p**3*q**11*r + 540*q**13*r - 4230*p**7*q**7*r**2 - 31845*p**4*q**9*r**2 + 20880*p*q**11*r**2 + 9645*p**8*q**5*r**3 + 57615*p**5*q**7*r**3 - 358255*p**2*q**9*r**3 - 1880*p**9*q**3*r**4 + 114020*p**6*q**5*r**4 + 2012190*p**3*q**7*r**4 - 26855*q**9*r**4 - 14400*p**10*q*r**5 - 470400*p**7*q**3*r**5 - 5088640*p**4*q**5*r**5 + 920*p*q**7*r**5 + 332800*p**8*q*r**6 + 5797120*p**5*q**3*r**6 + 1608000*p**2*q**5*r**6 - 2611200*p**6*q*r**7 - 7424000*p**3*q**3*r**7 - 2323200*q**5*r**7 + 8601600*p**4*q*r**8 + 9472000*p*q**3*r**8 - 10240000*p**2*q*r**9 - 3060*p**7*q**8*s - 39085*p**4*q**10*s - 132300*p*q**12*s + 36580*p**8*q**6*r*s + 520185*p**5*q**8*r*s + 1969860*p**2*q**10*r*s - 144045*p**9*q**4*r**2*s - 2438425*p**6*q**6*r**2*s - 10809475*p**3*q**8*r**2*s + 518850*q**10*r**2*s + 182520*p**10*q**2*r**3*s + 4533930*p**7*q**4*r**3*s + 26196770*p**4*q**6*r**3*s - 4542325*p*q**8*r**3*s + 21600*p**11*r**4*s - 2208080*p**8*q**2*r**4*s - 24787960*p**5*q**4*r**4*s + 10813900*p**2*q**6*r**4*s - 499200*p**9*r**5*s + 3827840*p**6*q**2*r**5*s + 9596000*p**3*q**4*r**5*s + 22662000*q**6*r**5*s + 3916800*p**7*r**6*s - 29952000*p**4*q**2*r**6*s - 90800000*p*q**4*r**6*s - 12902400*p**5*r**7*s + 87040000*p**2*q**2*r**7*s + 15360000*p**3*r**8*s + 12800000*q**2*r**8*s - 38070*p**9*q**5*s**2 - 566700*p**6*q**7*s**2 - 2574375*p**3*q**9*s**2 - 1822500*q**11*s**2 + 292815*p**10*q**3*r*s**2 + 5170280*p**7*q**5*r*s**2 + 27918125*p**4*q**7*r*s**2 + 21997500*p*q**9*r*s**2 - 573480*p**11*q*r**2*s**2 - 14566350*p**8*q**3*r**2*s**2 - 104851575*p**5*q**5*r**2*s**2 - 96448750*p**2*q**7*r**2*s**2 + 11001240*p**9*q*r**3*s**2 + 147798600*p**6*q**3*r**3*s**2 + 158632750*p**3*q**5*r**3*s**2 - 78222500*q**7*r**3*s**2 - 62819200*p**7*q*r**4*s**2 - 136160000*p**4*q**3*r**4*s**2 + 317555000*p*q**5*r**4*s**2 + 160224000*p**5*q*r**5*s**2 - 267600000*p**2*q**3*r**5*s**2 - 153600000*p**3*q*r**6*s**2 - 120000000*q**3*r**6*s**2 - 32000000*p*q*r**7*s**2 - 127575*p**11*q**2*s**3 - 2148750*p**8*q**4*s**3 - 13652500*p**5*q**6*s**3 - 19531250*p**2*q**8*s**3 + 495720*p**12*r*s**3 + 11856375*p**9*q**2*r*s**3 + 107807500*p**6*q**4*r*s**3 + 222334375*p**3*q**6*r*s**3 + 105062500*q**8*r*s**3 - 11566800*p**10*r**2*s**3 - 216787500*p**7*q**2*r**2*s**3 - 633437500*p**4*q**4*r**2*s**3 - 504484375*p*q**6*r**2*s**3 + 90918000*p**8*r**3*s**3 + 567080000*p**5*q**2*r**3*s**3 + 692937500*p**2*q**4*r**3*s**3 - 326640000*p**6*r**4*s**3 - 339000000*p**3*q**2*r**4*s**3 + 369250000*q**4*r**4*s**3 + 560000000*p**4*r**5*s**3 + 508000000*p*q**2*r**5*s**3 - 480000000*p**2*r**6*s**3 + 320000000*r**7*s**3 - 455625*p**10*q*s**4 - 27562500*p**7*q**3*s**4 - 120593750*p**4*q**5*s**4 - 60312500*p*q**7*s**4 + 110615625*p**8*q*r*s**4 + 662984375*p**5*q**3*r*s**4 + 528515625*p**2*q**5*r*s**4 - 541687500*p**6*q*r**2*s**4 - 1262343750*p**3*q**3*r**2*s**4 - 466406250*q**5*r**2*s**4 + 633000000*p**4*q*r**3*s**4 - 1264375000*p*q**3*r**3*s**4 + 1085000000*p**2*q*r**4*s**4 - 2700000000*q*r**5*s**4 - 68343750*p**9*s**5 - 478828125*p**6*q**2*s**5 - 355468750*p**3*q**4*s**5 - 11718750*q**6*s**5 + 718031250*p**7*r*s**5 + 1658593750*p**4*q**2*r*s**5 + 2212890625*p*q**4*r*s**5 - 2855625000*p**5*r**2*s**5 - 4273437500*p**2*q**2*r**2*s**5 + 4537500000*p**3*r**3*s**5 + 8031250000*q**2*r**3*s**5 - 1750000000*p*r**4*s**5 + 1353515625*p**5*q*s**6 + 1562500000*p**2*q**3*s**6 - 3964843750*p**3*q*r*s**6 - 7226562500*q**3*r*s**6 + 1953125000*p*q*r**2*s**6 - 1757812500*p**4*s**7 - 3173828125*p*q**2*s**7 + 6445312500*p**2*r*s**7 - 3906250000*r**2*s**7 + 6103515625*q*s**8 + + c[4] = 40*p**6*q**9 + 110*p**3*q**11 - 1080*q**13 - 560*p**7*q**7*r - 1780*p**4*q**9*r + 17370*p*q**11*r + 2850*p**8*q**5*r**2 + 10520*p**5*q**7*r**2 - 115910*p**2*q**9*r**2 - 6090*p**9*q**3*r**3 - 25330*p**6*q**5*r**3 + 448740*p**3*q**7*r**3 + 128230*q**9*r**3 + 4320*p**10*q*r**4 + 16960*p**7*q**3*r**4 - 1143600*p**4*q**5*r**4 - 1410310*p*q**7*r**4 + 3840*p**8*q*r**5 + 1744480*p**5*q**3*r**5 + 5619520*p**2*q**5*r**5 - 1198080*p**6*q*r**6 - 10579200*p**3*q**3*r**6 - 2940800*q**5*r**6 + 8294400*p**4*q*r**7 + 13568000*p*q**3*r**7 - 15360000*p**2*q*r**8 + 840*p**8*q**6*s + 7580*p**5*q**8*s + 24420*p**2*q**10*s - 8100*p**9*q**4*r*s - 94100*p**6*q**6*r*s - 473000*p**3*q**8*r*s - 473400*q**10*r*s + 22680*p**10*q**2*r**2*s + 374370*p**7*q**4*r**2*s + 2888020*p**4*q**6*r**2*s + 5561050*p*q**8*r**2*s - 12960*p**11*r**3*s - 485820*p**8*q**2*r**3*s - 6723440*p**5*q**4*r**3*s - 23561400*p**2*q**6*r**3*s + 190080*p**9*r**4*s + 5894880*p**6*q**2*r**4*s + 50882000*p**3*q**4*r**4*s + 22411500*q**6*r**4*s - 258560*p**7*r**5*s - 46248000*p**4*q**2*r**5*s - 103800000*p*q**4*r**5*s - 3737600*p**5*r**6*s + 119680000*p**2*q**2*r**6*s + 10240000*p**3*r**7*s + 19200000*q**2*r**7*s + 7290*p**10*q**3*s**2 + 117360*p**7*q**5*s**2 + 691250*p**4*q**7*s**2 - 198750*p*q**9*s**2 - 36450*p**11*q*r*s**2 - 854550*p**8*q**3*r*s**2 - 7340700*p**5*q**5*r*s**2 - 2028750*p**2*q**7*r*s**2 + 995490*p**9*q*r**2*s**2 + 18896600*p**6*q**3*r**2*s**2 + 5026500*p**3*q**5*r**2*s**2 - 52272500*q**7*r**2*s**2 - 16636800*p**7*q*r**3*s**2 - 43200000*p**4*q**3*r**3*s**2 + 223426250*p*q**5*r**3*s**2 + 112068000*p**5*q*r**4*s**2 - 177000000*p**2*q**3*r**4*s**2 - 244000000*p**3*q*r**5*s**2 - 156000000*q**3*r**5*s**2 + 43740*p**12*s**3 + 1032750*p**9*q**2*s**3 + 8602500*p**6*q**4*s**3 + 15606250*p**3*q**6*s**3 + 39625000*q**8*s**3 - 1603800*p**10*r*s**3 - 26932500*p**7*q**2*r*s**3 - 19562500*p**4*q**4*r*s**3 - 152000000*p*q**6*r*s**3 + 25555500*p**8*r**2*s**3 + 16230000*p**5*q**2*r**2*s**3 + 42187500*p**2*q**4*r**2*s**3 - 165660000*p**6*r**3*s**3 + 373500000*p**3*q**2*r**3*s**3 + 332937500*q**4*r**3*s**3 + 465000000*p**4*r**4*s**3 + 586000000*p*q**2*r**4*s**3 - 592000000*p**2*r**5*s**3 + 480000000*r**6*s**3 - 1518750*p**8*q*s**4 - 62531250*p**5*q**3*s**4 + 7656250*p**2*q**5*s**4 + 184781250*p**6*q*r*s**4 - 15781250*p**3*q**3*r*s**4 - 135156250*q**5*r*s**4 - 1148250000*p**4*q*r**2*s**4 - 2121406250*p*q**3*r**2*s**4 + 1990000000*p**2*q*r**3*s**4 - 3150000000*q*r**4*s**4 - 2531250*p**7*s**5 + 660937500*p**4*q**2*s**5 + 1339843750*p*q**4*s**5 - 33750000*p**5*r*s**5 - 679687500*p**2*q**2*r*s**5 + 6250000*p**3*r**2*s**5 + 6195312500*q**2*r**2*s**5 + 1125000000*p*r**3*s**5 - 996093750*p**3*q*s**6 - 3125000000*q**3*s**6 - 3222656250*p*q*r*s**6 + 1171875000*p**2*s**7 + 976562500*r*s**7 + + c[3] = 80*p**4*q**9 + 540*p*q**11 - 600*p**5*q**7*r - 4770*p**2*q**9*r + 1230*p**6*q**5*r**2 + 20900*p**3*q**7*r**2 + 47250*q**9*r**2 - 710*p**7*q**3*r**3 - 84950*p**4*q**5*r**3 - 526310*p*q**7*r**3 + 720*p**8*q*r**4 + 216280*p**5*q**3*r**4 + 2068020*p**2*q**5*r**4 - 198080*p**6*q*r**5 - 3703200*p**3*q**3*r**5 - 1423600*q**5*r**5 + 2860800*p**4*q*r**6 + 7056000*p*q**3*r**6 - 8320000*p**2*q*r**7 - 2720*p**6*q**6*s - 46350*p**3*q**8*s - 178200*q**10*s + 25740*p**7*q**4*r*s + 489490*p**4*q**6*r*s + 2152350*p*q**8*r*s - 61560*p**8*q**2*r**2*s - 1568150*p**5*q**4*r**2*s - 9060500*p**2*q**6*r**2*s + 24840*p**9*r**3*s + 1692380*p**6*q**2*r**3*s + 18098250*p**3*q**4*r**3*s + 9387750*q**6*r**3*s - 382560*p**7*r**4*s - 16818000*p**4*q**2*r**4*s - 49325000*p*q**4*r**4*s + 1212800*p**5*r**5*s + 64840000*p**2*q**2*r**5*s - 320000*p**3*r**6*s + 10400000*q**2*r**6*s - 36450*p**8*q**3*s**2 - 588350*p**5*q**5*s**2 - 2156250*p**2*q**7*s**2 + 123930*p**9*q*r*s**2 + 2879700*p**6*q**3*r*s**2 + 12548000*p**3*q**5*r*s**2 - 14445000*q**7*r*s**2 - 3233250*p**7*q*r**2*s**2 - 28485000*p**4*q**3*r**2*s**2 + 72231250*p*q**5*r**2*s**2 + 32093000*p**5*q*r**3*s**2 - 61275000*p**2*q**3*r**3*s**2 - 107500000*p**3*q*r**4*s**2 - 78500000*q**3*r**4*s**2 + 22000000*p*q*r**5*s**2 - 72900*p**10*s**3 - 1215000*p**7*q**2*s**3 - 2937500*p**4*q**4*s**3 + 9156250*p*q**6*s**3 + 2612250*p**8*r*s**3 + 16560000*p**5*q**2*r*s**3 - 75468750*p**2*q**4*r*s**3 - 32737500*p**6*r**2*s**3 + 169062500*p**3*q**2*r**2*s**3 + 121718750*q**4*r**2*s**3 + 160250000*p**4*r**3*s**3 + 219750000*p*q**2*r**3*s**3 - 317000000*p**2*r**4*s**3 + 260000000*r**5*s**3 + 2531250*p**6*q*s**4 + 22500000*p**3*q**3*s**4 + 39843750*q**5*s**4 - 266343750*p**4*q*r*s**4 - 776406250*p*q**3*r*s**4 + 789062500*p**2*q*r**2*s**4 - 1368750000*q*r**3*s**4 + 67500000*p**5*s**5 + 441406250*p**2*q**2*s**5 - 311718750*p**3*r*s**5 + 1785156250*q**2*r*s**5 + 546875000*p*r**2*s**5 - 1269531250*p*q*s**6 + 488281250*s**7 + + c[2] = 120*p**5*q**7 + 810*p**2*q**9 - 1280*p**6*q**5*r - 9160*p**3*q**7*r + 3780*q**9*r + 4530*p**7*q**3*r**2 + 36640*p**4*q**5*r**2 - 45270*p*q**7*r**2 - 5400*p**8*q*r**3 - 60920*p**5*q**3*r**3 + 200050*p**2*q**5*r**3 + 31200*p**6*q*r**4 - 476000*p**3*q**3*r**4 - 378200*q**5*r**4 + 521600*p**4*q*r**5 + 1872000*p*q**3*r**5 - 2240000*p**2*q*r**6 + 1440*p**7*q**4*s + 15310*p**4*q**6*s + 59400*p*q**8*s - 9180*p**8*q**2*r*s - 115240*p**5*q**4*r*s - 589650*p**2*q**6*r*s + 16200*p**9*r**2*s + 316710*p**6*q**2*r**2*s + 2547750*p**3*q**4*r**2*s + 2178000*q**6*r**2*s - 259200*p**7*r**3*s - 4123000*p**4*q**2*r**3*s - 11700000*p*q**4*r**3*s + 937600*p**5*r**4*s + 16340000*p**2*q**2*r**4*s - 640000*p**3*r**5*s + 2800000*q**2*r**5*s - 2430*p**9*q*s**2 - 54450*p**6*q**3*s**2 - 285500*p**3*q**5*s**2 - 2767500*q**7*s**2 + 43200*p**7*q*r*s**2 - 916250*p**4*q**3*r*s**2 + 14482500*p*q**5*r*s**2 + 4806000*p**5*q*r**2*s**2 - 13212500*p**2*q**3*r**2*s**2 - 25400000*p**3*q*r**3*s**2 - 18750000*q**3*r**3*s**2 + 8000000*p*q*r**4*s**2 + 121500*p**8*s**3 + 2058750*p**5*q**2*s**3 - 6656250*p**2*q**4*s**3 - 6716250*p**6*r*s**3 + 24125000*p**3*q**2*r*s**3 + 23875000*q**4*r*s**3 + 43125000*p**4*r**2*s**3 + 45750000*p*q**2*r**2*s**3 - 87500000*p**2*r**3*s**3 + 70000000*r**4*s**3 - 44437500*p**4*q*s**4 - 107968750*p*q**3*s**4 + 159531250*p**2*q*r*s**4 - 284375000*q*r**2*s**4 + 7031250*p**3*s**5 + 265625000*q**2*s**5 + 31250000*p*r*s**5 + + c[1] = 160*p**3*q**7 + 1080*q**9 - 1080*p**4*q**5*r - 8730*p*q**7*r + 1510*p**5*q**3*r**2 + 20420*p**2*q**5*r**2 + 720*p**6*q*r**3 - 23200*p**3*q**3*r**3 - 79900*q**5*r**3 + 35200*p**4*q*r**4 + 404000*p*q**3*r**4 - 480000*p**2*q*r**5 + 960*p**5*q**4*s + 2850*p**2*q**6*s + 540*p**6*q**2*r*s + 63500*p**3*q**4*r*s + 319500*q**6*r*s - 7560*p**7*r**2*s - 253500*p**4*q**2*r**2*s - 1806250*p*q**4*r**2*s + 91200*p**5*r**3*s + 2600000*p**2*q**2*r**3*s - 80000*p**3*r**4*s + 600000*q**2*r**4*s - 4050*p**7*q*s**2 - 120000*p**4*q**3*s**2 - 273750*p*q**5*s**2 + 425250*p**5*q*r*s**2 + 2325000*p**2*q**3*r*s**2 - 5400000*p**3*q*r**2*s**2 - 2875000*q**3*r**2*s**2 + 1500000*p*q*r**3*s**2 - 303750*p**6*s**3 - 843750*p**3*q**2*s**3 - 812500*q**4*s**3 + 5062500*p**4*r*s**3 + 13312500*p*q**2*r*s**3 - 14500000*p**2*r**2*s**3 + 15000000*r**3*s**3 - 3750000*p**2*q*s**4 - 35937500*q*r*s**4 + 11718750*p*s**5 + + c[0] = 80*p**4*q**5 + 540*p*q**7 - 600*p**5*q**3*r - 4770*p**2*q**5*r + 1080*p**6*q*r**2 + 11200*p**3*q**3*r**2 - 12150*q**5*r**2 - 4800*p**4*q*r**3 + 64000*p*q**3*r**3 - 80000*p**2*q*r**4 + 1080*p**6*q**2*s + 13250*p**3*q**4*s + 54000*q**6*s - 3240*p**7*r*s - 56250*p**4*q**2*r*s - 337500*p*q**4*r*s + 43200*p**5*r**2*s + 560000*p**2*q**2*r**2*s - 80000*p**3*r**3*s + 100000*q**2*r**3*s + 6750*p**5*q*s**2 + 225000*p**2*q**3*s**2 - 900000*p**3*q*r*s**2 - 562500*q**3*r*s**2 + 500000*p*q*r**2*s**2 + 843750*p**4*s**3 + 1937500*p*q**2*s**3 - 3000000*p**2*r*s**3 + 2500000*r**2*s**3 - 5468750*q*s**4 + + return c + + @property + def F(self): + p, q, r, s = self.p, self.q, self.r, self.s + F = 4*p**6*q**6 + 59*p**3*q**8 + 216*q**10 - 36*p**7*q**4*r - 623*p**4*q**6*r - 2610*p*q**8*r + 81*p**8*q**2*r**2 + 2015*p**5*q**4*r**2 + 10825*p**2*q**6*r**2 - 1800*p**6*q**2*r**3 - 17500*p**3*q**4*r**3 + 625*q**6*r**3 + 10000*p**4*q**2*r**4 + 108*p**8*q**3*s + 1584*p**5*q**5*s + 5700*p**2*q**7*s - 486*p**9*q*r*s - 9720*p**6*q**3*r*s - 45050*p**3*q**5*r*s - 9000*q**7*r*s + 10800*p**7*q*r**2*s + 92500*p**4*q**3*r**2*s + 32500*p*q**5*r**2*s - 60000*p**5*q*r**3*s - 50000*p**2*q**3*r**3*s + 729*p**10*s**2 + 12150*p**7*q**2*s**2 + 60000*p**4*q**4*s**2 + 93750*p*q**6*s**2 - 18225*p**8*r*s**2 - 175500*p**5*q**2*r*s**2 - 478125*p**2*q**4*r*s**2 + 135000*p**6*r**2*s**2 + 850000*p**3*q**2*r**2*s**2 + 15625*q**4*r**2*s**2 - 250000*p**4*r**3*s**2 + 225000*p**3*q**3*s**3 + 175000*q**5*s**3 - 1012500*p**4*q*r*s**3 - 1187500*p*q**3*r*s**3 + 1250000*p**2*q*r**2*s**3 + 928125*p**5*s**4 + 1875000*p**2*q**2*s**4 - 2812500*p**3*r*s**4 - 390625*q**2*r*s**4 - 9765625*s**6 + return F + + def l0(self, theta): + F = self.F + a = self.a + l0 = Poly(a, x).eval(theta)/F + return l0 + + def T(self, theta, d): + F = self.F + T = [0]*5 + b = self.b + # Note that the order of sublists of the b's has been reversed compared to the paper + T[1] = -Poly(b[1], x).eval(theta)/(2*F) + T[2] = Poly(b[2], x).eval(theta)/(2*d*F) + T[3] = Poly(b[3], x).eval(theta)/(2*F) + T[4] = Poly(b[4], x).eval(theta)/(2*d*F) + return T + + def order(self, theta, d): + F = self.F + o = self.o + order = Poly(o, x).eval(theta)/(d*F) + return N(order) + + def uv(self, theta, d): + c = self.c + u = self.q*Rational(-25, 2) + v = Poly(c, x).eval(theta)/(2*d*self.F) + return N(u), N(v) + + @property + def zeta(self): + return [self.zeta1, self.zeta2, self.zeta3, self.zeta4] diff --git a/phivenv/Lib/site-packages/sympy/polys/polyroots.py b/phivenv/Lib/site-packages/sympy/polys/polyroots.py new file mode 100644 index 0000000000000000000000000000000000000000..4def1312eb5b94a13e511d2d4f9b15f1d51fd63f --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/polyroots.py @@ -0,0 +1,1227 @@ +"""Algorithms for computing symbolic roots of polynomials. """ + + +import math +from functools import reduce + +from sympy.core import S, I, pi +from sympy.core.exprtools import factor_terms +from sympy.core.function import _mexpand +from sympy.core.logic import fuzzy_not +from sympy.core.mul import expand_2arg, Mul +from sympy.core.intfunc import igcd +from sympy.core.numbers import Rational, comp +from sympy.core.power import Pow +from sympy.core.relational import Eq +from sympy.core.sorting import ordered +from sympy.core.symbol import Dummy, Symbol, symbols +from sympy.core.sympify import sympify +from sympy.functions import exp, im, cos, acos, Piecewise +from sympy.functions.elementary.miscellaneous import root, sqrt +from sympy.ntheory import divisors, isprime, nextprime +from sympy.polys.domains import EX +from sympy.polys.polyerrors import (PolynomialError, GeneratorsNeeded, + DomainError, UnsolvableFactorError) +from sympy.polys.polyquinticconst import PolyQuintic +from sympy.polys.polytools import Poly, cancel, factor, gcd_list, discriminant +from sympy.polys.rationaltools import together +from sympy.polys.specialpolys import cyclotomic_poly +from sympy.utilities import public +from sympy.utilities.misc import filldedent + + + +z = Symbol('z') # importing from abc cause O to be lost as clashing symbol + + +def roots_linear(f): + """Returns a list of roots of a linear polynomial.""" + r = -f.nth(0)/f.nth(1) + dom = f.get_domain() + + if not dom.is_Numerical: + if dom.is_Composite: + r = factor(r) + else: + from sympy.simplify.simplify import simplify + r = simplify(r) + + return [r] + + +def roots_quadratic(f): + """Returns a list of roots of a quadratic polynomial. If the domain is ZZ + then the roots will be sorted with negatives coming before positives. + The ordering will be the same for any numerical coefficients as long as + the assumptions tested are correct, otherwise the ordering will not be + sorted (but will be canonical). + """ + + a, b, c = f.all_coeffs() + dom = f.get_domain() + + def _sqrt(d): + # remove squares from square root since both will be represented + # in the results; a similar thing is happening in roots() but + # must be duplicated here because not all quadratics are binomials + co = [] + other = [] + for di in Mul.make_args(d): + if di.is_Pow and di.exp.is_Integer and di.exp % 2 == 0: + co.append(Pow(di.base, di.exp//2)) + else: + other.append(di) + if co: + d = Mul(*other) + co = Mul(*co) + return co*sqrt(d) + return sqrt(d) + + def _simplify(expr): + if dom.is_Composite: + return factor(expr) + else: + from sympy.simplify.simplify import simplify + return simplify(expr) + + if c is S.Zero: + r0, r1 = S.Zero, -b/a + + if not dom.is_Numerical: + r1 = _simplify(r1) + elif r1.is_negative: + r0, r1 = r1, r0 + elif b is S.Zero: + r = -c/a + if not dom.is_Numerical: + r = _simplify(r) + + R = _sqrt(r) + r0 = -R + r1 = R + else: + d = b**2 - 4*a*c + A = 2*a + B = -b/A + + if not dom.is_Numerical: + d = _simplify(d) + B = _simplify(B) + + D = factor_terms(_sqrt(d)/A) + r0 = B - D + r1 = B + D + if a.is_negative: + r0, r1 = r1, r0 + elif not dom.is_Numerical: + r0, r1 = [expand_2arg(i) for i in (r0, r1)] + + return [r0, r1] + + +def roots_cubic(f, trig=False): + """Returns a list of roots of a cubic polynomial. + + References + ========== + [1] https://en.wikipedia.org/wiki/Cubic_function, General formula for roots, + (accessed November 17, 2014). + """ + if trig: + a, b, c, d = f.all_coeffs() + p = (3*a*c - b**2)/(3*a**2) + q = (2*b**3 - 9*a*b*c + 27*a**2*d)/(27*a**3) + D = 18*a*b*c*d - 4*b**3*d + b**2*c**2 - 4*a*c**3 - 27*a**2*d**2 + if (D > 0) == True: + rv = [] + for k in range(3): + rv.append(2*sqrt(-p/3)*cos(acos(q/p*sqrt(-3/p)*Rational(3, 2))/3 - k*pi*Rational(2, 3))) + return [i - b/3/a for i in rv] + + # a*x**3 + b*x**2 + c*x + d -> x**3 + a*x**2 + b*x + c + _, a, b, c = f.monic().all_coeffs() + + if c is S.Zero: + x1, x2 = roots([1, a, b], multiple=True) + return [x1, S.Zero, x2] + + # x**3 + a*x**2 + b*x + c -> u**3 + p*u + q + p = b - a**2/3 + q = c - a*b/3 + 2*a**3/27 + + pon3 = p/3 + aon3 = a/3 + + u1 = None + if p is S.Zero: + if q is S.Zero: + return [-aon3]*3 + u1 = -root(q, 3) if q.is_positive else root(-q, 3) + elif q is S.Zero: + y1, y2 = roots([1, 0, p], multiple=True) + return [tmp - aon3 for tmp in [y1, S.Zero, y2]] + elif q.is_real and q.is_negative: + u1 = -root(-q/2 + sqrt(q**2/4 + pon3**3), 3) + + coeff = I*sqrt(3)/2 + if u1 is None: + u1 = S.One + u2 = Rational(-1, 2) + coeff + u3 = Rational(-1, 2) - coeff + b, c, d = a, b, c # a, b, c, d = S.One, a, b, c + D0 = b**2 - 3*c # b**2 - 3*a*c + D1 = 2*b**3 - 9*b*c + 27*d # 2*b**3 - 9*a*b*c + 27*a**2*d + C = root((D1 + sqrt(D1**2 - 4*D0**3))/2, 3) + return [-(b + uk*C + D0/C/uk)/3 for uk in [u1, u2, u3]] # -(b + uk*C + D0/C/uk)/3/a + + u2 = u1*(Rational(-1, 2) + coeff) + u3 = u1*(Rational(-1, 2) - coeff) + + if p is S.Zero: + return [u1 - aon3, u2 - aon3, u3 - aon3] + + soln = [ + -u1 + pon3/u1 - aon3, + -u2 + pon3/u2 - aon3, + -u3 + pon3/u3 - aon3 + ] + + return soln + +def _roots_quartic_euler(p, q, r, a): + """ + Descartes-Euler solution of the quartic equation + + Parameters + ========== + + p, q, r: coefficients of ``x**4 + p*x**2 + q*x + r`` + a: shift of the roots + + Notes + ===== + + This is a helper function for ``roots_quartic``. + + Look for solutions of the form :: + + ``x1 = sqrt(R) - sqrt(A + B*sqrt(R))`` + ``x2 = -sqrt(R) - sqrt(A - B*sqrt(R))`` + ``x3 = -sqrt(R) + sqrt(A - B*sqrt(R))`` + ``x4 = sqrt(R) + sqrt(A + B*sqrt(R))`` + + To satisfy the quartic equation one must have + ``p = -2*(R + A); q = -4*B*R; r = (R - A)**2 - B**2*R`` + so that ``R`` must satisfy the Descartes-Euler resolvent equation + ``64*R**3 + 32*p*R**2 + (4*p**2 - 16*r)*R - q**2 = 0`` + + If the resolvent does not have a rational solution, return None; + in that case it is likely that the Ferrari method gives a simpler + solution. + + Examples + ======== + + >>> from sympy import S + >>> from sympy.polys.polyroots import _roots_quartic_euler + >>> p, q, r = -S(64)/5, -S(512)/125, -S(1024)/3125 + >>> _roots_quartic_euler(p, q, r, S(0))[0] + -sqrt(32*sqrt(5)/125 + 16/5) + 4*sqrt(5)/5 + """ + # solve the resolvent equation + x = Dummy('x') + eq = 64*x**3 + 32*p*x**2 + (4*p**2 - 16*r)*x - q**2 + xsols = list(roots(Poly(eq, x), cubics=False).keys()) + xsols = [sol for sol in xsols if sol.is_rational and sol.is_nonzero] + if not xsols: + return None + R = max(xsols) + c1 = sqrt(R) + B = -q*c1/(4*R) + A = -R - p/2 + c2 = sqrt(A + B) + c3 = sqrt(A - B) + return [c1 - c2 - a, -c1 - c3 - a, -c1 + c3 - a, c1 + c2 - a] + + +def roots_quartic(f): + r""" + Returns a list of roots of a quartic polynomial. + + There are many references for solving quartic expressions available [1-5]. + This reviewer has found that many of them require one to select from among + 2 or more possible sets of solutions and that some solutions work when one + is searching for real roots but do not work when searching for complex roots + (though this is not always stated clearly). The following routine has been + tested and found to be correct for 0, 2 or 4 complex roots. + + The quasisymmetric case solution [6] looks for quartics that have the form + `x**4 + A*x**3 + B*x**2 + C*x + D = 0` where `(C/A)**2 = D`. + + Although no general solution that is always applicable for all + coefficients is known to this reviewer, certain conditions are tested + to determine the simplest 4 expressions that can be returned: + + 1) `f = c + a*(a**2/8 - b/2) == 0` + 2) `g = d - a*(a*(3*a**2/256 - b/16) + c/4) = 0` + 3) if `f != 0` and `g != 0` and `p = -d + a*c/4 - b**2/12` then + a) `p == 0` + b) `p != 0` + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.polys.polyroots import roots_quartic + + >>> r = roots_quartic(Poly('x**4-6*x**3+17*x**2-26*x+20')) + + >>> # 4 complex roots: 1+-I*sqrt(3), 2+-I + >>> sorted(str(tmp.evalf(n=2)) for tmp in r) + ['1.0 + 1.7*I', '1.0 - 1.7*I', '2.0 + 1.0*I', '2.0 - 1.0*I'] + + References + ========== + + 1. http://mathforum.org/dr.math/faq/faq.cubic.equations.html + 2. https://en.wikipedia.org/wiki/Quartic_function#Summary_of_Ferrari.27s_method + 3. https://planetmath.org/encyclopedia/GaloisTheoreticDerivationOfTheQuarticFormula.html + 4. https://people.bath.ac.uk/masjhd/JHD-CA.pdf + 5. http://www.albmath.org/files/Math_5713.pdf + 6. https://web.archive.org/web/20171002081448/http://www.statemaster.com/encyclopedia/Quartic-equation + 7. https://eqworld.ipmnet.ru/en/solutions/ae/ae0108.pdf + """ + _, a, b, c, d = f.monic().all_coeffs() + + if not d: + return [S.Zero] + roots([1, a, b, c], multiple=True) + elif (c/a)**2 == d: + x, m = f.gen, c/a + + g = Poly(x**2 + a*x + b - 2*m, x) + + z1, z2 = roots_quadratic(g) + + h1 = Poly(x**2 - z1*x + m, x) + h2 = Poly(x**2 - z2*x + m, x) + + r1 = roots_quadratic(h1) + r2 = roots_quadratic(h2) + + return r1 + r2 + else: + a2 = a**2 + e = b - 3*a2/8 + f = _mexpand(c + a*(a2/8 - b/2)) + aon4 = a/4 + g = _mexpand(d - aon4*(a*(3*a2/64 - b/4) + c)) + + if f.is_zero: + y1, y2 = [sqrt(tmp) for tmp in + roots([1, e, g], multiple=True)] + return [tmp - aon4 for tmp in [-y1, -y2, y1, y2]] + if g.is_zero: + y = [S.Zero] + roots([1, 0, e, f], multiple=True) + return [tmp - aon4 for tmp in y] + else: + # Descartes-Euler method, see [7] + sols = _roots_quartic_euler(e, f, g, aon4) + if sols: + return sols + # Ferrari method, see [1, 2] + p = -e**2/12 - g + q = -e**3/108 + e*g/3 - f**2/8 + TH = Rational(1, 3) + + def _ans(y): + w = sqrt(e + 2*y) + arg1 = 3*e + 2*y + arg2 = 2*f/w + ans = [] + for s in [-1, 1]: + root = sqrt(-(arg1 + s*arg2)) + for t in [-1, 1]: + ans.append((s*w - t*root)/2 - aon4) + return ans + + # whether a Piecewise is returned or not + # depends on knowing p, so try to put + # in a simple form + p = _mexpand(p) + + + # p == 0 case + y1 = e*Rational(-5, 6) - q**TH + if p.is_zero: + return _ans(y1) + + # if p != 0 then u below is not 0 + root = sqrt(q**2/4 + p**3/27) + r = -q/2 + root # or -q/2 - root + u = r**TH # primary root of solve(x**3 - r, x) + y2 = e*Rational(-5, 6) + u - p/u/3 + if fuzzy_not(p.is_zero): + return _ans(y2) + + # sort it out once they know the values of the coefficients + return [Piecewise((a1, Eq(p, 0)), (a2, True)) + for a1, a2 in zip(_ans(y1), _ans(y2))] + + +def roots_binomial(f): + """Returns a list of roots of a binomial polynomial. If the domain is ZZ + then the roots will be sorted with negatives coming before positives. + The ordering will be the same for any numerical coefficients as long as + the assumptions tested are correct, otherwise the ordering will not be + sorted (but will be canonical). + """ + n = f.degree() + + a, b = f.nth(n), f.nth(0) + base = -cancel(b/a) + alpha = root(base, n) + + if alpha.is_number: + alpha = alpha.expand(complex=True) + + # define some parameters that will allow us to order the roots. + # If the domain is ZZ this is guaranteed to return roots sorted + # with reals before non-real roots and non-real sorted according + # to real part and imaginary part, e.g. -1, 1, -1 + I, 2 - I + neg = base.is_negative + even = n % 2 == 0 + if neg: + if even == True and (base + 1).is_positive: + big = True + else: + big = False + + # get the indices in the right order so the computed + # roots will be sorted when the domain is ZZ + ks = [] + imax = n//2 + if even: + ks.append(imax) + imax -= 1 + if not neg: + ks.append(0) + for i in range(imax, 0, -1): + if neg: + ks.extend([i, -i]) + else: + ks.extend([-i, i]) + if neg: + ks.append(0) + if big: + for i in range(0, len(ks), 2): + pair = ks[i: i + 2] + pair = list(reversed(pair)) + + # compute the roots + roots, d = [], 2*I*pi/n + for k in ks: + zeta = exp(k*d).expand(complex=True) + roots.append((alpha*zeta).expand(power_base=False)) + + return roots + + +def _inv_totient_estimate(m): + """ + Find ``(L, U)`` such that ``L <= phi^-1(m) <= U``. + + Examples + ======== + + >>> from sympy.polys.polyroots import _inv_totient_estimate + + >>> _inv_totient_estimate(192) + (192, 840) + >>> _inv_totient_estimate(400) + (400, 1750) + + """ + primes = [ d + 1 for d in divisors(m) if isprime(d + 1) ] + + a, b = 1, 1 + + for p in primes: + a *= p + b *= p - 1 + + L = m + U = int(math.ceil(m*(float(a)/b))) + + P = p = 2 + primes = [] + + while P <= U: + p = nextprime(p) + primes.append(p) + P *= p + + P //= p + b = 1 + + for p in primes[:-1]: + b *= p - 1 + + U = int(math.ceil(m*(float(P)/b))) + + return L, U + + +def roots_cyclotomic(f, factor=False): + """Compute roots of cyclotomic polynomials. """ + L, U = _inv_totient_estimate(f.degree()) + + for n in range(L, U + 1): + g = cyclotomic_poly(n, f.gen, polys=True) + + if f.expr == g.expr: + break + else: # pragma: no cover + raise RuntimeError("failed to find index of a cyclotomic polynomial") + + roots = [] + + if not factor: + # get the indices in the right order so the computed + # roots will be sorted + h = n//2 + ks = [i for i in range(1, n + 1) if igcd(i, n) == 1] + ks.sort(key=lambda x: (x, -1) if x <= h else (abs(x - n), 1)) + d = 2*I*pi/n + for k in reversed(ks): + roots.append(exp(k*d).expand(complex=True)) + else: + g = Poly(f, extension=root(-1, n)) + + for h, _ in ordered(g.factor_list()[1]): + roots.append(-h.TC()) + + return roots + + +def roots_quintic(f): + """ + Calculate exact roots of a solvable irreducible quintic with rational coefficients. + Return an empty list if the quintic is reducible or not solvable. + """ + result = [] + + coeff_5, coeff_4, p_, q_, r_, s_ = f.all_coeffs() + + if not all(coeff.is_Rational for coeff in (coeff_5, coeff_4, p_, q_, r_, s_)): + return result + + if coeff_5 != 1: + f = Poly(f / coeff_5) + _, coeff_4, p_, q_, r_, s_ = f.all_coeffs() + + # Cancel coeff_4 to form x^5 + px^3 + qx^2 + rx + s + if coeff_4: + p = p_ - 2*coeff_4*coeff_4/5 + q = q_ - 3*coeff_4*p_/5 + 4*coeff_4**3/25 + r = r_ - 2*coeff_4*q_/5 + 3*coeff_4**2*p_/25 - 3*coeff_4**4/125 + s = s_ - coeff_4*r_/5 + coeff_4**2*q_/25 - coeff_4**3*p_/125 + 4*coeff_4**5/3125 + x = f.gen + f = Poly(x**5 + p*x**3 + q*x**2 + r*x + s) + else: + p, q, r, s = p_, q_, r_, s_ + + quintic = PolyQuintic(f) + + # Eqn standardized. Algo for solving starts here + if not f.is_irreducible: + return result + f20 = quintic.f20 + # Check if f20 has linear factors over domain Z + if f20.is_irreducible: + return result + # Now, we know that f is solvable + for _factor in f20.factor_list()[1]: + if _factor[0].is_linear: + theta = _factor[0].root(0) + break + d = discriminant(f) + delta = sqrt(d) + # zeta = a fifth root of unity + zeta1, zeta2, zeta3, zeta4 = quintic.zeta + T = quintic.T(theta, d) + tol = S(1e-10) + alpha = T[1] + T[2]*delta + alpha_bar = T[1] - T[2]*delta + beta = T[3] + T[4]*delta + beta_bar = T[3] - T[4]*delta + + disc = alpha**2 - 4*beta + disc_bar = alpha_bar**2 - 4*beta_bar + + l0 = quintic.l0(theta) + Stwo = S(2) + l1 = _quintic_simplify((-alpha + sqrt(disc)) / Stwo) + l4 = _quintic_simplify((-alpha - sqrt(disc)) / Stwo) + + l2 = _quintic_simplify((-alpha_bar + sqrt(disc_bar)) / Stwo) + l3 = _quintic_simplify((-alpha_bar - sqrt(disc_bar)) / Stwo) + + order = quintic.order(theta, d) + test = (order*delta.n()) - ( (l1.n() - l4.n())*(l2.n() - l3.n()) ) + # Comparing floats + if not comp(test, 0, tol): + l2, l3 = l3, l2 + + # Now we have correct order of l's + R1 = l0 + l1*zeta1 + l2*zeta2 + l3*zeta3 + l4*zeta4 + R2 = l0 + l3*zeta1 + l1*zeta2 + l4*zeta3 + l2*zeta4 + R3 = l0 + l2*zeta1 + l4*zeta2 + l1*zeta3 + l3*zeta4 + R4 = l0 + l4*zeta1 + l3*zeta2 + l2*zeta3 + l1*zeta4 + + Res = [None, [None]*5, [None]*5, [None]*5, [None]*5] + Res_n = [None, [None]*5, [None]*5, [None]*5, [None]*5] + + # Simplifying improves performance a lot for exact expressions + R1 = _quintic_simplify(R1) + R2 = _quintic_simplify(R2) + R3 = _quintic_simplify(R3) + R4 = _quintic_simplify(R4) + + # hard-coded results for [factor(i) for i in _vsolve(x**5 - a - I*b, x)] + x0 = z**(S(1)/5) + x1 = sqrt(2) + x2 = sqrt(5) + x3 = sqrt(5 - x2) + x4 = I*x2 + x5 = x4 + I + x6 = I*x0/4 + x7 = x1*sqrt(x2 + 5) + sol = [x0, -x6*(x1*x3 - x5), x6*(x1*x3 + x5), -x6*(x4 + x7 - I), x6*(-x4 + x7 + I)] + + R1 = R1.as_real_imag() + R2 = R2.as_real_imag() + R3 = R3.as_real_imag() + R4 = R4.as_real_imag() + + for i, s in enumerate(sol): + Res[1][i] = _quintic_simplify(s.xreplace({z: R1[0] + I*R1[1]})) + Res[2][i] = _quintic_simplify(s.xreplace({z: R2[0] + I*R2[1]})) + Res[3][i] = _quintic_simplify(s.xreplace({z: R3[0] + I*R3[1]})) + Res[4][i] = _quintic_simplify(s.xreplace({z: R4[0] + I*R4[1]})) + + for i in range(1, 5): + for j in range(5): + Res_n[i][j] = Res[i][j].n() + Res[i][j] = _quintic_simplify(Res[i][j]) + r1 = Res[1][0] + r1_n = Res_n[1][0] + + for i in range(5): + if comp(im(r1_n*Res_n[4][i]), 0, tol): + r4 = Res[4][i] + break + + # Now we have various Res values. Each will be a list of five + # values. We have to pick one r value from those five for each Res + u, v = quintic.uv(theta, d) + testplus = (u + v*delta*sqrt(5)).n() + testminus = (u - v*delta*sqrt(5)).n() + + # Evaluated numbers suffixed with _n + # We will use evaluated numbers for calculation. Much faster. + r4_n = r4.n() + r2 = r3 = None + + for i in range(5): + r2temp_n = Res_n[2][i] + for j in range(5): + # Again storing away the exact number and using + # evaluated numbers in computations + r3temp_n = Res_n[3][j] + if (comp((r1_n*r2temp_n**2 + r4_n*r3temp_n**2 - testplus).n(), 0, tol) and + comp((r3temp_n*r1_n**2 + r2temp_n*r4_n**2 - testminus).n(), 0, tol)): + r2 = Res[2][i] + r3 = Res[3][j] + break + if r2 is not None: + break + else: + return [] # fall back to normal solve + + # Now, we have r's so we can get roots + x1 = (r1 + r2 + r3 + r4)/5 + x2 = (r1*zeta4 + r2*zeta3 + r3*zeta2 + r4*zeta1)/5 + x3 = (r1*zeta3 + r2*zeta1 + r3*zeta4 + r4*zeta2)/5 + x4 = (r1*zeta2 + r2*zeta4 + r3*zeta1 + r4*zeta3)/5 + x5 = (r1*zeta1 + r2*zeta2 + r3*zeta3 + r4*zeta4)/5 + result = [x1, x2, x3, x4, x5] + + # Now check if solutions are distinct + + saw = set() + for r in result: + r = r.n(2) + if r in saw: + # Roots were identical. Abort, return [] + # and fall back to usual solve + return [] + saw.add(r) + + # Restore to original equation where coeff_4 is nonzero + if coeff_4: + result = [x - coeff_4 / 5 for x in result] + return result + + +def _quintic_simplify(expr): + from sympy.simplify.simplify import powsimp + expr = powsimp(expr) + expr = cancel(expr) + return together(expr) + + +def _integer_basis(poly): + """Compute coefficient basis for a polynomial over integers. + + Returns the integer ``div`` such that substituting ``x = div*y`` + ``p(x) = m*q(y)`` where the coefficients of ``q`` are smaller + than those of ``p``. + + For example ``x**5 + 512*x + 1024 = 0`` + with ``div = 4`` becomes ``y**5 + 2*y + 1 = 0`` + + Returns the integer ``div`` or ``None`` if there is no possible scaling. + + Examples + ======== + + >>> from sympy.polys import Poly + >>> from sympy.abc import x + >>> from sympy.polys.polyroots import _integer_basis + >>> p = Poly(x**5 + 512*x + 1024, x, domain='ZZ') + >>> _integer_basis(p) + 4 + """ + monoms, coeffs = list(zip(*poly.terms())) + + monoms, = list(zip(*monoms)) + coeffs = list(map(abs, coeffs)) + + if coeffs[0] < coeffs[-1]: + coeffs = list(reversed(coeffs)) + n = monoms[0] + monoms = [n - i for i in reversed(monoms)] + else: + return None + + monoms = monoms[:-1] + coeffs = coeffs[:-1] + + # Special case for two-term polynominals + if len(monoms) == 1: + r = Pow(coeffs[0], S.One/monoms[0]) + if r.is_Integer: + return int(r) + else: + return None + + divs = reversed(divisors(gcd_list(coeffs))[1:]) + + try: + div = next(divs) + except StopIteration: + return None + + while True: + for monom, coeff in zip(monoms, coeffs): + if coeff % div**monom != 0: + try: + div = next(divs) + except StopIteration: + return None + else: + break + else: + return div + + +def preprocess_roots(poly): + """Try to get rid of symbolic coefficients from ``poly``. """ + coeff = S.One + + poly_func = poly.func + try: + _, poly = poly.clear_denoms(convert=True) + except DomainError: + return coeff, poly + + poly = poly.primitive()[1] + poly = poly.retract() + + # TODO: This is fragile. Figure out how to make this independent of construct_domain(). + if poly.get_domain().is_Poly and all(c.is_term for c in poly.rep.coeffs()): + poly = poly.inject() + + strips = list(zip(*poly.monoms())) + gens = list(poly.gens[1:]) + + base, strips = strips[0], strips[1:] + + for gen, strip in zip(list(gens), strips): + reverse = False + + if strip[0] < strip[-1]: + strip = reversed(strip) + reverse = True + + ratio = None + + for a, b in zip(base, strip): + if not a and not b: + continue + elif not a or not b: + break + elif b % a != 0: + break + else: + _ratio = b // a + + if ratio is None: + ratio = _ratio + elif ratio != _ratio: + break + else: + if reverse: + ratio = -ratio + + poly = poly.eval(gen, 1) + coeff *= gen**(-ratio) + gens.remove(gen) + + if gens: + poly = poly.eject(*gens) + + if poly.is_univariate and poly.get_domain().is_ZZ: + basis = _integer_basis(poly) + + if basis is not None: + n = poly.degree() + + def func(k, coeff): + return coeff//basis**(n - k[0]) + + poly = poly.termwise(func) + coeff *= basis + + if not isinstance(poly, poly_func): + poly = poly_func(poly) + return coeff, poly + + +@public +def roots(f, *gens, + auto=True, + cubics=True, + trig=False, + quartics=True, + quintics=False, + multiple=False, + filter=None, + predicate=None, + strict=False, + **flags): + """ + Computes symbolic roots of a univariate polynomial. + + Given a univariate polynomial f with symbolic coefficients (or + a list of the polynomial's coefficients), returns a dictionary + with its roots and their multiplicities. + + Only roots expressible via radicals will be returned. To get + a complete set of roots use RootOf class or numerical methods + instead. By default cubic and quartic formulas are used in + the algorithm. To disable them because of unreadable output + set ``cubics=False`` or ``quartics=False`` respectively. If cubic + roots are real but are expressed in terms of complex numbers + (casus irreducibilis [1]) the ``trig`` flag can be set to True to + have the solutions returned in terms of cosine and inverse cosine + functions. + + To get roots from a specific domain set the ``filter`` flag with + one of the following specifiers: Z, Q, R, I, C. By default all + roots are returned (this is equivalent to setting ``filter='C'``). + + By default a dictionary is returned giving a compact result in + case of multiple roots. However to get a list containing all + those roots set the ``multiple`` flag to True; the list will + have identical roots appearing next to each other in the result. + (For a given Poly, the all_roots method will give the roots in + sorted numerical order.) + + If the ``strict`` flag is True, ``UnsolvableFactorError`` will be + raised if the roots found are known to be incomplete (because + some roots are not expressible in radicals). + + Examples + ======== + + >>> from sympy import Poly, roots, degree + >>> from sympy.abc import x, y + + >>> roots(x**2 - 1, x) + {-1: 1, 1: 1} + + >>> p = Poly(x**2-1, x) + >>> roots(p) + {-1: 1, 1: 1} + + >>> p = Poly(x**2-y, x, y) + + >>> roots(Poly(p, x)) + {-sqrt(y): 1, sqrt(y): 1} + + >>> roots(x**2 - y, x) + {-sqrt(y): 1, sqrt(y): 1} + + >>> roots([1, 0, -1]) + {-1: 1, 1: 1} + + ``roots`` will only return roots expressible in radicals. If + the given polynomial has some or all of its roots inexpressible in + radicals, the result of ``roots`` will be incomplete or empty + respectively. + + Example where result is incomplete: + + >>> roots((x-1)*(x**5-x+1), x) + {1: 1} + + In this case, the polynomial has an unsolvable quintic factor + whose roots cannot be expressed by radicals. The polynomial has a + rational root (due to the factor `(x-1)`), which is returned since + ``roots`` always finds all rational roots. + + Example where result is empty: + + >>> roots(x**7-3*x**2+1, x) + {} + + Here, the polynomial has no roots expressible in radicals, so + ``roots`` returns an empty dictionary. + + The result produced by ``roots`` is complete if and only if the + sum of the multiplicity of each root is equal to the degree of + the polynomial. If strict=True, UnsolvableFactorError will be + raised if the result is incomplete. + + The result can be be checked for completeness as follows: + + >>> f = x**3-2*x**2+1 + >>> sum(roots(f, x).values()) == degree(f, x) + True + >>> f = (x-1)*(x**5-x+1) + >>> sum(roots(f, x).values()) == degree(f, x) + False + + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Cubic_equation#Trigonometric_and_hyperbolic_solutions + + """ + from sympy.polys.polytools import to_rational_coeffs + flags = dict(flags) + + if isinstance(f, list): + if gens: + raise ValueError('redundant generators given') + + x = Dummy('x') + + poly, i = {}, len(f) - 1 + + for coeff in f: + poly[i], i = sympify(coeff), i - 1 + + f = Poly(poly, x, field=True) + else: + try: + F = Poly(f, *gens, **flags) + if not isinstance(f, Poly) and not F.gen.is_Symbol: + raise PolynomialError("generator must be a Symbol") + f = F + except GeneratorsNeeded: + if multiple: + return [] + else: + return {} + else: + n = f.degree() + if f.length() == 2 and n > 2: + # check for foo**n in constant if dep is c*gen**m + con, dep = f.as_expr().as_independent(*f.gens) + fcon = -(-con).factor() + if fcon != con: + con = fcon + bases = [] + for i in Mul.make_args(con): + if i.is_Pow: + b, e = i.as_base_exp() + if e.is_Integer and b.is_Add: + bases.append((b, Dummy(positive=True))) + if bases: + rv = roots(Poly((dep + con).xreplace(dict(bases)), + *f.gens), *F.gens, + auto=auto, + cubics=cubics, + trig=trig, + quartics=quartics, + quintics=quintics, + multiple=multiple, + filter=filter, + predicate=predicate, + **flags) + return {factor_terms(k.xreplace( + {v: k for k, v in bases}) + ): v for k, v in rv.items()} + + if f.is_multivariate: + raise PolynomialError('multivariate polynomials are not supported') + + def _update_dict(result, zeros, currentroot, k): + if currentroot == S.Zero: + if S.Zero in zeros: + zeros[S.Zero] += k + else: + zeros[S.Zero] = k + if currentroot in result: + result[currentroot] += k + else: + result[currentroot] = k + + def _try_decompose(f): + """Find roots using functional decomposition. """ + factors, roots = f.decompose(), [] + + for currentroot in _try_heuristics(factors[0]): + roots.append(currentroot) + + for currentfactor in factors[1:]: + previous, roots = list(roots), [] + + for currentroot in previous: + g = currentfactor - Poly(currentroot, f.gen) + + for currentroot in _try_heuristics(g): + roots.append(currentroot) + + return roots + + def _try_heuristics(f): + """Find roots using formulas and some tricks. """ + if f.is_ground: + return [] + if f.is_monomial: + return [S.Zero]*f.degree() + + if f.length() == 2: + if f.degree() == 1: + return list(map(cancel, roots_linear(f))) + else: + return roots_binomial(f) + + result = [] + + for i in [-1, 1]: + if not f.eval(i): + f = f.quo(Poly(f.gen - i, f.gen)) + result.append(i) + break + + n = f.degree() + + if n == 1: + result += list(map(cancel, roots_linear(f))) + elif n == 2: + result += list(map(cancel, roots_quadratic(f))) + elif f.is_cyclotomic: + result += roots_cyclotomic(f) + elif n == 3 and cubics: + result += roots_cubic(f, trig=trig) + elif n == 4 and quartics: + result += roots_quartic(f) + elif n == 5 and quintics: + result += roots_quintic(f) + + return result + + # Convert the generators to symbols + dumgens = symbols('x:%d' % len(f.gens), cls=Dummy) + f = f.per(f.rep, dumgens) + + (k,), f = f.terms_gcd() + + if not k: + zeros = {} + else: + zeros = {S.Zero: k} + + coeff, f = preprocess_roots(f) + + if auto and f.get_domain().is_Ring: + f = f.to_field() + + # Use EX instead of ZZ_I or QQ_I + if f.get_domain().is_QQ_I: + f = f.per(f.rep.convert(EX)) + + rescale_x = None + translate_x = None + + result = {} + + if not f.is_ground: + dom = f.get_domain() + if not dom.is_Exact and dom.is_Numerical: + for r in f.nroots(): + _update_dict(result, zeros, r, 1) + elif f.degree() == 1: + _update_dict(result, zeros, roots_linear(f)[0], 1) + elif f.length() == 2: + roots_fun = roots_quadratic if f.degree() == 2 else roots_binomial + for r in roots_fun(f): + _update_dict(result, zeros, r, 1) + else: + _, factors = Poly(f.as_expr()).factor_list() + if len(factors) == 1 and f.degree() == 2: + for r in roots_quadratic(f): + _update_dict(result, zeros, r, 1) + else: + if len(factors) == 1 and factors[0][1] == 1: + if f.get_domain().is_EX: + res = to_rational_coeffs(f) + if res: + if res[0] is None: + translate_x, f = res[2:] + else: + rescale_x, f = res[1], res[-1] + result = roots(f) + if not result: + for currentroot in _try_decompose(f): + _update_dict(result, zeros, currentroot, 1) + else: + for r in _try_heuristics(f): + _update_dict(result, zeros, r, 1) + else: + for currentroot in _try_decompose(f): + _update_dict(result, zeros, currentroot, 1) + else: + for currentfactor, k in factors: + for r in _try_heuristics(Poly(currentfactor, f.gen, field=True)): + _update_dict(result, zeros, r, k) + + if coeff is not S.One: + _result, result, = result, {} + + for currentroot, k in _result.items(): + result[coeff*currentroot] = k + + if filter not in [None, 'C']: + handlers = { + 'Z': lambda r: r.is_Integer, + 'Q': lambda r: r.is_Rational, + 'R': lambda r: all(a.is_real for a in r.as_numer_denom()), + 'I': lambda r: r.is_imaginary, + } + + try: + query = handlers[filter] + except KeyError: + raise ValueError("Invalid filter: %s" % filter) + + for zero in dict(result).keys(): + if not query(zero): + del result[zero] + + if predicate is not None: + for zero in dict(result).keys(): + if not predicate(zero): + del result[zero] + if rescale_x: + result1 = {} + for k, v in result.items(): + result1[k*rescale_x] = v + result = result1 + if translate_x: + result1 = {} + for k, v in result.items(): + result1[k + translate_x] = v + result = result1 + + # adding zero roots after non-trivial roots have been translated + result.update(zeros) + + if strict and sum(result.values()) < f.degree(): + raise UnsolvableFactorError(filldedent(''' + Strict mode: some factors cannot be solved in radicals, so + a complete list of solutions cannot be returned. Call + roots with strict=False to get solutions expressible in + radicals (if there are any). + ''')) + + if not multiple: + return result + else: + zeros = [] + + for zero in ordered(result): + zeros.extend([zero]*result[zero]) + + return zeros + + +def root_factors(f, *gens, filter=None, **args): + """ + Returns all factors of a univariate polynomial. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy.polys.polyroots import root_factors + + >>> root_factors(x**2 - y, x) + [x - sqrt(y), x + sqrt(y)] + + """ + args = dict(args) + + F = Poly(f, *gens, **args) + + if not F.is_Poly: + return [f] + + if F.is_multivariate: + raise ValueError('multivariate polynomials are not supported') + + x = F.gens[0] + + zeros = roots(F, filter=filter) + + if not zeros: + factors = [F] + else: + factors, N = [], 0 + + for r, n in ordered(zeros.items()): + factors, N = factors + [Poly(x - r, x)]*n, N + n + + if N < F.degree(): + G = reduce(lambda p, q: p*q, factors) + factors.append(F.quo(G)) + + if not isinstance(f, Poly): + factors = [ f.as_expr() for f in factors ] + + return factors diff --git a/phivenv/Lib/site-packages/sympy/polys/polytools.py b/phivenv/Lib/site-packages/sympy/polys/polytools.py new file mode 100644 index 0000000000000000000000000000000000000000..11b9dd3435f8dc68ea3b0578df9fccfb07dd0f4c --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/polytools.py @@ -0,0 +1,7960 @@ +"""User-friendly public interface to polynomial functions. """ + +from __future__ import annotations + +from functools import wraps, reduce +from operator import mul +from typing import Optional +from collections import Counter, defaultdict + +from sympy.core import ( + S, Expr, Add, Tuple +) +from sympy.core.basic import Basic +from sympy.core.decorators import _sympifyit +from sympy.core.exprtools import Factors, factor_nc, factor_terms +from sympy.core.evalf import ( + pure_complex, evalf, fastlog, _evalf_with_bounded_error, quad_to_mpmath) +from sympy.core.function import Derivative +from sympy.core.mul import Mul, _keep_coeff +from sympy.core.intfunc import ilcm +from sympy.core.numbers import I, Integer, equal_valued +from sympy.core.relational import Relational, Equality +from sympy.core.sorting import ordered +from sympy.core.symbol import Dummy, Symbol +from sympy.core.sympify import sympify, _sympify +from sympy.core.traversal import preorder_traversal, bottom_up +from sympy.logic.boolalg import BooleanAtom +from sympy.polys import polyoptions as options +from sympy.polys.constructor import construct_domain +from sympy.polys.domains import FF, QQ, ZZ +from sympy.polys.domains.domainelement import DomainElement +from sympy.polys.fglmtools import matrix_fglm +from sympy.polys.groebnertools import groebner as _groebner +from sympy.polys.monomials import Monomial +from sympy.polys.orderings import monomial_key +from sympy.polys.polyclasses import DMP, DMF, ANP +from sympy.polys.polyerrors import ( + OperationNotSupported, DomainError, + CoercionFailed, UnificationFailed, + GeneratorsNeeded, PolynomialError, + MultivariatePolynomialError, + ExactQuotientFailed, + PolificationFailed, + ComputationFailed, + GeneratorsError, +) +from sympy.polys.polyutils import ( + basic_from_dict, + _sort_gens, + _unify_gens, + _dict_reorder, + _dict_from_expr, + _parallel_dict_from_expr, +) +from sympy.polys.rationaltools import together +from sympy.polys.rootisolation import dup_isolate_real_roots_list +from sympy.utilities import group, public, filldedent +from sympy.utilities.exceptions import sympy_deprecation_warning +from sympy.utilities.iterables import iterable, sift + +# Required to avoid errors +import sympy.polys + +import mpmath +from mpmath.libmp.libhyper import NoConvergence + + + +def _polifyit(func): + @wraps(func) + def wrapper(f, g): + g = _sympify(g) + if isinstance(g, Poly): + return func(f, g) + elif isinstance(g, Integer): + g = f.from_expr(g, *f.gens, domain=f.domain) + return func(f, g) + elif isinstance(g, Expr): + try: + g = f.from_expr(g, *f.gens) + except PolynomialError: + if g.is_Matrix: + return NotImplemented + expr_method = getattr(f.as_expr(), func.__name__) + result = expr_method(g) + if result is not NotImplemented: + sympy_deprecation_warning( + """ + Mixing Poly with non-polynomial expressions in binary + operations is deprecated. Either explicitly convert + the non-Poly operand to a Poly with as_poly() or + convert the Poly to an Expr with as_expr(). + """, + deprecated_since_version="1.6", + active_deprecations_target="deprecated-poly-nonpoly-binary-operations", + ) + return result + else: + return func(f, g) + else: + return NotImplemented + return wrapper + + + +@public +class Poly(Basic): + """ + Generic class for representing and operating on polynomial expressions. + + See :ref:`polys-docs` for general documentation. + + Poly is a subclass of Basic rather than Expr but instances can be + converted to Expr with the :py:meth:`~.Poly.as_expr` method. + + .. deprecated:: 1.6 + + Combining Poly with non-Poly objects in binary operations is + deprecated. Explicitly convert both objects to either Poly or Expr + first. See :ref:`deprecated-poly-nonpoly-binary-operations`. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x, y + + Create a univariate polynomial: + + >>> Poly(x*(x**2 + x - 1)**2) + Poly(x**5 + 2*x**4 - x**3 - 2*x**2 + x, x, domain='ZZ') + + Create a univariate polynomial with specific domain: + + >>> from sympy import sqrt + >>> Poly(x**2 + 2*x + sqrt(3), domain='R') + Poly(1.0*x**2 + 2.0*x + 1.73205080756888, x, domain='RR') + + Create a multivariate polynomial: + + >>> Poly(y*x**2 + x*y + 1) + Poly(x**2*y + x*y + 1, x, y, domain='ZZ') + + Create a univariate polynomial, where y is a constant: + + >>> Poly(y*x**2 + x*y + 1,x) + Poly(y*x**2 + y*x + 1, x, domain='ZZ[y]') + + You can evaluate the above polynomial as a function of y: + + >>> Poly(y*x**2 + x*y + 1,x).eval(2) + 6*y + 1 + + See Also + ======== + + sympy.core.expr.Expr + + """ + + __slots__ = ('rep', 'gens') + + is_commutative = True + is_Poly = True + _op_priority = 10.001 + + rep: DMP + gens: tuple[Expr, ...] + + def __new__(cls, rep, *gens, **args) -> Poly: + """Create a new polynomial instance out of something useful. """ + opt = options.build_options(gens, args) + + if 'order' in opt: + raise NotImplementedError("'order' keyword is not implemented yet") + + if isinstance(rep, (DMP, DMF, ANP, DomainElement)): + return cls._from_domain_element(rep, opt) + elif iterable(rep, exclude=str): + if isinstance(rep, dict): + return cls._from_dict(rep, opt) + else: + return cls._from_list(list(rep), opt) + else: + rep = sympify(rep, evaluate=type(rep) is not str) # type: ignore + + if rep.is_Poly: + return cls._from_poly(rep, opt) + else: + return cls._from_expr(rep, opt) + + # Poly does not pass its args to Basic.__new__ to be stored in _args so we + # have to emulate them here with an args property that derives from rep + # and gens which are instance attributes. This also means we need to + # define _hashable_content. The _hashable_content is rep and gens but args + # uses expr instead of rep (expr is the Basic version of rep). Passing + # expr in args means that Basic methods like subs should work. Using rep + # otherwise means that Poly can remain more efficient than Basic by + # avoiding creating a Basic instance just to be hashable. + + @classmethod + def new(cls, rep, *gens): + """Construct :class:`Poly` instance from raw representation. """ + if not isinstance(rep, DMP): + raise PolynomialError( + "invalid polynomial representation: %s" % rep) + elif rep.lev != len(gens) - 1: + raise PolynomialError("invalid arguments: %s, %s" % (rep, gens)) + + obj = Basic.__new__(cls) + obj.rep = rep + obj.gens = gens + + return obj + + @property + def expr(self): + return basic_from_dict(self.rep.to_sympy_dict(), *self.gens) + + @property + def args(self): + return (self.expr,) + self.gens + + def _hashable_content(self): + return (self.rep,) + self.gens + + @classmethod + def from_dict(cls, rep, *gens, **args): + """Construct a polynomial from a ``dict``. """ + opt = options.build_options(gens, args) + return cls._from_dict(rep, opt) + + @classmethod + def from_list(cls, rep, *gens, **args): + """Construct a polynomial from a ``list``. """ + opt = options.build_options(gens, args) + return cls._from_list(rep, opt) + + @classmethod + def from_poly(cls, rep, *gens, **args): + """Construct a polynomial from a polynomial. """ + opt = options.build_options(gens, args) + return cls._from_poly(rep, opt) + + @classmethod + def from_expr(cls, rep, *gens, **args): + """Construct a polynomial from an expression. """ + opt = options.build_options(gens, args) + return cls._from_expr(rep, opt) + + @classmethod + def _from_dict(cls, rep, opt): + """Construct a polynomial from a ``dict``. """ + gens = opt.gens + + if not gens: + raise GeneratorsNeeded( + "Cannot initialize from 'dict' without generators") + + level = len(gens) - 1 + domain = opt.domain + + if domain is None: + domain, rep = construct_domain(rep, opt=opt) + else: + for monom, coeff in rep.items(): + rep[monom] = domain.convert(coeff) + + return cls.new(DMP.from_dict(rep, level, domain), *gens) + + @classmethod + def _from_list(cls, rep, opt): + """Construct a polynomial from a ``list``. """ + gens = opt.gens + + if not gens: + raise GeneratorsNeeded( + "Cannot initialize from 'list' without generators") + elif len(gens) != 1: + raise MultivariatePolynomialError( + "'list' representation not supported") + + level = len(gens) - 1 + domain = opt.domain + + if domain is None: + domain, rep = construct_domain(rep, opt=opt) + else: + rep = list(map(domain.convert, rep)) + + return cls.new(DMP.from_list(rep, level, domain), *gens) + + @classmethod + def _from_poly(cls, rep, opt): + """Construct a polynomial from a polynomial. """ + if cls != rep.__class__: + rep = cls.new(rep.rep, *rep.gens) + + gens = opt.gens + field = opt.field + domain = opt.domain + + if gens and rep.gens != gens: + if set(rep.gens) != set(gens): + return cls._from_expr(rep.as_expr(), opt) + else: + rep = rep.reorder(*gens) + + if 'domain' in opt and domain: + rep = rep.set_domain(domain) + elif field is True: + rep = rep.to_field() + + return rep + + @classmethod + def _from_expr(cls, rep, opt): + """Construct a polynomial from an expression. """ + rep, opt = _dict_from_expr(rep, opt) + return cls._from_dict(rep, opt) + + @classmethod + def _from_domain_element(cls, rep, opt): + gens = opt.gens + domain = opt.domain + + level = len(gens) - 1 + rep = [domain.convert(rep)] + + return cls.new(DMP.from_list(rep, level, domain), *gens) + + def __hash__(self): + return super().__hash__() + + @property + def free_symbols(self): + """ + Free symbols of a polynomial expression. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x, y, z + + >>> Poly(x**2 + 1).free_symbols + {x} + >>> Poly(x**2 + y).free_symbols + {x, y} + >>> Poly(x**2 + y, x).free_symbols + {x, y} + >>> Poly(x**2 + y, x, z).free_symbols + {x, y} + + """ + symbols = set() + gens = self.gens + for i in range(len(gens)): + for monom in self.monoms(): + if monom[i]: + symbols |= gens[i].free_symbols + break + + return symbols | self.free_symbols_in_domain + + @property + def free_symbols_in_domain(self): + """ + Free symbols of the domain of ``self``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x, y + + >>> Poly(x**2 + 1).free_symbols_in_domain + set() + >>> Poly(x**2 + y).free_symbols_in_domain + set() + >>> Poly(x**2 + y, x).free_symbols_in_domain + {y} + + """ + domain, symbols = self.rep.dom, set() + + if domain.is_Composite: + for gen in domain.symbols: + symbols |= gen.free_symbols + elif domain.is_EX: + for coeff in self.coeffs(): + symbols |= coeff.free_symbols + + return symbols + + @property + def gen(self): + """ + Return the principal generator. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(x**2 + 1, x).gen + x + + """ + return self.gens[0] + + @property + def domain(self): + """Get the ground domain of a :py:class:`~.Poly` + + Returns + ======= + + :py:class:`~.Domain`: + Ground domain of the :py:class:`~.Poly`. + + Examples + ======== + + >>> from sympy import Poly, Symbol + >>> x = Symbol('x') + >>> p = Poly(x**2 + x) + >>> p + Poly(x**2 + x, x, domain='ZZ') + >>> p.domain + ZZ + """ + return self.get_domain() + + @property + def zero(self): + """Return zero polynomial with ``self``'s properties. """ + return self.new(self.rep.zero(self.rep.lev, self.rep.dom), *self.gens) + + @property + def one(self): + """Return one polynomial with ``self``'s properties. """ + return self.new(self.rep.one(self.rep.lev, self.rep.dom), *self.gens) + + def unify(f, g): + """ + Make ``f`` and ``g`` belong to the same domain. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> f, g = Poly(x/2 + 1), Poly(2*x + 1) + + >>> f + Poly(1/2*x + 1, x, domain='QQ') + >>> g + Poly(2*x + 1, x, domain='ZZ') + + >>> F, G = f.unify(g) + + >>> F + Poly(1/2*x + 1, x, domain='QQ') + >>> G + Poly(2*x + 1, x, domain='QQ') + + """ + _, per, F, G = f._unify(g) + return per(F), per(G) + + def _unify(f, g): + g = sympify(g) + + if not g.is_Poly: + try: + g_coeff = f.rep.dom.from_sympy(g) + except CoercionFailed: + raise UnificationFailed("Cannot unify %s with %s" % (f, g)) + else: + return f.rep.dom, f.per, f.rep, f.rep.ground_new(g_coeff) + + if isinstance(f.rep, DMP) and isinstance(g.rep, DMP): + gens = _unify_gens(f.gens, g.gens) + + dom, lev = f.rep.dom.unify(g.rep.dom, gens), len(gens) - 1 + + if f.gens != gens: + f_monoms, f_coeffs = _dict_reorder( + f.rep.to_dict(), f.gens, gens) + + if f.rep.dom != dom: + f_coeffs = [dom.convert(c, f.rep.dom) for c in f_coeffs] + + F = DMP.from_dict(dict(list(zip(f_monoms, f_coeffs))), lev, dom) + else: + F = f.rep.convert(dom) + + if g.gens != gens: + g_monoms, g_coeffs = _dict_reorder( + g.rep.to_dict(), g.gens, gens) + + if g.rep.dom != dom: + g_coeffs = [dom.convert(c, g.rep.dom) for c in g_coeffs] + + G = DMP.from_dict(dict(list(zip(g_monoms, g_coeffs))), lev, dom) + else: + G = g.rep.convert(dom) + else: + raise UnificationFailed("Cannot unify %s with %s" % (f, g)) + + cls = f.__class__ + + def per(rep, dom=dom, gens=gens, remove=None): + if remove is not None: + gens = gens[:remove] + gens[remove + 1:] + + if not gens: + return dom.to_sympy(rep) + + return cls.new(rep, *gens) + + return dom, per, F, G + + def per(f, rep, gens=None, remove=None): + """ + Create a Poly out of the given representation. + + Examples + ======== + + >>> from sympy import Poly, ZZ + >>> from sympy.abc import x, y + + >>> from sympy.polys.polyclasses import DMP + + >>> a = Poly(x**2 + 1) + + >>> a.per(DMP([ZZ(1), ZZ(1)], ZZ), gens=[y]) + Poly(y + 1, y, domain='ZZ') + + """ + if gens is None: + gens = f.gens + + if remove is not None: + gens = gens[:remove] + gens[remove + 1:] + + if not gens: + return f.rep.dom.to_sympy(rep) + + return f.__class__.new(rep, *gens) + + def set_domain(f, domain): + """Set the ground domain of ``f``. """ + opt = options.build_options(f.gens, {'domain': domain}) + return f.per(f.rep.convert(opt.domain)) + + def get_domain(f): + """Get the ground domain of ``f``. """ + return f.rep.dom + + def set_modulus(f, modulus): + """ + Set the modulus of ``f``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(5*x**2 + 2*x - 1, x).set_modulus(2) + Poly(x**2 + 1, x, modulus=2) + + """ + modulus = options.Modulus.preprocess(modulus) + return f.set_domain(FF(modulus)) + + def get_modulus(f): + """ + Get the modulus of ``f``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(x**2 + 1, modulus=2).get_modulus() + 2 + + """ + domain = f.get_domain() + + if domain.is_FiniteField: + return Integer(domain.characteristic()) + else: + raise PolynomialError("not a polynomial over a Galois field") + + def _eval_subs(f, old, new): + """Internal implementation of :func:`subs`. """ + if old in f.gens: + if new.is_number: + return f.eval(old, new) + else: + try: + return f.replace(old, new) + except PolynomialError: + pass + + return f.as_expr().subs(old, new) + + def exclude(f): + """ + Remove unnecessary generators from ``f``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import a, b, c, d, x + + >>> Poly(a + x, a, b, c, d, x).exclude() + Poly(a + x, a, x, domain='ZZ') + + """ + J, new = f.rep.exclude() + gens = [gen for j, gen in enumerate(f.gens) if j not in J] + + return f.per(new, gens=gens) + + def replace(f, x, y=None, **_ignore): + # XXX this does not match Basic's signature + """ + Replace ``x`` with ``y`` in generators list. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x, y + + >>> Poly(x**2 + 1, x).replace(x, y) + Poly(y**2 + 1, y, domain='ZZ') + + """ + if y is None: + if f.is_univariate: + x, y = f.gen, x + else: + raise PolynomialError( + "syntax supported only in univariate case") + + if x == y or x not in f.gens: + return f + + if x in f.gens and y not in f.gens: + dom = f.get_domain() + + if not dom.is_Composite or y not in dom.symbols: + gens = list(f.gens) + gens[gens.index(x)] = y + return f.per(f.rep, gens=gens) + + raise PolynomialError("Cannot replace %s with %s in %s" % (x, y, f)) + + def match(f, *args, **kwargs): + """Match expression from Poly. See Basic.match()""" + return f.as_expr().match(*args, **kwargs) + + def reorder(f, *gens, **args): + """ + Efficiently apply new order of generators. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x, y + + >>> Poly(x**2 + x*y**2, x, y).reorder(y, x) + Poly(y**2*x + x**2, y, x, domain='ZZ') + + """ + opt = options.Options((), args) + + if not gens: + gens = _sort_gens(f.gens, opt=opt) + elif set(f.gens) != set(gens): + raise PolynomialError( + "generators list can differ only up to order of elements") + + rep = dict(list(zip(*_dict_reorder(f.rep.to_dict(), f.gens, gens)))) + + return f.per(DMP.from_dict(rep, len(gens) - 1, f.rep.dom), gens=gens) + + def ltrim(f, gen): + """ + Remove dummy generators from ``f`` that are to the left of + specified ``gen`` in the generators as ordered. When ``gen`` + is an integer, it refers to the generator located at that + position within the tuple of generators of ``f``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x, y, z + + >>> Poly(y**2 + y*z**2, x, y, z).ltrim(y) + Poly(y**2 + y*z**2, y, z, domain='ZZ') + >>> Poly(z, x, y, z).ltrim(-1) + Poly(z, z, domain='ZZ') + + """ + rep = f.as_dict(native=True) + j = f._gen_to_level(gen) + + terms = {} + + for monom, coeff in rep.items(): + + if any(monom[:j]): + # some generator is used in the portion to be trimmed + raise PolynomialError("Cannot left trim %s" % f) + + terms[monom[j:]] = coeff + + gens = f.gens[j:] + + return f.new(DMP.from_dict(terms, len(gens) - 1, f.rep.dom), *gens) + + def has_only_gens(f, *gens): + """ + Return ``True`` if ``Poly(f, *gens)`` retains ground domain. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x, y, z + + >>> Poly(x*y + 1, x, y, z).has_only_gens(x, y) + True + >>> Poly(x*y + z, x, y, z).has_only_gens(x, y) + False + + """ + indices = set() + + for gen in gens: + try: + index = f.gens.index(gen) + except ValueError: + raise GeneratorsError( + "%s doesn't have %s as generator" % (f, gen)) + else: + indices.add(index) + + for monom in f.monoms(): + for i, elt in enumerate(monom): + if i not in indices and elt: + return False + + return True + + def to_ring(f): + """ + Make the ground domain a ring. + + Examples + ======== + + >>> from sympy import Poly, QQ + >>> from sympy.abc import x + + >>> Poly(x**2 + 1, domain=QQ).to_ring() + Poly(x**2 + 1, x, domain='ZZ') + + """ + if hasattr(f.rep, 'to_ring'): + result = f.rep.to_ring() + else: # pragma: no cover + raise OperationNotSupported(f, 'to_ring') + + return f.per(result) + + def to_field(f): + """ + Make the ground domain a field. + + Examples + ======== + + >>> from sympy import Poly, ZZ + >>> from sympy.abc import x + + >>> Poly(x**2 + 1, x, domain=ZZ).to_field() + Poly(x**2 + 1, x, domain='QQ') + + """ + if hasattr(f.rep, 'to_field'): + result = f.rep.to_field() + else: # pragma: no cover + raise OperationNotSupported(f, 'to_field') + + return f.per(result) + + def to_exact(f): + """ + Make the ground domain exact. + + Examples + ======== + + >>> from sympy import Poly, RR + >>> from sympy.abc import x + + >>> Poly(x**2 + 1.0, x, domain=RR).to_exact() + Poly(x**2 + 1, x, domain='QQ') + + """ + if hasattr(f.rep, 'to_exact'): + result = f.rep.to_exact() + else: # pragma: no cover + raise OperationNotSupported(f, 'to_exact') + + return f.per(result) + + def retract(f, field=None): + """ + Recalculate the ground domain of a polynomial. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> f = Poly(x**2 + 1, x, domain='QQ[y]') + >>> f + Poly(x**2 + 1, x, domain='QQ[y]') + + >>> f.retract() + Poly(x**2 + 1, x, domain='ZZ') + >>> f.retract(field=True) + Poly(x**2 + 1, x, domain='QQ') + + """ + dom, rep = construct_domain(f.as_dict(zero=True), + field=field, composite=f.domain.is_Composite or None) + return f.from_dict(rep, f.gens, domain=dom) + + def slice(f, x, m, n=None): + """Take a continuous subsequence of terms of ``f``. """ + if n is None: + j, m, n = 0, x, m + else: + j = f._gen_to_level(x) + + m, n = int(m), int(n) + + if hasattr(f.rep, 'slice'): + result = f.rep.slice(m, n, j) + else: # pragma: no cover + raise OperationNotSupported(f, 'slice') + + return f.per(result) + + def coeffs(f, order=None): + """ + Returns all non-zero coefficients from ``f`` in lex order. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(x**3 + 2*x + 3, x).coeffs() + [1, 2, 3] + + See Also + ======== + all_coeffs + coeff_monomial + nth + + """ + return [f.rep.dom.to_sympy(c) for c in f.rep.coeffs(order=order)] + + def monoms(f, order=None): + """ + Returns all non-zero monomials from ``f`` in lex order. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x, y + + >>> Poly(x**2 + 2*x*y**2 + x*y + 3*y, x, y).monoms() + [(2, 0), (1, 2), (1, 1), (0, 1)] + + See Also + ======== + all_monoms + + """ + return f.rep.monoms(order=order) + + def terms(f, order=None): + """ + Returns all non-zero terms from ``f`` in lex order. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x, y + + >>> Poly(x**2 + 2*x*y**2 + x*y + 3*y, x, y).terms() + [((2, 0), 1), ((1, 2), 2), ((1, 1), 1), ((0, 1), 3)] + + See Also + ======== + all_terms + + """ + return [(m, f.rep.dom.to_sympy(c)) for m, c in f.rep.terms(order=order)] + + def all_coeffs(f): + """ + Returns all coefficients from a univariate polynomial ``f``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(x**3 + 2*x - 1, x).all_coeffs() + [1, 0, 2, -1] + + """ + return [f.rep.dom.to_sympy(c) for c in f.rep.all_coeffs()] + + def all_monoms(f): + """ + Returns all monomials from a univariate polynomial ``f``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(x**3 + 2*x - 1, x).all_monoms() + [(3,), (2,), (1,), (0,)] + + See Also + ======== + all_terms + + """ + return f.rep.all_monoms() + + def all_terms(f): + """ + Returns all terms from a univariate polynomial ``f``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(x**3 + 2*x - 1, x).all_terms() + [((3,), 1), ((2,), 0), ((1,), 2), ((0,), -1)] + + """ + return [(m, f.rep.dom.to_sympy(c)) for m, c in f.rep.all_terms()] + + def termwise(f, func, *gens, **args): + """ + Apply a function to all terms of ``f``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> def func(k, coeff): + ... k = k[0] + ... return coeff//10**(2-k) + + >>> Poly(x**2 + 20*x + 400).termwise(func) + Poly(x**2 + 2*x + 4, x, domain='ZZ') + + """ + terms = {} + + for monom, coeff in f.terms(): + result = func(monom, coeff) + + if isinstance(result, tuple): + monom, coeff = result + else: + coeff = result + + if coeff: + if monom not in terms: + terms[monom] = coeff + else: + raise PolynomialError( + "%s monomial was generated twice" % monom) + + return f.from_dict(terms, *(gens or f.gens), **args) + + def length(f): + """ + Returns the number of non-zero terms in ``f``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(x**2 + 2*x - 1).length() + 3 + + """ + return len(f.as_dict()) + + def as_dict(f, native=False, zero=False): + """ + Switch to a ``dict`` representation. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x, y + + >>> Poly(x**2 + 2*x*y**2 - y, x, y).as_dict() + {(0, 1): -1, (1, 2): 2, (2, 0): 1} + + """ + if native: + return f.rep.to_dict(zero=zero) + else: + return f.rep.to_sympy_dict(zero=zero) + + def as_list(f, native=False): + """Switch to a ``list`` representation. """ + if native: + return f.rep.to_list() + else: + return f.rep.to_sympy_list() + + def as_expr(f, *gens): + """ + Convert a Poly instance to an Expr instance. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x, y + + >>> f = Poly(x**2 + 2*x*y**2 - y, x, y) + + >>> f.as_expr() + x**2 + 2*x*y**2 - y + >>> f.as_expr({x: 5}) + 10*y**2 - y + 25 + >>> f.as_expr(5, 6) + 379 + + """ + if not gens: + return f.expr + + if len(gens) == 1 and isinstance(gens[0], dict): + mapping = gens[0] + gens = list(f.gens) + + for gen, value in mapping.items(): + try: + index = gens.index(gen) + except ValueError: + raise GeneratorsError( + "%s doesn't have %s as generator" % (f, gen)) + else: + gens[index] = value + + return basic_from_dict(f.rep.to_sympy_dict(), *gens) + + def as_poly(self, *gens, **args): + """Converts ``self`` to a polynomial or returns ``None``. + + >>> from sympy import sin + >>> from sympy.abc import x, y + + >>> print((x**2 + x*y).as_poly()) + Poly(x**2 + x*y, x, y, domain='ZZ') + + >>> print((x**2 + x*y).as_poly(x, y)) + Poly(x**2 + x*y, x, y, domain='ZZ') + + >>> print((x**2 + sin(y)).as_poly(x, y)) + None + + """ + try: + poly = Poly(self, *gens, **args) + + if not poly.is_Poly: + return None + else: + return poly + except PolynomialError: + return None + + def lift(f): + """ + Convert algebraic coefficients to rationals. + + Examples + ======== + + >>> from sympy import Poly, I + >>> from sympy.abc import x + + >>> Poly(x**2 + I*x + 1, x, extension=I).lift() + Poly(x**4 + 3*x**2 + 1, x, domain='QQ') + + """ + if hasattr(f.rep, 'lift'): + result = f.rep.lift() + else: # pragma: no cover + raise OperationNotSupported(f, 'lift') + + return f.per(result) + + def deflate(f): + """ + Reduce degree of ``f`` by mapping ``x_i**m`` to ``y_i``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x, y + + >>> Poly(x**6*y**2 + x**3 + 1, x, y).deflate() + ((3, 2), Poly(x**2*y + x + 1, x, y, domain='ZZ')) + + """ + if hasattr(f.rep, 'deflate'): + J, result = f.rep.deflate() + else: # pragma: no cover + raise OperationNotSupported(f, 'deflate') + + return J, f.per(result) + + def inject(f, front=False): + """ + Inject ground domain generators into ``f``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x, y + + >>> f = Poly(x**2*y + x*y**3 + x*y + 1, x) + + >>> f.inject() + Poly(x**2*y + x*y**3 + x*y + 1, x, y, domain='ZZ') + >>> f.inject(front=True) + Poly(y**3*x + y*x**2 + y*x + 1, y, x, domain='ZZ') + + """ + dom = f.rep.dom + + if dom.is_Numerical: + return f + elif not dom.is_Poly: + raise DomainError("Cannot inject generators over %s" % dom) + + if hasattr(f.rep, 'inject'): + result = f.rep.inject(front=front) + else: # pragma: no cover + raise OperationNotSupported(f, 'inject') + + if front: + gens = dom.symbols + f.gens + else: + gens = f.gens + dom.symbols + + return f.new(result, *gens) + + def eject(f, *gens): + """ + Eject selected generators into the ground domain. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x, y + + >>> f = Poly(x**2*y + x*y**3 + x*y + 1, x, y) + + >>> f.eject(x) + Poly(x*y**3 + (x**2 + x)*y + 1, y, domain='ZZ[x]') + >>> f.eject(y) + Poly(y*x**2 + (y**3 + y)*x + 1, x, domain='ZZ[y]') + + """ + dom = f.rep.dom + + if not dom.is_Numerical: + raise DomainError("Cannot eject generators over %s" % dom) + + k = len(gens) + + if f.gens[:k] == gens: + _gens, front = f.gens[k:], True + elif f.gens[-k:] == gens: + _gens, front = f.gens[:-k], False + else: + raise NotImplementedError( + "can only eject front or back generators") + + dom = dom.inject(*gens) + + if hasattr(f.rep, 'eject'): + result = f.rep.eject(dom, front=front) + else: # pragma: no cover + raise OperationNotSupported(f, 'eject') + + return f.new(result, *_gens) + + def terms_gcd(f): + """ + Remove GCD of terms from the polynomial ``f``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x, y + + >>> Poly(x**6*y**2 + x**3*y, x, y).terms_gcd() + ((3, 1), Poly(x**3*y + 1, x, y, domain='ZZ')) + + """ + if hasattr(f.rep, 'terms_gcd'): + J, result = f.rep.terms_gcd() + else: # pragma: no cover + raise OperationNotSupported(f, 'terms_gcd') + + return J, f.per(result) + + def add_ground(f, coeff): + """ + Add an element of the ground domain to ``f``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(x + 1).add_ground(2) + Poly(x + 3, x, domain='ZZ') + + """ + if hasattr(f.rep, 'add_ground'): + result = f.rep.add_ground(coeff) + else: # pragma: no cover + raise OperationNotSupported(f, 'add_ground') + + return f.per(result) + + def sub_ground(f, coeff): + """ + Subtract an element of the ground domain from ``f``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(x + 1).sub_ground(2) + Poly(x - 1, x, domain='ZZ') + + """ + if hasattr(f.rep, 'sub_ground'): + result = f.rep.sub_ground(coeff) + else: # pragma: no cover + raise OperationNotSupported(f, 'sub_ground') + + return f.per(result) + + def mul_ground(f, coeff): + """ + Multiply ``f`` by a an element of the ground domain. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(x + 1).mul_ground(2) + Poly(2*x + 2, x, domain='ZZ') + + """ + if hasattr(f.rep, 'mul_ground'): + result = f.rep.mul_ground(coeff) + else: # pragma: no cover + raise OperationNotSupported(f, 'mul_ground') + + return f.per(result) + + def quo_ground(f, coeff): + """ + Quotient of ``f`` by a an element of the ground domain. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(2*x + 4).quo_ground(2) + Poly(x + 2, x, domain='ZZ') + + >>> Poly(2*x + 3).quo_ground(2) + Poly(x + 1, x, domain='ZZ') + + """ + if hasattr(f.rep, 'quo_ground'): + result = f.rep.quo_ground(coeff) + else: # pragma: no cover + raise OperationNotSupported(f, 'quo_ground') + + return f.per(result) + + def exquo_ground(f, coeff): + """ + Exact quotient of ``f`` by a an element of the ground domain. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(2*x + 4).exquo_ground(2) + Poly(x + 2, x, domain='ZZ') + + >>> Poly(2*x + 3).exquo_ground(2) + Traceback (most recent call last): + ... + ExactQuotientFailed: 2 does not divide 3 in ZZ + + """ + if hasattr(f.rep, 'exquo_ground'): + result = f.rep.exquo_ground(coeff) + else: # pragma: no cover + raise OperationNotSupported(f, 'exquo_ground') + + return f.per(result) + + def abs(f): + """ + Make all coefficients in ``f`` positive. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(x**2 - 1, x).abs() + Poly(x**2 + 1, x, domain='ZZ') + + """ + if hasattr(f.rep, 'abs'): + result = f.rep.abs() + else: # pragma: no cover + raise OperationNotSupported(f, 'abs') + + return f.per(result) + + def neg(f): + """ + Negate all coefficients in ``f``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(x**2 - 1, x).neg() + Poly(-x**2 + 1, x, domain='ZZ') + + >>> -Poly(x**2 - 1, x) + Poly(-x**2 + 1, x, domain='ZZ') + + """ + if hasattr(f.rep, 'neg'): + result = f.rep.neg() + else: # pragma: no cover + raise OperationNotSupported(f, 'neg') + + return f.per(result) + + def add(f, g): + """ + Add two polynomials ``f`` and ``g``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(x**2 + 1, x).add(Poly(x - 2, x)) + Poly(x**2 + x - 1, x, domain='ZZ') + + >>> Poly(x**2 + 1, x) + Poly(x - 2, x) + Poly(x**2 + x - 1, x, domain='ZZ') + + """ + g = sympify(g) + + if not g.is_Poly: + return f.add_ground(g) + + _, per, F, G = f._unify(g) + + if hasattr(f.rep, 'add'): + result = F.add(G) + else: # pragma: no cover + raise OperationNotSupported(f, 'add') + + return per(result) + + def sub(f, g): + """ + Subtract two polynomials ``f`` and ``g``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(x**2 + 1, x).sub(Poly(x - 2, x)) + Poly(x**2 - x + 3, x, domain='ZZ') + + >>> Poly(x**2 + 1, x) - Poly(x - 2, x) + Poly(x**2 - x + 3, x, domain='ZZ') + + """ + g = sympify(g) + + if not g.is_Poly: + return f.sub_ground(g) + + _, per, F, G = f._unify(g) + + if hasattr(f.rep, 'sub'): + result = F.sub(G) + else: # pragma: no cover + raise OperationNotSupported(f, 'sub') + + return per(result) + + def mul(f, g): + """ + Multiply two polynomials ``f`` and ``g``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(x**2 + 1, x).mul(Poly(x - 2, x)) + Poly(x**3 - 2*x**2 + x - 2, x, domain='ZZ') + + >>> Poly(x**2 + 1, x)*Poly(x - 2, x) + Poly(x**3 - 2*x**2 + x - 2, x, domain='ZZ') + + """ + g = sympify(g) + + if not g.is_Poly: + return f.mul_ground(g) + + _, per, F, G = f._unify(g) + + if hasattr(f.rep, 'mul'): + result = F.mul(G) + else: # pragma: no cover + raise OperationNotSupported(f, 'mul') + + return per(result) + + def sqr(f): + """ + Square a polynomial ``f``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(x - 2, x).sqr() + Poly(x**2 - 4*x + 4, x, domain='ZZ') + + >>> Poly(x - 2, x)**2 + Poly(x**2 - 4*x + 4, x, domain='ZZ') + + """ + if hasattr(f.rep, 'sqr'): + result = f.rep.sqr() + else: # pragma: no cover + raise OperationNotSupported(f, 'sqr') + + return f.per(result) + + def pow(f, n): + """ + Raise ``f`` to a non-negative power ``n``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(x - 2, x).pow(3) + Poly(x**3 - 6*x**2 + 12*x - 8, x, domain='ZZ') + + >>> Poly(x - 2, x)**3 + Poly(x**3 - 6*x**2 + 12*x - 8, x, domain='ZZ') + + """ + n = int(n) + + if hasattr(f.rep, 'pow'): + result = f.rep.pow(n) + else: # pragma: no cover + raise OperationNotSupported(f, 'pow') + + return f.per(result) + + def pdiv(f, g): + """ + Polynomial pseudo-division of ``f`` by ``g``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(x**2 + 1, x).pdiv(Poly(2*x - 4, x)) + (Poly(2*x + 4, x, domain='ZZ'), Poly(20, x, domain='ZZ')) + + """ + _, per, F, G = f._unify(g) + + if hasattr(f.rep, 'pdiv'): + q, r = F.pdiv(G) + else: # pragma: no cover + raise OperationNotSupported(f, 'pdiv') + + return per(q), per(r) + + def prem(f, g): + """ + Polynomial pseudo-remainder of ``f`` by ``g``. + + Caveat: The function prem(f, g, x) can be safely used to compute + in Z[x] _only_ subresultant polynomial remainder sequences (prs's). + + To safely compute Euclidean and Sturmian prs's in Z[x] + employ anyone of the corresponding functions found in + the module sympy.polys.subresultants_qq_zz. The functions + in the module with suffix _pg compute prs's in Z[x] employing + rem(f, g, x), whereas the functions with suffix _amv + compute prs's in Z[x] employing rem_z(f, g, x). + + The function rem_z(f, g, x) differs from prem(f, g, x) in that + to compute the remainder polynomials in Z[x] it premultiplies + the divident times the absolute value of the leading coefficient + of the divisor raised to the power degree(f, x) - degree(g, x) + 1. + + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(x**2 + 1, x).prem(Poly(2*x - 4, x)) + Poly(20, x, domain='ZZ') + + """ + _, per, F, G = f._unify(g) + + if hasattr(f.rep, 'prem'): + result = F.prem(G) + else: # pragma: no cover + raise OperationNotSupported(f, 'prem') + + return per(result) + + def pquo(f, g): + """ + Polynomial pseudo-quotient of ``f`` by ``g``. + + See the Caveat note in the function prem(f, g). + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(x**2 + 1, x).pquo(Poly(2*x - 4, x)) + Poly(2*x + 4, x, domain='ZZ') + + >>> Poly(x**2 - 1, x).pquo(Poly(2*x - 2, x)) + Poly(2*x + 2, x, domain='ZZ') + + """ + _, per, F, G = f._unify(g) + + if hasattr(f.rep, 'pquo'): + result = F.pquo(G) + else: # pragma: no cover + raise OperationNotSupported(f, 'pquo') + + return per(result) + + def pexquo(f, g): + """ + Polynomial exact pseudo-quotient of ``f`` by ``g``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(x**2 - 1, x).pexquo(Poly(2*x - 2, x)) + Poly(2*x + 2, x, domain='ZZ') + + >>> Poly(x**2 + 1, x).pexquo(Poly(2*x - 4, x)) + Traceback (most recent call last): + ... + ExactQuotientFailed: 2*x - 4 does not divide x**2 + 1 + + """ + _, per, F, G = f._unify(g) + + if hasattr(f.rep, 'pexquo'): + try: + result = F.pexquo(G) + except ExactQuotientFailed as exc: + raise exc.new(f.as_expr(), g.as_expr()) + else: # pragma: no cover + raise OperationNotSupported(f, 'pexquo') + + return per(result) + + def div(f, g, auto=True): + """ + Polynomial division with remainder of ``f`` by ``g``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(x**2 + 1, x).div(Poly(2*x - 4, x)) + (Poly(1/2*x + 1, x, domain='QQ'), Poly(5, x, domain='QQ')) + + >>> Poly(x**2 + 1, x).div(Poly(2*x - 4, x), auto=False) + (Poly(0, x, domain='ZZ'), Poly(x**2 + 1, x, domain='ZZ')) + + """ + dom, per, F, G = f._unify(g) + retract = False + + if auto and dom.is_Ring and not dom.is_Field: + F, G = F.to_field(), G.to_field() + retract = True + + if hasattr(f.rep, 'div'): + q, r = F.div(G) + else: # pragma: no cover + raise OperationNotSupported(f, 'div') + + if retract: + try: + Q, R = q.to_ring(), r.to_ring() + except CoercionFailed: + pass + else: + q, r = Q, R + + return per(q), per(r) + + def rem(f, g, auto=True): + """ + Computes the polynomial remainder of ``f`` by ``g``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(x**2 + 1, x).rem(Poly(2*x - 4, x)) + Poly(5, x, domain='ZZ') + + >>> Poly(x**2 + 1, x).rem(Poly(2*x - 4, x), auto=False) + Poly(x**2 + 1, x, domain='ZZ') + + """ + dom, per, F, G = f._unify(g) + retract = False + + if auto and dom.is_Ring and not dom.is_Field: + F, G = F.to_field(), G.to_field() + retract = True + + if hasattr(f.rep, 'rem'): + r = F.rem(G) + else: # pragma: no cover + raise OperationNotSupported(f, 'rem') + + if retract: + try: + r = r.to_ring() + except CoercionFailed: + pass + + return per(r) + + def quo(f, g, auto=True): + """ + Computes polynomial quotient of ``f`` by ``g``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(x**2 + 1, x).quo(Poly(2*x - 4, x)) + Poly(1/2*x + 1, x, domain='QQ') + + >>> Poly(x**2 - 1, x).quo(Poly(x - 1, x)) + Poly(x + 1, x, domain='ZZ') + + """ + dom, per, F, G = f._unify(g) + retract = False + + if auto and dom.is_Ring and not dom.is_Field: + F, G = F.to_field(), G.to_field() + retract = True + + if hasattr(f.rep, 'quo'): + q = F.quo(G) + else: # pragma: no cover + raise OperationNotSupported(f, 'quo') + + if retract: + try: + q = q.to_ring() + except CoercionFailed: + pass + + return per(q) + + def exquo(f, g, auto=True): + """ + Computes polynomial exact quotient of ``f`` by ``g``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(x**2 - 1, x).exquo(Poly(x - 1, x)) + Poly(x + 1, x, domain='ZZ') + + >>> Poly(x**2 + 1, x).exquo(Poly(2*x - 4, x)) + Traceback (most recent call last): + ... + ExactQuotientFailed: 2*x - 4 does not divide x**2 + 1 + + """ + dom, per, F, G = f._unify(g) + retract = False + + if auto and dom.is_Ring and not dom.is_Field: + F, G = F.to_field(), G.to_field() + retract = True + + if hasattr(f.rep, 'exquo'): + try: + q = F.exquo(G) + except ExactQuotientFailed as exc: + raise exc.new(f.as_expr(), g.as_expr()) + else: # pragma: no cover + raise OperationNotSupported(f, 'exquo') + + if retract: + try: + q = q.to_ring() + except CoercionFailed: + pass + + return per(q) + + def _gen_to_level(f, gen): + """Returns level associated with the given generator. """ + if isinstance(gen, int): + length = len(f.gens) + + if -length <= gen < length: + if gen < 0: + return length + gen + else: + return gen + else: + raise PolynomialError("-%s <= gen < %s expected, got %s" % + (length, length, gen)) + else: + try: + return f.gens.index(sympify(gen)) + except ValueError: + raise PolynomialError( + "a valid generator expected, got %s" % gen) + + def degree(f, gen=0): + """ + Returns degree of ``f`` in ``x_j``. + + The degree of 0 is negative infinity. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x, y + + >>> Poly(x**2 + y*x + 1, x, y).degree() + 2 + >>> Poly(x**2 + y*x + y, x, y).degree(y) + 1 + >>> Poly(0, x).degree() + -oo + + """ + j = f._gen_to_level(gen) + + if hasattr(f.rep, 'degree'): + d = f.rep.degree(j) + if d < 0: + d = S.NegativeInfinity + return d + else: # pragma: no cover + raise OperationNotSupported(f, 'degree') + + def degree_list(f): + """ + Returns a list of degrees of ``f``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x, y + + >>> Poly(x**2 + y*x + 1, x, y).degree_list() + (2, 1) + + """ + if hasattr(f.rep, 'degree_list'): + return f.rep.degree_list() + else: # pragma: no cover + raise OperationNotSupported(f, 'degree_list') + + def total_degree(f): + """ + Returns the total degree of ``f``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x, y + + >>> Poly(x**2 + y*x + 1, x, y).total_degree() + 2 + >>> Poly(x + y**5, x, y).total_degree() + 5 + + """ + if hasattr(f.rep, 'total_degree'): + return f.rep.total_degree() + else: # pragma: no cover + raise OperationNotSupported(f, 'total_degree') + + def homogenize(f, s): + """ + Returns the homogeneous polynomial of ``f``. + + A homogeneous polynomial is a polynomial whose all monomials with + non-zero coefficients have the same total degree. If you only + want to check if a polynomial is homogeneous, then use + :func:`Poly.is_homogeneous`. If you want not only to check if a + polynomial is homogeneous but also compute its homogeneous order, + then use :func:`Poly.homogeneous_order`. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x, y, z + + >>> f = Poly(x**5 + 2*x**2*y**2 + 9*x*y**3) + >>> f.homogenize(z) + Poly(x**5 + 2*x**2*y**2*z + 9*x*y**3*z, x, y, z, domain='ZZ') + + """ + if not isinstance(s, Symbol): + raise TypeError("``Symbol`` expected, got %s" % type(s)) + if s in f.gens: + i = f.gens.index(s) + gens = f.gens + else: + i = len(f.gens) + gens = f.gens + (s,) + if hasattr(f.rep, 'homogenize'): + return f.per(f.rep.homogenize(i), gens=gens) + raise OperationNotSupported(f, 'homogeneous_order') + + def homogeneous_order(f): + """ + Returns the homogeneous order of ``f``. + + A homogeneous polynomial is a polynomial whose all monomials with + non-zero coefficients have the same total degree. This degree is + the homogeneous order of ``f``. If you only want to check if a + polynomial is homogeneous, then use :func:`Poly.is_homogeneous`. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x, y + + >>> f = Poly(x**5 + 2*x**3*y**2 + 9*x*y**4) + >>> f.homogeneous_order() + 5 + + """ + if hasattr(f.rep, 'homogeneous_order'): + return f.rep.homogeneous_order() + else: # pragma: no cover + raise OperationNotSupported(f, 'homogeneous_order') + + def LC(f, order=None): + """ + Returns the leading coefficient of ``f``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(4*x**3 + 2*x**2 + 3*x, x).LC() + 4 + + """ + if order is not None: + return f.coeffs(order)[0] + + if hasattr(f.rep, 'LC'): + result = f.rep.LC() + else: # pragma: no cover + raise OperationNotSupported(f, 'LC') + + return f.rep.dom.to_sympy(result) + + def TC(f): + """ + Returns the trailing coefficient of ``f``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(x**3 + 2*x**2 + 3*x, x).TC() + 0 + + """ + if hasattr(f.rep, 'TC'): + result = f.rep.TC() + else: # pragma: no cover + raise OperationNotSupported(f, 'TC') + + return f.rep.dom.to_sympy(result) + + def EC(f, order=None): + """ + Returns the last non-zero coefficient of ``f``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(x**3 + 2*x**2 + 3*x, x).EC() + 3 + + """ + if hasattr(f.rep, 'coeffs'): + return f.coeffs(order)[-1] + else: # pragma: no cover + raise OperationNotSupported(f, 'EC') + + def coeff_monomial(f, monom): + """ + Returns the coefficient of ``monom`` in ``f`` if there, else None. + + Examples + ======== + + >>> from sympy import Poly, exp + >>> from sympy.abc import x, y + + >>> p = Poly(24*x*y*exp(8) + 23*x, x, y) + + >>> p.coeff_monomial(x) + 23 + >>> p.coeff_monomial(y) + 0 + >>> p.coeff_monomial(x*y) + 24*exp(8) + + Note that ``Expr.coeff()`` behaves differently, collecting terms + if possible; the Poly must be converted to an Expr to use that + method, however: + + >>> p.as_expr().coeff(x) + 24*y*exp(8) + 23 + >>> p.as_expr().coeff(y) + 24*x*exp(8) + >>> p.as_expr().coeff(x*y) + 24*exp(8) + + See Also + ======== + nth: more efficient query using exponents of the monomial's generators + + """ + return f.nth(*Monomial(monom, f.gens).exponents) + + def nth(f, *N): + """ + Returns the ``n``-th coefficient of ``f`` where ``N`` are the + exponents of the generators in the term of interest. + + Examples + ======== + + >>> from sympy import Poly, sqrt + >>> from sympy.abc import x, y + + >>> Poly(x**3 + 2*x**2 + 3*x, x).nth(2) + 2 + >>> Poly(x**3 + 2*x*y**2 + y**2, x, y).nth(1, 2) + 2 + >>> Poly(4*sqrt(x)*y) + Poly(4*y*(sqrt(x)), y, sqrt(x), domain='ZZ') + >>> _.nth(1, 1) + 4 + + See Also + ======== + coeff_monomial + + """ + if hasattr(f.rep, 'nth'): + if len(N) != len(f.gens): + raise ValueError('exponent of each generator must be specified') + result = f.rep.nth(*list(map(int, N))) + else: # pragma: no cover + raise OperationNotSupported(f, 'nth') + + return f.rep.dom.to_sympy(result) + + def coeff(f, x, n=1, right=False): + # the semantics of coeff_monomial and Expr.coeff are different; + # if someone is working with a Poly, they should be aware of the + # differences and chose the method best suited for the query. + # Alternatively, a pure-polys method could be written here but + # at this time the ``right`` keyword would be ignored because Poly + # doesn't work with non-commutatives. + raise NotImplementedError( + 'Either convert to Expr with `as_expr` method ' + 'to use Expr\'s coeff method or else use the ' + '`coeff_monomial` method of Polys.') + + def LM(f, order=None): + """ + Returns the leading monomial of ``f``. + + The Leading monomial signifies the monomial having + the highest power of the principal generator in the + expression f. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x, y + + >>> Poly(4*x**2 + 2*x*y**2 + x*y + 3*y, x, y).LM() + x**2*y**0 + + """ + return Monomial(f.monoms(order)[0], f.gens) + + def EM(f, order=None): + """ + Returns the last non-zero monomial of ``f``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x, y + + >>> Poly(4*x**2 + 2*x*y**2 + x*y + 3*y, x, y).EM() + x**0*y**1 + + """ + return Monomial(f.monoms(order)[-1], f.gens) + + def LT(f, order=None): + """ + Returns the leading term of ``f``. + + The Leading term signifies the term having + the highest power of the principal generator in the + expression f along with its coefficient. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x, y + + >>> Poly(4*x**2 + 2*x*y**2 + x*y + 3*y, x, y).LT() + (x**2*y**0, 4) + + """ + monom, coeff = f.terms(order)[0] + return Monomial(monom, f.gens), coeff + + def ET(f, order=None): + """ + Returns the last non-zero term of ``f``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x, y + + >>> Poly(4*x**2 + 2*x*y**2 + x*y + 3*y, x, y).ET() + (x**0*y**1, 3) + + """ + monom, coeff = f.terms(order)[-1] + return Monomial(monom, f.gens), coeff + + def max_norm(f): + """ + Returns maximum norm of ``f``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(-x**2 + 2*x - 3, x).max_norm() + 3 + + """ + if hasattr(f.rep, 'max_norm'): + result = f.rep.max_norm() + else: # pragma: no cover + raise OperationNotSupported(f, 'max_norm') + + return f.rep.dom.to_sympy(result) + + def l1_norm(f): + """ + Returns l1 norm of ``f``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(-x**2 + 2*x - 3, x).l1_norm() + 6 + + """ + if hasattr(f.rep, 'l1_norm'): + result = f.rep.l1_norm() + else: # pragma: no cover + raise OperationNotSupported(f, 'l1_norm') + + return f.rep.dom.to_sympy(result) + + def clear_denoms(self, convert=False): + """ + Clear denominators, but keep the ground domain. + + Examples + ======== + + >>> from sympy import Poly, S, QQ + >>> from sympy.abc import x + + >>> f = Poly(x/2 + S(1)/3, x, domain=QQ) + + >>> f.clear_denoms() + (6, Poly(3*x + 2, x, domain='QQ')) + >>> f.clear_denoms(convert=True) + (6, Poly(3*x + 2, x, domain='ZZ')) + + """ + f = self + + if not f.rep.dom.is_Field: + return S.One, f + + dom = f.get_domain() + if dom.has_assoc_Ring: + dom = f.rep.dom.get_ring() + + if hasattr(f.rep, 'clear_denoms'): + coeff, result = f.rep.clear_denoms() + else: # pragma: no cover + raise OperationNotSupported(f, 'clear_denoms') + + coeff, f = dom.to_sympy(coeff), f.per(result) + + if not convert or not dom.has_assoc_Ring: + return coeff, f + else: + return coeff, f.to_ring() + + def rat_clear_denoms(self, g): + """ + Clear denominators in a rational function ``f/g``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x, y + + >>> f = Poly(x**2/y + 1, x) + >>> g = Poly(x**3 + y, x) + + >>> p, q = f.rat_clear_denoms(g) + + >>> p + Poly(x**2 + y, x, domain='ZZ[y]') + >>> q + Poly(y*x**3 + y**2, x, domain='ZZ[y]') + + """ + f = self + + dom, per, f, g = f._unify(g) + + f = per(f) + g = per(g) + + if not (dom.is_Field and dom.has_assoc_Ring): + return f, g + + a, f = f.clear_denoms(convert=True) + b, g = g.clear_denoms(convert=True) + + f = f.mul_ground(b) + g = g.mul_ground(a) + + return f, g + + def integrate(self, *specs, **args): + """ + Computes indefinite integral of ``f``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x, y + + >>> Poly(x**2 + 2*x + 1, x).integrate() + Poly(1/3*x**3 + x**2 + x, x, domain='QQ') + + >>> Poly(x*y**2 + x, x, y).integrate((0, 1), (1, 0)) + Poly(1/2*x**2*y**2 + 1/2*x**2, x, y, domain='QQ') + + """ + f = self + + if args.get('auto', True) and f.rep.dom.is_Ring: + f = f.to_field() + + if hasattr(f.rep, 'integrate'): + if not specs: + return f.per(f.rep.integrate(m=1)) + + rep = f.rep + + for spec in specs: + if isinstance(spec, tuple): + gen, m = spec + else: + gen, m = spec, 1 + + rep = rep.integrate(int(m), f._gen_to_level(gen)) + + return f.per(rep) + else: # pragma: no cover + raise OperationNotSupported(f, 'integrate') + + def diff(f, *specs, **kwargs): + """ + Computes partial derivative of ``f``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x, y + + >>> Poly(x**2 + 2*x + 1, x).diff() + Poly(2*x + 2, x, domain='ZZ') + + >>> Poly(x*y**2 + x, x, y).diff((0, 0), (1, 1)) + Poly(2*x*y, x, y, domain='ZZ') + + """ + if not kwargs.get('evaluate', True): + return Derivative(f, *specs, **kwargs) + + if hasattr(f.rep, 'diff'): + if not specs: + return f.per(f.rep.diff(m=1)) + + rep = f.rep + + for spec in specs: + if isinstance(spec, tuple): + gen, m = spec + else: + gen, m = spec, 1 + + rep = rep.diff(int(m), f._gen_to_level(gen)) + + return f.per(rep) + else: # pragma: no cover + raise OperationNotSupported(f, 'diff') + + _eval_derivative = diff + + def eval(self, x, a=None, auto=True): + """ + Evaluate ``f`` at ``a`` in the given variable. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x, y, z + + >>> Poly(x**2 + 2*x + 3, x).eval(2) + 11 + + >>> Poly(2*x*y + 3*x + y + 2, x, y).eval(x, 2) + Poly(5*y + 8, y, domain='ZZ') + + >>> f = Poly(2*x*y + 3*x + y + 2*z, x, y, z) + + >>> f.eval({x: 2}) + Poly(5*y + 2*z + 6, y, z, domain='ZZ') + >>> f.eval({x: 2, y: 5}) + Poly(2*z + 31, z, domain='ZZ') + >>> f.eval({x: 2, y: 5, z: 7}) + 45 + + >>> f.eval((2, 5)) + Poly(2*z + 31, z, domain='ZZ') + >>> f(2, 5) + Poly(2*z + 31, z, domain='ZZ') + + """ + f = self + + if a is None: + if isinstance(x, dict): + mapping = x + + for gen, value in mapping.items(): + f = f.eval(gen, value) + + return f + elif isinstance(x, (tuple, list)): + values = x + + if len(values) > len(f.gens): + raise ValueError("too many values provided") + + for gen, value in zip(f.gens, values): + f = f.eval(gen, value) + + return f + else: + j, a = 0, x + else: + j = f._gen_to_level(x) + + if not hasattr(f.rep, 'eval'): # pragma: no cover + raise OperationNotSupported(f, 'eval') + + try: + result = f.rep.eval(a, j) + except CoercionFailed: + if not auto: + raise DomainError("Cannot evaluate at %s in %s" % (a, f.rep.dom)) + else: + a_domain, [a] = construct_domain([a]) + new_domain = f.get_domain().unify_with_symbols(a_domain, f.gens) + + f = f.set_domain(new_domain) + a = new_domain.convert(a, a_domain) + + result = f.rep.eval(a, j) + + return f.per(result, remove=j) + + def __call__(f, *values): + """ + Evaluate ``f`` at the give values. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x, y, z + + >>> f = Poly(2*x*y + 3*x + y + 2*z, x, y, z) + + >>> f(2) + Poly(5*y + 2*z + 6, y, z, domain='ZZ') + >>> f(2, 5) + Poly(2*z + 31, z, domain='ZZ') + >>> f(2, 5, 7) + 45 + + """ + return f.eval(values) + + def half_gcdex(f, g, auto=True): + """ + Half extended Euclidean algorithm of ``f`` and ``g``. + + Returns ``(s, h)`` such that ``h = gcd(f, g)`` and ``s*f = h (mod g)``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> f = x**4 - 2*x**3 - 6*x**2 + 12*x + 15 + >>> g = x**3 + x**2 - 4*x - 4 + + >>> Poly(f).half_gcdex(Poly(g)) + (Poly(-1/5*x + 3/5, x, domain='QQ'), Poly(x + 1, x, domain='QQ')) + + """ + dom, per, F, G = f._unify(g) + + if auto and dom.is_Ring: + F, G = F.to_field(), G.to_field() + + if hasattr(f.rep, 'half_gcdex'): + s, h = F.half_gcdex(G) + else: # pragma: no cover + raise OperationNotSupported(f, 'half_gcdex') + + return per(s), per(h) + + def gcdex(f, g, auto=True): + """ + Extended Euclidean algorithm of ``f`` and ``g``. + + Returns ``(s, t, h)`` such that ``h = gcd(f, g)`` and ``s*f + t*g = h``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> f = x**4 - 2*x**3 - 6*x**2 + 12*x + 15 + >>> g = x**3 + x**2 - 4*x - 4 + + >>> Poly(f).gcdex(Poly(g)) + (Poly(-1/5*x + 3/5, x, domain='QQ'), + Poly(1/5*x**2 - 6/5*x + 2, x, domain='QQ'), + Poly(x + 1, x, domain='QQ')) + + """ + dom, per, F, G = f._unify(g) + + if auto and dom.is_Ring: + F, G = F.to_field(), G.to_field() + + if hasattr(f.rep, 'gcdex'): + s, t, h = F.gcdex(G) + else: # pragma: no cover + raise OperationNotSupported(f, 'gcdex') + + return per(s), per(t), per(h) + + def invert(f, g, auto=True): + """ + Invert ``f`` modulo ``g`` when possible. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(x**2 - 1, x).invert(Poly(2*x - 1, x)) + Poly(-4/3, x, domain='QQ') + + >>> Poly(x**2 - 1, x).invert(Poly(x - 1, x)) + Traceback (most recent call last): + ... + NotInvertible: zero divisor + + """ + dom, per, F, G = f._unify(g) + + if auto and dom.is_Ring: + F, G = F.to_field(), G.to_field() + + if hasattr(f.rep, 'invert'): + result = F.invert(G) + else: # pragma: no cover + raise OperationNotSupported(f, 'invert') + + return per(result) + + def revert(f, n): + """ + Compute ``f**(-1)`` mod ``x**n``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(1, x).revert(2) + Poly(1, x, domain='ZZ') + + >>> Poly(1 + x, x).revert(1) + Poly(1, x, domain='ZZ') + + >>> Poly(x**2 - 2, x).revert(2) + Traceback (most recent call last): + ... + NotReversible: only units are reversible in a ring + + >>> Poly(1/x, x).revert(1) + Traceback (most recent call last): + ... + PolynomialError: 1/x contains an element of the generators set + + """ + if hasattr(f.rep, 'revert'): + result = f.rep.revert(int(n)) + else: # pragma: no cover + raise OperationNotSupported(f, 'revert') + + return f.per(result) + + def subresultants(f, g): + """ + Computes the subresultant PRS of ``f`` and ``g``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(x**2 + 1, x).subresultants(Poly(x**2 - 1, x)) + [Poly(x**2 + 1, x, domain='ZZ'), + Poly(x**2 - 1, x, domain='ZZ'), + Poly(-2, x, domain='ZZ')] + + """ + _, per, F, G = f._unify(g) + + if hasattr(f.rep, 'subresultants'): + result = F.subresultants(G) + else: # pragma: no cover + raise OperationNotSupported(f, 'subresultants') + + return list(map(per, result)) + + def resultant(f, g, includePRS=False): + """ + Computes the resultant of ``f`` and ``g`` via PRS. + + If includePRS=True, it includes the subresultant PRS in the result. + Because the PRS is used to calculate the resultant, this is more + efficient than calling :func:`subresultants` separately. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> f = Poly(x**2 + 1, x) + + >>> f.resultant(Poly(x**2 - 1, x)) + 4 + >>> f.resultant(Poly(x**2 - 1, x), includePRS=True) + (4, [Poly(x**2 + 1, x, domain='ZZ'), Poly(x**2 - 1, x, domain='ZZ'), + Poly(-2, x, domain='ZZ')]) + + """ + _, per, F, G = f._unify(g) + + if hasattr(f.rep, 'resultant'): + if includePRS: + result, R = F.resultant(G, includePRS=includePRS) + else: + result = F.resultant(G) + else: # pragma: no cover + raise OperationNotSupported(f, 'resultant') + + if includePRS: + return (per(result, remove=0), list(map(per, R))) + return per(result, remove=0) + + def discriminant(f): + """ + Computes the discriminant of ``f``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(x**2 + 2*x + 3, x).discriminant() + -8 + + """ + if hasattr(f.rep, 'discriminant'): + result = f.rep.discriminant() + else: # pragma: no cover + raise OperationNotSupported(f, 'discriminant') + + return f.per(result, remove=0) + + def dispersionset(f, g=None): + r"""Compute the *dispersion set* of two polynomials. + + For two polynomials `f(x)` and `g(x)` with `\deg f > 0` + and `\deg g > 0` the dispersion set `\operatorname{J}(f, g)` is defined as: + + .. math:: + \operatorname{J}(f, g) + & := \{a \in \mathbb{N}_0 | \gcd(f(x), g(x+a)) \neq 1\} \\ + & = \{a \in \mathbb{N}_0 | \deg \gcd(f(x), g(x+a)) \geq 1\} + + For a single polynomial one defines `\operatorname{J}(f) := \operatorname{J}(f, f)`. + + Examples + ======== + + >>> from sympy import poly + >>> from sympy.polys.dispersion import dispersion, dispersionset + >>> from sympy.abc import x + + Dispersion set and dispersion of a simple polynomial: + + >>> fp = poly((x - 3)*(x + 3), x) + >>> sorted(dispersionset(fp)) + [0, 6] + >>> dispersion(fp) + 6 + + Note that the definition of the dispersion is not symmetric: + + >>> fp = poly(x**4 - 3*x**2 + 1, x) + >>> gp = fp.shift(-3) + >>> sorted(dispersionset(fp, gp)) + [2, 3, 4] + >>> dispersion(fp, gp) + 4 + >>> sorted(dispersionset(gp, fp)) + [] + >>> dispersion(gp, fp) + -oo + + Computing the dispersion also works over field extensions: + + >>> from sympy import sqrt + >>> fp = poly(x**2 + sqrt(5)*x - 1, x, domain='QQ') + >>> gp = poly(x**2 + (2 + sqrt(5))*x + sqrt(5), x, domain='QQ') + >>> sorted(dispersionset(fp, gp)) + [2] + >>> sorted(dispersionset(gp, fp)) + [1, 4] + + We can even perform the computations for polynomials + having symbolic coefficients: + + >>> from sympy.abc import a + >>> fp = poly(4*x**4 + (4*a + 8)*x**3 + (a**2 + 6*a + 4)*x**2 + (a**2 + 2*a)*x, x) + >>> sorted(dispersionset(fp)) + [0, 1] + + See Also + ======== + + dispersion + + References + ========== + + 1. [ManWright94]_ + 2. [Koepf98]_ + 3. [Abramov71]_ + 4. [Man93]_ + """ + from sympy.polys.dispersion import dispersionset + return dispersionset(f, g) + + def dispersion(f, g=None): + r"""Compute the *dispersion* of polynomials. + + For two polynomials `f(x)` and `g(x)` with `\deg f > 0` + and `\deg g > 0` the dispersion `\operatorname{dis}(f, g)` is defined as: + + .. math:: + \operatorname{dis}(f, g) + & := \max\{ J(f,g) \cup \{0\} \} \\ + & = \max\{ \{a \in \mathbb{N} | \gcd(f(x), g(x+a)) \neq 1\} \cup \{0\} \} + + and for a single polynomial `\operatorname{dis}(f) := \operatorname{dis}(f, f)`. + + Examples + ======== + + >>> from sympy import poly + >>> from sympy.polys.dispersion import dispersion, dispersionset + >>> from sympy.abc import x + + Dispersion set and dispersion of a simple polynomial: + + >>> fp = poly((x - 3)*(x + 3), x) + >>> sorted(dispersionset(fp)) + [0, 6] + >>> dispersion(fp) + 6 + + Note that the definition of the dispersion is not symmetric: + + >>> fp = poly(x**4 - 3*x**2 + 1, x) + >>> gp = fp.shift(-3) + >>> sorted(dispersionset(fp, gp)) + [2, 3, 4] + >>> dispersion(fp, gp) + 4 + >>> sorted(dispersionset(gp, fp)) + [] + >>> dispersion(gp, fp) + -oo + + Computing the dispersion also works over field extensions: + + >>> from sympy import sqrt + >>> fp = poly(x**2 + sqrt(5)*x - 1, x, domain='QQ') + >>> gp = poly(x**2 + (2 + sqrt(5))*x + sqrt(5), x, domain='QQ') + >>> sorted(dispersionset(fp, gp)) + [2] + >>> sorted(dispersionset(gp, fp)) + [1, 4] + + We can even perform the computations for polynomials + having symbolic coefficients: + + >>> from sympy.abc import a + >>> fp = poly(4*x**4 + (4*a + 8)*x**3 + (a**2 + 6*a + 4)*x**2 + (a**2 + 2*a)*x, x) + >>> sorted(dispersionset(fp)) + [0, 1] + + See Also + ======== + + dispersionset + + References + ========== + + 1. [ManWright94]_ + 2. [Koepf98]_ + 3. [Abramov71]_ + 4. [Man93]_ + """ + from sympy.polys.dispersion import dispersion + return dispersion(f, g) + + def cofactors(f, g): + """ + Returns the GCD of ``f`` and ``g`` and their cofactors. + + Returns polynomials ``(h, cff, cfg)`` such that ``h = gcd(f, g)``, and + ``cff = quo(f, h)`` and ``cfg = quo(g, h)`` are, so called, cofactors + of ``f`` and ``g``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(x**2 - 1, x).cofactors(Poly(x**2 - 3*x + 2, x)) + (Poly(x - 1, x, domain='ZZ'), + Poly(x + 1, x, domain='ZZ'), + Poly(x - 2, x, domain='ZZ')) + + """ + _, per, F, G = f._unify(g) + + if hasattr(f.rep, 'cofactors'): + h, cff, cfg = F.cofactors(G) + else: # pragma: no cover + raise OperationNotSupported(f, 'cofactors') + + return per(h), per(cff), per(cfg) + + def gcd(f, g): + """ + Returns the polynomial GCD of ``f`` and ``g``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(x**2 - 1, x).gcd(Poly(x**2 - 3*x + 2, x)) + Poly(x - 1, x, domain='ZZ') + + """ + _, per, F, G = f._unify(g) + + if hasattr(f.rep, 'gcd'): + result = F.gcd(G) + else: # pragma: no cover + raise OperationNotSupported(f, 'gcd') + + return per(result) + + def lcm(f, g): + """ + Returns polynomial LCM of ``f`` and ``g``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(x**2 - 1, x).lcm(Poly(x**2 - 3*x + 2, x)) + Poly(x**3 - 2*x**2 - x + 2, x, domain='ZZ') + + """ + _, per, F, G = f._unify(g) + + if hasattr(f.rep, 'lcm'): + result = F.lcm(G) + else: # pragma: no cover + raise OperationNotSupported(f, 'lcm') + + return per(result) + + def trunc(f, p): + """ + Reduce ``f`` modulo a constant ``p``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(2*x**3 + 3*x**2 + 5*x + 7, x).trunc(3) + Poly(-x**3 - x + 1, x, domain='ZZ') + + """ + p = f.rep.dom.convert(p) + + if hasattr(f.rep, 'trunc'): + result = f.rep.trunc(p) + else: # pragma: no cover + raise OperationNotSupported(f, 'trunc') + + return f.per(result) + + def monic(self, auto=True): + """ + Divides all coefficients by ``LC(f)``. + + Examples + ======== + + >>> from sympy import Poly, ZZ + >>> from sympy.abc import x + + >>> Poly(3*x**2 + 6*x + 9, x, domain=ZZ).monic() + Poly(x**2 + 2*x + 3, x, domain='QQ') + + >>> Poly(3*x**2 + 4*x + 2, x, domain=ZZ).monic() + Poly(x**2 + 4/3*x + 2/3, x, domain='QQ') + + """ + f = self + + if auto and f.rep.dom.is_Ring: + f = f.to_field() + + if hasattr(f.rep, 'monic'): + result = f.rep.monic() + else: # pragma: no cover + raise OperationNotSupported(f, 'monic') + + return f.per(result) + + def content(f): + """ + Returns the GCD of polynomial coefficients. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(6*x**2 + 8*x + 12, x).content() + 2 + + """ + if hasattr(f.rep, 'content'): + result = f.rep.content() + else: # pragma: no cover + raise OperationNotSupported(f, 'content') + + return f.rep.dom.to_sympy(result) + + def primitive(f): + """ + Returns the content and a primitive form of ``f``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(2*x**2 + 8*x + 12, x).primitive() + (2, Poly(x**2 + 4*x + 6, x, domain='ZZ')) + + """ + if hasattr(f.rep, 'primitive'): + cont, result = f.rep.primitive() + else: # pragma: no cover + raise OperationNotSupported(f, 'primitive') + + return f.rep.dom.to_sympy(cont), f.per(result) + + def compose(f, g): + """ + Computes the functional composition of ``f`` and ``g``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(x**2 + x, x).compose(Poly(x - 1, x)) + Poly(x**2 - x, x, domain='ZZ') + + """ + _, per, F, G = f._unify(g) + + if hasattr(f.rep, 'compose'): + result = F.compose(G) + else: # pragma: no cover + raise OperationNotSupported(f, 'compose') + + return per(result) + + def decompose(f): + """ + Computes a functional decomposition of ``f``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(x**4 + 2*x**3 - x - 1, x, domain='ZZ').decompose() + [Poly(x**2 - x - 1, x, domain='ZZ'), Poly(x**2 + x, x, domain='ZZ')] + + """ + if hasattr(f.rep, 'decompose'): + result = f.rep.decompose() + else: # pragma: no cover + raise OperationNotSupported(f, 'decompose') + + return list(map(f.per, result)) + + def shift(f, a): + """ + Efficiently compute Taylor shift ``f(x + a)``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(x**2 - 2*x + 1, x).shift(2) + Poly(x**2 + 2*x + 1, x, domain='ZZ') + + See Also + ======== + + shift_list: Analogous method for multivariate polynomials. + """ + return f.per(f.rep.shift(a)) + + def shift_list(f, a): + """ + Efficiently compute Taylor shift ``f(X + A)``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x, y + + >>> Poly(x*y, [x,y]).shift_list([1, 2]) == Poly((x+1)*(y+2), [x,y]) + True + + See Also + ======== + + shift: Analogous method for univariate polynomials. + """ + return f.per(f.rep.shift_list(a)) + + def transform(f, p, q): + """ + Efficiently evaluate the functional transformation ``q**n * f(p/q)``. + + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(x**2 - 2*x + 1, x).transform(Poly(x + 1, x), Poly(x - 1, x)) + Poly(4, x, domain='ZZ') + + """ + P, Q = p.unify(q) + F, P = f.unify(P) + F, Q = F.unify(Q) + + if hasattr(F.rep, 'transform'): + result = F.rep.transform(P.rep, Q.rep) + else: # pragma: no cover + raise OperationNotSupported(F, 'transform') + + return F.per(result) + + def sturm(self, auto=True): + """ + Computes the Sturm sequence of ``f``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(x**3 - 2*x**2 + x - 3, x).sturm() + [Poly(x**3 - 2*x**2 + x - 3, x, domain='QQ'), + Poly(3*x**2 - 4*x + 1, x, domain='QQ'), + Poly(2/9*x + 25/9, x, domain='QQ'), + Poly(-2079/4, x, domain='QQ')] + + """ + f = self + + if auto and f.rep.dom.is_Ring: + f = f.to_field() + + if hasattr(f.rep, 'sturm'): + result = f.rep.sturm() + else: # pragma: no cover + raise OperationNotSupported(f, 'sturm') + + return list(map(f.per, result)) + + def gff_list(f): + """ + Computes greatest factorial factorization of ``f``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> f = x**5 + 2*x**4 - x**3 - 2*x**2 + + >>> Poly(f).gff_list() + [(Poly(x, x, domain='ZZ'), 1), (Poly(x + 2, x, domain='ZZ'), 4)] + + """ + if hasattr(f.rep, 'gff_list'): + result = f.rep.gff_list() + else: # pragma: no cover + raise OperationNotSupported(f, 'gff_list') + + return [(f.per(g), k) for g, k in result] + + def norm(f): + """ + Computes the product, ``Norm(f)``, of the conjugates of + a polynomial ``f`` defined over a number field ``K``. + + Examples + ======== + + >>> from sympy import Poly, sqrt + >>> from sympy.abc import x + + >>> a, b = sqrt(2), sqrt(3) + + A polynomial over a quadratic extension. + Two conjugates x - a and x + a. + + >>> f = Poly(x - a, x, extension=a) + >>> f.norm() + Poly(x**2 - 2, x, domain='QQ') + + A polynomial over a quartic extension. + Four conjugates x - a, x - a, x + a and x + a. + + >>> f = Poly(x - a, x, extension=(a, b)) + >>> f.norm() + Poly(x**4 - 4*x**2 + 4, x, domain='QQ') + + """ + if hasattr(f.rep, 'norm'): + r = f.rep.norm() + else: # pragma: no cover + raise OperationNotSupported(f, 'norm') + + return f.per(r) + + def sqf_norm(f): + """ + Computes square-free norm of ``f``. + + Returns ``s``, ``f``, ``r``, such that ``g(x) = f(x-sa)`` and + ``r(x) = Norm(g(x))`` is a square-free polynomial over ``K``, + where ``a`` is the algebraic extension of the ground domain. + + Examples + ======== + + >>> from sympy import Poly, sqrt + >>> from sympy.abc import x + + >>> s, f, r = Poly(x**2 + 1, x, extension=[sqrt(3)]).sqf_norm() + + >>> s + [1] + >>> f + Poly(x**2 - 2*sqrt(3)*x + 4, x, domain='QQ') + >>> r + Poly(x**4 - 4*x**2 + 16, x, domain='QQ') + + """ + if hasattr(f.rep, 'sqf_norm'): + s, g, r = f.rep.sqf_norm() + else: # pragma: no cover + raise OperationNotSupported(f, 'sqf_norm') + + return s, f.per(g), f.per(r) + + def sqf_part(f): + """ + Computes square-free part of ``f``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(x**3 - 3*x - 2, x).sqf_part() + Poly(x**2 - x - 2, x, domain='ZZ') + + """ + if hasattr(f.rep, 'sqf_part'): + result = f.rep.sqf_part() + else: # pragma: no cover + raise OperationNotSupported(f, 'sqf_part') + + return f.per(result) + + def sqf_list(f, all=False): + """ + Returns a list of square-free factors of ``f``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> f = 2*x**5 + 16*x**4 + 50*x**3 + 76*x**2 + 56*x + 16 + + >>> Poly(f).sqf_list() + (2, [(Poly(x + 1, x, domain='ZZ'), 2), + (Poly(x + 2, x, domain='ZZ'), 3)]) + + >>> Poly(f).sqf_list(all=True) + (2, [(Poly(1, x, domain='ZZ'), 1), + (Poly(x + 1, x, domain='ZZ'), 2), + (Poly(x + 2, x, domain='ZZ'), 3)]) + + """ + if hasattr(f.rep, 'sqf_list'): + coeff, factors = f.rep.sqf_list(all) + else: # pragma: no cover + raise OperationNotSupported(f, 'sqf_list') + + return f.rep.dom.to_sympy(coeff), [(f.per(g), k) for g, k in factors] + + def sqf_list_include(f, all=False): + """ + Returns a list of square-free factors of ``f``. + + Examples + ======== + + >>> from sympy import Poly, expand + >>> from sympy.abc import x + + >>> f = expand(2*(x + 1)**3*x**4) + >>> f + 2*x**7 + 6*x**6 + 6*x**5 + 2*x**4 + + >>> Poly(f).sqf_list_include() + [(Poly(2, x, domain='ZZ'), 1), + (Poly(x + 1, x, domain='ZZ'), 3), + (Poly(x, x, domain='ZZ'), 4)] + + >>> Poly(f).sqf_list_include(all=True) + [(Poly(2, x, domain='ZZ'), 1), + (Poly(1, x, domain='ZZ'), 2), + (Poly(x + 1, x, domain='ZZ'), 3), + (Poly(x, x, domain='ZZ'), 4)] + + """ + if hasattr(f.rep, 'sqf_list_include'): + factors = f.rep.sqf_list_include(all) + else: # pragma: no cover + raise OperationNotSupported(f, 'sqf_list_include') + + return [(f.per(g), k) for g, k in factors] + + def factor_list(f) -> tuple[Expr, list[tuple[Poly, int]]]: + """ + Returns a list of irreducible factors of ``f``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x, y + + >>> f = 2*x**5 + 2*x**4*y + 4*x**3 + 4*x**2*y + 2*x + 2*y + + >>> Poly(f).factor_list() + (2, [(Poly(x + y, x, y, domain='ZZ'), 1), + (Poly(x**2 + 1, x, y, domain='ZZ'), 2)]) + + """ + if hasattr(f.rep, 'factor_list'): + try: + coeff, factors = f.rep.factor_list() + except DomainError: + if f.degree() == 0: + return f.as_expr(), [] + else: + return S.One, [(f, 1)] + else: # pragma: no cover + raise OperationNotSupported(f, 'factor_list') + + return f.rep.dom.to_sympy(coeff), [(f.per(g), k) for g, k in factors] + + def factor_list_include(f): + """ + Returns a list of irreducible factors of ``f``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x, y + + >>> f = 2*x**5 + 2*x**4*y + 4*x**3 + 4*x**2*y + 2*x + 2*y + + >>> Poly(f).factor_list_include() + [(Poly(2*x + 2*y, x, y, domain='ZZ'), 1), + (Poly(x**2 + 1, x, y, domain='ZZ'), 2)] + + """ + if hasattr(f.rep, 'factor_list_include'): + try: + factors = f.rep.factor_list_include() + except DomainError: + return [(f, 1)] + else: # pragma: no cover + raise OperationNotSupported(f, 'factor_list_include') + + return [(f.per(g), k) for g, k in factors] + + def intervals(f, all=False, eps=None, inf=None, sup=None, fast=False, sqf=False): + """ + Compute isolating intervals for roots of ``f``. + + For real roots the Vincent-Akritas-Strzebonski (VAS) continued fractions method is used. + + References + ========== + .. [#] Alkiviadis G. Akritas and Adam W. Strzebonski: A Comparative Study of Two Real Root + Isolation Methods . Nonlinear Analysis: Modelling and Control, Vol. 10, No. 4, 297-304, 2005. + .. [#] Alkiviadis G. Akritas, Adam W. Strzebonski and Panagiotis S. Vigklas: Improving the + Performance of the Continued Fractions Method Using new Bounds of Positive Roots. Nonlinear + Analysis: Modelling and Control, Vol. 13, No. 3, 265-279, 2008. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(x**2 - 3, x).intervals() + [((-2, -1), 1), ((1, 2), 1)] + >>> Poly(x**2 - 3, x).intervals(eps=1e-2) + [((-26/15, -19/11), 1), ((19/11, 26/15), 1)] + + """ + if eps is not None: + eps = QQ.convert(eps) + + if eps <= 0: + raise ValueError("'eps' must be a positive rational") + + if inf is not None: + inf = QQ.convert(inf) + if sup is not None: + sup = QQ.convert(sup) + + if hasattr(f.rep, 'intervals'): + result = f.rep.intervals( + all=all, eps=eps, inf=inf, sup=sup, fast=fast, sqf=sqf) + else: # pragma: no cover + raise OperationNotSupported(f, 'intervals') + + if sqf: + def _real(interval): + s, t = interval + return (QQ.to_sympy(s), QQ.to_sympy(t)) + + if not all: + return list(map(_real, result)) + + def _complex(rectangle): + (u, v), (s, t) = rectangle + return (QQ.to_sympy(u) + I*QQ.to_sympy(v), + QQ.to_sympy(s) + I*QQ.to_sympy(t)) + + real_part, complex_part = result + + return list(map(_real, real_part)), list(map(_complex, complex_part)) + else: + def _real(interval): + (s, t), k = interval + return ((QQ.to_sympy(s), QQ.to_sympy(t)), k) + + if not all: + return list(map(_real, result)) + + def _complex(rectangle): + ((u, v), (s, t)), k = rectangle + return ((QQ.to_sympy(u) + I*QQ.to_sympy(v), + QQ.to_sympy(s) + I*QQ.to_sympy(t)), k) + + real_part, complex_part = result + + return list(map(_real, real_part)), list(map(_complex, complex_part)) + + def refine_root(f, s, t, eps=None, steps=None, fast=False, check_sqf=False): + """ + Refine an isolating interval of a root to the given precision. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(x**2 - 3, x).refine_root(1, 2, eps=1e-2) + (19/11, 26/15) + + """ + if check_sqf and not f.is_sqf: + raise PolynomialError("only square-free polynomials supported") + + s, t = QQ.convert(s), QQ.convert(t) + + if eps is not None: + eps = QQ.convert(eps) + + if eps <= 0: + raise ValueError("'eps' must be a positive rational") + + if steps is not None: + steps = int(steps) + elif eps is None: + steps = 1 + + if hasattr(f.rep, 'refine_root'): + S, T = f.rep.refine_root(s, t, eps=eps, steps=steps, fast=fast) + else: # pragma: no cover + raise OperationNotSupported(f, 'refine_root') + + return QQ.to_sympy(S), QQ.to_sympy(T) + + def count_roots(f, inf=None, sup=None): + """ + Return the number of roots of ``f`` in ``[inf, sup]`` interval. + + Examples + ======== + + >>> from sympy import Poly, I + >>> from sympy.abc import x + + >>> Poly(x**4 - 4, x).count_roots(-3, 3) + 2 + >>> Poly(x**4 - 4, x).count_roots(0, 1 + 3*I) + 1 + + """ + inf_real, sup_real = True, True + + if inf is not None: + inf = sympify(inf) + + if inf is S.NegativeInfinity: + inf = None + else: + re, im = inf.as_real_imag() + + if not im: + inf = QQ.convert(inf) + else: + inf, inf_real = list(map(QQ.convert, (re, im))), False + + if sup is not None: + sup = sympify(sup) + + if sup is S.Infinity: + sup = None + else: + re, im = sup.as_real_imag() + + if not im: + sup = QQ.convert(sup) + else: + sup, sup_real = list(map(QQ.convert, (re, im))), False + + if inf_real and sup_real: + if hasattr(f.rep, 'count_real_roots'): + count = f.rep.count_real_roots(inf=inf, sup=sup) + else: # pragma: no cover + raise OperationNotSupported(f, 'count_real_roots') + else: + if inf_real and inf is not None: + inf = (inf, QQ.zero) + + if sup_real and sup is not None: + sup = (sup, QQ.zero) + + if hasattr(f.rep, 'count_complex_roots'): + count = f.rep.count_complex_roots(inf=inf, sup=sup) + else: # pragma: no cover + raise OperationNotSupported(f, 'count_complex_roots') + + return Integer(count) + + def root(f, index, radicals=True): + """ + Get an indexed root of a polynomial. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> f = Poly(2*x**3 - 7*x**2 + 4*x + 4) + + >>> f.root(0) + -1/2 + >>> f.root(1) + 2 + >>> f.root(2) + 2 + >>> f.root(3) + Traceback (most recent call last): + ... + IndexError: root index out of [-3, 2] range, got 3 + + >>> Poly(x**5 + x + 1).root(0) + CRootOf(x**3 - x**2 + 1, 0) + + """ + return sympy.polys.rootoftools.rootof(f, index, radicals=radicals) + + def real_roots(f, multiple=True, radicals=True): + """ + Return a list of real roots with multiplicities. + + See :func:`real_roots` for more explanation. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(2*x**3 - 7*x**2 + 4*x + 4).real_roots() + [-1/2, 2, 2] + >>> Poly(x**3 + x + 1).real_roots() + [CRootOf(x**3 + x + 1, 0)] + """ + reals = sympy.polys.rootoftools.CRootOf.real_roots(f, radicals=radicals) + + if multiple: + return reals + else: + return group(reals, multiple=False) + + def all_roots(f, multiple=True, radicals=True): + """ + Return a list of real and complex roots with multiplicities. + + See :func:`all_roots` for more explanation. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(2*x**3 - 7*x**2 + 4*x + 4).all_roots() + [-1/2, 2, 2] + >>> Poly(x**3 + x + 1).all_roots() + [CRootOf(x**3 + x + 1, 0), + CRootOf(x**3 + x + 1, 1), + CRootOf(x**3 + x + 1, 2)] + + """ + roots = sympy.polys.rootoftools.CRootOf.all_roots(f, radicals=radicals) + + if multiple: + return roots + else: + return group(roots, multiple=False) + + def nroots(f, n=15, maxsteps=50, cleanup=True): + """ + Compute numerical approximations of roots of ``f``. + + Parameters + ========== + + n ... the number of digits to calculate + maxsteps ... the maximum number of iterations to do + + If the accuracy `n` cannot be reached in `maxsteps`, it will raise an + exception. You need to rerun with higher maxsteps. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(x**2 - 3).nroots(n=15) + [-1.73205080756888, 1.73205080756888] + >>> Poly(x**2 - 3).nroots(n=30) + [-1.73205080756887729352744634151, 1.73205080756887729352744634151] + + """ + if f.is_multivariate: + raise MultivariatePolynomialError( + "Cannot compute numerical roots of %s" % f) + + if f.degree() <= 0: + return [] + + # For integer and rational coefficients, convert them to integers only + # (for accuracy). Otherwise just try to convert the coefficients to + # mpmath.mpc and raise an exception if the conversion fails. + if f.rep.dom is ZZ: + coeffs = [int(coeff) for coeff in f.all_coeffs()] + elif f.rep.dom is QQ: + denoms = [coeff.q for coeff in f.all_coeffs()] + fac = ilcm(*denoms) + coeffs = [int(coeff*fac) for coeff in f.all_coeffs()] + else: + coeffs = [coeff.evalf(n=n).as_real_imag() + for coeff in f.all_coeffs()] + with mpmath.workdps(n): + try: + coeffs = [mpmath.mpc(*coeff) for coeff in coeffs] + except TypeError: + raise DomainError("Numerical domain expected, got %s" % \ + f.rep.dom) + + dps = mpmath.mp.dps + mpmath.mp.dps = n + + from sympy.functions.elementary.complexes import sign + try: + # We need to add extra precision to guard against losing accuracy. + # 10 times the degree of the polynomial seems to work well. + roots = mpmath.polyroots(coeffs, maxsteps=maxsteps, + cleanup=cleanup, error=False, extraprec=f.degree()*10) + + # Mpmath puts real roots first, then complex ones (as does all_roots) + # so we make sure this convention holds here, too. + roots = list(map(sympify, + sorted(roots, key=lambda r: (1 if r.imag else 0, r.real, abs(r.imag), sign(r.imag))))) + except NoConvergence: + try: + # If roots did not converge try again with more extra precision. + roots = mpmath.polyroots(coeffs, maxsteps=maxsteps, + cleanup=cleanup, error=False, extraprec=f.degree()*15) + roots = list(map(sympify, + sorted(roots, key=lambda r: (1 if r.imag else 0, r.real, abs(r.imag), sign(r.imag))))) + except NoConvergence: + raise NoConvergence( + 'convergence to root failed; try n < %s or maxsteps > %s' % ( + n, maxsteps)) + finally: + mpmath.mp.dps = dps + + return roots + + def ground_roots(f): + """ + Compute roots of ``f`` by factorization in the ground domain. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(x**6 - 4*x**4 + 4*x**3 - x**2).ground_roots() + {0: 2, 1: 2} + + """ + if f.is_multivariate: + raise MultivariatePolynomialError( + "Cannot compute ground roots of %s" % f) + + roots = {} + + for factor, k in f.factor_list()[1]: + if factor.is_linear: + a, b = factor.all_coeffs() + roots[-b/a] = k + + return roots + + def nth_power_roots_poly(f, n): + """ + Construct a polynomial with n-th powers of roots of ``f``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> f = Poly(x**4 - x**2 + 1) + + >>> f.nth_power_roots_poly(2) + Poly(x**4 - 2*x**3 + 3*x**2 - 2*x + 1, x, domain='ZZ') + >>> f.nth_power_roots_poly(3) + Poly(x**4 + 2*x**2 + 1, x, domain='ZZ') + >>> f.nth_power_roots_poly(4) + Poly(x**4 + 2*x**3 + 3*x**2 + 2*x + 1, x, domain='ZZ') + >>> f.nth_power_roots_poly(12) + Poly(x**4 - 4*x**3 + 6*x**2 - 4*x + 1, x, domain='ZZ') + + """ + if f.is_multivariate: + raise MultivariatePolynomialError( + "must be a univariate polynomial") + + N = sympify(n) + + if N.is_Integer and N >= 1: + n = int(N) + else: + raise ValueError("'n' must an integer and n >= 1, got %s" % n) + + x = f.gen + t = Dummy('t') + + r = f.resultant(f.__class__.from_expr(x**n - t, x, t)) + + return r.replace(t, x) + + def which_real_roots(f, candidates): + """ + Find roots of a square-free polynomial ``f`` from ``candidates``. + + Explanation + =========== + + If ``f`` is a square-free polynomial and ``candidates`` is a superset + of the roots of ``f``, then ``f.which_real_roots(candidates)`` returns a + list containing exactly the set of roots of ``f``. The domain must be + :ref:`ZZ`, :ref:`QQ`, or :ref:`QQ(a)` and``f`` must be univariate and + square-free. + + The list ``candidates`` must be a superset of the real roots of ``f`` + and ``f.which_real_roots(candidates)`` returns the set of real roots + of ``f``. The output preserves the order of the order of ``candidates``. + + Examples + ======== + + >>> from sympy import Poly, sqrt + >>> from sympy.abc import x + + >>> f = Poly(x**4 - 1) + >>> f.which_real_roots([-1, 1, 0, -2, 2]) + [-1, 1] + >>> f.which_real_roots([-1, 1, 1, 1, 1]) + [-1, 1] + + This method is useful as lifting to rational coefficients + produced extraneous roots, which we can filter out with + this method. + + >>> f = Poly(sqrt(2)*x**3 + x**2 - 1, x, extension=True) + >>> f.lift() + Poly(-2*x**6 + x**4 - 2*x**2 + 1, x, domain='QQ') + >>> f.lift().real_roots() + [-sqrt(2)/2, sqrt(2)/2] + >>> f.which_real_roots(f.lift().real_roots()) + [sqrt(2)/2] + + This procedure is already done internally when calling + `.real_roots()` on a polynomial with algebraic coefficients. + + >>> f.real_roots() + [sqrt(2)/2] + + See Also + ======== + + same_root + which_all_roots + """ + if f.is_multivariate: + raise MultivariatePolynomialError( + "Must be a univariate polynomial") + + dom = f.get_domain() + + if not (dom.is_ZZ or dom.is_QQ or dom.is_AlgebraicField): + raise NotImplementedError( + "root counting not supported over %s" % dom) + + return f._which_roots(candidates, f.count_roots()) + + def which_all_roots(f, candidates): + """ + Find roots of a square-free polynomial ``f`` from ``candidates``. + + Explanation + =========== + + If ``f`` is a square-free polynomial and ``candidates`` is a superset + of the roots of ``f``, then ``f.which_all_roots(candidates)`` returns a + list containing exactly the set of roots of ``f``. The polynomial``f`` + must be univariate and square-free. + + The list ``candidates`` must be a superset of the complex roots of + ``f`` and ``f.which_all_roots(candidates)`` returns exactly the + set of all complex roots of ``f``. The output preserves the order of + the order of ``candidates``. + + Examples + ======== + + >>> from sympy import Poly, I + >>> from sympy.abc import x + + >>> f = Poly(x**4 - 1) + >>> f.which_all_roots([-1, 1, -I, I, 0]) + [-1, 1, -I, I] + >>> f.which_all_roots([-1, 1, -I, I, I, I]) + [-1, 1, -I, I] + + This method is useful as lifting to rational coefficients + produced extraneous roots, which we can filter out with + this method. + + >>> f = Poly(x**2 + I*x - 1, x, extension=True) + >>> f.lift() + Poly(x**4 - x**2 + 1, x, domain='ZZ') + >>> f.lift().all_roots() + [CRootOf(x**4 - x**2 + 1, 0), + CRootOf(x**4 - x**2 + 1, 1), + CRootOf(x**4 - x**2 + 1, 2), + CRootOf(x**4 - x**2 + 1, 3)] + >>> f.which_all_roots(f.lift().all_roots()) + [CRootOf(x**4 - x**2 + 1, 0), CRootOf(x**4 - x**2 + 1, 2)] + + This procedure is already done internally when calling + `.all_roots()` on a polynomial with algebraic coefficients, + or polynomials with Gaussian domains. + + >>> f.all_roots() + [CRootOf(x**4 - x**2 + 1, 0), CRootOf(x**4 - x**2 + 1, 2)] + + See Also + ======== + + same_root + which_real_roots + """ + if f.is_multivariate: + raise MultivariatePolynomialError( + "Must be a univariate polynomial") + + return f._which_roots(candidates, f.degree()) + + def _which_roots(f, candidates, num_roots): + prec = 10 + # using Counter bc its like an ordered set + root_counts = Counter(candidates) + + while len(root_counts) > num_roots: + for r in list(root_counts.keys()): + # If f(r) != 0 then f(r).evalf() gives a float/complex with precision. + f_r = f(r).evalf(prec, maxn=2*prec) + if abs(f_r)._prec >= 2: + root_counts.pop(r) + + prec *= 2 + + return list(root_counts.keys()) + + def same_root(f, a, b): + """ + Decide whether two roots of this polynomial are equal. + + Examples + ======== + + >>> from sympy import Poly, cyclotomic_poly, exp, I, pi + >>> f = Poly(cyclotomic_poly(5)) + >>> r0 = exp(2*I*pi/5) + >>> indices = [i for i, r in enumerate(f.all_roots()) if f.same_root(r, r0)] + >>> print(indices) + [3] + + Raises + ====== + + DomainError + If the domain of the polynomial is not :ref:`ZZ`, :ref:`QQ`, + :ref:`RR`, or :ref:`CC`. + MultivariatePolynomialError + If the polynomial is not univariate. + PolynomialError + If the polynomial is of degree < 2. + + See Also + ======== + + which_real_roots + which_all_roots + """ + if f.is_multivariate: + raise MultivariatePolynomialError( + "Must be a univariate polynomial") + + dom_delta_sq = f.rep.mignotte_sep_bound_squared() + delta_sq = f.domain.get_field().to_sympy(dom_delta_sq) + # We have delta_sq = delta**2, where delta is a lower bound on the + # minimum separation between any two roots of this polynomial. + # Let eps = delta/3, and define eps_sq = eps**2 = delta**2/9. + eps_sq = delta_sq / 9 + + r, _, _, _ = evalf(1/eps_sq, 1, {}) + n = fastlog(r) + # Then 2^n > 1/eps**2. + m = (n // 2) + (n % 2) + # Then 2^(-m) < eps. + ev = lambda x: quad_to_mpmath(_evalf_with_bounded_error(x, m=m)) + + # Then for any complex numbers a, b we will have + # |a - ev(a)| < eps and |b - ev(b)| < eps. + # So if |ev(a) - ev(b)|**2 < eps**2, then + # |ev(a) - ev(b)| < eps, hence |a - b| < 3*eps = delta. + A, B = ev(a), ev(b) + return (A.real - B.real)**2 + (A.imag - B.imag)**2 < eps_sq + + def cancel(f, g, include=False): + """ + Cancel common factors in a rational function ``f/g``. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(2*x**2 - 2, x).cancel(Poly(x**2 - 2*x + 1, x)) + (1, Poly(2*x + 2, x, domain='ZZ'), Poly(x - 1, x, domain='ZZ')) + + >>> Poly(2*x**2 - 2, x).cancel(Poly(x**2 - 2*x + 1, x), include=True) + (Poly(2*x + 2, x, domain='ZZ'), Poly(x - 1, x, domain='ZZ')) + + """ + dom, per, F, G = f._unify(g) + + if hasattr(F, 'cancel'): + result = F.cancel(G, include=include) + else: # pragma: no cover + raise OperationNotSupported(f, 'cancel') + + if not include: + if dom.has_assoc_Ring: + dom = dom.get_ring() + + cp, cq, p, q = result + + cp = dom.to_sympy(cp) + cq = dom.to_sympy(cq) + + return cp/cq, per(p), per(q) + else: + return tuple(map(per, result)) + + def make_monic_over_integers_by_scaling_roots(f): + """ + Turn any univariate polynomial over :ref:`QQ` or :ref:`ZZ` into a monic + polynomial over :ref:`ZZ`, by scaling the roots as necessary. + + Explanation + =========== + + This operation can be performed whether or not *f* is irreducible; when + it is, this can be understood as determining an algebraic integer + generating the same field as a root of *f*. + + Examples + ======== + + >>> from sympy import Poly, S + >>> from sympy.abc import x + >>> f = Poly(x**2/2 + S(1)/4 * x + S(1)/8, x, domain='QQ') + >>> f.make_monic_over_integers_by_scaling_roots() + (Poly(x**2 + 2*x + 4, x, domain='ZZ'), 4) + + Returns + ======= + + Pair ``(g, c)`` + g is the polynomial + + c is the integer by which the roots had to be scaled + + """ + if not f.is_univariate or f.domain not in [ZZ, QQ]: + raise ValueError('Polynomial must be univariate over ZZ or QQ.') + if f.is_monic and f.domain == ZZ: + return f, ZZ.one + else: + fm = f.monic() + c, _ = fm.clear_denoms() + return fm.transform(Poly(fm.gen), c).to_ring(), c + + def galois_group(f, by_name=False, max_tries=30, randomize=False): + """ + Compute the Galois group of this polynomial. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + >>> f = Poly(x**4 - 2) + >>> G, _ = f.galois_group(by_name=True) + >>> print(G) + S4TransitiveSubgroups.D4 + + See Also + ======== + + sympy.polys.numberfields.galoisgroups.galois_group + + """ + from sympy.polys.numberfields.galoisgroups import ( + _galois_group_degree_3, _galois_group_degree_4_lookup, + _galois_group_degree_5_lookup_ext_factor, + _galois_group_degree_6_lookup, + ) + if (not f.is_univariate + or not f.is_irreducible + or f.domain not in [ZZ, QQ] + ): + raise ValueError('Polynomial must be irreducible and univariate over ZZ or QQ.') + gg = { + 3: _galois_group_degree_3, + 4: _galois_group_degree_4_lookup, + 5: _galois_group_degree_5_lookup_ext_factor, + 6: _galois_group_degree_6_lookup, + } + max_supported = max(gg.keys()) + n = f.degree() + if n > max_supported: + raise ValueError(f"Only polynomials up to degree {max_supported} are supported.") + elif n < 1: + raise ValueError("Constant polynomial has no Galois group.") + elif n == 1: + from sympy.combinatorics.galois import S1TransitiveSubgroups + name, alt = S1TransitiveSubgroups.S1, True + elif n == 2: + from sympy.combinatorics.galois import S2TransitiveSubgroups + name, alt = S2TransitiveSubgroups.S2, False + else: + g, _ = f.make_monic_over_integers_by_scaling_roots() + name, alt = gg[n](g, max_tries=max_tries, randomize=randomize) + G = name if by_name else name.get_perm_group() + return G, alt + + @property + def is_zero(f): + """ + Returns ``True`` if ``f`` is a zero polynomial. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(0, x).is_zero + True + >>> Poly(1, x).is_zero + False + + """ + return f.rep.is_zero + + @property + def is_one(f): + """ + Returns ``True`` if ``f`` is a unit polynomial. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(0, x).is_one + False + >>> Poly(1, x).is_one + True + + """ + return f.rep.is_one + + @property + def is_sqf(f): + """ + Returns ``True`` if ``f`` is a square-free polynomial. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(x**2 - 2*x + 1, x).is_sqf + False + >>> Poly(x**2 - 1, x).is_sqf + True + + """ + return f.rep.is_sqf + + @property + def is_monic(f): + """ + Returns ``True`` if the leading coefficient of ``f`` is one. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(x + 2, x).is_monic + True + >>> Poly(2*x + 2, x).is_monic + False + + """ + return f.rep.is_monic + + @property + def is_primitive(f): + """ + Returns ``True`` if GCD of the coefficients of ``f`` is one. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(2*x**2 + 6*x + 12, x).is_primitive + False + >>> Poly(x**2 + 3*x + 6, x).is_primitive + True + + """ + return f.rep.is_primitive + + @property + def is_ground(f): + """ + Returns ``True`` if ``f`` is an element of the ground domain. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x, y + + >>> Poly(x, x).is_ground + False + >>> Poly(2, x).is_ground + True + >>> Poly(y, x).is_ground + True + + """ + return f.rep.is_ground + + @property + def is_linear(f): + """ + Returns ``True`` if ``f`` is linear in all its variables. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x, y + + >>> Poly(x + y + 2, x, y).is_linear + True + >>> Poly(x*y + 2, x, y).is_linear + False + + """ + return f.rep.is_linear + + @property + def is_quadratic(f): + """ + Returns ``True`` if ``f`` is quadratic in all its variables. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x, y + + >>> Poly(x*y + 2, x, y).is_quadratic + True + >>> Poly(x*y**2 + 2, x, y).is_quadratic + False + + """ + return f.rep.is_quadratic + + @property + def is_monomial(f): + """ + Returns ``True`` if ``f`` is zero or has only one term. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(3*x**2, x).is_monomial + True + >>> Poly(3*x**2 + 1, x).is_monomial + False + + """ + return f.rep.is_monomial + + @property + def is_homogeneous(f): + """ + Returns ``True`` if ``f`` is a homogeneous polynomial. + + A homogeneous polynomial is a polynomial whose all monomials with + non-zero coefficients have the same total degree. If you want not + only to check if a polynomial is homogeneous but also compute its + homogeneous order, then use :func:`Poly.homogeneous_order`. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x, y + + >>> Poly(x**2 + x*y, x, y).is_homogeneous + True + >>> Poly(x**3 + x*y, x, y).is_homogeneous + False + + """ + return f.rep.is_homogeneous + + @property + def is_irreducible(f): + """ + Returns ``True`` if ``f`` has no factors over its domain. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> Poly(x**2 + x + 1, x, modulus=2).is_irreducible + True + >>> Poly(x**2 + 1, x, modulus=2).is_irreducible + False + + """ + return f.rep.is_irreducible + + @property + def is_univariate(f): + """ + Returns ``True`` if ``f`` is a univariate polynomial. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x, y + + >>> Poly(x**2 + x + 1, x).is_univariate + True + >>> Poly(x*y**2 + x*y + 1, x, y).is_univariate + False + >>> Poly(x*y**2 + x*y + 1, x).is_univariate + True + >>> Poly(x**2 + x + 1, x, y).is_univariate + False + + """ + return len(f.gens) == 1 + + @property + def is_multivariate(f): + """ + Returns ``True`` if ``f`` is a multivariate polynomial. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x, y + + >>> Poly(x**2 + x + 1, x).is_multivariate + False + >>> Poly(x*y**2 + x*y + 1, x, y).is_multivariate + True + >>> Poly(x*y**2 + x*y + 1, x).is_multivariate + False + >>> Poly(x**2 + x + 1, x, y).is_multivariate + True + + """ + return len(f.gens) != 1 + + @property + def is_cyclotomic(f): + """ + Returns ``True`` if ``f`` is a cyclotomic polnomial. + + Examples + ======== + + >>> from sympy import Poly + >>> from sympy.abc import x + + >>> f = x**16 + x**14 - x**10 + x**8 - x**6 + x**2 + 1 + + >>> Poly(f).is_cyclotomic + False + + >>> g = x**16 + x**14 - x**10 - x**8 - x**6 + x**2 + 1 + + >>> Poly(g).is_cyclotomic + True + + """ + return f.rep.is_cyclotomic + + def __abs__(f): + return f.abs() + + def __neg__(f): + return f.neg() + + @_polifyit + def __add__(f, g): + return f.add(g) + + @_polifyit + def __radd__(f, g): + return g.add(f) + + @_polifyit + def __sub__(f, g): + return f.sub(g) + + @_polifyit + def __rsub__(f, g): + return g.sub(f) + + @_polifyit + def __mul__(f, g): + return f.mul(g) + + @_polifyit + def __rmul__(f, g): + return g.mul(f) + + @_sympifyit('n', NotImplemented) + def __pow__(f, n): + if n.is_Integer and n >= 0: + return f.pow(n) + else: + return NotImplemented + + @_polifyit + def __divmod__(f, g): + return f.div(g) + + @_polifyit + def __rdivmod__(f, g): + return g.div(f) + + @_polifyit + def __mod__(f, g): + return f.rem(g) + + @_polifyit + def __rmod__(f, g): + return g.rem(f) + + @_polifyit + def __floordiv__(f, g): + return f.quo(g) + + @_polifyit + def __rfloordiv__(f, g): + return g.quo(f) + + @_sympifyit('g', NotImplemented) + def __truediv__(f, g): + return f.as_expr()/g.as_expr() + + @_sympifyit('g', NotImplemented) + def __rtruediv__(f, g): + return g.as_expr()/f.as_expr() + + @_sympifyit('other', NotImplemented) + def __eq__(self, other): + f, g = self, other + + if not g.is_Poly: + try: + g = f.__class__(g, f.gens, domain=f.get_domain()) + except (PolynomialError, DomainError, CoercionFailed): + return False + + if f.gens != g.gens: + return False + + if f.rep.dom != g.rep.dom: + return False + + return f.rep == g.rep + + @_sympifyit('g', NotImplemented) + def __ne__(f, g): + return not f == g + + def __bool__(f): + return not f.is_zero + + def eq(f, g, strict=False): + if not strict: + return f == g + else: + return f._strict_eq(sympify(g)) + + def ne(f, g, strict=False): + return not f.eq(g, strict=strict) + + def _strict_eq(f, g): + return isinstance(g, f.__class__) and f.gens == g.gens and f.rep.eq(g.rep, strict=True) + + +@public +class PurePoly(Poly): + """Class for representing pure polynomials. """ + + def _hashable_content(self): + """Allow SymPy to hash Poly instances. """ + return (self.rep,) + + def __hash__(self): + return super().__hash__() + + @property + def free_symbols(self): + """ + Free symbols of a polynomial. + + Examples + ======== + + >>> from sympy import PurePoly + >>> from sympy.abc import x, y + + >>> PurePoly(x**2 + 1).free_symbols + set() + >>> PurePoly(x**2 + y).free_symbols + set() + >>> PurePoly(x**2 + y, x).free_symbols + {y} + + """ + return self.free_symbols_in_domain + + @_sympifyit('other', NotImplemented) + def __eq__(self, other): + f, g = self, other + + if not g.is_Poly: + try: + g = f.__class__(g, f.gens, domain=f.get_domain()) + except (PolynomialError, DomainError, CoercionFailed): + return False + + if len(f.gens) != len(g.gens): + return False + + if f.rep.dom != g.rep.dom: + try: + dom = f.rep.dom.unify(g.rep.dom, f.gens) + except UnificationFailed: + return False + + f = f.set_domain(dom) + g = g.set_domain(dom) + + return f.rep == g.rep + + def _strict_eq(f, g): + return isinstance(g, f.__class__) and f.rep.eq(g.rep, strict=True) + + def _unify(f, g): + g = sympify(g) + + if not g.is_Poly: + try: + return f.rep.dom, f.per, f.rep, f.rep.per(f.rep.dom.from_sympy(g)) + except CoercionFailed: + raise UnificationFailed("Cannot unify %s with %s" % (f, g)) + + if len(f.gens) != len(g.gens): + raise UnificationFailed("Cannot unify %s with %s" % (f, g)) + + if not (isinstance(f.rep, DMP) and isinstance(g.rep, DMP)): + raise UnificationFailed("Cannot unify %s with %s" % (f, g)) + + cls = f.__class__ + gens = f.gens + + dom = f.rep.dom.unify(g.rep.dom, gens) + + F = f.rep.convert(dom) + G = g.rep.convert(dom) + + def per(rep, dom=dom, gens=gens, remove=None): + if remove is not None: + gens = gens[:remove] + gens[remove + 1:] + + if not gens: + return dom.to_sympy(rep) + + return cls.new(rep, *gens) + + return dom, per, F, G + + +@public +def poly_from_expr(expr, *gens, **args): + """Construct a polynomial from an expression. """ + opt = options.build_options(gens, args) + return _poly_from_expr(expr, opt) + + +def _poly_from_expr(expr, opt): + """Construct a polynomial from an expression. """ + orig, expr = expr, sympify(expr) + + if not isinstance(expr, Basic): + raise PolificationFailed(opt, orig, expr) + elif expr.is_Poly: + poly = expr.__class__._from_poly(expr, opt) + + opt.gens = poly.gens + opt.domain = poly.domain + + if opt.polys is None: + opt.polys = True + + return poly, opt + elif opt.expand: + expr = expr.expand() + + rep, opt = _dict_from_expr(expr, opt) + if not opt.gens: + raise PolificationFailed(opt, orig, expr) + + monoms, coeffs = list(zip(*list(rep.items()))) + domain = opt.domain + + if domain is None: + opt.domain, coeffs = construct_domain(coeffs, opt=opt) + else: + coeffs = list(map(domain.from_sympy, coeffs)) + + rep = dict(list(zip(monoms, coeffs))) + poly = Poly._from_dict(rep, opt) + + if opt.polys is None: + opt.polys = False + + return poly, opt + + +@public +def parallel_poly_from_expr(exprs, *gens, **args): + """Construct polynomials from expressions. """ + opt = options.build_options(gens, args) + return _parallel_poly_from_expr(exprs, opt) + + +def _parallel_poly_from_expr(exprs, opt): + """Construct polynomials from expressions. """ + if len(exprs) == 2: + f, g = exprs + + if isinstance(f, Poly) and isinstance(g, Poly): + f = f.__class__._from_poly(f, opt) + g = g.__class__._from_poly(g, opt) + + f, g = f.unify(g) + + opt.gens = f.gens + opt.domain = f.domain + + if opt.polys is None: + opt.polys = True + + return [f, g], opt + + origs, exprs = list(exprs), [] + _exprs, _polys = [], [] + + failed = False + + for i, expr in enumerate(origs): + expr = sympify(expr) + + if isinstance(expr, Basic): + if expr.is_Poly: + _polys.append(i) + else: + _exprs.append(i) + + if opt.expand: + expr = expr.expand() + else: + failed = True + + exprs.append(expr) + + if failed: + raise PolificationFailed(opt, origs, exprs, True) + + if _polys: + # XXX: this is a temporary solution + for i in _polys: + exprs[i] = exprs[i].as_expr() + + reps, opt = _parallel_dict_from_expr(exprs, opt) + if not opt.gens: + raise PolificationFailed(opt, origs, exprs, True) + + from sympy.functions.elementary.piecewise import Piecewise + for k in opt.gens: + if isinstance(k, Piecewise): + raise PolynomialError("Piecewise generators do not make sense") + + coeffs_list, lengths = [], [] + + all_monoms = [] + all_coeffs = [] + + for rep in reps: + monoms, coeffs = list(zip(*list(rep.items()))) + + coeffs_list.extend(coeffs) + all_monoms.append(monoms) + + lengths.append(len(coeffs)) + + domain = opt.domain + + if domain is None: + opt.domain, coeffs_list = construct_domain(coeffs_list, opt=opt) + else: + coeffs_list = list(map(domain.from_sympy, coeffs_list)) + + for k in lengths: + all_coeffs.append(coeffs_list[:k]) + coeffs_list = coeffs_list[k:] + + polys = [] + + for monoms, coeffs in zip(all_monoms, all_coeffs): + rep = dict(list(zip(monoms, coeffs))) + poly = Poly._from_dict(rep, opt) + polys.append(poly) + + if opt.polys is None: + opt.polys = bool(_polys) + + return polys, opt + + +def _update_args(args, key, value): + """Add a new ``(key, value)`` pair to arguments ``dict``. """ + args = dict(args) + + if key not in args: + args[key] = value + + return args + + +@public +def degree(f, gen=0): + """ + Return the degree of ``f`` in the given variable. + + The degree of 0 is negative infinity. + + Examples + ======== + + >>> from sympy import degree + >>> from sympy.abc import x, y + + >>> degree(x**2 + y*x + 1, gen=x) + 2 + >>> degree(x**2 + y*x + 1, gen=y) + 1 + >>> degree(0, x) + -oo + + See also + ======== + + sympy.polys.polytools.Poly.total_degree + degree_list + """ + + f = sympify(f, strict=True) + gen_is_Num = sympify(gen, strict=True).is_Number + if f.is_Poly: + p = f + isNum = p.as_expr().is_Number + else: + isNum = f.is_Number + if not isNum: + if gen_is_Num: + p, _ = poly_from_expr(f) + else: + p, _ = poly_from_expr(f, gen) + + if isNum: + return S.Zero if f else S.NegativeInfinity + + if not gen_is_Num: + if f.is_Poly and gen not in p.gens: + # try recast without explicit gens + p, _ = poly_from_expr(f.as_expr()) + if gen not in p.gens: + return S.Zero + elif not f.is_Poly and len(f.free_symbols) > 1: + raise TypeError(filldedent(''' + A symbolic generator of interest is required for a multivariate + expression like func = %s, e.g. degree(func, gen = %s) instead of + degree(func, gen = %s). + ''' % (f, next(ordered(f.free_symbols)), gen))) + result = p.degree(gen) + return Integer(result) if isinstance(result, int) else S.NegativeInfinity + + +@public +def total_degree(f, *gens): + """ + Return the total_degree of ``f`` in the given variables. + + Examples + ======== + >>> from sympy import total_degree, Poly + >>> from sympy.abc import x, y + + >>> total_degree(1) + 0 + >>> total_degree(x + x*y) + 2 + >>> total_degree(x + x*y, x) + 1 + + If the expression is a Poly and no variables are given + then the generators of the Poly will be used: + + >>> p = Poly(x + x*y, y) + >>> total_degree(p) + 1 + + To deal with the underlying expression of the Poly, convert + it to an Expr: + + >>> total_degree(p.as_expr()) + 2 + + This is done automatically if any variables are given: + + >>> total_degree(p, x) + 1 + + See also + ======== + degree + """ + + p = sympify(f) + if p.is_Poly: + p = p.as_expr() + if p.is_Number: + rv = 0 + else: + if f.is_Poly: + gens = gens or f.gens + rv = Poly(p, gens).total_degree() + + return Integer(rv) + + +@public +def degree_list(f, *gens, **args): + """ + Return a list of degrees of ``f`` in all variables. + + Examples + ======== + + >>> from sympy import degree_list + >>> from sympy.abc import x, y + + >>> degree_list(x**2 + y*x + 1) + (2, 1) + + """ + options.allowed_flags(args, ['polys']) + + try: + F, opt = poly_from_expr(f, *gens, **args) + except PolificationFailed as exc: + raise ComputationFailed('degree_list', 1, exc) + + degrees = F.degree_list() + + return tuple(map(Integer, degrees)) + + +@public +def LC(f, *gens, **args): + """ + Return the leading coefficient of ``f``. + + Examples + ======== + + >>> from sympy import LC + >>> from sympy.abc import x, y + + >>> LC(4*x**2 + 2*x*y**2 + x*y + 3*y) + 4 + + """ + options.allowed_flags(args, ['polys']) + + try: + F, opt = poly_from_expr(f, *gens, **args) + except PolificationFailed as exc: + raise ComputationFailed('LC', 1, exc) + + return F.LC(order=opt.order) + + +@public +def LM(f, *gens, **args): + """ + Return the leading monomial of ``f``. + + Examples + ======== + + >>> from sympy import LM + >>> from sympy.abc import x, y + + >>> LM(4*x**2 + 2*x*y**2 + x*y + 3*y) + x**2 + + """ + options.allowed_flags(args, ['polys']) + + try: + F, opt = poly_from_expr(f, *gens, **args) + except PolificationFailed as exc: + raise ComputationFailed('LM', 1, exc) + + monom = F.LM(order=opt.order) + return monom.as_expr() + + +@public +def LT(f, *gens, **args): + """ + Return the leading term of ``f``. + + Examples + ======== + + >>> from sympy import LT + >>> from sympy.abc import x, y + + >>> LT(4*x**2 + 2*x*y**2 + x*y + 3*y) + 4*x**2 + + """ + options.allowed_flags(args, ['polys']) + + try: + F, opt = poly_from_expr(f, *gens, **args) + except PolificationFailed as exc: + raise ComputationFailed('LT', 1, exc) + + monom, coeff = F.LT(order=opt.order) + return coeff*monom.as_expr() + + +@public +def pdiv(f, g, *gens, **args): + """ + Compute polynomial pseudo-division of ``f`` and ``g``. + + Examples + ======== + + >>> from sympy import pdiv + >>> from sympy.abc import x + + >>> pdiv(x**2 + 1, 2*x - 4) + (2*x + 4, 20) + + """ + options.allowed_flags(args, ['polys']) + + try: + (F, G), opt = parallel_poly_from_expr((f, g), *gens, **args) + except PolificationFailed as exc: + raise ComputationFailed('pdiv', 2, exc) + + q, r = F.pdiv(G) + + if not opt.polys: + return q.as_expr(), r.as_expr() + else: + return q, r + + +@public +def prem(f, g, *gens, **args): + """ + Compute polynomial pseudo-remainder of ``f`` and ``g``. + + Examples + ======== + + >>> from sympy import prem + >>> from sympy.abc import x + + >>> prem(x**2 + 1, 2*x - 4) + 20 + + """ + options.allowed_flags(args, ['polys']) + + try: + (F, G), opt = parallel_poly_from_expr((f, g), *gens, **args) + except PolificationFailed as exc: + raise ComputationFailed('prem', 2, exc) + + r = F.prem(G) + + if not opt.polys: + return r.as_expr() + else: + return r + + +@public +def pquo(f, g, *gens, **args): + """ + Compute polynomial pseudo-quotient of ``f`` and ``g``. + + Examples + ======== + + >>> from sympy import pquo + >>> from sympy.abc import x + + >>> pquo(x**2 + 1, 2*x - 4) + 2*x + 4 + >>> pquo(x**2 - 1, 2*x - 1) + 2*x + 1 + + """ + options.allowed_flags(args, ['polys']) + + try: + (F, G), opt = parallel_poly_from_expr((f, g), *gens, **args) + except PolificationFailed as exc: + raise ComputationFailed('pquo', 2, exc) + + try: + q = F.pquo(G) + except ExactQuotientFailed: + raise ExactQuotientFailed(f, g) + + if not opt.polys: + return q.as_expr() + else: + return q + + +@public +def pexquo(f, g, *gens, **args): + """ + Compute polynomial exact pseudo-quotient of ``f`` and ``g``. + + Examples + ======== + + >>> from sympy import pexquo + >>> from sympy.abc import x + + >>> pexquo(x**2 - 1, 2*x - 2) + 2*x + 2 + + >>> pexquo(x**2 + 1, 2*x - 4) + Traceback (most recent call last): + ... + ExactQuotientFailed: 2*x - 4 does not divide x**2 + 1 + + """ + options.allowed_flags(args, ['polys']) + + try: + (F, G), opt = parallel_poly_from_expr((f, g), *gens, **args) + except PolificationFailed as exc: + raise ComputationFailed('pexquo', 2, exc) + + q = F.pexquo(G) + + if not opt.polys: + return q.as_expr() + else: + return q + + +@public +def div(f, g, *gens, **args): + """ + Compute polynomial division of ``f`` and ``g``. + + Examples + ======== + + >>> from sympy import div, ZZ, QQ + >>> from sympy.abc import x + + >>> div(x**2 + 1, 2*x - 4, domain=ZZ) + (0, x**2 + 1) + >>> div(x**2 + 1, 2*x - 4, domain=QQ) + (x/2 + 1, 5) + + """ + options.allowed_flags(args, ['auto', 'polys']) + + try: + (F, G), opt = parallel_poly_from_expr((f, g), *gens, **args) + except PolificationFailed as exc: + raise ComputationFailed('div', 2, exc) + + q, r = F.div(G, auto=opt.auto) + + if not opt.polys: + return q.as_expr(), r.as_expr() + else: + return q, r + + +@public +def rem(f, g, *gens, **args): + """ + Compute polynomial remainder of ``f`` and ``g``. + + Examples + ======== + + >>> from sympy import rem, ZZ, QQ + >>> from sympy.abc import x + + >>> rem(x**2 + 1, 2*x - 4, domain=ZZ) + x**2 + 1 + >>> rem(x**2 + 1, 2*x - 4, domain=QQ) + 5 + + """ + options.allowed_flags(args, ['auto', 'polys']) + + try: + (F, G), opt = parallel_poly_from_expr((f, g), *gens, **args) + except PolificationFailed as exc: + raise ComputationFailed('rem', 2, exc) + + r = F.rem(G, auto=opt.auto) + + if not opt.polys: + return r.as_expr() + else: + return r + + +@public +def quo(f, g, *gens, **args): + """ + Compute polynomial quotient of ``f`` and ``g``. + + Examples + ======== + + >>> from sympy import quo + >>> from sympy.abc import x + + >>> quo(x**2 + 1, 2*x - 4) + x/2 + 1 + >>> quo(x**2 - 1, x - 1) + x + 1 + + """ + options.allowed_flags(args, ['auto', 'polys']) + + try: + (F, G), opt = parallel_poly_from_expr((f, g), *gens, **args) + except PolificationFailed as exc: + raise ComputationFailed('quo', 2, exc) + + q = F.quo(G, auto=opt.auto) + + if not opt.polys: + return q.as_expr() + else: + return q + + +@public +def exquo(f, g, *gens, **args): + """ + Compute polynomial exact quotient of ``f`` and ``g``. + + Examples + ======== + + >>> from sympy import exquo + >>> from sympy.abc import x + + >>> exquo(x**2 - 1, x - 1) + x + 1 + + >>> exquo(x**2 + 1, 2*x - 4) + Traceback (most recent call last): + ... + ExactQuotientFailed: 2*x - 4 does not divide x**2 + 1 + + """ + options.allowed_flags(args, ['auto', 'polys']) + + try: + (F, G), opt = parallel_poly_from_expr((f, g), *gens, **args) + except PolificationFailed as exc: + raise ComputationFailed('exquo', 2, exc) + + q = F.exquo(G, auto=opt.auto) + + if not opt.polys: + return q.as_expr() + else: + return q + + +@public +def half_gcdex(f, g, *gens, **args): + """ + Half extended Euclidean algorithm of ``f`` and ``g``. + + Returns ``(s, h)`` such that ``h = gcd(f, g)`` and ``s*f = h (mod g)``. + + Examples + ======== + + >>> from sympy import half_gcdex + >>> from sympy.abc import x + + >>> half_gcdex(x**4 - 2*x**3 - 6*x**2 + 12*x + 15, x**3 + x**2 - 4*x - 4) + (3/5 - x/5, x + 1) + + """ + options.allowed_flags(args, ['auto', 'polys']) + + try: + (F, G), opt = parallel_poly_from_expr((f, g), *gens, **args) + except PolificationFailed as exc: + domain, (a, b) = construct_domain(exc.exprs) + + try: + s, h = domain.half_gcdex(a, b) + except NotImplementedError: + raise ComputationFailed('half_gcdex', 2, exc) + else: + return domain.to_sympy(s), domain.to_sympy(h) + + s, h = F.half_gcdex(G, auto=opt.auto) + + if not opt.polys: + return s.as_expr(), h.as_expr() + else: + return s, h + + +@public +def gcdex(f, g, *gens, **args): + """ + Extended Euclidean algorithm of ``f`` and ``g``. + + Returns ``(s, t, h)`` such that ``h = gcd(f, g)`` and ``s*f + t*g = h``. + + Examples + ======== + + >>> from sympy import gcdex + >>> from sympy.abc import x + + >>> gcdex(x**4 - 2*x**3 - 6*x**2 + 12*x + 15, x**3 + x**2 - 4*x - 4) + (3/5 - x/5, x**2/5 - 6*x/5 + 2, x + 1) + + """ + options.allowed_flags(args, ['auto', 'polys']) + + try: + (F, G), opt = parallel_poly_from_expr((f, g), *gens, **args) + except PolificationFailed as exc: + domain, (a, b) = construct_domain(exc.exprs) + + try: + s, t, h = domain.gcdex(a, b) + except NotImplementedError: + raise ComputationFailed('gcdex', 2, exc) + else: + return domain.to_sympy(s), domain.to_sympy(t), domain.to_sympy(h) + + s, t, h = F.gcdex(G, auto=opt.auto) + + if not opt.polys: + return s.as_expr(), t.as_expr(), h.as_expr() + else: + return s, t, h + + +@public +def invert(f, g, *gens, **args): + """ + Invert ``f`` modulo ``g`` when possible. + + Examples + ======== + + >>> from sympy import invert, S, mod_inverse + >>> from sympy.abc import x + + >>> invert(x**2 - 1, 2*x - 1) + -4/3 + + >>> invert(x**2 - 1, x - 1) + Traceback (most recent call last): + ... + NotInvertible: zero divisor + + For more efficient inversion of Rationals, + use the :obj:`sympy.core.intfunc.mod_inverse` function: + + >>> mod_inverse(3, 5) + 2 + >>> (S(2)/5).invert(S(7)/3) + 5/2 + + See Also + ======== + sympy.core.intfunc.mod_inverse + + """ + options.allowed_flags(args, ['auto', 'polys']) + + try: + (F, G), opt = parallel_poly_from_expr((f, g), *gens, **args) + except PolificationFailed as exc: + domain, (a, b) = construct_domain(exc.exprs) + + try: + return domain.to_sympy(domain.invert(a, b)) + except NotImplementedError: + raise ComputationFailed('invert', 2, exc) + + h = F.invert(G, auto=opt.auto) + + if not opt.polys: + return h.as_expr() + else: + return h + + +@public +def subresultants(f, g, *gens, **args): + """ + Compute subresultant PRS of ``f`` and ``g``. + + Examples + ======== + + >>> from sympy import subresultants + >>> from sympy.abc import x + + >>> subresultants(x**2 + 1, x**2 - 1) + [x**2 + 1, x**2 - 1, -2] + + """ + options.allowed_flags(args, ['polys']) + + try: + (F, G), opt = parallel_poly_from_expr((f, g), *gens, **args) + except PolificationFailed as exc: + raise ComputationFailed('subresultants', 2, exc) + + result = F.subresultants(G) + + if not opt.polys: + return [r.as_expr() for r in result] + else: + return result + + +@public +def resultant(f, g, *gens, includePRS=False, **args): + """ + Compute resultant of ``f`` and ``g``. + + Examples + ======== + + >>> from sympy import resultant + >>> from sympy.abc import x + + >>> resultant(x**2 + 1, x**2 - 1) + 4 + + """ + options.allowed_flags(args, ['polys']) + + try: + (F, G), opt = parallel_poly_from_expr((f, g), *gens, **args) + except PolificationFailed as exc: + raise ComputationFailed('resultant', 2, exc) + + if includePRS: + result, R = F.resultant(G, includePRS=includePRS) + else: + result = F.resultant(G) + + if not opt.polys: + if includePRS: + return result.as_expr(), [r.as_expr() for r in R] + return result.as_expr() + else: + if includePRS: + return result, R + return result + + +@public +def discriminant(f, *gens, **args): + """ + Compute discriminant of ``f``. + + Examples + ======== + + >>> from sympy import discriminant + >>> from sympy.abc import x + + >>> discriminant(x**2 + 2*x + 3) + -8 + + """ + options.allowed_flags(args, ['polys']) + + try: + F, opt = poly_from_expr(f, *gens, **args) + except PolificationFailed as exc: + raise ComputationFailed('discriminant', 1, exc) + + result = F.discriminant() + + if not opt.polys: + return result.as_expr() + else: + return result + + +@public +def cofactors(f, g, *gens, **args): + """ + Compute GCD and cofactors of ``f`` and ``g``. + + Returns polynomials ``(h, cff, cfg)`` such that ``h = gcd(f, g)``, and + ``cff = quo(f, h)`` and ``cfg = quo(g, h)`` are, so called, cofactors + of ``f`` and ``g``. + + Examples + ======== + + >>> from sympy import cofactors + >>> from sympy.abc import x + + >>> cofactors(x**2 - 1, x**2 - 3*x + 2) + (x - 1, x + 1, x - 2) + + """ + options.allowed_flags(args, ['polys']) + + try: + (F, G), opt = parallel_poly_from_expr((f, g), *gens, **args) + except PolificationFailed as exc: + domain, (a, b) = construct_domain(exc.exprs) + + try: + h, cff, cfg = domain.cofactors(a, b) + except NotImplementedError: + raise ComputationFailed('cofactors', 2, exc) + else: + return domain.to_sympy(h), domain.to_sympy(cff), domain.to_sympy(cfg) + + h, cff, cfg = F.cofactors(G) + + if not opt.polys: + return h.as_expr(), cff.as_expr(), cfg.as_expr() + else: + return h, cff, cfg + + +@public +def gcd_list(seq, *gens, **args): + """ + Compute GCD of a list of polynomials. + + Examples + ======== + + >>> from sympy import gcd_list + >>> from sympy.abc import x + + >>> gcd_list([x**3 - 1, x**2 - 1, x**2 - 3*x + 2]) + x - 1 + + """ + seq = sympify(seq) + + def try_non_polynomial_gcd(seq): + if not gens and not args: + domain, numbers = construct_domain(seq) + + if not numbers: + return domain.zero + elif domain.is_Numerical: + result, numbers = numbers[0], numbers[1:] + + for number in numbers: + result = domain.gcd(result, number) + + if domain.is_one(result): + break + + return domain.to_sympy(result) + + return None + + result = try_non_polynomial_gcd(seq) + + if result is not None: + return result + + options.allowed_flags(args, ['polys']) + + try: + polys, opt = parallel_poly_from_expr(seq, *gens, **args) + + # gcd for domain Q[irrational] (purely algebraic irrational) + if len(seq) > 1 and all(elt.is_algebraic and elt.is_irrational for elt in seq): + a = seq[-1] + lst = [ (a/elt).ratsimp() for elt in seq[:-1] ] + if all(frc.is_rational for frc in lst): + lc = 1 + for frc in lst: + lc = lcm(lc, frc.as_numer_denom()[0]) + # abs ensures that the gcd is always non-negative + return abs(a/lc) + + except PolificationFailed as exc: + result = try_non_polynomial_gcd(exc.exprs) + + if result is not None: + return result + else: + raise ComputationFailed('gcd_list', len(seq), exc) + + if not polys: + if not opt.polys: + return S.Zero + else: + return Poly(0, opt=opt) + + result, polys = polys[0], polys[1:] + + for poly in polys: + result = result.gcd(poly) + + if result.is_one: + break + + if not opt.polys: + return result.as_expr() + else: + return result + + +@public +def gcd(f, g=None, *gens, **args): + """ + Compute GCD of ``f`` and ``g``. + + Examples + ======== + + >>> from sympy import gcd + >>> from sympy.abc import x + + >>> gcd(x**2 - 1, x**2 - 3*x + 2) + x - 1 + + """ + if hasattr(f, '__iter__'): + if g is not None: + gens = (g,) + gens + + return gcd_list(f, *gens, **args) + elif g is None: + raise TypeError("gcd() takes 2 arguments or a sequence of arguments") + + options.allowed_flags(args, ['polys']) + + try: + (F, G), opt = parallel_poly_from_expr((f, g), *gens, **args) + + # gcd for domain Q[irrational] (purely algebraic irrational) + a, b = map(sympify, (f, g)) + if a.is_algebraic and a.is_irrational and b.is_algebraic and b.is_irrational: + frc = (a/b).ratsimp() + if frc.is_rational: + # abs ensures that the returned gcd is always non-negative + return abs(a/frc.as_numer_denom()[0]) + + except PolificationFailed as exc: + domain, (a, b) = construct_domain(exc.exprs) + + try: + return domain.to_sympy(domain.gcd(a, b)) + except NotImplementedError: + raise ComputationFailed('gcd', 2, exc) + + result = F.gcd(G) + + if not opt.polys: + return result.as_expr() + else: + return result + + +@public +def lcm_list(seq, *gens, **args): + """ + Compute LCM of a list of polynomials. + + Examples + ======== + + >>> from sympy import lcm_list + >>> from sympy.abc import x + + >>> lcm_list([x**3 - 1, x**2 - 1, x**2 - 3*x + 2]) + x**5 - x**4 - 2*x**3 - x**2 + x + 2 + + """ + seq = sympify(seq) + + def try_non_polynomial_lcm(seq) -> Optional[Expr]: + if not gens and not args: + domain, numbers = construct_domain(seq) + + if not numbers: + return domain.to_sympy(domain.one) + elif domain.is_Numerical: + result, numbers = numbers[0], numbers[1:] + + for number in numbers: + result = domain.lcm(result, number) + + return domain.to_sympy(result) + + return None + + result = try_non_polynomial_lcm(seq) + + if result is not None: + return result + + options.allowed_flags(args, ['polys']) + + try: + polys, opt = parallel_poly_from_expr(seq, *gens, **args) + + # lcm for domain Q[irrational] (purely algebraic irrational) + if len(seq) > 1 and all(elt.is_algebraic and elt.is_irrational for elt in seq): + a = seq[-1] + lst = [ (a/elt).ratsimp() for elt in seq[:-1] ] + if all(frc.is_rational for frc in lst): + lc = 1 + for frc in lst: + lc = lcm(lc, frc.as_numer_denom()[1]) + return a*lc + + except PolificationFailed as exc: + result = try_non_polynomial_lcm(exc.exprs) + + if result is not None: + return result + else: + raise ComputationFailed('lcm_list', len(seq), exc) + + if not polys: + if not opt.polys: + return S.One + else: + return Poly(1, opt=opt) + + result, polys = polys[0], polys[1:] + + for poly in polys: + result = result.lcm(poly) + + if not opt.polys: + return result.as_expr() + else: + return result + + +@public +def lcm(f, g=None, *gens, **args): + """ + Compute LCM of ``f`` and ``g``. + + Examples + ======== + + >>> from sympy import lcm + >>> from sympy.abc import x + + >>> lcm(x**2 - 1, x**2 - 3*x + 2) + x**3 - 2*x**2 - x + 2 + + """ + if hasattr(f, '__iter__'): + if g is not None: + gens = (g,) + gens + + return lcm_list(f, *gens, **args) + elif g is None: + raise TypeError("lcm() takes 2 arguments or a sequence of arguments") + + options.allowed_flags(args, ['polys']) + + try: + (F, G), opt = parallel_poly_from_expr((f, g), *gens, **args) + + # lcm for domain Q[irrational] (purely algebraic irrational) + a, b = map(sympify, (f, g)) + if a.is_algebraic and a.is_irrational and b.is_algebraic and b.is_irrational: + frc = (a/b).ratsimp() + if frc.is_rational: + return a*frc.as_numer_denom()[1] + + except PolificationFailed as exc: + domain, (a, b) = construct_domain(exc.exprs) + + try: + return domain.to_sympy(domain.lcm(a, b)) + except NotImplementedError: + raise ComputationFailed('lcm', 2, exc) + + result = F.lcm(G) + + if not opt.polys: + return result.as_expr() + else: + return result + + +@public +def terms_gcd(f, *gens, **args): + """ + Remove GCD of terms from ``f``. + + If the ``deep`` flag is True, then the arguments of ``f`` will have + terms_gcd applied to them. + + If a fraction is factored out of ``f`` and ``f`` is an Add, then + an unevaluated Mul will be returned so that automatic simplification + does not redistribute it. The hint ``clear``, when set to False, can be + used to prevent such factoring when all coefficients are not fractions. + + Examples + ======== + + >>> from sympy import terms_gcd, cos + >>> from sympy.abc import x, y + >>> terms_gcd(x**6*y**2 + x**3*y, x, y) + x**3*y*(x**3*y + 1) + + The default action of polys routines is to expand the expression + given to them. terms_gcd follows this behavior: + + >>> terms_gcd((3+3*x)*(x+x*y)) + 3*x*(x*y + x + y + 1) + + If this is not desired then the hint ``expand`` can be set to False. + In this case the expression will be treated as though it were comprised + of one or more terms: + + >>> terms_gcd((3+3*x)*(x+x*y), expand=False) + (3*x + 3)*(x*y + x) + + In order to traverse factors of a Mul or the arguments of other + functions, the ``deep`` hint can be used: + + >>> terms_gcd((3 + 3*x)*(x + x*y), expand=False, deep=True) + 3*x*(x + 1)*(y + 1) + >>> terms_gcd(cos(x + x*y), deep=True) + cos(x*(y + 1)) + + Rationals are factored out by default: + + >>> terms_gcd(x + y/2) + (2*x + y)/2 + + Only the y-term had a coefficient that was a fraction; if one + does not want to factor out the 1/2 in cases like this, the + flag ``clear`` can be set to False: + + >>> terms_gcd(x + y/2, clear=False) + x + y/2 + >>> terms_gcd(x*y/2 + y**2, clear=False) + y*(x/2 + y) + + The ``clear`` flag is ignored if all coefficients are fractions: + + >>> terms_gcd(x/3 + y/2, clear=False) + (2*x + 3*y)/6 + + See Also + ======== + sympy.core.exprtools.gcd_terms, sympy.core.exprtools.factor_terms + + """ + + orig = sympify(f) + + if isinstance(f, Equality): + return Equality(*(terms_gcd(s, *gens, **args) for s in [f.lhs, f.rhs])) + elif isinstance(f, Relational): + raise TypeError("Inequalities cannot be used with terms_gcd. Found: %s" %(f,)) + + if not isinstance(f, Expr) or f.is_Atom: + return orig + + if args.get('deep', False): + new = f.func(*[terms_gcd(a, *gens, **args) for a in f.args]) + args.pop('deep') + args['expand'] = False + return terms_gcd(new, *gens, **args) + + clear = args.pop('clear', True) + options.allowed_flags(args, ['polys']) + + try: + F, opt = poly_from_expr(f, *gens, **args) + except PolificationFailed as exc: + return exc.expr + + J, f = F.terms_gcd() + + if opt.domain.is_Ring: + if opt.domain.is_Field: + denom, f = f.clear_denoms(convert=True) + + coeff, f = f.primitive() + + if opt.domain.is_Field: + coeff /= denom + else: + coeff = S.One + + term = Mul(*[x**j for x, j in zip(f.gens, J)]) + if equal_valued(coeff, 1): + coeff = S.One + if term == 1: + return orig + + if clear: + return _keep_coeff(coeff, term*f.as_expr()) + # base the clearing on the form of the original expression, not + # the (perhaps) Mul that we have now + coeff, f = _keep_coeff(coeff, f.as_expr(), clear=False).as_coeff_Mul() + return _keep_coeff(coeff, term*f, clear=False) + + +@public +def trunc(f, p, *gens, **args): + """ + Reduce ``f`` modulo a constant ``p``. + + Examples + ======== + + >>> from sympy import trunc + >>> from sympy.abc import x + + >>> trunc(2*x**3 + 3*x**2 + 5*x + 7, 3) + -x**3 - x + 1 + + """ + options.allowed_flags(args, ['auto', 'polys']) + + try: + F, opt = poly_from_expr(f, *gens, **args) + except PolificationFailed as exc: + raise ComputationFailed('trunc', 1, exc) + + result = F.trunc(sympify(p)) + + if not opt.polys: + return result.as_expr() + else: + return result + + +@public +def monic(f, *gens, **args): + """ + Divide all coefficients of ``f`` by ``LC(f)``. + + Examples + ======== + + >>> from sympy import monic + >>> from sympy.abc import x + + >>> monic(3*x**2 + 4*x + 2) + x**2 + 4*x/3 + 2/3 + + """ + options.allowed_flags(args, ['auto', 'polys']) + + try: + F, opt = poly_from_expr(f, *gens, **args) + except PolificationFailed as exc: + raise ComputationFailed('monic', 1, exc) + + result = F.monic(auto=opt.auto) + + if not opt.polys: + return result.as_expr() + else: + return result + + +@public +def content(f, *gens, **args): + """ + Compute GCD of coefficients of ``f``. + + Examples + ======== + + >>> from sympy import content + >>> from sympy.abc import x + + >>> content(6*x**2 + 8*x + 12) + 2 + + """ + options.allowed_flags(args, ['polys']) + + try: + F, opt = poly_from_expr(f, *gens, **args) + except PolificationFailed as exc: + raise ComputationFailed('content', 1, exc) + + return F.content() + + +@public +def primitive(f, *gens, **args): + """ + Compute content and the primitive form of ``f``. + + Examples + ======== + + >>> from sympy.polys.polytools import primitive + >>> from sympy.abc import x + + >>> primitive(6*x**2 + 8*x + 12) + (2, 3*x**2 + 4*x + 6) + + >>> eq = (2 + 2*x)*x + 2 + + Expansion is performed by default: + + >>> primitive(eq) + (2, x**2 + x + 1) + + Set ``expand`` to False to shut this off. Note that the + extraction will not be recursive; use the as_content_primitive method + for recursive, non-destructive Rational extraction. + + >>> primitive(eq, expand=False) + (1, x*(2*x + 2) + 2) + + >>> eq.as_content_primitive() + (2, x*(x + 1) + 1) + + """ + options.allowed_flags(args, ['polys']) + + try: + F, opt = poly_from_expr(f, *gens, **args) + except PolificationFailed as exc: + raise ComputationFailed('primitive', 1, exc) + + cont, result = F.primitive() + if not opt.polys: + return cont, result.as_expr() + else: + return cont, result + + +@public +def compose(f, g, *gens, **args): + """ + Compute functional composition ``f(g)``. + + Examples + ======== + + >>> from sympy import compose + >>> from sympy.abc import x + + >>> compose(x**2 + x, x - 1) + x**2 - x + + """ + options.allowed_flags(args, ['polys']) + + try: + (F, G), opt = parallel_poly_from_expr((f, g), *gens, **args) + except PolificationFailed as exc: + raise ComputationFailed('compose', 2, exc) + + result = F.compose(G) + + if not opt.polys: + return result.as_expr() + else: + return result + + +@public +def decompose(f, *gens, **args): + """ + Compute functional decomposition of ``f``. + + Examples + ======== + + >>> from sympy import decompose + >>> from sympy.abc import x + + >>> decompose(x**4 + 2*x**3 - x - 1) + [x**2 - x - 1, x**2 + x] + + """ + options.allowed_flags(args, ['polys']) + + try: + F, opt = poly_from_expr(f, *gens, **args) + except PolificationFailed as exc: + raise ComputationFailed('decompose', 1, exc) + + result = F.decompose() + + if not opt.polys: + return [r.as_expr() for r in result] + else: + return result + + +@public +def sturm(f, *gens, **args): + """ + Compute Sturm sequence of ``f``. + + Examples + ======== + + >>> from sympy import sturm + >>> from sympy.abc import x + + >>> sturm(x**3 - 2*x**2 + x - 3) + [x**3 - 2*x**2 + x - 3, 3*x**2 - 4*x + 1, 2*x/9 + 25/9, -2079/4] + + """ + options.allowed_flags(args, ['auto', 'polys']) + + try: + F, opt = poly_from_expr(f, *gens, **args) + except PolificationFailed as exc: + raise ComputationFailed('sturm', 1, exc) + + result = F.sturm(auto=opt.auto) + + if not opt.polys: + return [r.as_expr() for r in result] + else: + return result + + +@public +def gff_list(f, *gens, **args): + """ + Compute a list of greatest factorial factors of ``f``. + + Note that the input to ff() and rf() should be Poly instances to use the + definitions here. + + Examples + ======== + + >>> from sympy import gff_list, ff, Poly + >>> from sympy.abc import x + + >>> f = Poly(x**5 + 2*x**4 - x**3 - 2*x**2, x) + + >>> gff_list(f) + [(Poly(x, x, domain='ZZ'), 1), (Poly(x + 2, x, domain='ZZ'), 4)] + + >>> (ff(Poly(x), 1)*ff(Poly(x + 2), 4)) == f + True + + >>> f = Poly(x**12 + 6*x**11 - 11*x**10 - 56*x**9 + 220*x**8 + 208*x**7 - \ + 1401*x**6 + 1090*x**5 + 2715*x**4 - 6720*x**3 - 1092*x**2 + 5040*x, x) + + >>> gff_list(f) + [(Poly(x**3 + 7, x, domain='ZZ'), 2), (Poly(x**2 + 5*x, x, domain='ZZ'), 3)] + + >>> ff(Poly(x**3 + 7, x), 2)*ff(Poly(x**2 + 5*x, x), 3) == f + True + + """ + options.allowed_flags(args, ['polys']) + + try: + F, opt = poly_from_expr(f, *gens, **args) + except PolificationFailed as exc: + raise ComputationFailed('gff_list', 1, exc) + + factors = F.gff_list() + + if not opt.polys: + return [(g.as_expr(), k) for g, k in factors] + else: + return factors + + +@public +def gff(f, *gens, **args): + """Compute greatest factorial factorization of ``f``. """ + raise NotImplementedError('symbolic falling factorial') + + +@public +def sqf_norm(f, *gens, **args): + """ + Compute square-free norm of ``f``. + + Returns ``s``, ``f``, ``r``, such that ``g(x) = f(x-sa)`` and + ``r(x) = Norm(g(x))`` is a square-free polynomial over ``K``, + where ``a`` is the algebraic extension of the ground domain. + + Examples + ======== + + >>> from sympy import sqf_norm, sqrt + >>> from sympy.abc import x + + >>> sqf_norm(x**2 + 1, extension=[sqrt(3)]) + ([1], x**2 - 2*sqrt(3)*x + 4, x**4 - 4*x**2 + 16) + + """ + options.allowed_flags(args, ['polys']) + + try: + F, opt = poly_from_expr(f, *gens, **args) + except PolificationFailed as exc: + raise ComputationFailed('sqf_norm', 1, exc) + + s, g, r = F.sqf_norm() + + s_expr = [Integer(si) for si in s] + + if not opt.polys: + return s_expr, g.as_expr(), r.as_expr() + else: + return s_expr, g, r + + +@public +def sqf_part(f, *gens, **args): + """ + Compute square-free part of ``f``. + + Examples + ======== + + >>> from sympy import sqf_part + >>> from sympy.abc import x + + >>> sqf_part(x**3 - 3*x - 2) + x**2 - x - 2 + + """ + options.allowed_flags(args, ['polys']) + + try: + F, opt = poly_from_expr(f, *gens, **args) + except PolificationFailed as exc: + raise ComputationFailed('sqf_part', 1, exc) + + result = F.sqf_part() + + if not opt.polys: + return result.as_expr() + else: + return result + + +def _poly_sort_key(poly): + """Sort a list of polys.""" + rep = poly.rep.to_list() + return (len(rep), len(poly.gens), str(poly.domain), rep) + + +def _sorted_factors(factors, method): + """Sort a list of ``(expr, exp)`` pairs. """ + if method == 'sqf': + def key(obj): + poly, exp = obj + rep = poly.rep.to_list() + return (exp, len(rep), len(poly.gens), str(poly.domain), rep) + else: + def key(obj): + poly, exp = obj + rep = poly.rep.to_list() + return (len(rep), len(poly.gens), exp, str(poly.domain), rep) + + return sorted(factors, key=key) + + +def _factors_product(factors): + """Multiply a list of ``(expr, exp)`` pairs. """ + return Mul(*[f.as_expr()**k for f, k in factors]) + + +def _symbolic_factor_list(expr, opt, method): + """Helper function for :func:`_symbolic_factor`. """ + coeff, factors = S.One, [] + + args = [i._eval_factor() if hasattr(i, '_eval_factor') else i + for i in Mul.make_args(expr)] + for arg in args: + if arg.is_Number or (isinstance(arg, Expr) and pure_complex(arg)): + coeff *= arg + continue + elif arg.is_Pow and arg.base != S.Exp1: + base, exp = arg.args + if base.is_Number and exp.is_Number: + coeff *= arg + continue + if base.is_Number: + factors.append((base, exp)) + continue + else: + base, exp = arg, S.One + + try: + poly, _ = _poly_from_expr(base, opt) + except PolificationFailed as exc: + factors.append((exc.expr, exp)) + else: + func = getattr(poly, method + '_list') + + _coeff, _factors = func() + if _coeff is not S.One: + if exp.is_Integer: + coeff *= _coeff**exp + elif _coeff.is_positive: + factors.append((_coeff, exp)) + else: + _factors.append((_coeff, S.One)) + + if exp is S.One: + factors.extend(_factors) + elif exp.is_integer: + factors.extend([(f, k*exp) for f, k in _factors]) + else: + other = [] + + for f, k in _factors: + if f.as_expr().is_positive: + factors.append((f, k*exp)) + else: + other.append((f, k)) + + factors.append((_factors_product(other), exp)) + if method == 'sqf': + factors = [(reduce(mul, (f for f, _ in factors if _ == k)), k) + for k in {i for _, i in factors}] + #collect duplicates + rv = defaultdict(int) + for k, v in factors: + rv[k] += v + return coeff, list(rv.items()) + + +def _symbolic_factor(expr, opt, method): + """Helper function for :func:`_factor`. """ + if isinstance(expr, Expr): + if hasattr(expr,'_eval_factor'): + return expr._eval_factor() + coeff, factors = _symbolic_factor_list(together(expr, fraction=opt['fraction']), opt, method) + return _keep_coeff(coeff, _factors_product(factors)) + elif hasattr(expr, 'args'): + return expr.func(*[_symbolic_factor(arg, opt, method) for arg in expr.args]) + elif hasattr(expr, '__iter__'): + return expr.__class__([_symbolic_factor(arg, opt, method) for arg in expr]) + else: + return expr + + +def _generic_factor_list(expr, gens, args, method): + """Helper function for :func:`sqf_list` and :func:`factor_list`. """ + options.allowed_flags(args, ['frac', 'polys']) + opt = options.build_options(gens, args) + + expr = sympify(expr) + + if isinstance(expr, (Expr, Poly)): + if isinstance(expr, Poly): + numer, denom = expr, 1 + else: + numer, denom = together(expr).as_numer_denom() + + cp, fp = _symbolic_factor_list(numer, opt, method) + cq, fq = _symbolic_factor_list(denom, opt, method) + + if fq and not opt.frac: + raise PolynomialError("a polynomial expected, got %s" % expr) + + _opt = opt.clone({"expand": True}) + + for factors in (fp, fq): + for i, (f, k) in enumerate(factors): + if not f.is_Poly: + f, _ = _poly_from_expr(f, _opt) + factors[i] = (f, k) + + fp = _sorted_factors(fp, method) + fq = _sorted_factors(fq, method) + + if not opt.polys: + fp = [(f.as_expr(), k) for f, k in fp] + fq = [(f.as_expr(), k) for f, k in fq] + + coeff = cp/cq + + if not opt.frac: + return coeff, fp + else: + return coeff, fp, fq + else: + raise PolynomialError("a polynomial expected, got %s" % expr) + + +def _generic_factor(expr, gens, args, method): + """Helper function for :func:`sqf` and :func:`factor`. """ + fraction = args.pop('fraction', True) + options.allowed_flags(args, []) + opt = options.build_options(gens, args) + opt['fraction'] = fraction + return _symbolic_factor(sympify(expr), opt, method) + + +def to_rational_coeffs(f): + """ + try to transform a polynomial to have rational coefficients + + try to find a transformation ``x = alpha*y`` + + ``f(x) = lc*alpha**n * g(y)`` where ``g`` is a polynomial with + rational coefficients, ``lc`` the leading coefficient. + + If this fails, try ``x = y + beta`` + ``f(x) = g(y)`` + + Returns ``None`` if ``g`` not found; + ``(lc, alpha, None, g)`` in case of rescaling + ``(None, None, beta, g)`` in case of translation + + Notes + ===== + + Currently it transforms only polynomials without roots larger than 2. + + Examples + ======== + + >>> from sympy import sqrt, Poly, simplify + >>> from sympy.polys.polytools import to_rational_coeffs + >>> from sympy.abc import x + >>> p = Poly(((x**2-1)*(x-2)).subs({x:x*(1 + sqrt(2))}), x, domain='EX') + >>> lc, r, _, g = to_rational_coeffs(p) + >>> lc, r + (7 + 5*sqrt(2), 2 - 2*sqrt(2)) + >>> g + Poly(x**3 + x**2 - 1/4*x - 1/4, x, domain='QQ') + >>> r1 = simplify(1/r) + >>> Poly(lc*r**3*(g.as_expr()).subs({x:x*r1}), x, domain='EX') == p + True + + """ + from sympy.simplify.simplify import simplify + + def _try_rescale(f, f1=None): + """ + try rescaling ``x -> alpha*x`` to convert f to a polynomial + with rational coefficients. + Returns ``alpha, f``; if the rescaling is successful, + ``alpha`` is the rescaling factor, and ``f`` is the rescaled + polynomial; else ``alpha`` is ``None``. + """ + if not len(f.gens) == 1 or not (f.gens[0]).is_Atom: + return None, f + n = f.degree() + lc = f.LC() + f1 = f1 or f1.monic() + coeffs = f1.all_coeffs()[1:] + coeffs = [simplify(coeffx) for coeffx in coeffs] + if len(coeffs) > 1 and coeffs[-2]: + rescale1_x = simplify(coeffs[-2]/coeffs[-1]) + coeffs1 = [] + for i in range(len(coeffs)): + coeffx = simplify(coeffs[i]*rescale1_x**(i + 1)) + if not coeffx.is_rational: + break + coeffs1.append(coeffx) + else: + rescale_x = simplify(1/rescale1_x) + x = f.gens[0] + v = [x**n] + for i in range(1, n + 1): + v.append(coeffs1[i - 1]*x**(n - i)) + f = Add(*v) + f = Poly(f) + return lc, rescale_x, f + return None + + def _try_translate(f, f1=None): + """ + try translating ``x -> x + alpha`` to convert f to a polynomial + with rational coefficients. + Returns ``alpha, f``; if the translating is successful, + ``alpha`` is the translating factor, and ``f`` is the shifted + polynomial; else ``alpha`` is ``None``. + """ + if not len(f.gens) == 1 or not (f.gens[0]).is_Atom: + return None, f + n = f.degree() + f1 = f1 or f1.monic() + coeffs = f1.all_coeffs()[1:] + c = simplify(coeffs[0]) + if c.is_Add and not c.is_rational: + rat, nonrat = sift(c.args, + lambda z: z.is_rational is True, binary=True) + alpha = -c.func(*nonrat)/n + f2 = f1.shift(alpha) + return alpha, f2 + return None + + def _has_square_roots(p): + """ + Return True if ``f`` is a sum with square roots but no other root + """ + coeffs = p.coeffs() + has_sq = False + for y in coeffs: + for x in Add.make_args(y): + f = Factors(x).factors + r = [wx.q for b, wx in f.items() if + b.is_number and wx.is_Rational and wx.q >= 2] + if not r: + continue + if min(r) == 2: + has_sq = True + if max(r) > 2: + return False + return has_sq + + if f.get_domain().is_EX and _has_square_roots(f): + f1 = f.monic() + r = _try_rescale(f, f1) + if r: + return r[0], r[1], None, r[2] + else: + r = _try_translate(f, f1) + if r: + return None, None, r[0], r[1] + return None + + +def _torational_factor_list(p, x): + """ + helper function to factor polynomial using to_rational_coeffs + + Examples + ======== + + >>> from sympy.polys.polytools import _torational_factor_list + >>> from sympy.abc import x + >>> from sympy import sqrt, expand, Mul + >>> p = expand(((x**2-1)*(x-2)).subs({x:x*(1 + sqrt(2))})) + >>> factors = _torational_factor_list(p, x); factors + (-2, [(-x*(1 + sqrt(2))/2 + 1, 1), (-x*(1 + sqrt(2)) - 1, 1), (-x*(1 + sqrt(2)) + 1, 1)]) + >>> expand(factors[0]*Mul(*[z[0] for z in factors[1]])) == p + True + >>> p = expand(((x**2-1)*(x-2)).subs({x:x + sqrt(2)})) + >>> factors = _torational_factor_list(p, x); factors + (1, [(x - 2 + sqrt(2), 1), (x - 1 + sqrt(2), 1), (x + 1 + sqrt(2), 1)]) + >>> expand(factors[0]*Mul(*[z[0] for z in factors[1]])) == p + True + + """ + from sympy.simplify.simplify import simplify + p1 = Poly(p, x, domain='EX') + n = p1.degree() + res = to_rational_coeffs(p1) + if not res: + return None + lc, r, t, g = res + factors = factor_list(g.as_expr()) + if lc: + c = simplify(factors[0]*lc*r**n) + r1 = simplify(1/r) + a = [] + for z in factors[1:][0]: + a.append((simplify(z[0].subs({x: x*r1})), z[1])) + else: + c = factors[0] + a = [] + for z in factors[1:][0]: + a.append((z[0].subs({x: x - t}), z[1])) + return (c, a) + + +@public +def sqf_list(f, *gens, **args): + """ + Compute a list of square-free factors of ``f``. + + Examples + ======== + + >>> from sympy import sqf_list + >>> from sympy.abc import x + + >>> sqf_list(2*x**5 + 16*x**4 + 50*x**3 + 76*x**2 + 56*x + 16) + (2, [(x + 1, 2), (x + 2, 3)]) + + """ + return _generic_factor_list(f, gens, args, method='sqf') + + +@public +def sqf(f, *gens, **args): + """ + Compute square-free factorization of ``f``. + + Examples + ======== + + >>> from sympy import sqf + >>> from sympy.abc import x + + >>> sqf(2*x**5 + 16*x**4 + 50*x**3 + 76*x**2 + 56*x + 16) + 2*(x + 1)**2*(x + 2)**3 + + """ + return _generic_factor(f, gens, args, method='sqf') + + +@public +def factor_list(f, *gens, **args): + """ + Compute a list of irreducible factors of ``f``. + + Examples + ======== + + >>> from sympy import factor_list + >>> from sympy.abc import x, y + + >>> factor_list(2*x**5 + 2*x**4*y + 4*x**3 + 4*x**2*y + 2*x + 2*y) + (2, [(x + y, 1), (x**2 + 1, 2)]) + + """ + return _generic_factor_list(f, gens, args, method='factor') + + +@public +def factor(f, *gens, deep=False, **args): + """ + Compute the factorization of expression, ``f``, into irreducibles. (To + factor an integer into primes, use ``factorint``.) + + There two modes implemented: symbolic and formal. If ``f`` is not an + instance of :class:`Poly` and generators are not specified, then the + former mode is used. Otherwise, the formal mode is used. + + In symbolic mode, :func:`factor` will traverse the expression tree and + factor its components without any prior expansion, unless an instance + of :class:`~.Add` is encountered (in this case formal factorization is + used). This way :func:`factor` can handle large or symbolic exponents. + + By default, the factorization is computed over the rationals. To factor + over other domain, e.g. an algebraic or finite field, use appropriate + options: ``extension``, ``modulus`` or ``domain``. + + Examples + ======== + + >>> from sympy import factor, sqrt, exp + >>> from sympy.abc import x, y + + >>> factor(2*x**5 + 2*x**4*y + 4*x**3 + 4*x**2*y + 2*x + 2*y) + 2*(x + y)*(x**2 + 1)**2 + + >>> factor(x**2 + 1) + x**2 + 1 + >>> factor(x**2 + 1, modulus=2) + (x + 1)**2 + >>> factor(x**2 + 1, gaussian=True) + (x - I)*(x + I) + + >>> factor(x**2 - 2, extension=sqrt(2)) + (x - sqrt(2))*(x + sqrt(2)) + + >>> factor((x**2 - 1)/(x**2 + 4*x + 4)) + (x - 1)*(x + 1)/(x + 2)**2 + >>> factor((x**2 + 4*x + 4)**10000000*(x**2 + 1)) + (x + 2)**20000000*(x**2 + 1) + + By default, factor deals with an expression as a whole: + + >>> eq = 2**(x**2 + 2*x + 1) + >>> factor(eq) + 2**(x**2 + 2*x + 1) + + If the ``deep`` flag is True then subexpressions will + be factored: + + >>> factor(eq, deep=True) + 2**((x + 1)**2) + + If the ``fraction`` flag is False then rational expressions + will not be combined. By default it is True. + + >>> factor(5*x + 3*exp(2 - 7*x), deep=True) + (5*x*exp(7*x) + 3*exp(2))*exp(-7*x) + >>> factor(5*x + 3*exp(2 - 7*x), deep=True, fraction=False) + 5*x + 3*exp(2)*exp(-7*x) + + See Also + ======== + sympy.ntheory.factor_.factorint + + """ + f = sympify(f) + if deep: + def _try_factor(expr): + """ + Factor, but avoid changing the expression when unable to. + """ + fac = factor(expr, *gens, **args) + if fac.is_Mul or fac.is_Pow: + return fac + return expr + + f = bottom_up(f, _try_factor) + # clean up any subexpressions that may have been expanded + # while factoring out a larger expression + partials = {} + muladd = f.atoms(Mul, Add) + for p in muladd: + fac = factor(p, *gens, **args) + if (fac.is_Mul or fac.is_Pow) and fac != p: + partials[p] = fac + return f.xreplace(partials) + + try: + return _generic_factor(f, gens, args, method='factor') + except PolynomialError: + if not f.is_commutative: + return factor_nc(f) + else: + raise + + +@public +def intervals(F, all=False, eps=None, inf=None, sup=None, strict=False, fast=False, sqf=False): + """ + Compute isolating intervals for roots of ``f``. + + Examples + ======== + + >>> from sympy import intervals + >>> from sympy.abc import x + + >>> intervals(x**2 - 3) + [((-2, -1), 1), ((1, 2), 1)] + >>> intervals(x**2 - 3, eps=1e-2) + [((-26/15, -19/11), 1), ((19/11, 26/15), 1)] + + """ + if not hasattr(F, '__iter__'): + try: + F = Poly(F) + except GeneratorsNeeded: + return [] + + return F.intervals(all=all, eps=eps, inf=inf, sup=sup, fast=fast, sqf=sqf) + else: + polys, opt = parallel_poly_from_expr(F, domain='QQ') + + if len(opt.gens) > 1: + raise MultivariatePolynomialError + + for i, poly in enumerate(polys): + polys[i] = poly.rep.to_list() + + if eps is not None: + eps = opt.domain.convert(eps) + + if eps <= 0: + raise ValueError("'eps' must be a positive rational") + + if inf is not None: + inf = opt.domain.convert(inf) + if sup is not None: + sup = opt.domain.convert(sup) + + intervals = dup_isolate_real_roots_list(polys, opt.domain, + eps=eps, inf=inf, sup=sup, strict=strict, fast=fast) + + result = [] + + for (s, t), indices in intervals: + s, t = opt.domain.to_sympy(s), opt.domain.to_sympy(t) + result.append(((s, t), indices)) + + return result + + +@public +def refine_root(f, s, t, eps=None, steps=None, fast=False, check_sqf=False): + """ + Refine an isolating interval of a root to the given precision. + + Examples + ======== + + >>> from sympy import refine_root + >>> from sympy.abc import x + + >>> refine_root(x**2 - 3, 1, 2, eps=1e-2) + (19/11, 26/15) + + """ + try: + F = Poly(f) + if not isinstance(f, Poly) and not F.gen.is_Symbol: + # root of sin(x) + 1 is -1 but when someone + # passes an Expr instead of Poly they may not expect + # that the generator will be sin(x), not x + raise PolynomialError("generator must be a Symbol") + except GeneratorsNeeded: + raise PolynomialError( + "Cannot refine a root of %s, not a polynomial" % f) + + return F.refine_root(s, t, eps=eps, steps=steps, fast=fast, check_sqf=check_sqf) + + +@public +def count_roots(f, inf=None, sup=None): + """ + Return the number of roots of ``f`` in ``[inf, sup]`` interval. + + If one of ``inf`` or ``sup`` is complex, it will return the number of roots + in the complex rectangle with corners at ``inf`` and ``sup``. + + Examples + ======== + + >>> from sympy import count_roots, I + >>> from sympy.abc import x + + >>> count_roots(x**4 - 4, -3, 3) + 2 + >>> count_roots(x**4 - 4, 0, 1 + 3*I) + 1 + + """ + try: + F = Poly(f, greedy=False) + if not isinstance(f, Poly) and not F.gen.is_Symbol: + # root of sin(x) + 1 is -1 but when someone + # passes an Expr instead of Poly they may not expect + # that the generator will be sin(x), not x + raise PolynomialError("generator must be a Symbol") + except GeneratorsNeeded: + raise PolynomialError("Cannot count roots of %s, not a polynomial" % f) + + return F.count_roots(inf=inf, sup=sup) + + +@public +def all_roots(f, multiple=True, radicals=True, extension=False): + """ + Returns the real and complex roots of ``f`` with multiplicities. + + Explanation + =========== + + Finds all real and complex roots of a univariate polynomial with rational + coefficients of any degree exactly. The roots are represented in the form + given by :func:`~.rootof`. This is equivalent to using :func:`~.rootof` to + find each of the indexed roots. + + Examples + ======== + + >>> from sympy import all_roots + >>> from sympy.abc import x, y + + >>> print(all_roots(x**3 + 1)) + [-1, 1/2 - sqrt(3)*I/2, 1/2 + sqrt(3)*I/2] + + Simple radical formulae are used in some cases but the cubic and quartic + formulae are avoided. Instead most non-rational roots will be represented + as :class:`~.ComplexRootOf`: + + >>> print(all_roots(x**3 + x + 1)) + [CRootOf(x**3 + x + 1, 0), CRootOf(x**3 + x + 1, 1), CRootOf(x**3 + x + 1, 2)] + + All roots of any polynomial with rational coefficients of any degree can be + represented using :py:class:`~.ComplexRootOf`. The use of + :py:class:`~.ComplexRootOf` bypasses limitations on the availability of + radical formulae for quintic and higher degree polynomials _[1]: + + >>> p = x**5 - x - 1 + >>> for r in all_roots(p): print(r) + CRootOf(x**5 - x - 1, 0) + CRootOf(x**5 - x - 1, 1) + CRootOf(x**5 - x - 1, 2) + CRootOf(x**5 - x - 1, 3) + CRootOf(x**5 - x - 1, 4) + >>> [r.evalf(3) for r in all_roots(p)] + [1.17, -0.765 - 0.352*I, -0.765 + 0.352*I, 0.181 - 1.08*I, 0.181 + 1.08*I] + + Irrational algebraic coefficients are handled by :func:`all_roots` + if `extension=True` is set. + + >>> from sympy import sqrt, expand + >>> p = expand((x - sqrt(2))*(x - sqrt(3))) + >>> print(p) + x**2 - sqrt(3)*x - sqrt(2)*x + sqrt(6) + >>> all_roots(p) + Traceback (most recent call last): + ... + NotImplementedError: sorted roots not supported over EX + >>> all_roots(p, extension=True) + [sqrt(2), sqrt(3)] + + Algebraic coefficients can be complex as well. + + >>> from sympy import I + >>> all_roots(x**2 - I, extension=True) + [-sqrt(2)/2 - sqrt(2)*I/2, sqrt(2)/2 + sqrt(2)*I/2] + >>> all_roots(x**2 - sqrt(2)*I, extension=True) + [-2**(3/4)/2 - 2**(3/4)*I/2, 2**(3/4)/2 + 2**(3/4)*I/2] + + Transcendental coefficients cannot currently be handled by + :func:`all_roots`. In the case of algebraic or transcendental coefficients + :func:`~.ground_roots` might be able to find some roots by factorisation: + + >>> from sympy import ground_roots + >>> ground_roots(p, x, extension=True) + {sqrt(2): 1, sqrt(3): 1} + + If the coefficients are numeric then :func:`~.nroots` can be used to find + all roots approximately: + + >>> from sympy import nroots + >>> nroots(p, 5) + [1.4142, 1.732] + + If the coefficients are symbolic then :func:`sympy.polys.polyroots.roots` + or :func:`~.ground_roots` should be used instead: + + >>> from sympy import roots, ground_roots + >>> p = x**2 - 3*x*y + 2*y**2 + >>> roots(p, x) + {y: 1, 2*y: 1} + >>> ground_roots(p, x) + {y: 1, 2*y: 1} + + Parameters + ========== + + f : :class:`~.Expr` or :class:`~.Poly` + A univariate polynomial with rational (or ``Float``) coefficients. + multiple : ``bool`` (default ``True``). + Whether to return a ``list`` of roots or a list of root/multiplicity + pairs. + radicals : ``bool`` (default ``True``) + Use simple radical formulae rather than :py:class:`~.ComplexRootOf` for + some irrational roots. + extension: ``bool`` (default ``False``) + Whether to construct an algebraic extension domain before computing + the roots. Setting to ``True`` is necessary for finding roots of a + polynomial with (irrational) algebraic coefficients but can be slow. + + Returns + ======= + + A list of :class:`~.Expr` (usually :class:`~.ComplexRootOf`) representing + the roots is returned with each root repeated according to its multiplicity + as a root of ``f``. The roots are always uniquely ordered with real roots + coming before complex roots. The real roots are in increasing order. + Complex roots are ordered by increasing real part and then increasing + imaginary part. + + If ``multiple=False`` is passed then a list of root/multiplicity pairs is + returned instead. + + If ``radicals=False`` is passed then all roots will be represented as + either rational numbers or :class:`~.ComplexRootOf`. + + See also + ======== + + Poly.all_roots: + The underlying :class:`Poly` method used by :func:`~.all_roots`. + rootof: + Compute a single numbered root of a univariate polynomial. + real_roots: + Compute all the real roots using :func:`~.rootof`. + ground_roots: + Compute some roots in the ground domain by factorisation. + nroots: + Compute all roots using approximate numerical techniques. + sympy.polys.polyroots.roots: + Compute symbolic expressions for roots using radical formulae. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Abel%E2%80%93Ruffini_theorem + """ + try: + if isinstance(f, Poly): + if extension and not f.domain.is_AlgebraicField: + F = Poly(f.expr, extension=True) + else: + F = f + else: + if extension: + F = Poly(f, extension=True) + else: + F = Poly(f, greedy=False) + + if not isinstance(f, Poly) and not F.gen.is_Symbol: + # root of sin(x) + 1 is -1 but when someone + # passes an Expr instead of Poly they may not expect + # that the generator will be sin(x), not x + raise PolynomialError("generator must be a Symbol") + except GeneratorsNeeded: + raise PolynomialError( + "Cannot compute real roots of %s, not a polynomial" % f) + + return F.all_roots(multiple=multiple, radicals=radicals) + + +@public +def real_roots(f, multiple=True, radicals=True, extension=False): + """ + Returns the real roots of ``f`` with multiplicities. + + Explanation + =========== + + Finds all real roots of a univariate polynomial with rational coefficients + of any degree exactly. The roots are represented in the form given by + :func:`~.rootof`. This is equivalent to using :func:`~.rootof` or + :func:`~.all_roots` and filtering out only the real roots. However if only + the real roots are needed then :func:`real_roots` is more efficient than + :func:`~.all_roots` because it computes only the real roots and avoids + costly complex root isolation routines. + + Examples + ======== + + >>> from sympy import real_roots + >>> from sympy.abc import x, y + + >>> real_roots(2*x**3 - 7*x**2 + 4*x + 4) + [-1/2, 2, 2] + >>> real_roots(2*x**3 - 7*x**2 + 4*x + 4, multiple=False) + [(-1/2, 1), (2, 2)] + + Real roots of any polynomial with rational coefficients of any degree can + be represented using :py:class:`~.ComplexRootOf`: + + >>> p = x**9 + 2*x + 2 + >>> print(real_roots(p)) + [CRootOf(x**9 + 2*x + 2, 0)] + >>> [r.evalf(3) for r in real_roots(p)] + [-0.865] + + All rational roots will be returned as rational numbers. Roots of some + simple factors will be expressed using radical or other formulae (unless + ``radicals=False`` is passed). All other roots will be expressed as + :class:`~.ComplexRootOf`. + + >>> p = (x + 7)*(x**2 - 2)*(x**3 + x + 1) + >>> print(real_roots(p)) + [-7, -sqrt(2), CRootOf(x**3 + x + 1, 0), sqrt(2)] + >>> print(real_roots(p, radicals=False)) + [-7, CRootOf(x**2 - 2, 0), CRootOf(x**3 + x + 1, 0), CRootOf(x**2 - 2, 1)] + + All returned root expressions will numerically evaluate to real numbers + with no imaginary part. This is in contrast to the expressions generated by + the cubic or quartic formulae as used by :func:`~.roots` which suffer from + casus irreducibilis [1]_: + + >>> from sympy import roots + >>> p = 2*x**3 - 9*x**2 - 6*x + 3 + >>> [r.evalf(5) for r in roots(p, multiple=True)] + [5.0365 - 0.e-11*I, 0.33984 + 0.e-13*I, -0.87636 + 0.e-10*I] + >>> [r.evalf(5) for r in real_roots(p, x)] + [-0.87636, 0.33984, 5.0365] + >>> [r.is_real for r in roots(p, multiple=True)] + [None, None, None] + >>> [r.is_real for r in real_roots(p)] + [True, True, True] + + Using :func:`real_roots` is equivalent to using :func:`~.all_roots` (or + :func:`~.rootof`) and filtering out only the real roots: + + >>> from sympy import all_roots + >>> r = [r for r in all_roots(p) if r.is_real] + >>> real_roots(p) == r + True + + If only the real roots are wanted then using :func:`real_roots` is faster + than using :func:`~.all_roots`. Using :func:`real_roots` avoids complex root + isolation which can be a lot slower than real root isolation especially for + polynomials of high degree which typically have many more complex roots + than real roots. + + Irrational algebraic coefficients are handled by :func:`real_roots` + if `extension=True` is set. + + >>> from sympy import sqrt, expand + >>> p = expand((x - sqrt(2))*(x - sqrt(3))) + >>> print(p) + x**2 - sqrt(3)*x - sqrt(2)*x + sqrt(6) + >>> real_roots(p) + Traceback (most recent call last): + ... + NotImplementedError: sorted roots not supported over EX + >>> real_roots(p, extension=True) + [sqrt(2), sqrt(3)] + + Transcendental coefficients cannot currently be handled by + :func:`real_roots`. In the case of algebraic or transcendental coefficients + :func:`~.ground_roots` might be able to find some roots by factorisation: + + >>> from sympy import ground_roots + >>> ground_roots(p, x, extension=True) + {sqrt(2): 1, sqrt(3): 1} + + If the coefficients are numeric then :func:`~.nroots` can be used to find + all roots approximately: + + >>> from sympy import nroots + >>> nroots(p, 5) + [1.4142, 1.732] + + If the coefficients are symbolic then :func:`sympy.polys.polyroots.roots` + or :func:`~.ground_roots` should be used instead. + + >>> from sympy import roots, ground_roots + >>> p = x**2 - 3*x*y + 2*y**2 + >>> roots(p, x) + {y: 1, 2*y: 1} + >>> ground_roots(p, x) + {y: 1, 2*y: 1} + + Parameters + ========== + + f : :class:`~.Expr` or :class:`~.Poly` + A univariate polynomial with rational (or ``Float``) coefficients. + multiple : ``bool`` (default ``True``). + Whether to return a ``list`` of roots or a list of root/multiplicity + pairs. + radicals : ``bool`` (default ``True``) + Use simple radical formulae rather than :py:class:`~.ComplexRootOf` for + some irrational roots. + extension: ``bool`` (default ``False``) + Whether to construct an algebraic extension domain before computing + the roots. Setting to ``True`` is necessary for finding roots of a + polynomial with (irrational) algebraic coefficients but can be slow. + + Returns + ======= + + A list of :class:`~.Expr` (usually :class:`~.ComplexRootOf`) representing + the real roots is returned. The roots are arranged in increasing order and + are repeated according to their multiplicities as roots of ``f``. + + If ``multiple=False`` is passed then a list of root/multiplicity pairs is + returned instead. + + If ``radicals=False`` is passed then all roots will be represented as + either rational numbers or :class:`~.ComplexRootOf`. + + See also + ======== + + Poly.real_roots: + The underlying :class:`Poly` method used by :func:`real_roots`. + rootof: + Compute a single numbered root of a univariate polynomial. + all_roots: + Compute all real and non-real roots using :func:`~.rootof`. + ground_roots: + Compute some roots in the ground domain by factorisation. + nroots: + Compute all roots using approximate numerical techniques. + sympy.polys.polyroots.roots: + Compute symbolic expressions for roots using radical formulae. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Casus_irreducibilis + """ + try: + if isinstance(f, Poly): + if extension and not f.domain.is_AlgebraicField: + F = Poly(f.expr, extension=True) + else: + F = f + else: + if extension: + F = Poly(f, extension=True) + else: + F = Poly(f, greedy=False) + + if not isinstance(f, Poly) and not F.gen.is_Symbol: + # root of sin(x) + 1 is -1 but when someone + # passes an Expr instead of Poly they may not expect + # that the generator will be sin(x), not x + raise PolynomialError("generator must be a Symbol") + except GeneratorsNeeded: + raise PolynomialError( + "Cannot compute real roots of %s, not a polynomial" % f) + + return F.real_roots(multiple=multiple, radicals=radicals) + + +@public +def nroots(f, n=15, maxsteps=50, cleanup=True): + """ + Compute numerical approximations of roots of ``f``. + + Examples + ======== + + >>> from sympy import nroots + >>> from sympy.abc import x + + >>> nroots(x**2 - 3, n=15) + [-1.73205080756888, 1.73205080756888] + >>> nroots(x**2 - 3, n=30) + [-1.73205080756887729352744634151, 1.73205080756887729352744634151] + + """ + try: + F = Poly(f, greedy=False) + if not isinstance(f, Poly) and not F.gen.is_Symbol: + # root of sin(x) + 1 is -1 but when someone + # passes an Expr instead of Poly they may not expect + # that the generator will be sin(x), not x + raise PolynomialError("generator must be a Symbol") + except GeneratorsNeeded: + raise PolynomialError( + "Cannot compute numerical roots of %s, not a polynomial" % f) + + return F.nroots(n=n, maxsteps=maxsteps, cleanup=cleanup) + + +@public +def ground_roots(f, *gens, **args): + """ + Compute roots of ``f`` by factorization in the ground domain. + + Examples + ======== + + >>> from sympy import ground_roots + >>> from sympy.abc import x + + >>> ground_roots(x**6 - 4*x**4 + 4*x**3 - x**2) + {0: 2, 1: 2} + + """ + options.allowed_flags(args, []) + + try: + F, opt = poly_from_expr(f, *gens, **args) + if not isinstance(f, Poly) and not F.gen.is_Symbol: + # root of sin(x) + 1 is -1 but when someone + # passes an Expr instead of Poly they may not expect + # that the generator will be sin(x), not x + raise PolynomialError("generator must be a Symbol") + except PolificationFailed as exc: + raise ComputationFailed('ground_roots', 1, exc) + + return F.ground_roots() + + +@public +def nth_power_roots_poly(f, n, *gens, **args): + """ + Construct a polynomial with n-th powers of roots of ``f``. + + Examples + ======== + + >>> from sympy import nth_power_roots_poly, factor, roots + >>> from sympy.abc import x + + >>> f = x**4 - x**2 + 1 + >>> g = factor(nth_power_roots_poly(f, 2)) + + >>> g + (x**2 - x + 1)**2 + + >>> R_f = [ (r**2).expand() for r in roots(f) ] + >>> R_g = roots(g).keys() + + >>> set(R_f) == set(R_g) + True + + """ + options.allowed_flags(args, []) + + try: + F, opt = poly_from_expr(f, *gens, **args) + if not isinstance(f, Poly) and not F.gen.is_Symbol: + # root of sin(x) + 1 is -1 but when someone + # passes an Expr instead of Poly they may not expect + # that the generator will be sin(x), not x + raise PolynomialError("generator must be a Symbol") + except PolificationFailed as exc: + raise ComputationFailed('nth_power_roots_poly', 1, exc) + + result = F.nth_power_roots_poly(n) + + if not opt.polys: + return result.as_expr() + else: + return result + + +@public +def cancel(f, *gens, _signsimp=True, **args): + """ + Cancel common factors in a rational function ``f``. + + Examples + ======== + + >>> from sympy import cancel, sqrt, Symbol, together + >>> from sympy.abc import x + >>> A = Symbol('A', commutative=False) + + >>> cancel((2*x**2 - 2)/(x**2 - 2*x + 1)) + (2*x + 2)/(x - 1) + >>> cancel((sqrt(3) + sqrt(15)*A)/(sqrt(2) + sqrt(10)*A)) + sqrt(6)/2 + + Note: due to automatic distribution of Rationals, a sum divided by an integer + will appear as a sum. To recover a rational form use `together` on the result: + + >>> cancel(x/2 + 1) + x/2 + 1 + >>> together(_) + (x + 2)/2 + """ + from sympy.simplify.simplify import signsimp + from sympy.polys.rings import sring + options.allowed_flags(args, ['polys']) + + f = sympify(f) + if _signsimp: + f = signsimp(f) + opt = {} + if 'polys' in args: + opt['polys'] = args['polys'] + + if not isinstance(f, Tuple): + if f.is_Number or isinstance(f, Relational) or not isinstance(f, Expr): + return f + f = factor_terms(f, radical=True) + p, q = f.as_numer_denom() + + elif len(f) == 2: + p, q = f + if isinstance(p, Poly) and isinstance(q, Poly): + opt['gens'] = p.gens + opt['domain'] = p.domain + opt['polys'] = opt.get('polys', True) + p, q = p.as_expr(), q.as_expr() + else: + raise ValueError('unexpected argument: %s' % f) + + from sympy.functions.elementary.piecewise import Piecewise + try: + if f.has(Piecewise): + raise PolynomialError() + R, (F, G) = sring((p, q), *gens, **args) + if not R.ngens: + if not isinstance(f, Tuple): + return f.expand() + else: + return S.One, p, q + except PolynomialError as msg: + if f.is_commutative and not f.has(Piecewise): + raise PolynomialError(msg) + # Handling of noncommutative and/or piecewise expressions + if f.is_Add or f.is_Mul: + c, nc = sift(f.args, lambda x: + x.is_commutative is True and not x.has(Piecewise), + binary=True) + nc = [cancel(i) for i in nc] + return f.func(cancel(f.func(*c)), *nc) + else: + reps = [] + pot = preorder_traversal(f) + next(pot) + for e in pot: + if isinstance(e, BooleanAtom) or not isinstance(e, Expr): + continue + try: + reps.append((e, cancel(e))) + pot.skip() # this was handled successfully + except NotImplementedError: + pass + return f.xreplace(dict(reps)) + + c, (P, Q) = 1, F.cancel(G) + if opt.get('polys', False) and 'gens' not in opt: + opt['gens'] = R.symbols + + if not isinstance(f, Tuple): + return c*(P.as_expr()/Q.as_expr()) + else: + P, Q = P.as_expr(), Q.as_expr() + if not opt.get('polys', False): + return c, P, Q + else: + return c, Poly(P, *gens, **opt), Poly(Q, *gens, **opt) + + +@public +def reduced(f, G, *gens, **args): + """ + Reduces a polynomial ``f`` modulo a set of polynomials ``G``. + + Given a polynomial ``f`` and a set of polynomials ``G = (g_1, ..., g_n)``, + computes a set of quotients ``q = (q_1, ..., q_n)`` and the remainder ``r`` + such that ``f = q_1*g_1 + ... + q_n*g_n + r``, where ``r`` vanishes or ``r`` + is a completely reduced polynomial with respect to ``G``. + + Examples + ======== + + >>> from sympy import reduced + >>> from sympy.abc import x, y + + >>> reduced(2*x**4 + y**2 - x**2 + y**3, [x**3 - x, y**3 - y]) + ([2*x, 1], x**2 + y**2 + y) + + """ + options.allowed_flags(args, ['polys', 'auto']) + + try: + polys, opt = parallel_poly_from_expr([f] + list(G), *gens, **args) + except PolificationFailed as exc: + raise ComputationFailed('reduced', 0, exc) + + domain = opt.domain + retract = False + + if opt.auto and domain.is_Ring and not domain.is_Field: + opt = opt.clone({"domain": domain.get_field()}) + retract = True + + from sympy.polys.rings import xring + _ring, _ = xring(opt.gens, opt.domain, opt.order) + + for i, poly in enumerate(polys): + poly = poly.set_domain(opt.domain).rep.to_dict() + polys[i] = _ring.from_dict(poly) + + Q, r = polys[0].div(polys[1:]) + + Q = [Poly._from_dict(dict(q), opt) for q in Q] + r = Poly._from_dict(dict(r), opt) + + if retract: + try: + _Q, _r = [q.to_ring() for q in Q], r.to_ring() + except CoercionFailed: + pass + else: + Q, r = _Q, _r + + if not opt.polys: + return [q.as_expr() for q in Q], r.as_expr() + else: + return Q, r + + +@public +def groebner(F, *gens, **args): + """ + Computes the reduced Groebner basis for a set of polynomials. + + Use the ``order`` argument to set the monomial ordering that will be + used to compute the basis. Allowed orders are ``lex``, ``grlex`` and + ``grevlex``. If no order is specified, it defaults to ``lex``. + + For more information on Groebner bases, see the references and the docstring + of :func:`~.solve_poly_system`. + + Examples + ======== + + Example taken from [1]. + + >>> from sympy import groebner + >>> from sympy.abc import x, y + + >>> F = [x*y - 2*y, 2*y**2 - x**2] + + >>> groebner(F, x, y, order='lex') + GroebnerBasis([x**2 - 2*y**2, x*y - 2*y, y**3 - 2*y], x, y, + domain='ZZ', order='lex') + >>> groebner(F, x, y, order='grlex') + GroebnerBasis([y**3 - 2*y, x**2 - 2*y**2, x*y - 2*y], x, y, + domain='ZZ', order='grlex') + >>> groebner(F, x, y, order='grevlex') + GroebnerBasis([y**3 - 2*y, x**2 - 2*y**2, x*y - 2*y], x, y, + domain='ZZ', order='grevlex') + + By default, an improved implementation of the Buchberger algorithm is + used. Optionally, an implementation of the F5B algorithm can be used. The + algorithm can be set using the ``method`` flag or with the + :func:`sympy.polys.polyconfig.setup` function. + + >>> F = [x**2 - x - 1, (2*x - 1) * y - (x**10 - (1 - x)**10)] + + >>> groebner(F, x, y, method='buchberger') + GroebnerBasis([x**2 - x - 1, y - 55], x, y, domain='ZZ', order='lex') + >>> groebner(F, x, y, method='f5b') + GroebnerBasis([x**2 - x - 1, y - 55], x, y, domain='ZZ', order='lex') + + References + ========== + + 1. [Buchberger01]_ + 2. [Cox97]_ + + """ + return GroebnerBasis(F, *gens, **args) + + +@public +def is_zero_dimensional(F, *gens, **args): + """ + Checks if the ideal generated by a Groebner basis is zero-dimensional. + + The algorithm checks if the set of monomials not divisible by the + leading monomial of any element of ``F`` is bounded. + + References + ========== + + David A. Cox, John B. Little, Donal O'Shea. Ideals, Varieties and + Algorithms, 3rd edition, p. 230 + + """ + return GroebnerBasis(F, *gens, **args).is_zero_dimensional + + +@public +class GroebnerBasis(Basic): + """Represents a reduced Groebner basis. """ + + def __new__(cls, F, *gens, **args): + """Compute a reduced Groebner basis for a system of polynomials. """ + options.allowed_flags(args, ['polys', 'method']) + + try: + polys, opt = parallel_poly_from_expr(F, *gens, **args) + except PolificationFailed as exc: + raise ComputationFailed('groebner', len(F), exc) + + from sympy.polys.rings import PolyRing + ring = PolyRing(opt.gens, opt.domain, opt.order) + + polys = [ring.from_dict(poly.rep.to_dict()) for poly in polys if poly] + + G = _groebner(polys, ring, method=opt.method) + G = [Poly._from_dict(g, opt) for g in G] + + return cls._new(G, opt) + + @classmethod + def _new(cls, basis, options): + obj = Basic.__new__(cls) + + obj._basis = tuple(basis) + obj._options = options + + return obj + + @property + def args(self): + basis = (p.as_expr() for p in self._basis) + return (Tuple(*basis), Tuple(*self._options.gens)) + + @property + def exprs(self): + return [poly.as_expr() for poly in self._basis] + + @property + def polys(self): + return list(self._basis) + + @property + def gens(self): + return self._options.gens + + @property + def domain(self): + return self._options.domain + + @property + def order(self): + return self._options.order + + def __len__(self): + return len(self._basis) + + def __iter__(self): + if self._options.polys: + return iter(self.polys) + else: + return iter(self.exprs) + + def __getitem__(self, item): + if self._options.polys: + basis = self.polys + else: + basis = self.exprs + + return basis[item] + + def __hash__(self): + return hash((self._basis, tuple(self._options.items()))) + + def __eq__(self, other): + if isinstance(other, self.__class__): + return self._basis == other._basis and self._options == other._options + elif iterable(other): + return self.polys == list(other) or self.exprs == list(other) + else: + return False + + def __ne__(self, other): + return not self == other + + @property + def is_zero_dimensional(self): + """ + Checks if the ideal generated by a Groebner basis is zero-dimensional. + + The algorithm checks if the set of monomials not divisible by the + leading monomial of any element of ``F`` is bounded. + + References + ========== + + David A. Cox, John B. Little, Donal O'Shea. Ideals, Varieties and + Algorithms, 3rd edition, p. 230 + + """ + def single_var(monomial): + return sum(map(bool, monomial)) == 1 + + exponents = Monomial([0]*len(self.gens)) + order = self._options.order + + for poly in self.polys: + monomial = poly.LM(order=order) + + if single_var(monomial): + exponents *= monomial + + # If any element of the exponents vector is zero, then there's + # a variable for which there's no degree bound and the ideal + # generated by this Groebner basis isn't zero-dimensional. + return all(exponents) + + def fglm(self, order): + """ + Convert a Groebner basis from one ordering to another. + + The FGLM algorithm converts reduced Groebner bases of zero-dimensional + ideals from one ordering to another. This method is often used when it + is infeasible to compute a Groebner basis with respect to a particular + ordering directly. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy import groebner + + >>> F = [x**2 - 3*y - x + 1, y**2 - 2*x + y - 1] + >>> G = groebner(F, x, y, order='grlex') + + >>> list(G.fglm('lex')) + [2*x - y**2 - y + 1, y**4 + 2*y**3 - 3*y**2 - 16*y + 7] + >>> list(groebner(F, x, y, order='lex')) + [2*x - y**2 - y + 1, y**4 + 2*y**3 - 3*y**2 - 16*y + 7] + + References + ========== + + .. [1] J.C. Faugere, P. Gianni, D. Lazard, T. Mora (1994). Efficient + Computation of Zero-dimensional Groebner Bases by Change of + Ordering + + """ + opt = self._options + + src_order = opt.order + dst_order = monomial_key(order) + + if src_order == dst_order: + return self + + if not self.is_zero_dimensional: + raise NotImplementedError("Cannot convert Groebner bases of ideals with positive dimension") + + polys = list(self._basis) + domain = opt.domain + + opt = opt.clone({ + "domain": domain.get_field(), + "order": dst_order, + }) + + from sympy.polys.rings import xring + _ring, _ = xring(opt.gens, opt.domain, src_order) + + for i, poly in enumerate(polys): + poly = poly.set_domain(opt.domain).rep.to_dict() + polys[i] = _ring.from_dict(poly) + + G = matrix_fglm(polys, _ring, dst_order) + G = [Poly._from_dict(dict(g), opt) for g in G] + + if not domain.is_Field: + G = [g.clear_denoms(convert=True)[1] for g in G] + opt.domain = domain + + return self._new(G, opt) + + def reduce(self, expr, auto=True): + """ + Reduces a polynomial modulo a Groebner basis. + + Given a polynomial ``f`` and a set of polynomials ``G = (g_1, ..., g_n)``, + computes a set of quotients ``q = (q_1, ..., q_n)`` and the remainder ``r`` + such that ``f = q_1*f_1 + ... + q_n*f_n + r``, where ``r`` vanishes or ``r`` + is a completely reduced polynomial with respect to ``G``. + + Examples + ======== + + >>> from sympy import groebner, expand, Poly + >>> from sympy.abc import x, y + + >>> f = 2*x**4 - x**2 + y**3 + y**2 + >>> G = groebner([x**3 - x, y**3 - y]) + + >>> G.reduce(f) + ([2*x, 1], x**2 + y**2 + y) + >>> Q, r = _ + + >>> expand(sum(q*g for q, g in zip(Q, G)) + r) + 2*x**4 - x**2 + y**3 + y**2 + >>> _ == f + True + + # Using Poly input + >>> f_poly = Poly(f, x, y) + >>> G = groebner([Poly(x**3 - x), Poly(y**3 - y)]) + + >>> G.reduce(f_poly) + ([Poly(2*x, x, y, domain='ZZ'), Poly(1, x, y, domain='ZZ')], Poly(x**2 + y**2 + y, x, y, domain='ZZ')) + + """ + if isinstance(expr, Poly): + + if expr.gens != self._options.gens: + raise ValueError("Polynomial generators don't match Groebner basis generators") + poly = expr.set_domain(self._options.domain) + else: + + poly = Poly._from_expr(expr, self._options) + + polys = [poly] + list(self._basis) + + opt = self._options + domain = opt.domain + + retract = False + + if auto and domain.is_Ring and not domain.is_Field: + opt = opt.clone({"domain": domain.get_field()}) + retract = True + + from sympy.polys.rings import xring + _ring, _ = xring(opt.gens, opt.domain, opt.order) + + for i, poly in enumerate(polys): + poly = poly.set_domain(opt.domain).rep.to_dict() + polys[i] = _ring.from_dict(poly) + + Q, r = polys[0].div(polys[1:]) + + Q = [Poly._from_dict(dict(q), opt) for q in Q] + r = Poly._from_dict(dict(r), opt) + + if retract: + try: + _Q, _r = [q.to_ring() for q in Q], r.to_ring() + except CoercionFailed: + pass + else: + Q, r = _Q, _r + + if not opt.polys: + return [q.as_expr() for q in Q], r.as_expr() + else: + return Q, r + + def contains(self, poly): + """ + Check if ``poly`` belongs the ideal generated by ``self``. + + Examples + ======== + + >>> from sympy import groebner + >>> from sympy.abc import x, y + + >>> f = 2*x**3 + y**3 + 3*y + >>> G = groebner([x**2 + y**2 - 1, x*y - 2]) + + >>> G.contains(f) + True + >>> G.contains(f + 1) + False + + """ + return self.reduce(poly)[1] == 0 + + +@public +def poly(expr, *gens, **args): + """ + Efficiently transform an expression into a polynomial. + + Examples + ======== + + >>> from sympy import poly + >>> from sympy.abc import x + + >>> poly(x*(x**2 + x - 1)**2) + Poly(x**5 + 2*x**4 - x**3 - 2*x**2 + x, x, domain='ZZ') + + """ + options.allowed_flags(args, []) + + def _poly(expr, opt): + terms, poly_terms = [], [] + + for term in Add.make_args(expr): + factors, poly_factors = [], [] + + for factor in Mul.make_args(term): + if factor.is_Add: + poly_factors.append(_poly(factor, opt)) + elif factor.is_Pow and factor.base.is_Add and \ + factor.exp.is_Integer and factor.exp >= 0: + poly_factors.append( + _poly(factor.base, opt).pow(factor.exp)) + else: + factors.append(factor) + + if not poly_factors: + terms.append(term) + else: + product = poly_factors[0] + + for factor in poly_factors[1:]: + product = product.mul(factor) + + if factors: + factor = Mul(*factors) + + if factor.is_Number: + product *= factor + else: + product = product.mul(Poly._from_expr(factor, opt)) + + poly_terms.append(product) + + if not poly_terms: + result = Poly._from_expr(expr, opt) + else: + result = poly_terms[0] + + for term in poly_terms[1:]: + result = result.add(term) + + if terms: + term = Add(*terms) + + if term.is_Number: + result += term + else: + result = result.add(Poly._from_expr(term, opt)) + + return result.reorder(*opt.get('gens', ()), **args) + + expr = sympify(expr) + + if expr.is_Poly: + return Poly(expr, *gens, **args) + + if 'expand' not in args: + args['expand'] = False + + opt = options.build_options(gens, args) + + return _poly(expr, opt) + + +def named_poly(n, f, K, name, x, polys): + r"""Common interface to the low-level polynomial generating functions + in orthopolys and appellseqs. + + Parameters + ========== + + n : int + Index of the polynomial, which may or may not equal its degree. + f : callable + Low-level generating function to use. + K : Domain or None + Domain in which to perform the computations. If None, use the smallest + field containing the rationals and the extra parameters of x (see below). + name : str + Name of an arbitrary individual polynomial in the sequence generated + by f, only used in the error message for invalid n. + x : seq + The first element of this argument is the main variable of all + polynomials in this sequence. Any further elements are extra + parameters required by f. + polys : bool, optional + If True, return a Poly, otherwise (default) return an expression. + """ + if n < 0: + raise ValueError("Cannot generate %s of index %s" % (name, n)) + head, tail = x[0], x[1:] + if K is None: + K, tail = construct_domain(tail, field=True) + poly = DMP(f(int(n), *tail, K), K) + if head is None: + poly = PurePoly.new(poly, Dummy('x')) + else: + poly = Poly.new(poly, head) + return poly if polys else poly.as_expr() diff --git a/phivenv/Lib/site-packages/sympy/polys/polyutils.py b/phivenv/Lib/site-packages/sympy/polys/polyutils.py new file mode 100644 index 0000000000000000000000000000000000000000..6a2019d3b195891d84ce8e0b368f6bdc5f45d4b3 --- /dev/null +++ b/phivenv/Lib/site-packages/sympy/polys/polyutils.py @@ -0,0 +1,584 @@ +"""Useful utilities for higher level polynomial classes. """ + +from __future__ import annotations + +from sympy.external.gmpy import GROUND_TYPES + +from sympy.core import (S, Add, Mul, Pow, Eq, Expr, + expand_mul, expand_multinomial) +from sympy.core.exprtools import decompose_power, decompose_power_rat +from sympy.core.numbers import _illegal +from sympy.polys.polyerrors import PolynomialError, GeneratorsError +from sympy.polys.polyoptions import build_options + +import re + + +_gens_order = { + 'a': 301, 'b': 302, 'c': 303, 'd': 304, + 'e': 305, 'f': 306, 'g': 307, 'h': 308, + 'i': 309, 'j': 310, 'k': 311, 'l': 312, + 'm': 313, 'n': 314, 'o': 315, 'p': 216, + 'q': 217, 'r': 218, 's': 219, 't': 220, + 'u': 221, 'v': 222, 'w': 223, 'x': 124, + 'y': 125, 'z': 126, +} + +_max_order = 1000 +_re_gen = re.compile(r"^(.*?)(\d*)$", re.MULTILINE) + + +def _nsort(roots, separated=False): + """Sort the numerical roots putting the real roots first, then sorting + according to real and imaginary parts. If ``separated`` is True, then + the real and imaginary roots will be returned in two lists, respectively. + + This routine tries to avoid issue 6137 by separating the roots into real + and imaginary parts before evaluation. In addition, the sorting will raise + an error if any computation cannot be done with precision. + """ + if not all(r.is_number for r in roots): + raise NotImplementedError + if not len(roots): + return [] if not separated else ([], []) + # see issue 6137: + # get the real part of the evaluated real and imaginary parts of each root + key = [[i.n(2).as_real_imag()[0] for i in r.as_real_imag()] for r in roots] + # make sure the parts were computed with precision + if len(roots) > 1 and any(i._prec == 1 for k in key for i in k): + raise NotImplementedError("could not compute root with precision") + # insert a key to indicate if the root has an imaginary part + key = [(1 if i else 0, r, i) for r, i in key] + key = sorted(zip(key, roots)) + # return the real and imaginary roots separately if desired + if separated: + r = [] + i = [] + for (im, _, _), v in key: + if im: + i.append(v) + else: + r.append(v) + return r, i + _, roots = zip(*key) + return list(roots) + + +def _sort_gens(gens, **args): + """Sort generators in a reasonably intelligent way. """ + opt = build_options(args) + + gens_order, wrt = {}, None + + if opt is not None: + gens_order, wrt = {}, opt.wrt + + for i, gen in enumerate(opt.sort): + gens_order[gen] = i + 1 + + def order_key(gen): + gen = str(gen) + + if wrt is not None: + try: + return (-len(wrt) + wrt.index(gen), gen, 0) + except ValueError: + pass + + name, index = _re_gen.match(gen).groups() + + if index: + index = int(index) + else: + index = 0 + + try: + return ( gens_order[name], name, index) + except KeyError: + pass + + try: + return (_gens_order[name], name, index) + except KeyError: + pass + + return (_max_order, name, index) + + try: + gens = sorted(gens, key=order_key) + except TypeError: # pragma: no cover + pass + + return tuple(gens) + + +def _unify_gens(f_gens, g_gens): + """Unify generators in a reasonably intelligent way. """ + f_gens = list(f_gens) + g_gens = list(g_gens) + + if f_gens == g_gens: + return tuple(f_gens) + + gens, common, k = [], [], 0 + + for gen in f_gens: + if gen in g_gens: + common.append(gen) + + for i, gen in enumerate(g_gens): + if gen in common: + g_gens[i], k = common[k], k + 1 + + for gen in common: + i = f_gens.index(gen) + + gens.extend(f_gens[:i]) + f_gens = f_gens[i + 1:] + + i = g_gens.index(gen) + + gens.extend(g_gens[:i]) + g_gens = g_gens[i + 1:] + + gens.append(gen) + + gens.extend(f_gens) + gens.extend(g_gens) + + return tuple(gens) + + +def _analyze_gens(gens): + """Support for passing generators as `*gens` and `[gens]`. """ + if len(gens) == 1 and hasattr(gens[0], '__iter__'): + return tuple(gens[0]) + else: + return tuple(gens) + + +def _sort_factors(factors, **args): + """Sort low-level factors in increasing 'complexity' order. """ + + # XXX: GF(p) does not support comparisons so we need a key function to sort + # the factors if python-flint is being used. A better solution might be to + # add a sort key method to each domain. + def order_key(factor): + if isinstance(factor, _GF_types): + return int(factor) + elif isinstance(factor, list): + return [order_key(f) for f in factor] + else: + return factor + + def order_if_multiple_key(factor): + (f, n) = factor + return (len(f), n, order_key(f)) + + def order_no_multiple_key(f): + return (len(f), order_key(f)) + + if args.get('multiple', True): + return sorted(factors, key=order_if_multiple_key) + else: + return sorted(factors, key=order_no_multiple_key) + + +illegal_types = [type(obj) for obj in _illegal] +finf = [float(i) for i in _illegal[1:3]] + + +def _not_a_coeff(expr): + """Do not treat NaN and infinities as valid polynomial coefficients. """ + if type(expr) in illegal_types or expr in finf: + return True + if isinstance(expr, float) and float(expr) != expr: + return True # nan + return # could be + + +def _parallel_dict_from_expr_if_gens(exprs, opt): + """Transform expressions into a multinomial form given generators. """ + k, indices = len(opt.gens), {} + + for i, g in enumerate(opt.gens): + indices[g] = i + + polys = [] + + for expr in exprs: + poly = {} + + if expr.is_Equality: + expr = expr.lhs - expr.rhs + + for term in Add.make_args(expr): + coeff, monom = [], [0]*k + + for factor in Mul.make_args(term): + if not _not_a_coeff(factor) and factor.is_Number: + coeff.append(factor) + else: + try: + if opt.series is False: + base, exp = decompose_power(factor) + + if exp < 0: + exp, base = -exp, Pow(base, -S.One) + else: + base, exp = decompose_power_rat(factor) + + monom[indices[base]] = exp + except KeyError: + if not factor.has_free(*opt.gens): + coeff.append(factor) + else: + raise PolynomialError("%s contains an element of " + "the set of generators." % factor) + + monom = tuple(monom) + + if monom in poly: + poly[monom] += Mul(*coeff) + else: + poly[monom] = Mul(*coeff) + + polys.append(poly) + + return polys, opt.gens + + +def _parallel_dict_from_expr_no_gens(exprs, opt): + """Transform expressions into a multinomial form and figure out generators. """ + if opt.domain is not None: + def _is_coeff(factor): + return factor in opt.domain + elif opt.extension is True: + def _is_coeff(factor): + return factor.is_algebraic + elif opt.greedy is not False: + def _is_coeff(factor): + return factor is S.ImaginaryUnit + else: + def _is_coeff(factor): + return factor.is_number + + gens, reprs = set(), [] + + for expr in exprs: + terms = [] + + if expr.is_Equality: + expr = expr.lhs - expr.rhs + + for term in Add.make_args(expr): + coeff, elements = [], {} + + for factor in Mul.make_args(term): + if not _not_a_coeff(factor) and (factor.is_Number or _is_coeff(factor)): + coeff.append(factor) + else: + if opt.series is False: + base, exp = decompose_power(factor) + + if exp < 0: + exp, base = -exp, Pow(base, -S.One) + else: + base, exp = decompose_power_rat(factor) + + elements[base] = elements.setdefault(base, 0) + exp + gens.add(base) + + terms.append((coeff, elements)) + + reprs.append(terms) + + gens = _sort_gens(gens, opt=opt) + k, indices = len(gens), {} + + for i, g in enumerate(gens): + indices[g] = i + + polys = [] + + for terms in reprs: + poly = {} + + for coeff, term in terms: + monom = [0]*k + + for base, exp in term.items(): + monom[indices[base]] = exp + + monom = tuple(monom) + + if monom in poly: + poly[monom] += Mul(*coeff) + else: + poly[monom] = Mul(*coeff) + + polys.append(poly) + + return polys, tuple(gens) + + +def _dict_from_expr_if_gens(expr, opt): + """Transform an expression into a multinomial form given generators. """ + (poly,), gens = _parallel_dict_from_expr_if_gens((expr,), opt) + return poly, gens + + +def _dict_from_expr_no_gens(expr, opt): + """Transform an expression into a multinomial form and figure out generators. """ + (poly,), gens = _parallel_dict_from_expr_no_gens((expr,), opt) + return poly, gens + + +def parallel_dict_from_expr(exprs, **args): + """Transform expressions into a multinomial form. """ + reps, opt = _parallel_dict_from_expr(exprs, build_options(args)) + return reps, opt.gens + + +def _parallel_dict_from_expr(exprs, opt): + """Transform expressions into a multinomial form. """ + if opt.expand is not False: + exprs = [ expr.expand() for expr in exprs ] + + if any(expr.is_commutative is False for expr in exprs): + raise PolynomialError('non-commutative expressions are not supported') + + if opt.gens: + reps, gens = _parallel_dict_from_expr_if_gens(exprs, opt) + else: + reps, gens = _parallel_dict_from_expr_no_gens(exprs, opt) + + return reps, opt.clone({'gens': gens}) + + +def dict_from_expr(expr, **args): + """Transform an expression into a multinomial form. """ + rep, opt = _dict_from_expr(expr, build_options(args)) + return rep, opt.gens + + +def _dict_from_expr(expr, opt): + """Transform an expression into a multinomial form. """ + if expr.is_commutative is False: + raise PolynomialError('non-commutative expressions are not supported') + + def _is_expandable_pow(expr): + return (expr.is_Pow and expr.exp.is_positive and expr.exp.is_Integer + and expr.base.is_Add) + + if opt.expand is not False: + if not isinstance(expr, (Expr, Eq)): + raise PolynomialError('expression must be of type Expr') + expr = expr.expand() + # TODO: Integrate this into expand() itself + while any(_is_expandable_pow(i) or i.is_Mul and + any(_is_expandable_pow(j) for j in i.args) for i in + Add.make_args(expr)): + + expr = expand_multinomial(expr) + while any(i.is_Mul and any(j.is_Add for j in i.args) for i in Add.make_args(expr)): + expr = expand_mul(expr) + + if opt.gens: + rep, gens = _dict_from_expr_if_gens(expr, opt) + else: + rep, gens = _dict_from_expr_no_gens(expr, opt) + + return rep, opt.clone({'gens': gens}) + + +def expr_from_dict(rep, *gens): + """Convert a multinomial form into an expression. """ + result = [] + + for monom, coeff in rep.items(): + term = [coeff] + for g, m in zip(gens, monom): + if m: + term.append(Pow(g, m)) + + result.append(Mul(*term)) + + return Add(*result) + +parallel_dict_from_basic = parallel_dict_from_expr +dict_from_basic = dict_from_expr +basic_from_dict = expr_from_dict + + +def _dict_reorder(rep, gens, new_gens): + """Reorder levels using dict representation. """ + gens = list(gens) + + monoms = rep.keys() + coeffs = rep.values() + + new_monoms = [ [] for _ in range(len(rep)) ] + used_indices = set() + + for gen in new_gens: + try: + j = gens.index(gen) + used_indices.add(j) + + for M, new_M in zip(monoms, new_monoms): + new_M.append(M[j]) + except ValueError: + for new_M in new_monoms: + new_M.append(0) + + for i, _ in enumerate(gens): + if i not in used_indices: + for monom in monoms: + if monom[i]: + raise GeneratorsError("unable to drop generators") + + return map(tuple, new_monoms), coeffs + + +class PicklableWithSlots: + """ + Mixin class that allows to pickle objects with ``__slots__``. + + Examples + ======== + + First define a class that mixes :class:`PicklableWithSlots` in:: + + >>> from sympy.polys.polyutils import PicklableWithSlots + >>> class Some(PicklableWithSlots): + ... __slots__ = ('foo', 'bar') + ... + ... def __init__(self, foo, bar): + ... self.foo = foo + ... self.bar = bar + + To make :mod:`pickle` happy in doctest we have to use these hacks:: + + >>> import builtins + >>> builtins.Some = Some + >>> from sympy.polys import polyutils + >>> polyutils.Some = Some + + Next lets see if we can create an instance, pickle it and unpickle:: + + >>> some = Some('abc', 10) + >>> some.foo, some.bar + ('abc', 10) + + >>> from pickle import dumps, loads + >>> some2 = loads(dumps(some)) + + >>> some2.foo, some2.bar + ('abc', 10) + + """ + + __slots__ = () + + def __getstate__(self, cls=None): + if cls is None: + # This is the case for the instance that gets pickled + cls = self.__class__ + + d = {} + + # Get all data that should be stored from super classes + for c in cls.__bases__: + # XXX: Python 3.11 defines object.__getstate__ and it does not + # accept any arguments so we need to make sure not to call it with + # an argument here. To be compatible with Python < 3.11 we need to + # be careful not to assume that c or object has a __getstate__ + # method though. + getstate = getattr(c, "__getstate__", None) + objstate = getattr(object, "__getstate__", None) + if getstate is not None and getstate is not objstate: + d.update(getstate(self, c)) + + # Get all information that should be stored from cls and return the dict + for name in cls.__slots__: + if hasattr(self, name): + d[name] = getattr(self, name) + + return d + + def __setstate__(self, d): + # All values that were pickled are now assigned to a fresh instance + for name, value in d.items(): + setattr(self, name, value) + + +class IntegerPowerable: + r""" + Mixin class for classes that define a `__mul__` method, and want to be + raised to integer powers in the natural way that follows. Implements + powering via binary expansion, for efficiency. + + By default, only integer powers $\geq 2$ are supported. To support the + first, zeroth, or negative powers, override the corresponding methods, + `_first_power`, `_zeroth_power`, `_negative_power`, below. + """ + + def __pow__(self, e, modulo=None): + if e < 2: + try: + if e == 1: + return self._first_power() + elif e == 0: + return self._zeroth_power() + else: + return self._negative_power(e, modulo=modulo) + except NotImplementedError: + return NotImplemented + else: + bits = [int(d) for d in reversed(bin(e)[2:])] + n = len(bits) + p = self + first = True + for i in range(n): + if bits[i]: + if first: + r = p + first = False + else: + r *= p + if modulo is not None: + r %= modulo + if i < n - 1: + p *= p + if modulo is not None: + p %= modulo + return r + + def _negative_power(self, e, modulo=None): + """ + Compute inverse of self, then raise that to the abs(e) power. + For example, if the class has an `inv()` method, + return self.inv() ** abs(e) % modulo + """ + raise NotImplementedError + + def _zeroth_power(self): + """Return unity element of algebraic struct to which self belongs.""" + raise NotImplementedError + + def _first_power(self): + """Return a copy of self.""" + raise NotImplementedError + + +_GF_types: tuple[type, ...] + + +if GROUND_TYPES == 'flint': + import flint + _GF_types = (flint.nmod, flint.fmpz_mod) +else: + from sympy.polys.domains.modularinteger import ModularInteger + flint = None + _GF_types = (ModularInteger,)